找回密码
立即注册
搜索
热搜: Java Python Linux Go
发回帖 发新帖

3925

积分

0

好友

539

主题
发表于 前天 08:55 | 查看: 16| 回复: 0

随着大语言模型(LLM)处理超长上下文(例如百万级token)的需求日益增长,一个严峻的性能瓶颈随之浮现——KV缓存的数据搬移

在基于Transformer的解码阶段,每个生成步骤都需要从高带宽内存(HBM)中重新加载所有历史token的Key-Value缓存。随着上下文长度线性增长,KV缓存的大小也线性增加,导致数据搬移而非计算成为主要延迟来源,GPU计算单元的利用率因此变得极低。

为了缓解这一问题,业界提出了多种高效注意力机制。其中,多头潜在注意力(Multi-Head Latent Attention, MLA) 通过将KV缓存压缩到一个低维的潜在头(latent head)中,显著减少了缓存大小,并被DeepSeek-V2/V3等模型采用。

然而,MLA在分布式推理中遇到了新的挑战:虽然其上投影权重可以按头切分,但其核心的KV缓存是一个单一的、不可分割的潜在向量($c_t^\\text{KV}$。这导致在使用张量并行时,每个设备仍需加载完整的 $c_t^\\text{KV}$,KV缓存加载量不会随TP设备数增加而减少,严重制约了系统扩展性。

MULTI-HEAD LOW-RANK ATTENTION 论文标题与作者信息

针对这一局限,研究者提出了多头低秩注意力(Multi-Head Low-Rank Attention, MLRA),一种原生支持张量并行的新型注意力机制。MLRA不仅保持了MLA的高性能,还通过将潜在头分解为多个可并行的分支,大幅降低了分布式解码时的KV加载开销。实验表明,MLRA-4在2.9B规模模型上取得了优于MLA的性能,并在长上下文解码中实现高达2.8倍的加速。其代码和模型已在GitHubHugging Face开源,相关论文可查看arXiv

一、背景知识:从MHA到MLA

1.1 标准多头注意力(MHA)与KV缓存

标准的Transformer使用多头注意力(MHA),每个头独立计算查询(Q)、键(K)、值(V)。假设输入序列长度为 $n$,模型隐藏维度为 $d$,头数为 $h$,每个头的维度为 $d_h$,通常 $d = h \\times d_h$

MHA在解码阶段会缓存已生成token的K和V(即KV缓存)。每个token的KV缓存大小为 $2hd_h$,即 $2d$当上下文长度 $n$ 达到百万级别时,KV缓存的大小可达数十GB,成为内存带宽的瓶颈。

1.2 分组查询注意力(GQA)与多查询注意力(MQA)

为了减少KV缓存,GQA将查询头分组,组内共享同一个KV头;MQA则极端地使用单个KV头共享给所有查询头。设分组数为 $g$,则缓存大小降为 $2gd_h$这虽然降低了缓存大小,但可能影响模型质量。

1.3 多头潜在注意力(MLA)

MLA的核心思想是:将KV进一步压缩到一个更小的潜在空间,存储时只保留压缩后的潜在向量,在计算注意力时再通过上投影恢复。其高效解码三步算法简述如下:

  • Step 1:将查询与上投影矩阵结合,得到 $\\tilde{q}_t$
  • Step 2:将 $\\tilde{q}_t$ 和缓存的潜在向量 $c_t^\\text{KV}$ 视为共享的KV,执行MQA风格的注意力,得到中间输出。
  • Step 3:用值上投影矩阵将中间输出映射回最终输出。

1.4 MLA的张量并行瓶颈

尽管MLA在单卡上高效,但在张量并行(TP)推理时,其单一的潜在头 $c_t^\\text{KV}$ 无法被切分。因此,无论TP并行度如何,每个设备都必须完整加载整个 $c_t^\\text{KV}$,导致大量冗余的数据搬移。

不同注意力机制的参数与KV缓存加载量对比表
表1:不同注意力机制的参数与KV缓存加载量对比(以Qwen3-32B架构为基础)。MLRA-4在4路TP下每设备加载量显著降低。

二、块分解的洞察:MLA的另一种视角

为了设计可切分的潜在注意力,论文作者首先对MLA的KV上投影进行块分解分析。

将上投影矩阵 $W^K$ 按行切分为4个块(因为 $d_k^R = d_h/4$),每个块大小为 $d_h \\times d_k^R$。分析发现,每个头的NoPE键可以表示为四个块投影之和。这一观察启发了MLRA的核心设计:能否将求和操作从KV计算挪到注意力输出之后,从而让每个分支独立?

三、MLRA的核心设计:将求和移到注意力之外

基于上述洞察,MLRA的核心创新在于将块求和从KV构造阶段移到注意力输出阶段。具体来说,对于每个块,独立地计算该块的注意力输出(使用该块的KV投影),然后将所有块的输出相加。 这样做的好处是:每个块的潜在状态可以独立地分配到不同设备上,实现真正的切分。

3.1 MLRA-4

对于4路切分,MLRA-4的公式为:

$$o_t = \\frac{1}{\\sqrt{4}} \\sum_{b=1}^{4} \\text{Attention}(\\tilde{q}_t, [k_t^R; c_t^{b,\\text{KV}}W_b^K], c_t^{b,\\text{KV}}W_b^V)$$

其中 $\\tilde{q}_t$ 来自查询潜在向量的上投影。每个分支的RoPE键 $k_t^R$ 是共享的。注意,每个分支的Softmax是独立计算的,这意味着不同分支的注意力权重可能不同,从而增加了模型的表达能力。

3.2 MLRA-2

类似地,可以构造2路版本(MLRA-2),它将头分为两组,每组使用两个块分支。其架构直观展示了分支处理与求和的过程。

MLRA-2架构图
图:MLRA-2架构,两个分支分别处理部分潜在块,输出求和。

MLRA-4架构图
图:MLRA-4架构,四个分支并行计算后求和。

这种设计使得MLRA原生支持张量并行。对于4路TP,因为 $c_t^\\text{KV}$ 被切分到4个设备,每个设备负责一个块。每个设备实际加载量为 $d_k^R + d_h/4$,相比MLA的 $d_k^R + d_h$ 大幅降低。

3.3 方差分析与缩放:保证训练稳定性

论文作者发现,MLA及类似结构中,不同路径产生的张量(如NoPE键和RoPE键)可能存在方差失配问题,影响训练稳定性。

通过理论推导发现,由于潜空间维度 $d_k^R$ 通常远小于原始隐藏维度 $d_h$,导致NoPE键的方差可能远小于RoPE键。因此,需要对潜在状态进行缩放,使得后续组件的方差匹配。论文引入了缩放因子,对于MLRA,由于多个分支求和,还需要对最终输出进行缩放。

缩放对训练损失影响的消融实验图
图:缩放策略能有效降低训练损失,提升模型收敛效果。

四、实验验证:性能与效率的双重提升

4.1 主要结果

作者在2.9B参数规模下预训练了多种注意力机制。MLRA-4在语言建模困惑度和下游常识推理任务准确率上均达到最优或次优水平,证明了其强大的表示能力。

各模型在多个数据集上的困惑度对比表
表3:MLRA-4在所有数据集上均取得最优或次优困惑度。

各模型在下游任务上的平均准确率对比表
表4:MLRA-4以58.84%的平均准确率位居第一。

4.2 解码效率

实验表明,MLRA-4在长上下文解码延迟和吞吐量上均有显著优势。

解码延迟与序列长度的关系图
图5:MLRA-4解码延迟随序列长度增长最慢,相比MLA稳定快约2.8倍。

解码吞吐量与序列长度的关系图
图6:MLRA-4(TP=4/DP=2)在各种序列长度下吞吐量最高。

此外,MLRA-4的算术强度达到 $\\approx 2h$,远高于GQA等,表明其计算密集度更高,更接近计算瓶颈而非内存瓶颈,能更好地利用GPU算力。

不同注意力机制的算术强度对比表

4.3 消融实验

多项消融实验验证了MLRA设计选择的合理性:

  • 初始化:将输出投影初始化为零比高斯初始化效果更好。
  • 双倍头数:单纯增加头数而不改变KV缓存大小无益。
  • 门控机制:引入门控可进一步提升性能,MLRA-4结合门控后平均困惑度降至13.621。

初始化策略消融实验图
图:零初始化在多数情况下能带来更低的训练损失。

五、结论

Multi-Head Low-Rank Attention(MLRA)通过将潜在头分解为多个可并行的分支,巧妙地绕过了MLA在张量并行中的冗余加载问题。MLRA-4在2.9B规模下达到SOTA性能,并在长上下文解码中实现2.8倍加速。这一设计融合了低秩压缩与并行设计的优势,为未来超长上下文模型的分布式推理提供了新的高效解决方案,是Transformer架构演进中的重要探索。对这类前沿开源实战项目感兴趣的朋友,可以持续关注云栈社区的技术动态分享。




上一篇:TARA方法如何让多模态大模型掌握层级视觉识别能力
下一篇:Helios视频生成模型技术解析:14B参数如何实现单卡19.5 FPS实时推理
您需要登录后才可以回帖 登录 | 立即注册

手机版|小黑屋|网站地图|云栈社区 ( 苏ICP备2022046150号-2 )

GMT+8, 2026-3-10 10:18 , Processed in 0.580765 second(s), 40 queries , Gzip On.

Powered by Discuz! X3.5

© 2025-2026 云栈社区.

快速回复 返回顶部 返回列表