只需百行代码,让H100提速30%,斯坦福开源全新AI加速框架

提高 GPU 利用率,就是这么简单。AI 的快速发展,伴随而来的是大计算量。这就自然而然的引出了一个问题:如何减少 AI 对计算的需求,并提高现有 AI 计算效率。为了回答这一问题,来自斯坦福的钻研者在博客《GPUs Go Brrr》中给出了答案。博客地址::一是硬件真正需要什么?二是如何满足硬件需求?文章用大量篇幅讨论了如何让 GPU 更快的运行,并发布了一个库 ThunderKittens,用户可以很容易地在 CUDA 上编辑快速的深度学习内核。其具有以下特点:简单,ThunderKittens 写起来非常简单

提高 GPU 利用率,就是这么简单。

AI 的快速发展,伴随而来的是大计算量。这就自然而然的引出了一个问题:如何减少 AI 对计算的需求,并提高现有 AI 计算效率。

为了回答这一问题,来自斯坦福的钻研者在博客《GPUs Go Brrr》中给出了答案。

只需百行代码,让H100提速30%,斯坦福开源全新AI加速框架

博客地址:https://hazyresearch.stanford.edu/blog/2024-05-12-tk

文章主要专注于两个问题:一是硬件真正需要什么?二是如何满足硬件需求?

文章用大量篇幅讨论了如何让 GPU 更快的运行,并发布了一个库 ThunderKittens,用户可以很容易地在 CUDA 上编辑快速的深度学习内核。其具有以下特点:

简单,ThunderKittens 写起来非常简单。

可扩展性,如果用户需要 ThunderKittens 无法提供的功能,可以进行功能扩展。

速度快。

只需百行代码,让H100提速30%,斯坦福开源全新AI加速框架

GitHub 链接:https://github.com/HazyResearch/ThunderKittens

ThunderKittens 使得一些棘手的事情变得非常简单,从而在现代硬件上实现了非常高的利用率。项目中,作者用ThunderKittens 编辑了一个 RTX 4090 简单的 FlashAttention-2 内核,代码总共有 58 行代码(不包括空格),结果显示,ThunderKittens 在 RTX 4090 上实现了大约 122 TFLOP(理论最大值的 74%)。此外,内核程序只有 100 行的情况下,ThunderKittens 在 H100 上的性能比 FlashAttention-2 高出约 30%。

英伟达 H100 有些小怪癖

该钻研重点关注 NVIDIA H100,不过所介绍的内容也适用于其他 GPU。

只需百行代码,让H100提速30%,斯坦福开源全新AI加速框架

H100 SXM GPU 包含:

80 GB HBM3,带宽为 3 TB/s(实际上带宽会少一些);

50 MB 二级缓存,带宽 12 TB/s,在 GPU 上分成两个 25MB 的部分,通过 crossbar 连接;

132 个流多处理器 (SM,streaming multiprocessors)。

除了上述这些,H100 SXM GPU 还有很多可关注的东西,例如内存控制器、指令缓存等。

钻研者表示保持张量核心的运行流畅并不容易。他们发现了一些 AI 硬件上的怪癖,这些怪癖中的很多内容也适用于非 H100 GPU,但 H100 尤其棘手。(相比之下,RTX 4090 则非常容易使用),这些怪癖包括:

WGMMA 指令是必需的,但使用起来也非常令人恼火;

共享内存实际上并没有那么快,并且需要非常小心;

地址生成成本很高;

占用率仍然有帮助,寄存器通常是关键资源。

只需百行代码,让H100提速30%,斯坦福开源全新AI加速框架

                                文章进一步描述了 GPU 这些怪癖的具体内容。

WGMMA 指令令人恼火

H100 有一组新指令,称为「warp group matrix multiply accumulate,WGMMA」(PTX 中的 wgmma.mma_async,或 SASS 中的 HGMMA/IGMMA/QGMMA/BGMMA)。以前的 GPU 上可用的张量核心指令是 wmma.mma.sync 和 mma.sync 。通过这些指令,SM 单个象限上的 32 个线程将同步地将其数据块馈送到张量核心并等待结果。 

不同的是,wgmma.mma_async 指令并非如此,128 个连续线程(分布在 SM 的所有象限中)协作同步,并直接从共享内存(也可以选择寄存器)异步启动矩阵乘法。

在基准测试中,钻研团队发现这些指令对于提取 H100 的完整计算是必要的。如果没有它们,GPU 的峰值利用率似乎只能达到峰值利用率的 63% 左右。

