找回密码
立即注册
搜索
热搜: Java Python Linux Go
发回帖 发新帖

4607

积分

0

好友

604

主题
发表于 昨天 08:09 | 查看: 9| 回复: 0

ADM-v2论文标题页

想快速了解这项前沿研究?一张漫画带你把握核心!

ADM-v2核心原理漫画解读

你是否也在为离线强化学习中“模型跑几步就出错”的老大难问题而烦恼?想让智能体在学到的环境模型里自由探索,却总被不断累积的预测误差“打回原形”?最近,一项来自南京大学的研究ADM-v2,通过一个巧妙的架构“解耦”设计,成功让动力学模型实现了从1步到1000步的稳定精准预测,并在多个标准测试集上刷新了纪录,最高性能提升超过12.8%。

❓ 核心问题:为什么动力学模型不敢“看远”?

在离线强化学习(Offline RL)中,我们的目标是让智能体仅凭一批固定的历史交互数据,就能学会优秀的策略。基于模型的方法提供了一个诱人的思路:先学习一个环境动力学模型,然后在这个“模拟器”里安全、高效地进行策略探索和评估。

理想很丰满,现实却很骨感。绝大多数现有的动力学模型都有一个致命弱点:长期预测能力极差。它们通常采用“引导预测”的方式:用当前预测出的状态,作为下一步预测的输入。这就像用一张复印了多次、已经模糊的复印件去复印下一次,误差会像滚雪球一样累积,导致模型仅仅展开(roll-out)几步后,预测结果就与现实严重偏离。

因此,现有方法大多只能进行短时间或分支化的roll-out,这严重限制了策略的探索深度和评估的准确性。有研究表明,被截断的roll-out会阻碍策略探索到一些关键的状态边界。我们迫切需要一种能够进行“全时域”精准预测的动力学模型。

ADM-v2并行任意步长Roll-out架构示意图
图:ADM-v2的核心推理流程PARoll,实现了并行的多步直接预测,是支撑全时域 rollout 的关键。

🚀 原理拆解:从ADM到ADM-v2的“解耦”进化

💡 ADM初代:直接预测的曙光

要理解ADM-v2,得先看它的前身ADM。它的核心思想是用“直接预测”替代“引导预测”。传统模型预测 $s_{t+1}$ 需要 $s_t$,而ADM直接学习一个函数:给定初始状态 $s_t$ 和一个长度为 $m$ 的动作序列 ${a_t, ..., a_{t+m-1}}$,直接预测 $m$ 步之后的状态 $s_{t+m}$ 和奖励 $r_{t+m}$

这相当于跳过了中间所有可能出错的步骤,从源头上避免了误差的逐步累积。ADM使用GRU(门控循环单元)来建模这个时序过程,输入是重复拼接的 $s_t$ 与动作序列。

💡 ADM-v2的“神来之笔”:状态与GRU解耦

ADM的设计有一个小瑕疵:为了匹配动作序列长度,初始状态 $s_t$多次复制并与每个动作拼接,再一起送入GRU。这导致GRU的隐藏状态与初始状态强耦合,不够灵活,也难以实现并行化计算。

ADM-v2做出了一个关键改进:将初始状态编码为一个固定的隐藏向量 $h_t$,并直接将其作为GRU的初始隐藏状态。 此后,GRU的每一步前向计算只接收当前动作 $a_{t+k}$,而不再重复接收 $s_t$

这个“解耦”设计带来了三大好处:

  1. 提升预测鲁棒性:减少了初始状态 $s_t$ 的微小扰动对多步预测结果的连锁影响。
  2. 支持并行计算:为后续高效的并行推理算法 PARoll 铺平了道路。
  3. 结构更清晰:模型学习到的隐藏表示 $h$ 更能专注于刻画时序动态,而非重复记忆初始状态。

💡 核心架构:编码器-ADM2单元-解码器

ADM-v2由三个核心组件构成:

  • 状态编码器:一个5层MLP,将观测状态 $s_t$ 编码为隐藏向量 $h_t$
  • ADM2单元:核心是GRU,接收上一步的隐藏状态 $h_{t+k-1}$ 和当前动作 $a_{t+k}$,输出新的隐藏状态 $h_{t+k}$
  • 转移解码器:另一个5层MLP,将 $h_{t+m}$ 解码为下一状态 $s_{t+m}$ 和奖励 $r_{t+m}$ 的预测分布(均值和方差)。

编码器与解码器网络结构图
图:编码器/解码器采用的5层MLP结构,包含残差连接和层归一化,是模型稳定训练的基础。

模型的训练目标非常直观:最大化所有可能预测步长 $m$ (从1到预设的最大步长 $M$)的预测对数似然。这迫使模型同时学好短期和长期的动态。

💡 并行推理引擎:PARoll

拥有了支持直接预测的ADM-v2,如何高效地进行全时域 roll-out 并估计不确定性呢?论文提出了 PARoll(并行任意步 Roll-out)算法

其核心思想是并行维护 $K$ 条独立的预测时间线。每条时间线从一个不同的历史状态开始(例如时间线1从 $s_{t-0}$ 开始,时间线2从 $s_{t-1}$ 开始...),然后并行地执行ADM2单元的前向计算。这样,在每一步,我们都能同时得到 $K$ 个基于不同历史起点的预测结果。

