数学公式(对照正向): 正向: $$y = \begin{cases} F_{\text{then}}(x, \theta_{\text{then}}) & \text{if cond = true}\\ F_{\text{else}}(x, \theta_{\text{else}}) & \text{if cond = false} \end{cases}$$
反向: $$grad\_x, grad\_{\theta_{\text{then/else}}} = \begin{cases} G_{\text{then}}(grad\_y, x, \theta_{\text{then}}) & \text{if cond = true}\\ G_{\text{else}}(grad\_y, x, \theta_{\text{else}}) & \text{if cond = false} \end{cases}$$
对照伪代码:
结论:反向 2.3 对
|
| 正向步骤 | 反向对应步骤 |
|---|---|
| 第 $t$ 次迭代:$s_{t+1} = F(s_t, \theta)$ | 第 $t$ 次反向迭代: 计算 $\partial L/\partial s_t$ 和 $\partial L/\partial \theta$ 的贡献 |
| 正向顺序:$t=0 \rightarrow 1 \rightarrow ... \rightarrow T-1$ | 反向顺序:$t=T-1 \rightarrow T-2 \rightarrow ... \rightarrow 0$ |
| 正向使用 $s_t$ 计算 $s_{t+1}$ | 反向使用 $s_t$ 和 $\partial L/\partial s_{t+1}$ 计算 $\partial L/\partial s_t$ 和 $\partial L/\partial \theta$ |
数学公式(对照正向迭代):
正向迭代:
反向递推(从 $t = T-1$ 向下到 0):
对照伪代码:
# 正向循环
s = s0
for t in range(T):
s = F(s, theta) # s 更新为 s_{t+1}
# 反向循环(注意逆序)
grad_s = grad_y # ∂L/∂s_T
grad_theta = 0
for t in reversed(range(T)): # t = T-1, ..., 0
# 需要知道正向时的 s_t(必须缓存或重算)
s_t = cached_s[t]
grad_s_prev, grad_theta_t = F_gradient(s_t, grad_s, theta)
grad_theta += grad_theta_t
grad_s = grad_s_prev # 传递给前一时间步
# 最终 grad_s 即为 ∂L/∂s_0
Loop 算子对照正向的 Loop 算子,反向图也用一个 Loop 算子实现,但迭代方向隐含在循环体内部。构造对照如下:
正向 Loop 属性 |
反向 Loop 对照属性 |
|---|---|
| 最大迭代次数 $M$ | 相同的 $M$, 但实际步数取正向执行步数 $T$ |
初始条件 cond |
恒为 true(因为我们必须执行完所有反向步) |
| 初始状态 $s_0$ | 初始梯度为 $\partial L/\partial s_T$ (来自损失) |
| 循环体 $F$ | 循环体 $G$, 其内部包含 $F$ 的梯度计算 |
| 循环体输入:$(t, s_t)$ | 循环体输入:$(t_{rev}, grad\_s_{t+1}, acc\_grad\_theta)$ |
| 循环体输出:$(cond\_out, s_{t+1})$ | 循环体输出:$(true, grad\_s_t, new\_acc\_grad\_theta)$ |
| 正向迭代顺序: $t$ 递增 |
反向迭代顺序: 通过 $t = T-1 - t_{rev}$ 映射到 $t$,实现递减 |
形式化对照(ONNX 风格伪代码):
# 正向 Loop 节点
Loop(M, cond_initial, s0, body_F) -> (sT, cond_final)
# 反向 Loop 节点(自动构造)
# 输入:grad_y = ∂L/∂sT, 缓存的正向状态序列 [s0, s1, ..., sT]
# 输出:grad_s0, grad_theta
Loop(T, true, (grad_y, 0, cache), body_G) -> (grad_s0, grad_theta, _)
其中 body_G 的对照内容:
def body_G(t_rev, cond_in, grad_s_next, acc_grad_theta, cache):
t = T - 1 - t_rev # 映射到正向时间步
s_t = cache[t] # 从缓存获取正向状态
# 调用 F 的梯度子图
grad_s_t, grad_theta_t = F_grad(s_t, grad_s_next, theta)
new_acc = acc_grad_theta + grad_theta_t
return (true, grad_s_t, new_acc, cache) # cond_out = true
cond 的梯度对照(再次强调)正向 Loop 中的 cond(初始条件和每次迭代输出的 cond_out)仅控制循环是否继续。反向中,这些布尔变量同样不产生梯度:
cond 不参与数值计算,反向中它们被完全忽略。
| 概念 | 正向 If |
反向 If |
|---|---|---|
| 条件 | cond(布尔) |
相同的 cond |
| 分支 | then_branch 或 else_branch |
对应的 then_gradient 或 else_gradient |
| 执行 | 根据 cond 选一个 |
根据同一个 cond选对应的梯度分支 |
| 概念 | 正向 Loop |
反向 Loop |
|---|---|---|
| 迭代次数 | 动态 $T$ (由 cond 和 $M$ 决定) |
相同 $T$ |
| 初始条件 | 用户提供的 cond |
恒为 true |
| 初始状态 | $s_0$ | $\partial L/\partial s_T$ (来自损失) |
| 循环体 | $F$ | $G$ |
| 状态更新 | 顺序:$s_t \rightarrow s_{t+1}$ | 逆序:$\partial L/\partial s_{t+1} \rightarrow \partial L/\partial s_t$ (通过索引映射) |
| 输出 | 最终状态 $s_T$ | 初始梯度 $\partial L/\partial s_0$ 和参数梯度 $\partial L/\partial \theta$ |
在实际的 ONNX GradientBuilder 中,构造反向图时:
If 或 Loop 节点时,递归处理。If:
cond 的节点输出。then_branch 和 else_branch 子图,构建它们的梯度子图(递归)。If 节点,条件输入为记录的 cond,then_branch 和 else_branch 分别设置为刚刚构建的两个梯度子图。Loop:
cond 输出或额外保存)。Loop 节点,最大迭代次数设为 $T$,初始条件为 true,初始状态为 $(\partial L/\partial s_T, 0)$,循环体为 $G$。cond_final)均被忽略或置零。这种构造方式保证了反向图与正向图在控制流结构上完全对照,只是计算内容从“求值”变成了“求梯度”。
考虑一个简单例子:正向中根据 cond 决定对张量做平方或开方,然后循环多次。对照关系如下:
正向伪代码:
x = input
if cond:
y = x * x # then_branch
else:
y = sqrt(x) # else_branch
s = y
for i in range(3):
s = s + 1 # 简单循环体
反向图伪代码(对照):
grad_y = grad_output
if cond: # 相同的 cond
grad_x = 2 * x * grad_y # then_gradient
else:
grad_x = 0.5 / sqrt(x) * grad_y # else_gradient
grad_s = grad_y # 循环初始梯度
grad_theta = 0
for i in reversed(range(3)): # 反向循环
# 循环体梯度:加法梯度就是 1
grad_s_prev = grad_s * 1
# 无参数,所以无 grad_theta 更新
grad_s = grad_s_prev
# 最终 grad_s 即为 grad_y 不变,因为加法的梯度是恒等映射
可以看到,反向完全遵循正向的控制流结构,只是内部运算替换为梯度运算。
本文通过正向与反向的逐项对照,揭示了 ONNX If 和 Loop 算子自动微分的本质:
If/Loop),且控制条件(cond)和迭代次数与正向执行一致。Loop 算子。cond)的梯度恒为零,对照正向中它们仅用于决策。理解这种对照关系,不仅有助于正确导出可训练的 ONNX 模型,也为设计支持动态控制流的自定义算子梯度提供了理论指导。随着 ONNX 训练生态的不断完善,掌握控制流的自动微分将成为深度学习系统工程师的核心能力之一。若想了解更多关于算子实现与框架集成的技术细节,欢迎在云栈社区交流探讨。