← 返回信息流
技术博客arXiv cs.CL·2 天前

DLLM-JEPA:结合掩码扩散与联合嵌入预测架构

原标题:DLLM-JEPA: Joint Embedding Predictive Architectures for Masked Diffusion Language Models

速览

DLLM-JEPA将联合嵌入预测架构(JEPA)与掩码扩散语言模型结合,解决了LLM-JEPA需要显式多视图数据和两次梯度前向传播的高成本问题。利用扩散模型的双向注意力机制,该方法通过不同掩码率生成语义视图,单次前向传播即可降低33%的训练FLOPs。实验显示,该方法在GSM8K等任务上显著优于仅微调扩散模型,并展现出精度与泛化能力的双重优势。

AI 深度解读

DLLM-JEPA:联合嵌入预测架构赋能掩码扩散语言模型

背景

在计算机视觉领域,联合嵌入预测架构(Joint Embedding Predictive Architectures, JEPAs)已经彻底重塑了自监督表示学习的方式。JEPAs 的核心思想是通过预测嵌入空间中的潜在表示来学习数据的结构化特征,而非直接重建原始像素。

近期,研究人员尝试将这一强大的范式迁移到自然语言处理领域,推出了 LLM-JEPA。然而,LLM-JEPA 直接沿用了因果注意力(causal-attention)子结构,这导致其继承了两个显著的计算与数据成本:

  1. 显式多视图数据依赖:它需要成对的多视图数据(例如文本-代码对)来构建预测任务。
  2. 高昂的梯度计算开销:在每一步训练中,它需要执行两次携带梯度的前向传播(forward passes)。

这些限制阻碍了 JEPAs 在大规模语言模型中的高效应用。为了解决这一问题,本文提出了 DLLM-JEPA,旨在通过结合掩码扩散语言模型(Masked Diffusion Language Models)来一次性消除上述两项成本。

核心内容

DLLM-JEPA 的核心创新在于将 JEPAs 与掩码扩散语言模型相结合。这种结合利用了扩散模型的双向注意力机制,从而在无需显式多视图数据的情况下,通过不同的掩码率(masking rates)从同一输入中生成两个语义上截然不同的视图。

1. 架构优势与效率提升

  • 消除数据依赖:由于扩散模型的双向特性,模型可以通过改变掩码比例自然产生不同的输入视角,无需像 LLM-JEPA 那样依赖额外的文本-代码对等显式配对数据。
  • 降低计算成本:该架构支持单次携带梯度的前向传播。相比于 LLM-JEPA,DLLM-JEPA 将训练所需的浮点运算次数(FLOPs)降低了 33%

2. 性能表现

在广泛的评估中,DLLM-JEPA 在所有测试的(任务,架构)组合中均优于仅使用扩散模型进行微调(diffusion-only fine-tuning)的方法。具体性能提升如下:

  • GSM8K 数学推理任务
    • 在 LLaDA-8B 模型上,准确率提升了 +18.7 个百分点(pp)。
    • 在 Dream-7B 模型上,准确率提升了 +11.4 个百分点
  • 其他任务:在 Spider(数据库生成)、NL-RX-SYNTH(自然语言到代码合成)和 Django(Web 开发)任务中,也观察到了持续的正向收益。

3. “双重胜利”(Dual-win)特性与机制分析

DLLM-JEPA 展现出一种罕见的“双重胜利”属性,即在提升特定任务性能的同时,保持甚至增强模型的通用能力。以 LLaDA-8B(Wide-t 配置)为例:

  • 任务精度提升:GSM8K 准确率从 65.2% 提升至 67.1%(+1.8 pp)。
  • 通用语言建模能力增强:在保留的 Wikitext 数据集上,困惑度(loss)低于预训练基线模型。
  • 知识保留:在三个不同的微调种子下,MMLU(大规模多任务语言理解)的准确率保持在基线水平。

相比之下,使用 L2-to-base 参数锚点(parameter anchor)的方法虽然能匹配基线准确率,但在特定任务上没有任何增益。

4. 层级的几何-功能漂移解耦

通过逐层探测(layer-wise probing),研究揭示了 DLLM-JEPA 的工作机制:

  • 几何-功能漂移解耦:微调后的骨干网络在几何空间上距离预训练权重更远(意味着发生了更大的表示变化),但在功能上却“遗忘”得更少(即在 Wikitext 上的表现更好)。
  • 中间层集中效应:这种解耦现象主要集中在 Transformer 的中间层。
  • 泛化性:这一模式在 Dream-7B 模型中同样出现,表明该现象并非特定于某一种骨干网络,而是 DLLM-JEPA 架构的普遍特性。

关键要点

  • 架构创新:DLLM-JEPA 首次将 JEPAs 与掩码扩散语言模型结合,利用双向注意力机制替代因果注意力。
  • 成本削减:通过单次前向传播和无需显式多视图数据,将训练 FLOPs 降低 33%。
  • 显著性能增益:在 LLaDA-8B 和 Dream-7B 上,GSM8K 准确率分别提升 18.7 pp 和 11.4 pp。
  • 双重胜利:实现了任务特定精度提升与通用语言建模能力(Wikitext loss)增强及知识保留(MMLU)的同步优化。
  • 机制洞察:揭示了微调过程中“几何漂移”与“功能遗忘”的解耦现象,且该现象主要发生在 Transformer 中间层。

意义与影响

DLLM-JEPA 的提出标志着自监督表示学习在语言模型领域的一个重要转折点。它证明了扩散模型的双向特性不仅可以用于生成任务,还可以高效地支持预测性表示学习。

  1. 打破数据瓶颈:通过消除对显式多视图数据(如文本-代码对)的依赖,DLLM-JEPA 降低了高质量配对数据获取的难度,使得自监督学习更加通用和可扩展。
  2. 计算效率优化:33% 的训练 FLOPs 节省对于大规模语言模型的训练至关重要,降低了算力门槛。
  3. 解决灾难性遗忘难题:其“双重胜利”特性为解决微调过程中常见的“能力退化”或“灾难性遗忘”问题提供了新的架构思路。通过几何-功能漂移的解耦,模型能够在适应新任务的同时更好地保留预训练知识。
  4. 范式拓展:这项工作为 JEPAs 从视觉向语言领域的迁移提供了更优的解决方案,可能推动后续更多基于扩散架构的自监督学习方法的发展。
查看原文 →arxiv.org