Cross-Attention 是连接文本描述与视觉内容生成的核心技术。它允许图像的每个局部区域“查询”并融合相关的文本语义,从而实现从“一只猫在雨中行走”这样的提示词到具体图像或视频的精确生成。本文将以xDiT框架为例,深入解析其原理、实现与优化策略。
一、为何需要 Cross-Attention?
1.1 扩散模型的条件生成难题
扩散模型的生成是一个迭代去噪过程:纯噪声 → 去噪 → ... → 清晰图像。关键问题在于:如何让模型知道我们想生成什么?
传统方法存在局限:
- 简单拼接:文本和图像特征维度不匹配,无法精确对齐语义。
- 条件嵌入相加:文本信息容易被平均稀释,无法建立细粒度对应关系。
1.2 Cross-Attention 的突破
其核心思想是:让图像的每一个 patch 都能“查询”文本中相关的信息。
举个例子,对于文本“一只猫在雨中行走”:
- 图像 Patch 1(左上角):查询后可能更关注“雨中”(权重0.7),从而生成下雨的背景。
- 图像 Patch 2(中心):查询后可能更关注“猫”(权重0.9),从而生成猫的主体。
这种机制的优势在于:
- 细粒度对齐:每个图像区块可与不同文本部分建立联系。
- 动态权重:注意力权重可学习,能自适应调整。
- 信息保留:文本语义不会被简单稀释。
二、Cross-Attention 理论基础
2.1 Self-Attention 与 Cross-Attention 的区别
- Self-Attention(自注意力):在单一模态内部进行信息交互。Query、Key、Value 均来自同一输入(如图像特征本身)。
- Cross-Attention(交叉注意力):实现跨模态信息融合。Query 来自目标模态(如图像),而 Key 和 Value 来自条件模态(如文本)。这是实现文本条件控制的关键。
2.2 数学原理详解
以图像特征 $X$($N$ 个 token)和文本特征 $C$($M$ 个 token)为例:
-
计算 Q, K, V:
$$
Q = X W^Q \in \mathbb{R}^{N \times d_k}, \quad K = C W^K \in \mathbb{R}^{M \times d_k}, \quad V = C W^V \in \mathbb{R}^{M \times d_v}
$$
$W^Q, W^K, W^V$ 是可学习的投影矩阵。
-
计算注意力权重:
$$
\text{Score} = \frac{Q K^T}{\sqrt{d_k}} \in \mathbb{R}^{N \times M}
$$
除以 $\sqrt{d_k}$ 是为了防止点积结果过大导致梯度消失。
$$
\text{Attention} = \text{softmax}(\text{Score}) \in \mathbb{R}^{N \times M}
$$
经过 softmax 归一化后,每个图像 token 对所有文本 token 的注意力权重和为 1。
-
加权求和:
$$
\text{Output} = \text{Attention} \cdot V \in \mathbb{R}^{N \times d_v}
$$
最终,每个图像 token 都聚合了与其最相关的文本语义信息。
2.3 直观理解:图书馆检索
可以将其类比为图书馆检索:
- 你的问题(Query):“如何训练深度学习模型?”(图像patch的疑问)
- 书籍索引(Key):“深度学习基础”、“模型训练技巧”等(文本token的索引)
- 相关性匹配(QK^T):计算问题与各书籍索引的相关性。
- 归一化权重(Softmax):得到应参考各书籍的权重比例。
- 获取内容(Attention × V):按照权重综合各书籍(Value)的内容,得到最终答案。
在 Cross-Attention 中,图像的每个 patch 通过此过程,从文本描述中获取自己应该生成什么内容的指导。
三、xDiT 中的 Cross-Attention 实现
3.1 整体架构
在 xDiT 的文生视频流程中,Cross-Attention 嵌入在每一次去噪迭代的 Transformer Block 中:
- 文本编码:提示词通过 T5 等编码器转为特征。
- 噪声初始化:生成初始的噪声潜在表示(latent)。
- 迭代去噪(核心):
- Self-Attention:图像 latent 内部进行信息交互。
- Cross-Attention:latent(Query)查询文本特征(Key/Value),注入条件。
- FeedForward:进行非线性变换。
- VAE 解码:将去噪后的 latent 解码为最终视频。
3.2 代码流程与 QKV 生成
关键代码位于 wan_attention.py 的 _get_qkv_projections 函数中,它清晰地体现了 Cross-Attention 的本质:
def _get_qkv_projections(attn, hidden_states, encoder_hidden_states):
# Query 始终来自图像 latent
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
# Self-Attention 模式:Key/Value 也来自图像
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
else:
# Cross-Attention 模式:Key/Value 来自文本条件
key = attn.to_k(encoder_hidden_states) # 关键区别
value = attn.to_v(encoder_hidden_states) # 关键区别
return query, key, value
这里的 to_q, to_k, to_v 是三个独立的线性投影层,负责将输入映射到统一的维度,为后续计算做准备。想要深入理解此类模型中的张量操作与投影逻辑,可以进一步学习Python编程与深度学习框架的相关知识。
四、实例剖析:720p 视频生成的 Shape 变化
让我们以 WAN 2.1 模型生成 81 帧 720p 视频为例,跟踪 Cross-Attention 中张量的完整变化。
4.1 输入准备
- 文本特征:经过编码和投影后,形状为
[1, 512, 3072]。
- 视频 Latent:VAE 将
[3, 81, 720, 1280] 的视频压缩为 [1, 16, 81, 90, 160] 的潜在表示。经过 Patchify(块化)和嵌入投影后,转换为 [1, 291600, 3072] 的图像 token 序列(共 81×45×80 个 patch)。
4.2 Cross-Attention 的完整 Shape 变化轨迹
-
输入:
hidden_states (图像): [1, 291600, 3072]
encoder_hidden_states (文本): [1, 512, 3072]
-
生成 Q, K, V:
Q = to_q(hidden_states): [1, 291600, 3072]
K = to_k(encoder_hidden_states): [1, 512, 3072]
V = to_v(encoder_hidden_states): [1, 512, 3072]
-
重塑为多头格式(假设 40 个头,头维度 77):
Q: [1, 40, 291600, 77]
K: [1, 40, 512, 77]
V: [1, 40, 512, 77]
-
计算注意力分数与输出:
Scores = Q @ K.T / √77: [1, 40, 291600, 512]。每个图像 token 对应一个长度为 512 的向量,表示它与所有文本 token 的相关性。
Attention = softmax(Scores): [1, 40, 291600, 512]。权重归一化。
Output = Attention @ V: [1, 40, 291600, 77]。加权聚合文本信息。
- 拼接多头并投影输出:
[1, 291600, 3072]。形状与输入图像特征相同,但已融入文本语义。
4.3 可视化理解
假设视频中心的一个图像 token(对应猫的身体部分),其注意力权重可能分布如下:
- “猫”:0.65
- “雨中”:0.20
- “行走”:0.09
- 其他词:0.06
这意味着在生成这个 patch 时,模型主要融合了“猫”和“雨中”的语义,从而可能生成一个带有雨水效果的猫的身体部分。而另一个表示背景的 token,则可能更关注“雨中”和“行走”。这种细粒度的、动态的注意力分配,是生成高质量、符台提示内容的关键。
五、不同模型的 Cross-Attention 策略对比
除了 xDiT 使用的标准 Cross-Attention,其他先进生成模型采用了不同的策略:
| 模型 |
策略 |
原理 |
特点 |
| FLUX |
Joint Attention |
将图像 token 和文本 token 拼接成一个长序列,然后在整体上进行 Self-Attention。 |
实现图像与文本的双向交互,融合更强,但计算复杂度更高 (O((N+M)²))。 |
| CogVideoX |
双编码器 |
同时使用 CLIP 和 T5 编码文本,将两者的特征拼接后作为 Cross-Attention 的 Key/Value。 |
结合 CLIP 的视觉对齐能力和 T5 的深层语义理解。 |
| HunyuanVideo |
时序分离 Attention |
在 Transformer Block 中,分别进行时间维度的 Self-Attention、空间维度的 Self-Attention,最后再进行 Cross-Attention。 |
显式地分别建模时序一致性和空间细节,通常能获得更好的视频连贯性。 |
六、实战代码解析
6.1 简化的 Cross-Attention PyTorch 实现
以下是一个清晰、完整的 Cross-Attention 层实现,包含了多头处理、缩放点积注意力和 Mask 支持:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class CrossAttention(nn.Module):
def __init__(self, query_dim, cross_attention_dim, num_heads=8, head_dim=64, dropout=0.0):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.inner_dim = num_heads * head_dim
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=False)
self.to_k = nn.Linear(cross_attention_dim, self.inner_dim, bias=False)
self.to_v = nn.Linear(cross_attention_dim, self.inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(self.inner_dim, query_dim), nn.Dropout(dropout))
def forward(self, hidden_states, encoder_hidden_states, attention_mask=None):
batch_size = hidden_states.shape[0]
# 1. 投影得到 Q, K, V
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
# 2. 重塑为多头格式 [batch, heads, seq_len, head_dim]
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# 3. 计算缩放点积注意力
scale = 1.0 / math.sqrt(self.head_dim)
attn_scores = (q @ k.transpose(-2, -1)) * scale # [B, H, N_img, N_txt]
if attention_mask is not None:
# 将 mask 中为0的位置填充为负无穷,softmax后权重为0
attn_scores = attn_scores.masked_fill(attention_mask.unsqueeze(1).unsqueeze(2) == 0, float('-inf'))
attn_probs = F.softmax(attn_scores, dim=-1)
# 4. 加权求和并输出
out = attn_probs @ v # [B, H, N_img, head_dim]
out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
return self.to_out(out)
6.2 使用示例
# 初始化
cross_attn = CrossAttention(query_dim=3072, cross_attention_dim=4096, num_heads=40, head_dim=77).cuda()
# 模拟输入:图像特征 (291600个token) 和文本特征 (512个token)
hidden_states = torch.randn(1, 291600, 3072).cuda()
encoder_hidden_states = torch.randn(1, 512, 4096).cuda()
# 前向传播
output = cross_attn(hidden_states, encoder_hidden_states)
print(output.shape) # torch.Size([1, 291600, 3072])
七、Cross-Attention 的优化技巧
为了处理文生视频中庞大的序列长度(如29万个图像token),xDiT等框架采用了多种优化技术:
-
并行化策略:
- Tensor Parallel (TP):将多头注意力中的“头”切分到不同GPU上计算。
- Ulysses Parallel (USP):将序列长度维度切分到不同GPU上计算。两者可组合使用,极大降低单卡内存消耗。
-
FlashAttention:通过 Kernel 融合和 Tiling 技术,避免计算过程中存储巨大的 N × M 注意力矩阵,将内存复杂度从 O(NM) 降至 O(N),并获得2-4倍的加速。
-
SageAttention (INT8 量化):将计算注意力分数时的 Query 和 Key 投影量化至 INT8 格式,在保证精度的同时减少50%的内存占用并提升计算速度。
八、常见问题 FAQ
Q1: Cross-Attention 和 Self-Attention 可以同时使用吗?
A:可以,而且这是标准做法。在 DiT/xDiT 的每个 Transformer Block 中,通常是先进行 Self-Attention(建立图像内部结构),再进行 Cross-Attention(注入文本条件),最后通过 FeedForward 层。
Q2: Attention Mask 的作用是什么?
A:用于屏蔽无效的输入,如文本序列中的填充(padding)token。在计算 softmax 前,将 mask 位置对应的分数设为负无穷,使其权重为0,防止模型关注无意义信息。
Q3: 为什么要除以 √d_k?
A:为了稳定训练。当 Query 和 Key 的维度 d_k 较大时,点积的结果方差也会变大,导致 softmax 后的梯度非常小(梯度消失)。除以 √d_k 可以将方差缩放回1左右。
Q4: xDiT 如何支持超长文本?
A:通常采用分块(chunking)策略。将超长文本特征分割成多个长度固定的块,让 Cross-Attention 依次处理每个块,最后聚合所有块的输出结果(如取平均)。
Q5: 多头注意力中“头”的意义是什么?
A:多头机制允许模型在多个不同的子表示空间中并行学习信息。不同的“头”可能专注于不同类型的关系(如物体、纹理、动作等),最后将它们的输出组合起来,增强了模型的表示能力。这正是现代人工智能模型,特别是Transformer架构的强大之处。
九、核心要点总结
- 核心角色:Cross-Attention 是文生图/文生视频中实现文本条件控制的核心桥梁,它通过 Query(图像)与 Key/Value(文本)的交互,实现细粒度的跨模态语义融合。
- 实现关键:在代码上,其与 Self-Attention 的唯一关键区别在于 Key 和 Value 的来源是条件模态(如文本)。
- 计算规律:计算过程中,图像 token 数
N 和文本 token 数 M 决定了注意力矩阵的大小(N×M),输出序列长度与图像 token 数 N 一致。
- 优化必需:面对生成视频时巨大的
N,必须依赖 FlashAttention、模型并行、量化 等优化技术来实现高效推理。