找回密码
立即注册
搜索
热搜: Java Python Linux Go
发回帖 发新帖

3280

积分

0

好友

438

主题
发表于 4 小时前 | 查看: 7| 回复: 0

故事是这样的。2022 年夏天,斯坦福的一组研究者发了篇论文,标题叫《FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness》。

这名字听着挺无聊的,对吧?一堆缩写、一个技术名词,跟每天几百篇论文里随便哪篇长得都一样。但我读完却觉得,这件事挺有意思的。因为它解决了一个几乎所有做 Transformer 的人都踩过坑的问题:大模型推理慢,到底慢在哪里。

很多人第一反应是“算不动”,觉得算力不够,换更大的 GPU 就行。但这篇文章的作者 Tri Dao 一帮人说了句不一样的话:不是算不动,是数据搬不动。他们管这个叫 IO-Awareness ——在写注意力算法的时候,不要只盯着 FLOP 数看,要把 GPU 显存层级之间的读写次数也算进去。

听起来像废话对吧,但就是这个“废话”,让 Transformer 的训练速度提升了 3 倍,内存消耗降低了 20 倍,还让 Transformer 第一次在 64K 长度的序列上跑出了好结果。64K。之前所有方法,要么跑不了,要么跑出来跟随机猜差不多。

这篇文章,我想把这个论文的核心内容,用人话讲清楚。


1. 背景:Transformer 的瓶颈到底在哪里

问题:注意力机制的平方复杂度

先看最基本的注意力公式:

O = softmax(QK^T / sqrt(d)) V

三个矩阵相乘,再加一个 softmax。看起来很简单对吧。但问题在于,Q 和 K 的矩阵乘法会生成一个 N×N 的注意力矩阵,其中 N 是序列长度。这意味着时间和空间复杂度都是 O(N²)。序列长度翻倍,计算量翻四倍。这个平方复杂度是 Transformer 的天然缺陷,从 2017 年那篇“Attention Is All You Need”出来就没变过。

现有方案的局限

过去几年,一堆人想解决这个问题。方案主要分两类:

第一类,近似注意力。 用稀疏化、低秩分解、核函数近似等等手段,把注意力矩阵从 N×N 压缩到接近 N×1。这类方法理论上把计算复杂度降到了线性或近线性,但实际跑起来,墙钟时间(wall‑clock time)并没有明显加速。为什么?因为很多方案只关注减少 FLOP,忽略了内存访问的开销

第二类,稀疏注意力。 让每个 token 只关注有限的其他 token,直接剪掉大量注意力连接。这种方法确实减少了计算量,但稀疏模式本身也有内存访问的 overhead,而且效果往往不如 dense attention。

作者的判断

这篇文章的作者认为,现有方案没效果的根本原因不是算法不行,而是没有考虑 GPU 内存层级的 IO 特性。现代 GPU 的内存层级是这样的:

GPU 内存层级:SRAM、HBM、DRAM 带宽与容量对比

  • DRAM(系统内存):容量最大(几十 GB 到几百 GB),速度最慢(大约 12.8 GB/s)
  • HBM(高带宽内存):GPU 显存,容量中等(40–80 GB),速度中等(大约 1.5–2.0 TB/s)
  • SRAM(片上缓存):容量最小(A100 每个 SM 大约 192 KB),速度最快(大约 19 TB/s)

从 HBM 到 SRAM,带宽差了一个数量级。但现代 GPU 的计算速度已经超过内存速度了,操作越来越被内存访问(IO)而不是计算本身瓶颈住。所以,关键问题不是 FLOP 多不多,而是有多少数据在 HBM 和 SRAM 之间来回搬。这就是“IO‑Awareness”的核心思想。


2. 标准注意力实现的问题

标准算法:三步走

标准的注意力实现通常是三步:

第一步:计算 QK^T
把 Q 和 K 从 HBM 读到 SRAM,在芯片上算 QK^T,结果写回 HBM。这一步产生了一个 N×N 的注意力分数矩阵 S。

第二步:Softmax
把 S 从 HBM 读出来,逐行做 softmax,得到 P 矩阵。P 再写回 HBM。

第三步:PV
把 P 和 V 从 HBM 读出来,在芯片上算 PV,结果写回 HBM。

这三步看起来很自然,但每步都在做一件事:把中间结果从 HBM 写出去,再从 HBM 读进来

HBM 访问次数的分析

