找回密码
立即注册
搜索
热搜: Java Python Linux Go
发回帖 发新帖

2824

积分

0

好友

384

主题
发表于 前天 00:15 | 查看: 13| 回复: 0

在深度学习模型训练中,自动微分是核心引擎。对于卷积、矩阵乘法等无控制流的算子,反向传播往往对应一个固定的梯度算子(如 AddGradMatMulGrad)。然而,当模型包含动态控制流——ONNX 中的 IfLoop 算子——情况变得复杂:没有预定义的反向算子,反向图必须动态构造,且其结构必须与前向执行的控制流严格对照

本文将从正向执行出发,逐步骤对照反向计算,用数学公式、伪代码和构造规则揭示 IfLoop 的反向图如何保持“同态”。

一、对照原则:反向图是正向图的“结构镜像”

自动微分中,反向计算图通常由正向图的逆序执行构成。对于控制流,这意味着:

  • 反向图复用正向的控制条件cond 的值或迭代次数),而不是重新计算。
  • 反向图的算子结构与正向图相同(If 对应 IfLoop 对应 Loop),但内部子图替换为对应的梯度计算子图。
  • 反向的执行顺序严格逆向于正向:正向先执行 then 分支,反向也先计算 then 分支的梯度;正向先执行第一次循环迭代,反向则先计算最后一次迭代的梯度。

下文将对每个算子展开这种对照。

二、If 算子:正向分支选择与反向梯度分支选择

2.1 正向执行

ONNX If 算子执行:

  • cond 是布尔标量(不可微)。
  • then_branchelse_branch 是两个子图,具有相同的输入/输出签名。
  • 正向执行时,根据 cond 的值,只执行其中一个子图

2.2 反向计算(对照说明)

反向传播需要计算损失 $L$ 对输入 $x$ 和子图内部参数的梯度。关键对照点:

正向要素 反向对应
cond 的值(true/false) 完全相同cond
用于决定反向走哪个分支
执行的子图(如 then_branch 对应的梯度子图(如 GradThen
其输入是 $grad\_y = \partial L/\partial y$
未执行的子图(如 else_branch 对应的梯度子图不会被反向执行
因为正向没有走过

数学公式(对照正向)

正向:

$$y = \begin{cases} F_{\text{then}}(x, \theta_{\text{then}}) & \text{if cond = true}\\ F_{\text{else}}(x, \theta_{\text{else}}) & \text{if cond = false} \end{cases}$$

反向:

$$grad\_x, grad\_{\theta_{\text{then/else}}} = \begin{cases} G_{\text{then}}(grad\_y, x, \theta_{\text{then}}) & \text{if cond = true}\\ G_{\text{else}}(grad\_y, x, \theta_{\text{else}}) & \text{if cond = false} \end{cases}$$

对照伪代码

# 正向
if cond:
    y = then_branch(x)
else:
    y = else_branch(x)

# 反向(构建时复用 cond)
if cond:
    grad_x = then_gradient(grad_y, x)
else:
    grad_x = else_gradient(grad_y, x)

结论:反向 If 算子的结构与正向完全一致,只是内部子图从 then_branch 换成了 then_gradient,从 else_branch 换成了 else_gradientcond 本身不参与梯度计算,其梯度为零。

2.3 对 cond 的梯度对照

正向中 cond 用于选择路径,不参与数值计算。反向中,任何关于 cond 的偏导都没有定义:

$$\frac{\partial L}{\partial \text{cond}} = 0$$

在 ONNX 图中,cond 的输入边被标记为 stop_gradient,这对照了正向中 cond 的离散决策角色。

三、Loop 算子:正向顺序迭代与反向逆序迭代的对照

3.1 正向执行

ONNX Loop 算子执行循环体多次。简化模型(忽略动态条件退出,假设固定迭代 $T$ 次):

$$s_{t+1} = F(s_t, \theta), \quad t = 0, 1, ..., T-1$$

初始状态 $s_0$ 给定,最终输出 $y = s_T$。这里 $F$ 是循环体 body 子图,$\theta$ 是循环体内部的可训练参数(跨迭代共享)。

正向执行过程(按时间顺序):

  • 迭代 0: $s_1 = F(s_0, \theta)$
  • 迭代 1: $s_2 = F(s_1, \theta)$
  • ...
  • 迭代 $T-1$: $s_T = F(s_{T-1}, \theta)$

3.2 反向计算(对照说明)

反向传播需要计算 $\partial L/\partial s_0$$\partial L/\partial \theta$。对照正向的迭代顺序,反向必须逆向遍历时间步:

正向步骤 反向对应步骤
$t$ 次迭代:$s_{t+1} = F(s_t, \theta)$ $t$ 次反向迭代:
计算 $\partial L/\partial s_t$$\partial L/\partial \theta$ 的贡献
正向顺序:$t=0 \rightarrow 1 \rightarrow ... \rightarrow T-1$ 反向顺序:$t=T-1 \rightarrow T-2 \rightarrow ... \rightarrow 0$
正向使用 $s_t$ 计算 $s_{t+1}$ 反向使用 $s_t$$\partial L/\partial s_{t+1}$
计算 $\partial L/\partial s_t$$\partial L/\partial \theta$

数学公式(对照正向迭代)

正向迭代:

$$s_{t+1} = F(s_t, \theta)$$

反向递推(从 $t = T-1$ 向下到 0):

$$\frac{\partial L}{\partial s_t} = \frac{\partial L}{\partial s_{t+1}} \cdot \frac{\partial F}{\partial s_t}$$

$$\frac{\partial L}{\partial \theta} \text{ (累计) } += \frac{\partial L}{\partial s_{t+1}} \cdot \frac{\partial F}{\partial \theta}$$

其中 $\partial L/\partial s_T$ 由损失函数提供。

对照伪代码

# 正向循环
s = s0
for t in range(T):
    s = F(s, theta)          # s 更新为 s_{t+1}

# 反向循环(注意逆序)
grad_s = grad_y              # ∂L/∂s_T
grad_theta = 0
for t in reversed(range(T)): # t = T-1, ..., 0
    # 需要知道正向时的 s_t(必须缓存或重算)
    s_t = cached_s[t]
    grad_s_prev, grad_theta_t = F_gradient(s_t, grad_s, theta)
    grad_theta += grad_theta_t
    grad_s = grad_s_prev     # 传递给前一时间步
# 最终 grad_s 即为 ∂L/∂s_0

3.3 反向图结构:仍然是一个 Loop 算子

对照正向的 Loop 算子,反向图也用一个 Loop 算子实现,但迭代方向隐含在循环体内部。构造对照如下:

正向 Loop 属性 反向 Loop 对照属性
最大迭代次数 $M$ 相同$M$
但实际步数取正向执行步数 $T$
初始条件 cond 恒为 true
(因为我们必须执行完所有反向步)
初始状态 $s_0$ 初始梯度为 $\partial L/\partial s_T$
(来自损失)
循环体 $F$ 循环体 $G$
其内部包含 $F$ 的梯度计算
循环体输入:$(t, s_t)$ 循环体输入:$(t_{rev}, grad\_s_{t+1}, acc\_grad\_theta)$
循环体输出:$(cond\_out, s_{t+1})$ 循环体输出:$(true, grad\_s_t, new\_acc\_grad\_theta)$
正向迭代顺序:
$t$ 递增
反向迭代顺序:
通过 $t = T-1 - t_{rev}$ 映射到 $t$,实现递减

形式化对照(ONNX 风格伪代码)

# 正向 Loop 节点
Loop(M, cond_initial, s0, body_F) -> (sT, cond_final)

# 反向 Loop 节点(自动构造)
# 输入:grad_y = ∂L/∂sT, 缓存的正向状态序列 [s0, s1, ..., sT]
# 输出:grad_s0, grad_theta
Loop(T, true, (grad_y, 0, cache), body_G) -> (grad_s0, grad_theta, _)

其中 body_G 的对照内容:

def body_G(t_rev, cond_in, grad_s_next, acc_grad_theta, cache):
    t = T - 1 - t_rev                # 映射到正向时间步
    s_t = cache[t]                   # 从缓存获取正向状态
    # 调用 F 的梯度子图
    grad_s_t, grad_theta_t = F_grad(s_t, grad_s_next, theta)
    new_acc = acc_grad_theta + grad_theta_t
    return (true, grad_s_t, new_acc, cache)  # cond_out = true

3.4 对 cond 的梯度对照(再次强调)

正向 Loop 中的 cond(初始条件和每次迭代输出的 cond_out)仅控制循环是否继续。反向中,这些布尔变量同样不产生梯度:

$$\frac{\partial L}{\partial \text{cond}} = 0, \quad \frac{\partial L}{\partial \text{cond\_out}} = 0$$

对照正向中 cond 不参与数值计算,反向中它们被完全忽略。

四、对照表总览

概念 正向 If 反向 If
条件 cond(布尔) 相同的 cond
分支 then_branch
else_branch
对应的 then_gradient
else_gradient
执行 根据 cond 选一个 根据同一个 cond
选对应的梯度分支
概念 正向 Loop 反向 Loop
迭代次数 动态 $T$
(由 cond$M$ 决定)
相同 $T$
初始条件 用户提供的 cond 恒为 true
初始状态 $s_0$ $\partial L/\partial s_T$
(来自损失)
循环体 $F$ $G$
状态更新 顺序:$s_t \rightarrow s_{t+1}$ 逆序:$\partial L/\partial s_{t+1} \rightarrow \partial L/\partial s_t$
(通过索引映射)
输出 最终状态 $s_T$ 初始梯度 $\partial L/\partial s_0$
和参数梯度 $\partial L/\partial \theta$

五、构造方法中的对照细节

在实际的 ONNX GradientBuilder 中,构造反向图时:

  1. 遍历正向图节点,遇到 IfLoop 节点时,递归处理。
  2. 对于 If
    • 记录正向 cond 的节点输出。
    • 分别进入 then_branchelse_branch 子图,构建它们的梯度子图(递归)。
    • 生成一个新的 If 节点,条件输入为记录的 condthen_branchelse_branch 分别设置为刚刚构建的两个梯度子图。
  3. 对于 Loop
    • 记录正向循环的实际迭代次数 $T$(可以通过分析 cond 输出或额外保存)。
    • 记录正向循环中每个迭代的中间状态(或标记哪些需要缓存)。
    • 构建反向循环体子图 $G$,其内部包含:
      • 根据反向迭代计数计算正向时间步索引。
      • 从缓存中读取正向状态 $s_t$
      • 调用正向循环体 $F$ 的梯度子图(递归构建)。
      • 累加参数梯度。
    • 生成一个新的 Loop 节点,最大迭代次数设为 $T$,初始条件为 true,初始状态为 $(\partial L/\partial s_T, 0)$,循环体为 $G$
  4. 所有控制流节点输出的梯度(如 cond_final)均被忽略或置零。

这种构造方式保证了反向图与正向图在控制流结构上完全对照,只是计算内容从“求值”变成了“求梯度”。

六、实际应用中的对照示例

考虑一个简单例子:正向中根据 cond 决定对张量做平方或开方,然后循环多次。对照关系如下:

正向伪代码

x = input
if cond:
    y = x * x          # then_branch
else:
    y = sqrt(x)        # else_branch

s = y
for i in range(3):
    s = s + 1          # 简单循环体

反向图伪代码(对照)

grad_y = grad_output
if cond:               # 相同的 cond
    grad_x = 2 * x * grad_y   # then_gradient
else:
    grad_x = 0.5 / sqrt(x) * grad_y  # else_gradient

grad_s = grad_y        # 循环初始梯度
grad_theta = 0
for i in reversed(range(3)):   # 反向循环
    # 循环体梯度:加法梯度就是 1
    grad_s_prev = grad_s * 1
    # 无参数,所以无 grad_theta 更新
    grad_s = grad_s_prev
# 最终 grad_s 即为 grad_y 不变,因为加法的梯度是恒等映射

可以看到,反向完全遵循正向的控制流结构,只是内部运算替换为梯度运算。

七、总结

本文通过正向与反向的逐项对照,揭示了 ONNX IfLoop 算子自动微分的本质:

  • 结构镜像:反向图使用相同的控制流算子(If/Loop),且控制条件(cond)和迭代次数与正向执行一致。
  • 子图替换:正向子图被替换为对应的梯度子图,梯度子图的构造递归进行。
  • 方向逆转:循环的反向是时间反序的,但通过循环体内部的索引映射实现,外部仍表现为一个 Loop 算子。
  • 不可微输入:所有布尔控制变量(cond)的梯度恒为零,对照正向中它们仅用于决策。

理解这种对照关系,不仅有助于正确导出可训练的 ONNX 模型,也为设计支持动态控制流的自定义算子梯度提供了理论指导。随着 ONNX 训练生态的不断完善,掌握控制流的自动微分将成为深度学习系统工程师的核心能力之一。若想了解更多关于算子实现与框架集成的技术细节,欢迎在云栈社区交流探讨。




上一篇:Claude Code /loop 命令实战指南:实现自动化监控与自主迭代编程
下一篇:咨询业AI转型指南:2025-2027生存法则与三大现实路径
您需要登录后才可以回帖 登录 | 立即注册

手机版|小黑屋|网站地图|云栈社区 ( 苏ICP备2022046150号-2 )

GMT+8, 2026-4-7 18:15 , Processed in 0.965103 second(s), 42 queries , Gzip On.

Powered by Discuz! X3.5

© 2025-2026 云栈社区.

快速回复 返回顶部 返回列表