轨迹感知强化学习:为扩散语言模型优化推理路径
速览
针对扩散语言模型(dLLMs)强化学习中树搜索计算昂贵的问题,研究者提出CAPR算法。该方法通过总结去噪轨迹为紧凑路径状态,并利用缓存轨迹生成廉价分支,实现块级价值监督。实验显示,CAPR在Sudoku、GSM8K等任务上达到新SOTA,且计算成本显著低于传统树搜索方法。
AI 深度解读
阅读轨迹,引导路径:扩散语言模型的轨迹感知强化学习
背景
扩散大型语言模型(Diffusion Large Language Models, dLLMs)代表了一种不同于传统自回归(Autoregressive, AR)语言模型的新兴生成范式。与 AR 模型逐个 token 地生成文本不同,dLLMs 通过并行地对序列中的多个位置进行“去掩码”(unmasking)和修订来生成响应。这种迭代式的去噪过程会在模型内部留下丰富的“去噪轨迹”(denoising trace),详细记录了哪些 token 变得确定(confident),哪些仍然不稳定,以及承诺(commitments)是在何时形成的。
尽管这一轨迹蕴含了丰富的训练信号,但现有的 dLLM 强化学习(RL)方法对其利用非常有限。目前主流的两种轨迹生成策略各有优劣:
- 平坦轨迹(Flat Rollouts):计算成本低,但通常只给整个轨迹分配一个单一的结果奖励,缺乏细粒度的反馈。
- 树状轨迹(Tree Rollouts):通过分支部分轨迹并将叶子节点的奖励向上传播,提供更精细、可验证的训练信号。然而,这种方法计算密集,开销巨大。
这就引出了一个核心问题:能否利用去噪轨迹本身提供类似树状的监督信号,而无需承担树状展开的高昂计算成本?
核心内容
为了解决上述效率与精度的权衡问题,研究团队提出了 CAPR(Cached-Amortized Path Refinement,缓存-摊销路径细化)算法。这是一种专为 dLLM 设计的强化学习算法,旨在通过总结去噪轨迹来生成紧凑的路径状态,从而在保持低计算成本的同时恢复树搜索的粒度。
CAPR 的工作原理
CAPR 的核心机制包含以下几个关键步骤:
- 路径状态摘要:算法将去噪轨迹总结为一个紧凑的“路径状态”(path state)。
- 缓存与兄弟节点生成:利用缓存的轨迹状态,生成廉价的“兄弟节点”(sibling)延续,避免了重复的高成本计算。
- 块级价值头训练:训练一个块级(block-level)价值头,用于局部块级监督。
块级去掩码与奖励重分配
在具体的执行层面,CAPR 采用了一种块级去掩码调度(block-wise unmasking schedule)。在此过程中:
- 算法记录路径状态和块进度特征。
- 根据每个块中揭示(revealed)的 token,将最终的结果奖励重新分配给各个块。
这种机制使得价值头能够将一个稀疏的最终奖励转化为块级的 PPO(Proximal Policy Optimization)权重。通过这种方式,CAPR 成功恢复了大部分树搜索的粒度,同时避免了完整的树展开。
性能表现
在标准设置下,CAPR 显著降低了 rollout 生成的成本:
- 约为平坦轨迹成本的 0.75 倍。
- 约为树状轨迹成本的 0.6 倍。
在多个基准测试中,包括 4x4 数独(Sudoku)、Countdown 游戏、GSM8K 和 Math500,CAPR 在密集(dense)和混合专家(Mixture-of-Experts, MoE)的 LLaDA 骨干网络上,于 256 和 512 token 的预算限制下,为 RL 微调的 dLLMs 设立了新的最先进(State-of-the-Art)水平。特别是在数独任务中,CAPR 以不到每步计算量三分之一的代价,达到了最强的树状基线性能。
关键要点
- 填补技术空白:现有 dLLM 强化学习方法未能充分利用去噪轨迹中的丰富信号,CAPR 通过引入轨迹感知机制解决了这一问题。
- 效率与精度的平衡:CAPR 通过“缓存-摊销”机制,以低于树状轨迹 40% 的成本(即 0.6x),实现了接近树状搜索的细粒度监督能力。
- 块级奖励重分配:不同于传统方法给整个序列单一奖励,CAPR 根据块级去掩码进度,将最终奖励动态分配给各个块,训练出更精准的价值头。
- 广泛的适用性:该方法在逻辑推理(数独、Countdown)和数学推理(GSM8K, Math500)任务上均表现出 SOTA 性能,验证了其在不同推理场景下的有效性。
- 模型兼容性:CAPR 不仅适用于密集架构,也适用于混合专家(MoE)架构的 LLaDA 模型,展示了良好的架构适应性。
意义与影响
CAPR 的提出对扩散语言模型的发展具有重要意义。首先,它证明了去噪轨迹本身可以作为高效的监督信号来源,无需依赖昂贵的树搜索结构即可获得细粒度的训练反馈。这为 dLLM 的强化学习训练提供了一条新的、更具成本效益的路径。
其次,CAPR 显著降低了 RL 微调的计算门槛。通过将 rollout 生成成本降低至树状方法的 60%,使得在更大规模模型或更长序列上进行精细化的 RL 训练成为可能,有助于推动 dLLM 在复杂推理任务中的实际应用。
最后,这一工作丰富了 dLLM 的训练方法论。它表明,通过巧妙地利用模型内部生成的动态信息(如去噪轨迹),可以设计出比传统平坦或树状方法更智能的奖励分配机制,从而提升模型在需要多步推理和决策的任务中的表现。随着 dLLM 架构的成熟,此类轨迹感知强化学习技术有望成为标准训练流程的重要组成部分。
