FlashAttention的核心思想是让注意力计算变得IO感知,其性能提升并非源于算法改变,而是源于对GPU实际工作方式的深刻理解和利用。本文将从零开始,探讨如何使用自定义的Triton内核实现FlashAttention-2的前向传播,涵盖关键的分块策略、流式处理、在线softmax算法以及性能优化。
标准注意力为何是内存瓶颈
标准缩放点积注意力(Standard Scaled Dot-Product Attention)的计算流程为:
S = QKᵀ ∈ ℝᴺˣᴺ,P = softmax(S) ∈ ℝᴺˣᴺ,O = PV ∈ ℝᴺˣᵈ
其瓶颈不在于浮点运算量(FLOPs),而在于内存带宽。通常的计算中,需要先将完整的 N×N 注意力矩阵 S 写入高速带宽内存(HBM),再读回计算softmax并存储P,最后再次读取P与V相乘。矩阵中的每个元素被反复访问多次,每次都需经过HBM。
当序列长度(N)达到16K时,注意力矩阵包含约2.56亿个元素。在A100 GPU上,从HBM读取数据比从片上静态随机存取存储器(SRAM)读取大约慢15倍。反复在HBM与计算单元之间搬运如此庞大的中间结果,使得标准注意力成为典型的内存受限操作,计算单元大量时间处于等待数据的状态。
FlashAttention的核心:IO感知计算
FlashAttention的解决方案是重新组织计算调度,避免在HBM中物化(Materialize)巨大的N×N注意力矩阵。其核心是利用GPU的内存层次结构:片上SRAM比HBM快几个数量级。以NVIDIA A100为例,其HBM带宽约为1.5-2.0 TB/s,而SRAM的估计带宽高达约19 TB/s。
FlashAttention遵循一条黄金法则:将高频访问的数据尽可能地保留在内存层次的上层(如寄存器和共享内存),避免不必要的HBM往返。
具体实现方式是分块处理。不再一次性计算完整的注意力矩阵,而是将查询(Q)序列分块,对于每一块Q_block,流式地迭代读取键(K)和值(V)序列的对应块。在迭代过程中,使用在线softmax算法增量地计算部分结果,并逐步构建最终输出矩阵O。这样,注意力计算的内存复杂度就从O(N²)降低到了O(N)。
技术难点:在线Softmax
标准softmax计算需要获取一个序列(这里是注意力分数的一行)的所有元素,以求得最大值和求和项,再进行归一化。其公式如下:
softmax(s_i) = e^(s_i - max_j s_j) / Σ_j e^(s_j - max_j s_j)
在FlashAttention的分块流式处理中,内核一次只能看到部分注意力分数块,无法获得完整的分数集。因此,必须采用在线softmax算法。该算法为每个查询行维护三个状态变量,在迭代K/V块的过程中持续更新:
- 运行最大值(m_i):用于保证数值稳定性。
- 运行归一化项(l_i):累积的softmax分母部分和。
- 运行输出累加器(O_i):累积的未归一化注意力输出。
每处理一个新的K/V块,都会计算该块的局部注意力分数,并更新上述状态。关键的挑战在于,当处理新块发现更大的运行最大值(m_new > m_old)时,之前基于旧最大值累积的l_i和O_i需要进行尺度校正。校正因子为:
α = e^(m_old - m_new)
然后进行更新:
- 新的分母项:
l_new = l_old * α + Σ_j e^(score_ij_new - m_new)
- 新的输出累加器:
O_new = O_old * α + Σ_j e^(score_ij_new - m_new) · V_j
最终,在所有K/V块处理完毕后,对输出累加器进行归一化,得到最终的注意力输出:
O_i = O_i_raw / l_i
这一套更新机制确保了流式处理的结果与一次性对整个序列进行标准softmax计算的结果在数学上完全等价。
代码实现分解
整体的实现架构如下:
for each (batch, head):
for each Q_block:
initialize m_i, l_i, O_block
for each K/V block:
compute partial scores
update online softmax state
accumulate output
write O_block to memory
所有逻辑融合在一个内核中,中间状态全程驻留在SRAM等快速内存中。
Host端包装器与内核启动
Python端的包装器负责准备输入、定义执行网格并启动Triton内核。
class TritonFlashAttention(torch.autograd.Function):
@staticmethod
def flash_attention(Q, K, V, causal):
assert Q.is_cuda
assert K.is_cuda
assert V.is_cuda
B, H, Lq, D = Q.shape
B, H, Lk, D = K.shape
B, H, Lk, D = V.shape
O = torch.empty_like(Q)
BLOCK_SIZE_Q = 128
BLOCK_SIZE_KV = 32
stage = 3 if causal else 1
grid = lambda x: (triton.cdiv(Lq, x["BLOCK_SIZE_Q"]), B * H, 1)
M = torch.empty((B, H, Lq), device=Q.device, dtype=torch.float32)
scaling_factor = 1 / math.sqrt(D)
fwd_flash_attn_kernel[grid](Q, K, V, O, M, scaling_factor,
Q.stride(0), Q.stride(1), Q.stride(2), Q.stride(3),
K.stride(0), K.stride(1), K.stride(2), K.stride(3),
V.stride(0), V.stride(1), V.stride(2), V.stride(3),
O.stride(0), O.stride(1), O.stride(2), O.stride(3),
B, NUM_HEADS=H, SEQ_LEN=Lq, HEAD_DIM=D, STAGE=stage,)
return O
执行网格与并行策略
包装器中定义的2D执行网格决定了GPU上的工作分配:
- 第0维 (
program_id(0)):标识处理哪个查询序列块。总数为 ceil(序列长度 / BLOCK_SIZE_Q)。
- 第1维 (
program_id(1)):标识处理哪个(batch, head)对。总数为 batch_size * 头数。
这种设计实现了序列维度和batch/head维度的并行,且各程序实例间无需同步。
前向传播内核框架
我们将前向传播逻辑分为两个内核:fwd_flash_attn_kernel 负责协调(加载查询块、处理因果逻辑、写回输出),_attn_fwd_inner 实现核心的流式注意力计算。
fwd_flash_attn_kernel 的主要步骤:
- 网格映射与指针计算:根据
program_id计算出当前实例负责的batch、head和查询块范围。通过指针算术,构建指向输入输出张量特定区域的指针。
index_batch = index_batch_head // NUM_HEADS
index_head = index_batch_head % NUM_HEADS
qkv_offset = index_batch * qb_stride + index_head * qh_stride
- 初始化状态:为当前查询块初始化在线softmax的状态变量。
m_i = tl.zeros((BLOCK_SIZE_Q,), dtype= tl.float32) - float("inf")
l_i = tl.zeros((BLOCK_SIZE_Q,), dtype=tl.float32) + 1.0
O_block = tl.zeros((BLOCK_SIZE_Q, HEAD_DIM), dtype=tl.float32)
- 分派计算:根据是否因果注意力(
STAGE参数),调用_attn_fwd_inner内核处理相应的K/V块范围。
- 最终归一化与写回:在所有K/V块处理完毕后,对累积的输出进行归一化,并写回HBM。
O_block = O_block / l_i[:, None]
tl.store(O_block_ptr, O_block.to(tl.float16))
流式注意力核心内核
_attn_fwd_inner 内核实现了FlashAttention-2算法的核心循环。
- 确定K/V块范围:根据
STAGE参数决定当前需要关注哪些位置的K/V,以支持因果注意力。
STAGE=1:处理对角线左侧的块(仅用于因果注意力)。
STAGE=2:处理对角线块自身(仅用于因果注意力)。
STAGE=3:处理所有块(用于非因果注意力,或因果注意力中需要mask的部分)。
- 流式循环:在确定的
[lo, hi)范围内,以BLOCK_SIZE_KV为步长循环。
for start_kv in range(lo, hi, BLOCK_SIZE_KV):
- 加载K/V块:为当前循环的K/V块构建指针,并从HBM加载到SRAM。
K_block = tl.load(K_block_ptr, mask=mask_k, other=0.0)
V_block = tl.load(V_block_ptr, mask=mask_v, other=0.0)
- 计算与更新:
- 计算分块点积注意力分数:
QK_block = tl.dot(Q_block, K_block) * scale
- 应用因果mask(如需)。
- 更新运行最大值
m_ij。
- 计算当前块的softmax概率
P_block = exp(QK_block - m_ij)。
- 计算当前块的分母项
l_ij = sum(P_block, axis=1)。
- 计算缩放因子
alpha = exp(m_i - m_ij)。
- 更新运行分母
l_i = l_i * alpha + l_ij。
- 更新输出累加器
O_block = O_block * alpha[:, None] + tl.dot(P_block, V_block)。
- 更新运行最大值
m_i = m_ij。
性能验证
性能基准测试对比了三种实现:标准注意力(Standard)、自定义Triton FlashAttention内核(Triton Flash)以及PyTorch 2.2官方实现的FlashAttention(PyTorch Official),指标为在不同序列长度下达到的TFLOPS/sec。
| 性能数据对比: |
序列长度 |
Standard (TFLOPS/s) |
Triton Flash (TFLOPS/s) |
PyTorch Official (TFLOPS/s) |
| 512 |
3.35 |
6.77 |
22.36 |
| 1024 |
3.60 |
27.13 |
65.88 |
| 2048 |
3.72 |
95.07 |
83.68 |
| 4096 |
3.79 |
140.38 |
132.43 |
| 8192 |
3.73 |
174.24 |
166.54 |
| 16384 |
3.54 |
190.09 |
177.67 |
从数据中可以看出两个关键结论:
- 标准注意力是内存瓶颈:无论序列长度如何增加,其性能始终被限制在约3-4 TFLOPS/s。计算单元大部分时间在等待HBM的数据搬运。
- FlashAttention实现了计算受限:Triton Flash和PyTorch Official实现的性能随序列长度增长而显著提升。在长序列(如16K)下,性能达到约180-190 TFLOPS/s,接近GPU的理论计算峰值。这证明了通过避免物化大矩阵、将数据保留在SRAM的策略是成功的。
- 自定义Triton内核具备竞争力:在长序列场景下,手写的Triton内核性能与PyTorch官方优化实现持平甚至略有优势,显示了Triton在编写高性能定制内核方面的强大能力。
另一个针对GPT-2的注意力性能分解图显示,相较于标准PyTorch实现将时间消耗在多个独立操作(矩阵乘、Dropout、Softmax、Mask等)上,FlashAttention通过内核融合,将绝大部分计算合并到一个高效的内核中执行,显著减少了整体耗时。
总结与展望
通过实现自定义Triton内核的FlashAttention-2,我们验证了几个核心观点:
- 在现代GPU上,内核性能往往由内存访问模式而非纯粹的浮点运算能力决定。
- 内核融合和片上内存驻留是提升性能的关键策略,其效果可能超过纯粹的数学优化。
- 在线softmax等数值算法是实现IO感知计算的关键组件。
- Triton语言提供了足够的底层控制能力,同时保持了较好的可读性,使得开发者能够编写出与厂商优化库竞争力相当的高性能内核。
本文仅实现了前向传播。一个完整的训练级FlashAttention还需要实现高效的反向传播、支持Dropout以及更复杂的masking机制,这些将是后续工作的方向。对于从事人工智能模型底层优化的开发者而言,深入理解GPU内存层次和计算调度是提升性能的必经之路。
通过开源实战项目亲手实现核心算法,是加深对Transformer架构及硬件协同理解的有效方式。本文的完整代码已公开,可供进一步研究和改进。
参考资料
[1] 从零开始用自定义 Triton 内核编写 FlashAttention-2, 微信公众号:mp.weixin.qq.com/s/--GFiKSga3DE7G4WxG3kyw
版权声明:本文由 云栈社区 整理发布,版权归原作者所有。