英伟达又赚到了!FlashAttention3来了:H100利用率飙升至75%

740 TFLOPS!迄今最强 FlashAttention 来了。随着大型语言模型(LLM)加速落地,扩展模型上下文窗口变得越来越重要。然而,Transformer 架构的核心 —— 注意力层的时间复杂度和空间复杂度与输入序列长度的平方成正比。这使得扩展模型上下文窗口存在挑战。2022 年,一种快速、内存高效的注意力算法 ——FlashAttention 问世,该算法无需任何近似即可加速注意力并减少内存占用。FlashAttention 对注意力计算进行重新排序的算法,并利用 tiling 和重计算来显著加快计算

740 TFLOPS!迄今最强 FlashAttention 来了。

随着大型语言模型(LLM)加速落地,扩展模型上下文窗口变得越来越重要。然而,Transformer 架构的核心 —— 注意力层的时间复杂度和空间复杂度与输入序列长度的平方成正比。这使得扩展模型上下文窗口存在挑战。

2022 年,一种快速、内存高效的注意力算法 ——FlashAttention 问世,该算法无需任何近似即可加速注意力并减少内存占用。

FlashAttention 对注意力计算进行重新排序的算法,并利用 tiling 和重计算来显著加快计算速度,将内存使用量从序列长度的二次减少到线性。

图片

2023 年,研究团队宣布推出 FlashAttention-2,在算法、并行化和工作分区等方面有了显著改进。

现在,来自 Meta、英伟达、Together AI 等机构的研究者宣布推出 FlashAttention-3,它采用了加速 Hopper GPU 注意力的三种主要技术:

通过 warp-specialization 重叠整体计算和数据移动;

交错分块 matmul 和 softmax 运算;

利用硬件支持 FP8 低精度的不连贯处理。

FlashAttention-3 的速度是 FlashAttention-2 的 1.5-2.0 倍,高达 740 TFLOPS,即 H100 理论最大 FLOPS 利用率为 75%。使用 FP8,FlashAttention-3 的速度更是接近 1.2 PFLOPS。

FlashAttention-3 的改进将带来:

更高效的 GPU 利用率:H100 理论最大 FLOPS 利用率为 75%,而之前仅为 35%。这使得 LLM 的训练和运行速度比以前的版本快得多。

较低精度下更好的性能:FlashAttention-3 可以在保持精度的同时使用较低精度的数字 (FP8)。这可以实现更快的处理速度并可能降低内存使用量,从而为运行大规模人工智能操作的客户节省成本并提高效率。

能够在 LLM 中使用更长的上下文:通过加速注意力机制,FlashAttention-3 使 AI 模型能够更有效地处理更长的文本片段。这使得应用程序能够理解并生成更长、更复杂的内容而不会减慢速度。

图片

论文标题:FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision

论文地址:https://tridao.me/publications/flash3/flash3.pdf

论文作者之一 、FlashAttention1-3 版本的参与者 Tri Dao 表示:FlashAttention 被广泛用于加速 Transformers,已经使注意力速度提高了 4-8 倍,但尚未利用现代 GPU。因而他们发布了 FlashAttention-3:在 FP16 上速度提高了 1.5-2 倍,在 H100 上高达 740 TFLOPS(75% 实用性),FP8 接近 1.2 PFLOPS!

图片

Hopper GPU 硬件特性:WGMMA、TMA、FP8

虽然 FlashAttention-2 在 Ampere (A100) GPU 上可以实现 70% 的理论最大 FLOPS,但它尚未利用 Hopper GPU 上的新功能来最大限度地提高性能。接下来文章描述了一些新的 Hopper 特定功能,以及它们为何如此重要。

首先是 WGMMA(Warpgroup Matrix Multiply-Accumulate),该功能利用了 Hopper 架构上新的张量内核,比 Ampere 架构具有更高的吞吐量。

图片

然后是 TMA(Tensor Memory Accelerator),这是一个特殊的硬件单元,可以加速全局内存和共享内存之间的数据传输,用于处理所有索引计算和边界外预测。这样一来寄存器就释放了,寄存器是增加 tile 大小和效率的宝贵资源。

