本文将深入解析 linear_recurrence.cu 中的 CUDA 核函数实现,这是 LeetCode GPU Challenge 中的一道经典题目,同时也是 Mamba、S4、H3 等最新深度学习架构中状态空间模型的核心计算原语。
题目要求
给定两个形状为 (B, L) 的矩阵 a 和 x(其中 B 为 batch size,L 为 sequence length),需要计算形状同为 (B, L) 的线性递推序列 h:
$$h[t] = a[t] \cdot h[t-1] + x[t]$$
所有值均为 float32 类型,且初始状态 $h[-1] = 0$。
性能约束
- Batch Size: $B \le 10^4$
- Sequence Length: $L \le 10^4$
- 性能测试条件: NVIDIA V100 GPU
核心挑战
这是一个典型的顺序依赖问题。计算公式中,第 $t$ 时刻的结果必须等待第 $t-1$ 时刻的值才能计算,这与 GPU 追求的大规模并行形成了根本性冲突。解决思路是采用 Parallel Scan 算法,将序列分割成多个块,先在块内并行处理,再通过特殊的 combine 操作实现块间合并,这正是优化此类动态规划问题的关键。
问题分析
数学本质:Exclusive Prefix Scan
为了并行化,我们定义每个位置的二元组状态 $(p, s)$:
- $p$ (Product): 累积乘积 $\prod_{k=0}^{i} a[k]$
- $s$ (Sum): 加权累加 $\sum_{k=0}^{i} (x[k] \cdot \prod_{j=k+1}^{i} a[j])$
状态组合规则定义为:
$$(p_1, s_1) \otimes (p_2, s_2) = (p_1 \cdot p_2, s_1 + p_1 \cdot s_2)$$
可以验证该操作满足结合律,因此能够使用 Parallel Scan 算法进行并行化加速。
代码结构分析
通过分析源代码,可以识别出以下关键 kernel 函数:
| Kernel 函数 |
功能 |
Grid 配置 |
Block 配置 |
scan_mul_add_inter_block |
块内局部扫描 |
(num_chunks, B) |
128 |
scan_mul_add_across_block_warp |
Warp 级跨块扫描 |
div_up(B, 4) |
128 |
scan_mul_add_across_block_thread |
线程级跨块扫描 |
div_up(B, 128) |
128 |
scan_mul_add_across_block |
Shared Memory 跨块扫描 |
B |
nextPowerOfTwo(num_chunks) |
scan_mul_add_finalize |
最终结果计算 |
(num_chunks, B) |
128 |
linear_recurrence_naive_kernel |
朴素实现 (fallback) |
div_up(B, 128) |
128 |
优化策略
代码采用了三级并行优化策略,充分挖掘了现代 GPU 的计算潜力:
- 块内并行: 使用 Kogge-Stone 扫描算法在 block 内并行处理。
- 跨块优化: 根据 chunk 数量自适应选择最优实现,体现了对硬件层次结构的深刻理解。
num_chunks ≤ 32: Warp-Level Scan (寄存器通信)
32 < num_chunks < 128: Thread-Level Serial Scan (串行 work-efficient)
num_chunks ≥ 128: Shared Memory Multi-Warp Scan
- Fallback 机制: 短序列 ($L < 128$) 回退到朴素实现。
代码详解
1. 块内局部扫描 (scan_mul_add_inter_block)
__global__ void scan_mul_add_inter_block(
const float* a, const float* x, float* a_local, float* x_local,
float* p_chunk, float* s_chunk,
int B, int L, int num_chunks)
内核配置:
- Grid:
(num_chunks, B)
- Block:
CHUNK_SIZE = 128
Shared Memory 组织:
__shared__ float s_p[CHUNK_SIZE]; // prefix product
__shared__ float s_s[CHUNK_SIZE]; // prefix sum
Kogge-Stone 扫描算法:
for (int offset = 1; offset < CHUNK_SIZE; offset <<= 1) {
p_curr = s_p[threadIdx.x];
s_curr = s_s[threadIdx.x];
if (threadIdx.x >= offset) {
p_prev = s_p[threadIdx.x - offset];
s_prev = s_s[threadIdx.x - offset];
}
__syncthreads();
if (threadIdx.x >= offset) {
s_p[threadIdx.x] = p_curr * p_prev;
s_s[threadIdx.x] = s_curr + p_curr * s_prev;
}
__syncthreads();
}
技术要点:
- 每轮迭代都调用
__syncthreads() 确保正确性。
- 边界处理:超出序列长度的位置用单位元填充($p=1.0f, s=0.0f$)。
- 输出每个 chunk 的最终状态到
p_chunk 和 s_chunk,供下一阶段使用。
2. Warp-Level 跨块扫描 (scan_mul_add_across_block_warp)
适用场景: num_chunks ≤ 32
__global__ void scan_mul_add_across_block_warp(
float* p_chunk, float* s_chunk, int B, int num_chunks)
内核配置:
- Grid:
div_up(B, 4)
- Block:
128 threads (对应 4 warps)
核心思想: 每个 warp 独立负责一个 batch,使用寄存器级 shuffle而非 shared memory!
const int warp_id = threadIdx.x / 32;
const int lane_id = threadIdx.x % 32;
const int batch_id = blockIdx.x * 4 + warp_id;
Warp-Level Kogge-Stone (Inclusive):
for (int shift = 1; shift < 32; shift <<= 1) {
int src_lane = lane_id - shift;
int safe_src = (src_lane >= 0) ? src_lane : lane_id;
float p_prev = __shfl_sync(0xffffffff, p_curr, safe_src);
float s_prev = __shfl_sync(0xffffffff, s_curr, safe_src);
if (lane_id >= shift && lane_id < num_chunks) {
s_curr = s_curr + p_curr * s_prev;
p_curr = p_curr * p_prev;
}
__syncwarp();
}
技术优势:
- ✅ 零共享内存访问,完全使用寄存器通信。
- ✅ 无需 thread-block 级别的
__syncthreads(),只需 __syncwarp()。
- ✅ 延迟显著降低(寄存器访问 ~1 cycle vs shared memory ~30-100 cycles)。
Exclusive Scan 转换:
// Inclusive → Exclusive 转换
float p_out = 1.0f, s_out = 0.0f;
int src_lane = lane_id - 1;
float p_inclusive_prev = __shfl_sync(0xffffffff, p_curr, src_lane);
float s_inclusive_prev = __shfl_sync(0xffffffff, s_curr, src_lane);
if (lane_id > 0 && lane_id < num_chunks) {
p_out = p_inclusive_prev;
s_out = s_inclusive_prev;
}
3. Thread-Level 串行扫描 (scan_mul_add_across_block_thread)
适用场景: 32 < num_chunks < 128
__global__ void scan_mul_add_across_block_thread(
float* p_chunk, float* s_chunk, int B, int num_chunks)
内核配置:
- Grid:
div_up(B, 128)
- Block:
128 threads
Work-Efficient 串行扫描:
const int batch_id = blockIdx.x * blockDim.x + threadIdx.x;
int offset = batch_id * num_chunks;
float running_p = 1.0f;
float running_s = 0.0f;
for (int i = 0; i < num_chunks; ++i) {
float local_p = p_chunk[offset + i];
float local_s = s_chunk[offset + i];
// 写入 exclusive prefix
p_chunk[offset + i] = running_p;
s_chunk[offset + i] = running_s;
// 更新 running state
running_s = running_s * local_p + local_s;
running_p = running_p * local_p;
}
技术优势:
- ✅ Work-Efficient: 时间复杂度 $O(n)$,空间复杂度 $O(1)$。
- ✅ 无同步开销:单 thread 内部操作,无需任何同步指令。
4. Shared Memory 跨块扫描 (scan_mul_add_across_block)
适用场景: num_chunks ≥ 128
__global__ void scan_mul_add_across_block(
float* p_chunk, float* s_chunk, int B, int num_chunks)
内核配置:
- Grid:
B
- Block:
nextPowerOfTwo_32(num_chunks)
回到传统的 shared memory 实现,支持更大的 num_chunks。这种根据问题规模自适应的策略,是高效 C++ 与 CUDA 编程的典范。
5. 最终结果计算 (scan_mul_add_finalize)
__global__ void scan_mul_add_finalize(
const float* a_local, const float* x_local,
const float* s_local, float* h, int L, int num_chunks)
内核配置:
- Grid:
(num_chunks, B)
- Block:
128
计算公式:
$$h[i] = x\_local[i] + a\_local[i] \cdot prefix\_s[chunk\_id]$$
代码实现:
if (global_id < batch_id * L + L) {
float local_p = a_local[global_id];
float local_s = x_local[global_id];
float prefix_s = s_local[batch_id * num_chunks + chunk_id];
h[global_id] = local_s + local_p * prefix_s;
}
6. 总控函数 (linear_recurrence_cuda)
void linear_recurrence_cuda(const float* a, const float* x, float* h, int B, int L)
{
if (L < 128) {
linear_recurrence_naive_cuda(a, x, h, B, L);
} else {
linear_recurrence_scan_cuda(a, x, h, B, L);
}
}
设计哲学:
- 短序列 ($L < 128$): 朴素实现更简单高效。
- 长序列 ($L \ge 128$): Scan 算法充分发挥 GPU 并行能力。
总结
核心技术亮点
| 技巧 |
说明 |
适用场景 |
| 分级扫描策略 |
根据 chunk 数量自动选择最优算法 |
通用 |
| Warp Shuffle |
寄存器通信替代 shared memory |
num_chunks ≤ 32 |
| Smart Chunking |
CHUNK_SIZE=128 平衡负载 |
长序列 |
| Fallback 机制 |
短序列回退到 naive 实现 |
L < 128 |
数学原理回顾
通过引入二元组状态 $(p, s)$ 和组合算子 $\otimes$,我们将原问题转化为可并行的形式:
$$(p_i, s_i) = (a[i], x[i]) \otimes (p_{i-1}, s_{i-1})$$
这一转化是 Parallel Scan 算法能够成功的关键。掌握这种将顺序依赖问题转化为可并行形式的技巧,对于解决更多复杂的数据结构与算法问题大有裨益。
希望这篇深度解析能帮助你理解如何利用 CUDA 高效地并行化线性递推计算。如果你对 GPU加速 或 并行计算 的其他主题感兴趣,欢迎在云栈社区继续交流探讨。
生成时间: 2026-03-30
工具: CUDA Blog Writer v1.0