只需百行代码,让H100提速30%,斯坦福开源全新AI加速框架

共享内存

共享内存的单次访问延迟约为 30 个周期,这听起来似乎不算多,但在这段时间内,SM 的张量核心几乎可以完成两个完整的 32×32 矩阵乘法运算。

共享内存处理起来有些棘手,因为它被存储(banked)在 32 个独立的内存存储中。如果不小心,这可能会导致所谓的 bank 冲突,即同一内存 bank 被要求同时提供多个不同的内存片段,导致请求被串行化,这可能会不成比例地减慢内核的速度 – 而 wgmma 和 mma 指令所需的寄存器结构会受到这些 bank 冲突的影响。解决方法是使用各种交错模式重新排列共享内存,以避免这些冲突。

地址生成

H100 其中一个特点是张量核心和内存都足够快,以至于仅仅生成用于获取数据的内存地址就占据了芯片资源的相当一部分。

NVIDIA 似乎已经意识到了这一点,因为他们赋予了 GPU 张量内存加速器(或称之为 TMA)。TMA 允许用户在全局和共享内存中指定多维张量结构,这节省了所有的地址生成成本,并且还使得构建 pipeline 更加容易。

钻研团队还发现 TMA 和 wgmma.mma_async 一样,在实现 H100 的全部潜力方面是完全不可或缺的。

占用

在某些方面,与前几代硬件相比,H100 对占用率的依赖程度较低。NVIDIA 确实在设计 GPU 时考虑了占用率。虽然对于 H100 来说,占用率只能说有用,但作用不大。钻研者发现在 A100 和 RTX 4090 上它变得越来越重要。

ThunderKittens

那么,如何才能更轻松地编辑内核,同时仍兼具硬件的全部功能?

钻研团队设计了一个嵌入 CUDA 中的 DSL,被命名为 ThunderKittens。

只需百行代码,让H100提速30%,斯坦福开源全新AI加速框架

ThunderKittens 旨在尽可能简单,并包含四种模板类型:

寄存器 tile—— 寄存器文件中的 2D 张量。

寄存器向量 —— 寄存器文件中的 1D 张量。

共享 tile—— 共享内存中的 2D 张量。

共享向量 —— 共享内存中的 1D 张量。

tile 通过高度、宽度和结构进行参数化,寄存器向量由长度和结构参数化,共享向量仅由长度参数化。这样通常不会遭受 bank 冲突的困扰。

钻研团队还提供了一些必要操作:

初始化,如将共享向量清零

一元运算,如 exp

二元运算,如 mul

行 / 列操作,如 row_sum

该钻研给出了一个用 ThunderKittens 编辑的,用于 RTX 4090 的简单前向 flash attention 内核:

#define NUM_WORKERS 16 // This kernel uses 16 workers in parallel per block, to help issue instructions more quickly.
using namespace kittens; // this kernel only handles headdim=64 for simplicity. Also n should be a multiple of 256 here.
__global__ void attend_ker64(int n, const bf16* __restrict__ __q__, const bf16* __restrict__ __k__, const bf16* __restrict__ __v__, bf16* __o__) {

    auto warpid        = kittens::warpid();
    auto block_start   = blockIdx.x*(n*64);
    const bf16 *_q = __q__ + block_start, *_k = __k__ + block_start, *_v = __v__ + block_start;
          bf16 *_o = __o__ + block_start;

    extern __shared__ alignment_dummy __shm[]; // this is the CUDA shared memory
    shared_allocator al((int*)&__shm[0]);     

    // K and V live in shared memory -- this is about all that will fit.
    st_bf_1x4<ducks::st_layout::swizzle> (&k_smem)[NUM_WORKERS] = al.allocate<st_bf_1x4<ducks::st_layout::swizzle>, NUM_WORKERS>();
    st_bf_1x4<ducks::st_layout::swizzle> (&v_smem)[NUM_WORKERS] = al.allocate<st_bf_1x4<ducks::st_layout::swizzle>, NUM_WORKERS>();

    // Initialize all of the register tiles.
    rt_bf_1x4<> q_reg, k_reg, v_reg; // v_reg need to be swapped into col_l
    rt_fl_1x1<> att_block;
    rt_bf_1x1<> att_block_mma;
    rt_fl_1x4<> o_reg;
    rt_fl_1x1<>::col_vec max_vec_last, max_vec; // these are column vectors for the attention block 
    rt_fl_1x1<>::col_vec norm_vec_last, norm_vec; // these are column vectors for the attention block        

    int qo_blocks = n / (q_reg.rows*NUM_WORKERS), kv_blocks = n / (q_reg.rows*NUM_WORKERS);
    for(auto q_blk = 0; q_blk < qo_blocks; q_blk++) {

        // each warp loads its own Q tile of 16x64, and then multiplies by 1/sqrt(d)
        load(q_reg, _q + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
        mul(q_reg, q_reg, __float2bfloat16(0.125f)); // temperature adjustment

        // zero flash attention L, M, and O registers.
        neg_infty(max_vec); // zero registers for the Q chunk
        zero(norm_vec);
        zero(o_reg);

        // iterate over k, v for these q's that have been loaded
        for(auto kv_idx = 0; kv_idx < kv_blocks; kv_idx++) {

            // each warp loads its own chunk of k, v into shared memory
            load(v_smem[warpid], _v + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);  
            load(k_smem[warpid], _k + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
            __syncthreads(); // we need to make sure all memory is loaded before we can begin the compute phase

            // now each warp goes through all of the subtiles, loads them, and then does the flash attention internal alg.
            for(int subtile = 0; subtile < NUM_WORKERS; subtile++) {

                load(k_reg, k_smem[subtile]); // load k from shared into registers
                zero(att_block); // zero 16x16 attention tile
                mma_ABt(att_block, q_reg, k_reg, att_block); // [email protected]

                copy(norm_vec_last, norm_vec);
                copy(max_vec_last,  max_vec);

                row_max(max_vec, att_block, max_vec); // accumulate onto the max_vec
                sub_row(att_block, att_block, max_vec); // subtract max from attention -- now all <=0
                exp(att_block, att_block); // exponentiate the block in-place.

                sub(max_vec_last, max_vec_last, max_vec); // subtract new max from old max to find the new normalization.
                exp(max_vec_last, max_vec_last); // exponentiate this vector -- this is what we need to normalize by.
                mul(norm_vec, norm_vec, max_vec_last); // and the norm vec is now normalized.

                row_sum(norm_vec, att_block, norm_vec); // accumulate the new attention block onto the now-rescaled norm_vec
                div_row(att_block, att_block, norm_vec); // now the attention block is correctly normalized

                mul(norm_vec_last, norm_vec_last, max_vec_last); // normalize the previous norm vec according to the new max
                div(norm_vec_last, norm_vec_last, norm_vec); // normalize the previous norm vec according to the new norm

                copy(att_block_mma, att_block); // convert to bf16 for mma_AB

                load(v_reg, v_smem[subtile]); // load v from shared into registers.  
                rt_bf_1x4<ducks::rt_layout::col> &v_reg_col = swap_layout_inplace(v_reg); // this is a reference and the call has invalidated v_reg

                mul_row(o_reg, o_reg, norm_vec_last); // normalize o_reg in advance of mma_AB'ing onto it
                mma_AB(o_reg, att_block_mma, v_reg_col, o_reg); // mfma onto o_reg with the local attention@V matmul.
            }
            __syncthreads(); // we need to make sure all warps are done before we can start loading the next kv chunk
        }

        store(_o + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, o_reg, q_reg.cols); // write out o. compiler has an issue with register usage if d is made constexpr q_reg.rows :/
    }
}

总共大约有 60 行 CUDA 代码,硬件利用率为 75%,虽然非常密集,但大部分复杂性在于算法,而不是混合模式或寄存器结构。

TMA、WGMMA、swizzling 模式和描述符的复杂度又如何呢?如下是用 ThunderKittens 编辑的, H100 的 FlashAttention-2 前向传递:

template<int D>
__global__  __launch_bounds__((NUM_WORKERS)*kittens::WARP_THREADS, 2)
void fwd_attend_ker_dim(int N, const CUtensorMap* tma_q, const CUtensorMap* tma_k, const CUtensorMap* tma_v, CUtensorMap* tma_o) {
    extern __shared__ int __shm[]; // this is the CUDA shared memory
    tma_swizzle_allocator al((int*)&__shm[0]);

    constexpr int tile_width = fwd_attend_ker_tile_dims<D>::tile_width; // constants
    constexpr int qo_height  = fwd_attend_ker_tile_dims<D>::qo_height;
    constexpr int kv_height  = fwd_attend_ker_tile_dims<D>::kv_height;

    st_bf<qo_height, tile_width, layout_q>          (&q_smem)   [NUM_WARPGROUPS] = al.allocate<st_bf<qo_height, tile_width, layout_q>,          NUM_WARPGROUPS>();
    st_bf<kv_height, tile_width, layout_k>          (&k_smem)[2][NUM_WORKERS_KV] = al.allocate<st_bf<kv_height, tile_width, layout_k>, 2,       NUM_WORKERS_KV>();
    st_bf<kv_height, tile_width, layout_v>          (&v_smem)[2][NUM_WORKERS_KV] = al.allocate<st_bf<kv_height, tile_width, layout_v>, 2,       NUM_WORKERS_KV>();

    int tic = 0, toc = 1;

    rt_fl<1, kv_height> att_block;
    rt_bf<1, kv_height> att_block_mma;
    rt_fl<1, qo_height> o_prev;
    col_vec<rt_fl<1, kv_height>> max_vec_last, max_vec;
    col_vec<rt_fl<1, kv_height>> norm_vec_last, norm_vec;

    int warpid      = kittens::warpid();
    int warpgroupid = warpid/kittens::WARPGROUP_WARPS;

    int kv_blocks = N / (NUM_WORKERS_KV*k_smem[0][0].rows);

    __shared__ uint64_t qsmem_barrier, kvsmem_barrier;//, vsmem_barrier;

    int q_phasebit = 0;
    int kv_phasebit = 0;

    if (threadIdx.x == 0) {
        tma::init_barrier<st_bf<qo_height, tile_width, layout_q>, NUM_WARPGROUPS>(qsmem_barrier, 1);
        tma::init_barrier<st_bf<kv_height, tile_width, layout_k>, NUM_WORKERS_KV*2>(kvsmem_barrier, 1);
     }

    if (warpid == 0) {
        for (int wg = 0; wg < NUM_WORKERS/kittens::WARPGROUP_WARPS; wg++) { // load q
             int tile_idx = (blockIdx.y * NUM_WARPGROUPS * gridDim.x) + (blockIdx.x * NUM_WARPGROUPS) + wg;  
             tma::load_async((q_smem[wg]), tma_q, qsmem_barrier, tile_idx);
         }
        for (int w = 0; w < NUM_WORKERS_KV; w++) { // load k, v
             int tile_idx = (blockIdx.y * NUM_WORKERS_KV * kv_blocks) + (0 * NUM_WORKERS_KV) + w;
             tma::load_async((k_smem[tic][w]), tma_k, kvsmem_barrier, tile_idx);
             tma::load_async((v_smem[tic][w]), tma_v, kvsmem_barrier, tile_idx);
         }
    }

    neg_infty(max_vec); // zero registers for the Q chunk
    zero(norm_vec);
    zero(o_prev);
    __syncthreads();

    tma::arrive_and_wait(qsmem_barrier, q_phasebit);
    q_phasebit ^= 1;

    if constexpr (D == 64) { warpgroup::mul(q_smem[warpgroupid], q_smem[warpgroupid], __float2bfloat16(0.125f)); }
     else { warpgroup::mul(q_smem[warpgroupid], q_smem[warpgroupid], __float2bfloat16(0.08838834764f)); }

    for (auto kv_idx = 0; kv_idx < kv_blocks; kv_idx++, tic ^= 1, toc ^= 1) {
        tma::arrive_and_wait(kvsmem_barrier, kv_phasebit);
        kv_phasebit ^= 1;

        __syncthreads();
        if (warpid == 0) {
            tma::set_bytes(kvsmem_barrier, 2 * NUM_WORKERS_KV * k_smem[0][0].num_elements * sizeof(bf16));

            if (kv_idx + 1 < kv_blocks) {
                for (int w = 0; w < NUM_WORKERS_KV; w++) {
                     int tile_idx = (blockIdx.y * NUM_WORKERS_KV * kv_blocks) + ((kv_idx + 1) * NUM_WORKERS_KV) + w;
                     tma::load_async((k_smem[toc][w]), tma_k, kvsmem_barrier, tile_idx);
                     tma::load_async((v_smem[toc][w]), tma_v, kvsmem_barrier, tile_idx);
                }
            }
        }

        warpgroup::mma_fence(att_block);
        warpgroup::mm_ABt(att_block, q_smem[warpgroupid], k_smem[tic][0]);
        warpgroup::mma_commit_group();

        copy(norm_vec_last, norm_vec);
        copy(max_vec_last,  max_vec);

        warpgroup::mma_async_wait();

        row_max(max_vec, att_block, max_vec); // accumulate onto the max_vec
        sub_row(att_block, att_block, max_vec);
        exp(att_block, att_block);

        sub(max_vec_last, max_vec_last, max_vec);
        exp(max_vec_last, max_vec_last);
        mul(norm_vec, norm_vec, max_vec_last);

        row_sum(norm_vec, att_block, norm_vec); // accumulate onto the norm_vec
        div_row(att_block, att_block, norm_vec);

        mul(norm_vec_last, norm_vec_last, max_vec_last);
        div(norm_vec_last, norm_vec_last, norm_vec);

        copy(att_block_mma, att_block); // convert to bf16 for mma
        mul_row(o_prev, o_prev, norm_vec_last); // normalize o_prev in advance of mma'ing onto it

        warpgroup::mma_fence(o_prev);
        warpgroup::mma_AB(o_prev, att_block_mma, v_smem[tic][0]);
        warpgroup::mma_commit_group();
    }

    auto (*o_smem) = reinterpret_cast<st_bf<qo_height, tile_width, layout_o>(*)>(q_smem); // reuse q memory
    warpgroup::store(o_smem[warpgroupid], o_prev);
     __syncthreads();

        if (warpid % 4 == 0) { // store o
        int tile_idx = (blockIdx.y * NUM_WARPGROUPS * gridDim.x) + (blockIdx.x * NUM_WARPGROUPS) + warpgroupid;
        tma::store_async(tma_o, (o_smem[warpgroupid]), tile_idx);
        tma::store_commit_group();
     }

    tma::store_async_wait();
}

这个内核只有 100 行代码,它在 H100 上的性能比 FlashAttention-2 高出约 30%。ThunderKittens 负责 wrap up 结构和指令,并提供一个可以在 GPU 上使用的 mini-pytorch。

只需百行代码,让H100提速30%,斯坦福开源全新AI加速框架                             H100 SXM 上各种配置的 FlashAttention-2(Pytorch)与 ThunderKittens 的比较。

此外,钻研团队还发布了基于线性注意力的内核和其他架构。基于线性注意力内核的运行速度为 215 TFLOP(如果考虑算法中固有的重计算,则运行速度超过 300 TFLOP)。

虽然理论上线性注意力更高效,但从实践经验来看,线性注意力在硬件上的效率大大降低。因此,ThunderKittens 有望开辟广泛的高吞吐量应用。只需百行代码,让H100提速30%,斯坦福开源全新AI加速框架                            使用 ThunderKittens 可以非常快地实现线性注意力。

tile 看起来是个好点子

在钻研团队看来,ThunderKittens 之所以运行良好,是因为它不会试图做所有事情。CUDA 确实比 ThunderKittens 更有表现力,而 ThunderKittens 又小又简单。

不过,ThunderKittens 具有很好的抽象能力,它具有小的 tile,这与 AI 和硬件的发展相匹配。ThunderKittens 不支持任何少于 16 的维数。但在钻研团队看来,这一点并不重要,尤其对于硬件而言。如果你的矩阵乘法小于 16×16,你确定自己做的还是 AI 吗?

从哲学的视角来看,钻研团队认为框架迁移是合理的。「寄存器」当然不应该像旧 CPU 那样的 32 位。CUDA 使用的 1024 位宽向量寄存器无疑朝着正确方向迈出了一步。但对钻研团队而言,「寄存器」是 16×16 的数据 tile。他们认为 AI 想要这样,它仍然只是矩阵乘法、规约和重塑。当然硬件也想要这样,小的矩阵乘法寻求硬件支持,而不仅仅是 systolic mma。

实际上,从更广泛的视角来看,钻研团队认为应该围绕硬件的良好映射来重新调整 AI 思路。比如,循环状态应该有多大?SM 能够容纳多大尺寸?计算密度是多少?这些都不亚于硬件的要求。

钻研团队表示,这项工作未来的一个重要方向是利用他们对硬件的了解来帮助设计与硬件相匹配的 AI。

最后,AMD 硬件上适配的 ThunderKittens 也将很快推出。

给TA打赏
共{{data.count}}人
人已打赏
工程

字节开源大模型量化新思绪,2-bit量化模型精度齐平fp16

2024-5-13 14:47:00

工程

西浦、利物浦大学提出:点云数据巩固首个周全综述

2024-5-14 7:04:00

0 条回复 A文章作者 M管理员
    暂无讨论,说说你的看法吧
个人中心
今日签到
搜索