图片

低精度 FP8,让 Tensor Core 吞吐量翻了一倍。

图片

FlashAttention-3 充分利用了 Hopper 架构的所有这些新功能。

异步:GEMM 和 Softmax 重叠

注意力机制主要有两个操作,GEMM 和 softmax。为什么要将它们重叠?

问题在于在现代加速器上,非矩阵乘法(matmul)运算比矩阵乘法运算慢。特殊函数如指数运算(如 softmax 函数)的吞吐量甚至低于浮点乘加操作;这些运算是由多功能单元处理的,这是一个与浮点乘加或矩阵乘加不同的单元。

理想情况下,研究者希望矩阵乘法和 softmax 能够并行操作。当 Tensor Cores 忙于矩阵乘法时,多功能单元应当在计算指数运算! 

Inter-warpgroup 重叠

重叠 GEMM 和 softmax 最简单的方法是什么都不做,warp 调度程序会免费完成部分重叠。下图说明了 pingpong 调度,其中相同的颜色表示相同的迭代。

图片

Intra-warpgroup 重叠

即使在一个 warpgroup 中,研究者也可以在运行该 warpgroup 的 GEMM 时运行 softmax 的某些部分。如图所示,相同的颜色表示相同的迭代。

图片

这种 pipeline 流程可以将 FP16 注意力前向传播的吞吐量从大约 620 TFLOPS 提高到 640-660 TFLOPS,但代价是更高的寄存器压力,因而需要更多的寄存器来同时保存 GEMM 的累加器以及 Softmax 的输入 / 输出。

低精度:使用非相干处理减少量化误差

激活 LLM 可能存在一些极端值,导致量化困难,从而产生较大的量化误差。本文采用非相干处理(incoherent processing),该技术通过将查询和键与一个随机正交矩阵相乘来「分散(spread out)」极端值,从而减少量化误差。特别地,该研究使用了 Hadamard 变换,它可以在每个注意力头中以 O (d log d) 的时间复杂度完成,而不是 O (d^2),其中 d 是头部维度。

研究者发现非相干处理可以将量化误差减少很多,具体的数值误差比较见下表。

图片

实验

文中展示了 FlashAttention-3 的一些结果,并将其与 FlashAttention-2 以及 Triton 和 cuDNN 中的实现进行了比较(两者都已经使用了 Hopper GPU 的新硬件功能)。

在 FP16 精度下,FlashAttention-3 的速度是 FlashAttention-2 的 1.5-2.0 倍。

图片

对于 FP8,FlashAttention-3 接近 1.2 PFLOPS。

图片

扩展阅读:

斯坦福提出新型Attention算法!提速2-4倍,BERT单节点训练最快 

比标准Attention提速5-9倍,大模型都在用的FlashAttention v2来了

参考链接:

https://tridao.me/blog/2024/flash3/

相关资讯

比标准Attention提速5-9倍,大模型都在用的FlashAttention v2来了

一年时间,斯坦福大学提出的新型 Attention 算法 ——FlashAttention 完成了进化。这次在算法、并行化和工作分区等方面都有了显著改进,对大模型的适用性也更强了。

新PyTorch API:几行代码实现不同注意力变体,兼具FlashAttention性能和PyTorch灵活性

用 FlexAttention 尝试一种新的注意力模式。理论上,注意力机制就是你所需要的一切。然而在实际操作中,我们还需要优化像 FlashAttention 这样的注意力机制的实现。尽管这些融合的注意力机制大大提高了性能,且支持长上下文,但这种效率的提升也伴随着灵活性的丧失。对于机器学习研究人员来说,这就像是一种「软件彩票」—— 如果你的注意力变体不适合现有的优化内核,你将面临运行缓慢和 CUDA 内存不足的困境。 一些注意力变体包括因果注意力、相对位置嵌入、Alibi、滑动窗口注意力、PrefixLM、文档掩码