前向传播:  

  • 第一步:读 Q、K,写 S → O(N²d + N²) 次 HBM 访问  
  • 第二步:读 S,写 P → O(N²) 次 HBM 访问  
  • 第三步:读 P、V,写 O → O(N²d + N²) 次 HBM 访问  

总共 O(N²d + N²) 次,序列长度的平方级

反向传播:
反向传播需要用到前向计算的 S 和 P 矩阵,同样要读 S、P 写回 dQ、dK、dV,也大约 O(N²d + N²) 次 HBM 访问。

核心矛盾

整个前向 + 反向传播的 HBM 访问总量大约是 O(N²d) 次。但输入 Q、K、V 本身的总大小只有 O(Nd),输出 O 也只有 O(Nd)。数据量是 O(Nd) 的东西,为什么要做 O(N²d) 的 HBM 访问? 多出来的 O(N²) 次访问,全是用在那个大得离谱的 N×N 注意力矩阵上。这个矩阵太大,放不下 SRAM,只能在 HBM 和 SRAM 之间反复搬。这就是标准实现的根本问题。


3. FlashAttention 算法:核心思路

两个关键技术

FlashAttention 的思路很简单:用两个经典技术,避免把 N×N 注意力矩阵写到 HBM 上。 这两个技术是:

  1. Tiling(分块计算)  
  2. Recomputation(重计算)

Tiling:分块做 Softmax

标准 softmax 需要对整行做归一化,看起来必须把整行读进来才能算。但数学上有个技巧:softmax 可以分块计算。具体来说,如果我把一个向量 x 拆成两段 x¹ 和 x²,那么整个向量 x 的 softmax 结果,可以用 x¹ 和 x² 各自的 softmax 统计量(最大值 m 和归一化因子 ℓ)来逐步合并。

FlashAttention 分块计算与数据流示意图

公式大概是:

Softmax 分块计算统计量合并公式

这样,每次处理一个块,只需要记录两个小值(m 和 ℓ),就能把结果正确合并起来。

所以 FlashAttention 的做法是:把 K、V 分成多个块;每次只把一个块加载到 SRAM;对 Q 的每个块,和 K 的这个块算 QK^T;在 SRAM 里算 softmax,更新 m 和 ℓ;逐步累积输出,最后写回 HBM。关键:整个过程中,N×N 注意力矩阵从来没有完整地出现在 HBM 上。

Recomputation:反向传播时不再读矩阵

反向传播怎么办?反向传播需要用到前向的 S 和 P 矩阵。标准做法是把前向的 S 和 P 存在 HBM 上,反向时直接读。但 FlashAttention 说:不存了,反向时重新算。 它只存前向的输出 Osoftmax 的统计量 m、ℓ,这两个东西很小,O(Nd) 的大小。反向传播时,从 HBM 读 Q、K、V,重新在 SRAM 里算 S 和 P,然后再算梯度。虽然多算了一些 FLOP,但因为避免了从 HBM 读 N×N 矩阵的开销,实际运行时间反而更快。这在学术上叫选择性梯度检查点(selective gradient checkpointing)


4. IO 复杂度分析

理论保证

文章给出了严格分析。标准注意力的 HBM 访问次数是 Θ(N²d + N²)。FlashAttention 的 HBM 访问次数是 Θ(N²d / M),其中 M 是 SRAM 的大小。

为什么是 N²d / M?因为 SRAM 能放下大小为 Θ(M) 的 K、V 块,每次能处理 Θ(M/d) 个 K 行。对于 N 行的 Q,需要 N / (M/d) = Nd/M 次扫描。每次扫描加载 O(Nd) 数据,所以总共 O(N²d / M) 次 HBM 访问。

实际差距

拿 A100 来算:d = 64(head 维度),M ≈ 100KB(每个 SM 的 SRAM)。标准注意力:O(N² × 64) 次 HBM 访问;FlashAttention:O(N² × 64 / 100000) ≈ O(N² × 0.00064) 次 HBM 访问。HBM 访问量减少了大约 100 倍。 虽然实际不可能完全达到理论极限,但文章实验显示前向传播减少了约 8 倍,反向传播减少了约 7 倍,合计约 9 倍的 HBM 访问量降低。

下界证明

文章还证明了一个有意思的结论:对于任何精确注意力算法,在所有可能的 SRAM 大小范围内,不可能渐近地优于 O(N²d / M) 的 HBM 访问下界。 换句话说,FlashAttention 在这个意义上是最优的。


5. Block‑Sparse FlashAttention