这带来了两个巨大优势:

  1. 高效不确定性估计:这 $K$ 个预测结果的方差,自然构成了对模型在当前状态下预测不确定性的估计 $\hat{u}$,无需额外集成多个模型。
  2. 高吞吐量:所有 $K$ 条时间线的计算都是并行的,极大提升了采样效率。

📊 实验验证:数据与可视化双重震撼

🏆 SOTA性能:全面领先

理论再精妙,也需要数据验证。ADM-v2在权威的D4RL和更具挑战性的NeoRL基准测试上,与众多基于模型和无模型的SOTA方法进行了对比。

D4RL MuJoCo任务离线学习归一化得分表
表:在D4RL MuJoCo任务上,基于ADM-v2的离线策略优化方法ADM2PO-fh取得了全面领先的归一化得分。

NeoRL MuJoCo任务离线学习归一化得分表
表:在NeoRL任务上,ADM2PO-fh同样在所有难度级别上大幅领先,验证了其强大的泛化能力。

关键结论:基于ADM-v2并使用全时域Roll-out的策略优化方法ADM2PO-fh,在D4RL和NeoRL上均达到了新的SOTA,平均性能分别显著提升。更重要的是,只有ADM-v2能够在全时域(长达1000步)Roll-out下实现稳定的性能提升,其他模型一旦进行长时域Roll-out,性能都会因误差累积而严重下降。

🔬 不确定性量化:可靠性的标尺

一个优秀的动力学模型,不仅要预测得准,还要知道自己什么时候可能不准。ADM-v2通过PARoll自然产生的多个预测,可以计算出不确定性估计 $\hat{u}$

各模型不确定性量化对比散点图
图:在hopper-medium-v2任务上,ADM-v2的模型预测误差与其不确定性估计表现出高达0.928的相关系数,远超其他基线模型。

这个高相关性意味着:当模型预测误差大时,其不确定性估计值也高;当预测准确时,不确定性估计值则低。这使得 $\hat{u}$ 可以作为一个可靠的惩罚项,在策略学习时自动规避模型不熟悉的危险区域,极大地提升了离线学习的安全性。

🎨 隐藏状态可视化:学到的“世界模型”

为了深入理解ADM-v2为何优秀,作者对模型学习到的环境状态编码进行了可视化。

真实状态与模型隐藏状态t-SNE可视化
图:使用t-SNE对真实状态和模型隐藏状态进行降维可视化。ADM-v2学习到的隐藏状态与真实状态具有高度相似的结构化聚类,且时间演化轨迹清晰。

这证明ADM-v2的编码器成功学习到了环境的结构化、时序化的语义表示,这是其能进行精准长期预测的内在原因。

⚖️ 客观评价与未来展望

优势总结:

  1. 全时域预测:首次在标准基准上实现可靠的长达1000步的模型Roll-out。
  2. 高效并行:PARoll算法在保证精度的同时,提供了高推理吞吐量。
  3. 即插即用:ADM-v2学到的模型可直接用于策略评估和策略优化,均能取得领先性能。
  4. 理论扎实:提供了比单步模型更紧的性能边界理论证明。

局限与挑战:

  • 计算开销:相比最简单的集成模型,ADM-v2的参数量稍大,训练和推理成本更高,但其带来的性能提升是显著的。
  • 任务范围:当前实验集中于MuJoCo连续控制任务,在更复杂的视觉输入或离散动作空间中的表现有待进一步验证。
  • 超参数:最大直接预测步长 $M$ 是一个需要根据任务调节的超参数。

🌟 总结与思考

ADM-v2的研究为强化学习领域,特别是基于模型的离线学习方向,提供了新的思路。它证明,通过精巧的架构设计(状态与GRU解耦),动力学模型完全有能力进行精准的长期预测,从而解锁全时域的策略探索与评估。

这不仅是一个模型的改进,更是一种范式的推进:从过去的“如何限制模型在短时内少犯错”,转向“如何让模型具备长远视野并保持准确”。这项研究对机器人学习、自动驾驶等需要在安全仿真环境中进行大量训练的场景,具有重要的应用潜力。一个高保真、长视界的“模拟世界”,将是训练更强大、更安全AI智能体的关键基础设施。

你认为ADM-v2这种“直接预测”的范式,最可能率先在哪个实际AI应用场景中产生重要影响?是让机械臂学习更复杂的操作序列,还是让游戏AI进行更长期的战略规划?欢迎在云栈社区的相关板块分享你的见解,与更多开发者一同探讨人工智能的未来。




上一篇:人形机器人公司高薪招聘引热议:日薪50万招具身智能首席科学家
下一篇:深度剖析Go信号处理:从os/signal到runtime的集成全景
您需要登录后才可以回帖 登录 | 立即注册

手机版|小黑屋|网站地图|云栈社区 ( 苏ICP备2022046150号-2 )

GMT+8, 2026-4-7 18:23 , Processed in 0.862603 second(s), 41 queries , Gzip On.

Powered by Discuz! X3.5

© 2025-2026 云栈社区.

快速回复 返回顶部 返回列表