JAX高性能运行的秘诀其实很直接:通过巧妙的变换组合,让底层的XLA编译器能够“看见”并优化大块连续的计算逻辑。无论是批处理、算子融合还是张量分片,目标都是让计算在单设备或多设备同步时,都能像一个高效的内核一样执行。今天,我们就来深入总结7个能够显著提升JAX程序运行速度的核心变换组合。

1、 jit 优先,确保形状稳定
jit 变换通过对函数进行一次追踪,将内部操作交由XLA进行算子融合与优化,从而将Python层的开销分摊掉。要实现最佳效果,关键是要保持被追踪函数的“形状稳定”,并且没有副作用。
应将动态的形状创建和静态参数移到 step 函数外部,或者使用 static_argnums 显式标记为静态。donate_argnums 参数允许JAX复用输入缓冲区,避免不必要的内存拷贝。此外,确保各步骤之间输入输出的 dtype 和 shape 保持一致,这样编译结果才能被有效缓存。
import jax, jax.numpy as jnp
@jax.jit(donate_argnums=(0,))
def sgd_step(params, batch, lr):
x, y = batch
def loss_fn(p):
preds = model_apply(p, x) # 纯函数
return jnp.mean((preds - y) ** 2)
grads = jax.grad(loss_fn)(params)
return jax.tree_map(lambda p, g: p - lr * g, params, grads)
每个独特的 (shape, dtype, static-arg) 组合只会被追踪编译一次。如果发现频繁的重新追踪(retrace),通常是因为输入形状在变化,或者有Python逻辑“泄露”到了计算图中。
2、用 vmap 替代 Python 循环
vmap 在指定的前导轴(leading axis)上自动进行向量化,XLA能够将整个批处理过程融合进一个内核中执行。这不仅消除了显式的Python for 循环,减少了设备启动开销,还使得内存访问模式更加连续。
# 单个样本的损失计算
def example_loss(params, x, y):
pred = model_apply(params, x)
return jnp.mean((pred - y) ** 2)
# 无需手动写循环,直接向量化批次计算
batched_loss = jax.vmap(example_loss, in_axes=(None, 0, 0)) # params 被广播
你甚至可以嵌套使用 vmap 来处理二维批次,例如时间步 × 批量大小,只要确保不超过设备HBM容量。vmap 非常适合用于内层的微批处理,例如集成学习或蒙特卡洛采样等场景,而外层维度则可以留给分片并行。
3、处理长循环的融合利器:Scan
对于RNN、展开式解码、迭代求解器等包含长序列循环的场景,使用 lax.scan 比Python循环要高效得多。scan 只编译一次循环体,然后在XLA的 while-loop 结构中运行,几乎消除了Python开销,并且能进行更激进的算子融合和内存复用。
from jax import lax
def rnn_cell(carry, x):
h = carry
h = jnp.tanh(W_hh @ h + W_xh @ x + b)
y = W_hy @ h
return h, y # 返回 (新的状态, 输出)
def rnn_forward(h0, xs):
hT, ys = lax.scan(rnn_cell, h0, xs) # xs 形状: [时间步T, 批次B, 特征D]
return hT, ys
循环状态通过 carry 传递,循环体应保持小巧、纯净。保持循环过程中张量形状的稳定是关键。这种方法适用于序列模型、扩散模型步进循环、定点迭代以及形状稳定的束搜索解码等场景。
4、remat:以计算换取内存
更大的批次尺寸(batch size)往往能更好地“喂饱”TPU/GPU,提升FLOPs利用率。remat(也称为梯度检查点)会丢弃部分前向传播的中间激活值,在反向传播需要时重新计算它们。这样峰值显存占用得以降低,从而允许使用更大的批次。
from jax import remat
def block(params, x):
x = jax.nn.gelu(x @ params['w1'])
x = x @ params['w2']
return x
fast_block = remat(block) # 对该块启用重计算
@jax.jit
def forward(params, x):
for _ in range(6):
x = x + fast_block(params, x) # 重计算被应用在这里
return x
通常只需要对计算最重的子模块(例如Transformer中的注意力层或大型MLP层)应用 remat。同时配合 vmap 或分片,可以进一步提升全局批次大小。虽然这会引入额外的FLOPs开销,但如果能换来1.3到2倍的批次提升,实际运行时间(wall-clock time)往往更短。
5、pmap:实现单机多卡数据并行
pmap 将函数复制到单台主机的多个设备上(例如8卡工作站或单节点8核TPU),自动处理梯度的跨设备规约(all-reduce),并且每个设备只编译一次。
from jax import pmap, lax
@pmap(axis_name='d')
def train_step(params, batch, lr):
x, y = batch # 每个设备看到的是 [local_B, ...]
def loss_fn(p):
pred = model_apply(p, x)
loss = jnp.mean((pred - y) ** 2)
return loss
loss, grads = jax.value_and_grad(loss_fn)(params)
loss = lax.pmean(loss, axis_name='d') # 聚合损失
grads = lax.pmean(grads, axis_name='d') # 聚合梯度
params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
return params, loss
批次数据沿前导轴进行分片,lax.pmean 负责聚合各设备上的损失和梯度。在单机多卡场景下,pmap 简单可靠。如果需要跨主机扩展,或希望进行张量级别的细粒度分片,则可以转而使用 pjit。
6、pjit + 命名分片:灵活的SPMD并行
pjit 可以编译出单一的SPMD(单程序多数据)程序,运行在跨设备甚至跨主机的集群上。通过定义网格(Mesh)和分区规约(PartitionSpec),你可以精确描述张量如何被切分到各设备,JAX会自动处理所需的集体通信。这使得数据并行、张量模型并行以及混合并行策略成为可能。
import jax
from jax.sharding import Mesh, PartitionSpec as P
import numpy as np
# 创建一个 2 × 4 的网格 (数据并行 × 模型并行)
devices = np.array(jax.devices()).reshape(2, 4)
mesh = Mesh(devices, ('dp', 'mp'))
from jax.experimental.pjit import pjit
with mesh:
# 定义输入输出的分片方式 (需根据具体形状调整)
step = pjit(model_apply,
in_shardings=(P('mp',), P('dp',)), # 参数按模型并行分片,输入按数据并行分片
out_shardings=P('dp',)) # 输出按数据并行分片
y = step(params_sharded, x_sharded)
常见的策略是让批次维度走数据并行(dp),让大矩阵维度(如隐藏层大小、注意力头数)走模型并行(mp)。分片方案需要与设备拓扑对齐,以最小化跨主机通信流量。
7、value_and_grad 的正确堆叠方式
规范的做法是将 value_and_grad 包裹在 jit 内部,即 jit(value_and_grad(loss, has_aux=True)),外层再根据需要套用 pmap 或 pjit。这样,前向传播只执行一次,额外的指标(metrics)可以通过 aux 参数带出,无需再次计算。
def loss_with_aux(params, batch):
x, y = batch
pred = model_apply(params, x)
loss = jnp.mean((pred - y) ** 2)
aux = {'mse': loss, 'mean_pred': jnp.mean(pred)} # 将指标放在 aux 中
return loss, aux
@jax.jit
def train_step(params, opt_state, batch, lr):
(loss, aux), grads = jax.value_and_grad(loss_with_aux, has_aux=True)(params, batch)
# ... 使用优化器更新参数 ...
return params, opt_state, loss, aux
将 value_and_grad 放在 jit 内部,JAX会将前向和反向计算一同交给XLA进行优化。返回的 (loss, aux) 使得记录日志指标时无需重新运行一次前向传播,这在深度学习模型训练中能有效提升效率。
这套组合拳非常灵活:内部用 vmap 处理微批次,用 scan 处理时序循环,外层套上 pmap 或 pjit 进行扩展,并用 donate_argnums 标记可回收的缓冲区。
总结
要最大化JAX性能,首先确保计算图形状稳定(例如,变长序列需填充并配合掩码)。被追踪的代码中应避免引入Python侧的随机性,PRNG密钥应在外部提前切分好。矩阵乘法可考虑使用 bfloat16 数据类型,在保证数值稳定性的同时,在TPU/GPU上获得更高的吞吐量。性能剖析(profile)应重点关注预热(warm-up)后的 tokens/sec 或 samples/sec 指标。日志记录仅需标量指标,切勿在每个训练步将大数组传回主机端,这是性能杀手。
本质上,JAX的性能优化是透明可控的:以稳定的 jit 为基础,用 vmap 扩展批次,用 scan 融合长循环,用 remat 平衡内存,用 pmap 或 pjit 进行横向扩展,最后用 value_and_grad(..., has_aux=True) 确保每一步只执行一次前向和一次反向传播。掌握这些在 Python 高性能计算中的核心变换,你将能更充分地利用硬件算力。更多关于前沿技术实践的讨论,欢迎访问云栈社区与广大开发者交流。