Jiahong 的个人博客

凡事预则立,不预则废


  • Home

  • Tags

  • Archives

  • Navigation

  • Search

RS——Meta-GRs-HSTU

  • 参考链接:
    • (HSTU)Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for Generative Recommendations, ICML 2024, Meta

整体说明

  • 当前推荐系统的特点 :
    • 大规模推荐系统的特点在于其依赖高基数(high cardinality,注:也称为高维)、异构特征(heterogeneous features),并且需要每天处理数百亿用户行为
    • 尽管大多数工业界的深度学习推荐模型(Deep Learning Recommendation Models,DLRMs)利用海量数据和数千个特征进行训练,却未能实现与计算资源的有效扩展(即增加了资源,效果得不到对应的提升)
  • 论文方案一句话说明 :(受 Transformer 在语言和视觉领域成功的启发)
    • 论文将推荐问题重新定义为生成式建模框架下的 Squential Transduction 任务(Generative Recommenders(GRs),生成式推荐器),并提出了一种新架构HSTU ,专为高基数、非平稳的流式推荐数据设计
    • 注:论文的 Squential Transduction 和 Transduction Learning(直推式学习)没有直接关系(虽然二者名字很像);
      • Transduction Learning(直推式学习) :利用训练数据和测试数据的整体信息,直接对见过的测试数据做出预测
      • Inductive Learning(归纳式学习) :利用训练数据的信息,训练一个可以泛化到未见过的测试数据的通用的模型
      • 从定义看,论文本质还是一个Inductive Learning(归纳式学习) ,学到的模型是要使用在未知数据场景的(包括未知用户、未知时间和未知商品)
  • 效果和效率 :
    • 离线效果 :在合成和公开数据集上,HSTU相比基线模型在 NDCG 指标上最高提升 65.8%
    • 效率 :并且在处理 8192 长度序列时比基于 FlashAttention2 的 Transformer 快 5.3 倍至 15.2 倍
    • 在线效果 :基于 HSTU 的生成式推荐器拥有1.5万亿参数,在线A/B测试中指标提升12.4%,已在拥有数十亿用户的互联网平台的多个场景中部署
  • scaling law发现 :更重要的是,生成式推荐器的模型质量随训练计算量呈幂律增长(power-law),扩展范围达到GPT-3/LLaMa-2规模,这减少了未来模型开发所需的 carbon footprint(碳排放总和),并为推荐系统中的首个基础模型铺平了道路
    • 理解:power-law说明模型的算力投入可以很高效的得到回报,所以不需要盲目增加很多计算资源,也就减少了碳排放
  • 注意:论文的方法依然保留了推荐系统中的 Retrieval-Ranking 两阶段(Retrieval也称为召回或 Retrieval 阶段),只是分别将两阶段都使用生成式模型来建模了,对于召回阶段,甚至可以将生成式模型作为一个新的召回通道使用,作为对原始 DLRM 的一种补足

一些背景讨论

  • 传统推荐系统 : SOTA 推荐方法主要基于深度学习推荐模型 (DLRM),其特点在于使用异构特征,包括数值特征(如 counters 和 ratios)、嵌入以及分类特征(如商品ID、用户ID等)
    • 由于每分钟都有新内容和商品加入,特征空间具有极高的基数,通常达到十亿级别
    • 为了利用这些数以万计的特征,DLRM 采用各种神经网络组合特征、转换中间表示并生成最终输出
    • 尽管 DLRM 利用了大量人工设计的特征和海量数据进行训练,但工业界大多数 DLRM 的计算扩展性较差。这一显著限制至今仍未得到解决
  • 论文的思路 :受 Transformer 在语言和视觉领域成功的启发,论文重新审视了现代推荐系统中的基础设计选择,但作者观察到:在十亿用户规模下,这需要克服三个挑战 :
    • 推荐系统中的特征缺乏明确结构 :虽然在小规模场景中已探索过序列化方法(附录B详细讨论),但工业级 DLRM 中异构特征(包括高基数ID、交叉特征、计数器、比率等)起着关键作用
    • 持续变化的十亿级词汇表 :与语言模型中十万级静态词汇表相比,十亿级动态词汇表带来了训练挑战,并由于需要考虑数万个目标感知(target-aware)候选而推高了推理成本
      • 目标感知(target-aware)指生成用户表示或进行预测时,能够动态地结合当前候选内容(target item)的信息,从而更精准地建模用户与特定候选内容之间的交互
      • 注:传统推荐系统(如DLRMs)中,target-aware通常指在特征交互阶段显式引入候选内容(target item)的信息
    • 计算成本大 :(这是大规模序列模型落地的主要瓶颈),GPT-3使用数千个GPU在1-2个月内处理了总计3000亿 token。这一规模看似惊人,但与用户行为规模相比则相形见绌。最大规模的互联网平台每天服务数十亿日活用户,用户每天与数十亿帖子、图片和视频互动。用户序列长度可达 \(10^5\) 。因此,推荐系统每天需要处理的 token 数量比语言模型1-2个月处理的量高出几个数量级
  • 一个特别重要的创新点 :在本工作中,论文将用户行为视为生成式建模中的新模态(在图片、文本和视频等多模态的基础上,增加一个模态叫做用户行为)。论文的 key insights 是:
    • a)给定适当的新特征空间,工业级推荐系统中的核心排序和召回任务可被转化为生成式建模问题;
    • b)这一范式使论文能够系统性地利用特征、训练和推理中的冗余(redundancies)来提高效率,基于这一新范式,论文部署的模型计算复杂度比之前最先进技术高出三个数量级(如图1所示),同时核心业务指标(英文中称为topline metrics)提升了12.4%
      • 问题:如何理解”特征、训练和推理中的冗余(redundancies)“这句话?
  • 论文的贡献可总结如下:
    • 第一 :提出生成式推荐器(Generative Recommenders,GRs) ,这一新范式将取代 DLRM
      • 论文将 DLRM 中的异构特征空间序列化并统一(注:用户行为视为特定的一个新模态),当序列长度趋近无穷时,新方法可近似完整的 DLRM 特征空间。使得:
        • 论文能将主要推荐问题(排序和召回)重新定义为GR中的纯 Squential Transduction 任务
        • 论文模型训练能以序列化、生成式的方式进行,从而允许论文在相同计算量下处理多几个数量级的数据
    • 第二 :解决训练和推理过程中的计算成本挑战 :
      • 提出新结构 :提出新的 Squential Transduction 架构——分层 Squential Transduction 单元(Hierarchical Sequential Transduction Units,HSTU)。HSTU针对大规模非平稳词汇表修改了注意力机制 ,并利用推荐数据集特性实现比基于 FlashAttention2 的 Transformer 在 8192 长度序列上快5.3倍至15.2倍
      • 提出新算法M-FALCON :通过新算法M-FALCON完全平摊计算成本,论文能在传统 DLRM 使用的相同推理预算下,服务复杂度高出285倍的GR模型,同时实现1.50-2.99倍的加速
    • 第三 :验证所提技术在合成数据集、公开数据集以及在拥有数十亿日活用户的大型互联网平台多个场景中的部署效果
      • 据论文所知,论文的工作首次展示了纯 Squential Transduction 架构(如HSTU)在生成式设置(GRs)中显著优于工业级大规模 DLRM。值得注意的是,论文不仅克服了传统 DLRM 中已知的扩展瓶颈,还进一步证明了扩展定律(scaling law)适用于推荐系统 ,这可能是推荐系统的”ChatGPT时刻“

推荐作为 Squential Transduction 任务:从 DLRM 到GR

统一 DLRM 中的异构特征空间

  • 现代 DLRM 模型通常使用大量分类(sparse)和数值(dense)特征进行训练。在GR中,论文将这些特征整合并编码为统一的时间序列 ,如图2所示
  • 分类(sparse)特征 :这类特征的例子包括用户喜欢的商品、用户关注的某类别(如户外)创作者、用户语言、用户加入的社区、请求发起的城市等。论文按以下方式将这些特征序列化:
    • 选择最长的时间序列(通常通过合并代表用户互动商品的特征)作为主时间序列
    • 其余特征通常是随时间缓慢变化的序列(如人口统计信息或关注的创作者),论文通过保留每个连续段的最早条目(entry)来压缩这些时间序列,然后将结果合并到主时间序列中。由于这些时间序列变化非常缓慢,这种方法不会显著增加总体序列长度
      • 理解:如图2所示,auxiliary time series 1 和 auxiliary time series 2 中,分别只有每个连续段的第一个entry被插入到 main time series 中(注:auxiliary time series 1 包含多个连续段,每个连续段的第一个 entry 都会保留并插入 main time series 中)
    • 这部分写得不是很清晰,有点晦涩,原文介绍如下:

      Categorical (‘sparse’) features. Examples of such features include items that user liked, creators in a category (e.g., Outdoors) that user is following, user languages, communities that user joined, cities from which requests were initiated, etc. We sequentialize these features as follows. We first select the longest time series, typically by merging the features that represent items user engaged with, as the main time series. The remaining features are generally time series that slowly change over time, such as demographics or followed creators. We compress these time series by keeping the earliest entry per consecutive segment and then merge the results into the main time series. Given these time series change very slowly, this approach does not significantly increase the overall sequence length.

  • 数值(dense)特征 :这类特征的例子包括加权和衰减计数器、比率等(例如,某个特征可能代表用户过去对匹配特定主题商品的点击率(CTR))
    • 与分类特征相比,这些特征变化更频繁 ,可能随每个(用户,商品)互动而变化,所以从计算和存储角度看,完全序列化这些特征并不可行
    • 然而一个重要观察是:论文执行这些聚合所基于的分类特征(如商品主题、位置)在GR中已被序列化和编码。因此 ,给定足够强大的 Squential Transduction 架构与 target-aware 公式相结合,当增加GR中的总体序列长度和计算量时,论文可以移除数值特征 ,因为它们能被有效地捕捉

将排序和召回重新定义为 Squential Transduction 任务

  • 给定按时间顺序排列的 \(n\) 个 token 列表 \(x_{0},x_{1},\ldots,x_{n-1}\) (\(x_{i}\in\mathbb{X}\)),观察到这些 token 的时间 \(t_{0},t_{1},\ldots,t_{n-1}\) ,Squential Transduction 任务将此输入序列映射到输出 token \(y_{0},y_{1},\ldots,y_{n-1}\) (\(y_{i}\in\mathbb{X}\cup\{\varnothing\}\)),其中 \(y_{i}=\varnothing\) 表示 \(y_{i}\) 未定义
  • 论文用 \(\Phi_{i}\in\mathbb{X}_{c}\) (\(\mathbb{X}_{c}\subseteq\mathbb{X}\))表示系统向用户提供的内容(即历史上展示给用户看过的内容 ,如图片或视频)。由于不断有新内容产生 , \(\mathbb{X}_{c}\) 和 \(\mathbb{X}\) 是非平稳的。用户可以用某个行为 \(a_{i}\) (如点赞、跳过、视频完播+分享) \(a_{i}\in\mathbb{X}\) 来回应 \(\Phi_{i}\) 。论文用 \(n_{c}\) 表示用户互动过的内容总数
    • 对 Retrieval 输入的理解:输入是 \(\{(\Phi_{0},a_0),(\Phi_{1},a_1),\cdots,(\Phi_{n_c-1},a_{n_c-1})\}\) ,表示历史上依次给用户分别展示了内容 \(\Phi_{i}\) 以后用户的行为 \(a_{i}\) 的元组序列
    • 对 Retrieval 输出的理解:每个元组 \((\Phi_{i},a_i)\) 对应输出一个目标值 \(\Phi_{i}^{\prime}\) :
      $$\Phi_{i}^{\prime}=
      \begin{cases}
      \Phi_{i}& \ a_i \ \text{ is positive}\\
      \varnothing& \ \text{ otherwise}
      \end{cases}$$
    • 对 Ranking 输入的理解:如表1所示,对输入 \(x_i\) 的建模是, \(\{\Phi_{0},a_0,\Phi_{1},a_1,\cdots,\Phi_{n_c-1},a_{n_c-1}\}\) 是历史行为序列,表示历史上依次给用户分别展示了内容 \(\Phi_{i}\) 以后用户的行为 \(a_{i}\)
    • 对 Ranking 输出的理解:如表1所示,输出和输入依次对应,输入为 \(\Phi_{i}\) 时预估目标是用户的行为 \(a_{i}\)
  • 在因果自回归设置中,标准排序和 Retrieval 任务可定义为 Squential Transduction 任务(表1)。论文得出以下观察:
    • 召回(Retrieval) :在推荐系统的召回阶段,论文学习 \(\Phi_{i+1}\in\mathbb{X}_{c}\) 上的分布 \(p(\Phi_{i+1}|u_{i})\) ,其中 \(u_{i}\) 是用户 \(i\) 的 token 表示。典型目标是选择 \(\arg\max_{\Phi\in\mathbb{X}_{c} }p(\Phi|u_{i})\) 以最大化某些奖励。这与标准自回归设置有两个不同:首先, \(x_{i}\) 的监督信号 \(y_{i}\) 不一定是 \(\Phi_{i+1}\) ,因为用户可能对 \(\Phi_{i+1}\) 做出负面回应。其次,当 \(x_{i+1}\) 代表非互动相关的分类特征(如人口统计信息)时, \(y_{i}\) 未定义
    • 排序(Ranking) :GR中的 Ranking 任务带来独特挑战,因为工业推荐系统通常需要“目标感知”(target-aware)公式。在此设置中,目标 \(\Phi_{i+1}\) 与历史特征的”交互”(interaction)需要尽早发生,这在标准自回归设置中不可行(理解:不建模动作时,”交互”发生较晚)。论文通过表1中的交错排列商品和行为来解决这一问题,使得 Ranking 任务可被公式化为 \(p(a_{i+1}|\Phi_{0},a_{0},\Phi_{1},a_{1},\ldots,\Phi_{i+1})\) (在分类特征之前)。实践中论文使用小型神经网络将在 \(\Phi_{i+1}\) 的输出转换为多任务预测。重要的是,这使论文能在单次传递中对所有 \(n_{c}\) 次互动应用 target-aware 交叉注意力
  • 补充:对论文 target-aware 实现方式的理解 :其实 Transformer 的 Attention 也有交叉功能,只要输入端有目标 item 和用户历史交互 item 即可,但论文所说的传统的自回归模型是指输入侧不包含目标 item 的情况,论文中,将动作也建模进去,则在输出对目标 item 的动作 token 时,自然就需要将目标 item 作为输入,从而也就实现了 target-aware 交叉注意力:
    • 传统自回归模型预估目标为 : \(p(\Phi_{i+1}|\Phi_{0},\Phi_{1},\ldots,\Phi_{i})\),目标物品 \(\Phi_{i+1}\) 与其他历史序列无交叉
      • 注:此时动作 \(a_i\) 信息可能会作为额外表征加入到 \(\Phi_{i}\)中,所以 表述为 \(p(\Phi_{i}|(\Phi_{0}, a_{0}), …, (\Phi_{i - 1}, a_{i - 1}))\) 也可以
    • 论文预估目标为 : \(p(a_{i+1}|\Phi_{0},a_{0},\Phi_{1},a_{1},\ldots,\Phi_{i+1})\),目标物品 \(\Phi_{i+1}\) 与其他历史序列有交叉

生成式训练

  • 工业推荐系统通常在流式设置中训练,其中每个样本在可用时被顺序处理。在此设置中,基于自注意力的 Squential Transduction 架构(如 Transformer)的总计算需求量级为 \(\sum_{i}n_{i}(n_{i}^{2}d+n_{i}d_{ff}d)\) ,其中 \(n_{i}\) 是用户 \(i\) 的 token 数量, \(d\) 是嵌入维度,设 \(N=\max_{i}n_{i}\) ,总体时间复杂度降至 \(O(N^{3}d+N^{2}d^{2})\) ,这对推荐场景来说成本过高
    • 注:更多对公示的理解看笔者补充的附录
  • 为应对训练 Squential Transduction 模型处理长序列的挑战,论文从传统曝光级(impression-level)训练转向生成式训练,将计算复杂度降低 \(O(N)\) 因子,如图2顶部所示。这样做可将编码器成本平摊到多个目标上。具体来说,当论文以 \(s_{u}(n_{i})\) 的速率采样第 \(i\) 个用户时,总训练成本现在按 \(\sum_{i}s_{u}(n_{i})n_{i}(n_{i}^{2}d+n_{i}d^{2})\) 缩放,通过设 \(s_{u}(n_{i})\) 为 \(1/n_{i}\) 可降至 \(O(N^{2}d+Nd^{2})\) 。在工业级系统中实现此采样的一种方法是在用户请求或会话结束时发出训练样本,使得 \(\hat{s_{u} }(n_{i})\propto 1/n_{i}\)
    • 理解:从「曝光级(impression-level)训练转向生成式训练」主要是强调从以前每个曝光都是一个样本,现在一个用户的所有行为组成一个样本
    • 问题:这里对用户进行采样不会带来效果损失吗?模型看不到部分用户是OK的吗?

面向生成式推荐的高性能自注意力编码器(A High Performance Self-Attention Encoder for Generative Recommendations)

  • 为了将生成式推荐系统(GRs)扩展至具有大规模非静态词表的工业级推荐场景,论文提出了一种新型编码器设计——分层序列转导单元(Hierarchical Sequential Transduction Unit, HSTU)

  • HSTU由多个通过残差连接堆叠的相同层构成,每层包含三个子层:Pointwise Projection(公式1)、Spatial Aggregation(公式2)和 Pointwise Transformation(公式3):

$$ U(X), V(X), Q(X), K(X) = \text{Split}(\phi_1(f_1(X))) \tag{1} $$

$$ A(X)V(X) = \phi_2\left(Q(X)K(X)^T + \text{rab}^{p,t}\right)V(X) \tag{2} $$

$$ Y(X) = f_2\left(\text{Norm}\left(A(X)V(X)\right) \odot U(X)\right) \tag{3} $$

  • 其中:
    • \( f_i(X) \) 表示多层感知机(MLP);为降低计算复杂度, \( f_1 \) 和 \( f_2 \) 使用单线性层 \( f_i(X) = W_i(X) + b_i \)
      • 注:还通过融合内核批量处理查询 \( Q(X) \) 、键 \( K(X) \) 、值 \( V(X) \) 和门控权重 \( U(X) \) (即一次性将Q,K,V和Q映射都做完)
    • \( \phi_1 \) 和 \( \phi_2 \) 为非线性激活函数,均采用SiLU(即Swish);
    • Norm 表示 LayerNorm
    • \( \text{rab}^{p,t} \) 为结合位置(\( p \))和时间(\( t \))信息的相对注意力偏置
  • HSTU的编码器设计允许用单一模块化块替代DLRMs中的异构模块。论文观察到DLRMs实际包含三个阶段:特征提取、特征交互和表征变换
    • 1:特征提取(Feature Extraction)阶段通过池化操作获取类别特征的嵌入表示,其高级形式可泛化为成对注意力和 target-aware 池化
      • 如 HSTU 层所实现,能够做到上述能力
    • 2:特征交互(Feature Interaction)是DLRMs的核心,常用方法包括因子分解机及其神经网络变体、高阶特征交互等
      • HSTU通过注意力池化特征直接与其他特征交互(即 \(\text{Norm}(A(X)V(X)) \odot U(X)\))替代传统特征交互模块
      • 该设计源于学习型MLP难以近似点积的挑战(Rendle等, 2020; Zhai等, 2023a),由于SILU作用于 \(U(X)\),此结构也可视为SwiGLU(Shazeer, 2020)的变体
    • 3:表征变换(Transformations of Representations)阶段通常采用 MoE 和路由机制处理异构用户群体,其核心思想是通过子网络 specialization 实现条件计算
      • HSTU中的逐点乘积操作(Element-wise dot products)本质上能在归一化因子范围内模拟MoE的门控机制
  • 以上实现中可以看出,HSTU Layer 实现了类似 Transformer Layer 的功能
    • \(X\) 是 HSTU 的输入, \(Y(X)\) 是 HSTU 的输出
    • \(A(X)\) 为输入 \(X\) 的注意力向量(注:这里没有做归一化),是带位置偏置 \(\text{rab}^{p,t}\) 的 \(Q,K\) 内积
  • HSTU的一层 与传统的 Transformer 相比,HSTU有以下不同:
    • HSTU 增加了一个 \(U(X)\),跟Q,K,V的 projection 在同一层处理,\(U(X)\) 会作为一个 point-wise的向量与 HSTU Attention 的结果相乘
    • HSTU 没有使用传统 Transformer 中的 Softmax Attention机制,放弃了 Attention 权重和为1的约束,文中提到,这能保留用户对 item 的动作强度信息(Softmax会使得这部分信息失真)
    • HSTU 没有使用 FFN 层,这一层被 \(f_2\left(\text{Norm}\left(A(X)V(X)\right) \odot U(X)\right)\) 替换了,实现了将 MLP 替换为 point-wise 乘法,速度更快
      • 理解1:这里更像是使用 \(U(X)\) 保留了原始特征,和 Attention 后的高阶交叉特征 point-wise 相乘,得到高阶交叉特征同时也保留低阶交叉;
      • 理解2:这里的操作可以看做是 对 \(U(X)\) 进行加权,如前文所述,像是门网络/个性化加权机制实现对原始 Embedding \(U(X)\) 的加权(可类比于MoE的思想)

逐点聚合注意力(Pointwise aggregated attention)

  • HSTU采用了一种新型的Pointwise aggregated attention(区别于传统的Softmax注意力)。其动机有二:
    • 第一:与目标相关的历史数据点数量是用户偏好强度的关键特征,而Softmax归一化会削弱这一信息(Softmax的归一化会将强度信息也归一化为比例,这会使得绝对值失真)
    • 第二:Softmax对噪声鲁棒,但在流式非静态词表场景中表现欠佳(频繁增加新的 item 会影响 Softmax)
  • 实验表明,在合成数据(基于 Dirichlet 过程生成的非静态词表流式数据)中,逐点聚合注意力机制相比Softmax注意力可将Hit Rate@10提升高达44.7%(见表2)

利用算法增加稀疏性

  • 在推荐系统中,用户历史序列的长度通常呈现偏态分布(skewed distribution,这里主要指大部分用户交互序列短,少部分用户交互序列长),导致输入序列具有稀疏性,尤其是在处理超长序列时。这种稀疏性可以被有效利用以显著提升编码器的效率。为此,论文开发了一种高效的GPU注意力核函数,其设计类似于[Dao等, 2022; Zhai等, 2023b],但实现了完全分组的注意力计算。这本质上将注意力计算转化为不同尺寸的分组GEMM(通用矩阵乘法)运算(详见附录G)。因此,HSTU中的自注意力变为内存受限(memory-bound)操作,其内存访问复杂度为 \(\Theta(\sum_{i}n_{i}^{2}d_{qk}^{2}R^{-1})\) ,其中 \(n_i\) 为样本 \(i\) 的序列长度, \(d_{qk}\) 为注意力维度, \(R\) 为寄存器大小。仅此一项改进即可带来2-5倍的吞吐量提升(详见第4.2节讨论)
  • 论文进一步通过随机长度(Stochastic Length, SL)算法增加用户历史序列的稀疏性。推荐场景中用户历史序列的一个关键特征是用户行为在时间上具有重复性,这种重复性体现在用户交互历史的多尺度行为模式中。这为论文提供了在不损失模型质量的前提下人为增加稀疏性的机会,从而显著降低编码器计算成本(其复杂度为 \(\Theta(\sum_{i}n_{i}^{2})\))
    • 问题:这种降低稀疏性的方式是有损的吧?此外,相同算力消耗下,随机长度是不是不如仅保留最近多个 item 的策略?
  • 假设用户 \(j\) 的历史序列为 \((x_i)_{i=0}^{n_{c,j} }\) ,其中 \(n_{c,j}\) 为用户交互过的内容数量。令 \(N_c = \max_j n_{c,j}\) , \((x_{i_k})_{k=0}^L\) 为从1原始序列 \((x_i)_{i=0}^{n_{c,j} }\) 中选取的长度为 \(L\) 的子序列。SL算法按以下方式选择输入序列:
    $$
    \begin{align}
    (x_i)_{i=0}^{n_{c,j} } \text{ if } n_{c,j} &\leq N_c^{\alpha/2}\\
    (x_{i_k})_{k=0}^{N_c^{\alpha/2} } \text{ if } n_{c,j} &> N_c^{\alpha/2}, \text{w/ probability } 1 - N_c^{\alpha}/n_{c,j}^2\\
    (x_i)_{i=0}^{n_{c,j} } \text{ if } n_{c,j} &> N_c^{\alpha/2}, \text{w/ probability } N_c^{\alpha}/n_{c,j}^2 \tag{4}
    \end{align}
    $$
    • 其中 “w/ probability xx” 是 “with probability”的简写,表示以 “xx” 概率采样
    • \(N_c^{\alpha/2}\) 表示 \(N_c\) 的 \(\alpha/2\) 次方, \(\alpha\) 是一个调节采样比例的数字
  • 公式4描述了Stochastic Length (SL)算法如何根据用户历史序列的长度(\(n_{c,j}\))动态调整子序列的采样策略:
    • 如果序列长度 \(n_{c,j} \leq N_c^{\alpha/2}\) ,则:
      • 直接使用完整序列
    • 如果序列长度 \(n_{c,j} > N_c^{\alpha/2}\) ,则
      • 以概率 \(1 - N_c^\alpha / n_{c,j}^2\) 选择长度为 \(N_c^{\alpha/2}\) 的子序列
      • 以概率 \(N_c^\alpha / n_{c,j}^2\) 仍使用完整序列
  • 该算法将注意力相关复杂度降低至 \(O(N_c^{\alpha}d) = O(N^{\alpha}d)\) ,其中 \(\alpha \in (1,2]\) 。关于子序列选择的更详细讨论见附录F.1。值得注意的是,将SL应用于训练能够实现高性价比的系统设计,因为训练的计算成本通常远高于推理
  • 表3展示了在具有30天用户历史的典型工业规模配置下,不同序列长度和 \(\alpha\) 值对应的稀疏性(定义见附录F)。其中对模型质量影响可忽略的设置用下划线和蓝色标出。标为“ \(\alpha=2.0\) ”的行表示未应用SL时的基准稀疏性。更低的 \(\alpha\) 值适用于更长的序列,论文测试的最长序列长度为8,192

最小化激活内存使用

  • 在推荐系统中,使用大的 Batch Size 训练对提升训练吞吐量和模型质量至关重要(理解:大的 Batch Size 更容易得到准确的梯度),因此,激活内存的使用成为扩展的主要瓶颈,这与通常以小的 Batch Size 训练且以参数内存为主的大型语言模型形成鲜明对比
  • 与 Transformer 相比,HSTU采用了一种简化且完全融合的设计,显著降低了激活内存的使用
    • 首先,HSTU将注意力外的线性层数量从 6 个减少到 2 个,这与近期通过逐元素门控减少MLP计算的工作[Hua等, 2022; Zhai等, 2023b]一致
    • 其次,HSTU将计算过程激进地融合为单一算子,包括公式(1)中的 \(\phi_1(f_1(\cdot))\) ,以及公式(3)中的层归一化、可选Dropout和输出MLP。这种简化设计将每层的激活内存使用量降低至 \(2d + 2d + 4hd_{qk} + 4hd_v + 2hd_v = 14d\) (以bfloat16格式存储)
    • 理解:这里的减少线性层数量主要是通过移除传统的FFN来实现(传统的一个FFN包含两个线性层),HSTU总体仅保留两个线性层,即输入层和输出层
    • 注:文章这里说的将线性层从 6 层降低到 2 层的表达应该是跳过了一些设计细节,比如网络结构中FFN的数量未明确
  • 作为对比, Transformer 在注意力后使用前馈层和Dropout(中间状态为 \(3hd_v\)),随后是一个逐点前馈块,包含层归一化、线性层、激活函数、线性层和Dropout,中间状态为 \(2d + 4d_{ff} + 2d + 1d = 4d + 4d_{ff}\) 。这里论文假设 \(hd_v \geq d\) 且 \(d_{ff} = 4d\) [Vaswani等, 2017; Touvron等, 2023a]。因此,在考虑输入和输入层归一化(\(4d\))以及QKV投影后,总激活状态为 \(33d\) 。HSTU的设计使其能够扩展到比 Transformer 深2倍以上的网络
  • 此外,用于表示词汇表的大规模原子ID也需要大量内存。对于包含100亿词汇、512维嵌入和Adam优化器的配置,仅存储嵌入和优化器状态(以fp32格式)就需要60TB内存。为缓解内存压力,论文采用了 row-wise AdamW 优化器并将优化器状态放置在DRAM中,从而将每个浮点数的HBM使用量从12字节减少到2字节
    • 注:论文给的引用中没有详细介绍 row-wise AdamW 优化器
    • HBM(High Bandwidth Memory,高带宽内存)和 DRAM(Dynamic Random Access Memory,动态随机存取存储器)都是计算机内存技术,HBM 将多个 DRAM 芯片堆叠在一起(HBM可以理解为更为昂贵的告诉内存),拥有比 DRAM 更高的数据传输性能
  • 总结来看 :HSTU通过以下设计显著降低内存占用:
    • 第一:将注意力外的线性层从6层减至2层,采用逐点门控减少MLP计算;
    • 第二:将计算融合为单一算子(如公式1中的 \( \phi_1(f_1(\cdot)) \) 和公式3中的层归一化、Dropout及输出MLP);
    • 第三:使用row-wise AdamW优化器,将优化器状态存储于DRAM,使每个浮点数在HBM的占用从12字节降至2字节
    • 最终,与 Transformer 相比,HSTU的激活内存占用从每层33d降至14d(bfloat16),支持构建2倍更深的网络

通过成本分摊扩展推理规模

  • 论文需要解决的最后一个挑战是推荐系统在服务时需要处理大量候选内容。论文主要关注 Ranking 任务 :
  • 在Retrieval 任务(无需论文关注)中,编码器成本完全可以分摊,且已有高效的算法支持基于量化(quantization)、哈希(hashing)或分区(partitioning)的最大内积搜索(MIPS),以及基于束搜索或分层检索的 non-MIPS 场景
    • MIPS 是 Maximum Inner Product Search(最大内积搜索)的缩写,主要用于高效地找到与给定查询向量内积最大的候选向量 ,在召回阶段,MIPS可以与量化、哈希或分区等技术结合使用,以提升效率
  • 对于Ranking 任务 ,候选内容数量可能高达数万。论文提出了一种算法M-FALCON(Microbatched-Fast Attention Leveraging Cacheable Operations)(可译为微批次-快速注意力利用可缓存操作),用于在输入序列长度为 \(n\) 时对 \(m\) 个候选内容进行推理
    • 批量推理 :在一次前向传播中,M-FALCON通过修改注意力掩码和 \(\text{rab}^{p,t}\) 偏置,使得 \(b_m\) 个候选内容的注意力操作完全一致,从而并行处理 \(b_m\) 个候选内容,此时原本的 \(b_m\) 次长度为 \(n\) 的推理转换为 1 次 \(b_m+n\) 的推理(注:这里有MTP的味道)
      • 这将交叉注意力的计算成本从 \(O(b_m n^2 d)\) 降低到 \(O((n + b_m)^2 d) = O(n^2 d)\) (当 \(b_m\) 相对于 \(n\) 可视为小常数时)
    • KV Caching机制 :论文还可以将全部 \(m\) 个候选内容划分为 \(\lceil m/b_m \rceil\) 个大小为 \(b_m\) 的微批次,以利用编码器级的KV缓存 ,从而在跨前向传播时降低成本,或在跨请求时最小化尾部延迟
    • 关于M-FALCON方法的更多细节见附录H
  • 总体而言,M-FALCON使得模型复杂度能够与传统DLRM排序阶段的候选数量线性扩展。论文成功地在典型的排序配置中(见第4.3节)应用了复杂度高28倍的目标感知交叉注意力模型 ,同时在恒定推理预算下实现了1.5倍至3倍的吞吐量提升
  • M-FALCON方法的整体结构图如下(From 附录):

Experiments

验证HSTU编码器的归纳假设

传统序列化设置
  • 论文首先在两个流行的推荐数据集(MovieLens和Amazon Reviews)上评估HSTU的性能。实验遵循文献中的序列化推荐设置,包括完全打乱和多轮次训练。基线采用 SOTA Transformer 实现SASRec(Kang & McAuley, 2018)。与近期工作(Dallmann等, 2021; Zhai等, 2023a)一致,论文报告整个语料库的Hit Rate@K和NDCG@K
  • 结果如表4所示。“SASRec (2023)”表示(Zhai等, 2023a)中报告的最佳SASRec配置。“HSTU”行使用与SASRec相同的配置(层数、头数等)。“HSTU-large”表示更大的HSTU编码器(层数4倍,头数2倍)。结果表明:
    • a)HSTU因其针对推荐优化的设计,在相同配置下显著优于基线;
    • b)HSTU在扩展后性能进一步提升
  • 需要注意的是,此处使用的评估方法与工业级设置差异显著,因为完全打乱和多轮次训练在工业级流式设置中通常不实用(Liu等, 2022)
工业级流式设置
  • 接下来,论文在工业级数据集的流式设置中比较HSTU、消融版HSTU和 Transformer 的性能。本节剩余部分中,论文报告排名的归一化熵(Normalized Entropy,NE)(详细定义见附录)。模型训练超过1000亿样本(DLRM等效),每个任务使用64-256块H100 GPU(!!!)。由于排名是多任务设置,论文报告主要互动事件(“E-Task”)和主要消费事件(“C-Task”)。在论文的上下文中,NE降低0.001通常意味着数十亿用户的顶层指标提升0.5%。对于 Retrieval 任务,由于设置与语言建模类似,论文报告对数困惑度。在小规模设置中固定编码器参数(排名任务: \(l=3\) , \(n=2048\) , \(d=512\) ; Retrieval 任务: \(l=6\) , \(n=512\) , \(d=256\)),并因资源限制对其他超参数进行网格搜索
  • 结果如表5所示:
    • 第一 :HSTU显著优于 Transformer,尤其是在排名任务中,这可能是由于点注意力机制和改进的相对注意力偏差
    • 第二 :消融版HSTU与完整HSTU之间的差距验证了论文的设计有效性。Softmax版HSTU和 Transformer 的最佳学习率比其他模型低约10倍,这是为了训练稳定性。即使采用较低学习率和预归一化残差连接(Xiong等, 2020),标准 Transformer 在排名任务中仍频繁出现损失爆炸
    • 第三 :HSTU优于LLM中流行的 Transformer 变体 Transformer ++(Touvron等, 2023b), Transformer ++ 使用了RoPE、SwiGLU等技术。总体而言,在此小规模设置中,HSTU在质量上表现更优,同时训练速度提升1.5-2倍,HBM使用减少50%

编码器效率

  • 随机长度(Stochastic Length, SL)。图4和图5(a)展示了随机长度(SL)对模型指标的影响
    • 当 \(\alpha=1.6\) 时,长度为 4096 的序列在大多数情况下被压缩为 776,移除了超过 80% 的 token。即使 sparse 率增加到 64%-84%,主要任务的NE下降不超过 0.002(0.2%)
    • 以上这一证据表明,对于合适的 \(\alpha\) 值,SL 不会对模型质量产生负面影响,同时可以通过高 sparse 性降低训练成本。论文还在附录F.3中验证了 SL 显著优于现有的长度外推技术
  • 编码器效率。图5比较了HSTU和 Transformer 编码器在训练和推理设置中的效率
    • 对于 Transformer,论文使用 SOTA FlashAttention-2(Dao, 2023)实现
    • 考虑序列长度从 1,024 到 8,192,并在训练中应用随机长度(SL)
    • 评估时,HSTU和 Transformer 使用相同配置(\(d=512\) , \(h=8\) , \(d_{qk}=64\)),并消融相对注意力偏差(如第4.1.2节所示,HSTU在不使用rab \(^{p,t}\) 时仍优于 Transformer)。在NVIDIA H100 GPU上以bfloat16比较编码器级性能。总体而言,HSTU在训练和推理中的效率分别比 Transformer 高15.2倍和5.6倍
  • 此外,如第3.3节讨论的激活内存使用减少,使论文能够构建比 Transformer 深2倍以上的HSTU网络

工业级流式设置中生成式推荐器(GR)与 DLRM 的比较

  • 最后,论文比较了GR与 SOTA DLRM 基线在工业级流式设置中的端到端性能。论文的GR实现反映了生产中使用的典型配置,而 DLRM 设置则是数百人多年迭代的结果
    • 由于推荐系统的召回阶段使用了多个生成器(理解:这里应该是指多路召回),论文报告了将 GR 作为新增召回通道(“new source”)和替换现有主要 DLRM 通道(“replace”)的在线结果。表6和表7显示,GR不仅在离线测试中显著优于 DLRM,还在A/B测试中带来了12.4%的提升
    • 从表6中可知,作为一个召回通道加入能得到的更好的效果
  • 如第2节所述,GR基于原始分类互动特征构建,而 DLRM 通常使用更多手工特征(handcrafted features)。如果论文将GR使用的相同特征集提供给 DLRM (“DLRM(abl. features)”,消融特征),DLRM 的性能会显著下降,这表明GR通过其架构和统一特征空间可以有效地捕捉这些特征
  • 论文进一步通过与传统序列推荐器设置的比较验证了第2.2节中的GR公式。传统设置仅考虑用户交互过的物品(Kang & McAuley, 2018),结果显示其性能显著下降(表6和表7中的“GR(interactions only)”行)。论文还包含了一个仅使用内容特征的GR基线(“GR(content-based)”)。内容基线与 DLRM /GR之间的巨大性能差距凸显了高基数用户行为的重要性
  • 图6比较了GR与生产 DLRM 的效率。尽管GR模型的复杂度是 DLRM 的285倍(其中 \(285\times \text{FLOPs}\) 表示其复杂度是DLRM的285倍),但由于HSTU和第3.4节中的M-FALCON算法,论文在评分1024/16384个候选时实现了1.50倍/2.99倍的更高QPS
推荐系统的 scaling law
  • 众所周知,在大规模工业设置中,DLRM 在特定计算和参数范围内会达到质量饱和(Zhao等, 2023)。论文比较了GR和 DLRM 的可扩展性以更好地理解这一现象
  • 由于特征交互层对 DLRM 性能至关重要(Mudigere等, 2022),论文尝试了 Transformer (Vaswani等, 2017)、DHEN(Zhang等, 2022)以及论文生产设置中使用的带残差连接的DCN变体(He等, 2015)来扩展 DLRM 基线
    • 对于召回基线,论文通过调整隐藏层大小、嵌入维度和层数来扩展模型
    • 对于基于HSTU的生成式推荐器(GR),论文通过调整HSTU的超参数(如残差层数、序列长度、嵌入维度、注意力头数等)扩展模型,并调整 Retrieval 任务的负样本数量
  • 结果如图7所示:
    • 在低计算量区域,由于手工特征(handcrafted features)的存在,DLRM 可能优于GR,这验证了传统 DLRM 中特征工程的重要性
    • 然而,GR在FLOPs方面表现出显著更好的可扩展性(scalability),而 DLRM 性能趋于饱和,这与先前的研究结果一致
    • 论文还观察到,GR在嵌入参数和非嵌入参数方面均具有更好的可扩展性,其模型参数达到 1,500B ,而 DLRM 在约 200B 参数时性能饱和(这个参数量的训练,Meta真够有钱!)
  • 最后,论文的所有主要指标(包括 Retrieval 任务的 Hit Rate@100 和 Hit Rate@500,以及排名任务的NE)在适当的超参数下 ,随计算量的增加呈现幂律扩展(power law of compute used)
    • 这一现象在三个数量级范围内均成立,直至论文测试的最大模型(序列长度8,192,嵌入维度1,024,24层HSTU)
    • 此时,论文使用的总计算量(按365天标准化,因为论文采用标准流式训练设置)接近GPT-3(Brown等, 2020)和LLaMa-2(Touvron等, 2023b)的总训练计算量,如图1所示
    • 与语言建模(Kaplan等, 2020)相比,序列长度在GR中扮演更重要的角色,因此需要同步扩展序列长度和其他参数
    • 这或许是论文提出的方法最重要的优势,因为论文首次证明了 LLM 的 scaling law 可能也适用于大规模推荐系统

相关工作

传统序列推荐方法

  • 先前关于序列推荐的研究将用户交互行为简化为单一同质的物品序列(Hidasi等,2016;Kang & McAuley,2018)。工业级应用中,序列方法主要作为深度推荐模型(DLRMs)的一部分,包括成对注意力(Zhou等,2018)或序列编码器(Chen等,2019;Xia等,2023)。为了提高效率,已有研究探索了多阶段注意力替代自注意力机制(Chang等,2023)。在 Retrieval 任务中,将ID表示为 token 序列的生成方法也有所探索(Zhuo等,2020)。附录B.1中对相关工作进行了更详细的讨论

高效注意力机制

  • 由于自注意力机制的 \(O(n^2)\) 复杂度,高效注意力一直是研究热点,主要工作包括因子分解注意力(Child等,2019)、低秩近似(Katharopoulos等,2020)等。近期,针对序列转导任务的新架构也被提出(Gu等,2022;Hua等,2022)。HSTU的逐元素门控设计尤其受到FLASH(Hua等,2022)的启发。硬件感知的优化方法显著降低了内存使用(Rabe & Staats,2021;Korthikanti等,2022;Zhai等,2023b),并大幅缩短了实际运行时间(Dao等,2022)

长度外推技术

  • 长度外推技术使模型能够在较短的序列上训练后泛化到更长的序列,但多数研究聚焦于微调或改进偏置机制(Press等,2022)。论文则通过在长度维度引入随机性,受深度维度随机性研究(Huang等,2016)的启发

大语言模型(LLMs)的应用

  • 大语言模型的兴起促使研究者将推荐任务视为上下文学习(Silco等,2022)、指令微调(Bao等,2023)或基于预训练LLMs的迁移学习(Li等,2023)。LLMs中嵌入的世界知识可以迁移到下游任务(Cui等,2022),并在零样本或少样本场景中提升推荐效果。用户行为序列的文本表示也在中等规模数据集上展现出良好的扩展性(Shin等,2023)。然而,多数关于LLMs用于推荐的研究集中在低数据量场景;在大规模场景中,它们尚未在MovieLens等数据集上超越协同过滤方法(Hou等,2024)

整体总结

  • 新范式提出 :论文提出了一种新范式:生成式推荐器(Generative Recommenders, GRs),将排序和 Retrieval 任务重新定义为序列直推式任务(Sequential Transduction Tasks),使其能够以生成式方式进行训练
    • 注:最核心的创新点可能是将用户行为也视为一个像是图片、视频和文本类似的新模态
  • 性能优化 :得益于新颖的HSTU编码器设计,其在8192长度的序列上比当前 SOTA Transformer 快5.3至15.2倍,同时结合了M-FALCON等新型训练和推理算法
  • 实验 :通过GRs,论文部署了复杂度提升285倍的模型,同时减少了推理计算量。GRs和HSTU在生产环境中实现了12.4%的指标提升,并展现出比传统 DLRMs 更优的扩展性能
    • 论文的结果验证了用户行为是生成式建模中一个尚未充分探索的模态——正如标题所言,Actions speak louder than words(行动胜于雄辩)
  • 论文通过统一特征空间的简化设计,为推荐、搜索和广告领域的首个基础模型铺平了道路。GRs的完全序列化设置还支持端到端的生成式推荐框架。这两点使得推荐系统能够更全面地辅助用户
  • 未来规划 :生成式推荐器的完全序列化特性有望进一步推动端到端推荐系统的发展,例如通过直接生成推荐序列而非传统的列表排序。作者相信,这一方向将为推荐系统带来更广阔的应用前景

附录(个人): Transformer 模型计算量评估

  • 在原论文2.3节的公式 \(\sum_{i} n_{i}(n_{i}^{2}d + n_{i}d_{ff}d)\) 的来源解释如下

  • 注: \(d_{ff}\) 表示 前馈神经网络(Feed-Forward Network, FFN)的隐藏层维度(即中间层的神经元数量),在原始 Transformer 中通常设为 \(d_{ff} = 4d\)

  • 在 Transformer 或类似的自注意力架构(如HSTU)中,每个编码器层通常包含两个核心子层:

    • 自注意力机制(Self-Attention):复杂度为 \(O(n_i^2 d)\) (与序列长度平方和嵌入维度相关)
      • 自注意力的计算涉及三个主要步骤:
        • 计算查询(Q)、键(K)、值(V)矩阵:复杂度为 \(O(n_i d^2)\) ,其中 \(n_i\) 是序列长度, \(d\) 是嵌入维度
        • 计算注意力分数 \(QK^T\) :复杂度为 \(O(n_i^2 d)\)
        • 加权求和 \(AV\) :复杂度为 \(O(n_i^2 d)\)
      • 因此,自注意力的总复杂度为 \(O(n_i^2 d + n_i d^2)\) 。对于长序列(\(n_i \gg d\)),主导项为 \(O(n_i^2 d)\)
    • 前馈神经网络(FFN):复杂度为 \(O(n_i d_{ff} d)\) (与序列长度、嵌入维度和FFN隐藏层维度相关,)
    • FFN的典型结构是:
      $$
      \text{FFN}(x) = W_2 \cdot \text{ReLU}(W_1 x + b_1) + b_2
      $$
      • 其中:
        • \(W_1 \in \mathbb{R}^{d \times d_{ff} }\) :将输入从 \(d\) 维映射到 \(d_{ff}\) 维(通常 \(d_{ff} = O(d)\) ,注意不是相等,一般是通常 \(d_{ff} = 4d\) 等)
        • \(W_2 \in \mathbb{R}^{d_{ff} \times d}\) :将结果映射回 \(d\) 维
    • 在 Transformer 中, \(d_{ff}\) 通常远大于 \(d\) (例如 \(d_{ff}=4d\)),这使得FFN成为计算瓶颈之一。HSTU通过点乘门控(如 \(U(X)\))替代了传统FFN(见式3),因此可能不再需要显式的 \(d_{ff}\) 参数。但在对比 Transformer 基线时,公式中仍保留了 \(d_{ff}\) 以反映传统架构的成本
  • 总计算复杂度

    • 对于每个用户的序列,计算量为自注意力和前馈网络复杂度的总和:
      $$
      O(n_i^2 d + n_i d^2)
      $$
    • 若对所有用户求和,总计算量为:
      $$
      \sum_i n_i (n_i^2 d + n_i d^2)
      $$
    • 假设最大序列长度为 \(N = \max_i n_i\) ,则复杂度可简化为:
      $$
      O(N^3 d + N^2 d^2)
      $$
  • 论文中的其他优化:论文提出通过生成式训练(generative training)对计算量进行优化:

    • 通过调整用户采样率 \(s_u(n_i) \propto 1/n_i\) ,将总计算量降低为:
      $$
      \sum_i s_u(n_i) n_i (n_i^2 d + n_i d^2) \approx O(N^2 d + N d^2)
      $$
    • 这一优化将复杂度从三次方(\(N^3\))降至二次方(\(N^2\))

附录(个人):关于NE的完整定义

  • 参考原始论文:Practical Lessons from Predicting Clicks on Ads at Facebook
  • NE(Normalized Cross Entropy,归一化交叉熵),也可更准确地称为Normalized Logarithmic Loss(归一化对数损失),其定义为平均每个展示的对数损失除以模型预测每个展示的 CTR 时的平均对数损失
  • 假设给定的训练数据集有 \(N\) 个样本,则计算公式为:
    $$NE=\frac{-\frac{1}{N} \sum_{i=1}^{n}\left(\frac{1+y_{i} }{2} log \left(p_{i}\right)+\frac{1-y_{i} }{2} log \left(1-p_{i}\right)\right)}{-(p * log (p)+(1-p) * log (1-p))} $$
    • \(y_{i} \in{-1,+1}\) 为真实标签
    • \(p_{i}\) 为第 \(i\) 个样本的CTR预估值,其中 \(i = 1,2,\cdots,N\)
    • \(p\) 平均经验点击率(基于统计的真实点击率)
  • 该指标用于评估模型预测的好坏,其值越低,模型的预测效果越好
  • 进行这种归一化处理的原因是,真实 CTR 预估值越接近0或1,就越容易获得更好的对数损失,而除以 真实 CTR 的熵可使 NE 对真实 CTR 不敏感
  • 其他相关指标:Relative Information Gain (RIG)
    $$ RIG = 1 − NE $$

附录A:符号说明

  • 论文在下表中总结了论文使用的关键符号(对应原始论文的表8和表9)
    符号 描述
    \(\Psi_{k}(t_{j})\) Feature Logging 系统在 \(t_{j}\) 时刻发出的第 \(k\) 个 Training Example(\(k\) 是全局排序的)
    在典型的深度推荐模型(DLRM)推荐系统中,用户消费某些内容 \(\Phi_{i}\) (通过诸如跳过、视频观看完成和分享等动作 \(a_{i}\) 做出响应)后,特征记录系统将元组 \((\Phi_{i}, a_{i})\) 与用于对 \(\Phi_{i}\) 进行排名的特征相结合,并发出 \((\Phi_{i}, a_{i}\) ,以及 \(\Phi_{i}\) 的特征) 作为 Training Example \(\Psi_{k}(t_{j})\)
    如2.3节所述,DLRM和生成式推荐器(GRs)处理的 Training Example 数量不同,GRs中的 Example 数量通常少1 - 2个数量级
    \(n_{c}(n_{c,i})\) 与用户/样本 \(i\) 交互的内容数量
    \(\Phi_{0}, …, \Phi_{n_{c}-1}\) 在推荐系统的上下文中,与用户交互的内容列表
    \(a_{0}, …, a_{n_{c}-1}\) 与内容 \(\Phi_{i}\) 对应的用户动作列表
    当所有预测事件都是 Binary 时,每个动作可以被视为一个 multi-hot vector 事件(如点赞、分享、评论、图片浏览、视频初始化、视频观看完成、隐藏等)
    \(E, F\) 图2中DLRM中的分类特征
    \(E_{0}, E_{1}, …, E_{7}, E_{8}\) 和 \(F_{0}, F_{1}, …, F_{7}\) 表示在不同时间点通过特征提取(例如,最近喜欢的10张图片、用户过去点击的与当前候选内容最相似的50个网址等)从 \((\Phi_{0}, a_{0}, t_{0}), …, (\Phi_{n_{c}-1}, a_{n_{c}-1}, t_{n_{c}-1})\) 获得的转换结果
    “merge & sequentialize” 表示获取原始参与系列 \((\Phi_{0}, a_{0}, t_{0}), …, (\Phi_{n_{c}-1}, a_{n_{c}-1}, t_{n_{c}-1})\) 的(虚拟)反向过程
    \(G, H\) 图2中DLRM中与 user-item 参与无关的分类特征。这些特征(例如人口统计信息或关注的创作者)被合并到主时间序列(用户参与的内容列表,例如 \(\Phi_{0}, a_{0}, …, \Phi_{n_{c}-1}, a_{n_{c}-1}\))中,如2.1节所述并在图2中说明
    \(n(n_{i})\) Squential Transduction 任务中的 token 数量(对于用户或样本 \(i\))。虽然 \(O(n) = O(n_{c})\) ,但即使没有任何与非交互相关的分类特征, \(n\) 也可能与 \(n_{c}\) 不同;例如,见表1
    \(x_{0}, …, x_{n - 1}\) Squential Transduction 任务中的输入 token 列表
    \(y_{0}, …, y_{n - 1}\) Squential Transduction 任务中的输出 token 列表
    \(t_{0}, …, t_{n - 1}\) 与观察到 \(x_{0}, …, x_{n - 1}\) 的时间对应的时间戳列表
    \(\mathbb{X}, \mathbb{X}_{c}\) 所有输入/输出 token 的词汇表(\(\mathbb{X}\))及其内容子集(\(\mathbb{X}_{c}\))
    \(N, N_{c}\) \(\max_{i}n_{i}\) , \(\max_{i}n_{c,i}\)
    \(u_{t}\) 在时间 \(t\) 的用户表示
    \(s_{u}(n_{i}), \hat{s}_{u}(n_{i})\) 生成式训练(2.3节)中用于用户 \(i\) 的采样率
    \(d\) 模型维度(嵌入维度)
    \(d_{qk}\) HSTU和 Transformer 中注意力维度的大小。这适用于公式(1)中的 \(Q(X)\) 和 \(K(X)\)
    \(d_{v}\) HSTU中 Value 维度的大小。对于 Transformer,论文通常有 \(d_{qk} = d_{v}\)
    \(d_{ff}\) Transformer pointwise 前馈层中(feedforward)的隐藏维度大小。HSTU不使用前馈层;见下面的 \(U(X)\)
    \(h\) 注意力头的数量
    \(l\) HSTU中的层数。对于 Transformer,注意力层和逐点前馈层共同构成一层
    \(Q(X), K(X), V(X)\) 根据公式(1)为给定输入 \(X\) 在HSTU中获得的查询、键、值(Query/Key/Value)。其定义与标准 Transformer 中的 \(Q, K, V\) 类似。 \(Q(X), K(X) \in \mathbb{R}^{h \times N \times d_{qk} }\) ,并且HSTU使用 \(U(X)\) (与 \(f_{2}(·)\) 一起)在公式(3)中 “门控” 注意力池化值(\(V(X)\)),这使得HSTU完全避免了前馈层。 \(U(X) \in \mathbb{R}^{h \times N \times d_{v} }\)
    \(A(X)\) 为输入 \(X\) 获得的注意力张量。 \(A(X) \in \mathbb{R}^{h \times N \times N}\)
    \(Y(X)\) HSTU层针对输入 \(X\) 的输出。 \(Y(X) \in \mathbb{R}^{d}\)
    \(Split(·)\) 将张量分割成块的操作。 \(Split(\phi_{1}(f_{1}(X))) \in \mathbb{R}^{N \times (2hd_{qk} + 2hd_{v})}\) ;论文通过分割较大的张量(并置换维度)获得 \(U(X), V(X)\) (两者形状均为 \(h \times N \times d_{v}\))、 \(Q(X), K(X)\) (两者形状均为 \(h \times N \times d_{qk}\))
    \(rab^{p,t}\) 结合了位置(Raffel等,2020)和时间(基于观察到 token 的时间 \(t_{0}, …, t_{n - 1}\) ;一种可能的实现方式是对 \((t_{j} - t_{i})\) 应用某种分桶函数得到 \((i, j)\))信息的相对注意力偏差。在实践中,论文在一层内的不同注意力头之间共享 \(rab^{p,t}\) ,因此 \(rab^{p,t} \in \mathbb{R}^{1 \times N \times N}\)
    \(\alpha\) 控制HSTU中随机长度算法(3.2节)稀疏性的参数
    \(R\) GPU上的寄存器大小,在3.2节讨论的HSTU算法的上下文中
    \(m\) 推荐系统Ranking 阶段考虑的候选数量
    \(b_{m}\) M-FALCON 算法(3.4节)中的微批次大小

附录B:生成式推荐器:背景与公式

  • 许多读者可能对经典的深度学习推荐模型(DLRM)更为熟悉(Mudigere等,2022),因为自YouTube DNN时代起它就颇受欢迎(Covington等,2016),并且在每个大型在线内容和电子商务平台上都得到了广泛应用(Cheng等,2016;Zhou等,2018;Wang等,2021;Chang等,2023;Xia等,2023;Zhai等,2023a)。DLRM在异构特征空间上运行,使用各种神经网络,包括特征交互模块(Guo等,2017;Xiao等,2017;Wang等,2021)、顺序池化或目标感知成对注意力模块(Hidasi等,2016;Zhou等,2018;Chang等,2023)以及先进的多专家多任务模块(Ma等,2018;Tang等,2020)。因此,论文在第2节和第3节中通过将生成式推荐器(GRs)与经典DLRM进行明确对比,概述了GRs。在本节中,论文从经典的顺序推荐文献出发,为读者提供另一种视角

附录B.1:背景:学术界和工业界的顺序推荐

附录B.1.1:学术研究(传统顺序推荐设置)
  • 循环神经网络(RNNs)最早在GRU4Rec(Hidasi等,2016)中应用于推荐场景。Hidasi等(2016)考虑了门控循环单元(GRUs),并将其应用于两个数据集,即RecSys Challenge 2015和VIDEO(一个专有数据集)。在这两种情况下,只有 positive 事件(点击的电子商务商品或用户观看至少一定时间的视频)被保留作为输入序列的一部分。论文进一步观察到,在由检索和Ranking 阶段组成的经典工业规模两阶段推荐系统设置中(Covington等,2016),Hidasi等(2016)解决的任务主要对应于检索任务
  • 后来, Squential Transduction 架构的进步,特别是 Transformer (Vaswani等,2017),推动了推荐系统的类似进展。SASRec(Kang和McAuley,2018)首次在自回归设置中应用 Transformer。他们将评论或评分的存在视为 positive 反馈 ,从而将像亚马逊评论和MovieLens这样的经典数据集转换为 positive item 的序列,类似于GRU4Rec。采用二元交叉熵损失,其中正目标定义为下一个 “positive” item (回想一下,这本质上只是评论或评分的存在),负目标从 item 语料库 \(\mathbb{X}=\mathbb{X}_{c}\) 中随机采样
  • 随后的大多数研究都基于与上述GRU4Rec(Hidasi等,2016)和SASRec(Kang和McAuley,2018)类似的设置,例如BERT4Rec(Sun等,2019)应用来自BERT(Devlin等,2019)的双向编码器设置,S3Rec(Zhou等,2020)引入明确的预训练阶段,等等
附录B.1.2:作为深度学习推荐模型(DLRM)一部分的工业应用
  • 顺序方法,包括顺序编码器和成对注意力模块,由于能够作为DLRM的一部分增强用户表示,已在工业环境中得到广泛应用。DLRM通常使用相对较短的序列长度 ,例如BST(Chen等,2019)中为20,DIN(Zhou等,2018)中为1000,TransAct(Xia等,2023)中为100。论文观察到,论文中的8192比传统DLRM序列长度大 1-3 个数量级
  • 尽管使用短序列长度,大多数DLRM仍能成功捕捉长期用户偏好。这可归因于两个关键方面:
    • 第一 :现代DLRM中通常使用预计算的用户配置文件/嵌入或外部向量存储(Chang等,2023),这两者都有效地扩展了回顾窗口
      • 理解,用户 Embedding 其实与用于历史长期用户偏好有关?
    • 第二 :通常会采用大量的上下文、用户和 item 侧特征,并且使用各种异构网络,如因子分解机(FMs)、深度交叉网络(DCNs)、混合专家(MoEs)等,来转换表示并组合输出
      • 理解:特征交叉能泛化到用户长期偏好?
  • 与附录B.1.1中讨论的顺序设置相比,所有主要的工业界工作都在(用户/请求,候选 item)对上定义损失。在排名设置中,通常使用多任务二元交叉熵损失。在检索设置中,双塔设置(Covington等,2016)仍然是主导方法。最近的工作研究了将下一个推荐 item 表示为(子) token 序列上的概率分布,如OTM(Zhuo等,2020)和DR(Gao等,2021)(注意,在其他近期工作中,相同的设置有时被称为 “生成式召回”)。它们通常利用 beam search 从子 token 中解码 item。随着现代加速器(如GPU、定制ASIC和TPU)的普及,还提出并部署了先进的学习相似性函数,如混合逻辑(Zhai等,2023a),作为双塔设置和 beam search 的替代方案
  • 从问题公式化的角度来看,考虑到模型架构、使用的特征和损失与附录B.1.1中讨论的学术顺序推荐研究有显著差异,作者认为上述所有工作都属于DLRM(Mudigere等,2022)的一部分。值得注意的是,在这项工作之前,工业界尚未成功应用完全顺序排名设置,特别是在每日活跃用户(DAU)达到数十亿规模的情况下

附录B.2:公式化表述:生成式推荐器(GRs)中作为序列转换任务的排序与检索

  • 接下来,论文讨论传统序列推荐器设置和深度推荐模型(DLRM)设置中的三个局限性,以及生成式推荐器(GRs)如何从问题公式化的角度解决这些问题
  • 问题1:忽略用户交互物品之外的特征 :以往的序列公式化表述仅考虑用户明确交互过的内容(物品) ,而在GRs出现之前,传统推荐系统会基于大量特征进行训练 ,以增强用户和内容的表征 ,GRs通过以下方式解决这一局限性:
    • a)压缩其他分类特征并将其与主时间序列合并;
    • b)如2.1节和图2所述,利用目标感知公式通过交叉注意力交互来捕捉数值特征。论文通过实验验证了这一点,结果表明,忽略这些特征的传统 “interactions only” 公式化表述会显著降低模型质量;实验结果可在表7和表6中标记为 “GR(interactions only)” 的行中找到,论文发现仅利用交互历史会导致检索的 HitRate@100 下降1.3%,排序的归一化熵(NE)下降2.6%(回想一下,如4.1.2节和4.3.1节所述,NE变化0.1% 即为显著变化)
  • 问题2:用户表征在与目标无关的设置中计算 :大多数传统序列推荐器,包括GRU4Rec(2016)、SASRec(2018)、BERT4Rec(2019)、S3Rec(2020)等,都是以与目标无关的方式构建的。在这种方式中,对于目标物品 \(\Phi_{i}\) , \(\Phi_{0}, \Phi_{1}, …, \Phi_{i - 1}\) 被用作编码器输入来计算用户表征,然后用于进行预测。相比之下,工业环境中使用的大多数主要DLRM方法在构建所使用的序列模块时考虑了目标感知,能够将 “目标”(排序候选)信息整合到用户表征中。这些方法包括DIN(2018)(阿里巴巴)、BST(2019)(阿里巴巴)、TWIN(2023)(快手)和TransAct(2023)(Pinterest)
    • 生成式推荐器(GRs)通过交错内容和动作序列(2.2节),在因果自回归设置中实现目标感知注意力机制 ,结合了两者的优点(序列推荐器和DLRM的优点)。论文在表10中对先前的工作和本工作进行了分类和对比
    • 理解:(详情见前文对论文 target-aware 实现方式的理解),其实 Transformer 的 Attention 也有交叉功能,只要输入端有目标 item 和用户历史交互 item 即可,但论文所说的传统的自回归模型是指输入侧不包含目标 item 的情况,论文中,将动作也建模进去,则在输出对目标 item 的动作 token 时,自然就需要将目标 item 作为输入,从而也就实现了 target-aware 交叉注意力:
      • 传统自回归模型预估目标为 : \(p(\Phi_{i+1}|\Phi_{0},\Phi_{1},\ldots,\Phi_{i})\),目标物品 \(\Phi_{i+1}\) 与其他历史序列无交叉
        • 注:此时动作 \(a_i\) 信息可能会作为额外表征加入到 \(\Phi_{i}\)中,所以 表述为 \(p(\Phi_{i}|(\Phi_{0}, a_{0}), …, (\Phi_{i - 1}, a_{i - 1}))\) 也可以
      • 论文预估目标为 : \(p(a_{i+1}|\Phi_{0},a_{0},\Phi_{1},a_{1},\ldots,\Phi_{i+1})\),目标物品 \(\Phi_{i+1}\) 与其他历史序列有交叉
  • 问题3:判别式公式限制了先前序列推荐器工作的适用性 :传统的序列推荐器本质上是判别式的。现有的序列推荐文献,包括GRU4Rec和SASRec等开创性工作,对 \(p(\Phi_{i}|\Phi_{0}, a_{0}, …, \Phi_{i - 1}, a_{i - 1})\) 进行建模(注:虽然这里传统这里输入中包含动作信息,但是实际上动作是作为额外的信息加入到 item 表征中的,不是单独将 action 作为一个 token,所以写成 \(p(\Phi_{i}|(\Phi_{0}, a_{0}), …, (\Phi_{i - 1}, a_{i - 1}))\) 更合适),即根据用户当前状态推荐下一个物品的条件分布
    • 推荐系统中实际上存在两个概率过程 :
      • 1)推荐系统向用户推荐内容 \(\Phi_{i}\) 的过程;
      • 2)用户通过某些动作 \(a_{i}\) 对推荐内容 \(\Phi_{i}\) 做出反应的过程
    • 生成式方法对推荐内容和用户动作序列的联合分布进行建模 ,如2.2节所述,即联合概率分布 \(p(\Phi_{0}, a_{0}, \Phi_{1}, a_{1}, …, \Phi_{n_{c}-1}, a_{n_{c}-1})\)(如表11(图8)所示,论文提出的生成式推荐器能够对这种分布进行建模):
      • Next action token (\(a_{i}\)) prediction 任务正是 Ranking 任务(即表1中讨论的 GR Ranking 设定)
      • Next content token (\(\Phi_{i}\)) prediction 任务对应 Retrieval 任务,目标是学习 next item
  • 重要的是,这种公式化表述不仅能够对数据分布进行适当建模,还能够通过例如 beam search 直接采样要推荐给用户的物品序列。论文假设这将产生一种比传统列表式设置(例如DPP(2014)和RL(2018))更优越的方法,论文将此类系统的完整公式化表述和评估(在6节中简要讨论)留作未来的工作

附录C:评估:合成数据

  • 如3.1节先前所述,标准的softmax注意力机制因其归一化因子,难以捕捉用户偏好的强度,而这对于用户表示学习至关重要。在推荐场景中,这一点很关键,因为系统可能不仅需要预测 item 的相对排序,还需要预测用户参与的强度(例如,未来对特定主题的 positive 行为数量)
  • 为理解这种行为,论文构建了遵循狄利克雷过程(Dirichlet Process)的合成数据,该过程在动态词汇集上生成流式数据。狄利克雷过程捕捉了用户参与历史中 “富者更富”(rich gets richer) 的行为。论文设置合成实验如下:
    • 论文将 20,000 个 item ID中的每一个随机分配到100个类别中的某一个
    • 论文生成 1,000,000 条长度为 128 的记录,其中前90%用于训练,最后10%用于测试。为模拟流式训练设置,论文最初提供40%的 item ID,其余的以相等的间隔逐步提供;即在记录 500,000 时,可以采样的最大ID是 \((40% + 60% * 0.5) * 20,000 = 14,000\)
    • 论文为每条记录从100个类别中随机选择最多5个类别,并为这5个类别随机采样一个先验 \(H_{c}\) 。按照狄利克雷过程,为每个位置顺序采样类别,具体如下:
      • 对于 \(n>1\) :
        • 以概率 \(\alpha / (\alpha + n - 1)\) ,从 \(H_{c}\) 中抽取类别 \(c\)
        • 以概率 \(n_{c} / (\alpha + n - 1)\) ,抽取类别 \(c\) ,其中 \(n_{c}\) 是先前具有类别 \(c\) 的 item 数量
        • 随机采样一个符合类别 \(c\) 且受流式约束的分配 item
      • 其中 \(\alpha\) 从(1.0, 500.0)中均匀随机采样
  • 结果见表2。由于此数据集没有时间戳,论文在HSTU中去除 \(rab^{p, t}\)。论文观察到,相对于标准 Transformer,HSTU的 HitRate@10 提高了100%以上。重要的是,将HSTU的逐点聚合注意力机制(Pointwise aggregated attention)替换为 Softmax(“HSTU w/ Softmax”)也会导致 HitRate 显著降低,这验证了类似逐点聚合注意力机制的重要性

附录D:评估:传统顺序推荐器设置

  • 论文在4.1.1节的评估重点是将HSTU与 SOTA Transformer 基线SASRec进行比较,使用最新的训练方法。在本节中,论文进一步考虑另外两种替代方法
  • 循环神经网络(RNNs)。论文考虑顺序推荐器的经典工作GRU4Rec(Hidasi等,2016),以帮助读者理解包括 Transformer 和HSTU在内的自注意力模型,在充分融入最新的建模和训练改进后,与传统RNNs相比如何
  • 自监督顺序方法。论文考虑最受欢迎的工作BERT4Rec(Sun等,2019),以了解双向自监督(BERT4Rec通过完形填空目标利用)与单向因果自回归设置(如SASRec和HSTU)相比如何
  • 结果见表12。论文重用Klenitskiy和Vasilev(2023)报告的BERT4Rec和GRU4Rec在ML-1M和ML-20M上的结果。由于使用了采样softmax损失,论文保持负样本数量不变(ML-1M、ML-20M为128,亚马逊图书为512),以确保方法之间的公平比较
  • 结果证实,在使用采样softmax损失的传统顺序推荐设置中,SASRec仍然是最具竞争力的方法之一(Zhai等,2023a;Klenitskiy和Vasilev,2023),而HSTU显著优于所评估的 Transformer 、RNN和自监督双向 Transformer

附录E:评估:传统DLRM基线

  • 第4节中使用的DLRM基线配置反映了数百名研究人员和工程师多年来的持续迭代,并且在部署HSTU/GR之前,是对拥有数十亿日活跃用户的大型互联网平台上生产配置的近似。本节对所使用的模型进行简要描述

排名设置

  • 如(Mudigere等,2022)所述,基线排名模型采用了大约一千个密集特征和五十个稀疏特征。论文结合了各种建模技术,如 Mixture of Experts(Ma等,2018)、Deep & Cross Network 的变体(Wang等,2021)、各种顺序推荐模块,包括 target-aware pairwise 注意力(在工业设置中常用的一种变体可参见(Zhou等,2018)),以及特殊交互层上的残差连接(He等,2015;Zhang等,2022)。在 scaling law 部分(4.3.1节)的低FLOP区域,一些计算成本高的模块被简化、替换为其他 state-of-the-art 的变体,如DCN,以达到所需的FLOP
  • 由于保密考虑,论文无法透露确切设置 ,但据论文所知,在充分纳入最新研究成果后,论文的基线代表了最优秀的DLRM方法之一。为验证这一说法并帮助读者理解,论文在表7中报告了基于相同特征,但仅利用主要已发表成果(包括DIN(Zhou等,2018)、DCN(Wang等,2021)和MMoE(Ma等,2018))的典型设置(“DLRM (DIN + DCN)”),并在图9中展示了组合架构。该设置在主要E任务的NE上比论文的生产DLRM设置低0.71%,在主要C任务的NE上低0.57%(0.1%的NE变化即具有显著性)

检索设置

  • 基线检索模型采用标准的双塔神经检索设置(Covington等,2016),并结合了批内(in-batch)和批外(out-of-batch)采样。输入特征集包括高基数稀疏特征(如 item ID、用户ID)和低基数稀疏特征(如语言、主题、兴趣实体)。使用带有残差连接的前馈层堆栈(He等,2015)将输入特征压缩为用户和 item 嵌入

特征和序列长度

  • DLRM基线中使用的特征,包括各种顺序编码器/成对注意力模块所利用的主要用户交互历史,是所有 GR 候选模型使用特征的严格超集(strict supersets)。这适用于论文中进行的所有研究,包括缩放研究(4.3.1节)中使用的特征

附录F:随机长度

附录F.1:子序列选择

  • 在公式(4)中,论文从完整的用户历史中选择长度为 \(L\) 的子序列以增加稀疏性。论文的实证结果表明,精心设计子序列选择技术可以提高模型质量。论文计算一个指标 \(f_{i}=t_{n}-t_{i}\) ,它对应于用户与 item \(x_{i}\) 交互后经过的时间量。论文使用以下子序列选择方法进行离线实验:
    • 贪心选择(Greedy Selection) - 从 \(S\) 中选择 \(L\) 个 \(f_{i}\) 值最小的 item,即保留最近交互的 item
    • 随机选择(Random Selection) - 从 \(S\) 中随机选择 \(L\) 个 item
    • 特征加权选择(Feature-Weighted Selection) - 根据加权分布 \(1 - f_{n, i} / (\sum_{j = 1}^{L}f_{j, i})\) 从 \(S\) 中选择 \(L\) 个 item(理解:对越近交互的样本,采样权重越大)
  • 在离线实验中,特征加权子序列选择方法产生了最佳的模型质量,如表13所示

附录F.2:随机长度对序列稀疏性的影响

  • 在表3中,论文展示了随机长度对具有30天用户参与历史的代表性工业规模配置中序列稀疏性的影响。序列稀疏性定义为1减去所有样本的平均序列长度与最大序列长度的比值。为了更好地描述稀疏注意力的计算成本,论文还定义了 \(s2\) ,它被定义为:1减去注意力矩阵的稀疏性(which is defined as one minus the sparsity of the attention matrix)。作为参考,论文在表14和表15中分别给出了60天和90天用户参与历史的结果

附录F.3:与序列长度外推技术的比较

  • 论文进行了额外的研究,以验证随机长度与语言建模中使用的现有序列长度外推技术相比具有竞争力。许多现有方法通过修改旋转位置嵌入(RoPE)(Su等,2023)来进行序列长度外推。为了与现有方法进行比较,论文训练了一个没有相对注意力偏差和旋转嵌入的HSTU变体(HSTU - RoPE)
  • 论文在HSTU-RoPE上评估以下序列长度外推方法:
    • 零样本(Zero-Shot) - 应用NTK感知的RoPE(Peng等,2024),然后直接评估模型,不进行 Fine-tune
    • 微调(Fine-tune) - 应用逐部分NTK(Peng等,2024)后,对模型进行1000步 Fine-tune
  • 论文在HSTU(包括相对注意力偏差,无旋转嵌入)上评估以下序列长度外推方法:
    • 零样本(Zero-Shot) - 根据最大训练序列长度夹紧相对位置偏差,直接评估模型(Raffel等,2020;Press等,2022)
    • 微调(Fine-tune) - 根据最大训练序列长度夹紧相对位置偏差,在评估模型之前对模型进行1000步 Fine-tune
  • 在表16中,论文报告了训练期间引入数据稀疏性的模型(随机长度、Zero-Shot、Fine-tune)与在完整数据上训练的模型之间的NE差异。论文将 Zero-Shot 和 Fine-tune 技术的稀疏性定义为训练期间的平均序列长度与评估期间的最大序列长度之比。所有 Zero-Shot 和 Fine-tune 模型都在1024序列长度的数据上进行训练,并在2048和4096序列长度的数据上进行评估。为了为这些技术找到合适的随机长度基线,论文选择了导致相同数据稀疏性指标的随机长度设置
  • 作者认为,Zero-Shot 和 Fine-tune 的序列长度外推方法不太适合处理高基数ID的推荐场景。从经验上看,论文观察到随机长度明显优于 Fine-tune 和 Zero-Shot 方法。作者认为这可能是由于论文的词表较大 ,Zero-Shot 和 Fine-tune 方法无法为较旧的ID学习良好的表示,这可能会损害它们充分利用较长序列中包含的信息的能力

附录G:稀疏分组通用矩阵乘法(GEMMs)和融合相对注意力偏差

  • 论文提供3.2节中介绍的高效HSTU注意力内核的更多信息。论文的方法基于内存高效注意力(Rabe和Staats,2021)和FlashAttention(Dao等,2022),是一种内存高效的自注意力机制,它将输入划分为块,并避免在反向传播中具体化大的 \(h×N×N\) 中间注意力张量。通过利用输入序列的稀疏性,我们可以将注意力计算重新表述为一组具有不同形状的连续GEMM运算。论文实现了高效的GPU内核来加速此计算。相对注意力偏差的构建也因内存访问而成为瓶颈。为解决此问题,论文将相对偏差构建和分组GEMM运算融合到单个GPU内核中,并在反向传播中使用GPU的快速共享内存来累积梯度。尽管论文的算法在反向传播中需要重新计算注意力和相对偏差,但它比 Transformer 中使用的标准方法明显更快且内存使用更少

附录H:Microbatched-Fast Attention Leveraging Cacheable OperatioNs (M-FALCON)

  • 在本节中,论文详细描述3.4节中讨论的 M-FALCON 算法
  • M-FALCON 引入了三个关键思想
  • 批量推理可应用于因果自回归设置 :GR中的排名任务以目标感知的方式制定,如2.2节所述。通常认为,在目标感知设置中,论文需要一次对一个 item 进行推理,对于 \(m\) 个候选 item 和长度为 \(n\) 的序列,成本为 \(O(mn^{2}d)\) 。但这里论文表明这不是最优解决方案;即使使用普通 Transformer,也可以修改自注意力中使用的注意力掩码以进行批量操作(“批量推理”),并将成本降低到 \(O((n + m)^{2}d)=O(n^{2}d)\)
    • 图11给出了说明。这里,图11(a)和(b)都涉及因果自回归设置的注意力掩码矩阵。关键区别在于,图11(a)在因果训练中使用大小为 \(2n_{c}\) 的标准下三角矩阵,而图11(b)通过将 \(i, j≥2n_{c}\) 且 \(i≠j\) 的条目设置为 False 或 \(-\infty\) 来修改大小为 \(2n_{c}+b_{m}\) 的下三角矩阵,以防止目标位置 \(\Phi_{0}’, …, \Phi_{b_{m}-1}’\) 相互关注。很容易看出,通过这样做,自注意力块对 \(\Phi_{i}’, a_{i}’\) 的输出仅取决于 \(\Phi_{0}, a_{0}, …, \Phi_{n_{c}-1}, a_{n_{c}-1}\) ,而不取决于 \(\Phi_{j}’\) (\(i≠j\))。换句话说,通过使用修改后的注意力掩码对 \((2n_{c}+b_{m})\) 个 token 进行前向传递,论文现在可以获得与对 \((2n_{c}+1)\) 个 token 进行 \(b_{m}\) 次单独前向传递相同的最后 \(b_{m}\) 个 token 的结果,在第 \(i\) 次前向传递中, \(\Phi_{i}’\) 位于第 \(2n_{c}\) (基于0)位置,使用标准因果注意力掩码
  • 微批次将批量推理扩展到大型候选集 :Ranking 阶段可能需要处理大量的排名候选 item ,多达数万个(Wang等,2020)。我们可以将总共 \(m\) 个候选 item 划分为 \(\lceil m / b_{m}\rceil\) 个大小为 \(b_{m}\) 的微批次,使得 \(O(b_{m}) = O(n)\) ,这在大多数实际推荐设置中,对于多达数万个候选 item ,保持了前面讨论的 \(O((n + m)^{2}d)=O(n^{2}d)\) 的运行时间
  • 编码器级缓存可在请求内和请求间实现计算共享 :最后,键值缓存(Pope等,2022)可在请求内和请求间应用。例如,对于论文中介绍的HSTU模型(3节), \(K(X)\) 和 \(V(X)\) 在微批次内和/或请求间完全可缓存。对于缓存的前向传递,论文只需要为最后 \(b_{m}\) 个 token 计算 \(U(X), Q(X), K(X)\) 和 \(V(X)\) ,同时为包含 \(n\) 个 token 的序列化用户历史重用缓存的 \(K(X)\) 和 \(V(X)\) 。同样, \(f_{2}(Norm(A(X)V(X))\odot U(X))\) 只需要为 \(b_{m}\) 个候选 item 重新计算。这将缓存前向传递的计算复杂度降低到 \(O(b_{m}d^{2}+b_{m}nd)\) ,即使 \(b_{m}=n\) ,也比 \(O((n + b_{m})d^{2}+(n + b_{m})^{2}d)\) 提高了2 - 4倍
  • 算法1说明了 M-FALCON 算法,有助于理解。论文注意到, M-FALCON 不仅适用于HSTU和GR,还广泛适用于其他基于自注意力架构的目标感知因果自回归设置的推理优化算法

附录H.1:推理吞吐量评估:使用 M-FALCON 的生成式推荐器(GRs)与DLRM的比较

  • 如3.4节所述, M-FALCON 在推理时并行处理 \(b_{m}\) 个候选 item ,以在所有 \(m\) 个候选 item 之间分摊计算成本。为理解论文的设计,论文在相同硬件设置下比较了GR和DLRM的吞吐量(即每秒评分的候选 item 数,QPS)
  • 如图12和图13所示,由于批量推理实现了成本分摊,GR的吞吐量在一定区域内(在论文的案例研究中 \(m = 2048\))随Ranking 阶段候选 item 数 \((m)\) 呈次线性增长。这证实了批量推理在因果自回归设置中的关键性。由于注意力复杂度按 \(O((n + b_{m})^{2})\) 缩放,利用多个微批次本身就提高了吞吐量。缓存进一步消除了微批次之上的冗余线性和注意力计算。两者结合,相对于使用单个微批次的 \(b_{m}=m = 1024\) 基线,实现了高达1.99倍的额外加速,如图13所示。总体而言,凭借高效的HSTU编码器设计和利用 M-FALCON ,基于HSTU的生成式推荐器在大规模生产设置中的吞吐量比DLRM高出2.99倍,尽管GR在FLOP方面复杂285倍

NLP——Megatron源码阅读笔记

  • 参考链接:
    • 各种并行的通信量:图解大模型训练之:张量模型并行(TP),Megatron-LM - 猛猿的文章 - 知乎
    • 模型分组及方式理解:Megatron-LM训练的大模型如何分组? - wx1997的文章 - 知乎
      • DP,TP,PP:TP > DP > PP,优先在机器内进行TP,其次是DP,最后是PP,因为通信量上是TP最多,DP其次,PP最后
    • 图解大模型系列之:Megatron源码解读1,分布式环境初始化 - 猛猿的文章 - 知乎
    • 图解大模型训练之:Megatron源码解读2,模型并行 - 猛猿的文章 - 知乎
    • 图解大模型训练系列之:Megatron源码解读3,分布式混合精度训练 - 猛猿的文章 - 知乎
    • Megatron-LM 中使用 DeepSpeed 加速:(DeepSpeed 官方文档)在 Megatron-LM 中加入 DeepSpeed
    • 专家并行论文:GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding, 2020, Google
    • [张量/序列并行]📚图解 DeepSpeed-Ulysses & Megatron-LM TP/SP - DefTruth的文章 - 知乎:写的非常详细
    • [转]Megatron-LM源码系列(八): Context Parallel并行 - 李睿的文章 - 知乎:作者还有一些博客内容可供参考
      • Megatron-LM源码系列(一):模型并行初始化
      • Megatron-LM源码系列(二):Tensor模型并行和Sequence模型并行训练
      • Megatron-LM源码系列(三):详解Pipeline模型并行训练实现
      • Megatron-LM源码系列(四):重计算(recompute)
      • Megatron-LM源码系列(五): FP16使用
      • Megatron-LM源码系列(六):Distributed-Optimizer分布式优化器实现Part1
      • Megatron-LM源码系列(七):Distributed-Optimizer分布式优化器实现Part2
      • Megatron-LM源码系列(八): Context Parallel并行
    • Megatron 新版 MoE 源码阅读 - Fizzmy的文章 - 知乎
    • 知乎专栏:跟着执行流程阅读源码系列:
      • 跟代码执行流程,读Megatron源码(一)从目录结构开始 - AIer的文章 - 知乎
      • 跟代码执行流程,读Megatron源码(二)训练入口pretrain_gpt.py - AIer的文章 - 知乎
      • 跟代码执行流程,读Megatron源码(三)megatron训练脚本training.py之pretrain() - AIer的文章 - 知乎
      • 跟代码执行流程,读Megatron源码(四)megatron训练脚本initialize.py之initialize_megatron()分布式环境初始化 - AIer的文章 - 知乎

整体说明

  • 本文的代码以 Megatron-LM 20250904 的版本 e000263e21ac89571123303c4043ec9ea7261513 为主,还有部分更早的版本的代码(之前写的,没有修改)

Megatron 数据处理

  • 数据预处理负责将 .jsonl 的文本数据 tokenize 并处理成 Megatron 可以直接读取的数据格式(.bin 和 .idx 类型的文件),减少训练时的数据处理时间

  • 数据处理的使用方式详情参考:github.com/NVIDIA/Megatron-LM

    • 准备 .jsonl 文件,文件格式如下:

      1
      2
      {"text": "Your training text here..."}
      {"text": "Another training sample..."}
    • 数据预处理:

      1
      2
      3
      4
      5
      6
      7
      python tools/preprocess_data.py \
      --input data.jsonl \
      --output-prefix processed_data \
      --tokenizer-type HuggingFaceTokenizer \
      --tokenizer-model /path/to/tokenizer.model \
      --workers 8 \
      --append-eod
      • output-prefix:输出文件的前缀
      • append-eod:是否添加 EOD Token?
      • 注意:还可以根据需要设置 split_sentences 参数,对文档进行拆分成 sentence 再做 tokenize
  • process_data.py 的核心处理逻辑如下:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    # split sentences in partition files
    if args.split_sentences and not split_sentences_present:
    processes = []
    for name in in_ss_out_names:
    p = multiprocessing.Process(target=partition.split_sentences, # TODO(ZJH): 构造完成数据 sentence 分隔的进程
    args=((name['partition'], name['sentence_split']),)) # TODO(ZJH): 参数是输入文件名和分隔结果的输出文件名
    p.start() # TODO(ZJH): 启动进程
    processes.append(p)

    for p in processes:
    p.join()

    if args.partitions == 1:
    return


    # encode partition files in parallel
    processes = []
    input_key = 'sentence_split' if args.split_sentences else 'partition' # TODO(ZJH): 根据是否走 sentence_split 来选择输入文件
    for name in in_ss_out_names:
    p = multiprocessing.Process(target=partition.process_json_file, # TODO(ZJH): 构造完成数据 encode 的进程
    args=((name[input_key], name['output_prefix']),)) # TODO(ZJH): 参数是输入文件名(上一步处理后的)和分隔结果的输出文件名
    p.start() # TODO(ZJH): 启动进程
    processes.append(p)

    for p in processes:
    p.join()

    if args.partitions == 1:
    return

Megatron-LM 训练过程梳理

  • 总入口(以 GPTModel 为例):
    1
    2
    3
    4
    5
    6
    7
    8
    9
    pretrain(
    train_valid_test_datasets_provider,
    partial(model_provider, gpt_builder), # TODO(ZJH): model_provider 调用 gpt_builder 构造模型
    ModelType.encoder_or_decoder,
    forward_step,
    args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
    extra_args_provider=add_modelopt_args if has_nvidia_modelopt else None,
    store=store,
    )

pretrain 函数是总入口,包含的核心参数如下

pretrain 参数一:train_valid_test_datasets_provider,负责管理数据
  • 输出返回迭代器,这个迭代器每个 Batch 将包含一个 micro-batch 数据
pretrain 参数二:partial(model_provider, gpt_builder),对应 model_provider 参数,负责构造并,返回模型对象
  • 返回对象的 __init__ 函数负责实现 模型结构定义

  • 返回的模型对象会实现一个 forward 函数

  • 该函数依次调用 _preprocess(),decoder() 和 _postprocess() 实现整体逻辑

    • _preprocess 负责处理输入层,包含位置编码等信息,返回 decoder 的输入
    • decoder 负责模型主要前向流程
    • _postprocess 负责处理输出层,包括 MTP 处理、 损失函数定义等,返回 损失函数值(lm_loss, 交叉熵损失)
      • 若需要执行 MTP 过程,执行 MTP 过程,同时若打开训练,则 MTP loss 在这里被计算(
        • 使用 mtp_num_layers 来表示 MTP 的深度,每深一层都会多预测一个 Token,每层对应交叉熵损失,然后乘以 loss_mask
        • 处理后的 MTP 损失使用 MTPLossAutoScaler(是 torch.autograd.Function 的子类,是 PyTorch 自定义算子的实现 ) 实现前向和反向传播
          1
          2
          3
          4
          5
          6
          7
          8
          9
          10
          11
          12
          13
          14
          15
          16
          17
          18
          19
          20
          21
          22
          23
          24
          25
          26
          27
          28
          29
          30
          31
          32
          33
          34
          35
          36
          37
          38
          39
          40
          41
          42
          43
          44
          45
          46
          47
          48
          49
          50
          51
          52
          53
          54
          55
          56
          57
          58
          59
          60
          61
          62
          63
          64
          65
          66
          67
          68
          # TODO(ZJH): 将 MTP 的每个 Token(t+2...t+k)的 loss 都添加到(通过特殊的自定义算子)主网络的计算依赖上,从而保证对主网络求梯度时,MTP 相关的梯度也能回传
          if self.mtp_process:
          mtp_labels = labels.clone()
          hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) # TODO(ZJH): 将多个 hidden_state 拆开
          hidden_states = hidden_states_list[0] # TODO(ZJH): 主函数的输出
          if loss_mask is None:
          # if loss_mask is not provided, use all ones as loss_mask
          loss_mask = torch.ones_like(mtp_labels)
          for mtp_layer_number in range(self.config.mtp_num_layers): # TODO(ZJH): 每一层都计算 loss,每一层代表一个更深的未来 Token 预测目标
          # output
          mtp_logits, _ = self.output_layer( # TODO(ZJH): 每个 hidden_states 都要走 output_layer 得到 logits 后再计算损失
          hidden_states_list[mtp_layer_number + 1], # TODO(ZJH): hidden_states_list[0] 是主网络的 hidden_state
          weight=output_weight,
          runtime_gather_output=runtime_gather_output,
          )
          # Calc loss for the current Multi-Token Prediction (MTP) layers.
          mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group) # TODO(ZJH): MTP 目标逐步后移
          loss_mask, num_tokens = roll_tensor(
          loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group
          )
          mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits)
          mtp_loss = loss_mask * mtp_loss
          if self.training:
          # TODO(shifangx): remove the use of parallel_state here
          # after moving loss logging to loss_func in pretrain_gpt.py
          MTPLossLoggingHelper.save_loss_to_tracker( # TODO(ZJH): for logging
          torch.sum(mtp_loss) / num_tokens,
          mtp_layer_number,
          self.config.mtp_num_layers,
          avg_group=parallel_state.get_data_parallel_group(
          with_context_parallel=True
          ),
          )
          mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers # TODO(ZJH): 根据参数和层数计算 scale,除以 mtp_num_layers 得到平均值,保证总的 MTP loss 量级(影响)不变
          if self.config.calculate_per_token_loss: # TODO(ZJH): 判断损失是否按照 Token 做平均
          # TODO(ZJH): MTPLossAutoScaler 是特殊的自定义算子,不改变第一个参数的值(输入即输出),求导时直接返回 第二个参数(loss)*scale 作为梯度
          hidden_states = MTPLossAutoScaler.apply( # TODO(ZJH): 经过这个自定义算子后,不会改变 hidden_states 的值(注意 hidden_states 始终是主网络的隐藏层),但对 hidden_states 计算梯度会直接返回 mtp_loss_scale * mtp_loss
          hidden_states, mtp_loss_scale * mtp_loss
          ) # TODO(ZJH): hidden_states 经过所有层后,最终得到的是所有 MTP 层 Token 的梯度(多个深度的 Token 一起)
          else:
          hidden_states = MTPLossAutoScaler.apply(
          hidden_states, mtp_loss_scale * mtp_loss / num_tokens
          )
          sequence_parallel_override = False
          if in_inference_mode and inference_context.materialize_only_last_token_logits:
          if inference_context.is_static_batching():
          hidden_states = hidden_states[-1:, :, :]
          else:
          if self.output_layer.sequence_parallel:
          # Perform the sequence parallel gather here instead of after the output layer
          # because we need to slice the last token logits from the full view of the
          # packed logits across all requests.
          # TODO(ksanthanam): Make the equivalent change in the `MambaModel` code after
          # merging in !3722.
          hidden_states = gather_from_sequence_parallel_region(
          hidden_states, group=self.model_comm_pgs.tp
          )
          self.output_layer.sequence_parallel = False
          sequence_parallel_override = True
          # Reshape [B, 1, H] to [1, B, H] -> extract each sample’s true last‐token hidden
          # state ([B, H]) -> unsqueeze back to [1, B, H]
          # (so that the output layer, which expects S×B×H, receives only the final token)
          hidden_states = inference_context.last_token_logits(
          hidden_states.squeeze(1).unsqueeze(0)
          ).unsqueeze(1)
          logits, _ = self.output_layer( # TODO(ZJH):主网络的 logits 计算
          hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
          )
  • model_provider 函数(也是 pretrain 的入参)会返回一个模型 model

    • 这个返回的模型实现了 forward 函数,model.forward 函数整体返回值是一个 loss 值(主网络的 loss,但计算图上带着 MTP 所有深度上的 loss),该值是由 _postprocess 返回的值
pretrain 参数三:forward_step,输入参数包括模型,负责调用模型执行前向过程,并返回 loss 指针等
  • 返回的 loss 函数指针可以被调用,从而计算 loss

pretrain 的工作包括环境初始化,执行训练过程等

第一步:initialize_megatron(),初始化分布式环境,包括 TP,PP,DP 等的子进程组等
第二步:setup_model_and_optimizer(),定义模型架构,切割模型,完成 optimizer 初始化
第三步:build_train_valid_test_data_iterators(), 获取数据 iterator
第四步:train(),训练入口
  • 训练的入参包括上面得到的各种结果

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    def train(
    forward_step_func,
    model,
    optimizer,
    opt_param_scheduler,
    train_data_iterator,
    valid_data_iterator,
    process_non_loss_data_func,
    config,
    checkpointing_context,
    non_loss_data_func,
    ):
  • train_step():训练过程包含一个主要的 while 循环,每次走一个 train_step()

    1
    2
    def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config, forward_backward_func): # TODO(ZJH): 单步训练入口
    """Single training step."""
    • train_step 第一步:forward_backward_func(),完成一次前向和后向过程,是训练的核心函数,也最难

      • 实际上调用的函数 forward_backward_func 经过层层函数传递 train_step() <- train() <- megatron/core/pipeline_parallel/schedules.py,最终可追述到 schedules.py 文件的 get_forward_backward_func() 函数

        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        def get_forward_backward_func():
        pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
        if pipeline_model_parallel_size > 1: # TODO(ZJH): 若打开 PP
        if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
        forward_backward_func = forward_backward_pipelining_with_interleaving # TODO(ZJH): 若打开 interleaving pipeline 调度
        else:
        forward_backward_func = forward_backward_pipelining_without_interleaving # TODO(ZJH): 若关闭 interleaving pipeline 调度
        else:
        forward_backward_func = forward_backward_no_pipelining # TODO(ZJH): 没有 PP 的情况
        return forward_backward_func
      • 若打开 PP

        • 开启 interleaving pipeline:forward_backward_pipelining_with_interleaving
          • 负责实现对应的 1F1B 调度策略,函数内部像 forward_backward_no_pipelining 函数一样,会调用 forward_step 和 backward_step 两步完成前向后向过程和梯度的积累
        • 未开启 interleaving pipeline:forward_backward_pipelining_without_interleaving
          • 在 forward_backward_pipelining_with_interleaving 的基础上,增加了 interleaving 调度策略(实现则更为复杂),进一步优化气泡
        • 其他特殊的 Pipeline 并行调度策略,如 zero_bubble 的调度,实现都在这里新建函数就可以
      • 若为 没有打开 PP 的情况:调用同文件(schedules.py)下的 forward_backward_no_pipelining() 函数,下面是该函数的介绍:

        • 前向过程+后向过程函数为(config.overlap_moe_expert_parallel_comm 为 True):combined_1f1b_schedule_for_no_pipelining

          • 注:config.overlap_moe_expert_parallel_comm 为 True 表示 框架会尝试将专家并行所需的通信操作(如数据传输)与模型的计算操作(如其他层的前向 / 反向计算)重叠进行,而不是等通信完成后再执行计算

            1
            2
            3
            4
            5
            6
            7
            8
            9
            10
            11
            12
            13
            14
            15
            16
            17
            18
            19
            20
            if config.overlap_moe_expert_parallel_comm and not forward_only: # TODO(ZJH): 如果打开 overlap MoE 的专家并行通信(将专家并行(expert parallelism)中的通信操作与计算操作重叠执行),且包含 backward
            # TODO(ZJH): 当 config.overlap_moe_expert_parallel_comm 设为 True 时,框架会尝试将专家并行所需的通信操作(如数据传输)与模型的计算操作(如其他层的前向 / 反向计算)重叠进行,而不是等通信完成后再执行计算
            forward_data_store, total_num_tokens = combined_1f1b_schedule_for_no_pipelining( # TODO(ZJH): 1次前向+1次后向过程
            forward_step_func,
            data_iterator,
            model,
            num_microbatches,
            input_tensor,
            output_tensor_grad,
            forward_data_store,
            config,
            collect_non_loss_data,
            first_val_step,
            forward_only,
            no_sync_func,
            total_num_tokens,
            partial(check_first_val_step, first_val_step, forward_only),
            )
            else:
            # forward_step 和 backward_step 交替执行
          • 分开执行的函数分别为:forward_step 和 backward_step (这里会调用 for 循环完成多个 microbatches, forward_step 和 backward_step 在循环南北部)

            • 前置说明:microbatches - 1 个 microbatches 先调用,然后最后一个负责处理梯度同步等
              1
              2
              3
              4
              5
              6
              7
              8
              9
              10
              11
              12
              13
              14
              15
              16
              17
              18
              19
              20
              21
              22
              23
              24
              25
              26
              27
              28
              29
              30
              31
              32
              33
              34
              35
              36
              37
              38
              39
              40
              41
              42
              43
              44
              45
              46
              47
              48
              49
                  with no_sync_func(): # TODO(ZJH): 如果 no_sync_func 上下文管理器 是一个实际的同步禁用逻辑(比如禁用某些 IO 同步、锁机制等),则代码块会在 “不执行同步” 的环境中运行
              for i in range(num_microbatches - 1): # TODO(ZJH): 每个设备负责多个 microbatches,注意这里少一个
              output_tensor, num_tokens = forward_step( # TODO(ZJH): 前向过程,注意,这里是单个 microbatch 走一次
              forward_step_func,
              data_iterator,
              model,
              num_microbatches,
              input_tensor,
              forward_data_store,
              config,
              grad_finalize_pgs.cp.size(),
              collect_non_loss_data,
              is_first_microbatch=check_first_val_step(first_val_step, forward_only, i == 0),
              current_microbatch=i,
              )
              total_num_tokens += num_tokens # TODO(ZJH): 累加 Token 数
              if not forward_only:
              backward_step( # TODO(ZJH): 后向过程,梯度直接累加(forward 中已经对 loss/num_microbatches)
              input_tensor, output_tensor, output_tensor_grad, model_type, config
              )
              # Run computation for last microbatch out of context handler (want to
              # synchronize gradients).
              output_tensor, num_tokens = forward_step( # TODO(ZJH): 最后一个梯度单独处理,这个梯度的计算要在 context handler 之外,核心原因是确保最后一次梯度计算完成后能触发必要的同步操作,从而保证梯度的正确性和一致性
              forward_step_func,
              data_iterator,
              model,
              num_microbatches,
              input_tensor,
              forward_data_store,
              config,
              grad_finalize_pgs.cp.size(),
              collect_non_loss_data,
              is_first_microbatch=check_first_val_step(
              first_val_step, forward_only, num_microbatches == 1
              ),
              current_microbatch=num_microbatches - 1,
              )
              total_num_tokens += num_tokens # TODO(ZJH): 累加 Token 数
              if not forward_only:
              backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)
              # TODO(ZJH): 梯度聚合
              if config.finalize_model_grads_func is not None and not forward_only:
              # Finalize model grads (perform full grad all-reduce / reduce-scatter for
              # data parallelism and layernorm all-reduce for sequence parallelism).
              config.finalize_model_grads_func( # TODO(ZJH): 梯度聚合操作 all_reduce + reduce-scatter ?
              [model],
              total_num_tokens if config.calculate_per_token_loss else None,
              grad_finalize_pgs=grad_finalize_pgs,
              )
          • forward_step() 执行过程:

            1
            2
            3
            4
            5
            6
            7
            8
            9
            10
            11
            12
            13
            14
            15
            16
            17
            18
            19
            20
            21
            forward_step 核心代码如下:

            with context_manager:
            if checkpoint_activations_microbatch is None:
            output_tensor, loss_func = forward_step_func(data_iterator, model)
            else:
            output_tensor, loss_func = forward_step_func(
            data_iterator, model, checkpoint_activations_microbatch
            )
            output_tensor, num_tokens = forward_step_calc_loss(
            model,
            output_tensor,
            loss_func,
            config,
            vp_stage,
            collect_non_loss_data,
            num_microbatches,
            forward_data_store,
            cp_group_size,
            is_last_stage,
            )
            • 其中,forward_step_calc_loss 核心代码如下:
              1
              2
              3
              4
              5
              6
              7
              8
              9
              10
              11
              12
              13
              14
              15
              16
              17
              18
              19
              if is_last_stage: # TODO(ZJH): 只有最后一个 stage 包含 loss,其他 stage 都不需要计算
              if not collect_non_loss_data:
              outputs = loss_func(output_tensor) # TODO(ZJH): 获取损失值,详情见 pretrain_gpt.py 的 loss_func 的返回值(中间使用 forward_step 作为 partial 封装)
              if len(outputs) == 3:
              output_tensor, num_tokens, loss_reduced = outputs
              # TODO(ZJH): 当 calculate_per_token_loss=True 时,损失计算会保留每个 token 的损失值(即按 token 粒度计算损失),通常用于需要获取单 token 损失的场景(如后续可能的梯度裁剪、损失分析等)
              if not config.calculate_per_token_loss: # TODO(ZJH): 当 calculate_per_token_loss=False 时,损失会被归一化(通常除以总 token 数和微批次数量),得到一个全局平均损失,这是训练中更常见的做法(避免损失值因批次大小不同而波动)
              output_tensor /= num_tokens # TODO(ZJH): 视情况看是否需要做 Token 粒度的归一化
              output_tensor /= num_microbatches # TODO(ZJH): 这里是单个 Batch,但除以 num_microbatches,是为了后续 backward 时梯度可以直接累加
              else:
              # preserve legacy loss averaging behavior (ie, over the number of microbatches)
              assert len(outputs) == 2
              output_tensor, loss_reduced = outputs
              output_tensor *= cp_group_size
              output_tensor /= num_microbatches
              forward_data_store.append(loss_reduced)
              else:
              data = loss_func(output_tensor, non_loss_data=True)
              forward_data_store.append(data)
          • backward_step() 执行过程:

            • 后向过程,可选择自定义的 backward 或 PyTorch 标准的官方实现,梯度直接累加(forward 中已经对 loss/num_microbatches)
              1
              2
              3
              4
              5
              if output_tensor[0].requires_grad:
              if config.deallocate_pipeline_outputs:
              custom_backward(output_tensor[0], output_tensor_grad[0]) # TODO(ZJH): 使用自定义的 backward
              else:
              torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0]) # TODO(ZJH): 直接使用 backward
    • train_step 第二步:optimizer.step()

      • 用梯度完成一次完整的参数更新
    • train_step 第三步:继续处理 loss 并上报

      • 注:调用完 optimizer 后,还要继续处理 loss 的原因是梯度更新不需要汇总 DP 的 loss,只有上报时需要聚合 所有 DP 的数据

附录:Megatron MTP 损失绑定函数的测试

  • MTP 损失绑定到 main_hidden_states 的方式是通过一个不修改值,但绑定梯度的自定义算子 MTPLossAutoScaler 实现:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    # megatron/core/transformer/multi_token_prediction.py
    class MTPLossAutoScaler(torch.autograd.Function): # TODO(ZJH): 相当于在实现自定义的 PyTorch 算子
    """An AutoScaler that triggers the backward pass and scales the grad for mtp loss."""

    main_loss_backward_scale: torch.Tensor = torch.tensor(1.0)

    @staticmethod
    def forward(ctx, output: torch.Tensor, mtp_loss: torch.Tensor): # TODO(ZJH): 前向过程,存储 loss,返回输入的原始值,不做任何计算
    """Preserve the mtp by storing it in the context to avoid garbage collection.

    Args:
    output (torch.Tensor): The output tensor.
    mtp_loss (torch.Tensor): The mtp loss tensor.

    Returns:
    torch.Tensor: The output tensor.
    """
    ctx.save_for_backward(mtp_loss)
    return output

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor): # TODO(ZJH): 后向过程,获取前向过程存储的 loss,乘以 main_loss_backward_scale 并返回
    """Compute and scale the gradient for mtp loss..

    Args:
    grad_output (torch.Tensor): The gradient of the output.

    Returns:
    Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled mtp loss
    gradient.
    """
    (mtp_loss,) = ctx.saved_tensors
    mtp_loss_backward_scale = MTPLossAutoScaler.main_loss_backward_scale
    scaled_mtp_loss_grad = torch.ones_like(mtp_loss) * mtp_loss_backward_scale
    return grad_output, scaled_mtp_loss_grad
  • MTPLossAutoScaler 算子测试代码:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    import torch

    # 待测试类(直接复用原代码)
    class MTPLossAutoScaler(torch.autograd.Function):
    main_loss_backward_scale: torch.Tensor = torch.tensor(1.0)

    @staticmethod
    def forward(ctx, output: torch.Tensor, mtp_loss: torch.Tensor):
    ctx.save_for_backward(mtp_loss)
    return output

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
    (mtp_loss,) = ctx.saved_tensors
    scaled_mtp_loss_grad = torch.ones_like(mtp_loss) * MTPLossAutoScaler.main_loss_backward_scale
    return grad_output, scaled_mtp_loss_grad

    def test_all_scenarios():
    # 初始化模型可训练参数(所有场景共用)
    x = torch.tensor([2.0], requires_grad=True)
    y = torch.tensor([3.0], requires_grad=True)

    # 定义测试场景配置(场景名、mtp_loss构造、scale系数)
    scenarios = [
    # 场景1:mtp_loss不依赖参数 + scale=1.0
    ("不依赖参数 + scale=1.0", lambda: torch.tensor([5.0], requires_grad=True), 1.0),
    # 场景2:mtp_loss不依赖参数 + scale=0.3
    ("不依赖参数 + scale=0.3", lambda: torch.tensor([5.0], requires_grad=True), 0.3),
    # 场景3:mtp_loss依赖参数(x+y) + scale=1.0
    ("依赖参数3*(x+y) + scale=1.0", lambda: 3*(x + y), 1.0),
    # 场景4:mtp_loss依赖参数(x+y) + scale=0.6
    ("依赖参数3*(x+y) + scale=0.6", lambda: 3*(x + y), 0.6),
    ]

    for scenario_name, mtp_loss_fn, scale in scenarios:
    # 重置梯度和scale系数
    x.grad = None
    y.grad = None
    MTPLossAutoScaler.main_loss_backward_scale = torch.tensor(scale)

    # 1. 计算模型输出和mtp_loss
    output = x * y # 模型输出(固定逻辑:x*y,梯度易验证)
    mtp_loss = mtp_loss_fn()

    # 关键修复:为mtp_loss保留梯度(无论是否为叶子张量)
    mtp_loss.retain_grad() # mtp_loss 依赖模型参数时(mtp_loss 非叶子张量),mtp_loss.grad 不存在,使用 retain_grad() 强制保留其梯度,方便后续查看

    # 2. 使用AutoScaler处理
    scaled_output = MTPLossAutoScaler.apply(output, mtp_loss)

    # 3. 构造总损失并反向传播
    total_loss = scaled_output.sum()
    total_loss.backward()

    # 4. 打印结果(保留1位小数,简洁清晰)
    print(f"=== {scenario_name} ===")
    # 确保grad存在(避免None报错)
    mtp_grad = mtp_loss.grad.item() # 注意:若不使用 mtp_loss.retain_grad(),则 mtp_loss 依赖模型参数时(mtp_loss 非叶子张量),mtp_loss.grad 不存在
    print(f"mtp_loss梯度: {mtp_grad:.1f}") # 验证scale是否生效
    print(f"参数x梯度: {x.grad.item():.1f}") # 验证是否受mtp_loss依赖关系影响
    print(f"参数y梯度: {y.grad.item():.1f}\n") # 验证是否受mtp_loss依赖关系影响

    if __name__ == "__main__":
    test_all_scenarios()

    # === 不依赖参数 + scale=1.0 ===
    # mtp_loss梯度: 1.0
    # 参数x梯度: 3.0
    # 参数y梯度: 2.0
    #
    # === 不依赖参数 + scale=0.3 ===
    # mtp_loss梯度: 0.3
    # 参数x梯度: 3.0
    # 参数y梯度: 2.0
    #
    # === 依赖参数3*(x+y) + scale=1.0 ===
    # mtp_loss梯度: 1.0
    # 参数x梯度: 6.0
    # 参数y梯度: 5.0
    #
    # === 依赖参数3*(x+y) + scale=0.6 ===
    # mtp_loss梯度: 0.6
    # 参数x梯度: 4.8
    # 参数y梯度: 3.8

NLP——LLM-Rubric-RL相关总结

注:本文包含 AI 辅助创作


整体说明

  • 在 2025 年的 LLM RL 研究中,利用 Rubrics(评分细则/量规) 来构建更精细、可解释的奖励函数已经成为主流趋势之一
  • 这种方法主要解决传统标量奖励(Scalar Reward)无法提供细粒度指导的问题
  • 本文总结相关领域具有代表性的论文,论文详细阅读见其他内容

RaR

  • 原始论文:(RaR)Rubrics as Rewards: Reinforcement Learning Beyond Verifiable Domains, Scale AI, 20251003
  • 该论文提出了“可验证奖励强化学习”(RLVR)的概念,证明了在复杂推理任务中,使用明确的 Rubrics 作为奖励信号比单纯的人类偏好更有效,能够显著提升模型推理的正确性

DR Tulu: Reinforcement Learning with Evolving Rubrics for Deep Research

  • 原始博客:DR Tulu: An open, end-to-end training recipe for long-form deep research, 20251118, AI2
  • 原始论文:DR Tulu: Reinforcement Learning with Evolving Rubrics for Deep Research, 20251124 & 20251126, AI2
  • 针对深度研究(Deep Research)场景,提出了一种“演化细则”(Evolving Rubrics)机制
  • 随着模型能力的提升,评分标准也会动态调整,从而引导模型适应性地使用工具并完成更复杂的科学问答任务

RubricRL: Simple Generalizable Rewards for Text-to-Text Generation

  • 原始论文:RubricRL: Simple Generalizable Rewards for Text-to-Image Generation, 20251125, Microsoft CoreAI
  • 文生图 Rubric,提出了一个名为 RubricRL 的通用框架,旨在为文本生成任务设计简单且可泛化的基于 Rubric 的奖励
  • 该方法强调了奖励设计的可解释性和可组合性,使用户能更灵活地定制模型行为

AdvancedIF: Rubric-Based Benchmarking and Reinforcement Learning for Advancing LLM Instruction

  • 原始论文:AdvancedIF: Rubric-Based Benchmarking and Reinforcement Learning for Advancing LLM Instruction Following, 20251113 & 20251126, Meta Superintelligence Labs & CMU
  • 该工作发布了 AdvancedIF 框架,利用基于 Rubric 的流水线来提升大模型的Instruction Following能力
  • 不仅将 Rubric 用于评估(Benchmarking),还将其直接用于 RL 训练环节

(Rubicon) Reinforcement Learning with Rubric Anchors

  • 原始论文:(Rubicon) Reinforcement Learning with Rubric Anchors, 20250818, Inclusion AI & Ant Group & Zhejiang University
  • 探讨了在 RLVR(可验证奖励 RL)范式下,如何利用“Rubric Anchors”(评分锚点)来增强大模型
  • 通过锚点机制,模型能够更稳定地对齐到预期的细粒度标准上

(RuscaRL) Breaking the Exploration Bottleneck: Rubric-Scaffolded Reinforcement Learning for General LLM Reasoning

  • 原始论文:(RuscaRL) Breaking the Exploration Bottleneck: Rubric-Scaffolded Reinforcement Learning for General LLM Reasoning, 20250823-20251022, ZJU
  • 提出了 RuscaRL 框架,将 Rubric 作为一种教学脚手架(Instructional Scaffolding)
  • 该方法旨在帮助模型突破复杂任务中的“探索瓶颈”,通过结构化的细则引导模型逐步探索出正确的策略

Self-Rewarding Rubric-Based Reinforcement Learning for Open-Ended Reasoning

  • 原始论文:Self-Rewarding Rubric-Based Reinforcement Learning for Open-Ended Reasoning, 20250919, Ant Group
  • 针对开放式推理任务,提出了一种自我奖励机制
  • 模型能够根据预设的 Rubric 对自己的输出进行评分和反馈,从而在缺乏外部大规模标注的情况下实现自我迭代和提升

RLAC: Reinforcement Learning with Adversarial Critic for Dynamic Rubric Generation

  • 原始论文:RLAC: Reinforcement Learning with Adversarial Critic for Free-Form Generation Tasks, 20251103, SJTU & UC Berkeley
  • 提出了一种结合对抗性 Critic 的强化学习方法(RLAC),通过动态生成的 Rubric 来应对训练过程中的挑战,属于Post-training阶段的优化策略

PaTaRM: Bridging Pairwise and Pointwise Signals via Preference-Aware Task-Adaptive Reward Modeling

  • 原始论文:PaTaRM: Bridging Pairwise and Pointwise Signals via Preference-Aware Task-Adaptive Reward Modeling, BUPT & Meituan, 20251028

Chasing the Tail: Effective Rubric-based Reward Modeling for Large Language Model Post-Training

  • 原始论文:Chasing the Tail: Effective Rubric-based Reward Modeling for Large Language Model Post-Training, Scale AI, 20250925

Auto-Rubric: Learning to Extract Generalizable Criteria for Reward Modeling

  • 原始论文:Auto-Rubric: Learning to Extract Generalizable Criteria for Reward Modeling, 20251020
  • 构造静态 Rubrics

(Self-Rewarding Rubrics) Self-Rewarding Rubric-Based Reinforcement Learning for Open-Ended Reasoning

  • 原始论文:Self-Rewarding Rubric-Based Reinforcement Learning for Open-Ended Reasoning, 20250919
  • 将策略自己用作 Rubrics 生成器

QA-LIGN: Aligning LLMs through Constitutionally Decomposed QA

  • 原始论文:QA-LIGN: Aligning LLMs through Constitutionally Decomposed QA, 20250609-20251204, Arizona State University

AutoRubric-R1V: Rubric-Based Generative Rewards for Faithful Multimodal Reasoning

  • 原始论文:AutoRubric-R1V: Rubric-Based Generative Rewards for Faithful Multimodal Reasoning, 20251016, University of Notre Dame
  • 多模态 Rubric

(DeepSeek-GRM)Inference-Time Scaling for Generalist Reward Modeling

  • 原始论文:(DeepSeek-GRM)Inference-Time Scaling for Generalist Reward Modeling, DeepSeek & THU, 20250403-20250925
  • 推出了 DeepSeek-GRM 模型,是 Pointwise GRM,模型地址:huggingface.co/collections/BBQGOD/deepseek-grm

Rubric-ARM

  • 原始论文:(Rubric-ARM)Alternating Reinforcement Learning for Rubric-Based Reward Modeling in Non-Verifiable LLM Post-Training, 20260202, Emory University & Purdue University
  • 交替训练

Rationale Consistency

  • 原始论文:(Rationale Consistency)Outcome Accuracy is Not Enough: Aligning the Reasoning Process of Reward Models, 20260204, Qwen Team & Fudan & THU
  • 作者提出,目前的 RLHF 中,训练和评估时都仅考虑结果准确性,这本身是不够的

MaMs(Multi-agent Markov-state)

  • 原始论文:(MaMs)Learning Query-Specific Rubrics from Human Preferences for DeepResearch Report Generation, 20260203, Tencent & Fudan
  • for DeepResearch

NLP——技术报告解读-Qwen3

注:本文包含 AI 辅助创作

  • 参考链接:
    • 原始论文:Qwen3 Technical Report, Qwen, 20250514
    • Qwen3技术报告解读-包包算法笔记
    • Qwen3技术报告解读-刘聪NLP
    • Github 链接:github.com/QwenLM/Qwen3
      • github.com/QwenLM/Qwen3/tree/main/examples/llama-factory 文件夹下有 Qwen3 在 Llama-factory 框架下微调的方法

Paper Summary

  • 整体说明:
    • Qwen3 包含一系列 LLM,包括 Dense 和 MoE 模型,参数规模从 0.6B 到 235B 不等(这个模型 Size 的丰富度,可以称为地表最强开源模型了)
    • Qwen3 的一项关键创新是将思考模式(thinking mode,用于复杂的多步推理)和非思考模式(non-thinking mode,用于快速的上下文驱动响应)集成到一个统一框架中
    • Qwen3 引入了思考预算(thinking budget)机制,允许用户在推理过程中自适应地分配计算资源,从而根据任务复杂度平衡延迟和性能
    • 该模型基于包含 36 万亿 token 的大规模数据集进行预训练,能够理解和生成 119 种语言和方言的文本
    • Qwen3 在预训练和训练后模型的标准基准测试中均表现出色,涵盖代码生成、数学、推理和 Agent 相关任务
  • 利用旗舰模型的知识(蒸馏小模型),显著减少了构建小规模模型所需的计算资源,同时确保其具备高度竞争力
  • Qwen3 在多样化基准测试中取得了 SOTA 结果,包括代码生成、数学推理、智能体任务等,与更大的 MoE 模型和专有模型相比具有竞争力
  • 与上一代 Qwen2.5 相比,Qwen3 将多语言支持从 29 种扩展到 119 种语言和方言

Introduction and Discussion

  • 人类的目标:人工通用智能(Artificial General Intelligence, AGI)或人工超级智能(Artificial Super Intelligence, ASI)
  • 这项工作介绍了基础模型家族 Qwen 的最新系列 Qwen3
    • Qwen3 是一组开源 LLM ,在广泛的任务和领域中实现了 SOTA 性能
    • 发布了 Dense 和 MoE 模型,参数数量从 0.6B 到 235B 不等,以满足不同下游应用的需求
    • 旗舰模型 Qwen3-235B-A22B 是一个 MoE 模型,总参数为 235B,每个 token 激活的参数为 22B
      • 这种设计确保了高性能和高效的推理
  • Qwen3 引入了多项关键技术进步以增强其功能和可用性
    • 首先,它将两种不同的操作模式(思考模式和非思考模式)集成到单一模型中
      • 这使得用户无需在不同模型之间切换(例如从 Qwen2.5 切换到 QwQ (2024))即可切换模式
      • 这种灵活性确保开发者和用户可以高效地根据特定任务调整模型行为
    • 此外,Qwen3 引入了思考预算,使用户能够精细控制模型在执行任务时应用的推理努力水平
      • 这一能力对于优化计算资源和性能至关重要,能够根据实际应用中的复杂度调整模型的思考行为
    • 此外,Qwen3 在涵盖 119 种语言和方言的 36 万亿 token 上进行了预训练,有效增强了其多语言能力
      • 这种广泛的语言支持扩大了其在全球用例和国际应用中的部署潜力
    • 这些进步共同使 Qwen3 成为 SOTA 开源大语言模型家族,能够有效解决跨领域和跨语言的复杂任务
  • Qwen3 的预训练过程使用了约 36 万亿 token 的大规模数据集,经过精心筛选以确保语言和领域的多样性
    • 为了高效扩展训练数据,作者采用了多模态方法:Qwen2.5-VL (2025) 被微调以从大量 PDF 文档中提取文本
    • 作者还使用领域特定模型生成合成数据:Qwen2.5-Math (2024) 用于数学内容,Qwen2.5-Coder (2024) 用于代码相关数据
  • 预训练过程采用三阶段策略:
    • 1)通用阶段(S1) :所有 Qwen3 模型在 4,096 token 的序列长度上训练超过 30 万亿 token
      • 在此阶段,模型在语言能力和通用世界知识上完成了全面预训练,训练数据涵盖 119 种语言和方言
    • 2)推理阶段(S2) :为了进一步提升推理能力,作者优化了本阶段的预训练语料库,增加了 STEM、编码、推理和合成数据的比例
      • 模型在 4,096 token 的序列长度上进一步训练约 5 万亿高质量 token
      • 在此阶段,作者加速了学习率衰减
    • 3)长上下文阶段(S3) :在最后的预训练阶段,作者收集高质量的长上下文语料以扩展 Qwen3 模型的上下文长度
      • 所有模型在 32,768 token 的序列长度上训练了数千亿 token
      • 长上下文语料库包含 75% 长度在 16,384 到 32,768 token 之间的文本,以及 25% 长度在 4,096 到 16,384 token 之间的文本
  • 为了更好地将基础模型与人类偏好和下游应用对齐,作者采用了多阶段后训练方法,同时赋能思考(推理)和非思考模式
    • 在前两个阶段,通过长 CoT 冷启动微调和专注于数学与编码任务的强化学习 ,开发了强大的推理能力
    • 在最后两个阶段,将带有和不带有推理路径的数据合并为一个统一数据集进行进一步微调 ,使模型能够有效处理两种类型的输入,随后应用通用领域的强化学习以提升在广泛下游任务中的性能
    • 对于小模型,使用强到弱蒸馏(strong-to-weak distillation),利用从大模型的 Off-policy 和 On-policy 知识迁移来增强其能力
      • 从高级教师模型的蒸馏在性能和训练效率上显著优于强化学习

Architecture

  • Qwen3 系列包括 6 个 Dense 模型(Qwen3-0.6B、Qwen3-1.7B、Qwen3-4B、Qwen3-8B、Qwen3-14B 和 Qwen3-32B)和 2 个 MoE 模型(Qwen3-30B-A3B 和 Qwen3-235B-A22B)
    • 旗舰模型 Qwen3-235B-A22B 总参数为 235B,激活参数为 22B
  • Qwen3 Dense 模型的架构与 Qwen2.5 (2024) 类似,包括使用如下组件:
    • 分组查询注意力(Grouped Query Attention, GQA)(2023)
    • SwiGLU (2017)
    • 旋转位置嵌入(Rotary Positional Embeddings, RoPE)(2024)
    • RMSNorm (2023) 的预归一化(pre-normalization)
  • Qwen3 移除了 Qwen2 (2024) 中使用的 QKV 偏置(QKV-bias),并在注意力机制中引入了 QK 归一化(QK-Norm)(2023) 以确保 Qwen3 的训练稳定性
  • 模型架构的关键信息如表 1 所示
  • Qwen3 MoE 模型与 Qwen3 Dense 模型共享相同的基础架构,模型架构的关键信息如表 2 所示
    • 遵循 Qwen2.5-MoE (2024) 并实现了细粒度专家分割(fine-grained expert segmentation)(2024)
    • Qwen3 MoE 模型共有 128 个专家,每个 token 激活 8 个专家
    • 与 Qwen2.5-MoE 不同,Qwen3-MoE 设计排除了共享专家
    • Qwen3 采用全局批量负载均衡损失(global-batch load balancing loss)(2025) 以鼓励专家专业化
    • 这些架构和训练创新在下游任务中显著提升了模型性能
  • Qwen3 模型使用 Qwen 的分词器 (2023),它实现了 BBPE(Byte-level Byte-Pair Encoding)(2020; 2020; 2016),词表为 151,669

Pre-training

Pre-training Data

  • 与 Qwen2.5(2024)相比,Qwen3 显著扩大了预训练数据的规模和多样性
    • 作者收集了 36 万亿 token 的数据,覆盖 119 种语言和方言
      • 这些数据包含多个领域的高质量内容,例如代码、STEM(科学、技术、工程和数学)、推理任务、书籍、多语言文本以及合成数据
  • 作者对预训练数据做了如下扩展:
    • First,使用 Qwen2.5-VL 模型(2025)对大量 PDF 类文档进行文本识别,随后通过 Qwen2.5 模型(2024)对识别文本进行优化,从而获得额外的高质量文本 token,总量达到数万亿
    • Second,利用 Qwen2.5(2024)、Qwen2.5-Math(2024)和 Qwen2.5-Coder(2024)模型合成了数万亿不同格式的文本 token,涵盖教材、问答、指令和代码片段等,覆盖数十个领域
    • Third,通过引入更多语言数据进一步扩展了预训练语料库。与 Qwen2.5 相比,支持的语言数量从 29 种大幅增加到 119 种,显著提升了模型的跨语言能力
  • 作者开发了一套多语言数据标注系统,用于提升训练数据的质量和多样性
    • 该系统已应用于大规模预训练数据集,对超过 30 万亿 token 进行了多维度标注,包括教育价值、领域和安全性等
    • 这些细粒度的标注支持更高效的数据过滤和组合
    • 与以往研究(2023;2023;2024)在数据源或领域级别优化数据混合比例不同,Qwen3 通过在小规模代理模型上进行细粒度数据标签的消融实验,实现了实例级别的数据混合优化

Pre-training Stages

  • Qwen3 模型的预训练分为三个阶段:
    • 1)General Stage(S1) :在第一阶段,所有 Qwen3 模型在超过 30 万亿 token 上以 4,096 token 的序列长度进行训练
      • 此阶段模型已完全掌握语言能力和通用世界知识,训练数据覆盖 119 种语言和方言
    • 2)Reasoning Stage(S2) :为了进一步提升推理能力,作者优化了本阶段的预训练语料库,增加了 STEM、代码、推理和合成数据的比例
      • 模型以 4,096 token 的序列长度进一步预训练约 5 万亿高质量 token,并在此阶段加速学习率衰减
    • 3)Long Context Stage(S3) :在最后的预训练阶段,作者收集高质量的长上下文语料库,将 Qwen3 模型的上下文长度扩展到 32,768 token
      • 长上下文语料库中,75% 的文本长度介于 16,384 到 32,768 token 之间,25% 介于 4,096 到 16,384 token 之间
      • 遵循 Qwen2.5(2024)的方法,作者使用 ABF 技术(2023)将 RoPE 的基础频率从 10,000 增加到 1,000,000,同时引入 YARN(2023)和双块注意力(Dual Chunk Attention, DCA)(2024),在推理时实现序列长度容量的四倍提升
  • 与 Qwen2.5(2024)类似,Qwen3 基于上述三个阶段开发了超参数(如学习率 Scheduler 和 Batch Size)的扩展规律
    • 通过大量实验,作者系统研究了模型架构、训练数据、训练阶段与最优训练超参数之间的关系,最终为每个 Dense 或 MoE 模型预测了最优学习率和 Batch Size 策略

Pre-training Evaluation

  • Qwen3 系列的基础语言模型评估重点关注其在通用知识、推理、数学、科学知识、代码和多语言能力方面的表现。评估数据集包括 15 个基准测试:
    • 通用任务(General Tasks) :MMLU(2021a)(5-shot)、MMLU-Pro(2024)(5-shot,CoT)、MMLU-redux(2024)(5-shot)、BBH(2023)(3-shot,CoT)、SuperGPQA(2025)(5-shot,CoT)
    • 数学与 STEM 任务(Math & STEM Tasks) :GPQA(2023)(5-shot,CoT)、GSM8K(2021)(4-shot,CoT)、MATH(2021b)(4-shot,CoT)
    • 代码任务(Coding Tasks) :EvalPlus(2023a)(0-shot)(HumanEval(2021)、MBPP(2021)、Humaneval+、MBPP+ 的平均值)、MultiPL-E(2023)(0-shot)(Python、C++、Java、PHP、TypeScript、C#、Bash、JavaScript)、MBPP-3shot(2021)、CRUX-O of CRUXEval(2024)(1-shot)
    • 多语言任务(Multilingual Tasks) :MGSM(2023)(8-shot,CoT)、MMMLU(OpenAI,2024)(5-shot)、INCLUDE(2024)(5-shot)
  • 在基础模型的基线对比中,将 Qwen3 系列基础模型与 Qwen2.5 基础模型(2024)以及其他领先的开源基础模型进行了比较,包括 DeepSeek-V3 Base(2024a)、Gemma-3(2025)、Llama-3(2024)和 Llama-4(Meta-AI,2025)系列基础模型
    • 所有模型均使用相同的评估流程和广泛采用的评估设置,以确保公平比较
Summary of Evaluation Results
  • 基于整体评估结果,论文中总结了 Qwen3 基础模型的关键结论:
    • (1) 与之前开源的 SOTA Dense 和 MoE 基础模型(如 DeepSeek-V3 Base、Llama-4-Maverick Base 和 Qwen2.5-72B-Base)相比,Qwen3-235B-A22B-Base 在大多数任务中表现更优,且参数量显著更少
    • (2) 对于 Qwen3 MoE 基础模型,实验结果表明:
      • (a) 使用相同的预训练数据,Qwen3 MoE 基础模型仅需 1/5 的激活参数即可达到与 Dense 模型相当的性能;
      • (b) 由于架构改进、训练 token 规模扩大和更先进的训练策略,Qwen3 MoE 基础模型的性能优于 Qwen2.5 MoE 基础模型,且激活参数和总参数量更少;
      • (c) 即使仅使用 Qwen2.5 Dense 基础模型 1/10 的激活参数,Qwen3 MoE 基础模型也能达到相当的性能,显著降低了推理和训练成本
    • (3) Qwen3 Dense 基础模型的整体性能与更高参数规模的 Qwen2.5 基础模型相当
      • 例如,Qwen3-1.7B/4B/8B/14B/32B-Base 分别与 Qwen2.5-3B/7B/14B/32B/72B-Base 性能相当,尤其在 STEM、代码和推理基准测试中,Qwen3 Dense 基础模型的性能甚至超越了更高参数规模的 Qwen2.5 基础模型
  • 注:【此处省略一些评估细节】

Post-training

  • 图 1 展示了 Qwen3 系列模型的后训练流程,该流程围绕两个核心目标设计:
    • (1) 思维控制(Thinking Control) :通过整合“非思维模式(non-thinking mode)”和“思维模式(thinking mode)”,使用户能够灵活选择模型是否进行推理,并通过指定思维过程的 token 预算来控制思考深度
    • (2) 强到弱蒸馏(Strong-to-Weak Distillation) :旨在简化和优化轻量级模型的后训练过程
      • 通过利用大规模模型的知识,显著降低构建小规模模型所需的计算成本和开发工作量

Long-CoT Cold Start

  • 首先构建了一个涵盖数学、代码、逻辑推理和 STEM 问题的综合数据集,每个问题均配有已验证的参考答案或基于代码的测试用例
  • 该数据集用于长链思维(long Chain-of-Thought, long-CoT)训练的冷启动阶段
  • 数据集构建包含两阶段过滤流程:查询过滤(query filtering)和响应过滤(query filtering)
  • 查询过滤阶段:使用 Qwen2.5-72B-Instruct 识别并移除难以验证的查询(例如包含多个子问题或要求生成通用文本的查询),同时排除 Qwen2.5-72B-Instruct 无需 CoT 推理即可正确回答的查询,以确保仅包含需要深度推理的复杂问题
    • 使用 Qwen2.5-72B-Instruct 标注每个查询的领域,以保持数据集的领域平衡
  • 响应过滤阶段:保留验证集后,使用 QwQ-32B 为每个剩余查询生成 \(N\) 个候选响应
    • 若 QwQ-32B 无法生成正确解决方案,则由人工标注者评估响应准确性
    • 对于通过 Pass@\(N\) 的查询,进一步应用严格过滤标准,移除以下类型的响应:
      • (1) 最终答案错误;
      • (2) 包含大量重复内容;
      • (3) 明显猜测而无充分推理;
      • (4) 思维内容与总结不一致;
      • (5) 语言混合不当或风格突变;
      • (6) 疑似与验证集内容过度相似
  • 随后,使用精炼后的数据集子集进行推理模式的初始冷启动训练
    • 此阶段的目标是为模型奠定基础推理模式 ,而非过度强调即时推理性能,从而确保模型潜力不受限制,为后续RL阶段提供更大改进空间
    • 为实现这一目标,建议在此准备阶段尽量减少训练样本数量和训练步数

Reasoning RL

  • 推理强化学习阶段使用的 query-verifier 对需满足以下四个标准:
    • (1) 未在冷启动阶段使用;
    • (2) 冷启动模型可学习;
    • (3) 尽可能具有挑战性;
    • (4) 覆盖广泛的子领域
  • 最终收集了 3995 对 query-verifier ,并采用 GRPO(2024)更新模型参数
    • 实验表明,使用大的 Batch Size 和(每个查询)高 rollout 次数,结合 Off-policy 训练以提高样本效率,对训练过程有益
    • 此外,通过控制模型的熵(entropy)使其稳步增加或保持稳定 ,可以有效平衡探索与利用(exploration and exploitation),这对维持训练稳定性至关重要
      • 注:这里作者说模型的熵是稳步增加的,但大部分论文中给出的训练都是看到熵在下降的,作者应该是否做了一些特定的设计来提升模型在这方面的能力?
    • 结果表明,在单次 RL 训练过程中,无需手动干预超参数,即可实现训练奖励和验证性能的持续提升
      • 例如,Qwen3-235B-A22B 模型的 AIME’24 分数在 170 个 RL 训练步中从 70.1 提升至 85.1

Thinking Mode Fusion

  • 思维模式融合阶段的目标是将“非思维(non-thinking)”能力整合到已具备“思维(thinking)”能力的模型中
  • 这种方法使开发者能够管理和控制推理行为,同时降低部署独立模型以处理思维和非思维任务的成本和复杂性
  • 为实现这一目标,作者对推理强化学习模型进行持续 SFT ,并设计聊天模板以融合两种模式
  • 作者发现能够熟练处理两种模式的模型在不同思维预算下均表现稳定
Construction of SFT data
  • SFT 数据集结合了“thinking”和“non-thinking”数据
  • 为确保 Stage 2 模型的性能不受额外 SFT 影响 :
    • “thinking”数据通过使用 Stage 2 模型对 Stage 1 查询进行拒绝采样(rejection sampling)生成;
    • “non-thinking”数据则精心策划,涵盖编码、数学、指令遵循、多语言任务、创意写作、问答和角色扮演等多样化任务
    • 理解:因为 Stage 2 模型已经是训练好的了,为了增加思维链融合功能,额外的 SFT 是可能会导致模型出现问题的,这里使用 Stage 2 和 Stage 1 模型相关的数据来训练,从而保证 Stage 2 和 Stage 1 模型的原始能力?
  • 此外,作者采用自动生成的检查表评估“non-thinking”数据的响应质量,并特别增加低资源(low-resource)任务上的性能,作者特地增加了翻译任务的比例
Chat Template Design
  • 为更好整合两种模式并支持用户动态切换思维过程,作者设计了 Qwen3 的聊天模板(如表 9 所示)
    • 对于思维模式和非思维模式的样本,作者分别在用户查询或系统消息中引入 /think 和 /no_think 标志,使模型能够根据用户输入选择相应的思维模式
    • 对于非思维模式样本 ,作者在助手响应中保留空的思维块(thinking block)
      • 这一设计确保模型内部格式一致性,并允许开发者通过在聊天模板中拼接空的思维块来阻止模型进行思维行为
    • 默认情况下,模型以思维模式运行,因此作者添加了一些用户查询中不包含 /think 标志的思维模式训练样本
    • 对于更复杂的多轮对话,作者在用户查询中随机插入多个 /think 和 /no_think 标志,模型响应则遵循最后遇到的标志
Thinking Budget
  • 思维模式融合的额外优势在于,一旦模型学会以非思维和思维模式响应,它自然能够处理中间情况(即基于不完整思维生成响应)
  • 这一能力为实施模型思维过程的预算控制奠定了基础
    • 具体而言,当模型的思维长度达到用户定义的阈值时,手动停止思维过程,并插入停止思维指令:“Considering the limited time by the user, I have to give the solution based on the thinking directly now.\n</think>.\n\n”
    • 插入该指令后,模型基于已积累的推理生成最终响应
    • 值得注意的是,此能力并非通过显式训练获得,而是思维模式融合的自然结果

General RL

  • 通用强化学习阶段旨在广泛增强模型在多样化场景中的能力和稳定性
  • 为此,作者建立了一个覆盖 20 多项任务 的复杂奖励系统 ,每项任务均配备定制化评分标准,重点关注以下核心能力:
    • Instruction Following :确保模型准确解释并遵循用户指令,包括内容、格式、长度和结构化输出等要求,生成符合用户期望的响应
    • 格式遵循(Format Following) :除显式指令外,模型需遵循特定格式约定
      • 例如,通过 /think 和 /no_think 标志切换思维模式,并在最终输出中使用指定 token(如 <think> 和 </think>)分隔思维和响应部分
    • 偏好对齐(Preference Alignment) :针对开放式查询,偏好对齐侧重于提升模型的帮助性、参与度和风格,最终提供更自然和令人满意的用户体验
    • 智能体能力(Agent Ability) :训练模型通过指定接口正确调用工具
      • 在 RL rollout 过程中,模型可执行完整的多轮交互周期 ,并接收真实环境执行反馈 ,从而提升其在长视野决策任务中的性能和稳定性
    • 专项场景能力(Abilities for Specialized Scenarios) :在更专业的场景中,作者设计特定任务
      • 例如,在检索增强生成(Retrieval-Augmented Generation, RAG)任务中,引入奖励信号引导模型生成准确且上下文合适的响应 ,从而最小化幻觉风险
  • 为上述任务提供反馈时,作者采用三种不同类型的奖励:
    • (1)基于规则的奖励(Rule-based Reward) :广泛用于推理强化学习阶段,也适用于指令遵循(2024)和格式遵循等通用任务
      • 设计良好的基于规则的奖励可高精度评估模型输出的正确性,避免奖励破解(reward hacking)问题
    • (2)带参考答案的模型奖励(Model-based Reward with Reference Answer) :为每个查询提供参考答案,并提示 Qwen2.5-72B-Instruct 基于参考答案对模型响应评分
      • 该方法无需严格格式化即可灵活处理多样化任务,避免纯规则奖励可能导致的假阴性
    • (3)无参考答案的模型奖励(Model-based Reward without Reference Answer) :利用人类偏好数据训练奖励模型 ,为模型响应分配标量分数
      • 该方法不依赖参考答案,可处理更广泛查询,同时有效提升模型的参与度和帮助性

Strong-to-Weak Distillation

  • 强到弱蒸馏流程专为优化轻量级模型设计,涵盖 5 个 Dense 模型(Qwen3-0.6B、1.7B、4B、8B 和 14B)和 1 个 MoE 模型(Qwen3-30B-A3B)
    • 该方法在提升模型性能的同时,有效赋予其强大的模式切换能力
  • 蒸馏过程分为两个主要阶段:
    • (1) Off-policy 蒸馏(Off-policy Distillation) :在此初始阶段,作者结合教师模型在 /think 和 /no_think 模式下的输出进行响应蒸馏 ,帮助轻量级学生模型发展基础推理能力和思维模式切换能力,为后续 On-policy 训练阶段奠定基础
    • (2) On-policy 蒸馏(On-policy Distillation) :在此阶段,学生模型生成 On-policy 序列进行微调
      • 使用学生模型采样提示后,学生模型以 /think 或 /no_think 模式生成响应,并通过对齐其 logits 与教师模型(Qwen3-32B 或 Qwen3-235B-A22B)以最小化 KL 散度(Kullback-Leibler divergence)进行微调

Post-training Evaluation

  • For 全面评估指令调优模型的质量,采用自动化基准测试思维和非思维模式下的模型性能。这些基准分为以下几类:
    • 通用任务(General Tasks) :使用 MMLU-Redux(2024)、GPQA-Diamond(2023)、C-Eval(2023)和 LiveBench(2024-11-25)(2024)等基准
      • 对于 GPQA-Diamond,每个查询采样 10 次并报告平均准确率
    • 对齐任务(Alignment Tasks) :评估模型与人类偏好的对齐程度
      • 针对指令遵循性能,报告 IFEval(2023)的严格提示准确率;
      • 针对通用主题的人类偏好对齐,使用 Arena-Hard(2024)和 AlignBench v1.1(2023);
      • 针对写作任务,依赖 Creative Writing V3(2024)和 WritingBench(2025)评估模型的熟练度和创造力
    • 数学与文本推理(Math & Text Reasoning) :评估数学和逻辑推理能力,采用高阶数学基准 MATH-500(2023)、AIME’24 和 AIME’25(2025),以及文本推理任务 ZebraLogic(2025)和 AutoLogi(2025)
      • 对于 AIME 问题,每年试题包含 Part I 和 Part II,共 30 题,每题采样 64 次并以平均准确率作为最终分数
    • 智能体与编码(Agent & Coding) :测试模型在编码和智能体任务中的熟练度,使用 BFCL v3(2024)、LiveCodeBench(v5, 2024.10-2025.02)(2024)和 Codeforces Ratings from CodeElo(2025)
      • 对于 BFCL,所有 Qwen3 模型均采用 FC 格式评估,并使用 yarn 将模型部署至 64k 上下文长度以进行多轮评估
        • 部分基线来自 BFCL 排行榜,取 FC 和 Prompt 格式中的较高分;未在排行榜中报告的模型则评估 Prompt 格式
      • 对于 LiveCodeBench,非思维模式使用官方推荐提示,思维模式则调整提示模板以允许模型更自由地思考(移除限制 You will not return anything except for the program)
      • 为评估模型与竞技编程专家的性能差距,使用 CodeForces 计算 Elo 评分,每个问题最多生成八次独立推理尝试
    • 多语言任务(Multilingual Tasks) :评估四种任务:指令遵循、知识、数学和逻辑推理
      • 指令遵循使用 Multi-IF(2024)(覆盖 8 种关键语言);
      • 知识评估包含两类:区域知识通过 INCLUDE(2024)(覆盖 44 种语言)评估,通用知识通过 MMMLU(2024)(覆盖 14 种语言,排除未优化的约鲁巴语);
      • 数学任务采用 MT-AIME2024(2025)(涵盖 55 种语言)和 PolyMath(2025)(涵盖 18 种语言);
      • 逻辑推理使用 MlogiQA(覆盖 10 种语言,源自 2024)
  • 所有 Qwen3 模型采样超参数设置如下:
    • 在思维模式下的采样超参数设置为 temperature\(=0.6\)、top-p\(=0.95\)、top-k\(=20\);
      • 对于 Creative Writing v3 和 WritingBench,应用 presence penalty\(=1.5\) 以鼓励生成更多样化内容
      • 补充:presence penalty 会对已经生成过的词汇的概率值进行削弱
        • 值越大,削弱越厉害,生成的文本更加多样化;
        • 值越小,更容易生成重复的词,但文本会更加连贯
    • 非思维模式的采样超参数为 temperature\(=0.7\)、top-p \(=0.8\)、top-k \(=20\)、presence penalty\(=1.5\)
    • 两种模式的最大输出长度均设为 32,768 token,AIME’24 和 AIME’25 除外(扩展至 38,912 token 以提供充足思维空间)
      Summary of Evaluation Results
  • 从评估结果中,我们总结出最终确定的Qwen3模型的几个关键结论如下:
    • (1)我们的旗舰模型 Qwen3-235B-A22B 在思考模式和非思考模式下均展现出开源模型中的顶尖整体性能,超越了 DeepSeek-R1 和 DeepSeek-V3 等强基线模型
      • Qwen3-235B-A22B 与闭源领先模型(如 OpenAI-o1、Gemini2.5-Pro 和 GPT-4o)相比也具有高度竞争力,彰显了其深厚的推理能力和全面的通用能力
    • (2)我们的旗舰 Dense 模型 Qwen3-32B 在大多数基准测试中优于我们此前最强的推理模型 QwQ-32B,且性能与闭源的OpenAI-o3-mini 相当,表明其推理能力令人瞩目
      • Qwen3-32B 在非思考模式下的表现也极为出色,超越了我们此前的旗舰非推理 Dense 模型 Qwen2.5-72B-Instruct。
    • (3)我们的轻量级模型(包括 Qwen3-30B-A3B、Qwen3-14B 及其他较小的 Dense 模型)与参数规模相近或更大的开源模型相比,性能始终更优,证明了我们“Strong-to-Weak Distillation”方法的成功。
  • 注:【此处省略一些评估细节】

Discussion

The Effectiveness of Thinking Budget

  • 为验证 Qwen3 可通过增加思维预算提升智能水平,在数学、编码和 STEM 领域的四个基准上调整分配的思维预算
  • 如图 2 所示,Qwen3 展现出与分配思维预算相关的可扩展且平滑的性能提升
    • Thinking Budget 越大,效果越好
  • 若进一步将输出长度扩展至 32K 以上,模型性能有望在未来继续提升,作者将此探索留作未来工作

The Effectiveness and Efficiency of On-Policy Distillation

  • 通过比较蒸馏与直接强化学习在相同 Off-policy 蒸馏 8B 检查点后的性能和计算成本(以 GPU 小时计),评估 On-policy 蒸馏的有效性和效率
  • 注:为简化,以下仅关注数学和代码相关查询
  • 表 21 的结果表明,蒸馏在仅需约 1/10 GPU 小时的情况下,性能显著优于强化学习
    • 从教师 logits 蒸馏使学生模型能够扩展其探索空间并增强推理潜力,表现为蒸馏后 AIME’24 和 AIME’25 基准的 pass@64 分数较初始检查点有所提升;
    • 强化学习未带来 pass@64 分数的任何改进。
    • 这些观察凸显了利用更强教师模型指导学生模型学习的优势

The Effects of Thinking Mode Fusion and General RL

  • For 评估后训练中思维模式融合和通用强化学习的有效性,对 Qwen-32B 模型的各个阶段进行评估
  • 除前述数据集外,引入多个内部基准以监控其他能力:
    • CounterFactQA :包含反事实问题,模型需识别问题非事实并避免生成幻觉答案
    • LengthCtrl :包含带长度要求的创意写作任务,最终分数基于生成内容长度与目标长度的差异
    • ThinkFollow :包含随机插入 /think 和 /no_think 标志的多轮对话,测试模型能否根据用户查询正确切换思维模式
    • ToolUse :评估模型在单轮、多轮和多步工具调用过程中的稳定性,分数包括工具调用过程中的意图识别准确率、格式准确率和参数准确率
  • 结果如表 22 所示,可得出以下结论:
    • (1) Stage 3 将非思维模式整合至已通过前两阶段训练具备思维能力的模型中
      • ThinkFollow 基准分数 88.7 表明模型已具备初步模式切换能力,但仍偶有错误
      • Stage 3 还提升了模型在思维模式下的通用和指令遵循能力,CounterFactQA 提升 10.9 分,LengthCtrl 提升 8.0 分
    • (2) Stage 4 进一步强化模型在思维和非思维模式下的通用、指令遵循和智能体能力
      • ThinkFollow 分数提升至 98.9,确保准确模式切换
    • (3) 对于知识、STEM、数学和编码任务 ,思维模式融合和通用强化学习未带来显著提升
      • 相反,对于 AIME’24 和 LiveCodeBench 等挑战性任务,思维模式性能在这两个训练阶段后实际下降
      • 作者推测这种退化是由于模型在更广泛通用任务上训练,可能影响其处理复杂问题的专项能力
      • 在 Qwen3 开发中,作者选择接受这种性能权衡以增强模型的整体通用性

Future Work

  • 在不久的将来,作者的研究将聚焦于以下几个关键方向:
    • (1) 预训练扩展(Scale up pretraining) :作者将继续扩展预训练规模,使用质量更高、内容更多样的数据
    • (2) 架构优化(Improving model architecture) :改进模型架构和训练方法,以实现高效压缩、超长上下文支持等目标
    • (3) RL :增加计算资源投入,重点关注基于环境反馈的智能体强化学习系统,以构建能够处理需要推理时间扩展的复杂任务的智能体

Appendix

A.1 Additional Evaluation Results

A.1.1 Long-Context Ability
  • For 评估长上下文处理能力,在 RULER 基准测试 (2024) 中报告了结果(表 23)
  • 为实现长度外推(length extrapolation),使用 YARN (2023) 并设置缩放因子 \( \text{scaling_factor} = 4 \)
  • 在思考模式下,作者将思考预算(thinking budget)设为 8192 token,以减少对超长输入的冗余推理
  • 结果显示:
    • (1) 在非思考模式下,Qwen3 在长上下文处理任务中优于同规模的 Qwen2.5 模型
    • (2) 在思考模式下,模型性能略有下降
      • 作者推测思考内容对这些检索任务(无需依赖推理)帮助有限 ,甚至可能干扰检索过程
      • 未来版本将重点提升思考模式下的长上下文能力

A.1.2 Multilingual Ability

  • 表 24-35 展示了 Qwen3 系列模型在西班牙语、法语、葡萄牙语、意大利语、阿拉伯语、日语、韩语、印尼语、俄语、越南语、德语和泰语等多种语言中的详细基准得分
    • 这些结果表明,Qwen3 系列模型在所有评估基准中均表现优异,展现了强大的多语言能力
  • 为更广泛评估 Qwen3 的多语言表现,作者使用 Belebele (2023) 基准测试,覆盖 80 种优化语言(表 36 按语系列出)
  • 表 37 展示了 Qwen3 与其他基线模型在 Belebele 基准上的性能对比
  • 结果显示,Qwen3 在同等规模下与 Gemma 模型表现相当,同时显著优于 Qwen2.5

附录:Qwen3-MoE 源码

  • 参考博客:图解 Qwen3 MoE 模型源码
    • 包含一些流程图和源码,比较清晰

NLP——LLM对齐微调-VC-PPO

注:本文包含 AI 辅助创作

  • 参考链接:
    • (VC-PPO)What’s Behind PPO’s Collapse in Long-CoT? Value Optimization Holds the Secret, arXiv 20250303, ByteDance Seed

整体总结

  • 本文提出一种改进的 PPO 方法, VC-PPO(Value-Calibrated PPO)
  • 核心贡献:
    • Pretrained value:开始 RL 前先预训练价值网络
    • Decoupled-GAE:计算 Advantage (for Actor 损失)时和 计算 Target Reward(for Critic 损失)时使用不同的 \(\lambda\)

Motivation

  • LLM 在复杂推理任务(如数学、编程)中表现出色,尤其是通过生成长链思维(Long-CoT)来逐步推导答案
  • OpenAI的o1、DeepSeek-R1等模型都采用了这种“推理时扩展”策略
  • 一个发现:在Long-CoT任务中,PPO经常失效 ,表现为:
    • 模型输出长度迅速下降(注:通常是训练初期即开始大幅下降)
    • 验证集性能大幅退化
    • 无法有效利用长链思维
  • 注:这些现象与 PPO 在传统 RL 任务中的成功形成鲜明对比

问题诊断:PPO 失效的两大原因简单描述

  • 作者通过实验和分析,识别出 PPO 在 Long-CoT 任务中失效的两大根本原因
  • 注:导致模型输出长度迅速下降的直接原因是:前置 Advantage 被高估,详情见本文后面补充的附录

Value/Critic Initialization Bias

  • 在 RLHF 中,常用做法是用训练好的 奖励模型来初始化Value/Critic 模型
    • 注:这种做法则源于奖励模型和价值模型之间的表面相似性,因为两个模型都旨在预测关于响应的标量信息
  • 但奖励模型只在 <EOS> 处提供评分,对前面的 token 没有监督信号,导致它对前置 token 的打分偏低
  • 这种偏差在 GAE 中会被放大,导致前置 token 的 Advantage 被高估,进而促使模型倾向于生成短回答

奖励信号衰减(Reward Signal Decay)

  • 在 GAE 中,当 \(\lambda < 1\) 时,来自 <EOS> 的奖励信号会随传播距离指数衰减
  • 在 Long-CoT 任务中,序列长度可达数千 token,前置 token 几乎接收不到任何奖励信号
  • 这导致值模型无法有效学习,进而影响策略优化
  • 注:传统 RLHF 中常常使用 \(\lambda = 0.95\)
    • 这种做法源于传统的 RL 文献,其中 PPO 已在像 Mujoco 和 Atari 这样的环境中得到广泛测试
    • 在这些环境中,奖励会在轨迹上累积,导致高方差的回报,因此,使用 \(\lambda < 1\) 方差降低是必要的
    • 但这种方式会导致模型收敛缓慢
奖励信号衰减的数学推导
  • 定义:
    • \(V\):值函数
    • \(r_t\):即时奖励
    • \(\lambda\):GAE 平滑因子
  • 在标准 GAE 中,优势估计为:
    $$
    \hat{A}_t = \sum_{l=0}^{T-t-1} \lambda^l \delta_{t+l}, \quad \delta_t = r_t + V(s_{t+1}) - V(s_t)
    $$
  • Critic 目标为回报的估值:
    $$ R \approx V^{\text{target} }(s_t) = V_{old}(s_t) + A^{GAE}_t$$
  • 展开即可得到:
    $$
    V^{\text{target} }(s_t) =
    \begin{cases}
    \sum_{l=0}^{T-t-1} \lambda^l (r_{t+l} + V(s_{t+l+1}) - V(s_{t+l})) + V(s_t), & \lambda < 1 \\
    \sum_{l=0}^{T-t-1} r_{t+l}, & \lambda = 1
    \end{cases}
    $$
    • 上面式子的详细推导见:RL——强化学习中的方差与偏差 的 GAE 部分
    • 当 \(\lambda=1\) 时, Critic 目标即为累积奖励,无偏且稳定
    • 当 \(\lambda<1\) 时, Critic 目标引入了值函数自身的估计,可能不稳定
      • 当长度太大时,可以看到最终的奖励 \(r_T\) 几乎被淹没了,需要逐步透传到 Critic,所以 Critic 收敛性本身也会变慢

模型输出长度迅速下降 的原因详细分析

  • 现象描述:在标准的 PPO 训练中,模型本应生成长链思维(Long-CoT)来逐步推理答案
    • 但在实验中,训练刚开始不久,模型输出的平均长度急剧下降 ,随之而来的是验证集性能的崩溃(如图 1 所示)
  • 根据论文内容,可以得到一个清晰的因果链:
    • 1)初始阶段 :Critic 模型对前置 token 的估计偏低(来自奖励模型初始化)
    • 2)GAE 计算 :这种偏差被 GAE 放大,使前置 token 获得过高的优势值
    • 3)策略更新 :PPO 鼓励高优势值的动作,即鼓励模型多输出前置 token,少输出后置 token
    • 4)行为变化 :模型学会尽早结束生成,因为“早结束”意味着“多输出前置 token、少输出后置 token”
    • 5)恶性循环 :输出变短后,训练数据中的长序列减少,值模型更难学习长序列的价值,进一步强化短输出倾向

直接原因:前置 token 的优势被高估

  • 通过分析优势值(Advantage)与token位置的关联(如图 2 所示)发现:
  • 越靠前的 token,其优势值越高(正偏差越大);
    • 这种偏差导致模型倾向于更早结束生成 ,因为前置 token 被“鼓励”输出,而后置 token 被“惩罚”或忽略
    • 模型学会“尽早收尾”,从而输出长度急剧下降

根本原因之一:值初始化偏差

奖励模型的训练目标
  • 奖励模型只在 token 处给出评分(如正确=1,错误=-1);
  • 它对前面的 token 没有直接的监督信号;
  • 因此,奖励模型对前置 token 的评分偏低(因为信息不完整)
    • 理解:这里不够严谨
      • 其实需要看初始化值和大部分目标值的相对关系,本质应该是可能偏高也可能偏低才对
值模型从奖励模型初始化
  • 在 RLHF 中,常用做法是将训练好的奖励模型作为值模型的初始化;
  • 这导致值模型在初始阶段也对前置token的预期回报估计偏低
偏差在 GAE 中被放大
  • GAE 的优势估计公式为:
    $$
    \hat{A}_t = \sum_{l=0}^{T-t-1} \lambda^l \delta_{t+l}, \quad \delta_t = r_t + V(s_{t+1}) - V(s_t)
    $$
  • 由于前置token的 \(V(s_t)\) 被低估,而 \(V(s_{t+1})\) 相对较高(因为更接近 <EOS>),导致:
    • \(\delta_t\) 为正;
    • 这些正偏差在累加过程中被放大,最终使前置token的优势值显著偏高

根本原因之二:奖励信号衰减

GAE 中的奖励传播机制
  • 当 \(\lambda < 1\) 时(如默认的 0.95),来自 <EOS> 的奖励信号会随传播距离指数衰减:
    $$
    \text{传播到第 } t \text{ 个 token 的奖励信号} = \lambda^{T-t} \cdot r_{}
    $$
Long-CoT 任务的特点
  • 序列长度 \(T\) 可能达到数千token;
  • 前置 token 的 \(T-t\) 很大,\(\lambda^{T-t} \approx 0\);
  • 前置 token 几乎接收不到任何来自最终答案的奖励信号
对 Critic 模型的影响
  • 值模型难以学习到前置 token 的真实价值;
  • 值估计进一步失真,加剧了前置 token 的优势偏差

解决方案:VC-PPO

  • VC-PPO(Value-Calibrated PPO) 同时解决上述两个问题

创新1:Value-Pretraining

  • 目的:解决值初始化偏差,使值模型在训练开始前就与初始策略对齐
  • 方法流程
    • Step 1: 固定策略模型 ,使用初始策略(如SFT后的模型)生成大量 Response
    • Step 2: 使用 Monte-Carlo 回报(即 GAE \(\lambda=1.0\)) 作为值模型的目标,进行离线训练
    • Step 3: 训练至值损失和解释方差(explained variance)收敛
    • Step 4: 将该值模型作为后续 PPO 训练的初始值模型
  • 实验效果
    • 消除了前置 token 的优势偏差
    • 保留了 Long-CoT 的模式,避免输出长度崩溃

创新2:Decoupled-GAE(解耦 GAE)

  • 这是本文的核心创新点,目的是在策略优化和值函数优化中使用不同的 \(\lambda\) 值,以分别满足两者的不同需求
  • 注:使用不同的 \(\lambda\) 值更新时,需要证明其策略梯度还能准确,下文会证明这个事情
背景问题描述
  • 策略优化需要低方差 的梯度估计,因此希望使用较小的 \(\lambda\)(如 0.95);
  • 但值函数优化需要无偏的目标 ,因此希望使用 \(\lambda=1.0\),避免奖励信号衰减;
  • 传统 PPO 中,两者共用同一个 \(\lambda\),无法兼顾
方法流程
  • Step 1: 策略优化使用 \(\lambda_{\text{actor} } < 1.0\)(如 0.95),以降低梯度方差
  • Step 2: 值函数优化使用 \(\lambda_{\text{critic} } = 1.0\),以确保 Critic 目标无偏
  • Step 3: 两者的 GAE 计算独立进行 ,互不干扰
  • Algorithm1:
策略梯度的无偏性证明(待详细推导和理解)
  • 作者进一步证明,即使值函数是用不同 \(\lambda\) 训练的,将其代入策略梯度中仍然无偏:
  • 定义 n-step 回报为:
    $$
    G_{t:t+h} =
    \begin{cases}
    \sum_{l=0}^{h-1} r_{t+l} + \bar{V}(s_{t+h}), & t+h < T \\
    \sum_{l=0}^{T-t-1} r_{t+l}, & t+h = T
    \end{cases}
    $$
  • 则优势可写为:
    $$
    \hat{A}_t = (1-\lambda) \sum_{l=1}^{T-t-1} \lambda^{l-1} G_{t:t+l} + \lambda^{T-t-1} G_{t:T} - \bar{V}(s_t)
    $$
  • 即,具有任意 \(\lambda\) 的策略梯度可以重写如下:
    $$\begin{aligned} \mathbb{E}_t [\nabla_\theta \log \pi_\theta(a_t|s_t) A_t] &= \mathbb{E}_t \left[ \nabla_\theta \log \pi_\theta(a_t|s_t) \sum_{l=0}^{T-t-1} \lambda^l (r_{t+l} + \bar{V}(s_{t+l+1}) - \bar{V}(s_{t+l})) \right] \\ &= \mathbb{E}_t \left[ \nabla_\theta \log \pi_\theta(a_t|s_t) \left( (1-\lambda) \sum_{l=1}^{T-t-1} \lambda^{l-1} G_{t:t+l} + \lambda^{T-t-1} G_{t:T} - \bar{V}(s_t) \right) \right] \\ &= \mathbb{E}_t \left[ \nabla_\theta \log \pi_\theta(a_t|s_t) \left( (1-\lambda) \sum_{l=1}^{T-t-1} \lambda^{l-1} G_{t:t+l} + \lambda^{T-t-1} G_{t:T} \right) \right] \end{aligned} \tag{8}$$
    • 根据公式 8,代入任意价值函数不会给策略梯度引入额外的偏差
    • 鉴于大型语言模型所需的大量时间和计算资源,使用较小的 \(\lambda\) 来加快策略的收敛是可取的
    • 一个候选配置可以是 \(\lambda_{\text{policy}} = 0.95\) 和 \(\lambda_{\text{value}} = 1.0\)
  • 论文中提到:使用任意值函数不会引入额外偏差
    • 问题:这里并不是说明使用 Decoupled-GAE 前后的梯度一致,因为 \(G_{t:t+l}\) 中也包含了 \(\bar{V}(s_{t+l})\)
      • 此时 Actor 更新时最大化的目标(Advantage)本身已经发生了改变,目标变成了 最大化新的 \(\lambda\) 下的 Critic 值
    • 这里仅仅是证明了这 Actor 和 Critic 解耦的 \(\lambda\) 下,也有一个可以学习的目标值(形式上与原始的 策略梯度法/PPO 的结果是一样的)
      • 注:这也就是说形式上本身是不冲突的,但没有说明两者的目标是完全等价的

实验

Setting

  • 主要任务 :AIME、GPQA、Codeforces
  • Base 模型型 :Qwen2.5 32B
  • 冷启动 :使用少量 Long-CoT 格式样本进行 SFT
  • 奖励 :规则驱动的答案匹配(正确=1,错误=-1)
  • Baselines :PPO(\(\lambda=0.95\))和 GRPO(DeepSeek-R1 中使用)

对照 GRPO

  • VC-PPO vs GRPO
    模型 AIME 2024 pass@1 GPQA pass@1 CodeForces pass@1
    GRPO 38.9 49.4 12.6
    VC-PPO 48.8 48.8 12.8
  • VC-PPO 在 AIME 上显著优于 GRPO,达到 Qwen-32B 模型在该任务上的 SOTA

消融实验

  • 移除 Value-Pretraining :AIME pass@1 从 41.9 降至29.4
  • 移除 Decoupled-GAE :AIME pass@1 降至 29.4
  • \(\lambda_{\text{actor} }\) 调优 :\(\lambda=0.99\) 效果最佳,\(\lambda=1.0\) 效果最差

关键 Insight

  • 值模型与策略的对齐至关重要 :尤其是在 Long-CoT 任务中,值模型必须充分理解策略的生成模式,否则会破坏 CoT 结构
  • 值预训练不仅是值热身,更是知识注入 :它帮助模型理解哪些 token 更有价值,是提升性能的关键
  • 值优化对噪声更鲁棒 :值模型可以使用 \(\lambda=1.0\) 的无偏目标,而策略需要 \(\lambda<1.0\) 来降低方差,二者解耦是合理的

附录:证明 score function 的梯度性质

  • 核心思路:
    • 概率和的梯度为 0
  • 证明详情见:Math——KL中得分函数和路径梯度的理解

NLP——技术报告解读-DeepSeek-V3

注:本文包含 AI 辅助创作

  • 参考链接:
    • 原始论文:DeepSeek-V3 Technical Report, arXiv 202412 & 202502, DeepSeek-AI
    • GitHub:github.com/deepseek-ai/DeepSeek-V3

Paper Summary

  • 核心内容总结:
    • DeepSeek-V3 是一个大型 MoE 语言模型(671B-A37B),在 14.8T Token 上进行了训练
    • 除了 DeepSeek-V2 中的 MLA 和 DeepSeekMoE 架构外,还开创性地采用了无辅助损失(auxiliary-loss-free)的负载平衡策略,并设置了 Multi-token 预测(Multi-token prediction)训练目标以获得更强的性能
    • 性能:截止到发布,DeepSeek-V3 是最强的开源模型,并且实现了与领先的闭源模型(如 GPT-4o 和 Claude-3.5-Sonnet)相媲美的性能
    • 最大特点:
      • 性能强大且保持了经济的训练成本,其完整训练(包括预训练、上下文长度扩展和后训练)仅需 2.788M H800 GPU 小时
      • 训练过程非常稳定,在整个训练过程中,论文没有遇到任何不可恢复的损失尖峰或执行任何回滚
  • DeepSeek-V3 采用了:
    • 多头潜在注意力(Multi-head Latent Attention, MLA),from DeepSeek-V2
    • DeepSeekMoE 架构, from DeepSeek-V2
    • DeepSeek-V3 开创了一种无辅助损失的(auxiliary-loss-free)负载均衡策略,并设定了 Multi-token 预测训练目标以获得更强的性能
  • 论文在 14.8T 个多样化和高质量的 Token 上对 DeepSeek-V3 进行了预训练(随后进行了 SFT 和 RL)

Introduction and Discussion

  • LLMs 经历了快速的迭代和演进 (Anthropic, 2024; Google, 2024; OpenAI, 2024a),逐步缩小了通往通用人工智能(Artificial General Intelligence, AGI)的差距(以上都是闭源模型)
  • 开源模型,包括 DeepSeek 系列 (DeepSeek-AI, 2024a, 2024b, 2024c; 2024)、LLaMA 系列 (Al@Meta, 2024a,2023b)、Qwen 系列 (Qwen, 2023, 2024a, 2024b) 和 Mistral 系列 (2023; Mistral, 2024),也取得了显著进展,努力缩小与闭源模型的差距
  • 为了进一步推动开源模型能力的边界,论文扩大了模型规模并推出了 DeepSeek-V3,这是一个大型 MoE 模型,拥有 671B 参数,其中每个 Token 激活 37B 参数
  • 以前瞻性的视角,论文始终致力于实现强大的模型性能和经济高效的成本
    • 在架构方面,DeepSeek-V3 仍然采用 MLA (DeepSeek-AI, 2024c) 以实现高效推理,并采用 DeepSeekMoE (2024) 以实现经济高效的训练
    • 这两种架构已在 DeepSeek-V2 (DeepSeek-AI, 2024c) 中得到验证,证明了它们在保持强大模型性能的同时实现高效训练和推理的能力
    • 除了基本架构之外,论文还实施了两种额外的策略以进一步增强模型能力
      • 1)DeepSeek-V3 开创了一种无辅助损失的负载均衡策略 (2024a),旨在最小化因鼓励负载均衡而对模型性能产生的不利影响
      • 2)DeepSeek-V3 采用了 Multi-token 预测训练目标,论文观察到这可以提升在评估基准上的整体性能
  • 为了实现高效训练,论文支持 FP8 混合精度训练并对训练框架进行了全面的优化
    • 低精度训练已成为一种有前景的高效训练解决方案 (2022; 2019; 2017; 2023b),其发展与硬件能力的进步密切相关 (2024; 2022; 2023a)
    • 论文引入了一个 FP8 混合精度训练框架,并首次在超大规模模型上验证了其有效性
    • 通过对 FP8 计算和存储的支持,论文实现了加速训练和降低 GPU 内存使用
  • 至于训练框架,论文设计了 DualPipe 算法以实现高效的流水线并行(Pipeline Parallelism, PP),该算法具有更少的流水线气泡,并通过计算-通信重叠(Overlap)隐藏了训练过程中的大部分通信
    • 这种重叠确保了随着模型的进一步扩展,只要论文保持恒定的计算-通信比,论文仍然可以在节点间使用细粒度专家(fine-grained experts),同时实现接近零的全交换(all-to-all)通信开销
    • 论文还开发了高效的跨节点全交换通信内核,以充分利用 InfiniBand (IB) 和 NVLink 带宽
    • 论文精心优化了内存占用,使得在不使用昂贵的张量并行(Tensor Parallelism, TP)的情况下训练 DeepSeek-V3 成为可能
  • 结合这些努力,论文实现了高训练效率
  • 在预训练期间,论文在 14.8T 高质量和多样化的 Token 上训练 DeepSeek-V3
    • 预训练过程非常稳定,在整个训练过程中,论文没有遇到任何不可恢复的损失尖峰或不得不回滚的情况
  • 接下来,论文对 DeepSeek-V3 进行了两阶段的上下文长度扩展
    • 在第一阶段,最大上下文长度扩展到 32K
    • 在第二阶段,进一步扩展到 128K
  • 此后,论文进行了Post-training,包括对 DeepSeek-V3 基础模型进行 SFT 和 RL ,以使其与人类偏好对齐并进一步释放其潜力
    • 在后训练阶段,论文从 DeepSeek-R1 系列模型中蒸馏推理能力,同时仔细维护模型准确性和生成长度之间的平衡
  • 论文在全面的基准测试套件上评估了 DeepSeek-V3
    • 综合评估显示 DeepSeek-V3-Base 已成为当前最强的开源基础模型,尤其是在代码和数学领域
    • 其聊天版本也在一系列标准和开放式基准测试中优于其他开源模型,并实现了与领先闭源模型(包括 GPT-4o 和 Claude-3.5-Sonnet)相媲美的性能
  • 最后,论文再次强调 DeepSeek-V3 的经济训练成本,总结在表 1 中,这是通过论文对算法、框架和硬件的优化协同设计实现的
    • 在预训练阶段,在每万亿 Token 上训练 DeepSeek-V3 仅需 180K H800 GPU 小时,即在论文拥有 2048 个 H800 GPU 的集群上需要 3.7 天
    • 因此,论文的预训练阶段在不到两个月内完成,成本为 2664K GPU 小时
    • 加上上下文长度扩展的 119K GPU 小时和后训练的 5K GPU 小时,DeepSeek-V3 的完整训练仅花费 2.788M GPU 小时
    • 假设 H800 GPU 的租赁价格为每小时 2 美元,论文的总训练成本仅为 557.6 万美元
    • 请注意,上述成本仅包括 DeepSeek-V3 的官方训练,不包括先前在架构、算法或数据上的研究和消融实验相关的成本
  • 论文的主要贡献包括:
    • 架构:创新的负载均衡策略和训练目标 (Architecture: Innovative Load Balancing Strategy and Training Objective)
      • 在 DeepSeek-V2 的高效架构基础上,论文开创了一种无辅助损失的负载均衡策略 (auxiliary-loss-free strategy for load balancing),该策略最小化了因鼓励负载均衡而产生的性能下降
      • 论文研究了 Multi-token 预测 (Multi-token Prediction, MTP) 目标,并证明其有益于模型性能。它也可用于推测解码 (speculative decoding) 以加速推理
    • 预训练:追求极致的训练效率 (Pre-Training: Towards Ultimate Training Efficiency)
      • 论文设计了一个 FP8 混合精度训练框架 (FP8 mixed precision training framework),并首次在超大规模模型上验证了 FP8 训练的可行性和有效性
      • 通过算法、框架和硬件的协同设计 (co-design),论文克服了跨节点 MoE 训练中的通信瓶颈,实现了近乎完全的计算-通信重叠 (computation-communication overlap)
        • 这显著提高了论文的训练效率并降低了训练成本,使论文能够在没有额外开销的情况下进一步扩大模型规模
      • 仅以 2.664M H800 GPU 小时的经济成本,论文完成了 DeepSeek-V3 在 14.8T Token 上的预训练,产生了当前最强的开源基础模型
        • 预训练之后的后续训练阶段仅需 0.1M GPU 小时
    • 后训练:来自 DeepSeek-R1 的知识蒸馏 (Post-Training: Knowledge Distillation from DeepSeek-R1)
      • 论文引入了一种创新方法,将推理能力从长思维链 (long Chain-of-Thought, long CoT) 模型,特别是从 DeepSeek R1 系列模型之一,蒸馏到标准大语言模型中,尤其是 DeepSeek-V3
      • 论文的流程巧妙地将 R1 的验证和反思模式融入 DeepSeek-V3,并显著提高了其推理性能
      • 同时,论文也保持了对 DeepSeek-V3 输出风格和长度的控制
    • 核心评估结果总结 (Summary of Core Evaluation Results)
      • 知识 (Knowledge) :
        • (1) 在教育类基准测试如 MMLU、MMLU-Pro 和 GPQA 上,DeepSeek-V3 优于所有其他开源模型,在 MMLU 上达到 88.5,在 MMLU-Pro 上达到 75.9,在 GPQA 上达到 59.1
          • 其性能与领先的闭源模型如 GPT-4o 和 Claude-Sonnet-3.5 相当,缩小了开源模型与闭源模型在该领域的差距
        • (2) 在事实性基准测试方面,DeepSeek-V3 在 SimpleQA 和 Chinese SimpleQA 上均表现出优于其他开源模型的性能
          • 虽然它在英语事实知识 (SimpleQA) 上落后于 GPT-4o 和 Claude-Sonnet-3.5,但在中文事实知识 (Chinese SimpleQA) 上超越了这些模型,突显了其中文事实知识的优势
      • 代码、数学和推理 (Code, Math, and Reasoning) :
        • (1) DeepSeek-V3 在所有非长思维链的开源和闭源模型中,在数学相关基准测试上达到了 SOTA 性能。值得注意的是,它在特定基准测试上(如 MATH-500)甚至超过了 o1-preview,展示了其强大的数学推理能力
        • (2) 在代码相关任务上,DeepSeek-V3 成为代码竞赛基准测试(如 LiveCodeBench)中表现最佳的模型,巩固了其在该领域的领先地位
          • 对于工程相关任务,虽然 DeepSeek-V3 的表现略低于 Claude-Sonnet-3.5,但它仍然以显著优势领先所有其他模型,展示了其在各种技术基准测试上的竞争力
  • 剩余部分组织如下:
    • 详细阐述论文的 DeepSeek-V3 模型架构(第 2 节)
    • 介绍论文的基础设施,包括论文的计算集群、训练框架、对 FP8 训练的支持、推理部署策略以及论文对未来硬件设计的建议(第 3 节)
    • 描述论文的预训练过程,包括训练数据的构建、超参数设置、长上下文扩展技术、相关评估以及一些讨论(第 4 节)
    • 讨论论文在后训练方面的工作,包括 SFT 、 RL 、相应的评估和讨论(第 5 节)
    • 总结讨论 DeepSeek-V3 现有的局限性,并提出未来研究的潜在方向(第 6 节)

Architecture

  • 论文首先介绍 DeepSeek-V3 的基本架构,其特点是采用 MLA (DeepSeek-AI) 实现高效推理,以及采用 DeepSeekMoE (2024) 实现经济高效的训练
  • 然后,论文介绍一种 MTP 训练目标,论文观察到该目标能够提升模型在评估基准上的整体性能
  • 对于其他未明确提及的细节,DeepSeek-V3 遵循 DeepSeek-V2 (DeepSeek-AI) 的设置

Basic Architecture

  • DeepSeek-V3 的基本架构仍然在 Transformer (2017) 框架内
  • 为了高效推理和经济高效的训练,DeepSeek-V3 同样采用了 MLA 和 DeepSeekMoE,这两者已在 DeepSeek-V2 中得到充分验证
  • 与 DeepSeek-V2 相比,一个例外是:
    • 论文额外为 DeepSeekMoE 引入了一种 无辅助损失负载均衡 (auxiliary-loss-free load balancing) 策略 (Wang 等人),以减轻因努力确保负载均衡而导致的模型性能下降
    • 图 2 展示了 DeepSeek-V3 的基本架构,论文将在本节简要回顾 MLA 和 DeepSeekMoE 的细节
Multi-Head Latent Attention
  • 对于注意力机制,DeepSeek-V3 采用 MLA 架构
  • 令 \(d\) 表示嵌入维度,\(n_{h}\) 表示注意力头数,\(d_{h}\) 表示每个头的维度,\(\mathbf{h}_{t}\in\mathbb{R}^{d}\) 表示给定注意力层中第 \(t\) 个 Token 的注意力输入
  • MLA 的核心是对注意力 Key 和 Value 进行低秩联合压缩,以减少推理期间的 Key-Value 缓存 (Key-Value, KV cache):
    $$
    \begin{align}
    \boxed{\mathbf{c}_{t}^{KV} } &= \mathbf{W}^{DKY}\mathbf{h}_{t},\\
    \left[\mathbf{k}_{t,1}^{C};\mathbf{k}_{t,2}^{C};…;\mathbf{k}_{t,n_{h} }^{C}\right] = \mathbf{k}_{t}^{C} &=\mathbf{W}^{UK}\mathbf{c}_{t}^{KV},\\
    \boxed{\mathbf{k}_{t}^{R}} &= \text{RoPE}(\mathbf{W}^{KR}\mathbf{h}_{t}),\\
    \mathbf{k}_{t,i} &= \left[\mathbf{k}_{t,i}^{C};\mathbf{k}_{t}^{R}\right],\\
    [\mathbf{v}_{t,1}^{C};\mathbf{v}_{t,2}^{C};…;\mathbf{v}_{t,n_{h} }^{C}] = \mathbf{v}_{t}^{C} &=\mathbf{W}^{UV}\mathbf{c}_{t}^{KV}
    \end{align}
    $$
    • 其中 \(\mathbf{c}_{t}^{KV}\in\mathbb{R}^{d_{c} }\) 是 Key 和 Value 的压缩潜在向量;
    • \(d_{c}(\ll d_{h}n_{h})\) 表示 KV 压缩维度;
    • \(W^{DKY}\in\mathbb{R}^{d_{c}\times d}\) 表示下投影矩阵;
    • \(W^{UK},W^{UV}\in\mathbb{R}^{d_{h}n_{h}\times d_{c} }\) 分别是 Key 和 Value 的上投影矩阵;
    • \(W^{KR}\in\mathbb{R}^{d_{h}^{R}\times d}\) 是用于生成携带 旋转位置嵌入 (Rotary Positional Embedding, RoPE) (2024) 的解耦 Key 的矩阵;
    • RoPE(\(\cdot\)) 表示应用 RoPE 矩阵的操作;
    • [\(\cdot\);\(\cdot\)] 表示拼接
  • 对于 MLA,在生成过程中只需要缓存蓝色框内的向量(即 \(\mathbf{c}_{t}^{KV}\) 和 \(\mathbf{k}_{t}^{R}\)),这能在保持性能与标准 多头注意力 (Multi-Head Attention, MHA) (2017) 相当的同时,显著减少 KV 缓存
  • 对于注意力 Query ,论文也执行低秩压缩,这可以减少训练期间的激活内存:
    $$
    \begin{align}
    \mathbf{c}_{t}^{Q}&=W^{DQ}\mathbf{h}_{t},\\
    \left[\mathbf{q}_{t,1}^{C};\mathbf{q}_{t,2}^{C};…;\mathbf{q}_{t,n_{h} }^{C}\right]=\mathbf{q}_{t}^{C}&=W^{UQ}\mathbf{c}_{t}^{Q},\\
    \left[\mathbf{q}_{t,1}^{R};\mathbf{q}_{t,2}^{R};…;\mathbf{q}_{t,n_{h} }^{R}\right]=\mathbf{q}_{t}^{R}&=\text{RoPE}(W^{QR}\mathbf{c}_{t}^{Q}),\\
    \mathbf{q}_{t,i}&=\left[\mathbf{q}_{t,i}^{C};\mathbf{q}_{t,i}^{R}\right],
    \end{align}
    $$
    • 其中 \(\mathbf{c}_{t}^{Q}\in\mathbb{R}^{d_{c}^{\prime} }\) 是 Query 的压缩潜在向量;
    • \(d_{c}^{\prime}(\ll d_{h}n_{h})\) 表示 Query 压缩维度;
    • \(W^{DQ}\in\mathbb{R}^{d_{c}^{\prime}\times d}\),\(W^{UQ}\in\mathbb{R}^{d_{h}n_{h}\times d_{c}^{\prime} }\) 分别是 Query 的下投影和上投影矩阵;
    • \(W^{QR}\in\mathbb{R}^{d_{h}^{R}n_{h}\times d_{c}^{\prime} }\) 是用于生成携带 RoPE 的解耦 Query 的矩阵
  • 最终,注意力 Query (\(\mathbf{q}_{t,i}\))、 Key (\(\mathbf{k}_{j,i}\)) 和值 (\(\mathbf{v}_{j,i}^{C}\)) 被组合以产生最终的注意力输出 \(\mathbf{u}_{t}\):
    $$
    \begin{align}
    \mathbf{o}_{t,i}&=\sum_{j=1}^{t}\text{Softmax}_{j}(\frac{\mathbf{q}_{t,i}^{T}\mathbf{k}_{j,i} }{\sqrt{d_{h}+d_{h}^{R} } })\mathbf{v}_{j,i}^{C},\\
    \mathbf{u}_{t}&=W^{O}\left[\mathbf{o}_{t,1};\mathbf{o}_{t,2};…;\mathbf{o}_{t,n_{h} }\right],
    \end{align}
    $$
    • 其中 \(W^{O}\in\mathbb{R}^{d\times d_{h}n_{h} }\) 表示输出投影矩阵
DeepSeekMoE with Auxiliary-Loss-Free Load Balancing
Basic Architecture of DeepSeekMoE
  • 对于 前馈网络 (Feed-Forward Networks, FFNs),DeepSeek-V3 采用 DeepSeekMoE 架构 (2024)
  • 与传统的 MoE 架构(如 CShard (2021))相比,DeepSeekMoE 使用更细粒度的专家,并将一些专家隔离为共享专家
  • 令 \(\mathbf{u}_{t}\) 表示第 \(t\) 个 Token 的 FFN 输入,论文计算 FFN 输出 \(\mathbf{h}_{t}^{\prime}\) 如下:
    $$
    \begin{align}
    \mathbf{h}_{t}^{\prime} &=\mathbf{u}_{t}+\sum_{i=1}^{N_{s} }\text{FFN }_{i}^{(s)}\left(\mathbf{u}_{t}\right)+\sum_{i=1}^{N_{r} }g_{i,t}\text{ FFN}_{i}^{(r)}\left(\mathbf{u}_{t}\right),\\
    g_{i,t} &= \frac{g_{i,t}^{\prime} }{\sum_{j=1}^{N_{r} }g_{j,t}^{\prime} },\\
    g_{i,t}^{\prime} &= \begin{cases}s_{i,t},&s_{i,t}\in\text{Topk}(\{s_{j,t}|1\leqslant j\leqslant N_{r}\},K_{r}),\\ 0,&\text{otherwise},\end{cases} \\
    s_{i,t} &= \text{Sigmoid}\left(\mathbf{u}_{t}^{T}\mathbf{e}_{t}\right),
    \end{align}
    $$
    • FFN\({}_{i}^{(s)}(\cdot)\) 表示第 \(i\) 个共享专家;
      • \(N_{s}\) 表示共享专家数量
    • FFN\({}_{i}^{(r)}(\cdot)\) 表示第 \(i\) 个路由专家;
      • 和 \(N_{r}\) 表示路由专家数量
    • \(K_{r}\) 表示激活的路由专家数量;
    • \(g_{i,t}\) 是第 \(i\) 个专家的门控值;
    • \(s_{i,t}\) 是 Token 到专家的亲和度;
    • \(\mathbf{e}_{t}\) 是第 \(i\) 个路由专家的质心向量;
    • Topk\((\cdot,K)\) 表示包含为第 \(t\) 个 Token 和所有路由专家计算的亲和度分数中 \(K\) 个最高分的集合
  • 与 DeepSeek-V2 略有不同,DeepSeek-V3 使用 sigmoid 函数计算亲和度分数,并在所有选定的亲和度分数之间应用归一化以产生门控值
Auxiliary-Loss-Free Load Balancing
  • 对于 MoE 模型,不平衡的专家负载将导致 路由崩溃 (routing collapse) (2017),并在使用专家并行的情况下降低计算效率
    • 传统的解决方案通常依赖 辅助损失 (auxiliary loss) (2021;) 来避免负载不平衡,但过大的辅助损失会损害模型性能 (Wang 等人)
  • 为了在负载均衡和模型性能之间取得更好的权衡,论文开创了一种 无辅助损失负载均衡 (auxiliary-loss-free load balancing) 策略 (Wang 等人) 来确保负载均衡
    • 论文为每个专家引入一个偏置项 \(b_{i}\),并将其加到相应的亲和度分数 \(s_{i,t}\) 上以确定 top-K 路由:
      $$g^{\prime}_{i,t}=\begin{cases}s_{i,t},&s_{i,t}+b_{i}\in\text{Topk}((s_{j,t}+b_{j}|1\leqslant j\leqslant N_{r}),K_{r}),\\ 0,&\text{otherwise}.\end{cases}$$
    • 特别说明:偏置项仅用于路由
    • 将与 FFN 输出相乘的门控值仍然来自原始亲和度分数 \(s_{i,t}\)
  • 在训练期间,论文持续监控每个训练步骤的整个批次的专家负载
    • 在每个步骤结束时,如果某个专家对应的负载过重,论文将其偏置项减少 \(y\),如果其对应的专家负载不足,则将其偏置项增加 \(y\),其中 \(y\) 是一个称为 偏置更新速度 (bias update speed) 的超参数
    • 通过这种动态调整,DeepSeek-V3 在训练期间保持均衡的专家负载,并且比那些仅通过纯辅助损失来鼓励负载均衡的模型实现了更好的性能
Complementary Sequence-Wise Auxiliary Loss
  • 尽管 DeepSeek-V3 主要依赖无辅助损失策略进行负载均衡,但为了防止任何单个序列内出现极端不平衡,论文也采用了互补的序列级均衡损失:
    $$
    \begin{align}
    \mathcal{L}_{\text{Bal} }&=\alpha\sum_{i=1}^{N_{r} }f_{i}P_{i},\\
    f_{i}&=\frac{N_{r} }{K_{r}T}\sum_{t=1}^{T}\mathbb{I}\left(s_{i,t}\in\text{Topk}(\{s_{j,t}|1\leqslant j\leqslant N_{r}\},K_{r})\right),\\
    s^{\prime}_{i,t}&=\frac{s_{i,t} }{\sum_{j=1}^{N_{r} }s_{j,t} },\\
    P_{i}&=\frac{1}{T}\sum_{t=1}^{T}s^{\prime}_{i,t},
    \end{align}
    $$
    • 平衡因子 \(\alpha\) 是一个超参数,对于 DeepSeek-V3 将被赋予一个极小的值;
    • \(\mathbb{I}(\cdot)\) 表示指示函数;
    • \(T\) 表示序列中的 Token 数
    • 序列级均衡损失鼓励每个序列上的专家负载保持均衡
    • 理解:
      • 每次计算最小化 损失 \(\mathcal{L}_{\text{Bal} }\) 时,都可以看做是一个求解约束优化问题的过程
      • 约束优化问题为:
        $$
        \begin{align}
        \min_{P_i} \quad &\sum_{i=1}^{N_{r} }f_{i}P_{i} \\
        \text{s.t.} &\sum_i P_i = 1
        \end{align}
        $$
        • 求解约束优化问题时,系数 \(f_{i}\) 可以看做是固定值(是每个序列的统计值,不同序列该值不同)
        • 变量是 \(P_{i}\) ,需要满足一定约束
          $$ \sum_i P_i=1 $$
          • 该约束可以通过定义推导得到:
            $$ \sum_j P_i = \sum_i \frac{1}{T}\sum_{t=1}^{T}s^{\prime}_{i,t} = \sum_i \frac{1}{T}\sum_{t=1}^{T} \frac{s_{i,t} }{\sum_{j=1}^{N_{r} }s_{j,t} } = \frac{1}{T}\sum_{t=1}^{T} \sum_i \frac{s_{i,t} }{\sum_{j=1}^{N_{r} }s_{j,t} } = \frac{1}{T}\sum_{t=1}^{T} 1 = \frac{1}{T}T = 1$$
        • 直观上看,最小化 \(\sum_{j=1}^{D+1}f_iP_i\) 的解就是让概率 \(P_i\) 随着 \(f_i\) 变化, \(f_i\) 越小,则 \(P_i\) 应该越大
          • 这样才能才能满足最小化 \(\sum_{j=1}^{D+1}f_iP_i\)
        • 从梯度上看
          $$ \frac{\partial \mathcal{L}_{\text{ LB} }}{\partial P_i} = \alpha f_i $$
          • 对于 \(f_i\) 越大的组,其概率 \(P_i\) 下降的越多
          • 进一步理解:\(P_i\) 下降是通过调整模型 router 参数实现的,这会导致参数更新后下一轮中真实统计值 \(f_i\) 下降,最终会收敛到一个大家的真实分配统计值 \(f_i\) 都差不多相同的地方(这也就实现了所谓的均衡)
Node-Limited Routing
  • 与 DeepSeek-V2 使用的设备受限路由类似,DeepSeek-V3 也使用一种受限路由机制来限制训练期间的通信成本
    • 注:该方法在 DeepSeek-V2 中被称为 Device-Limited Routing
      • 对于 DeepSeek-V2,除了对被路由专家进行简单的 top-K 选择外,还额外确保每个 Token 的目标专家最多分布在 \(M\) 个设备上
      • 具体来说,对于每个 Token
        • 首先筛选出 \(M\) 个设备,这些设备中包含亲和度得分(affinity scores)最高的专家;
        • 随后,在这 \(M\) 个设备的专家中执行 top-K 选择
      • 在实际实验中作者发现,当 \(M \ge 3\) 时,这种设备受限路由能够实现良好的性能,其效果与无限制的 top-K 路由基本持平
  • 简而言之,论文确保每个 Token 最多被发送到 \(M\) 个节点 ,这些节点是根据分布在该节点上的专家的最高 \(\frac{K}{M}\) 个亲和度分数之和来选择的
    • 理解(为什么是选择依据是每个节点最高 \(\frac{K}{M}\) 个亲和度之和?):对于每个 Token,需要分发到共 \(K\) 个专家,在限定只能发送到 \(M\) 个节点的情况下,则每个节点平均需要负责这个 Token 的 \(\frac{K}{M}\) 个专家,所以根据每个节点的最高 \(\frac{K}{M}\) 个亲和度之和来选择最终的 \(M\) 个节点
  • 在此约束下,论文的 MoE 训练框架几乎可以实现完全的计算-通信重叠
    • 理解:因为限定了每个 Token 需要分发的节点数量,这种情况下可以大幅减少通信开销,避免一个 Token 被发送到太多的节点而造成通信开销太大的情况
No Token-Dropping
  • 由于有效的负载均衡策略,DeepSeek-V3 在整个训练过程中保持良好的负载均衡,DeepSeek-V3 在训练期间不丢弃任何 Token
  • 论文还实现了特定的部署策略以确保推理时的负载均衡,因此 DeepSeek-V3 在推理期间也不丢弃 Token

Multi-token Prediction

  • 受 Gloeckle 等人 (2024) 的启发,论文研究并为 DeepSeek-V3 设定了 MTP 目标,该目标将每个位置的预测范围扩展到多个未来 Token
    • 1)MTP 目标使训练信号更加密集,并可能提高数据效率
    • 2)MTP 可能使模型能够预规划其表示,以更好地预测未来 Token
  • 图 3 展示了论文的 MTP 实现
    • 与 Gloeckle 等人 (2024) 使用独立的输出头并行预测 \(D\) 个额外 Token 不同,论文顺序预测额外 Token ,并在每个预测深度保持完整的因果链
  • 注:上图是训练的流程,按照 Teacher Forcing 方式训练的,所以 MTP 头的输入是真实的 Token,在推理时,会使用主网络的输出作为第二个 Token 的输入,以此类推,详情见后文
MTP Modules
  • 具体来说,论文的 MTP 实现使用 \(D\) 个顺序模块来预测 \(D\) 个额外 Token
    • 注:从图3可以看出,并行头只在模型的最后一层(主模型的 Head 层前)上增加,而不是在 Transformer 网络前面的层增加
    • 注:加入的内容是单层的 Transformer 再加一个 Head
  • 第 \(k\) 个 MTP 模块包含一个共享嵌入层 \(\text{Emb}(\cdot)\)、一个共享输出头 \(\text{OutHead}(\cdot)\)、一个 Transformer 块 \(\text{TRM}_{k}(\cdot)\) 和一个投影矩阵 \(M_{k}\in\mathbb{R}^{d\times 2d}\)
  • 对于第 \(i\) 个输入 Token \(t_{i}\),在第 \(k\) 个预测深度(注:这里的预测深度是指往后预测的 Token 数,不是 Transformer 网络层深度)
    • 论文首先将第 \(i\) 个 Token 在第 \((k-1)\) 个深度的表示 \(\mathbf{h}_{i}^{k-1}\in\mathbb{R}^{d}\) 与第 \((i+k)\) 个 Token 的嵌入通过线性投影相结合:
      $$\mathbf{h}_{i}^{ \prime{k} }=M_{k}[\textrm{RMSNorm}(\mathbf{h}_{i}^{ {k}-1});\textrm{RMSNorm}(\textrm{Emb}(t_{i+k}))],$$
    • 其中 \([.;.]\) 表示拼接
    • 特别地,当 \(k=1\) 时,\(\mathbf{h}_{i}^{ {k}-1}\) 指的是主模型给出的表示
    • 特别注意:对于每个 MTP 模块,其嵌入层与主模型共享
    • 组合后的 \(\mathbf{h}_{i}^{ {k} }\) 作为第 \(k\) 个深度的 Transformer 块的输入,以产生当前深度的输出表示 \(\mathbf{h}_{i}^{ {k} }\):
      $$\mathbf{h}_{1:T-k}^{ {k} }=\textrm{TRM}_{k}(\mathbf{h}_{1:T-k}^{ \prime{k} }),$$
    • 其中 \(T\) 表示输入序列长度,\({}_{ {i:j} }\) 表示切片操作(包括左右边界)
    • 最后,以 \(\mathbf{h}_{i}^{ {k} }\) 作为输入,共享输出头将计算第 \(k\) 个额外预测 Token 的概率分布 \(p_{ {i+1+k} }^{ {k} }\in\mathbb{R}^{V}\),其中 \(V\) 是词汇表大小:
      $$p_{ {i+k+1} }^{ {k} }=\textrm{OutHead}(\mathbf{h}_{i}^{ {k} }).$$
  • 输出头 \(\textrm{OutHead}(\cdot)\) 将表示线性映射为 logits,随后应用 \(\textrm{Softmax}(\cdot)\) 函数来计算第 \(k\) 个额外 Token 的预测概率
    • 同样,对于每个 MTP 模块,其输出头与主模型共享
    • 论文保持预测因果链的原则与 EAGLE (Li 等人) 类似,但 EAGLE 主要目标是 推测解码 (speculative decoding) (2023; 2023),而论文利用 MTP 来改进训练
MTP Training Objective
  • 对于每个预测深度,论文计算一个交叉熵损失 \(\mathcal{L}_{\textrm{MTP} }^{ {k} }\):
    $$\mathcal{L}_{\textrm{MTP} }^{ {k} }=\textrm{CrossEntropy}(p_{2+ k:T+1}^{ {k} },t_{2+k:T+1})=\frac{1}{T}\sum_{i=2+k}^{T+1}\log p_ {i}^{ {k} }[t_{i}],$$
    • 其中 \(T\) 表示输入序列长度,\(t_{i}\) 表示第 \(i\) 个位置的真实 Token
    • \(p_{i}^{ {k} }[t_{i}]\) 表示由第 \(k\) 个 MTP 模块给出的 \(t_{i}\) 的相应预测概率
  • 最后,论文计算所有深度上 MTP 损失的平均值,并将其乘以权重因子 \(\lambda\) 以获得总体 MTP 损失 \(\mathcal{L}_{\textrm{MTP} }\),该损失作为 DeepSeek-V3 的额外训练目标:
    $$\mathcal{L}_{\textrm{MTP} }=\frac{\lambda}{D}\sum_{k=1}^{D}\mathcal{L}_{\textrm{ MTP} }^{ {k} }.$$
MTP in Inference
  • 论文的 MTP 策略主要旨在提高主模型的性能 ,因此在推理期间,我们可以直接丢弃 MTP 模块,主模型可以独立且正常地运行
  • 也可以将这些 MTP 模块重新用于推测解码,以进一步改善生成延迟
  • 理解:实际上,后续的其他工作来看,还是提高主模型本身的性能用的多
  • 额外补充:关于 MTP 推理的其他讲解
    • 参考链接:deepseek技术解读(2)-MTP(Multi-token Prediction)的前世今生 - 姜富春的文章 - 知乎
    • 基本思想:
      • 预测阶段(Predict) :通过 K 个头一次生成 K 个 Token 的预测
      • 验证阶段(Verify) :将 K 个 Token 组装成 K 个 <input,label> 对,并行地利用输入 Main Model 作为评估验证,如果输出 label 与 Main Model 一致,则接受该 token
      • 接受阶段(Accept) :最终接受满足 Main Model 的最大长度 tokens 作为输出
    • 但是,从上述 图3 来看,多头预估时,预测 \(t^{\prime}_3\) 的输入包括了 \(t_1\) 的 Transformer 最后一层表征和 \(t_2\) 的 Embedding(即 Teacher Forcing 的形式),所以推理时是无法直接使用的,流程同上述方法一样,但应该使用下图的形式:
  • 特别说明:虽然主头的预测和后面头的预测是串行的,但是实际上,后面的头进需要走一个 Transformer 层,速度是很快的,和动辄几十层的主网络比起来(比如 DeepSeek-V3 是 61 层),算是很快的了(几乎可以认为是并行了)

Infrastructures

Compute Clusters

  • DeepSeek-V3 在一个配备有 2048 个 NVIDIA H800 GPU 的集群上进行训练
  • H800 集群中的每个节点包含 8 个通过节点内的 NVLink 和 NVSwitch 连接的 GPU
  • 在不同节点之间,利用 InfiniBand (IB) 互连来促进通信

Training Framework

  • DeepSeek-V3 的训练由 HAI-LLM 框架支持,这是一个由论文的工程师从头开始精心打造的、高效且轻量级的训练框架
  • DeepSeek-V3 应用了 16 路流水线并行 (Pipeline Parallelism, PP) (2023b)、跨越 8 个节点的 64 路专家并行 (Expert Parallelism, EP) (2021) 和 ZeRO-1 数据并行 (Data Parallelism, DP) (2020)
  • 为了促进 DeepSeek-V3 的高效训练,论文实施了细致的工程优化
    • 首先,论文设计了 DualPipe 算法用于高效的流水线并行
      • 与现有的 PP 方法相比,DualPipe 具有更少的流水线气泡
      • 更重要的是,它在前向和后向过程中重叠了计算和通信阶段,从而解决了跨节点专家并行引入的沉重通信开销的挑战
    • 其次,论文开发了高效的跨节点 All-to-All 通信内核,以充分利用 IB 和 NVLink 的带宽,并节省专用于通信的流式多处理器 (Streaming Multiprocessors, SMs)
    • 最后,论文精心优化了训练期间的内存占用,从而使得论文能够在训练 DeepSeek-V3 时不使用昂贵的张量并行 (Tensor Parallelism, TP)
DualPipe and Computation-Communication Overlap
  • 对于 DeepSeek-V3,跨节点专家并行引入的通信开销导致了大约 1:1 的低效计算-通信比
  • 为了应对这一挑战,论文设计了一种名为 DualPipe 的创新流水线并行算法,该算法不仅通过有效重叠前向和后向计算-通信阶段来加速模型训练,而且还减少了流水线气泡
  • DualPipe 的关键思想是在一对独立的前向和后向块 (chunk) 内重叠计算和通信
    • 论文将每个块划分为四个组件:注意力 (attention)、 All-to-All 分发 (all-to-all dispatch)、MLP 和 All-to-All 合并 (all-to-all combine)
    • 特别地,对于一个后向块,注意力和 MLP 都进一步分为两部分:对输入的梯度 (backward for input) 和对权重的梯度 (backward for weights),类似于 ZeroBubble (2023a) 中的做法
    • 此外,论文还有一个 PP 通信组件
  • 如图 4 所示,对于一对前向和后向块,论文重新排列了这些组件,并手动调整了专用于通信与计算的 GPU SMs 的比例
    • 在这种重叠策略中,我们可以确保 All-to-All 通信和 PP 通信在执行过程中都能被完全隐藏
    • 颜色上看:
      • 绿色(B)表示 backward for input
      • 蓝色(W)表示 backward for weight
      • 橙色(F)表示 forward
    • 通信上看:
      • 绿色(B)表示 backward 的通信
      • 橙色(F)表示 forward 的通信
      • 紫色(PP)表示 PP 的通信
  • 鉴于高效的重叠策略,完整的 DualPipe 调度如图 5 所示
    • 它采用了双向流水线调度,同时从流水线的两端输入微批次 (micro-batches),并且大部分通信可以完全重叠
    • 这种重叠也确保了,随着模型的进一步扩展,只要论文保持恒定的计算-通信比,论文仍然可以在跨节点使用细粒度专家,同时实现接近零的 All-to-All 通信开销
  • 即使在通信负担较轻的更一般情况下,DualPipe 仍然展现出效率优势
    • 在表 2 中,论文总结了不同 PP 方法的流水线气泡和内存使用情况
    • 如表所示,与 ZB1P (2023b) 和 1F1B (2018) 相比,DualPipe 显著减少了流水线气泡,同时仅将峰值激活内存增加了 \(\frac{1}{PP}\) 倍
    • 尽管 DualPipe 需要保留两份模型参数 ,但这并不会显著增加内存消耗 ,因为论文在训练期间使用了较大的 EP 规模
    • 与 Chimera (2021) 相比,DualPipe 仅要求流水线阶段和微批次能被 2 整除,而不要求微批次能被流水线阶段整除
    • 此外,对于 DualPipe,气泡和激活内存都不会随着微批次数量的增加而增加
Efficient Implementation of Cross-Node All-to-All Communication
  • 为了确保 DualPipe 具有足够的计算性能,论文定制了高效的跨节点 All-to-All 通信内核(包括分发和合并),以节省专用于通信的 SMs 数量
    • 这些内核的实现与论文的 MoE 门控算法和集群的网络拓扑结构协同设计
  • 具体来说,在论文的集群中,跨节点 GPU 通过 IB 完全互连,节点内通信通过 NVLink 处理
    • NVLink 提供 160 GB/s 的带宽,大约是 IB (50 GB/s) 的 3.2 倍
    • 为了有效利用 IB 和 NVLink 的不同带宽,论文将每个 Token 限制为最多分发到 4 个节点,从而减少 IB 流量
    • 对于每个 Token ,当做出路由决策时,它将首先通过 IB 传输到其目标节点上具有相同节点内索引的 GPU
    • 一旦到达目标节点,论文将尽力确保它通过 NVLink 瞬时转发到承载其目标专家的特定 GPU,而不会被后续到达的 Token 阻塞
    • 通过这种方式,通过 IB 和 NVLink 的通信完全重叠,每个 Token 可以高效地平均在每个节点选择 3.2 个专家,而不会产生 NVLink 的额外开销
    • 这意味着,尽管 DeepSeek-V3 在实践中仅选择 8 个路由专家,但它可以将此数量最多扩展到 13 个专家(4 个节点 × 3.2 个专家/节点),同时保持相同的通信成本
    • 总体而言,在这种通信策略下,仅需 20 个 SMs 就足以充分利用 IB 和 NVLink 的带宽
  • 详细来说,论文采用了 warp 专业化技术 (2014),并将 20 个 SMs 划分为 10 个通信通道,在分发过程中下面的操作由各自的 warp 处理
    • (1) IB 发送
    • (2) IB 到 NVLink 转发
    • (3) NVLink 接收
  • 分配给每个通信任务的 warp 数量根据所有 SMs 上的实际工作负载动态调整,类似地,在合并过程中下面的步骤也有由动态调整的 warp 处理
    • (1) NVLink 发送
    • (2) NVLink 到 IB 转发和累加
    • (3) IB 接收和累加
  • 此外,分发和合并内核都与计算流重叠,因此论文也考虑了它们对其他 SM 计算内核的影响
    • 具体来说,论文采用了定制的 PTX (Parallel Thread Execution) 指令并自动调整通信块大小,这显著减少了 L2 缓存的使用以及对其他 SMs 的干扰
Extremely Memory Saving with Minimal Overhead
  • 为了减少训练期间的内存占用,论文采用了以下技术
  • RMSNorm 和 MLA 上投影的重计算 (Recomputation of RMSNorm and MLA Up-Projection)
    • 论文在反向传播期间重新计算所有 RMSNorm 操作和 MLA 上投影,从而无需持久存储它们的输出激活
    • 以微小的开销为代价,该策略显著减少了存储激活所需的内存
  • CPU 中的指数移动平均 (Exponential Moving Average in CPU)
    • 在训练期间,论文保留模型参数的指数移动平均 (EMA) 用于在学习率衰减后早期估计模型性能
      • EMA 参数存储在 CPU 内存中,并在每个训练步骤后异步更新
    • 这种方法使论文能够维护 EMA 参数,而不会产生额外的内存或时间开销
  • Multi-token 预测的共享嵌入和输出头 (Shared Embedding and Output Head for Multi-token Prediction)。通过 DualPipe 策略,论文将模型的最浅层(包括嵌入层)和最深层(包括输出头)部署在同一个 PP 排名 (rank) 上。这种安排使得 MTP 模块和主模型之间能够物理共享共享嵌入和输出头的参数和梯度。这种物理共享机制进一步提高了论文的内存效率

FP8 Training

  • 受低精度训练最新进展 (2022; 2023b) 的启发,论文提出了一个利用 FP8 数据格式的细粒度混合精度框架来训练 DeepSeek-V3
  • 虽然低精度训练前景广阔,但它通常受到激活、权重和梯度中异常值 (outliers) 存在的限制 (2024; 2024)
    • 尽管在推理量化方面取得了显著进展 (2022; 2023),但相对较少的研究证明了低精度技术在大规模语言模型预训练中的成功应用 (2024)
  • 为了应对这一挑战并有效扩展 FP8 格式的动态范围,论文引入了一种细粒度的量化策略:
    • 使用 \(1\times N_{c}\) 元素的 tile-wise 分组 (tile-wise grouping) 或 \(N_{c}\times N_{c}\) 元素的块级分组 (block-wise grouping)
      • 注:tile-wise 说明量化维度是按照块做的,不是整个张量统一量化(tile 粒度比张量粒度更细)
    • 相关的反量化开销在论文提高了精度的累加过程下得到了很大程度的缓解,这是实现精确 FP8 通用矩阵乘法 (General Matrix Multiplication, GEMM) 的关键方面
    • 为了进一步减少 MoE 训练中的内存和通信开销,论文以 FP8 格式缓存和分发激活,同时以 BF16 格式存储低精度的优化器状态
    • 论文在两个与 DeepSeek-V2-Lite 和 DeepSeek-V2 规模相似的模型上验证了所提出的 FP8 混合精度框架,训练了大约 1T 个 Token(更多细节见附录 B.1)
  • 与 BF16 基线相比,论文的 FP8 训练模型的相对损失误差始终低于 0.25%,这一水平完全在训练随机性的可接受范围内
Mixed Precision Framework
  • 基于低精度训练中广泛采用的技术 (2019; 2017),论文提出了一个用于 FP8 训练的混合精度框架
  • 在此框架中,大多数计算密集型操作以 FP8 精度进行,而一些关键操作则策略性地保持其原始数据格式,以平衡训练效率和数值稳定性
  • 整体框架如图 6 所示
    • 首先,为了加速模型训练,大多数核心计算内核,即 GEMM 操作,都以 FP8 精度实现
    • 这些 GEMM 操作接受 FP8 张量作为输入,并产生 BF16 或 FP32 的输出。如图 6 所示,与线性算子 (Linear operator) 相关的所有三个 GEMM,即 Fprop(前向传播)、Dgrad(激活反向传播)和 Wgrad(权重反向传播),都在 FP8 中执行
    • 该设计理论上比原始的 BF16 方法提高一倍的计算速度
    • 此外,FP8 的 Wgrad GEMM 允许激活以 FP8 格式存储,用于反向传播。这显著减少了内存消耗
  • 尽管 FP8 格式具有效率优势,但某些算子由于其对低精度计算的敏感性,仍然需要更高的精度
    • 此外,一些低成本算子也可以使用更高的精度,而对整体训练成本的开销可以忽略不计
    • 因此,经过仔细研究,论文为以下组件保留了原始精度(例如 BF16 或 FP32):嵌入模块、输出头、MoE 门控模块、归一化算子和注意力算子
    • 这些有针对性的高精度保留确保了 DeepSeek-V3 的训练动态稳定性
    • 为了进一步保证数值稳定性,论文以更高的精度存储主权重 (master weights)、权重梯度和优化器状态
    • 虽然这些高精度组件会产生一些内存开销,但它们的影响可以通过在论文的分布式训练系统中跨多个 DP 排名进行高效分片来最小化
Improved Precision from Quantization and Multiplication
  • 基于论文的混合精度 FP8 框架,论文引入了若干策略来增强低精度训练的准确性,重点关注量化方法和乘法过程
Fine-Grained Quantization
  • 在低精度训练框架中,由于 FP8 格式的有限动态范围(受限于其减少的指数位),溢出和下溢是常见的挑战
    • 作为标准实践,通过将输入张量的最大绝对值缩放到 FP8 格式的最大可表示值来将输入分布对齐到 FP8 格式的可表示范围 (2017)
  • 这种方法使得低精度训练对激活异常值高度敏感,这会严重降低量化精度
  • 为了解决这个问题,论文提出了一种细粒度量化方法,在更细粒度的级别上应用缩放
  • 如图 7 (a) 所示
    • (1) 对于激活,论文在 1x128 的 tile 基础上(即每个 Token 每 128 个通道)对元素进行分组和缩放;
    • (2) 对于权重,论文在 128x128 的块基础上(即每 128 个输入通道每 128 个输出通道)对元素进行分组和缩放
  • 这种方法通过根据更小的元素组调整缩放比例,确保量化过程能更好地适应异常值
  • 在附录 B.2 中,论文进一步讨论了当论文以与权重量化相同的方式对激活进行块级分组和缩放时出现的训练不稳定性
  • 论文方法中的一个关键修改是引入了沿 GEMM 操作内部维度 (inner dimension) 的每组缩放因子 (per-group scaling factors)
    • 此功能在标准的 FP8 GEMM 中并不直接支持
    • 但结合论文精确的 FP32 累加策略,它可以被高效实现
  • 值得注意的是,论文的细粒度量化策略与微缩放格式 (microscaling formats) 的思想高度一致 (2023a),而 NVIDIA 下一代 GPU(Blackwell 系列)的 Tensor Cores 已宣布支持具有更小量化粒度的微缩放格式 (2022b)
    • 作者希望论文的设计能为未来的工作提供参考,以跟上最新的 GPU 架构
Increasing Accumulation Precision
  • 低精度 GEMM 操作经常遭受下溢问题,其精度在很大程度上依赖于高精度累加,这通常以 FP32 精度执行 (2019; 2017)
    • 但论文观察到在 NVIDIA H800 GPU 上,FP8 GEMM 的累加精度仅限于保留大约 14 位,这显著低于 FP32 的累加精度
    • 当内部维度 K 很大时,这个问题会更加明显 (2023),这在大规模模型训练中增加批大小和模型宽度时是典型情况
    • 以两个随机矩阵的 GEMM 操作为例,其中 \(\textit{K}=4096\),在论文的初步测试中,Tensor Cores 中有限的累加精度导致最大相对误差接近 \(2%\)
  • 尽管存在这些问题,有限的累加精度仍然是一些 FP8 框架中的默认选项 (2022c),严重限制了训练精度
  • 为了解决这个问题,论文采用了提升到 CUDA Cores 以获得更高精度的策略 (2023)
    • 该过程如图 7 (b) 所示
    • 在 Tensor Cores 上执行 MMA(矩阵乘积累加)期间,中间结果使用有限的位宽进行累加
    • 一旦达到 \(N_c\) 个元素的间隔,这些部分结果将被复制到 CUDA Cores 上的 FP32 寄存器中,在那里执行全精度的 FP32 累加
    • 如前所述,论文的细粒度量化沿内部维度 K 应用每组缩放因子
    • 这些缩放因子可以作为反量化过程在 CUDA Cores 上高效地相乘,而只需最小的额外计算成本
  • 值得注意的是,这种修改降低了单个 warpgroup 的 WGMMA (Warpgroup-level Matrix Multiply-Accumulate) 指令发出率
    • 但在 H800 架构上,通常可以同时维持两个 WGMMA:当一个 warpgroup 执行提升操作时,另一个能够执行 MMA 操作
    • 这种设计使得两个操作能够重叠,保持了 Tensor Cores 的高利用率
    • 根据论文的实验,设置 \(N_c\) = 128 个元素,相当于 4 个 WGMMAs,是能够显著提高精度而不引入大量开销的最小累加间隔
Mantissa over Exponents(尾数优先于指数)
  • 与先前工作采用的混合 FP8 格式 (2022c; 2023b;) 相比,后者在 Fprop 中使用 E4M3(4 位指数和 3 位尾数),在 Dgrad 和 Wgrad 中使用 E5M2(5 位指数和 2 位尾数),论文在所有张量上采用 E4M3 格式以获得更高的精度
  • 论文将此方法的可行性归功于论文的细粒度量化策略,即 tile 和块级缩放
  • 通过在更小的元素组上操作,论文的方法有效地在这些分组元素之间共享指数位,从而减轻了有限动态范围的影响
Online Quantization
  • 张量级量化框架 (2022c; 2023b) 中采用了延迟量化 (delayed quantization),它维护先前迭代的最大绝对值历史记录以推断当前值
  • 为了确保准确的缩放因子并简化框架,论文为每个 1x128 激活 tile 或 128x128 权重块在线计算最大绝对值
  • 基于此,论文推导出缩放因子,然后将激活或权重在线量化为 FP8 格式
Low-Precision Storage and Communication
  • 结合论文的 FP8 训练框架,论文通过将缓存的激活和优化器状态压缩成更低精度的格式,进一步减少了内存消耗和通信开销
  • 低精度优化器状态 (Low-Precision Optimizer States)
    • 论文采用 BF16 数据格式而不是 FP32 来跟踪 AdamW (2017) 优化器中的一阶矩和二阶矩,而不会引起可观察到的性能下降
    • 然而,主权重(由优化器存储)和梯度(用于批大小累加)仍然保留在 FP32 中,以确保整个训练过程中的数值稳定性
  • 低精度激活 (Low-Precision Activation)
    • 如图 6 所示,Wgrad 操作以 FP8 精度执行
    • 为了减少内存消耗,一个自然的选择是以 FP8 格式缓存激活,用于线性算子的反向传播。然而,论文对几个算子进行了特殊考虑,以实现低成本的高精度训练:
      • (1) 注意力算子后的线性算子的输入
        • 这些激活也用于注意力算子的反向传播,这使得它对精度敏感
        • 论文为这些激活专门采用了一种定制的 E5M6 数据格式
        • 此外,这些激活在反向传播中将从 1x128 量化 tile 转换为 128x1 tile
        • 为了避免引入额外的量化误差,所有的缩放因子都是 2 的整数幂舍入缩放
      • (2) MoE 中 SwiGLU 算子的输入
        • 为了进一步降低内存成本,论文缓存 SwiGLU 算子的输入,并在反向传播中重新计算其输出
        • 这些激活也使用论文的细粒度量化方法以 FP8 格式存储,在内存效率和计算精度之间取得了平衡
  • 低精度通信 (Low-Precision Communication)
    • 通信带宽是 MoE 模型训练的关键瓶颈
    • 为了缓解这一挑战,论文在 MoE 上投影之前将激活量化为 FP8,然后应用分发组件,这与 MoE 上投影中的 FP8 Fprop 兼容
      • 与注意力算子后的线性算子的输入类似,此激活的缩放因子是 2 的整数幂
      • 类似的策略应用于 MoE 下投影之前的激活梯度
    • 对于前向和后向的合并组件,论文将它们保留在 BF16 中,以在训练管道的关键部分保持训练精度

Inference and Deployment

  • 论文将 DeepSeek-V3 部署在 H800 集群上,其中每个节点内的 GPU 使用 NVLink 互连,集群中的所有 GPU 通过 IB 完全互连
  • 为了同时确保在线服务的服务水平目标 (Service-Level Objective, SLO) 和高吞吐量,论文采用了以下将预填充 (prefilling) 和解码 (decoding) 阶段分离的部署策略
Prefilling
  • 预填充阶段的最小部署单元由 4 个节点(32 个 GPU)组成
  • 注意力部分采用 4 路张量并行 (Tensor Parallelism, TP4) 结合序列并行 (Sequence Parallelism, SP),以及 8 路数据并行 (Data Parallelism, DP8)
  • 其较小的 TP 规模(4)限制了 TP 通信的开销
  • 对于 MoE 部分,论文使用 32 路专家并行 (Expert Parallelism, EP32),这确保了每个专家处理足够大的批大小,从而提高了计算效率
  • 对于 MoE All-to-All 通信,论文使用与训练相同的方法:
    • 首先通过 IB 跨节点传输 Token ,然后通过 NVLink 在节点内 GPU 之间转发
    • 特别地,论文对浅层的稠密 MLP 使用 1 路张量并行以节省 TP 通信
  • 为了实现 MoE 部分中不同专家之间的负载平衡,论文需要确保每个 GPU 处理大致相同数量的 Token
  • 为此,论文引入了冗余专家 (redundant experts) 的部署策略,该策略复制高负载专家并冗余部署它们
    • 高负载专家是基于在线部署期间收集的统计信息检测出来的,并定期调整(例如,每 10 分钟)
    • 在确定了冗余专家集合后,论文根据观察到的负载,仔细地在节点内的 GPU 之间重新安排专家,力求在不增加跨节点 All-to-All 通信开销的情况下,尽可能平衡 GPU 间的负载
    • 对于 DeepSeek-V3 的部署,论文为预填充阶段设置了 32 个冗余专家
    • 对于每个 GPU,除了它原本承载的 8 个专家外,它还将承载一个额外的冗余专家
  • 在预填充阶段,为了提高吞吐量并隐藏 All-to-All 和 TP 通信的开销,论文同时处理两个计算工作量相似的微批次,将一个微批次的注意力和 MoE 与另一个微批次的分发和合并重叠起来
  • 最后,作者正在探索一种专家的动态冗余 (dynamic redundancy) 策略,其中每个 GPU 承载更多专家(例如,16 个专家),但在每次推理步骤中只激活 9 个
    • 在每层的 All-to-All 操作开始之前,论文实时计算全局最优的路由方案
    • 鉴于预填充阶段涉及大量计算,计算此路由方案的开销几乎可以忽略不计
Decoding
  • 在解码期间,论文将共享专家视为路由专家
    • 从这个角度来看,每个 Token 将在路由期间选择 9 个专家,其中共享专家被视为一个总是被选中的高负载专家
  • 解码阶段的最小部署单元由 40 个节点(320 个 GPU)组成
    • 注意力部分采用 TP4 结合 SP,以及 DP80,而 MoE 部分使用 EP320
    • 对于 MoE 部分,每个 GPU 仅承载一个专家,其中 64 个 GPU 负责承载冗余专家和共享专家
  • 分发和合并部分的 All-to-All 通信通过 IB 上的直接点对点传输进行,以实现低延迟
    • 论文还利用 IBSDA (2022a) 技术进一步最小化延迟并提高通信效率
  • 与预填充类似,论文基于在线服务的统计专家负载,以一定的间隔定期确定冗余专家集合
    • 但论文不需要重新安排专家,因为每个 GPU 只承载一个专家
  • 论文也在探索用于解码的动态冗余策略。然而,这需要更仔细地优化计算全局最优路由方案的算法以及与分发内核的融合以减少开销
  • 为了提高吞吐量并隐藏 All-to-All 通信的开销,论文也在探索在解码阶段同时处理两个计算工作量相似的微批次
    • 与预填充不同,注意力在解码阶段消耗的时间比例更大
    • 因此,论文将一个微批次的注意力与另一个微批次的(分发+MoE+合并)重叠起来
    • 在解码阶段,每个专家的批大小相对较小(通常在 256 个 Token 以内),瓶颈是内存访问而非计算
    • 由于 MoE 部分只需要加载一个专家的参数,内存访问开销很小,因此使用较少的 SMs 不会显著影响整体性能
    • 因此,为了避免影响注意力部分的计算速度,我们可以只分配一小部分 SMs 给(分发+MoE+合并)

Suggestions on Hardware Design

  • 基于论文 All-to-All 通信和 FP8 训练方案的实现,论文向 AI 硬件供应商提出以下芯片设计建议
Communication Hardware
  • 在 DeepSeek-V3 中,论文实现了计算和通信之间的重叠 ,以在计算期间隐藏通信延迟
    • 与串行计算和通信相比,这显著降低了对通信带宽的依赖
  • 但当前的通信实现依赖于昂贵的 SMs(例如,论文分配了 H800 GPU 中可用的 132 个 SMs 中的 20 个用于此目的),这将限制计算吞吐量
    • 而且使用 SMs 进行通信会导致显著的效率低下,因为 Tensor Cores 完全未被充分利用
  • 目前,SMs 主要为 All-to-All 通信执行以下任务:
    • 在 IB (InfiniBand) 和 NVLink 域之间转发数据 ,同时聚合来自单个 GPU、目的地为同一节点内多个 GPU 的 IB 流量
    • 在 RDMA 缓冲区(已注册的 GPU 内存区域)和输入/输出缓冲区之间传输数据
    • 执行归约操作以进行 All-to-All 合并
    • 在跨 IB 和 NVLink 域的分块数据传输到多个专家期间,管理细粒度的内存布局
  • 论文期望未来的供应商能够开发出将这些通信任务从宝贵的计算单元 SM 上卸载的硬件,作为 GPU 协处理器或类似 NVIDIA SHARP (2016) 的网络协处理器
    • 此外,为了降低应用程序编程的复杂性,作者希望这种硬件能从计算单元的角度统一 IB(横向扩展)和 NVLink(纵向扩展)网络
    • 通过这种统一的接口,计算单元可以基于简单的原语提交通信请求,轻松地在整个 IB-NVLink 统一域中完成诸如读、写、多播和归约等操作
Compute Hardware
  • Tensor Cores 中更高的 FP8 GEMM 累加精度 (Higher FP8 GEMM Accumulation Precision in Tensor Cores)
    • 在 NVIDIA Hopper 架构当前的 Tensor Core 实现中,FP8 GEMM 受到有限累加精度的困扰
    • 在基于最大指数通过右移对齐 32 个尾数乘积后,Tensor Core 仅使用每个尾数乘积的最高 14 位进行加法,并截断超出此范围的位(加法结果累加到寄存器中也采用 14 位精度)
    • 论文的实现通过将 128 次 FP8×FP8 乘法的加法结果累加到 CUDA core 中具有 FP32 精度的寄存器中,部分缓解了这一限制
    • 尽管这有助于实现成功的 FP8 训练,但这仅仅是由于 Hopper 架构在 FP8 GEMM 累加精度方面的硬件缺陷而做出的妥协
    • 未来的芯片需要采用更高的精度
  • 支持 tile 和块级量化 (Support for Tile- and Block-Wise Quantization)
    • 当前的 GPU 仅支持每张量 (per-tensor) 量化,缺乏对论文 tile 和块级量化等细粒度量化的原生支持
    • 在当前实现中,当达到 \(N_{C}\) 间隔时,部分结果将从 Tensor Cores 复制到 CUDA cores,乘以缩放因子,并添加到 CUDA cores 上的 FP32 寄存器中
    • 尽管结合论文精确的 FP32 累加策略,反量化开销得到了显著缓解,但 Tensor Cores 和 CUDA cores 之间频繁的数据移动仍然限制了计算效率
    • 论文建议未来的芯片通过使 Tensor Cores 能够接收缩放因子并实现具有组缩放的 MMA 来支持细粒度量化
    • 通过这种方式,整个部分和累加和反量化可以直接在 Tensor Cores 内部完成,直到产生最终结果,从而避免频繁的数据移动
  • 支持在线量化 (Support for Online Quantization)
    • 尽管论文的研究证明了在线量化的有效性,但当前的实现难以有效支持它
    • 在现有流程中,论文需要从 HBM(高带宽内存)中读取 128 个 BF16 激活值(先前计算的输出)进行量化,然后量化后的 FP8 值写回 HBM,只是为了再次读取用于 MMA
    • 为了解决这种低效率问题,论文建议未来的芯片将 FP8 转换和 TMA(Tensor Memory Accelerator)访问融合到单个融合操作中,这样量化可以在激活从全局内存传输到共享内存的过程中完成,避免频繁的内存读写
    • 论文还建议支持 warp 级的转换指令以加速,这进一步促进了层归一化和 FP8 转换的更好融合
      • 或者,可以采用近内存计算方法,将计算逻辑放置在 HBM 附近
    • 在这种情况下,BF16 元素在从 HBM 读入 GPU 时可以直接转换为 FP8,将片外内存访问减少大约 50%
  • 支持转置 GEMM 操作 (Support for Transposed GEMM Operations)
    • 当前的架构使得将矩阵转置与 GEMM 操作融合变得很麻烦
    • 在论文的工作流程中,前向传播期间的激活被量化为 1x128 的 FP8 tile 并存储
    • 在后向传播期间,矩阵需要被读出、反量化、转置、重新量化为 128x1 tile ,并存储在 HBM 中
    • 为了减少内存操作,论文建议未来的芯片能够在 MMA 操作之前直接从共享内存中对矩阵进行转置读取,以支持训练和推理中所需的那些精度
    • 结合 FP8 格式转换和 TMA 访问的融合,这一增强将显著简化量化工作流程

Pre-Training

Data Construction

  • 与 DeepSeek-V2 相比,论文通过提高数学和编程样本的比例来优化预训练语料库,同时将多语言覆盖范围扩展到英语和中文之外
  • 论文的数据处理流程经过改进,在保持语料库多样性的同时最大限度地减少了冗余
  • 受 (2024) 的启发,论文实施了文档打包方法(Document Packing Method)以保证数据完整性,但在训练期间并未引入跨样本注意力掩码
  • 最终,DeepSeek-V3 的训练语料库包含 14.8T 个高质量且多样化的 Token(使用论文的分词器)
  • 在 DeepSeekCoder-V2 (2024) 的训练过程中,论文观察到 Fill-in-Middle (FIM) 策略在使模型能够根据上下文线索准确预测中间文本的同时,并不会损害其下一个 Token 预测能力
    • 与 DeepSeekCoder-V2 保持一致,论文也在 DeepSeek-V3 的预训练中纳入了 FIM 策略
    • 具体来说,论文采用 Prefix-Suffix-Middle (PSM, Prefix, Suffix, Middle) 框架来结构化数据,如下所示:
      $$ \text{<|fim_begin|>} f_{\text{pre} } \text{<|fim_hole|>} f_{\text{suff} } \text{<|fim_end|>} f_{\text{middle} } \text{<|eos_token|>} $$
    • 该结构在文档级别应用,作为预打包过程的一部分
    • FIM 策略的应用率为 0.1,与 PSM 框架保持一致
    • 注:Fill-in-Middle(FIM,中间填充)是大语言模型的一种训练目标,核心是让模型根据上下文的 “前缀” 和 “后缀” 信息,预测并补全中间缺失的内容;这种设计旨在增强模型对文本全局逻辑的理解能力,尤其适用于需要双向参考上下文的场景
  • DeepSeek-V3 的分词器采用字节级 BPE (1999),并扩展了 128K Token 的词汇表
    • 论文对分词器的预分词器和训练数据进行了修改,以优化多语言压缩效率
    • 与 DeepSeek-V2 相比,新的预分词器引入了结合标点符号和换行符的 Token
      • 然而,当模型处理没有终止换行符的多行提示(特别是少样本评估提示)时,这种技巧可能会引入 Token 边界偏差 (2023)
      • 为了解决这个问题,论文在训练期间随机拆分一定比例的此类组合 Token ,使模型接触到更广泛的特例,从而减轻这种偏差

Hyper-Parameters

  • 模型超参数 (Model Hyper-Parameters)
    • 论文将 Transformer 层数设置为 61,隐藏维度设置为 7168
    • 所有可学习参数均使用标准差为 0.006 进行随机初始化
    • 在 MLA 中,论文将注意力头数 \(n_{h}\) 设置为 128,每个头的维度 \(d_{h}\) 设置为 128
    • KV 压缩维度 \(d_{c}\) 设置为 512, Query 压缩维度 \(d^{\prime}_{c}\) 设置为 1536
    • 对于解耦的 Query 和 Key ,论文将每个头的维度 \(d^{R}_{h}\) 设置为 64
    • 论文将除前三层之外的所有 FFN 替换为 MoE 层
    • 每个 MoE 层包含 1 个共享专家和 256 个路由专家,其中每个专家的中间隐藏维度为 2048
    • 在路由专家中,每个 Token 将激活 8 个专家,并且确保每个 Token 最多被发送到 4 个节点
    • 多 Token 预测深度 \(D\) 设置为 1 ,即除了精确的下一个 Token 外,每个 Token 还会预测一个额外的 Token
    • 与 DeepSeek-V2 一样,DeepSeek-V3 也在压缩潜在向量之后采用了额外的 RMSNorm 层,并在宽度瓶颈处乘以额外的缩放因子
    • 总结:在此配置下,DeepSeek-V3 总共包含 671B 参数,其中每个 Token 激活 37B 参数
  • 训练超参数 (Training Hyper-Parameters)
    • 采用 AdamW 优化器 (2017),超参数设置为 \(\beta_{1}=0.9\),\(\beta_{2}=0.95\),权重衰减 = 0.1
    • 在预训练期间,论文将最大序列长度设置为 4K,并在 14.8T Token 上对 DeepSeek-V3 进行预训练
    • 学习率调度方面
      • 首先,在前 2K 步期间将其从 \(0\) 线性增加到 \(2.2\times 10^{-4}\)
      • 然后,论文保持 \(2.2\times 10^{-4}\) 的恒定学习率,直到模型消耗完 10T 训练 Token
      • 随后,论文在 4.3T Token 内,按照余弦衰减曲线将学习率逐渐衰减到 \(2.2\times 10^{-5}\)
      • 在最后 500B Token 的训练期间,论文在前 333B Token 中保持 \(2.2\times 10^{-5}\) 的恒定学习率,并在剩余的 167B Token 中切换到另一个恒定学习率 \(7.3\times 10^{-6}\)
    • 梯度裁剪范数设置为 1.0
    • 论文采用批量大小调度策略(batch size scheduling strategy)
      • 在前 469B Token 的训练中,批量大小从 3072 逐渐增加到 15360,然后在剩余训练中保持 15360
    • 论文利用流水线并行将模型的不同层部署在不同的 GPU 上,对于每一层,路由专家将均匀部署在属于 8 个节点的 64 个 GPU 上
      • 问题:这里如何理解?
    • 至于节点限制路由 ,每个 Token 最多被发送到 4 个节点(即 \(M=4\))
    • 对于无辅助损失的负载平衡,论文在前 14.3T Token 中将偏置更新速度 \(\gamma\) 设置为 0.001,在剩余的 500B Token 中设置为 0.0
    • 对于平衡损失,论文将 \(\alpha\) 设置为 0.0001,仅用于避免任何单个序列内的极端不平衡
    • MTP 损失权重 \(\lambda\) 在前 10T Token 中设置为 0.3,在剩余的 4.8T Token 中设置为 0.1

Long Context Extension

  • 论文采用与 DeepSeek-V2 (2024) 类似的方法来使 DeepSeek-V3 具备长上下文能力
  • 在预训练阶段之后,论文应用 YaRN (2023) 进行上下文扩展,并执行两个额外的训练阶段,每个阶段包含 1000 步,以逐步将上下文窗口从 4K 扩展到 32K,然后再扩展到 128K
  • YaRN 配置与 DeepSeek-V2 中使用的配置一致,仅应用于解耦的共享 Key \(\mathbf{k}_{t}^{R}\)
  • 两个阶段的超参数保持相同
    • 尺度 \(s=40\),\(\alpha=1\),\(\beta=32\),缩放因子 \(\sqrt{t}=0.1\ln s+1\)
    • 在第一阶段,序列长度设置为 32K,批量大小为 1920
    • 在第二阶段,序列长度增加到 128K,批量大小减少到 480
  • 两个阶段的学习率均设置为 \(7.3\times 10^{-6}\),与预训练阶段的最终学习率相匹配
  • 通过这种两阶段的扩展训练,DeepSeek-V3 能够处理长达 128K 的输入,同时保持强大的性能
  • 图 8 显示,经过监督微调后,DeepSeek-V3 在 “Needle In A Haystack” (NIAH) 测试中取得了显著性能,在上下文窗口长度高达 128K 的范围内表现出一致的鲁棒性

Evaluations

Evaluation Benchmarks
  • DeepSeek-V3 的基础模型是在英语和中文占多数的多语言语料库上进行预训练的,因此论文在一系列主要以英语和中文为主的基准测试上评估其性能,同时也包括一个多语言基准测试
  • 论文的评估基于集成在论文 HAI-LLM 框架中的内部评估框架
  • 所考虑的基准测试分类并列出如下,其中带下划线的基准测试为中文,带双下划线的基准测试为多语言:
    • 多学科多项选择题 (Multi-subject multiple-choice) 数据集包括 MMLU (2020)、MMLU-Redux (2024)、MMLU-Pro (2024)、MMMLU (2024)、C-Eval (2023) 和 CMMLU (2023)
    • 语言理解和推理 (Language understanding and reasoning) 数据集包括 HellaSwag (2019)、PIQA (2020)、ARC (2018) 和 BigBench Hard (BBH) (2022)
    • 闭卷问答 (Closed-book question answering) 数据集包括 TriviaQA (2017) 和 NaturalQuestions (2019)
    • 阅读理解 (Reading comprehension) 数据集包括 RACE (2017)、DROP (2019)、C3 (2019) 和 CMRC (2019)
    • 指代消解 (Reference disambiguation) 数据集包括 CLUEWSC (2020) 和 WinoGrande (2019)
    • 语言建模 (Language modeling) 数据集包括 Pile (2020)
    • 中文理解与文化 (Chinese understanding and culture) 数据集包括 CCPM (2021)
    • 数学 (Math) 数据集包括 GSM8K (2021)、MATH (2021)、MGSM (2023) 和 CMath (2023)
    • 代码 (Code) 数据集包括 HumanEval (2021)、LiveCodeBench-Base (0801-1101) (2024)、MBPP (2021) 和 CRUXEval (2024)
    • 标准化考试 (Standardized exams) 包括 AGIEval (2023)。注意,AGIEval 包含英语和中文子集
  • 遵循论文之前的工作 (2024, 2024)
    • 对包括 HellaSwag、PIQA、WinoGrande、RACE-Middle、RACE-High、MMLU、MMLU-Redux、MMLU-Pro、MMMLU、ARC-Easy、ARC-Challenge、C-Eval、CMMLU、C3 和 CCPM 在内的数据集采用基于困惑度的评估;
    • 对 TriviaQA、NaturalQuestions、DROP、MATH、GSM8K、MGSM、HumanEval、MBPP、LiveCodeBench-Base、CRUXEval、BBH、AGIEval、CLUEWSC、CMRC 和 CMath 采用基于生成的评估
    • 此外,论文对 Pile-test 执行基于语言建模的评估,并使用 Bits-Per-Byte (BPB) 作为度量标准,以保证使用不同分词器的模型之间的公平比较
      • 理解:Pile-test 是大语言模型评估中常用的标准测试集,源于更庞大的通用文本数据集 The Pile
        • The Pile 由 EleutherAI 构建的大规模开源文本数据集,总规模约 800GB,涵盖 22 个不同来源的文本类型(如学术论文、网页文本、书籍、新闻、代码等),旨在为模型提供多样化、高质量的训练与评估数据,避免单一数据分布导致的 “过拟合评估”
        • Pile-test 是 The Pile 的测试子集,与训练集(Pile-train)严格划分,用于客观衡量模型在通用语言理解与生成任务上的泛化能力,由于其覆盖场景广,模型在 Pile-test 上的表现能更真实反映 “通用能力”,而非仅适配某类特定数据
      • 问题:这里 BPB 是什么含义?
      • 回答:Bits-Per-Byte(字节每比特,简称 BPB)是一种用于消除 “分词器差异” 影响、实现不同模型公平比较的度量标准,其核心是 “归一化模型的预测成本
        • 比如,计算模型的困惑度时,如果按照 Token 计算,可能是不公平的,不同模型的 Token 数量不一样,但是如果按照 BPB 计算,则能保证与 Tokenizer 无关
Evaluation Results
  • 在表 3 中,论文将 DeepSeek-V3 的基础模型与 SOTA 开源基础模型进行了比较,包括 DeepSeek-V2-Base (2024)(论文之前的发布)、Qwen2.5 72B Base (2024) 和 LLaMA-3.1 405B Base (2024)
  • 论文在内部评估框架下评估所有这些模型,并确保它们共享相同的评估设置
  • 请注意,由于过去几个月论文评估框架的变化,DeepSeek-V2-Base 的性能与论文之前报告的结果略有不同
  • 总体而言,DeepSeek-V3-Base 全面超越了 DeepSeek-V2-Base 和 Qwen2.5 72B Base,并在大多数基准测试中超过了 LLaMA-3.1 405B Base,基本上成为最强的开源模型
  • 从更详细的角度看,论文将 DeepSeek-V3-Base 与其他开源基础模型进行了单独比较
    • (1) 与 DeepSeek-V2-Base 相比,由于论文模型架构的改进、模型规模和训练 Token 的扩大以及数据质量的提高,DeepSeek-V3-Base 如预期般取得了显著更好的性能
    • (2) 与 SOTA 中文开源模型 Qwen2.5 72B Base 相比,DeepSeek-V3-Base 仅以一半的激活参数量,也在英语、多语言、代码和数学基准测试上展现出了显著优势,尤其是在中文基准测试上,除了中文多学科多项选择题任务 CMMLU 外,DeepSeek-V3-Base 也显示出比 Qwen2.5 72B 更好的性能
    • (3) 与最大的开源模型 LLaMA-3.1 405B Base(其激活参数量是 DeepSeek-V3-Base 的 11 倍)相比,DeepSeek-V3-Base 在多语言、代码和数学基准测试上也表现出更优的性能
    • 对于英语和中文语言基准测试,DeepSeek-V3-Base 显示出竞争性或更好的性能,尤其在 BBH、MMLU 系列、DROP、C-Eval、CMMLU 和 CCPM 上表现突出
  • 得益于论文高效的架构和全面的工程优化,DeepSeek-V3 实现了极高的训练效率
    • 在论文的训练框架和基础设施下,训练 DeepSeek-V3 每万亿 Token 仅需 180K H800 GPU 小时,这比训练 72B 或 405B 的稠密模型要便宜得多

Discussion

Ablation Studies for Multi-token Prediction
  • 在表 4 中,论文展示了 MTP 策略的消融结果
  • 具体来说,论文在两个不同规模的基线模型上验证了 MTP 策略
    • 在小规模上,论文在 1.33T Token 上训练了一个包含 15.7B 总参数的基线 MoE 模型
    • 在大规模上,论文在 540B Token 上训练了一个包含 228.7B 总参数的基线 MoE 模型
  • 在此基础上,保持训练数据和其他架构不变,论文为它们附加了一个深度为 1 的 MTP 模块,并训练了两个采用 MTP 策略的模型进行比较
  • 注意,在推理期间,论文直接丢弃 MTP 模块,因此比较模型的推理成本完全相同
  • 从表中我们可以观察到,MTP 策略在大多数评估基准上持续提升了模型性能
Ablation Studies for the Auxiliary-Loss-Free Balancing Strategy
  • 在表 5 中,论文展示了无辅助损失平衡策略的消融结果
  • 论文在两个不同规模的基线模型上验证了该策略
    • 在小规模上,论文在 1.33T Token 上训练了一个包含 15.7B 总参数的基线 MoE 模型
    • 在大规模上,论文在 578B Token 上训练了一个包含 228.7B 总参数的基线 MoE 模型
  • 这两个基线模型都纯粹使用辅助损失来鼓励负载平衡,并使用 sigmoid 门控函数和 top-K 亲和度归一化
    • 它们控制辅助损失强度的超参数分别与 DeepSeek-V2-Lite 和 DeepSeek-V2 相同
  • 在这两个基线模型的基础上,保持训练数据和其他架构不变,论文移除了所有辅助损失,并引入了无辅助损失平衡策略进行比较
  • 从表中我们可以观察到,与纯辅助损失方法相比,无辅助损失策略在大多数评估基准上持续实现了更好的模型性能
Batch-Wise Load Balance VS. Sequence-Wise Load Balance
  • 无辅助损失平衡与序列级辅助损失之间的关键区别在于它们的平衡范围:批次级与序列级
    • 与序列级辅助损失相比,批次级平衡施加了更灵活的约束,因为它不强制每个序列在域内平衡
    • 这种灵活性允许专家更好地专精于不同领域
  • 为了验证这一点,论文记录并分析了基于辅助损失的 16B 基线模型和无辅助损失的 16B 模型在 Pile 测试集上不同领域的专家负载
  • 如图 9 所示,论文观察到无辅助损失模型如预期那样表现出更大的专家专精模式(注:也就是说有辅助损失的模型得到的专家负载均衡结果更加平均)
  • 为了进一步研究这种灵活性与模型性能优势之间的相关性,论文额外设计并验证了一种批次级辅助损失,该损失鼓励每个训练批次而不是每个序列上的负载平衡
  • 实验结果表明,当达到相似水平的批次级负载平衡时,批次级辅助损失也可以实现与无辅助损失方法相似的模型性能
  • 具体来说,在论文使用 1B MoE 模型的实验中,验证损失分别为:2.258(使用序列级辅助损失)、2.253(使用无辅助损失方法)和 2.253(使用批次级辅助损失)
  • 论文在 3B MoE 模型上也观察到了类似的结果:使用序列级辅助损失的模型验证损失为 2.085,而使用无辅助损失方法或批次级辅助损失的模型验证损失均为 2.080
  • 此外,尽管批次级负载平衡方法显示出持续的性能优势,但它们在效率方面也面临两个潜在的挑战:
    • (1) 某些序列内或小批次内的负载不平衡
    • (2) 推理期间由领域偏移引起的负载不平衡
  • 第一个挑战通过论文使用大规模专家并行和数据并行的训练框架自然得到解决,这保证了每个微批次的规模足够大
  • 对于第二个挑战,论文也设计并实现了一个具有冗余专家部署的高效推理框架,如第 3.4 节所述,以克服它

Post-Training

Supervised Fine-Tuning

  • 论文精心策划了论文的指令微调数据集,包含了涵盖多个领域的 150 万条实例,每个领域都采用了针对其特定需求而定制的不同数据创建方法
  • 推理数据 (Reasoning Data)
    • 对于推理相关的数据集,包括那些专注于数学、代码竞赛问题和逻辑谜题的,论文通过利用内部的 DeepSeek-R1 模型来生成数据
      • 问题:DeepSeek-R1 的一些流程依赖着 DeepSeek-V3 添加一些 CoT 数据吧,两者之间目前从论文看起来是互相依赖的关系
    • 具体来说,虽然 R1 生成的数据表现出很高的准确性,但它存在一些问题,如过度思考、格式不佳和长度过长
    • 论文的目标是在 R1 生成的高精度推理数据和常规格式的清晰简洁的推理数据之间取得平衡
    • 为了建立论文的方法,论文首先为特定领域(例如代码、数学或通用推理)开发一个专家模型,该模型使用结合了 SFT 和 RL 的训练流程
      • 这个专家模型作为最终模型的数据生成器
    • 训练过程涉及为每个实例生成两种不同类型的 SFT 样本:
      • 第一种将问题与其原始回答配对,格式为 <问题, 原始回答>;
      • 第二种则结合了系统提示、问题以及 R1 的回答,格式为 <系统提示, 问题, R1 回答>
    • 系统提示经过精心设计,包含指导模型产生富含反思和验证机制的回答的指令
      • 在 RL 阶段,模型利用高温采样来生成融合了 R1 生成数据和原始数据模式的回答,即使在没有显式系统提示的情况下也是如此。经过数百个 RL 步骤后,中间的 RL 模型学会了融入 R1 的模式,从而战略性地提升了整体性能
    • 在完成 RL 训练阶段后,论文实施拒绝采样来为最终模型筛选高质量的 SFT 数据,其中专家模型被用作数据源
      • 这种方法确保了最终训练数据保留了 DeepSeek-R1 的优势,同时产生的回答简洁有效
  • 非推理数据 (Non-Reasoning Data)
    • 对于非推理数据,例如创意写作、角色扮演和简单问答,论文使用 DeepSeek-V2.5 来生成回答,并聘请人工标注员来验证数据的准确性和正确性
  • SFT 设置 (SFT Settings)
    • 使用 SFT 数据集对 DeepSeek-V3-Base 进行了两个 epoch 的微调
    • 使用了余弦衰减学习率调度
    • 学习率从 \(5 \times 10^{-6}\) 开始,逐渐降低到 \(1 \times 10^{-6}\)
    • 在训练期间,每个单独的序列由多个样本打包而成
      • 论文采用了样本掩码策略来确保这些样本保持隔离且相互不可见

Reinforcement Learning

Reward Model
  • 论文在 RL 过程中使用了基于规则的奖励模型 (Reward Model, RM) 和基于模型的 RM
  • 基于规则的 RM (Rule-Based RM)
    • 对于可以使用特定规则验证的问题,论文采用基于规则的奖励系统来确定反馈
      • 例如,某些数学问题具有确定性的结果,论文要求模型以指定的格式(例如,在一个框内)提供最终答案,从而允许论文应用规则来验证正确性
      • 类似地,对于 LeetCode 问题,我们可以利用编译器根据测试用例生成反馈
    • 通过在可能的情况下利用基于规则的验证,论文确保了更高水平的可靠性,因为这种方法能够抵抗操纵或利用(manipulation or exploitation)
  • 基于模型的 RM (Model-Based RM)
    • 对于具有自由形式标准答案的问题,论文依赖奖励模型来确定回答是否符合预期的标准答案
    • 相反,对于没有明确标准答案的问题,例如涉及创意写作的问题,奖励模型的任务是基于问题和相应的回答作为输入来提供反馈
    • 该奖励模型是从 DeepSeek-V3 SFT 检查点训练而来的
    • 为了增强其可靠性,论文构建了不仅提供最终奖励,还包含导致该奖励的思维链 (Chain-of-Thought) 的偏好数据
    • 这种方法有助于减轻特定任务中 Reward Hacking 的风险
GRPO (Group Relative Policy Optimization)
  • 与 DeepSeek-V2 (2024c) 类似,论文采用 GRPO(2024),它摒弃了通常与策略模型大小相同的评论家模型 (Critic Model),而是从组分数中估计基线
  • 具体来说,对于每个问题 \(q\),GRPO 从旧策略模型 \(\pi_{\theta_{old} }\) 中采样一组输出 \(\{o_{1}, o_{2}, \cdots, o_{G}\}\),然后通过最大化以下目标来优化策略模型 \(\pi_{\theta}\):
    $$
    \mathcal{J}_{GRPO}(\theta) = \mathbb{E} \left[ q \sim P(Q), \{o_{i}\}_{i=1}^{G} \sim \pi_{\theta_{old} }(O|q) \right] \frac{1}{G} \sum_{i=1}^{G} \left( \min \left( \frac{\pi_{\theta}(o_{i}|q)}{\pi_{\theta_{old} }(o_{i}|q)} A_{i}, \text{clip} \left( \frac{\pi_{\theta}(o_{i}|q)}{\pi_{\theta_{old} }(o_{i}|q)}, 1-\epsilon, 1+\epsilon \right) A_{i} \right) - \beta \mathbb{D}_{KL} \left( \pi_{\theta} || \pi_{ref} \right) \right), \\
    \mathbb{D}_{KL} \left( \pi_{\theta} || \pi_{ref} \right) = \frac{\pi_{ref}(o_{i}|q)}{\pi_{\theta}(o_{i}|q)} - \log \frac{\pi_{ref}(o_{i}|q)}{\pi_{\theta}(o_{i}|q)} - 1,
    $$
    • 其中 \(\epsilon\) 和 \(\beta\) 是超参数;\(\pi_{ref}\) 是参考模型;\(A_{i}\) 是优势 (Advantage),由每个组内输出对应的奖励 \(\{r_{1}, r_{2}, \ldots, r_{G}\}\) 推导得出:
      $$
      A_{i} = \frac{r_{i} - \text{mean}(\{r_{1}, r_{2}, \cdots, r_{G}\})}{\text{std}(\{r_{1}, r_{2}, \cdots, r_{G}\})}.
      $$
  • 论文在 RL 过程中融入了来自不同领域的提示,例如编码、数学、写作、角色扮演和问答
    • 这种方法不仅使模型更符合人类偏好,还提高了在基准测试上的性能,特别是在可用 SFT 数据有限的情况下

Evaluations

Evaluation Settings
  • 评估基准 (Evaluation Benchmarks)
    • 除了用于基础模型测试的基准之外,论文还在 IFEval (2023)、FRAMES (2024)、LongBench v2 (2024)、GPQA (2023)、SimpleQA (OpenAI, 2024c)、C-SimpleQA (2024)、SWE-Bench Verified (OpenAI, 2024d)、Aider、LiveCodeBench (2024)(2024年8月至11月的问题)、Codeforces、中国高中数学奥林匹克 (Chinese National High School Mathematics Olympiad, CNMO 2024) 和美国数学邀请赛 2024 (American Invitational Mathematics Examination 2024, AIME 2024) (MAA, 2024) 上进一步评估了指令模型
  • 对比基线 (Compared Baselines)
    • 论文对论文的聊天模型与几个强大的基线进行了全面评估,包括 DeepSeek-V2-0506、DeepSeek-V2.5-0905、Qwen2.5 72B Instruct、LLaMA-3.1 405B Instruct、Claude-Sonnet-3.5-1022 和 GPT-4o-0513。对于 DeepSeek-V2 模型系列,论文选择了最具代表性的变体进行比较
    • 对于闭源模型,通过它们各自的 API 进行评估
  • 详细评估配置 (Detailed Evaluation Configurations)
    • 对于包括 MMLU、DROP、GPQA 和 SimpleQA 在内的标准基准,论文采用了来自 simple-evals 框架的评估提示
      • 注:simple-evals 是一个轻量级的大语言模型评估框架,主要用于快速验证模型在特定任务上的基础能力,其Evaluation Prompts设计遵循简洁、直接的任务导向原则,通常针对具体任务类型构建标准化提示模板
    • 对于 MMLU-Redux,论文在零样本设置下使用了 Zero-Eval 提示格式 (Lin, 2024)
    • 对于其他数据集,论文遵循其原始的评估协议,使用数据集创建者提供的默认提示
    • 对于代码和数学基准,HumanEval-Mul 数据集总共包含了 8 种主流编程语言(Python、Java、C++、C#、JavaScript、TypeScript、PHP 和 Bash)
      • 使用 CoT 和非 CoT 方法来评估模型在 LiveCodeBench 上的性能,其数据收集自 2024 年 8 月至 11 月
      • Codeforces 数据集使用参赛者百分比来衡量
      • SWE-Bench verified 使用无代理框架 (Agentless Framework) (2024) 进行评估
      • 论文使用 “diff” 格式来评估与 Aider 相关的基准
    • 对于数学评估,AIME 和 CNMO 2024 在温度为 0.7 的情况下进行评估,结果取 16 次运行的平均值,而 MATH-500 则使用贪婪解码
    • 论文允许所有模型为每个基准输出最多 8192 个 Token
Standard Evaluation
  • 表 6 展示了评估结果,表明 DeepSeek-V3 是性能最好的开源模型
  • 此外,它与前沿的闭源模型(如 GPT-4o 和 Claude-3.5-Sonnet)相比也具有竞争力
  • 英文基准 (English Benchmarks)
    • MMLU 是一个广泛认可的基准,旨在评估大语言模型在不同知识领域和任务上的表现
      • DeepSeek-V3 展示了具有竞争力的性能,与顶级模型如 LLaMA-3.1-405B、GPT-4o 和 Claude-Sonnet 3.5 不相上下,同时显著优于 Qwen2.5 72B
      • DeepSeek-V3 在 MMLU-Pro(一个更具挑战性的教育知识基准)上表现出色,紧随 Claude-Sonnet 3.5 之后
      • 在 MMLU-Redux(一个带有修正标签的 MMLU 改进版本)上,DeepSeek-V3 超越了其同行
    • 在 GPQA-Diamond(一个博士级评估测试平台)上,DeepSeek-V3 取得了显著成果,仅次于 Claude 3.5 Sonnet,并以显著优势优于所有其他竞争对手
    • 在长上下文理解基准测试中,如 DROP、LongBench v2 和 FRAMES,DeepSeek-V3 继续展示其作为顶级模型的地位
      • DeepSeek-V3 在 DROP 的 3-shot 设置中取得了令人印象深刻的 91.6 F1 分数,超过了该类别中的所有其他模型
      • 在 FRAMES(一个需要在 100k Token 上下文中进行问答的基准测试)上,DeepSeek-V3 紧随 GPT-4o,同时以显著优势优于所有其他模型
        • 这证明了 DeepSeek-V3 在处理极长上下文任务方面的强大能力
      • DeepSeek-V3 的长上下文能力进一步通过其在 LongBench v2(一个在 DeepSeek V3 发布前几周才发布的数据集)上的最佳表现得到了验证
      • 在事实性知识基准 SimpleQA 上,DeepSeek-V3 落后于 GPT-4o 和 Claude-Sonnet,这主要归因于其设计重点和资源分配
      • DeepSeek-V3 分配了更多的训练 Token 来学习中文知识,导致在 C-SimpleQA 上表现卓越
      • 在指令遵循基准测试中,DeepSeek-V3 显著优于其前身 DeepSeek-V2 系列,突显了其理解和遵守用户定义格式约束能力的提升
  • 代码和数学基准 (Code and Math Benchmarks)
    • 编码对于大语言模型来说是一项具有挑战性且实用的任务,涵盖了以工程为重点的任务(如 SWE-Bench-Verified 和 Aider)以及算法任务(如 HumanEval 和 LiveCodeBench)
    • 在工程任务中,DeepSeek-V3 落后于 Claude-Sonnet-3.5-1022,但显著优于开源模型
      • 开源的 DeepSeek-V3 有望推动编码相关工程任务的进步
    • 通过提供对其强大功能的访问,DeepSeek-V3 可以推动软件工程和算法开发等领域的创新和改进,使开发人员和研究人员能够突破开源模型在编码任务中能力的界限
    • 在算法任务中,DeepSeek-V3 表现出卓越的性能,在 HumanEval-Mul 和 LiveCodeBench 等基准测试中优于所有基线
      • 这一成功归功于其先进的知识蒸馏技术,该技术有效地增强了其在算法任务中的代码生成和问题解决能力
    • 在数学基准测试中,DeepSeek-V3 展示了卓越的性能,显著超过了基线,并为非 o1 类模型设定了新的最先进水平
      • 具体来说,在 AIME、MATH-500 和 CNMO 2024 上,DeepSeek-V3 在绝对分数上比第二好的模型 Qwen2.5 72B 高出约 10%,这对于如此具有挑战性的基准测试来说是一个显著的差距
      • 这种卓越的能力突显了从 DeepSeek-R1 进行蒸馏技术的有效性,该技术已被证明对非 o1 类模型非常有益
  • 中文基准 (Chinese Benchmarks)
    • Qwen 和 DeepSeek 是两个对中文和英文都有强大支持的代表性模型系列
    • 在事实性基准 Chinese SimpleQA 上,DeepSeek-V3 超过了 Qwen2.5-72B 16.4 分,尽管 Qwen2.5 是在包含 18T Token (比 DeepSeek-V3 预训练的 14.8T Token 多 20%)的更大语料库上训练的
    • 在 C-Eval(一个代表性的中文教育知识评估基准)和 CLUEWSC(中文 Winograd 模式挑战赛)上,DeepSeek-V3 和 Qwen2.5-72B 表现出相似的水平,表明这两个模型都为具有挑战性的中文推理和教育任务进行了良好的优化
Open-Ended Evaluation
  • 除了标准基准测试,论文还使用大语言模型作为评判者对论文的模型在开放式生成任务上进行了评估,结果如表 7 所示
  • 论文遵循 AlpacaEval 2.0 (2024) 和 Arena-Hard (2024a) 的原始配置,它们利用 GPT-4-Turbo-1106 作为成对比较的评判者
  • 在 Arena-Hard 上,DeepSeek-V3 相对于基线 GPT-4-0314 取得了超过 86% 的惊人胜率,与 Claude-Sonnet-3.5-1022 等顶级模型表现相当
    • 这凸显了 DeepSeek-V3 的强大能力,尤其是在处理复杂提示(包括编码和调试任务)方面
  • DeepSeek-V3 实现了一个突破性的里程碑,成为第一个在 Arena-Hard 基准测试中超过 85% 的开源模型
    • 这一成就显著缩小了开源模型和闭源模型之间的性能差距,为开源模型在挑战性领域所能达到的水平设定了新标准
  • 类似地,DeepSeek-V3 在 AlpacaEval 2.0 上展示了卓越的性能,优于闭源和开源模型
    • 这证明了其在写作任务和处理直接问答场景方面的出色熟练度
  • 值得注意的是,它以 20% 的显著优势超过了 DeepSeek-V2.5-0905,突显了其在处理简单任务方面的实质性改进,并展示了其进步的有效性
DeepSeek-V3 as a Generative Reward Model
  • 论文将 DeepSeek-V3 的判断能力与 SOTA 模型(即 GPT-4o 和 Claude-3.5)进行了比较
  • 表 8 展示了这些模型在 RewardBench (2024) 中的性能(注:RewardBench 是用来评估 Reward Model 本身性能的)
  • DeepSeek-V3 达到了与最佳版本的 GPT-4o-0806 和 Claude-3.5-Sonnet-1022 相当的水平,同时超越了其他版本
  • 此外,DeepSeek-V3 的判断能力也可以通过投票技术得到增强
  • 因此,论文采用 DeepSeek-V3 并结合投票来为开放式问题提供自我反馈,从而提高对齐过程的有效性和鲁棒性

Discussion

Distillation from DeepSeek-R1
  • 论文基于 DeepSeek-V2.5 对从 DeepSeek-R1 进行蒸馏的贡献进行了消融研究
  • 基线是在短 CoT 数据上训练的,而其对比模型则使用上述专家检查点生成的数据
  • 表 9 展示了蒸馏数据的有效性,在 LiveCodeBench 和 MATH-500 基准测试上均显示出显著的改进
  • 论文的实验揭示了一个有趣的权衡:蒸馏带来了更好的性能,但也显著增加了平均回答长度
  • 为了在模型准确性和计算效率之间保持平衡,论文为 DeepSeek-V3 在蒸馏过程中仔细选择了最优设置
  • 论文的研究表明,从推理模型进行知识蒸馏为后训练优化提供了一个有前景的方向
    • 虽然论文目前的工作侧重于蒸馏数学和编码领域的数据,但这种方法显示出在需要复杂推理的各种任务领域中具有更广泛应用的潜力
    • 在这些特定领域展示的有效性表明,长 CoT 蒸馏对于增强其他认知任务中的模型性能可能很有价值
    • 在不同领域进一步探索这种方法仍然是未来研究的一个重要方向
Self-Rewarding
  • 奖励在强化学习中起着关键作用,引导着优化过程
  • 在通过外部工具易于验证的领域,例如某些编码或数学场景,强化学习表现出卓越的功效
    • 但在更一般的场景中,通过硬编码构建反馈机制是不切实际的
  • 在 DeepSeek-V3 的开发过程中,对于这些更广泛的上下文,论文采用了 Constitutional AI 方法 (2022),利用 DeepSeek-V3 自身的投票评估结果作为反馈源
    • 这种方法产生了显著的对齐效果,显著提升了 DeepSeek-V3 在主观评估中的性能
    • 通过整合额外的 Constitutional 输入,DeepSeek-V3 可以向 Constitutional 方向优化
  • 作者相信,这种将补充信息与大语言模型作为反馈源相结合的范式至关重要
    • 大语言模型作为一个通用的处理器,能够将来自不同场景的非结构化信息转化为奖励,最终促进大语言模型的自我改进
    • 除了自我奖励,论文还致力于发现其他通用和可扩展的奖励方法,以持续提升模型在一般场景中的能力
Multi-token Prediction Evaluation
  • DeepSeek-V3 不仅仅预测下一个单独的 Token ,而是通过 MTP 技术预测接下来的 2 个 Token
  • 结合推测解码 (Speculative Decoding) (2023; 2023) 的框架,它可以显著加速模型的解码速度
  • 一个自然的问题是,额外预测的 Token 的接受率如何
    • 根据论文的评估,在不同生成主题下,第二个 Token 预测的接受率在 85% 到 90% 之间,表现出一致的可靠性
    • 这种高接受率使得 DeepSeek-V3 能够实现显著提高的解码速度,提供 1.8 倍的 TPS(每秒 Token 数)

Limitations, and Future Directions

Limitations

  • 在承认其强大性能和成本效益的同时,论文也认识到 DeepSeek-V3 存在一些局限性,尤其是在部署方面
  • 首先,为了确保高效的推理,DeepSeek-V3 推荐的部署单元相对较大,这可能会给小型团队带来负担
  • 其次,尽管论文为 DeepSeek-V3 设计的部署策略已经实现了端到端生成速度超过 DeepSeek-V2 的两倍,但仍有进一步提升的潜力
  • 幸运的是,随着更先进硬件的发展,这些局限性有望得到自然解决

Future Directions

  • DeepSeek 始终坚持具有长期主义(longtermism)的开源模型路线,旨在稳步接近通用人工智能(Artificial General Intelligence, AGI)的终极目标
  • 未来,论文计划在以下几个方向进行战略性投入和研究:
    • DeepSeek 将持续研究和改进论文的模型架构,旨在进一步提高训练和推理效率,努力实现对无限上下文长度的高效支持
      • 此外,DeepSeek 将尝试突破 Transformer 的架构限制,从而推动其建模能力的边界
    • DeepSeek 将持续迭代训练数据的数量和质量,并探索纳入额外的训练信号源,旨在推动数据在更全面维度上的扩展
    • DeepSeek 将持续探索和迭代模型的深度思考能力,旨在通过扩展其推理长度和深度来增强其智能和问题解决能力
    • DeepSeek 将探索更全面、多维度的模型评估方法,以防止在研究过程中倾向于优化固定的基准测试集,这可能造成对模型能力的误导性印象并影响论文的基础评估

附录 B: Ablation Studies for Low-Precision Training

  • 图 10:BF16 与 FP8 训练的损失曲线比较(结果使用系数为 0.9 的指数移动平均(Exponential Moving Average, EMA)进行了平滑处理)

B.1: FP8 v.s. BF16 Training

  • 论文在两个不同规模的基线模型上验证了论文的 FP8 混合精度框架,并与 BF16 训练进行了比较
    • 在小规模上,论文训练了一个总参数量约为 16B 的 MoE 基线模型,使用了 1.33T token
    • 在大规模上,论文训练了一个总参数量约为 230B 的 MoE 基线模型,使用了约 0.9T token
  • 论文在图 10 中展示了训练曲线,并证明了通过论文的高精度累加和细粒度量化策略,相对误差保持在 0.25% 以下

B.2:Discussion About Block-Wise Quantization(分块量化)

  • 尽管论文的切片式(tile-wise)细粒度量化有效缓解了特征异常值(feature outliers)引入的误差,但它需要对激活量化进行不同的分组,即在正向传播中使用 1x128 的分组,在反向传播中则需要 128x1 的分组
    • 激活梯度也需要类似的处理过程
  • 一个直接的策略是像论文量化模型权重那样,对每 128x128 个元素应用分块(block-wise)量化
    • 这样,反向传播只需要进行转置操作
  • 论文进行了一项实验,将与 Dgrad 相关的所有张量都在分块基础上进行量化
    • 结果显示,计算激活梯度并以链式方式反向传播到浅层的 Dgrad 操作对精度高度敏感
  • 论文在总参数量约为 16B、训练了约 300B Token 的 MoE 模型上,激活梯度的分块量化会导致模型发散
    • 论文假设这种敏感性源于激活梯度在 Token 之间高度不平衡,导致了与 Token 相关的异常值 (2023)
    • 这些异常值无法通过分块量化方法有效管理

附录 C:16B 基于辅助损失和无辅助损失模型的专家专业化模式 (Expert Specialization Patterns of the 16B Aux-Loss-Based and Aux-Loss-Free Models)

  • 本节记录了 16B 基于辅助损失的基线模型和无辅助损失模型在 Pile 测试集上的专家负载
  • 如图 11 所示,无辅助损失模型在所有层中都倾向于表现出更强的专家专业化程度(Expert Specialization)
    • 问题:如何理解这里的 专家专业化程度(Expert Specialization)?
    • 理解:是指某些任务上,某类专家被激活的更多,而在另外的任务上,其他专家被激活的更多,专家体现出一定的专业化倾向(注:这不是我们想要的,因为会导致模型拟合能力下降,且单个序列上不利于 EP 负载均衡)
  • 图 11:无辅助损失模型和基于辅助损失的模型在 Pile 测试集的三个领域上的专家负载
    • 无辅助损失模型 比 基于辅助损失模型 表现出 更强的专家专业化模式
    • 相对专家负载(Relative Expert Load)表示实际专家负载与理论平衡专家负载之间的比率

RS——生成式推荐

  • 参考链接:
    • (TIGER)Recommender Systems with Generative Retrieval, NeurIPS 2023, Google: 一篇比较早,也比较基础的工作
    • OneRec: Unifying Retrieve and Rank with Generative Recommender and Preference Alignment, 202502, KuaiShou
    • (COBRA)Sparse Meets Dense: Unified Generative Recommendations with Cascaded Sparse-Dense Representations, 202503, Baidu
    • (RARE)Real-time Ad retrieval via LLM-generative Commercial Intention for Sponsored Search Advertising, 202504, Tencent
      • 相关博客:腾讯搜索广告生成式检索
    • 一文汇总:LLM应用到推荐系统的各类玩法总结
    • Slow Thinking for Sequential Recommendation, 2025, RUC
      • 相关博客:STREAM-Rec: 推荐系统实现慢思考推理
    • LEARN: Knowledge Adaptation from Large Language Model to Recommendation for Practical Industrial Application, AAAI 2025, Kuaishou
      • 相关博客:AAAI’25 | 快手LEARN:使用LLM做特征增强用于电商广告推荐
    • Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for Generative Recommendations
    • HLLM: Enhancing Sequential Recommendations via Hierarchical Large Language Models for Item and User Modeling, 202409, ByteDance
    • Unlocking Scaling Law in Industrial Recommendation Systems with a Three-step Paradigm based Large User Model, 202502, Alibaba

整体说明

  • 传统深度学习推荐模型(Deep Learning Recommendation Models,DLRMs)范式主要是“Retrieval-Ranking”模式
  • 目前大模型在推荐领域的实践大致分为三类:
    • 第一类:保留 DLRMs 范式,仅利用通用的 LLM 做知识增强,或从领域数据中提取表征信息等
    • 第二类:仍保留 DLRMs 范式,部分重构 Retrieval 阶段或者 Ranking 阶段(或者在 Retrieval 阶段增加一路生成式召回通道)
    • 第二类:直接修改 DLRMs 范式,将输入和输出都重新构造为 LLM 模式,直接利用重构后的领域数据做预训练和微调

NLP——Gated-Delta-Net

注:本文包含 AI 辅助创作

  • 参考链接:
    • 原始文档:(Gated DeltaNet)Gated Delta Networks: Improving Mamba2 with Delta Rule, ICLR 2025, NVIDIA
    • GitHub:github.com/NVlabs/GatedDeltaNet

Paper Summary

  • 整体总结:
    • 本文提出了一种新的框架:Gated DeltaNet
      • vs Mamba2:能够实现更好的键值关联学习
      • vs DeltaNet:具有更强的自适应内存清除能力
  • 背景 & 问题:
    • 线性 Transformer (Linear Transformers) 作为标准 Transformer 的高效替代方案获得了关注,但它们在检索和长上下文任务中的性能有限
    • 为了解决这些限制,最近的工作探索了两种不同的机制:
      • 用于自适应内存控制的门控 (gating for adaptive memory control)
      • 用于精确内存修改的 Delta 更新规则 (delta update rule for precise memory modifications)
    • 论文观察到这些机制是互补的,门控支持快速内存擦除,而 Delta 规则促进定向更新
      • 基于这一见解,论文引入了门控 Delta 规则 (gated delta rule),并开发了一种针对现代硬件优化的并行训练算法
  • 论文提出新的架构 Gated DeltaNet
    • 在包括语言建模、常识推理、上下文检索、长度外推和长上下文理解在内的多个基准测试中,超越了 Mamba2 和 DeltaNet 等现有模型
  • 作者通过开发将 Gated DeltaNet 层与滑动窗口注意力 (sliding window attention) 或 Mamba2 层相结合的混合架构,进一步提升了性能,实现了改进的训练效率和卓越的任务性能

Introduction and Discussion

  • Transformer 架构显著提升了 LLM 的能力,由于其有效的注意力机制,在各种任务上展现出卓越的性能
    • 该机制在精确序列建模方面表现出色,并在训练期间利用了现代 GPU 的并行处理能力
    • 但自注意力 (self-attention) 组件的计算复杂度随序列长度呈二次方增长,导致巨大的计算需求,给训练和推理都带来了挑战
  • 为了缓解这些问题,研究人员探索了诸如线性 Transformer (2020a) 等替代方案,它们用基于核化点积的线性注意力 (kernelized dot-product-based linear attention) 取代了传统的基于 softmax 的注意力,通过将其重构为具有矩阵值状态的线性 RNN,显著减少了推理期间的内存需求
  • 虽然早期版本的线性 Transformer 在语言建模任务中表现不如标准 Transformer,但最近的增强已经显示出有前景的改进
    • 例如结合类似于 LSTM 中的数据依赖门控机制,以 GLA (2024a) 和 Mamba2 (2024a) 等模型为例
  • 然而,在管理长序列信息方面仍然存在挑战,特别是在上下文检索任务中,传统 Transformer 仍保持其优势 (2023a; 2024;)
  • 这种现象并不令人惊讶:
    • 线性 Transformer 可以被解释为实现了一种基于外积 (outer-product) 的键值关联记忆,让人联想到张量积表示 (1990)
    • 但它们可以存储的正交键值对的数量受到模型维度的限制
      • 当序列长度超过这个维度时,“记忆碰撞 (memory collisions)”将不可避免,阻碍精确检索 (2021a)
  • Mamba2 通过引入一个简单的门控更新规则来解决这个限制:
    $$ \mathbf{S}_{t}=\alpha_{t}\mathbf{S}_{t-1}+\boldsymbol{v}_{t}\boldsymbol{k}^{\mathrm{T} }_{t}$$
    • 该规则在每个时间步通过一个动态比率 \(\alpha_t \in (0,1)\) 统一衰减所有键值关联
  • 但这种方法没有考虑不同键值关联的重要性差异,可能导致内存利用效率低下
    • 如果模型需要忘记一个特定的键值关联,所有键值关联都会被同等程度地遗忘 ,使得这个过程缺乏针对性和效率
  • 相比之下,采用 Delta 规则 (1960) 的线性 Transformer,即 DeltaNet (2021a; 2024b),通过(软性地)用传入的新键值对替换旧的键值对来选择性更新记忆
    • 这种方法在上下文检索的合成基准测试中展示了令人印象深刻的性能
    • 但由于这个过程((一次只修改一个键值对)),模型缺乏快速清除过时或无关信息的能力 ,尤其是在需要擦除先前数据的上下文切换期间
    • 因此,人们发现 DeltaNet 在现实世界任务中表现一般 (2024b),这很可能是因为缺乏强大的内存清除机制
  • 认识到门控更新规则和 Delta 规则在内存管理方面的互补优势,论文提出了门控 Delta 规则 (gated delta rule)
    • 这是一种简单直观的机制,结合了两种方法
    • 这个统一的规则实现了灵活的内存控制:
      • 可以通过设置 \(\alpha_{t}\to 0\) 来迅速清除内存
      • 同时可以通过设置 \(\alpha_{t}\to 1\) 来选择性更新特定内容而不影响其他信息(有效地切换到纯 Delta 规则)
  • 剩下的挑战在于以硬件高效的方式实现门控 Delta 规则
    • 基于 (2024b) 使用 WY 表示 (1985) 并行化 Delta 规则计算的高效算法,论文仔细扩展了他们的方法以纳入门控项
    • 论文的扩展保留了分块并行 (chunkwise parallelism) 的优势 (2022a; 2023a; 2024a),实现了硬件高效的训练
  • 论文最终的架构 Gated DeltaNet,在一套全面的基准测试中,包括语言建模、常识推理、上下文检索、长度外推和长上下文理解,持续优于 Mamba2 和 DeltaNet
  • 基于这些结果,作者还开发了混合架构,策略性地将 Gated DeltaNet 层与滑动窗口注意力或 Mamba2 层相结合,进一步提升了训练效率和模型性能

Preliminary

Mamba2:带衰减的线性注意力 (Mamba2: Linear Attention with decay)

  • 众所周知,线性 Transformer (2020a) 在排除归一化和查询/键激活的情况下,可以表述为以下线性递归:
    $$
    \begin{align}
    \mathbf{S}_{t}&=\mathbf{S}_{t-1}+\boldsymbol{v}_{t}\boldsymbol{k}_{t}^{\intercal} \in\mathbb{R}^{d_{v}\times d_{k} },\\
    \boldsymbol{o}_{t}&=\mathbf{S}_{t}\boldsymbol{q}_ {t}\in\mathbb{R}^{d_{v} }
    \end{align}
    $$
    • 其中 \(d_{k}\) 和 \(d_{v}\) 分别代表查询/键 (query/key) 和值 (value) 的(头)维度
  • 通过展开递归,我们可以将其表示为向量形式(左)和矩阵形式(右)如下:
    $$
    \begin{align}
    \boldsymbol{o}_{t}&=\sum_{i=1}^{t}(\boldsymbol{v}_{i}\boldsymbol{k}_{i}^{\intercal})\boldsymbol{q }_{t}=\sum_{i=1}^{t}\boldsymbol{v}_{i}(\boldsymbol{k}_{i}^{\intercal}\boldsymbol{q}_{t})\in\mathbb{R}^{d_{v} },\\
    \mathbf{O}&=(\mathbf{Q}\mathbf{K}^{\intercal}\odot\mathbf{M}) \mathbf{V}\in\mathbb{R}^{L\times d_{v} }
    \end{align}
    $$
    • 其中 \(L\) 是序列长度,\(\mathbf{M}\in\mathbb{R}^{L\times L}\) 是由 \(\mathbf{M}_{ij}=0\)(当 \(i< j\))和 \(1\)(其他情况)定义的因果掩码 (causal mask)
  • 然而,这种普通的线性注意力在语言建模中表现远不如 Transformer
    • 理解:本质上 \(\sum_{i=1}^{t}\boldsymbol{v}_{i}(\boldsymbol{k}_{i}^{\intercal}\boldsymbol{q}_{t})\in\mathbb{R}^{d_{v} }\) 也有了一定的按照 Q 和 K 的相似度对 V 进行加权的思想,但这里主要差异在于 Transformer 是 Softmax 的 Attention 权重:
      $$Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax\left(\frac{\boldsymbol{Q}\boldsymbol{K}^{\top}}{\sqrt{d_k}}\right)\boldsymbol{V}$$
    • 理解:线性 Attention 和 Softmax Attention 都是 QK 内积越大,V 权重越大,但两者有本质区别:
      • 线性 Attention 未实现归一化,仅仅累加所有 Token 的 V,理论上会导致越靠后的 Token,方差是越大的
      • Softmax Attention 的权重是 \(e^{qk}\) 加权归一化,本质与线性归一化是不同的
      • 问题:如果线性 Attention 中使用 \(e^{qk}\) 作为累加,是否基本上可以实践 Softmax 的等价实现?
        • 回答:不可以,因为 Softmax Attention 中的 权重是 \(e^{qk}\),而线性 Attention 累加的对象是 \(kv\),重点:即使使用 \(e^{kv}\) 累加,累加对象也不同,反而导致含义变了!线性 Attention 的形式导致了他们无法实现 Softmax Attention 这样的 \(e^{qk}\) 加权平均
  • 为了解决这个问题,通常添加一个衰减项来遗忘历史信息
    • 这里论文以 Mamba2 (2024a) 为例,它可以表示为以下线性递归(取决于具体的参数化):
      $$
      \begin{align}
      \mathbf{S}_{t}&=\alpha_{t}\mathbf{S}_{t-1}+\boldsymbol{v}_{t}\boldsymbol{k}_{t}^ {\intercal},\\
      \boldsymbol{o}_{t}&=\mathbf{S}_{t}\boldsymbol{q}_{t}
      \end{align}
      $$
    • 其中 \(\alpha_{t}\in(0,1)\) 是一个数据依赖的标量值衰减项,随 \(t\) 变化
  • 定义累积衰减乘积
    $$\gamma_{j}=\prod_{i=1}^{j}\alpha_{i}$$
  • 并通过展开递归,我们可以将结果表示为向量形式(左)和矩阵并行形式(右):
    $$
    \begin{align}
    \boldsymbol{o}_{t}&=\sum_{i=1}^{t}\left(\frac{\gamma_{t} }{\gamma_{i} }\boldsymbol{v }_{i}\boldsymbol{k}_{i}^{\intercal}\right)\boldsymbol{q}_{t}=\sum_{i=1}^{t}\boldsymbol{v}_{i}\left( \frac{\gamma_{t} }{\gamma_{i} }\boldsymbol{k}_{i}^{\intercal}\boldsymbol{q}_{t}\right),\\
    \mathbf{O}&=((\mathbf{Q}\mathbf{K}^{\intercal})\odot\Gamma),\mathbf{V}
    \end{align}
    $$
    • 这里,\(\Gamma\in\mathbb{R}^{L\times L}\) 是一个衰减感知的因果掩码,其中 \(\Gamma_{ij}=\frac{\gamma_{i} }{\gamma_{j} }\)(如果 \(i\geq j\)),否则 \(\Gamma_{ij}=0\)
  • 这种并行形式和递归形式之间的等价性也被称为 状态空间对偶性 (state space duality, SSD) (2024a)
    • 这种递归结构也出现在其他几种架构中,包括 Gated RFA (2021)、xLSTM (2024) 和 Gated RetNet (2023a)
    • 当 \(\gamma_{t}\) 与数据无关时(data-independent),该公式简化为 RetNet (2023a) 和 Lightning-Attention (2024a)
  • 此外,如果 \(\gamma_{t}\) 扩展为矩阵值而非标量值,当使用外积结构参数化时,高效的训练算法仍然是可能的,正如 (2024a) 所展示并被 (2024b;2024;2025 等) 所使用的 分块训练 (Chunkwise training)
    • 但递归形式和并行形式对于高效训练来说都不是理想的 (2022a; 2024a),这促使了使用分块并行形式 (2022a; 2023a) 进行硬件高效的线性时间训练,如下所述
  • 总结来说,分块并行形式将输入和输出分割成几个大小为 \(C\) 的块 (chunk),并根据前一个块的最终状态以及当前块的查询/键/值块来计算每个块的输出
    • 遵循 (2023a); (2024a) 的符号,论文以查询块 \(\boldsymbol{q}\) 为例
  • 论文将
    • \(\mathbf{Q}_{[t]}:={\boldsymbol{q} }_{tC+1:(t+1)C+1}\) 表示为块 \(t\) 的查询块
    • \({\boldsymbol{q} }_{[t]}^{r}:={\boldsymbol{q} }_{tC+r}\) 表示为块 \(t\) 内的第 \(r\) 个查询块
    • \(t\) 的初始状态定义为 \(\mathbf{S}_{[t]}:=\mathbf{S}_{[t]}^{0}=\mathbf{S}_{[t-1]}^{C}\)
  • 通过部分展开递归,论文有
    $$
    \begin{align}
    \mathbf{S}_{[t]}^{r}&=\mathbf{S}_{[t]}+\sum_{i=1}^{r}\boldsymbol{v}_{[t]}^{i}\boldsymbol{k}_{[t ]}^{i\tau}\in\mathbb{R}^{d_{v}\times d_{k} },\\
    \boldsymbol{\sigma}_{[t]}^{r}&=\mathbf{S }_{[t]}^{r}{\boldsymbol{q} }_{[t]}^{r}=\mathbf{S}_{[t]}{\boldsymbol{q} }_{[t]}^{r}+\sum_{i=1}^{r} \boldsymbol{v}_{[t]}^{i}\left(\boldsymbol{k}_{[t]}^{i\tau}{\boldsymbol{q} }_{[t]}^{r}\right)\in\mathbb{R}^ {d_{v} }
    \end{align}
    $$
  • 等价地,以矩阵形式表示:
    $$
    \begin{align}
    \mathbf{S}_{[t+1]}&=\mathbf{S}_{[t]}+\mathbf{V}_{[t]}\mathbf{K}_{[t]}^{\tau} \in\mathbb{R}^{d_{v}\times d_{k} },\\
    \mathbf{O}_{[t]}&=\mathbf{Q}_{[t]}{ \mathbf{S} }_{[t]}^{\tau}+\left(\mathbf{Q}_{[t]}\mathbf{K}_{[t]}^{\tau}\odot \mathbf{M}\right)\mathbf{V}_{[t]}\in\mathbb{R}^{C\times d_{v} }
    \end{align}
    $$
    • 其中 \(\mathbf{M}\in\mathbb{R}^{C\times C}\) 是因果掩码
    • 上述等式富含矩阵乘法 (matmuls),允许基于张量核心 (tensor-core) 的硬件优化
  • 这个分块算法可以轻松扩展到带衰减的线性注意力:
    $$
    \begin{align}
    \mathbf{S}_{[t+1]}&=\overrightarrow{\mathbf{S}_{[t]} }+\mathbf{V}_{[t]}^{\tau} \overrightarrow{\mathbf{K}_{[t]} }\in\mathbb{R}^{d_{v}\times d_{k} },\\
    \mathbf {O}_{[t]}&=\overleftarrow{\mathbf{Q}_{[t]}\mathbf{S}_{[t]}^{\tau} }+\left( \mathbf{Q}_{[t]}\mathbf{K}_{[t]}^{\tau}\odot\Gamma_{[t]}\right)\mathbf{V}_{[t ]}\in\mathbb{R}^{C\times d_{v} }
    \end{align} \tag{1}
    $$
    • 其中有
      $$(\Gamma_{[t]})_{ij}=\frac{\gamma_{[t]}^{i} }{\gamma_{[t]}^{j} },\gamma_{[t]}^{j}= \prod_{j=tC+1}^{tC+j}\alpha_{j} $$
      • 注:这里论文稍微滥用了 \(\gamma\) 的符号来表示每个块的累积乘积(分别从每个块的第一个位置开始),而不是整个序列
      • 这里论文使用左箭头 (\(\overset{\leftarrow}{\cdot}\)) 或右箭头 (\(\overset{\rightarrow}{\cdot}\)) 分别表示衰减到每个块第一个位置和最后一个位置的变量,
        $$
        \begin{align}
        \overleftarrow{ {\boldsymbol{q} }_{[t]}^{r} }&=\gamma_{[t]}^{r}{\boldsymbol{q} }_{[t]}^{r} \text{ 将每个向量衰减到块 t 的第一个位置 (decaying each vector to the first position of chunk t)}\\
        \overleftarrow{ {\boldsymbol{k} }_{[t]}^{r} }&=\frac{\gamma_{[t]}^{C} }{ \gamma_{[t]}^{c} }\boldsymbol{k}_{[t]}^{r} \text{ 将每个向量衰减到块 t 的最后一个位置 (decaying each vector to the last position of chunk t)} \\
        \overrightarrow{\mathbf{S}_{[t]} }&=\gamma_{[t]}^{C}\mathbf{S}_{[t]} \text{ 在整个块 t 上衰减状态矩阵 (decaying the state matrix over the entire chunk t)} \tag{2}
        \end{align}
        $$
        • 其他变量(例如 \(\overrightarrow{\boldsymbol{v} }\))也类似
  • Mamba2 中引入的 SSD 分解算法在很大程度上等同于这种分块算法
    • 对于更通用的方法, (2024a) 提出了一种扩展的分块算法,用于线性注意力,该算法结合了细粒度的衰减机制

Delta Networks: Linear Attention with Delta Rule

  • Delta 更新规则(Delta Update Rule) (1960; 2021a) 动态地擦除与当前输入键 (\(\boldsymbol{k}_{t}\)) 关联的旧值 (\(\boldsymbol{v}_{t}^{\text{old} }\)) ,并写入一个新值 (\(\boldsymbol{v}_{t}^{\text{new} }\)) ,该新值是基于“写入强度” \(\beta_{t}\in(0,1)\) 的当前输入值和旧值的线性组合 (可以将 \(\beta_{t}\in(0,2)\) 设置为允许负特征值,以解锁 DeltaNet 的状态跟踪能力 (2024; 2025))
    $$
    \begin{align}
    \mathbf{S}_{t}&=\mathbf{S}_{t-1}-\underbrace{(\mathbf{S}_{t-1}\boldsymbol{k}_{t})}_{\boldsymbol {v}_{t}^{\text{old} } }\boldsymbol{k}_{t}^{\top}+\underbrace{(\beta_{t}\boldsymbol{v}_{t}+(1-\beta_{t})\mathbf{S}_{t-1}\boldsymbol{k}_{t}))}_{\boldsymbol{v}_{t}^{\text{new} } }\boldsymbol{k}_{t}^{\top} \\
    &= \mathbf{S}_{t-1}\left(\mathbf{I}-\beta_{t}\boldsymbol{k}_{t}\boldsymbol{k}_{t}^{\top}\right)+ \beta_{t}\boldsymbol{v}_{t}\boldsymbol{k}_{t}^{\top}
    \end{align}
    $$
    • 如上所示,DeltaNet 实现了具有广义 Householder 转移矩阵 (\(\mathbf{I}-\beta_{t}\boldsymbol{k}_{t}\boldsymbol{k}_{t}^{\mathsf{T} }\)) 的一阶线性递归
    • 尽管在关联回忆和语言建模性能上表现出色 (2021a),但由于计算效率低下,DeltaNet 受到的关注有限,直到 (2024b) 引入了一种硬件高效的分块训练算法,详情如下
  • 分块并行形式 (Chunkwise parallel form)。通过部分展开递归,论文有
    $$\mathbf{S}_{[t]}^{r}=\mathbf{S}_{[t]}\underbrace{\left(\prod_{i=1}^{r} \mathbf{I}-\beta_{[t]}^{i}\boldsymbol{k}_{[t]}^{i}\boldsymbol{k}_{[t]}^{i\mathsf{T} }\right)}_{:= \mathbf{P}_{[t]}^{r} }+\underbrace{\sum_{i=1}^{r}\left(\beta_{[t]}^{i}\boldsymbol{v}_{[t ]}^{i}\boldsymbol{k}_{[t]}^{i\mathsf{T} }\prod_{j=i+1}^{r}\left(\mathbf{I}-\beta_{[t]}^{ j}\boldsymbol{k}_{[t]}^{j}\boldsymbol{k}_{[t]}^{j\mathsf{T} }\right)\right)}_{:=\mathbf{H}_{[t]}^{r} } \tag{3}$$
  • 其中 \(\mathbf{P}_{[t]}^{j}\) 涉及广义 Householder 矩阵的累积乘积,可以通过经典的 WY 表示 (1985) 进行优化:
    $$\mathbf{P}_{[t]}^{r}=\mathbf{I}-\sum_{i=1}^{r}\mathbf{w}_{[t]}^{i}\boldsymbol{k}_{[t]}^{ i\mathsf{T} }\in\mathbb{R}^{d_{k}\times d_{k} }\qquad\mathbf{w}_{[t]}^{r}=\beta_{ [t]}^{r}\left(\boldsymbol{k}_{[t]}^{r}-\sum_{i=1}^{r-1}\left(\mathbf{w}_{[t]}^{i}(\boldsymbol{k}_ {[t]}^{i\mathsf{T} }\boldsymbol{k}_{[t]}^{r})\right)\right)\in\mathbb{R}^{d_{k} } \tag{4}$$
  • 同样,\(\mathbf{H}_{[t]}^{r}\) 可以表示为:
    $$\mathbf{H}_{[t]}^{r}=\sum_{i=1}^{r}\mathbf{u}_{[t]}^{i}\boldsymbol{k}_{[t]}^{i\mathsf{T } }\in\mathbb{R}^{d_{v}\times d_{k} }\qquad\mathbf{u}_{[t]}^{r}=\beta_{[t]}^{r} \left(\boldsymbol{v}_{[t]}^{r}-\sum_{i=1}^{r-1}\left(\mathbf{u}_{[t]}^{i}(\boldsymbol{k}_{[t]}^{ i\mathsf{T} }\boldsymbol{k}_{[t]}^{r})\right)\right)\in\mathbb{R}^{d_{v} } \tag{5}$$
  • 并以矩阵形式表示:\(\mathbf{P}_{[t]}=\mathbf{I}-\mathbf{W}_{[t]}^{\top}\mathbf{K}_{[t]}\in\mathbb{R }^{d_{k}\times d_{k} }\),\(\mathbf{H}_{[t]}=\mathbf{U}_{[t]}^{\top}\mathbf{K}_{[t]}\in\mathbb{R}^{d_{v} \times d_{k} }\)。通过使用 UT 变换 (2006),我们可以进一步将 \(\mathbf{W}\) 和 \(\mathbf{U}\) 写成矩阵形式:
    $$\mathbf{T}_{[t]} =\left[\mathbf{I}+\text{strictLower}\left(\text{diag}(\beta_{[t]} )\mathbf{K}_{[t]}\mathbf{K}_{[t]}^{\mathsf{T} }\right)\right]^{-1}\text{ diag}\left(\beta_{[t]}\right)\in\mathbb{R}^{C\times C} \\
    \mathbf{W}_{[t]} =\mathbf{T}_{[t]}\mathbf{K}_{[t]}\in\mathbb{R}^{C\times d_{k} }, \qquad\mathbf{U}_{[t]}=\mathbf{T}_{[t]}\mathbf{V}_{[t]}\in\mathbb{R}^{C \times d_{v} }\tag{6-7}$$
  • 将这些代回方程 3,得到了一个硬件高效的 DeltaNet 分块算法,该算法利用了矩阵乘法,实现了基于张量核心的 GPU 优化:
    $$\mathbf{S}_{[t+1]} =\mathbf{S}_{[t]}\mathbf{P}_{[t]}+\mathbf{H}_{[t]}=\mathbf{S}_{[t ]}+\left(\mathbf{U}_{[t]}-\mathbf{W}_{[t]}\mathbf{S}_{[t]}^{\mathsf{T} }\right)^ {\mathsf{T} }\mathbf{K}_{[t]} \in\mathbb{R}^{d_{v}\times d_{k} } \\
    \mathbf{O}_{[t]} =\mathbf{Q}_{[t]}\mathbf{S}_{[t]}^{\mathsf{T} }+(\mathbf{Q}_{[t]} \mathbf{K}_{[t]}^{\mathsf{T} }\odot\mathbf{M})\left(\mathbf{U}_{[t]}-\mathbf{W }_{[t]}\mathbf{S}_{[t]}^{\mathsf{T} }\right) \in\mathbb{R}^{C\times d_{v} } \tag{8-9}$$

Gated Delta Networks

Formulation: Gated Delta Rule

  • 论文提出的门控 Delta 规则简单而有效:
    $$\mathbf{S}_{t}=\mathbf{S}_{t-1}\left(\alpha_{t}(\mathbf{I}-\beta_{t}\boldsymbol{k}_{t }\boldsymbol{k}_{t}^{\mathsf{T} })\right)+\beta_{t}\boldsymbol{v}_{t}\boldsymbol{k}_{t}^{\mathsf{T} } \tag{10}$$
    • 其中数据依赖的门控项 \(\alpha_{t}\in(0,1)\) 控制状态衰减
    • 这个公式统一了门控机制和 Delta 规则的优势:门控项支持自适应内存管理,而 Delta 更新结构促进有效的键值关联学习
  • 论文通过 Liu 等 (2024) 引入的在线学习框架的视角,对门控 Delta 规则进行了正式分析
    • 在这个框架中,循环状态更新作为在线学习问题的闭式解出现,如表 1 所示
    • 最近的线性 RNN 架构通常在其在线学习目标中包含一个正则化项,以防止状态偏离先前的值,从而实现记忆保留
      • 但当状态被信息饱和时,这种保留机制会变得有问题。在这种情况下,每个状态将对多个信息片段进行编码,使得精确检索变得困难
    • 为了解决这个限制,Mamba2 和 Gated DeltaNet 引入了一个自适应缩放因子 \(\alpha_{t}\),它放松了正则化项,允许 \(\mathbf{S}_{t}\) 和 \(\mathbf{S}_{t-1}\) 之间的受控偏差
      • 这个修改通过选择性遗忘实现了动态内存管理,这在过滤无关信息时可能很有用(见 章节3.2)
  • 另一方面,线性注意力 (Linear Attention, LA) 和 Mamba2 使用简单的负内积损失 -\(\langle\mathbf{S}_{t}\boldsymbol{k}_{t},\boldsymbol{v}_{t}\rangle\),而 Longhorn (2024) 使用更具表达力的在线回归目标 \(|\mathbf{S}_{t}\boldsymbol{k}_{t}-\boldsymbol{v}_{t}|^{2}\) 来更好地建模键值关联。由此产生的 Longhorn 更新规则与 Delta 更新规则非常相似,这表明(门控)Delta 规则在上下文关联回忆方面优于 Mamba2
    • Longhorn 更新规则与 Delta 更新规则理论上的区别在于优化方法:Longhorn 使用隐式在线学习 (2010) 来推导闭式的全局最优更新,而 DeltaNet 通过一步显式梯度下降来优化相同的目标,正如 (2024) 所指出的
  • 从快速权重编程 (fast weight programming) (2022a)、测试时训练 (test-time training) (2024a) 和回归 (2025) 的角度来看,隐藏状态 \(\mathbf{S}\) 可以被解释为一个(快速)权重矩阵,Delta 规则通过测试时随机梯度下降 (stochastic gradient descent, SGD) 优化在线回归目标 \(\mathcal{L}(\mathbf{S}_{t})=\frac{1}{2}|\mathbf{S}_{t}\boldsymbol{k}_{t}-\boldsymbol{v}_{t}|^{2}\):
    $$\mathbf{S}_{t+1}=\mathbf{S}_{t}-\beta_{t}\nabla\mathcal{L}(\mathbf{S}_{t})= \mathbf{S}_{t}-\beta_{t}(\mathbf{S}_{t}\boldsymbol{k}_{t}-\boldsymbol{v}_{t})\boldsymbol{k}_{t}^{\top}= \mathbf{S}_{t}\left(\mathbf{I}-\beta_{t}\boldsymbol{k}_{t}\boldsymbol{k}_{t}^{\top}\right)+\beta_{t}\boldsymbol{v}_{t}\boldsymbol{k}_{t}^{\top}$$
    • 其中 \(\beta_{t}\) 代表(自适应)学习率。从这个角度来看,门控 Delta 规则可以被视为将自适应权重衰减项 \(\alpha_{t}\) 纳入 SGD 更新中,这是一种在深度学习中广泛使用的技术 (1991; 2023)
    • 同时,Titans (2024) 证明了在 RNN 测试时 SGD 更新中结合权重衰减机制的有效性

Case study: Single Needle in a Haystack, S-NIAH,大海捞针

  • 为了更好地理解 Delta 规则和门控规则之间的互补优势,论文提供了一个关于来自 RULER (2024) 的“大海捞针”基准测试套件的案例研究,其中一个键值对充当“大海”(上下文)中的“针”,模型必须在给定键时回忆起值
  • 表 2 展示了结果,论文得出三个主要观察:
    • 衰减损害记忆保留 (Decay hurts memory retention)
      • 在最简单的 S-NIAH-1 设置中,具有重复的合成上下文,模型记忆的信息最少,测试长期保留能力
      • DeltaNet 在所有序列长度上都实现了接近完美的性能
      • Mamba2 在超过 2K 序列后性能显著下降,因为它衰减历史信息太快,而 Gated DeltaNet 的下降不那么严重,这得益于使用了 Delta 规则
    • 门控促进过滤 (Gating facilitates filtering)
      • 在具有真实世界文章上下文的 S-NIAH-2/3 中,模型存储所有潜在相关信息,测试高效的内存管理
      • 在固定状态大小的情况下,缺乏清除会导致内存碰撞,信息变得叠加且无法区分
      • DeltaNet 的性能在较长序列时显著下降,原因是内存清除能力差
      • Mamba2 和 Gated DeltaNet 通过过滤无关信息的门控机制保持了更好的性能
    • Delta 规则有助于记忆 (Delta rule helps memorization)
      • 在 S-NIAH-3 中,值从数字变为 UUID,测试复杂模式记忆
      • Mamba2 的性能迅速下降,而 Gated DeltaNet 表现更好,验证了 Delta 规则确实具有更好的记忆能力

Algorithm: Hardware-efficient Chunkwise training

  • 在本小节中,论文推导出用于训练 Gated DeltaNet 的硬件高效分块算法。通过部分展开方程 10 中的递归,论文有
    $$\mathbf{S}_{[t]}^{r}=\underbrace{\mathbf{S}_{[t]}\left(\prod_{i=1}^{ r}\alpha_{[t]}^{i}\left(\mathbf{I}-\beta_{[t]}^{i}\boldsymbol{k}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i\mathsf{T} }\right)\right)}_{\text{:=F}_{[t]}^{r} }+\underbrace {\sum_{i=1}^{r}\left(\beta_{[t]}^{i}\boldsymbol{v}_{[t]}^{i}\boldsymbol{k}_{[t]} ^{i\mathsf{T} }\prod_{j=i+1}^{r}\alpha_{[t]}^{j}\left(\mathbf{I}-\beta_{[t]}^{j }\boldsymbol{k}_{[t]}^{j}\boldsymbol{k}_{[t]}^{j\mathsf{T} }\right)\right)}_{ \text{:=G}_{[t]}^{r} }$$
    • 很容易看出 \(\mathbf{F}_{[t]}^{r}=\gamma_{[t]}^{r}\mathrm{P}_{[t]}^{r}=\hat{\overline{\mathbf {P} }_{[t]}^{r} }\)
    • 至于 \(\mathbf{G}_{[t]}^{r}\),论文调整方程 5 如下:
      $$\mathbf{G}_{[t]}^{r}=\sum_{i=1}^{r}\frac{\gamma_{[t]}^{r} }{\gamma_{[t]}^{i} }\tilde{\mathbf{u} }_{[t]}^{i}\boldsymbol{k}_{[t]}^{i\mathsf{T} }\in \mathbb{R}^{d_{v}\times d_{k} }\qquad\tilde{\mathbf{u} }_{[t]}^{r}=\beta_{[t]}^{r }\left(\boldsymbol{v}_{[t]}^{r}-\sum_{i=1}^{r-1}\left(\tilde{\mathbf{u} }_{[t]}^ {i}(\frac{\gamma_{[t]}^{r} }{\gamma_{[t]}^{i} }\boldsymbol{k}_{[t]}^{i\mathsf{T } }\boldsymbol{k}_{[t]}^{r})\right)\right)\in\mathbb{R}^{d_{v} }$$
      • (证明见附录 A)
    • 通过 UT 变换,论文得到矩阵形式:
      $$\widetilde{\mathbf{U} }_{[t]}^{-}=\left[\mathbf{I}+\text{strictLower}\left( \operatorname{diag}\left(\beta_{[t]}\right)(\Gamma_{[t]}\odot\mathbf{K}_{[t]} \mathbf{K}_{[t]}^{\mathsf{T} })\right)\right]^{-1}\operatorname{diag}\left( \beta_{[t]}\right)\mathbf{V}_{[t]}\qquad\in\mathbb{R}^{C\times d_{v} }$$
  • 类似于 Mamba2 扩展线性注意力的方式(方程 1),我们可以调整 DeltaNet 的分块算法(方程 8-9)用于 Gated DeltaNet,以实现硬件高效训练,如下所示:
    $$
    \begin{align}
    \mathbf{S}_{[t+1]}&=\overrightarrow{\mathbf{S}_{[t]} }+\left( \overrightarrow{\mathbf{U} }_{[t]}-\overleftarrow{\mathbf{W} }_{[t]}\mathbf{S}_{[ t]}^{\mathsf{T} }\right)^{\mathsf{T} }\overrightarrow{\mathbf{K} }_{[t]} \in\mathbb{R}^{d_{v}\times d_{k} } \\
    \mathbf{O}_{[t]}&=\overleftarrow{\mathbf{O} }_{[t]}\mathbf{S}_{[t]} ^{\mathsf{T} }+\left(\mathbf{Q}_{[t]}\mathbf{K}_{[t]}^{\mathsf{T} }\odot \mathbf{M}\right)\left(\overrightarrow{\mathbf{U} }_{[t]}-\overleftarrow{ \mathbf{W} }_{[t]}\mathbf{S}_{[t]}^{\mathsf{T} }\right)\in\mathbb{R}^{C\times d_ {v} }
    \end{align}
    $$
    • 其中 \(\overleftarrow{\boldsymbol{q} }_{[t]}^{r}=\gamma_{[t]}^{r}\boldsymbol{q}_{[t]}^ {r}\), \(\overleftarrow{\mathbf{w} }_{[t]}^{r}=\gamma_{[t]}^{r}\mathbf{w}_{[t]}^{r}\), \(\overleftarrow{\boldsymbol{k} }_{[t]}^{r}=\frac{\gamma_{[t]}^{C} }{\gamma_{[t]} } \boldsymbol{k}_{[t]}^{r}\), 并且 \(\overrightarrow{\mathbf{S} }_{[t]}\neq\gamma_{[t]}^{C}\mathbf{S}_{[t]}\) 如同论文在方程 2 中的定义

Gated Delta Networks and Hybrid Models

  • Token 混合器块 (Token mixer block)
    • 基本的 Gated DeltaNet 遵循 Llama 的宏观架构,将 Token 混合器层与 SwiGLU MLP 层堆叠,但用门控 Delta 规则 Token 混合取代了自注意力
    • 图 1(右)显示了其块设计
      • 对于门控 Delta 规则(方程 10),查询、键和值 \(\{\boldsymbol{q},\boldsymbol{k},\boldsymbol{v}\}\) 通过线性投影、短卷积 (short convolution) 和 SiLU 生成,并对 \(\boldsymbol{q},\boldsymbol{k}\) 应用 L2 归一化以保持训练稳定性。\(\alpha,\beta\) 仅使用线性投影(论文对 \(\alpha\) 使用 Mamba2 的参数化,但为简洁起见省略了细节)
      • 遵循 Sun 等 (2023a),输出在应用输出投影之前通过归一化和门控处理
  • 混合模型 (Hybrid models)
    • 线性 Transformer 在建模局部偏移和比较方面存在局限性,并且其固定状态大小使得检索任务变得困难 (2023a)
    • 遵循最近的混合架构,如 Griffin (2024) 和 Samba (2024),论文将线性循环层与滑动窗口注意力 (sliding window attention, SWA) 相结合,产生了 GatedDeltaNet-H1
    • 论文还堆叠了 Mamba2、GatedDeltaNet 和 SWA,产生了 GatedDeltaNet-H2

Experiments

Setup

  • 论文的实验包括对最近最先进架构的全面比较,包括纯 Transformer 模型、基于 RNN 的方法和混合架构
  • 论文针对以下基线进行评估:RetNet (2023a)、HGRN2 (2024b)、Mamba (2023)、Mamba2 (2024a)、Samba (2024) 和 DeltaNet (2024b)
  • 为了公平比较,所有模型都在相同条件下训练,具有 1.3B 参数,使用从 FineWeb-Edu 数据集 (2024) 中采样的 100B Token
  • 论文使用 AdamW 优化器,峰值学习率为 4e-4,权重衰减为 0.1,梯度裁剪为 1.0
  • 学习率遵循余弦退火调度,具有 1B Token 的预热期 和 0.5M Token 的批次大小
  • 所有模型都使用词表为 32,000 的 Llama2 分词器
  • 对于序列建模,论文将训练长度设置为 4K Token ,Samba 和论文的混合模型使用 2K 的滑动窗口大小
  • 评估设置见附录 B.1,消融研究见附录 B.2

Common-sense reasoning

  • 在表 3 中,论文展示了具有 0.4B 和 1.3B 参数模型的语言建模困惑度以及在常识推理基准测试上的零样本准确率
  • Gated DeltaNet 在两个规模上都持续优于其他线性模型,包括 RetNet、HGRN2、Mamba、Mamba2 和 DeltaNet
  • 正如预期的那样,混合变体进一步提升了性能

In-context retrieval on real-world data

  • 表 4 展示了在 Arora 等 (2023a) 使用的真实世界回忆密集型任务上的结果
  • 正如预期的那样,线性循环模型与 Transformer 相比显示出显著的性能差距,而结合线性循环和注意力的混合模型在检索任务中优于纯注意力模型
  • 对于纯循环模型,尽管 DeltaNet 在合成的上下文检索任务上表现出色 (2024b),但其真实世界的检索性能落后于 Mamba2,这与论文在 S-NIAH-2 和 S-NIAH-3 中的观察一致(表 2)
  • Gated DeltaNet 由于其门控 Delta 规则,性能优于 DeltaNet 和 Mamba2,尽管改进幅度小于表 2
  • 论文将这种性能差距的缩小归因于未进行指令调优的小型语言模型容易产生重复错误,这些错误是这些任务中错误的主要来源(参见 Arora 等 (2023a, 附录 E))
  • 由于这个问题在很大程度上与更新规则的选择无关,因此模型之间的性能差异与表 2 相比不那么明显

Length extrapolation on long sequences

  • 如图 2 所示,论文评估了模型在六个长上下文基准测试中外推到长达 20K Token 序列的能力
    • 在 RNN 模型中,Gated DeltaNet 在所有任务中实现了最低的整体困惑度
    • 虽然论文在长度外推中观察到结果好坏参半,但 Gated DeltaNet 表现出相对更稳健的性能,表明其具有更好的内存管理能力
    • 混合模型通过利用注意力进行局部上下文建模,进一步改善了这一点,从而减轻了其循环组件在内存管理上的负担
    • 未来的工作将探索这些模型在更长序列上的能力

Long context understanding

  • 如表 5 所示,论文评估了模型在 LongBench (2023) 上的性能
  • 在循环模型中,Gated DeltaNet 显示出持续的优势,特别是在单文档问答、少样本上下文学习和代码任务中,分别展示了其在检索、上下文学习和状态跟踪方面的卓越能力

Throughput Comparison

  • 不同模型的训练吞吐量对比如图 3 所示
    • 正如论文的分析所示,论文提出的门控 delta 规则 (gated delta rule) 与原始 delta 规则相比仅引入了微小的开销,Gated DeltaNet 的吞吐量与 DeltaNet 基本相同
    • 由于它们具有更具表达力的转移矩阵 (transition matrices),两者都比 Mamba2 稍慢(约慢 2–3K tokens/sec)
  • 正如论文的分析所示,与原始 Delta 规则相比,提出的门控 Delta 规则仅引入了边际开销,Gated DeltaNet 实现了与 DeltaNet 基本相同的吞吐量
  • 由于它们具有更具表达力的转移矩阵,两者都比 Mamba2(2-3K Token /秒)稍慢
  • Transformer++ 在 2K 上下文窗口领域实现了最佳性能
    • 这得益于高度优化的 Flash-Attention-2 内核 (2023)
    • 因此,将 2K 窗口大小的滑动窗口注意力 (sliding window attention, SWA) 与其他令牌混合器 (token mixer) 相结合的混合方法,其吞吐量高于独立的混合器:Samba 优于 Mamba,而 Gated DeltaNet-H1 和 Gated DeltaNet-H2 优于 Gated DeltaNet
    • 值得注意的是,Gated DeltaNet-H1 在所有序列长度上均保持了可观的训练吞吐量,即使在短序列上也是如此

Related Work

Gated linear RNN

  • 大型线性循环语言模型因其训练和推理效率而备受关注
  • 线性循环神经网络领域已经从使用数据无关的衰减机制迅速发展为在更近期的架构中融入数据相关的衰减机制
    • 数据无关的衰减机制:例如 S4 (2022)、S5 (2023)、LRU (2023)、RWKV4/5 (2023) 和 RetNet (2023a) 等模型所例证
    • 数据相关的衰减机制:例如 HGRN1/2 (2023a; 2024b)、Mamba1/2 (2023; 2024a)、RWKV6 (2024)、GSA (2024)
  • 这一转变源于门控/遗忘机制(在 Mamba 中称为选择性机制)被证实的优势,这是一个源于门控循环神经网络文献 (2000) 的经典概念,其重要性一直被反复确认 (2015; 2018; 2023b; 2024b)
  • 现代遗忘门与传统设计(如 LSTM 中的遗忘门)的不同之处在于,它移除了对先前隐藏状态的依赖,仅依赖于输入数据
    • 这种修改使得在序列长度上能够实现高效的并行性 (2018; 2023b)
    • 缺乏遗忘门一直是 DeltaNet 的一个显著局限,而论文的门控扩展以一种自然、有效且硬件高效的方式弥补了这一差距
    • 论文也注意到最近的一项并行工作 RWKV-7 使用了类似的想法,但采用了更宽松的使用对角加低秩转移的形式化表示:
      $$\mathbf{S}_{t}=\mathbf{S}_{t-1}(\mathrm{diag}(\mathbf{d}_{t})-\mathbf{a}_{t}\mathbf{b}_{t}^{\top})+v_{t}\boldsymbol{k}_{t}^{\top}$$
    • 其中 \(\mathbf{d}_{t},\mathbf{a}_{t},\mathbf{b}_{t}\in\mathbb{R}^{d_{k} }\)
    • 分块算法可以类似地适用于这种情况,正如在 Flash Linear Attention (2024) 中所实现的那样

Delta rule

  • Delta 学习规则在记忆容量上表现出优于赫布学习规则 (1988; 1989),DeltaNet 利用了这一点,而线性 Transformer 则依赖于类赫布规则
    • 这种记忆容量优势在合成上下文学习任务中表现明显,并延伸到语言建模 (2021; 2024b)、强化学习 (2022b) 和图像生成 (2023)
    • (2024b) 将 delta 规则计算并行化,并证明了 DeltaNet 的数据相关的单位加低秩结构 (\(\mathbf{I}-\beta_{t}\boldsymbol{k}_{t}\boldsymbol{k}_{t}^{\top}\)) 比 Mamba2 的数据相关的对角矩阵 (\(\alpha_{\mathbf{I} }\)) 提供了更大的灵活性
    • 这种结构优势可能实现复杂的推理,包括正则语言识别 (2024; 2024) 和超越 TC’ 复杂度的状态跟踪 (2024),这对于编码和推理应用至关重要
  • 尽管有这些显著优势,delta 规则面临理论局限 (2023) 并且在真实世界数据集上表现中等 (2024b),表明仍有改进空间
    • 先前通过非线性循环 (2021, 2022a) 来增强表达力的尝试解决了一些局限,但牺牲了训练并行性,造成了性能-效率的权衡
    • 最近的工作提出了一些不损害并行性的增强方法,以获得更好的状态跟踪性能,包括使用负特征值 (2024) 和使用多个户主转移矩阵的乘积 (2025),这实现了高秩变换
    • 这些方法可以无缝应用于 Gated DeltaNet
  • 从(在线)学习目标的角度来看,替代的公式化可以进一步扩展表达力:
    • 非线性回归 (\(\mathcal{L}(\mathbf{S}_{t})=\frac{1}{2}||\boldsymbol{f}_{\mathbf{S}_{t} }(\boldsymbol{k}_{t})-\boldsymbol{v}_{t}||^{2}\)),如 TTT (2024a) 和 Titans (2024) 中那样,其中 \(\boldsymbol{f}_{\mathbf{S} }\) 是由 \(\mathbf{S}\) 参数化的非线性函数;
    • 或者考虑整个历史的回归 (\(\mathcal{L}(\mathbf{S}_{t})=\frac{1}{2}\sum_{i=1}^{d}||\mathbf{S}_{t}\boldsymbol{k}_{i}-\boldsymbol{v}_{t}||^{2}\)),如 Mesa 层 (von 2024) 中那样——类似于最小均方算法和递归最小二乘算法之间的区别
    • 然而,这些更具表达力的变体引入了非线性循环,并且需要变通方法,例如在处理完整个分块后才执行非线性更新(如在 TTT 和 Titans 中);或者近似非线性循环方法,如 (2024; 2024; 2025)

Hybrid models

  • 在这项工作中,论文探索了在层间交错混合注意力层,这在诸如 MiniMax-01 (2025) 和 Hybrid Mamba2-Attention (2024) 中常用
  • 研究在单个层内混合线性/softmax 注意力也很有趣 (2022a; 2024; 2025)

NLP——EfficientAttention-Survey(THU-2025)

注:本文包含 AI 辅助创作

  • 参考链接:
    • 原始论文:(Attention Survey of THU 2025)Efficient Attention Mechanisms for Large Language Models: A Survey, THU, 20250807
      • 有豆包基金的支持

Paper Summary

  • 整体总结:
    • 本文是一篇高效 Attention 的 Survey
    • 整体涉及到了 Linear Attention 和 Sparse Attention
  • 问题提出:
    • Transformer-based 架构已成为 LLM 的主流 backbone
    • 自注意力机制存在的二次时间和内存复杂度,是实现高效长上下文建模的根本障碍
  • 当前解法:
    • 为解决这一限制,近期研究引入了两大类高效注意力机制
    • 机制一:Linear Attention 方法:
      • 通过核近似、循环公式或快速权重动态来实现线性复杂度,从而以更低计算开销实现可扩展推理
    • 机制二:稀疏注意力 (Sparse Attention) 技术:
      • 基于固定模式、分块路由或聚类策略,将注意力计算限制在选定的 Token 子集上,在保持上下文覆盖的同时提升效率
  • 本综述主要工作如下:
    • 第一:对上述机制的发展进行了系统而全面的概述,结合了算法创新和硬件
    • 第二:分析了高效注意力在大规模预训练语言模型中的整合方式
      • 包括完全基于高效注意力构建的架构,以及结合局部和全局组件的混合设计

综述引言和一些讨论

  • Transformer-based 架构 (2017) 已成为现代 LLM 事实上的 backbone 选择
    • 标准的自注意力机制仍然是一个显著的计算瓶颈,其时间和内存复杂度相对于输入序列长度呈二次方增长
    • 这一限制对 LLM 在处理日益增长的长上下文时,同时保持高性能和高效率,提出了巨大挑战
  • 为解决此问题,出现了两个主要方向来降低 softmax 注意力 (softmax Attention) 的时间和空间复杂度
    • 第一种机制是Linear Attention (2020; 2020; 2021; 2021; 2023; 2024):
      • 通过将 softmax 注意力重新参数化或近似为线性操作来降低注意力复杂度
    • 第二种候选方案是稀疏注意力 (Sparse Attention) (2019; 2021; 2021; 2022; 2024):
      • 基于固定或动态的稀疏模式,将注意力计算限制在完整键空间的某个子集上
    • 虽然这两种方法都旨在提高效率,但它们在公式、设计选择和硬件影响方面存在显著差异
  • 本综述全面回顾了高效注意力 (Efficient Attention) 机制的最新进展,同时关注算法原理和系统级实现
    • 在此基础上,论文还研究了采用这些高效注意力的预训练 LLM
  • 论文将线性注意力方法分为三大范式
    • 范式一:核化线性注意力 (kernelized linear attention):
      • 通过特征空间内的内积来近似 softmax 核,借助随机特征映射 (2020; 2021) 或固定正映射 (2020) 实现线性复杂度
    • 范式二:带有遗忘机制 (forgetting mechanism) 的循环线性注意力 (recurrent linear attention):
      • 引入了位置感知的循环,通过数据无关 (2021) 或数据相关 (2023; 2023) 的衰减机制来控制过去信息随时间衰减的方式,从而实现对长序列的建模
    • 范式三:基于快速权重 (fast-weight) 和元学习 (meta-learning) 的公式:
      • 将线性注意力重新解释为在线优化的记忆更新过程,其中如 DeltaNet (2021; 2024) 和 TTT (2024) 等模型将快速学习动态直接整合到状态演化中
  • 作者还探讨了线性注意力在硬件友好的表示形式(包括并行、循环和分块形式)重点分析了它们在计算复杂度、内存占用以及与训练或推理工作流兼容性方面的权衡
  • 论文将稀疏注意力分为固定模式稀疏度 (fixed-pattern sparsity)、分块稀疏度 (block sparsity) 和基于聚类的稀疏度 (clustering-based sparsity)
    • 固定模式稀疏度采用静态的 Token-level 掩码,如滑动窗口、扩张位置或指定的全局 Token ,提供了简单性和硬件友好性 (2020; 2019; 2023; 2023)
    • 分块稀疏度在分块粒度上选择或路由注意力,可以通过启发式评分 (2021; 2022; 2025)、可训练的选通机制 (2024; 2024) 来实现,从而实现结构化的内存访问和高效的 GPU 利用率
    • 基于聚类的稀疏度使用基于内容或位置感知的分组方法(如 k-means 或 LSH)来组织键值对,以降低内存开销并促进基于语义感知的检索 (2023; 2020; 2024)
  • 最后,本综述还讨论了将稀疏模式扩展到编码器风格模型的双向稀疏设计
    • 这些方法在稀疏粒度、选择机制以及与 FlashAttention (2023) 等硬件原语的契合度上有所不同,共同代表了现代 Transformer 中实现高效长上下文建模的基础
  • 近期已有文章将高效注意力机制整合到工业级预训练语言模型中
    • 这包括纯高效架构:如线性注意力和状态空间模型,以及结合了局部和全局注意力模式的混合设计
      • 像 EAGLE (2024)、Falcon Mamba (2024) 和 MiniCPM4 (2025) 这样的模型展示了纯线性或稀疏方法在数十亿参数规模上的可扩展性,在提供强性能的同时实现了恒定时间推理
    • 同时,混合模型 (2020; 2022; 2024; 2025; 2025; 2025) 交错使用密集、稀疏和局部注意力,以平衡计算效率和上下文建模能力
      • 反映了现代 LLM 中朝着组合化、硬件感知注意力设计发展的趋势
  • 论文的目标是为理解注意力机制在算法和硬件双重约束下的演进,以及这些设计如何被整合到可扩展的 LLM 架构中,提供一个统一的框架
    • 通过将理论洞见与实际实现相结合,作者希望本综述能为致力于高效且可部署模型设计的研究者和从业者提供有价值的参考
  • 论文将讨论安排如下:
    • 第 2 节:介绍Linear Attention,涵盖其在不同模型世代中的演变、相关的设计原理以及对硬件实现的影响
    • 第 3 节:介绍稀疏注意力 (Sparse Attention),对稀疏模式进行分类,分析部署场景,并提供实用的系统级设计建议
    • 第 4 节:回顾整合了高效注意力机制的预训练语言模型 (Pretrained Language Models),包括统一高效架构和整合了局部、稀疏、密集注意力的混合模型
    • 第 5 节:对未来的发展方向进行展望 (Outlook),讨论算法和硬件对齐研究中的开放挑战和潜在进展

Linear Attention

  • 传统的线性注意力方法旨在以与序列长度成线性关系的方式近似基于 softmax 的注意力机制
    • 核心思想是用一种基于核函数(kernel)的注意力权重近似来替代计算代价高昂的 softmax 计算
  • 在标准的自注意力中,每个输出是值 \(V\) 的加权和,权重由查询-键相似度经过 softmax 得到:
    $$
    \text{Attn}(Q,K,V)=\text{softmax}(QK^{\top})V, \tag{1}
    $$
    • \(Q,K,V \in \mathbb{R}^{L \times d}\)
    • (\(L\) 是序列长度
    • \(d\) 是每个头的模型维度)
  • softmax 为查询 \(q_i\) 和键 \(k_j\) 产生权重 \(\propto \exp(q_i^{\top}k_j)\)
  • 核化线性注意力则寻找一个特征映射 \(\phi(\cdot)\),使得 softmax 核函数可以在诱导出的特征空间中通过一个简单的点积来近似:\(\exp(q^{\top}k) \approx \phi(q)^{\top} \phi(k)\)(2019)
    • 给定这样的 \(\phi\),可以将注意力重写为:
      $$
      O = \frac{\phi(Q)(\phi(K)^{\top}V)}{\phi(Q)(\phi(K)^{\top}\mathbf{I})} \tag{2}
      $$
  • 由于 \(\exp(\cdot)\) 的值域是非负的,\(\phi(\cdot)\) 通常被选择为产生非负输出,同时应用归一化除数来模拟 softmax 概率
  • 这种重新表述将复杂度从 \(O(L^2 d)\) 降低到 \(O(L d^2)\)(或者在适当的特征降维下甚至可以达到 \(O(L d)\)),因为昂贵的 \(L \times L\) 注意力矩阵从未显式形成
  • Linear Transformer (2020)
    • 用一个固定的正特征映射替换了 softmax 核函数
    • 在实践中,他们设置 \(\phi(x) = \operatorname{ELU}(x) + 1\)
    • \(\operatorname{ELU}(\cdot)\) 在整个定义域内可微,与朴素的 \(\operatorname{ReLU}(\cdot)\) 函数相比表现出更好的性能
  • Performer (2020)
    • 引入了 FAVOR+(一种能够无偏估计 softmax 核函数的随机特征方案)
    • 它对随机特征映射 \(\phi\) 进行采样,使得 \(E[\phi(Q)\phi(K)^{\top}] = \exp(QK^{\top})\)
    • 这产生了一个仅需 \(O(N)\) 操作即可证明是完整 softmax 注意力的无偏估计量
    • 特别地,Performer 使用正交互随机特征,降低了近似中的方差
  • 随机特征注意力 (Random Feature Attention,RFA) (2021)
    • 一种通过为 softmax 核函数使用随机傅里叶特征构建的线性注意力
    • 与 Performer 类似,RFA 利用随机映射和三角激活来近似 softmax
    • RFA 在随机投影之前进一步对查询和键进行归一化以减少方差
    • RFA 还有一个变体 RFA-Gate,它增加了一个可选的选通机制以引入近因偏差
  • cosFormer (2022)
    • 使用余弦函数来近似 softmax
    • 由于 \(\cos(a+b) = \cos a \cos b - \sin a \sin b\),cosFormer 将余弦重加权注意力 \(S_{ij} = Q_i’ K_j’ \cos(\frac{\pi}{2} \times \frac{i-j}{M})\) 分解为线性注意力的形式
  • HedgeDog (2024)
    • 利用了一个尖峰核函数 \(\phi(x) = \exp(Wx + b)\),因为他们观察到 Transformer 和线性 Transformer 之间的性能差距源于缺乏尖峰和单调性属性
    • HedgeDog 展示了更好的注意力熵和单调性

Linear Attention with Forgetting Mechanism

  • 最近的一系列工作通过循环神经网络或连续状态空间模型的视角来解释注意力
  • 传统的线性注意力通常是无位置感知的,其中循环顺序对输出没有影响,但现代的线性注意力表现得更像具有状态追踪和隐藏记忆的 RNN
  • 因此,这些模型明确地结合了循环、门控或状态动力学,以线性复杂度处理长序列
  • 衰减因子是引入遗忘机制的最重要因素
Data-Independent Decay
  • Retentive Networks(RetNet) (2023)
    • RetNet 引入了一种 Retention 机制,它使用固定的衰减系数,以一种循环风格的更新来取代注意力
    • 在 RetNet 层中,每个时间步 \(t\) 维护一个状态向量 \(s_t\),该向量以指数遗忘的方式聚合过去的输入
    • 循环可以写为:
      $$
      S_t = \gamma S_{t-1} + k_t^{\top} \nu_t \tag{3}
      $$
      • \(\gamma \in (0,1)\) 是一个学习的衰减因子(每个保持头)
      • \(k_t^{\top} \nu_t\) 是来自当前 Token 的新贡献(\(\nu_t\) 是 \(x_t\) 的值投影,\(k_t\) 是键投影)
    • 然后通过一个线性“查询”投影获得输出:\(o_t = q_t s_t\);展开方程 3 得到保持的显式公式:
      $$
      o_t = q_t S_t = \sum_{n=1}^{t} \gamma^{t-n} q_t k_t^{\top} \nu_t \tag{4}
      $$
    • 这表明 Token \(n\) 的贡献在时间步 \(t\) 时以因子 \(\gamma^{t-n}\) 呈指数衰减
    • Crucially,\(\gamma\) 是一个数据无关的衰减,是层的固定参数(在多头保持中通常每个头一个),而不是输入内容的函数
      • 这使得 RetNet 能够像 RNN 一样进行 \(O(1)\) 的内存更新,同时仍然允许通过等效的矩阵公式在训练期间进行并行计算(例如,可以证明方程 3 等价于一个“保持矩阵”形式,\(\text{Retention}(X) = (Q K^{\top} \odot D) V\),其中对于 \(t \geq n\),\(D_{t,n} = \gamma^{t-n}\) 实现了衰减和因果掩码)
    • RetNet 的保持机制与其他数据无关的循环模型有共同的主题
  • Eagle (2024)
    • 通过外积记忆改进了 RWKV 设计,这等效于线性注意力
    • 在 RWKV 系列中,衰减因子参数化为 \(\gamma = \exp(-\exp(w))\),其中 \(w\) 是一个数据无关的可学习因子
    • 在实践中,RetNet 和 Eagle 都使用固定衰减来遗忘旧信息,实现了线性推理扩展和具有竞争力的性能
    • 经验上,RetNet 每个头使用一个固定标量 \(\gamma\)(通常每层有多个具有不同 \(\gamma\) 值的保持头,形成一种多尺度衰减),而 Eagle 使用可学习的标量 \(w\) 来参数化衰减因子
  • Lightning Attention (2023; 2024)
    • Lightning Attention 也提出了一种线性注意力层 ,每个头增加了一个固定标量衰减,以实现长度不变的计算速度
    • 在 Lightning Attention 中,对于某个常数 \(\lambda\)(\(\lambda\) 由模型学习或设置),隐藏状态本质上是 \(s_t = \lambda s_{t-1} + k_t^{\top} \nu_t\),这与 RetNet 的 \(\gamma\) 精神相同,但针对硬件效率进行了优化
  • H3 (2022)
    • H3 将循环状态空间模型 (2021) 引入线性注意力,通过 SSM 为键值外积隐藏状态使用学习的、数据无关的指数衰减
    • 线性注意力通过分块计算实现了高效的训练,但 H3 需要为 SSM 计算显式状态扩展,从而限制了头维度,导致表达能力有限
  • In summary,数据无关的衰减方法维持一个随时间以预定速率衰减的持久状态,实现了 \(O(1)\) 循环和每步恒定内存
    • 它们牺牲了一些适应性 ,这促使了在更近期的模型中引入数据相关机制
Data-Dependent Decay
  • 虽然固定衰减提供了简单性和速度,但它们可能未能充分利用输入流中的信息
  • 门控或数据相关的方法使遗忘因子本身成为当前输入的学习函数
  • 这种循环更新的一般形式是:
    $$
    S_t = G_t S_{t-1} + k_t^{\top} \nu_t \tag{5}
    $$
    • \(S_{t-1}\) 是前一个状态
    • \(G_t\) 是由 Token \(x_t\) 确定的门控张量
      • 如果 \(G_t\) 在某个分量上接近 0,则该分量中的过去状态在时间 \(t\) 被大量遗忘;
      • 如果 \(G_t \approx 1\),则过去状态被保留
    • 与 RetNet 中的常数 \(\gamma\) 不同,这里的 \(G_t\) 通过 \(x_t\) 随 \(t\) 变化
  • 在大语言模型设计中,这种策略的两个显著例子是 Mamba (2023) 和门控线性注意力 (Gated Linear Attention, GLA) (2023)
  • Mamba
    • 一种循环状态空间模型,它赋予状态衰减率以输入依赖性
    • 在每个 Mamba 层中,基本的状态演化类似于 S4 (2021),但状态矩阵实际上变得动态
    • \(G_t\) 是一个范围在 0 到 1 之间的分组向量,作为动态遗忘门
      • 这弥合了注意力和纯 SSM 之间的差距
    • 实证结果表明,Mamba2 在语言建模任务上可以超越相似甚至更大规模的 Transformer,凸显了数据相关衰减在长序列建模中的能力
  • GLA
    • 直接在线性注意力中引入了门控机制,将一个门控函数嵌入到线性化的注意力层中以提高其表达能力
    • GLA 通过一个可学习的 Element-wise 遗忘门 \(G_t\) 来修改保持循环
  • 除此之外,其他几个模型同样赋予了其循环以内容相关的门控
  • xLSTM (2024)
    • 用线性门控信号(带归一化)的指数变换取代了标准的 sigmoid 遗忘门,对其单元状态产生了平滑的、输入条件化的衰减
  • GateLoop (2023)
    • 在保持机制上应用了头级门控,实现了一个简单但有效的数据相关衰减,同时保持了高效的硬件实现
  • HGRN (2023)
    • 在线性 RNN 中引入了门控循环
    • HGRN2 (2024) 进一步在 HGRN 框架中增加了状态扩展
    • 状态扩展等效于线性注意力中的键值外积
  • Finch (2024)
    • 在 Eagle 上使用了数据相关的门控
    • 由于 Eagle 与其他正交修改的保持机制相似,Finch 也显示出与上述模型的深厚联系
  • In summary,数据相关衰减模型通过基于内容的门控来增强线性注意力或 RNN 风格的架构,这些门控控制着信息的流动
    • 论文中的结果表明,这些模型在语言任务上通常能够匹配或超越 Transformer 的性能 ,同时能够扩展到非常长的输入

Linear Attention as In-Context Learners

  • 除了线性注意力机制带来的效率提升外,一项重大进展在于它们应用于增强上下文学习能力
    • 这指的是模型能够从给定的 Prompt 中快速适应或学习,而无需对其预训练权重进行显式的梯度更新
    • 理解:这里是说注意力机制本身是上下文学习期,即增加一些 Prompt Token,就可以借助修改注意力内容来实现对模型的快速学习
  • 大型 Transformer 模型通过将 Prompt 解释为一种训练数据的形式,固有地表现出上下文学习
    • 但最近的创新已将快速学习规则直接整合到注意力机制中,有效地将序列处理视为一个在线训练过程
    • 问题:这里的方法具体是什么?
  • FWP (2021) 建立了现有线性注意力机制与快速权重编程器之间的形式等价性
    • 在 FWP 范式中,一个慢速神经网络学习去编程另一个神经网络的“快速权重”,通常通过自发明的键和值模式的外积加法来实现
  • 表 1:
    • 不同线性注意力变体的更新规则
    • 每个模型都是关于矩阵记忆 \(S_t\) 的循环
Learning Objective
  • 从元学习的角度来看,这些模型定义了一个在推理过程中优化的隐式学习目标
  • 用 \(q_t, k_t, \nu_t\) 表示时间步 t 的查询、键和值,上下文记忆 \(S_t\) 通过以下目标进行优化:
    $$
    \mathcal{L}_t(S) = \frac{1}{2} || f_S(k_t) - \nu_t ||^2 \tag{6}
    $$
  • DeltaNet
    • DeltaNet 融合了经典的 delta 规则 (2021),其中 \(f_S(k_t) = S k_t\)
    • DeltaNet 更新规则为
      $$ S_t = S_{t-1} + \eta_t (\nu_t - S_{t-1} k_t) k_t^{\top} $$
      • 可以通过最小化当前记忆检索 \(S_{t-1} k_t\) 与新值 \(\nu_t\) 之间的误差推导出来
      • 这标志着朝着在线学习键值映射、基于即时上下文改进记忆迈出了一步
  • TTT (2024)
    • TTT 用不同的建模架构概括了元学习目标:
      $$
      f_S(k_t) = \begin{cases}
      \text{LM}(S k_t) + k_t, & \text{TTT-Linear} \\
      \text{LM}(\text{MLP}_S(k_t)) + k_t, & \text{TTT-MLP}
      \end{cases} \tag{7}
      $$
    • 上下文网络 \(f_S\) 增强了上下文元学习的能力
    • 但由于 \(f_S\) 的梯度比简单的线性投影复杂得多,在线更新不能写成一个简单的规则
  • 批量更新 (Batch Update)
    • 批量更新试图解决当 \(f_S\) 作为神经网络工作时训练并行性的困难
    • 通常,上下文记忆是以批大小为 1 进行元学习的,这对于一般的 TTT 模型来说不可行
    • 相反,类似于分块并行,TTT 将整个块视为一个批次
    • 在批次内没有状态更新(即 \(S\) 保持不变)
    • 处理完批次后,\(S\) 使用来自批次中所有样本的聚合梯度或更新信号进行一次更新
    • 这种策略在保持并行效率的同时,适应了更复杂架构的训练要求
  • Momentum Titans (2025)
    • Titans 中引入了优化中常用的动量,以加强记忆更新机制的能力:
      $$
      \mathcal{M}_t = (1 - \alpha_t) \mathcal{M}_{t-1} + S_t \\
      o_t = q_t \mathcal{M}_t \tag{8}
      $$
    • 动量项允许记忆通过对状态 \(S\) 进行指数移动平均来逐渐积累信息
    • 这可以看作是元学习的一种形式,其中更新规则本身学会了在长序列上更加稳定和鲁棒
  • 权重衰减 (Weight Decay)
    • 权重衰减是训练中的另一种正则化技术,对应于线性注意力模型中的遗忘机制
    • Gated DeltaNet (2024) 和 Titans 在其记忆更新中使用了权重衰减,作为学习的遗忘门来限制非常旧或嘈杂数据的影响
    • 它对应于在 RetNet (2023) 和 Mamba (2023) 等架构中发现的选择性状态保持机制,其中衰减机制被证明对语言建模性能至关重要:
      $$
      S_n = \gamma_n S_{t-1} + \eta_t (v_t - S_{t-1} k_t) k_t^{\top} \tag{9}
      $$
  • In summary,线性注意力机制中的这些进步通过明确地将元学习原则整合到其架构中,正在推动上下文学习的边界
    • 通过快速权重更新、复杂的内存管理技术和在线学习规则,这些模型正朝着一种范式发展,在这种范式中,训练和推理之间的区别变得越来越模糊,从而产生了能够直接从上下文中学习和利用知识的更高效、适应性更强的大语言模型

Discussion on Other Designs

Element-wise Linear Attention
  • 无需注意力的 Transformer (Attention-Free Transformer) (2021)
    • 利用一个简单的权重 \(\exp(K_{\nu’} + w_{t,\nu’})\) 来代替 \(\exp(QK^{\top})\):
      $$
      O_t = \sigma_q(Q_t) \odot \frac{\sum_{\nu’=1}^{t} \exp(K_{\nu’} + w_{t,\nu’}) \odot V_{\nu’} }{\sum_{\nu’=1}^{t} \exp(K_{\nu’} + w_{t,\nu’})} \tag{10}
      $$
    • 其中 \(w_{t,\nu’}\) 是学习的成对位置偏置
    • 在 AFT 的变体中,AFT-Simple 移除了 \(w_{t,t’}\),实现了线性化的推理模式
    • 由于 \(K\) 和 \(V\) 的乘积是 Element-wise 的,循环状态大小是 \(\mathbb{R}^d\) 而不是外积状态 \(\mathbb{R}^{d \times d}\)
  • RWKV (2023)
    • 在 AFT-Simple 上利用了衰减机制。具体来说,RWKV 通过指数衰减 \(w_{t,i} = -(t-i)w\) 改进了 AFT 的位置偏置
    • 指数形式在引入位置偏置的同时保留了循环属性
  • Element-wise 线性注意力带来了强大的推理优势,但它受到状态大小瓶颈的影响,性能低于基于矩阵的状态大小
    • Besides,尽管 Element-wise 内存比外积内存快得多,但由于其他组件在拥有外积内存的情况下占据了超过 95% 的延迟 (2023),端到端的优势仍然有限
Multi-Pass Linear Attention
  • 带有限内存控制的注意力 (Attention with Bounded-memory Control)
    • 将线性注意力视为一个有界内存模块:
      $$
      \begin{split}
      \tilde{K}_n &= \sum_{i=1}^{n} K_i \otimes \phi_i, \quad \tilde{V}_n = \sum_{i=1}^{n} V_i \otimes \phi_i \\
      O_n &= \text{softmax}(Q_n \tilde{K}_n^{\top}) \tilde{V}_n
      \end{split} \tag{11}
      $$
    • 其中 \(\tilde{K}_n, \tilde{V}_n\) 是在线更新的、大小受限的键和值。在实现中,ABC 可以简化为两遍线性注意力
  • 门控 Slot 注意力 (Gated Slot Attention) (2024)
    • 进一步将 GLA 引入到 ABC 框架 (2021) 中
    • 由于 \(\tilde{K}_n, \tilde{V}_n\) 作为一个隐式的线性注意力工作,GSA 将更新改进为门控形式:
      $$
      \tilde{K}_n = \text{Diag}(\alpha_n) \tilde{K}_{n-1} + (1 - \alpha_n) \otimes K_n, \quad \tilde{V}_n = \text{Diag}(\alpha_n) \tilde{V}_{n-1} + (1 - \alpha_n) \otimes V_n \tag{12}
      $$
  • Multi-Pass 是增强线性注意力表达能力的一种有效方式,但它也带来了额外的计算开销,这使得架构设计需要在训练效率和性能之间进行权衡
Bidirectional Linear Attention
  • 双向注意力在编码器风格的架构(如 BERT (2019))中扮演着重要角色
  • 单向和双向注意力在线性表述中的关键区别在于推理瓶颈和计算模式
  • Decoder-only 模型通常表现出 \(O(N^2)\) 的复杂度,且 Decoder-only 模型中的每个 Token 都可以访问全局信息
    • 因此双向线性注意力通常维护一个恒定长度的全局 Token 池以降低复杂度,同时保留 softmax 函数的使用
  • 例如
    • Linformer (2020) 通过一个额外的矩阵投影将键和值的数量减少到一个恒定长度
    • Luna (2021) 通过跨模型层编码全局 Token 池进一步扩展了 Linformer 设计
  • 虽然双向线性注意力对 Decoder-only 架构有效,但这些设计在应用于因果设置时面临重大挑战,因为基于全局池的方法往往计算成本高昂
    • 因此,此类架构不太适合大语言模型

Hardware Implementation

  • 并行表示 (The Parallel Representation)
    • 可以将带门控衰减的因果线性注意力定义为:
      $$
      \begin{split}
      Q &= \phi(XW_Q), \quad K = \phi(XW_K), \quad V = XW_V, \quad \gamma = f_\gamma(X) \\
      D_{nm} &= \begin{cases}
      \prod_{i=m+1}^{n} \gamma_i, & n \geq m \\
      0, & n < m
      \end{cases}, \quad O(X) = \text{LN}((Q K^{\top} \odot D) V)
      \end{split} \tag{13}
      $$
    • 其中
      • \(W_Q, W_K, W_V \in \mathbb{R}^{d \times d}\)
      • \(f_\gamma\) 控制衰减的锐度
      • 矩阵 \(D \in \mathbb{R}^{N \times N}\) 编码了具有衰减模式的因果掩码,确保信息的单向流动
    • 当衰减是数据无关时,\(f_\gamma(\cdot) = \text{const} \in (0,1]\)
      • 注意,线性注意力后的分组归一化 (GroupNorm) (2018) 已经是一个强制性的组件 (2023),方程 2 中核化线性注意力的显式除数变得不必要
    • 并行表示简单易懂,但有两个缺点
      • 第一:并行形式仍然保持了与 softmax 注意力相同的 \(O(N^2)\) 复杂度
      • 第二:当表示第 2.3 节中的 ICL 风格线性注意力时,其复杂性会增加
  • 循环表示 (The Recurrent Representation)
    • 上述并行公式可以等效地表达为逐步解码的循环形式,如图 2(b) 所示,在每个时间步 \(n\),输出计算为:
      $$
      S_n = f_{\text{update} }(S_{n-1}, K_n, V_n), \quad O_n = Q_n S_n \\
      f_{\text{update} }(S_{n-1}, K_n, V_n) = \begin{cases}
      \gamma_n S_{n-1} + K_n^{\top} V_n, & \text{Linear Attention with Decay} \\
      \gamma_n S_{n-1} + \eta_n (V_n - S_{n-1} K_n) K_n^{\top}, & \text{ICL-style Linear Attention}
      \end{cases} \tag{14}
      $$
    • 这种循环表示通过维护单个状态向量 \(S_n\),实现了具有恒定内存的高效自回归生成
    • 虽然循环表示将计算复杂度从 \(O(N^2)\) 降低到 \(O(N)\),但它在训练期间会产生大量的内存开销
      • 因为 \(S_n\) 涉及存储 \(K_n\) 和 \(V_n\) 的外积,这对于长序列来说是极其昂贵的
      • 因此,循环形式通常仅限于解码阶段
  • 分块循环表示 (The Chunkwise Recurrent Representation)
    • 分块表示结合了线性复杂度和硬件友好的并行性的优势 (2022; 2023)
    • 如图 2(a) 所示,以衰减风格线性注意力为例,给定一个块大小 \(B\),令 \(x_{[i]}\) 表示第 \(i\) 块。定义块内的累积衰减为:
      $$
      \beta_{(i-1)B+j} = \prod_{k=(i-1)B+1}^{(i-1)B+j} \gamma_k, \quad D_{[i]}(j,k) = \begin{cases}
      \frac{\beta_{(i-1)B+k} }{\beta_{(i-1)B+j} }, & j \leq k \\
      0, & \text{Other wise}
      \end{cases} \tag{15}
      $$
    • 块级记忆状态 \(R_i\) 计算为:
      $$
      R_i = K_{[i]}^{\top} (V_{[i]} \odot \frac{\beta_{iB} }{\beta_{[i]} }) + \beta_{iB} R_{i-1} \tag{16}
      $$
    • 第 \(i\) 块的输出为:
      $$
      O_{[i]} = (Q_{[i]} K_{[i]}^{\top} \odot D_{[i]}) V_{[i]} + (Q_{[i]} R_{i-1}) \odot \beta_{[i]} \tag{17}
      $$
    • 这种表述提供了循环和并行的统一视图:
      • 第一项捕获块内依赖关系
      • 第二项通过单个矩阵-向量乘积传播块间记忆
      • 由于其效率和可并行性,分块表示通常在训练和预填充阶段被采用
  • 对于 ICL 风格的线性注意力,已经使用 Householder 变换 (2024) 开发了硬件友好的分块表示
    • 但对于更复杂的变体,如 TTT 和 Titans,构建显式的分块形式仍然具有挑战性
    • 相反,这些架构通常依赖于大的批次大小进行记忆更新,通过固定超参数有效地模拟分块计算
  • 内核级优化对于实现高性能至关重要
    • 广泛采用的 FLA (2024) 为许多常见的线性注意力模块提供了基于 Triton 的实现
    • 或者,开发者还提供了用 CUDA 或 TileLang (2024) 编写的自定义实现,可以用于进一步加速
  • 图 2:线性注意力的双重形式

Sparse Attention

  • 稀疏注意力(Sparse Attention)方法利用注意力计算中固有的稀疏特性,通过以下公式近似完整注意力计算:
    $$
    \text{Attn}(Q,K,V) = \text{softmax}(QK_{[\mathcal{S}]}^T)V_{[\mathcal{S}]} \tag{18}
    $$
    • 其中 \(\mathcal{S}(t)\) 是查询向量 \(Q(t)\) 所关注索引的子集
    • 不同的方法基于选择准确性和硬件效率的考量,设计了不同的选择标准 \(\mathcal{S}(t)\)
    • 目标是在预填充(prefilling)阶段实现亚线性或线性复杂度,或在解码(decoding)阶段控制固定的计算预算

Fixed-pattern Sparse Attention (固定模式稀疏注意力)

  • 一些研究工作利用 Token-level 别稀疏性的结构化模式,为注意力计算构建固定的稀疏掩码
  • Local Window Attention (局部窗口注意力)
    • 局部窗口注意力将每个查询限制在仅与固定滑动窗口 \(w\) 内的相邻 Token 交互,从而在保留局部上下文的同时降低内存和计算需求
    • Sparse Transformer (2019) 首先应用局部窗口(行)注意力,其中 \(w\) 接近 \(\sqrt{N}\),然后通过额外的列注意力来总结先前位置并在全局传播信息
    • GPT-3 (2020) 也采用了与 Sparse Transformer 类似的稀疏注意力模式
    • StreamingLLM (2023) 发现大量的注意力得分被分配给了输入序列的初始 Token ,他们称之为“注意力汇集点(attention sink)”。他们提出了一种简单的固定模式注意力,只保留汇集点 Token 和滑动窗口内的 Token 。例如,给定一个长度为 \(n\) 的输入序列,StreamingLLM 中查询 Token \(q_t\) 的选定 Token 子集 \(\mathcal{S}(t)\) 被定义为:
      $$
      \mathcal{S}(t) = \left\{ j \mid 0 \leq j \leq s \ \lor \ t-w \leq j \leq t \right\}, \ \forall t \in [1, n] \tag{19}
      $$
      • 其中 \(s\) 是汇集点 Token 的大小,\(w\) 是滑动窗口的大小
      • 为了获得更好的硬件效率,采用块粒度的 StreamingLLM (2024) 以块方式保留汇集点 Token 和局部 Token ,从而实现高效的内存加载和计算
  • Dilated Attention (膨胀注意力)
    • LongNet (2023)
      • 引入了膨胀注意力作为长上下文训练和推理的固定稀疏模式。膨胀注意力随着距离的增长呈指数级扩展注意力范围,从而将注意力的复杂度从 \(O(n^2)\) 降低到 \(O(n)\)。具体来说,在沿序列维度将输入分割成长度为 \(w\) 的片段后,从每个片段中以间隔 \(r\) 选择膨胀稀疏索引。第 \(i\) 个片段的选择索引为:
        $$
        \hat{I}_i = \left[iw, iw+r, iw+2r, …, (i+1)w-1 \right] \tag{20}
        $$
      • 稀疏化的片段 \(Q_{\hat{I}_i}, K_{\hat{I}_i}, V_{\hat{I}_i},\ i \in \{0, 1, …, \frac{n}{w}\}\) 被并行输入到注意力计算中,得到注意力输出 \(O\)。结合不同片段大小和膨胀率 \(\{r_i, w_i\}^{k}\) 的注意力输出,最终注意力计算如下:
        $$
        O = \sum_{i=1}^{K} \alpha_i O|_{r_i, w_i}, \quad \alpha_i = \frac{s_i}{\sum_{j} s_j} \tag{21}
        $$
      • 其中 \(s_i\) 表示 \(O|_{r_i, w_i}\) 的注意力 softmax 的分母
  • LogSparse (2019)
    • 采用指数稀疏注意力方案,其中每个位置仅关注 \(\log N\) 个 Token ,这可以看作是指数膨胀注意力的一个实例

Block Sparse Attention (块稀疏注意力)

  • 给定一个长度为 \(n\)、块大小为 \(b\) 的输入序列,可以将 \(Q, K, V \in \mathcal{R}^{n \times d}\) 各自划分为 \(\frac{n}{b}\) 个块,每个块大小为 \(b \times d\)
  • 目标是近似一个块级掩码 \(M \in \{0, 1\}^{n/b \times n/b}\),用于选择关键块进行计算,如图 3 所示
    $$
    \text{Attn}(Q, K, V)_i = \sum_{j=1}^{n/b} M_{ij} \cdot \text{softmax}(Q_i K_j^T) V_j \tag{22}
    $$
  • 块级选择对于在现代 GPU 上实现高效计算至关重要
  • 图 3: 块稀疏注意力:长序列被分成若干块,每个 Token 仅关注其局部窗口和 top-k 相关块
Block-Sparse Attention for Prefill (用于预填充的块稀疏注意力)
  • 使用块稀疏注意力进行预填充的方法,近似选择覆盖大部分注意力得分且具有高召回率的 Top-K 块,从而将注意力计算复杂度从 \(O(n^2)\) 降低到 \(O(K)\)
    $$
    S = \text{softmax}(QK^T - c(1 - M)) \\
    \min \ |S(M) - S_{\text{dense} }|
    $$
    • 其中 \(M\) 是如上定义的块级稀疏掩码,\(c\) 是一个大常数(例如 1e5),确保重要性较低的注意力权重在 softmax 计算后趋近于零
    • 块稀疏注意力的目标是在最小开销下实现更大的加速,同时尽可能保留更多的注意力权重
  • MInference (2024)
    • 观察到注意力权重中存在三种模式:
      • 流式(A 型)模式(Streaming (A Shape) Pattern)
      • 垂直斜线模式(Vertical-Slash Pattern)
      • 块稀疏模式(Block-Sparse Pattern)
    • 它离线确定每个注意力头的最佳模式,并在推理过程中基于指定的模式动态构建稀疏索引
  • FlexPrefill (2025)
    & 提出了一种上下文感知的稀疏注意力机制,能够实时动态调整注意力模式和计算预算
  • XAttention (2025)
    • 提出了一个块稀疏注意力框架,利用反对角线(antidiagonal)评分来预测注意力块的重要性,从而能够高效识别并剪除非必要块,实现高稀疏性和显著的计算增益
  • SpargeAttn (2025)
    • 同样在预填充阶段采用块级稀疏注意力,通过一个双阶段在线过滤过程完成:
      • 第一阶段快速预测注意力图以跳过某些矩阵乘法
      • 第二阶段应用 softmax 感知的过滤器以进一步消除不必要的计算
Block-Sparse Attention for Decode (用于解码的块稀疏注意力)
  • 用于解码的块稀疏注意力方法动态选择包含每个解码步骤最关键的 Token 的子集 \(S\) 的 \(K, V\) 向量,从而减少内存加载并提高效率
  • Quest (2024) 通过计算注意力权重的上界来近似每个块的关键性。对于块 \(K_i\),论文通过以下公式维护 Element-wise 的 Min 和 Max Key \(m_i\) 和 \(M_i\):
    $$
    m_{i,d} = \min(K_{i,d}), \quad M_{i,d} = \max(K_{i,d})
    $$
    • 其中 \(\min(\cdot)\) 和 \(\max(\cdot)\) 在每个维度 \(d\) 上 Element-wise 应用
  • 给定查询 \(q\),块 \(i\) 的近似注意力得分由下式给出:
    $$
    \text{score}_i = \sum_{j=1}^{d} \max(q_j \times M_{i,j}, q_j \times m_{i,j})
    $$
  • 然后选择得分最高的 Top-K 块作为注意力计算的稀疏子集 \(S\):
    $$
    S = \text{argtopk}(\text{score}, k)
    $$
  • DoubleSparsity (2024) 通过降低计算 \(QK^T\) 乘积的矩阵乘法维度来高效近似关键 Token
    • 它首先离线计算 \(QK^T\) 中的离群通道,记为 \(C\)
    • 然后选择具有最高近似注意力得分 \(\hat{s}\) 的 Top-K Token 作为稀疏子集 \(S\):
      $$
      Q_{\text{label} } = Q_{\{C\} }, \quad \hat{s} = Q_{\text{label} } K_{\text{label} }^T, \quad S = \text{argtopk}(\hat{s}, k)
      $$
  • ReSA (2025)
    • 结合了无需训练的块稀疏估计和 GQA 共享,有助于提高效率
    • ReSA 还提出了一个修正阶段来控制 KV 缓存累积误差
    • ReSA 在长序列生成任务上显示出优势
Routing-based Block-Sparse Attention (基于路由的块稀疏注意力)
  • 基于路由的块稀疏注意力通过可训练的 MLP 层学习每个 Token 块的重要性,该层在推理期间充当门控网络以选择关键块
  • Learnable Sparsity on Pretrained Models (预训练模型上的可学习稀疏性)
  • SeerAttention (2024)
    • 通过自蒸馏(self-distillation)的方式在预训练的 LLM 上训练门控网络
    • 为了获得每个块的重要性分数,它首先沿序列维度对 \(Q\) 和 \(K\) 进行池化,记为 \(P_q\) 和 \(P_k\)。下采样后的 \(Q, K\) 然后通过一个可学习的线性层 \(W_q\) 和 \(W_k\)
    • 投影后的 \(W_q P_q(Q)\) 和 \(W_k P_k(K)\) 的矩阵相乘结果通过 softmax 运算符作为门控过程:
      $$
      \text{score} = \text{softmax}((W_q P_q(Q)) \cdot (W_k P_k(K)))
      $$
    • 可学习的线性层通过自蒸馏的方式训练,以与原始 LLM 的 2D 最大池化结果对齐。蒸馏损失计算如下:
      $$
      gt = \text{MaxPool2D}(\text{softmax}(QK^T)), \quad loss = D_{KL}(gt \ || \ \text{score})
      $$
    • 在推理过程中,门控分数通过 Top-K 或阈值化来预测块级稀疏性,用于稀疏计算和效率提升
  • Training-aware Sparse Attention (训练感知的稀疏注意力)
    • Landmark (2023)
      • 提出使用特殊的“地标”(landmark) Token 来表示每个块,并训练注意力机制通过这些地标 Token 直接检索 Top-K 块
      • 然而,它没有在大型预训练模型上进行实验
    • MoBA (2025)
      • 将可训练的稀疏注意力集成到预训练阶段
      • 它提出了 Mixture of Block Attention,应用来自 MoE(专家混合)的 Top-K 机制作为门控机制,为每个查询 Token 决定关键块
      • 每个块的重要性分数通过查询 Token \(q\) 和块 \(K_i\) 沿 Token 维度的平均池化结果的内积计算:
        $$
        s_i = \langle q, P_{\text{mean} }(K_i) \rangle
        $$
      • 然后选择具有最高 \(s\) 分数的 Top-K 块用于计算 \(q\) 的注意力
      • Notably,MoBA 使用的 Top-K 块选择是不可微分的
        • 因此,在预训练阶段,稀疏模式仍然以无需训练的模式进行估计,从而实现高效的推理和加速的训练
    • NSA (2025)
      • 引入了一种训练感知的多粒度稀疏注意力机制,包含三个分支 \(C \in \{\text{cmp}, \text{slc}, \text{win}\}\),分别对应压缩(compression)、选择(selection)和滑动窗口(sliding window)策略
      • NSA 利用一个可微分的压缩分支来学习块选择分数。结合三个分支,NSA 的注意力输出由下式给出:
        $$
        o = \sum_{c \in \mathcal{C} } g^c \cdot \text{Attn}(q, K^c, V^c), \quad g_c \in [0, 1]
        $$
      • 对于压缩分支 \(c = \text{cmp}\),块 \(i\) 的键 \(K_i \in \mathcal{R}^{d_i \times b}\) 通过一个可学习的 MLP 层 \(\varphi\) 被压缩为单个键 \(K_i^{\text{cmp} } \in \mathcal{R}^{d_i \times 1}\)
      • 对于选择分支 \(c = \text{slc}\),基于块重要性分数 \(p\) 选择 Top-K 块,该分数可以直接从压缩分支获得
    • InfLLM-v2 (2025)
      • 采用了与 MoBA 类似的训练感知 Top-K 块稀疏注意力机制
      • 为了提高 Top-K 块选择的准确性,它将块划分为具有重叠的小粒度内核,并在每个块内对内核重要性分数进行聚合
System-level Design Choices (系统级设计选择)
  • 训练感知的稀疏注意力 (2025; 2025; 2025) 开始考虑内核实现和高效执行
  • 为了实现块稀疏注意力的高效实现,FlashAttention (2023) 被用于在高效的平铺机制中进行注意力计算,这为更好地利用硬件资源带来了要求和机会,包括:
    • 为避免内存访问不一致,在 SeerAttention (2024) 和 MInference (2024) 中,块大小 \(b\) 通常设置为至少 64 的相对较大值
    • 为了与 GPU 张量核心上分组矩阵乘法指令(Grouped Matrix Multiplication)的最低要求对齐,在 NSA (2025) 和 InfLLM-v2 (2025) 中,一个查询组内的 K、V 头数被设置为至少 16
    • 为了减少内存访问,NSA (2025) 和 InfLLM-v2 (2025) 强制查询组之间共享选定的块,这是通过在查询组内对块级重要性分数进行池化来完成的

Clustering Attention (聚类注意力)

  • 与块稀疏注意力类似,聚类注意力旨在为解码选择最关键地 Token ,但将 Token 组织在数据结构中以获得更好的语义属性或采样效率
  • RetrievalAttention (2024)
    • 采用近似最近邻搜索(Approximate Nearest Neighbor Search, ANNS)来选择关键的 K 个聚类
    • 为了解决注意力机制中查询向量和键向量之间的分布外(out-of-distribution)性质带来的挑战,它引入了一种适应查询向量分布的注意力感知向量搜索算法
  • ClusterKV (2024)
    • 在语义聚类的粒度上选择 Token ,克服了诸如 Quest 等页面级检索方法内部碎片化的问题
    • 在预填充阶段之后, Token 通过 K-means 算法进行聚类
    • Token \(i\) 和 \(j\) 之间的语义相似性通过键向量的余弦相似性度量:\(\mathcal{D}(i, j) = 1 - \frac{\langle k_i, k_j \rangle}{|k_i| \cdot |k_j|}\)
    • 语义聚类由其质心 \(\mu_1, \mu_2, …, \mu_C \in \mathcal{R}^d\) 表示
    • 在每个解码步骤,基于查询 Token \(q\) 和质心 \(\mu_i\) 的注意力权重(即 \(q \mu_i^T\))的排名来选择聚类
  • MagicPIG (2024)
    • 利用局部敏感哈希(Locality Sensitive Hashing, LSH)采样来高效近似注意力计算
    • 它使用 LSH 将相似的查询向量和键向量映射到相同的哈希桶中,并将存储和部分计算卸载到 CPU,以解决 KV 缓存的瓶颈
    • 它还引入了 Oracle Top-K 采样作为比暴力 Top-K 更好的策略

Bidirectional Sparse Attention (双向稀疏注意力)

  • 双向稀疏注意力建立在编码器风格的架构上,使用静态模式或块级稀疏性来加速注意力计算
  • 块稀疏性在双向稀疏注意力中被广泛使用
    • BigBird (2020)
      • 使用块级随机注意力,作为缩短 Token 之间间接路径的桥梁
    • Longformer (2020)
      • 使用静态的全局-局部混合注意力
      • Longformer 也依赖于块级稀疏性,并带有额外的全局和随机链接,以促进结构化计算和内存高效的并行性
  • 基于聚类的方法也用于双向稀疏注意力
    • Reformer (2020)
      • 使用局部敏感哈希(LSH)将相似的 Token 分配到同一个桶中
    • Routing Transformer (2020)
      • 在每一层执行在线 \(k\)-means 聚类
    • ClusterFormer (2022)
      • 引入了一个与下游目标共同训练的可微分聚类模块
      • 这些方法通过分组相关 Token 来减少计算,同时通过学习的适应性来保持性能

采用高效注意力机制的预训练大语言模型

采用统一高效注意力机制的预训练模型

  • 尽管早期的线性注意力探索通常局限于小规模模型,但近期的进展已证明其能够成功扩展到数十亿参数的规模,使其成为标准 Transformer 的一种可行且高效的替代方案
  • 这些模型完全基于线性注意力或其架构等价物,如状态空间模型和循环神经网络,即使在大规模下也能保持其标志性的推理效率
  • 基于 RWKV 的模型
    • RWKV 项目代表了一项持续且有影响力的努力,旨在创建一个可扩展的循环神经网络架构,它结合了 Transformer 的可并行化训练与传统 RNN 的高效推理 (2023)
    • 例如,EAGLE 系列引入了矩阵值状态以增加容量,而后续迭代如 Finch (RWKV-6) (2024) 和 Goose (RWKV-7) (2025) 则结合了动态循环和更具表达力的状态演化机制,以支持更复杂、数据依赖的状态转换
  • 基于 Mamba 的模型
    • Mamba 架构的成功及其数据依赖的选择机制,已引发了主要研究实验室的一波采用和扩展浪潮
    • Falcon Mamba (2024)
      • 基于纯 Mamba 架构,在一系列通用语言基准测试中展现出与领先 Transformer 模型相竞争的性能,验证了该架构在此类任务上的可行性,同时保持了其标志性的恒定时间推理
    • Codestral Mamba (2024)
      • 基于 Mamba-2 架构,进一步证明了该范式的潜力
      • 虽然专门用于代码生成,但它在相关基准测试中取得了 SOTA 结果,并支持 256K 个 Token 的上下文长度,展示了 SSM 方法在复杂结构化领域内的可扩展性和有效性
  • 基于稀疏注意力的模型
    • MiniCPM-4 (2025) 引入了一种两阶段稀疏注意力机制,根据语义相似性为每个查询 Token 动态选择相关的键值块
    • MiniCPM-4 利用 InfLLM-v2(一种块稀疏注意力变体)来替代标准注意力机制
    • 此外,一种轻量级的 LogSumExp 近似实现了高效的 top-k 选择,使得该方法能够扩展到极长序列
    • 这些技术共同使 MiniCPM-4 能够在细粒度上下文感知能力与可控的内存和计算需求之间取得平衡,使其成为长上下文建模的有力候选者

采用混合高效注意力机制的预训练模型

  • 随着对高效长上下文建模和多样化计算范式需求的增长,最近的研究广泛探索了混合注意力机制
  • 此类策略结合了全局和局部注意力组件,通常交错使用专门设计的层,以平衡计算成本和性能
  • 候选模型架构示意图如图 4 所示
  • 稀疏混合模型
    • GPT-3 (2020) 通过交错使用稠密注意力和局部带状稀疏注意力层,集成了一种混合注意力机制,其灵感来自 Sparse Transformer (2019)
    • 稠密注意力提供全上下文建模,而稀疏层则采用固定或跨步模式来减少被关注的 Token 数量
    • 这种设计使得 GPT-3 能够在 2048 个 Token 的固定上下文窗口内高效扩展到大规模模型,平衡了建模能力和计算效率
  • 线性-全注意力混合模型
    • Jamba (2024) 和 MiniMax-01 (2025) 结合了线性注意力和全注意力层,以在吞吐量和表达力之间实现高效的权衡
    • MiniMax-01 在大多数层中使用 Lightning Attention,并每八层插入一次基于 Softmax 的全注意力
    • Jamba 采用了相似的比例,在每八层的 Mamba 块中插入一个 Transformer 层
    • 两者都通过限制计算密集的全注意力的使用,实现了更快的解码和改善的长序列性能
  • 局部-全注意力混合模型
    • Gemma 3 (2025)、Command A (2025) 和 LLaMA-4-Maverick (2025) 在局部和全局注意力层之间交替使用,其共享的设计理念是稀疏地使用全局层(例如,每 4-6 层一次)以提升效率
    • 虽然局部层采用滑动窗口模式,但关键区别在于位置编码策略
    • Gemma 3 调制 RoPE 的基础频率——为局部层分配 10K,为全局层分配 1M(以更好地捕获长距离依赖关系)
    • Command A 和 LLaMA-4-Maverick 混合了基于 RoPE 的局部层与完全省略位置嵌入的全注意力层,从而实现了更强的长序列性能
  • 先进混合模型
    • Character.AI (2024) 将滑动窗口的局部注意力与每六层应用一次的稀疏全局注意力层交错使用
    • 特别是,它们在多个非相邻层之间复用全局注意力层的键值表示
    • 这种 KV 共享机制能以减少的内存和延迟开销实现高效的长上下文处理
  • YOCO (2024) 和 Phi-4-mini-flash (2025) 采用了一种双解码器架构,将预填充阶段和生成阶段分离开来
    • 自解码器在预填充和生成中都使用 RetNet 和滑动窗口注意力等线性注意力机制,而交叉解码器仅在生成期间激活
    • 整个过程中使用单层全局 KV 缓存,实现了线性时间的预填充和高效的解码,同时 GPU 内存消耗最小
  • In summary,这些最新进展强调了混合注意力机制以在不同计算约束和序列长度下实现平衡性能的趋势。每种架构都独特地贡献了关于如何有效结合局部细节管理与全局上下文整合的见解,从而为未来注意力机制的发展提供了有价值的框架

Outlook

  • 本综述全面概述了高效注意力机制,重点关注其算法基础、实际实现以及在大规模预训练语言模型中的集成
  • 通过将线性和稀疏注意力分类为明确定义的范式,论文识别了实现可扩展性、计算效率和长上下文能力的关键设计原则
  • 论文还分析了这些机制在最先进模型中如何部署,无论是作为独立架构还是作为平衡局部和全局计算的混合设计的一部分
  • 展望未来,论文强调几个预计将塑造该领域未来研究的关键方向:
  • 对混合模型的架构理解
    • 虽然先前关于线性注意力的工作主要集中于独立的线性架构,但混合模型通常通过结合现成的线性注意力模块与稠密或局部组件来构建
      • 但更强的线性主干是否直接转化为改进的混合性能尚不清楚
    • 未来的工作应将混合模型作为一个独特的架构类别来研究,试图理解它们的组成、相互作用效应和优化动态
  • 无损稀疏注意力与扩展上下文
    • 稀疏注意力仍然受到精度和计算增益之间权衡的挑战。完全训练的稀疏模型通常性能不如稠密模型,而训练后稀疏近似则由于缺乏端到端训练而面临限制
    • 一个主要的研究前沿在于开发能够保持稠密注意力的表达力和精度,同时扩展到更长上下文的稀疏注意力机制
    • Besides,稀疏预算与上下文长度之间的关系尚不明确,固定的 top-k 方案在序列更长时可能会退化,这需要更具适应性的策略
  • 对稀疏和混合注意力的机制性洞见
    • 实证研究反复证明,混合注意力模型可以用更少的注意力计算匹配甚至超越稠密模型,但其有效性的根本原因仍未得到充分探索
    • 此外,研究在合成基准测试中表现良好的稀疏模式是否适用于现实世界任务,以及描述基于稀疏性的泛化极限,尤为重要
  • 随着基于注意力的模型不断发展,论文预计架构创新、理论洞见和硬件感知设计之间将进一步融合
  • 作者希望本综述能为未来高效、高性能语言建模系统的研究奠定坚实的基础

NLP——旋转位置编码-RoPE

  • 参考链接:
    • 原始论文:(RoPE) RoFormer: Enhanced Transformer with Rotary Position Embedding, Arxiv 2023 & Neurocomputing 2024, 追一科技
    • 苏神博客:Transformer升级之路:2、博采众长的旋转式位置编码
    • 旋转式位置编码 (RoPE) 知识总结 - Soaring的文章 - 知乎,一篇把知识串的比较好的博客

原始 Transformer

基本 Attention 公式

  • 在标准的Transformer模型中,自注意力机制(Self-Attention)的公式是核心组成部分
  • 给定查询矩阵 \( Q \)、键矩阵 \( K \) 和值矩阵 \( V \),注意力输出计算为:
    $$
    \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k} }\right)V
    $$
    • \( Q \in \mathbb{R}^{n \times d_k} \), \( K \in \mathbb{R}^{m \times d_k} \), \( V \in \mathbb{R}^{m \times d_v} \)(\( n \)是目标序列长度,\( m \)是源序列长度)
    • \( d_k \) 是键/查询向量的维度
    • \( \sqrt{d_k} \) 用于缩放点积,防止梯度消失

Multi-Head Attention

  • Transformer使用多头注意力扩展基本注意力:
    $$
    \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W^O
    $$
  • 每个头的计算为:
    $$
    \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
    $$
    • \( W_i^Q \in \mathbb{R}^{d_{\text{model} } \times d_k} \), \( W_i^K \in \mathbb{R}^{d_{\text{model} } \times d_k} \), \( W_i^V \in \mathbb{R}^{d_{\text{model} } \times d_v} \)
    • \( W^O \in \mathbb{R}^{hd_v \times d_{\text{model} } } \) 是输出投影矩阵
    • \( h \) 是头的数量,通常满足 \( d_k = d_v = \frac{d_{\text{model} } }{h} \)

加入位置编码(仅修改输入即可)

  • 在Transformer中,输入会加上正弦位置编码 \( P \in \mathbb{R}^{d_{\text{model}}} \):
    $$
    X = \text{Embedding}(x) + P
    $$
    • 其中 \( P \in \mathbb{R}^{d_{\text{model}}} \) 的每个元素为:
      $$
      P_{pos, 2i} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model} } } }\right), \quad
      P_{pos, 2i+1} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model} } } }\right)
      $$

Self-Attention 完整公式(以单头为例)

  • 对于一个输入序列 \( X \in \mathbb{R}^{n \times d_{\text{model} } } \):
    $$
    \begin{aligned}
    Q &= XW^Q, \quad K = XW^K, \quad V = XW^V \\
    \text{Attention}(X) &= \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k} } + M\right)V
    \end{aligned}
    $$
    • \( M \) 是可选的掩码矩阵(如解码器的因果掩码)
    • 注:实际实现时,也可以直接在进入 Softmax 操作前,将 \(\frac{QK^\top}{\sqrt{d_k} }\) 的结果置为最小值 \(-e^9\),效果是等价的

Self-Attention 简单实现

  • Self-Attention的Python代码简单实现
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    # 单头Attention
    class ScaledDotProductAttention(nn.Module):
    def__init__(self, d_k):
    super().__init__()
    self.d_k = d_k

    def forward(self, Q, K, V, mask=None):
    # Q, K, V shape: (batch_size, seq_len, d_k(or d_v))
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
    if mask is not None:
    scores = scores.masked_fill(mask == 0, -1e9)
    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, V)
    return output
    # 多头Attention
    class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
    super().__init__()
    assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

    self.d_model = d_model
    self.num_heads = num_heads
    self.d_k = d_model // num_heads

    self.W_q = nn.Linear(d_model, d_model)
    self.W_k = nn.Linear(d_model, d_model)
    self.W_v = nn.Linear(d_model, d_model)
    self.W_o = nn.Linear(d_model, d_model)
    # 使用单头Attention
    self.attention = ScaledDotProductAttention(self.d_k)

    def split_heads(self, x):
    """
    x shape: (batch_size, seq_len, d_model)
    return shape: (batch_size, num_heads, seq_len, d_k)
    """
    batch_size, seq_len, _ = x.size()
    return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
    """
    x shape: (batch_size, num_heads, seq_len, d_k)
    return shape: (batch_size, seq_len, d_model)
    """
    batch_size, _, seq_len, _ = x.size()
    return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

    def forward(self, Q, K, V, mask=None):
    Q = self.W_q(Q)
    K = self.W_k(K)
    V = self.W_v(V)

    Q = self.split_heads(Q)
    K = self.split_heads(K)
    V = self.split_heads(V)

    # 如果需要,扩展mask以匹配多头(这里假设了mask是为单头准备的)
    if mask is not None:
    mask = mask.unsqueeze(1) # (batch_size, 1, seq_len) -> (batch_size, 1, 1, seq_len)

    attn_output = self.attention(Q, K, V, mask)
    output = self.combine_heads(attn_output)
    output = self.W_o(output)
    return output

固定位置编码实现

  • 固定位置编码,比如正弦位置编码,直接在 Attention 之前将位置编码向量加入到原始向量 \(X\) 中,Attention代码不需要做任何修改

Rotary Position Embedding, RoPE

  • 本节符号和原始论文 (RoPE) RoFormer: Enhanced Transformer with Rotary Position Embedding, Arxiv 2023 & Neurocomputing 2024, 追一科技 符号保持一致
  • 旋转位置编码(RoPE)的核心思想 :通过旋转矩阵将位置信息融入Self-Attention 机制中
  • 基本定义 :对于位置\( m \)的词向量\( \boldsymbol{x}_m \in \mathbb{R}^d \),通过线性变换得到查询向量\( \boldsymbol{q}_m \)和键向量\( \boldsymbol{k}_n \):
    $$
    \boldsymbol{q}_m = W_q \boldsymbol{x}_m, \quad \boldsymbol{k}_n = W_k \boldsymbol{x}_n
    $$
  • 旋转操作 :将\( \boldsymbol{q}_m \)和\( \boldsymbol{k}_n \)划分为\( d/2 \)个复数对(每组2维,RoPE要求维度必须是偶数,这一般都能满足),对第\( i \)组复数应用旋转矩阵:
    $$
    \begin{aligned}
    \boldsymbol{q}_m^{(i)} &= \begin{pmatrix}
    q_{m,2i} \\
    q_{m,2i+1}
    \end{pmatrix}, \quad
    \boldsymbol{k}_n^{(i)} = \begin{pmatrix}
    k_{n,2i} \\
    k_{n,2i+1}
    \end{pmatrix} \\
    R_{\theta_i}^m &= \begin{pmatrix}
    \cos m\theta_i & -\sin m\theta_i \\
    \sin m\theta_i & \cos m\theta_i
    \end{pmatrix}, \quad \theta_i = 10000^{-2i/d}
    \end{aligned}
    $$
    • 注意:位置为 \(m\) 的旋转矩阵对应正余弦角度为 \(\color{red}{m}\theta_i\)
    • 理解:旋转矩阵 \(R_{\theta_i}^m\) 可以将目标向量进行旋转,\(R_{\theta_i}^m \boldsymbol{x}\) 相当于将 \( \boldsymbol{x}\) 向逆时针方向旋转 \(m\theta_i\) 度(注意:只是旋转,并不修改原始向量的模长,因为 \(R_{\theta_i}^m\) 是正交矩阵),详情见附录
  • 旋转后的向量 :旋转后的查询和键向量为:
    $$
    \begin{aligned}
    \boldsymbol{q}_m’ = \bigoplus_{i=0}^{d/2-1} R_{\theta_i}^m \boldsymbol{q}_m^{(i)}, \quad
    \boldsymbol{k}_n’ = \bigoplus_{i=0}^{d/2-1} R_{\theta_i}^n \boldsymbol{k}_n^{(i)}
    \end{aligned}
    $$
    • 其中\( \oplus \)表示向量拼接:
      $$\bigoplus_{i=0}^{d/2-1} R_{\theta_i}^m \boldsymbol{q}_m^{(i)} = \text{Concat}(\{ R_{\theta_i}^m \boldsymbol{q}_m^{(i)}\}_{i=0}^{d/2-1})$$
  • 旋转后的Attention权重变化
    $$
    \begin{equation}
    (\boldsymbol{\mathcal{R}}_m \boldsymbol{q}_m)^{\top}(\boldsymbol{\mathcal{R}}_n \boldsymbol{k}_n) = \boldsymbol{q}_m^{\top} \boldsymbol{\mathcal{R}}_m^{\top}\boldsymbol{\mathcal{R}}_n \boldsymbol{k}_n = \boldsymbol{q}_m^{\top} \boldsymbol{\mathcal{R}}_{n-m} \boldsymbol{k}_n
    \end{equation}
    $$
    • 位置为 \(m\) 的向量 \(\boldsymbol{q}_m\) 乘以矩阵 \(\boldsymbol{\mathcal{R}}_m\);位置为 \(n\) 的向量 \(\boldsymbol{k}_n\) 乘以矩阵 \(\boldsymbol{\mathcal{R}}_n\)(注意角标)
    • 上面的式子中等式是恒成立的(\( \boldsymbol{\mathcal{R}}_m^{\top}\boldsymbol{\mathcal{R}}_n = \boldsymbol{\mathcal{R}}_{\color{red}{n-m}}\)的详细证明见附录),右边的 \(\boldsymbol{\mathcal{R}}_{\color{red}{n-m}}\)仅与相对位置 \(n-m\) 有关,体现了相对位置编码的核心要义
      • 注:\(\boldsymbol{\mathcal{R}}_{m-n}\) 和 \(\boldsymbol{\mathcal{R}}_{n-m}\) 不相等,旋转角度相同,但方向相反
  • 展开成矩阵相乘的形式为(refer to Transformer升级之路:2、博采众长的旋转式位置编码):
    $$
    \begin{equation}\scriptsize{\underbrace{\begin{pmatrix}
    \cos m\theta_0 & -\sin m\theta_0 & 0 & 0 & \cdots & 0 & 0 \\
    \sin m\theta_0 & \cos m\theta_0 & 0 & 0 & \cdots & 0 & 0 \\
    0 & 0 & \cos m\theta_1 & -\sin m\theta_1 & \cdots & 0 & 0 \\
    0 & 0 & \sin m\theta_1 & \cos m\theta_1 & \cdots & 0 & 0 \\
    \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\
    0 & 0 & 0 & 0 & \cdots & \cos m\theta_{d/2-1} & -\sin m\theta_{d/2-1} \\
    0 & 0 & 0 & 0 & \cdots & \sin m\theta_{d/2-1} & \cos m\theta_{d/2-1} \\
    \end{pmatrix}}_{\boldsymbol{\mathcal{R}}_m} \begin{pmatrix}q_0 \\ q_1 \\ q_2 \\ q_3 \\ \vdots \\ q_{d-2} \\ q_{d-1}\end{pmatrix}}\end{equation}
    $$
  • 由于旋转矩阵是一个稀疏矩阵,所以旋转过程可以改进为如下等价实现:
    $$
    \begin{equation}\begin{pmatrix}q_0 \\ q_1 \\ q_2 \\ q_3 \\ \vdots \\ q_{d-2} \\ q_{d-1}
    \end{pmatrix}\otimes\begin{pmatrix}\cos m\theta_0 \\ \cos m\theta_0 \\ \cos m\theta_1 \\ \cos m\theta_1 \\ \vdots \\ \cos m\theta_{d/2-1} \\ \cos m\theta_{d/2-1}
    \end{pmatrix} + \begin{pmatrix}-q_1 \\ q_0 \\ -q_3 \\ q_2 \\ \vdots \\ -q_{d-1} \\ q_{d-2}
    \end{pmatrix}\otimes\begin{pmatrix}\sin m\theta_0 \\ \sin m\theta_0 \\ \sin m\theta_1 \\ \sin m\theta_1 \\ \vdots \\ \sin m\theta_{d/2-1} \\ \sin m\theta_{d/2-1}
    \end{pmatrix}\end{equation}
    $$
    • 其中 \(\otimes\) 是按位相乘
  • RoPE下的Attention公式总结 :(旋转位置编码的核心公式)
    $$
    \begin{aligned}
    \text{Attention}(\boldsymbol{x}) &= \text{softmax}\left(\frac{(\boldsymbol{q}’)^\top \boldsymbol{k}’}{\sqrt{d} }\right)V
    \end{aligned}
    $$
    • 注:\(V\) 是 Attention 中的 Value 矩阵,中不需要位置编码信息
    • 这里使用 \((\boldsymbol{q}’)^\top \boldsymbol{k}’\),转置在 \(\boldsymbol{q}’\) 上,和原始论文表达方式一致,实际上这种表示是OK的,数学中常用这种表示 ,这种表示下,向量为列向量;原始 Transformer 论文中的符号转置在 Key 上,此时向量为行向量
  • 原始论文中的RoPE示意图:

多头注意力下的 RoPE

  • 为了跟传统的 Transformer 符号对齐,本节改用 \(Q,K,V\)表示矩阵,与 RoPE 原始论文符号不再一致
  • 给定输入序列 \( X \in \mathbb{R}^{n \times d_{\text{model} } } \),先投影到查询、键、值空间:
    $$
    Q = XW^Q, \quad K = XW^K, \quad V = XW^V
    $$
    • 其中 \( W^Q, W^K \in \mathbb{R}^{d_{\text{model} } \times d_k} \), \( W^V \in \mathbb{R}^{d_{\text{model} } \times d_v} \)
  • 应用旋转位置编码(RoPE) :对 \( Q \) 和 \( K \) 的每个位置 \( m \) 和 \( n \) 的分量应用旋转矩阵 \( R_{\theta}^m \) 和 \( R_{\theta}^n \):
    $$
    \begin{aligned}
    Q’ &= \text{RoPE}(Q) = \bigoplus_{i=0}^{d_k/2-1} R_{\theta_i}^m Q^{(i)} \\
    K’ &= \text{RoPE}(K) = \bigoplus_{i=0}^{d_k/2-1} R_{\theta_i}^n K^{(i)}
    \end{aligned}
    $$
    • \( Q^{(i)} \in \mathbb{R}^2 \) 和 \( K^{(i)} \in \mathbb{R}^2 \) 是 \( Q \) 和 \( K \) 的第 \( i \) 个二维分量
    • \( \oplus \)表示向量拼接:
      $$\bigoplus_{i=0}^{d_k/2-1} R_{\theta_i}^m Q^{(i)} = \text{Concat}(\{R_{\theta_i}^m Q^{(i)}\}_{i=0}^{d_k/2-1} )$$
  • 旋转矩阵 \( R_{\theta_i}^m \) 定义为:
    $$
    R_{\theta_i}^m = \begin{pmatrix}
    \cos m\theta_i & -\sin m\theta_i \\
    \sin m\theta_i & \cos m\theta_i
    \end{pmatrix}, \quad \theta_i = 10000^{-2i/d_k}
    $$
  • 多头注意力输出
    $$
    \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_H) W^O \\
    \text{head}_h = \text{Softmax}\left(\frac{Q^{\prime}_h {K^{\prime}_h}^\top}{\sqrt{d_k} }\right) V_h
    $$
    • 其中 \( W^O \in \mathbb{R}^{H d_v \times d_{\text{model} } } \) 是输出投影矩阵
  • 多头注意力下的 RoPE 的 Attention 公式总结 :
    $$
    \begin{aligned}
    \text{Attention}(Q,K,V) &= \text{softmax}\left(\frac{(\text{RoPE}(Q))(\text{RoPE}(K))^\top}{\sqrt{d_k} }\right)V \\
    \text{where} \quad \text{RoPE}(X) &= \bigoplus_{i=0}^{d/2-1} R_{\theta_i}^m X^{(i)}
    \end{aligned}
    $$

多头注意力下的 RoPE 实现

  • 多头注意力下的RoPE PyTorch 实现 (待补充)
    • 注:多头注意力下,每个头是独立编码的(每个头维度从 0 开始),且使用的旋转矩阵一样,即旋转矩阵的维度 \(d = d_{\text{head}} = d_{\text{model}}/h\)

关于 RoPE 的一些讨论

  • RoPE 传统 Transformer 的区别 :
    • 传统Transformer:位置编码是加性的(\( X + P \))
    • RoPE:位置编码是乘性的(通过旋转矩阵直接修改 \( Q \) 和 \( K \))
  • 相对位置保持性 :
    • 旋转后的注意力分数 \( [Q’ K’^\top]_{m,n} \) 仅依赖于相对位置 \( m-n \),满足线性注意力性质
  • 远程衰减性 :(注:如下图所示的远距离衰减,会导致太远的距离下,难以区分位置,效果不好,外推性变差?)

附录:相对位置旋转公式证明

  • 目标:证明 \( \boldsymbol{\mathcal{R}}_m^{\top}\boldsymbol{\mathcal{R}}_n = \boldsymbol{\mathcal{R}}_{\color{red}{n-m}}\)
  • 考虑到 \(\boldsymbol{\mathcal{R}}_m\) 是以二维子矩阵为单位的“对角”矩阵,故只要证明 \((R_{\theta_i}^m)^\top R_{\theta_i}^n = R_{\theta_i}^{\color{red}{n-m}}\) 即可,证明过程如下:
  • 给定旋转矩阵 \( R_{\theta_i}^m \) 定义为:
    $$
    R_{\theta_i}^m = \begin{pmatrix}
    \cos m\theta_i & -\sin m\theta_i \\
    \sin m\theta_i & \cos m\theta_i
    \end{pmatrix}
    $$
  • 其转置矩阵为:
    $$
    (R_{\theta_i}^m)^\top = \begin{pmatrix}
    \cos m\theta_i & \sin m\theta_i \\
    -\sin m\theta_i & \cos m\theta_i
    \end{pmatrix}
    $$
  • 计算 \((R_{\theta_i}^m)^\top R_{\theta_i}^n\)
    $$
    (R_{\theta_i}^m)^\top R_{\theta_i}^n = \begin{pmatrix}
    \cos m\theta_i & \sin m\theta_i \\
    -\sin m\theta_i & \cos m\theta_i
    \end{pmatrix}
    \begin{pmatrix}
    \cos n\theta_i & -\sin n\theta_i \\
    \sin n\theta_i & \cos n\theta_i
    \end{pmatrix}
    $$
  • 回顾三角函数和差角公式:
    $$
    \sin(A \pm B) = \sin A \cos B \pm \cos A \sin B \\
    \cos(A \pm B) = \cos A \cos B \mp \sin A \sin B
    $$
  • 计算矩阵乘积的每个元素:
    • 左上角元素:
      $$
      \begin{align}
      \cos m\theta_i \cdot \cos n\theta_i + \sin m\theta_i \cdot \sin n\theta_i &= \cos(n\theta_i - m\theta_i) \\
      &= \cos((n - m)\theta_i) \\
      \end{align}
      $$
    • 右上角元素:
      $$
      \begin{align}
      \cos m\theta_i \cdot (-\sin n\theta_i) + \sin m\theta_i \cdot \cos n\theta_i &= -\cos m\theta_i \sin n\theta_i + \sin m\theta_i \cos n\theta_i \\
      &= \sin m\theta_i \cos n\theta_i - \cos m\theta_i \sin n\theta_i\\
      & = \sin((m-n)\theta_i) \\
      & = - \sin((n-m)\theta_i)
      \end{align}
      $$
    • 左下角元素:
      $$
      \begin{align}
      -\sin m\theta_i \cdot \cos n\theta_i + \cos m\theta_i \cdot \sin n\theta_i &= -\sin m\theta_i \cos n\theta_i + \cos m\theta_i \sin n\theta_i \\
      &= \sin n\theta_i \cos m\theta_i - \cos n\theta_i \sin m\theta_i \\
      &= \sin((n - m)\theta_i)
      \end{align}
      $$
    • 右下角元素:
      $$
      \begin{align}
      -\sin m\theta_i \cdot (-\sin n\theta_i) + \cos m\theta_i \cdot \cos n\theta_i &= \sin m\theta_i \sin n\theta_i + \cos m\theta_i \cos n\theta_i \\
      &= \sin m\theta_i \sin n\theta_i + \cos m\theta_i \cos n\theta_i \\
      &= \cos((n - m)\theta_i)
      \end{align}
      $$
  • 因此,乘积矩阵为:
    $$
    (R_{\theta_i}^m)^\top R_{\theta_i}^n = \begin{pmatrix}
    \cos((n - m)\theta_i) & -\sin((n - m)\theta_i) \\
    \sin((n - m)\theta_i) & \cos((n - m)\theta_i)
    \end{pmatrix} = R_{\theta_i}^{\color{red}{n-m}}
    $$
  • 至此,我们证明了:
    $$
    (R_{\theta_i}^m)^\top R_{\theta_i}^n = R_{\theta_i}^{\color{red}{n-m}}
    $$
  • 证毕

附录:不同参数下 RoPE 对 Attention 的影响

  • RoPE原始论文已经说明了随着距离的增长,Attention Score 有越来越小的趋势(且长距离部分会波动)
  • 下面是固定 query_index 下,Attention Score 随 key_index(横轴)变化的图像(代码参考链接修改的 旋转式位置编码 (RoPE) 知识总结 - Soaring的文章 - 知乎)
  • 从上图可以观察到:
    • 图1:RoPE确实有远距离衰减趋势(震荡递减),且dim_model=256时,q,k 距离为500时Attention值已经衰减的较小了
    • 图2:与图1类似,RoPE的衰减是对称的(图2展示的是当 query_index=256时的图像);
      • 注:注意虽然 Attention Score 是对称相等的,但是旋转角度是相反的,即 \(\boldsymbol{\mathcal{R}}_{m-n}\) 和 \(\boldsymbol{\mathcal{R}}_{n-m}\) 旋转角度相同,但方向是相反的
    • 图3:RoPE 在拉到足够长的距离后,不会一直衰减(从后续的图可以知道实际上还是周期函数,只是周期很大)
    • 图4+图5:与图1对比可以发现,RoPE 在拉到足够长的距离后,实际上还是周期函数,只不过周期与 d_model 相关(d_model越大,周期越长),图4说明了当 d_model=8 时,周期是 10000 左右
    • 图6:缩小了图5的横轴区间,将图5的前半部分图像放大了看,是在小周期上震荡的,且还存在图5所示的大周期
  • 附上图的代码:
    >>>点击展开折叠内容...
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    # refer to: [旋转式位置编码 (RoPE) 知识总结 - Soaring的文章 - 知乎](https://zhuanlan.zhihu.com/p/662790439)
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.axes import Axes

    def create_sin_cos_table_cache(max_num_tokens, dim_model):
    # 所有pos下对应的cos/sin值分别存储为矩阵
    theta = 10000 ** (-np.arange(0, dim_model, 2) / dim_model)
    theta = theta.reshape(-1, 1).repeat(2, axis=1).flatten()

    pos = np.arange(0, max_num_tokens)
    table = pos.reshape(-1, 1) @ theta.reshape(1, -1) # [max_num_tokens, dim_model]

    sin_table_cache = np.sin(table)
    sin_table_cache[:, ::2] = -sin_table_cache[:, ::2]

    cos_table_cache = np.cos(table)
    return sin_table_cache, cos_table_cache

    def rotate_half(q_vec):
    # 将q_vec的值两个一组分组并在分组内对调,实现从[q_0,q_1,q_2,q_3,...,q_{d-1},q_d]到的转换[q_1,q_0,q_3,q_2,...,q_d,q_{d-1}]
    return q_vec.reshape(-1, 2)[:, ::-1].flatten()

    def rotary(vec, pos, sin_table, cos_table):
    # 原始论文中的公式
    return vec * cos_table[pos] + rotate_half(vec) * sin_table[pos]

    def plot(plt_obj: Axes, pic_index, query_index=0, dim_model=256, max_num_tokens=8192, step=1):
    # q_vec 和 k_vec 都设定为1,仅关注 RoPE 引发的Attention Score的变化
    q_vec = np.ones(dim_model)
    k_vec = np.ones(dim_model)
    sin_table, cos_table = create_sin_cos_table_cache(max_num_tokens, dim_model)

    rotated_q_vec = rotary(q_vec, query_index, sin_table, cos_table)
    k_indices = np.arange(0, max_num_tokens, step)
    rotated_k_vecs = rotary(k_vec, k_indices, sin_table, cos_table)
    attn_scores = (rotated_k_vecs @ rotated_q_vec) / np.sqrt(dim_model) # 未经过Softmax的Attention权重,用于展示RoPE对原始Attention Score

    plt_obj.plot(k_indices, attn_scores)
    plt_obj.set_title(f"Figure {pic_index}: query_index={query_index}, dim_model={dim_model}")
    plt_obj.set_xlabel("key index")
    plt_obj.set_ylabel("Attention Score")

    plt.rcParams.update({
    "font.sans-serif": ["Times New Roman", ],
    "font.size": 10
    })

    _, axes = plt.subplots(nrows=2, ncols=3, figsize=(10, 10))
    plot(axes[0, 0], 1, query_index=0, max_num_tokens=512)
    plot(axes[0, 1], 2, query_index=256, max_num_tokens=512)
    plot(axes[0, 2], 3, query_index=0, dim_model=256, max_num_tokens=65535)
    # plot(axes[1, 0], 4, query_index=0, dim_model=32, max_num_tokens=65535)
    plot(axes[1, 0], 4, query_index=0, dim_model=16, max_num_tokens=65535)
    plot(axes[1, 1], 5, query_index=0, dim_model=8, max_num_tokens=65536)
    plot(axes[1, 2], 6, query_index=0, dim_model=8, max_num_tokens=512)
    plt.show()

附录:如何改进 RoPE 以实现绝对位置编码?

  • 当前的设计仅能实现相对位置编码(因为 Query 和 Key 的内积只与他们的相对位置有关,与绝对位置无关),但如果在 Value 上也施加位置编码则能实现绝对位置编码的能力
  • 让研究人员绞尽脑汁的Transformer位置编码中苏神提到:

    这样一来,我们得到了一种融绝对位置与相对位置于一体的位置编码方案,从形式上看它有点像乘性的绝对位置编码,通过在 \(\boldsymbol{q},\boldsymbol{k}\) 中施行该位置编码,那么效果就等价于相对位置编码,而如果还需要显式的绝对位置信息,则可以同时在 \(\boldsymbol{v}\) 上也施行这种位置编码。总的来说,我们通过绝对位置的操作,可以达到绝对位置的效果,也能达到相对位置的效果


附录:旋转体现在哪里?

  • 旋转体现在 \(\boldsymbol{q}_m\)(或\(\boldsymbol{k}_n\)) 的每两个维度组成的向量 \(\boldsymbol{q}_m^{(i)} = \begin{pmatrix} q_{m,2i} \\ q_{m,2i+1} \end{pmatrix}\) 经过 RoPE 变换前后,他们的向量长度不变,即:
    $$
    \begin{pmatrix}cos(m\theta_1) &-sin(m\theta_1) \\ sin(m\theta_1) &cos(m\theta_1)\end{pmatrix} \begin{pmatrix}q_{m,2i} \\ q_{m,2i+1}\end{pmatrix} = \begin{pmatrix}q_{m,2i}\cdot cos(m\theta_1) -q_{m,2i+1}\cdot sin(m\theta_1) \\ q_{m,2i}\cdot sin(m\theta_1)+ q_{m,2i+1}\cdot cos(m\theta_1) \end{pmatrix} \\
    $$
  • 进一步,由于 \((cos(m\theta_1))^2 + (sin(m\theta_1))^2 = 1\) 有
    $$
    \begin{align}
    \left|\begin{pmatrix}q_{m,2i}\cdot cos(m\theta_1) -q_{m,2i+1}\cdot sin(m\theta_1) \\ q_{m,2i}\cdot sin(m\theta_1)+ q_{m,2i+1}\cdot cos(m\theta_1) \end{pmatrix} \right| &= \sqrt{\left(q_{m,2i}\cdot cos(m\theta_1) -q_{m,2i+1}\cdot sin(m\theta_1)\right)^2 + \left(q_{m,2i}\cdot sin(m\theta_1)+ q_{m,2i+1}\cdot cos(m\theta_1)\right)^2} \\
    &= \sqrt{(q_{m,2i})^2 + (q_{m,2i+1})^2}
    \end{align}
    $$
  • 也就是说: \(\boldsymbol{q}_m\)(或\(\boldsymbol{k}_n\))相邻两两维度在变换前后的向量长度并没有变化,是一个旋转操作

附录:可视化RoPE旋转过程

  • 旋转矩阵 \(R_{\theta_i}^m\) 的定义如下:
    $$
    \begin{aligned}
    R_{\theta_i}^m &= \begin{pmatrix}
    \cos m\theta_i & -\sin m\theta_i \\
    \sin m\theta_i & \cos m\theta_i
    \end{pmatrix}, \quad \theta_i = 10000^{-2i/d}
    \end{aligned}
    $$
  • 旋转矩阵 \(R_{\theta_i}^m\) 可以将目标向量进行旋转,\(R_{\theta_i}^m \boldsymbol{x}\) 相当于将 \( \boldsymbol{x}\) 向逆时针方向旋转 \(m\theta_i\) 度
  • 当 \(m\theta_i = \frac{\pi}{4}\) 时,其旋转可视化结果如下:
  • 实现上述旋转的代码如下
    >>>点击展开折叠内容...
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
       import numpy as np
    import matplotlib.pyplot as plt

    # 设定原始向量 x
    x = np.array([1, 0])

    # 角度转换为弧度
    angle = np.radians(45)

    # 旋转矩阵定义
    R = np.array([[np.cos(angle), -np.sin(angle)],
    [np.sin(angle), np.cos(angle)]])

    # 矩阵惩罚实现旋转向量
    x_rotated = R @ x

    plt.rcParams['figure.dpi'] = 300

    plt.figure(figsize=(6, 6))
    plt.quiver(0, 0, x[0], x[1], angles='xy', scale_units='xy', scale=1, color='b', label='Original Vector')
    plt.quiver(0, 0, x_rotated[0], x_rotated[1], angles='xy', scale_units='xy', scale=1, color='r', label='Rotated Vector')

    plt.xlim(-1.5, 1.5)
    plt.ylim(-1.5, 1.5)

    plt.grid(True)

    plt.legend()
    plt.title('Vector Rotation')
    plt.xlabel('X-axis')
    plt.ylabel('Y-axis')

    plt.savefig("./demo.png")
    plt.show()

附录:RoPE的诞生历史

  • RoPE方案来自苏神原始博客:
    • 2021年2月在让研究人员绞尽脑汁的Transformer位置编码中提出想法
    • 2023年3月在Transformer升级之路:2、博采众长的旋转式位置编码中给出详细方案和推导,同时提交论文到arXiv上
    • 原始论文:Roformer: Enhanced Transformer With Rotray Position Embedding,该论文24年发表于Neurocomputing期刊上(《Neurocomputing》是国际知名期刊,被列为中科院SCI二区top期刊,CCF-C类期刊)
    • 随后,各种开源大模型开始使用RoPE,RoPE逐渐成为大模型的标配

附录:RoPE的高维扩展

  • 论文介绍当前的是一维 RoPE,每 2 维为一组,旋转矩阵为 \(2\times 2\),在二维 RoPE 场景中,可以每 4 维为一组,旋转矩阵变成 \(4\times 4\) 即可,详情见 旋转式位置编码 (RoPE) 知识总结
  • 将四个维度作为一组:
    $$
    \boldsymbol{R}_{m_1,m_2} =
    \begin{bmatrix}
    \cos m_1 \theta & -\sin m_1 \theta & 0 & 0 \\
    \sin m_1 \theta & \cos m_1 \theta & 0 & 0 \\
    0 & 0 & \cos m_2 \theta & -\sin m_2 \theta \\
    0 & 0 & \sin m_2 \theta & \cos m_2 \theta
    \end{bmatrix}
    $$
  • 上述分组下满足:
    $$
    \mathbf{R}_{m_1,m_2}^{\top} \cdot \mathbf{R}_{n_1,n_2} = \mathbf{R}_{n_1 - m_1,n_2 - m_2}
    $$
  • 注:更高维度的可以继续扩展,比如三维的扩展为每 6 维为一组,旋转矩阵为 \(6\times 6\) 即可

附录:RoPE中的复数和旋转矩阵等价性证明

问题定义

  • 定义二维旋转矩阵和向量如下:
    $$
    \begin{align}
    R_{\theta_i}^m &= \begin{pmatrix}
    \cos m\theta_i & -\sin m\theta_i \\
    \sin m\theta_i & \cos m\theta_i
    \end{pmatrix} \\
    \boldsymbol{x} &= \begin{pmatrix} x_1 \\ x_2
    \end{pmatrix}
    \end{align}
    $$
  • 目标:证明下面的等式
    $$R_{\theta_i}^m \boldsymbol{x} = z e^{i m\theta_i}$$
    • 其中 \( z = x_1 + i x_2 \) 是向量 \(\boldsymbol{x} = \begin{pmatrix} x_1 \\ x_2 \end{pmatrix}\) 的复数形式
    • 即目标是证明:旋转矩阵 \( R_{\theta_i}^m \) 作用在 \(\boldsymbol{x}\) 上相当于将复数 \( z \) 乘以旋转因子 \( e^{i m\theta_i} \)

证明

  • 方程左边展开
    $$
    R_{\theta_i}^m \boldsymbol{x} = \begin{pmatrix}
    \cos m\theta_i & -\sin m\theta_i \\
    \sin m\theta_i & \cos m\theta_i
    \end{pmatrix}
    \begin{pmatrix} x_1 \\ x_2 \end{pmatrix} =
    \begin{pmatrix}
    x_1 \cos m\theta_i - x_2 \sin m\theta_i \\
    x_1 \sin m\theta_i + x_2 \cos m\theta_i
    \end{pmatrix}
    $$

    • 这个结果对应的复数为:
      $$
      R_{\theta_i}^m \boldsymbol{x} = (x_1 \cos m\theta_i - x_2 \sin m\theta_i) + i (x_1 \sin m\theta_i + x_2 \cos m\theta_i)
      $$
  • 等方程右边展开

    • 右边的 \(\boldsymbol{x} e^{i m\theta_i}\) 表示将复数 \( z = x_1 + i x_2 \) 乘以 \( e^{i m\theta_i} \),即:
      $$
      z e^{i m\theta_i} = (x_1 + i x_2)(\cos m\theta_i + i \sin m\theta_i)
      $$
    • 可以重新整理为:
      $$ z e^{i m\theta_i} = (x_1 \cos m\theta_i - x_2 \sin m\theta_i) + i (x_1 \sin m\theta_i + x_2 \cos m\theta_i) $$
    • 展开后与旋转矩阵作用的结果完全一致
  • 结论: 在复数表示下,旋转矩阵 \( R_{\theta_i}^m \) 作用在向量 \(\boldsymbol{x}\) 上等价于将对应的复数 \( z \) 乘以旋转因子 \( e^{i m\theta_i} \)。因此,等式成立:
    $$
    \color{red}{R_{\theta_i}^m \boldsymbol{x} = \boldsymbol{x} e^{i m\theta_i}}
    $$


附录:旋转位置编码的其他推导过程

  • 下面的推导过程来自通俗易懂-大模型的关键技术之一:旋转位置编码rope(2)和通俗易懂-大模型的关键技术之一:旋转位置编码rope (3),推导过程看过,基本没有问题,先截图,以后有时间再手打一遍
  • \(q,k\) 向量旋转后再进行Attention,注意图中 \(e^{im\theta}\) 的 \(i\) 是虚数的意思,这里使用二维向量 \((\boldsymbol{W}_q\boldsymbol{x}_m)\) 乘以一个虚数的本质是想表达向量内积的意思(虚数可以展开成二维向量),此外,\(a_{m,n}\) 表示 Attention 权重,在不考虑权重时值为 Softmax,Element-wise看:\(a_{m,n} =\frac{\exp(\frac{q^T_m k_n}{\sqrt{d_k}})}{\sum_{j=1}^N \exp(\frac{q^T_m k_j}{\sqrt{d_k}})}\),注:下图中的表达有误,实际上应该是 \(a_{m,n}=\frac{\exp(\frac{ {x^{\prime}_m}^T x^{\prime}_n}{\sqrt{d_k} })}{\sum_{j=1}^N \exp(\frac{ {x^{\prime}_m}^T x^{\prime}_j}{\sqrt{d_k} })}\)(详情见原始论文:(RoPE) RoFormer: Enhanced Transformer with Rotary Position Embedding, Arxiv 2023 & Neurocomputing 2024, 追一科技 )
  • 公式1,2,推导出公式3的过程(内积角度)
  • 公式1,2,推导出公式3的过程(公式角度)
  • 内积角度和公式角度推导结果可以对齐
  • 扩展到多维的方式,将模型d_model维度按照两两一组分(注意这里不是序列两两分,序列本来做Attention就是两两做内积的),这里要求模型维度是偶数的
  • 公式化简
  • 代码实现
    • 上述实现中使用了torch.einsum,本质是爱因斯坦求和约定 ,是矩阵乘法的一种表示,C = torch.einsum("n,d->md", A,B)表示矩阵C[m,d] = A[n]*B[d],,这是个很常用的省略写法,更详细的可以看看图学 AI:einsum 爱因斯坦求和约定到底是怎么回事?
1…678…63
Joe Zhou

Joe Zhou

Stay Hungry. Stay Foolish.

628 posts
53 tags
GitHub E-Mail
© 2026 Joe Zhou
Powered by Hexo
|
Theme — NexT.Gemini v5.1.4