近来,几种长上下文谈话模型陆续问世,包括 GPT-4(上下文长度为 32k)、MosaicML 的 MPT(上下文长度为 65k)Anthropic 的 Claude(上下文长度为 100k)。长文档盘问和故事写作等新兴用例已经表明扩展谈话模型上下文窗口是非常必要的。
然而,扩大 Transformer 的上下文长度是一个挑战,因为其核心的注意力层在时间复杂度和空间复杂度与输入序列长度的平方成正比。
一年前,来自斯坦福大学、纽约州立大学布法罗分校的研究者共同提出一种快速、内存高效的注意力算法 ——FlashAttention。该算法无需任何近似即可减速注意力并削减内存占用。现在,已经有许多机构和研究实验室采用 FlashAttention 来减速训练和推理。
FlashAttention 示意图。
尽管 FlashAttention 的速率已经是优化基线的 2-4 倍,但它仍然有相当大的改进空间。FlashAttention 仍然不如优化过的矩阵乘法 (GEMM) 运算快,仅达到理论最大 FLOPs/s 的 25-40%。
现在,研究团队宣布推出 FlashAttention-2。FlashAttention-2 完全从头开始重写,运用 Nvidia 的 CUTLASS 3.x 及其核心库 CuTe 的原语(primitive)。
FlashAttention-2 开发者 Tri Dao。他是斯坦福大学博士生,还是 Together.AI 首席科学家,并将于 2024 年 9 月开始任职普林斯顿大学计算机科学助理教授。
FlashAttention-2 的速率是 FlashAttention 的 2 倍,在 A100 GPU 上达到 230 TFLOPs/s。在端到端训练 GPT 类谈话模型时,FlashAttention-2 可让训练速率高达 225 TFLOPs/s(模型 FLOP 利用率为 72%)。
FlashAttention-2 将减速现有模型的训练、微调和推理。这意味着我们可以用相同成本训练 2 倍上下文长度的谈话模型。这将有助于谈话模型理解长篇书籍和报告、高分辨率图像、音频和视频。
项目地址:https://github.com/Dao-AILab/flash-attention
技术报告:https://tridao.me/publications/flash2/flash2.pdf
FlashAttention 是什么?
FlashAttention 是一种重新排序注意力计算的算法,它利用平铺、重计算等经典技术来显著提升计算速率,并将序列长度中的内存运用实现从二次到线性削减。其中平铺意味着将输入块从 HBM(GPU 内存)加载到 SRAM(快速缓存),并对该块执行注意力操作,更新 HBM 中的输出。
此外通过不将大型中间注意力矩阵写入 HBM,内存读写量削减,带来了 2-4 倍的时钟时间减速。
下图为 FlashAttention 的前向传递图:通过平铺和 softmax 重新缩放,研究者按块进行操作,避免从 HBM 中读取 / 写入,同时获得正确的输出,无需近似操作。
然而,FlashAttention 仍然存在一些低效率问题,原因在于不同线程块之间的工作分区不理想以及 GPU 上的 warp。这些导致低占用率或不必要的共享内存读写。
FlashAttention-2
更好的算法、并行化和工作分区
更少的非矩阵乘法 Flops
研究者调整了 FlashAttention 的算法,从而削减了非矩阵乘法(non-matmul)的 Flops 数量。这点很重要,因为现代 GPU 具有专门的计算单元(例如 Nvidia GPU 上的张量核心),使得矩阵乘法速率更快。
举例而言,A100 GPU 的 FP16/BF16 矩阵乘法的最大理论吞吐量为 312 TFLOPs/s,但非矩阵乘法 FP32 的理论吞吐量仅为 19.5 TFLOPs/s。
换一种思考方式,每一个非矩阵乘法 FLOP 比矩阵乘法 FLOP 的代价高 16 倍。为了保持高吞吐量,研究者希望在矩阵乘法 FLOP 上花费尽可能多的时间。因此他们重写了 FlashAttention 中运用的在线 softmax 技巧,以削减重新缩放操作、边界检查和因果掩码操作的数量,而无需更改输出。
更好的并行化
FlashAttention v1 在批大小和头(head)数量上进行并行化。研究者运用 1 个线程块来处理一个注意力头,总共有(批大小 * 头数量)个线程块。每一个线程块都计划在流式多处理器(SM)上运行,例如 A100 GPU 上有 108 个这样的 SM。当这个数字非常大(如 >= 80)时,这种调度是有效的,这时可以高效地运用 GPU 上几乎所有计算资源。
在长序列的情况下(通常意味着小批量或少量头),为了更好地利用 GPU 上的多处理器,现在研究者在序列长度维数上额外地进行并行化,使该机制显著减速。
更好的工作分区
即使在每一个线程块内,研究者也必须决定如何在不同的 warp 之间划分工作(一组 32 个线程一起工作)。通常情况下,每一个线程块运用 4 或 8 个 warp,分区方案如下图所述。
研究者改进了 FlashAttention-2 中的这种分区,削减不同 warp 之间的同步和通信量,进而削减共享内存读写。
对于每一个块,FlashAttention 将 K 和 V 分割到 4 个 warp 上,同时保持 Q 可被所有 warp 访问。这被称为「sliced-K」方案。不过,这种方案是低效的,原因在于所有 warp 都需要将它们的中间结果写入共享内存,并同步,然后将中间结果相加。这些共享内存读写会减慢 FlashAttention 中的前向传递速率。
在 FlashAttention-2 中,研究者将 Q 分割在 4 个 warp 上,同时保持 K 和 V 可被所有的 warp 访问。每一个 warp 执行矩阵乘法以获得 Q K^T 的切片,然后只需与 V 的共享切片相乘就能获得相应的输出切片。warp 之间不需要通信。共享内存读写的削减也可以提升速率。
新特性:头维数高达 256、多盘问注意力
我们知道,FlashAttention 仅支持最高 128 的头维数,这适用于大多数模型,但有一些模型被遗漏了。
因此,FlashAttention-2 支持了高达 256 的头维数,这意味着 GPT-J、CodeGen 和 CodeGen2、StableDiffusion 1.x 等模型可以运用 FlashAttention-2 来获得减速和节省内存。
此外,FlashAttention-2 还支持了多盘问注意力(multi-query attention, MQA)以及分组盘问注意力(grouped-query attention, GQA)。它们是注意力的变体,其中多个盘问头关注相同的键和值头,以削减推理过程中 KV 缓存的大小,并可以显著提高推理吞吐量。
注意力基准结果
研究者在 A100 80GB SXM4 GPU 上,测量不同设置(无 / 有因果掩码、头维数 64 或 128)下不同注意力方法的运行时。
结果发现, FlashAttention-2 的速率是 FlashAttention(以及 xformers 库和 Triton 中的其他实现)的 2 倍。与 PyTorch 中的标准注意力实现相比,FlashAttention-2 的速率最高是它们的 9 倍。
A100 GPU 上的注意力前向 + 后向速率。
此外只需要在 H100 GPU 上 运行相同的实现(不运用特殊指令来利用 TMA 和第四代 Tensor Core 等新硬件功能),研究者最高获得了 335 TFLOPs/s。
H100 GPU 上的注意力前向 + 后向速率。
当用于端到端 GPT 类模型训练时,FlashAttention-2 有助于在 A100 GPU 上实现最高 225 TFLOPs/s(模型 FLOPs 利用率为 72%)。与优化良好的 FlashAttention 模型相比,端到端实现 1.3 倍减速。
这里的基线是不运用 FlashAttention 的 Megatron-LM,它现在也可以选择运用 FlashAttention 了。不久的将来,FlashAttention-2 也将集成到 Megatron-LM 中。
研究团队表示:下一步将针对 H100 GPU 优化 FlashAttention-2,以运用新的硬件功能。
参考链接:
https://princeton-nlp.github.io/flash-atttention-2/