Jiahong 的个人博客

凡事预则立,不预则废


  • Home

  • Tags

  • Archives

  • Navigation

  • Search

RS——生成式推荐

  • 参考链接:
    • (TIGER)Recommender Systems with Generative Retrieval, NeurIPS 2023, Google: 一篇比较早,也比较基础的工作
    • OneRec: Unifying Retrieve and Rank with Generative Recommender and Preference Alignment, 202502, KuaiShou
    • (COBRA)Sparse Meets Dense: Unified Generative Recommendations with Cascaded Sparse-Dense Representations, 202503, Baidu
    • (RARE)Real-time Ad retrieval via LLM-generative Commercial Intention for Sponsored Search Advertising, 202504, Tencent
      • 相关博客:腾讯搜索广告生成式检索
    • 一文汇总:LLM应用到推荐系统的各类玩法总结
    • Slow Thinking for Sequential Recommendation, 2025, RUC
      • 相关博客:STREAM-Rec: 推荐系统实现慢思考推理
    • LEARN: Knowledge Adaptation from Large Language Model to Recommendation for Practical Industrial Application, AAAI 2025, Kuaishou
      • 相关博客:AAAI’25 | 快手LEARN:使用LLM做特征增强用于电商广告推荐
    • Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for Generative Recommendations
    • HLLM: Enhancing Sequential Recommendations via Hierarchical Large Language Models for Item and User Modeling, 202409, ByteDance
    • Unlocking Scaling Law in Industrial Recommendation Systems with a Three-step Paradigm based Large User Model, 202502, Alibaba

整体说明

  • 传统深度学习推荐模型(Deep Learning Recommendation Models,DLRMs)范式主要是“Retrieval-Ranking”模式
  • 目前大模型在推荐领域的实践大致分为三类:
    • 第一类:保留 DLRMs 范式,仅利用通用的 LLM 做知识增强,或从领域数据中提取表征信息等
    • 第二类:仍保留 DLRMs 范式,部分重构 Retrieval 阶段或者 Ranking 阶段(或者在 Retrieval 阶段增加一路生成式召回通道)
    • 第二类:直接修改 DLRMs 范式,将输入和输出都重新构造为 LLM 模式,直接利用重构后的领域数据做预训练和微调

NLP——EfficientAttention-Survey(THU-2025)

注:本文包含 AI 辅助创作

  • 参考链接:
    • 原始论文:(Attention Survey of THU 2025)Efficient Attention Mechanisms for Large Language Models: A Survey, THU, 20250807
      • 有豆包基金的支持

Paper Summary

  • 整体总结:
    • 本文是一篇高效 Attention 的 Survey
    • 整体涉及到了 Linear Attention 和 Sparse Attention
  • 问题提出:
    • Transformer-based 架构已成为 LLM 的主流 backbone
    • 自注意力机制存在的二次时间和内存复杂度,是实现高效长上下文建模的根本障碍
  • 当前解法:
    • 为解决这一限制,近期研究引入了两大类高效注意力机制
    • 机制一:Linear Attention 方法:
      • 通过核近似、循环公式或快速权重动态来实现线性复杂度,从而以更低计算开销实现可扩展推理
    • 机制二:稀疏注意力 (Sparse Attention) 技术:
      • 基于固定模式、分块路由或聚类策略,将注意力计算限制在选定的 Token 子集上,在保持上下文覆盖的同时提升效率
  • 本综述主要工作如下:
    • 第一:对上述机制的发展进行了系统而全面的概述,结合了算法创新和硬件
    • 第二:分析了高效注意力在大规模预训练语言模型中的整合方式
      • 包括完全基于高效注意力构建的架构,以及结合局部和全局组件的混合设计

综述引言和一些讨论

  • Transformer-based 架构 (2017) 已成为现代 LLM 事实上的 backbone 选择
    • 标准的自注意力机制仍然是一个显著的计算瓶颈,其时间和内存复杂度相对于输入序列长度呈二次方增长
    • 这一限制对 LLM 在处理日益增长的长上下文时,同时保持高性能和高效率,提出了巨大挑战
  • 为解决此问题,出现了两个主要方向来降低 softmax 注意力 (softmax Attention) 的时间和空间复杂度
    • 第一种机制是Linear Attention (2020; 2020; 2021; 2021; 2023; 2024):
      • 通过将 softmax 注意力重新参数化或近似为线性操作来降低注意力复杂度
    • 第二种候选方案是稀疏注意力 (Sparse Attention) (2019; 2021; 2021; 2022; 2024):
      • 基于固定或动态的稀疏模式,将注意力计算限制在完整键空间的某个子集上
    • 虽然这两种方法都旨在提高效率,但它们在公式、设计选择和硬件影响方面存在显著差异
  • 本综述全面回顾了高效注意力 (Efficient Attention) 机制的最新进展,同时关注算法原理和系统级实现
    • 在此基础上,论文还研究了采用这些高效注意力的预训练 LLM
  • 论文将线性注意力方法分为三大范式
    • 范式一:核化线性注意力 (kernelized linear attention):
      • 通过特征空间内的内积来近似 softmax 核,借助随机特征映射 (2020; 2021) 或固定正映射 (2020) 实现线性复杂度
    • 范式二:带有遗忘机制 (forgetting mechanism) 的循环线性注意力 (recurrent linear attention):
      • 引入了位置感知的循环,通过数据无关 (2021) 或数据相关 (2023; 2023) 的衰减机制来控制过去信息随时间衰减的方式,从而实现对长序列的建模
    • 范式三:基于快速权重 (fast-weight) 和元学习 (meta-learning) 的公式:
      • 将线性注意力重新解释为在线优化的记忆更新过程,其中如 DeltaNet (2021; 2024) 和 TTT (2024) 等模型将快速学习动态直接整合到状态演化中
  • 作者还探讨了线性注意力在硬件友好的表示形式(包括并行、循环和分块形式)重点分析了它们在计算复杂度、内存占用以及与训练或推理工作流兼容性方面的权衡
  • 论文将稀疏注意力分为固定模式稀疏度 (fixed-pattern sparsity)、分块稀疏度 (block sparsity) 和基于聚类的稀疏度 (clustering-based sparsity)
    • 固定模式稀疏度采用静态的 Token-level 掩码,如滑动窗口、扩张位置或指定的全局 Token ,提供了简单性和硬件友好性 (2020; 2019; 2023; 2023)
    • 分块稀疏度在分块粒度上选择或路由注意力,可以通过启发式评分 (2021; 2022; 2025)、可训练的选通机制 (2024; 2024) 来实现,从而实现结构化的内存访问和高效的 GPU 利用率
    • 基于聚类的稀疏度使用基于内容或位置感知的分组方法(如 k-means 或 LSH)来组织键值对,以降低内存开销并促进基于语义感知的检索 (2023; 2020; 2024)
  • 最后,本综述还讨论了将稀疏模式扩展到编码器风格模型的双向稀疏设计
    • 这些方法在稀疏粒度、选择机制以及与 FlashAttention (2023) 等硬件原语的契合度上有所不同,共同代表了现代 Transformer 中实现高效长上下文建模的基础
  • 近期已有文章将高效注意力机制整合到工业级预训练语言模型中
    • 这包括纯高效架构:如线性注意力和状态空间模型,以及结合了局部和全局注意力模式的混合设计
      • 像 EAGLE (2024)、Falcon Mamba (2024) 和 MiniCPM4 (2025) 这样的模型展示了纯线性或稀疏方法在数十亿参数规模上的可扩展性,在提供强性能的同时实现了恒定时间推理
    • 同时,混合模型 (2020; 2022; 2024; 2025; 2025; 2025) 交错使用密集、稀疏和局部注意力,以平衡计算效率和上下文建模能力
      • 反映了现代 LLM 中朝着组合化、硬件感知注意力设计发展的趋势
  • 论文的目标是为理解注意力机制在算法和硬件双重约束下的演进,以及这些设计如何被整合到可扩展的 LLM 架构中,提供一个统一的框架
    • 通过将理论洞见与实际实现相结合,作者希望本综述能为致力于高效且可部署模型设计的研究者和从业者提供有价值的参考
  • 论文将讨论安排如下:
    • 第 2 节:介绍Linear Attention,涵盖其在不同模型世代中的演变、相关的设计原理以及对硬件实现的影响
    • 第 3 节:介绍稀疏注意力 (Sparse Attention),对稀疏模式进行分类,分析部署场景,并提供实用的系统级设计建议
    • 第 4 节:回顾整合了高效注意力机制的预训练语言模型 (Pretrained Language Models),包括统一高效架构和整合了局部、稀疏、密集注意力的混合模型
    • 第 5 节:对未来的发展方向进行展望 (Outlook),讨论算法和硬件对齐研究中的开放挑战和潜在进展

Linear Attention

  • 传统的线性注意力方法旨在以与序列长度成线性关系的方式近似基于 softmax 的注意力机制
    • 核心思想是用一种基于核函数(kernel)的注意力权重近似来替代计算代价高昂的 softmax 计算
  • 在标准的自注意力中,每个输出是值 \(V\) 的加权和,权重由查询-键相似度经过 softmax 得到:
    $$
    \text{Attn}(Q,K,V)=\text{softmax}(QK^{\top})V, \tag{1}
    $$
    • \(Q,K,V \in \mathbb{R}^{L \times d}\)
    • (\(L\) 是序列长度
    • \(d\) 是每个头的模型维度)
  • softmax 为查询 \(q_i\) 和键 \(k_j\) 产生权重 \(\propto \exp(q_i^{\top}k_j)\)
  • 核化线性注意力则寻找一个特征映射 \(\phi(\cdot)\),使得 softmax 核函数可以在诱导出的特征空间中通过一个简单的点积来近似:\(\exp(q^{\top}k) \approx \phi(q)^{\top} \phi(k)\)(2019)
    • 给定这样的 \(\phi\),可以将注意力重写为:
      $$
      O = \frac{\phi(Q)(\phi(K)^{\top}V)}{\phi(Q)(\phi(K)^{\top}\mathbf{I})} \tag{2}
      $$
  • 由于 \(\exp(\cdot)\) 的值域是非负的,\(\phi(\cdot)\) 通常被选择为产生非负输出,同时应用归一化除数来模拟 softmax 概率
  • 这种重新表述将复杂度从 \(O(L^2 d)\) 降低到 \(O(L d^2)\)(或者在适当的特征降维下甚至可以达到 \(O(L d)\)),因为昂贵的 \(L \times L\) 注意力矩阵从未显式形成
  • Linear Transformer (2020)
    • 用一个固定的正特征映射替换了 softmax 核函数
    • 在实践中,他们设置 \(\phi(x) = \operatorname{ELU}(x) + 1\)
    • \(\operatorname{ELU}(\cdot)\) 在整个定义域内可微,与朴素的 \(\operatorname{ReLU}(\cdot)\) 函数相比表现出更好的性能
  • Performer (2020)
    • 引入了 FAVOR+(一种能够无偏估计 softmax 核函数的随机特征方案)
    • 它对随机特征映射 \(\phi\) 进行采样,使得 \(E[\phi(Q)\phi(K)^{\top}] = \exp(QK^{\top})\)
    • 这产生了一个仅需 \(O(N)\) 操作即可证明是完整 softmax 注意力的无偏估计量
    • 特别地,Performer 使用正交互随机特征,降低了近似中的方差
  • 随机特征注意力 (Random Feature Attention,RFA) (2021)
    • 一种通过为 softmax 核函数使用随机傅里叶特征构建的线性注意力
    • 与 Performer 类似,RFA 利用随机映射和三角激活来近似 softmax
    • RFA 在随机投影之前进一步对查询和键进行归一化以减少方差
    • RFA 还有一个变体 RFA-Gate,它增加了一个可选的选通机制以引入近因偏差
  • cosFormer (2022)
    • 使用余弦函数来近似 softmax
    • 由于 \(\cos(a+b) = \cos a \cos b - \sin a \sin b\),cosFormer 将余弦重加权注意力 \(S_{ij} = Q_i’ K_j’ \cos(\frac{\pi}{2} \times \frac{i-j}{M})\) 分解为线性注意力的形式
  • HedgeDog (2024)
    • 利用了一个尖峰核函数 \(\phi(x) = \exp(Wx + b)\),因为他们观察到 Transformer 和线性 Transformer 之间的性能差距源于缺乏尖峰和单调性属性
    • HedgeDog 展示了更好的注意力熵和单调性

Linear Attention with Forgetting Mechanism

  • 最近的一系列工作通过循环神经网络或连续状态空间模型的视角来解释注意力
  • 传统的线性注意力通常是无位置感知的,其中循环顺序对输出没有影响,但现代的线性注意力表现得更像具有状态追踪和隐藏记忆的 RNN
  • 因此,这些模型明确地结合了循环、门控或状态动力学,以线性复杂度处理长序列
  • 衰减因子是引入遗忘机制的最重要因素
Data-Independent Decay
  • Retentive Networks(RetNet) (2023)
    • RetNet 引入了一种 Retention 机制,它使用固定的衰减系数,以一种循环风格的更新来取代注意力
    • 在 RetNet 层中,每个时间步 \(t\) 维护一个状态向量 \(s_t\),该向量以指数遗忘的方式聚合过去的输入
    • 循环可以写为:
      $$
      S_t = \gamma S_{t-1} + k_t^{\top} \nu_t \tag{3}
      $$
      • \(\gamma \in (0,1)\) 是一个学习的衰减因子(每个保持头)
      • \(k_t^{\top} \nu_t\) 是来自当前 Token 的新贡献(\(\nu_t\) 是 \(x_t\) 的值投影,\(k_t\) 是键投影)
    • 然后通过一个线性“查询”投影获得输出:\(o_t = q_t s_t\);展开方程 3 得到保持的显式公式:
      $$
      o_t = q_t S_t = \sum_{n=1}^{t} \gamma^{t-n} q_t k_t^{\top} \nu_t \tag{4}
      $$
    • 这表明 Token \(n\) 的贡献在时间步 \(t\) 时以因子 \(\gamma^{t-n}\) 呈指数衰减
    • Crucially,\(\gamma\) 是一个数据无关的衰减,是层的固定参数(在多头保持中通常每个头一个),而不是输入内容的函数
      • 这使得 RetNet 能够像 RNN 一样进行 \(O(1)\) 的内存更新,同时仍然允许通过等效的矩阵公式在训练期间进行并行计算(例如,可以证明方程 3 等价于一个“保持矩阵”形式,\(\text{Retention}(X) = (Q K^{\top} \odot D) V\),其中对于 \(t \geq n\),\(D_{t,n} = \gamma^{t-n}\) 实现了衰减和因果掩码)
    • RetNet 的保持机制与其他数据无关的循环模型有共同的主题
  • Eagle (2024)
    • 通过外积记忆改进了 RWKV 设计,这等效于线性注意力
    • 在 RWKV 系列中,衰减因子参数化为 \(\gamma = \exp(-\exp(w))\),其中 \(w\) 是一个数据无关的可学习因子
    • 在实践中,RetNet 和 Eagle 都使用固定衰减来遗忘旧信息,实现了线性推理扩展和具有竞争力的性能
    • 经验上,RetNet 每个头使用一个固定标量 \(\gamma\)(通常每层有多个具有不同 \(\gamma\) 值的保持头,形成一种多尺度衰减),而 Eagle 使用可学习的标量 \(w\) 来参数化衰减因子
  • Lightning Attention (2023; 2024)
    • Lightning Attention 也提出了一种线性注意力层 ,每个头增加了一个固定标量衰减,以实现长度不变的计算速度
    • 在 Lightning Attention 中,对于某个常数 \(\lambda\)(\(\lambda\) 由模型学习或设置),隐藏状态本质上是 \(s_t = \lambda s_{t-1} + k_t^{\top} \nu_t\),这与 RetNet 的 \(\gamma\) 精神相同,但针对硬件效率进行了优化
  • H3 (2022)
    • H3 将循环状态空间模型 (2021) 引入线性注意力,通过 SSM 为键值外积隐藏状态使用学习的、数据无关的指数衰减
    • 线性注意力通过分块计算实现了高效的训练,但 H3 需要为 SSM 计算显式状态扩展,从而限制了头维度,导致表达能力有限
  • In summary,数据无关的衰减方法维持一个随时间以预定速率衰减的持久状态,实现了 \(O(1)\) 循环和每步恒定内存
    • 它们牺牲了一些适应性 ,这促使了在更近期的模型中引入数据相关机制
Data-Dependent Decay
  • 虽然固定衰减提供了简单性和速度,但它们可能未能充分利用输入流中的信息
  • 门控或数据相关的方法使遗忘因子本身成为当前输入的学习函数
  • 这种循环更新的一般形式是:
    $$
    S_t = G_t S_{t-1} + k_t^{\top} \nu_t \tag{5}
    $$
    • \(S_{t-1}\) 是前一个状态
    • \(G_t\) 是由 Token \(x_t\) 确定的门控张量
      • 如果 \(G_t\) 在某个分量上接近 0,则该分量中的过去状态在时间 \(t\) 被大量遗忘;
      • 如果 \(G_t \approx 1\),则过去状态被保留
    • 与 RetNet 中的常数 \(\gamma\) 不同,这里的 \(G_t\) 通过 \(x_t\) 随 \(t\) 变化
  • 在大语言模型设计中,这种策略的两个显著例子是 Mamba (2023) 和门控线性注意力 (Gated Linear Attention, GLA) (2023)
  • Mamba
    • 一种循环状态空间模型,它赋予状态衰减率以输入依赖性
    • 在每个 Mamba 层中,基本的状态演化类似于 S4 (2021),但状态矩阵实际上变得动态
    • \(G_t\) 是一个范围在 0 到 1 之间的分组向量,作为动态遗忘门
      • 这弥合了注意力和纯 SSM 之间的差距
    • 实证结果表明,Mamba2 在语言建模任务上可以超越相似甚至更大规模的 Transformer,凸显了数据相关衰减在长序列建模中的能力
  • GLA
    • 直接在线性注意力中引入了门控机制,将一个门控函数嵌入到线性化的注意力层中以提高其表达能力
    • GLA 通过一个可学习的 Element-wise 遗忘门 \(G_t\) 来修改保持循环
  • 除此之外,其他几个模型同样赋予了其循环以内容相关的门控
  • xLSTM (2024)
    • 用线性门控信号(带归一化)的指数变换取代了标准的 sigmoid 遗忘门,对其单元状态产生了平滑的、输入条件化的衰减
  • GateLoop (2023)
    • 在保持机制上应用了头级门控,实现了一个简单但有效的数据相关衰减,同时保持了高效的硬件实现
  • HGRN (2023)
    • 在线性 RNN 中引入了门控循环
    • HGRN2 (2024) 进一步在 HGRN 框架中增加了状态扩展
    • 状态扩展等效于线性注意力中的键值外积
  • Finch (2024)
    • 在 Eagle 上使用了数据相关的门控
    • 由于 Eagle 与其他正交修改的保持机制相似,Finch 也显示出与上述模型的深厚联系
  • In summary,数据相关衰减模型通过基于内容的门控来增强线性注意力或 RNN 风格的架构,这些门控控制着信息的流动
    • 论文中的结果表明,这些模型在语言任务上通常能够匹配或超越 Transformer 的性能 ,同时能够扩展到非常长的输入

Linear Attention as In-Context Learners

  • 除了线性注意力机制带来的效率提升外,一项重大进展在于它们应用于增强上下文学习能力
    • 这指的是模型能够从给定的 Prompt 中快速适应或学习,而无需对其预训练权重进行显式的梯度更新
    • 理解:这里是说注意力机制本身是上下文学习期,即增加一些 Prompt Token,就可以借助修改注意力内容来实现对模型的快速学习
  • 大型 Transformer 模型通过将 Prompt 解释为一种训练数据的形式,固有地表现出上下文学习
    • 但最近的创新已将快速学习规则直接整合到注意力机制中,有效地将序列处理视为一个在线训练过程
    • 问题:这里的方法具体是什么?
  • FWP (2021) 建立了现有线性注意力机制与快速权重编程器之间的形式等价性
    • 在 FWP 范式中,一个慢速神经网络学习去编程另一个神经网络的“快速权重”,通常通过自发明的键和值模式的外积加法来实现
  • 表 1:
    • 不同线性注意力变体的更新规则
    • 每个模型都是关于矩阵记忆 \(S_t\) 的循环
Learning Objective
  • 从元学习的角度来看,这些模型定义了一个在推理过程中优化的隐式学习目标
  • 用 \(q_t, k_t, \nu_t\) 表示时间步 t 的查询、键和值,上下文记忆 \(S_t\) 通过以下目标进行优化:
    $$
    \mathcal{L}_t(S) = \frac{1}{2} || f_S(k_t) - \nu_t ||^2 \tag{6}
    $$
  • DeltaNet
    • DeltaNet 融合了经典的 delta 规则 (2021),其中 \(f_S(k_t) = S k_t\)
    • DeltaNet 更新规则为
      $$ S_t = S_{t-1} + \eta_t (\nu_t - S_{t-1} k_t) k_t^{\top} $$
      • 可以通过最小化当前记忆检索 \(S_{t-1} k_t\) 与新值 \(\nu_t\) 之间的误差推导出来
      • 这标志着朝着在线学习键值映射、基于即时上下文改进记忆迈出了一步
  • TTT (2024)
    • TTT 用不同的建模架构概括了元学习目标:
      $$
      f_S(k_t) = \begin{cases}
      \text{LM}(S k_t) + k_t, & \text{TTT-Linear} \\
      \text{LM}(\text{MLP}_S(k_t)) + k_t, & \text{TTT-MLP}
      \end{cases} \tag{7}
      $$
    • 上下文网络 \(f_S\) 增强了上下文元学习的能力
    • 但由于 \(f_S\) 的梯度比简单的线性投影复杂得多,在线更新不能写成一个简单的规则
  • 批量更新 (Batch Update)
    • 批量更新试图解决当 \(f_S\) 作为神经网络工作时训练并行性的困难
    • 通常,上下文记忆是以批大小为 1 进行元学习的,这对于一般的 TTT 模型来说不可行
    • 相反,类似于分块并行,TTT 将整个块视为一个批次
    • 在批次内没有状态更新(即 \(S\) 保持不变)
    • 处理完批次后,\(S\) 使用来自批次中所有样本的聚合梯度或更新信号进行一次更新
    • 这种策略在保持并行效率的同时,适应了更复杂架构的训练要求
  • Momentum Titans (2025)
    • Titans 中引入了优化中常用的动量,以加强记忆更新机制的能力:
      $$
      \mathcal{M}_t = (1 - \alpha_t) \mathcal{M}_{t-1} + S_t \\
      o_t = q_t \mathcal{M}_t \tag{8}
      $$
    • 动量项允许记忆通过对状态 \(S\) 进行指数移动平均来逐渐积累信息
    • 这可以看作是元学习的一种形式,其中更新规则本身学会了在长序列上更加稳定和鲁棒
  • 权重衰减 (Weight Decay)
    • 权重衰减是训练中的另一种正则化技术,对应于线性注意力模型中的遗忘机制
    • Gated DeltaNet (2024) 和 Titans 在其记忆更新中使用了权重衰减,作为学习的遗忘门来限制非常旧或嘈杂数据的影响
    • 它对应于在 RetNet (2023) 和 Mamba (2023) 等架构中发现的选择性状态保持机制,其中衰减机制被证明对语言建模性能至关重要:
      $$
      S_n = \gamma_n S_{t-1} + \eta_t (v_t - S_{t-1} k_t) k_t^{\top} \tag{9}
      $$
  • In summary,线性注意力机制中的这些进步通过明确地将元学习原则整合到其架构中,正在推动上下文学习的边界
    • 通过快速权重更新、复杂的内存管理技术和在线学习规则,这些模型正朝着一种范式发展,在这种范式中,训练和推理之间的区别变得越来越模糊,从而产生了能够直接从上下文中学习和利用知识的更高效、适应性更强的大语言模型

Discussion on Other Designs

Element-wise Linear Attention
  • 无需注意力的 Transformer (Attention-Free Transformer) (2021)
    • 利用一个简单的权重 \(\exp(K_{\nu’} + w_{t,\nu’})\) 来代替 \(\exp(QK^{\top})\):
      $$
      O_t = \sigma_q(Q_t) \odot \frac{\sum_{\nu’=1}^{t} \exp(K_{\nu’} + w_{t,\nu’}) \odot V_{\nu’} }{\sum_{\nu’=1}^{t} \exp(K_{\nu’} + w_{t,\nu’})} \tag{10}
      $$
    • 其中 \(w_{t,\nu’}\) 是学习的成对位置偏置
    • 在 AFT 的变体中,AFT-Simple 移除了 \(w_{t,t’}\),实现了线性化的推理模式
    • 由于 \(K\) 和 \(V\) 的乘积是 Element-wise 的,循环状态大小是 \(\mathbb{R}^d\) 而不是外积状态 \(\mathbb{R}^{d \times d}\)
  • RWKV (2023)
    • 在 AFT-Simple 上利用了衰减机制。具体来说,RWKV 通过指数衰减 \(w_{t,i} = -(t-i)w\) 改进了 AFT 的位置偏置
    • 指数形式在引入位置偏置的同时保留了循环属性
  • Element-wise 线性注意力带来了强大的推理优势,但它受到状态大小瓶颈的影响,性能低于基于矩阵的状态大小
    • Besides,尽管 Element-wise 内存比外积内存快得多,但由于其他组件在拥有外积内存的情况下占据了超过 95% 的延迟 (2023),端到端的优势仍然有限
Multi-Pass Linear Attention
  • 带有限内存控制的注意力 (Attention with Bounded-memory Control)
    • 将线性注意力视为一个有界内存模块:
      $$
      \begin{split}
      \tilde{K}_n &= \sum_{i=1}^{n} K_i \otimes \phi_i, \quad \tilde{V}_n = \sum_{i=1}^{n} V_i \otimes \phi_i \\
      O_n &= \text{softmax}(Q_n \tilde{K}_n^{\top}) \tilde{V}_n
      \end{split} \tag{11}
      $$
    • 其中 \(\tilde{K}_n, \tilde{V}_n\) 是在线更新的、大小受限的键和值。在实现中,ABC 可以简化为两遍线性注意力
  • 门控 Slot 注意力 (Gated Slot Attention) (2024)
    • 进一步将 GLA 引入到 ABC 框架 (2021) 中
    • 由于 \(\tilde{K}_n, \tilde{V}_n\) 作为一个隐式的线性注意力工作,GSA 将更新改进为门控形式:
      $$
      \tilde{K}_n = \text{Diag}(\alpha_n) \tilde{K}_{n-1} + (1 - \alpha_n) \otimes K_n, \quad \tilde{V}_n = \text{Diag}(\alpha_n) \tilde{V}_{n-1} + (1 - \alpha_n) \otimes V_n \tag{12}
      $$
  • Multi-Pass 是增强线性注意力表达能力的一种有效方式,但它也带来了额外的计算开销,这使得架构设计需要在训练效率和性能之间进行权衡
Bidirectional Linear Attention
  • 双向注意力在编码器风格的架构(如 BERT (2019))中扮演着重要角色
  • 单向和双向注意力在线性表述中的关键区别在于推理瓶颈和计算模式
  • Decoder-only 模型通常表现出 \(O(N^2)\) 的复杂度,且 Decoder-only 模型中的每个 Token 都可以访问全局信息
    • 因此双向线性注意力通常维护一个恒定长度的全局 Token 池以降低复杂度,同时保留 softmax 函数的使用
  • 例如
    • Linformer (2020) 通过一个额外的矩阵投影将键和值的数量减少到一个恒定长度
    • Luna (2021) 通过跨模型层编码全局 Token 池进一步扩展了 Linformer 设计
  • 虽然双向线性注意力对 Decoder-only 架构有效,但这些设计在应用于因果设置时面临重大挑战,因为基于全局池的方法往往计算成本高昂
    • 因此,此类架构不太适合大语言模型

Hardware Implementation

  • 并行表示 (The Parallel Representation)
    • 可以将带门控衰减的因果线性注意力定义为:
      $$
      \begin{split}
      Q &= \phi(XW_Q), \quad K = \phi(XW_K), \quad V = XW_V, \quad \gamma = f_\gamma(X) \\
      D_{nm} &= \begin{cases}
      \prod_{i=m+1}^{n} \gamma_i, & n \geq m \\
      0, & n < m
      \end{cases}, \quad O(X) = \text{LN}((Q K^{\top} \odot D) V)
      \end{split} \tag{13}
      $$
    • 其中
      • \(W_Q, W_K, W_V \in \mathbb{R}^{d \times d}\)
      • \(f_\gamma\) 控制衰减的锐度
      • 矩阵 \(D \in \mathbb{R}^{N \times N}\) 编码了具有衰减模式的因果掩码,确保信息的单向流动
    • 当衰减是数据无关时,\(f_\gamma(\cdot) = \text{const} \in (0,1]\)
      • 注意,线性注意力后的分组归一化 (GroupNorm) (2018) 已经是一个强制性的组件 (2023),方程 2 中核化线性注意力的显式除数变得不必要
    • 并行表示简单易懂,但有两个缺点
      • 第一:并行形式仍然保持了与 softmax 注意力相同的 \(O(N^2)\) 复杂度
      • 第二:当表示第 2.3 节中的 ICL 风格线性注意力时,其复杂性会增加
  • 循环表示 (The Recurrent Representation)
    • 上述并行公式可以等效地表达为逐步解码的循环形式,如图 2(b) 所示,在每个时间步 \(n\),输出计算为:
      $$
      S_n = f_{\text{update} }(S_{n-1}, K_n, V_n), \quad O_n = Q_n S_n \\
      f_{\text{update} }(S_{n-1}, K_n, V_n) = \begin{cases}
      \gamma_n S_{n-1} + K_n^{\top} V_n, & \text{Linear Attention with Decay} \\
      \gamma_n S_{n-1} + \eta_n (V_n - S_{n-1} K_n) K_n^{\top}, & \text{ICL-style Linear Attention}
      \end{cases} \tag{14}
      $$
    • 这种循环表示通过维护单个状态向量 \(S_n\),实现了具有恒定内存的高效自回归生成
    • 虽然循环表示将计算复杂度从 \(O(N^2)\) 降低到 \(O(N)\),但它在训练期间会产生大量的内存开销
      • 因为 \(S_n\) 涉及存储 \(K_n\) 和 \(V_n\) 的外积,这对于长序列来说是极其昂贵的
      • 因此,循环形式通常仅限于解码阶段
  • 分块循环表示 (The Chunkwise Recurrent Representation)
    • 分块表示结合了线性复杂度和硬件友好的并行性的优势 (2022; 2023)
    • 如图 2(a) 所示,以衰减风格线性注意力为例,给定一个块大小 \(B\),令 \(x_{[i]}\) 表示第 \(i\) 块。定义块内的累积衰减为:
      $$
      \beta_{(i-1)B+j} = \prod_{k=(i-1)B+1}^{(i-1)B+j} \gamma_k, \quad D_{[i]}(j,k) = \begin{cases}
      \frac{\beta_{(i-1)B+k} }{\beta_{(i-1)B+j} }, & j \leq k \\
      0, & \text{Other wise}
      \end{cases} \tag{15}
      $$
    • 块级记忆状态 \(R_i\) 计算为:
      $$
      R_i = K_{[i]}^{\top} (V_{[i]} \odot \frac{\beta_{iB} }{\beta_{[i]} }) + \beta_{iB} R_{i-1} \tag{16}
      $$
    • 第 \(i\) 块的输出为:
      $$
      O_{[i]} = (Q_{[i]} K_{[i]}^{\top} \odot D_{[i]}) V_{[i]} + (Q_{[i]} R_{i-1}) \odot \beta_{[i]} \tag{17}
      $$
    • 这种表述提供了循环和并行的统一视图:
      • 第一项捕获块内依赖关系
      • 第二项通过单个矩阵-向量乘积传播块间记忆
      • 由于其效率和可并行性,分块表示通常在训练和预填充阶段被采用
  • 对于 ICL 风格的线性注意力,已经使用 Householder 变换 (2024) 开发了硬件友好的分块表示
    • 但对于更复杂的变体,如 TTT 和 Titans,构建显式的分块形式仍然具有挑战性
    • 相反,这些架构通常依赖于大的批次大小进行记忆更新,通过固定超参数有效地模拟分块计算
  • 内核级优化对于实现高性能至关重要
    • 广泛采用的 FLA (2024) 为许多常见的线性注意力模块提供了基于 Triton 的实现
    • 或者,开发者还提供了用 CUDA 或 TileLang (2024) 编写的自定义实现,可以用于进一步加速
  • 图 2:线性注意力的双重形式

Sparse Attention

  • 稀疏注意力(Sparse Attention)方法利用注意力计算中固有的稀疏特性,通过以下公式近似完整注意力计算:
    $$
    \text{Attn}(Q,K,V) = \text{softmax}(QK_{[\mathcal{S}]}^T)V_{[\mathcal{S}]} \tag{18}
    $$
    • 其中 \(\mathcal{S}(t)\) 是查询向量 \(Q(t)\) 所关注索引的子集
    • 不同的方法基于选择准确性和硬件效率的考量,设计了不同的选择标准 \(\mathcal{S}(t)\)
    • 目标是在预填充(prefilling)阶段实现亚线性或线性复杂度,或在解码(decoding)阶段控制固定的计算预算

Fixed-pattern Sparse Attention (固定模式稀疏注意力)

  • 一些研究工作利用 Token-level 别稀疏性的结构化模式,为注意力计算构建固定的稀疏掩码
  • Local Window Attention (局部窗口注意力)
    • 局部窗口注意力将每个查询限制在仅与固定滑动窗口 \(w\) 内的相邻 Token 交互,从而在保留局部上下文的同时降低内存和计算需求
    • Sparse Transformer (2019) 首先应用局部窗口(行)注意力,其中 \(w\) 接近 \(\sqrt{N}\),然后通过额外的列注意力来总结先前位置并在全局传播信息
    • GPT-3 (2020) 也采用了与 Sparse Transformer 类似的稀疏注意力模式
    • StreamingLLM (2023) 发现大量的注意力得分被分配给了输入序列的初始 Token ,他们称之为“注意力汇集点(attention sink)”。他们提出了一种简单的固定模式注意力,只保留汇集点 Token 和滑动窗口内的 Token 。例如,给定一个长度为 \(n\) 的输入序列,StreamingLLM 中查询 Token \(q_t\) 的选定 Token 子集 \(\mathcal{S}(t)\) 被定义为:
      $$
      \mathcal{S}(t) = \left\{ j \mid 0 \leq j \leq s \ \lor \ t-w \leq j \leq t \right\}, \ \forall t \in [1, n] \tag{19}
      $$
      • 其中 \(s\) 是汇集点 Token 的大小,\(w\) 是滑动窗口的大小
      • 为了获得更好的硬件效率,采用块粒度的 StreamingLLM (2024) 以块方式保留汇集点 Token 和局部 Token ,从而实现高效的内存加载和计算
  • Dilated Attention (膨胀注意力)
    • LongNet (2023)
      • 引入了膨胀注意力作为长上下文训练和推理的固定稀疏模式。膨胀注意力随着距离的增长呈指数级扩展注意力范围,从而将注意力的复杂度从 \(O(n^2)\) 降低到 \(O(n)\)。具体来说,在沿序列维度将输入分割成长度为 \(w\) 的片段后,从每个片段中以间隔 \(r\) 选择膨胀稀疏索引。第 \(i\) 个片段的选择索引为:
        $$
        \hat{I}_i = \left[iw, iw+r, iw+2r, …, (i+1)w-1 \right] \tag{20}
        $$
      • 稀疏化的片段 \(Q_{\hat{I}_i}, K_{\hat{I}_i}, V_{\hat{I}_i},\ i \in \{0, 1, …, \frac{n}{w}\}\) 被并行输入到注意力计算中,得到注意力输出 \(O\)。结合不同片段大小和膨胀率 \(\{r_i, w_i\}^{k}\) 的注意力输出,最终注意力计算如下:
        $$
        O = \sum_{i=1}^{K} \alpha_i O|_{r_i, w_i}, \quad \alpha_i = \frac{s_i}{\sum_{j} s_j} \tag{21}
        $$
      • 其中 \(s_i\) 表示 \(O|_{r_i, w_i}\) 的注意力 softmax 的分母
  • LogSparse (2019)
    • 采用指数稀疏注意力方案,其中每个位置仅关注 \(\log N\) 个 Token ,这可以看作是指数膨胀注意力的一个实例

Block Sparse Attention (块稀疏注意力)

  • 给定一个长度为 \(n\)、块大小为 \(b\) 的输入序列,可以将 \(Q, K, V \in \mathcal{R}^{n \times d}\) 各自划分为 \(\frac{n}{b}\) 个块,每个块大小为 \(b \times d\)
  • 目标是近似一个块级掩码 \(M \in \{0, 1\}^{n/b \times n/b}\),用于选择关键块进行计算,如图 3 所示
    $$
    \text{Attn}(Q, K, V)_i = \sum_{j=1}^{n/b} M_{ij} \cdot \text{softmax}(Q_i K_j^T) V_j \tag{22}
    $$
  • 块级选择对于在现代 GPU 上实现高效计算至关重要
  • 图 3: 块稀疏注意力:长序列被分成若干块,每个 Token 仅关注其局部窗口和 top-k 相关块
Block-Sparse Attention for Prefill (用于预填充的块稀疏注意力)
  • 使用块稀疏注意力进行预填充的方法,近似选择覆盖大部分注意力得分且具有高召回率的 Top-K 块,从而将注意力计算复杂度从 \(O(n^2)\) 降低到 \(O(K)\)
    $$
    S = \text{softmax}(QK^T - c(1 - M)) \\
    \min \ |S(M) - S_{\text{dense} }|
    $$
    • 其中 \(M\) 是如上定义的块级稀疏掩码,\(c\) 是一个大常数(例如 1e5),确保重要性较低的注意力权重在 softmax 计算后趋近于零
    • 块稀疏注意力的目标是在最小开销下实现更大的加速,同时尽可能保留更多的注意力权重
  • MInference (2024)
    • 观察到注意力权重中存在三种模式:
      • 流式(A 型)模式(Streaming (A Shape) Pattern)
      • 垂直斜线模式(Vertical-Slash Pattern)
      • 块稀疏模式(Block-Sparse Pattern)
    • 它离线确定每个注意力头的最佳模式,并在推理过程中基于指定的模式动态构建稀疏索引
  • FlexPrefill (2025)
    & 提出了一种上下文感知的稀疏注意力机制,能够实时动态调整注意力模式和计算预算
  • XAttention (2025)
    • 提出了一个块稀疏注意力框架,利用反对角线(antidiagonal)评分来预测注意力块的重要性,从而能够高效识别并剪除非必要块,实现高稀疏性和显著的计算增益
  • SpargeAttn (2025)
    • 同样在预填充阶段采用块级稀疏注意力,通过一个双阶段在线过滤过程完成:
      • 第一阶段快速预测注意力图以跳过某些矩阵乘法
      • 第二阶段应用 softmax 感知的过滤器以进一步消除不必要的计算
Block-Sparse Attention for Decode (用于解码的块稀疏注意力)
  • 用于解码的块稀疏注意力方法动态选择包含每个解码步骤最关键的 Token 的子集 \(S\) 的 \(K, V\) 向量,从而减少内存加载并提高效率
  • Quest (2024) 通过计算注意力权重的上界来近似每个块的关键性。对于块 \(K_i\),论文通过以下公式维护 Element-wise 的 Min 和 Max Key \(m_i\) 和 \(M_i\):
    $$
    m_{i,d} = \min(K_{i,d}), \quad M_{i,d} = \max(K_{i,d})
    $$
    • 其中 \(\min(\cdot)\) 和 \(\max(\cdot)\) 在每个维度 \(d\) 上 Element-wise 应用
  • 给定查询 \(q\),块 \(i\) 的近似注意力得分由下式给出:
    $$
    \text{score}_i = \sum_{j=1}^{d} \max(q_j \times M_{i,j}, q_j \times m_{i,j})
    $$
  • 然后选择得分最高的 Top-K 块作为注意力计算的稀疏子集 \(S\):
    $$
    S = \text{argtopk}(\text{score}, k)
    $$
  • DoubleSparsity (2024) 通过降低计算 \(QK^T\) 乘积的矩阵乘法维度来高效近似关键 Token
    • 它首先离线计算 \(QK^T\) 中的离群通道,记为 \(C\)
    • 然后选择具有最高近似注意力得分 \(\hat{s}\) 的 Top-K Token 作为稀疏子集 \(S\):
      $$
      Q_{\text{label} } = Q_{\{C\} }, \quad \hat{s} = Q_{\text{label} } K_{\text{label} }^T, \quad S = \text{argtopk}(\hat{s}, k)
      $$
  • ReSA (2025)
    • 结合了无需训练的块稀疏估计和 GQA 共享,有助于提高效率
    • ReSA 还提出了一个修正阶段来控制 KV 缓存累积误差
    • ReSA 在长序列生成任务上显示出优势
Routing-based Block-Sparse Attention (基于路由的块稀疏注意力)
  • 基于路由的块稀疏注意力通过可训练的 MLP 层学习每个 Token 块的重要性,该层在推理期间充当门控网络以选择关键块
  • Learnable Sparsity on Pretrained Models (预训练模型上的可学习稀疏性)
  • SeerAttention (2024)
    • 通过自蒸馏(self-distillation)的方式在预训练的 LLM 上训练门控网络
    • 为了获得每个块的重要性分数,它首先沿序列维度对 \(Q\) 和 \(K\) 进行池化,记为 \(P_q\) 和 \(P_k\)。下采样后的 \(Q, K\) 然后通过一个可学习的线性层 \(W_q\) 和 \(W_k\)
    • 投影后的 \(W_q P_q(Q)\) 和 \(W_k P_k(K)\) 的矩阵相乘结果通过 softmax 运算符作为门控过程:
      $$
      \text{score} = \text{softmax}((W_q P_q(Q)) \cdot (W_k P_k(K)))
      $$
    • 可学习的线性层通过自蒸馏的方式训练,以与原始 LLM 的 2D 最大池化结果对齐。蒸馏损失计算如下:
      $$
      gt = \text{MaxPool2D}(\text{softmax}(QK^T)), \quad loss = D_{KL}(gt \ || \ \text{score})
      $$
    • 在推理过程中,门控分数通过 Top-K 或阈值化来预测块级稀疏性,用于稀疏计算和效率提升
  • Training-aware Sparse Attention (训练感知的稀疏注意力)
    • Landmark (2023)
      • 提出使用特殊的“地标”(landmark) Token 来表示每个块,并训练注意力机制通过这些地标 Token 直接检索 Top-K 块
      • 然而,它没有在大型预训练模型上进行实验
    • MoBA (2025)
      • 将可训练的稀疏注意力集成到预训练阶段
      • 它提出了 Mixture of Block Attention,应用来自 MoE(专家混合)的 Top-K 机制作为门控机制,为每个查询 Token 决定关键块
      • 每个块的重要性分数通过查询 Token \(q\) 和块 \(K_i\) 沿 Token 维度的平均池化结果的内积计算:
        $$
        s_i = \langle q, P_{\text{mean} }(K_i) \rangle
        $$
      • 然后选择具有最高 \(s\) 分数的 Top-K 块用于计算 \(q\) 的注意力
      • Notably,MoBA 使用的 Top-K 块选择是不可微分的
        • 因此,在预训练阶段,稀疏模式仍然以无需训练的模式进行估计,从而实现高效的推理和加速的训练
    • NSA (2025)
      • 引入了一种训练感知的多粒度稀疏注意力机制,包含三个分支 \(C \in \{\text{cmp}, \text{slc}, \text{win}\}\),分别对应压缩(compression)、选择(selection)和滑动窗口(sliding window)策略
      • NSA 利用一个可微分的压缩分支来学习块选择分数。结合三个分支,NSA 的注意力输出由下式给出:
        $$
        o = \sum_{c \in \mathcal{C} } g^c \cdot \text{Attn}(q, K^c, V^c), \quad g_c \in [0, 1]
        $$
      • 对于压缩分支 \(c = \text{cmp}\),块 \(i\) 的键 \(K_i \in \mathcal{R}^{d_i \times b}\) 通过一个可学习的 MLP 层 \(\varphi\) 被压缩为单个键 \(K_i^{\text{cmp} } \in \mathcal{R}^{d_i \times 1}\)
      • 对于选择分支 \(c = \text{slc}\),基于块重要性分数 \(p\) 选择 Top-K 块,该分数可以直接从压缩分支获得
    • InfLLM-v2 (2025)
      • 采用了与 MoBA 类似的训练感知 Top-K 块稀疏注意力机制
      • 为了提高 Top-K 块选择的准确性,它将块划分为具有重叠的小粒度内核,并在每个块内对内核重要性分数进行聚合
System-level Design Choices (系统级设计选择)
  • 训练感知的稀疏注意力 (2025; 2025; 2025) 开始考虑内核实现和高效执行
  • 为了实现块稀疏注意力的高效实现,FlashAttention (2023) 被用于在高效的平铺机制中进行注意力计算,这为更好地利用硬件资源带来了要求和机会,包括:
    • 为避免内存访问不一致,在 SeerAttention (2024) 和 MInference (2024) 中,块大小 \(b\) 通常设置为至少 64 的相对较大值
    • 为了与 GPU 张量核心上分组矩阵乘法指令(Grouped Matrix Multiplication)的最低要求对齐,在 NSA (2025) 和 InfLLM-v2 (2025) 中,一个查询组内的 K、V 头数被设置为至少 16
    • 为了减少内存访问,NSA (2025) 和 InfLLM-v2 (2025) 强制查询组之间共享选定的块,这是通过在查询组内对块级重要性分数进行池化来完成的

Clustering Attention (聚类注意力)

  • 与块稀疏注意力类似,聚类注意力旨在为解码选择最关键地 Token ,但将 Token 组织在数据结构中以获得更好的语义属性或采样效率
  • RetrievalAttention (2024)
    • 采用近似最近邻搜索(Approximate Nearest Neighbor Search, ANNS)来选择关键的 K 个聚类
    • 为了解决注意力机制中查询向量和键向量之间的分布外(out-of-distribution)性质带来的挑战,它引入了一种适应查询向量分布的注意力感知向量搜索算法
  • ClusterKV (2024)
    • 在语义聚类的粒度上选择 Token ,克服了诸如 Quest 等页面级检索方法内部碎片化的问题
    • 在预填充阶段之后, Token 通过 K-means 算法进行聚类
    • Token \(i\) 和 \(j\) 之间的语义相似性通过键向量的余弦相似性度量:\(\mathcal{D}(i, j) = 1 - \frac{\langle k_i, k_j \rangle}{|k_i| \cdot |k_j|}\)
    • 语义聚类由其质心 \(\mu_1, \mu_2, …, \mu_C \in \mathcal{R}^d\) 表示
    • 在每个解码步骤,基于查询 Token \(q\) 和质心 \(\mu_i\) 的注意力权重(即 \(q \mu_i^T\))的排名来选择聚类
  • MagicPIG (2024)
    • 利用局部敏感哈希(Locality Sensitive Hashing, LSH)采样来高效近似注意力计算
    • 它使用 LSH 将相似的查询向量和键向量映射到相同的哈希桶中,并将存储和部分计算卸载到 CPU,以解决 KV 缓存的瓶颈
    • 它还引入了 Oracle Top-K 采样作为比暴力 Top-K 更好的策略

Bidirectional Sparse Attention (双向稀疏注意力)

  • 双向稀疏注意力建立在编码器风格的架构上,使用静态模式或块级稀疏性来加速注意力计算
  • 块稀疏性在双向稀疏注意力中被广泛使用
    • BigBird (2020)
      • 使用块级随机注意力,作为缩短 Token 之间间接路径的桥梁
    • Longformer (2020)
      • 使用静态的全局-局部混合注意力
      • Longformer 也依赖于块级稀疏性,并带有额外的全局和随机链接,以促进结构化计算和内存高效的并行性
  • 基于聚类的方法也用于双向稀疏注意力
    • Reformer (2020)
      • 使用局部敏感哈希(LSH)将相似的 Token 分配到同一个桶中
    • Routing Transformer (2020)
      • 在每一层执行在线 \(k\)-means 聚类
    • ClusterFormer (2022)
      • 引入了一个与下游目标共同训练的可微分聚类模块
      • 这些方法通过分组相关 Token 来减少计算,同时通过学习的适应性来保持性能

采用高效注意力机制的预训练大语言模型

采用统一高效注意力机制的预训练模型

  • 尽管早期的线性注意力探索通常局限于小规模模型,但近期的进展已证明其能够成功扩展到数十亿参数的规模,使其成为标准 Transformer 的一种可行且高效的替代方案
  • 这些模型完全基于线性注意力或其架构等价物,如状态空间模型和循环神经网络,即使在大规模下也能保持其标志性的推理效率
  • 基于 RWKV 的模型
    • RWKV 项目代表了一项持续且有影响力的努力,旨在创建一个可扩展的循环神经网络架构,它结合了 Transformer 的可并行化训练与传统 RNN 的高效推理 (2023)
    • 例如,EAGLE 系列引入了矩阵值状态以增加容量,而后续迭代如 Finch (RWKV-6) (2024) 和 Goose (RWKV-7) (2025) 则结合了动态循环和更具表达力的状态演化机制,以支持更复杂、数据依赖的状态转换
  • 基于 Mamba 的模型
    • Mamba 架构的成功及其数据依赖的选择机制,已引发了主要研究实验室的一波采用和扩展浪潮
    • Falcon Mamba (2024)
      • 基于纯 Mamba 架构,在一系列通用语言基准测试中展现出与领先 Transformer 模型相竞争的性能,验证了该架构在此类任务上的可行性,同时保持了其标志性的恒定时间推理
    • Codestral Mamba (2024)
      • 基于 Mamba-2 架构,进一步证明了该范式的潜力
      • 虽然专门用于代码生成,但它在相关基准测试中取得了 SOTA 结果,并支持 256K 个 Token 的上下文长度,展示了 SSM 方法在复杂结构化领域内的可扩展性和有效性
  • 基于稀疏注意力的模型
    • MiniCPM-4 (2025) 引入了一种两阶段稀疏注意力机制,根据语义相似性为每个查询 Token 动态选择相关的键值块
    • MiniCPM-4 利用 InfLLM-v2(一种块稀疏注意力变体)来替代标准注意力机制
    • 此外,一种轻量级的 LogSumExp 近似实现了高效的 top-k 选择,使得该方法能够扩展到极长序列
    • 这些技术共同使 MiniCPM-4 能够在细粒度上下文感知能力与可控的内存和计算需求之间取得平衡,使其成为长上下文建模的有力候选者

采用混合高效注意力机制的预训练模型

  • 随着对高效长上下文建模和多样化计算范式需求的增长,最近的研究广泛探索了混合注意力机制
  • 此类策略结合了全局和局部注意力组件,通常交错使用专门设计的层,以平衡计算成本和性能
  • 候选模型架构示意图如图 4 所示
  • 稀疏混合模型
    • GPT-3 (2020) 通过交错使用稠密注意力和局部带状稀疏注意力层,集成了一种混合注意力机制,其灵感来自 Sparse Transformer (2019)
    • 稠密注意力提供全上下文建模,而稀疏层则采用固定或跨步模式来减少被关注的 Token 数量
    • 这种设计使得 GPT-3 能够在 2048 个 Token 的固定上下文窗口内高效扩展到大规模模型,平衡了建模能力和计算效率
  • 线性-全注意力混合模型
    • Jamba (2024) 和 MiniMax-01 (2025) 结合了线性注意力和全注意力层,以在吞吐量和表达力之间实现高效的权衡
    • MiniMax-01 在大多数层中使用 Lightning Attention,并每八层插入一次基于 Softmax 的全注意力
    • Jamba 采用了相似的比例,在每八层的 Mamba 块中插入一个 Transformer 层
    • 两者都通过限制计算密集的全注意力的使用,实现了更快的解码和改善的长序列性能
  • 局部-全注意力混合模型
    • Gemma 3 (2025)、Command A (2025) 和 LLaMA-4-Maverick (2025) 在局部和全局注意力层之间交替使用,其共享的设计理念是稀疏地使用全局层(例如,每 4-6 层一次)以提升效率
    • 虽然局部层采用滑动窗口模式,但关键区别在于位置编码策略
    • Gemma 3 调制 RoPE 的基础频率——为局部层分配 10K,为全局层分配 1M(以更好地捕获长距离依赖关系)
    • Command A 和 LLaMA-4-Maverick 混合了基于 RoPE 的局部层与完全省略位置嵌入的全注意力层,从而实现了更强的长序列性能
  • 先进混合模型
    • Character.AI (2024) 将滑动窗口的局部注意力与每六层应用一次的稀疏全局注意力层交错使用
    • 特别是,它们在多个非相邻层之间复用全局注意力层的键值表示
    • 这种 KV 共享机制能以减少的内存和延迟开销实现高效的长上下文处理
  • YOCO (2024) 和 Phi-4-mini-flash (2025) 采用了一种双解码器架构,将预填充阶段和生成阶段分离开来
    • 自解码器在预填充和生成中都使用 RetNet 和滑动窗口注意力等线性注意力机制,而交叉解码器仅在生成期间激活
    • 整个过程中使用单层全局 KV 缓存,实现了线性时间的预填充和高效的解码,同时 GPU 内存消耗最小
  • In summary,这些最新进展强调了混合注意力机制以在不同计算约束和序列长度下实现平衡性能的趋势。每种架构都独特地贡献了关于如何有效结合局部细节管理与全局上下文整合的见解,从而为未来注意力机制的发展提供了有价值的框架

Outlook

  • 本综述全面概述了高效注意力机制,重点关注其算法基础、实际实现以及在大规模预训练语言模型中的集成
  • 通过将线性和稀疏注意力分类为明确定义的范式,论文识别了实现可扩展性、计算效率和长上下文能力的关键设计原则
  • 论文还分析了这些机制在最先进模型中如何部署,无论是作为独立架构还是作为平衡局部和全局计算的混合设计的一部分
  • 展望未来,论文强调几个预计将塑造该领域未来研究的关键方向:
  • 对混合模型的架构理解
    • 虽然先前关于线性注意力的工作主要集中于独立的线性架构,但混合模型通常通过结合现成的线性注意力模块与稠密或局部组件来构建
      • 但更强的线性主干是否直接转化为改进的混合性能尚不清楚
    • 未来的工作应将混合模型作为一个独特的架构类别来研究,试图理解它们的组成、相互作用效应和优化动态
  • 无损稀疏注意力与扩展上下文
    • 稀疏注意力仍然受到精度和计算增益之间权衡的挑战。完全训练的稀疏模型通常性能不如稠密模型,而训练后稀疏近似则由于缺乏端到端训练而面临限制
    • 一个主要的研究前沿在于开发能够保持稠密注意力的表达力和精度,同时扩展到更长上下文的稀疏注意力机制
    • Besides,稀疏预算与上下文长度之间的关系尚不明确,固定的 top-k 方案在序列更长时可能会退化,这需要更具适应性的策略
  • 对稀疏和混合注意力的机制性洞见
    • 实证研究反复证明,混合注意力模型可以用更少的注意力计算匹配甚至超越稠密模型,但其有效性的根本原因仍未得到充分探索
    • 此外,研究在合成基准测试中表现良好的稀疏模式是否适用于现实世界任务,以及描述基于稀疏性的泛化极限,尤为重要
  • 随着基于注意力的模型不断发展,论文预计架构创新、理论洞见和硬件感知设计之间将进一步融合
  • 作者希望本综述能为未来高效、高性能语言建模系统的研究奠定坚实的基础

NLP——Gated-Delta-Net

注:本文包含 AI 辅助创作

  • 参考链接:
    • 原始文档:(Gated DeltaNet)Gated Delta Networks: Improving Mamba2 with Delta Rule, ICLR 2025, NVIDIA
    • GitHub:github.com/NVlabs/GatedDeltaNet

Paper Summary

  • 整体总结:
    • 本文提出了一种新的框架:Gated DeltaNet
      • vs Mamba2:能够实现更好的键值关联学习
      • vs DeltaNet:具有更强的自适应内存清除能力
  • 背景 & 问题:
    • 线性 Transformer (Linear Transformers) 作为标准 Transformer 的高效替代方案获得了关注,但它们在检索和长上下文任务中的性能有限
    • 为了解决这些限制,最近的工作探索了两种不同的机制:
      • 用于自适应内存控制的门控 (gating for adaptive memory control)
      • 用于精确内存修改的 Delta 更新规则 (delta update rule for precise memory modifications)
    • 论文观察到这些机制是互补的,门控支持快速内存擦除,而 Delta 规则促进定向更新
      • 基于这一见解,论文引入了门控 Delta 规则 (gated delta rule),并开发了一种针对现代硬件优化的并行训练算法
  • 论文提出新的架构 Gated DeltaNet
    • 在包括语言建模、常识推理、上下文检索、长度外推和长上下文理解在内的多个基准测试中,超越了 Mamba2 和 DeltaNet 等现有模型
  • 作者通过开发将 Gated DeltaNet 层与滑动窗口注意力 (sliding window attention) 或 Mamba2 层相结合的混合架构,进一步提升了性能,实现了改进的训练效率和卓越的任务性能

Introduction and Discussion

  • Transformer 架构显著提升了 LLM 的能力,由于其有效的注意力机制,在各种任务上展现出卓越的性能
    • 该机制在精确序列建模方面表现出色,并在训练期间利用了现代 GPU 的并行处理能力
    • 但自注意力 (self-attention) 组件的计算复杂度随序列长度呈二次方增长,导致巨大的计算需求,给训练和推理都带来了挑战
  • 为了缓解这些问题,研究人员探索了诸如线性 Transformer (2020a) 等替代方案,它们用基于核化点积的线性注意力 (kernelized dot-product-based linear attention) 取代了传统的基于 softmax 的注意力,通过将其重构为具有矩阵值状态的线性 RNN,显著减少了推理期间的内存需求
  • 虽然早期版本的线性 Transformer 在语言建模任务中表现不如标准 Transformer,但最近的增强已经显示出有前景的改进
    • 例如结合类似于 LSTM 中的数据依赖门控机制,以 GLA (2024a) 和 Mamba2 (2024a) 等模型为例
  • 然而,在管理长序列信息方面仍然存在挑战,特别是在上下文检索任务中,传统 Transformer 仍保持其优势 (2023a; 2024;)
  • 这种现象并不令人惊讶:
    • 线性 Transformer 可以被解释为实现了一种基于外积 (outer-product) 的键值关联记忆,让人联想到张量积表示 (1990)
    • 但它们可以存储的正交键值对的数量受到模型维度的限制
      • 当序列长度超过这个维度时,“记忆碰撞 (memory collisions)”将不可避免,阻碍精确检索 (2021a)
  • Mamba2 通过引入一个简单的门控更新规则来解决这个限制:
    $$ \mathbf{S}_{t}=\alpha_{t}\mathbf{S}_{t-1}+\boldsymbol{v}_{t}\boldsymbol{k}^{\mathrm{T} }_{t}$$
    • 该规则在每个时间步通过一个动态比率 \(\alpha_t \in (0,1)\) 统一衰减所有键值关联
  • 但这种方法没有考虑不同键值关联的重要性差异,可能导致内存利用效率低下
    • 如果模型需要忘记一个特定的键值关联,所有键值关联都会被同等程度地遗忘 ,使得这个过程缺乏针对性和效率
  • 相比之下,采用 Delta 规则 (1960) 的线性 Transformer,即 DeltaNet (2021a; 2024b),通过(软性地)用传入的新键值对替换旧的键值对来选择性更新记忆
    • 这种方法在上下文检索的合成基准测试中展示了令人印象深刻的性能
    • 但由于这个过程((一次只修改一个键值对)),模型缺乏快速清除过时或无关信息的能力 ,尤其是在需要擦除先前数据的上下文切换期间
    • 因此,人们发现 DeltaNet 在现实世界任务中表现一般 (2024b),这很可能是因为缺乏强大的内存清除机制
  • 认识到门控更新规则和 Delta 规则在内存管理方面的互补优势,论文提出了门控 Delta 规则 (gated delta rule)
    • 这是一种简单直观的机制,结合了两种方法
    • 这个统一的规则实现了灵活的内存控制:
      • 可以通过设置 \(\alpha_{t}\to 0\) 来迅速清除内存
      • 同时可以通过设置 \(\alpha_{t}\to 1\) 来选择性更新特定内容而不影响其他信息(有效地切换到纯 Delta 规则)
  • 剩下的挑战在于以硬件高效的方式实现门控 Delta 规则
    • 基于 (2024b) 使用 WY 表示 (1985) 并行化 Delta 规则计算的高效算法,论文仔细扩展了他们的方法以纳入门控项
    • 论文的扩展保留了分块并行 (chunkwise parallelism) 的优势 (2022a; 2023a; 2024a),实现了硬件高效的训练
  • 论文最终的架构 Gated DeltaNet,在一套全面的基准测试中,包括语言建模、常识推理、上下文检索、长度外推和长上下文理解,持续优于 Mamba2 和 DeltaNet
  • 基于这些结果,作者还开发了混合架构,策略性地将 Gated DeltaNet 层与滑动窗口注意力或 Mamba2 层相结合,进一步提升了训练效率和模型性能

Preliminary

Mamba2:带衰减的线性注意力 (Mamba2: Linear Attention with decay)

  • 众所周知,线性 Transformer (2020a) 在排除归一化和查询/键激活的情况下,可以表述为以下线性递归:
    $$
    \begin{align}
    \mathbf{S}_{t}&=\mathbf{S}_{t-1}+\boldsymbol{v}_{t}\boldsymbol{k}_{t}^{\intercal} \in\mathbb{R}^{d_{v}\times d_{k} },\\
    \boldsymbol{o}_{t}&=\mathbf{S}_{t}\boldsymbol{q}_ {t}\in\mathbb{R}^{d_{v} }
    \end{align}
    $$
    • 其中 \(d_{k}\) 和 \(d_{v}\) 分别代表查询/键 (query/key) 和值 (value) 的(头)维度
  • 通过展开递归,我们可以将其表示为向量形式(左)和矩阵形式(右)如下:
    $$
    \begin{align}
    \boldsymbol{o}_{t}&=\sum_{i=1}^{t}(\boldsymbol{v}_{i}\boldsymbol{k}_{i}^{\intercal})\boldsymbol{q }_{t}=\sum_{i=1}^{t}\boldsymbol{v}_{i}(\boldsymbol{k}_{i}^{\intercal}\boldsymbol{q}_{t})\in\mathbb{R}^{d_{v} },\\
    \mathbf{O}&=(\mathbf{Q}\mathbf{K}^{\intercal}\odot\mathbf{M}) \mathbf{V}\in\mathbb{R}^{L\times d_{v} }
    \end{align}
    $$
    • 其中 \(L\) 是序列长度,\(\mathbf{M}\in\mathbb{R}^{L\times L}\) 是由 \(\mathbf{M}_{ij}=0\)(当 \(i< j\))和 \(1\)(其他情况)定义的因果掩码 (causal mask)
  • 然而,这种普通的线性注意力在语言建模中表现远不如 Transformer
    • 理解:本质上 \(\sum_{i=1}^{t}\boldsymbol{v}_{i}(\boldsymbol{k}_{i}^{\intercal}\boldsymbol{q}_{t})\in\mathbb{R}^{d_{v} }\) 也有了一定的按照 Q 和 K 的相似度对 V 进行加权的思想,但这里主要差异在于 Transformer 是 Softmax 的 Attention 权重:
      $$Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax\left(\frac{\boldsymbol{Q}\boldsymbol{K}^{\top}}{\sqrt{d_k}}\right)\boldsymbol{V}$$
    • 理解:线性 Attention 和 Softmax Attention 都是 QK 内积越大,V 权重越大,但两者有本质区别:
      • 线性 Attention 未实现归一化,仅仅累加所有 Token 的 V,理论上会导致越靠后的 Token,方差是越大的
      • Softmax Attention 的权重是 \(e^{qk}\) 加权归一化,本质与线性归一化是不同的
      • 问题:如果线性 Attention 中使用 \(e^{qk}\) 作为累加,是否基本上可以实践 Softmax 的等价实现?
        • 回答:不可以,因为 Softmax Attention 中的 权重是 \(e^{qk}\),而线性 Attention 累加的对象是 \(kv\),重点:即使使用 \(e^{kv}\) 累加,累加对象也不同,反而导致含义变了!线性 Attention 的形式导致了他们无法实现 Softmax Attention 这样的 \(e^{qk}\) 加权平均
  • 为了解决这个问题,通常添加一个衰减项来遗忘历史信息
    • 这里论文以 Mamba2 (2024a) 为例,它可以表示为以下线性递归(取决于具体的参数化):
      $$
      \begin{align}
      \mathbf{S}_{t}&=\alpha_{t}\mathbf{S}_{t-1}+\boldsymbol{v}_{t}\boldsymbol{k}_{t}^ {\intercal},\\
      \boldsymbol{o}_{t}&=\mathbf{S}_{t}\boldsymbol{q}_{t}
      \end{align}
      $$
    • 其中 \(\alpha_{t}\in(0,1)\) 是一个数据依赖的标量值衰减项,随 \(t\) 变化
  • 定义累积衰减乘积
    $$\gamma_{j}=\prod_{i=1}^{j}\alpha_{i}$$
  • 并通过展开递归,我们可以将结果表示为向量形式(左)和矩阵并行形式(右):
    $$
    \begin{align}
    \boldsymbol{o}_{t}&=\sum_{i=1}^{t}\left(\frac{\gamma_{t} }{\gamma_{i} }\boldsymbol{v }_{i}\boldsymbol{k}_{i}^{\intercal}\right)\boldsymbol{q}_{t}=\sum_{i=1}^{t}\boldsymbol{v}_{i}\left( \frac{\gamma_{t} }{\gamma_{i} }\boldsymbol{k}_{i}^{\intercal}\boldsymbol{q}_{t}\right),\\
    \mathbf{O}&=((\mathbf{Q}\mathbf{K}^{\intercal})\odot\Gamma),\mathbf{V}
    \end{align}
    $$
    • 这里,\(\Gamma\in\mathbb{R}^{L\times L}\) 是一个衰减感知的因果掩码,其中 \(\Gamma_{ij}=\frac{\gamma_{i} }{\gamma_{j} }\)(如果 \(i\geq j\)),否则 \(\Gamma_{ij}=0\)
  • 这种并行形式和递归形式之间的等价性也被称为 状态空间对偶性 (state space duality, SSD) (2024a)
    • 这种递归结构也出现在其他几种架构中,包括 Gated RFA (2021)、xLSTM (2024) 和 Gated RetNet (2023a)
    • 当 \(\gamma_{t}\) 与数据无关时(data-independent),该公式简化为 RetNet (2023a) 和 Lightning-Attention (2024a)
  • 此外,如果 \(\gamma_{t}\) 扩展为矩阵值而非标量值,当使用外积结构参数化时,高效的训练算法仍然是可能的,正如 (2024a) 所展示并被 (2024b;2024;2025 等) 所使用的 分块训练 (Chunkwise training)
    • 但递归形式和并行形式对于高效训练来说都不是理想的 (2022a; 2024a),这促使了使用分块并行形式 (2022a; 2023a) 进行硬件高效的线性时间训练,如下所述
  • 总结来说,分块并行形式将输入和输出分割成几个大小为 \(C\) 的块 (chunk),并根据前一个块的最终状态以及当前块的查询/键/值块来计算每个块的输出
    • 遵循 (2023a); (2024a) 的符号,论文以查询块 \(\boldsymbol{q}\) 为例
  • 论文将
    • \(\mathbf{Q}_{[t]}:={\boldsymbol{q} }_{tC+1:(t+1)C+1}\) 表示为块 \(t\) 的查询块
    • \({\boldsymbol{q} }_{[t]}^{r}:={\boldsymbol{q} }_{tC+r}\) 表示为块 \(t\) 内的第 \(r\) 个查询块
    • \(t\) 的初始状态定义为 \(\mathbf{S}_{[t]}:=\mathbf{S}_{[t]}^{0}=\mathbf{S}_{[t-1]}^{C}\)
  • 通过部分展开递归,论文有
    $$
    \begin{align}
    \mathbf{S}_{[t]}^{r}&=\mathbf{S}_{[t]}+\sum_{i=1}^{r}\boldsymbol{v}_{[t]}^{i}\boldsymbol{k}_{[t ]}^{i\tau}\in\mathbb{R}^{d_{v}\times d_{k} },\\
    \boldsymbol{\sigma}_{[t]}^{r}&=\mathbf{S }_{[t]}^{r}{\boldsymbol{q} }_{[t]}^{r}=\mathbf{S}_{[t]}{\boldsymbol{q} }_{[t]}^{r}+\sum_{i=1}^{r} \boldsymbol{v}_{[t]}^{i}\left(\boldsymbol{k}_{[t]}^{i\tau}{\boldsymbol{q} }_{[t]}^{r}\right)\in\mathbb{R}^ {d_{v} }
    \end{align}
    $$
  • 等价地,以矩阵形式表示:
    $$
    \begin{align}
    \mathbf{S}_{[t+1]}&=\mathbf{S}_{[t]}+\mathbf{V}_{[t]}\mathbf{K}_{[t]}^{\tau} \in\mathbb{R}^{d_{v}\times d_{k} },\\
    \mathbf{O}_{[t]}&=\mathbf{Q}_{[t]}{ \mathbf{S} }_{[t]}^{\tau}+\left(\mathbf{Q}_{[t]}\mathbf{K}_{[t]}^{\tau}\odot \mathbf{M}\right)\mathbf{V}_{[t]}\in\mathbb{R}^{C\times d_{v} }
    \end{align}
    $$
    • 其中 \(\mathbf{M}\in\mathbb{R}^{C\times C}\) 是因果掩码
    • 上述等式富含矩阵乘法 (matmuls),允许基于张量核心 (tensor-core) 的硬件优化
  • 这个分块算法可以轻松扩展到带衰减的线性注意力:
    $$
    \begin{align}
    \mathbf{S}_{[t+1]}&=\overrightarrow{\mathbf{S}_{[t]} }+\mathbf{V}_{[t]}^{\tau} \overrightarrow{\mathbf{K}_{[t]} }\in\mathbb{R}^{d_{v}\times d_{k} },\\
    \mathbf {O}_{[t]}&=\overleftarrow{\mathbf{Q}_{[t]}\mathbf{S}_{[t]}^{\tau} }+\left( \mathbf{Q}_{[t]}\mathbf{K}_{[t]}^{\tau}\odot\Gamma_{[t]}\right)\mathbf{V}_{[t ]}\in\mathbb{R}^{C\times d_{v} }
    \end{align} \tag{1}
    $$
    • 其中有
      $$(\Gamma_{[t]})_{ij}=\frac{\gamma_{[t]}^{i} }{\gamma_{[t]}^{j} },\gamma_{[t]}^{j}= \prod_{j=tC+1}^{tC+j}\alpha_{j} $$
      • 注:这里论文稍微滥用了 \(\gamma\) 的符号来表示每个块的累积乘积(分别从每个块的第一个位置开始),而不是整个序列
      • 这里论文使用左箭头 (\(\overset{\leftarrow}{\cdot}\)) 或右箭头 (\(\overset{\rightarrow}{\cdot}\)) 分别表示衰减到每个块第一个位置和最后一个位置的变量,
        $$
        \begin{align}
        \overleftarrow{ {\boldsymbol{q} }_{[t]}^{r} }&=\gamma_{[t]}^{r}{\boldsymbol{q} }_{[t]}^{r} \text{ 将每个向量衰减到块 t 的第一个位置 (decaying each vector to the first position of chunk t)}\\
        \overleftarrow{ {\boldsymbol{k} }_{[t]}^{r} }&=\frac{\gamma_{[t]}^{C} }{ \gamma_{[t]}^{c} }\boldsymbol{k}_{[t]}^{r} \text{ 将每个向量衰减到块 t 的最后一个位置 (decaying each vector to the last position of chunk t)} \\
        \overrightarrow{\mathbf{S}_{[t]} }&=\gamma_{[t]}^{C}\mathbf{S}_{[t]} \text{ 在整个块 t 上衰减状态矩阵 (decaying the state matrix over the entire chunk t)} \tag{2}
        \end{align}
        $$
        • 其他变量(例如 \(\overrightarrow{\boldsymbol{v} }\))也类似
  • Mamba2 中引入的 SSD 分解算法在很大程度上等同于这种分块算法
    • 对于更通用的方法, (2024a) 提出了一种扩展的分块算法,用于线性注意力,该算法结合了细粒度的衰减机制

Delta Networks: Linear Attention with Delta Rule

  • Delta 更新规则(Delta Update Rule) (1960; 2021a) 动态地擦除与当前输入键 (\(\boldsymbol{k}_{t}\)) 关联的旧值 (\(\boldsymbol{v}_{t}^{\text{old} }\)) ,并写入一个新值 (\(\boldsymbol{v}_{t}^{\text{new} }\)) ,该新值是基于“写入强度” \(\beta_{t}\in(0,1)\) 的当前输入值和旧值的线性组合 (可以将 \(\beta_{t}\in(0,2)\) 设置为允许负特征值,以解锁 DeltaNet 的状态跟踪能力 (2024; 2025))
    $$
    \begin{align}
    \mathbf{S}_{t}&=\mathbf{S}_{t-1}-\underbrace{(\mathbf{S}_{t-1}\boldsymbol{k}_{t})}_{\boldsymbol {v}_{t}^{\text{old} } }\boldsymbol{k}_{t}^{\top}+\underbrace{(\beta_{t}\boldsymbol{v}_{t}+(1-\beta_{t})\mathbf{S}_{t-1}\boldsymbol{k}_{t}))}_{\boldsymbol{v}_{t}^{\text{new} } }\boldsymbol{k}_{t}^{\top} \\
    &= \mathbf{S}_{t-1}\left(\mathbf{I}-\beta_{t}\boldsymbol{k}_{t}\boldsymbol{k}_{t}^{\top}\right)+ \beta_{t}\boldsymbol{v}_{t}\boldsymbol{k}_{t}^{\top}
    \end{align}
    $$
    • 如上所示,DeltaNet 实现了具有广义 Householder 转移矩阵 (\(\mathbf{I}-\beta_{t}\boldsymbol{k}_{t}\boldsymbol{k}_{t}^{\mathsf{T} }\)) 的一阶线性递归
    • 尽管在关联回忆和语言建模性能上表现出色 (2021a),但由于计算效率低下,DeltaNet 受到的关注有限,直到 (2024b) 引入了一种硬件高效的分块训练算法,详情如下
  • 分块并行形式 (Chunkwise parallel form)。通过部分展开递归,论文有
    $$\mathbf{S}_{[t]}^{r}=\mathbf{S}_{[t]}\underbrace{\left(\prod_{i=1}^{r} \mathbf{I}-\beta_{[t]}^{i}\boldsymbol{k}_{[t]}^{i}\boldsymbol{k}_{[t]}^{i\mathsf{T} }\right)}_{:= \mathbf{P}_{[t]}^{r} }+\underbrace{\sum_{i=1}^{r}\left(\beta_{[t]}^{i}\boldsymbol{v}_{[t ]}^{i}\boldsymbol{k}_{[t]}^{i\mathsf{T} }\prod_{j=i+1}^{r}\left(\mathbf{I}-\beta_{[t]}^{ j}\boldsymbol{k}_{[t]}^{j}\boldsymbol{k}_{[t]}^{j\mathsf{T} }\right)\right)}_{:=\mathbf{H}_{[t]}^{r} } \tag{3}$$
  • 其中 \(\mathbf{P}_{[t]}^{j}\) 涉及广义 Householder 矩阵的累积乘积,可以通过经典的 WY 表示 (1985) 进行优化:
    $$\mathbf{P}_{[t]}^{r}=\mathbf{I}-\sum_{i=1}^{r}\mathbf{w}_{[t]}^{i}\boldsymbol{k}_{[t]}^{ i\mathsf{T} }\in\mathbb{R}^{d_{k}\times d_{k} }\qquad\mathbf{w}_{[t]}^{r}=\beta_{ [t]}^{r}\left(\boldsymbol{k}_{[t]}^{r}-\sum_{i=1}^{r-1}\left(\mathbf{w}_{[t]}^{i}(\boldsymbol{k}_ {[t]}^{i\mathsf{T} }\boldsymbol{k}_{[t]}^{r})\right)\right)\in\mathbb{R}^{d_{k} } \tag{4}$$
  • 同样,\(\mathbf{H}_{[t]}^{r}\) 可以表示为:
    $$\mathbf{H}_{[t]}^{r}=\sum_{i=1}^{r}\mathbf{u}_{[t]}^{i}\boldsymbol{k}_{[t]}^{i\mathsf{T } }\in\mathbb{R}^{d_{v}\times d_{k} }\qquad\mathbf{u}_{[t]}^{r}=\beta_{[t]}^{r} \left(\boldsymbol{v}_{[t]}^{r}-\sum_{i=1}^{r-1}\left(\mathbf{u}_{[t]}^{i}(\boldsymbol{k}_{[t]}^{ i\mathsf{T} }\boldsymbol{k}_{[t]}^{r})\right)\right)\in\mathbb{R}^{d_{v} } \tag{5}$$
  • 并以矩阵形式表示:\(\mathbf{P}_{[t]}=\mathbf{I}-\mathbf{W}_{[t]}^{\top}\mathbf{K}_{[t]}\in\mathbb{R }^{d_{k}\times d_{k} }\),\(\mathbf{H}_{[t]}=\mathbf{U}_{[t]}^{\top}\mathbf{K}_{[t]}\in\mathbb{R}^{d_{v} \times d_{k} }\)。通过使用 UT 变换 (2006),我们可以进一步将 \(\mathbf{W}\) 和 \(\mathbf{U}\) 写成矩阵形式:
    $$\mathbf{T}_{[t]} =\left[\mathbf{I}+\text{strictLower}\left(\text{diag}(\beta_{[t]} )\mathbf{K}_{[t]}\mathbf{K}_{[t]}^{\mathsf{T} }\right)\right]^{-1}\text{ diag}\left(\beta_{[t]}\right)\in\mathbb{R}^{C\times C} \\
    \mathbf{W}_{[t]} =\mathbf{T}_{[t]}\mathbf{K}_{[t]}\in\mathbb{R}^{C\times d_{k} }, \qquad\mathbf{U}_{[t]}=\mathbf{T}_{[t]}\mathbf{V}_{[t]}\in\mathbb{R}^{C \times d_{v} }\tag{6-7}$$
  • 将这些代回方程 3,得到了一个硬件高效的 DeltaNet 分块算法,该算法利用了矩阵乘法,实现了基于张量核心的 GPU 优化:
    $$\mathbf{S}_{[t+1]} =\mathbf{S}_{[t]}\mathbf{P}_{[t]}+\mathbf{H}_{[t]}=\mathbf{S}_{[t ]}+\left(\mathbf{U}_{[t]}-\mathbf{W}_{[t]}\mathbf{S}_{[t]}^{\mathsf{T} }\right)^ {\mathsf{T} }\mathbf{K}_{[t]} \in\mathbb{R}^{d_{v}\times d_{k} } \\
    \mathbf{O}_{[t]} =\mathbf{Q}_{[t]}\mathbf{S}_{[t]}^{\mathsf{T} }+(\mathbf{Q}_{[t]} \mathbf{K}_{[t]}^{\mathsf{T} }\odot\mathbf{M})\left(\mathbf{U}_{[t]}-\mathbf{W }_{[t]}\mathbf{S}_{[t]}^{\mathsf{T} }\right) \in\mathbb{R}^{C\times d_{v} } \tag{8-9}$$

Gated Delta Networks

Formulation: Gated Delta Rule

  • 论文提出的门控 Delta 规则简单而有效:
    $$\mathbf{S}_{t}=\mathbf{S}_{t-1}\left(\alpha_{t}(\mathbf{I}-\beta_{t}\boldsymbol{k}_{t }\boldsymbol{k}_{t}^{\mathsf{T} })\right)+\beta_{t}\boldsymbol{v}_{t}\boldsymbol{k}_{t}^{\mathsf{T} } \tag{10}$$
    • 其中数据依赖的门控项 \(\alpha_{t}\in(0,1)\) 控制状态衰减
    • 这个公式统一了门控机制和 Delta 规则的优势:门控项支持自适应内存管理,而 Delta 更新结构促进有效的键值关联学习
  • 论文通过 Liu 等 (2024) 引入的在线学习框架的视角,对门控 Delta 规则进行了正式分析
    • 在这个框架中,循环状态更新作为在线学习问题的闭式解出现,如表 1 所示
    • 最近的线性 RNN 架构通常在其在线学习目标中包含一个正则化项,以防止状态偏离先前的值,从而实现记忆保留
      • 但当状态被信息饱和时,这种保留机制会变得有问题。在这种情况下,每个状态将对多个信息片段进行编码,使得精确检索变得困难
    • 为了解决这个限制,Mamba2 和 Gated DeltaNet 引入了一个自适应缩放因子 \(\alpha_{t}\),它放松了正则化项,允许 \(\mathbf{S}_{t}\) 和 \(\mathbf{S}_{t-1}\) 之间的受控偏差
      • 这个修改通过选择性遗忘实现了动态内存管理,这在过滤无关信息时可能很有用(见 章节3.2)
  • 另一方面,线性注意力 (Linear Attention, LA) 和 Mamba2 使用简单的负内积损失 -\(\langle\mathbf{S}_{t}\boldsymbol{k}_{t},\boldsymbol{v}_{t}\rangle\),而 Longhorn (2024) 使用更具表达力的在线回归目标 \(|\mathbf{S}_{t}\boldsymbol{k}_{t}-\boldsymbol{v}_{t}|^{2}\) 来更好地建模键值关联。由此产生的 Longhorn 更新规则与 Delta 更新规则非常相似,这表明(门控)Delta 规则在上下文关联回忆方面优于 Mamba2
    • Longhorn 更新规则与 Delta 更新规则理论上的区别在于优化方法:Longhorn 使用隐式在线学习 (2010) 来推导闭式的全局最优更新,而 DeltaNet 通过一步显式梯度下降来优化相同的目标,正如 (2024) 所指出的
  • 从快速权重编程 (fast weight programming) (2022a)、测试时训练 (test-time training) (2024a) 和回归 (2025) 的角度来看,隐藏状态 \(\mathbf{S}\) 可以被解释为一个(快速)权重矩阵,Delta 规则通过测试时随机梯度下降 (stochastic gradient descent, SGD) 优化在线回归目标 \(\mathcal{L}(\mathbf{S}_{t})=\frac{1}{2}|\mathbf{S}_{t}\boldsymbol{k}_{t}-\boldsymbol{v}_{t}|^{2}\):
    $$\mathbf{S}_{t+1}=\mathbf{S}_{t}-\beta_{t}\nabla\mathcal{L}(\mathbf{S}_{t})= \mathbf{S}_{t}-\beta_{t}(\mathbf{S}_{t}\boldsymbol{k}_{t}-\boldsymbol{v}_{t})\boldsymbol{k}_{t}^{\top}= \mathbf{S}_{t}\left(\mathbf{I}-\beta_{t}\boldsymbol{k}_{t}\boldsymbol{k}_{t}^{\top}\right)+\beta_{t}\boldsymbol{v}_{t}\boldsymbol{k}_{t}^{\top}$$
    • 其中 \(\beta_{t}\) 代表(自适应)学习率。从这个角度来看,门控 Delta 规则可以被视为将自适应权重衰减项 \(\alpha_{t}\) 纳入 SGD 更新中,这是一种在深度学习中广泛使用的技术 (1991; 2023)
    • 同时,Titans (2024) 证明了在 RNN 测试时 SGD 更新中结合权重衰减机制的有效性

Case study: Single Needle in a Haystack, S-NIAH,大海捞针

  • 为了更好地理解 Delta 规则和门控规则之间的互补优势,论文提供了一个关于来自 RULER (2024) 的“大海捞针”基准测试套件的案例研究,其中一个键值对充当“大海”(上下文)中的“针”,模型必须在给定键时回忆起值
  • 表 2 展示了结果,论文得出三个主要观察:
    • 衰减损害记忆保留 (Decay hurts memory retention)
      • 在最简单的 S-NIAH-1 设置中,具有重复的合成上下文,模型记忆的信息最少,测试长期保留能力
      • DeltaNet 在所有序列长度上都实现了接近完美的性能
      • Mamba2 在超过 2K 序列后性能显著下降,因为它衰减历史信息太快,而 Gated DeltaNet 的下降不那么严重,这得益于使用了 Delta 规则
    • 门控促进过滤 (Gating facilitates filtering)
      • 在具有真实世界文章上下文的 S-NIAH-2/3 中,模型存储所有潜在相关信息,测试高效的内存管理
      • 在固定状态大小的情况下,缺乏清除会导致内存碰撞,信息变得叠加且无法区分
      • DeltaNet 的性能在较长序列时显著下降,原因是内存清除能力差
      • Mamba2 和 Gated DeltaNet 通过过滤无关信息的门控机制保持了更好的性能
    • Delta 规则有助于记忆 (Delta rule helps memorization)
      • 在 S-NIAH-3 中,值从数字变为 UUID,测试复杂模式记忆
      • Mamba2 的性能迅速下降,而 Gated DeltaNet 表现更好,验证了 Delta 规则确实具有更好的记忆能力

Algorithm: Hardware-efficient Chunkwise training

  • 在本小节中,论文推导出用于训练 Gated DeltaNet 的硬件高效分块算法。通过部分展开方程 10 中的递归,论文有
    $$\mathbf{S}_{[t]}^{r}=\underbrace{\mathbf{S}_{[t]}\left(\prod_{i=1}^{ r}\alpha_{[t]}^{i}\left(\mathbf{I}-\beta_{[t]}^{i}\boldsymbol{k}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i\mathsf{T} }\right)\right)}_{\text{:=F}_{[t]}^{r} }+\underbrace {\sum_{i=1}^{r}\left(\beta_{[t]}^{i}\boldsymbol{v}_{[t]}^{i}\boldsymbol{k}_{[t]} ^{i\mathsf{T} }\prod_{j=i+1}^{r}\alpha_{[t]}^{j}\left(\mathbf{I}-\beta_{[t]}^{j }\boldsymbol{k}_{[t]}^{j}\boldsymbol{k}_{[t]}^{j\mathsf{T} }\right)\right)}_{ \text{:=G}_{[t]}^{r} }$$
    • 很容易看出 \(\mathbf{F}_{[t]}^{r}=\gamma_{[t]}^{r}\mathrm{P}_{[t]}^{r}=\hat{\overline{\mathbf {P} }_{[t]}^{r} }\)
    • 至于 \(\mathbf{G}_{[t]}^{r}\),论文调整方程 5 如下:
      $$\mathbf{G}_{[t]}^{r}=\sum_{i=1}^{r}\frac{\gamma_{[t]}^{r} }{\gamma_{[t]}^{i} }\tilde{\mathbf{u} }_{[t]}^{i}\boldsymbol{k}_{[t]}^{i\mathsf{T} }\in \mathbb{R}^{d_{v}\times d_{k} }\qquad\tilde{\mathbf{u} }_{[t]}^{r}=\beta_{[t]}^{r }\left(\boldsymbol{v}_{[t]}^{r}-\sum_{i=1}^{r-1}\left(\tilde{\mathbf{u} }_{[t]}^ {i}(\frac{\gamma_{[t]}^{r} }{\gamma_{[t]}^{i} }\boldsymbol{k}_{[t]}^{i\mathsf{T } }\boldsymbol{k}_{[t]}^{r})\right)\right)\in\mathbb{R}^{d_{v} }$$
      • (证明见附录 A)
    • 通过 UT 变换,论文得到矩阵形式:
      $$\widetilde{\mathbf{U} }_{[t]}^{-}=\left[\mathbf{I}+\text{strictLower}\left( \operatorname{diag}\left(\beta_{[t]}\right)(\Gamma_{[t]}\odot\mathbf{K}_{[t]} \mathbf{K}_{[t]}^{\mathsf{T} })\right)\right]^{-1}\operatorname{diag}\left( \beta_{[t]}\right)\mathbf{V}_{[t]}\qquad\in\mathbb{R}^{C\times d_{v} }$$
  • 类似于 Mamba2 扩展线性注意力的方式(方程 1),我们可以调整 DeltaNet 的分块算法(方程 8-9)用于 Gated DeltaNet,以实现硬件高效训练,如下所示:
    $$
    \begin{align}
    \mathbf{S}_{[t+1]}&=\overrightarrow{\mathbf{S}_{[t]} }+\left( \overrightarrow{\mathbf{U} }_{[t]}-\overleftarrow{\mathbf{W} }_{[t]}\mathbf{S}_{[ t]}^{\mathsf{T} }\right)^{\mathsf{T} }\overrightarrow{\mathbf{K} }_{[t]} \in\mathbb{R}^{d_{v}\times d_{k} } \\
    \mathbf{O}_{[t]}&=\overleftarrow{\mathbf{O} }_{[t]}\mathbf{S}_{[t]} ^{\mathsf{T} }+\left(\mathbf{Q}_{[t]}\mathbf{K}_{[t]}^{\mathsf{T} }\odot \mathbf{M}\right)\left(\overrightarrow{\mathbf{U} }_{[t]}-\overleftarrow{ \mathbf{W} }_{[t]}\mathbf{S}_{[t]}^{\mathsf{T} }\right)\in\mathbb{R}^{C\times d_ {v} }
    \end{align}
    $$
    • 其中 \(\overleftarrow{\boldsymbol{q} }_{[t]}^{r}=\gamma_{[t]}^{r}\boldsymbol{q}_{[t]}^ {r}\), \(\overleftarrow{\mathbf{w} }_{[t]}^{r}=\gamma_{[t]}^{r}\mathbf{w}_{[t]}^{r}\), \(\overleftarrow{\boldsymbol{k} }_{[t]}^{r}=\frac{\gamma_{[t]}^{C} }{\gamma_{[t]} } \boldsymbol{k}_{[t]}^{r}\), 并且 \(\overrightarrow{\mathbf{S} }_{[t]}\neq\gamma_{[t]}^{C}\mathbf{S}_{[t]}\) 如同论文在方程 2 中的定义

Gated Delta Networks and Hybrid Models

  • Token 混合器块 (Token mixer block)
    • 基本的 Gated DeltaNet 遵循 Llama 的宏观架构,将 Token 混合器层与 SwiGLU MLP 层堆叠,但用门控 Delta 规则 Token 混合取代了自注意力
    • 图 1(右)显示了其块设计
      • 对于门控 Delta 规则(方程 10),查询、键和值 \(\{\boldsymbol{q},\boldsymbol{k},\boldsymbol{v}\}\) 通过线性投影、短卷积 (short convolution) 和 SiLU 生成,并对 \(\boldsymbol{q},\boldsymbol{k}\) 应用 L2 归一化以保持训练稳定性。\(\alpha,\beta\) 仅使用线性投影(论文对 \(\alpha\) 使用 Mamba2 的参数化,但为简洁起见省略了细节)
      • 遵循 Sun 等 (2023a),输出在应用输出投影之前通过归一化和门控处理
  • 混合模型 (Hybrid models)
    • 线性 Transformer 在建模局部偏移和比较方面存在局限性,并且其固定状态大小使得检索任务变得困难 (2023a)
    • 遵循最近的混合架构,如 Griffin (2024) 和 Samba (2024),论文将线性循环层与滑动窗口注意力 (sliding window attention, SWA) 相结合,产生了 GatedDeltaNet-H1
    • 论文还堆叠了 Mamba2、GatedDeltaNet 和 SWA,产生了 GatedDeltaNet-H2

Experiments

Setup

  • 论文的实验包括对最近最先进架构的全面比较,包括纯 Transformer 模型、基于 RNN 的方法和混合架构
  • 论文针对以下基线进行评估:RetNet (2023a)、HGRN2 (2024b)、Mamba (2023)、Mamba2 (2024a)、Samba (2024) 和 DeltaNet (2024b)
  • 为了公平比较,所有模型都在相同条件下训练,具有 1.3B 参数,使用从 FineWeb-Edu 数据集 (2024) 中采样的 100B Token
  • 论文使用 AdamW 优化器,峰值学习率为 4e-4,权重衰减为 0.1,梯度裁剪为 1.0
  • 学习率遵循余弦退火调度,具有 1B Token 的预热期 和 0.5M Token 的批次大小
  • 所有模型都使用词表为 32,000 的 Llama2 分词器
  • 对于序列建模,论文将训练长度设置为 4K Token ,Samba 和论文的混合模型使用 2K 的滑动窗口大小
  • 评估设置见附录 B.1,消融研究见附录 B.2

Common-sense reasoning

  • 在表 3 中,论文展示了具有 0.4B 和 1.3B 参数模型的语言建模困惑度以及在常识推理基准测试上的零样本准确率
  • Gated DeltaNet 在两个规模上都持续优于其他线性模型,包括 RetNet、HGRN2、Mamba、Mamba2 和 DeltaNet
  • 正如预期的那样,混合变体进一步提升了性能

In-context retrieval on real-world data

  • 表 4 展示了在 Arora 等 (2023a) 使用的真实世界回忆密集型任务上的结果
  • 正如预期的那样,线性循环模型与 Transformer 相比显示出显著的性能差距,而结合线性循环和注意力的混合模型在检索任务中优于纯注意力模型
  • 对于纯循环模型,尽管 DeltaNet 在合成的上下文检索任务上表现出色 (2024b),但其真实世界的检索性能落后于 Mamba2,这与论文在 S-NIAH-2 和 S-NIAH-3 中的观察一致(表 2)
  • Gated DeltaNet 由于其门控 Delta 规则,性能优于 DeltaNet 和 Mamba2,尽管改进幅度小于表 2
  • 论文将这种性能差距的缩小归因于未进行指令调优的小型语言模型容易产生重复错误,这些错误是这些任务中错误的主要来源(参见 Arora 等 (2023a, 附录 E))
  • 由于这个问题在很大程度上与更新规则的选择无关,因此模型之间的性能差异与表 2 相比不那么明显

Length extrapolation on long sequences

  • 如图 2 所示,论文评估了模型在六个长上下文基准测试中外推到长达 20K Token 序列的能力
    • 在 RNN 模型中,Gated DeltaNet 在所有任务中实现了最低的整体困惑度
    • 虽然论文在长度外推中观察到结果好坏参半,但 Gated DeltaNet 表现出相对更稳健的性能,表明其具有更好的内存管理能力
    • 混合模型通过利用注意力进行局部上下文建模,进一步改善了这一点,从而减轻了其循环组件在内存管理上的负担
    • 未来的工作将探索这些模型在更长序列上的能力

Long context understanding

  • 如表 5 所示,论文评估了模型在 LongBench (2023) 上的性能
  • 在循环模型中,Gated DeltaNet 显示出持续的优势,特别是在单文档问答、少样本上下文学习和代码任务中,分别展示了其在检索、上下文学习和状态跟踪方面的卓越能力

Throughput Comparison

  • 不同模型的训练吞吐量对比如图 3 所示
    • 正如论文的分析所示,论文提出的门控 delta 规则 (gated delta rule) 与原始 delta 规则相比仅引入了微小的开销,Gated DeltaNet 的吞吐量与 DeltaNet 基本相同
    • 由于它们具有更具表达力的转移矩阵 (transition matrices),两者都比 Mamba2 稍慢(约慢 2–3K tokens/sec)
  • 正如论文的分析所示,与原始 Delta 规则相比,提出的门控 Delta 规则仅引入了边际开销,Gated DeltaNet 实现了与 DeltaNet 基本相同的吞吐量
  • 由于它们具有更具表达力的转移矩阵,两者都比 Mamba2(2-3K Token /秒)稍慢
  • Transformer++ 在 2K 上下文窗口领域实现了最佳性能
    • 这得益于高度优化的 Flash-Attention-2 内核 (2023)
    • 因此,将 2K 窗口大小的滑动窗口注意力 (sliding window attention, SWA) 与其他令牌混合器 (token mixer) 相结合的混合方法,其吞吐量高于独立的混合器:Samba 优于 Mamba,而 Gated DeltaNet-H1 和 Gated DeltaNet-H2 优于 Gated DeltaNet
    • 值得注意的是,Gated DeltaNet-H1 在所有序列长度上均保持了可观的训练吞吐量,即使在短序列上也是如此

Related Work

Gated linear RNN

  • 大型线性循环语言模型因其训练和推理效率而备受关注
  • 线性循环神经网络领域已经从使用数据无关的衰减机制迅速发展为在更近期的架构中融入数据相关的衰减机制
    • 数据无关的衰减机制:例如 S4 (2022)、S5 (2023)、LRU (2023)、RWKV4/5 (2023) 和 RetNet (2023a) 等模型所例证
    • 数据相关的衰减机制:例如 HGRN1/2 (2023a; 2024b)、Mamba1/2 (2023; 2024a)、RWKV6 (2024)、GSA (2024)
  • 这一转变源于门控/遗忘机制(在 Mamba 中称为选择性机制)被证实的优势,这是一个源于门控循环神经网络文献 (2000) 的经典概念,其重要性一直被反复确认 (2015; 2018; 2023b; 2024b)
  • 现代遗忘门与传统设计(如 LSTM 中的遗忘门)的不同之处在于,它移除了对先前隐藏状态的依赖,仅依赖于输入数据
    • 这种修改使得在序列长度上能够实现高效的并行性 (2018; 2023b)
    • 缺乏遗忘门一直是 DeltaNet 的一个显著局限,而论文的门控扩展以一种自然、有效且硬件高效的方式弥补了这一差距
    • 论文也注意到最近的一项并行工作 RWKV-7 使用了类似的想法,但采用了更宽松的使用对角加低秩转移的形式化表示:
      $$\mathbf{S}_{t}=\mathbf{S}_{t-1}(\mathrm{diag}(\mathbf{d}_{t})-\mathbf{a}_{t}\mathbf{b}_{t}^{\top})+v_{t}\boldsymbol{k}_{t}^{\top}$$
    • 其中 \(\mathbf{d}_{t},\mathbf{a}_{t},\mathbf{b}_{t}\in\mathbb{R}^{d_{k} }\)
    • 分块算法可以类似地适用于这种情况,正如在 Flash Linear Attention (2024) 中所实现的那样

Delta rule

  • Delta 学习规则在记忆容量上表现出优于赫布学习规则 (1988; 1989),DeltaNet 利用了这一点,而线性 Transformer 则依赖于类赫布规则
    • 这种记忆容量优势在合成上下文学习任务中表现明显,并延伸到语言建模 (2021; 2024b)、强化学习 (2022b) 和图像生成 (2023)
    • (2024b) 将 delta 规则计算并行化,并证明了 DeltaNet 的数据相关的单位加低秩结构 (\(\mathbf{I}-\beta_{t}\boldsymbol{k}_{t}\boldsymbol{k}_{t}^{\top}\)) 比 Mamba2 的数据相关的对角矩阵 (\(\alpha_{\mathbf{I} }\)) 提供了更大的灵活性
    • 这种结构优势可能实现复杂的推理,包括正则语言识别 (2024; 2024) 和超越 TC’ 复杂度的状态跟踪 (2024),这对于编码和推理应用至关重要
  • 尽管有这些显著优势,delta 规则面临理论局限 (2023) 并且在真实世界数据集上表现中等 (2024b),表明仍有改进空间
    • 先前通过非线性循环 (2021, 2022a) 来增强表达力的尝试解决了一些局限,但牺牲了训练并行性,造成了性能-效率的权衡
    • 最近的工作提出了一些不损害并行性的增强方法,以获得更好的状态跟踪性能,包括使用负特征值 (2024) 和使用多个户主转移矩阵的乘积 (2025),这实现了高秩变换
    • 这些方法可以无缝应用于 Gated DeltaNet
  • 从(在线)学习目标的角度来看,替代的公式化可以进一步扩展表达力:
    • 非线性回归 (\(\mathcal{L}(\mathbf{S}_{t})=\frac{1}{2}||\boldsymbol{f}_{\mathbf{S}_{t} }(\boldsymbol{k}_{t})-\boldsymbol{v}_{t}||^{2}\)),如 TTT (2024a) 和 Titans (2024) 中那样,其中 \(\boldsymbol{f}_{\mathbf{S} }\) 是由 \(\mathbf{S}\) 参数化的非线性函数;
    • 或者考虑整个历史的回归 (\(\mathcal{L}(\mathbf{S}_{t})=\frac{1}{2}\sum_{i=1}^{d}||\mathbf{S}_{t}\boldsymbol{k}_{i}-\boldsymbol{v}_{t}||^{2}\)),如 Mesa 层 (von 2024) 中那样——类似于最小均方算法和递归最小二乘算法之间的区别
    • 然而,这些更具表达力的变体引入了非线性循环,并且需要变通方法,例如在处理完整个分块后才执行非线性更新(如在 TTT 和 Titans 中);或者近似非线性循环方法,如 (2024; 2024; 2025)

Hybrid models

  • 在这项工作中,论文探索了在层间交错混合注意力层,这在诸如 MiniMax-01 (2025) 和 Hybrid Mamba2-Attention (2024) 中常用
  • 研究在单个层内混合线性/softmax 注意力也很有趣 (2022a; 2024; 2025)

NLP——LLM对齐微调-SDPO

注:本文包含 AI 辅助创作

  • 参考链接:
    • 原始论文:(SDPO)Reinforcement Learning via Self-Distillation, 20260128 & 20260216, ETH Zurich & Max Planck Institute for Intelligent Systems & MIT & Stanford

Paper Summary

  • SDPO 整体总结:
    • RLVR 方法仅从每次尝试的标量结果奖励中学习,存在信用分配瓶颈
      • 许多可验证环境实际上提供了丰富的文本反馈,比如运行时错误或评判员评估,这些反馈解释了尝试失败的原因
      • 注:
        • 将 RLVR 描述为仅结果奖励的方法也不一定准确吧,RLVR 主要是指有可验证奖励的 RL 场景,其实也可以有 PRM 等
        • 个人感觉 RLVR 更像一种场景,而不是一种方法
    • 本文将此设定形式化为具有丰富反馈的强化学习,并提出了 SDPO
    • SDPO(Self-Distillation Policy Optimization)是一种针对 LLM 在可验证环境(如代码、数学推理)中进行强化学习的方法
      • 理解:其实只要有反馈信号(比如 开放问答领域的 Reference Answer 或者 Reward Model 的打分),也可以考虑使用(做一些信号处理后输入 Teacher 模型的上下文)
    • SDPO 基本思路:
      • 通过利用环境提供的反馈(如运行时错误、测试失败信息、评语等)
      • 以自蒸馏的方式实现密集信用分配(Dense Credit Assignment)
      • 克服传统 RL 方法中因标量奖励导致的奖励稀疏性 带来的信用分配瓶颈
      • 理解:SDPO 的本意实际上就是想使用更丰富的奖励信号来给 Token 训练提效
    • 本文的范式可称为 RLRF 范式(带 Rich Feedback 的 RL)
      • 环境提供超出标量奖励的 Tokenized Feedback,这论证了这移除了 RLVR 的一个关键信息瓶颈
    • 作者强调 SDPO 的优点:
      • 在缺乏 Rich Feedback 的标准 RLVR 环境中进行训练,SDPO 在推理任务上也展现出比 GRPO 更优的样本效率和实际收敛时间
      • SDPO 的收益随着模型规模的增大而增长
        • 表明自我修正的能力与模型的上下文学习能力共同扩展
      • 在 Test-time 对单个困难的二元奖励任务执行 SDPO,与强基线相比,能加速解的发现
    • SDPO 能够以一种更接近人类认知的方式从 Rich Feedback 中学习:利用精确的结果而不仅仅是二元奖励
      • 通过允许模型回顾性地确定它本应如何行动
  • SDPO(本文) vs OPSD 方法:
    • 除了 OPSD 强调信号来自已有数据集 \(\mathcal{D}\) 的参考答案外,OPSD 几乎和本文 SDPO 思路一致,都是 On-policy 蒸馏的
    • SDPO 强调外部评估环境信号
      • 实际上,在 SDPO 原始论文的表 2 中可以看到:如果存在正确 Rollout 的话,Teacher 的 Prompt 中可能会包含之前生成的正确 Rollout 作为 Hint
    • OPSD 强调参考答案 \(y^*\),类似 SFT 的样本(注意:OPSD 要求包含的参考答案 \(y^*\) 是原始数据集中必须存在的,不是 Student 也不是 Teacher 生成的)
      • 两者算法上几乎没有差异(两篇文章几乎同时发出 20260126 vs 20260128,算是并发的工作,OPSD 引用了 SDPO )
      • 总结来说:两者核心区别在于 hint 信号不同:
        • OPSD 更强调仅使用自身(能生成至少一次正确答案的自身)模型,数据集必须包含参考答案 \(y^*\)(类似 SFT 的样本)
        • SDPO 则强调外部评估信号的引入(数据集中只需要 Query,不需要参考答案 \(y^*\),但需要外部评估的反馈信号)
  • 对 SDPO 评价:
    • 新颖的想法:将符号化反馈转化为密集学习信号;在 Token-level 分配 Advantage ,提升学习效率
    • 训练提效:在仅返回标量反馈的标准 RLVR 环境中
      • SDPO 将成功的 Rollout 作为失败尝试的隐式反馈,同样优于基线方法
    • Test-time 提效:
      • 在 Test-time 将 SDPO 应用于单个问题,可以加速在困难的二元奖励任务上的发现过程
      • 在达到与 Best-of-\(k\) 采样或多轮对话相同的发现概率时,所需的尝试次数减少了 \(3 \times\)
    • 无需外部 Teacher :完全自监督,适用于在线学习
    • 适配性强:可作为标准 RLVR 方法的即插即用替代
    • 可能存在问题:
      • 性能依赖于模型的上下文学习能力 ,对较弱模型可能还不如 GRPO
      • 反馈质量直接影响学习效果,要重点关注

Introduction and Discussion

  • 背景 & 分析:
    • 传统 RL:对经验(行动、接收反馈、更新策略)进行迭代,可以解锁仅从静态监督中难以获得的能力 (2015; 2016; 2017; 2019)
    • 在 LLM RL:RL 显著提高了在重推理任务上的性能,尤其是在具有程序化或其他可验证评估的环境中 (2024; 2025; 2025; 2025)
    • 目前 LLM 后训练的主流 RL 方法仍然受限于信用分配
      • 大多数现有方法在 “RLVR” 设定下运行:
        • 给定一个问题 \(x\),模型采样一个答案 \(y \sim \pi_{\theta}(\cdot | x)\) 并接收一个标量奖励 \(r \in \mathbb{R}\),通常是二元的(例如,代码生成中的单元测试通过/失败)
      • 现代策略梯度 RLVR 方法(如 GRPO)从这些稀疏的结果奖励中估计 Advantage
      • 而且:当一个组中的所有 Rollout 都获得相同(通常为零)奖励时,GRPO 的 Advantage 会坍缩为零,导致学习停滞
  • 已有一些解法:
    • 一些方法从强大的 Teacher 那里进行蒸馏 (2025; 2026)
      • 可以提供密集的 Token-level 监督
      • 但在在线学习中,当目标是提升现有模型的能力上限时,强大的 Teacher 通常是不可用的
  • 本文认为关键的限制不在于 RL 本身,而在于标量结果奖励所施加的信息瓶颈
    • 许多可验证环境暴露了超出标量奖励 \(r\) 的丰富 Tokenized 反馈,例如运行时错误、失败的单元测试或来自 LLM 评判员的评估
    • 这种反馈不仅揭示了 Rollout 是否正确,还揭示了错误所在
  • 本文将这个更一般的设定形式化为 Reinforcement Learning with Rich Feedback (RLRF)
    • 图 2 中说明了其与 RLVR 的区别
    • RLRF 中,反馈可以是 Agentic 系统达到的任何状态的任何 Tokenized 表示
  • 核心问题提出:如何在不依赖外部强大 Teacher 提供监督的情况下,将丰富的反馈转化为有效的信用分配?
  • 本文出发点是观察到 LLM 已经拥有一个使用反馈的强大机制(上下文学习)
    • 当以反馈为条件时,同一个模型通常能够识别出可能的错误并提出修正方法
    • 这种反馈的一个常见例子是像 LeetCode 这样的编码平台上失败测试用例的总结(图 3)
    • 许多近期的工作利用这种返回机制来迭代生成修正 (2021a; 2023; 2023; 2024; 2025; 2025)
    • 本文使用当前的策略作为一个 “Self-Teacher”
      • 在接收到丰富的反馈后,不是采样一个新的 Response,而是重新评估现有的 Rollout
      • 将反馈包含在上下文中会改变模型的下一 Token 分布,使得 Self-Teacher 能够在特定 Token 上同意或不同意 Student 原先的选择
      • 这产生了密集的、Logit-level 信用分配
    • 当提供图 3 中的反馈时, Self-Teacher 可以识别出应如何修改最初的尝试以避免运行时错误
      • 这种机制不会产生采样的开销:只需在 Self-Teacher 以反馈增强的上下文中重新计算原始尝试的对数概率
  • 本文提出一种通过自蒸馏执行 RL 的 On-Policy 算法 SDPO
    • SDPO 从当前策略中采样 Rollout,获得丰富的环境反馈,然后最小化一个 Logit-level 蒸馏损失,该损失使当前策略的下一 Token 分布与 Self-Teacher 的分布相匹配
    • SDPO 解决了将蒸馏应用于 Online Learning 时的核心限制:缺乏更强的外部 Teacher
    • SDPO 不依赖于固定的 Teacher ,而是利用模型事后识别自身错误的能力
    • Self-Teacher 以刚刚收到的丰富反馈为条件来构建当前策略 ,提供了蒸馏的密集监督,同时保留了 On-Policy RL 的探索 Advantage
    • 表 1 总结了 SDPO 相对于 RLVR 和蒸馏基线所处的位置(这里强调 OPD 和 SDPO 的区别主要是前者需要一个 更强的 Teacher 模型)
  • 注:本文在第 6 节中包含了相关工作的全面总结
  • 注:SDPO 是一种策略梯度算法,其 Advantage 是使用 Self-Teacher 估计的
    • SDPO 只需替换掉 Advantage 项,就能对标准的 RLVR 流程进行微小改动来实现
  • 本文在三种在线 RL 设定中评估 SDPO:
    • Learning without rich feedback (§3):
      • 评估场景:不返回任何标量奖励之外反馈的标准 RLVR 环境
      • SDPO 将当前批次中采样到的成功尝试视为对同一问题上失败尝试的“反馈”
      • 使用 Qwen3-8B 和 Olmo3-7B-Instruct 作为起点,在科学推理和工具使用任务上进行了训练
      • 结果:
        • SDPO 在总体最终准确率上优于整合了近期改进的强 GRPO 基线(分别为 \(70.2%\) 和 \(66.6%\))
        • 与 GRPO 相比,SDPO 在实现更高准确率的同时,生成长度最多可缩短 \(11\times\)
          • 这表明有效的推理不一定需要冗长
    • Learning with rich feedback (§4):
      • 评估场景:来自 LiveCodeBench v6 的、带有 LeetCode 风格反馈的竞争性编程问题
      • 如图 1 所示
        • SDPO 显著优于 GRPO,达到了更高的最终准确率(分别为 \(48.8%\) 和 \(41.2%\)),并且达到 GRPO 最终准确率所需的生成次数减少了 \(4\times\)
        • SDPO 的增益随着模型规模的扩大而增长,这表明自教学的能力随着模型成为更强的上下文学习者而出现
    • Discovering novel solutions to hard tasks at test-time(§5): 评估模型在推理阶段,发现困难任务的新颖解决方案
      • SDPO 可以加速在困难的二元奖励问题上找到解决方案的过程
        • RLVR 只有在找到第一个解决方案后才会开始学习
        • 利用 SDPO 进行 Test-time 自蒸馏,这是一种模型针对单个测试问题进行特化的 Test-time 训练形式
        • 这里针对的是 LiveCodeBench 中非常困难的问题(基础模型的 pass@64 低于 0.03)
        • SDPO 将发现解决方案的速度加快了 \(3\times\)

SDPO: Self-Distillation Policy Optimization

  • SDPO 利用当前策略的上下文学习能力来进行信用分配
  • Self-Teacher 是 \(\pi_{\theta}(\cdot \mid x,f)\),以问题 \(x\) 和丰富反馈 \(f\) 为条件的当前策略(“Student”)
    • 除了 Student 最初的尝试 \(y\),\(f\) 可以包含两种关键的反馈信息:
      • 任何环境输出:例如来自代码环境的运行时错误
      • 已有解决方案:比如 \(x\) 已经在 Rollout 组中被另一个尝试解决了的样本解决方案
    • Self-Teacher \(\pi_{\theta}(\cdot |x,f)\) 应该比 Student \(\pi_{\theta}(\cdot |x)\) 有更高的准确率,因为它在上下文中看到了额外的信息
  • 这引导作者得出如下观察:

    可以将同一个策略用于两种不同的角色:作为 Student 进行最初的尝试,以及作为 Teacher 来确定行动在事后的价值

  • SDPO 反复地将 Self-Teacher 蒸馏到 Student 中
    • 给定一个问题 \(x\)
      • 首先从 Student \(\pi_{\theta}\) 中采样 Rollout,并获得相应的环境反馈
      • 然后使用 KL 散度,作为 Student 和 Teacher 下一 Token 分布之间的距离度量
        $$ \text{KL}(p| q) = \sum_{i}p(i)\log \frac{p(i)}{q(i)} $$
      • 并优化一个标准的 Logit 蒸馏损失:
        $$\mathcal{L}_{\text{SDPO} }(\theta):= \sum_{t}\text{KL}(\pi_{\theta}(\cdot |x,y_{< t})| \text{stopgrad}(\pi_{\theta}(\cdot |x,f,y_{< t}))) \tag{1}$$
        • stopgrad 运算符阻止梯度通过 Teacher 传播,防止 Teacher 向 Student 退化而忽略 \(f\)
        • Teacher 的直观作用是,基于反馈 \(f\) 通过回溯来确定 Student 的原始尝试 \(y\) 在何处以及如何出错
  • 图 4 展示了使用 Qwen3-8B 作为 Student 和 Self-Teacher 的自教学示例
  • 算法 1 中总结了 SDPO
  • 表 2 中展示了 Teacher 的报告模板
    • prompt 替换为问题
    • 一个先前由 Student 生成的样本解决方案被替换为 successful previous rollout(如果该问题有的话;否则跳过该段落)
    • environment_output 被替换为环境输出(例如,见图 3),该输出来自模型的原始尝试(如果尝试不成功且没有解决方案;否则跳过该段落)
    • 如果模型的原始尝试成功,则将此尝试作为正确的解决方案传递
    • original_response 被替换为模型的原始尝试,以便在 Self-Teacher 下重新评估其对数概率

Proposition 2.1,SDPO 的梯度

  • 本节详情参见附录 A.1
  • \(\mathcal{L}_{\text{SDPO} }\) 的梯度是
    $$\nabla \mathcal{L}_{\text{SDPO} }(\theta) = \mathbb{E}_{y\sim \pi_{\theta}(\cdot |x)}\left[\sum_{t = 1}^{|y|}\mathbb{E}_{\hat{y}_{t}\sim \pi_{\theta}(\cdot |x,y_{< t})}\left[\log \frac{\pi_{\theta}(\hat{y}_{t}\mid x,y_{< t})}{\pi_{\theta}(\hat{y}_{t}\mid x,f,y_{< t})}\cdot \nabla_{\theta}\log \pi_{\theta}(\hat{y}_{t}\mid x,y_{< t})\right]\right]. \tag{2}$$

Comparison to RLVR

  • SDPO 梯度是一个(负的)Logit-level 策略梯度,其 Advantage 是使用 Self-Teacher 估计的
    • 也就是说:SDPO 可以重用标准的 RLVR 实现,只需替换掉 Advantage 项
  • 设 \(y_{i}\) 是问题 \(x\) 的规模为 \(G\) 的 Rollout 组中的第 \(i\) 个 Rollout,那么比较 GRPO 和 SDPO,有:
    $$
    \begin{align}
    A_{i,t}^{\text{GRPO} }:&=r_{i}-\text{mean}\{r_{i}\}_{i=1}^{G}(\text{constant in }t),\\
    A_{i,t}^{\text{SDPO} }(\hat{y}_{i,t})&=\log\frac{\pi_{\theta}(\hat{y}_{i,t}\mid x,f_{i},y_{i,t})}{\pi_{\theta}(\hat{y}_{i,t}\mid x,y_{i,t})}
    \end{align}
    $$
    • GRPO 的 Advantage 仅应用于 Sampled Token \(y_{i,t}\),并且在 Rollout \(y_{i}\) 内是常数
      • 问题:这里有必要强调 Advantage 仅应用于 Sampled Token 吗?
      • 回答:有的,这里是想强调 SDPO 本身的 Advantage 是词表上所有可能的 Token 都有的
    • SDPO 的 Advantage 仅当 Student 和 Teacher 在 Token (对数概率)上完全一致时才为零
      • 在 Teacher 下更有可能出现的 Token(相对 Student),SDPO Advantage 为正
      • 在 Teacher 下不太可能出现的 Token(相对 Student),SDPO Advantage 为负
    • SDPO 可以在两个方面被视为标准 RLVR 方法的直接扩展:
      • 1)从 1-bit 反馈扩展到允许任意 Token 序列作为反馈
      • 2)利用这种丰富的反馈来估计密集的 Logit-level Advantage
  • Off-policy 实现:
    • SDPO 方法本身是与 RLVR 方法解耦的,我们可以通过 PPO 风格的裁剪重要性采样,从而将公式 (2) 中的 SDPO 梯度轻松扩展到 Off-Policy 数据
    • 详见附录 A.4

Compute time & memory

  • 与 GRPO 相比,SDPO 唯一的计算开销是需要额外计算 Self-Teacher 的对数概率,这可以有效地并行化,并且比顺序生成快得多
  • 图 5 比较了 SDPO 和 GRPO 的计算时间
    • SDPO 的计算开销相对较小 (+17.1% 或 +5.8%)
    • 注:这里使用的 Micro Batch Size 为 \(2\),通过使用更大的 Micro Batch Size 可以进一步减少计算时间
      • 这里的 MBS 对应一次累计梯度使用的 Rollout 数量
  • 朴素地计算 Student 和 Teacher 之间的 KL 散度需要在内存中保存两者的完整 Logits
    • 为避免这种情况,在 SDPO 损失中执行 Top-\(K\) 蒸馏来近似 KL 散度
      • 即:仅计算两部分:
        • 1) Student 的 Top-\(K\) Logits 和 Teacher 对应的 Logits
        • 2)一个捕捉尾概率的项
      • 参见附录 A.3
      • 通过合理选择 \(K\)(例如,\(K = 100\)),这几乎可以避免任何内存开销,同时捕获大部分信息

Stability improvements

  • 两个实际的修改可以显著增强 SDPO 的训练稳定性
    • 1)采用正则化的 Self-Teacher ,通过 Student 参数的指数移动平均或通过将当前 Teacher 与初始 Teacher 进行插值来实现(参见附录 A.2)
      • 后面会详述,这两种策略都能有效地稳定学习
    • 2)采用对称的 Jensen-Shannon 散度作为蒸馏损失
      • 已有研究表明,这种形式也能提高从外部 Teacher 进行 On-Policy 蒸馏时的稳定性 (2024)

Learning without Rich Environment Feedback

  • 标准的 RLVR 环境中的反馈仅限于标量奖励
  • 在这个场景下,SDPO 不使用标量奖励,而是将当前批次中采样到的成功尝试视为对同一问题上失败尝试的“反馈”
    • 通过将 Student 的尝试与正确的解决方案进行比较
    • Self-Teacher 可以识别出 Student 在何处出错,并提供密集的信用分配

Experimental setting

  • 评估模型尚未经过显式微调的任务:
    • Science Q&A (化学、物理、生物学、材料科学): 使用 SciKnowEval (2024a) 中的推理子集 (L3) 进行本科水平的科学推理
    • Tool use: 根据 ToolAlpaca (2023),将工具 API 规范和用户请求映射到正确的工具调用
  • 对数据进行训练/测试分割以测试领域内的泛化能力
  • 使用 Qwen3-8B (2025a) 和 Olmo3-7B-Instruct (2025) 作为初始检查点
  • 报告相对于实际训练时间(不包括初始化和验证)的 avg@16
Baselines
  • 将 SDPO 与一个改进版的 GRPO 进行比较
    • 该版本 GRPO 整合了最近的几项修改 (2025; 2026),例如非对称裁剪 (2025)、避免有偏归一化 (2025b) 以及在使用高效推理框架时校正 Off-Policy 数据 (2025)
    • 将这些修改整合到一个 GRPO 实现中,代表了一个强基线,详见附录 A.4 中的公式 (12)
  • GRPO 通过 PPO 的裁剪重要性加权支持 Off-Policy 训练 (2017)
    • 本文还报告了 On-Policy GRPO 的特例(与原始 SDPO 的超参数匹配)
  • 对于这两个基线,执行了超参数扫描,并报告了在所有目标任务上实现最高验证性能的模型的结果
  • 超参数和训练细节在附录 E 中提供
  • 使用 ver1 库 (2025) 进行快速的多 GPU 训练

Results

  • 表 3 总结了本文的结果
    • SDPO 在几乎所有运行中都优于 GRPO,通常带来显著的改进
    • SDPO 的学习速度明显快于 GRPO
      • 在几种情况下,SDPO 仅训练 1 小时就达到了 GRPO 训练 5 小时的性能
  • 如图 6(左)所示,SDPO 在化学任务上比 GRPO 取得了特别显著的改进
    • 使用 Olmo3-7B-Instruct,SDPO 在 50 分钟的实际训练时间内就达到了 GRPO 5 小时的准确率,速度提升了 \(6\times\)
    • SDPO 的 5 小时准确率比 GRPO 高出超过 \(10%\) 的百分点
  • 使用 SDPO 的结果严格采用了 On-Policy 训练(即,每个生成批次只进行一次梯度更新)
    • 考虑到 Off-Policy 方法(每个生成批次执行多次梯度更新)已知的效率 Advantage
  • 注:作者认为研究 SDPO 与 Off-Policy 更新相结合是一个令人兴奋的未来方向
Takeaway 1
  • SDPO 能够学习有效地进行推理,泛化到具有挑战性的推理任务
  • 无需对现有的 RLVR 环境进行任何修改,SDPO 在多个案例中显著优于 GRPO

Self-distillation learns to reason concisely

  • 实验观察可知,SDPO 生成的 Response 比 GRPO 短得多,同时实现了更高的准确率
    • 平均而言,SDPO 的 Response 比 GRPO 短 \(3\times\) 以上(参见附录 D 的表 8)
    • 在化学任务上使用 Olmo3-7B-Instruct,SDPO 甚至将 Response 长度减少了 \(11\times\),同时保持了更高的准确率(图 6(右))
  • 注:RLVR 上的实验表明,扩展 Response 长度是激发新兴推理能力的强大驱动力 (2024; 2025; 2025)
    • 本文的结果表明,有效的推理不一定总是冗长的(比如 SDPO 就提高了推理的效率)
  • 定性地看,GRPO 较长的 Response 通常源于“表面”的推理,而非必要的分析步骤
    • GRPO 经常生成诸如“Hmm”和“Wait”之类的填充短语,或者进入循环逻辑,逐字重复之前的步骤
    • 图 7 显示了这种现象的一个代表性例子
      • SDPO 的生成保持简洁,避免了这些表面模式
      • 这可以通过 SDPO 的密集信用分配来解释,它为每个下一 Token 预测分配一个特定的 Advantage ,从而产生稀疏的 Advantage (参见附录 F 的图 21)
      • 通过提高推理效率,SDPO 减少了推理生成时间,并表明推理性能可以通过改进模型的推理方式(而不仅仅是推理长度)来提高

Learning with Rich Environment Feedback

  • 本节在编码任务上评估 SDPO,编码是一个典型的 RL 环境示例,提供了丰富的反馈,例如运行时错误和失败的单元测试
    • 学习解决这些编码问题需要强大的信用分配,因为模型必须识别其精确的错误,以避免在未来重复这些错误
  • LiveCodeBench (2025) 提供了一组竞赛风格的编码问题,范围从简单到竞赛级别
    • 本文将评估限定在 LCB 的最新子集 LCBv6 上,该子集包含 2025 年 2 月至 5 月发布的 131 个问题
    • 本文考虑一个包含公开和私有单元测试的设置,这在代码竞赛和 LeetCode 等编码平台中很常见,其中公开测试用于训练期间的评估,私有测试用于验证 (2022; 2022; El-2025; 2025)
  • 本文在实验中使用 Qwen3 (2025a) 模型系列,除非另有说明,默认使用 Qwen3-8B
    • 本文报告 4 次 Rollout 的平均准确率,并使用与第 3.1 节中概述的相同的 GRPO 基线
  • Results
    • 图 1 比较了 SDPO 和 GRPO 在 LCBv6 上的学习曲线
    • SDPO 达到了比 GRPO 显著更高的最终准确率 \((48.8%)\),相比 GRPO 的 \((41.2%)\),同时也优于公开 LCBv6 排行榜上最强的指令模型(Claude Sonnet 4 \((40.5%)\) 和 Claude Opus 4 \((39.7%)\))
    • SDPO 达到 GRPO 最终准确率所需的生成次数减少了 \(4\times\)
  • 在附录的表 9 中包含了与其他性能与 GRPO 相似的 RLVR 基线的扩展比较
    • 本文根据 LCB 的简单、中等和困难问题进行分类,SDPO 在解决中等和困难问题方面相比 GRPO 有特别显著的改进(参见附录的图 15)

Self-distillation benefits from stronger models

  • 本文工作的一个核心问题是 SDPO 是否对基础模型的上下文学习能力敏感
    • 我们期望 SDPO 受益于强大的上下文学习器,因为这使得 Teacher 能够执行更准确的回顾
  • 为了回答这个问题,本文对 Qwen3 系列的不同模型大小进行了扩展研究
    • 正如大量先前工作所证明的,上下文学习的能力随着模型大小的增加而增强(例如,2020)
  • 如图 8 所示
    • SDPO 在较大模型上显著优于 GRPO,而在较小模型上仅略微优于 GRPO
  • 为了确定 SDPO 在比 Qwen3-0.6B 更弱的模型上是否也可能表现不如 GRPO,本文使用 Qwen2.5-Instruct (2024) 进行了额外的扩展研究
    • 在使用 Qwen2.5-7B 时 SDPO 优于 GRPO,在使用 Qwen2.5-8B 时 SDPO 和 GRPO 表现相当
    • SDPO 在 Qwen2.5-1.5B 上表现不如 GRPO,如附录 D 的图 17 所示
Takeaway 2
  • 本文实验结果表明:SDPO 相对于 GRPO 的边际改进与基础模型的强度紧密相关
    • 就像上下文学习是随着规模扩展而出现的一种现象一样
    • SDPO 中 Self-Teacher 的准确回顾能力似乎也是随着规模扩展而出现的
  • 未来建议对比 Qwen3-8B 更强的模型进行研究

Self-distillation performs dense credit assignment

  • GRPO 为每个生成的 Token 分配一个恒定的 Advantage (Sequence-level)
  • SDPO 根据 Student 和 Teacher 的一致性,为生成序列中每个位置上的每个可能的下一个 Token 分配一个单独的 Advantage (Token-level)
    • 在生成序列 \(y\) 的每个位置 \(t\),有 \(|\nu|\) 个可能的下一个 Token,其中 \(\nu\) 是词表
  • 在蒸馏中,这个层级通常被称为 Logit 层级,因为它对应于模型的 Logits
    • 在实践中,本文通过前 \(K\) 个 Token 加上 尾部来近似完整的下一个 Token 分布
    • 也就是说:SDPO 为每个序列分配 \(|y|\cdot (K + 1)\) 个唯一的 Advantage
      • 注:这里的 \(|y|\) 是当前序列的 Token 数
    • 如图 9 所示,这允许 SDPO 执行密集的信用分配
      • 在图 4 的示例中,SDPO 的密集信用分配
      • 蓝色显示的是在 Self-Teacher 下变得更有可能的 Token
      • Self-Teacher 识别出返回的 range 语句必须如何修改,以便它不包含 n
  • 一个自然的问题是,SDPO 的性能提升是由于利用了 RLRF 中的丰富反馈,还是由于 SDPO 的密集信用分配?
  • 为了回答上述问题,作者在三种配置中进行了 SDPO 性能的消融实验:
    • Logit-level SDPO : 在每个位置上对(Student 认为的)最有可能的 100 个 Token 进行信用分配
    • Token-level SDPO : 在每个位置上只对最有可能的 Token 进行信用分配
    • Sequence-level SDPO : 在整个序列上计算 Generated Token 的 SDPO Advantage 的均值
      • 这产生的是每个序列的单个标量 Advantage (如 GRPO 中那样)
      • 注:Sequence-level SDPO 并不比 GRPO 执行更密集的信用分配,但仍然利用了丰富的反馈 \(f\)
  • 如图 10(左)所示,Logit-level SDPO 的密集信用分配带来了比 Token-level SDPO 和 Sequence-level SDPO 显著的性能提升
    • 但即使 Sequence-level SDPO 也优于 GRPO
      • 这表明即使在没有密集信用分配的情况下,在 RLRF 中利用丰富的反馈也能带来比 RLRV 方法实质性的收益
  • 图 10:
    • 左图: RLRF 中的丰富反馈和 SDPO 的密集信用分配是互补的
      • 本文比较了 Logit-level、Token-level 和 Sequence-level 的 SDPO Advantage 与 GRPO
      • 虽然 SDPO 中更密集的信用分配是有益的(Logit-level > Token-level > Sequence-level),但即使 Sequence-level SDPO 也因利用了丰富的反馈而显著优于 GRPO
      • 误差线表示 3 次随机种子的标准误差
    • 右图: Self-Teacher 在训练过程中得到改进
      • Self-Teacher 与 Student 相比在当前训练批次上的生成准确率(5 步滚动平均)
        • 最终 Student 得分在第 80 步取得
      • Student 的性能显著超过了初始 Teacher 的准确率
      • 误差线表示 3 次随机种子的标准差

The self-teacher improves during training

  • 与标准的蒸馏相反,SDPO 中的 Self-Teacher 不是冻结的,而是在整个训练过程中更新
    • 这是 SDPO 的一个关键组成部分,因为它使 Teacher 能够随着时间的推移而改进,这意味着 Student 可以从更强的目标中学习
  • 为了研究 Self-Teacher 是否在训练过程中得到改进,本文在图 10(右)中绘制了使用 Self-Teacher 生成时的平均准确率
    • Self-Teacher 在训练过程中显著提高
  • 最值得注意的是, Student 的在训练的后期阶段, Student 的准确率超过了初始 Teacher 的准确率。这表明 SDPO 能够实现从弱模型到强模型的真正的自举,而不受初始 Self-Teacher 性能对最终 Student 的限制
  • 表 4: 第 90 步之前各种 Teacher 正则化方法的最佳/平均准确率
    • Trust-region 和 EMA Teacher 使用 \(\alpha = 0.01\)。\(q_{\theta}\) 的训练最终发散
    • 误差范围表示 3 次随机种子的标准误差
  • 如第 2.3 节所述,SDPO 使用正则化的 Teacher 来稳定训练
    • 从表 4 可以看出,未正则化的 Teacher 性能显著低于正则化的 Teacher
    • 信任区域 (Trust-region) 和 EMA Teacher 的表现优于冻结在初始 Teacher 参数上的 Teacher ,这表明 Teacher 通过与 Student 的参数共享得到了改进
    • 即使使用冻结的 Teacher ,SDPO 也表现良好

On-policy self-distillation avoids catastrophic forgetting,OPSD 防止了灾难性遗忘

  • 先前的工作表明,On-policy 算法(如 GRPO)的一个关键好处是模型往往不会忘记先前获得的能力 (2026b; 2025b; 2025)
    • 这在实践上是可取的,因为它支持持续训练流程,即模型可以在不同任务上顺序训练,而无需从头开始重新训练
  • 为了评估遗忘,本文在各种保留任务上测试了 GRPO 和 SDPO 的最终检查点:
    • IFEval (2023):测试模型遵循精确格式指令的能力
    • ArenaHard-v2 (2025a):基于 LLM 评判的真实世界指令跟随 Prompt 的基准,源自 LMArena (2024)
    • MMLU-Pro (2024b):测试广泛的多任务知识和推理能力
  • 如表 5 所示,SDPO 在学习新任务的同时减轻了初始能力的退化,总体上实现了比 GRPO 更好的性能-遗忘权衡
    • 注:这里看起来 SDPO 相对 GRPO 的提升不太明显
Off-policy self-distillation baseline
  • 作为一个额外的基线,考虑通过 SFT 来训练 Student ,使用来自 Self-Teacher 的成功生成 (2023; 2024; 2025)
    • 对于相同数量的步骤,这需要 \(2\times\) 的 SDPO 生成次数,因为必须同时从 Student 和 Teacher 生成
  • 本文报告了对 Self-Teacher 的成功生成进行 SFT 的结果
    • 这比在 SFT 数据中同时包含来自 Student 的初始成功取得了更高的准确率
    • 如表 5 所示,对 Self-Teacher 进行 SFT 在 LCBv6 上的表现显著不如 SDPO,同时导致对先前能力的遗忘更严重
      • 这反映了先前关于 Off-policy 模仿不稳定的发现(例如,2024)

Can GRPO and SDPO be combined?

  • GRPO 使用蒙特卡洛 Advantage ,这对于最大化期望奖励的目标是无偏的:
    $$J(\theta) := \mathbb{E}_{y\sim \pi_{\theta}(\cdot |x)}[r(y|x)]$$
  • SDPO Advantage 相对于 \(J(\theta)\) 本质上是有偏的,因为它们是由丰富的反馈和 Self-Teacher 计算得出的
  • 这种二分法类似于 RL 中蒙特卡洛和 Bootstrapped Advantage 之间的根本区别:
    • 虽然后者有偏,但它们通常具有更低的方差 (1998; 2016)
    • 这激发了一种混合方法,结合了从奖励派生的 GRPO Advantage 和从反馈派生的 SDPO Advantage :
      $$A_{i,t}^{\text{SDPO + GRPO} }(\hat{y}_{i,t}) := \lambda A_{i,t}^{\text{GRPO} }(\hat{y}_{i,t}) + (1 - \lambda) A_{i,t}^{\text{SDPO} }(\hat{y}_{i,t}), \quad \lambda \in [0,1]. \tag{3}$$
  • 如图 11 所示,SDPO+GRPO 在较弱模型上似乎比 SDPO 更鲁棒
    • 在像 Qwen3-0.6B 这样的较弱模型中,SDPO Advantage 的可靠性较低,因此包含 GRPO Advantage 有助于稳定训练
    • 在像 Qwen3-8B 这样的强模型中,SDPO+GRPO 略逊于 SDPO
    • 这表明,仅由标量奖励提供信息的 GRPO 信号,在初始模型较强时可能是有害的

Which feedback is most informative?

  • 为了理解哪种类型的丰富反馈信息最丰富,本文消融了在像代码生成这样的可验证环境中存在的三种反馈类型:
    • 样本解决方案:仅在当前 Rollout 组中有一个成功的 Rollout 时可用
    • 环境输出:例如运行时错误
    • Student 的原始尝试
      • 问题:Student 的原始尝试没有任何反馈要怎么用?
      • 回答:这里只是写出来作为验证,下文会证明这个做法没有用,而且是有害的
Sample solutions
  • 在失败尝试的 Rollout 组中包含一个样本解决方案,这与 GRPO 的组相对 Advantage 非常相似
    • 注:这些样本解决方案总是由 Student 生成的,就像在 GRPO 中一样,并且不需要专家模型
  • 如果模型已经能够解决问题,可以用来抑制不成功的方法
    • (但与 GRPO 中所有 Token 都收到相同负 Advantage 不同) Self-Teacher 可以识别特定的错误并提供关于如何修复它们的反馈
Environment output
  • 环境输出描述了 Student 尝试后的环境状态
  • 这与样本解决方案是互补的,因为即使 Student 以前从未解决过该问题,它也能提供有用的信号(作者将在第 5 节中广泛探讨这种设置)
  • 利用环境输出是 RLRF 和 RLVR 设置之间的一个关键区别因素
Student‘s original attempt
  • Student 的原始尝试 \(y\) 不必包含在 Teacher 的 Reprompting 模板中
  • 事实上,包含它会使 Teacher 偏向于 Student 的尝试(参见表 6)
    • 这降低了 Student 分布的熵(特别是在最初不确定的 Token 上),从而减少了探索
  • 表 6 中总结了结果,本文评估了对 SDPO 训练的影响以及对 Self-Teacher 的直接影响
    • 本文发现环境输出和样本解决方案是互补的,各自提供信息丰富的反馈
    • 注:通常,性能对表 2 中 Reprompting 模板的句法变化不敏感
  • 表 6 的说明:
    • 本文根据 SDPO 训练(直到第 60 步)以及对 Self-Teacher 的直接影响来评估反馈的信息量
    • “相同输出” 衡量:Teacher 收到与 Student 初始尝试相同的环境输出的情况百分比(即,不探索替代方法)
      • 问题:这句话如何理解?

        “Same output” measures the percentage of cases where the teacher receives the same environment output as the student’s initial attempt (i.e., not exploring alternative approaches).

      • 环境输出和样本解决方案是互补的,各自提供信息丰富的反馈
    • 仅包含解决方案或初始尝试 \(y\) 会显著降低 Teacher 和 Student 的多样性
    • 样本解决方案是由 Student 生成的,这使得类似于 GRPO 的组相对 Advantage 估计成为可能
    • 误差线表示 3 次随机种子的标准差

Solving Hard Questions via Test-Time Self-Distillation

  • 第 3 节和第 4 节已证明了 SDPO 能够在推理任务的 “train-time RL” 中显著优于 RLVR 方法
  • 本节转向一个 Test-time 设定,其中模型仅被给予一个困难的(二元奖励)问题 \(x\) ,并且必须尽快发现一个解:
  • Definition 5.1(Discovery time)
    • 发现时间是直到找到解所需的尝试次数(即最小的 \(k\) ,使得第 \(k\) 次尝试 \(y_{k}\) 获得奖励 1)
  • 基于这个概念,我们可以定义发现效率的度量:
    $$
    \begin{align}
    \text{discovery@k}:&= \mathbb{P}(\text{discovery time}\leq k)\\
    &= \mathbb{P}(r(y_1\mid x) = 1\text{ or }r(y_2\mid x) = 1\text{ or }\ldots \text{ or }r(y_k\mid x) = 1)
    \end{align}
    \tag{4}
    $$
    • 其中概率覆盖了生成 \(y_{k}\) 的算法以及奖励中的所有随机性
    • discovery@k 度量了在 \(k\) 次内发现解的概率
  • 原文中一句难以理解的话:

    While prior work has studied discovery with continuous rewards (2025;2026), discovery with language models in sparse or binary-reward settings does not allow “hill-climbing” a continuous reward and has remained less well understood.

    • 个人理解:这句话的意思是,在不同环境中,模型学习(Discovery)过程不同:
      • 在连续奖励环境 中,模型可通过逐步 优化(“Climbing”) 奖励梯度来逼近最优解
      • 在稀疏奖励 或二元奖励 的环境中,语言模型的 “Discovery” 过程无法像在连续奖励环境 中逐步 Climbing
  • 在二元奖励任务中最朴素发现方法是重复地从基础模型中独立同分布地采样,也称为 best-of-\(k\)
    • 用于 best-of-\(k\) 采样的标准 pass@k 指标正是从固定模型中进行 \(k\) 次独立采样时至少发现一个解的概率(与 discovery@k 一致)
    • discovery@k 指标将 pass@k 推广到顺序采样尝试的算法
    • 一种常见的顺序方法是使用来自先前尝试的额外上下文来 Reprompting 基础模型(2023;2023)
      • 这可称为多轮采样(multi-turn) ,此时模型本身并未改变,只有其上下文随时间演化
  • 论文中一个核心的 Insight:
    • 对问题 \(x\) 执行 RLVR 并不会比从基础模型进行 best-of-\(k\) 采样有所改进
      • 因为在第一个解被发现之前,二元奖励不提供任何信号
    • 像 SDPO 这样的 RLRF 方法则没有同样的限制,因为它在每次尝试后都会从环境中接收到 Rich Feedback
      • 这种 Rich Feedback 使得模型能够在遇到错误并收到反馈后反复“修正”其错误,甚至在发现解之前就能做到
      • 与多轮采样相比,SDPO 通过将 \(\pi_{\theta}(\cdot | x, c)\) 蒸馏到模型 \(\pi_{\theta^{\prime} }(\cdot | x)\) 中,反复压缩上下文
        $$c = (y_{k}, f_{k})$$
        • 如图 12 所示
      • 这种自蒸馏使得 SDPO 能够在长上下文中持续学习,而 Transformer 的内存瓶颈从根本上限制了多轮采样的上下文长度(2017)
  • 本节试图回答以下问题:通过自蒸馏反复将上下文压缩到模型权重中能否加速困难问题的发现?

Experimental setting

  • 本文考虑 LCBv6 中一个特别具有挑战性的子集
    • 这些问题是 Qwen3-8B 性能的上限,需要大量的 Test-time 采样才能找到解
  • 本文使用 Qwen3-8B 的 pass@k 定义了两组任务:
    • 困难任务 (Hard tasks),其 pass@64 < 0.5
    • 非常困难的任务 (very hard tasks),其 pass@64 < 0.03
  • 在这些任务中,仅保留那些 best-of-\(k\)、多轮或 SDPO 方法在 5 个随机种子的 512 步内至少找到一个解的问题
    • 这产生了 19 个困难问题和 9 个非常困难的问题
    • 问题:这里的 512 步是指 512 轮交互吗?
  • 对于基础模型下的 best-of-\(k\) 采样
    • 本文报告来自 2944 次独立 Rollout 的标准 pass@k 估计值 (2021b)
  • 对于多轮采样下,使用先前尝试的串联反馈在上下文中顺序地 Reprompting 模型
    • 为了保持在 Qwen3-8B 的 40k Token 上下文限制内,本文采用先进先出的滑动窗口,一旦达到最大提示长度(32k Token),就丢弃最早的反馈
  • 附录 D 的图 19 中消融了多轮 Reprompting 策略
    • 发现:遗忘过去尝试(仅保留过去反馈) 显著优于 保留过去尝试(过去尝试 + 过去反馈)
  • 本文使用 16 的 Batch size 评估 SDPO,并在附录 D 的图 19 中消融了这个选择
    • 发现:总体性能差异很小,但有点差异
      • 较小的 Batch size 有利于在低生成预算下取得改进
      • 较大的 Batch size 能产生更稳定的更新,这些更新在运行的后期阶段仍然能够学习解决问题

Results

  • 图 13 比较了 SDPO、多轮采样和 best-of-\(k\) 采样在 LCBv6 的非常困难(左)和困难(右)问题上的 discovery@k
    • 在几乎所有生成预算下,SDPO 都实现了显著更高的 discovery@k 比率
  • 在非常困难的任务上:
    • 多轮和 best-of-\(k\) 在可用的生成预算内很大程度上未能解决问题
      • discovery@2750 分别仅为 \(35.6%\) 和 \(41.5%\)
    • SDPO 表现最好
      • 在 \(53.2%\) 的情况下发现了解
    • SDPO 不仅在总体上解决了更多问题,而且使用的尝试次数也少得多
    • 为了在非常困难的问题上达到 \(22%\) 的发现概率,SDPO 所需的生成次数比 best-of-\(k\) 和多轮采样少大约 \(3\times\)
  • 在困难任务上
    • SDPO 达到了 \(78%\) 的 discovery@2750 概率
      • 且以比 best-of-\(k\) 和多轮采样少大约 \(2.4\times\) 的生成次数实现了 \(67%\) 的发现概率
    • 多轮和 best-of-\(k\) 采样分别只解决了 \(68.4%\) 和 \(72.3%\) 的问题
    • 多轮采样的上下文窗口长度在困难问题的 837 (\(\pm 466\)) 步后和非常困难问题的 1007 (\(\pm 349\)) 步后达到极限
      • 这为其在高生成预算下收益递减提供了一个可能的解释
Question 3 is only solved by SDPO
  • SDPO 解决了所有被 best-of-\(k\) 和多轮采样解决的问题,且只有 SDPO 能为 Q3 发现了一个解
    • 该问题在 2750 次尝试内既无法通过多轮采样也无法通过 best-of-\(k\) 采样解决
  • SDPO 在 321 次尝试后首次为 Q3 发现了一个解
    • 这对应于使用 16 的 Batch size 进行 20 次基于反馈的自蒸馏迭代步骤
  • 在附录 D 的表 10 中包含了详细的每个问题的结果
The initial self-teacher does not solve hard questions
  • 对于几乎所有问题, Self-Teacher 的初始准确率都 \(< 1%\) ,甚至在 \(78%\) 的问题上精确为 \(0%\) (附录 D 的表 11)
    • 这表明单轮上下文中的反馈不足以解决问题,但 Self-Teacher 的信用分配对于 SDPO 迭代地精炼策略并最终解决这些问题来说已经足够有效
Takeaway 3
  • 本文证明了丰富的环境反馈使 SDPO 能够显著加速困难问题的发现
  • 对比 RLVR:RLVR 方法仅接收二元奖励信号,因此只有在第一个解已经被找到后才能开始学习

Related Work

Reinforcement Learning with LLMs

  • LLM 在 RL 上的进展主要得益于使用奖励的蒙特卡洛估计的 RLVR 方法,例如 STaR 或 GRPO (2022;2024),类似于经典的 REINFORCE 算法 (1992)
    • 几种传统的 RLVR 算法依赖于学习独立的价值网络 (2017):
      • 会带来显著的内存成本
      • 存在标量奖励的信息瓶颈
  • RLVR 设定通常(结果)奖励仅在序列结束时给出
    • 为了改进信用分配,一些工作学习了过程奖励模型 (PRM),用于估计序列中每一步的奖励 (2023;2024a;2025)
    • 与本文 RLRF 设定不同,PRM 通常在标量奖励上进行训练,要么基于中间状态的价值估计,要么基于结果奖励 (2025)
    • 与 SDPO 中的 Self-Teacher 也不同,PRM 是与 Student 不同的模型,引入了显著的内存开销
  • 本文工作表明,如果给予 Rich Feedback,每个语言模型通过 retrospection 隐式地就是一个 PRM
  • 从概念上讲,本文工作与 “bootstrapping your own latent”(BYOL;2020) 和 “expert iteration”(2017) 相关
    • 这些方法中 Student 通过反复模仿自身的改进版本(称为 “expert”)来引导自身
      • 通常,expert 将 Student 与 Test-time 搜索(如树搜索 (2017) 或多数投票 (2025))相结合
    • 本文的 SDPO 利用 Student 从上下文中提供的 Rich Feedback 中学习的能力,这与 BYOL 中的 “augmented views” 相关

Learning from Rich Feedback and through Retrospection,从 Rich Feedback 和通过 Retrospection 学习

  • 除了标量的结果奖励之外,近期工作还利用丰富的执行或口头反馈来指导生成 (2025;2024b;2025)
    • 一个主要的研究方向集中在将口头反馈转化为 RL 的奖励函数
    • 这通常通过使用外部冻结模型将反馈映射到离散的 Token-level 奖励 (2026),或通过使用强大的外部 LLM 显式构建状态级奖励函数来实现 (2019;2024;2026)
  • 也可以在没有显式奖励建模的情况下利用反馈
    • 有几种方法专注于上下文中的改进,而没有将这个过程整合到 RL 优化循环中 (2021a;2023;2023;2024;2025;2025)
    • 其他方法通过配对反馈前后的 Response 来手动构建偏好数据集,以便使用直接偏好优化进行训练 (2024;2024)
      • 但这需要额外的生成,并且缺乏 SDPO 的直接信用分配
    • 最近的几项工作从已知答案中引导思考轨迹,使用这些答案作为 Rich Feedback (2026;2026;2025)
  • 近期几项工作的核心对象是反馈条件策略 \(\pi_{\theta}(y \mid x, f)\) ,学习能够导致反馈 \(f\) 的答案 \(y\) (2023;2023;2025),通常通过监督目标来实现
    • 这些方法背后的思想是部署一个以期望的(即正面的)反馈为条件的策略
    • 这种方法在概念上与目标条件 RL (2015;2025a) 相关,在该方法中,可以通过目标重新标记 (2017) 从负例中学习
    • 区别:
      • 反馈条件策略将反馈视为一个目标
      • RLRF 将反馈视为一个状态,可用于确定目标 \(x\) 是否实现
    • 总结:与 SDPO 不同,这些方法不使用反馈来对负面轨迹进行信用分配,而是作为目标重新标记的数据转换

Distillation

  • 当强大的 Teacher 模型可用时,蒸馏经常被用作 SFT 的替代方案
    • 蒸馏通过训练 Student 模仿 Teacher 的输出分布或中间表示来迁移能力 (2015;2015;2016;2019;2020)
    • 蒸馏通常是在固定的 Off-Policy 数据集上进行的
  • 为了解决训练和推理之间的分布偏移,近期的工作探索了 On-Policy 蒸馏
    • Student 从外部 Teacher 对 Student 生成结果的反馈中学习 (2024;2024;2025a;2025)
    • 这减轻了训练-测试不匹配的问题,与早期关于在线模仿学习的工作密切相关 (2011)

Self-Distillation

  • 自蒸馏的概念首先由 Snell 等人 (2022) 在类似于监督学习的设定中提出,引入了从提供了额外上下文的模型中采样 ,并训练同一模型在没有该上下文的情况下模仿这些预测 的思想
    • 这种机制已被证明能有效地将行为 (2022;2022;2024;2025b) 和事实信息 (2026;2025;2025a) 压缩到模型权重中
  • 除了将固定上下文压缩到模型权重之外,近期的工作还使用自蒸馏从环境反馈中学习 (2023;2024;2025;2025;2026)
    • 这些方法使用 Off-Policy 自蒸馏目标
    • 本文发现其性能显著低于 SDPO 的 On-Policy 学习
      • Off-Policy 自蒸馏训练 Student 在 Teacher 的生成结果上进行学习
      • SDPO 则训练 Student 避免在其自身生成中犯错
  • 在同时进行的工作中
    • Chen 等人 (2025c) 将 On-Policy 自蒸馏应用于反馈为标量奖励的网格世界环境,并在 Self-Teacher 中设置了一个反思阶段来诊断可能的错误
      • · 09o展示了与学习价值网络进行 Advantage 估计相比有所改进的信用分配
    • 其他同时进行的工作研究了在固定专家演示数据集上的 SDPO,没有在线环境交互 (2026a;2026)

Limitations & Future Work

Limitations

  • SDPO 的性能依赖于模型的上下文学习能力
    • SDPO 主要适用于 RL 训练较强的基础模型,在较弱的模型上可能不如 GRPO
  • SDPO 性能取决于环境反馈的质量
    • 如果环境提供的信息不足或具有误导性,模型可能无法通过 SDPO 从中学习
  • 与 GRPO 相比,SDPO 在计算回顾性模型的 Logit 概率时增加了少量的计算开销
    • 通常可以忽略不计,但对于生成长度较短(生成时间相对较小)的小型模型来说,这可能是一个较大的开销
      • 注:此时相对成本会增加

Future Work

Long-horizon and agentic settings
  • 当轨迹很长或暴露关于中间状态的信息时,RLRF 尤其具有吸引力
  • 在 Agentic 环境中 SDPO 可能会有收益
Training dynamics at scale
  • 除了本文在 LiveCodeBench 上的评估之外,可考虑将 SDPO 扩展到大规模的多任务 RL 训练运行,并进一步研究其在前沿基础模型上的扩展特性
Beyond verifiable rewards
  • 本文专注于可验证的代码生成,但许多任务提供文本反馈却没有一个 ground-truth 验证器
  • 后续可考虑研究 SDPO 的 retrospection 机制能否应用于 开放式文本生成 或 连续奖励任务
Behavioral differences in reasoning
  • 实验中看到 SDPO 诱导出与 GRPO 本质上不同的推理模式
    • SDPO 避免了 GRPO 倾向于冗长和表面化推理的倾向
  • 未来的工作可以系统地研究各个方面的因素(例如 Reprompting 模板)如何影响行为

附录 A:Implementation of SDPO

  • Figure 14: The pseudo-code of SDPO within a standard RL training pipeline. Omitted here is the filtering to top-K logprobs for student and teacher (including a tail term) as described in Appendix A.3. Further, we omit here any importance sampling weights to correct for off-policy data. reprompt modifies the batch to incorporate teacher context (i.e., rich feedback). divergence implements any per-token divergence such as reverse-KL, forward-KL, or Jensen-Shannon.

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    def compute_sdpo_loss(batch, teacher_context, loss_mask):
    """
    Computes probabilities of response y under the self-teacher
    and the per-logit SDPO loss.
    """
    # Compute model probabilities for response y
    logprobs_student = compute_log_prob(batch) # (T,V)
    probs_student = logprobs_student.exp() # (T,V),注:实际上,这一样没有用

    # Compute self-teacher probabilities for response y
    teacher_batch = reprompt(batch, teacher_context)
    logprobs_teacher = compute_log_prob(teacher_batch).detach() # (T,V)

    # Compute SDPO loss: per-token divergence
    per_token_loss = divergence(logprobs_student, logprobs_teacher) # (T,)
    return agg_loss(per_token_loss, loss_mask, loss_agg_mode="token-mean")
  • 附录 A 的下文内容包括:

    • 本文实现中使用的梯度估计器 (附录 A.1)
    • Teacher 正则化 (Teacher regularization) (附录 A.2)
    • 通过 top-K logits 近似 logit 蒸馏以节省 GPU 内存 (附录 A.3)
    • 将 PPO 风格的策略梯度算法泛化到 Logit-level Advantage (附录 A.4)
  • 为了区分 Self-Teacher 的符号,下文中使用
    $$ q_{\theta}(\cdot |x,f):= \pi_{\theta}(\cdot |\text{reprompt}(x,f))$$

    • 这里,reprompt 表示 Self-Teacher 的重提示模板

A.1 Gradient Estimators

  • 本节讨论当前策略 \(\pi_{\theta}(y|x)\) 和 Teacher 策略 \(q_{\theta}(y|x,f)\) 之间 KL 散度的两种可能的梯度估计器
Per-token estimator
  • 推导公式 (1) 定义的 SDPO 损失的梯度:
    $$\mathcal{L}_{\text{token} }(\theta):= \mathbb{E}_{y\sim \text{stopgrad}(\pi_{\theta}(\cdot |x))}\left[\sum_{t = 1}^{T}\text{KL}(\pi_{\theta}(\cdot |x,y_{< t})||\text{stopgrad}(\pi_{\theta}(\cdot |x,f,y_{< t})))\right] \tag{5}$$
  • 得到以下估计器(详细证明见附录 B.1),它对应于每个 token 处 KL 散度的梯度之和:
    $$\nabla \mathcal{L}_{\text{token} }(\theta) = \mathbb{E}_{y\sim \pi_{\theta}(\cdot |x)}\left[\sum_{t = 1}^{T}\mathbb{E}_{\hat{y}_{t}\sim \pi_{\theta}(\cdot |x,y_{< t})}\left[\nabla_{\theta}\log \pi_{\theta}(\hat{y}_{t}\mid x,y_{< t})\cdot \log \frac{\pi_{\theta}(\hat{y}_{t}\mid x,y_{< t})}{\pi_{\theta}(\hat{y}_{t}\mid x,f,y_{< t})}\right]\right]. \tag{6}$$
  • 这对应于命题 2.1 中提出的估计器
    • 该梯度估计器有效地假定了生成 \(y\) 的采样分布是固定的
Sequence-level estimator
  • 另一种自蒸馏目标是最小化 Student 和 Self-Teacher 之间的 Sequence-level KL 散度,即:
    $$\begin{align} \mathcal{L}_{\text{seq} }(\theta):&= \text{KL}(\pi_{\theta}| q_{\theta}) = \mathbb{E}_{y\sim \pi_{\theta}(\cdot |x)}\left[\log \frac{\pi_{\theta}(y\mid x)}{q_{\theta}(y\mid x,f)}\right]\\
    &= \sum_{t = 1}^{T}\mathbb{E}_{s_{t}\sim \Pi_{\theta} }\left[\text{KL}(\pi_{\theta}(\cdot \mid s_{t})| q_{\theta}(\cdot \mid s_{t},f))\right] \end{align} \tag{7}$$
    • \(s_{t} = (x,y_{< t})\) 是前缀(“状态”)
    • \(\Pi_{\theta}\) 表示策略 \(\pi_{\theta}\) 下的前缀分布
      • 理解,即zhuangtai状态 \(s_t\) 的分布
  • 估计该目标的梯度还会考虑 \(y_{t}\) 的选择如何影响未来的状态 \(y_{> t}\)(由于对 \(\Pi_{\theta}\) 的额外依赖性)
  • Amini 等人 (2025) 表明相应的梯度估计器由下式给出:
    $$\pmb {\nabla}\mathcal{L}_{\text{seq} }(\theta) = \pmb {\nabla}\mathcal{L}_{\text{token} }(\theta) + \mathbb{E}_{y\sim \pi_{\theta}(\cdot |x)}\left[\sum_{t = 1}^{T}\text{KL}(\pi_{\theta}(\cdot \mid s_{t})| q_{\theta}(\cdot \mid s_{t},f))\pmb {\nabla}_{\theta}\log \Pi_{\theta}(s_{t})\right]. \tag{8}$$
  • Sequence-level 梯度的额外项捕捉了前缀如何影响未来 token 的自蒸馏散度
  • 本文也实验了这种 Sequence-level 梯度估计器,但没有发现相对于其额外复杂性的可衡量的收益

A.2 Regularized teacher

  • 与标准蒸馏不同,SDPO 中的 Teacher 在训练过程中会发生变化
  • 这种自举 (bootstrapping) 使得 Teacher 能够改进,但也可能导致训练不稳定
  • 为了稳定训练,本文尝试防止 Teacher \(q\) 快速偏离初始 Teacher \(\color{red}{q_{\theta_{\text{ref} } }}\)
  • 可以通过对 \(q\) 施加一个明确的信任域约束 (trust-region constraint) (2015; 019) 来实现这一点,即:
    $$\sum_{t}\text{KL}(q(y_{t}\mid x,f,y_{< t})| \color{red}{q_{\theta_{\text{ref} } }}(y_{t}\mid x,f,y_{< t}))\leq \epsilon ,\quad \epsilon >0. \tag{9}$$
  • 这个信任域可以通过两种方式实现:
    • 1)Explicit trust-region: 将 Teacher 定义为在满足信任域约束的同时最接近 \(q_{\theta}\) 的策略,这个 Teacher 可以表示为
      $$q(y_{t}\mid x,f,y_{< t})\propto \exp ((1 - \alpha)\log \color{red}{q_{\theta_{\text{ref} } }}(y_{t}\mid x,f,y_{< t}) + \alpha \log q_{\theta}(y_{t}\mid x,f,y_{< t})), \tag{10}$$
      • 其中 \(\alpha \in (0,1)\) 是信任域约束的逆拉格朗日乘子
      • 本文在附录 B.2 中包含了完整的推导
      • 可以将这个显式约束的 Teacher 直接代入 SDPO 目标
    • 2)Exponential moving average, EMA: 直接稳定 Teacher 的参数
      • 将 \(q_{\theta}\) 参数化为 \(\theta^{\prime}\) 并更新为
        $$ \theta^{\prime}\leftarrow (1 - \alpha)\theta^{\prime} + \alpha \theta$$
        • 其中 \(\alpha \in (0,1)\)
  • 每种实现都有不同的实际 Advantage :
    • EMA Teacher 需要额外的 GPU 内存来存储 \(\theta^{\prime}\)
      • 但不引入任何运行时间开销
    • 信任域 Teacher 需要额外的 log-prob 计算与 \(\color{red}{q_{\theta_{\text{ref} } }}\)
      • 但如果 \(\theta_{\text{ref} }\) 用于显式的 KL 正则化 ,则不需要额外的 GPU 内存

A.3 Approximate Logit Distillation

  • 为了节省 GPU 内存,本文仅对 Student 预测的 top-\(K\) 个 token 进行蒸馏:
    $$\begin{aligned}
    \mathcal{L}_{\text{SDPO} }(\theta) &= \sum_{t=1}^{T} \text{KL}(\pi_{\theta}(\cdot | x, y_{ < t}) | \text{stopgrad}(q_{\theta}(\cdot | x, f, y_{ < t}))) \\
    &\approx \sum_{t=1}^{T} \left[ \sum_{\hat{y}_t \in \text{top}_K(\pi_{\theta})} \pi_{\theta}(\hat{y}_t | x, y_{ < t}) \cdot \log \frac{\pi_{\theta}(\hat{y}_t | x, y_{ < t})}{\text{stopgrad}(q_{\theta}(\hat{y}_t | x, f, y_{ < t}))} \\
    \quad \quad \quad + \underbrace{\left(1 - \sum_{\hat{y}_t \in \text{top}_K(\pi_{\theta})} \pi_{\theta}(\hat{y}_t | x, y_{ < t})\right) \cdot \log \frac{1 - \sum_{\hat{y}_t \in \text{top}_K(\pi_{\theta})} \pi_{\theta}(\hat{y}_t | x, y_{ < t})}{\text{stopgrad}\left(1 - \sum_{\hat{y}_t \in \text{top}_K(\pi_{\theta})} q_{\theta}(\hat{y}_t | x, f, y_{ < t})\right)}}_{\text{tail}} \right]
    \end{aligned}
    \tag{11}
    $$
    • top-\(K\) 是相对于 Student 的
  • 如果没有 top-K 蒸馏,我们将不得不在内存中保留两份 logits 副本:一份给 Student ,一份给 Teacher
    • Top-K 蒸馏避免了几乎所有的内存开销,且不会显著影响性能,因为在每个 Token 点(Step \(t\)),词表中的大部分 token 在并不包含信息

A.4 Off-Policy Training: Generalization to Logit-Level Losses, 泛化到 Logit-level

  • 下面的 Token-level Off-policy 损失函数为:
    $$\mathcal{L}_{\text{token} }(\theta) := -\color{red}{\frac{1}{\sum_{i=1}^{G} |y_i|}} \sum_{i=1}^{G} \sum_{t=1}^{|y_i|} \color{red}{\min \left( w_{i,t}^{\text{TIS} }, \rho \right)} \min \left( w_{i,t} A_{i,t}, \text{clip}(w_{i,t}, 1 - \varepsilon_{\text{low} }, 1 + \color{red}{\varepsilon_{\text{high} }}) A_{i,t} \right), \tag{12}$$
    • PPO 风格的裁剪 (clipping) (2017) :
      $$ \color{red}{\frac{1}{\sum_{i=1}^{G} |y_i|}} $$
    • 截断重要性采样 (truncated importance sampling) (2025):
      $$ \color{red}{\min \left( w_{i,t}^{\text{TIS} }, \rho \right)} $$
      • 其中:
        $$ w_{i,t}^{\text{TIS} } := \frac{\pi_{\theta_{\text{old} } }(y_{i,t}|x,y_{i,<t})}{\pi_{\text{rollout} }^{\theta_{\text{old} } }(y_{i,t}|x,y_{i,<t})} $$
    • clip-higher (2025) 和固定长度归一化 (2025b):
      $$ \color{red}{\varepsilon_{\text{high} }} $$
    • 其中 \(A_{i,t}\) 表示每个 token 的 Advantage ,且
      $$
      \begin{align}
      w_{i,t} &:= \frac{\pi_{\theta}(y_{i,t}|x,y_{i,<t})}{\pi_{\theta_{\text{old} } }(y_{i,t}|x,y_{i,<t})} \\
      \end{align}
      $$
  • 将其扩展到 Logit-level 损失:
    $$
    \begin{align}
    \mathcal{L}_{\text{logit} }(\theta) := -\color{red}{\frac{1}{\sum_{i=1}^{G} |y_i|}} \sum_{i=1}^{G} \sum_{t=1}^{|y_i|} \color{blue}{\sum_{\hat{y}_{i,t} }} &\color{red}{\min \left( \pi_{\theta_{\text{old} } }(\hat{y}_{i,t} | x, y_{i,<t}), \rho \pi_{\text{rollout} }^{\theta_{\text{old} } }(\hat{y}_{i,t} | x, y_{i,<t}) \right)}\\
    \cdot &\min \left( w_{i,t}(\hat{y}_{i,t}) A_{i,t}(\hat{y}_{i,t}), \text{clip}(w_{i,t}(\hat{y}_{i,t}), 1 - \varepsilon_{\text{low} }, 1 + \color{red}{\varepsilon_{\text{high} }}) A_{i,t}(\hat{y}_{i,t}) \right)
    \end{align}
    \tag{13}$$
    • 其中 \(\hat{y}_{i,t}\) 对位置 \(t\) 处 Rollout \(i\) 的所有可能 token(或 \(\pi_{\theta_{\text{old} } }\) 下最可能的 \(K\) 个 token)求和
      • \(\pi_{\theta_{\text{old} } }\) 下最可能的 \(K\) 个 token 详情参见附录 A.3
    • TIS (Truncated Importance Sampling) 发生了变化
      $$\color{red}{\min \left( \pi_{\theta_{\text{old} } }(\hat{y}_{i,t} | x, y_{i,<t}), \rho \pi_{\text{rollout} }^{\theta_{\text{old} } }(\hat{y}_{i,t} | x, y_{i,< t}) \right)} $$
      • 这里明确地使用 \(\pi_{\theta_{\text{old} } }\) 下的概率来加权每个 logit,而不是依赖对下一个 token 预测期望的蒙特卡洛估计
      • 问题:这里 TIS 为什么相对前面 Token-level 的式子同时乘以了 \(\pi_{\text{rollout} }^{\theta_{\text{old} } }(\hat{y}_{i,t} | x, y_{i,< t})\) ?
        • 采样使用的是 \(\pi_{\text{rollout} }^{\theta_{\text{old} } }\),这本身就相当于乘过 \(\pi_{\text{rollout} }^{\theta_{\text{old} } }\) 了吧
    • 这里 \(A_{i,t}(\hat{y}_{i,t})\) 是每个 logit 的 Advantage
  • 注:在本文对 SDPO 的实验中,在 Token-level 而非 Logit-level 应用 TIS 项

附录 B:Theoretical Analysis

  • 附录 B.1 推导命题 2.1 中的 SDPO 梯度
  • 附录 B.2 推导附录 A.2 中讨论的信任域正则化 Teacher
  • 注:同附录 A,下文中使用:
    $$ q_{\theta}(\cdot | x, f ) := \pi_{\theta}(\cdot | \text{reprompt}(x, f )) $$
    • reprompt 表示 Self-Teacher 的重提示模板

B.1 Proof of Proposition 2.1.

  • 令:
    $$ A_{t,k} := \log \left( \frac{\text{stopgrad}(q_{\theta}(\hat{y}_t|x,f,y_{ < t}))}{\pi_{\theta}(\hat{y}_t|x,y_{ < t})} \right)$$
  • 于是,从上文有 \(\mathcal{L}_{\text{SDPO} }\) 的梯度:
    $$
    \begin{aligned}
    \nabla_{\theta} \mathcal{L}_{\text{SDPO} }(\theta) &= \nabla_{\theta} \sum_{t=1}^{T} \text{KL}(\pi_{\theta}(\cdot | x, y_{ < t}) | \text{stopgrad}(q_{\theta}(\cdot | x, f, y_{ < t}))) \\
    &= \nabla_{\theta} \sum_{t=1}^{T} \sum_{\hat{y}_t} \pi_{\theta}(\hat{y}_t | x, y_{ < t}) \log \left( \frac{\pi_{\theta}(\hat{y}_t | x, y_{ < t})}{\text{stopgrad}(q_{\theta}(\hat{y}_t | x, f, y_{ < t}))} \right) \\
    &= -\nabla_{\theta} \sum_{t=1}^{T} \sum_{\hat{y}_t} \pi_{\theta}(\hat{y}_t | x, y_{ < t}) A_{t,k} \\
    &= -\sum_{t=1}^{T} \sum_{\hat{y}_t} \left( \pi_{\theta}(\hat{y}_t | x, y_{ < t}) \nabla_{\theta} A_{t,k} + A_{t,k} \nabla_{\theta} \pi_{\theta}(\hat{y}_t | x, y_{ < t}) \right)
    \end{aligned}
    $$
  • 由于有
    $$ \nabla_{\theta} A_{t,k} = -\nabla_{\theta} \log \pi_{\theta}(\hat{y}_t | x, y_{ < t})$$
    • 这是负的得分函数 (score function)
  • 使用得分技巧 (score trick)
    $$ \pi_{\theta}(\hat{y}_t | x, y_{ < t}) \nabla_{\theta} \log \pi_{\theta}(\hat{y}_t | x, y_{ < t}) = \nabla_{\theta} \pi_{\theta}(\hat{y}_t | x, y_{ < t})$$
  • 于是第一项简化为
    $$-\sum_{t=1}^{T} \sum_{\hat{y}_t} \pi_{\theta}(\hat{y}_t | x, y_{ < t}) \nabla_{\theta} A_{t,k} = \sum_{t=1}^{T} \sum_{\hat{y}_t} \nabla_{\theta} \pi_{\theta}(\hat{y}_t | x, y_{ < t}) = \sum_{t=1}^{T} \nabla_{\theta} \underbrace{\sum_{\hat{y}_t} \pi_{\theta}(\hat{y}_t | x, y_{ < t})}_{=1} = 0.$$
  • 因此,\(\mathcal{L}_{\text{SDPO} }\) 的梯度是
    $$\begin{aligned}
    \nabla_{\theta} \mathcal{L}_{\text{SDPO} } &= -\sum_{t=1}^{T} \sum_{\hat{y}_t} A_{t,k} \nabla_{\theta} \pi_{\theta}(\hat{y}_t | x, y_{ < t}) \\
    &= -\sum_{t=1}^{T} \sum_{\hat{y}_t} \pi_{\theta}(\hat{y}_t | x, y_{ < t}) \left( A_{t,k} \nabla_{\theta} \log \pi_{\theta}(\hat{y}_t | x, y_{ < t}) \right) \\
    &= -\sum_{t=1}^{T} \mathbb{E}_{\hat{y}_t \sim \pi_{\theta}(\cdot | x, y_{ < t})} \left[ A_{t,k} \nabla_{\theta} \log \pi_{\theta}(\hat{y}_t | x, y_{ < t}) \right].
    \end{aligned}$$
  • 这意味着如果 \(A_{t,k} = \text{stopgrad}\left( \log \frac{q_{\theta}(y_t|x,f,y_{ < t})}{\pi_{\theta}(y_t|x,y_{ < t})} \right)\),则 \(\mathcal{L}_{\text{SDPO} }\) 的梯度等价于上述梯度

B.2 Trust-region Teacher

  • 为稳定训练,本文尝试防止 Teacher \(q\) 偏离初始 Teacher \(\color{red}{q_{\theta_{\text{ref} } }}\)
  • 可以通过对 Teacher \(q\) 施加一个明确的信任域约束 (2015; 2019) 来实现这一点,即:
    $$\sum_{t} \text{KL}(q(y_t | x, f, y_{ < t}) | \color{red}{q_{\theta_{\text{ref} } }}(y_t | x, f, y_{ < t})) \le \epsilon, \quad \epsilon > 0. \tag{14}$$
  • 下面本文推导出一个 Teacher \(q\),它满足信任域约束,同时保持接近目标 \(q_{\theta}\)
  • 以下优化问题刻画了这样一个 \(q\) (2019):
    $$\begin{array}{rl}
    & \underset{q\in \Delta}{\arg \max} \sum_{t} \sum_{y_t} q(y_t \mid x, f, y_{< t}) \log \frac{q_{\theta}(y_t \mid x, f, y_{< t})}{\color{red}{q_{\theta_{\text{ref} } }}(y_t \mid x, f, y_{< t})} \\
    & \qquad \text{s.t.} \sum_{t} \text{KL}(q(y_t \mid x, f, y_{< t}) | \color{red}{q_{\theta_{\text{ref} } }}(y_t \mid x, f, y_{< t})) \le \epsilon,
    \end{array} \tag{15}$$
    • 其中 \(\Delta\) 表示概率单纯形
  • 直观地说,解是满足信任域约束的、最接近 \(q_{\theta}\)(即与 \(q_{\theta}\) 的交叉熵最小)同时最远离 \(\color{red}{q_{\theta_{\text{ref} } }}\)(即与 \(\color{red}{q_{\theta_{\text{ref} } }}\) 的交叉熵最大)的 \(q\)
Proposition B.1.
  • 公式 (15) 的解可以闭式表示为
    $$q^{*}(y_{t}\mid x,f,y_{< t})\propto \exp ((1 - \alpha)\log \color{red}{q_{\theta_{\text{ref} } }}(y_{t}\mid x,f,y_{< t}) + \alpha \log q_{\theta}(y_{t}\mid x,f,y_{< t})). \tag{16}$$
  • Proof
    • 为简化符号,在下文中省略条件
    • 拉格朗日函数(对于 KL 约束有 \(\lambda \ge 0\),对于归一化有 \(\nu\))为
      $$\mathcal{L}(q,\lambda ,\nu) = \sum_{t}\sum_{y_{t} } q(y_{t}) \log \frac{q_{\theta}(y_{t})}{\color{red}{q_{\theta_{\text{ref} }} }(y_{t})} -\lambda \Big( \sum_{y_{t} } q(y_{t}) \log \frac{q(y_{t})}{\color{red}{q_{\theta_{\text{ref} } }}(y_{t})} -\epsilon \Big) + \nu \Big( \sum_{y_{t} } q(y_{t}) - 1 \Big).$$
    • 平稳性条件给出,对于所有 \(y_t\)
      $$0 = \frac{\partial \mathcal{L} }{\partial q(y_t)} = \log \frac{q_{\theta}(y_t)}{\color{red}{q_{\theta_{\text{ref} } }}(y_t)} - \lambda \Big( \log \frac{q(y_t)}{\color{red}{q_{\theta_{\text{ref} } }}(y_t)} + 1 \Big) + \nu.$$
    • 令 \(\alpha := \frac{1}{\lambda}\),于是公式 (15) 的解可以闭式刻画为
      $$
      \begin{align}
      q^{*}(y_t) &\propto \color{red}{q_{\theta_{\text{ref} } }}(y_t) \exp \Big( \alpha \log \frac{q_{\theta}(y_t)}{\color{red}{q_{\theta_{\text{ref} } }}(y_t)} \Big) \\
      \qquad &\propto \exp \big( (1 - \alpha) \log \color{red}{q_{\theta_{\text{ref} } }}(y_t) + \alpha \log q_{\theta}(y_t) \big)
      \end{align}
      $$
      • \(\color{red}{q_{\theta_{\text{ref} } }}(y_t) \) 变成 \(\exp (\log \color{red}{q_{\theta_{\text{ref} } }}(y_t)) \) 即可推导得到
  • 注:Chen 等人 (2025c) 进行了类似的推导,但使用参考策略 \(\pi_{\theta_{\text{ref} } }\)
    • 本文作者观察到这比使用参考策略 \(\color{red}{q_{\theta_{\text{ref} } }}\) 表现更差

附录 C:Additional Related Work

  • Value networks and Monte Carlo advantage estimation

    • 一些先前的方法旨在改进信用分配,但面临着与 GRPO 相同的信息瓶颈
    • 经典 RL 经常训练价值网络来提供 Token-level Advantage ,但这些网络本身是从标量奖励中学习的 (2016; 2017)
      • 而且价值网络会带来显著的计算和内存开销,因此通常不用于训练 LLM
    • 其他近期工作通过从原始尝试中的不同位置开始执行额外的生成来估计 Token-level Advantage (2025; 2025b)
      • 虽然这可以用比 GRPO 更少的梯度步骤进行学习,但它仍然仅使用标量奖励作为信号,并且需要昂贵的额外生成
  • Dense credit assignment with a reward model

    • 近期几项工作研究了在能够访问外部奖励模型的情况下进行密集(每个 token)奖励分配,通常通过利用奖励模型的内部结构来实现 (2024; 2025b)
    • 与此相关的是,Li 等人 (2025b) 认为,通过将下一个 token 预测与离线逆 RL 联系起来,LLM 的 logits 中隐含着一个 Token-level 奖励信号,这实际上为 RL 微调提供了一个无需训练 (training-free) 的奖励模型
  • Partial observability

    • 从经典 RL 的角度来看,LLM 的许多可验证领域本质上是部分可观测的:
      • 执行一个 Proposed 解决方案会引发一个潜在的环境状态(例如,失败的测试或 Agent 系统的状态),该状态仅通过丰富的反馈才会显现
      • 这与部分可观测马尔可夫决策过程 (POMDP) 的形式化方法一致,其中 Agent 必须在状态观测不完整的情况下行动 (1998; 1998)
      • 相反,RLVR 和 RLHF 流程通常会丢弃这个观测通道,并且仅从终端标量奖励或成对偏好中学习
  • Relation to test-time training

    • 本文在第 5 节中的设置可以被视为 Test-time 训练的一个特例
      • 其中模型本身在 Test-time 通过自蒸馏进行更新
    • 在 Test-time 更新模型被称为 Test-time 训练 (2020; 2025; 2024; 2025; 2025; 2026)
    • 与以上提到的工作都不同,自蒸馏使用当前模型的上下文学习能力来在接收反馈后进行信用分配
      • 这可以被视为通过将上下文周期性地压缩到模型权重中来模拟长上下文推理

C.1 SDPO as Maximum Entropy RL

  • SDPO 目标类似于最大熵 RL (例如,2018) 中的目标,但具有特定的奖励函数选择
Maximum Entropy RL
  • Maximum Entropy RL 考虑优化
    $$\arg \max_{\theta} \mathbb{E}_{y\sim \pi_{\theta}(\cdot |x)} \left[ \sum_t r(y_t \mid x, y_{ < t}) \right] + \lambda \text{H}[\pi_{\theta}(\cdot \mid x)], \quad \lambda > 0 \tag{17}$$
    • 其中 \(\pi_{\theta}(y\mid x) = \prod_{t=1}^{T} \pi_{\theta}(y_t \mid x, y_{ < t})\) 且 \(\text{H}[\pi_{\theta}(\cdot \mid x)] = \mathbb{E}_{y\sim \pi_{\theta}(\cdot \mid x)}[-\log \pi_{\theta}(y \mid x)]\) 是策略的熵
    • 且 \(r(y_t \mid x, y_{ < t})\) 是一个任意的奖励函数,可能是“密集的”(即每个 token 的)
  • 公式 (17) 被称为最大熵 RL
    • 已知该目标等价于求解一个变分推理问题,本文接下来将讨论这个问题
  • 本文定义一个伯努利随机变量 \(\mathcal{C}\)
    • 如果尝试 \(y\) 正确则为 1,否则为 0
  • 接着定义其分布为
    $$ p(\mathcal{C} = 1 \mid x, y) \propto \exp \left( \frac{1}{\lambda} \sum_t r(y_t \mid x, y_{ < t}) \right)$$
    • 进一步假设在响应上的“先验”是均匀的,本文可以将在正确性事件条件下的后验表示为
      $$\pi^{\star}(y \mid x) := p(y \mid x, \mathcal{C} = 1) \propto p(\mathcal{C} = 1 \mid x, y) \propto \exp \left( \frac{1}{\lambda} \sum_t r(y_t \mid x, y_{ < t}) \right). \tag{18}$$
  • 那么,公式 (17) 等价于最小化与 \(\pi^{*}\) 的 KL 散度:
    $$\arg \min_{\theta} \sum_t \text{KL}(\pi_{\theta}(y_t \mid x, y_{ < t}) | \pi^{*}(y_t \mid x, y_{ < t})). \tag{19}$$
SDPO optimizes an implicit reward defined by the teacher
  • 公式 (19) 等价于 SDPO 目标 (公式 (1)),其中隐式奖励
    $$ r(y_{t} \mid x, y_{< t}) = \log q(y_{t} \mid x, f, y_{< t})$$
    • 注: \(\lambda = 1\)
  • 从这个意义上说,SDPO 可以被视为一种最大熵 RL 算法,其密集奖励通过回顾模型 (retrospective model) 隐式构建
  • 这也指出了 SDPO 与逆 RL (2000; 2008; 2023) 的联系,后者的目标是恢复未知的奖励函数
  • 在 SDPO 中, Student 学习了由回顾模型定义的隐式奖励函数

附录 D:Additional Results & Ablations

D.1 Learning without rich environment feedback

  • 表 7 报告了为每个模型/任务组合选择最优超参数时的结果,表 7 的具体实验细节如下
    • 问题:
      • 为什么没有数学场景?
    • 本文报告了在 1 小时和 5 小时实际训练时间内达到的最高 avg@16
    • 梯度更新步骤:
      • SDPO 和 on-policy GRPO 每个生成批次执行一个梯度步骤
      • GRPO 执行 4 个 off-policy 小批量步骤
    • 本文根据 5 小时的准确率为 SDPO 和基线选择最优超参数
      • 本文为每个模型和数据集独立执行此选择
    • 每次运行在一个包含 4 个 NVIDIA GH200 GPU 的节点上执行
    • 包括初始化和验证,每次运行大约需要 6 小时
    • 与表 3(为每种方法全局选择最优超参数)不同,表 7 根据 5 小时准确率为每个模型/任务组合单独选择最优超参数
    • 注:超参数网格在附录 E.2.1 中描述
  • 表 8 比较了 SDPO 和 GRPO 的平均响应长度 (对第 3 节中的任务取平均)
    • 两种算法均在 on-policy 设置下进行评估
  • 核心结论:在 without rich environment feedback 的场景中,实验(比如工具调用和材料科学(Materials)等)的结果看,SDPO 训练相同时间得到的效果并没比 On-policy GRPO 和 GRPO 好太多(部分指标甚至没有提升)

D.2 Learning with rich environment feedback

D.2.1 Additional Results
  • 图 15 显示了按问题难度分层的 SDPO 和 GRPO 的平均准确率(训练期间直至第 80 步的平均准确率,按难度分层)
    • LCB 区分了简单 (easy)、中等 (medium) 和困难 (hard) 问题
      • 注:这种问题分类与第 5 节中的不同
    • 如图所示,SDPO 在解决中等和困难问题方面显著优于 GRPO
      • 突显了丰富反馈对具有挑战性任务的重要性
  • 图 16 比较了在 LCBv6 上训练 GRPO 和 SDPO 时不同的 Training Batch Size 和 Rollout 数量
  • 图 17 中展示了使用 Qwen2.5-Instruct (2024) 的额外结果(作为对图 8 所示结果的补充)
D.2.2 Training Stability
  • 图 18 显示了训练期间记录的各种指标
D.2.3 Baselines
  • 表 9 比较了 LCBv6 上各种基线的性能,包括 GRPO 的两种变体、GSPO 和 CISPO 与 SDPO(训练到 80 步的结果)

D.3 Test-time self-distillation

  • 本节作为对第 5 节所示结果的补充
  • 图 20 中展示了所有困难问题的 discovery@k 曲线
    • 注:曲线表示每个问题 5 个随机种子的均值和 90% 置信区间
  • 表 10 中报告了首次发现前的平均生成次数(这里的首次发现就是首次成功的尝试)
  • 表 11 显示了 SDPO 初始训练步骤中 Self-Teacher 在每个问题上的准确率
    • 对于这些困难和非常困难的任务中的大多数, Teacher 的准确率接近或正好为 \(0%\)
    • 即便如此自蒸馏的 Token-level Advantage 仍然足够丰富,SDPO 能够迭代地改进其策略,并在后续的更新中解决这些问题
  • 图 19 对 SDPO 的 Batch Size 和多轮采样的上下文内重提示策略进行了消融研究
    • 左图:SDPO Batch Size 对 pass@k 曲线的影响
      • 较小的 Batch Size (8 和 16) 可以在非常低的生成预算下 \((k< 2^{6})\) 导致稍早的发现
      • 较大的 Batch Size (16, 32) 会产生更稳定的更新,随着预算的增加显著提高发现率
    • 右图:在困难问题子集上多轮重提示模板的比较
      • “仅反馈”模板使用先进先出滑动窗口连接来自先前尝试的反馈
      • “尝试 + 反馈”模板也使用滑动窗口连接完整的对话轮次
      • 仅包含反馈的方法显著优于连接完整对话的方法
  • 注:在困难问题的筛选中,本文丢弃了一个格式错误的问题 (Q9)
    • 因为编码环境由于舍入不准确而无法正确验证解决方案,即使逻辑正确也会导致失败

附录 E:Experiment Details

E.1 Technical setup

  • 所有实验均在配备四块 NVIDIA GH200 GPU(总计 378GB 显存)的单个节点上进行
    • 每张显卡的显存是 94.5GB ?
  • 本文的环境基于 NVIDIA PyTorch 容器 nvcr.io/nvidia/pytorch:25.02-py3 构建,使用 CUDA 12.8 和 PyTorch v2.7.0
  • 本文的实现基于 ver1 库 (2025)
    • 本文使用 PyTorch 全分片数据并行(FSDP2)进行分布式训练
    • 对于 Rollout 生成,本文采用 vLLM (2023),它能在多 GPU 节点上实现高效的批量推理

E.2 Hyperparameters

  • 表 12 中总结了用于 SDPO 的超参数
  • 表 13 中总结了用于 GRPO 的超参数
  • 对于第 3 节的实验,本文对 GRPO 进行了学习率 \(\{10^{-5}, 10^{-6}\}\) 和小 Batch Size {8, 32} 的网格搜索
    • 对于 on-policy GRPO,本文在固定小 Batch Size 为 32 的情况下搜索相同的学习率
    • 对于 SDPO,本文对 KL 变体(前向 KL,Jensen-Shannon)、学习率 \(\{10^{-5}, 10^{-6}\}\) 和小 Batch Size {8, 32} 进行了网格搜索
    • 对于每种方法(GRPO、on-policy GRPO 和 SDPO),本文选择了一个超参数配置,该配置在训练的前 5 小时内,在第 3 节使用的所有数据集和模型上取得了最高的验证准确率
    • 表 3 中报告了为每个模型和数据集单独选择最佳超参数配置所获得的结果

E.3 User Templates

  • 对于多项选择题和工具使用,必须以特定于任务的方式提示模型

  • 因此,本文在下面提供了用于这些设置的 Prompt 模板

  • Listing 1: System prompt: Multiple Choice Questions

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    Given a question and four options, please select the right answer. Respond in the following format:

    <reasoning>
    ...
    </reasoning>
    <answer>
    ...
    </answer>

    For the answer, only output the letter corresponding to the correct option (A, B, C, or D), and nothing else. Do not restate the answer text. For example, if the answer is "A", just output:

    <answer>
    A
    </answer>
  • Listing 2: User prompt: Multiple Choice Questions

    1
    {question} Please reason step by step.
  • Listing 3: Example user prompt: Tool use

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    Your task is to answer the user's question using available tools. You have access to the following tools:

    Name: Axolotl
    Description: Collection of axolotl pictures and facts
    Documentation:
    getRandomAxolotlImage: Retrieve a random axolotl image with information on the image source.
    Parameters:
    Output: Successful response.
    - Format: application/json
    - Structure: Object{url, source, description}
    SearchAxolotlImages: Search for axolotl images based on specific criteria such as color, gender, and size.
    Parameters:
    {"color": "string. One of: [wild, leucistic, albino]. The color of the axolotl (e.g., 'wild', 'leucistic', 'albino', etc.)",
    "gender": "string. One of: [male, female]. The gender of the axolotl ('male', 'female')",
    "size": "string. One of: [small, medium, large]. The size of the axolotl ('small', 'medium', 'large')",
    "page": "integer. The page number for pagination purposes."}
    Output: Successful response.
    - Format: application/json
    - Structure: Object{results: Array[Object{url, source, description}], pagination: Object{current_page, total_pages, total_results} }
    getAxolotlFacts: Retrieve interesting facts about axolotls such as their habits, habitats, and physical characteristics.
    Parameters:
    {"category": "string. One of: [habits, habitat, physical characteristics]. The category of facts to retrieve (e.g., 'habits', 'habitat', 'physical characteristics')",
    "limit": "integer. The maximum number of facts to return."}
    Output: Successful response.
    - Format: application/json
    - Structure: Array[Object{fact, source}]

    Use the following format:
    Thought: you should always think about what to do
    Action: the action to take, should be one of the tool names.
    Action Input: the input to the action, must be in JSON format. All of the action input must be realistic and from the user.
    Begin!
    Question: Hey, can you show me a random picture of an axolotl?

附录 F:Qualitative Examples

F.1 Visualization of Advantages

  • 图 21 在一个代表性示例中比较了 SDPO 和 GRPO 的 Advantage
    • SDPO 和 GRPO 中 Advantage 的可视化,使用 Olmo3-7B-Instruct,在第 3 节化学任务的一个批次中
    • 每一行对应一个 Response 的开头
    • 颜色表示该 Token 位置的 Advantage 值
      • 蓝色表示正 Advantage
      • 红色表示负 Advantage

F.2 Examples

  • 下面是一个使用 Qwen3-8B 在 LCBv6 上训练 SDPO 的示例
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    [Prompt]
    You are a coding expert. You will be given a coding problem, and you need to write a correct Python program that matches the specification and passes all tests. The time limit is 1 second. You may start by outlining your thought process. In the end, please provide the complete code in a code block enclosed with "``` ```".

    You are given a binary string s of length n, where:

    '1' represents an active section.
    '0' represents an inactive section.

    You can perform at most one trade to maximize the number of active sections in s. In a trade, you:

    - Convert a contiguous block of '1's that is surrounded by '0's to all '0's.
    - Afterward, convert a contiguous block of '0's that is surrounded by '1's to all '1's.

    Return the maximum number of active sections in s after making the optimal trade.
    Note: Treat s as if it is augmented with a '1' at both ends, forming t = '1' + s + '1'. The augmented '1's do not contribute to the final count.

    Example 1:

    Input: s = "01"
    Output: 1
    Explanation: Because there is no block of '1's surrounded by '0's, no valid trade is possible. The maximum number of active sections is 1.

    Example 2:

    Input: s = "0100"
    Output: 4
    Explanation:
    String "0100" -> Augmented to "101001". Choose "0100", convert "101001" -> "100001" -> "111111". The final string without augmentation is "1111". The maximum number of active sections is 4.

    Example 3:

    Input: s = "1000100"
    Output: 7
    Explanation:
    String "1000100" -> Augmented to "110001001". Choose "000100", convert "110001001" -> "110000001" -> "111111111". The final string without augmentation is "1111111". The maximum number of active sections is 7.

    Example 4:

    Input: s = "01010"
    Output: 4
    Explanation:
    String "01010" -> Augmented to "1010101". Choose "010", convert "1010101" -> "1000101" -> "1111101". The final string without augmentation is "111110". The maximum number of active sections is 4.

    Constraints:

    1 <= n == s.length <= 10^5
    s[i] is either '0' or '1'

    Your solution should have the following signature: "python def maxActiveSectionsAfterTrade(s: str) -> int:"

    [Response]
    <think>
    </think>

    To solve this problem, we need to understand the structure of the string and the trade operation. Here's the plan:

    ### Key Insights:

    1. We are given a binary string 's', and we are allowed **at most one trade**.
    2. A **trade** consists of two steps:
    - Convert a contiguous block of '1' that is **surrounded by '0'** into '0' (i.e., flip it).
    - Then convert a contiguous block of '0' that is **surrounded by '1'** into '1' (i.e., flip it).
    3. The actual count of active sections is the number of '1' characters **after** the trade. The augmented '1' at the beginning and end do not count in the final result.

    ### Strategy:

    - We simulate the trade by considering each possible way to flip a block of "1" (bounded by "0") and then a block of "0" (bounded by "1").
    - For each possible valid block, simulate the trade and compute the resulting number of "1" characters.
    - We only need to consider the original string "s", not the augmented one (since we can simulate the trade directly on "s" by checking boundaries).

    We'll use a sliding window technique to find contiguous blocks of "1" and "0" and simulate the trade.

    ### Implementation Steps:

    1. Simulate the trade by checking for a block of "1" surrounded by "0" or at the edges (we simulate the augmented "1" at both sides).
    2. For each such block, simulate converting it to "0", then convert a "0" block (again, surrounded by "1") to "1".
    3. Count the number of "1" in the final string and track the maximum.

    ### Python Code:

    '''python
    def maxActiveSectionsAfterTrade(s: str) -> int:
    if not s:
    return 0
    max_ones = 0
    n = len(s)
    # We are simulating the augmented "1" on both ends, so we treat the boundaries as '1's
    # So '0' at index 0 or n-1 is considered surrounded by '1's

    # Step 1: Find all candidate positions for the first part of the trade
    # (convert a block of 1s surrounded by 0s into 0s)
    for i in range(n):
    if s[i] == '1':
    start = i
    while i < n and s[i] == '1':
    i += 1
    end = i - 1
    # Check if this block is surrounded by '0's
    left_boundary = (start == 0 or s[start - 1] == '0')
    right_boundary = (end == n - 1 or s[end + 1] == '0')
    if left_boundary and right_boundary:
    # This is a candidate block of '1's to flip
    # Now, look for a block of '0's surrounded by '1's in the remaining string
    # Replace the '1' block with '0's
    # Now find a '0' block surrounded by '1's
    # We can simulate the remaining string
    temp = list(s)
    for j in range(start, end + 1):
    temp[j] = '0'
    # Now find a contiguous block of '0's surrounded by '1's
    max_trade_ones = 0
    for k in range(n):
    if temp[k] == '0':
    start0 = k
    while k < n and temp[k] == '0':
    k += 1
    end0 = k - 1
    # Check if this block is surrounded by '1's
    left0 = (start0 == 0 or temp[start0 - 1] == '1')
    right0 = (end0 == n - 1 or temp[end0 + 1] == '1')
    if left0 and right0:
    # Flip this block of '0's to '1's
    for j in range(start0, end0 + 1):
    temp[j] = '1'
    # Count the number of '1's
    count = sum(temp)
    max_ones = max(max_ones, count)
    break # Only take the first valid block for this trade
    # After trying this '1' block, restore the original string
    # No need, we just used a copy
    return max_ones
    '''

    #### Optimization Note:
    The above approach is **not optimal** for very large strings. A more efficient method would involve precomputing the positions of blocks of '1' and '0', and checking valid trades between these blocks in linear time.

    But with the constraints given (up to 10^5), we need to ensure linear or near-linear performance. The current code is correct but **not efficient** for large inputs.

    Would you like me to provide a more optimized version that runs in **O(n)** time and uses only one pass through the string?

    [Feedback]
    b'Runtime Error\nTypeError: unsupported operand type(s) for +: \'int\' and \'str\'\nLine 48 in maxActiveSectionsAfterTrade (Solution.py)\n\nLast Executed Input\n"11000"'

F.3 Environment Feedback

  • 本文在代码环境中展示了三个反馈示例,灵感来自 LeetCode

  • Listing 4: Example of feedback “Wrong Answer” from our code environment in case of a wrong answer, inspired by LeetCode

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    Test Case 3: Wrong Answer
    Input
    [[3,6,1],[4,21,4]]
    Output
    6
    Expected
    22.87500

    Test Case 6: Wrong Answer
    Input
    [[12,25,3],[3,14,2]]
    Output
    14
    Expected
    25.83333
  • Listing 5: Example of feedback “Memory Error” from our code environment in case of a wrong answer, inspired by LeetCode

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    Runtime Error
    MemoryError:
    Line 91 in <module> (Solution.py)
    Line 25 in solve (Solution.py)

    Last Executed Input
    10
    633 9312
    1314 8548
    8857 1062
    6410 3289
    8594 1263
    8549 733
    3858 5973
    ... (3 more lines)
  • Listing 6: Example of feedback “Index Error” from our code environment in case of a wrong answer, inspired by LeetCode

    1
    2
    3
    4
    5
    6
    Runtime Error
    IndexError: list index out of range
    Line 28 in sortMatrix (Solution.py)

    Last Executed Input
    [[-1,-1,-1,-1,-1,-1,-1,-2,-1,-1,-1,-1,-1,-1,-

F.4 Illustrative Example,说明性示例

  • 图 22 展示了 SDPO 中密集信用分配的一个说明性示例
    • 图 22 为通过 SDPO 中的自教学进行密集信用分配
    • 答案由模型(Qwen3-8B)在看到反馈之前生成
    • 在看到反馈后,用 Self-Teacher 重新评估原始尝试的 log-prob
    • 本文展示了每个 Token 的
      $$ \log (\frac{\text{P}(\text{self-teacher})}{\text{P}(\text{student})})$$
      • 红色表示负值( Self-Teacher 不同意)
      • 蓝色表示正值( Teacher 强化)
      • 白色表示接近于零的值
    • 奖励分配:
      • GRPO 使用二元奖励,且会给序列中的所有 Token 分配相同的负 Advantage
      • SDPO 将反馈转化为序列上的密集信用分配
    • 第一行显示生成的 Response 的 Token
    • 其他 3 行显示 Self-Teacher 用于自蒸馏的 top-\(k\) logits,暗示了替代 Token
      • 理解:这里的三行表示每个 Token 位置最可能得前三个候选 Token
    • 在这个例子中, Self-Teacher 通过回溯识别了错误,而没有明确的解决方案
      • 对生成序列的信用分配,以及替代的 top-\(k\) logits,正确表明将 set 替换为 dict 可以保持元素的顺序
    • 在第七个显示的位置,模型还识别了一个替代的解决方案路径,该路径从已看到的 Token 开始,而不是直接返回输出
      • 激活是稀疏的,识别出错误发生的位置,并专门针对这几个 Token 调整 Student 的 Response 分布

NLP——LLM对齐微调-SDPO(Segment-Level-DPO)

注:本文包含 AI 辅助创作

  • 参考链接:
    • 原始论文:(Segment-Level-DPO)SDPO: Segment-Level Direct Preference Optimization for Social Agents, 202502, Nankai, Alibaba
      • 论文 GitHub 开源网址:AlibabaResearch/DAMO-ConvAI
      • 论文是 南开实习生在阿里实习时的工作

Paper Summary

  • 整体说明:
    • 论文提出了 Segment-level 直接偏好优化(Segment-Level Direct Preference Optimization, SDPO),用于提升 LLM-based Agent 在多轮社交对话中 的表现
    • 对比:与现有的多轮对齐方法(如 ETO 和 DMPO)不同:SDPO 通过聚焦 Session 中的关键片段(key segments)来优化 Agent 策略
  • 背景 & 问题:
    • LLM 驱动的社交 Agent 能够模拟人类社交行为,但在处理复杂社交对话时仍存在不足
    • DPO 在多种 Agent 任务中已被证明能有效对齐 LLM 行为与人类偏好
      • 标准 DPO 仅关注单轮对话,这限制了其在多轮社交互动中的效果
      • 一些基于DPO的多轮对齐方法(利用 Session-level 数据)已展现出解决这一问题的潜力
    • 这些方法考虑了整个 Session 中的多轮对话,但其粒度通常过于粗糙,会引入训练噪声,且缺乏坚实的理论支持
  • 本文解法:SDPO
    • 该方法通过动态选择交互中的关键分段(Segment)来优化多轮 Agent 行为
    • SDPO 能够最小化训练噪声,并建立在严格的理论框架之上
    • 在 SOTOPIA 基准测试上的评估表明,经 SDPO 调优的 Agent 在性能上 consistently 优于现有的基于 DPO 的方法以及 GPT-4o 等专有 LLM,这凸显了 SDPO 在提升 LLM-based Agent 社交智能方面的潜力
  • 注:论文的代码和数据已开源

Introduction and Discussion

  • LLM 的进步显著提升了其在语言理解和生成方面的能力,尤其是在人机交互领域
  • 通过融入身份特定信息,LLM-based Agent 能够模拟人类社交行为,在角色扮演休闲对话[22, 23]和模拟社交环境导航[24]等任务中展现出基本的社交智能
  • 然而,近期研究表明,在更复杂的、目标导向的社交场景(如谈判、竞争与合作)中,LLM 仍难以表现出人类社交互动中特有的细腻决策能力
  • 为应对这些挑战,研究者们开发了多种方法以更好地对齐 LLM 行为与人类偏好
    • 这些方法为改进 LLM 的社交决策提供了有前景的策略
  • 论文关注基于直接偏好优化(DPO)的方法。标准 DPO 通过识别单轮对话,并利用该轮中的“ Positive-Negative”响应对,通过偏好损失函数优化模型
    • 尽管 DPO 已展现出一定效果,但其对单轮对话的关注限制了其在目标导向社交对话中建模目标完成的能力,因为这类对话的成功通常依赖于跨越多轮的高质量互动
  • 为更有效地对齐多轮互动中的 Agent 行为,研究者提出了多种多轮对齐方法,包括 ETO[20]和 DMPO[26]
    • 这些方法将采样范围从单轮扩展至整个 Session ,通过构建 “good and bad” 的 Session 对并应用改进的 DPO 损失进行训练
    • 论文将这些方法归类为 Session-level DPO,它们在数据粒度和理论基础上均存在局限性
  • 从数据角度来看,Session-level DPO 因其较粗的对齐粒度而存在以下缺陷:
    • (i) Negative Session 中无错误的轮次也会被误判为 Negative 输出 ,从而引入大量噪声,对训练过程产生 Negative 影响
    • (ii) 从零开始采样为对话者(interlocutor)提供了巨大的行动空间。 Positive Session 的高分可能源于对话者行为的改变 ,这使得模型难以从 Positive 样本中学习正确的行为模式
      • 理解:这里是说高分 Session 可能不是 Agent 回答的好,而是 Interlocutor 回答的好
  • 从理论角度来看,在多轮场景中,直接应用 DPO 无法消除配分函数(partition function) \( Z \)(2024)
    • ETO 将 DPO 损失扩展至多轮互动,但缺乏形式化的理论保证
    • DMPO 结合了 SAOM 理论,成功将 \( Z \) 转换为常数;
    • 然而,由于 Positive 和 Negative Session 的轮次数量不同,DMPO 通过启发式长度归一化消除 \( Z \),但缺乏严格的数学证明
    • 这部分的详细的理论分析见附录B.2
  • 为克服 Session-level DPO 的局限性,论文提出了 Segment-level 直接偏好优化(SDPO)
    • 论文的方法将采样起点向后移动,并截断 Session 末尾的无用内容,从而获取关键分段对以细化粒度
    • 同时,论文确保 Positive 和 Negative 分段的轮次数量一致,从而消除 \( Z \) 并严格推导出 SDPO 损失
  • 具体而言,SDPO 的做法如下:
    • 识别 Negative Session 中的错误轮次 ,利用该错误轮次之前的交互历史进行多次采样 ,生成 Positive Session
      • 问题:怎么识别到错误轮次呢?回答:见后面的章节
    • 以首个差异轮次为起点,从 Positive Session 中选择有助于提升分数的关键分段 ,并从 Negative Session 中截取相同长度的对应分段以形成数据对
      • 问题:如何评估有助于提升分数?回答:见后面的章节
    • 针对分段内的轮次计算 SDPO 损失
  • 论文在图1中展示了三种社交对话(social dialogues)对齐算法的概览
  • 在数据层面,SDPO 能够解决 Session-level DPO 的缺陷,SDPO 优点如下:
    • (i) 仅针对 Negative 和 Positive 分段中的轮次计算损失,从而大幅消除了非错误轮次引入的训练噪声
    • (ii) 从错误轮次开始采样缩小了对话者的行动空间,使得采样得到的 Positive Session 更可能包含 Agent 的正确行为模式
      • 理解:这是相对上面 Session-level 对话对应的问题
  • 在理论上,得益于分段选择的灵活性,SDPO 能够控制 Positive 和 Negative 分段的轮次数量以确保一致性,从而消除 \( Z \) 并得到一个简洁而严谨的 SDPO 损失
  • 我们通过 SOTOPIA(2024)这一开放、交互式的社交智能基准测试对我们的方法进行了实证评估
    • 实验采用了自对话(self-chat)以及与其他智能体(包括 GPT-4o 和 GPT-4o-mini)的交互形式
    • 结果表明,经过 SDPO 调优的智能体在性能上始终优于 DPO、ETO、DMPO 等现有方法,甚至超越了 GPT-4o 等专有大语言模型,这充分证明了 Segment-level 对齐(segment-level alignment)的有效性
  • Segment level 是一种更灵活且统一的数据粒度,能够根据不同数据对动态选择优化范围,同时优雅地解决了多轮对齐的理论挑战。在论文中,我们主要将SDPO应用于提升智能体的社交智能,但我们相信该方法同样适用于其他场景,从而进一步扩展智能体在不同领域的能力
  • 论文的主要贡献包括以下三点:
    • 提出了 SDPO,这是一种新颖的多轮对齐算法,为 Segment-level 偏好数据对的构建提供了完整的流程,实现了更高效的对齐粒度
    • 指出了现有多轮对齐方法在理论上的局限性,并通过分段粒度(segment-level granularity)的灵活性解决了这些问题,严谨地推导出了一个简洁的 SDPO 损失函数公式
    • 在 SOTOPIA 这一模拟交互式社交基准测试上评估了论文的方法,SDPO 的表现及深入分析证明了 Segment-level 对齐的有效性和鲁棒性

Preliminary

SOTOPIA Environment

  • SOTOPIA 提供了一个交互式、开放且逼真的模拟环境,能够更精准地评估 Agent 的社交智能(不同于其他静态问答形式测试的社交基准(2019; 2024)不同)
    • SOTOPIA 中的社交任务包含一个场景、两个角色档案及其需要通过互动达成的私人社交目标
    • 场景与社交目标的多样化组合涵盖了谈判、合作与竞争等广泛的社交互动类型
    • SOTOPIA 定义了七个评估社交 Agent 的维度
  • 论文主要关注“目标(goal)”(0 to 10,int)和“关系(relationship)”(-5 to 5,int)
    • 因为 GPT-4o 在这些指标上的评分与人类评估高度一致
  • SOTOPIA-\(\pi\)(2024a)是一项后续工作,利用 GPT-4 自动构建了一组场景(与 SOTOPIA 完全不重叠),作为本研究的训练数据集
  • 此外,论文重构了 SOTOPIA 的提示组织格式以支持多轮对齐,具体细节见附录A

Task Formulation

  • 在 SOTOPIA 任务中,论文将 Agent 可用的背景信息记为 \( b \),包括场景、角色档案及其目标
  • Agent 在第 \( n \) 轮面临的交互历史 \( h_n \) 定义如下:
    $$
    h_n =
    \begin{cases}
    b, y_0, y’_0, \ldots, y_{n-1}, y’_{n-1}, & \text{if speak first} \\
    b, y’_0, y_0, \ldots, y_{n-1}, y’_n, & \text{if speak later}
    \end{cases} \tag{1}
    $$
    • \( y_i \sim \pi_\theta(\cdot|h_i) \) 表示 LLM-based Agent 在第 \( i \) 轮根据其策略 \( \pi_\theta \) 和参数 \( \theta \) 生成的输出;
    • \( y’_i \) 表示对话者的输出,其服从一个未知分布
  • 基于此形式化,论文在附录B.1和B.2中分别给出了 ETO 和 DMPO 的损失函数

Direct Preference Optimization

  • Rafailov 等 (2023) 提出了直接偏好优化(DPO),该方法利用成对偏好数据训练策略模型,而无需依赖强化学习(2022)。在社交对话场景中,论文将错误轮次编号记为 \( e \),DPO 损失函数定义如下:
    $$
    L_{DPO} = -\mathbb{E}_{(h_e, y^w_e, y^l_e) \sim D} \log \sigma
    \left[ \beta \log \frac{\pi_\theta(y^w_e|h_e)}{\pi_{ref}(y^w_e|h_e)} - \beta \log \frac{\pi_\theta(y^l_e|h_e)}{\pi_{ref}(y^l_e|h_e)} \right], \tag{2}
    $$
    • 其中 \( y^w_e, y^l_e \sim \pi_\theta(\cdot|h_e) \) 分别表示错误轮次中的 Positive 和 Negative 输出
  • 然而,由于其单轮优化的特性,DPO 并不适合社交对话
    • 将 DPO 严谨地扩展至多轮场景是一个待解决的挑战,论文将在第3.3节中探讨这一问题

Method

Behavioral Cloning

  • 注:使用 Behavioral Cloning 初始化 Agent
  • Behavioral Cloning(BC)作为一种有效的模仿学习方法,被广泛应用于各类 LLM-based Agent 构建中
  • 在本研究中,论文使用 GPT-4-turbo 作为专家模型,通过自对话(self-chat)以及与 GPT-4o 的交互,在 SOTOPIA-\(\pi\) 数据集上收集专家 Session 数据
  • 基于这些数据,论文对开源大语言模型(如 Llama-3.1)进行微调,构建了实验所需的初始社交 Agent

Preference Data Construction

  • 构建高质量的 Segment-level偏好数据对是本方法的核心
  • 在 SOTOPIA-\(\pi\) 数据集上,论文的社交 Agent 通过自对话以及与 GPT-4o 的交互生成数据
  • 论文设定目标维度(goal dimension)的阈值为 7,所有目标完成度低于此阈值的对话均被视为潜在的负样本
  • 给定一个负样本 Session,生成 Segment-level 数据对的流程包括以下三个步骤(如图2 所示):
  • 错误定位(Error Location)
    • 与数学等具有明确错误定义的场景不同,社交对话中的错误是一个相对模糊的概念
    • 在负样本 Session 中,如果论文的 Agent 在某一轮次(turn)的回应满足以下条件,则将该轮次 Token 为错误轮次:
      • (1) 该轮次对实现角色目标至关重要;
      • (2) 当前回应在目标完成度或双方关系改善方面仍有提升空间
    • 注:错误定位由 GPT-4o 完成,具体提示词(prompt)见附录 C.3
  • 正样本 Session 采样(Positive Session Sampling)
    • 在错误定位后,基于该轮次之前的交互历史,论文采样 5 个完整 Session
    • 从这些 Session 中选择目标分和关系分(优先考虑目标分)最高的一个
      • 如果最优 Session 的目标分或关系分高于负样本,则该 Session 与负样本组成数据对;
      • 否则丢弃该负样本
  • 段选择(Segment Selection)
    • 获得 Session-level 数据对后,论文将正负样本同时提供给 GPT-4o,提示其从正样本中选择一个段(segment)
      • 该段应包含对正样本获得更高目标分和关系分起到关键作用的部分
      • 接着从负样本中提取相同长度的段,与正样本段配对形成 Segment-level 数据对
      • 问题:相同长度的段是什么意思?正负样本对的段不一定都完全一样吧
    • 这一过程旨在排除与 Session 目标无关的轮次(如寒暄内容)
    • 注:给 GPT-4o 的提示词见附录 C.3
  • 论文评估了 GPT-4o 在错误定位和段选择任务中的表现,结论显示其能够有效完成这两项任务,详细分析见附录 C.2

SDPO Loss

  • 论文将 DMPO 的框架迁移到对话领域,并首先引入状态-动作占用测度(State-Action Occupancy Measure,SAOM)
  • 在此框架下:
    • 交互历史 \( h \) 作为状态(state),Agent 的输出 \( y \) 作为动作(action)
    • 策略 \( \pi \) 的折扣 SAOM \( d^{\pi}(h, y) \) 定义如下:
      $$
      d^{\pi}(h = h_t, y = y_t) = \gamma^t \cdot P(h_0) \cdot \prod_{k=0}^{t-1} \pi(y_k | h_k) P(h_{k+1} | h_k, y_k), \tag{3}
      $$
      • \( \gamma \) 为折扣因子
    • 基于 \( d^{\pi} \) 的强化学习目标为:
      $$
      \max_{\pi_\theta} \mathbb{E}_{(h,y) \sim d^{\pi_\theta}(h,y)} [r(h, y)] - \beta \mathbb{D}_{KL} [d^{\pi_\theta}(h,y) || d^{\pi_{ref} }(h,y)], \tag{4}
      $$
    • 根据 DPO,式(4)的最优解为:
      $$
      d^{\pi^*}(h,y) = \frac{1}{Z} d^{\pi_{ref} }(h,y) \exp \left( \frac{1}{\beta} r(h,y) \right), \tag{5}
      $$
      • \( \pi^* \) 为最优策略
      • \( Z \) 为归一化概率的配分函数(partition function)
  • 由于 \( d^{\pi}(h,y) \) 是 \( (h,y) \) 对的函数,对其进行归一化会导致配分函数 \( Z \) 与当前历史 \( h \) 无关
    • 因此,\( Z \) 对所有 \( (h,y) \) 对均为常数,这是消除 \( Z \) 的关键步骤
  • 奖励函数的形式为:
    $$
    r(h,y) = \beta \log \frac{d^{\pi^*}(h,y)}{d^{\pi_{ref} }(h,y)} + \beta \log Z. \tag{6}
    $$
  • 随后,论文使用 Bradley-Terry(BT)模型对偏好分布建模
    • 在这一步中,DMPO 错误地重复计算了式(3)中的 \( \gamma \),并在后续步骤中启发式地对长度进行归一化以消除 \( Z \),但缺乏严格证明(详细讨论见附录 B.2)
  • 给定偏好数据对,BT 模型的正确应用如下:
    $$
    p(\tau^w \succ \tau^l | h_0) = \sigma \left( \sum_{t=0}^{T_w - 1} r(h_t^w, y_t^w) - \sum_{t=0}^{T_l - 1} r(h_t^l, y_t^l) \right), \tag{7}
    $$
    • 其中 \( \tau^w \) 和 \( \tau^l \) 分别表示“胜”和“负”样本,\( T_w, T_l \) 为各自的轮次数
    • Session-level DPO 无法控制正负样本的长度,由于通常 \( T_w \neq T_l \),配分函数 \( Z \) 无法直接在式(7)中抵消
      • 理解:正负样本通常不同,此时配分函数是消不掉的
  • 与这些方法不同,SDPO 从正负样本中各选一个段进行优化,可以自由控制其长度
    • 通过确保两段长度相同,论文能够直接消除式(7)中的 \( Z \)
    • 同时,结合式(3)替换 \( d^{\pi} \),论文得到以下简洁的 SDPO 损失函数:
      $$
      L_\text{SDPO} = -\mathbb{E}_{(h_e, h^w, h^l) \sim D} \log \sigma \left[ \sum_{t=e}^{e+k} \beta \left( \log \frac{\pi_\theta(y_t^w | h_t^w)}{\pi_{ref}(y_t^w | h_t^w)} - \log \frac{\pi_\theta(y_t^l | h_t^l)}{\pi_{ref}(y_t^l | h_t^l)} \right) \right], \tag{8}
      $$
      • \( e \) 表示错误轮次的编号
      • \( k \) 为所选段中的总轮次数

Experiments

Datasets

  • 训练使用的 SOTOPIA-\(\pi\) 共包含 410 个场景:
    • 其中 100 个场景用于行为克隆(Behavioral Cloning, BC),每个场景包含 10 对角色;
    • 310 个场景用于对齐,每个场景包含 8 对角色
    • 测试使用的 SOTOPIA 包含 90 个场景,每个场景包含 5 对角色,总计 450 个自对话任务和 900 个非自对话任务

Experimental Setup

  • 训练(Training)
    • 论文主要使用 Llama-3.1-8B-Chat 作为基础 LLM 来构建社交 Agent
    • 最大 token 限制设置为 4096
    • 所有训练过程均采用 AdamW 优化器
    • 在 SFT 阶段:
      • 批大小为 32
      • dropout 率为 0.2
      • 学习率为 \(1 \times 10^{-5}\)
      • 预热比例为 5%
      • 采用余弦衰减调度
    • 在 SDPO 的对齐训练阶段:
      • 批大小保持为 32
      • SDPO 损失中的 \(\beta\) 设为 0.1
      • 学习率为 \(1 \times 10^{-6}\)
      • 无预热
      • 采用余弦衰减调度
    • SDPO 训练数据的统计细节见附录 C.1
  • SOTOPIA
    • 在正样本采样过程中:
      • 目标 Agent 的温度(temperature)设置为 1.0
      • 另一 Agent 的温度设置为 0.7
    • 测试时:
      • 论文将两个交互 Agent 的温度均设为 0.7
      • 尽管温度会为 Agent 的输出引入随机性,但论文发现评估结果在数值上保持稳定,因此论文基于单次测试报告结果(吐槽:这么稳定吗?仅单次报告就行?)

Baselines

  • 论文将提出的 SDPO 与以下几种强基线方法进行比较:
    • 1)OpenAI 的专有大语言模型(如 GPT-4o),具体版本见附录 D.1
    • 2)监督微调行为克隆(SFT Behavioral Cloning) :在专家交互数据上微调 LLM,生成的模型作为 SDPO 及其他基线的基础 Agent
    • 3)直接偏好优化(DPO) :基于单轮数据优化 Agent 策略,具体针对 SDPO 中正负样本的首个差异轮次
    • 4)基于探索的轨迹优化(ETO) :使用 Session-level 数据优化 Agent 策略
      • ETO 使用与 SDPO 相同的负样本 Session,但从头采样五个新 Session 来构成数据对
    • 5)动态多轮偏好优化(DMPO) :使用与 ETO 相同的数据,但采用 DMPO 新的损失函数更新策略
    • 6)偏好监督微调(Preferred-SFT) :在 SDPO 的正样本 Session 上微调基础 Agent

Results

  • 表1 展示了 SDPO 和所有基线方法在 SOTOPIA 上的结果
  • 如表所示,在目标和关系两个维度上,SDPO 显著优于标准 DPO、 Session-level 的 ETO 和 DMPO,甚至大幅超越 GPT-4o 等专有 LLM,凸显了 Segment-level 对齐的有效性
  • 通过分析 SOTOPIA 中的交互历史,论文发现较弱的 Agent 通常表现出固执性,仅重复表达需求,导致目标和关系水平较低,尤其是在自对话场景中
  • 使用专家数据进行行为克隆可以有效改善这种情况,使 Agent 更具沟通性
  • Llama-8B+BC 在与 GPT-4o 交互时目标率(goal rate)下降的原因是 Agent 变得更具说服力
  • 论文还观察到,对齐后的 Agent 在目标和关系(goal and relationship)上同时提升,这表明对齐方法确实增强了模型的社会智能(social intelligence),而非通过违反社会规范的行为(如威胁或欺骗)实现目标
  • 论文也在 Mistral-Instruct-v0.3 上重复了上面的实验,实验结果见 表2,实验设置的细节见附录 D.2
    • 在这些实验上,SDPO 也一致超过了基线模型,展示了其泛化性

Analysis

  • 模型输出长度的变化(Variation in Model Output Length)
    • 图3 展示了各种 Agent 与 GPT-4o 交互时的输出长度
      • 与 BC Agent 相比,所有对齐方法均增加了 Agent 的输出长度
      • 这种现象在将 DPO 应用于 AI 聊天机器人时常见
      • 但与用户对长回复的潜在偏见不同,有效的社交策略通常需要更多 token 进行沟通,因此输出长度的增加是合理的
    • 此外,论文实验在 SDPO 调优的 Agent 达到 10 轮交互时终止对话,以在相似 token 数量下比较性能
      • 结果显示,SDPO 仍优于其他多轮方法,表明 SDPO 调优的 Agent 能更高效地利用词汇
  • 正样本利用效率(Efficiency of Positive Sample Utilization)
    • 图4 展示了 Session-level 和 Segment-level 正样本的质量对比
      • 在采样次数一致的情况下, Session-level 正样本在目标和关系上优于 Segment-level 正样本
      • 从头采样提供了更大的采样空间,增加了生成高质量 Session 的可能性
      • 但尽管 Session-level DPO 使用了更高质量的数据,其表现仍不及 SDPO
        • 这表明由于 Segment-level 的更细粒度和对损失函数的理论支持,SDPO 能更高效地利用正样本
  • 论文还分析了 DPO 和 SDPO 对正负样本概率差异的影响,详细内容见附录 E.2

Ablation Study

  • 段选择方法(Segment Selection)
    • 论文探索了 SDPO 的不同段选择方法,结果如表3 所示
      • 方括号中,负样本段的长度在前,正样本段的长度在后
      • 对于对称段长度,固定长度为 3 和 5 的段优于长度为 1 的段(DPO),证明了多轮对齐的有效性
      • 长度为 5 的段效果略逊于长度为 3 的段,表明更长的段并非总是更好
    • 基于此,论文利用 GPT-4o 动态从每个正样本中识别关键段,取得了最佳结果
      • 对于非对称段长度,[3,1] 和 [5,3] 的模型训练崩溃,无法正常交互
      • 其他非对称段的表现均不及对称段,支持了 3.3 节的理论讨论
    • 此外,随着非对称程度的降低,模型性能提升,这可能是因为未消除的 \(Z\) 对损失的影响减弱
      • 这一发现有助于解释 ETO 的有效性,因其未对正负样本 Session 的长度施加约束
  • 采样对话者(Interlocutor for Sampling)
    • SDPO 的对齐数据分别使用 BC Agent 自身和 GPT-4o 作为对话伙伴收集
    • 论文在每个数据子集上独立训练模型,结果如表4 所示
      • 使用单一数据源训练的模型在自对话和与 GPT-4o 交互中均表现提升,进一步验证了 SDPO 的泛化能力
      • 组合数据集训练的模型优于单一数据集训练的模型,表明引入多样化对话伙伴的数据可以进一步提升 Agent 的社会智能
  • 分布外数据(Out-of-Distribution Data)
    • 基础 BC Agent 从 GPT-4-turbo 生成的专家数据中学习,使用 GPT-4-turbo 生成正样本是否能带来更好的性能?
    • 论文让 GPT-4-turbo 与 BC Agent 交互,并为 SDPO 采样 5 次
      • 生成的正样本在目标和关系评分上均优于自采样样本
    • 然而,如表4 所示,使用该数据训练的模型表现不及自采样方法
      • 这表明分布外正样本的效果不如分布内样本
      • 在训练过程中,分布外正样本的概率显著低于负样本,这种更大的概率差距可能是性能不佳的原因

Related Work

Social Intelligence

  • 社交智能(Social Intelligence)可以定义为 Agent在社会互动中理解、适应和回应他人情感、意图和行为的能力
  • 大多数关于社交智能的研究集中在评估方面,例如
    • SOCIALIQA(2019)强调对社会情境的常识推理
    • SocialIQ(2019)将评估模式从纯文本扩展到视频
    • Shapira 等(2023)使用 Faux Pas Test 评估 LLM
    • SocialBench(2024)则在个体和群体层面评估角色扮演 Agent 的社交性
  • 此外,一些研究(2019; 2024)从心理理论(theory-of-mind)的角度探讨模型的社交智能
  • 随着 LLM 的发展,基于大语言模型的社交 Agent 已经能够在真实社交场景中互动,传统的静态问答式基准(QA-style benchmarks)已不足以评估 Agent 的社交智能
    • SOTOPIA(2024)是目前唯一动态且交互式的社交基准,为当代社交 Agent 提供了模拟测试环境
    • 作者希望这项工作能够通过方法创新,进一步推动提升模型社交智能的研究

Alignment Methods with Refined Granularity

  • Rafailov 等(2023)提出了直接偏好优化(Direct Preference Optimization, DPO),该方法利用离线数据和简化的损失函数对齐大语言模型
  • 基于 DPO,多种细粒度对齐算法被开发出来。例如
    • Token-level DPO(2024)在 Token-level 别整合前向 KL 散度约束,同时提升对齐性和多样性
    • Step-DPO(2024)利用单个推理步骤(reasoning steps)进行偏好优化,而非整体答案级评估
    • SePO(2024)提出了一种 Token-level 奖励函数估计方法,选择性优化回答中的关键 Token
  • 然而,在社交对话或网络导航等多轮交互场景中,单轮对齐(single-turn alignment)是不够的
    • ETO 和 DMPO 将 Session-level 数据(session-level data)引入 DPO,将其扩展到多轮场景
  • 论文进一步提出了 SDPO,通过动态的 Segment-level 优化框架(segment-level optimization framework)实现多轮交互中更细粒度的对齐

Limitations

  • 论文提出的 SDPO 假设正负片段的长度相等 ,并在此假设下实现了 SOTA 性能
    • 具体而言,在从正样本中选出一个片段后,论文从负样本中选取相同长度的片段以消除配分函数 \( Z \)
    • 这种方法存在一定局限性
    • 负样本片段可能包含无关或无错误的轮次(turns),或未能捕捉到所有错误轮次,这表明在从负样本中选择片段时需要更细粒度的控制
    • 目前,作者尚未找到能够有效支持不等长片段对齐的理论框架
    • 作者希望这项工作能够激发进一步研究,鼓励多样化的理论分析以解决多轮对齐中的这一问题
  • 由于 SOTOPIA 是目前唯一可用的交互式社交基准 ,论文的实验仅在该数据集上进行
    • 未来,作者计划引入更多交互式 Agent 任务,以进一步验证 SDPO 的通用性

附录A Modifications to SOTOPIA

  • 在 SOTOPIA 中,每次交互都以单轮(single-turn)格式组织,这不支持多轮对齐(multi-turn alignment)
  • 为了解决这一限制,论文修改了提示组织格式(prompt organization format),如图5 所示
  • 这些修改在调用 LLM 的 API 之前应用,确保其对 SOTOPIA 本身不可见,且不会影响 GPT-4o 的评估
    • 更多细节可在论文的代码仓库中找到

附录B Supplementary Theoretical Analysis

B.1 ETO

  • Song等人(2024)提出了基于探索的轨迹优化(Exploration-Based Trajectory Optimization, ETO),该方法在没有严格证明的情况下将直接偏好优化(Direct Preference Optimization, DPO)扩展到 Session-level 别(session level)。其损失函数如下:
    $$
    L_{ETO} = -\mathbb{E}_{(b,h^{w},h^{l})\sim D} \log \sigma \left[ \sum_{t=0}^{T_{w}-1} \beta \log \frac{\pi_{\theta}(y_{t}^{w}|h_{t}^{w})}{\pi_{ref}(y_{t}^{w}|h_{t}^{w})} - \sum_{t=0}^{T_{l}-1} \beta \log \frac{\pi_{\theta}(y_{t}^{l}|h_{t}^{l})}{\pi_{ref}(y_{t}^{l}|h_{t}^{l})} \right] \tag{9}
    $$
    • \( h^{w} \) 和 \( h^{l} \) 分别表示完整的正向和负向交互历史
    • \( T_{w} \) 和 \( T_{l} \) 表示各自的轮数
  • 当 \( T_{w} = T_{l} \) 时,ETO 的损失函数与 SDPO的损失函数等价

B.2 Discussion on DMPO

  • 应用BT模型时的错误(Mistake when Applying BT Model) :在公式(6)之后,DMPO应用了布拉德利-特里模型(Bradley-Terry model, BT model)得到以下公式:
    $$
    p(\tau^{w} \succ \tau^{l}|h_{0}) = \sigma \left( \sum_{t=0}^{T_{w}-1} \gamma^{t} r(h_{t}^{w}, y_{t}^{w}) - \sum_{t=0}^{T_{l}-1} \gamma^{t} r(h_{t}^{l}, y_{t}^{l}) \right) \tag{10}
    $$
    • \( \tau^{w} \) 和 \( \tau^{l} \) 分别表示“胜”和“负”样本
    • \( T_{w} \) 和 \( T_{l} \) 表示各自的轮数
    • 仔细检查公式(10)可以发现,对 \( (h,y) \) 对的求和应排除 \( \gamma^{t} \),因为它已经被纳入 \( d^{\pi}(h,y) \) 中
  • 长度归一化的限制(Limitation of Length Normalization) :暂时忽略公式(10)中的错误,DMPO启发式地引入了基于轮数的正则化来消除 \( Z \):
    $$
    p(\tau^{w} \succ \tau^{l}|h_{0}) = \sigma \left( \frac{1-\gamma}{1-\gamma^{T_{w} } } \sum_{t=0}^{T_{w}-1} \gamma^{t} r(h_{t}^{w}, y_{t}^{w}) - \frac{1-\gamma}{1-\gamma^{T_{l} } } \sum_{t=0}^{T_{l}-1} \gamma^{t} r(h_{t}^{l}, y_{t}^{l}) \right)
    $$
    • 然而,DMPO 并未讨论为什么可以应用长度归一化或这一操作带来的影响。这种转换缺乏严格的理论依据
  • DMPO 损失函数(DMPO Loss Function) :遵循DMPO的方法,其损失函数如下:
    $$
    L_{DMPO} = -\mathbb{E}_{(b,h^{w},h^{l})\sim D} \log \sigma \left[ \sum_{t=0}^{T_{w}-1} \beta \phi(t,T_{w}) \log \frac{\pi_{\theta}(y_{t}^{w}|h_{t}^{w})}{\pi_{ref}(y_{t}^{w}|h_{t}^{w})} - \sum_{t=0}^{T_{l}-1} \beta \phi(t,T_{l}) \log \frac{\pi_{\theta}(y_{t}^{l}|h_{t}^{l})}{\pi_{ref}(y_{t}^{l}|h_{t}^{l})} \right]
    $$
    • 其中,折扣函数 \( \phi(t,T) = (1-\gamma^{T-t})/(1-\gamma^{T}) \)

附录C Data Construction Details

C.1 Statistics and Analysis of SDPO Data

  • SDPO 数据集包含 1019 对样本
    • GPT-4o 识别的错误轮次分布如表5 所示
    • GPT-4 识别的片段长度分布如表6 所示
    • 截断轮数的分布如 表7 所示
  • 结合表3 和表6,尽管在自动片段长度选择中,长度为 3 的片段占比近 90%,但自动选择方法的性能仍明显优于固定长度为 3 的方法,这凸显了自动选择方法的有效性

C.2 GPT-4o’s Performance in Pipeline

  • 论文从 SDPO 数据中随机选择 40 对样本,由三位作者独立评估 GPT-4o 在错误定位和片段选择中的表现
  • 在社交对话的背景下,正确性和错误的概念本质上是模糊的
  • 为此,论文定义了三个评估类别:正确、模糊和错误
  • 平均评估结果如表8所示。评估者均表示,模糊的主要原因是他们可以确定GPT-4o的选择是合理的,但难以判断是否最优。总体而言,论文得出结论:GPT-4o能够处理错误定位和片段选择任务

C.3 rompts in Data Construction

  • GPT-4o用于错误定位和片段选择的提示分别如图7 和图8 所示

附录D Supplementary Experimental Setup

D.1 Versions of OpenAI LLMs

  • 论文使用的 OpenAI 大语言模型版本如下:GPT-4o-2024-08-06、GPT-4-turbo-2024-04-09、GPT-4o-mini-2024-07-18 和 GPT-3.5-turbo-0125

D.2 Mistral Training Details

  • 与 Llama 的实验设置一致,最大 Token 限制设为 4096,所有训练过程均使用 AdamW 优化器
  • 在 SFT 阶段:
    • 批大小为 32
    • 丢弃率为 0.2
    • 学习率为 \( 3e^{-6} \)
    • 预热比例为 5%
    • 采用余弦衰减调度
  • 在 SDPO 的训练阶段:
    • 批大小为32
    • SDPO 损失中的 \( \beta \) 为 0.1
    • 学习率为 \( 5e^{-7} \)
    • 学习率无预热
    • 采用余弦衰减调度
  • Mistral 的 SDPO 数据构建过程与 Llama 相同

附录E Additional Empirical Results

E.1 SOTOPIA Hard Subset

  • SOTOPIA 将数据集中更具挑战性的部分 Token 为困难子集(Hard subset),详细结果如表9 所示
    • 各种方法在困难子集上的排名与完整数据集上的表现基本一致
    • SDPO 仍然取得了最佳结果,这表明 SDPO 在不同难度的场景中均能提升 Agent 的社交智能(social intelligence)

E.2 多轮对齐的必要性

  • 在 DPO 调整第一轮正向和负向片段的输出概率后,后续轮次中正向片段的概率会增加而负向片段的概率会降低吗?
  • 为了探究这一点,论文绘制了 DPO 和 SDPO 在训练过程中正向与负向片段概率差的变化,如图9 所示(只有SDPO可以直接与DPO比较,因此未提及ETO和DMPO)
  • DPO-turn 轨迹几乎与 DPO 轨迹平行,表明 DPO 对后续轮次的概率差几乎没有影响
  • 相比之下,SDPO 轨迹上升更陡峭。这些结果证明了显式修改整个片段内多轮概率分布的必要性,也为多轮对齐优于 DPO 提供了依据

NLP——技术报告解读-Step-3.5-Flash

注:本文包含 AI 辅助创作

  • 参考链接:
    • 原始论文:Step 3.5 Flash: Open Frontier-Level Intelligence with 11B Active Parameters, 20260211 & 20260223, StepFun Team

Paper Summary

  • 整体总结:
    • Step 3.5 Flash 是 MoE(196B-A11B),纯文本 的针对 Agentic 优化过的模型
    • 最大优点是计算效率高(也是名字中 Flash 的来源),亲测速度确实还不错
    • 交错使用 3:1 的 Sliding Window Attention(SWA)/Full Attention 和 Multi-Token Prediction(MTP-3)进行加速
    • 提出 MIS-PO 算法
      • 虽然使用了一种新的方式 MIS(Metropolis Independence Sampling)去描述 ,但本文 MIS(Metropolis Independence Sampling)的思路跟一些博客和文章中提到的 Masked Importance Sampling(MIS) 其实类似,并不是很创新
    • 不同领域专家构建 + Self-Distillation 融合不同领域专家能力
    • 在 Terminal-Bench 2.0 上达到 \(51.0%\),几乎与 GPT-5.2 xHigh 和 Gemini 3.0 Pro 等相当
    • 注:本文中,作者经过评测认为 XML 的工具模板比 JSON 更好(去年还有争论,目前似乎慢慢成为共识)
  • 架构情况简单介绍
  • 注:暂时重点对 Post-training 部分进行解读,后续有时间再补充其他部分

Post-Training

  • 用于大规模 RL 的统一后训练方案
    • 从一个统一 SFT 模型开始
    • 结合可验证的奖励信号和人类偏好反馈,实现了持续的自改进,在 MoE 模型的大规模 Off-policy 训练期间也能保持稳定性
    • 两阶段方案(类似 DeepSeek-V3.2):
      • 第一步(构建专家模型):在 SFT 的基础上,训练数学、代码、STEM、工具使用、长上下文理解、人类偏好和 Agentic Reasoning 等领域的特定领域 RL
      • 第二步(专家能力融合):使用 Self-Distillation 和可扩展 RL 将这些专门的专家蒸馏到一个通才模型中
        • 通过在目标专业化和广泛综合之间系统地交替,实现了强大的泛化能力,而不会牺牲专家级的性能

Expert Model Construction and Self-Distillation

  • SFT 阶段:采用 两阶段 SFT 流程 来构建用于后续 RL 的 Base
    • 第一阶段执行大规模的多领域 SFT,涵盖数学、代码、STEM、逻辑、通用问答、Code Agent、工具使用、Search Agent 和长上下文理解
      • 应用难度感知过滤和策略平衡来培养广泛的智能体行为
    • 第二阶段通过注入 OOD 信号 来明确最大化推理密度,这些信号包括约 3 万条专家级化学轨迹和合成算术任务
      • 这种对独特推理模式的有针对性暴露,仅在三个 Epoch 内就解锁了潜在能力,使模型具备了初始化后续特定领域 RL 阶段所需的复杂结构复杂性
  • 特定领域 RL 阶段:针对特定领域分别进行 RL
  • 专家能力合并:将不同的专家能力整合到一个统一的 Student 模型中,该模型从 Mid-train 检查点初始化 (注意:不是从 SFT 后的 ckpt 初始化)
    • SFT 数据构造:
      • 使用专家模型,在上面特定领域的 SFT 阶段使用的相同 Prompt 数据,重新 Rollout 得到 SFT 样本
        • 给直接 RL 集成提供了一种更稳定、更高效的替代方案
    • 在构造数据时,采用拒绝采样来消除不良模式,如语言混杂或过度思考,从而将专家知识集中到单个 Student 模型中
    • 通过建立这个高质量的 Base 模型,自蒸馏显著减轻了后续 RL 阶段的优化负担
Hyper-Parameters
  • 采用 Muon 优化器,进行 3% 的 warmup,并采用余弦衰减,学习率从 \(1.0 \times 10^{-5}\) 衰减到 \(5.0 \times 10^{-6}\)
  • 冻结 MoE 路由器的权重,并像 Mid-training 一样禁用 EP 组平衡损失
  • SFT 训练以 0.1 的 MTP 损失权重、32 的 Global batch size 和 128k 的 Global sequence length 进行
  • 关于 ROPE,作者保持 \(\theta_{SWA} = 10,000\),并调整 \(\theta_{Full} = 5,000,000\) 以适应 128k 的上下文长度 (2023)

Scalable RL

  • LLM RL 的目标是优化策略 \(\pi_{\theta}\) 以最大化轨迹 \(\tau = (s_0, a_0, \ldots , s_T)\) 上的终端奖励
    • 其中 \(a_t\) 表示在状态 \(s_t\) 生成的 Token
  • 对于推理任务,由于极高的时间跨度和模型规模进一步放大了由高梯度方差引起的严重不稳定性(图 5 (2))
  • 这种方差主要源于高吞吐量推理引擎和训练框架之间的基础设施差异,以及迭代更新固有的 Off-policy 偏差
  • 在这种设置下,Importance Sampling 本质上是非稳定的,因为微小的 Token-level 概率偏移会累积成噪声梯度,阻碍收敛
MIS-Filtered Policy Optimization, MIS-PO
  • 注意:这里的 MIS 名字来源不是之前其他文章中的 Masked IS
  • 本文提出了 MIS-PO,一种受 Metropolis 独立性采样 (Metropolis Independence Sampling, MIS) (1953; 1970) 启发的方法
    • 注:Metropolis Independence Sampling 是一种 MCMC 采样方法,详情见附录
    • 将推理策略视为 Proposal distribution,训练策略视为目标分布,并将更新限制在那些与目标分布保持足够接近的样本上
    • 与通过有界比率缩放梯度且通常遭受高方差影响的重要性采样不同,MIS-PO 应用二元掩码来过滤 Off-distribution 的样本,并将保留的轨迹视为有效的 On-policy ,从而显著降低梯度方差并实现稳定的优化
  • 定义一个二元指示函数并将其应用于两个不同的粒度级别
    $$\mathbb{I}(x) = \mathbb{1}[\rho_{\min}\leq x\leq \rho_{\max}]$$
    • 在 Token-level ,该函数过滤概率比率
      $$x_{t} = \frac{\pi_{\theta_{\text{old} } }(a_{t}|s_{t})}{\pi_{\theta_{\text{vllm} } }(a_{t}|s_{t})}$$
      • 以抑制训练策略和推理策略之间的局部不匹配 (2025)
    • 在 Trajectory-level ,将相同的指示函数应用于几何平均比率
      $$\bar{\rho} (\tau) = (\prod_{t}x_{t})^{\frac{1}{T} }$$
      • 有效地丢弃那些已显著偏离目标分布的整个轨迹,这个思想类似 GSPO
    • 重新制定的 Actor 损失用这些双级别离散掩码替代了连续的重要性权重:
      $$\mathcal{L}_{actor} = -\mathbb{E}_{\tau \sim \pi_{\theta_{\text{vllm} } } }\left[\mathbb{I}(x_t)\cdot \mathbb{I}(\bar{\rho} (\tau))\cdot \log \pi_\theta (a_t|s_t)\cdot \hat{A}_t\right]. \tag{2}$$
      • 理解:
        • 上述重要性权重有两级过滤(理论上稳定性是最高的)
        • 可以看到,该目标函数将有效样本视为 On-policy (重要性采样的输出是一个二元的值,有梯度回传的 Token 也没有重要性采样比例修正)
          • 论文中提到这在信任区域约束下显著降低了长时程推理任务的梯度方差
        • 问题:对于 off-policy 的场景,即使已经做过 MIS 过滤了,是否也应该继续保留重要性权重以恢复采样差异呢?(至少在数学上保证准确)
      • 图 5 展示了一个大约 5,000 个训练步骤的消融研究,其中 MIS-PO 在 Actor 梯度范数上表现出比 PPO 显著更低的噪声,表明了其改进的可扩展性
      • 更多消融研究见附录 D.2.3
  • 为进一步稳定训练动态,采用了多种技术:
    • 截断感知的价值引导 (Truncation-Aware Value Bootstrapping) (2026) :纠正由上下文长度截断引入的乐观奖励偏差
    • 路由置信度 (Routing Confidence) 监控:预测特定于 MoE 架构的不稳定性
Truncation-Aware Value Bootstrapping
  • 为截断的上下文轨迹分配零奖励会将截断与任务失败混为一谈
    • 这种模糊性通过未能区分不完整和错误的结果来惩罚长链推理
  • 为了解决这个问题,本文的做法是区分两者
    • 将零奖励替换为对最终状态的价值引导估计 (bootstrapped value estimate)
    • 将截断视为视野中断而不是终端失败
    • 轨迹 \(\tau_{i}\) 的修改后奖励定义为:
      $$\hat{R}_{i} = \left\{ \begin{array}{ll}V_{\phi}(s_{T}) & \text{if the response is truncated,}\\ R_{i} & \text{otherwise.} \end{array} \right. \tag{3}$$
      • 理解:可以看到
        • 发生截断时,其奖励估计是截断状态的估计值,即 \(V_{\phi}(s_{T})\)
        • 发生失败时,其分数正常表示为 \(R_i\),可能为 0 或 -1 等
  • 作者的实践经验表明
    • 这种截断感知的价值引导在截断率高达 20% 时也能稳定训练,防止了通常由不完整轨迹 (2025; 2025) 引发的奖励退化
  • 消融研究证实,该技术对 Competition-level 基准测试特别有益,因为长时程推理使截断效应最为普遍
Routing Confidence as a Stability Proxy
  • 一些之前的研究将 RL 稳定性与 MoE 路由一致性联系起来
  • 在此基础上,作者提出将路由置信度 \((\Sigma_{k})\) 作为稳定性的代理
    • 即激活专家的平均概率质量
      • 理解:被激活专家的平均概率(每个专家都有一个概率)
    • 较低的 \(\Sigma_{k}\) 意味着较高的路由不确定性,这会放大训练-推理不匹配
      • 理解:较低的 \(\Sigma_{k}\) 意味着被激活的专家并不是概率非常大的,不确定性大(类似熵较高)
  • 通过初步实验,确定了一个明显的稳定性相变:
    • 路由置信度低的模型是脆弱的,需要极端的稳定化措施
      • 极端稳定化措施如:路由器重放 (Router Replay) (2025) 和 严格的 On-policy 更新 等
    • 路由置信度高的模型保持鲁棒性,能够在没有复杂干预的情况下进行 Off-policy 训练
RL Training Dynamics
  • 本文在图 6 中展示了 Step 3.5 Flash 的 RLVR 训练动态和下游评估改进
  • 训练奖励的稳步上升表明了一个稳定且有效的学习过程
  • Step 3.5 Flash 在不同的评估基准上实现了一致的性能提升
    • IMO-AnswerBench (2025) 上显著提升了 \(+3.2%\)
    • CF-Div2-Stepfun-cpp(附录 E.2.1:作者定制的 CodeForces Div.2 基准)上提升了 \(+6.1%\)
    • ARC-AGI-1 (2019) 上提升了 \(+10.6%\)
    • \(\text{HLE}_{\text{text} }\) (2025) 上提升了 \(+3.4%\)
Reward System
  • 将 RL 框架解耦为 with Verifiable Rewards RL (RLVR) 和 with Non-verifiable Rewards RL(如 RLHF)
    • 每个部分都由一个根据其监督特性量身定制的独特奖励支持
Verifiable Rewards
  • 对于 RLVR,每个 Prompt 都与一个特定于任务的验证器配对,该验证器输出一个奖励
    • 基于规则的检查器用于逻辑、指令遵循和代码任务
    • 基于模型的验证器则用于 STEM 任务
  • 在作者内部模型上进行的 450 步 RL 训练消融研究中
    • 对 STEM 任务使用基于模型的验证器比直接使用普通的数学验证器平均高出 \(2.0%\)
  • 更多细节见附录 D.2.2
Non-Verifiable Reward
  • 使用 Pairwise GRM (2025) 来处理不可验证的任务,该模型根据固定参考来评估响应
    • 理解:这里使用的是 Pairwise 的 GRM,应该是输入一个 Response 和 一个 Reference Response,然后融 GRM 评估 Response 的胜率是多少
  • GenRM 是一个推理模型,输出一个置信度分数 ,指示一个 Response 获胜的可能性
    • 该分数随后被转换为 Bradley-Terry 胜率 (1952),作为奖励信号
  • Length Control 在 GenRM 中作为置信度分数惩罚进行建模 ,并传播到胜率奖励中,从而有效地抑制 RL 训练期间过度的长度增长
    • 理解:这里应该是 Prompt GRM 评估 Response 胜率时也考虑 Length Control
  • 通过对捏造的引用、过度自信的主张或语言不一致的 Response 分配零奖励来进一步确保鲁棒性
    • 问题:这里是通过 GRM 直接判断的吗?
Agent Reward(待讨论)
  • 搜索任务使用基于实体匹配分数的 LLM 进行评估
  • 对于报告生成,基于 Rubric 的 LLM judge 评估研究查询、Rubric 规范和候选报告,产生三元判断(满意、部分满意、不满意)(2025)
  • 由于中间类别经常与专家偏好不一致,本文将输出映射到非对称的二元奖励,从而产生更清晰的学习信号,并更快地收敛到符合专家行为的行为
    • 问题:如何理解这里的 非对称的二元奖励
GenRM Training and MetaRM
  • GenRM 初始化:使用 RM-Specific Prompts 微调本文的 SFT 模型来初始化
  • GenRM RL 训练:使用精心挑选的 Pairwise 偏好数据,并采用类似于标量奖励模型公式的 logsigmoid 损失
  • For 提高 GenRM 的鲁棒性
    • 集成 MetaRM(一个额外的验证器)来惩罚表现出虚假推理 (即, 从有缺陷的逻辑中得出正确的偏好 )的 Response
    • 当检测到此类存在问题的模式时,MetaRM 会降低奖励(避免模型陷入中间过程错误/结果正确的 Reward Hacking 模式)
  • 在作者的内部模型上进行的 200 步 RL 训练消融研究中,MetaRM-augmented GenRM 在每个基准测试上都比普通 GenRM 高出 \(0.5% - 3%\)
    • 注意:每个基准上都高
Hyper-Parameters
  • 对于 Rollout
    • 采样温度和 top-\(p\) 都设置为 1.0
    • 最大序列长度为 128k 个 Token
    • 每次生成
      • 对推理任务:采样 256 个唯一的 Prompt,每个 Prompt 生成 16 个 Response
      • 对人类偏好任务:采样 512 个唯一的 Prompt,每个 Prompt 生成 8 个 Response
      • 对工具使用任务:采样 128 个唯一的 Prompt,每个 Prompt 生成 8 个 Response
  • Rollout 后,完成的样本被划分成 mini-batch,并用于训练一个 Epoch
    • 其中 Actor 使用 4 个 mini-batch,Critic 使用 12 个 mini-batch
    • 优化使用 Muon 优化器进行,权重衰减为 0.1
  • Actor 的学习率为 \(2\times 10^{-6}\), Warmup 步数为 20
  • Critic 的学习率为 \(5\times 10^{-6}\),Warmup 步数为 50
  • 遵循 ORZ (2025),将 \(\gamma\) 和 \(\lambda\) 都设置为 1
  • 在最后阶段进一步采用了无偏的 KL 损失 (2025),系数为 0.001
  • 对于公式 (2),Token-level 和 Trajectory-level 的掩码边界分别设置为 [0.5, 2] 和 [0.996, 1.001]
    • 理解:可以证明 Trajectory-level 下 几何平均的 Mask 权重 方差非常小(类似 GSPO 一样,需要设置非常小的过滤边界才能生效),其重要性采样系数几乎在 1 附近,而 Token-level 的权重则波动较大

Data Synthesis & Curation

  • 整体说明:融合开源数据、合成数据和用户轨迹,构建了一个多样且难度均衡的 Prompt 池
    • 使用统一的合成和策展流程,结合严格的全局过滤和特定领域的优化,以最大化推理密度
    • 使用基于规则的启发式方法和基于模型的保真度检查相结合来确保数据质量
  • 生成的数据集包含 871K 个样本(7.23B Token),详细的统计数据总结在表 3 中
General and Reasoning
  • 本文的训练语料库聚合了来自不同开源社区的 Prompt、专家 Response 和合成数据
    • 包括数学 (2025; 注:这里的开源数据集非常多)、编码 (2023; 2025; 2025) 以及科学与开放式问答 (2023; 2025; 2025; 2024)
  • 为了最大化推理密度,本文采用了一个统一的流程,将严格的全局过滤与特定领域的优化相结合,通过基于规则的启发式方法和基于模型的保真度检查相结合来强制保证质量
    • 对数学,通过专家引导的拒绝采样和合成的大数算术来确保数值稳定性
    • 对编程,通过选择严谨的算法挑战来优先考虑离线可执行性,同时严格清除与 RAG 相关的幻觉
      • 问题:如何理解这里与 RAG 相关的幻觉?
    • 特别地, 本文减少了模型错误地声称可以访问外部搜索引擎或假装检索在线解决方案的倾向
    • 此外,本文将科学数据限制在具有唯一、可确定答案的明确问题上
  • 为了能够泛化到实际场景,本文还扩展了开源检查器,并使用一些现实世界的约束来增强样本
  • 同时,作者从开源、合成和用户轨迹中收集通用 Prompt,形成一个多样且难度均衡的池
    • 这个过程产生了一个包含数百万样本、达到十亿 Token-level 的高保真数据集
Generalized Tool Learning
  • 本文提出了一个执行驱动的数据生成框架,用于在智能体中学习可靠的工具使用行为,解决了现有合成流程中的关键限制,如数据不一致、缺乏可验证性和模型幻觉
  • 作者的方法不依赖于随机探索 (2025; 2025) 或基于模型的模拟 (2025; 2025),而是将工具使用行为分解为原子意图 (atomic intents),并使用有限状态机 (finite state machine, FSM) 对它们进行建模
    • 明确地将抽象的工具调用逻辑与参数化的执行约束分离开来
  • 数据通过一个 “采样-执行-验证” 循环与拒绝采样生成,所有候选轨迹都在真实环境中执行 ,并通过确定性反馈进行验证 ,从而确保保真度并消除幻觉行为
  • 通过组合方式组合原子意图,该框架支持复杂、可控的工具使用场景的可扩展生成
  • 使用此范式,本文构建了超过 10 万条高质量轨迹 ,总计数十亿个 Token,为基于工具的计划、推理和执行提供了精确的监督
Code Agents
  • Code Agent 可以通过可验证环境构建和解决方案生成之间的闭环干预来自我改进,其中可执行的反馈不断优化这两种能力
    • 本文作者将环境构建视为与错误修复和功能实现同等重要的第一类能力 (first-class capability),并在可验证的奖励信号下对其进行合成
  • 为此,本文作者开发了一个专门的智能体流程,该流程从 SWE-factory (2026) 框架演变而来
    • 结合了一个跨任务记忆池,用于检索历史构建成功案例作为 Few-shot 示例,以及一个循环检测机制,以防止冗余探索
    • 该流程实现了 40% 的环境构建成功率,通过来自构建轨迹(包括 Shell 命令和错误恢复)的密集监督,形成了一个模型自我进化的正反馈循环
  • 为了进一步提高信号质量,本文通过抽象和屏蔽那些对最终解决没有贡献的瞬态故障和冗余执行模式来规范化环境构建轨迹
    • 得到的环境充当动态测试平台,利用执行反馈和单元测试生成高质量的合成数据,用于持续对齐的奖励信号
  • 经验中观察到一种双向转移:
    • 构建专业知识加速了编码性能,而在这些环境中编码进一步提高了构建准确性,如 DockSmith (2026) 所示
    • 利用这个进化流程,作者精心挑选了 5 万个经过验证的环境,涵盖了超过 1.5 万个 GitHub 仓库和 20 多种编程语言
    • 这个多样化的集合捕捉了广泛的现实世界场景,为训练通才 Code Agent 提供了坚实的基础
  • 此外,作者还整合了几个著名的开源环境,包括 SWE-smith (2025)、SWE-Gym (2024)、R2E-Gym (2025)、SWE-rebench (2025) 和 SETA (2026)
Search and Research Agents
  • 为了促进高级信息搜索,本文的流程集成了基于图和多文档合成 的方法来强制实现多跳推理
    • 通过在知识图谱(例如,Wikidata5m (2021))上执行拓扑扩展并模拟跨网站浏览轨迹,生成了反映现实世界研究复杂性的数据
  • 为了保证外部检索的必要性,针对 DeepSeek-R1 (2025) 验证生成的查询,系统地排除 了这个强大的推理模型无需工具交互即可解决的实例
  • 生成的轨迹通过结构化的报告生成流程 (2025) 进行精炼
    • 该流程强制执行严格的指令遵循和结构完整性
    • 具体方法:
      • 强制执行对预设研究计划的严格遵守,丢弃任何偏离该结构的轨迹
      • 随后有效的输出通过基于模型的评判器和启发式规则进行迭代清洗,以解决诸如非正式写作、时间幻觉和混合语言等细微问题
  • 这种端到端的方法在 RESEARCHUBRICS (2025) 基准测试上实现了业界领先的性能

Agent Infrastructure

Reasoning with Tool-Use Template Design
  • 为了有效地将推理和智能体能力整合到一个基础模型中,确定思考过程和工具使用的适当模板至关重要
  • 关于推理模板,作者评估了三种管理策略
    • 每轮丢弃推理历史的方法 (2025) 虽然激励了独立生成,但在长时程任务(例如,超过 100 轮的编码会话)中会导致任务失败
    • 保留完整的推理历史会带来高昂的上下文消耗,这会迅速耗尽模型的容量并阻止后续的工具调用
    • 作者采用了一种选择性保留策略:仅为最近用户指令触发的工具使用轨迹保留推理痕迹
  • 这种设计实现了推理连贯性和上下文效率之间的最佳权衡,这一做法与最新的前沿模型 (2025; 2025) 一致
    • 关于工具使用模板,本文比较了流行的 JSON 和 XML 格式
      • JSON 的严格语法,包括转义序列和分隔符,经常导致小型、训练不足的模型出现解析错误
      • XML 格式允许平坦的字符串输出,语法开销显著更低
      • 因此,本文选择 XML 格式以确保在复杂、真实的智能体编码场景中的鲁棒性
Scalable Code Agent Infrastructure
  • 本文的集成架构侧重于可扩展的会话管理和跨框架泛化,以促进高吞吐量的智能体编码
    • 其核心是一个专有的 Session-Router,它通过 Kubernetes 编排容器生命周期,并通过 Tmux 确保交互一致性
  • 该架构支持数千个并发环境,具有无缝的状态持久性,无需手动配置特定于 Scaffold 的 Docker 配置
  • 为了确保跨不同智能体工作流的高度泛化,本文训练模型适应广泛的交互框架,从学术标准(例如,OpenHands (2024)、SWE-agent (2024) 和 Terminus-2 (2026))到企业级协议(例如,KiloCode (2026)、RooCode (2026) 和 ClaudeCode (2026))
  • 通过在训练期间让模型接触这些多样化的交互范式,有效地防止了它过拟合到特定的流程模式,确保无论底层执行环境如何,它都能保持鲁棒性

附录:Metropolis Independence Sampling (MIS) 介绍

  • Metropolis Independence Sampling (MIS) 是一种马尔可夫链蒙特卡罗 (Markov Chain Monte Carlo, MCMC) 方法
  • Metropolis Independence Sampling 是一种经典的 MCMC 方法,其核心在于使用独立于当前状态的提议分布进行采样
  • Step 3.5 Flash 论文中的创新之处在于,将 MIS 的思想从概率采样领域迁移到强化学习的策略优化中 ,用二元过滤替代了传统的重要性采样权重,从而在保持有效学习信号的同时,极大地降低了梯度方差,为大规模 MoE 模型的稳定 RL 训练提供了新的有效范式
    • 思路跟一些博客和文章中提到的 Masked IS 其实类似,并不是很创新

Metropolis Independence Sampling 方法的核心思想

  • Metropolis Independence Sampling 的核心目标是从一个复杂的目标分布 \( \pi(x) \) 中采样,当直接采样困难时,通过构建一个马尔可夫链来间接生成服从该分布的样本
  • Metropolis Independence Sampling 的基本框架是:
    • 1)提议分布 (Proposal Distribution) \( q(x) \):一个易于采样的分布,用于生成候选样本
    • 2)接受-拒绝机制 (Acceptance-Rejection Mechanism) :根据 Metropolis-Hastings 准则决定是否接受候选样本
  • 注:独立性采样 意味着提议分布 \( q(x) \) 不依赖于当前状态 \( x^{(t)} \),即 \( q(x’ | x^{(t)}) = q(x’) \)
    • 这与随机游走 Metropolis (Random Walk Metropolis) 形成对比,随机游走的提议分布依赖于当前位置

Metropolis-Hastings 接受概率

  • 对于 MIS,接受概率简化为:
    $$
    \alpha(x^{(t)}, x’) = \min\left(1, \frac{\pi(x’) / q(x’)}{\pi(x^{(t)}) / q(x^{(t)})}\right)
    $$
    • \( \pi(x) \) 是目标分布(未归一化的概率密度)
    • \( q(x) \) 是提议分布

论文中 MIS-PO 的创新应用

  • 在 Step 3.5 Flash 的 RL 框架中,作者将 MIS 的思想创造性地应用于策略优化,提出了 MIS-Filtered Policy Optimization (MIS-PO)
问题背景
  • 在 RL 训练中,由于推理引擎和训练框架的分离,以及迭代更新的特性,会产生训练-推理策略不一致的问题
  • 这导致传统的重要性采样 (Importance Sampling) 方法出现高方差和不稳定性,尤其是在 MoE 模型的长程推理任务中
MIS-PO 的核心创新
  • 论文将 MIS 的思想进行了以下转化(MCMC 概念 -> MIS-PO 中的对应):
    • 目标分布 \( \pi(x) \) -> 训练策略 \( \pi_{\theta} \)
    • 提议分布 \( q(x) \) -> 推理策略 \( \pi_{\theta_{\text{vllm} } } \)(用于生成 Rollout 数据)
    • 接受概率 \( \alpha \) -> 二元过滤函数 \( \mathbb{I}(x) \)
双重过滤机制
  • MIS-PO 引入了两个层级的过滤(同时过滤):
    • Token-level Filtering :
      $$
      x_t = \frac{\pi_{\theta_{\text{old} } }(a_t|s_t)}{\pi_{\theta_{\text{vllm} } }(a_t|s_t)}
      $$
      • 通过阈值 \([\rho_{\min}, \rho_{\max}]\) 过滤单个令牌的概率比,抑制局部的分布不匹配
    • Trajectory-level Filtering :
      $$
      \bar{\rho}(\tau) = \left(\prod_{t=0}^{T-1} x_t\right)^{\frac{1}{T} }
      $$
      • 计算几何平均概率比,丢弃整个偏离目标分布过远的轨迹(类似 GSPO)
  • 最终的 Actor 损失函数
    $$
    \mathcal{L}_{\text{actor} } = -\mathbb{E}_{\tau \sim \pi_{\theta_{\text{vllm} } } }\left[\mathbb{I}(x_t)\cdot \mathbb{I}(\bar{\rho}(\tau))\cdot \log \pi_\theta (a_t|s_t)\cdot \hat{A}_t\right]
    $$
    • 其中 \( \mathbb{I}(x) = \mathbb{1}[\rho_{\min} \leq x \leq \rho_{\max}] \) 是二元指示函数

MIS-PO 的优势

  • 论文通过实验验证了 MIS-PO 相对于传统方法(如 PPO 和 GSPO)的显著优势:
    • 1)更高的样本效率 :收敛速度更快,能够达到更高的奖励平台
    • 2)更强的稳定性 :有效抑制了策略梯度范数的剧烈波动,消除了训练过程中的大幅尖峰
    • 3)更好的探索-利用平衡 :熵值衰减更慢,保持了更持久的探索能力
    • 4)对 MoE 架构更友好 :有效控制了训练-推理的分布漂移,解决了 MoE 模型在离策略 RL 训练中的关键稳定性问题

NLP——旋转位置编码-RoPE

  • 参考链接:
    • 原始论文:(RoPE) RoFormer: Enhanced Transformer with Rotary Position Embedding, Arxiv 2023 & Neurocomputing 2024, 追一科技
    • 苏神博客:Transformer升级之路:2、博采众长的旋转式位置编码
    • 旋转式位置编码 (RoPE) 知识总结 - Soaring的文章 - 知乎,一篇把知识串的比较好的博客

原始 Transformer

基本 Attention 公式

  • 在标准的Transformer模型中,自注意力机制(Self-Attention)的公式是核心组成部分
  • 给定查询矩阵 \( Q \)、键矩阵 \( K \) 和值矩阵 \( V \),注意力输出计算为:
    $$
    \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k} }\right)V
    $$
    • \( Q \in \mathbb{R}^{n \times d_k} \), \( K \in \mathbb{R}^{m \times d_k} \), \( V \in \mathbb{R}^{m \times d_v} \)(\( n \)是目标序列长度,\( m \)是源序列长度)
    • \( d_k \) 是键/查询向量的维度
    • \( \sqrt{d_k} \) 用于缩放点积,防止梯度消失

Multi-Head Attention

  • Transformer使用多头注意力扩展基本注意力:
    $$
    \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W^O
    $$
  • 每个头的计算为:
    $$
    \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
    $$
    • \( W_i^Q \in \mathbb{R}^{d_{\text{model} } \times d_k} \), \( W_i^K \in \mathbb{R}^{d_{\text{model} } \times d_k} \), \( W_i^V \in \mathbb{R}^{d_{\text{model} } \times d_v} \)
    • \( W^O \in \mathbb{R}^{hd_v \times d_{\text{model} } } \) 是输出投影矩阵
    • \( h \) 是头的数量,通常满足 \( d_k = d_v = \frac{d_{\text{model} } }{h} \)

加入位置编码(仅修改输入即可)

  • 在Transformer中,输入会加上正弦位置编码 \( P \in \mathbb{R}^{d_{\text{model}}} \):
    $$
    X = \text{Embedding}(x) + P
    $$
    • 其中 \( P \in \mathbb{R}^{d_{\text{model}}} \) 的每个元素为:
      $$
      P_{pos, 2i} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model} } } }\right), \quad
      P_{pos, 2i+1} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model} } } }\right)
      $$

Self-Attention 完整公式(以单头为例)

  • 对于一个输入序列 \( X \in \mathbb{R}^{n \times d_{\text{model} } } \):
    $$
    \begin{aligned}
    Q &= XW^Q, \quad K = XW^K, \quad V = XW^V \\
    \text{Attention}(X) &= \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k} } + M\right)V
    \end{aligned}
    $$
    • \( M \) 是可选的掩码矩阵(如解码器的因果掩码)
    • 注:实际实现时,也可以直接在进入 Softmax 操作前,将 \(\frac{QK^\top}{\sqrt{d_k} }\) 的结果置为最小值 \(-e^9\),效果是等价的

Self-Attention 简单实现

  • Self-Attention的Python代码简单实现
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    # 单头Attention
    class ScaledDotProductAttention(nn.Module):
    def__init__(self, d_k):
    super().__init__()
    self.d_k = d_k

    def forward(self, Q, K, V, mask=None):
    # Q, K, V shape: (batch_size, seq_len, d_k(or d_v))
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
    if mask is not None:
    scores = scores.masked_fill(mask == 0, -1e9)
    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, V)
    return output
    # 多头Attention
    class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
    super().__init__()
    assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

    self.d_model = d_model
    self.num_heads = num_heads
    self.d_k = d_model // num_heads

    self.W_q = nn.Linear(d_model, d_model)
    self.W_k = nn.Linear(d_model, d_model)
    self.W_v = nn.Linear(d_model, d_model)
    self.W_o = nn.Linear(d_model, d_model)
    # 使用单头Attention
    self.attention = ScaledDotProductAttention(self.d_k)

    def split_heads(self, x):
    """
    x shape: (batch_size, seq_len, d_model)
    return shape: (batch_size, num_heads, seq_len, d_k)
    """
    batch_size, seq_len, _ = x.size()
    return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
    """
    x shape: (batch_size, num_heads, seq_len, d_k)
    return shape: (batch_size, seq_len, d_model)
    """
    batch_size, _, seq_len, _ = x.size()
    return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

    def forward(self, Q, K, V, mask=None):
    Q = self.W_q(Q)
    K = self.W_k(K)
    V = self.W_v(V)

    Q = self.split_heads(Q)
    K = self.split_heads(K)
    V = self.split_heads(V)

    # 如果需要,扩展mask以匹配多头(这里假设了mask是为单头准备的)
    if mask is not None:
    mask = mask.unsqueeze(1) # (batch_size, 1, seq_len) -> (batch_size, 1, 1, seq_len)

    attn_output = self.attention(Q, K, V, mask)
    output = self.combine_heads(attn_output)
    output = self.W_o(output)
    return output

固定位置编码实现

  • 固定位置编码,比如正弦位置编码,直接在 Attention 之前将位置编码向量加入到原始向量 \(X\) 中,Attention代码不需要做任何修改

Rotary Position Embedding, RoPE

  • 本节符号和原始论文 (RoPE) RoFormer: Enhanced Transformer with Rotary Position Embedding, Arxiv 2023 & Neurocomputing 2024, 追一科技 符号保持一致
  • 旋转位置编码(RoPE)的核心思想 :通过旋转矩阵将位置信息融入Self-Attention 机制中
  • 基本定义 :对于位置\( m \)的词向量\( \boldsymbol{x}_m \in \mathbb{R}^d \),通过线性变换得到查询向量\( \boldsymbol{q}_m \)和键向量\( \boldsymbol{k}_n \):
    $$
    \boldsymbol{q}_m = W_q \boldsymbol{x}_m, \quad \boldsymbol{k}_n = W_k \boldsymbol{x}_n
    $$
  • 旋转操作 :将\( \boldsymbol{q}_m \)和\( \boldsymbol{k}_n \)划分为\( d/2 \)个复数对(每组2维,RoPE要求维度必须是偶数,这一般都能满足),对第\( i \)组复数应用旋转矩阵:
    $$
    \begin{aligned}
    \boldsymbol{q}_m^{(i)} &= \begin{pmatrix}
    q_{m,2i} \\
    q_{m,2i+1}
    \end{pmatrix}, \quad
    \boldsymbol{k}_n^{(i)} = \begin{pmatrix}
    k_{n,2i} \\
    k_{n,2i+1}
    \end{pmatrix} \\
    R_{\theta_i}^m &= \begin{pmatrix}
    \cos m\theta_i & -\sin m\theta_i \\
    \sin m\theta_i & \cos m\theta_i
    \end{pmatrix}, \quad \theta_i = 10000^{-2i/d}
    \end{aligned}
    $$
    • 注意:位置为 \(m\) 的旋转矩阵对应正余弦角度为 \(\color{red}{m}\theta_i\)
    • 理解:旋转矩阵 \(R_{\theta_i}^m\) 可以将目标向量进行旋转,\(R_{\theta_i}^m \boldsymbol{x}\) 相当于将 \( \boldsymbol{x}\) 向逆时针方向旋转 \(m\theta_i\) 度(注意:只是旋转,并不修改原始向量的模长,因为 \(R_{\theta_i}^m\) 是正交矩阵),详情见附录
  • 旋转后的向量 :旋转后的查询和键向量为:
    $$
    \begin{aligned}
    \boldsymbol{q}_m’ = \bigoplus_{i=0}^{d/2-1} R_{\theta_i}^m \boldsymbol{q}_m^{(i)}, \quad
    \boldsymbol{k}_n’ = \bigoplus_{i=0}^{d/2-1} R_{\theta_i}^n \boldsymbol{k}_n^{(i)}
    \end{aligned}
    $$
    • 其中\( \oplus \)表示向量拼接:
      $$\bigoplus_{i=0}^{d/2-1} R_{\theta_i}^m \boldsymbol{q}_m^{(i)} = \text{Concat}(\{ R_{\theta_i}^m \boldsymbol{q}_m^{(i)}\}_{i=0}^{d/2-1})$$
  • 旋转后的Attention权重变化
    $$
    \begin{equation}
    (\boldsymbol{\mathcal{R}}_m \boldsymbol{q}_m)^{\top}(\boldsymbol{\mathcal{R}}_n \boldsymbol{k}_n) = \boldsymbol{q}_m^{\top} \boldsymbol{\mathcal{R}}_m^{\top}\boldsymbol{\mathcal{R}}_n \boldsymbol{k}_n = \boldsymbol{q}_m^{\top} \boldsymbol{\mathcal{R}}_{n-m} \boldsymbol{k}_n
    \end{equation}
    $$
    • 位置为 \(m\) 的向量 \(\boldsymbol{q}_m\) 乘以矩阵 \(\boldsymbol{\mathcal{R}}_m\);位置为 \(n\) 的向量 \(\boldsymbol{k}_n\) 乘以矩阵 \(\boldsymbol{\mathcal{R}}_n\)(注意角标)
    • 上面的式子中等式是恒成立的(\( \boldsymbol{\mathcal{R}}_m^{\top}\boldsymbol{\mathcal{R}}_n = \boldsymbol{\mathcal{R}}_{\color{red}{n-m}}\)的详细证明见附录),右边的 \(\boldsymbol{\mathcal{R}}_{\color{red}{n-m}}\)仅与相对位置 \(n-m\) 有关,体现了相对位置编码的核心要义
      • 注:\(\boldsymbol{\mathcal{R}}_{m-n}\) 和 \(\boldsymbol{\mathcal{R}}_{n-m}\) 不相等,旋转角度相同,但方向相反
  • 展开成矩阵相乘的形式为(refer to Transformer升级之路:2、博采众长的旋转式位置编码):
    $$
    \begin{equation}\scriptsize{\underbrace{\begin{pmatrix}
    \cos m\theta_0 & -\sin m\theta_0 & 0 & 0 & \cdots & 0 & 0 \\
    \sin m\theta_0 & \cos m\theta_0 & 0 & 0 & \cdots & 0 & 0 \\
    0 & 0 & \cos m\theta_1 & -\sin m\theta_1 & \cdots & 0 & 0 \\
    0 & 0 & \sin m\theta_1 & \cos m\theta_1 & \cdots & 0 & 0 \\
    \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\
    0 & 0 & 0 & 0 & \cdots & \cos m\theta_{d/2-1} & -\sin m\theta_{d/2-1} \\
    0 & 0 & 0 & 0 & \cdots & \sin m\theta_{d/2-1} & \cos m\theta_{d/2-1} \\
    \end{pmatrix}}_{\boldsymbol{\mathcal{R}}_m} \begin{pmatrix}q_0 \\ q_1 \\ q_2 \\ q_3 \\ \vdots \\ q_{d-2} \\ q_{d-1}\end{pmatrix}}\end{equation}
    $$
  • 由于旋转矩阵是一个稀疏矩阵,所以旋转过程可以改进为如下等价实现:
    $$
    \begin{equation}\begin{pmatrix}q_0 \\ q_1 \\ q_2 \\ q_3 \\ \vdots \\ q_{d-2} \\ q_{d-1}
    \end{pmatrix}\otimes\begin{pmatrix}\cos m\theta_0 \\ \cos m\theta_0 \\ \cos m\theta_1 \\ \cos m\theta_1 \\ \vdots \\ \cos m\theta_{d/2-1} \\ \cos m\theta_{d/2-1}
    \end{pmatrix} + \begin{pmatrix}-q_1 \\ q_0 \\ -q_3 \\ q_2 \\ \vdots \\ -q_{d-1} \\ q_{d-2}
    \end{pmatrix}\otimes\begin{pmatrix}\sin m\theta_0 \\ \sin m\theta_0 \\ \sin m\theta_1 \\ \sin m\theta_1 \\ \vdots \\ \sin m\theta_{d/2-1} \\ \sin m\theta_{d/2-1}
    \end{pmatrix}\end{equation}
    $$
    • 其中 \(\otimes\) 是按位相乘
  • RoPE下的Attention公式总结 :(旋转位置编码的核心公式)
    $$
    \begin{aligned}
    \text{Attention}(\boldsymbol{x}) &= \text{softmax}\left(\frac{(\boldsymbol{q}’)^\top \boldsymbol{k}’}{\sqrt{d} }\right)V
    \end{aligned}
    $$
    • 注:\(V\) 是 Attention 中的 Value 矩阵,中不需要位置编码信息
    • 这里使用 \((\boldsymbol{q}’)^\top \boldsymbol{k}’\),转置在 \(\boldsymbol{q}’\) 上,和原始论文表达方式一致,实际上这种表示是OK的,数学中常用这种表示 ,这种表示下,向量为列向量;原始 Transformer 论文中的符号转置在 Key 上,此时向量为行向量
  • 原始论文中的RoPE示意图:

多头注意力下的 RoPE

  • 为了跟传统的 Transformer 符号对齐,本节改用 \(Q,K,V\)表示矩阵,与 RoPE 原始论文符号不再一致
  • 给定输入序列 \( X \in \mathbb{R}^{n \times d_{\text{model} } } \),先投影到查询、键、值空间:
    $$
    Q = XW^Q, \quad K = XW^K, \quad V = XW^V
    $$
    • 其中 \( W^Q, W^K \in \mathbb{R}^{d_{\text{model} } \times d_k} \), \( W^V \in \mathbb{R}^{d_{\text{model} } \times d_v} \)
  • 应用旋转位置编码(RoPE) :对 \( Q \) 和 \( K \) 的每个位置 \( m \) 和 \( n \) 的分量应用旋转矩阵 \( R_{\theta}^m \) 和 \( R_{\theta}^n \):
    $$
    \begin{aligned}
    Q’ &= \text{RoPE}(Q) = \bigoplus_{i=0}^{d_k/2-1} R_{\theta_i}^m Q^{(i)} \\
    K’ &= \text{RoPE}(K) = \bigoplus_{i=0}^{d_k/2-1} R_{\theta_i}^n K^{(i)}
    \end{aligned}
    $$
    • \( Q^{(i)} \in \mathbb{R}^2 \) 和 \( K^{(i)} \in \mathbb{R}^2 \) 是 \( Q \) 和 \( K \) 的第 \( i \) 个二维分量
    • \( \oplus \)表示向量拼接:
      $$\bigoplus_{i=0}^{d_k/2-1} R_{\theta_i}^m Q^{(i)} = \text{Concat}(\{R_{\theta_i}^m Q^{(i)}\}_{i=0}^{d_k/2-1} )$$
  • 旋转矩阵 \( R_{\theta_i}^m \) 定义为:
    $$
    R_{\theta_i}^m = \begin{pmatrix}
    \cos m\theta_i & -\sin m\theta_i \\
    \sin m\theta_i & \cos m\theta_i
    \end{pmatrix}, \quad \theta_i = 10000^{-2i/d_k}
    $$
  • 多头注意力输出
    $$
    \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_H) W^O \\
    \text{head}_h = \text{Softmax}\left(\frac{Q^{\prime}_h {K^{\prime}_h}^\top}{\sqrt{d_k} }\right) V_h
    $$
    • 其中 \( W^O \in \mathbb{R}^{H d_v \times d_{\text{model} } } \) 是输出投影矩阵
  • 多头注意力下的 RoPE 的 Attention 公式总结 :
    $$
    \begin{aligned}
    \text{Attention}(Q,K,V) &= \text{softmax}\left(\frac{(\text{RoPE}(Q))(\text{RoPE}(K))^\top}{\sqrt{d_k} }\right)V \\
    \text{where} \quad \text{RoPE}(X) &= \bigoplus_{i=0}^{d/2-1} R_{\theta_i}^m X^{(i)}
    \end{aligned}
    $$

多头注意力下的 RoPE 实现

  • 多头注意力下的RoPE PyTorch 实现 (待补充)
    • 注:多头注意力下,每个头是独立编码的(每个头维度从 0 开始),且使用的旋转矩阵一样,即旋转矩阵的维度 \(d = d_{\text{head}} = d_{\text{model}}/h\)

关于 RoPE 的一些讨论

  • RoPE 传统 Transformer 的区别 :
    • 传统Transformer:位置编码是加性的(\( X + P \))
    • RoPE:位置编码是乘性的(通过旋转矩阵直接修改 \( Q \) 和 \( K \))
  • 相对位置保持性 :
    • 旋转后的注意力分数 \( [Q’ K’^\top]_{m,n} \) 仅依赖于相对位置 \( m-n \),满足线性注意力性质
  • 远程衰减性 :(注:如下图所示的远距离衰减,会导致太远的距离下,难以区分位置,效果不好,外推性变差?)

附录:相对位置旋转公式证明

  • 目标:证明 \( \boldsymbol{\mathcal{R}}_m^{\top}\boldsymbol{\mathcal{R}}_n = \boldsymbol{\mathcal{R}}_{\color{red}{n-m}}\)
  • 考虑到 \(\boldsymbol{\mathcal{R}}_m\) 是以二维子矩阵为单位的“对角”矩阵,故只要证明 \((R_{\theta_i}^m)^\top R_{\theta_i}^n = R_{\theta_i}^{\color{red}{n-m}}\) 即可,证明过程如下:
  • 给定旋转矩阵 \( R_{\theta_i}^m \) 定义为:
    $$
    R_{\theta_i}^m = \begin{pmatrix}
    \cos m\theta_i & -\sin m\theta_i \\
    \sin m\theta_i & \cos m\theta_i
    \end{pmatrix}
    $$
  • 其转置矩阵为:
    $$
    (R_{\theta_i}^m)^\top = \begin{pmatrix}
    \cos m\theta_i & \sin m\theta_i \\
    -\sin m\theta_i & \cos m\theta_i
    \end{pmatrix}
    $$
  • 计算 \((R_{\theta_i}^m)^\top R_{\theta_i}^n\)
    $$
    (R_{\theta_i}^m)^\top R_{\theta_i}^n = \begin{pmatrix}
    \cos m\theta_i & \sin m\theta_i \\
    -\sin m\theta_i & \cos m\theta_i
    \end{pmatrix}
    \begin{pmatrix}
    \cos n\theta_i & -\sin n\theta_i \\
    \sin n\theta_i & \cos n\theta_i
    \end{pmatrix}
    $$
  • 回顾三角函数和差角公式:
    $$
    \sin(A \pm B) = \sin A \cos B \pm \cos A \sin B \\
    \cos(A \pm B) = \cos A \cos B \mp \sin A \sin B
    $$
  • 计算矩阵乘积的每个元素:
    • 左上角元素:
      $$
      \begin{align}
      \cos m\theta_i \cdot \cos n\theta_i + \sin m\theta_i \cdot \sin n\theta_i &= \cos(n\theta_i - m\theta_i) \\
      &= \cos((n - m)\theta_i) \\
      \end{align}
      $$
    • 右上角元素:
      $$
      \begin{align}
      \cos m\theta_i \cdot (-\sin n\theta_i) + \sin m\theta_i \cdot \cos n\theta_i &= -\cos m\theta_i \sin n\theta_i + \sin m\theta_i \cos n\theta_i \\
      &= \sin m\theta_i \cos n\theta_i - \cos m\theta_i \sin n\theta_i\\
      & = \sin((m-n)\theta_i) \\
      & = - \sin((n-m)\theta_i)
      \end{align}
      $$
    • 左下角元素:
      $$
      \begin{align}
      -\sin m\theta_i \cdot \cos n\theta_i + \cos m\theta_i \cdot \sin n\theta_i &= -\sin m\theta_i \cos n\theta_i + \cos m\theta_i \sin n\theta_i \\
      &= \sin n\theta_i \cos m\theta_i - \cos n\theta_i \sin m\theta_i \\
      &= \sin((n - m)\theta_i)
      \end{align}
      $$
    • 右下角元素:
      $$
      \begin{align}
      -\sin m\theta_i \cdot (-\sin n\theta_i) + \cos m\theta_i \cdot \cos n\theta_i &= \sin m\theta_i \sin n\theta_i + \cos m\theta_i \cos n\theta_i \\
      &= \sin m\theta_i \sin n\theta_i + \cos m\theta_i \cos n\theta_i \\
      &= \cos((n - m)\theta_i)
      \end{align}
      $$
  • 因此,乘积矩阵为:
    $$
    (R_{\theta_i}^m)^\top R_{\theta_i}^n = \begin{pmatrix}
    \cos((n - m)\theta_i) & -\sin((n - m)\theta_i) \\
    \sin((n - m)\theta_i) & \cos((n - m)\theta_i)
    \end{pmatrix} = R_{\theta_i}^{\color{red}{n-m}}
    $$
  • 至此,我们证明了:
    $$
    (R_{\theta_i}^m)^\top R_{\theta_i}^n = R_{\theta_i}^{\color{red}{n-m}}
    $$
  • 证毕

附录:不同参数下 RoPE 对 Attention 的影响

  • RoPE原始论文已经说明了随着距离的增长,Attention Score 有越来越小的趋势(且长距离部分会波动)
  • 下面是固定 query_index 下,Attention Score 随 key_index(横轴)变化的图像(代码参考链接修改的 旋转式位置编码 (RoPE) 知识总结 - Soaring的文章 - 知乎)
  • 从上图可以观察到:
    • 图1:RoPE确实有远距离衰减趋势(震荡递减),且dim_model=256时,q,k 距离为500时Attention值已经衰减的较小了
    • 图2:与图1类似,RoPE的衰减是对称的(图2展示的是当 query_index=256时的图像);
      • 注:注意虽然 Attention Score 是对称相等的,但是旋转角度是相反的,即 \(\boldsymbol{\mathcal{R}}_{m-n}\) 和 \(\boldsymbol{\mathcal{R}}_{n-m}\) 旋转角度相同,但方向是相反的
    • 图3:RoPE 在拉到足够长的距离后,不会一直衰减(从后续的图可以知道实际上还是周期函数,只是周期很大)
    • 图4+图5:与图1对比可以发现,RoPE 在拉到足够长的距离后,实际上还是周期函数,只不过周期与 d_model 相关(d_model越大,周期越长),图4说明了当 d_model=8 时,周期是 10000 左右
    • 图6:缩小了图5的横轴区间,将图5的前半部分图像放大了看,是在小周期上震荡的,且还存在图5所示的大周期
  • 附上图的代码:
    >>>点击展开折叠内容...
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    # refer to: [旋转式位置编码 (RoPE) 知识总结 - Soaring的文章 - 知乎](https://zhuanlan.zhihu.com/p/662790439)
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.axes import Axes

    def create_sin_cos_table_cache(max_num_tokens, dim_model):
    # 所有pos下对应的cos/sin值分别存储为矩阵
    theta = 10000 ** (-np.arange(0, dim_model, 2) / dim_model)
    theta = theta.reshape(-1, 1).repeat(2, axis=1).flatten()

    pos = np.arange(0, max_num_tokens)
    table = pos.reshape(-1, 1) @ theta.reshape(1, -1) # [max_num_tokens, dim_model]

    sin_table_cache = np.sin(table)
    sin_table_cache[:, ::2] = -sin_table_cache[:, ::2]

    cos_table_cache = np.cos(table)
    return sin_table_cache, cos_table_cache

    def rotate_half(q_vec):
    # 将q_vec的值两个一组分组并在分组内对调,实现从[q_0,q_1,q_2,q_3,...,q_{d-1},q_d]到的转换[q_1,q_0,q_3,q_2,...,q_d,q_{d-1}]
    return q_vec.reshape(-1, 2)[:, ::-1].flatten()

    def rotary(vec, pos, sin_table, cos_table):
    # 原始论文中的公式
    return vec * cos_table[pos] + rotate_half(vec) * sin_table[pos]

    def plot(plt_obj: Axes, pic_index, query_index=0, dim_model=256, max_num_tokens=8192, step=1):
    # q_vec 和 k_vec 都设定为1,仅关注 RoPE 引发的Attention Score的变化
    q_vec = np.ones(dim_model)
    k_vec = np.ones(dim_model)
    sin_table, cos_table = create_sin_cos_table_cache(max_num_tokens, dim_model)

    rotated_q_vec = rotary(q_vec, query_index, sin_table, cos_table)
    k_indices = np.arange(0, max_num_tokens, step)
    rotated_k_vecs = rotary(k_vec, k_indices, sin_table, cos_table)
    attn_scores = (rotated_k_vecs @ rotated_q_vec) / np.sqrt(dim_model) # 未经过Softmax的Attention权重,用于展示RoPE对原始Attention Score

    plt_obj.plot(k_indices, attn_scores)
    plt_obj.set_title(f"Figure {pic_index}: query_index={query_index}, dim_model={dim_model}")
    plt_obj.set_xlabel("key index")
    plt_obj.set_ylabel("Attention Score")

    plt.rcParams.update({
    "font.sans-serif": ["Times New Roman", ],
    "font.size": 10
    })

    _, axes = plt.subplots(nrows=2, ncols=3, figsize=(10, 10))
    plot(axes[0, 0], 1, query_index=0, max_num_tokens=512)
    plot(axes[0, 1], 2, query_index=256, max_num_tokens=512)
    plot(axes[0, 2], 3, query_index=0, dim_model=256, max_num_tokens=65535)
    # plot(axes[1, 0], 4, query_index=0, dim_model=32, max_num_tokens=65535)
    plot(axes[1, 0], 4, query_index=0, dim_model=16, max_num_tokens=65535)
    plot(axes[1, 1], 5, query_index=0, dim_model=8, max_num_tokens=65536)
    plot(axes[1, 2], 6, query_index=0, dim_model=8, max_num_tokens=512)
    plt.show()

附录:如何改进 RoPE 以实现绝对位置编码?

  • 当前的设计仅能实现相对位置编码(因为 Query 和 Key 的内积只与他们的相对位置有关,与绝对位置无关),但如果在 Value 上也施加位置编码则能实现绝对位置编码的能力
  • 让研究人员绞尽脑汁的Transformer位置编码中苏神提到:

    这样一来,我们得到了一种融绝对位置与相对位置于一体的位置编码方案,从形式上看它有点像乘性的绝对位置编码,通过在 \(\boldsymbol{q},\boldsymbol{k}\) 中施行该位置编码,那么效果就等价于相对位置编码,而如果还需要显式的绝对位置信息,则可以同时在 \(\boldsymbol{v}\) 上也施行这种位置编码。总的来说,我们通过绝对位置的操作,可以达到绝对位置的效果,也能达到相对位置的效果


附录:旋转体现在哪里?

  • 旋转体现在 \(\boldsymbol{q}_m\)(或\(\boldsymbol{k}_n\)) 的每两个维度组成的向量 \(\boldsymbol{q}_m^{(i)} = \begin{pmatrix} q_{m,2i} \\ q_{m,2i+1} \end{pmatrix}\) 经过 RoPE 变换前后,他们的向量长度不变,即:
    $$
    \begin{pmatrix}cos(m\theta_1) &-sin(m\theta_1) \\ sin(m\theta_1) &cos(m\theta_1)\end{pmatrix} \begin{pmatrix}q_{m,2i} \\ q_{m,2i+1}\end{pmatrix} = \begin{pmatrix}q_{m,2i}\cdot cos(m\theta_1) -q_{m,2i+1}\cdot sin(m\theta_1) \\ q_{m,2i}\cdot sin(m\theta_1)+ q_{m,2i+1}\cdot cos(m\theta_1) \end{pmatrix} \\
    $$
  • 进一步,由于 \((cos(m\theta_1))^2 + (sin(m\theta_1))^2 = 1\) 有
    $$
    \begin{align}
    \left|\begin{pmatrix}q_{m,2i}\cdot cos(m\theta_1) -q_{m,2i+1}\cdot sin(m\theta_1) \\ q_{m,2i}\cdot sin(m\theta_1)+ q_{m,2i+1}\cdot cos(m\theta_1) \end{pmatrix} \right| &= \sqrt{\left(q_{m,2i}\cdot cos(m\theta_1) -q_{m,2i+1}\cdot sin(m\theta_1)\right)^2 + \left(q_{m,2i}\cdot sin(m\theta_1)+ q_{m,2i+1}\cdot cos(m\theta_1)\right)^2} \\
    &= \sqrt{(q_{m,2i})^2 + (q_{m,2i+1})^2}
    \end{align}
    $$
  • 也就是说: \(\boldsymbol{q}_m\)(或\(\boldsymbol{k}_n\))相邻两两维度在变换前后的向量长度并没有变化,是一个旋转操作

附录:可视化RoPE旋转过程

  • 旋转矩阵 \(R_{\theta_i}^m\) 的定义如下:
    $$
    \begin{aligned}
    R_{\theta_i}^m &= \begin{pmatrix}
    \cos m\theta_i & -\sin m\theta_i \\
    \sin m\theta_i & \cos m\theta_i
    \end{pmatrix}, \quad \theta_i = 10000^{-2i/d}
    \end{aligned}
    $$
  • 旋转矩阵 \(R_{\theta_i}^m\) 可以将目标向量进行旋转,\(R_{\theta_i}^m \boldsymbol{x}\) 相当于将 \( \boldsymbol{x}\) 向逆时针方向旋转 \(m\theta_i\) 度
  • 当 \(m\theta_i = \frac{\pi}{4}\) 时,其旋转可视化结果如下:
  • 实现上述旋转的代码如下
    >>>点击展开折叠内容...
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
       import numpy as np
    import matplotlib.pyplot as plt

    # 设定原始向量 x
    x = np.array([1, 0])

    # 角度转换为弧度
    angle = np.radians(45)

    # 旋转矩阵定义
    R = np.array([[np.cos(angle), -np.sin(angle)],
    [np.sin(angle), np.cos(angle)]])

    # 矩阵惩罚实现旋转向量
    x_rotated = R @ x

    plt.rcParams['figure.dpi'] = 300

    plt.figure(figsize=(6, 6))
    plt.quiver(0, 0, x[0], x[1], angles='xy', scale_units='xy', scale=1, color='b', label='Original Vector')
    plt.quiver(0, 0, x_rotated[0], x_rotated[1], angles='xy', scale_units='xy', scale=1, color='r', label='Rotated Vector')

    plt.xlim(-1.5, 1.5)
    plt.ylim(-1.5, 1.5)

    plt.grid(True)

    plt.legend()
    plt.title('Vector Rotation')
    plt.xlabel('X-axis')
    plt.ylabel('Y-axis')

    plt.savefig("./demo.png")
    plt.show()

附录:RoPE的诞生历史

  • RoPE方案来自苏神原始博客:
    • 2021年2月在让研究人员绞尽脑汁的Transformer位置编码中提出想法
    • 2023年3月在Transformer升级之路:2、博采众长的旋转式位置编码中给出详细方案和推导,同时提交论文到arXiv上
    • 原始论文:Roformer: Enhanced Transformer With Rotray Position Embedding,该论文24年发表于Neurocomputing期刊上(《Neurocomputing》是国际知名期刊,被列为中科院SCI二区top期刊,CCF-C类期刊)
    • 随后,各种开源大模型开始使用RoPE,RoPE逐渐成为大模型的标配

附录:RoPE的高维扩展

  • 论文介绍当前的是一维 RoPE,每 2 维为一组,旋转矩阵为 \(2\times 2\),在二维 RoPE 场景中,可以每 4 维为一组,旋转矩阵变成 \(4\times 4\) 即可,详情见 旋转式位置编码 (RoPE) 知识总结
  • 将四个维度作为一组:
    $$
    \boldsymbol{R}_{m_1,m_2} =
    \begin{bmatrix}
    \cos m_1 \theta & -\sin m_1 \theta & 0 & 0 \\
    \sin m_1 \theta & \cos m_1 \theta & 0 & 0 \\
    0 & 0 & \cos m_2 \theta & -\sin m_2 \theta \\
    0 & 0 & \sin m_2 \theta & \cos m_2 \theta
    \end{bmatrix}
    $$
  • 上述分组下满足:
    $$
    \mathbf{R}_{m_1,m_2}^{\top} \cdot \mathbf{R}_{n_1,n_2} = \mathbf{R}_{n_1 - m_1,n_2 - m_2}
    $$
  • 注:更高维度的可以继续扩展,比如三维的扩展为每 6 维为一组,旋转矩阵为 \(6\times 6\) 即可

附录:RoPE中的复数和旋转矩阵等价性证明

问题定义

  • 定义二维旋转矩阵和向量如下:
    $$
    \begin{align}
    R_{\theta_i}^m &= \begin{pmatrix}
    \cos m\theta_i & -\sin m\theta_i \\
    \sin m\theta_i & \cos m\theta_i
    \end{pmatrix} \\
    \boldsymbol{x} &= \begin{pmatrix} x_1 \\ x_2
    \end{pmatrix}
    \end{align}
    $$
  • 目标:证明下面的等式
    $$R_{\theta_i}^m \boldsymbol{x} = z e^{i m\theta_i}$$
    • 其中 \( z = x_1 + i x_2 \) 是向量 \(\boldsymbol{x} = \begin{pmatrix} x_1 \\ x_2 \end{pmatrix}\) 的复数形式
    • 即目标是证明:旋转矩阵 \( R_{\theta_i}^m \) 作用在 \(\boldsymbol{x}\) 上相当于将复数 \( z \) 乘以旋转因子 \( e^{i m\theta_i} \)

证明

  • 方程左边展开
    $$
    R_{\theta_i}^m \boldsymbol{x} = \begin{pmatrix}
    \cos m\theta_i & -\sin m\theta_i \\
    \sin m\theta_i & \cos m\theta_i
    \end{pmatrix}
    \begin{pmatrix} x_1 \\ x_2 \end{pmatrix} =
    \begin{pmatrix}
    x_1 \cos m\theta_i - x_2 \sin m\theta_i \\
    x_1 \sin m\theta_i + x_2 \cos m\theta_i
    \end{pmatrix}
    $$

    • 这个结果对应的复数为:
      $$
      R_{\theta_i}^m \boldsymbol{x} = (x_1 \cos m\theta_i - x_2 \sin m\theta_i) + i (x_1 \sin m\theta_i + x_2 \cos m\theta_i)
      $$
  • 等方程右边展开

    • 右边的 \(\boldsymbol{x} e^{i m\theta_i}\) 表示将复数 \( z = x_1 + i x_2 \) 乘以 \( e^{i m\theta_i} \),即:
      $$
      z e^{i m\theta_i} = (x_1 + i x_2)(\cos m\theta_i + i \sin m\theta_i)
      $$
    • 可以重新整理为:
      $$ z e^{i m\theta_i} = (x_1 \cos m\theta_i - x_2 \sin m\theta_i) + i (x_1 \sin m\theta_i + x_2 \cos m\theta_i) $$
    • 展开后与旋转矩阵作用的结果完全一致
  • 结论: 在复数表示下,旋转矩阵 \( R_{\theta_i}^m \) 作用在向量 \(\boldsymbol{x}\) 上等价于将对应的复数 \( z \) 乘以旋转因子 \( e^{i m\theta_i} \)。因此,等式成立:
    $$
    \color{red}{R_{\theta_i}^m \boldsymbol{x} = \boldsymbol{x} e^{i m\theta_i}}
    $$


附录:旋转位置编码的其他推导过程

  • 下面的推导过程来自通俗易懂-大模型的关键技术之一:旋转位置编码rope(2)和通俗易懂-大模型的关键技术之一:旋转位置编码rope (3),推导过程看过,基本没有问题,先截图,以后有时间再手打一遍
  • \(q,k\) 向量旋转后再进行Attention,注意图中 \(e^{im\theta}\) 的 \(i\) 是虚数的意思,这里使用二维向量 \((\boldsymbol{W}_q\boldsymbol{x}_m)\) 乘以一个虚数的本质是想表达向量内积的意思(虚数可以展开成二维向量),此外,\(a_{m,n}\) 表示 Attention 权重,在不考虑权重时值为 Softmax,Element-wise看:\(a_{m,n} =\frac{\exp(\frac{q^T_m k_n}{\sqrt{d_k}})}{\sum_{j=1}^N \exp(\frac{q^T_m k_j}{\sqrt{d_k}})}\),注:下图中的表达有误,实际上应该是 \(a_{m,n}=\frac{\exp(\frac{ {x^{\prime}_m}^T x^{\prime}_n}{\sqrt{d_k} })}{\sum_{j=1}^N \exp(\frac{ {x^{\prime}_m}^T x^{\prime}_j}{\sqrt{d_k} })}\)(详情见原始论文:(RoPE) RoFormer: Enhanced Transformer with Rotary Position Embedding, Arxiv 2023 & Neurocomputing 2024, 追一科技 )
  • 公式1,2,推导出公式3的过程(内积角度)
  • 公式1,2,推导出公式3的过程(公式角度)
  • 内积角度和公式角度推导结果可以对齐
  • 扩展到多维的方式,将模型d_model维度按照两两一组分(注意这里不是序列两两分,序列本来做Attention就是两两做内积的),这里要求模型维度是偶数的
  • 公式化简
  • 代码实现
    • 上述实现中使用了torch.einsum,本质是爱因斯坦求和约定 ,是矩阵乘法的一种表示,C = torch.einsum("n,d->md", A,B)表示矩阵C[m,d] = A[n]*B[d],,这是个很常用的省略写法,更详细的可以看看图学 AI:einsum 爱因斯坦求和约定到底是怎么回事?

NLP——VeRL框架使用总结

  • 参考链接:
    • 源码地址:github.com/volcengine/verl
    • 官方教程文档:https://verl.readthedocs.io/
      • 官方配置链接:verl.readthedocs.io/en/latest/examples/config.html
      • 官方调优链接:verl.readthedocs.io/en/latest/perf/perf_tuning.html
      • 官方设备调优链接:verl.readthedocs.io/en/latest/perf/device_tuning.html
      • 其他官方示例文档:
        • https://github.com/volcengine/verl/blob/main/examples/tuning/7b/qwen2-7b_grpo_2_h800_fsdp_vllm.sh
        • https://github.com/volcengine/verl/blob/main/examples/tuning/14b/qwen2-14b_grpo_4_h800_fsdp_vllm.sh
        • https://github.com/volcengine/verl/blob/main/examples/grpo_trainer/run_qwen2-7b.sh
    • DeepWiki 解读:deepwiki.com/volcengine/verl
    • 官方公开讲座(青稞社区):verl 源码解读与 HybridFlow 编程范式讲解
    • 字节跳动Seed官方解读:最高提升20倍吞吐量!豆包大模型团队发布全新 RLHF 框架,现已开源!
    • 相关论文:HybridFlow: A Flexible and Efficient RLHF Framework, EuroSys 2025, HKU & ByteDance
      • 论文解读:HybridFlow / veRL 原文浅析 - Chayenne Zhao的文章 - 知乎
    • 其他解读:
      • 基于 Ray 的分离式架构:veRL、OpenRLHF 工程设计 - 杨远航的文章 - 知乎
      • verl:一个集SFT与RL于一体的灵活大模型post-training框架 (快速入门) - Cyril-KI的文章 - 知乎
      • [AI Infra] VeRL 框架入门&代码带读 - 不关岳岳的事的文章 - 知乎
      • 跟着 verl 代码学习 GRPO 算法流程 - 想当大侠的文章 - 知乎
      • 跟着 verl 代码学习 PPO 算法流程 - 想当大侠的文章 - 知乎
      • 从零开始的verl框架解析 - Nasusu的文章 - 知乎
    • verl 参数速览 - Chayenne Zhao的文章 - 知乎
    • 不错的系列文章:
      • RLHF Infra — Verl 学习(一):Overview - swtheking的文章 - 知乎
      • RLHF Infra — Verl 学习(二):Initialization - swtheking的文章 - 知乎
      • RLHF Infra — Verl 学习(三):Sample Generation - swtheking的文章 - 知乎
      • RLHF Infra — Verl 学习(四): Train Data Organize & Reward Model - swtheking的文章 - 知乎
      • RLHF Infra — Verl 学习(五): Review Verl - swtheking的文章 - 知乎
      • RLHF Infra — Verl 学习(六)Fully Async Policy Trainer - swtheking的文章 - 知乎

环境安装

  • 参考链接:verl.readthedocs.io/en/latest/start/install
  • 建议使用 docker 镜像安装方式,亲测本地直接安装坑很多,且安装后还会陆陆续续出现错误
  • 特别注意:官方镜像加载后还需要执行本地安装 pip3 install --no-deps -e .
    • 不执行这一步会提示 verl 库找不到
    • 建议将代码拉到本地 host 机器,然后用镜像挂载 host 路径
  • 注:官方镜像可能缺失一些依赖包,比如我就遇到缺少 vllm 库,遇到这种情况直接安装即可
    • 最新测试过可用的镜像为:verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.0-fa2.7.4,仅需要自己安装一个 vllm 即可,还有个较小的包按需要安装
      1
      2
      3
      4
      5
      sudo docker create --gpus all --net=host --shm-size="10g" --cap-add=SYS_ADMIN -v ../verl:/workspace/verl -v ~/llm:/workspace/llm --name verl verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.0-fa2.7.4 sleep infinity
      sudo docker start verl
      sudo docker exec -it verl bash
      cd verl && pip3 install --no-deps -e .
      sudo docker stop verl

模型训练

  • Quick Start 可参考:verl.readthedocs.io/en/latest/start/quickstart
  • 多节点启动:verl.readthedocs.io/en/latest/start/multinode

源码阅读

verl 库的目标

  • 将原始问题建模为一个有向图 DataFlow 问题
  • 统一实现,让算法开发者仅需要考虑自身的代码优化即可

数据流的流向过程

  • 原始论文的图片

SPMD 的初始化

  • 在 RayPPOTrainer.init_workers() 内找到相关流程
  • 对每个 资源池分别初始化(for resource_pool, class_dict in self.resource_pool_to_cls.items():)
  • 每个资源池进行如下操作(self.ray_worker_group_cls)
    • 进一步地,执行函数 self._init_with_resource_pool
    • for 循环依次处理每个 GPU(每个 GPU 启动一个进程),每个进程配置好对应的分别是环境变量
    • 每个 GPU 对应一个 worker

数据的分发是如何实现的

  • 每个 Worker 的函数都会接受来自上游的数据,处理数据并输出
    • 注意传入每个 Worker 的数据已经是分布式处理过的,仅仅是 1/WORLD_SIZE,这里的数据分发是使用 @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) 定义的
    • @register 是一个注解,用于实现数据分发过程和收集过程,dispatch_modee=Dispatch.DP_COMPUTE_PROTO 会对应的绑定两个函数(分别负责分发和收集)

每个 Worker 的大致工作流程(Multi Controller 逻辑核心)

  • 具体函数:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    def func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, blocking):
    class Functor:
    def __call__(this, *args, **kwargs):
    args, kwargs = dispatch_fn(self, *args, **kwargs)
    padding_count = kwargs.pop(_padding_size_key, 0)
    output = execute_fn(method_name, *args, **kwargs)
    if blocking:
    output = ray.get(output)
    output = collect_fn(self, output)
    if padding_count > 0:
    if isinstance(output, DataProto):
    indices = [i for i in range(len(output))][:-padding_count]
    output = output.select_idxs(indices)
    elif isinstance(output, list):
    output = output[:-padding_count]
    return output

    # use class type to pass the method_name to get a better observability
    return type(method_name, (Functor,), {})()
  • 核心函数名为 func_generator,这个函数会接受5个参数 method_name,dispatch_fn,collect_fn,execute_fn, blocking

    • dispatch_fn 负责 dispatch 参数
    • execute_fn 负责根据 dispatch 后的参数调用 method_name 函数(使用 getattr 方法实现)
    • blocking 决定在这里是否等待 execute_fn 执行完成
    • collect_fn 负责收集 collect_fn 函数返回的分组结果
  • 注:原始代码中的 type(method_name, (Functor,), {})() 表示:动态创建一个名为 method_name、继承自 Functor 且无自定义属性的类,然后实例化该类,最终得到一个继承自 Functor 的实例对象


verl 编程接口

数据集修改(最简单)

  • 保证数据集符合 verl 的格式即可,verl 要求数据是 .parquet 格式,且包含下面 5 列

    • prompt:是一个 message list,每个 message 是 {"role":"...", "content": "..."} 的格式
    • data_source:数据来源,比如 gsm8k 来自 openai/gsm8k
    • ability:数据分类,比如 gsm8k 属于 math 类
    • reward_model:是一个字典,比如 {'ground_truth': '72', 'style':'rule'} 说明使用规则型 reward 模型
    • extro_info:是一个字典,作为额外的信息在训练中使用,可以包含一些自定义信息,比如 PPO 官方示例中的 gsm8k 数据处理就是将 prompt 的 answer 放进去了,完整格式为:{'answer': '...', 'index': 0, 'question': '[原始问题]', 'split': 'train'}
      • 注:extro_info 的 [原始问题] 比 prompt 的 content 少一些模板内容
  • 注:支持 VLM 时,需要 images 和 videos 这样的字段

  • 注:建议使用 pandas 加载数据后多看:

    1
    2
    import pandas as pd
    df = pd.read_parquet(file_path)
  • 数据处理的参考模板见:examples/data_preprocess/ 目录下,比如 gsm8k 数据集的处理文件是 examples/data_preprocess/gsm8k.py

  • 特别地:还可以自定义数据类,通过参数将定义类的 Python 文件路径和类名传入并注册到 verl 中即可,详情见:verl 源码解读与 HybridFlow 编程范式讲解:40:06

自定义 Reward

  • reward fuction 的参数定义:

    1
    2
    3
    4
    5
    custom_reward_function:
    path: null # 指定源码路径
    name: compute_score # 指定函数
    reward_model:
    reward_manager: naive # 指定 reward_manager 类 NaiveRewardManager
  • 可以通过参数传入,示例如下:

    1
    2
    3
    --custom_reward_function.path=./examples/reward_fn/custom_reward_fn.py \
    --custom_reward_function.name=compute_score \
    --reward_model.reward_manager=naive
  • 函数定义可参考 NaiveRewardManager 类的定义

自定义损失函数

  • 全局搜索找到 .backward() 函数调用的地方,这里就是损失定义的地方
    • 在这里可以修改函数 compute_policy_loss
    • 也可以添加其他损失项,比如 交叉熵损失
  • verl 的损失函数定义方式和 llama_factory 的模板类有点相似,是通过将 loss 注册到 POLICY_LOSS_REGISTRY: dict[str, PolicyLossFn] = {} 中实现的
  • 可以通过修改 POLICY_LOSS_REGISTRY: dict[str, PolicyLossFn] = {} 所在文件增加自己的损失函数

修改整个训练逻辑(最复杂)

  • 核心是修改 fit 函数
  • DAPO 的实现类 RayDAPOTrainer 就是继承 RayPPOTrainer 后实现的
    1
    2
    3
    4
    5
    6
    7
    8
    class RayDAPOTrainer(RayPPOTrainer):
    """
    Note that this trainer runs on the driver process on a single CPU/GPU node.
    """

    def fit(self):
    # ...
    # DAPO 的 fit 实现

模型融合

  • 训练完成模型是按照 GPU,以分片的形式存储的,所以需要进行模型融合

    1
    2
    3
    4
    python3 -m verl.model_merger merge \
    --backend fsdp \
    --local_dir checkpoints/verl_examples/gsm8k/global_step_410/actor \
    --target_dir checkpoints/verl_examples/gsm8k/global_step_410/actor/huggingface
    • 将模型路径替换为目标路径
    • 融合结果会存储到 target_dir 下,即 huggingface 目录下,执行后会存贮 model.safetensors 文件和 tokenizer.json 文件
  • 注:模型融合不一定需要安装 verl 的所有的依赖,某些情况下,安装所有 verl 依赖是昂贵的,容易出错,建议简单安装(参考:官方安装说明)

    1
    2
    3
    4
    5
    # 安装底层框架依赖
    USE_MEGATRON=0 bash scripts/install_vllm_sglang_mcore.sh # 仅使用 FSDP,不适用 Megatron(Megatron 安装容易出错)

    # 安装 verl
    pip install --no-deps -e . # 不安装依赖,在使用模型融合命令时遇到缺失的再安装,否则安装依赖容易出错

使用 verl 进行模型评估

  • 评估分成生成回答和评估结果两个部分

  • 生成回答

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    python3 -m verl.trainer.main_generation \
    trainer.nnodes=1 \
    trainer.n_gpus_per_node=2 \
    data.path=/path/to/test.parquet \
    data.prompt_key=prompt \
    data.batch_size=1024 \
    data.n_samples=8 \
    data.output_path=/path/to/output.parquet \
    model.path=/path/to/model \
    rollout.temperature=0.6 \
    rollout.top_p=0.95
    • 注意:这里会为每个 Prompt 生成 8 个样本
    • 路径替换为目标模型和目标输出文件名(注意:输出必须到文件名)
  • 评估结果

    1
    2
    3
    4
    5
    6
    python3 -m recipe.r1.main_eval \
    data.path=/path/to/output.parquet \
    data.prompt_key=prompt \
    data.response_key=responses \
    custom_reward_function.path=./recipe/r1/reward_score.py \
    custom_reward_function.name=reward_func
    • 注意:原始的 ./recipe/r1/reward_score.py 文件中不含有 gsm8k 数据集,只需要在数学类型中加入 “openai/gsm8k” 即可
    • 执行该命令可能需要安装 math-verify 包,执行 pip install math-verify 即可
  • 亲测:对 Qwen2.5-0.5B-Instruct 模型在 gsm8k 上训练,从 step=30 到 step=410 (batch_size=256, epoch=15),测试集上的准确率从 0.45 提升至 0.53 左右


附录:如何传入多个数据集?

  • 传入下面的参数?
    1
    2
    3
    4
    5
    train_files="['$train_data_path1','$train_data_path2']"
    test_files="['$valid_data_path1','$valid_data_path2']"

    data.train_files="$train_files" \
    data.val_files="$test_files" \

附录:其他注意事项或技巧

  • 控制保留的 ckpt 数量
    1
    2
    trainer.max_actor_ckpt_to_keep=10
    trainer.max_critic_ckpt_to_keep=10

附录:错误记录

HTTPRequestEntityTooLarge 错误

  • 问题详情:HTTPRequestEntityTooLarge: Request Entity Too Large
  • 原因:Ray 打包文件上传时上传了太多东西,导致实体过大,需要在 verl/trainer/runtime_env.yaml 中增加需要移出的文件 至 excludes
    • 一般都是 *.safetensors 相关的文件导致
  • 详情参考:github.com/volcengine/verl/issues/696

NCCL 错误

  • 表现是单机多卡没错误,多机多卡就会出现错误,错误信息为:

    1
    torch.distributed.DistBackendError: NCCL error in: /pytorch/torch/csrc/distributed/c10d/NCCLUtils.hpp:268, unhandled system error (run with NCCL_DEBUG=INFO for details), NCCL version 2.21.5
  • 一般是 NCCL 相关的环境变量配置有问题,需要检查一下,被修改过后成功运行的参数包括

    1
    2
    3
    4
    NCCL_SOCKET_IFNAME
    NCCL_SOCKET_IFNAME
    NCCL_IB_DISABLE
    NCCL_NET_GDR_LEVEL
  • 注:分布式训练中经常遇到 NCCL 相关的错误,下面是 NCCL 相关的官方错误说明:docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html


附录:特殊参数说明和记录

  • log_prob_micro_batch_size_per_gpu:表示 ref 或 rollout(actor) 一次前向推理时的真实 样本数

    from https://verl.readthedocs.io/en/latest/examples/config.html#actor-rollout-reference-policy
    The batch size for one forward pass in the computation of ref_log_prob. The value represent the local num per gpu.

    • actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu 表示 ref_log_prob 的配置
    • actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu 表示 log_prob 的配置
    • 注:log_prob 的计算是一个前向过程,但 batch_size 较大时显存会比较大,所以进一步进行拆分
  • 注:更多 batch_size 相关介绍:
    • 参考链接:聊聊verl中的batch_size
  • mini_batch,ppo_mini_batch_size(mini_batch_size) :一个 mini_batch 表示一次 PPO 参数更新
  • micro_batch,ppo_micro_batch_size_per_gpu :一次前向/反向过程的批次大小,多个 micro_batch 会累加梯度,直到足够一次 mini_batch 再更新一次模型

NLP——vLLM使用相关笔记

  • 参考链接:
    • GitHub 地址:github.com/vllm-project/vllm
    • 文档地址:Welcome to vLLM
    • 中文文档地址:vLLM 中文站

vLLM 采样参数:SamplingParams

  • SamplingParams 是控制模型“如何生成”的核心对象,常常包含下面的几个参数
    • n : 每个输入提示生成的输出序列数量(默认为 1)
    • best_of : 从生成的一组序列中选择最好的 k 个(用于集束搜索等)
    • temperature : 采样温度,控制随机性;0 表示贪心采样(确定性),值越高越随机
    • top_p : 核采样概率阈值,控制候选词的累积概率
    • top_k : 仅从概率最高的 k 个 token 中采样
    • max_tokens : 每个输出序列生成的最大 token 数
    • stop : 停止生成的字符串列表(遇到这些词即停止)
    • ignore_eos : 是否忽略结束符(EOS),强制生成直到达到最大长度
  • 更多详细参数见附录

输入格式:Prompts

  • vLLM 支持两种形式的输入,可以在同一个 batch 中混合使用:
    • 直接传入字符串,例如 "Hello, world"
      • vLLM 会自动调用内置 Tokenizer 进行编码
    • 传入已经编码好的 Token ID 列表
      • 这在需要自定义 Tokenizer 或复用已编码数据时非常有用
  • 还可以在一个列表中混合输入以上两种输入

vLLM 使用示例

  • 本文将通过三个维度的代码示例来展示 vLLM 的核心能力:
    • 高层同步接口 (LLM) :最常用的离线批量推理方式
    • 高层异步接口 (AsyncLLM) :适用于构建高并发服务的异步流式处理
    • 底层引擎接口 (LLMEngine) :展示如何手动控制调度循环 (Step-level control)

离线批量推理:LLM

  • 这是最简单的用法,适用于处理数据集
  • LLM 类封装了引擎的初始化和调度循环:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    from vllm import LLM, SamplingParams

    # # 初始化 LLM
    # tensor_parallel_size: 使用的 GPU 数量
    # gpu_memory_utilization: 显存占用比例 (0.0 - 1.0)
    llm = LLM(
    model="path_to_model",
    tensor_parallel_size=1,
    gpu_memory_utilization=0.9
    )

    # # 定义采样参数
    # sampling_params_greedy = SamplingParams(temperature=0, max_tokens=10) # 贪心采样策略
    sampling_params_creative = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=50)

    prompts = [
    "Hello, my name is", # 索引 0
    "The capital of France is" # 索引 1
    ]

    # # 执行批量解码 (Batch Decoding)
    # generate 函数是同步阻塞的,直到所有请求完成
    outputs = llm.generate(prompts, sampling_params_creative)

    # # 处理输出结果
    for i, output in enumerate(outputs):
    prompt = output.prompt
    # output.outputs 是一个列表,包含 'n' 个生成的序列 (这里 n=1,每个 Prompt 仅生成一个)
    generated_text = output.outputs[0].text

    print(f"--- Sample {i+1} ---")
    print(f"Prompt: {prompt!r}")
    print(f"Generated: {generated_text!r}")
    print(f"Finish Reason: {output.outputs[0].finish_reason}") # e.g., 'stop', 'length'

异步流式推理:AsyncLLM

  • AsyncLLM 是 LLM 的异步版本,基于 AsyncLLMEngine 构建
  • AsyncLLM 允许你在 Python 的 asyncio 循环中非阻塞地提交请求并获取结果,非常适合搭建 API 服务
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    import asyncio
    from vllm import AsyncLLM, SamplingParams
    from vllm.engine.arg_utils import AsyncEngineArgs

    async def run_async_inference(): # 使用 async 关键字定义一个协程函数
    # 1. 配置引擎参数
    # AsyncEngineArgs 允许更细粒度地控制引擎行为,如 max_num_seqs (最大并发序列数)
    engine_args = AsyncEngineArgs(
    model="path_to_model",
    tensor_parallel_size=1,
    disable_log_requests=True
    )

    # 2. 初始化异步 LLM
    # AsyncLLM 内部维护了一个后台循环来处理请求
    llm = AsyncLLM.from_engine_args(engine_args)

    # 3. 定义采样参数
    sampling_params = SamplingParams(temperature=0.7, max_tokens=20)

    # 4. 定义异步生成任务
    # request_id 是必须的,用于在引擎内部追踪请求,需保证唯一性
    async def generate_stream(request_id, prompt):
    results_generator = llm.generate(
    prompt,
    sampling_params,
    request_id=request_id
    )

    # 异步迭代生成结果 (Streaming)
    final_output = None
    async for request_output in results_generator:
    # 这里可以实现流式推送到前端
    final_output = request_output

    return final_output

    # 5. 模拟并发请求 (多样本解码)
    # 同时发送文本提示和 Token 提示
    tasks = [
    generate_stream("req_001", "To be or not to be,"),
    generate_stream("req_002", "The capital of France is") # TokensPrompt
    ]

    # 等待所有任务完成
    results = await asyncio.gather(*tasks)

    for res in results:
    print(f"Request ID: {res.request_id}")
    print(f"Output: {res.outputs[0].text}")

    # 运行异步主函数
    if __name__ == "__main__":
    asyncio.run(run_async_inference())

底层引擎手动调度:LLMEngine

  • LLMEngine 是 vLLM 最底层的核心
    • 通常用户不需要直接操作它,除非你需要极度定制化的调度逻辑(例如自定义 Web Server 或特殊的强化学习循环)
  • LLM 类本质上就是在这个类外面包了一层 while 循环
  • 这个示例展示了 vLLM 内部是如何通过 step() 函数一步步完成推理的
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    from vllm import LLMEngine, SamplingParams, RequestOutput
    from vllm.engine.arg_utils import EngineArgs
    from vllm.utils import random_uuid

    def run_core_engine_loop():
    # 1. 初始化引擎参数与实例
    engine_args = EngineArgs(model="path_to_model")
    engine = LLMEngine.from_engine_args(engine_args)

    sampling_params = SamplingParams(temperature=0, max_tokens=10)

    # 2. 手动添加请求 (Add Requests)
    # 必须手动管理 request_id
    engine.add_request( # 注意:add_request 函数不会启动推理,需要等待 step 函数来执行
    request_id="req_text",
    prompt="Artificial Intelligence is",
    sampling_params=sampling_params
    )

    engine.add_request(
    request_id="req_text",
    prompt="The capital of France is",
    sampling_params=sampling_params
    )

    # 3. 手动执行调度循环 (The Step Loop)
    # 只要引擎中还有未完成的请求,就继续循环
    while engine.has_unfinished_requests():
    # step() 执行一次推理迭代:
    # 1. 调度器决定哪些请求进入 GPU 计算
    # 2. 执行模型的前向传播 (Model Forward)
    # 3. 采样下一个 Token
    # 4. 更新 KV Cache
    request_outputs: list[RequestOutput] = engine.step() # 注意 step 是一次仅采样一个 Token!streaming 也是借助 step 函数实现的;平时不需要 step 函数是因为封装到底层了

    # 打印当前步的中间结果 (Streaming 效果)
    for output in request_outputs:
    if output.finished:
    print(f"[{output.request_id}] Finished: {output.outputs[0].text}")
    else:
    # 仅打印当前生成的最新 token(简化展示)
    # 实际 output.outputs[0].text 包含完整的累积文本
    pass

    # 运行
    if __name__ == "__main__":
    run_core_engine_loop()

vLLM 部署及参数说明

部署命令

  • 推荐的 Linux 启动命令(可根据实际情况修改路径和显卡数量):

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    vllm serve /data/models/Llama-3-8B-Instruct \
    --served-model-name llama3-8b \
    --host 0.0.0.0 \
    --port 8000 \
    --dtype auto \
    --tensor-parallel-size 1 \
    --pipeline-parallel-size 1 \
    --gpu-memory-utilization 0.90 \
    --swap-space 4 \
    --max-model-len 8192 \
    --max-num-seqs 256 \
    --max-num-batched-tokens 8192 \
    --trust-remote-code \
    --enable-chunked-prefill \
    --disable-custom-all-reduce \
    --quantization awq \
    --enforce-eager \
    --api-key "sk-your-secure-password"
  • TLDR:参数配置建议:

    • 如果追求极致吞吐量(Throughput) :增大 --max-num-batched-tokens
      • 这允许一次性处理更多数据,但可能会导致生成过程中的停顿感(因为大批量的预填充会抢占计算资源)
    • 如果追求低延迟和流畅度(Latency) :建议保持适中的 --max-num-batched-tokens,并 开启 --enable-chunked-prefill
      • 这样可以将大的预填充任务打散,避免计算尖峰,确保正在生成的对话不会卡顿
    • 显存限制 :请注意,max-num-batched-tokens 的大小直接影响 KV Cache 的瞬时显存需求
      • 如果该值过大而显存不足,可能会触发 OOM 或强制调度器减少并发序列数(--max-num-seqs)

参数详细解析

  • vllm serve <path_to_model>
    • 这是 vLLM 的启动入口命令
    • 后面的 path_to_model 路径是模型在本地文件系统中的绝对路径(也可以是 Hugging Face 的模型 ID)
  • --served-model-name <model_name>
    • 指定服务对外显示的名称,建议使用类似 “llama-8b” 等类似名称标注
    • 当客户端调用 OpenAI 兼容 API 时,model 字段需要匹配这个名字
    • 如果不设置,默认使用模型路径作为名字
  • --host 0.0.0.0
    • 指定服务绑定的 IP 地址
    • 0.0.0.0 表示允许来自任何网络接口的连接(对外网开放);如果仅限本地访问,可设置为 127.0.0.1
  • --port 8000
    • 指定服务监听的端口号
  • --dtype auto
    • 指定模型权重的加载精度
    • 设置为 auto 时,vLLM 会根据配置文件(config.json)自动检测(通常是 float16 或 bfloat16)
    • 也可以强制指定为 float16、bfloat16 或 float32
  • --tensor-parallel-size 1 (TP)
    • 张量并行度,即把一个模型的层拆分到几张显卡上并行计算
    • 通常设置为单机内的 GPU 数量
  • --pipeline-parallel-size 1 (PP)
    • 流水线并行度,即把模型的不同层分配到不同的显卡上
    • 通常用于模型过大,单卡显存塞不下且 TP 无法解决时
    • 注:模型部署不建议开 PP,使用 TP 即可,一般情况下 PP 保持为 1
  • --gpu-memory-utilization 0.90
    • GPU 显存使用率上限,注意:这是影响并发能力的核心参数
    • vLLM 会预先占用这部分比例的显存(此处为 90%)
      • 其中一部分用于加载模型权重,剩余的所有空间都会被预分配给 KV Cache(键值缓存)
    • 如果设得太高容易 OOM(显存溢出),设得太低则浪费显存,导致并发量上不去
  • --swap-space 4 (新增重要参数)
    • CPU 交换空间大小(单位:GiB)
    • 当 GPU 显存不足以存放 KV Cache 时,vLLM 会将部分 KV Block 换出到 CPU 内存中
    • 设置此参数可以防止在请求突发高峰时发生 OOM 崩溃
  • --max-model-len 8192
    • 模型的最大上下文长度(输入+输出)
      • 如果未指定,vLLM 会尝试从模型配置中读取
    • 显式指定可以限制显存占用,避免处理过长的序列导致崩溃
  • --max-num-seqs 256
    • 最大并发序列数,即同一时刻 vLLM 能处理的请求数量(Batch Size)
    • 这个值越高,吞吐量越大,但每个请求的延迟可能会增加
  • --max-num-batched-tokens 8192
    • 一次迭代(iteration)中处理的最大 Token 总数
    • 这包括了 Prefill(预填充)阶段和 Decode(解码)阶段的所有 Token
    • 通常默认为 max(max_model_len, 2048),建议根据卡的性能灵活配置,以最大化效率
  • --trust-remote-code
    • 允许执行模型仓库中的自定义 Python 代码
    • 对于某些非标准架构的模型(如 ChatGLM、Qwen 的早期版本等),必须开启此选项才能正确加载模型架构
  • --enable-chunked-prefill
    • 为了解决长 Prompt 导致的“队头阻塞”问题(即一个超长 Prompt 占满计算资源,导致短请求延迟增加),引入了分块预填充机制
    • 开启分块预填充,这是一个优化参数,允许将长 Prompt 的 Prefill 阶段拆分成多个小块,与 Decode 阶段混合调度
    • 这可以显著降低长文本输入时的首字延迟(TTFT) ,因为允许解码(Decode)任务和预填充(Prefill)任务更平滑地交错执行,显著降低了其他并发请求的 Inter-Token Latency(ITL,Token 间延迟),使生成过程更加流畅
  • --max-num-partial-prefills
    • 并发预填充数,当启用了分块预填充(Chunked Prefill)后,这个参数变得非常重要
    • 限制了在同一时刻,有多少个请求可以处于“部分预填充”状态,
    • 默认为 1:意味着在任何给定的迭代中,调度器最多只允许 1 个请求进行部分预填充计算(与其他正在解码的请求并行),这有助于防止过多的上下文切换开销,同时保证显存管理的稳定性
  • --long-prefill-token-threshold
    • 长预填充阈值,这是一个辅助参数,用于配合分块预填充使用
    • 定义了多少 Token 数量的 Prompt 被视为“长请求”
    • 当 Prompt 长度超过此阈值时,vLLM 才会考虑对其应用特殊的调度策略或分块逻辑。默认值为 0,意味着所有请求都遵循统一的规则
  • --disable-custom-all-reduce
    • 禁用 vLLM 自定义的 All-Reduce 内核
    • 通常在某些 GPU 架构不支持或驱动不兼容导致 NCCL 通信错误时使用
    • 如果硬件环境标准,通常不需要加这个,但在排查多卡通信问题时很有用
  • --quantization awq (新增重要参数)
    • 指定量化格式
    • 如果模型是量化版本(如 AWQ, GPTQ, SqueezeLLM),必须指定此参数
    • 例如加载 Llama-3-8B-AWQ 时,需设置为 awq
    • 如果是非量化模型,请去掉此参数
  • --enforce-eager (新增重要参数)
    • 强制使用 PyTorch 的 Eager 模式,禁用 CUDA Graph
    • 虽然 CUDA Graph 能加速小 Batch 的推理,但在某些特定显卡或驱动版本上可能会导致显存分配错误或死锁
    • 开启此项有助于调试和提高稳定性
  • --api-key "sk-your-secure-password" (新增重要参数)
    • 设置访问 API 的密钥
    • 在生产环境中,为了防止未授权访问,配置 API Key 是必须的安全措施
    • 客户端请求头需携带 Authorization: Bearer sk-your-secure-password

附录:如果是量化模型,不添加 --quantization 参数 会怎样?

  • 以 W8A8-QuaRot(Weight 8-bit / Activation 8-bit,使用了 QuaRot 旋转算法进行离群值抑制)量化模型 为例,如果在启动 vLLM 时不指定 --quantization 参数,通常会发生以下三种情况之一(具体取决于模型的 config.json 配置和 vLLM 的版本)
  • 注:QuaRot 是一种算法技术,它生成的模型最终通常以 FP8 (E4M3/E5M2) 或 Int8 的格式存储
情况1:直接报错并无法启动(最常见的情况)
  • 这是最可能发生的结果
  • vLLM 启动时会读取模型的 config.json
    • 如果该配置文件中包含 quantization_config 字段(例如标记为 fp8、compressed-tensors 或自定义格式),但 vLLM 在默认模式下无法自动匹配到合适的 Kernel(内核),或者检测到硬件不支持该量化格式(例如在非 Hopper 架构显卡上加载 FP8),程序会直接抛出 ValueError 或 RuntimeError
  • 终端会打印类似 ValueError: Unknown quantization method... 或 RuntimeError: Shape mismatch... 的错误日志,服务启动失败
情况2:加载成功但输出乱码,Garbage Output
  • 这种情况比较危险,因为它看起来“跑起来了”,但完全不可用
  • 如果模型的 config.json 中缺失 了量化相关的元数据,或者 vLLM 错误地将其识别为标准模型,它会尝试以默认精度(通常是 float16 或 bfloat16)来解释权重数据
  • 数据曲解:原本是 8-bit 的整数或 FP8 数据,被当成了 16-bit 的浮点数读取
  • 模型可以接受输入,但吐出来的全是乱码、重复符号或毫无逻辑的字符
情况3:自动识别成功(理想情况,但有前提)
  • 在较新的 vLLM 版本中,如果模型打包规范(例如使用 llm-compressor 或 AutoGPTQ 正确导出),config.json 中会有明确的 quantization_config 字段
  • config.json 里的 quantization 字段(如 fp8 或 compressed-tensors)被当前版本的 vLLM 原生支持
  • 如果显卡支持该精度(例如 w8a8 的 QuaRot 通常对应 FP8 ,这通常需要 NVIDIA Ada Lovelace (RTX 4090) 或 Hopper (H100) 架构的 GPU)
  • 此时即使你不写 --quantization,vLLM 也会根据配置文件自动启用对应的量化内核,服务正常运行
推荐做法
  • 先看模型文件夹下的 config.json,寻找 quantization_config 字段

  • 如果格式是 FP8(常见于 QuaRot 转换的模型):

    1
    --quantization fp8 --kv-cache-dtype fp8
    • 注意:支持 fp8 通常需要 H100/L40/RTX4090 等新显卡
  • 如果格式是 Compressed-Tensors / Neural Magic 格式:

    • vLLM 通常能自动识别,但如果报错,可能需要指定:
      1
      --quantization compressed-tensors

附录:SamplingParams 参数项详解

  • vLLM 的 SamplingParams 参数很多,覆盖了多个方面:
    • 从基础生成控制(长度、终止)
    • 采样策略(随机性、候选集)
    • 重复控制(惩罚)
    • 输出格式(detokenize、 Special Token )
    • 高级自定义(logits 处理器、结构化输出)的全维度参数
  • 这些参数既兼容 OpenAI API 规范,又扩展了 beam search、结构化输出、不良词过滤等特有功能
  • 一些简单的常用理解:
    • 追求确定性可以配置:temperature=0 + top_k=1;
      • 问题:temperature=0 其实就已经是贪心采样了,但是我们一般还是会使用 top_k=1 进一步明确 贪心采样
    • 追求多样性可以配置:temperature=0.7 + top_p=0.9;
      • 理解:temperature=0.7 + top_p=0.9 是很常用的参数
    • 避免重复:presence_penalty=0.5 + frequency_penalty=0.3;
      • presence_penalty 惩罚是否出现过
      • frequency_penalty 惩罚出现频次

SamplingParams 源码配置

  • 以下源码参考自:github.com/vllm-project
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    class SamplingParams(
    PydanticMsgspecMixin,
    msgspec.Struct,
    omit_defaults=True, # type: ignore[call-arg]
    # required for @cached_property.
    dict=True,
    ): # type: ignore[call-arg]
    """Sampling parameters for text generation.

    Overall, we follow the sampling parameters from the OpenAI text completion
    API (https://platform.openai.com/docs/api-reference/completions/create).
    In addition, we support beam search, which is not supported by OpenAI.
    """

    n: int = 1
    """Number of outputs to return for the given prompt request.

    NOTE:
    `AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs
    are generated and streamed cumulatively per request. To see all `n`
    outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY`
    in `SamplingParams`."""
    presence_penalty: float = 0.0
    """Penalizes new tokens based on whether they appear in the generated text
    so far. Values > 0 encourage the model to use new tokens, while values < 0
    encourage the model to repeat tokens."""
    frequency_penalty: float = 0.0
    """Penalizes new tokens based on their frequency in the generated text so
    far. Values > 0 encourage the model to use new tokens, while values < 0
    encourage the model to repeat tokens."""
    repetition_penalty: float = 1.0
    """Penalizes new tokens based on whether they appear in the prompt and the
    generated text so far. Values > 1 encourage the model to use new tokens,
    while values < 1 encourage the model to repeat tokens."""
    temperature: float = 1.0
    """Controls the randomness of the sampling. Lower values make the model
    more deterministic, while higher values make the model more random. Zero
    means greedy sampling."""
    top_p: float = 1.0
    """Controls the cumulative probability of the top tokens to consider. Must
    be in (0, 1]. Set to 1 to consider all tokens."""
    top_k: int = 0
    """Controls the number of top tokens to consider. Set to 0 (or -1) to
    consider all tokens."""
    min_p: float = 0.0
    """Represents the minimum probability for a token to be considered,
    relative to the probability of the most likely token. Must be in [0, 1].
    Set to 0 to disable this."""
    seed: int | None = None
    """Random seed to use for the generation."""
    stop: str | list[str] | None = None
    """String(s) that stop the generation when they are generated. The returned
    output will not contain the stop strings."""
    stop_token_ids: list[int] | None = None
    """Token IDs that stop the generation when they are generated. The returned
    output will contain the stop tokens unless the stop tokens are special
    tokens."""
    ignore_eos: bool = False
    """Whether to ignore the EOS token and continue generating
    tokens after the EOS token is generated."""
    max_tokens: int | None = 16
    """Maximum number of tokens to generate per output sequence."""
    min_tokens: int = 0
    """Minimum number of tokens to generate per output sequence before EOS or
    `stop_token_ids` can be generated"""
    logprobs: int | None = None
    """Number of log probabilities to return per output token. When set to
    `None`, no probability is returned. If set to a non-`None` value, the
    result includes the log probabilities of the specified number of most
    likely tokens, as well as the chosen tokens. Note that the implementation
    follows the OpenAI API: The API will always return the log probability of
    the sampled token, so there may be up to `logprobs+1` elements in the
    response. When set to -1, return all `vocab_size` log probabilities."""
    prompt_logprobs: int | None = None
    """Number of log probabilities to return per prompt token.
    When set to -1, return all `vocab_size` log probabilities."""
    flat_logprobs: bool = False
    """Whether to return logprobs in flatten format (i.e. FlatLogprob)
    for better performance.
    NOTE: GC costs of FlatLogprobs is significantly smaller than
    list[dict[int, Logprob]]. After enabled, PromptLogprobs and
    SampleLogprobs would populated as FlatLogprobs."""
    # NOTE: This parameter is only exposed at the engine level for now.
    # It is not exposed in the OpenAI API server, as the OpenAI API does
    # not support returning only a list of token IDs.
    detokenize: bool = True
    """Whether to detokenize the output."""
    skip_special_tokens: bool = True
    """Whether to skip special tokens in the output."""
    spaces_between_special_tokens: bool = True
    """Whether to add spaces between special tokens in the output."""
    # `list[LogitsProcessor] | None` type. We use Any here because
    # `list[LogitsProcessor] | None` type is not supported by msgspec.
    logits_processors: Any | None = None
    """Functions that modify logits based on previously generated tokens, and
    optionally prompt tokens as a first argument."""
    include_stop_str_in_output: bool = False
    """Whether to include the stop strings in output text."""
    truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None
    """If set to -1, will use the truncation size supported by the model. If
    set to an integer k, will use only the last k tokens from the prompt
    (i.e., left truncation). If set to `None`, truncation is disabled."""
    output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE

    # The below fields are not supposed to be used as an input.
    # They are set in post_init.
    output_text_buffer_length: int = 0
    _all_stop_token_ids: set[int] = msgspec.field(default_factory=set)

    # Fields used to construct logits processors
    structured_outputs: StructuredOutputsParams | None = None
    """Parameters for configuring structured outputs."""
    logit_bias: dict[int, float] | None = None
    """If provided, the engine will construct a logits processor that applies
    these logit biases."""
    allowed_token_ids: list[int] | None = None
    """If provided, the engine will construct a logits processor which only
    retains scores for the given token ids."""
    extra_args: dict[str, Any] | None = None
    """Arbitrary additional args, that can be used by custom sampling
    implementations, plugins, etc. Not used by any in-tree sampling
    implementations."""

    # Fields used for bad words
    bad_words: list[str] | None = None
    """Words that are not allowed to be generated. More precisely, only the
    last token of a corresponding token sequence is not allowed when the next
    generated token can complete the sequence."""
    _bad_words_token_ids: list[list[int]] | None = None

    skip_reading_prefix_cache: bool | None = None

基础参数说明

  • n: int = 1:
    • 为单个 Prompt 请求返回的生成结果数量
    • vLLM 默认一个个输出结果,当 n > 1 时,所有 n 个结果会按请求累积流式一个个返回;
      • 问题:这里的流式,不是通常意义上的流式,而是针对 Response n 粒度的流式?
    • 若希望仅在生成完成后一次性获取所有 n 个结果,需将 output_kind 设置为 RequestOutputKind.FINAL_ONLY
  • max_tokens: int | None = 16:
    • 每个输出序列允许生成的最大 token 数量
    • 若设为 None,需确保模型有明确的终止条件(如 EOS 或 stop 词),否则可能无限生成
  • min_tokens: int = 0
    • 每个输出序列在生成 EOS(结束符)或 stop_token_ids 之前必须生成的最小 token 数
    • 作用 :避免生成过短的结果,例如设置 min_tokens=5 时,即使模型提前触发终止条件 ,也会继续生成直到达到 5 个 token
  • ignore_eos: bool = False
    • 是否忽略 EOS token,强制模型在生成 EOS 后继续生成
    • 适用于需要生成超长文本、绕过模型默认终止逻辑的场景(如生成完整文档而非单句)

采样策略参数

  • 采样策略参数参数控制模型生成 token 时的随机性和候选范围,是最常用的参数,当不做采样时,no_sample
  • temperature: float = 1.0
    • 控制采样的随机性,本质是对 logits(token 概率对数)进行缩放(注意:是在 Softmax 前进行缩放)
    • temperature = 0:贪心采样(Greedy Sampling),直接选择概率最高的 token,结果完全确定;
    • 0 < temperature < 1:降低随机性(提高确定性),结果更聚焦、确定(如 0.7 是平衡随机性和确定性的常用值);
    • temperature > 1:提高随机性,结果更发散、创意性更强,但可能出现无意义内容
    • temperature 越小越容易出现重复现象
    • 注意 :当 temperature=0 时,top_p/top_k 等参数会失效(贪心采样无需候选集)
      • vLLM 中没有 do_sample 参数 参照了 HF Transformer 相似的思路,但是实现方式不同,通过 temperature 隐晦地实现了是否贪心采样的控制
      • temperature=0 强制 do_sample=False(贪心采样,只选概率最高的 token);
      • temperature>0 等价 do_sample=True(启用随机采样,按概率分布选 token)
  • top_p: float = 1.0
    • 核采样(Nucleus Sampling),控制待选 token 的累积概率阈值,取值范围 (0, 1]
    • 将所有 token 按概率从高到低排序,累加概率直到达到 top_p,仅从这些 token 中采样
      • top_p=0.9 时,仅选择累计概率前 90% 的 token 作为候选;
      • top_p=1.0 时,包含所有 token(等同于不限制)
    • 相比 top_k 更灵活,能自适应调整候选集大小(高概率 token 少则候选集小,反之则大)
  • top_k: int = 0
    • 限制采样的候选 token 数量,仅从概率最高的 top_k 个 token 中选择
      • top_k=0(或 -1):不限制,包含所有 token;
      • top_k=50:仅从概率前 50 的 token 中采样
    • 对比 top_p :top_k 是固定数量限制,top_p 是概率累积限制,通常两者二选一使用
      • 两者组合时:先按 top_k 筛选,再按 top_p 过滤
  • min_p: float = 0.0
    • 基于最高概率 token 的相对概率阈值,筛选候选 token,取值范围 [0, 1]
    • 设本次采样遇到的最高概率 token 的概率为 P_max(注意:是个随分布变化的值),仅保留概率 \(\ge\) min_p * P_max 的 token
      • min_p=0.1 且 P_max=0.5 时,仅保留概率 \(\ge\) 0.05 的 token;
      • min_p=0 时禁用该规则
    • 优势 :相比 top_k/top_p,能避免极端情况下的不合理筛选(如 top_k 可能漏掉低概率但有意义的 token,top_p 可能包含过多低概率 token)
  • seed: int | None = None
    • 生成随机数的种子,用于复现生成结果
    • 设置固定 seed 后,相同 Prompt 和参数下,模型会生成完全相同的结果(解决采样随机性导致的不可复现问题)

重复/惩罚类参数

  • 用于控制模型生成时的重复率,避免生成冗余、重复的文本
  • presence_penalty: float = 0.0
    • 基于 token 是否“出现过”的惩罚,与出现次数无关
      • 正值(如 0.5):惩罚已出现的 token,鼓励生成新内容;
      • 负值(如 -0.5):奖励已出现的 token,鼓励重复;
      • 0:无惩罚/奖励
    • 适用场景 :需要避免模型重复提及相同实体(如人名、地名)的场景
  • frequency_penalty: float = 0.0
    • 基于 token 出现“频率”的惩罚,出现次数越多,惩罚越重
      • 正值:抑制高频 token,减少重复;
      • 负值:强化高频 token,增加重复;
      • 0:无惩罚/奖励
    • 区别于 presence_penalty :前者是“有无”惩罚,后者是“多少”惩罚,例如重复 3 次的 token 会比重复 1 次的 token 受到更重的频率惩罚
  • repetition_penalty: float = 1.0
    • 基于 prompt 和已生成文本中 token 出现的惩罚,核心是调整 token 的概率
      • 取值 > 1:惩罚重复 token(概率 = 原概率 / repetition_penalty),鼓励新内容;
      • 取值 < 1:奖励重复 token(概率 = 原概率 * repetition_penalty),鼓励重复;
      • 1:无惩罚/奖励
    • 覆盖范围(特别注意) :同时作用于 prompt 和生成文本中的 token,是更通用的重复控制参数
      • 理解:这里的含义是在 prompt 中的 Token 也会当做是否重复的判断依据进行累计

终止条件参数

  • 控制模型何时停止生成,避免无限制输出
  • stop: str | list[str] | None = None
    • 触发生成终止的字符串(单个或列表),返回的结果中默认不包含这些停止字符串
    • stop=["\n", "###"] 时,模型生成到换行符或 ### 时立即停止
  • stop_token_ids: list[int] | None = None
    • 触发生成终止的 token ID 列表(底层 token-level 的终止条件)
    • 返回结果中会包含 stop_token_ids 对应的 stop token(Special Token 服从本规则)
      • 如果是 Special Token,可能是不会在输出结果中的,有自己的规则
      • 与 stop(字符串级)互补,分别用于指定字符串或者 Token
  • include_stop_str_in_output: bool = False
    • 是否将 stop 参数指定的停止字符串包含在输出文本中
      • 注意:这里只影响 stop,与 stop_token_ids 无关,stop_token_ids 不受此参数影响
    • 若设为 True,停止字符串会出现在最终输出里
  • 理解终止条件参数,vLLM 的 SamplingParams 内部会维护一个参数:_all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
    • _all_stop_token_ids 存储所有终止 token ID
      • 包括 stop_token_ids 转换后的 ID、EOS token 等
      • 这个参数无需用户手动设置,由 post_init 自动初始化

日志概率(logprobs)参数

  • 用于获取 token 生成的概率信息,适用于需要分析模型决策过程的场景(如评估生成可靠性)
  • logprobs: int | None = None
    • 每个输出 token 返回的最高概率 token 的数量(包含选中的 token)
      • logprobs=None:不返回概率;
      • logprobs=k(\(k \in \mathbb{Z}^+\)):返回概率最高的 k 个 token 的 log 概率(实际返回 k+1 个,因为包含选中的 token);
        • 理解:这里选中的 Token 不一定是概率最高的, 所以被选中的一定会返回
      • logprobs=-1:返回全词表(vocab_size 维度)所有 token 的 log 概率
    • Following OpenAI API :始终返回选中 token 的 log 概率
  • prompt_logprobs: int | None = None
    • 每个 Prompt token 返回的最高概率 token 的数量
      • 取值规则同 logprobs,-1 表示返回全词表概率
    • 问题:prompt 为什么也会对应概率?
      • prompt_logprobs 是专门针对输入的 prompt 部分(而非生成的 completion 部分)返回的每个 token 的对数概率信息
      • logprobs 则通常指生成部分的对数概率
  • flat_logprobs: bool = False
    • 是否展平返回 logprobs,优化性能
    • 优势 :FlatLogprob 的 GC(垃圾回收)成本远低于 list[dict[int, Logprob]] 格式,适合高并发场景;
    • 启用后 PromptLogprobs 和 SampleLogprobs 均会以 FlatLogprob 格式返回

输出格式与处理参数

  • 控制生成结果的格式、是否过滤 Special Token 等
  • detokenize: bool = True
    • 是否将生成的 token ID 转换为文本
    • 注意 :仅在引擎层暴露,OpenAI API 不支持仅返回 token ID,默认开启,得到的就是文本而不是 Token ID
  • skip_special_tokens: bool = True
    • 是否在输出中跳过 Special Token (如 、、 等)
    • 注意默认是 True(跳过),避免输出包含无意义的特殊标记
  • spaces_between_special_tokens: bool = True
    • 是否在 Special Token 之间添加空格
    • 优化 Special Token 的可读性,例如 <|endoftext|><|user|> 会变成 <|endoftext|> <|user|>
    • 理解:为什么这里默认是 True,目前我们几乎不用,但确从不需要设置?猜测如下(待确定):
      • Hugging Face Tokenizer 基类的 通用默认值 是 True;
      • LLaMA/Qwen 等模型的 专属默认值 是 False(通过代码硬编码覆盖了通用默认值)
  • output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
    • 输出类型,控制流式返回的方式:
      • output_kind=RequestOutputKind.CUMULATIVE(默认):累积式输出(如第 1 次返回第 1 个 token,第 2 次返回前 2 个 token,依此类推);
      • output_kind=RequestOutputKind.FINAL_ONLY:仅在生成完成后返回最终完整结果(此时不是异步生成了)
      • output_kind=RequestOutputKind.DELTA:仅返回增量
    • 问题:这个参数的使用待测试确认
  • output_text_buffer_length: int = 0
    • 内部参数,存储输出文本缓冲区长度,无需用户设置,由 post_init 初始化

Prompt 处理参数

  • truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None
    • Prompt 的左截断规则(仅保留最后 k 个 token):
      • -1:使用模型支持的最大截断长度;
      • 正整数 k:仅保留 Prompt 最后 k 个 token;
      • None:禁用截断
    • 常用参数,适配模型的上下文窗口限制,避免 Prompt 过长导致超出模型容量
  • skip_reading_prefix_cache: bool | None = None
    • 是否跳过读取前缀缓存(prefix cache),用于优化 Prompt 处理性能,通常无需用户手动设置
  • logits_processors: Any | None = None
    • 修改 logits 的自定义处理器列表(函数),可基于已生成的 token(或 Prompt token)调整 token 概率
    • 因 msgspec 不支持 list[LogitsProcessor] | None,故用 Any 替代;适用于自定义生成逻辑(如强制生成特定 token、限制生成内容)
    • 问题:待确认这个参数
  • structured_outputs: StructuredOutputsParams | None = None
    • 结构化输出参数,用于控制模型生成符合特定格式的内容(如 JSON、XML)
    • 需要结构化结果的场景(如数据提取、API 调用返回)
  • logit_bias: dict[int, float] | None = None
    • token 级别的概率偏置,键为 token ID,值为偏置值
    • 调整指定 token 的生成概率(正值提高概率,负值降低概率),例如 logit_bias={123: 5.0} 会大幅提高 ID 为 123 的 token 被选中的概率
    • 问题:待尝试这个参数
  • allowed_token_ids: list[int] | None = None
    • 允许生成的 token ID 列表,后续生成时,会仅保留这些 token 的概率,其余 token 概率置 0
    • 严格限制生成内容的范围(如仅允许生成数字、特定词汇)
    • 问题:待尝试这个参数
  • extra_args: dict[str, Any] | None = None
    • 自定义额外参数,供第三方插件、自定义采样逻辑使用,vLLM 内置采样逻辑不使用该参数

不良词过滤参数

  • bad_words: list[str] | None = None
    • 禁止生成的词汇列表,核心逻辑是:当生成的 token 即将完成某个 bad word 的 token 序列时,禁止生成该序列的最后一个 token
    • 比如 bad_words=["暴力"] 时,模型会避免生成“暴力”这个词(通过阻止其最后一个 token 的生成),直接停止
    • 问题:待测试这个参数
  • _bad_words_token_ids: list[list[int]] | None = None
    • 内部参数,存储 bad_words 转换后的 token ID 序列,无需用户设置,由 post_init 初始化
    • 问题:待测试这个参数

NLP——技术报告解读-DeepSeek-R1

注:本文包含 AI 辅助创作

  • 参考链接:
    • 原始论文:DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning, arXiv 20250122 & 20260104, DeepSeek-AI
      • 20260104 补充了更多细节【本文还未补充,待后续有时间再更新】
    • 补充文档:(DeepSeek-R1-Supplements)Supplementary Information for: DeepSeek-R1 Incentivizes Reasoning in LLMs via Reinforcement Learning, DeepSeek-AI
      • 补充材料阅读笔记见另一篇
    • 中文完整版: 梁文锋Nature论文的同行评审和团队回应- 上
    • rebuttal过程

Paper Summary

  • 评价:
    • 划时代的一篇文章,25 年春节前后 DeepSeek 给大家带来的冲击是巨大的,众多社区一起复现 DeepSeek-R1 的 Aha Moment 的空前盛况
    • 本文及其附录都是非常值得深入阅读的文章
  • 论文介绍了 DeepSeek 的第一代推理模型 DeepSeek-R1-Zero 和 DeepSeek-R1
  • DeepSeek-R1-Zero 是一个通过大规模 RL 训练、无需 SFT 作为初步步骤的模型,展现出卓越的推理能力
    • 通过强化学习,DeepSeek-R1-Zero 自然地涌现出许多强大而有趣的推理行为
    • 但 DeepSeek-R1-Zero 也面临可读性差和语言混合等挑战
  • 为了解决这些问题并进一步提升推理性能,论文引入了 DeepSeek-R1,它在强化学习之前引入了多阶段训练和冷启动数据
    • DeepSeek-R1 在推理任务上的性能与 OpenAI-o1-1217 相当
    • 为了支持研究社区,论文开源了 DeepSeek-R1-Zero、DeepSeek-R1,以及六个基于 Qwen 和 Llama 从 DeepSeek-R1 蒸馏得到的稠密模型(1.5B、7B、8B、14B、32B、70B)

Introduction and Discussion

  • Post-training 已成为完整训练流程中的一个重要组成部分
    • 能够提升推理任务的准确性、与社会价值观对齐并适应用户偏好,同时相对于预训练所需计算资源相对较少
  • 在推理能力的背景下,OpenAI 的 o1 (OpenAI) 系列模型首次通过增加 CoT 推理过程的长度引入了推理时扩展(inference-time scaling)
    • 这种方法在数学、编程和科学推理等各种推理任务中取得了显著改进
    • 但有效的测试时扩展(test-time scaling)的挑战仍然是研究社区的一个开放性问题
  • 之前的几项工作探索了各种方法,包括基于过程的奖励模型(process-based reward models)(2022; 2023)、强化学习 (2024) 以及蒙特卡洛树搜索(Monte Carlo Tree Search)和束搜索(Beam Search)等搜索算法 (2024; 2024; 2024)
    • 然而,这些方法都没有达到与 OpenAI 的 o1 系列模型相媲美的通用推理性能
  • 在论文中,论文迈出了第一步,使用纯 RL 来改进语言模型的推理能力
    • 论文的目标是探索大语言模型在没有任何监督数据的情况下发展推理能力的潜力,重点关注它们通过纯强化学习过程进行的自我进化
    • 具体来说,论文使用 DeepSeek-V3-Base 作为基础模型,并采用 GRPO (2024) 作为强化学习框架来提高模型在推理中的性能
    • 在训练过程中,DeepSeek-R1-Zero 自然地涌现出许多强大而有趣的推理行为
    • 经过数千个强化学习步骤后,DeepSeek-R1-Zero 在推理基准测试中表现出卓越的性能
      • 例如,在 AIME 2024 上的 pass@1 分数从 15.6% 增加到 71.0%,并且通过多数投票(majority voting),分数进一步提高到 86.7%,与 OpenAI-o1-0912 的性能相匹配
  • 然而,DeepSeek-R1-Zero 遇到了可读性差和语言混合等挑战
    • 为了解决这些问题并进一步提升推理性能,论文引入了 DeepSeek-R1,它结合了少量冷启动数据(cold-start data)和多阶段训练流程
    • 具体来说
      • 论文首先收集数千个冷启动数据来微调 DeepSeek-V3-Base 模型
      • 随后,论文像 DeepSeek-R1-Zero 一样执行面向推理的强化学习
    • 当强化学习过程接近收敛时,论文通过对强化学习检查点(checkpoint)进行拒绝采样(rejection sampling)来创建新的监督微调数据,并结合来自 DeepSeek-V3 在写作、事实问答(factual QA)和自我认知(self-cognition)等领域的有监督数据,然后重新训练 DeepSeek-V3-Base 模型
    • 在使用新数据微调后,该检查点会经历额外的强化学习过程,考虑所有场景的 Prompts
    • 经过这些步骤,论文获得了一个称为 DeepSeek-R1 的检查点,其性能与 OpenAI-o1-1217 相当
  • 论文进一步探索了从 DeepSeek-R1 到更小稠密模型(dense models)的蒸馏(distillation)
    • 使用 Qwen2.5-32B (Qwen) 作为基础模型,直接从 DeepSeek-R1 进行蒸馏优于在其上应用强化学习
    • 这表明由更大基础模型发现的推理模式对于提高推理能力至关重要
  • 论文开源了蒸馏后的 Qwen 和 Llama (2024) 系列
    • 值得注意的是,论文蒸馏的 14B 模型大幅优于 SOTA 开源模型 QwQ-32B-Preview (Qwen),并且蒸馏的 32B 和 70B 模型在稠密模型的推理基准测试中创造了新纪录
  • 补充:来自辅助材料的说明
    • DeepSeek-V3-Base 指基础模型
    • DeepSeek-V3 指经过指令微调的模型
    • DeepSeek-R1 与 DeepSeek-R1-Zero 均在 DeepSeek-V3-Base 的基础上训练而成
      • 且 DeepSeek-R1 还利用了 DeepSeek-V3 监督微调数据中的非推理类数据

Approach

Overview

  • 先前的工作严重依赖大量的监督数据来提升模型性能
  • 在本研究中,论文证明了即使不使用 SFT 作为冷启动,通过大规模 RL 也能显著提升推理能力
    • 此外,加入少量冷启动数据可以进一步提升性能
  • 在接下来的小节中,论文将介绍:
    • (1) DeepSeek-R1-Zero,它直接在基础模型(DeepSeek-V3-Base)上应用强化学习,不使用任何监督微调数据;
    • (2) DeepSeek-R1,它从一个经过数千个长 CoT 示例微调过的检查点开始应用强化学习;
    • (3) 将 DeepSeek-R1 的推理能力蒸馏到小型稠密模型中

DeepSeek-R1-Zero: Reinforcement Learning on the Base Model

  • 强化学习在推理任务中已展现出显著的有效性,这在论文先前的工作 (2024; 2023) 中得到了证明
    • 但这些工作严重依赖监督数据,而收集这些数据非常耗时
  • 在本节中,论文探索了 LLM 在没有任何监督数据的情况下发展推理能力的潜力,重点关注其通过纯强化学习过程进行的自我进化
  • 论文首先简要概述论文的强化学习算法,随后展示一些令人兴奋的结果,并希望这能为研究社区提供有价值的见解
Reinforcement Learning Algorithm
Group Relative Policy Optimization, GRPO
  • 为了节省强化学习的训练成本,论文采用了 GRPO (2024)
  • 该方法省去了通常与策略模型大小相同的评论家模型,转而从组分数中估计基线
  • 具体来说,对于每个问题 \(q\),GRPO 从旧策略 \(\pi_{\theta_{old} }\) 中采样一组输出 \(\{o_{1},o_{2},\cdots,o_{G}\}\),然后通过最大化以下目标来优化策略模型 \(\pi_{\theta}\):
    $$
    \mathcal{J}_{GRPO}(\theta)=\mathbb{E}_{[q\sim P(Q),\{o_{i}\}_{i=1}^{ G}\sim\pi_{\theta_{old} }(O|q)]} \frac{1}{G}\sum_{i=1}^{G}\left(\min\left(\frac{\pi_{\theta}(o_{i}|q)}{\pi_{\theta_{old} }(o_{i}|q)}A_{i},\text{clip}\left(\frac{\pi_{\theta}(o_{i}|q)}{\pi_{\theta_{old} }(o_{i}|q)},1-\varepsilon,1+\varepsilon\right)A_{i}\right)-\beta\mathbb{D}_{KL}\left(\pi_{\theta}||\pi_{ref}\right)\right),
    $$
    • 其中 \(\varepsilon\) 和 \(\beta\) 是超参数,\(A_{i}\) 是优势函数,使用与组内每个输出对应的一组奖励 \(\{r_{1},r_{2},\ldots,r_{G}\}\) 计算得出:
      $$
      A_{i}=\frac{r_{i}-\text{mean}(\{r_{1},r_{2},\cdots,r_{G}\})}{\text{std}(\{r_{1},r_{2},\cdots,r_{G}\})}.
      $$
  • KL 散度项 \(\mathbb{D}_{KL}\left(\pi_{\theta}||\pi_{ref}\right)\) 定义为:
    $$
    \mathbb{D}_{KL}\left(\pi_{\theta}||\pi_{ref}\right)=\frac{\pi_{ ref}(o_{i}|q)}{\pi_{\theta}(o_{i}|q)}-\log\frac{\pi_{ref}(o_{i}|q)}{\pi_{ \theta}(o_{i}|q)}-1.
    $$
Reward Modeling
  • 奖励是训练信号的来源,它决定了强化学习的优化方向
  • 为了训练 DeepSeek-R1-Zero,论文采用了一个基于规则的奖励系统,主要包括两种类型的奖励:
    • 准确性奖励 (Accuracy rewards) :准确性奖励模型评估响应是否正确
      • 例如,对于具有确定性结果的数学问题,模型需要以指定格式(例如,在方框内)提供最终答案,从而能够基于规则可靠地验证正确性
      • 类似地,对于 LeetCode 问题,可以使用编译器根据预定义的测试用例生成反馈
    • 格式奖励 (Format rewards) :除了准确性奖励模型,论文还采用了一个格式奖励模型,强制模型将其思维过程放在 <think> 和 </think> 标签之间
  • 在开发 DeepSeek-R1-Zero 时,论文没有使用基于结果的或基于过程的神经奖励模型(neural reward model)
    • 因为论文发现神经奖励模型在大规模强化学习过程中可能遭受奖励破解(reward hacking)问题 ,并且重新训练奖励模型需要额外的训练资源,并使整个训练流程复杂化
Training Template
  • 为了训练 DeepSeek-R1-Zero,论文首先设计了一个简单的模板,引导基础模型遵循论文指定的指令
  • 如表 1 所示,该模板要求 DeepSeek-R1-Zero 先生成一个推理过程,然后是最终答案
  • 论文有意将约束限制在这种结构格式上,避免任何特定于内容的偏见——例如强制进行反思性推理或推广特定的问题解决策略,以确保论文能够准确观察模型在强化学习过程中的自然进展
Performance, Self-evolution Process and Aha Moment of DeepSeek-R1-Zero
Performance of DeepSeek-R1-Zero
  • 图 2 描绘了 DeepSeek-R1-Zero 在 AIME 2024 基准测试上的性能随强化学习训练过程的变化轨迹
  • 如图所示,随着强化学习训练的进行,DeepSeek-R1-Zero 表现出稳定且一致的性能提升
    • AIME 2024 的平均 pass@1 分数显著增加,从最初的 15.6% 跃升至令人印象深刻的 71.0%,达到了与 OpenAI-o1-0912 相当的性能水平
    • 这一显著改进凸显了论文的强化学习算法在随时间优化模型性能方面的有效性
    • 注:图 2 中 cons@k 是多数投票的结果(cons 表示 consensus,即共识):详情见 NLP——技术报告解读-DeepSeek-R1-Supplements
  • 表 2 提供了 DeepSeek-R1-Zero 与 OpenAI 的 o1-0912 模型在各种推理相关基准测试上的比较分析
    • 研究结果表明,强化学习使 DeepSeek-R1-Zero 能够在不需要任何监督微调数据的情况下获得强大的推理能力
    • 这是一个值得注意的成就,因为它强调了模型仅通过强化学习就能有效学习和泛化的能力
  • 此外,通过应用多数投票(majority voting),可以进一步增强 DeepSeek-R1-Zero 的性能
    • 例如,在 AIME 基准测试上使用多数投票时,DeepSeek-R1-Zero 的性能从 71.0% 提升至 86.7%(图 2 中 cons@16 的结果),从而超过了 OpenAI-o1-0912 的性能
    • DeepSeek-R1-Zero 在有和没有多数投票的情况下都能实现如此有竞争力的性能,这突显了其强大的基础能力及其在推理任务中进一步发展的潜力
Self-evolution Process of DeepSeek-R1-Zero
  • DeepSeek-R1-Zero 的自我进化过程是一个迷人的演示(demonstration),展示了强化学习如何驱动模型自主提高其推理能力
    • 通过直接从基础模型启动强化学习,我们可以在不受监督微调阶段影响的情况下密切监控模型的进展
    • 这种方法清晰地展示了模型随时间演变的过程,特别是在其处理复杂推理任务的能力方面
  • 如图 3 所示,DeepSeek-R1-Zero 的思考时间在整个训练过程中持续改善
    • 这种改进不是外部调整的结果,而是模型内部的内在发展
    • DeepSeek-R1-Zero 自然地获得了通过利用延长的测试时间计算来解决日益复杂的推理任务的能力
    • 这种计算范围从生成数百到数千个推理 Token,使模型能够更深入地探索和完善其思维过程
  • 这种自我进化最显著的方面之一是随着测试时间计算的增加而出现的复杂行为
    • 诸如反思(模型重新审视和重新评估其先前步骤)以及探索替代性问题解决方法等行为自发产生
    • 这些行为不是显式编程的,而是模型与强化学习环境交互的结果
    • 这种自发的发展显著增强了 DeepSeek-R1-Zero 的推理能力,使其能够更高效、更准确地应对更具挑战性的任务
Aha Moment of DeepSeek-R1-Zero
  • 在 DeepSeek-R1-Zero 的训练过程中观察到一个特别有趣的现象是 “顿悟时刻”(aha moment) 的出现
  • 如表 3 所示,这个时刻发生在模型的一个中间版本中
    • 在此阶段,DeepSeek-R1-Zero 学会了通过重新评估其初始方法为问题分配更多的思考时间
    • 这种行为不仅证明了模型不断增长的推理能力,也是强化学习如何导致意外和复杂结果的一个引人入胜的例子
  • 这个时刻不仅是模型的“顿悟时刻”,对观察其行为的研究人员来说也是如此
    • 它强调了强化学习的力量和美感:论文不是明确地教导模型如何解决问题,而是简单地提供正确的激励,它就会自主地发展出高级的问题解决策略
    • “顿悟时刻”有力地提醒论文强化学习在人工智能系统中解锁新智能水平的潜力,为未来更自主和自适应的模型铺平道路
  • 个人理解:后面的一些文章逐步分析并证明,一些顿悟时刻实际上并不是一个突然发生的过程,而是逐步发生的,只是在特定任务上看起来像是突然发生一样
DeepSeek-R1-Zero 的缺点 (Drawback of DeepSeek-R1-Zero)**
  • 尽管 DeepSeek-R1-Zero 表现出强大的推理能力并自主发展出意想不到的强大推理行为,但它也面临几个问题
    • DeepSeek-R1-Zero 存在可读性差和语言混合等挑战
  • 为了使推理过程更具可读性并与开放社区分享,论文探索了 DeepSeek-R1,这是一种利用强化学习和对人类友好的冷启动数据的方法

DeepSeek-R1: Reinforcement Learning with Cold Start

  • 受 DeepSeek-R1-Zero 有希望的结果的启发,两个自然的问题出现了:
    • 1)通过加入少量高质量数据作为冷启动,能否进一步提高推理性能或加速收敛?
    • 2)论文如何训练一个用户友好的模型,不仅能产生清晰连贯的思维链(CoT),还能展现出强大的通用能力?
  • 为了解决这些问题,论文设计了一个训练 DeepSeek-R1 的流程
    • 该流程包括四个阶段,概述如下文所示
  • 补充:来自其他博主制作的非常好的 DeepSeek-R1 训练过程:
    • 注意:根据 DeepSeek-V3 辅助材料给出的结论,下图中存在问题(已补充),DeepSeek-R1 和 DeepSeek-R1-Zero 均是从 DeepSeek-V3-Base 训练而来,图中给的是 DeepSeek-V3 (这是 DeepSeek-V3-Base 的微调版本);部分训练数据(监督微调数据中的非推理类数据)确实来源于 DeepSeek-V3
Cold Start
  • 与 DeepSeek-R1-Zero 不同,为了防止从基础模型开始强化学习训练时早期不稳定的冷启动阶段,对于 DeepSeek-R1,论文构建并收集了少量长思维链(CoT)数据来微调模型,微调后的模型作为初始的强化学习 Actor
  • 为了收集此类数据,论文探索了几种方法:
    • 使用带有长 CoT 示例的少样本提示(few-shot prompting),直接提示模型生成带有反思和验证的详细答案,以可读格式收集 DeepSeek-R1-Zero 的输出,以及通过人工标注员的后处理来细化结果
  • 在这项工作中,论文收集了数千个冷启动数据来微调 DeepSeek-V3-Base,作为强化学习的起点
  • 与 DeepSeek-R1-Zero 相比,冷启动数据的优势包括:
    • 可读性 (Readability) :DeepSeek-R1-Zero 的一个关键限制是其内容通常不适合阅读
      • 响应可能混合多种语言或缺乏用于向用户突出显示答案的 markdown 格式
      • 在为 DeepSeek-R1 创建冷启动数据时,论文设计了一种可读的模式 ,包括在每个响应末尾进行总结 ,并过滤掉对读者不友好的响应
      • 在这里,论文将输出格式定义为 \(|\)special_token\(|\)\(<\)reasoning_process\(>\)\(|\)special_token\(|\)\(<\)summary\(>\),其中推理过程是针对查询的 CoT,而总结(summary)用于总结推理结果
    • 潜力 (Potential) :通过利用人类先验知识精心设计冷启动数据的模式,论文观察到相对于 DeepSeek-R1-Zero 更好的性能
      • 作者相信,对于推理模型来说,迭代训练(iterative training)是一种更好的方式
Reasoning-oriented Reinforcement Learning
  • 在基于冷启动数据对 DeepSeek-V3-Base 进行微调之后,论文应用了与 DeepSeek-R1-Zero 相同的大规模强化学习训练过程
  • 此阶段侧重于增强模型的推理能力 ,特别是在编码、数学、科学和逻辑推理等推理密集型任务中,这些任务涉及具有明确解决方案的明确定义的问题
  • 在训练过程中,论文观察到 CoT 经常出现语言混合 ,特别是当强化学习提示涉及多种语言时
    • 为了缓解语言混合问题,论文在强化学习训练期间引入了语言一致性奖励(language consistency reward),该奖励计算为 CoT 中目标语言单词的比例
    • 尽管消融实验表明这种对齐会导致模型性能略有下降 ,但这种奖励符合人类偏好,使其更具可读性
    • 问题:CoT 不需要让人可以阅读吧?
  • 论文通过直接求和将推理任务的准确性和语言一致性奖励结合起来,形成最终奖励
  • 论文在微调后的模型上应用强化学习训练,直到其在推理任务上达到收敛
Rejection Sampling and Supervised Fine-Tuning
  • 当面向推理的强化学习收敛时,论文利用得到的检查点来为后续轮次收集 SFT 数据
  • 与主要关注推理的初始冷启动数据不同,此阶段合并了来自其他领域的数据 ,以增强模型在写作、角色扮演和其他通用任务中的能力
  • 具体来说,论文按照下述方式生成数据并微调模型(分推理数据和非推理数据)
Reasoning data
  • 论文策划(curate)推理提示词(prompts),并通过从上述强化学习训练的检查点执行拒绝采样(rejection sampling)来生成推理轨迹(trajectories)
  • 在前一阶段,论文只包含了可以使用基于规则的奖励进行评估的数据
  • 在此阶段,论文通过合并额外的数据来扩展数据集,其中一些数据使用生成式奖励模型,通过将真实值(ground-truth)和模型预测输入到 DeepSeek-V3 中进行判断
  • 由于模型输出有时混乱且难以阅读,论文过滤掉了具有混合语言、长段落和代码块的思维链
  • 对于每个提示,论文采样多个响应,并仅保留正确的响应
  • 总共,论文收集了约 60 万个与推理相关的训练样本
Non-Reasoning data
  • 对于非推理数据,例如写作、事实问答(factual QA)、自我认知(self-cognition)和翻译,论文采用 DeepSeek-V3 的流程并重用部分 DeepSeek-V3 的监督微调数据集
  • 对于某些非推理任务,论文调用 DeepSeek-V3,通过在回答问题时进行提示来生成潜在的思维链
  • 但对于更简单的查询 ,例如“你好”,论文不提供 CoT 响应
  • 最终,论文总共收集了约 20 万个与推理无关的训练样本
  • 论文使用上述策划的大约 80 万个样本的数据集 ,对 DeepSeek-V3-Base 进行了两个轮次(epochs)的微调
    • 注:这个 2 个 epochs 的设定和 大约 100W 样本的微调量级,已经成为后来一些文章的标配
Reinforcement Learning for all Scenarios
  • 为了进一步使模型与人类偏好对齐,论文实施了第二轮强化学习阶段,旨在提高模型的有用性(helpfulness)和无害性(harmlessness) ,同时完善其推理能力
  • 论文结合使用奖励信号和多样化的提示词分布(diverse prompt distributions)来训练模型
    • 对于推理数据:
      • 论文遵循 DeepSeek-R1-Zero 中概述的方法论,该方法利用基于规则的奖励来指导数学、代码和逻辑推理领域的学习过程
    • 对于通用数据
      • 论文用奖励模型来捕捉复杂和细微场景中的人类偏好
        • 论文基于 DeepSeek-V3 的流程,并采用了类似的偏好对(preference pairs)和训练提示词分布
      • 对于有用性,论文只关注最终总结(summary),确保评估强调响应对于用户的实用性和相关性,同时最小化对底层推理过程的干扰
      • 对于无害性,论文评估模型的整个响应,包括推理过程和总结,以识别和减轻生成过程中可能出现的任何潜在风险、偏见或有害内容
      • 最终,奖励信号和多样化数据分布的整合使论文能够训练出一个在推理方面表现出色,同时优先考虑有用性和无害性的模型

Distillation: Empower Small Models with Reasoning Capability

  • 为了让更高效的小型模型具备像 DeepSeek-R1 一样的推理能力,论文直接使用 DeepSeek-R1 策划的 80 万个样本(详见 章节2.3.3)对开源模型如 Qwen 和 Llama 进行了微调
  • 论文的研究结果表明,这种简单的蒸馏方法显著增强了小型模型的推理能力
  • 论文这里使用的基础模型是 Qwen2.5-Math-1.5B, Qwen2.5-Math-7B, Qwen2.5-14B, Qwen2.5-32B, Llama-3.1-8B 和 Llama-3.3-70B-Instruct
    • 论文选择 Llama-3.3 是因为其推理能力略优于 Llama-3.1
    • 问题:为什么 70B 量级的模型,选择的不是 Qwen2.5-72B-Instruct ?
  • 对于蒸馏模型 ,论文仅应用了 SFT ,没有包含 RL 阶段 ,尽管加入强化学习可以大幅提升模型性能
    • 论文这里的主要目标是证明蒸馏技术的有效性 ,将强化学习阶段的探索留给更广泛的研究社区

Experiment

Benchmarks

  • 标准测试基准方面:
    • 论文在 MMLU (2020)、MMLU-Redux (2024)、MMLU-Pro (2024)、C-Eval (2023)、CMMLU (2023)、IFFval (2023)、FRAMES (2024)、GPQA Diamond (2023)、SimpleQA (OpenAI)、C-SimpleQA (2024)、SWE-Bench Verified (OpenAI, 2024d)、Aider、LiveCodeBench (2024) (2024-08 - 2025-01)、Codeforces、中国高中数学奥林匹克竞赛(Chinese National High School Mathematics Olympiad, CNMO 2024) 和美国数学邀请赛 2024(American Invitational Mathematics Examination 2024, AIME 2024)(MAA, 2024) 上评估模型
  • 除了标准基准测试外,论文还使用 LLM 作为评判者,在开放式生成任务上评估论文的模型
    • 论文遵循 AlpacaEval 2.0 (2024) 和 Arena-Hard (2024) 的原始配置,它们利用 GPT-4-Turbo-1106 作为配对比较的评判者
    • 在这里,论文仅将最终摘要提供给评估,以避免长度偏差
    • 对于蒸馏模型,论文报告了在 AIME 2024、MATH-500、GPQA Diamond、Codeforces 和 LiveCodeBench 上的代表性结果

Evaluation Prompts

  • 遵循 DeepSeek-V3 的设置,使用 simple-evals 框架提供的提示来评估标准基准测试,如 MMLU、DROP、GPQA Diamond 和 SimpleQA
  • 对于 MMLU-Redux,论文在零样本(zero-shot)设置中采用 Zero-Eval 提示格式 (2024)
  • 对于 MMLU-Pro、C-Eval 和 CLUE-WSC,由于原始提示是少样本(few-shot)的,论文略微修改了提示以适应零样本设置
    • 少样本中的思维链(CoT)可能会损害 DeepSeek-R1 的性能
    • 问题:这里怎么理解?
  • 其他数据集遵循其创建者提供的默认提示的原始评估协议
  • 对于代码和数学基准测试
    • HumanEval-Mul 数据集涵盖了八种主流编程语言(Python、Java、C++、C#、JavaScript、TypeScript、PHP 和 Bash)
    • 使用 CoT 格式评估模型在 LiveCodeBench 上的性能,数据收集时间为 2024 年 8 月至 2025 年 1 月
    • 使用 10 场 Div.2 比赛的题目以及专家精心设计的测试用例来评估 Codeforces 数据集,然后计算预期评分和参赛者百分比
    • 通过无代理框架(agentless framework)(2024) 获得 SWE-Bench 验证结果
    • 使用“diff”格式衡量 AIDER 相关基准测试
  • DeepSeek-R1 的输出在每个基准测试中最多限制为 32,768 个 Token

Baselines

  • 论文对多个强基线模型进行了全面评估,包括 DeepSeek-V3、Claude-Sonnet-3.5-1022、GPT-4o-0513、OpenAI-o1-mini 和 OpenAI-o1-1217
  • 由于在中国大陆访问 OpenAI-o1-1217 API 具有挑战性,论文根据官方报告报告其性能
  • 对于蒸馏模型,论文还比较了开源模型 QwQ-32B-Preview (Qwen)

Evaluation Setup

  • 论文将模型的最大生成长度设置为 32,768 个 Token
  • 论文发现使用贪婪解码(greedy decoding)来评估长输出推理模型会导致更高的重复率和不同检查点之间的显著变异性
    • 因此,论文默认使用 pass@\(k\) 评估 (2021),并使用非零温度(non-zero temperature)报告 pass@1
    • 理解:零温度表示贪婪解码
  • 具体来说,论文使用采样温度 0.6 和 top-\(p\) 值 0.95 为每个问题生成 \(k\) 个回复(通常在 4 到 64 之间,取决于测试集大小)。然后 pass@1 计算为
    $$
    \text{pass@}1=\frac{1}{k}\sum_{i=1}^{k}p_{i}
    $$
    • 其中 \(p_{i}\) 表示第 \(i\) 个回复的正确性
    • 这种方法提供了更可靠的性能估计
  • 对于 AIME 2024,论文还使用 64 个样本报告了共识(consensus),即多数投票(majority vote)结果 (2022),记为 cons@64

DeepSeek-R1 Evaluation

  • 评估结果如表 4 所示:
  • 对于面向教育的知识基准测试,如 MMLU、MMLU-Pro 和 GPQA Diamond,DeepSeek-R1 相较于 DeepSeek-V3 展现出更优越的性能
    • 这一改进主要归功于通过大规模强化学习在 STEM 相关问题上准确率的显著提升
  • DeepSeek-R1 在 FRAMES(一个依赖长上下文的问答任务)上表现卓越,展示了其强大的文档分析能力
    • 这凸显了推理模型在 AI 驱动的搜索和数据分析任务中的潜力
  • 在事实性基准测试 SimpleQA 上,DeepSeek-R1 的表现优于 DeepSeek-V3,证明了其处理基于事实的查询的能力
    • OpenAI-o1 在该基准测试上超越 GPT-4o 也观察到了类似的趋势
    • 由于在安全强化学习(safety RL)后倾向于拒绝回答某些查询 ,DeepSeek-R1 在中文 SimpleQA 基准测试上的表现不如 DeepSeek-V3
    • 若没有安全强化学习,DeepSeek-R1 的准确率可以超过 70%
  • DeepSeek-R1 在 IF-Eval(一个旨在评估模型遵循格式指令能力的基准测试)上也取得了令人印象深刻的结果
    • 这些改进可以归因于在 SFT 和强化学习训练的最后阶段包含了遵循指令的数据
  • 在 AlpacaEval2.0 和 ArenaHard 上观察到了卓越的性能,表明 DeepSeek-R1 在写作任务和开放域问答方面的优势
    • 其显著超越 DeepSeek-V3 的表现凸显了大规模强化学习的泛化益处,它不仅提升了推理能力,还提高了跨不同领域的性能
    • 特别地,DeepSeek-R1 生成的摘要长度简洁,在 ArenaHard 上平均为 689 个 Token,在 AlpacaEval 2.0 上平均为 2218 个字符
      • 这表明 DeepSeek-R1 在基于 GPT 的评估中避免了引入长度偏差,进一步巩固了其在多项任务中的鲁棒性
  • 在数学任务和编码算法任务(如 LiveCodeBench 和 Codeforces)上,DeepSeek-R1 表现出与 OpenAI-o1-1217 相当的性能,大幅超越其他模型
    • 专注于推理的模型在这些基准测试中占据主导地位
  • 特别地,在面向工程的编码任务上 ,OpenAI-o1-1217 在 Aider 上优于 DeepSeek-R1 ,但在 SWE Verified 上取得了相当的性能
    • 作者认为 DeepSeek-R1 的工程性能将在下一个版本中得到改善,因为目前相关的强化学习训练数据量仍然非常有限

Distilled Model Evaluation

  • 如表 5 所示
    • 仅通过蒸馏 DeepSeek-R1 的输出,高效的 DeepSeek-R1-7B(即 DeepSeek-R1-Distill-Qwen-7B,下文类似缩写)就能全面超越如 GPT-4o-0513 这样的非推理模型
    • DeepSeek-R1-14B 在所有评估指标上均超越了 QwQ-32B-Preview,而 DeepSeek-R1-32B 和 DeepSeek-R1-70B 在大多数基准测试上显著超过了 o1-mini
    • 这些结果展示了蒸馏的强大潜力
  • 论文发现对这些蒸馏模型应用强化学习能带来显著的进一步增益
    • 作者认为这值得进一步探索,因此在此仅展示简单 SFT 蒸馏模型的结果

Discussion

Distillation v.s. Reinforcement Learning

  • 在 3.2 节中,我们可以看到通过蒸馏 DeepSeek-R1,小模型能够取得令人印象深刻的结果
    • 但还有一个问题悬而未决:模型能否不通过蒸馏,而是通过论文讨论的大规模强化学习训练达到相当的性能?
  • 为了回答这个问题,论文在 Qwen-32B-Base 上使用数学、代码和 STEM 数据进行了大规模强化学习训练,训练超过 10K 步,得到了 DeepSeek-R1-Zero-Qwen-32B
  • 实验结果如表 6 所示,表明 32B 基础模型在经过大规模强化学习训练后,性能与 QwQ-32B-Preview 相当
    • 但从 DeepSeek-R1 蒸馏得到的 DeepSeek-R1-Distill-Qwen-32B 在所有基准测试上的表现均显著优于 DeepSeek-R1-Zero-Qwen-32B
  • 因此,我们可以得出两个结论:
    • 首先,将更强大的模型蒸馏到较小的模型中能产生优异的结果 ,而依赖论文提到的大规模强化学习的小模型需要巨大的计算能力 ,甚至可能无法达到蒸馏的性能
    • 其次,虽然蒸馏策略既经济又有效 ,但要突破智能的边界可能仍然需要更强大的基础模型和更大规模的强化学习

Unsuccessful Attempts

  • 在开发 DeepSeek-R1 的早期阶段,论文也遇到了一些失败和挫折
    • 论文在此分享论文的失败经验以提供见解,但这并不意味着这些方法无法开发出有效的推理模型
Process Reward Model, PRM
  • PRM 是一种合理的方法,可以指导模型采用更好的方法来解决推理任务 (2023; 2022;);但在实践中,PRM 有三个主要局限性可能阻碍其最终成功
    • 第一,在通用推理中明确定义细粒度的步骤具有挑战性
    • 第二,判断当前中间步骤是否正确是一项艰巨的任务
      • 使用模型进行自动标注可能无法产生令人满意的结果,而手动标注不利于扩大规模
    • 第三,一旦引入基于模型的 PRM,就不可避免地会导致奖励黑客攻击(reward hacking)(2022)
      • 并且重新训练奖励模型需要额外的训练资源,并使整个训练流程复杂化
  • 总之,虽然 PRM 在重排模型生成的 top-N 响应或辅助引导式搜索 (2024) 方面表现出良好的能力,但在论文实验的大规模强化学习过程中,与其引入的额外计算开销相比,其优势有限
    • 问题:如何理解 PRM 在重排模型生成的 top-N 响应或辅助引导式搜索 方面表现出良好的能力?
Monte Carlo Tree Search,MCTS
  • 受 AlphaGo (2017a) 和 AlphaZero (2017b) 的启发,论文探索了使用蒙特卡洛树搜索(MCTS)来增强测试时计算的可扩展性
  • MCTS 方法涉及将答案分解成更小的部分,以便模型能够系统地探索解决方案空间
  • 为了促进这一点,论文提示模型生成多个标签,这些标签对应于搜索所需的特定推理步骤
  • 对于训练:
    • 首先使用收集的提示,通过由预训练价值模型引导的 MCTS 来寻找答案
      • 理解:这里的价值模型决定了每次选择哪些节点进行扩展
    • 随后使用产生的 问题-答案 对来训练行动者模型和价值模型,并迭代地改进这个过程
  • 但这种方法在扩大训练规模时遇到了几个挑战
    • 首先,与搜索空间相对明确的象棋不同,Token 生成呈现出一个指数级更大的搜索空间
      • 为了解决这个问题,论文为每个节点设置了最大扩展限制,但这可能导致模型陷入局部最优
    • 其次,价值模型直接影响生成的质量,因为它指导搜索过程的每一步
      • 训练一个细粒度的价值模型本身就很困难,这使得模型难以迭代改进
      • 虽然 AlphaGo 的核心成功依赖于训练一个价值模型来逐步提高其性能,但由于 Token 生成的复杂性,这一原则在论文的设置中难以复制
  • 总之,虽然 MCTS 在与预训练价值模型配对时可以在推理过程中提高性能,但通过自我搜索迭代地提升模型性能仍然是一个重大挑战

Conclusion, Limitations, and Future Work

  • 在本工作中,论文分享了通过 RL 来增强模型推理能力的探索历程
    • DeepSeek-R1-Zero 代表了一种不依赖冷启动数据(cold-start data)的纯强化学习方法,在各种任务上均取得了强劲的性能
    • DeepSeek-R1 更加强大,它利用了冷启动数据以及迭代式的强化学习微调
    • 最终,DeepSeek-R1 在一系列任务上达到了与 OpenAI-o1-1217 相当的性能
  • 论文进一步探索了将推理能力蒸馏(distillation)到小型稠密模型(small dense models)中的方法
    • 论文使用 DeepSeek-R1 作为教师模型(teacher model)来生成 80 万条训练样本,并对多个小型稠密模型进行了微调
    • 结果令人鼓舞:
      • DeepSeek-R1-Distill-Qwen-1.5B 在数学基准测试中超越了 GPT-4o 和 Claude-3.5-Sonnet,在 AIME 上达到了 28.9%,在 MATH 上达到了 83.9%
      • 其他稠密模型也取得了令人印象深刻的结果,显著超越了基于相同底层检查点(underlying checkpoints)的其他指令微调模型(instruction-tuned models)
  • 未来,论文计划在以下几个方向为 DeepSeek-R1 投入研究
    • 通用能力 (General Capability):
      • 目前,DeepSeek-R1 在函数调用(function calling)、多轮对话(multi-turn)、复杂角色扮演(complex role-playing)和 JSON 输出等任务上的能力尚不及 DeepSeek-V3
      • 接下来,论文计划探索如何利用长思维链(long Chain-of-Thought, CoT)来增强这些领域的任务
    • 语言混合 (Language Mixing):
      • DeepSeek-R1 目前针对中文和英文进行了优化,这可能导致在处理其他语言的查询时出现语言混合问题
      • 例如,即使用户查询使用的不是英文或中文,DeepSeek-R1 也可能使用英文进行推理和回复
      • 论文旨在未来的更新中解决这一局限性
    • 提示工程 (Prompting Engineering):
      • 在评估 DeepSeek-R1 时,论文观察到它对 Prompts 很敏感
      • 少样本提示(few-shot prompting)consistently 会降低其性能
      • 因此,论文建议用户在使用零样本(zero-shot)设置时直接描述问题并指定输出格式 ,以获得最佳结果
    • 软件工程任务 (Software Engineering Tasks):
      • 由于评估时间较长,影响了强化学习过程的效率,大规模强化学习尚未广泛应用于软件工程任务
      • 因此,DeepSeek-R1 在软件工程基准测试中并未显示出相对于 DeepSeek-V3 的巨大改进
      • 未来的版本将通过对软件工程数据实施拒绝采样(rejection sampling),或在强化学习过程中引入异步评估(asynchronous evaluations)来提高效率,从而解决这一问题
1…91011…66
Joe Zhou

Joe Zhou

Stay Hungry. Stay Foolish.

659 posts
53 tags
GitHub E-Mail
© 2026 Joe Zhou
Powered by Hexo
|
Theme — NexT.Gemini v5.1.4