扩展:稀疏注意力

FlashAttention 不只是精确注意力,还可以扩展到稀疏注意力。思路很简单:如果注意力矩阵是块稀疏的(比如某些块全是零),那么在 Tiling 循环中直接跳过这些块就行。算法跟 FlashAttention 几乎一样,只是加了一个 if 判断:如果当前块 M_ij = 0,跳过计算。文章证明了 Block‑Sparse FlashAttention 的 HBM 访问次数是 Θ(N²d · s / M),其中 s 是非零块的比例。s 越小,加速越多。实验显示,在 LRA benchmark 上,Block‑Sparse FlashAttention 相对于标准 FlashAttention 有 2.8 倍的加速,同时精度相当。

FlashAttention 块大小影响与稀疏加速实验


6. 实验结果

训练速度

  • BERT‑large:在 MLPerf 1.1 上,FlashAttention 比 NVIDIA 记录快了 15%(从 20.0 分钟降到 17.4 分钟)。
  • GPT‑2 small:比 HuggingFace 实现快 3.5 倍(从 9.5 天降到 2.7 天),比 Megatron‑LM 快 2.0 倍(从 4.7 天降到 2.7 天)。
  • GPT‑2 medium:比 HuggingFace 快 3.0 倍(从 21.0 天降到 6.9 天),比 Megatron‑LM 快 1.7 倍(从 11.5 天降到 6.9 天)。
  • Long‑Range Arena:平均加速 2.4 倍

模型质量提升

FlashAttention 不只是更快,还能训练出更好的模型。

  • GPT‑2 长上下文:用 FlashAttention 训练 GPT‑2 small,上下文长度从 1K 提升到 4K,仍然比 Megatron 的 1K 版本快 30%,且 perplexity 低了 0.7
  • 长文档分类:在 MIMIC‑III(医疗文本分类)和 ECtHR(法律判决分类)上,增加序列长度带来显著提升:MIMIC‑III 16K 序列比 512 序列提升 4.3 分;ECtHR 8K 序列比 512 序列提升 8.5 分
  • PathFinder 挑战:Path‑X(16K 序列)FlashAttention 的 Transformer 达到 61.4% 准确率,是第一个在这个任务上超过随机猜测的 Transformer;Path‑256(64K 序列)Block‑Sparse FlashAttention 达到 63.1% 准确率。

基准测试

  • 序列长度 128–512:FlashAttention 比 PyTorch 标准实现快 2–3 倍
  • 序列长度 1024–2048:FlashAttention 比所有近似注意力方法都快。
  • 内存占用:FlashAttention 比 PyTorch 标准实现低 20 倍,比 Linformer 低 2 倍

不同硬件上的表现

  • A100:2–4 倍加速
  • RTX 3090:2.5–4.5 倍加速(HBM 带宽更低,加速效果更明显)
  • T4:加速较少(SRAM 更小,块大小需要更小)

7. 总结

这篇文章的核心贡献,可以用一句话概括:写注意力算法的时候,要把 GPU 内存层级的读写开销也算进去。 这个“IO‑Awareness”的思想听起来简单,但在深度学习这个领域里,很少有人真正认真对待过。大家习惯了看 FLOP 数,看理论复杂度,看 benchmark 上的 accuracy。但 FLOP 不等于 wall‑clock time,不等于内存使用量,不等于实际训练出来的模型质量。FlashAttention 用两个经典技术——Tiling 和 Recomputation——把注意力机制的 HBM 访问量从 O(N²) 降到了 O(N² / M),在保持精确计算的同时实现了 3 倍的加速和 20 倍的内存节省。而且它不只是更快,还让 Transformer 第一次真正具备了建模 64K 长度上下文的能力。这就是 IO‑Awareness 的力量。

如果你对 Transformer 模型的训练优化有更多兴趣,欢迎访问 云栈社区 与开发者交流实战经验。




上一篇:项目管理必知的9条经典法则:从木桶到热炉的实务解读
下一篇:深度学习硬件加速器完全指南:从GPU到TPU/FPGA的架构演进与LLM推理瓶颈
您需要登录后才可以回帖 登录 | 立即注册

手机版|小黑屋|网站地图|云栈社区 ( 苏ICP备2022046150号-2 )

GMT+8, 2026-5-2 22:08 , Processed in 0.787066 second(s), 41 queries , Gzip On.

Powered by Discuz! X3.5

© 2025-2026 云栈社区.

快速回复 返回顶部 返回列表