记录在大模型推理优化与相关岗位面试中经常被考察手写实现的一些核心CUDA算子,并探讨其关键优化策略。本文包含可运行的完整代码示例,代码仓库地址位于文末。

01 Softmax
手写Safe Softmax是常见考察点,除了掌握基础的安全实现(防止数值溢出),了解Online Softmax的实现思路也是加分项。其核心优化在于Warp级别和Block级别的归约操作。
// Warp Reduce Sum
template<const int kWarpSize = WARP_SIZE>
__device__ __forceinline__ float warp_reduce_sum_f32(float val) {
#pragma unroll
for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) {
val += __shfl_xor_sync(0xffffffff, val, mask);
}
return val;
}
// Warp Reduce Max
template<const int kWarpSize = WARP_SIZE>
__device__ __forceinline__ float warp_reduce_max_f32(float val) {
#pragma unroll
for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) {
val = fmaxf(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}
template<const int NUM_THREADS=256>
__device__ float block_reduce_sum_f32(float val) {
constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
static __shared__ float shared[NUM_WARPS];
float value = warp_reduce_sum_f32<WARP_SIZE>(val);
if (lane == 0) shared[warp] = value;
__syncthreads();
value = (lane < NUM_WARPS) ? shared[lane] : 0.0f;
value = warp_reduce_sum_f32<NUM_WARPS>(value);
value = __shfl_sync(0xffffffff, value, 0, 32);
return value;
}
// 向量化FLOAT4版本的Safe Softmax
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
template <const int NUM_THREADS = 256 / 4>
__global__ void safe_softmax_f32x4_per_token_kernel(float *x, float *y, int N) {
const int tid = threadIdx.x;
const int idx = (blockIdx.x * blockDim.x + tid) * 4;
float4 reg_x = FLOAT4(x[idx]);
reg_x.x = (idx + 0 < N) ? reg_x.x : -FLT_MAX;
reg_x.y = (idx + 1 < N) ? reg_x.y : -FLT_MAX;
reg_x.z = (idx + 2 < N) ? reg_x.z : -FLT_MAX;
reg_x.w = (idx + 3 < N) ? reg_x.w : -FLT_MAX;
float val = reg_x.x;
val = fmaxf(val, reg_x.y);
val = fmaxf(val, reg_x.z);
val = fmaxf(val, reg_x.w);
float max_val = block_reduce_max_f32<NUM_THREADS>(val); // block max
float4 reg_exp;
reg_exp.x = (idx + 0 < N) ? expf(reg_x.x - max_val) : 0.0f;
reg_exp.y = (idx + 1 < N) ? expf(reg_x.y - max_val) : 0.0f;
reg_exp.z = (idx + 2 < N) ? expf(reg_x.z - max_val) : 0.0f;
reg_exp.w = (idx + 3 < N) ? expf(reg_x.w - max_val) : 0.0f;
float exp_val = (reg_exp.x + reg_exp.y + reg_exp.z + reg_exp.w);
float exp_sum = block_reduce_sum_f32<NUM_THREADS>(exp_val); // block sum
// e^x_i/sum(e^x_0,...,e^x_n-1)
if (idx + 3 < N) {
float4 reg_y;
reg_y.x = reg_exp.x / (exp_sum);
reg_y.y = reg_exp.y / (exp_sum);
reg_y.z = reg_exp.z / (exp_sum);
reg_y.w = reg_exp.w / (exp_sum);
FLOAT4(y[idx]) = reg_y;
}
}
02 GEMM
通用矩阵乘法是CUDA性能优化的经典案例。虽然面试通常不会要求写出极致优化的版本,但理解从朴素实现到分块优化、共享内存利用、以及线程块分块的演进思路至关重要。这直接关系到你对GPU编程和内存层次结构的理解深度。
// 朴素GEMM内核
__global__ void naive_sgemm(const float* A, const float* B, float* C, int M, int N, int K) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= M || col >= N) return;
float acc = 0.0f;
for (int k = 0; k < K; ++k) {
acc += A[row * K + k] * B[k * N + col];
}
C[row * N + col] = acc;
}
// 使用共享内存的Block-Tile优化
#define BLOCK_SIZE 32
__global__ void sgemm(float* A, float* B, float* C, int M, int N, int K) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int idy = blockDim.y * blockIdx.y + threadIdx.y;
if (idx >= M || idy >= N) return;
int bx = blockIdx.x;
int by = blockIdx.y;
int tx = threadIdx.x;
int ty = threadIdx.y;
const int BM = BLOCK_SIZE;
const int BN = BLOCK_SIZE;
const int BK = BLOCK_SIZE;
__shared__ float As[BM * BK];
__shared__ float Bs[BK * BN];
// 初始化block tile起始位置
A = &A[(by * BM) * K];
B = &B[bx * BN];
C = &C[(by * BM) * N + bx * BN];
float accum = 0.0f;
for (int k = 0; k < K; k += BK) {
// 搬运 global ==> shared
As[ty * BK + tx] = A[ty * K + tx];
Bs[ty * BN + tx] = B[ty * N + tx];
__syncthreads();
A = A + BK;
B = B + BK * N;
for (int i = 0; i < BK; i++) {
accum += As[ty * BK + i] * Bs[i * BN + tx];
}
__syncthreads();
}
C[ty * N + tx] = accum;
}
03 Reduce
归约操作是许多算子的基础。一个典型的面试题是:给定一个形状为[batch, 4096]的输入矩阵,如何高效地实现按行求和(Reduce Sum)得到[batch]的输出?关键在于处理长序列时的并行策略与层次化归约。
template <typename T>
__global__ void reduce_sum_kernel(const T* __restrict__ x, T* __restrict__ y, int num_elements) {
extern __shared__ T sdata[];
int tid = threadIdx.x;
int row = blockIdx.x; // 每个 block 处理一行
T sum = 0;
// 每个线程对该行的多个元素求部分和; 跨步读取
for (int i = tid; i < num_elements; i += blockDim.x) {
sum += x[row * num_elements + i];
}
sdata[tid] = sum;
__syncthreads();
for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
sdata[tid] += sdata[tid + s];
}
__syncthreads();
}
if (tid == 0) {
y[row] = sdata[0];
}
}
// 启动配置示例
int batch = 10; // 批次数量
int num_elements = 4096; // 每行元素数
int threadsPerBlock = 256;
dim3 grid(batch); // 每个 block 对应一行
dim3 block(threadsPerBlock);
reduce_sum_kernel<float><<<grid, block, threadsPerBlock * sizeof(float)>>>(x, y, num_elements);
04 LayerNorm
层归一化是Transformer架构中的关键组件。其CUDA实现的核心依然是高效的均值和方差计算,这依赖于高性能的归约操作。
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
template<const int NUM_THREADS=128>
__global__ void layer_norm(float* x,float* y,float g,float b,int N,int K){
int tid = threadIdx.x;
int idx = tid + blockIdx.x * NUM_THREADS;
int bid = blockIdx.x;
const float epsilon = 1e-5f;
__shared__ float s_mean; // shared within block
__shared__ float s_variance; // shared within block
float value = (idx < N) ? x[idx] : 0.0f;
float sum = block_reduce_sum<NUM_THREADS>(value);
if(tid == 0) s_mean = sum / (float)K; // 均值
__syncthreads();
float variance = (value-s_mean) * (value-s_mean);
variance = block_reduce_sum<NUM_THREADS>(variance);
if(tid == 0) s_variance = rsqrtf(variance/(float)K+epsilon);
__syncthreads();
if(idx<N*K) y[idx] = ((value-s_mean)*s_variance)*g+b;
}
05 RmsNorm
RMSNorm是LLM中常用的归一化方法,相比LayerNorm省去了均值计算。其优化重点同样在于平方和的归约。
// Warp级别的归约辅助函数
#define WARP_SIZE 32
__device__ float warpReduce(float x) {
float val = x;
for (int activeThreads = WARP_SIZE >> 1; activeThreads > 0;
activeThreads >>= 1) {
val += __shfl_down_sync(0xffffffff, val, activeThreads);
}
return val;
}
// 结合Warp归约与Shared Memory的RMSNorm内核
template <int hiddenDim, int threadsPerBlock>
__global__ void rmsNormKernelWarp(float *x, float *w, float eps, float *y) {
__shared__ float squaredPerThread[threadsPerBlock];
__shared__ float xShared[hiddenDim];
__shared__ float sumPerWarp[WARP_SIZE];
__shared__ float rms_;
const int tid = threadIdx.x;
const int laneId = tid & 31;
const int warpId = tid >> 5;
const int warpsPerBlock = threadsPerBlock >> 5;
const int bid = blockIdx.x;
float sum = 0.0f;
for (int i = tid; i < hiddenDim; i += threadsPerBlock) {
float x_ = x[bid * hiddenDim + i];
xShared[i] = x_;
sum += x_ * x_;
}
squaredPerThread[tid] = sum;
__syncthreads();
float warpSum = warpReduce(squaredPerThread[tid]);
if (laneId == 0) {
sumPerWarp[warpId] = warpSum;
}
__syncthreads();
if (tid < WARP_SIZE) {
sumPerWarp[tid] = warpReduce(tid < warpsPerBlock ? sumPerWarp[tid] : 0);
if (tid == 0) {
rms_ = rsqrtf(sumPerWarp[tid] / hiddenDim + eps);
}
}
__syncthreads();
for (int i = tid; i < hiddenDim; i += threadsPerBlock) {
y[bid * hiddenDim + i] = xShared[i] * rms_ * w[i];
}
}
06 Sigmoid
Sigmoid是经典的非线性激活函数,输出范围在(0,1),常用于二分类问题的概率输出。其CUDA实现相对简单,但向量化访存能有效提升性能。
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
// 标准Sigmoid
__global__ void sigmoid(float* x, float* y, int N) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < N) y[idx] = 1.0f / (1.0f + expf(-x[idx]));
}
// 向量化版本
__global__ void sigmoid_vec4(float* x, float* y, int N) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
if (idx < N) {
float4 reg_x = FLOAT4(x[idx]);
float4 reg_y;
reg_y.x = 1.0f / (1.0f + expf(-reg_x.x));
reg_y.y = 1.0f / (1.0f + expf(-reg_x.y));
reg_y.z = 1.0f / (1.0f + expf(-reg_x.z));
reg_y.w = 1.0f / (1.0f + expf(-reg_x.w));
FLOAT4(y[idx]) = reg_y;
}
}
07 Silu
SiLU(Swish)激活函数在深度网络中表现良好,具备自门控特性。其实现为x * sigmoid(x),需要注意sigmoid计算中的数值稳定性。
// 优化后的Sigmoid函数,避免数值溢出
__device__ __forceinline__ float sigmoid(float x) {
if (x >= 0.0f) {
return 1.0f / (1.0f + __expf(-x));
} else {
float exp_x = __expf(x);
return exp_x / (1.0f + exp_x);
}
}
// float4向量化优化
__global__ void silu_kernel_vectorized(float4* input, float4* output, int n_float4) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n_float4) {
float4 vec = input[idx];
vec.x *= sigmoid(vec.x);
vec.y *= sigmoid(vec.y);
vec.z *= sigmoid(vec.z);
vec.w *= sigmoid(vec.w);
output[idx] = vec;
}
}
08 Gemv
在大模型解码阶段,KV Cache使得注意力计算从GEMM退化为GEMV(矩阵向量乘)。优化GEMV的关键在于充分利用内存带宽和Warp内的归约。
// 基础GEMV实现,每个Warp处理一行
template<unsigned int WarpSize>
__device__ __forceinline__ float warpReduceSum(float sum) {
if(WarpSize >= 32) sum += __shfl_down_sync(0xffffffff,sum,16);
if(WarpSize >= 16) sum += __shfl_down_sync(0xffffffff,sum,8);
if(WarpSize >= 8) sum += __shfl_down_sync(0xffffffff,sum,4);
if(WarpSize >= 4) sum += __shfl_down_sync(0xffffffff,sum,2);
if(WarpSize >= 2) sum += __shfl_down_sync(0xffffffff,sum,1);
return sum;
}
__global__ void sgemv_v0(
float* __restrict__ A,
float* __restrict__ x,
float* __restrict__ y,
const int M,const int N)
{
int bx = blockIdx.x;
int tx = threadIdx.x;
int ty = threadIdx.y;
const int warp_size = 32;
int laneId = tx % warp_size;
int current_row = blockDim.y * bx + ty;
if(current_row < M) {
float res = 0;
int kIteration = (N + warp_size - 1) / warp_size; // 向上取整
#pragma unroll
for(int i=0;i<kIteration;i++) {
int current_col = i*warp_size + laneId;
if(current_col < N) {
res += A[current_row*N+current_col] * x[current_col];
}
}
res = warpReduceSum<warp_size>(res);
if(laneId==0) y[current_row] = res;
}
}
09 Scan
前缀和(Scan)是一种略微复杂的并行模式。高效的Scan实现通常采用层次化方法:先进行Warp内扫描,然后进行Block内Warp结果的扫描,最后组合。理解这种并行算法设计对优化复杂算子很有帮助。
#define WARP_SIZE 32
#define LOG_WARP_SIZE 5
#define WARP_MASK (WARP_SIZE - 1)
__device__ inline int lane_id(void) { return threadIdx.x & WARP_MASK; }
__device__ inline int warp_id(void) { return threadIdx.x >> LOG_WARP_SIZE; }
// Warp内的并行前缀和(Exclusive Scan)
__device__ __forceinline__ int warp_scan(int val) {
int x = val;
#pragma unroll
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
int y = __shfl_up_sync(0xffffffff, x, offset);
if (lane_id() >= offset) x += y;
}
return x - val; // 返回 exclusive scan 结果
}
// Block级别的扫描
template <int threadsPerBlock>
__device__ int block_scan(int in) {
__shared__ int sdata[threadsPerBlock >> LOG_WARP_SIZE];
// A. Exclusive scan within each warp
int warpPrefix = warp_scan(in);
// B. Store in shared memory
if (lane_id() == WARP_SIZE - 1) sdata[warp_id()] = warpPrefix + in;
__syncthreads();
// C. One warp scans in shared memory
if (threadIdx.x < WARP_SIZE)
sdata[threadIdx.x] = warp_scan(sdata[threadIdx.x]);
__syncthreads();
// D. Each thread calculates its final value
int thread_out_element = warpPrefix + sdata[warp_id()];
return thread_out_element;
}
代码仓库
文中涉及的所有可运行完整代码示例,均已整理至GitHub仓库:
https://github.com/XiaoDiandian-623/CUDA-Demo