Jiahong 的个人博客

凡事预则立,不预则废


  • Home

  • Tags

  • Archives

  • Navigation

  • Search

RS——生成式推荐

  • 参考链接:
    • (TIGER)Recommender Systems with Generative Retrieval, NeurIPS 2023, Google: 一篇比较早,也比较基础的工作
    • OneRec: Unifying Retrieve and Rank with Generative Recommender and Preference Alignment, 202502, KuaiShou
    • (COBRA)Sparse Meets Dense: Unified Generative Recommendations with Cascaded Sparse-Dense Representations, 202503, Baidu
    • (RARE)Real-time Ad retrieval via LLM-generative Commercial Intention for Sponsored Search Advertising, 202504, Tencent
      • 相关博客:腾讯搜索广告生成式检索
    • 一文汇总:LLM应用到推荐系统的各类玩法总结
    • Slow Thinking for Sequential Recommendation, 2025, RUC
      • 相关博客:STREAM-Rec: 推荐系统实现慢思考推理
    • LEARN: Knowledge Adaptation from Large Language Model to Recommendation for Practical Industrial Application, AAAI 2025, Kuaishou
      • 相关博客:AAAI’25 | 快手LEARN:使用LLM做特征增强用于电商广告推荐
    • Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for Generative Recommendations
    • HLLM: Enhancing Sequential Recommendations via Hierarchical Large Language Models for Item and User Modeling, 202409, ByteDance
    • Unlocking Scaling Law in Industrial Recommendation Systems with a Three-step Paradigm based Large User Model, 202502, Alibaba

整体说明

  • 传统深度学习推荐模型(Deep Learning Recommendation Models,DLRMs)范式主要是“Retrieval-Ranking”模式
  • 目前大模型在推荐领域的实践大致分为三类:
    • 第一类:保留 DLRMs 范式,仅利用通用的 LLM 做知识增强,或从领域数据中提取表征信息等
    • 第二类:仍保留 DLRMs 范式,部分重构 Retrieval 阶段或者 Ranking 阶段(或者在 Retrieval 阶段增加一路生成式召回通道)
    • 第二类:直接修改 DLRMs 范式,将输入和输出都重新构造为 LLM 模式,直接利用重构后的领域数据做预训练和微调

RS——Meta-GRs-HSTU

  • 参考链接:
    • (HSTU)Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for Generative Recommendations, ICML 2024, Meta

整体说明

  • 当前推荐系统的特点 :
    • 大规模推荐系统的特点在于其依赖高基数(high cardinality,注:也称为高维)、异构特征(heterogeneous features),并且需要每天处理数百亿用户行为
    • 尽管大多数工业界的深度学习推荐模型(Deep Learning Recommendation Models,DLRMs)利用海量数据和数千个特征进行训练,却未能实现与计算资源的有效扩展(即增加了资源,效果得不到对应的提升)
  • 论文方案一句话说明 :(受 Transformer 在语言和视觉领域成功的启发)
    • 论文将推荐问题重新定义为生成式建模框架下的 Squential Transduction 任务(Generative Recommenders(GRs),生成式推荐器),并提出了一种新架构HSTU ,专为高基数、非平稳的流式推荐数据设计
    • 注:论文的 Squential Transduction 和 Transduction Learning(直推式学习)没有直接关系(虽然二者名字很像);
      • Transduction Learning(直推式学习) :利用训练数据和测试数据的整体信息,直接对见过的测试数据做出预测
      • Inductive Learning(归纳式学习) :利用训练数据的信息,训练一个可以泛化到未见过的测试数据的通用的模型
      • 从定义看,论文本质还是一个Inductive Learning(归纳式学习) ,学到的模型是要使用在未知数据场景的(包括未知用户、未知时间和未知商品)
  • 效果和效率 :
    • 离线效果 :在合成和公开数据集上,HSTU相比基线模型在 NDCG 指标上最高提升 65.8%
    • 效率 :并且在处理 8192 长度序列时比基于 FlashAttention2 的 Transformer 快 5.3 倍至 15.2 倍
    • 在线效果 :基于 HSTU 的生成式推荐器拥有1.5万亿参数,在线A/B测试中指标提升12.4%,已在拥有数十亿用户的互联网平台的多个场景中部署
  • scaling law发现 :更重要的是,生成式推荐器的模型质量随训练计算量呈幂律增长(power-law),扩展范围达到GPT-3/LLaMa-2规模,这减少了未来模型开发所需的 carbon footprint(碳排放总和),并为推荐系统中的首个基础模型铺平了道路
    • 理解:power-law说明模型的算力投入可以很高效的得到回报,所以不需要盲目增加很多计算资源,也就减少了碳排放
  • 注意:论文的方法依然保留了推荐系统中的 Retrieval-Ranking 两阶段(Retrieval也称为召回或 Retrieval 阶段),只是分别将两阶段都使用生成式模型来建模了,对于召回阶段,甚至可以将生成式模型作为一个新的召回通道使用,作为对原始 DLRM 的一种补足

一些背景讨论

  • 传统推荐系统 :最先进的推荐方法主要基于深度学习推荐模型 (DLRM),其特点在于使用异构特征,包括数值特征(如 counters 和 ratios)、嵌入以及分类特征(如商品ID、用户ID等)
    • 由于每分钟都有新内容和商品加入,特征空间具有极高的基数,通常达到十亿级别
    • 为了利用这些数以万计的特征,DLRM 采用各种神经网络组合特征、转换中间表示并生成最终输出
    • 尽管 DLRM 利用了大量人工设计的特征和海量数据进行训练,但工业界大多数 DLRM 的计算扩展性较差。这一显著限制至今仍未得到解决
  • 论文的思路 :受 Transformer 在语言和视觉领域成功的启发,论文重新审视了现代推荐系统中的基础设计选择,但作者观察到:在十亿用户规模下,这需要克服三个挑战 :
    • 推荐系统中的特征缺乏明确结构 :虽然在小规模场景中已探索过序列化方法(附录B详细讨论),但工业级 DLRM 中异构特征(包括高基数ID、交叉特征、计数器、比率等)起着关键作用
    • 持续变化的十亿级词汇表 :与语言模型中十万级静态词汇表相比,十亿级动态词汇表带来了训练挑战,并由于需要考虑数万个目标感知(target-aware)候选而推高了推理成本
      • 目标感知(target-aware)指生成用户表示或进行预测时,能够动态地结合当前候选内容(target item)的信息,从而更精准地建模用户与特定候选内容之间的交互
      • 注:传统推荐系统(如DLRMs)中,target-aware通常指在特征交互阶段显式引入候选内容(target item)的信息
    • 计算成本大 :(这是大规模序列模型落地的主要瓶颈),GPT-3使用数千个GPU在1-2个月内处理了总计3000亿 token。这一规模看似惊人,但与用户行为规模相比则相形见绌。最大规模的互联网平台每天服务数十亿日活用户,用户每天与数十亿帖子、图片和视频互动。用户序列长度可达 \(10^5\) 。因此,推荐系统每天需要处理的 token 数量比语言模型1-2个月处理的量高出几个数量级
  • 一个特别重要的创新点 :在本工作中,论文将用户行为视为生成式建模中的新模态(在图片、文本和视频等多模态的基础上,增加一个模态叫做用户行为)。论文的 key insights 是:
    • a)给定适当的新特征空间,工业级推荐系统中的核心排序和召回任务可被转化为生成式建模问题;
    • b)这一范式使论文能够系统性地利用特征、训练和推理中的冗余(redundancies)来提高效率,基于这一新范式,论文部署的模型计算复杂度比之前最先进技术高出三个数量级(如图1所示),同时核心业务指标(英文中称为topline metrics)提升了12.4%
      • 问题:如何理解”特征、训练和推理中的冗余(redundancies)“这句话?
  • 论文的贡献可总结如下:
    • 第一 :提出生成式推荐器(Generative Recommenders,GRs) ,这一新范式将取代 DLRM
      • 论文将 DLRM 中的异构特征空间序列化并统一(注:用户行为视为特定的一个新模态),当序列长度趋近无穷时,新方法可近似完整的 DLRM 特征空间。使得:
        • 论文能将主要推荐问题(排序和召回)重新定义为GR中的纯 Squential Transduction 任务
        • 论文模型训练能以序列化、生成式的方式进行,从而允许论文在相同计算量下处理多几个数量级的数据
    • 第二 :解决训练和推理过程中的计算成本挑战 :
      • 提出新结构 :提出新的 Squential Transduction 架构——分层 Squential Transduction 单元(Hierarchical Sequential Transduction Units,HSTU)。HSTU针对大规模非平稳词汇表修改了注意力机制 ,并利用推荐数据集特性实现比基于 FlashAttention2 的 Transformer 在 8192 长度序列上快5.3倍至15.2倍
      • 提出新算法M-FALCON :通过新算法M-FALCON完全平摊计算成本,论文能在传统 DLRM 使用的相同推理预算下,服务复杂度高出285倍的GR模型,同时实现1.50-2.99倍的加速
    • 第三 :验证所提技术在合成数据集、公开数据集以及在拥有数十亿日活用户的大型互联网平台多个场景中的部署效果
      • 据论文所知,论文的工作首次展示了纯 Squential Transduction 架构(如HSTU)在生成式设置(GRs)中显著优于工业级大规模 DLRM。值得注意的是,论文不仅克服了传统 DLRM 中已知的扩展瓶颈,还进一步证明了扩展定律(scaling law)适用于推荐系统 ,这可能是推荐系统的”ChatGPT时刻“

推荐作为 Squential Transduction 任务:从 DLRM 到GR

统一 DLRM 中的异构特征空间

  • 现代 DLRM 模型通常使用大量分类(sparse)和数值(dense)特征进行训练。在GR中,论文将这些特征整合并编码为统一的时间序列 ,如图2所示
  • 分类(sparse)特征 :这类特征的例子包括用户喜欢的商品、用户关注的某类别(如户外)创作者、用户语言、用户加入的社区、请求发起的城市等。论文按以下方式将这些特征序列化:
    • 选择最长的时间序列(通常通过合并代表用户互动商品的特征)作为主时间序列
    • 其余特征通常是随时间缓慢变化的序列(如人口统计信息或关注的创作者),论文通过保留每个连续段的最早条目(entry)来压缩这些时间序列,然后将结果合并到主时间序列中。由于这些时间序列变化非常缓慢,这种方法不会显著增加总体序列长度
      • 理解:如图2所示,auxiliary time series 1 和 auxiliary time series 2 中,分别只有每个连续段的第一个entry被插入到 main time series 中(注:auxiliary time series 1 包含多个连续段,每个连续段的第一个 entry 都会保留并插入 main time series 中)
    • 这部分写得不是很清晰,有点晦涩,原文介绍如下:

      Categorical (‘sparse’) features. Examples of such features include items that user liked, creators in a category (e.g., Outdoors) that user is following, user languages, communities that user joined, cities from which requests were initiated, etc. We sequentialize these features as follows. We first select the longest time series, typically by merging the features that represent items user engaged with, as the main time series. The remaining features are generally time series that slowly change over time, such as demographics or followed creators. We compress these time series by keeping the earliest entry per consecutive segment and then merge the results into the main time series. Given these time series change very slowly, this approach does not significantly increase the overall sequence length.

  • 数值(dense)特征 :这类特征的例子包括加权和衰减计数器、比率等(例如,某个特征可能代表用户过去对匹配特定主题商品的点击率(CTR))
    • 与分类特征相比,这些特征变化更频繁 ,可能随每个(用户,商品)互动而变化,所以从计算和存储角度看,完全序列化这些特征并不可行
    • 然而一个重要观察是:论文执行这些聚合所基于的分类特征(如商品主题、位置)在GR中已被序列化和编码。因此 ,给定足够强大的 Squential Transduction 架构与 target-aware 公式相结合,当增加GR中的总体序列长度和计算量时,论文可以移除数值特征 ,因为它们能被有效地捕捉

将排序和召回重新定义为 Squential Transduction 任务

  • 给定按时间顺序排列的 \(n\) 个 token 列表 \(x_{0},x_{1},\ldots,x_{n-1}\) (\(x_{i}\in\mathbb{X}\)),观察到这些 token 的时间 \(t_{0},t_{1},\ldots,t_{n-1}\) ,Squential Transduction 任务将此输入序列映射到输出 token \(y_{0},y_{1},\ldots,y_{n-1}\) (\(y_{i}\in\mathbb{X}\cup\{\varnothing\}\)),其中 \(y_{i}=\varnothing\) 表示 \(y_{i}\) 未定义
  • 论文用 \(\Phi_{i}\in\mathbb{X}_{c}\) (\(\mathbb{X}_{c}\subseteq\mathbb{X}\))表示系统向用户提供的内容(即历史上展示给用户看过的内容 ,如图片或视频)。由于不断有新内容产生 , \(\mathbb{X}_{c}\) 和 \(\mathbb{X}\) 是非平稳的。用户可以用某个行为 \(a_{i}\) (如点赞、跳过、视频完播+分享) \(a_{i}\in\mathbb{X}\) 来回应 \(\Phi_{i}\) 。论文用 \(n_{c}\) 表示用户互动过的内容总数
    • 对 Retrieval 输入的理解:输入是 \(\{(\Phi_{0},a_0),(\Phi_{1},a_1),\cdots,(\Phi_{n_c-1},a_{n_c-1})\}\) ,表示历史上依次给用户分别展示了内容 \(\Phi_{i}\) 以后用户的行为 \(a_{i}\) 的元组序列
    • 对 Retrieval 输出的理解:每个元组 \((\Phi_{i},a_i)\) 对应输出一个目标值 \(\Phi_{i}^{\prime}\) :
      $$\Phi_{i}^{\prime}=
      \begin{cases}
      \Phi_{i}& \ a_i \ \text{ is positive}\\
      \varnothing& \ \text{ otherwise}
      \end{cases}$$
    • 对 Ranking 输入的理解:如表1所示,对输入 \(x_i\) 的建模是, \(\{\Phi_{0},a_0,\Phi_{1},a_1,\cdots,\Phi_{n_c-1},a_{n_c-1}\}\) 是历史行为序列,表示历史上依次给用户分别展示了内容 \(\Phi_{i}\) 以后用户的行为 \(a_{i}\)
    • 对 Ranking 输出的理解:如表1所示,输出和输入依次对应,输入为 \(\Phi_{i}\) 时预估目标是用户的行为 \(a_{i}\)
  • 在因果自回归设置中,标准排序和 Retrieval 任务可定义为 Squential Transduction 任务(表1)。论文得出以下观察:
    • 召回(Retrieval) :在推荐系统的召回阶段,论文学习 \(\Phi_{i+1}\in\mathbb{X}_{c}\) 上的分布 \(p(\Phi_{i+1}|u_{i})\) ,其中 \(u_{i}\) 是用户 \(i\) 的 token 表示。典型目标是选择 \(\arg\max_{\Phi\in\mathbb{X}_{c} }p(\Phi|u_{i})\) 以最大化某些奖励。这与标准自回归设置有两个不同:首先, \(x_{i}\) 的监督信号 \(y_{i}\) 不一定是 \(\Phi_{i+1}\) ,因为用户可能对 \(\Phi_{i+1}\) 做出负面回应。其次,当 \(x_{i+1}\) 代表非互动相关的分类特征(如人口统计信息)时, \(y_{i}\) 未定义
    • 排序(Ranking) :GR中的 Ranking 任务带来独特挑战,因为工业推荐系统通常需要“目标感知”(target-aware)公式。在此设置中,目标 \(\Phi_{i+1}\) 与历史特征的”交互”(interaction)需要尽早发生,这在标准自回归设置中不可行(理解:不建模动作时,”交互”发生较晚)。论文通过表1中的交错排列商品和行为来解决这一问题,使得 Ranking 任务可被公式化为 \(p(a_{i+1}|\Phi_{0},a_{0},\Phi_{1},a_{1},\ldots,\Phi_{i+1})\) (在分类特征之前)。实践中论文使用小型神经网络将在 \(\Phi_{i+1}\) 的输出转换为多任务预测。重要的是,这使论文能在单次传递中对所有 \(n_{c}\) 次互动应用 target-aware 交叉注意力
  • 补充:对论文 target-aware 实现方式的理解 :其实 Transformer 的 Attention 也有交叉功能,只要输入端有目标 item 和用户历史交互 item 即可,但论文所说的传统的自回归模型是指输入侧不包含目标 item 的情况,论文中,将动作也建模进去,则在输出对目标 item 的动作 token 时,自然就需要将目标 item 作为输入,从而也就实现了 target-aware 交叉注意力:
    • 传统自回归模型预估目标为 : \(p(\Phi_{i+1}|\Phi_{0},\Phi_{1},\ldots,\Phi_{i})\),目标物品 \(\Phi_{i+1}\) 与其他历史序列无交叉
      • 注:此时动作 \(a_i\) 信息可能会作为额外表征加入到 \(\Phi_{i}\)中,所以 表述为 \(p(\Phi_{i}|(\Phi_{0}, a_{0}), …, (\Phi_{i - 1}, a_{i - 1}))\) 也可以
    • 论文预估目标为 : \(p(a_{i+1}|\Phi_{0},a_{0},\Phi_{1},a_{1},\ldots,\Phi_{i+1})\),目标物品 \(\Phi_{i+1}\) 与其他历史序列有交叉

生成式训练

  • 工业推荐系统通常在流式设置中训练,其中每个样本在可用时被顺序处理。在此设置中,基于自注意力的 Squential Transduction 架构(如 Transformer)的总计算需求量级为 \(\sum_{i}n_{i}(n_{i}^{2}d+n_{i}d_{ff}d)\) ,其中 \(n_{i}\) 是用户 \(i\) 的 token 数量, \(d\) 是嵌入维度,设 \(N=\max_{i}n_{i}\) ,总体时间复杂度降至 \(O(N^{3}d+N^{2}d^{2})\) ,这对推荐场景来说成本过高
    • 注:更多对公示的理解看笔者补充的附录
  • 为应对训练 Squential Transduction 模型处理长序列的挑战,论文从传统曝光级(impression-level)训练转向生成式训练,将计算复杂度降低 \(O(N)\) 因子,如图2顶部所示。这样做可将编码器成本平摊到多个目标上。具体来说,当论文以 \(s_{u}(n_{i})\) 的速率采样第 \(i\) 个用户时,总训练成本现在按 \(\sum_{i}s_{u}(n_{i})n_{i}(n_{i}^{2}d+n_{i}d^{2})\) 缩放,通过设 \(s_{u}(n_{i})\) 为 \(1/n_{i}\) 可降至 \(O(N^{2}d+Nd^{2})\) 。在工业级系统中实现此采样的一种方法是在用户请求或会话结束时发出训练样本,使得 \(\hat{s_{u} }(n_{i})\propto 1/n_{i}\)
    • 理解:从「曝光级(impression-level)训练转向生成式训练」主要是强调从以前每个曝光都是一个样本,现在一个用户的所有行为组成一个样本
    • 问题:这里对用户进行采样不会带来效果损失吗?模型看不到部分用户是OK的吗?

面向生成式推荐的高性能自注意力编码器(A High Performance Self-Attention Encoder for Generative Recommendations)

  • 为了将生成式推荐系统(GRs)扩展至具有大规模非静态词表的工业级推荐场景,论文提出了一种新型编码器设计——分层序列转导单元(Hierarchical Sequential Transduction Unit, HSTU)

  • HSTU由多个通过残差连接堆叠的相同层构成,每层包含三个子层:Pointwise Projection(公式1)、Spatial Aggregation(公式2)和 Pointwise Transformation(公式3):

$$ U(X), V(X), Q(X), K(X) = \text{Split}(\phi_1(f_1(X))) \tag{1} $$

$$ A(X)V(X) = \phi_2\left(Q(X)K(X)^T + \text{rab}^{p,t}\right)V(X) \tag{2} $$

$$ Y(X) = f_2\left(\text{Norm}\left(A(X)V(X)\right) \odot U(X)\right) \tag{3} $$

  • 其中:
    • \( f_i(X) \) 表示多层感知机(MLP);为降低计算复杂度, \( f_1 \) 和 \( f_2 \) 使用单线性层 \( f_i(X) = W_i(X) + b_i \)
      • 注:还通过融合内核批量处理查询 \( Q(X) \) 、键 \( K(X) \) 、值 \( V(X) \) 和门控权重 \( U(X) \) (即一次性将Q,K,V和Q映射都做完)
    • \( \phi_1 \) 和 \( \phi_2 \) 为非线性激活函数,均采用SiLU(即Swish);
    • Norm 表示 LayerNorm
    • \( \text{rab}^{p,t} \) 为结合位置(\( p \))和时间(\( t \))信息的相对注意力偏置
  • HSTU的编码器设计允许用单一模块化块替代DLRMs中的异构模块。论文观察到DLRMs实际包含三个阶段:特征提取、特征交互和表征变换
    • 1:特征提取(Feature Extraction)阶段通过池化操作获取类别特征的嵌入表示,其高级形式可泛化为成对注意力和 target-aware 池化
      • 如 HSTU 层所实现,能够做到上述能力
    • 2:特征交互(Feature Interaction)是DLRMs的核心,常用方法包括因子分解机及其神经网络变体、高阶特征交互等
      • HSTU通过注意力池化特征直接与其他特征交互(即 \(\text{Norm}(A(X)V(X)) \odot U(X)\))替代传统特征交互模块
      • 该设计源于学习型MLP难以近似点积的挑战(Rendle等, 2020; Zhai等, 2023a),由于SILU作用于 \(U(X)\),此结构也可视为SwiGLU(Shazeer, 2020)的变体
    • 3:表征变换(Transformations of Representations)阶段通常采用混合专家(MoE)和路由机制处理异构用户群体,其核心思想是通过子网络 specialization 实现条件计算
      • HSTU中的逐点乘积操作(Element-wise dot products)本质上能在归一化因子范围内模拟MoE的门控机制
  • 以上实现中可以看出,HSTU Layer 实现了类似 Transformer Layer 的功能
    • \(X\) 是 HSTU 的输入, \(Y(X)\) 是 HSTU 的输出
    • \(A(X)\) 为输入 \(X\) 的注意力向量(注:这里没有做归一化),是带位置偏置 \(\text{rab}^{p,t}\) 的 \(Q,K\) 内积
  • HSTU的一层 与传统的 Transformer 相比,HSTU有以下不同:
    • HSTU 增加了一个 \(U(X)\),跟Q,K,V的 projection 在同一层处理,\(U(X)\) 会作为一个 point-wise的向量与 HSTU Attention 的结果相乘
    • HSTU 没有使用传统 Transformer 中的 Softmax Attention机制,放弃了 Attention 权重和为1的约束,文中提到,这能保留用户对 item 的动作强度信息(Softmax会使得这部分信息失真)
    • HSTU 没有使用 FFN 层,这一层被 \(f_2\left(\text{Norm}\left(A(X)V(X)\right) \odot U(X)\right)\) 替换了,实现了将 MLP 替换为 point-wise 乘法,速度更快
      • 理解1:这里更像是使用 \(U(X)\) 保留了原始特征,和 Attention 后的高阶交叉特征 point-wise 相乘,得到高阶交叉特征同时也保留低阶交叉;
      • 理解2:这里的操作可以看做是 对 \(U(X)\) 进行加权,如前文所述,像是门网络/个性化加权机制实现对原始 Embedding \(U(X)\) 的加权(可类比于MoE的思想)

逐点聚合注意力(Pointwise aggregated attention)

  • HSTU采用了一种新型的Pointwise aggregated attention(区别于传统的Softmax注意力)。其动机有二:
    • 第一:与目标相关的历史数据点数量是用户偏好强度的关键特征,而Softmax归一化会削弱这一信息(Softmax的归一化会将强度信息也归一化为比例,这会使得绝对值失真)
    • 第二:Softmax对噪声鲁棒,但在流式非静态词表场景中表现欠佳(频繁增加新的 item 会影响 Softmax)
  • 实验表明,在合成数据(基于 Dirichlet 过程生成的非静态词表流式数据)中,逐点聚合注意力机制相比Softmax注意力可将Hit Rate@10提升高达44.7%(见表2)

利用算法增加稀疏性

  • 在推荐系统中,用户历史序列的长度通常呈现偏态分布(skewed distribution,这里主要指大部分用户交互序列短,少部分用户交互序列长),导致输入序列具有稀疏性,尤其是在处理超长序列时。这种稀疏性可以被有效利用以显著提升编码器的效率。为此,论文开发了一种高效的GPU注意力核函数,其设计类似于[Dao等, 2022; Zhai等, 2023b],但实现了完全分组的注意力计算。这本质上将注意力计算转化为不同尺寸的分组GEMM(通用矩阵乘法)运算(详见附录G)。因此,HSTU中的自注意力变为内存受限(memory-bound)操作,其内存访问复杂度为 \(\Theta(\sum_{i}n_{i}^{2}d_{qk}^{2}R^{-1})\) ,其中 \(n_i\) 为样本 \(i\) 的序列长度, \(d_{qk}\) 为注意力维度, \(R\) 为寄存器大小。仅此一项改进即可带来2-5倍的吞吐量提升(详见第4.2节讨论)
  • 论文进一步通过随机长度(Stochastic Length, SL)算法增加用户历史序列的稀疏性。推荐场景中用户历史序列的一个关键特征是用户行为在时间上具有重复性,这种重复性体现在用户交互历史的多尺度行为模式中。这为论文提供了在不损失模型质量的前提下人为增加稀疏性的机会,从而显著降低编码器计算成本(其复杂度为 \(\Theta(\sum_{i}n_{i}^{2})\))
    • 问题:这种降低稀疏性的方式是有损的吧?此外,相同算力消耗下,随机长度是不是不如仅保留最近多个 item 的策略?
  • 假设用户 \(j\) 的历史序列为 \((x_i)_{i=0}^{n_{c,j} }\) ,其中 \(n_{c,j}\) 为用户交互过的内容数量。令 \(N_c = \max_j n_{c,j}\) , \((x_{i_k})_{k=0}^L\) 为从1原始序列 \((x_i)_{i=0}^{n_{c,j} }\) 中选取的长度为 \(L\) 的子序列。SL算法按以下方式选择输入序列:
    $$
    \begin{align}
    (x_i)_{i=0}^{n_{c,j} } \text{ if } n_{c,j} &\leq N_c^{\alpha/2}\\
    (x_{i_k})_{k=0}^{N_c^{\alpha/2} } \text{ if } n_{c,j} &> N_c^{\alpha/2}, \text{w/ probability } 1 - N_c^{\alpha}/n_{c,j}^2\\
    (x_i)_{i=0}^{n_{c,j} } \text{ if } n_{c,j} &> N_c^{\alpha/2}, \text{w/ probability } N_c^{\alpha}/n_{c,j}^2 \tag{4}
    \end{align}
    $$
    • 其中 “w/ probability xx” 是 “with probability”的简写,表示以 “xx” 概率采样
    • \(N_c^{\alpha/2}\) 表示 \(N_c\) 的 \(\alpha/2\) 次方, \(\alpha\) 是一个调节采样比例的数字
  • 公式4描述了Stochastic Length (SL)算法如何根据用户历史序列的长度(\(n_{c,j}\))动态调整子序列的采样策略:
    • 如果序列长度 \(n_{c,j} \leq N_c^{\alpha/2}\) ,则:
      • 直接使用完整序列
    • 如果序列长度 \(n_{c,j} > N_c^{\alpha/2}\) ,则
      • 以概率 \(1 - N_c^\alpha / n_{c,j}^2\) 选择长度为 \(N_c^{\alpha/2}\) 的子序列
      • 以概率 \(N_c^\alpha / n_{c,j}^2\) 仍使用完整序列
  • 该算法将注意力相关复杂度降低至 \(O(N_c^{\alpha}d) = O(N^{\alpha}d)\) ,其中 \(\alpha \in (1,2]\) 。关于子序列选择的更详细讨论见附录F.1。值得注意的是,将SL应用于训练能够实现高性价比的系统设计,因为训练的计算成本通常远高于推理
  • 表3展示了在具有30天用户历史的典型工业规模配置下,不同序列长度和 \(\alpha\) 值对应的稀疏性(定义见附录F)。其中对模型质量影响可忽略的设置用下划线和蓝色标出。标为“ \(\alpha=2.0\) ”的行表示未应用SL时的基准稀疏性。更低的 \(\alpha\) 值适用于更长的序列,论文测试的最长序列长度为8,192

最小化激活内存使用

  • 在推荐系统中,使用大的 Batch Size 训练对提升训练吞吐量和模型质量至关重要(理解:大的 Batch Size 更容易得到准确的梯度),因此,激活内存的使用成为扩展的主要瓶颈,这与通常以小的 Batch Size 训练且以参数内存为主的大型语言模型形成鲜明对比
  • 与 Transformer 相比,HSTU采用了一种简化且完全融合的设计,显著降低了激活内存的使用
    • 首先,HSTU将注意力外的线性层数量从 6 个减少到 2 个,这与近期通过逐元素门控减少MLP计算的工作[Hua等, 2022; Zhai等, 2023b]一致
    • 其次,HSTU将计算过程激进地融合为单一算子,包括公式(1)中的 \(\phi_1(f_1(\cdot))\) ,以及公式(3)中的层归一化、可选Dropout和输出MLP。这种简化设计将每层的激活内存使用量降低至 \(2d + 2d + 4hd_{qk} + 4hd_v + 2hd_v = 14d\) (以bfloat16格式存储)
    • 理解:这里的减少线性层数量主要是通过移除传统的FFN来实现(传统的一个FFN包含两个线性层),HSTU总体仅保留两个线性层,即输入层和输出层
    • 注:文章这里说的将线性层从 6 层降低到 2 层的表达应该是跳过了一些设计细节,比如网络结构中FFN的数量未明确
  • 作为对比, Transformer 在注意力后使用前馈层和Dropout(中间状态为 \(3hd_v\)),随后是一个逐点前馈块,包含层归一化、线性层、激活函数、线性层和Dropout,中间状态为 \(2d + 4d_{ff} + 2d + 1d = 4d + 4d_{ff}\) 。这里论文假设 \(hd_v \geq d\) 且 \(d_{ff} = 4d\) [Vaswani等, 2017; Touvron等, 2023a]。因此,在考虑输入和输入层归一化(\(4d\))以及QKV投影后,总激活状态为 \(33d\) 。HSTU的设计使其能够扩展到比 Transformer 深2倍以上的网络
  • 此外,用于表示词汇表的大规模原子ID也需要大量内存。对于包含100亿词汇、512维嵌入和Adam优化器的配置,仅存储嵌入和优化器状态(以fp32格式)就需要60TB内存。为缓解内存压力,论文采用了 row-wise AdamW 优化器并将优化器状态放置在DRAM中,从而将每个浮点数的HBM使用量从12字节减少到2字节
    • 注:论文给的引用中没有详细介绍 row-wise AdamW 优化器
    • HBM(High Bandwidth Memory,高带宽内存)和 DRAM(Dynamic Random Access Memory,动态随机存取存储器)都是计算机内存技术,HBM 将多个 DRAM 芯片堆叠在一起(HBM可以理解为更为昂贵的告诉内存),拥有比 DRAM 更高的数据传输性能
  • 总结来看 :HSTU通过以下设计显著降低内存占用:
    • 第一:将注意力外的线性层从6层减至2层,采用逐点门控减少MLP计算;
    • 第二:将计算融合为单一算子(如公式1中的 \( \phi_1(f_1(\cdot)) \) 和公式3中的层归一化、Dropout及输出MLP);
    • 第三:使用row-wise AdamW优化器,将优化器状态存储于DRAM,使每个浮点数在HBM的占用从12字节降至2字节
    • 最终,与 Transformer 相比,HSTU的激活内存占用从每层33d降至14d(bfloat16),支持构建2倍更深的网络

通过成本分摊扩展推理规模

  • 论文需要解决的最后一个挑战是推荐系统在服务时需要处理大量候选内容。论文主要关注 Ranking 任务 :
  • 在Retrieval 任务(无需论文关注)中,编码器成本完全可以分摊,且已有高效的算法支持基于量化(quantization)、哈希(hashing)或分区(partitioning)的最大内积搜索(MIPS),以及基于束搜索或分层检索的 non-MIPS 场景
    • MIPS 是 Maximum Inner Product Search(最大内积搜索)的缩写,主要用于高效地找到与给定查询向量内积最大的候选向量 ,在召回阶段,MIPS可以与量化、哈希或分区等技术结合使用,以提升效率
  • 对于Ranking 任务 ,候选内容数量可能高达数万。论文提出了一种算法M-FALCON(Microbatched-Fast Attention Leveraging Cacheable Operations)(可译为微批次-快速注意力利用可缓存操作),用于在输入序列长度为 \(n\) 时对 \(m\) 个候选内容进行推理
    • 批量推理 :在一次前向传播中,M-FALCON通过修改注意力掩码和 \(\text{rab}^{p,t}\) 偏置,使得 \(b_m\) 个候选内容的注意力操作完全一致,从而并行处理 \(b_m\) 个候选内容,此时原本的 \(b_m\) 次长度为 \(n\) 的推理转换为 1 次 \(b_m+n\) 的推理(注:这里有MTP的味道)
      • 这将交叉注意力的计算成本从 \(O(b_m n^2 d)\) 降低到 \(O((n + b_m)^2 d) = O(n^2 d)\) (当 \(b_m\) 相对于 \(n\) 可视为小常数时)
    • KV Caching机制 :论文还可以将全部 \(m\) 个候选内容划分为 \(\lceil m/b_m \rceil\) 个大小为 \(b_m\) 的微批次,以利用编码器级的KV缓存 ,从而在跨前向传播时降低成本,或在跨请求时最小化尾部延迟
    • 关于M-FALCON方法的更多细节见附录H
  • 总体而言,M-FALCON使得模型复杂度能够与传统DLRM排序阶段的候选数量线性扩展。论文成功地在典型的排序配置中(见第4.3节)应用了复杂度高28倍的目标感知交叉注意力模型 ,同时在恒定推理预算下实现了1.5倍至3倍的吞吐量提升
  • M-FALCON方法的整体结构图如下(From 附录):

Experiments

验证HSTU编码器的归纳假设

传统序列化设置
  • 论文首先在两个流行的推荐数据集(MovieLens和Amazon Reviews)上评估HSTU的性能。实验遵循文献中的序列化推荐设置,包括完全打乱和多轮次训练。基线采用最先进的 Transformer 实现SASRec(Kang & McAuley, 2018)。与近期工作(Dallmann等, 2021; Zhai等, 2023a)一致,论文报告整个语料库的Hit Rate@K和NDCG@K
  • 结果如表4所示。“SASRec (2023)”表示(Zhai等, 2023a)中报告的最佳SASRec配置。“HSTU”行使用与SASRec相同的配置(层数、头数等)。“HSTU-large”表示更大的HSTU编码器(层数4倍,头数2倍)。结果表明:
    • a)HSTU因其针对推荐优化的设计,在相同配置下显著优于基线;
    • b)HSTU在扩展后性能进一步提升
  • 需要注意的是,此处使用的评估方法与工业级设置差异显著,因为完全打乱和多轮次训练在工业级流式设置中通常不实用(Liu等, 2022)
工业级流式设置
  • 接下来,论文在工业级数据集的流式设置中比较HSTU、消融版HSTU和 Transformer 的性能。本节剩余部分中,论文报告排名的归一化熵(Normalized Entropy,NE)(详细定义见附录)。模型训练超过1000亿样本(DLRM等效),每个任务使用64-256块H100 GPU(!!!)。由于排名是多任务设置,论文报告主要互动事件(“E-Task”)和主要消费事件(“C-Task”)。在论文的上下文中,NE降低0.001通常意味着数十亿用户的顶层指标提升0.5%。对于 Retrieval 任务,由于设置与语言建模类似,论文报告对数困惑度。在小规模设置中固定编码器参数(排名任务: \(l=3\) , \(n=2048\) , \(d=512\) ; Retrieval 任务: \(l=6\) , \(n=512\) , \(d=256\)),并因资源限制对其他超参数进行网格搜索
  • 结果如表5所示:
    • 第一 :HSTU显著优于 Transformer,尤其是在排名任务中,这可能是由于点注意力机制和改进的相对注意力偏差
    • 第二 :消融版HSTU与完整HSTU之间的差距验证了论文的设计有效性。Softmax版HSTU和 Transformer 的最佳学习率比其他模型低约10倍,这是为了训练稳定性。即使采用较低学习率和预归一化残差连接(Xiong等, 2020),标准 Transformer 在排名任务中仍频繁出现损失爆炸
    • 第三 :HSTU优于LLM中流行的 Transformer 变体 Transformer ++(Touvron等, 2023b), Transformer ++ 使用了RoPE、SwiGLU等技术。总体而言,在此小规模设置中,HSTU在质量上表现更优,同时训练速度提升1.5-2倍,HBM使用减少50%

编码器效率

  • 随机长度(Stochastic Length, SL)。图4和图5(a)展示了随机长度(SL)对模型指标的影响
    • 当 \(\alpha=1.6\) 时,长度为 4096 的序列在大多数情况下被压缩为 776,移除了超过 80% 的 token。即使 sparse 率增加到 64%-84%,主要任务的NE下降不超过 0.002(0.2%)
    • 以上这一证据表明,对于合适的 \(\alpha\) 值,SL 不会对模型质量产生负面影响,同时可以通过高 sparse 性降低训练成本。论文还在附录F.3中验证了 SL 显著优于现有的长度外推技术
  • 编码器效率。图5比较了HSTU和 Transformer 编码器在训练和推理设置中的效率
    • 对于 Transformer,论文使用最先进的FlashAttention-2(Dao, 2023)实现
    • 考虑序列长度从 1,024 到 8,192,并在训练中应用随机长度(SL)
    • 评估时,HSTU和 Transformer 使用相同配置(\(d=512\) , \(h=8\) , \(d_{qk}=64\)),并消融相对注意力偏差(如第4.1.2节所示,HSTU在不使用rab \(^{p,t}\) 时仍优于 Transformer)。在NVIDIA H100 GPU上以bfloat16比较编码器级性能。总体而言,HSTU在训练和推理中的效率分别比 Transformer 高15.2倍和5.6倍
  • 此外,如第3.3节讨论的激活内存使用减少,使论文能够构建比 Transformer 深2倍以上的HSTU网络

工业级流式设置中生成式推荐器(GR)与 DLRM 的比较

  • 最后,论文比较了GR与最先进的 DLRM 基线在工业级流式设置中的端到端性能。论文的GR实现反映了生产中使用的典型配置,而 DLRM 设置则是数百人多年迭代的结果
    • 由于推荐系统的召回阶段使用了多个生成器(理解:这里应该是指多路召回),论文报告了将 GR 作为新增召回通道(“new source”)和替换现有主要 DLRM 通道(“replace”)的在线结果。表6和表7显示,GR不仅在离线测试中显著优于 DLRM,还在A/B测试中带来了12.4%的提升
    • 从表6中可知,作为一个召回通道加入能得到的更好的效果
  • 如第2节所述,GR基于原始分类互动特征构建,而 DLRM 通常使用更多手工特征(handcrafted features)。如果论文将GR使用的相同特征集提供给 DLRM (“DLRM(abl. features)”,消融特征),DLRM 的性能会显著下降,这表明GR通过其架构和统一特征空间可以有效地捕捉这些特征
  • 论文进一步通过与传统序列推荐器设置的比较验证了第2.2节中的GR公式。传统设置仅考虑用户交互过的物品(Kang & McAuley, 2018),结果显示其性能显著下降(表6和表7中的“GR(interactions only)”行)。论文还包含了一个仅使用内容特征的GR基线(“GR(content-based)”)。内容基线与 DLRM /GR之间的巨大性能差距凸显了高基数用户行为的重要性
  • 图6比较了GR与生产 DLRM 的效率。尽管GR模型的复杂度是 DLRM 的285倍(其中 \(285\times \text{FLOPs}\) 表示其复杂度是DLRM的285倍),但由于HSTU和第3.4节中的M-FALCON算法,论文在评分1024/16384个候选时实现了1.50倍/2.99倍的更高QPS
推荐系统的 scaling law
  • 众所周知,在大规模工业设置中,DLRM 在特定计算和参数范围内会达到质量饱和(Zhao等, 2023)。论文比较了GR和 DLRM 的可扩展性以更好地理解这一现象
  • 由于特征交互层对 DLRM 性能至关重要(Mudigere等, 2022),论文尝试了 Transformer (Vaswani等, 2017)、DHEN(Zhang等, 2022)以及论文生产设置中使用的带残差连接的DCN变体(He等, 2015)来扩展 DLRM 基线
    • 对于召回基线,论文通过调整隐藏层大小、嵌入维度和层数来扩展模型
    • 对于基于HSTU的生成式推荐器(GR),论文通过调整HSTU的超参数(如残差层数、序列长度、嵌入维度、注意力头数等)扩展模型,并调整 Retrieval 任务的负样本数量
  • 结果如图7所示:
    • 在低计算量区域,由于手工特征(handcrafted features)的存在,DLRM 可能优于GR,这验证了传统 DLRM 中特征工程的重要性
    • 然而,GR在FLOPs方面表现出显著更好的可扩展性(scalability),而 DLRM 性能趋于饱和,这与先前的研究结果一致
    • 论文还观察到,GR在嵌入参数和非嵌入参数方面均具有更好的可扩展性,其模型参数达到 1,500B ,而 DLRM 在约 200B 参数时性能饱和(这个参数量的训练,Meta真够有钱!)
  • 最后,论文的所有主要指标(包括 Retrieval 任务的 Hit Rate@100 和 Hit Rate@500,以及排名任务的NE)在适当的超参数下 ,随计算量的增加呈现幂律扩展(power law of compute used)
    • 这一现象在三个数量级范围内均成立,直至论文测试的最大模型(序列长度8,192,嵌入维度1,024,24层HSTU)
    • 此时,论文使用的总计算量(按365天标准化,因为论文采用标准流式训练设置)接近GPT-3(Brown等, 2020)和LLaMa-2(Touvron等, 2023b)的总训练计算量,如图1所示
    • 与语言建模(Kaplan等, 2020)相比,序列长度在GR中扮演更重要的角色,因此需要同步扩展序列长度和其他参数
    • 这或许是论文提出的方法最重要的优势,因为论文首次证明了 LLM 的 scaling law 可能也适用于大规模推荐系统

相关工作

传统序列推荐方法

  • 先前关于序列推荐的研究将用户交互行为简化为单一同质的物品序列(Hidasi等,2016;Kang & McAuley,2018)。工业级应用中,序列方法主要作为深度推荐模型(DLRMs)的一部分,包括成对注意力(Zhou等,2018)或序列编码器(Chen等,2019;Xia等,2023)。为了提高效率,已有研究探索了多阶段注意力替代自注意力机制(Chang等,2023)。在 Retrieval 任务中,将ID表示为 token 序列的生成方法也有所探索(Zhuo等,2020)。附录B.1中对相关工作进行了更详细的讨论

高效注意力机制

  • 由于自注意力机制的 \(O(n^2)\) 复杂度,高效注意力一直是研究热点,主要工作包括因子分解注意力(Child等,2019)、低秩近似(Katharopoulos等,2020)等。近期,针对序列转导任务的新架构也被提出(Gu等,2022;Hua等,2022)。HSTU的逐元素门控设计尤其受到FLASH(Hua等,2022)的启发。硬件感知的优化方法显著降低了内存使用(Rabe & Staats,2021;Korthikanti等,2022;Zhai等,2023b),并大幅缩短了实际运行时间(Dao等,2022)

长度外推技术

  • 长度外推技术使模型能够在较短的序列上训练后泛化到更长的序列,但多数研究聚焦于微调或改进偏置机制(Press等,2022)。论文则通过在长度维度引入随机性,受深度维度随机性研究(Huang等,2016)的启发

大语言模型(LLMs)的应用

  • 大语言模型的兴起促使研究者将推荐任务视为上下文学习(Silco等,2022)、指令微调(Bao等,2023)或基于预训练LLMs的迁移学习(Li等,2023)。LLMs中嵌入的世界知识可以迁移到下游任务(Cui等,2022),并在零样本或少样本场景中提升推荐效果。用户行为序列的文本表示也在中等规模数据集上展现出良好的扩展性(Shin等,2023)。然而,多数关于LLMs用于推荐的研究集中在低数据量场景;在大规模场景中,它们尚未在MovieLens等数据集上超越协同过滤方法(Hou等,2024)

整体总结

  • 新范式提出 :论文提出了一种新范式:生成式推荐器(Generative Recommenders, GRs),将排序和 Retrieval 任务重新定义为序列直推式任务(Sequential Transduction Tasks),使其能够以生成式方式进行训练
    • 注:最核心的创新点可能是将用户行为也视为一个像是图片、视频和文本类似的新模态
  • 性能优化 :得益于新颖的HSTU编码器设计,其在8192长度的序列上比当前最先进的 Transformer 快5.3至15.2倍,同时结合了M-FALCON等新型训练和推理算法
  • 实验 :通过GRs,论文部署了复杂度提升285倍的模型,同时减少了推理计算量。GRs和HSTU在生产环境中实现了12.4%的指标提升,并展现出比传统 DLRMs 更优的扩展性能
    • 论文的结果验证了用户行为是生成式建模中一个尚未充分探索的模态——正如标题所言,Actions speak louder than words(行动胜于雄辩)
  • 论文通过统一特征空间的简化设计,为推荐、搜索和广告领域的首个基础模型铺平了道路。GRs的完全序列化设置还支持端到端的生成式推荐框架。这两点使得推荐系统能够更全面地辅助用户
  • 未来规划 :生成式推荐器的完全序列化特性有望进一步推动端到端推荐系统的发展,例如通过直接生成推荐序列而非传统的列表排序。作者相信,这一方向将为推荐系统带来更广阔的应用前景

附录(个人): Transformer 模型计算量评估

  • 在原论文2.3节的公式 \(\sum_{i} n_{i}(n_{i}^{2}d + n_{i}d_{ff}d)\) 的来源解释如下

  • 注: \(d_{ff}\) 表示 前馈神经网络(Feed-Forward Network, FFN)的隐藏层维度(即中间层的神经元数量),在原始 Transformer 中通常设为 \(d_{ff} = 4d\)

  • 在 Transformer 或类似的自注意力架构(如HSTU)中,每个编码器层通常包含两个核心子层:

    • 自注意力机制(Self-Attention):复杂度为 \(O(n_i^2 d)\) (与序列长度平方和嵌入维度相关)
      • 自注意力的计算涉及三个主要步骤:
        • 计算查询(Q)、键(K)、值(V)矩阵:复杂度为 \(O(n_i d^2)\) ,其中 \(n_i\) 是序列长度, \(d\) 是嵌入维度
        • 计算注意力分数 \(QK^T\) :复杂度为 \(O(n_i^2 d)\)
        • 加权求和 \(AV\) :复杂度为 \(O(n_i^2 d)\)
      • 因此,自注意力的总复杂度为 \(O(n_i^2 d + n_i d^2)\) 。对于长序列(\(n_i \gg d\)),主导项为 \(O(n_i^2 d)\)
    • 前馈神经网络(FFN):复杂度为 \(O(n_i d_{ff} d)\) (与序列长度、嵌入维度和FFN隐藏层维度相关,)
    • FFN的典型结构是:
      $$
      \text{FFN}(x) = W_2 \cdot \text{ReLU}(W_1 x + b_1) + b_2
      $$
      • 其中:
        • \(W_1 \in \mathbb{R}^{d \times d_{ff} }\) :将输入从 \(d\) 维映射到 \(d_{ff}\) 维(通常 \(d_{ff} = O(d)\) ,注意不是相等,一般是通常 \(d_{ff} = 4d\) 等)
        • \(W_2 \in \mathbb{R}^{d_{ff} \times d}\) :将结果映射回 \(d\) 维
    • 在 Transformer 中, \(d_{ff}\) 通常远大于 \(d\) (例如 \(d_{ff}=4d\)),这使得FFN成为计算瓶颈之一。HSTU通过点乘门控(如 \(U(X)\))替代了传统FFN(见式3),因此可能不再需要显式的 \(d_{ff}\) 参数。但在对比 Transformer 基线时,公式中仍保留了 \(d_{ff}\) 以反映传统架构的成本
  • 总计算复杂度

    • 对于每个用户的序列,计算量为自注意力和前馈网络复杂度的总和:
      $$
      O(n_i^2 d + n_i d^2)
      $$
    • 若对所有用户求和,总计算量为:
      $$
      \sum_i n_i (n_i^2 d + n_i d^2)
      $$
    • 假设最大序列长度为 \(N = \max_i n_i\) ,则复杂度可简化为:
      $$
      O(N^3 d + N^2 d^2)
      $$
  • 论文中的其他优化:论文提出通过生成式训练(generative training)对计算量进行优化:

    • 通过调整用户采样率 \(s_u(n_i) \propto 1/n_i\) ,将总计算量降低为:
      $$
      \sum_i s_u(n_i) n_i (n_i^2 d + n_i d^2) \approx O(N^2 d + N d^2)
      $$
    • 这一优化将复杂度从三次方(\(N^3\))降至二次方(\(N^2\))

附录(个人):关于NE的完整定义

  • 参考原始论文:Practical Lessons from Predicting Clicks on Ads at Facebook
  • NE(Normalized Cross Entropy,归一化交叉熵),也可更准确地称为Normalized Logarithmic Loss(归一化对数损失),其定义为平均每个展示的对数损失除以模型预测每个展示的 CTR 时的平均对数损失
  • 假设给定的训练数据集有 \(N\) 个样本,则计算公式为:
    $$NE=\frac{-\frac{1}{N} \sum_{i=1}^{n}\left(\frac{1+y_{i} }{2} log \left(p_{i}\right)+\frac{1-y_{i} }{2} log \left(1-p_{i}\right)\right)}{-(p * log (p)+(1-p) * log (1-p))} $$
    • \(y_{i} \in{-1,+1}\) 为真实标签
    • \(p_{i}\) 为第 \(i\) 个样本的CTR预估值,其中 \(i = 1,2,\cdots,N\)
    • \(p\) 平均经验点击率(基于统计的真实点击率)
  • 该指标用于评估模型预测的好坏,其值越低,模型的预测效果越好
  • 进行这种归一化处理的原因是,真实 CTR 预估值越接近0或1,就越容易获得更好的对数损失,而除以 真实 CTR 的熵可使 NE 对真实 CTR 不敏感
  • 其他相关指标:Relative Information Gain (RIG)
    $$ RIG = 1 − NE $$

附录A:符号说明

  • 论文在下表中总结了论文使用的关键符号(对应原始论文的表8和表9)
    符号 描述
    \(\Psi_{k}(t_{j})\) Feature Logging 系统在 \(t_{j}\) 时刻发出的第 \(k\) 个 Training Example(\(k\) 是全局排序的)
    在典型的深度推荐模型(DLRM)推荐系统中,用户消费某些内容 \(\Phi_{i}\) (通过诸如跳过、视频观看完成和分享等动作 \(a_{i}\) 做出响应)后,特征记录系统将元组 \((\Phi_{i}, a_{i})\) 与用于对 \(\Phi_{i}\) 进行排名的特征相结合,并发出 \((\Phi_{i}, a_{i}\) ,以及 \(\Phi_{i}\) 的特征) 作为 Training Example \(\Psi_{k}(t_{j})\)
    如2.3节所述,DLRM和生成式推荐器(GRs)处理的 Training Example 数量不同,GRs中的 Example 数量通常少1 - 2个数量级
    \(n_{c}(n_{c,i})\) 与用户/样本 \(i\) 交互的内容数量
    \(\Phi_{0}, …, \Phi_{n_{c}-1}\) 在推荐系统的上下文中,与用户交互的内容列表
    \(a_{0}, …, a_{n_{c}-1}\) 与内容 \(\Phi_{i}\) 对应的用户动作列表
    当所有预测事件都是 Binary 时,每个动作可以被视为一个 multi-hot vector 事件(如点赞、分享、评论、图片浏览、视频初始化、视频观看完成、隐藏等)
    \(E, F\) 图2中DLRM中的分类特征
    \(E_{0}, E_{1}, …, E_{7}, E_{8}\) 和 \(F_{0}, F_{1}, …, F_{7}\) 表示在不同时间点通过特征提取(例如,最近喜欢的10张图片、用户过去点击的与当前候选内容最相似的50个网址等)从 \((\Phi_{0}, a_{0}, t_{0}), …, (\Phi_{n_{c}-1}, a_{n_{c}-1}, t_{n_{c}-1})\) 获得的转换结果
    “merge & sequentialize” 表示获取原始参与系列 \((\Phi_{0}, a_{0}, t_{0}), …, (\Phi_{n_{c}-1}, a_{n_{c}-1}, t_{n_{c}-1})\) 的(虚拟)反向过程
    \(G, H\) 图2中DLRM中与 user-item 参与无关的分类特征。这些特征(例如人口统计信息或关注的创作者)被合并到主时间序列(用户参与的内容列表,例如 \(\Phi_{0}, a_{0}, …, \Phi_{n_{c}-1}, a_{n_{c}-1}\))中,如2.1节所述并在图2中说明
    \(n(n_{i})\) Squential Transduction 任务中的 token 数量(对于用户或样本 \(i\))。虽然 \(O(n) = O(n_{c})\) ,但即使没有任何与非交互相关的分类特征, \(n\) 也可能与 \(n_{c}\) 不同;例如,见表1
    \(x_{0}, …, x_{n - 1}\) Squential Transduction 任务中的输入 token 列表
    \(y_{0}, …, y_{n - 1}\) Squential Transduction 任务中的输出 token 列表
    \(t_{0}, …, t_{n - 1}\) 与观察到 \(x_{0}, …, x_{n - 1}\) 的时间对应的时间戳列表
    \(\mathbb{X}, \mathbb{X}_{c}\) 所有输入/输出 token 的词汇表(\(\mathbb{X}\))及其内容子集(\(\mathbb{X}_{c}\))
    \(N, N_{c}\) \(\max_{i}n_{i}\) , \(\max_{i}n_{c,i}\)
    \(u_{t}\) 在时间 \(t\) 的用户表示
    \(s_{u}(n_{i}), \hat{s}_{u}(n_{i})\) 生成式训练(2.3节)中用于用户 \(i\) 的采样率
    \(d\) 模型维度(嵌入维度)
    \(d_{qk}\) HSTU和 Transformer 中注意力维度的大小。这适用于公式(1)中的 \(Q(X)\) 和 \(K(X)\)
    \(d_{v}\) HSTU中 Value 维度的大小。对于 Transformer,论文通常有 \(d_{qk} = d_{v}\)
    \(d_{ff}\) Transformer pointwise 前馈层中(feedforward)的隐藏维度大小。HSTU不使用前馈层;见下面的 \(U(X)\)
    \(h\) 注意力头的数量
    \(l\) HSTU中的层数。对于 Transformer,注意力层和逐点前馈层共同构成一层
    \(Q(X), K(X), V(X)\) 根据公式(1)为给定输入 \(X\) 在HSTU中获得的查询、键、值(Query/Key/Value)。其定义与标准 Transformer 中的 \(Q, K, V\) 类似。 \(Q(X), K(X) \in \mathbb{R}^{h \times N \times d_{qk} }\) ,并且HSTU使用 \(U(X)\) (与 \(f_{2}(·)\) 一起)在公式(3)中 “门控” 注意力池化值(\(V(X)\)),这使得HSTU完全避免了前馈层。 \(U(X) \in \mathbb{R}^{h \times N \times d_{v} }\)
    \(A(X)\) 为输入 \(X\) 获得的注意力张量。 \(A(X) \in \mathbb{R}^{h \times N \times N}\)
    \(Y(X)\) HSTU层针对输入 \(X\) 的输出。 \(Y(X) \in \mathbb{R}^{d}\)
    \(Split(·)\) 将张量分割成块的操作。 \(Split(\phi_{1}(f_{1}(X))) \in \mathbb{R}^{N \times (2hd_{qk} + 2hd_{v})}\) ;论文通过分割较大的张量(并置换维度)获得 \(U(X), V(X)\) (两者形状均为 \(h \times N \times d_{v}\))、 \(Q(X), K(X)\) (两者形状均为 \(h \times N \times d_{qk}\))
    \(rab^{p,t}\) 结合了位置(Raffel等,2020)和时间(基于观察到 token 的时间 \(t_{0}, …, t_{n - 1}\) ;一种可能的实现方式是对 \((t_{j} - t_{i})\) 应用某种分桶函数得到 \((i, j)\))信息的相对注意力偏差。在实践中,论文在一层内的不同注意力头之间共享 \(rab^{p,t}\) ,因此 \(rab^{p,t} \in \mathbb{R}^{1 \times N \times N}\)
    \(\alpha\) 控制HSTU中随机长度算法(3.2节)稀疏性的参数
    \(R\) GPU上的寄存器大小,在3.2节讨论的HSTU算法的上下文中
    \(m\) 推荐系统Ranking 阶段考虑的候选数量
    \(b_{m}\) M-FALCON 算法(3.4节)中的微批次大小

附录B:生成式推荐器:背景与公式

  • 许多读者可能对经典的深度学习推荐模型(DLRM)更为熟悉(Mudigere等,2022),因为自YouTube DNN时代起它就颇受欢迎(Covington等,2016),并且在每个大型在线内容和电子商务平台上都得到了广泛应用(Cheng等,2016;Zhou等,2018;Wang等,2021;Chang等,2023;Xia等,2023;Zhai等,2023a)。DLRM在异构特征空间上运行,使用各种神经网络,包括特征交互模块(Guo等,2017;Xiao等,2017;Wang等,2021)、顺序池化或目标感知成对注意力模块(Hidasi等,2016;Zhou等,2018;Chang等,2023)以及先进的多专家多任务模块(Ma等,2018;Tang等,2020)。因此,论文在第2节和第3节中通过将生成式推荐器(GRs)与经典DLRM进行明确对比,概述了GRs。在本节中,论文从经典的顺序推荐文献出发,为读者提供另一种视角

附录B.1:背景:学术界和工业界的顺序推荐

附录B.1.1:学术研究(传统顺序推荐设置)
  • 循环神经网络(RNNs)最早在GRU4Rec(Hidasi等,2016)中应用于推荐场景。Hidasi等(2016)考虑了门控循环单元(GRUs),并将其应用于两个数据集,即RecSys Challenge 2015和VIDEO(一个专有数据集)。在这两种情况下,只有 positive 事件(点击的电子商务商品或用户观看至少一定时间的视频)被保留作为输入序列的一部分。论文进一步观察到,在由检索和Ranking 阶段组成的经典工业规模两阶段推荐系统设置中(Covington等,2016),Hidasi等(2016)解决的任务主要对应于检索任务
  • 后来, Squential Transduction 架构的进步,特别是 Transformer (Vaswani等,2017),推动了推荐系统的类似进展。SASRec(Kang和McAuley,2018)首次在自回归设置中应用 Transformer。他们将评论或评分的存在视为 positive 反馈 ,从而将像亚马逊评论和MovieLens这样的经典数据集转换为 positive item 的序列,类似于GRU4Rec。采用二元交叉熵损失,其中正目标定义为下一个 “positive” item (回想一下,这本质上只是评论或评分的存在),负目标从 item 语料库 \(\mathbb{X}=\mathbb{X}_{c}\) 中随机采样
  • 随后的大多数研究都基于与上述GRU4Rec(Hidasi等,2016)和SASRec(Kang和McAuley,2018)类似的设置,例如BERT4Rec(Sun等,2019)应用来自BERT(Devlin等,2019)的双向编码器设置,S3Rec(Zhou等,2020)引入明确的预训练阶段,等等
附录B.1.2:作为深度学习推荐模型(DLRM)一部分的工业应用
  • 顺序方法,包括顺序编码器和成对注意力模块,由于能够作为DLRM的一部分增强用户表示,已在工业环境中得到广泛应用。DLRM通常使用相对较短的序列长度 ,例如BST(Chen等,2019)中为20,DIN(Zhou等,2018)中为1000,TransAct(Xia等,2023)中为100。论文观察到,论文中的8192比传统DLRM序列长度大 1-3 个数量级
  • 尽管使用短序列长度,大多数DLRM仍能成功捕捉长期用户偏好。这可归因于两个关键方面:
    • 第一 :现代DLRM中通常使用预计算的用户配置文件/嵌入或外部向量存储(Chang等,2023),这两者都有效地扩展了回顾窗口
      • 理解,用户 Embedding 其实与用于历史长期用户偏好有关?
    • 第二 :通常会采用大量的上下文、用户和 item 侧特征,并且使用各种异构网络,如因子分解机(FMs)、深度交叉网络(DCNs)、混合专家(MoEs)等,来转换表示并组合输出
      • 理解:特征交叉能泛化到用户长期偏好?
  • 与附录B.1.1中讨论的顺序设置相比,所有主要的工业界工作都在(用户/请求,候选 item)对上定义损失。在排名设置中,通常使用多任务二元交叉熵损失。在检索设置中,双塔设置(Covington等,2016)仍然是主导方法。最近的工作研究了将下一个推荐 item 表示为(子) token 序列上的概率分布,如OTM(Zhuo等,2020)和DR(Gao等,2021)(注意,在其他近期工作中,相同的设置有时被称为 “生成式召回”)。它们通常利用 beam search 从子 token 中解码 item。随着现代加速器(如GPU、定制ASIC和TPU)的普及,还提出并部署了先进的学习相似性函数,如混合逻辑(Zhai等,2023a),作为双塔设置和 beam search 的替代方案
  • 从问题公式化的角度来看,考虑到模型架构、使用的特征和损失与附录B.1.1中讨论的学术顺序推荐研究有显著差异,作者认为上述所有工作都属于DLRM(Mudigere等,2022)的一部分。值得注意的是,在这项工作之前,工业界尚未成功应用完全顺序排名设置,特别是在每日活跃用户(DAU)达到数十亿规模的情况下

附录B.2:公式化表述:生成式推荐器(GRs)中作为序列转换任务的排序与检索

  • 接下来,论文讨论传统序列推荐器设置和深度推荐模型(DLRM)设置中的三个局限性,以及生成式推荐器(GRs)如何从问题公式化的角度解决这些问题
  • 问题1:忽略用户交互物品之外的特征 :以往的序列公式化表述仅考虑用户明确交互过的内容(物品) ,而在GRs出现之前,传统推荐系统会基于大量特征进行训练 ,以增强用户和内容的表征 ,GRs通过以下方式解决这一局限性:
    • a)压缩其他分类特征并将其与主时间序列合并;
    • b)如2.1节和图2所述,利用目标感知公式通过交叉注意力交互来捕捉数值特征。论文通过实验验证了这一点,结果表明,忽略这些特征的传统 “interactions only” 公式化表述会显著降低模型质量;实验结果可在表7和表6中标记为 “GR(interactions only)” 的行中找到,论文发现仅利用交互历史会导致检索的 HitRate@100 下降1.3%,排序的归一化熵(NE)下降2.6%(回想一下,如4.1.2节和4.3.1节所述,NE变化0.1% 即为显著变化)
  • 问题2:用户表征在与目标无关的设置中计算 :大多数传统序列推荐器,包括GRU4Rec(2016)、SASRec(2018)、BERT4Rec(2019)、S3Rec(2020)等,都是以与目标无关的方式构建的。在这种方式中,对于目标物品 \(\Phi_{i}\) , \(\Phi_{0}, \Phi_{1}, …, \Phi_{i - 1}\) 被用作编码器输入来计算用户表征,然后用于进行预测。相比之下,工业环境中使用的大多数主要DLRM方法在构建所使用的序列模块时考虑了目标感知,能够将 “目标”(排序候选)信息整合到用户表征中。这些方法包括DIN(2018)(阿里巴巴)、BST(2019)(阿里巴巴)、TWIN(2023)(快手)和TransAct(2023)(Pinterest)
    • 生成式推荐器(GRs)通过交错内容和动作序列(2.2节),在因果自回归设置中实现目标感知注意力机制 ,结合了两者的优点(序列推荐器和DLRM的优点)。论文在表10中对先前的工作和本工作进行了分类和对比
    • 理解:(详情见前文对论文 target-aware 实现方式的理解),其实 Transformer 的 Attention 也有交叉功能,只要输入端有目标 item 和用户历史交互 item 即可,但论文所说的传统的自回归模型是指输入侧不包含目标 item 的情况,论文中,将动作也建模进去,则在输出对目标 item 的动作 token 时,自然就需要将目标 item 作为输入,从而也就实现了 target-aware 交叉注意力:
      • 传统自回归模型预估目标为 : \(p(\Phi_{i+1}|\Phi_{0},\Phi_{1},\ldots,\Phi_{i})\),目标物品 \(\Phi_{i+1}\) 与其他历史序列无交叉
        • 注:此时动作 \(a_i\) 信息可能会作为额外表征加入到 \(\Phi_{i}\)中,所以 表述为 \(p(\Phi_{i}|(\Phi_{0}, a_{0}), …, (\Phi_{i - 1}, a_{i - 1}))\) 也可以
      • 论文预估目标为 : \(p(a_{i+1}|\Phi_{0},a_{0},\Phi_{1},a_{1},\ldots,\Phi_{i+1})\),目标物品 \(\Phi_{i+1}\) 与其他历史序列有交叉
  • 问题3:判别式公式限制了先前序列推荐器工作的适用性 :传统的序列推荐器本质上是判别式的。现有的序列推荐文献,包括GRU4Rec和SASRec等开创性工作,对 \(p(\Phi_{i}|\Phi_{0}, a_{0}, …, \Phi_{i - 1}, a_{i - 1})\) 进行建模(注:虽然这里传统这里输入中包含动作信息,但是实际上动作是作为额外的信息加入到 item 表征中的,不是单独将 action 作为一个 token,所以写成 \(p(\Phi_{i}|(\Phi_{0}, a_{0}), …, (\Phi_{i - 1}, a_{i - 1}))\) 更合适),即根据用户当前状态推荐下一个物品的条件分布
    • 推荐系统中实际上存在两个概率过程 :
      • 1)推荐系统向用户推荐内容 \(\Phi_{i}\) 的过程;
      • 2)用户通过某些动作 \(a_{i}\) 对推荐内容 \(\Phi_{i}\) 做出反应的过程
    • 生成式方法对推荐内容和用户动作序列的联合分布进行建模 ,如2.2节所述,即联合概率分布 \(p(\Phi_{0}, a_{0}, \Phi_{1}, a_{1}, …, \Phi_{n_{c}-1}, a_{n_{c}-1})\)(如表11(图8)所示,论文提出的生成式推荐器能够对这种分布进行建模):
      • Next action token (\(a_{i}\)) prediction 任务正是 Ranking 任务(即表1中讨论的 GR Ranking 设定)
      • Next content token (\(\Phi_{i}\)) prediction 任务对应 Retrieval 任务,目标是学习 next item
  • 重要的是,这种公式化表述不仅能够对数据分布进行适当建模,还能够通过例如 beam search 直接采样要推荐给用户的物品序列。论文假设这将产生一种比传统列表式设置(例如DPP(2014)和RL(2018))更优越的方法,论文将此类系统的完整公式化表述和评估(在6节中简要讨论)留作未来的工作

附录C:评估:合成数据

  • 如3.1节先前所述,标准的softmax注意力机制因其归一化因子,难以捕捉用户偏好的强度,而这对于用户表示学习至关重要。在推荐场景中,这一点很关键,因为系统可能不仅需要预测 item 的相对排序,还需要预测用户参与的强度(例如,未来对特定主题的 positive 行为数量)
  • 为理解这种行为,论文构建了遵循狄利克雷过程(Dirichlet Process)的合成数据,该过程在动态词汇集上生成流式数据。狄利克雷过程捕捉了用户参与历史中 “富者更富”(rich gets richer) 的行为。论文设置合成实验如下:
    • 论文将 20,000 个 item ID中的每一个随机分配到100个类别中的某一个
    • 论文生成 1,000,000 条长度为 128 的记录,其中前90%用于训练,最后10%用于测试。为模拟流式训练设置,论文最初提供40%的 item ID,其余的以相等的间隔逐步提供;即在记录 500,000 时,可以采样的最大ID是 \((40% + 60% * 0.5) * 20,000 = 14,000\)
    • 论文为每条记录从100个类别中随机选择最多5个类别,并为这5个类别随机采样一个先验 \(H_{c}\) 。按照狄利克雷过程,为每个位置顺序采样类别,具体如下:
      • 对于 \(n>1\) :
        • 以概率 \(\alpha / (\alpha + n - 1)\) ,从 \(H_{c}\) 中抽取类别 \(c\)
        • 以概率 \(n_{c} / (\alpha + n - 1)\) ,抽取类别 \(c\) ,其中 \(n_{c}\) 是先前具有类别 \(c\) 的 item 数量
        • 随机采样一个符合类别 \(c\) 且受流式约束的分配 item
      • 其中 \(\alpha\) 从(1.0, 500.0)中均匀随机采样
  • 结果见表2。由于此数据集没有时间戳,论文在HSTU中去除 \(rab^{p, t}\)。论文观察到,相对于标准 Transformer,HSTU的 HitRate@10 提高了100%以上。重要的是,将HSTU的逐点聚合注意力机制(Pointwise aggregated attention)替换为 Softmax(“HSTU w/ Softmax”)也会导致 HitRate 显著降低,这验证了类似逐点聚合注意力机制的重要性

附录D:评估:传统顺序推荐器设置

  • 论文在4.1.1节的评估重点是将HSTU与最先进的 Transformer 基线SASRec进行比较,使用最新的训练方法。在本节中,论文进一步考虑另外两种替代方法
  • 循环神经网络(RNNs)。论文考虑顺序推荐器的经典工作GRU4Rec(Hidasi等,2016),以帮助读者理解包括 Transformer 和HSTU在内的自注意力模型,在充分融入最新的建模和训练改进后,与传统RNNs相比如何
  • 自监督顺序方法。论文考虑最受欢迎的工作BERT4Rec(Sun等,2019),以了解双向自监督(BERT4Rec通过完形填空目标利用)与单向因果自回归设置(如SASRec和HSTU)相比如何
  • 结果见表12。论文重用Klenitskiy和Vasilev(2023)报告的BERT4Rec和GRU4Rec在ML-1M和ML-20M上的结果。由于使用了采样softmax损失,论文保持负样本数量不变(ML-1M、ML-20M为128,亚马逊图书为512),以确保方法之间的公平比较
  • 结果证实,在使用采样softmax损失的传统顺序推荐设置中,SASRec仍然是最具竞争力的方法之一(Zhai等,2023a;Klenitskiy和Vasilev,2023),而HSTU显著优于所评估的 Transformer 、RNN和自监督双向 Transformer

附录E:评估:传统DLRM基线

  • 第4节中使用的DLRM基线配置反映了数百名研究人员和工程师多年来的持续迭代,并且在部署HSTU/GR之前,是对拥有数十亿日活跃用户的大型互联网平台上生产配置的近似。本节对所使用的模型进行简要描述

排名设置

  • 如(Mudigere等,2022)所述,基线排名模型采用了大约一千个密集特征和五十个稀疏特征。论文结合了各种建模技术,如 Mixture of Experts(Ma等,2018)、Deep & Cross Network 的变体(Wang等,2021)、各种顺序推荐模块,包括 target-aware pairwise 注意力(在工业设置中常用的一种变体可参见(Zhou等,2018)),以及特殊交互层上的残差连接(He等,2015;Zhang等,2022)。在 scaling law 部分(4.3.1节)的低FLOP区域,一些计算成本高的模块被简化、替换为其他 state-of-the-art 的变体,如DCN,以达到所需的FLOP
  • 由于保密考虑,论文无法透露确切设置 ,但据论文所知,在充分纳入最新研究成果后,论文的基线代表了最优秀的DLRM方法之一。为验证这一说法并帮助读者理解,论文在表7中报告了基于相同特征,但仅利用主要已发表成果(包括DIN(Zhou等,2018)、DCN(Wang等,2021)和MMoE(Ma等,2018))的典型设置(“DLRM (DIN + DCN)”),并在图9中展示了组合架构。该设置在主要E任务的NE上比论文的生产DLRM设置低0.71%,在主要C任务的NE上低0.57%(0.1%的NE变化即具有显著性)

检索设置

  • 基线检索模型采用标准的双塔神经检索设置(Covington等,2016),并结合了批内(in-batch)和批外(out-of-batch)采样。输入特征集包括高基数稀疏特征(如 item ID、用户ID)和低基数稀疏特征(如语言、主题、兴趣实体)。使用带有残差连接的前馈层堆栈(He等,2015)将输入特征压缩为用户和 item 嵌入

特征和序列长度

  • DLRM基线中使用的特征,包括各种顺序编码器/成对注意力模块所利用的主要用户交互历史,是所有 GR 候选模型使用特征的严格超集(strict supersets)。这适用于论文中进行的所有研究,包括缩放研究(4.3.1节)中使用的特征

附录F:随机长度

附录F.1:子序列选择

  • 在公式(4)中,论文从完整的用户历史中选择长度为 \(L\) 的子序列以增加稀疏性。论文的实证结果表明,精心设计子序列选择技术可以提高模型质量。论文计算一个指标 \(f_{i}=t_{n}-t_{i}\) ,它对应于用户与 item \(x_{i}\) 交互后经过的时间量。论文使用以下子序列选择方法进行离线实验:
    • 贪心选择(Greedy Selection) - 从 \(S\) 中选择 \(L\) 个 \(f_{i}\) 值最小的 item,即保留最近交互的 item
    • 随机选择(Random Selection) - 从 \(S\) 中随机选择 \(L\) 个 item
    • 特征加权选择(Feature-Weighted Selection) - 根据加权分布 \(1 - f_{n, i} / (\sum_{j = 1}^{L}f_{j, i})\) 从 \(S\) 中选择 \(L\) 个 item(理解:对越近交互的样本,采样权重越大)
  • 在离线实验中,特征加权子序列选择方法产生了最佳的模型质量,如表13所示

附录F.2:随机长度对序列稀疏性的影响

  • 在表3中,论文展示了随机长度对具有30天用户参与历史的代表性工业规模配置中序列稀疏性的影响。序列稀疏性定义为1减去所有样本的平均序列长度与最大序列长度的比值。为了更好地描述稀疏注意力的计算成本,论文还定义了 \(s2\) ,它被定义为:1减去注意力矩阵的稀疏性(which is defined as one minus the sparsity of the attention matrix)。作为参考,论文在表14和表15中分别给出了60天和90天用户参与历史的结果

附录F.3:与序列长度外推技术的比较

  • 论文进行了额外的研究,以验证随机长度与语言建模中使用的现有序列长度外推技术相比具有竞争力。许多现有方法通过修改旋转位置嵌入(RoPE)(Su等,2023)来进行序列长度外推。为了与现有方法进行比较,论文训练了一个没有相对注意力偏差和旋转嵌入的HSTU变体(HSTU - RoPE)
  • 论文在HSTU-RoPE上评估以下序列长度外推方法:
    • 零样本(Zero-Shot) - 应用NTK感知的RoPE(Peng等,2024),然后直接评估模型,不进行 Fine-tune
    • 微调(Fine-tune) - 应用逐部分NTK(Peng等,2024)后,对模型进行1000步 Fine-tune
  • 论文在HSTU(包括相对注意力偏差,无旋转嵌入)上评估以下序列长度外推方法:
    • 零样本(Zero-Shot) - 根据最大训练序列长度夹紧相对位置偏差,直接评估模型(Raffel等,2020;Press等,2022)
    • 微调(Fine-tune) - 根据最大训练序列长度夹紧相对位置偏差,在评估模型之前对模型进行1000步 Fine-tune
  • 在表16中,论文报告了训练期间引入数据稀疏性的模型(随机长度、Zero-Shot、Fine-tune)与在完整数据上训练的模型之间的NE差异。论文将 Zero-Shot 和 Fine-tune 技术的稀疏性定义为训练期间的平均序列长度与评估期间的最大序列长度之比。所有 Zero-Shot 和 Fine-tune 模型都在1024序列长度的数据上进行训练,并在2048和4096序列长度的数据上进行评估。为了为这些技术找到合适的随机长度基线,论文选择了导致相同数据稀疏性指标的随机长度设置
  • 作者认为,Zero-Shot 和 Fine-tune 的序列长度外推方法不太适合处理高基数ID的推荐场景。从经验上看,论文观察到随机长度明显优于 Fine-tune 和 Zero-Shot 方法。作者认为这可能是由于论文的词汇量较大 ,Zero-Shot 和 Fine-tune 方法无法为较旧的ID学习良好的表示,这可能会损害它们充分利用较长序列中包含的信息的能力

附录G:稀疏分组通用矩阵乘法(GEMMs)和融合相对注意力偏差

  • 论文提供3.2节中介绍的高效HSTU注意力内核的更多信息。论文的方法基于内存高效注意力(Rabe和Staats,2021)和FlashAttention(Dao等,2022),是一种内存高效的自注意力机制,它将输入划分为块,并避免在反向传播中具体化大的 \(h×N×N\) 中间注意力张量。通过利用输入序列的稀疏性,我们可以将注意力计算重新表述为一组具有不同形状的连续GEMM运算。论文实现了高效的GPU内核来加速此计算。相对注意力偏差的构建也因内存访问而成为瓶颈。为解决此问题,论文将相对偏差构建和分组GEMM运算融合到单个GPU内核中,并在反向传播中使用GPU的快速共享内存来累积梯度。尽管论文的算法在反向传播中需要重新计算注意力和相对偏差,但它比 Transformer 中使用的标准方法明显更快且内存使用更少

附录H:Microbatched-Fast Attention Leveraging Cacheable OperatioNs (M-FALCON)

  • 在本节中,论文详细描述3.4节中讨论的 M-FALCON 算法
  • M-FALCON 引入了三个关键思想
  • 批量推理可应用于因果自回归设置 :GR中的排名任务以目标感知的方式制定,如2.2节所述。通常认为,在目标感知设置中,论文需要一次对一个 item 进行推理,对于 \(m\) 个候选 item 和长度为 \(n\) 的序列,成本为 \(O(mn^{2}d)\) 。但这里论文表明这不是最优解决方案;即使使用普通 Transformer,也可以修改自注意力中使用的注意力掩码以进行批量操作(“批量推理”),并将成本降低到 \(O((n + m)^{2}d)=O(n^{2}d)\)
    • 图11给出了说明。这里,图11(a)和(b)都涉及因果自回归设置的注意力掩码矩阵。关键区别在于,图11(a)在因果训练中使用大小为 \(2n_{c}\) 的标准下三角矩阵,而图11(b)通过将 \(i, j≥2n_{c}\) 且 \(i≠j\) 的条目设置为 False 或 \(-\infty\) 来修改大小为 \(2n_{c}+b_{m}\) 的下三角矩阵,以防止目标位置 \(\Phi_{0}’, …, \Phi_{b_{m}-1}’\) 相互关注。很容易看出,通过这样做,自注意力块对 \(\Phi_{i}’, a_{i}’\) 的输出仅取决于 \(\Phi_{0}, a_{0}, …, \Phi_{n_{c}-1}, a_{n_{c}-1}\) ,而不取决于 \(\Phi_{j}’\) (\(i≠j\))。换句话说,通过使用修改后的注意力掩码对 \((2n_{c}+b_{m})\) 个 token 进行前向传递,论文现在可以获得与对 \((2n_{c}+1)\) 个 token 进行 \(b_{m}\) 次单独前向传递相同的最后 \(b_{m}\) 个 token 的结果,在第 \(i\) 次前向传递中, \(\Phi_{i}’\) 位于第 \(2n_{c}\) (基于0)位置,使用标准因果注意力掩码
  • 微批次将批量推理扩展到大型候选集 :Ranking 阶段可能需要处理大量的排名候选 item ,多达数万个(Wang等,2020)。我们可以将总共 \(m\) 个候选 item 划分为 \(\lceil m / b_{m}\rceil\) 个大小为 \(b_{m}\) 的微批次,使得 \(O(b_{m}) = O(n)\) ,这在大多数实际推荐设置中,对于多达数万个候选 item ,保持了前面讨论的 \(O((n + m)^{2}d)=O(n^{2}d)\) 的运行时间
  • 编码器级缓存可在请求内和请求间实现计算共享 :最后,键值缓存(Pope等,2022)可在请求内和请求间应用。例如,对于论文中介绍的HSTU模型(3节), \(K(X)\) 和 \(V(X)\) 在微批次内和/或请求间完全可缓存。对于缓存的前向传递,论文只需要为最后 \(b_{m}\) 个 token 计算 \(U(X), Q(X), K(X)\) 和 \(V(X)\) ,同时为包含 \(n\) 个 token 的序列化用户历史重用缓存的 \(K(X)\) 和 \(V(X)\) 。同样, \(f_{2}(Norm(A(X)V(X))\odot U(X))\) 只需要为 \(b_{m}\) 个候选 item 重新计算。这将缓存前向传递的计算复杂度降低到 \(O(b_{m}d^{2}+b_{m}nd)\) ,即使 \(b_{m}=n\) ,也比 \(O((n + b_{m})d^{2}+(n + b_{m})^{2}d)\) 提高了2 - 4倍
  • 算法1说明了 M-FALCON 算法,有助于理解。论文注意到, M-FALCON 不仅适用于HSTU和GR,还广泛适用于其他基于自注意力架构的目标感知因果自回归设置的推理优化算法

附录H.1:推理吞吐量评估:使用 M-FALCON 的生成式推荐器(GRs)与DLRM的比较

  • 如3.4节所述, M-FALCON 在推理时并行处理 \(b_{m}\) 个候选 item ,以在所有 \(m\) 个候选 item 之间分摊计算成本。为理解论文的设计,论文在相同硬件设置下比较了GR和DLRM的吞吐量(即每秒评分的候选 item 数,QPS)
  • 如图12和图13所示,由于批量推理实现了成本分摊,GR的吞吐量在一定区域内(在论文的案例研究中 \(m = 2048\))随Ranking 阶段候选 item 数 \((m)\) 呈次线性增长。这证实了批量推理在因果自回归设置中的关键性。由于注意力复杂度按 \(O((n + b_{m})^{2})\) 缩放,利用多个微批次本身就提高了吞吐量。缓存进一步消除了微批次之上的冗余线性和注意力计算。两者结合,相对于使用单个微批次的 \(b_{m}=m = 1024\) 基线,实现了高达1.99倍的额外加速,如图13所示。总体而言,凭借高效的HSTU编码器设计和利用 M-FALCON ,基于HSTU的生成式推荐器在大规模生产设置中的吞吐量比DLRM高出2.99倍,尽管GR在FLOP方面复杂285倍

ML——集成学习

集成学习(Ensemble)的本质是一种组合基础模型实现更高泛化能力和精度的技术框架
本文参考了博客: http://www.cnblogs.com/jasonfreak/p/5657196.html


集成学习的三族算法

Bagging

通过重采样技术生成若干不同子训练集 ,然后在每个训练集上训练一个分类器 ,最终采用投票方式产生模型最终结果

  • m个基础模型
  • 从原始训练集中抽样生成m个子训练集,用子训练集训练每个基础模型
  • 最终预测结果: 对所有基础模型预测的结果进行综合产生
代表模型
随机森林
  • RF = Bagging + DT (随机森林中特征的选择也是随机的,这一点不同于DT,也不同与Bagging)
  • 随机森林详情可参考ML——RF

Boosting

每个训练样例都有权重 ,每次训练新分类器的时候都着重训练那些在上一次分类过程中分错的样例 ,权重会随着迭代次数的变化而变化

  • 训练思想是,每一轮迭代都将重心放到分错类的样本上
  • 训练过程为阶梯状
  • 基础模型按照次序依次训练(实现时可做到并行)
  • 前一个模型的预测结果修改后一个模型的样本权重(注意:模型的训练集时不会变的,只有每个样本的权重在变化,增大分错的样本的权重,使得后面训练时重视分错的样本),以此类推
  • 最终预测结果: 对所有基础模型预测的结果进行线性综合产生
代表模型
提升树
  • Boosting Tree = AdaBoost + DT (AdaBoost是Boosting族算法的一种)
梯度提升树
  • GBDT = Gradient Boosting + DT (Gradient Boosting是Boosting族算法的一种)
  • GBDT详情可参考ML——GBDT

Stacking

每个分类器首先做一遍决策,然后将分类器们的决策送到更高一层的模型中,把他们当做特征再进行一次训练

  • 训练所有基础模型
  • 获取所有基础模型的预测结果
  • 第j个模型对某个训练样本产生的预测值作为该训练样本的第j个特征,标签不变,生成一个新的数据集(注意,样本的特征空间大小发生了变化,标签没变)
  • 基于新的训练集进行训练,得到预测模型M()
  • 测试时也要将特征转换后再用M进行预测
  • 实质上就是先用一些模型提取特征(这些模型的输出作为更高层模型的特征),然后再用模型的输出作为最终模型的特征,从而实现模型的堆叠(Stacking)
  • 理论上可以堆叠各种各样的分类器

方差与偏差

  • 偏差与方差可以参考我的博客: 模型的偏差与方差

ML——DT-决策树

决策树算法时很多优秀的集成算法的基础,包括RF,GBDT,提升树(Boosting Tree)等

论文参考: 《统计学习方法》


说明

  • 决策树(DT)是一种基本的分类和回归方法
  • 分类问题中可以理解为if-then规则的集合
  • 分类问题中也可以理解为定义在特征空间->类空间上的条件概率分布
  • 分类问题中使用的是ID3和C4.5
    • ID3 基于 最大化信息增益 来选择特征
    • C4.5基于 最大化信息增益比 来选择特征
  • 回归问题使用的是分类与回归树(Classification And Regression Tree, CART)
    • 分类树: 基于 最小化基尼(Gini Index)指数 来选择特征
    • 回归树: 基于 最小化平方误差 来选择特征
  • 关于树模型的复杂度可以用下面的方式评估, 统计学习方法中CART选择使用树的节点总数来评估树的复杂度
    • 叶子节点的数量
    • 节点总数
    • 树的深度

树模型的优缺点

优点

  • 可解释性强
  • 可处理混合类型特征(?)
  • 具体伸缩不变性(不用归一化特征,LR模型需要归一化特征)
  • 有特征组合的作用
  • 可自然地处理缺失值(神经网络不能处理缺失值)
  • 对异常点鲁棒
  • 有特征选择作用
  • 可扩展性强,容易并行

缺点

  • 缺乏平滑性(回归预测时输出值只能 输出有限的若干种数值)
  • 不适合处理高维稀疏数据
    • 数据稀疏时
      • 比如某个数据集有10000个样本
      • 某个特征只有10个样本存在值,其他样本都不存在值
    • 决策树:
    • 这样的话树容易将当前特征选中作为分类特征,这种划分可能在训练集上看起来很好,但测试集中表现可能不太好(这里不是简单的缺失值,而是数据很稀疏,这里需要进一步的理解[待更新])
    • LR等线性模型:
      • 因为现在的模型普遍都会带着正则项,而LR等线性模型的正则项是对权重的惩罚,也就是特征对应的权重一旦过大,惩罚就会很大,进一步压缩权重的值,使他不至于过大,而树模型则不一样,树模型的惩罚项通常为叶子节点数和深度等,而我们都知道,对于上面这种 case,树只需要一个节点就可以完美分割9990和10个样本,惩罚项极其之小.

决策树训练的三个步骤

特征选择

信息增益
  • 对数据集D的经验熵: \(H(D)=-\sum_{k=1}^{K}\frac{|C_{k}|}{|D|}log_{2}\frac{|C_{k}|}{|D|}\)
  • 对特征A对数据集D的经验条件熵: \(H(D|A)=\sum_{n=1}^{n}\frac{|D_{i}|}{|D|}H(D_{i})=-\sum_{n=1}^{n}\frac{|D_{i}|}{|D|}\sum_{k=1}^{K}\frac{|D_{ik}|}{|D_{i}|}log_{2}\frac{|D_{ik}|}{|D_{i}|}\),n为特征A的取值个数
信息增益比
  • 特征A对训练数据D的信息增益比: \(g_{R}(D,A)=\frac{g(D,A)}{H_{A}(D)}\)
    • 其中 \(H_{A}=-\sum_{i=1}^{n}\frac{|D_{i}|}{|D|}log_{2}\frac{|D_{i}|}{|D|}\),n为特征A的取值个数
基尼指数
  • 对于分布 \(p=(p_{1},…,p_{k})\) 的基尼指数: \(Gini(p)=\sum_{k=1}^{K}p_{k}(1-p_{k})=1-\sum_{k=1}^{K}p_{k}^{2}\)
  • 对样本集合D,基尼指数为: \(Gini(p)=1-\sum_{k=1}^{K}\left(\frac{|C_{k}|}{|D|}\right)^{2}\)
  • 在特征A的条件下,集合D的基尼指数: \(Gini(D,A)=\frac{|C_{1}|}{|D|}Gini(D_{1})+\frac{|C_{2}|}{|D|}Gini(D_{2})\)
    • 我们将集合D根据特征A分为两类:
    • 是否取特征A的某个值,其中 \(A=a\) 的是 \(D_{1}\), \(A\neq a\) 的是 \(D_{2}\)

决策树的生成

ID3
  • 基于最大化信息增益来选择特征
  • 选取所有信息增益最大的特征作为当前结点的特征
  • 对取值数目较多的属性有所偏好
  • 只有分类树,没有回归树
  • ID3相当于用极大似然法对模型进行选择(问题:如何理解?)
C4.5
  • 基于最大化信息增益比来选择特征
  • 选取信息增益比最大的特征作为当前结点的特征
    • 由于使用最大的信息增益比特征可能对取值数目少的特征属性有所偏好,所以C4.5算法一般不会直接选信息增益比最大的,而是
      • 先从候选区属性中找出信息增益高于平均水平的
      • 再从筛选结果中寻找信息增益比最大的
  • 处理连续型特征时使用二分法(bi-partition)
  • 只有分类树,没有回归树
CART
  • CART的决策树是二叉树
  • 分类树: 基于最小化基尼指数来选择特征
    • 输出变量为离散变量
    • 基于基尼指数选取所有可能的特征A和所有可能的切分点a中,选择基尼指数最小的特征和切分点为最优特征和最优且分点
  • 回归树: 基于最小化平方误差来选择特征
    • CART回归树一般称为最小二乘回归树(因为目标函数的优化是最小化误差的平方和实现的)
    • 输出变量为连续变量
    • 选取时选择最优划分变量(特征)j和最优切分点s,然后按照变量j的划分点s将结点分为两个结点,大于s的为第一个结点,小于s的为第二个结点
  • 一点说明:
      • 《统计学习方法》 *中:
        • 回归树的特征默认是连续变量,选取划分点s时使用的是连续变量的各种中间值作为候选值,大于s的分为一个结点,小于s的分为一个结点
        • 分类树的特征默认是离散变量,选取划分点a时使用的是离散变量的值作为候选值,等于a的分为一个结点,不等于a的分为另一个结点
        • 但是实际上无论时分类树还是回归树,CART都可以用相同手段处理连续值

剪枝

  • 剪枝的核心思想:
    • 就是加入考虑树的复杂度考量(决策树的生成时仅仅是考虑到信息增益和信息增益比,没有考量树的复杂度)
      • 树的深度越深,树越复杂
    • 整体上再考虑树变得更简单的同时保证分类误差较小
预剪枝
  • 预剪枝发生在决策树生成过程中
  • 在每个节点划分前先进行估计
  • 若当前节点划分不能带来决策树准确率提升,则停止划分
  • 在 《统计学习方法》 中未提到这个方法,只讲了一种简单的后剪枝算法
  • 时间复杂度小,欠拟合风险大
后剪枝
  • 后剪枝发生在决策树生成后
  • 自底向上的对非叶子节点进行考察(注意千万不可从根节点开始自顶向下的剪枝,可能失去整体最优的决策树)
  • 若将该节点换成叶子节点能带来决策树准确率提升,则将该节点替换为叶子节点
  • 时间复杂度大,欠拟合风险小
常见的后剪枝方法

各种方法各有优劣,关注不同的优化角度

  • REP 错误率降低剪枝(Reduced Error Pruning)
  • PEP 悲观剪枝(Pessimistic Error Pruning)
  • CCP 代价复杂度剪枝(Cost Complexity Pruning), 详细过程参考CPP剪枝算法描述
  • MEP 最小误差剪枝(Minimum Error Pruning)
  • CVP (Critical Value Pruning)
  • OPP (Optimal Pruning)
ID3和C4.5的剪枝
CART的剪枝

CART使用的是CCP剪枝

  • 对任意的子树,我们可以定义子树的损失
    • \(C_{a}(T)=C(T)+\alpha|T|\)
    • 子树的损失 = 子树的预测误差 + \(\alpha\) * 子树的节点数
    • 对于回归树和分类树,子树的预测误差定义不一样,前者是误差的平方和,后者是基尼指数
  • 可以证明对于给定的Alpha,一定存在某个损失最小的子树,也就是我们要的最优子树
  • 现实中实现时可以使用递归的方法实现
CCP剪枝算法描述

CPP剪枝也是一种后剪枝算法
修正:统计学习方法CART算法第六步中应该跳到第二步,而不是第三步

  • 1:计算所有节点对应的 \(\alpha\) 值: \(\alpha=\frac{C(t)-C(T_{t})}{|T_{t}|-1}\)
    • \(C(t)\) 是以t节点单一节点为树时单一节点树的损失函数
    • \(C(T_{t})\) 是以t节点为根节点的子树时整棵子树的损失函数
    • \(|T_{t}|\) 是以t节点为根节点的子树时整棵子树的节点数量
  • 2:对当前 \(\alpha\) 值最小的节点t剪枝,并存储中间结果的 \(\alpha\) 值和剪枝后的树结构
    • 当 \(\alpha\) 确定时,存在唯一的最小子树 \(T_{\alpha}\) 使损失函数 \(C_{\alpha}(T)\) 最小
  • 3:选取当前树为剪枝后的树,跳转到第1步,直到剪枝到只有三个节点的树时截止
  • 4:截止后得到节点数量从大到小的多个子树 \(T_{\alpha_{0}}, T_{\alpha_{1}},…,T_{\alpha_{n}}\) . (其中 \(T_{\alpha_{i}}\) 也就对应着第i个 \(\alpha\) 值 \(\alpha_{i}\))
  • 5:用交叉验证法对 \(\alpha\) 的值进行选择(CART算法执行时 \(\alpha\) 类似超参数,整个算法学习的过程类似于用交叉验证法确定超参数的过程, \(\alpha\) 的值确定了,对应的决策树也就确定了!)

关于连续特征

统计学习方法
  • 在 《统计学习方法》 中回归树的特征默认是连续变量,分类树的特征默认是离散变量
机器学习 周志华
  • 在<<机器学习>>中提到连续特征的一种解决方案:
    • 把该连续特征所有出现的取值排序
    • 取临近两两之间的平均值作为划分点
    • 像处理离散的点一样,使用信息增益最大化,信息增益率最大化,或者是基尼指数最小化实现对应的划分选择
个人总结
  • 处理连续型变量(特征的能力)
    • ID3 不能处理连续型特征
      • 因为连续型特征往往使得每一个样本该特征取值都不一样,造成该特征对数据集D的经验条件熵为0?
      • 使得ID3算法趋向于选择这个特征?
    • C4.5 能处理连续型特征
      • 将数据排序后找到类别不同的分割线作为切分点
      • 根据切分点把连续属性二分类装换为布尔类型
      • 可多次使用连续属性分类
    • CART 能处理连续型特征
      • 实际对连续型特征的处理与C4.5一样
      • 由于CART树构建时每次都会对特征进行二值化分,所以可以很好的适用与连续型变量

一些总结

  • 算法ID3生成可能是多叉树,而CART一定是二叉树,《统计学习方法》中二者生成相同的数是巧合,除了不同评价方式的特征选择结果一样以外,还需要被选中的特征都是二值的!
  • ID3相当于用极大似然法对模型进行选择

一种很好的理解思路

ID3
  • ID3算法就是用信息增益大小来判断当前节点应该用什么特征来构建决策树,用计算出的信息增益最大的特征来建立决策树的当前节点
ID3的缺点:
  • ID3不能处理连续特征
  • ID3对取值较多的特征有着偏好在相同条件下,取值比较多的特征比取值少的特征信息增益大
  • ID3不能处理缺失值
  • 没有考虑过拟合的问题(问题: ID3没有剪枝吗?)
C4.5
  • C4.5可以看成是对ID3进行改进
C4.5对ID3的改进
  • 对于ID3不能处理连续特征,C4.5的思路是将连续的特征离散化.

    • 样本(m个样本)按照特征的取值排列后取相邻样本的平均值作为划分点(m-1个划分点),分别计算以每个划分点作为二元分类时的信息增益. 最终选择信息增益高的(问题:这里是C4.5,为什么不是使用信息增益比而是信息增益来作为区分?)
    • C4.5选择连续特征作为分类特征时,只分两类,但是后面(其他层结点)可以使用该特征分类,也就是说连续特征在C4.5中可以多次使用,但每次只分为两部分
  • 对于ID3信息增益最大指标会造成偏向于取值较多的特征的问题. C4.5使用信息增益比来解决问题

  • 对于ID3不能处理缺失值的问题,C4.5主要解决两个问题

    • 1)如何在属性值缺失条件下进行属性选择
      • 没有缺失值的属性,正常处理
      • 每个属性A有缺失值:
        • 对该特征进行划分时仅仅考虑在属性A上没有缺失的部分数据,有缺失的数据不考虑.
        • 无缺失的数据计算收益时需要乘以一个权重(无缺失的样本总数/样本总数)
        • 相当于信息增益适当缩小
    • 2)给定划分属性,若样本在该属性上的值缺失,如何对样本进行划分
      • 在A属性没有缺失值的样本,正常划分
      • 在A属性有缺失值的样本:
        • 给每个样本引入权重,初始值都为1
        • 同一个样本按照不同概率划入到不同的结点(当前叶节点)中去(概率是当前结点中样本数量/无缺失样本总数)
  • 对于没有考虑过拟合的问题:

    • C4.5引入了正则化系数剪枝(问题: ID3没有剪枝吗?)
C4.5的缺点
  • [这个问题存疑,没有任何书籍显示C4.5的剪枝策略是PEP, 《统计学习方法》 中只是简单的介绍了后剪枝]C4.5的剪枝方法时PEP,PEP准确度高,但是存在下面两个子问题:
    • 1)PEP使用由上而下的剪枝策略,会导致与预先剪枝相同的问题,造成过度剪枝
    • 2)会造成剪枝失败的情况
  • C4.5生成多叉树,计算机中很多时候多叉树运算效率不如二叉树来的高
  • C4.5只能用于分类
  • C4.5需要进行大量的对数运算(计算熵)
CART
  • 可以理解为对C4.5进一步的改进
CART对C4.5的改进
  • CART使用CPP代价复杂度剪枝算法
    • 详细过程参考CCP剪枝算法描述
  • CART使用二叉树只生成二叉树,即使是离散特征
    • CART对连续特征的处理与C4.5完全相同
    • CART对离散特征也是二分类且也是可以多次使用同一特征(ID3与C4.5中离散特征只能使用一次,且是多分叉)
    • 所以CART是一颗二叉树
  • CART可有分类树和回归树两种
    • 回归树的目标函数的优化是: 最小化误差的平方和
    • 分类树以概率最大的类别作为叶节点的类别
    • 回归树以中位数或者均值作为预测结果
  • CART使用基尼指数作为结点混乱度的度量指标
    • 避免了对数计算(与熵比较)
ID3,C4.5,CART的缺点
  • 每次之全责一个最优特征作为分类决策,而实际中其实可能需要多个特征一起决策
    • 解决方案: 多变量决策树(每次选择多个特征一起决策)
      • 单个特征决策可以看成是直线
      • 多个特征决策可以看成是斜线
  • 样本改变一点点都会造成树的结构改变很大
    • 解决方案: 随机森林等集成学习方法

ML——GBDT-梯度提升树-概念性总结

GBDT(GradientBoostingDecisionTree), 梯度提升树
GBDT泛指所有梯度提升树,包括XGBoost(XGBoost是GBDT的变种),平时为了进行区分,GBDT特指“Greedy Function Approximation:A Gradient Boosting Machine”(GBDT论文原文)提出的算法,只利用了一阶导数信息(XGBoost利用了二阶导数信息)
*梯度的数学定义:函数上升最快的方向

参考论文:Greedy Function Approximation: A Gradient Boosting Machine
一篇很详细的论文阅读笔记:GBM Paper Reading

引用一个常见的通俗例子:GBDT的思想可以用一个通俗的例子解释,假如有个人30岁,我们首先用20岁去拟合,发现损失有10岁,这时我们用6岁去拟合剩下的损失,发现差距还有4岁,第三轮我们用3岁拟合剩下的差距,差距就只有一岁了(残差作为下一轮拟合的数据的理解)。如果我们的迭代轮数还没有完,可以继续迭代下面,每一轮迭代,拟合的岁数误差都会减小,最终预测时使用他们的结果


五种简称

  • 各种简称,都是同一种算法:
    • GBDT(Gradient Boosting Decision Tree)
    • MART(Multiple Additive Regression Tree)
    • GBT(Gradient Boosting Tree)
    • GTB(Gradient Tree Boosting)
    • GBRT(Gradient Boosting Regression Tree)

模型原理

核心

  • 使用CART作为基础模型(GDBT只能使用CART,不能使用其他树C4.5和ID3等), 后来作者提出也可以使用逻辑回归模型
  • 每棵树学习的是前一棵树的残差(预测值与真实值的差)[这里是当损失函数是均方误差(平方损失, square loss)时可直接使用残差,其他类似的学习算法中若]
  • 残差 = 真实值 - 预测值

如何理解GBDT中的梯度G?

  • G表示Gradient,表示梯度,在GBDT中梯度是指损失函数关于每一个迭代中模型的梯度

与AdaBoost不同

  • AdaBoost是通过利用前一轮弱学习器的误差率来更新训练集的权重

    • 增大分类错误的样本权重,从而使得样本更加关注上一步中分类错误的样本
  • GBDT是通过学习上一轮学习器的残差来实现对真实结果的不断逼近

    • 上一步中预测越接近真实结果的样本,残差越接近0,下一轮中对该样本的关注度越低
    • 上一步中预测越不接近真实结果的样本,残差越大,下一轮中对该样本的关注度越高
  • GBDT的弱学习器只能使用CART回归树,只能用回归树这一点与AdaBoost和RF均不同

    • 因为我们的目标是拟合上一若学习器的残差
    • 而残差往往是数值,不是类别,所以只能使用回归树CART

GBDT分类

二分类
  • GBDT实现二分类时可以每轮迭代直接用一个模型去学残差
  • 此时类别可编码为一个一维向量,取值0或1
多分类
  • GBDT实现多分类时每轮使用一个模型不够了,因为三个模型时使用1,2,3编码显然是不科学的,类别之间不应该有这种数值大小关系
  • 此时三分类模型的类别可编码为一个三维向量,每一维的取值为0或1
  • 在每一轮迭代时,为每个类训练一个CART树,每棵CART树是相互独立的
  • 然后每个模型每轮分别学习当前特征的残差
  • 每个模型都会用到所有的样本
    • 比如一个标记为标记为类别2的样本(x, y=2)
    • 编码为(x, [0,1,0])
    • 对于CART1和CART3(类别1和类别3的CART树)来说,该样本对应输入为(x, y=0)
    • 对于CART2(类别2CART树)来说,该样本对应输入为(x, y=1)
  • 可以理解为把一个三分类问题转化成了3个二分类问题解决了
  • 最后预测时
    • 给定一个未标记样本
    • 每个类对应的模型(每个类的模型个数是该类上模型的迭代次数)都对应给出该类的打分
    • 最后选择分数最高的一个类为最终分类即可

GDBT + LR

  • 为什么在广告CTR预估中, GDBT+LR能够提升效果?
    • 和LR对比: 线性模型
    • 和GBDT对比:

GBDT和神经网络的优劣

深度神经网络

  • 通过不同层级和类型的网络可以对时空信息建模
  • 适合图像, 声音, 文字等带有时序特质的数据

GBDT

  • 基于树模型的GBDT则能很好地处理表格数据
  • 模型的可解释性
  • 输入数据的不变性(几乎不用格式化数据)
  • 更易于调参等特质更适合数值型数据

ML——GBDT-梯度提升树-推导过程

本文介绍梯度提升树(族)的推导过程,包括传统的GBDT,XGBoost等


参数空间的优化

  • 参数空间中我们使用梯度下降法(一阶导数)和牛顿迭代法(二阶导数)来优化
  • 关于无约束参数优化方法参考无约束参数优化方法

从参数空间优化到函数空间的优化

  • 函数空间中我们使用GBDT(一阶导数)和XGBoost(二阶导数)来优化
  • 函数空间的优化完全类比参数空间的优化

Boosting是一种加法模型

加法模型:additive training
$$
\begin{align}
F(x) = \sum_{t=0}^{T}f_{t}(x)
\end{align}
$$

  • 上式中 \(f_{t}(x)\) 为基分类器, 我们通常采用回归树[Friedman 1999] 和逻辑回归
    [Friedman 2000]
  • 树模型的优缺点可以参考ML——DT-决策树

GBDT算法原理

这里只原论文中的Gradient Boosting Tree算法
Friedman于论文”GreedyFunctionApproximation…”中最早 出GBDT

模型 \(F\) 的加法定义:

$$
\begin{align}
F(x) &= \sum_{t=0}^{T}\alpha_{t} h_{t}(x;w_{t}) \\
&= \sum_{t=0}^{T} f_{t}(x;w_{t})
\end{align}
$$

  • 其中, \(x\) 为输入样本, \(h_{t}\) 为分类回归树, \(w_{t}\) 是树 \(h_{t}\) 的参数, \(\alpha_{t}\) 是树 \(h_{t}\) 的权重

损失函数的定义

$$
\begin{align}
F^{\star} = \mathop{\arg\max}_{F}\sum_{i=1}^{N}L(y_{i}, F(x_{i};w))
\end{align}
$$

  • 其中, \(N\) 为样本数量,所有样本的损失总和为总的损失函数
  • 最小化损失函数即可得到最优模型

最优模型的求解

  • 直接列举所有可能的树——NP难问题
  • 所以GBDT算法使用贪心法, 迭代求局部最优解
  • 详细的Gradient Boosting Tree迭代过程如下
  • 上面的推导中
    • 2.1中: \(\tilde{y}_{i}\) 是当前的损失函数 \(L(y_{i}, F(x_{i}))\) 关于当前函数 \(F(x)\) (模型)在 \(F(x)=F_{t-1}(x)\) 处的负梯度(每个样本都有一个负梯度),这个梯度也是GBDT名字中梯度的由来
      • 这个损失函数在使用不同回归方法时时定义各不相同
      • 原始论文中提到两个损失函数定义,根据损失函数的不同,对应的归回方法不同: 最小二乘归回或者最小绝对误差回归
    • 2.2中: \(w^{\star}\) 是指能够拟合当前负梯度的树 \(h_{t}(x;w^{\star})\) 的最佳参数,这里我们认为最佳参数就是最小二乘的最佳参数,实际上这个地方可以使用其他测拟合标准(这个标准是拟合当前负梯度的拟合标准,与后面的损失函数 \(L(y, F(x;w))\) 无关),只是这里最小二乘是最简单也是最自然的选择
      • 原始论文中这里使用的基函数是 \(\beta h(x;w)\),其中 \(\beta\) 是当前基函数 \(h(x;w)\) 的权重,这里我认为直接使用 \(h(x;w)\) 作为基函数即可,权重 \(\beta\) 会自动被基函数学到的(可能原始论文中 \(h(x;w)\) 指的是简单的基函数,是不含权重学习功能的)
      • 但是需要注意的是,如果我们使用的基分类器是逻辑回归,那么这里每个基分类器的结果都是在[0-1]之间的数,是需要前面的 \(\beta\) 的
      • 这里我们推导的时候假定了基分类器是回归树,所以不需要使用 \(\beta\)
    • 2.3中:由于 \(w^{\star}\) 只能保证当前树 \(h_{t}(x;w^{\star})\) 是能拟合负梯度的,不能保证把当前这棵树添加到模型中时模型的损失函数是最小的,所以我们加了一个步长参数 \(\rho\),用来表示得到当前树的最优权重
      • 损失函数是平方损失函数时,这里的参数为 \(\rho\) 就是1,无需计算
    • 2.4中:将当前树的最优树(包括权重)一起加入到模型中

不同损失函数和基函数对应不同的算法

上述式子中推导用到的基函数为树模型,实际使用中也可以使用逻辑回归模型[Friedman 2000]等基函数
本小节将介绍不同损失函数或者基函数带来的不同算法

  • 注意:包括Adaboost和GBDT在内Boosting框架中,基函数(基分类器)都不能使用线性模型
    • 理解: Boosting框架本质上是一个加法模型,是对多个基函数的线性组合,得到更优的分类器,可以证明线性模型的线性组合还是线性模型,如果Boosting框架中使用线性模型,那么我们最终得到的分类器也是线性模型,这就局限了我们的整体模型的表达能力
最小二乘回归(损失函数)

Least-Squares Regression

损失函数定义
  • 此时损失函数定义为
    $$
    \begin{align}
    L(y,F(x)) = \frac{1}{2}(y-F(x))^{2}
    \end{align}
    $$
进一步理解推导过程
  • 2.1中 \(\tilde{y}_{i}=y_{i}-F_{t-1}(x_{i})\),这里直接对损失函数求导即可的到
    $$
    \begin{align}
    \tilde{y_{i}} &= -\left [\frac{\partial L(y,F(x_{i}))}{\partial F(x_{i})}\right ]_{F(x) = F_{t-1}(x)} \\
    &= -\left [\frac{\partial L(y,F_{t-1}(x_{i}))}{\partial F_{t-1}(x_{i})}\right ] \\
    &= -\left [\frac{\partial \frac{1}{2}(y_{i}-F_{t-1}(x_{i}))^{2}}{\partial F_{t-1}(x_{i})}\right ] \\
    &= 2 \cdot -\frac{1}{2}(y_{i}-F_{t-1}(x_{i})) \cdot -1 \\
    &= y_{i}-F_{t-1}(x_{i})
    \end{align}
    $$
  • 2.2中正常拟合(使用线性回归和CART回归树均可)
  • 2.3中基函数的权重 \(\rho^{\star}\) 是常数1,推导如下
    $$
    \begin{align}
    L(y,F_{t}(x)) &= \sum_{i=1}^{N}L(y_{i}, F_{t}(x_{i})) \\
    &= \frac{1}{2}\sum_{i=1}^{N}((y_{i}-F_{t}(x_{i}))^{2}) \\
    &= \frac{1}{2}\sum_{i=1}^{N}((y_{i}-F_{t-1}(x_{i}) - h_{t}(x_{i};w))^{2}) \\
    &= \frac{1}{2}\sum_{i=1}^{N}((\tilde{y}_{i}-h_{t}(x_{i};w))^{2}) \\
    \end{align}
    $$
  • 显然,这个式子和2.2中拟合目标完全相同(只差着2倍常数权重),所以2.2中得到的最优基函数 \(h_{t}(x;w^{\star})\) 就是2.3中使得模型损失函数 \(L(y,F_{t}(x))\) 最小的最优基函数,无需添加任何的权重系数
最小绝对偏差回归(损失函数)

Least Absolute Deviation Regression, LAD

损失函数定义
  • 此时损失函数定义为
    $$
    \begin{align}
    L(y,F(x)) = |y-F(x)|
    \end{align}
    $$
进一步理解推导过程
  • 2.1中 \(\tilde{y}_{i}=sign(y_{i}-F_{t-1}(x_{i}))\)
    • 绝对值的导数就是1或者-1,当
      $$y_{i}-F_{t-1}(x_{i}) > 0$$
    • 对损失函数求导得到-1,负梯度为1
    • 同理得到,当
      $$y_{i}-F_{t-1}(x_{i}) > 0$$
    • 对损失函数求导得到1,负梯度为-1
    • 总结得到负梯度为
      $$\tilde{y}_{i}=sign(y_{i}-F_{t-1}(x_{i}))$$
  • 2.2中正常拟合(使用线性回归和CART回归树均可)
  • 2.3中基函数的权重 \(\rho^{\star}\) 不再是常数,推导如下
  • 待更新*
    $$
    \begin{align}
    待更新
    \end{align}
    $$
回归树(基函数)

Regression Trees

  • 传统GBDT中原始论文使用树回归 ,论文见Firedman 1999,后来作者提出可以使用逻辑回归 ,论文见Friedman 2000
  • 回归树和逻辑回归的优缺点比较
    • 树模型优点 :
      • 可解释性强
      • 可处理混合类型特征(混合类型特征指的是数值型和类别型均可处理)
      • 具有伸缩不变性(无需特征归一化: 神经网络和逻辑回归都需要,逻辑回归中是为了保证随机梯度下降的方向正确,速度快)
      • 可自然的处理缺失值, C4.5树处理缺失值默认使用的方法是先用未缺失样本计算信息增益确定分裂结点,然后将缺失值的每个样本按照权重(当前叶子节点未缺失样本数量 / 未缺失样本总数)分配到各个结点
      • 对异常点鲁棒(不用去除异常点,LR不具有这个优点)
      • 有特征选择的作用
      • 可扩展性强,容易并行(并行是最好的,用来解释为什么XGBoost等都用树模型)
    • 树模型缺点 :
      • 缺乏平滑性(回归预测时输出值只能输出若干种数值,而不是连续的数值,所以不平滑[不平滑即离散])
      • 不适合处理高维稀疏数据(当数据量太少时,非常容易过拟合,树的深度太深,从而模型变得太复杂)

传统GDBT与XGBoost的比较

  • 参考博客: ML——XGBoost-vs-传统GBDT

ML——RF

随机森林是一种集成学习(Ensemble)方法,属于集成学习中的Bagging族,是一种典型的Bagging方法


随机森林

训练

  • 假设训练集总大小为N,特征总数为K
  • 有放回抽样N个训练样本
    • 说明:这种有放回的抽样技术又称为自助法(Bootstrap)重采样技术
  • 随机从所有特征中选取k个特征(k一般远小于总特征数K)
  • 用bootstrap采样的结果(N个样本)和k个特征训练一棵决策树(训练时不剪枝,每棵树都尽可能生长)
  • 重复上述三个步骤:bootstrap采样和训练新的决策树,直到生成m棵决策树
  • 将m棵决策树组合成随机森林

预测

  • 对于一个确定的未标记样本
  • 经过每一棵决策树预测
  • 投票产生结果,确定最终分类

优点

  • 避免过拟合(bootstrap采样[随机]和随机选择部分特征)
  • 便于处理高维度的数据
  • 对高维度的数据无需做特征选择,全部保留也无所谓
  • 可并行计算
  • 模型实现简单

缺点

  • 模型训练多棵树,所以训练时间长
  • 预测时经过多棵树,预测时间长

参数

  • 决策树类型: 三种(ID3,C4.5,CART)均可选,默认均分的三种算法混合到一起
    • 决策树详细信息可参考ML——DT
  • 树的数量m:
    • n_estimators
    • int i: n_estimators = i
    • 默认为10
  • 特征个数k(以Sklearn库为例):
    • max_features = k
    • int i: k = i
    • float f: k = f * K
    • sqrt: k = sqrt(K)
    • log2: k = log2(K)
    • None: k = K
    • Sklearn.ensemble.RandomRorestClassifier默认是”auto”,也就是sqrt(K)
  • 树的深度:
    • max_depth
    • 单棵树的深度
    • 默认是None,不做限制
    • int i: 指定树的深度为i
  • 叶子节点最小样本数:
    • int i: min_samples_leaf = i
    • float f: min_samples_leaf = f * N
    • 默认为1

问题: 随机森林的随机体现在哪里?

  • 个人理解
    • 体现在每棵树的训练样本随机 ,有放回采样(自助法(Bootstrap)重采样技术)
    • 体现在每棵树的训练样本的特征随机 , 对每棵树而言,随机从K个特征中选取k个特征,当前树的每个叶节点分裂时只能从这k个特征中选择

NLP——DuoAttention

注:本文包含 AI 辅助创作

  • 参考链接:
    • 原始论文:DuoAttention: Efficient Long-Context LLM Inference with Retrieval and Streaming Heads, arXiv 202410 , MIT & THU & SJTU & NVIDIA
      • 与 StreamingLLM 同作者
    • GitHub:github.com/mit-han-lab/duo-attention

Paper Summary

  • 整体总结:
    • 核心:DuoAttention 是一种通过区分 Retrieval Heads 和 Streaming Heads 来优化 LLM 内存和计算资源的框架
    • 具体:DuoAttention 可以显著减少了长上下文应用中解码和 Pre-filling 的内存使用和延迟
      • 因为 DuoAttention 对 Retrieval Heads 应用完整的 KV 缓存(Streaming Heads 仅缓存 Sink Token 和 Recent Token)
    • 对比之前 MHA 和 GQA 的效果(内存大幅减少、解码速度大幅提升)
      • MHA 模型内存减少高达 \(2.55\times\),MHA 模型解码速度提升高达 \(2.18\times\),Pre-filling 加速高达 \(1.73\times\)
      • GQA 模型内存减少高达 \(1.67\times\),GQA 模型解码速度提升高达 \(1.50\times\),Pre-filling 加速高达 \(1.63\times\)
      • 且与完全注意力相比准确率损失最小(minimal accuracy loss)
    • 当与量化结合时,DuoAttention 可以进一步提升 KV 缓存容量,在单个 A100 GPU 上支持高达 3.30M 个上下文 Token
  • 背景 & 问题提出:
    • 部署长上下文(long-context)LLM 至关重要,但长上下文带来了显著的计算和内存挑战
    • 跨所有注意力头缓存所有 Key 和 Value (KV)状态会消耗大量内存
    • 现有的 KV 缓存剪枝方法要么损害 LLM 的长上下文能力,要么仅提供有限的效率提升
  • 作者发现:
    • 只有一小部分注意力头(Retrieval Heads),对于处理长上下文至关重要,并且需要对所有 Token 进行完整的注意力计算
    • 而其他头(Streaming Heads),主要关注最近的 Token 和 Attention Sinks,不需要完整的注意力计算
  • 基于这一洞察,论文引入了 DuoAttention:
    • 该框架仅对 Retrieval Heads 应用完整的 KV 缓存,同时对 Streaming Heads 使用轻量级的、恒定长度的 KV 缓存
    • 从而在不损害其长上下文能力的情况下,减少 LLM 解码和 Pre-filling 的内存占用和延迟
  • DuoAttention 使用一种轻量级的、基于优化的算法以及合成数据来准确识别 Retrieval Heads
  • 内存方面:
    • 对于多头注意力模型最高减少 2.55\(\times\)
    • 对于分组 Query 注意力模型最高减少 1.67\(\times\)
  • 效率方面:
    • 对于多头注意力模型解码速度最高提升 2.18\(\times\), Pre-filling 速度最高提升 1.73\(\times\)
    • 和分组 Query 注意力模型解码速度最高提升1.50\(\times\), Pre-filling 速度最高提升 1.63\(\times\)
    • 与完整注意力相比,准确率损失最小(minimal accuracy loss)
  • 开源链接:github.com/mit-han-lab/duo-attention

Introduction and Discussion

  • LLM 处于人工智能革命的前沿,驱动着高级应用,如多轮对话、长文档摘要以及涉及混合模态的任务,如视觉和视频理解
    • 这些应用通常需要处理大量的上下文 Token ;
    • 例如,总结整个《哈利·波特》系列可能涉及分析约一百万个 Token
    • 对于视觉语言模型,挑战更加严峻,其中一张 224×224 的图像对应 256 个 Token ,而一段三分钟、24 FPS 的视频会生成约 1.1M 个 Token
  • 在此类应用中部署 LLM 的一个关键问题是长上下文推理问题
    • 完整的注意力机制要求所有 Token 关注所有先前的 Token 以获得准确的表示,这导致解码延迟线性增加, Pre-filling 延迟二次方增加
    • KV 缓存技术存储所有先前 Token 的 Key 和 Value ,导致内存使用量随上下文长度线性增长
    • 随着序列变长,内存越来越多地被 KV 缓存消耗,给注意力机制带来了显著的计算负担
      • 例如,在 Llama-3-8B 模型架构中,为 1M 个 Token 提供服务并使用 FP16 KV 缓存将需要至少 137 GB 的内存(这已经超过了单个 80GB GPU 的容量)
    • 而且使用如此大上下文进行 Pre-filling 和解码会有显著延迟,这对 LLM 在长上下文场景中的有效使用构成了重大挑战
  • 尽管有许多努力来克服注意力机制在长上下文推理中的挑战,但显著的计算和内存问题仍然存在
    • 架构修改,如分组 Query 注意力,需要模型预训练,并且无法降低计算成本
      • 线性注意力(Linear Attention)方法虽然在计算和内存需求上较低,但在长上下文场景下往往不如 Transformer 模型
      • 近似注意力(Approximative attention)方法,如 H\({}_{2}\)O、StreamingLLM、TOVA 和 FastGen,常常在长上下文应用中牺牲精度,并且与关键的 KV 缓存优化技术(如分组 Query 注意力)不兼容
    • KV 缓存量化虽然有用,但并未减少注意力机制的计算时间
      • 系统级优化,包括 FlashAttention、FlashDecoding 和 PagedAttention,虽然有效,但并未减少 KV 缓存大小,并且在扩展上下文时仍然需要大量计算
  • 论文引入了一个关键观察
    • LLM 中的注意力头可以分为两种不同的类型 :Retrieval Heads 和 Streaming Heads ,如图 1 所示
    • Retrieval Heads 仅占总头数的一小部分,对于处理长上下文至关重要,并且需要对所有 Token 进行完整的注意力计算
    • Streaming Heads(大多数注意力头),主要关注最近的 Token 和 Attention Sinks ,并且可以在仅包含 Recent Token 和 Attention Sinks 的简化 KV 缓存下有效运行
  • 基于 Retrieval Heads 和 Streaming Heads 的二分法,论文提出了 DuoAttention
    • DuoAttention 是一种通用、直接且易于集成的方法,能显著加速 LLM 的解码和 Pre-filling ,并减少内存占用,尤其是在长上下文场景中
    • DuoAttention 的核心创新是一种轻量级的、基于优化的过程,它使用合成数据集来识别不可压缩的 Retrieval Heads
    • 与依赖注意力模式分析 (2024;) 的现有方法不同,DuoAttention 直接测量因 Token 丢弃而产生的输出偏差,从而实现更高的压缩率和改进的部署效率
  • DuoAttention 的设计注重简洁和高效:每个 Transformer 层有两个 KV 缓存
    • 一个用于关键 Retrieval Heads 的完整 KV 缓存
    • 一个用于 Streaming Heads 的恒定 KV 缓存,仅存储 Attention Sinks 和最近的 Token
  • 这种设计使得 DuoAttention 能够显著减少内存使用、提高模型的解码速度,且与完整注意力相比,精度损失最小
  • DuoAttention 与重要的优化技术(如分组 Query 注意力和量化)完全兼容
    • 当结合 8-bit 权重和 4-bit KV 缓存量化时,DuoAttention 使得 Llama-3-8B 模型能够在单个 A100 GPU 上处理高达 3.3M 上下文 Token
      • 与标准的完整注意力 FP16 部署相比,实现了 \(6.4\times\) 的容量提升
    • DuoAttention 为在需要百万级上下文处理的应用中部署 LLM 铺平了道路

DuoAttention

Retrieval Heads 和 Streaming Heads

Retrieval Heads
  • 在基于 Transformer 的 LLM 中,注意力头表现出独特且一致的模式,反映了它们的专门功能
  • 图 1 使用句子“最好的水果是橙子。什么是最好的水果?橙子。”可视化了 Llama-2-7B-32K-Instruct 模型中的两种注意力头
  • 左图突出显示了一个在解码过程中强调相关 Token 的注意力头;
    • 例如,在解码第二个“最好的水果”时,第一个“最好的水果”被加重;在推断第二个“橙子”时,初始的“橙子”被突出显示
    • 这些注意力头,论文称之为 Retrieval Heads ,对于上下文处理至关重要,因为它们捕获了上下文相关的 Token
    • 压缩 Retrieval Heads 的 KV 缓存将导致关键上下文信息的丢失,因此它们需要对所有 Token 进行完整的注意力计算
Streaming Heads
  • 图 1 中间图描绘的注意力头主要关注最近的 Token 和 Attention Sinks,不强调上下文中较早的相关 Token
    • 论文称这些为 Streaming Heads
  • 压缩 Streaming Heads 的 KV 缓存是可行的,因为丢弃未被关注的中国 Token 不会显著改变注意力输出
    • 可以通过仅保留 Attention Sinks 和 Recent Token 的 KV 状态来优化 Streaming Heads ,而不会损害模型管理长上下文的能力
Impact of Token Pruning on Retrieval and Streaming Heads
  • 图 1 的右图显示了一个初步的 Passkey 检索实验
    • 当 Retrieval Heads KV 缓存中的中间 Token 被剪枝时,模型的性能显著下降
    • 移除 Streaming Heads 的中间 Token 对 Passkey 检索精度没有显著影响
  • 这一观察表明,我们可以在不牺牲模型长上下文能力的情况下提高计算效率:
    • 通过丢弃 Streaming Heads 的中间 Token ,同时保持 Retrieval Heads 的完整注意力,将 Streaming Heads 的内存需求降低到 \(O(1)\),从而提高了处理长上下文的效率

Optimization-Based Identification of Retrieval Heads

Definition of Retrieval Heads
  • 第 2.1 节定性地定义了 Retrieval Heads 和 Streaming Heads ,但为了精确识别,论文需要一个具体且量化的定义
  • 在论文中,论文将“Retrieval Heads”定义为:
    • 当被限制为仅关注 Recent Token 和 Attention Sinks 时,会显著改变模型输出的注意力头
  • 论文使用这个标准来区分 Retrieval Heads 和 Streaming Heads
    • 这个定义不同于现有工作 (2024; ),它们仅依赖注意力分数来识别 Retrieval Heads ,忽略了
      • 1)压缩特定注意力头 KV 缓存的端到端影响
      • 2)Value 状态的角色
      • 3)注意力分布在层和头之间的可变性
    • 论文的定义直接测量输出偏差 ,即使它们在注意力分数中不明显,论文也能够识别对长上下文处理至关重要的注意力头
    • 论文在第 3.5 节中提供的消融研究支持了这一论点
Optimization-based Identification
  • 论文采用一种基于优化的方法来识别 Retrieval Heads ,灵感来自先前在 CNN 滤波器剪枝方面的工作,如图 2 所示
    • 首先为 LLM 中的每个 KV 头分配一个门控值 \(\alpha_{i,j}\)
      • 这个值直观地表示了第 \(i\) 层第 \(j\) 个 KV 头在处理长上下文信息时的重要性
      • 在使用分组 Query 注意力的模型中,一个 KV 头可能与多个注意力头相关联,论文的方法考虑了对整个注意力头组的 KV 缓存压缩
  • 论文的基于优化的识别方法直接评估了仅使用 Sink Token 和 Recent Token 压缩每个 KV 头 KV 缓存的影响
    • 首先将每个头的门控值 \(\alpha_{i,j}\in[0,1]\) 初始化为 1,假设所有头最初都作为 Retrieval Heads
    • 然后优化这些门控值,同时保持 LLM 的参数固定,将可训练参数的数量限制在 \(N\times H\),并防止影响模型的原始能力
  • 在前向传播过程中,论文结合每个 KV 头的完整注意力和流式注意力的输出,使用门控值作为混合权重:
    $$\texttt{attn}_{i,j}=\alpha_{i,j}\cdot\texttt{full_attn}+(1-\alpha_{i,j})\cdot\texttt{streaming_attn}$$
    • 其中注意力计算定义为:
      $$\texttt{full_attn} =\texttt{softmax}(\boldsymbol{Q}\boldsymbol{K}^{T}\odot\boldsymbol{M}_{\text{causal} })\boldsymbol{V}, \\
      \texttt{streaming_attn} =\texttt{softmax}(\boldsymbol{Q}\boldsymbol{K}^{T}\odot\boldsymbol{M}_{\text{streaming} })\boldsymbol{V},$$
    • 其中 \(\boldsymbol{M}_{\text{causal} }\) 是因果注意力掩码,而 \(\boldsymbol{M}_{\text{streaming} }\) 表示一个类 \(\Lambda\) 掩码,仅关注最近和初始的 Token
Synthetic Dataset for Identifying Retrieval Heads
  • 仅依赖自然语言建模目标不足以识别 Retrieval Heads
    • 自然文本中需要长跨度推理的监督信号是稀疏的,且大多数 Token 可以使用局部上下文进行推断
  • 论文设计了一个专门旨在增强模型长上下文检索能力的合成数据集,使论文能够有效地识别哪些 KV 头可以在不损害模型性能的情况下被压缩
  • 如图 3 所示,论文通过在一个非常长的上下文中,在十个随机位置嵌入十个随机生成的 \(s\) 个 Token 的 passkey sequences 来创建一个 passkey-retrieval 数据集
    • 模型的任务是在上下文末尾回忆这十个序列
Training and Loss Functions
  • 论文优化蒸馏损失,即完整注意力模型的最后一个隐藏状态与使用 DuoAttention 的模型的最后一个隐藏状态之间的 L2 差异,仅关注整个输入中最后 \(l\) 个 Passkey Token :
    $$\mathcal{L}_{\text{distill} }=\frac{1}{N}\sum_{i=1}^{N}\sum_{j=\bar{T}-l+1}^{T}(\boldsymbol{H}_{\text{full} }^{(i)}[j]-\boldsymbol{H}_{\text{mixed} }^{(i)}[j])^{2}$$
  • 论文的合成数据集确保每个监督信号都与最终的压缩策略相关,使得该过程在信息检索精度方面是无损的
    • 事实证明,它比仅使用自然语言建模更有效
    • 论文使用 L1 正则化项来鼓励门控值的稀疏性:
      $$\mathcal{L}_{\text{reg} }=\sum_{i=1}^{L}\sum_{j=1}^{H}|\alpha_{i,j}|,.$$
  • 最终的训练损失是蒸馏损失和正则化损失的组合,由一个超参数 \(\lambda\) 加权,论文在实验中将其设置为 0.05:
    $$\mathcal{L}=\mathcal{L}_{\text{distill} }+\lambda\mathcal{L}_{\text{reg} }.$$
  • 由于可训练参数的总数仅为数千个浮点数,此优化过程相当快,仅需要 2,000 步
    • 论文论文中的所有训练实验都可以在 8×NVIDIA A100 GPU 服务器上进行

Deploying LLMs with DuoAttention

Binarizing Attention Implementations(二值化注意力)
  • 在推理时,论文仅对指定的 Retrieval Heads 应用完整注意力,这些 Retrieval Heads 是使用训练阶段优化的门控值识别的
  • 论文根据阈值 \(\tau\) 对每个头的注意力策略进行二值化,以区分 Retrieval Heads 和 Streaming Heads :
    $$\text{attn}_{i,j}=\begin{cases}\text{full_attn}&\text{if }\alpha_{i,j}>\tau \\ \text{streaming_attn}&\text{otherwise}\ \end{cases}$$
Reordering Attention Heads(重排注意力头)
  • 在部署之前,论文通过根据注意力头分配重新排序 Query 、 Key 和 Value 投影权重的输出通道来预处理模型
  • 这种重新排序将 Retrieval Heads 和 Streaming Heads 分组为两个不同的、连续的簇,从而允许在层内管理这两种类型头的 KV 缓存时进行高效的切片和连接操作,而不是依赖 scattering 和 gathering 操作
Decoding
  • 如图 5 所示,论文在解码期间为 LLM 的每一层分配两个 KV 缓存 :
    • 一个用于 Retrieval Heads ,存储所有过去的 Key 和 Value ;
    • 另一个用于 Streaming Heads ,仅存储 Attention Sinks 和最近的 Token ,保持恒定大小
  • 当处理一个新 Token 时,其 Query 、 Key 和 Value 向量沿头维度分割,以计算 Retrieval Heads 的完整注意力和 Streaming Heads 的流式注意力
    • 然后将结果沿头维度连接以进行输出投影
Chunked Pre-filling(分块 Pre-filling)
  • 论文使用 FlashAttention-2 来 Pre-fill Retrieval Heads 和 Streaming Heads 的 KV 缓存
    • 在长上下文 LLM 中,分块 Pre-filling 是一种常见做法,将提示分成固定长度的块来 Pre-filling KV 缓存
    • 这种技术通过将线性层中的峰值中间激活大小从序列长度降低到块大小,显著降低了峰值内存使用
  • DuoAttention 与分块 Pre-filling 完全兼容,并且 DuoAttention 中 Streaming Heads 的 Pre-filling 可以在线性时间和恒定内存复杂度下实现,无需专门的核
  • 如图 5 所示,计算了某一层的 KV 后,Streaming Heads 的 KV 缓存会立即被剪枝,仅保留 Sink Token 和最近的 Token
    • 下一个传入 Token 块在 Pre-filling 期间将仅关注恒定数量的上下文 Token
  • 令 \(L\) 表示序列长度,\(K\) 表示块大小(chunk size)
    • Streaming Heads 的 Pre-filling 时间复杂度从 \(O(L^{2})\) 优化到 \(O(LK)\),内存复杂度从 \(O(L)\) 减少到 \(O(K)\)
  • 需要注意的是,DuoAttention 的设计非常适合批量操作,这可以在具有大批量大小的服务场景中进一步提高 LLM 的效率

Experiments

Setups

  • 模型、数据集和基线 (Models, Datasets, and Baselines)
    • 论文在长上下文和短上下文基准测试上评估 DuoAttention,证明论文的方法在保留模型处理长短上下文任务性能的同时,显著提高了效率
      • 对于长上下文评估
        • 论文使用 Needle-in-a-Haystack (NIAH) 基准测试 (Kamradt, 2024) 和 LongBench (2023)
      • 对于短上下文评估
        • 论文评估了在 MMLU (2021)、MBPP (2021) 和 MT-Bench (2023) 上的性能
    • 论文采用了最先进的开源模型,包括 Llama-2-7B-chat (2023b)(及其长上下文变体 Llama-2-7B-32K-Instruct (Together, 2023))、Llama-3-[8,70]B-Instruct(及其长上下文变体 Llama-3-8B-Instruct-Gradient-1048k)以及 Mistral-7B-v0.2-Instruct (2023)
    • 论文将论文的方法与 KV 缓存压缩算法进行了比较,包括 H2O (2023b)、TOVA (2024)、FastGen (2024) 和 StreamingLLM (2023b)
  • Implementation details
    • 论文使用 PyTorch (2019) 和来自 FlashInfer (2024) 的 RoPE (2021) 和 RMSNorm 内核来实现 DuoAttention
    • 对于 Retrieval Heads 的识别
      • 论文使用批量大小为 1,将 10 个 32 词(Words)的 passkeys 插入到 BookSum (2021) 数据集中
      • 识别过程使用 128 个 Sink Token 和 256 个 Recent Token
      • 训练样本从范围为 1,000 个 Token 到模型特定的最大长度(间隔 50 intervals)中采样(问题:这里是指样本长度的采样)

        Training samples are drawn from 50 intervals ranging from 1,000 tokens to the model-specific maximum length

    • passkeys 在上下文中的 1000 个点处随机插入(更多细节包含在附录 A.1 节中)
    • 论文使用 AdamW (2015) 优化器优化门控值,初始学习率为 0.02,在前 400 步从 0.002 进行预热,并在最后 400 步降回 0.002
      • 所有实验在 NVIDIA A100 GPU 上运行 2,000 步

Long-Context Benchmarks

  • 使用 Needle-in-a-Haystack (NIAH) 基准测试和 LongBench (2023) 来评估 DuoAttention
  • 使用了两个长上下文模型:Llama-2-7B-32K-Instruct 和 Llama-3-8B-Instruct-Gradient-1048k
    • DuoAttention 配置:
      • Llama-2-7B-32K-Instruct 使用 25% 的 Retrieval Heads 比例
      • Llama-3-8B-Instruct-Gradient-1048k 使用 50% 的比例
    • 论文在相同的 KV 缓存预算下,将 DuoAttention 与 H2O、TOVA 和 StreamingLLM 进行比较
      • 论文为 DuoAttention 使用 64 个 Sink Token 、256 个 Recent Token 和 32,000 的 Pre-filling 块大小
    • 由于 H2O 和 TOVA 的原始设计不支持长上下文,论文修改了它们的算法,将 Pre-filling 阶段替换为 FlashAttention,并模拟输入最后 50 个 Token 的解码(遵循 Tang 等人 (2024b) 的方法)
    • FastGen 的算法不允许指定 KV 压缩比,因为它会随输入波动
      • 论文调整了注意力恢复比例,以确保在图 6 所示的实验中,KV 缓存预算平均高于 25% 或 50%
    • FastGen 在 Attention Profiling 阶段的二次内存成本限制了其处理长上下文样本的能力
      • 论文测量了 FastGen 在 NIAH 上对 Llama-2-7B 最高到 24K 上下文、对 Llama-3-8B 最高到 32K 上下文的性能;
      • 超过这些大小会导致内存不足错误
    • 详细的基线实现和理由在附录 A.3 节和 A.5 节中提供
  • Needle-in-a-Haystack (NIAH) 是一个具有挑战性的压力测试,旨在评估模型从冗长上下文中准确识别和检索相关信息的能力
    • 如图 6 所示,所有基线方法都无法从长序列的不同深度检索到正确答案,因为它们在生成过程中丢弃了包含必要信息的 KV 缓存
    • DuoAttention 保留了 Retrieval Heads 中的所有 KV 缓存,同时仅丢弃 Streaming Heads 中的缓存,从而保留了模型的检索能力
    • DuoAttention 在所有序列深度上都表现出强大的性能,有效处理高达 1048K Token 的长度
  • LongBench (2023) 是一个全面的长上下文数据集套件,涵盖多个任务和自然文本,旨在更全面地评估长上下文理解能力
    • 图 7 显示了在 14 个 LongBench 任务上的性能,比较了不同方法基于其 KV 缓存预算的表现
    • DuoAttention 在大多数任务上显示出 KV 预算和准确性之间的优越权衡,突显了其泛化能力
    • DuoAttention 在大多数任务上实现了与完全注意力相当的性能,对 MHA 使用 25% 的 KV 缓存预算,对 GQA 使用 50% 的 KV 缓存预算,这与在 needle-in-a-haystack 基准测试中观察到的结果一致
    • 论文在附录的表 5 和表 6 中将 DuoAttention 与 FastGen 进行了比较
    • 附录中的表 3 和表 4 提供了两个模型使用 25% 和 50% KV 缓存预算在所有 21 个 LongBench 任务上的完整结果,表明 DuoAttention 在大多数任务上始终优于基线,并取得了最高的平均分数

Short-Context Benchmarks

  • 为了确保 DuoAttention 不损害模型在短上下文任务上的性能,论文将其与所有基线一起在三个短上下文基准测试上进行了评估:MMLU、MBPP 和 MT-Bench
    • 这些基准测试评估模型的知识、编码能力和帮助性
    • 对 MMLU 使用 one-shot 提示,对 MBPP 和 MT-Bench 使用 zero-shot 提
    • 对于 DuoAttention,在 MMLU 上配置 32 个 Sink Token 和 128 个 Recent Token ,在 MBPP 和 MT-Bench 上配置 16 个 Sink Token 和 64 个 Recent Token
  • 如图 8 和表 1 所示
    • 在相同的 KV 缓存预算下,DuoAttention 在各种模型(包括 Llama-2-7B、Llama-3-8B 和 Llama-3-70B-Instruct)上始终优于所有基线
    • 在 50% KV 缓存预算下,DuoAttention 在大多数基准测试上实现了近乎无损的性能,表明它保留了模型的原始能力

Efficiency Results

  • 论文在单个 NVIDIA A100 GPU 上评估了 DuoAttention 在 Llama-2-7B 和 Llama-3-8B 模型上的解码延迟和内存使用情况
  • 论文为整个基准测试序列预分配 KV 缓存,以防止动态内存分配的额外开销
  • 权重和激活的默认数字格式为 BFloat16
  • 通过对 Llama-2-7B 采用 25% 的 Retrieval Heads 比例,对 Llama-3-8B 采用 50% 的比例,DuoAttention 在保持准确性的同时显著提高了效率
Decoding Efficiency
  • 如图 9 所示
    • DuoAttention 的解码速度呈线性缩放,但与完全注意力相比斜率更平缓,这反映了所选的 Retrieval Heads 比例
      • 这种高效的缩放带来了内存使用的显著减少和解码速度的显著提升
    • 这些改进随着上下文长度的增加而接近 Retrieval Heads 比例的倒数
  • 图 11 显示
    • 在固定上下文大小下,DuoAttention 在不同 KV 预算设置下的加速和内存节省
    • 随着部署配置中 Retrieval Heads 比例的降低,解码延迟和内存使用都线性下降
    • 在图 11 的设置下,DuoAttention 在 A100 GPU 上实现了最大改进:MHA 模型内存减少 2.55 倍,GQA 模型内存减少 1.67 倍;MHA 模型延迟减少 2.18 倍,GQA 模型延迟减少 1.50 倍
Pre-filling Efficiency
  • 如第 2.3 节所述,DuoAttention 也加速了 LLM 的长上下文 Pre-filling
  • 图 10 显示
    • DuoAttention 显著降低了 Pre-filling 延迟和内存使用,并且这些节省随着 Pre-filling 块大小的减小而增加
      • 这是因为 Streaming Heads 的时间和内存复杂度随着块大小的减小而降低
    • DuoAttention 实现了 MHA 模型延迟减少高达 1.73 倍,GQA 模型延迟减少高达 1.63 倍,同时 MHA 模型内存减少高达 2.38 倍,GQA 模型内存减少高达 1.53 倍
Combination with Quantization
  • 为了将更多 Token 装入有限的内存,我们可以将权重和 KV 缓存量化与 DuoAttention 结合,以最大化 KV 缓存容量
  • 先前的研究表明,权重量化 (2023a;) 和 4-bit KV 缓存量化 (2024;) 不会损害模型性能
  • 论文将 DuoAttention 与 QServe (2024) 量化方法和内核相结合,以实现 8-bit 权重和 4-bit KV 缓存的 LLM 推理
  • 测量结果如图 12 所示
    • 将量化技术与 DuoAttention 结合,使论文能够在单个 A100-80G GPU 上使用 Llama-3-8B 模型容纳高达 3.30M 个 Token ,与朴素的完全注意力 BF16 部署相比,容量增加了 \(6.4\times\)

Ablation Studies

  • 论文使用 Mistral-7B-Instruct-v0.2 在 passkeys 检索和 MMLU 数据集上进行了消融研究
  • 对于 passkeys 检索任务,论文将一个 8 词的 passkeys 嵌入到一个 30K 词的文本中,并在 100 个插入深度上进行线性扫描,报告精确匹配准确率
  • 基于优化与基于 Attention Profiling 的 Retrieval Heads 识别 (Optimization-based vs. Attention Profiling-based Retrieval Head Identification)
    • 论文评估了论文的基于优化的方法与 FastGen (2024) 和 RazorAttention (2024a) 中使用的 Attention Profiling 方法,两者使用相同的合成 passkeys 数据集
    • 图 13 (1) 中的结果表明,论文的方法显著优于 Attention Profiling ,后者难以识别 Retrieval Heads ,从而影响了模型的准确优化
  • 使用合成数据优化与语言建模 (Optimizing with Synthetic Data vs. Language Modeling)
    • 如图 13 (1) 所示,论文使用合成数据识别 Retrieval Heads 的方法比传统的 Language Modeling(在自然数据中的所有 Token 上计算损失)产生了明显更好的结果
  • 优化中结合 Sink 和 Recent 注意力的必要性 (Necessity of Sink+Recent Attention in Optimization)
    • 图 13 (2) 强调了在优化阶段结合 Sink 和 Recent 注意力的重要性
    • 仅依赖 Sink Token 或 Recent Token 注意力不足以有效识别 Retrieval Heads
  • 部署阶段配置 (Deployment Phase Configuration)
    • 论文分析了 Streaming Heads 中注意力 Sink 和 Recent Token 的部署配置
    • 论文的发现表明:
      • 性能在 16 个 Sink Token 和 64 个 Recent Token 时达到稳定(图 13 (3))
      • 进一步增加只会带来边际改进
    • 问题:论文的发现跟 StreamingLLM 的 4 个 Token 足以的发现有矛盾!

Related Work

  • 已有很多方法在扩展 LLM 并提高其处理长上下文的效率;这些方法可以分为四个主要类别:
    • 优化模型架构、使用近似注意力机制、应用 KV 缓存量化以及系统级优化
  • Model Architecture
    • MQA (2019) 和 GQA (2023) 通过在 Query 头之间共享 KV 头来减小 KV 缓存的大小
    • 但这些方法需要使用特定架构进行预训练,并且不会降低计算成本(但会降低显存)
    • 线性注意力 Transformer (2023) 减少了内存使用,但在需要长上下文处理的任务上往往表现不佳
  • Approximate Attention
    • 诸如 Sparse Transformer (2019) 和 LongFormer (2020) 等方法使用 Local Attention 或 Block Attention 模式来降低计算复杂度
    • BigBird (2020) 通过结合 Local Attention 和 Global Attention 实现线性复杂度,但其中许多方法需要定制的 GPU 内核或重新训练,限制了其实用性
    • H2O (2023b) 和 TOVA (2024) 基于 Query 模式丢弃 Token 来简化注意力
    • StreamingLLM (2023b) 识别了“注意力 Sink ”并提出始终保留 Initial Token 和 Recent Token 以维持恒定的解码延迟和内存使用,使模型能够处理比预训练序列长度多得多的输入 Token
    • FastGen (2024) 分析注意力头以在解码期间丢弃 Token
    • 论文的实验表明:
      • 这些方法会降低 LLM 的长上下文能力
      • 这些方法无法降低长上下文 LLM 的 Pre-filling 成本
  • KV Cache Quantization
    • 诸如 8-bit 和 4-bit 量化 (2024; 2024; 2024) 等技术减小了 KV 缓存的大小,但它们没有解决注意力内核的计算开销问题
    • 这些方法与 DuoAttention 是互补的,可以结合使用以进一步减少内存使用
  • System Optimizations
    • vLLM (2023) 和 FlashAttention (2022; 2023) 通过优化批处理(Batch Processing)和利用 GPU 内存层次结构来提高注意力计算效率
    • FlashDecoding (2024) 和 RingAttention (2023a) 在解码速度和序列级并行性方面引入了进一步的改进
    • 这些方法提高了计算性能,但它们没有解决 KV 缓存大小减少的问题,它们与 DuoAttention 互补,以实现额外的速度和内存优化
  • Recent Works
    • 一些近期工作与 DuoAttention 有相似的想法
    • Wu 等人 (2024) 引入了 Retrieval Heads 的概念来解释 LLM 的长上下文能力
      • 但他们的方法没有压缩非 Retrieval Heads 的 KV 缓存,仅关注准确性
    • MInference (2024) 通过使用稀疏注意力模式来加速长上下文 LLM 的 Pre-filling
      • 但没有优化解码期间的 KV 缓存存储或延迟
    • RazorAttention (2024a) 也将注意力头分为 Retrieval 和 Non-Retrieval 类别
      • 但 RazorAttention 使用 Attention Profiling 方法而不是 Optimization-based 方法区区分
        • 论文的实验表明,Attention Profiling-based 方法不如论文的 Optimization-based 的方法准确
      • 而且,RazorAttention 没有优化 Pre-filling
        • DuoAttention 提供了更有效的 KV 缓存管理和更高的压缩率,从而在长上下文应用中为 Pre-filling 和解码带来了更好的性能

Appendix A

A.1 Experimental Details

  • 论文使用 PyTorch (2019) 中的 FSDP 进行模型训练,并使用 DeepSpeed Ulysses (2023) 序列并行来支持长序列
  • 在训练期间,论文使用 Guo 等人 (2024) 实现的、如图 14 所示的高效块稀疏近似 \(\Lambda\) 类注意力来计算流式注意力
  • 不同模型的最大序列长度各不相同,详见表 2

A.2 Full LongBench Results

A.3 在长上下文基准测试上 H2O 和 TOVA 的实现 (Implementation of H2O and TOVA on Long-Context Benchmarks)

  • H2O (2023b) 和 TOVA (2024) 算法的原始设计与 Pre-filling 阶段的 FlashAttention (2022) 不兼容,因为它们依赖注意力分数来执行 Token Eviction(驱逐)
    • 由于 FlashAttention 中的注意力分数从未被具体化,这些算法无法用于 Pre-filling ,这是它们的主要缺陷之一
    • 因此,不可能在像“大海捞针”和 LongBench 这样的长上下文设置中评估这些算法,因为它们会在上下文 Pre-filling 期间导致内存不足(OOM)
  • 为了与这些策略进行比较,论文修改了算法:
    • 在 Pre-filling 期间,论文使用 FlashAttention 进行精确计算
    • 在解码阶段,论文根据生成 Token 对上下文 Token 的注意力分数执行 Token Eviction
  • 这种修改相比原始设计提高了性能,因为 Pre-filling 是精确的,并且 Token Eviction 仅发生在解码期间
    • 在极端情况下,如果答案中只有一个生成 Token (例如,多项选择题任务),论文实现的 H2O 和 TOVA 将与完全注意力一样精确,这并非它们的真实精度
    • 为了接近它们的真实性能,论文在长输入基准测试(“大海捞针”和 LongBench)中模拟最后 50 个 Token 作为生成 Token ,以足够长时间地执行它们的 Token Eviction 策略,论文的算法也是如此
  • 此实验设置也被 Tang 等人 (2024b) 使用
    • 实验结果表明论文的方法可以通过此压力测试,而 H2O 和 TOVA 则不能

A.4 Implementation of FastGen on Long-Context Benchmarks

  • 由于缺乏 FastGen (2024) 算法的官方实现,论文使用一个社区代码库 (2024) 对其进行了复现,该代码库被 FastGen 的官方仓库引用
    • 在 FastGen 算法中,剪枝比率不能直接配置;而是使用恢复比率 \(T\) 来控制稀疏度,如 FastGen 论文中所述
  • 为了量化稀疏度,论文计算了所有测试用例的平均 KV 缓存使用量作为整体稀疏度的度量
    • 对于 Llama-2-7B 模型,论文将恢复比率设置为 \(0.7\),确保平均 KV 缓存预算超过完整 KV 缓存的 25%
    • 对于 Llama-3-8B 模型,论文将恢复比率设置为 \(0.87\),确保平均 KV 缓存预算超过完整 KV 缓存的 50%
  • 由于 FastGen 使用用户提供提示的完整注意力图来分析不同头的类型,它会导致 \(O(n^{2})\) 的注意力图复杂度
    • 论文无法在长上下文中测试其性能
  • 对于长上下文基准测试,论文使用了 8 个 A100-80G GPU,对于 Llama-2-7B 模型实现了最高 24k Token 的序列长度,对于 Llama-3-8B 模型实现了最高 32k Token 的序列长度
  • 除了图 6 中显示的“大海捞针”基准测试结果外,论文还评估了FastGen 在两个模型上的 LongBench 表现
    • 但由于 FastGen 的二次内存消耗,论文仅报告了在 8x A100-80G GPU 上使用 FastGen 可以运行的数据集结果
    • 如表 5 和表 6 所示,DuoAttention 在 LongBench 数据集上 consistently 优于 FastGen

补充表格和图标

  • 图 15: NIAH result on the Mistral-7B-Instruct-v0.2 model
  • 图 16: NIAH result on the Mistral-7B-Instruct-v0.3 model

NLP——Gated-Delta-Net

注:本文包含 AI 辅助创作

  • 参考链接:
    • 原始文档:(Gated DeltaNet)Gated Delta Networks: Improving Mamba2 with Delta Rule, ICLR 2025, NVIDIA
    • GitHub:github.com/NVlabs/GatedDeltaNet

Paper Summary

  • 整体总结:
    • 本文提出了一种新的框架:Gated DeltaNet
      • vs Mamba2:能够实现更好的键值关联学习
      • vs DeltaNet:具有更强的自适应内存清除能力
  • 背景 & 问题:
    • 线性 Transformer (Linear Transformers) 作为标准 Transformer 的高效替代方案获得了关注,但它们在检索和长上下文任务中的性能有限
    • 为了解决这些限制,最近的工作探索了两种不同的机制:
      • 用于自适应内存控制的门控 (gating for adaptive memory control)
      • 用于精确内存修改的 Delta 更新规则 (delta update rule for precise memory modifications)
    • 论文观察到这些机制是互补的,门控支持快速内存擦除,而 Delta 规则促进定向更新
      • 基于这一见解,论文引入了门控 Delta 规则 (gated delta rule),并开发了一种针对现代硬件优化的并行训练算法
  • 论文提出新的架构 Gated DeltaNet
    • 在包括语言建模、常识推理、上下文检索、长度外推和长上下文理解在内的多个基准测试中,超越了 Mamba2 和 DeltaNet 等现有模型
  • 作者通过开发将 Gated DeltaNet 层与滑动窗口注意力 (sliding window attention) 或 Mamba2 层相结合的混合架构,进一步提升了性能,实现了改进的训练效率和卓越的任务性能

Introduction and Discussion

  • Transformer 架构显著提升了 LLM 的能力,由于其有效的注意力机制,在各种任务上展现出卓越的性能
    • 该机制在精确序列建模方面表现出色,并在训练期间利用了现代 GPU 的并行处理能力
    • 但自注意力 (self-attention) 组件的计算复杂度随序列长度呈二次方增长,导致巨大的计算需求,给训练和推理都带来了挑战
  • 为了缓解这些问题,研究人员探索了诸如线性 Transformer (2020a) 等替代方案,它们用基于核化点积的线性注意力 (kernelized dot-product-based linear attention) 取代了传统的基于 softmax 的注意力,通过将其重构为具有矩阵值状态的线性 RNN,显著减少了推理期间的内存需求
  • 虽然早期版本的线性 Transformer 在语言建模任务中表现不如标准 Transformer,但最近的增强已经显示出有前景的改进
    • 例如结合类似于 LSTM 中的数据依赖门控机制,以 GLA (2024a) 和 Mamba2 (2024a) 等模型为例
  • 然而,在管理长序列信息方面仍然存在挑战,特别是在上下文检索任务中,传统 Transformer 仍保持其优势 (2023a; 2024;)
  • 这种现象并不令人惊讶:
    • 线性 Transformer 可以被解释为实现了一种基于外积 (outer-product) 的键值关联记忆,让人联想到张量积表示 (1990)
    • 但它们可以存储的正交键值对的数量受到模型维度的限制
      • 当序列长度超过这个维度时,“记忆碰撞 (memory collisions)”将不可避免,阻碍精确检索 (2021a)
  • Mamba2 通过引入一个简单的门控更新规则来解决这个限制:
    $$ \mathbf{S}_{t}=\alpha_{t}\mathbf{S}_{t-1}+\boldsymbol{v}_{t}\boldsymbol{k}^{\mathrm{T} }_{t}$$
    • 该规则在每个时间步通过一个动态比率 \(\alpha_t \in (0,1)\) 统一衰减所有键值关联
  • 但这种方法没有考虑不同键值关联的重要性差异,可能导致内存利用效率低下
    • 如果模型需要忘记一个特定的键值关联,所有键值关联都会被同等程度地遗忘 ,使得这个过程缺乏针对性和效率
  • 相比之下,采用 Delta 规则 (1960) 的线性 Transformer,即 DeltaNet (2021a; 2024b),通过(软性地)用传入的新键值对替换旧的键值对来选择性更新记忆
    • 这种方法在上下文检索的合成基准测试中展示了令人印象深刻的性能
    • 但由于这个过程((一次只修改一个键值对)),模型缺乏快速清除过时或无关信息的能力 ,尤其是在需要擦除先前数据的上下文切换期间
    • 因此,人们发现 DeltaNet 在现实世界任务中表现一般 (2024b),这很可能是因为缺乏强大的内存清除机制
  • 认识到门控更新规则和 Delta 规则在内存管理方面的互补优势,论文提出了门控 Delta 规则 (gated delta rule)
    • 这是一种简单直观的机制,结合了两种方法
    • 这个统一的规则实现了灵活的内存控制:
      • 可以通过设置 \(\alpha_{t}\to 0\) 来迅速清除内存
      • 同时可以通过设置 \(\alpha_{t}\to 1\) 来选择性更新特定内容而不影响其他信息(有效地切换到纯 Delta 规则)
  • 剩下的挑战在于以硬件高效的方式实现门控 Delta 规则
    • 基于 (2024b) 使用 WY 表示 (1985) 并行化 Delta 规则计算的高效算法,论文仔细扩展了他们的方法以纳入门控项
    • 论文的扩展保留了分块并行 (chunkwise parallelism) 的优势 (2022a; 2023a; 2024a),实现了硬件高效的训练
  • 论文最终的架构 Gated DeltaNet,在一套全面的基准测试中,包括语言建模、常识推理、上下文检索、长度外推和长上下文理解,持续优于 Mamba2 和 DeltaNet
  • 基于这些结果,作者还开发了混合架构,策略性地将 Gated DeltaNet 层与滑动窗口注意力或 Mamba2 层相结合,进一步提升了训练效率和模型性能

Preliminary

Mamba2:带衰减的线性注意力 (Mamba2: Linear Attention with decay)

  • 众所周知,线性 Transformer (2020a) 在排除归一化和查询/键激活的情况下,可以表述为以下线性递归:
    $$
    \begin{align}
    \mathbf{S}_{t}&=\mathbf{S}_{t-1}+\boldsymbol{v}_{t}\boldsymbol{k}_{t}^{\intercal} \in\mathbb{R}^{d_{v}\times d_{k} },\\
    \boldsymbol{o}_{t}&=\mathbf{S}_{t}\boldsymbol{q}_ {t}\in\mathbb{R}^{d_{v} }
    \end{align}
    $$
    • 其中 \(d_{k}\) 和 \(d_{v}\) 分别代表查询/键 (query/key) 和值 (value) 的(头)维度
  • 通过展开递归,我们可以将其表示为向量形式(左)和矩阵形式(右)如下:
    $$
    \begin{align}
    \boldsymbol{o}_{t}&=\sum_{i=1}^{t}(\boldsymbol{v}_{i}\boldsymbol{k}_{i}^{\intercal})\boldsymbol{q }_{t}=\sum_{i=1}^{t}\boldsymbol{v}_{i}(\boldsymbol{k}_{i}^{\intercal}\boldsymbol{q}_{t})\in\mathbb{R}^{d_{v} },\\
    \mathbf{O}&=(\mathbf{Q}\mathbf{K}^{\intercal}\odot\mathbf{M}) \mathbf{V}\in\mathbb{R}^{L\times d_{v} }
    \end{align}
    $$
    • 其中 \(L\) 是序列长度,\(\mathbf{M}\in\mathbb{R}^{L\times L}\) 是由 \(\mathbf{M}_{ij}=0\)(当 \(i< j\))和 \(1\)(其他情况)定义的因果掩码 (causal mask)
  • 然而,这种普通的线性注意力在语言建模中表现远不如 Transformer
    • 理解:本质上 \(\sum_{i=1}^{t}\boldsymbol{v}_{i}(\boldsymbol{k}_{i}^{\intercal}\boldsymbol{q}_{t})\in\mathbb{R}^{d_{v} }\) 也有了一定的按照 Q 和 K 的相似度对 V 进行加权的思想,但这里主要差异在于 Transformer 是 Softmax 的 Attention 权重:
      $$Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax\left(\frac{\boldsymbol{Q}\boldsymbol{K}^{\top}}{\sqrt{d_k}}\right)\boldsymbol{V}$$
    • 理解:线性 Attention 和 Softmax Attention 都是 QK 内积越大,V 权重越大,但两者有本质区别:
      • 线性 Attention 未实现归一化,仅仅累加所有 Token 的 V,理论上会导致越靠后的 Token,方差是越大的
      • Softmax Attention 的权重是 \(e^{qk}\) 加权归一化,本质与线性归一化是不同的
      • 问题:如果线性 Attention 中使用 \(e^{qk}\) 作为累加,是否基本上可以实践 Softmax 的等价实现?
        • 回答:不可以,因为 Softmax Attention 中的 权重是 \(e^{qk}\),而线性 Attention 累加的对象是 \(kv\),重点:即使使用 \(e^{kv}\) 累加,累加对象也不同,反而导致含义变了!线性 Attention 的形式导致了他们无法实现 Softmax Attention 这样的 \(e^{qk}\) 加权平均
  • 为了解决这个问题,通常添加一个衰减项来遗忘历史信息
    • 这里论文以 Mamba2 (2024a) 为例,它可以表示为以下线性递归(取决于具体的参数化):
      $$
      \begin{align}
      \mathbf{S}_{t}&=\alpha_{t}\mathbf{S}_{t-1}+\boldsymbol{v}_{t}\boldsymbol{k}_{t}^ {\intercal},\\
      \boldsymbol{o}_{t}&=\mathbf{S}_{t}\boldsymbol{q}_{t}
      \end{align}
      $$
    • 其中 \(\alpha_{t}\in(0,1)\) 是一个数据依赖的标量值衰减项,随 \(t\) 变化
  • 定义累积衰减乘积
    $$\gamma_{j}=\prod_{i=1}^{j}\alpha_{i}$$
  • 并通过展开递归,我们可以将结果表示为向量形式(左)和矩阵并行形式(右):
    $$
    \begin{align}
    \boldsymbol{o}_{t}&=\sum_{i=1}^{t}\left(\frac{\gamma_{t} }{\gamma_{i} }\boldsymbol{v }_{i}\boldsymbol{k}_{i}^{\intercal}\right)\boldsymbol{q}_{t}=\sum_{i=1}^{t}\boldsymbol{v}_{i}\left( \frac{\gamma_{t} }{\gamma_{i} }\boldsymbol{k}_{i}^{\intercal}\boldsymbol{q}_{t}\right),\\
    \mathbf{O}&=((\mathbf{Q}\mathbf{K}^{\intercal})\odot\Gamma),\mathbf{V}
    \end{align}
    $$
    • 这里,\(\Gamma\in\mathbb{R}^{L\times L}\) 是一个衰减感知的因果掩码,其中 \(\Gamma_{ij}=\frac{\gamma_{i} }{\gamma_{j} }\)(如果 \(i\geq j\)),否则 \(\Gamma_{ij}=0\)
  • 这种并行形式和递归形式之间的等价性也被称为 状态空间对偶性 (state space duality, SSD) (2024a)
    • 这种递归结构也出现在其他几种架构中,包括 Gated RFA (2021)、xLSTM (2024) 和 Gated RetNet (2023a)
    • 当 \(\gamma_{t}\) 与数据无关时(data-independent),该公式简化为 RetNet (2023a) 和 Lightning-Attention (2024a)
  • 此外,如果 \(\gamma_{t}\) 扩展为矩阵值而非标量值,当使用外积结构参数化时,高效的训练算法仍然是可能的,正如 (2024a) 所展示并被 (2024b;2024;2025 等) 所使用的 分块训练 (Chunkwise training)
    • 但递归形式和并行形式对于高效训练来说都不是理想的 (2022a; 2024a),这促使了使用分块并行形式 (2022a; 2023a) 进行硬件高效的线性时间训练,如下所述
  • 总结来说,分块并行形式将输入和输出分割成几个大小为 \(C\) 的块 (chunk),并根据前一个块的最终状态以及当前块的查询/键/值块来计算每个块的输出
    • 遵循 (2023a); (2024a) 的符号,论文以查询块 \(\boldsymbol{q}\) 为例
  • 论文将
    • \(\mathbf{Q}_{[t]}:={\boldsymbol{q} }_{tC+1:(t+1)C+1}\) 表示为块 \(t\) 的查询块
    • \({\boldsymbol{q} }_{[t]}^{r}:={\boldsymbol{q} }_{tC+r}\) 表示为块 \(t\) 内的第 \(r\) 个查询块
    • \(t\) 的初始状态定义为 \(\mathbf{S}_{[t]}:=\mathbf{S}_{[t]}^{0}=\mathbf{S}_{[t-1]}^{C}\)
  • 通过部分展开递归,论文有
    $$
    \begin{align}
    \mathbf{S}_{[t]}^{r}&=\mathbf{S}_{[t]}+\sum_{i=1}^{r}\boldsymbol{v}_{[t]}^{i}\boldsymbol{k}_{[t ]}^{i\tau}\in\mathbb{R}^{d_{v}\times d_{k} },\\
    \boldsymbol{\sigma}_{[t]}^{r}&=\mathbf{S }_{[t]}^{r}{\boldsymbol{q} }_{[t]}^{r}=\mathbf{S}_{[t]}{\boldsymbol{q} }_{[t]}^{r}+\sum_{i=1}^{r} \boldsymbol{v}_{[t]}^{i}\left(\boldsymbol{k}_{[t]}^{i\tau}{\boldsymbol{q} }_{[t]}^{r}\right)\in\mathbb{R}^ {d_{v} }
    \end{align}
    $$
  • 等价地,以矩阵形式表示:
    $$
    \begin{align}
    \mathbf{S}_{[t+1]}&=\mathbf{S}_{[t]}+\mathbf{V}_{[t]}\mathbf{K}_{[t]}^{\tau} \in\mathbb{R}^{d_{v}\times d_{k} },\\
    \mathbf{O}_{[t]}&=\mathbf{Q}_{[t]}{ \mathbf{S} }_{[t]}^{\tau}+\left(\mathbf{Q}_{[t]}\mathbf{K}_{[t]}^{\tau}\odot \mathbf{M}\right)\mathbf{V}_{[t]}\in\mathbb{R}^{C\times d_{v} }
    \end{align}
    $$
    • 其中 \(\mathbf{M}\in\mathbb{R}^{C\times C}\) 是因果掩码
    • 上述等式富含矩阵乘法 (matmuls),允许基于张量核心 (tensor-core) 的硬件优化
  • 这个分块算法可以轻松扩展到带衰减的线性注意力:
    $$
    \begin{align}
    \mathbf{S}_{[t+1]}&=\overrightarrow{\mathbf{S}_{[t]} }+\mathbf{V}_{[t]}^{\tau} \overrightarrow{\mathbf{K}_{[t]} }\in\mathbb{R}^{d_{v}\times d_{k} },\\
    \mathbf {O}_{[t]}&=\overleftarrow{\mathbf{Q}_{[t]}\mathbf{S}_{[t]}^{\tau} }+\left( \mathbf{Q}_{[t]}\mathbf{K}_{[t]}^{\tau}\odot\Gamma_{[t]}\right)\mathbf{V}_{[t ]}\in\mathbb{R}^{C\times d_{v} }
    \end{align} \tag{1}
    $$
    • 其中有
      $$(\Gamma_{[t]})_{ij}=\frac{\gamma_{[t]}^{i} }{\gamma_{[t]}^{j} },\gamma_{[t]}^{j}= \prod_{j=tC+1}^{tC+j}\alpha_{j} $$
      • 注:这里论文稍微滥用了 \(\gamma\) 的符号来表示每个块的累积乘积(分别从每个块的第一个位置开始),而不是整个序列
      • 这里论文使用左箭头 (\(\overset{\leftarrow}{\cdot}\)) 或右箭头 (\(\overset{\rightarrow}{\cdot}\)) 分别表示衰减到每个块第一个位置和最后一个位置的变量,
        $$
        \begin{align}
        \overleftarrow{ {\boldsymbol{q} }_{[t]}^{r} }&=\gamma_{[t]}^{r}{\boldsymbol{q} }_{[t]}^{r} \text{ 将每个向量衰减到块 t 的第一个位置 (decaying each vector to the first position of chunk t)}\\
        \overleftarrow{ {\boldsymbol{k} }_{[t]}^{r} }&=\frac{\gamma_{[t]}^{C} }{ \gamma_{[t]}^{c} }\boldsymbol{k}_{[t]}^{r} \text{ 将每个向量衰减到块 t 的最后一个位置 (decaying each vector to the last position of chunk t)} \\
        \overrightarrow{\mathbf{S}_{[t]} }&=\gamma_{[t]}^{C}\mathbf{S}_{[t]} \text{ 在整个块 t 上衰减状态矩阵 (decaying the state matrix over the entire chunk t)} \tag{2}
        \end{align}
        $$
        • 其他变量(例如 \(\overrightarrow{\boldsymbol{v} }\))也类似
  • Mamba2 中引入的 SSD 分解算法在很大程度上等同于这种分块算法
    • 对于更通用的方法, (2024a) 提出了一种扩展的分块算法,用于线性注意力,该算法结合了细粒度的衰减机制

Delta Networks: Linear Attention with Delta Rule

  • Delta 更新规则(Delta Update Rule) (1960; 2021a) 动态地擦除与当前输入键 (\(\boldsymbol{k}_{t}\)) 关联的旧值 (\(\boldsymbol{v}_{t}^{\text{old} }\)) ,并写入一个新值 (\(\boldsymbol{v}_{t}^{\text{new} }\)) ,该新值是基于“写入强度” \(\beta_{t}\in(0,1)\) 的当前输入值和旧值的线性组合 (可以将 \(\beta_{t}\in(0,2)\) 设置为允许负特征值,以解锁 DeltaNet 的状态跟踪能力 (2024; 2025))
    $$
    \begin{align}
    \mathbf{S}_{t}&=\mathbf{S}_{t-1}-\underbrace{(\mathbf{S}_{t-1}\boldsymbol{k}_{t})}_{\boldsymbol {v}_{t}^{\text{old} } }\boldsymbol{k}_{t}^{\top}+\underbrace{(\beta_{t}\boldsymbol{v}_{t}+(1-\beta_{t})\mathbf{S}_{t-1}\boldsymbol{k}_{t}))}_{\boldsymbol{v}_{t}^{\text{new} } }\boldsymbol{k}_{t}^{\top} \\
    &= \mathbf{S}_{t-1}\left(\mathbf{I}-\beta_{t}\boldsymbol{k}_{t}\boldsymbol{k}_{t}^{\top}\right)+ \beta_{t}\boldsymbol{v}_{t}\boldsymbol{k}_{t}^{\top}
    \end{align}
    $$
    • 如上所示,DeltaNet 实现了具有广义 Householder 转移矩阵 (\(\mathbf{I}-\beta_{t}\boldsymbol{k}_{t}\boldsymbol{k}_{t}^{\mathsf{T} }\)) 的一阶线性递归
    • 尽管在关联回忆和语言建模性能上表现出色 (2021a),但由于计算效率低下,DeltaNet 受到的关注有限,直到 (2024b) 引入了一种硬件高效的分块训练算法,详情如下
  • 分块并行形式 (Chunkwise parallel form)。通过部分展开递归,论文有
    $$\mathbf{S}_{[t]}^{r}=\mathbf{S}_{[t]}\underbrace{\left(\prod_{i=1}^{r} \mathbf{I}-\beta_{[t]}^{i}\boldsymbol{k}_{[t]}^{i}\boldsymbol{k}_{[t]}^{i\mathsf{T} }\right)}_{:= \mathbf{P}_{[t]}^{r} }+\underbrace{\sum_{i=1}^{r}\left(\beta_{[t]}^{i}\boldsymbol{v}_{[t ]}^{i}\boldsymbol{k}_{[t]}^{i\mathsf{T} }\prod_{j=i+1}^{r}\left(\mathbf{I}-\beta_{[t]}^{ j}\boldsymbol{k}_{[t]}^{j}\boldsymbol{k}_{[t]}^{j\mathsf{T} }\right)\right)}_{:=\mathbf{H}_{[t]}^{r} } \tag{3}$$
  • 其中 \(\mathbf{P}_{[t]}^{j}\) 涉及广义 Householder 矩阵的累积乘积,可以通过经典的 WY 表示 (1985) 进行优化:
    $$\mathbf{P}_{[t]}^{r}=\mathbf{I}-\sum_{i=1}^{r}\mathbf{w}_{[t]}^{i}\boldsymbol{k}_{[t]}^{ i\mathsf{T} }\in\mathbb{R}^{d_{k}\times d_{k} }\qquad\mathbf{w}_{[t]}^{r}=\beta_{ [t]}^{r}\left(\boldsymbol{k}_{[t]}^{r}-\sum_{i=1}^{r-1}\left(\mathbf{w}_{[t]}^{i}(\boldsymbol{k}_ {[t]}^{i\mathsf{T} }\boldsymbol{k}_{[t]}^{r})\right)\right)\in\mathbb{R}^{d_{k} } \tag{4}$$
  • 同样,\(\mathbf{H}_{[t]}^{r}\) 可以表示为:
    $$\mathbf{H}_{[t]}^{r}=\sum_{i=1}^{r}\mathbf{u}_{[t]}^{i}\boldsymbol{k}_{[t]}^{i\mathsf{T } }\in\mathbb{R}^{d_{v}\times d_{k} }\qquad\mathbf{u}_{[t]}^{r}=\beta_{[t]}^{r} \left(\boldsymbol{v}_{[t]}^{r}-\sum_{i=1}^{r-1}\left(\mathbf{u}_{[t]}^{i}(\boldsymbol{k}_{[t]}^{ i\mathsf{T} }\boldsymbol{k}_{[t]}^{r})\right)\right)\in\mathbb{R}^{d_{v} } \tag{5}$$
  • 并以矩阵形式表示:\(\mathbf{P}_{[t]}=\mathbf{I}-\mathbf{W}_{[t]}^{\top}\mathbf{K}_{[t]}\in\mathbb{R }^{d_{k}\times d_{k} }\),\(\mathbf{H}_{[t]}=\mathbf{U}_{[t]}^{\top}\mathbf{K}_{[t]}\in\mathbb{R}^{d_{v} \times d_{k} }\)。通过使用 UT 变换 (2006),我们可以进一步将 \(\mathbf{W}\) 和 \(\mathbf{U}\) 写成矩阵形式:
    $$\mathbf{T}_{[t]} =\left[\mathbf{I}+\text{strictLower}\left(\text{diag}(\beta_{[t]} )\mathbf{K}_{[t]}\mathbf{K}_{[t]}^{\mathsf{T} }\right)\right]^{-1}\text{ diag}\left(\beta_{[t]}\right)\in\mathbb{R}^{C\times C} \\
    \mathbf{W}_{[t]} =\mathbf{T}_{[t]}\mathbf{K}_{[t]}\in\mathbb{R}^{C\times d_{k} }, \qquad\mathbf{U}_{[t]}=\mathbf{T}_{[t]}\mathbf{V}_{[t]}\in\mathbb{R}^{C \times d_{v} }\tag{6-7}$$
  • 将这些代回方程 3,得到了一个硬件高效的 DeltaNet 分块算法,该算法利用了矩阵乘法,实现了基于张量核心的 GPU 优化:
    $$\mathbf{S}_{[t+1]} =\mathbf{S}_{[t]}\mathbf{P}_{[t]}+\mathbf{H}_{[t]}=\mathbf{S}_{[t ]}+\left(\mathbf{U}_{[t]}-\mathbf{W}_{[t]}\mathbf{S}_{[t]}^{\mathsf{T} }\right)^ {\mathsf{T} }\mathbf{K}_{[t]} \in\mathbb{R}^{d_{v}\times d_{k} } \\
    \mathbf{O}_{[t]} =\mathbf{Q}_{[t]}\mathbf{S}_{[t]}^{\mathsf{T} }+(\mathbf{Q}_{[t]} \mathbf{K}_{[t]}^{\mathsf{T} }\odot\mathbf{M})\left(\mathbf{U}_{[t]}-\mathbf{W }_{[t]}\mathbf{S}_{[t]}^{\mathsf{T} }\right) \in\mathbb{R}^{C\times d_{v} } \tag{8-9}$$

Gated Delta Networks

Formulation: Gated Delta Rule

  • 论文提出的门控 Delta 规则简单而有效:
    $$\mathbf{S}_{t}=\mathbf{S}_{t-1}\left(\alpha_{t}(\mathbf{I}-\beta_{t}\boldsymbol{k}_{t }\boldsymbol{k}_{t}^{\mathsf{T} })\right)+\beta_{t}\boldsymbol{v}_{t}\boldsymbol{k}_{t}^{\mathsf{T} } \tag{10}$$
    • 其中数据依赖的门控项 \(\alpha_{t}\in(0,1)\) 控制状态衰减
    • 这个公式统一了门控机制和 Delta 规则的优势:门控项支持自适应内存管理,而 Delta 更新结构促进有效的键值关联学习
  • 论文通过 Liu 等 (2024) 引入的在线学习框架的视角,对门控 Delta 规则进行了正式分析
    • 在这个框架中,循环状态更新作为在线学习问题的闭式解出现,如表 1 所示
    • 最近的线性 RNN 架构通常在其在线学习目标中包含一个正则化项,以防止状态偏离先前的值,从而实现记忆保留
      • 但当状态被信息饱和时,这种保留机制会变得有问题。在这种情况下,每个状态将对多个信息片段进行编码,使得精确检索变得困难
    • 为了解决这个限制,Mamba2 和 Gated DeltaNet 引入了一个自适应缩放因子 \(\alpha_{t}\),它放松了正则化项,允许 \(\mathbf{S}_{t}\) 和 \(\mathbf{S}_{t-1}\) 之间的受控偏差
      • 这个修改通过选择性遗忘实现了动态内存管理,这在过滤无关信息时可能很有用(见 章节3.2)
  • 另一方面,线性注意力 (Linear Attention, LA) 和 Mamba2 使用简单的负内积损失 -\(\langle\mathbf{S}_{t}\boldsymbol{k}_{t},\boldsymbol{v}_{t}\rangle\),而 Longhorn (2024) 使用更具表达力的在线回归目标 \(|\mathbf{S}_{t}\boldsymbol{k}_{t}-\boldsymbol{v}_{t}|^{2}\) 来更好地建模键值关联。由此产生的 Longhorn 更新规则与 Delta 更新规则非常相似,这表明(门控)Delta 规则在上下文关联回忆方面优于 Mamba2
    • Longhorn 更新规则与 Delta 更新规则理论上的区别在于优化方法:Longhorn 使用隐式在线学习 (2010) 来推导闭式的全局最优更新,而 DeltaNet 通过一步显式梯度下降来优化相同的目标,正如 (2024) 所指出的
  • 从快速权重编程 (fast weight programming) (2022a)、测试时训练 (test-time training) (2024a) 和回归 (2025) 的角度来看,隐藏状态 \(\mathbf{S}\) 可以被解释为一个(快速)权重矩阵,Delta 规则通过测试时随机梯度下降 (stochastic gradient descent, SGD) 优化在线回归目标 \(\mathcal{L}(\mathbf{S}_{t})=\frac{1}{2}|\mathbf{S}_{t}\boldsymbol{k}_{t}-\boldsymbol{v}_{t}|^{2}\):
    $$\mathbf{S}_{t+1}=\mathbf{S}_{t}-\beta_{t}\nabla\mathcal{L}(\mathbf{S}_{t})= \mathbf{S}_{t}-\beta_{t}(\mathbf{S}_{t}\boldsymbol{k}_{t}-\boldsymbol{v}_{t})\boldsymbol{k}_{t}^{\top}= \mathbf{S}_{t}\left(\mathbf{I}-\beta_{t}\boldsymbol{k}_{t}\boldsymbol{k}_{t}^{\top}\right)+\beta_{t}\boldsymbol{v}_{t}\boldsymbol{k}_{t}^{\top}$$
    • 其中 \(\beta_{t}\) 代表(自适应)学习率。从这个角度来看,门控 Delta 规则可以被视为将自适应权重衰减项 \(\alpha_{t}\) 纳入 SGD 更新中,这是一种在深度学习中广泛使用的技术 (1991; 2023)
    • 同时,Titans (2024) 证明了在 RNN 测试时 SGD 更新中结合权重衰减机制的有效性

Case study: Single Needle in a Haystack, S-NIAH,大海捞针

  • 为了更好地理解 Delta 规则和门控规则之间的互补优势,论文提供了一个关于来自 RULER (2024) 的“大海捞针”基准测试套件的案例研究,其中一个键值对充当“大海”(上下文)中的“针”,模型必须在给定键时回忆起值
  • 表 2 展示了结果,论文得出三个主要观察:
    • 衰减损害记忆保留 (Decay hurts memory retention)
      • 在最简单的 S-NIAH-1 设置中,具有重复的合成上下文,模型记忆的信息最少,测试长期保留能力
      • DeltaNet 在所有序列长度上都实现了接近完美的性能
      • Mamba2 在超过 2K 序列后性能显著下降,因为它衰减历史信息太快,而 Gated DeltaNet 的下降不那么严重,这得益于使用了 Delta 规则
    • 门控促进过滤 (Gating facilitates filtering)
      • 在具有真实世界文章上下文的 S-NIAH-2/3 中,模型存储所有潜在相关信息,测试高效的内存管理
      • 在固定状态大小的情况下,缺乏清除会导致内存碰撞,信息变得叠加且无法区分
      • DeltaNet 的性能在较长序列时显著下降,原因是内存清除能力差
      • Mamba2 和 Gated DeltaNet 通过过滤无关信息的门控机制保持了更好的性能
    • Delta 规则有助于记忆 (Delta rule helps memorization)
      • 在 S-NIAH-3 中,值从数字变为 UUID,测试复杂模式记忆
      • Mamba2 的性能迅速下降,而 Gated DeltaNet 表现更好,验证了 Delta 规则确实具有更好的记忆能力

Algorithm: Hardware-efficient Chunkwise training

  • 在本小节中,论文推导出用于训练 Gated DeltaNet 的硬件高效分块算法。通过部分展开方程 10 中的递归,论文有
    $$\mathbf{S}_{[t]}^{r}=\underbrace{\mathbf{S}_{[t]}\left(\prod_{i=1}^{ r}\alpha_{[t]}^{i}\left(\mathbf{I}-\beta_{[t]}^{i}\boldsymbol{k}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i\mathsf{T} }\right)\right)}_{\text{:=F}_{[t]}^{r} }+\underbrace {\sum_{i=1}^{r}\left(\beta_{[t]}^{i}\boldsymbol{v}_{[t]}^{i}\boldsymbol{k}_{[t]} ^{i\mathsf{T} }\prod_{j=i+1}^{r}\alpha_{[t]}^{j}\left(\mathbf{I}-\beta_{[t]}^{j }\boldsymbol{k}_{[t]}^{j}\boldsymbol{k}_{[t]}^{j\mathsf{T} }\right)\right)}_{ \text{:=G}_{[t]}^{r} }$$
    • 很容易看出 \(\mathbf{F}_{[t]}^{r}=\gamma_{[t]}^{r}\mathrm{P}_{[t]}^{r}=\hat{\overline{\mathbf {P} }_{[t]}^{r} }\)
    • 至于 \(\mathbf{G}_{[t]}^{r}\),论文调整方程 5 如下:
      $$\mathbf{G}_{[t]}^{r}=\sum_{i=1}^{r}\frac{\gamma_{[t]}^{r} }{\gamma_{[t]}^{i} }\tilde{\mathbf{u} }_{[t]}^{i}\boldsymbol{k}_{[t]}^{i\mathsf{T} }\in \mathbb{R}^{d_{v}\times d_{k} }\qquad\tilde{\mathbf{u} }_{[t]}^{r}=\beta_{[t]}^{r }\left(\boldsymbol{v}_{[t]}^{r}-\sum_{i=1}^{r-1}\left(\tilde{\mathbf{u} }_{[t]}^ {i}(\frac{\gamma_{[t]}^{r} }{\gamma_{[t]}^{i} }\boldsymbol{k}_{[t]}^{i\mathsf{T } }\boldsymbol{k}_{[t]}^{r})\right)\right)\in\mathbb{R}^{d_{v} }$$
      • (证明见附录 A)
    • 通过 UT 变换,论文得到矩阵形式:
      $$\widetilde{\mathbf{U} }_{[t]}^{-}=\left[\mathbf{I}+\text{strictLower}\left( \operatorname{diag}\left(\beta_{[t]}\right)(\Gamma_{[t]}\odot\mathbf{K}_{[t]} \mathbf{K}_{[t]}^{\mathsf{T} })\right)\right]^{-1}\operatorname{diag}\left( \beta_{[t]}\right)\mathbf{V}_{[t]}\qquad\in\mathbb{R}^{C\times d_{v} }$$
  • 类似于 Mamba2 扩展线性注意力的方式(方程 1),我们可以调整 DeltaNet 的分块算法(方程 8-9)用于 Gated DeltaNet,以实现硬件高效训练,如下所示:
    $$
    \begin{align}
    \mathbf{S}_{[t+1]}&=\overrightarrow{\mathbf{S}_{[t]} }+\left( \overrightarrow{\mathbf{U} }_{[t]}-\overleftarrow{\mathbf{W} }_{[t]}\mathbf{S}_{[ t]}^{\mathsf{T} }\right)^{\mathsf{T} }\overrightarrow{\mathbf{K} }_{[t]} \in\mathbb{R}^{d_{v}\times d_{k} } \\
    \mathbf{O}_{[t]}&=\overleftarrow{\mathbf{O} }_{[t]}\mathbf{S}_{[t]} ^{\mathsf{T} }+\left(\mathbf{Q}_{[t]}\mathbf{K}_{[t]}^{\mathsf{T} }\odot \mathbf{M}\right)\left(\overrightarrow{\mathbf{U} }_{[t]}-\overleftarrow{ \mathbf{W} }_{[t]}\mathbf{S}_{[t]}^{\mathsf{T} }\right)\in\mathbb{R}^{C\times d_ {v} }
    \end{align}
    $$
    • 其中 \(\overleftarrow{\boldsymbol{q} }_{[t]}^{r}=\gamma_{[t]}^{r}\boldsymbol{q}_{[t]}^ {r}\), \(\overleftarrow{\mathbf{w} }_{[t]}^{r}=\gamma_{[t]}^{r}\mathbf{w}_{[t]}^{r}\), \(\overleftarrow{\boldsymbol{k} }_{[t]}^{r}=\frac{\gamma_{[t]}^{C} }{\gamma_{[t]} } \boldsymbol{k}_{[t]}^{r}\), 并且 \(\overrightarrow{\mathbf{S} }_{[t]}\neq\gamma_{[t]}^{C}\mathbf{S}_{[t]}\) 如同论文在方程 2 中的定义

Gated Delta Networks and Hybrid Models

  • Token 混合器块 (Token mixer block)
    • 基本的 Gated DeltaNet 遵循 Llama 的宏观架构,将 Token 混合器层与 SwiGLU MLP 层堆叠,但用门控 Delta 规则 Token 混合取代了自注意力
    • 图 1(右)显示了其块设计
      • 对于门控 Delta 规则(方程 10),查询、键和值 \(\{\boldsymbol{q},\boldsymbol{k},\boldsymbol{v}\}\) 通过线性投影、短卷积 (short convolution) 和 SiLU 生成,并对 \(\boldsymbol{q},\boldsymbol{k}\) 应用 L2 归一化以保持训练稳定性。\(\alpha,\beta\) 仅使用线性投影(论文对 \(\alpha\) 使用 Mamba2 的参数化,但为简洁起见省略了细节)
      • 遵循 Sun 等 (2023a),输出在应用输出投影之前通过归一化和门控处理
  • 混合模型 (Hybrid models)
    • 线性 Transformer 在建模局部偏移和比较方面存在局限性,并且其固定状态大小使得检索任务变得困难 (2023a)
    • 遵循最近的混合架构,如 Griffin (2024) 和 Samba (2024),论文将线性循环层与滑动窗口注意力 (sliding window attention, SWA) 相结合,产生了 GatedDeltaNet-H1
    • 论文还堆叠了 Mamba2、GatedDeltaNet 和 SWA,产生了 GatedDeltaNet-H2

Experiments

Setup

  • 论文的实验包括对最近最先进架构的全面比较,包括纯 Transformer 模型、基于 RNN 的方法和混合架构
  • 论文针对以下基线进行评估:RetNet (2023a)、HGRN2 (2024b)、Mamba (2023)、Mamba2 (2024a)、Samba (2024) 和 DeltaNet (2024b)
  • 为了公平比较,所有模型都在相同条件下训练,具有 1.3B 参数,使用从 FineWeb-Edu 数据集 (2024) 中采样的 100B Token
  • 论文使用 AdamW 优化器,峰值学习率为 4e-4,权重衰减为 0.1,梯度裁剪为 1.0
  • 学习率遵循余弦退火调度,具有 1B Token 的预热期 和 0.5M Token 的批次大小
  • 所有模型都使用词汇量为 32,000 的 Llama2 分词器
  • 对于序列建模,论文将训练长度设置为 4K Token ,Samba 和论文的混合模型使用 2K 的滑动窗口大小
  • 评估设置见附录 B.1,消融研究见附录 B.2

Common-sense reasoning

  • 在表 3 中,论文展示了具有 0.4B 和 1.3B 参数模型的语言建模困惑度以及在常识推理基准测试上的零样本准确率
  • Gated DeltaNet 在两个规模上都持续优于其他线性模型,包括 RetNet、HGRN2、Mamba、Mamba2 和 DeltaNet
  • 正如预期的那样,混合变体进一步提升了性能

In-context retrieval on real-world data

  • 表 4 展示了在 Arora 等 (2023a) 使用的真实世界回忆密集型任务上的结果
  • 正如预期的那样,线性循环模型与 Transformer 相比显示出显著的性能差距,而结合线性循环和注意力的混合模型在检索任务中优于纯注意力模型
  • 对于纯循环模型,尽管 DeltaNet 在合成的上下文检索任务上表现出色 (2024b),但其真实世界的检索性能落后于 Mamba2,这与论文在 S-NIAH-2 和 S-NIAH-3 中的观察一致(表 2)
  • Gated DeltaNet 由于其门控 Delta 规则,性能优于 DeltaNet 和 Mamba2,尽管改进幅度小于表 2
  • 论文将这种性能差距的缩小归因于未进行指令调优的小型语言模型容易产生重复错误,这些错误是这些任务中错误的主要来源(参见 Arora 等 (2023a, 附录 E))
  • 由于这个问题在很大程度上与更新规则的选择无关,因此模型之间的性能差异与表 2 相比不那么明显

Length extrapolation on long sequences

  • 如图 2 所示,论文评估了模型在六个长上下文基准测试中外推到长达 20K Token 序列的能力
    • 在 RNN 模型中,Gated DeltaNet 在所有任务中实现了最低的整体困惑度
    • 虽然论文在长度外推中观察到结果好坏参半,但 Gated DeltaNet 表现出相对更稳健的性能,表明其具有更好的内存管理能力
    • 混合模型通过利用注意力进行局部上下文建模,进一步改善了这一点,从而减轻了其循环组件在内存管理上的负担
    • 未来的工作将探索这些模型在更长序列上的能力

Long context understanding

  • 如表 5 所示,论文评估了模型在 LongBench (2023) 上的性能
  • 在循环模型中,Gated DeltaNet 显示出持续的优势,特别是在单文档问答、少样本上下文学习和代码任务中,分别展示了其在检索、上下文学习和状态跟踪方面的卓越能力

Throughput Comparison

  • 不同模型的训练吞吐量对比如图 3 所示
    • 正如论文的分析所示,论文提出的门控 delta 规则 (gated delta rule) 与原始 delta 规则相比仅引入了微小的开销,Gated DeltaNet 的吞吐量与 DeltaNet 基本相同
    • 由于它们具有更具表达力的转移矩阵 (transition matrices),两者都比 Mamba2 稍慢(约慢 2–3K tokens/sec)
  • 正如论文的分析所示,与原始 Delta 规则相比,提出的门控 Delta 规则仅引入了边际开销,Gated DeltaNet 实现了与 DeltaNet 基本相同的吞吐量
  • 由于它们具有更具表达力的转移矩阵,两者都比 Mamba2(2-3K Token /秒)稍慢
  • Transformer++ 在 2K 上下文窗口领域实现了最佳性能
    • 这得益于高度优化的 Flash-Attention-2 内核 (2023)
    • 因此,将 2K 窗口大小的滑动窗口注意力 (sliding window attention, SWA) 与其他令牌混合器 (token mixer) 相结合的混合方法,其吞吐量高于独立的混合器:Samba 优于 Mamba,而 Gated DeltaNet-H1 和 Gated DeltaNet-H2 优于 Gated DeltaNet
    • 值得注意的是,Gated DeltaNet-H1 在所有序列长度上均保持了可观的训练吞吐量,即使在短序列上也是如此

Related Work

Gated linear RNN

  • 大型线性循环语言模型因其训练和推理效率而备受关注
  • 线性循环神经网络领域已经从使用数据无关的衰减机制迅速发展为在更近期的架构中融入数据相关的衰减机制
    • 数据无关的衰减机制:例如 S4 (2022)、S5 (2023)、LRU (2023)、RWKV4/5 (2023) 和 RetNet (2023a) 等模型所例证
    • 数据相关的衰减机制:例如 HGRN1/2 (2023a; 2024b)、Mamba1/2 (2023; 2024a)、RWKV6 (2024)、GSA (2024)
  • 这一转变源于门控/遗忘机制(在 Mamba 中称为选择性机制)被证实的优势,这是一个源于门控循环神经网络文献 (2000) 的经典概念,其重要性一直被反复确认 (2015; 2018; 2023b; 2024b)
  • 现代遗忘门与传统设计(如 LSTM 中的遗忘门)的不同之处在于,它移除了对先前隐藏状态的依赖,仅依赖于输入数据
    • 这种修改使得在序列长度上能够实现高效的并行性 (2018; 2023b)
    • 缺乏遗忘门一直是 DeltaNet 的一个显著局限,而论文的门控扩展以一种自然、有效且硬件高效的方式弥补了这一差距
    • 论文也注意到最近的一项并行工作 RWKV-7 使用了类似的想法,但采用了更宽松的使用对角加低秩转移的形式化表示:
      $$\mathbf{S}_{t}=\mathbf{S}_{t-1}(\mathrm{diag}(\mathbf{d}_{t})-\mathbf{a}_{t}\mathbf{b}_{t}^{\top})+v_{t}\boldsymbol{k}_{t}^{\top}$$
    • 其中 \(\mathbf{d}_{t},\mathbf{a}_{t},\mathbf{b}_{t}\in\mathbb{R}^{d_{k} }\)
    • 分块算法可以类似地适用于这种情况,正如在 Flash Linear Attention (2024) 中所实现的那样

Delta rule

  • Delta 学习规则在记忆容量上表现出优于赫布学习规则 (1988; 1989),DeltaNet 利用了这一点,而线性 Transformer 则依赖于类赫布规则
    • 这种记忆容量优势在合成上下文学习任务中表现明显,并延伸到语言建模 (2021; 2024b)、强化学习 (2022b) 和图像生成 (2023)
    • (2024b) 将 delta 规则计算并行化,并证明了 DeltaNet 的数据相关的单位加低秩结构 (\(\mathbf{I}-\beta_{t}\boldsymbol{k}_{t}\boldsymbol{k}_{t}^{\top}\)) 比 Mamba2 的数据相关的对角矩阵 (\(\alpha_{\mathbf{I} }\)) 提供了更大的灵活性
    • 这种结构优势可能实现复杂的推理,包括正则语言识别 (2024; 2024) 和超越 TC’ 复杂度的状态跟踪 (2024),这对于编码和推理应用至关重要
  • 尽管有这些显著优势,delta 规则面临理论局限 (2023) 并且在真实世界数据集上表现中等 (2024b),表明仍有改进空间
    • 先前通过非线性循环 (2021, 2022a) 来增强表达力的尝试解决了一些局限,但牺牲了训练并行性,造成了性能-效率的权衡
    • 最近的工作提出了一些不损害并行性的增强方法,以获得更好的状态跟踪性能,包括使用负特征值 (2024) 和使用多个户主转移矩阵的乘积 (2025),这实现了高秩变换
    • 这些方法可以无缝应用于 Gated DeltaNet
  • 从(在线)学习目标的角度来看,替代的公式化可以进一步扩展表达力:
    • 非线性回归 (\(\mathcal{L}(\mathbf{S}_{t})=\frac{1}{2}||\boldsymbol{f}_{\mathbf{S}_{t} }(\boldsymbol{k}_{t})-\boldsymbol{v}_{t}||^{2}\)),如 TTT (2024a) 和 Titans (2024) 中那样,其中 \(\boldsymbol{f}_{\mathbf{S} }\) 是由 \(\mathbf{S}\) 参数化的非线性函数;
    • 或者考虑整个历史的回归 (\(\mathcal{L}(\mathbf{S}_{t})=\frac{1}{2}\sum_{i=1}^{d}||\mathbf{S}_{t}\boldsymbol{k}_{i}-\boldsymbol{v}_{t}||^{2}\)),如 Mesa 层 (von 2024) 中那样——类似于最小均方算法和递归最小二乘算法之间的区别
    • 然而,这些更具表达力的变体引入了非线性循环,并且需要变通方法,例如在处理完整个分块后才执行非线性更新(如在 TTT 和 Titans 中);或者近似非线性循环方法,如 (2024; 2024; 2025)

Hybrid models

  • 在这项工作中,论文探索了在层间交错混合注意力层,这在诸如 MiniMax-01 (2025) 和 Hybrid Mamba2-Attention (2024) 中常用
  • 研究在单个层内混合线性/softmax 注意力也很有趣 (2022a; 2024; 2025)

NLP——EfficientAttention-Survey(THU-2025)

注:本文包含 AI 辅助创作

  • 参考链接:
    • 原始论文:(Attention Survey of THU 2025)Efficient Attention Mechanisms for Large Language Models: A Survey, THU, 20250807
      • 有豆包基金的支持

Paper Summary

  • 整体总结:
    • 本文是一篇高效 Attention 的 Survey
    • 整体涉及到了 Linear Attention 和 Sparse Attention
  • 问题提出:
    • Transformer-based 架构已成为 LLM 的主流 backbone
    • 自注意力机制存在的二次时间和内存复杂度,是实现高效长上下文建模的根本障碍
  • 当前解法:
    • 为解决这一限制,近期研究引入了两大类高效注意力机制
    • 机制一:Linear Attention 方法:
      • 通过核近似、循环公式或快速权重动态来实现线性复杂度,从而以更低计算开销实现可扩展推理
    • 机制二:稀疏注意力 (Sparse Attention) 技术:
      • 基于固定模式、分块路由或聚类策略,将注意力计算限制在选定的 Token 子集上,在保持上下文覆盖的同时提升效率
  • 本综述主要工作如下:
    • 第一:对上述机制的发展进行了系统而全面的概述,结合了算法创新和硬件
    • 第二:分析了高效注意力在大规模预训练语言模型中的整合方式
      • 包括完全基于高效注意力构建的架构,以及结合局部和全局组件的混合设计

综述引言和一些讨论

  • Transformer-based 架构 (2017) 已成为现代 LLM 事实上的 backbone 选择
    • 标准的自注意力机制仍然是一个显著的计算瓶颈,其时间和内存复杂度相对于输入序列长度呈二次方增长
    • 这一限制对 LLM 在处理日益增长的长上下文时,同时保持高性能和高效率,提出了巨大挑战
  • 为解决此问题,出现了两个主要方向来降低 softmax 注意力 (softmax Attention) 的时间和空间复杂度
    • 第一种机制是Linear Attention (2020; 2020; 2021; 2021; 2023; 2024):
      • 通过将 softmax 注意力重新参数化或近似为线性操作来降低注意力复杂度
    • 第二种候选方案是稀疏注意力 (Sparse Attention) (2019; 2021; 2021; 2022; 2024):
      • 基于固定或动态的稀疏模式,将注意力计算限制在完整键空间的某个子集上
    • 虽然这两种方法都旨在提高效率,但它们在公式、设计选择和硬件影响方面存在显著差异
  • 本综述全面回顾了高效注意力 (Efficient Attention) 机制的最新进展,同时关注算法原理和系统级实现
    • 在此基础上,论文还研究了采用这些高效注意力的预训练 LLM
  • 论文将线性注意力方法分为三大范式
    • 范式一:核化线性注意力 (kernelized linear attention):
      • 通过特征空间内的内积来近似 softmax 核,借助随机特征映射 (2020; 2021) 或固定正映射 (2020) 实现线性复杂度
    • 范式二:带有遗忘机制 (forgetting mechanism) 的循环线性注意力 (recurrent linear attention):
      • 引入了位置感知的循环,通过数据无关 (2021) 或数据相关 (2023; 2023) 的衰减机制来控制过去信息随时间衰减的方式,从而实现对长序列的建模
    • 范式三:基于快速权重 (fast-weight) 和元学习 (meta-learning) 的公式:
      • 将线性注意力重新解释为在线优化的记忆更新过程,其中如 DeltaNet (2021; 2024) 和 TTT (2024) 等模型将快速学习动态直接整合到状态演化中
  • 作者还探讨了线性注意力在硬件友好的表示形式(包括并行、循环和分块形式)重点分析了它们在计算复杂度、内存占用以及与训练或推理工作流兼容性方面的权衡
  • 论文将稀疏注意力分为固定模式稀疏度 (fixed-pattern sparsity)、分块稀疏度 (block sparsity) 和基于聚类的稀疏度 (clustering-based sparsity)
    • 固定模式稀疏度采用静态的 Token-level 掩码,如滑动窗口、扩张位置或指定的全局 Token ,提供了简单性和硬件友好性 (2020; 2019; 2023; 2023)
    • 分块稀疏度在分块粒度上选择或路由注意力,可以通过启发式评分 (2021; 2022; 2025)、可训练的选通机制 (2024; 2024) 来实现,从而实现结构化的内存访问和高效的 GPU 利用率
    • 基于聚类的稀疏度使用基于内容或位置感知的分组方法(如 k-means 或 LSH)来组织键值对,以降低内存开销并促进基于语义感知的检索 (2023; 2020; 2024)
  • 最后,本综述还讨论了将稀疏模式扩展到编码器风格模型的双向稀疏设计
    • 这些方法在稀疏粒度、选择机制以及与 FlashAttention (2023) 等硬件原语的契合度上有所不同,共同代表了现代 Transformer 中实现高效长上下文建模的基础
  • 近期已有文章将高效注意力机制整合到工业级预训练语言模型中
    • 这包括纯高效架构:如线性注意力和状态空间模型,以及结合了局部和全局注意力模式的混合设计
      • 像 EAGLE (2024)、Falcon Mamba (2024) 和 MiniCPM4 (2025) 这样的模型展示了纯线性或稀疏方法在数十亿参数规模上的可扩展性,在提供强性能的同时实现了恒定时间推理
    • 同时,混合模型 (2020; 2022; 2024; 2025; 2025; 2025) 交错使用密集、稀疏和局部注意力,以平衡计算效率和上下文建模能力
      • 反映了现代 LLM 中朝着组合化、硬件感知注意力设计发展的趋势
  • 论文的目标是为理解注意力机制在算法和硬件双重约束下的演进,以及这些设计如何被整合到可扩展的 LLM 架构中,提供一个统一的框架
    • 通过将理论洞见与实际实现相结合,作者希望本综述能为致力于高效且可部署模型设计的研究者和从业者提供有价值的参考
  • 论文将讨论安排如下:
    • 第 2 节:介绍Linear Attention,涵盖其在不同模型世代中的演变、相关的设计原理以及对硬件实现的影响
    • 第 3 节:介绍稀疏注意力 (Sparse Attention),对稀疏模式进行分类,分析部署场景,并提供实用的系统级设计建议
    • 第 4 节:回顾整合了高效注意力机制的预训练语言模型 (Pretrained Language Models),包括统一高效架构和整合了局部、稀疏、密集注意力的混合模型
    • 第 5 节:对未来的发展方向进行展望 (Outlook),讨论算法和硬件对齐研究中的开放挑战和潜在进展

Linear Attention

  • 传统的线性注意力方法旨在以与序列长度成线性关系的方式近似基于 softmax 的注意力机制
    • 核心思想是用一种基于核函数(kernel)的注意力权重近似来替代计算代价高昂的 softmax 计算
  • 在标准的自注意力中,每个输出是值 \(V\) 的加权和,权重由查询-键相似度经过 softmax 得到:
    $$
    \text{Attn}(Q,K,V)=\text{softmax}(QK^{\top})V, \tag{1}
    $$
    • \(Q,K,V \in \mathbb{R}^{L \times d}\)
    • (\(L\) 是序列长度
    • \(d\) 是每个头的模型维度)
  • softmax 为查询 \(q_i\) 和键 \(k_j\) 产生权重 \(\propto \exp(q_i^{\top}k_j)\)
  • 核化线性注意力则寻找一个特征映射 \(\phi(\cdot)\),使得 softmax 核函数可以在诱导出的特征空间中通过一个简单的点积来近似:\(\exp(q^{\top}k) \approx \phi(q)^{\top} \phi(k)\)(2019)
    • 给定这样的 \(\phi\),可以将注意力重写为:
      $$
      O = \frac{\phi(Q)(\phi(K)^{\top}V)}{\phi(Q)(\phi(K)^{\top}\mathbf{I})} \tag{2}
      $$
  • 由于 \(\exp(\cdot)\) 的值域是非负的,\(\phi(\cdot)\) 通常被选择为产生非负输出,同时应用归一化除数来模拟 softmax 概率
  • 这种重新表述将复杂度从 \(O(L^2 d)\) 降低到 \(O(L d^2)\)(或者在适当的特征降维下甚至可以达到 \(O(L d)\)),因为昂贵的 \(L \times L\) 注意力矩阵从未显式形成
  • Linear Transformer (2020)
    • 用一个固定的正特征映射替换了 softmax 核函数
    • 在实践中,他们设置 \(\phi(x) = \operatorname{ELU}(x) + 1\)
    • \(\operatorname{ELU}(\cdot)\) 在整个定义域内可微,与朴素的 \(\operatorname{ReLU}(\cdot)\) 函数相比表现出更好的性能
  • Performer (2020)
    • 引入了 FAVOR+(一种能够无偏估计 softmax 核函数的随机特征方案)
    • 它对随机特征映射 \(\phi\) 进行采样,使得 \(E[\phi(Q)\phi(K)^{\top}] = \exp(QK^{\top})\)
    • 这产生了一个仅需 \(O(N)\) 操作即可证明是完整 softmax 注意力的无偏估计量
    • 特别地,Performer 使用正交互随机特征,降低了近似中的方差
  • 随机特征注意力 (Random Feature Attention,RFA) (2021)
    • 一种通过为 softmax 核函数使用随机傅里叶特征构建的线性注意力
    • 与 Performer 类似,RFA 利用随机映射和三角激活来近似 softmax
    • RFA 在随机投影之前进一步对查询和键进行归一化以减少方差
    • RFA 还有一个变体 RFA-Gate,它增加了一个可选的选通机制以引入近因偏差
  • cosFormer (2022)
    • 使用余弦函数来近似 softmax
    • 由于 \(\cos(a+b) = \cos a \cos b - \sin a \sin b\),cosFormer 将余弦重加权注意力 \(S_{ij} = Q_i’ K_j’ \cos(\frac{\pi}{2} \times \frac{i-j}{M})\) 分解为线性注意力的形式
  • HedgeDog (2024)
    • 利用了一个尖峰核函数 \(\phi(x) = \exp(Wx + b)\),因为他们观察到 Transformer 和线性 Transformer 之间的性能差距源于缺乏尖峰和单调性属性
    • HedgeDog 展示了更好的注意力熵和单调性

Linear Attention with Forgetting Mechanism

  • 最近的一系列工作通过循环神经网络或连续状态空间模型的视角来解释注意力
  • 传统的线性注意力通常是无位置感知的,其中循环顺序对输出没有影响,但现代的线性注意力表现得更像具有状态追踪和隐藏记忆的 RNN
  • 因此,这些模型明确地结合了循环、门控或状态动力学,以线性复杂度处理长序列
  • 衰减因子是引入遗忘机制的最重要因素
Data-Independent Decay
  • Retentive Networks(RetNet) (2023)
    • RetNet 引入了一种 Retention 机制,它使用固定的衰减系数,以一种循环风格的更新来取代注意力
    • 在 RetNet 层中,每个时间步 \(t\) 维护一个状态向量 \(s_t\),该向量以指数遗忘的方式聚合过去的输入
    • 循环可以写为:
      $$
      S_t = \gamma S_{t-1} + k_t^{\top} \nu_t \tag{3}
      $$
      • \(\gamma \in (0,1)\) 是一个学习的衰减因子(每个保持头)
      • \(k_t^{\top} \nu_t\) 是来自当前 Token 的新贡献(\(\nu_t\) 是 \(x_t\) 的值投影,\(k_t\) 是键投影)
    • 然后通过一个线性“查询”投影获得输出:\(o_t = q_t s_t\);展开方程 3 得到保持的显式公式:
      $$
      o_t = q_t S_t = \sum_{n=1}^{t} \gamma^{t-n} q_t k_t^{\top} \nu_t \tag{4}
      $$
    • 这表明 Token \(n\) 的贡献在时间步 \(t\) 时以因子 \(\gamma^{t-n}\) 呈指数衰减
    • Crucially,\(\gamma\) 是一个数据无关的衰减,是层的固定参数(在多头保持中通常每个头一个),而不是输入内容的函数
      • 这使得 RetNet 能够像 RNN 一样进行 \(O(1)\) 的内存更新,同时仍然允许通过等效的矩阵公式在训练期间进行并行计算(例如,可以证明方程 3 等价于一个“保持矩阵”形式,\(\text{Retention}(X) = (Q K^{\top} \odot D) V\),其中对于 \(t \geq n\),\(D_{t,n} = \gamma^{t-n}\) 实现了衰减和因果掩码)
    • RetNet 的保持机制与其他数据无关的循环模型有共同的主题
  • Eagle (2024)
    • 通过外积记忆改进了 RWKV 设计,这等效于线性注意力
    • 在 RWKV 系列中,衰减因子参数化为 \(\gamma = \exp(-\exp(w))\),其中 \(w\) 是一个数据无关的可学习因子
    • 在实践中,RetNet 和 Eagle 都使用固定衰减来遗忘旧信息,实现了线性推理扩展和具有竞争力的性能
    • 经验上,RetNet 每个头使用一个固定标量 \(\gamma\)(通常每层有多个具有不同 \(\gamma\) 值的保持头,形成一种多尺度衰减),而 Eagle 使用可学习的标量 \(w\) 来参数化衰减因子
  • Lightning Attention (2023; 2024)
    • Lightning Attention 也提出了一种线性注意力层 ,每个头增加了一个固定标量衰减,以实现长度不变的计算速度
    • 在 Lightning Attention 中,对于某个常数 \(\lambda\)(\(\lambda\) 由模型学习或设置),隐藏状态本质上是 \(s_t = \lambda s_{t-1} + k_t^{\top} \nu_t\),这与 RetNet 的 \(\gamma\) 精神相同,但针对硬件效率进行了优化
  • H3 (2022)
    • H3 将循环状态空间模型 (2021) 引入线性注意力,通过 SSM 为键值外积隐藏状态使用学习的、数据无关的指数衰减
    • 线性注意力通过分块计算实现了高效的训练,但 H3 需要为 SSM 计算显式状态扩展,从而限制了头维度,导致表达能力有限
  • In summary,数据无关的衰减方法维持一个随时间以预定速率衰减的持久状态,实现了 \(O(1)\) 循环和每步恒定内存
    • 它们牺牲了一些适应性 ,这促使了在更近期的模型中引入数据相关机制
Data-Dependent Decay
  • 虽然固定衰减提供了简单性和速度,但它们可能未能充分利用输入流中的信息
  • 门控或数据相关的方法使遗忘因子本身成为当前输入的学习函数
  • 这种循环更新的一般形式是:
    $$
    S_t = G_t S_{t-1} + k_t^{\top} \nu_t \tag{5}
    $$
    • \(S_{t-1}\) 是前一个状态
    • \(G_t\) 是由 Token \(x_t\) 确定的门控张量
      • 如果 \(G_t\) 在某个分量上接近 0,则该分量中的过去状态在时间 \(t\) 被大量遗忘;
      • 如果 \(G_t \approx 1\),则过去状态被保留
    • 与 RetNet 中的常数 \(\gamma\) 不同,这里的 \(G_t\) 通过 \(x_t\) 随 \(t\) 变化
  • 在大语言模型设计中,这种策略的两个显著例子是 Mamba (2023) 和门控线性注意力 (Gated Linear Attention, GLA) (2023)
  • Mamba
    • 一种循环状态空间模型,它赋予状态衰减率以输入依赖性
    • 在每个 Mamba 层中,基本的状态演化类似于 S4 (2021),但状态矩阵实际上变得动态
    • \(G_t\) 是一个范围在 0 到 1 之间的分组向量,作为动态遗忘门
      • 这弥合了注意力和纯 SSM 之间的差距
    • 实证结果表明,Mamba2 在语言建模任务上可以超越相似甚至更大规模的 Transformer,凸显了数据相关衰减在长序列建模中的能力
  • GLA
    • 直接在线性注意力中引入了门控机制,将一个门控函数嵌入到线性化的注意力层中以提高其表达能力
    • GLA 通过一个可学习的 Element-wise 遗忘门 \(G_t\) 来修改保持循环
  • 除此之外,其他几个模型同样赋予了其循环以内容相关的门控
  • xLSTM (2024)
    • 用线性门控信号(带归一化)的指数变换取代了标准的 sigmoid 遗忘门,对其单元状态产生了平滑的、输入条件化的衰减
  • GateLoop (2023)
    • 在保持机制上应用了头级门控,实现了一个简单但有效的数据相关衰减,同时保持了高效的硬件实现
  • HGRN (2023)
    • 在线性 RNN 中引入了门控循环
    • HGRN2 (2024) 进一步在 HGRN 框架中增加了状态扩展
    • 状态扩展等效于线性注意力中的键值外积
  • Finch (2024)
    • 在 Eagle 上使用了数据相关的门控
    • 由于 Eagle 与其他正交修改的保持机制相似,Finch 也显示出与上述模型的深厚联系
  • In summary,数据相关衰减模型通过基于内容的门控来增强线性注意力或 RNN 风格的架构,这些门控控制着信息的流动
    • 论文中的结果表明,这些模型在语言任务上通常能够匹配或超越 Transformer 的性能 ,同时能够扩展到非常长的输入

Linear Attention as In-Context Learners

  • 除了线性注意力机制带来的效率提升外,一项重大进展在于它们应用于增强上下文学习能力
    • 这指的是模型能够从给定的 Prompt 中快速适应或学习,而无需对其预训练权重进行显式的梯度更新
    • 理解:这里是说注意力机制本身是上下文学习期,即增加一些 Prompt Token,就可以借助修改注意力内容来实现对模型的快速学习
  • 大型 Transformer 模型通过将 Prompt 解释为一种训练数据的形式,固有地表现出上下文学习
    • 但最近的创新已将快速学习规则直接整合到注意力机制中,有效地将序列处理视为一个在线训练过程
    • 问题:这里的方法具体是什么?
  • FWP (2021) 建立了现有线性注意力机制与快速权重编程器之间的形式等价性
    • 在 FWP 范式中,一个慢速神经网络学习去编程另一个神经网络的“快速权重”,通常通过自发明的键和值模式的外积加法来实现
  • 表 1:
    • 不同线性注意力变体的更新规则
    • 每个模型都是关于矩阵记忆 \(S_t\) 的循环
Learning Objective
  • 从元学习的角度来看,这些模型定义了一个在推理过程中优化的隐式学习目标
  • 用 \(q_t, k_t, \nu_t\) 表示时间步 t 的查询、键和值,上下文记忆 \(S_t\) 通过以下目标进行优化:
    $$
    \mathcal{L}_t(S) = \frac{1}{2} || f_S(k_t) - \nu_t ||^2 \tag{6}
    $$
  • DeltaNet
    • DeltaNet 融合了经典的 delta 规则 (2021),其中 \(f_S(k_t) = S k_t\)
    • DeltaNet 更新规则为
      $$ S_t = S_{t-1} + \eta_t (\nu_t - S_{t-1} k_t) k_t^{\top} $$
      • 可以通过最小化当前记忆检索 \(S_{t-1} k_t\) 与新值 \(\nu_t\) 之间的误差推导出来
      • 这标志着朝着在线学习键值映射、基于即时上下文改进记忆迈出了一步
  • TTT (2024)
    • TTT 用不同的建模架构概括了元学习目标:
      $$
      f_S(k_t) = \begin{cases}
      \text{LM}(S k_t) + k_t, & \text{TTT-Linear} \\
      \text{LM}(\text{MLP}_S(k_t)) + k_t, & \text{TTT-MLP}
      \end{cases} \tag{7}
      $$
    • 上下文网络 \(f_S\) 增强了上下文元学习的能力
    • 但由于 \(f_S\) 的梯度比简单的线性投影复杂得多,在线更新不能写成一个简单的规则
  • 批量更新 (Batch Update)
    • 批量更新试图解决当 \(f_S\) 作为神经网络工作时训练并行性的困难
    • 通常,上下文记忆是以批大小为 1 进行元学习的,这对于一般的 TTT 模型来说不可行
    • 相反,类似于分块并行,TTT 将整个块视为一个批次
    • 在批次内没有状态更新(即 \(S\) 保持不变)
    • 处理完批次后,\(S\) 使用来自批次中所有样本的聚合梯度或更新信号进行一次更新
    • 这种策略在保持并行效率的同时,适应了更复杂架构的训练要求
  • Momentum Titans (2025)
    • Titans 中引入了优化中常用的动量,以加强记忆更新机制的能力:
      $$
      \mathcal{M}_t = (1 - \alpha_t) \mathcal{M}_{t-1} + S_t \\
      o_t = q_t \mathcal{M}_t \tag{8}
      $$
    • 动量项允许记忆通过对状态 \(S\) 进行指数移动平均来逐渐积累信息
    • 这可以看作是元学习的一种形式,其中更新规则本身学会了在长序列上更加稳定和鲁棒
  • 权重衰减 (Weight Decay)
    • 权重衰减是训练中的另一种正则化技术,对应于线性注意力模型中的遗忘机制
    • Gated DeltaNet (2024) 和 Titans 在其记忆更新中使用了权重衰减,作为学习的遗忘门来限制非常旧或嘈杂数据的影响
    • 它对应于在 RetNet (2023) 和 Mamba (2023) 等架构中发现的选择性状态保持机制,其中衰减机制被证明对语言建模性能至关重要:
      $$
      S_n = \gamma_n S_{t-1} + \eta_t (v_t - S_{t-1} k_t) k_t^{\top} \tag{9}
      $$
  • In summary,线性注意力机制中的这些进步通过明确地将元学习原则整合到其架构中,正在推动上下文学习的边界
    • 通过快速权重更新、复杂的内存管理技术和在线学习规则,这些模型正朝着一种范式发展,在这种范式中,训练和推理之间的区别变得越来越模糊,从而产生了能够直接从上下文中学习和利用知识的更高效、适应性更强的大语言模型

Discussion on Other Designs

Element-wise Linear Attention
  • 无需注意力的 Transformer (Attention-Free Transformer) (2021)
    • 利用一个简单的权重 \(\exp(K_{\nu’} + w_{t,\nu’})\) 来代替 \(\exp(QK^{\top})\):
      $$
      O_t = \sigma_q(Q_t) \odot \frac{\sum_{\nu’=1}^{t} \exp(K_{\nu’} + w_{t,\nu’}) \odot V_{\nu’} }{\sum_{\nu’=1}^{t} \exp(K_{\nu’} + w_{t,\nu’})} \tag{10}
      $$
    • 其中 \(w_{t,\nu’}\) 是学习的成对位置偏置
    • 在 AFT 的变体中,AFT-Simple 移除了 \(w_{t,t’}\),实现了线性化的推理模式
    • 由于 \(K\) 和 \(V\) 的乘积是 Element-wise 的,循环状态大小是 \(\mathbb{R}^d\) 而不是外积状态 \(\mathbb{R}^{d \times d}\)
  • RWKV (2023)
    • 在 AFT-Simple 上利用了衰减机制。具体来说,RWKV 通过指数衰减 \(w_{t,i} = -(t-i)w\) 改进了 AFT 的位置偏置
    • 指数形式在引入位置偏置的同时保留了循环属性
  • Element-wise 线性注意力带来了强大的推理优势,但它受到状态大小瓶颈的影响,性能低于基于矩阵的状态大小
    • Besides,尽管 Element-wise 内存比外积内存快得多,但由于其他组件在拥有外积内存的情况下占据了超过 95% 的延迟 (2023),端到端的优势仍然有限
Multi-Pass Linear Attention
  • 带有限内存控制的注意力 (Attention with Bounded-memory Control)
    • 将线性注意力视为一个有界内存模块:
      $$
      \begin{split}
      \tilde{K}_n &= \sum_{i=1}^{n} K_i \otimes \phi_i, \quad \tilde{V}_n = \sum_{i=1}^{n} V_i \otimes \phi_i \\
      O_n &= \text{softmax}(Q_n \tilde{K}_n^{\top}) \tilde{V}_n
      \end{split} \tag{11}
      $$
    • 其中 \(\tilde{K}_n, \tilde{V}_n\) 是在线更新的、大小受限的键和值。在实现中,ABC 可以简化为两遍线性注意力
  • 门控 Slot 注意力 (Gated Slot Attention) (2024)
    • 进一步将 GLA 引入到 ABC 框架 (2021) 中
    • 由于 \(\tilde{K}_n, \tilde{V}_n\) 作为一个隐式的线性注意力工作,GSA 将更新改进为门控形式:
      $$
      \tilde{K}_n = \text{Diag}(\alpha_n) \tilde{K}_{n-1} + (1 - \alpha_n) \otimes K_n, \quad \tilde{V}_n = \text{Diag}(\alpha_n) \tilde{V}_{n-1} + (1 - \alpha_n) \otimes V_n \tag{12}
      $$
  • Multi-Pass 是增强线性注意力表达能力的一种有效方式,但它也带来了额外的计算开销,这使得架构设计需要在训练效率和性能之间进行权衡
Bidirectional Linear Attention
  • 双向注意力在编码器风格的架构(如 BERT (2019))中扮演着重要角色
  • 单向和双向注意力在线性表述中的关键区别在于推理瓶颈和计算模式
  • Decoder-only 模型通常表现出 \(O(N^2)\) 的复杂度,且 Decoder-only 模型中的每个 Token 都可以访问全局信息
    • 因此双向线性注意力通常维护一个恒定长度的全局 Token 池以降低复杂度,同时保留 softmax 函数的使用
  • 例如
    • Linformer (2020) 通过一个额外的矩阵投影将键和值的数量减少到一个恒定长度
    • Luna (2021) 通过跨模型层编码全局 Token 池进一步扩展了 Linformer 设计
  • 虽然双向线性注意力对 Decoder-only 架构有效,但这些设计在应用于因果设置时面临重大挑战,因为基于全局池的方法往往计算成本高昂
    • 因此,此类架构不太适合大语言模型

Hardware Implementation

  • 并行表示 (The Parallel Representation)
    • 可以将带门控衰减的因果线性注意力定义为:
      $$
      \begin{split}
      Q &= \phi(XW_Q), \quad K = \phi(XW_K), \quad V = XW_V, \quad \gamma = f_\gamma(X) \\
      D_{nm} &= \begin{cases}
      \prod_{i=m+1}^{n} \gamma_i, & n \geq m \\
      0, & n < m
      \end{cases}, \quad O(X) = \text{LN}((Q K^{\top} \odot D) V)
      \end{split} \tag{13}
      $$
    • 其中
      • \(W_Q, W_K, W_V \in \mathbb{R}^{d \times d}\)
      • \(f_\gamma\) 控制衰减的锐度
      • 矩阵 \(D \in \mathbb{R}^{N \times N}\) 编码了具有衰减模式的因果掩码,确保信息的单向流动
    • 当衰减是数据无关时,\(f_\gamma(\cdot) = \text{const} \in (0,1]\)
      • 注意,线性注意力后的分组归一化 (GroupNorm) (2018) 已经是一个强制性的组件 (2023),方程 2 中核化线性注意力的显式除数变得不必要
    • 并行表示简单易懂,但有两个缺点
      • 第一:并行形式仍然保持了与 softmax 注意力相同的 \(O(N^2)\) 复杂度
      • 第二:当表示第 2.3 节中的 ICL 风格线性注意力时,其复杂性会增加
  • 循环表示 (The Recurrent Representation)
    • 上述并行公式可以等效地表达为逐步解码的循环形式,如图 2(b) 所示,在每个时间步 \(n\),输出计算为:
      $$
      S_n = f_{\text{update} }(S_{n-1}, K_n, V_n), \quad O_n = Q_n S_n \\
      f_{\text{update} }(S_{n-1}, K_n, V_n) = \begin{cases}
      \gamma_n S_{n-1} + K_n^{\top} V_n, & \text{Linear Attention with Decay} \\
      \gamma_n S_{n-1} + \eta_n (V_n - S_{n-1} K_n) K_n^{\top}, & \text{ICL-style Linear Attention}
      \end{cases} \tag{14}
      $$
    • 这种循环表示通过维护单个状态向量 \(S_n\),实现了具有恒定内存的高效自回归生成
    • 虽然循环表示将计算复杂度从 \(O(N^2)\) 降低到 \(O(N)\),但它在训练期间会产生大量的内存开销
      • 因为 \(S_n\) 涉及存储 \(K_n\) 和 \(V_n\) 的外积,这对于长序列来说是极其昂贵的
      • 因此,循环形式通常仅限于解码阶段
  • 分块循环表示 (The Chunkwise Recurrent Representation)
    • 分块表示结合了线性复杂度和硬件友好的并行性的优势 (2022; 2023)
    • 如图 2(a) 所示,以衰减风格线性注意力为例,给定一个块大小 \(B\),令 \(x_{[i]}\) 表示第 \(i\) 块。定义块内的累积衰减为:
      $$
      \beta_{(i-1)B+j} = \prod_{k=(i-1)B+1}^{(i-1)B+j} \gamma_k, \quad D_{[i]}(j,k) = \begin{cases}
      \frac{\beta_{(i-1)B+k} }{\beta_{(i-1)B+j} }, & j \leq k \\
      0, & \text{Other wise}
      \end{cases} \tag{15}
      $$
    • 块级记忆状态 \(R_i\) 计算为:
      $$
      R_i = K_{[i]}^{\top} (V_{[i]} \odot \frac{\beta_{iB} }{\beta_{[i]} }) + \beta_{iB} R_{i-1} \tag{16}
      $$
    • 第 \(i\) 块的输出为:
      $$
      O_{[i]} = (Q_{[i]} K_{[i]}^{\top} \odot D_{[i]}) V_{[i]} + (Q_{[i]} R_{i-1}) \odot \beta_{[i]} \tag{17}
      $$
    • 这种表述提供了循环和并行的统一视图:
      • 第一项捕获块内依赖关系
      • 第二项通过单个矩阵-向量乘积传播块间记忆
      • 由于其效率和可并行性,分块表示通常在训练和预填充阶段被采用
  • 对于 ICL 风格的线性注意力,已经使用 Householder 变换 (2024) 开发了硬件友好的分块表示
    • 但对于更复杂的变体,如 TTT 和 Titans,构建显式的分块形式仍然具有挑战性
    • 相反,这些架构通常依赖于大的批次大小进行记忆更新,通过固定超参数有效地模拟分块计算
  • 内核级优化对于实现高性能至关重要
    • 广泛采用的 FLA (2024) 为许多常见的线性注意力模块提供了基于 Triton 的实现
    • 或者,开发者还提供了用 CUDA 或 TileLang (2024) 编写的自定义实现,可以用于进一步加速
  • 图 2:线性注意力的双重形式

Sparse Attention

  • 稀疏注意力(Sparse Attention)方法利用注意力计算中固有的稀疏特性,通过以下公式近似完整注意力计算:
    $$
    \text{Attn}(Q,K,V) = \text{softmax}(QK_{[\mathcal{S}]}^T)V_{[\mathcal{S}]} \tag{18}
    $$
    • 其中 \(\mathcal{S}(t)\) 是查询向量 \(Q(t)\) 所关注索引的子集
    • 不同的方法基于选择准确性和硬件效率的考量,设计了不同的选择标准 \(\mathcal{S}(t)\)
    • 目标是在预填充(prefilling)阶段实现亚线性或线性复杂度,或在解码(decoding)阶段控制固定的计算预算

Fixed-pattern Sparse Attention (固定模式稀疏注意力)

  • 一些研究工作利用 Token-level 别稀疏性的结构化模式,为注意力计算构建固定的稀疏掩码
  • Local Window Attention (局部窗口注意力)
    • 局部窗口注意力将每个查询限制在仅与固定滑动窗口 \(w\) 内的相邻 Token 交互,从而在保留局部上下文的同时降低内存和计算需求
    • Sparse Transformer (2019) 首先应用局部窗口(行)注意力,其中 \(w\) 接近 \(\sqrt{N}\),然后通过额外的列注意力来总结先前位置并在全局传播信息
    • GPT-3 (2020) 也采用了与 Sparse Transformer 类似的稀疏注意力模式
    • StreamingLLM (2023) 发现大量的注意力得分被分配给了输入序列的初始 Token ,他们称之为“注意力汇集点(attention sink)”。他们提出了一种简单的固定模式注意力,只保留汇集点 Token 和滑动窗口内的 Token 。例如,给定一个长度为 \(n\) 的输入序列,StreamingLLM 中查询 Token \(q_t\) 的选定 Token 子集 \(\mathcal{S}(t)\) 被定义为:
      $$
      \mathcal{S}(t) = \left\{ j \mid 0 \leq j \leq s \ \lor \ t-w \leq j \leq t \right\}, \ \forall t \in [1, n] \tag{19}
      $$
      • 其中 \(s\) 是汇集点 Token 的大小,\(w\) 是滑动窗口的大小
      • 为了获得更好的硬件效率,采用块粒度的 StreamingLLM (2024) 以块方式保留汇集点 Token 和局部 Token ,从而实现高效的内存加载和计算
  • Dilated Attention (膨胀注意力)
    • LongNet (2023)
      • 引入了膨胀注意力作为长上下文训练和推理的固定稀疏模式。膨胀注意力随着距离的增长呈指数级扩展注意力范围,从而将注意力的复杂度从 \(O(n^2)\) 降低到 \(O(n)\)。具体来说,在沿序列维度将输入分割成长度为 \(w\) 的片段后,从每个片段中以间隔 \(r\) 选择膨胀稀疏索引。第 \(i\) 个片段的选择索引为:
        $$
        \hat{I}_i = \left[iw, iw+r, iw+2r, …, (i+1)w-1 \right] \tag{20}
        $$
      • 稀疏化的片段 \(Q_{\hat{I}_i}, K_{\hat{I}_i}, V_{\hat{I}_i},\ i \in \{0, 1, …, \frac{n}{w}\}\) 被并行输入到注意力计算中,得到注意力输出 \(O\)。结合不同片段大小和膨胀率 \(\{r_i, w_i\}^{k}\) 的注意力输出,最终注意力计算如下:
        $$
        O = \sum_{i=1}^{K} \alpha_i O|_{r_i, w_i}, \quad \alpha_i = \frac{s_i}{\sum_{j} s_j} \tag{21}
        $$
      • 其中 \(s_i\) 表示 \(O|_{r_i, w_i}\) 的注意力 softmax 的分母
  • LogSparse (2019)
    • 采用指数稀疏注意力方案,其中每个位置仅关注 \(\log N\) 个 Token ,这可以看作是指数膨胀注意力的一个实例

Block Sparse Attention (块稀疏注意力)

  • 给定一个长度为 \(n\)、块大小为 \(b\) 的输入序列,可以将 \(Q, K, V \in \mathcal{R}^{n \times d}\) 各自划分为 \(\frac{n}{b}\) 个块,每个块大小为 \(b \times d\)
  • 目标是近似一个块级掩码 \(M \in \{0, 1\}^{n/b \times n/b}\),用于选择关键块进行计算,如图 3 所示
    $$
    \text{Attn}(Q, K, V)_i = \sum_{j=1}^{n/b} M_{ij} \cdot \text{softmax}(Q_i K_j^T) V_j \tag{22}
    $$
  • 块级选择对于在现代 GPU 上实现高效计算至关重要
  • 图 3: 块稀疏注意力:长序列被分成若干块,每个 Token 仅关注其局部窗口和 top-k 相关块
Block-Sparse Attention for Prefill (用于预填充的块稀疏注意力)
  • 使用块稀疏注意力进行预填充的方法,近似选择覆盖大部分注意力得分且具有高召回率的 Top-K 块,从而将注意力计算复杂度从 \(O(n^2)\) 降低到 \(O(K)\)
    $$
    S = \text{softmax}(QK^T - c(1 - M)) \\
    \min \ |S(M) - S_{\text{dense} }|
    $$
    • 其中 \(M\) 是如上定义的块级稀疏掩码,\(c\) 是一个大常数(例如 1e5),确保重要性较低的注意力权重在 softmax 计算后趋近于零
    • 块稀疏注意力的目标是在最小开销下实现更大的加速,同时尽可能保留更多的注意力权重
  • MInference (2024)
    • 观察到注意力权重中存在三种模式:
      • 流式(A 型)模式(Streaming (A Shape) Pattern)
      • 垂直斜线模式(Vertical-Slash Pattern)
      • 块稀疏模式(Block-Sparse Pattern)
    • 它离线确定每个注意力头的最佳模式,并在推理过程中基于指定的模式动态构建稀疏索引
  • FlexPrefill (2025)
    & 提出了一种上下文感知的稀疏注意力机制,能够实时动态调整注意力模式和计算预算
  • XAttention (2025)
    • 提出了一个块稀疏注意力框架,利用反对角线(antidiagonal)评分来预测注意力块的重要性,从而能够高效识别并剪除非必要块,实现高稀疏性和显著的计算增益
  • SpargeAttn (2025)
    • 同样在预填充阶段采用块级稀疏注意力,通过一个双阶段在线过滤过程完成:
      • 第一阶段快速预测注意力图以跳过某些矩阵乘法
      • 第二阶段应用 softmax 感知的过滤器以进一步消除不必要的计算
Block-Sparse Attention for Decode (用于解码的块稀疏注意力)
  • 用于解码的块稀疏注意力方法动态选择包含每个解码步骤最关键的 Token 的子集 \(S\) 的 \(K, V\) 向量,从而减少内存加载并提高效率
  • Quest (2024) 通过计算注意力权重的上界来近似每个块的关键性。对于块 \(K_i\),论文通过以下公式维护 Element-wise 的 Min 和 Max Key \(m_i\) 和 \(M_i\):
    $$
    m_{i,d} = \min(K_{i,d}), \quad M_{i,d} = \max(K_{i,d})
    $$
    • 其中 \(\min(\cdot)\) 和 \(\max(\cdot)\) 在每个维度 \(d\) 上 Element-wise 应用
  • 给定查询 \(q\),块 \(i\) 的近似注意力得分由下式给出:
    $$
    \text{score}_i = \sum_{j=1}^{d} \max(q_j \times M_{i,j}, q_j \times m_{i,j})
    $$
  • 然后选择得分最高的 Top-K 块作为注意力计算的稀疏子集 \(S\):
    $$
    S = \text{argtopk}(\text{score}, k)
    $$
  • DoubleSparsity (2024) 通过降低计算 \(QK^T\) 乘积的矩阵乘法维度来高效近似关键 Token
    • 它首先离线计算 \(QK^T\) 中的离群通道,记为 \(C\)
    • 然后选择具有最高近似注意力得分 \(\hat{s}\) 的 Top-K Token 作为稀疏子集 \(S\):
      $$
      Q_{\text{label} } = Q_{\{C\} }, \quad \hat{s} = Q_{\text{label} } K_{\text{label} }^T, \quad S = \text{argtopk}(\hat{s}, k)
      $$
  • ReSA (2025)
    • 结合了无需训练的块稀疏估计和 GQA 共享,有助于提高效率
    • ReSA 还提出了一个修正阶段来控制 KV 缓存累积误差
    • ReSA 在长序列生成任务上显示出优势
Routing-based Block-Sparse Attention (基于路由的块稀疏注意力)
  • 基于路由的块稀疏注意力通过可训练的 MLP 层学习每个 Token 块的重要性,该层在推理期间充当门控网络以选择关键块
  • Learnable Sparsity on Pretrained Models (预训练模型上的可学习稀疏性)
  • SeerAttention (2024)
    • 通过自蒸馏(self-distillation)的方式在预训练的 LLM 上训练门控网络
    • 为了获得每个块的重要性分数,它首先沿序列维度对 \(Q\) 和 \(K\) 进行池化,记为 \(P_q\) 和 \(P_k\)。下采样后的 \(Q, K\) 然后通过一个可学习的线性层 \(W_q\) 和 \(W_k\)
    • 投影后的 \(W_q P_q(Q)\) 和 \(W_k P_k(K)\) 的矩阵相乘结果通过 softmax 运算符作为门控过程:
      $$
      \text{score} = \text{softmax}((W_q P_q(Q)) \cdot (W_k P_k(K)))
      $$
    • 可学习的线性层通过自蒸馏的方式训练,以与原始 LLM 的 2D 最大池化结果对齐。蒸馏损失计算如下:
      $$
      gt = \text{MaxPool2D}(\text{softmax}(QK^T)), \quad loss = D_{KL}(gt \ || \ \text{score})
      $$
    • 在推理过程中,门控分数通过 Top-K 或阈值化来预测块级稀疏性,用于稀疏计算和效率提升
  • Training-aware Sparse Attention (训练感知的稀疏注意力)
    • Landmark (2023)
      • 提出使用特殊的“地标”(landmark) Token 来表示每个块,并训练注意力机制通过这些地标 Token 直接检索 Top-K 块
      • 然而,它没有在大型预训练模型上进行实验
    • MoBA (2025)
      • 将可训练的稀疏注意力集成到预训练阶段
      • 它提出了 Mixture of Block Attention,应用来自 MoE(专家混合)的 Top-K 机制作为门控机制,为每个查询 Token 决定关键块
      • 每个块的重要性分数通过查询 Token \(q\) 和块 \(K_i\) 沿 Token 维度的平均池化结果的内积计算:
        $$
        s_i = \langle q, P_{\text{mean} }(K_i) \rangle
        $$
      • 然后选择具有最高 \(s\) 分数的 Top-K 块用于计算 \(q\) 的注意力
      • Notably,MoBA 使用的 Top-K 块选择是不可微分的
        • 因此,在预训练阶段,稀疏模式仍然以无需训练的模式进行估计,从而实现高效的推理和加速的训练
    • NSA (2025)
      • 引入了一种训练感知的多粒度稀疏注意力机制,包含三个分支 \(C \in \{\text{cmp}, \text{slc}, \text{win}\}\),分别对应压缩(compression)、选择(selection)和滑动窗口(sliding window)策略
      • NSA 利用一个可微分的压缩分支来学习块选择分数。结合三个分支,NSA 的注意力输出由下式给出:
        $$
        o = \sum_{c \in \mathcal{C} } g^c \cdot \text{Attn}(q, K^c, V^c), \quad g_c \in [0, 1]
        $$
      • 对于压缩分支 \(c = \text{cmp}\),块 \(i\) 的键 \(K_i \in \mathcal{R}^{d_i \times b}\) 通过一个可学习的 MLP 层 \(\varphi\) 被压缩为单个键 \(K_i^{\text{cmp} } \in \mathcal{R}^{d_i \times 1}\)
      • 对于选择分支 \(c = \text{slc}\),基于块重要性分数 \(p\) 选择 Top-K 块,该分数可以直接从压缩分支获得
    • InfLLM-v2 (2025)
      • 采用了与 MoBA 类似的训练感知 Top-K 块稀疏注意力机制
      • 为了提高 Top-K 块选择的准确性,它将块划分为具有重叠的小粒度内核,并在每个块内对内核重要性分数进行聚合
System-level Design Choices (系统级设计选择)
  • 训练感知的稀疏注意力 (2025; 2025; 2025) 开始考虑内核实现和高效执行
  • 为了实现块稀疏注意力的高效实现,FlashAttention (2023) 被用于在高效的平铺机制中进行注意力计算,这为更好地利用硬件资源带来了要求和机会,包括:
    • 为避免内存访问不一致,在 SeerAttention (2024) 和 MInference (2024) 中,块大小 \(b\) 通常设置为至少 64 的相对较大值
    • 为了与 GPU 张量核心上分组矩阵乘法指令(Grouped Matrix Multiplication)的最低要求对齐,在 NSA (2025) 和 InfLLM-v2 (2025) 中,一个查询组内的 K、V 头数被设置为至少 16
    • 为了减少内存访问,NSA (2025) 和 InfLLM-v2 (2025) 强制查询组之间共享选定的块,这是通过在查询组内对块级重要性分数进行池化来完成的

Clustering Attention (聚类注意力)

  • 与块稀疏注意力类似,聚类注意力旨在为解码选择最关键地 Token ,但将 Token 组织在数据结构中以获得更好的语义属性或采样效率
  • RetrievalAttention (2024)
    • 采用近似最近邻搜索(Approximate Nearest Neighbor Search, ANNS)来选择关键的 K 个聚类
    • 为了解决注意力机制中查询向量和键向量之间的分布外(out-of-distribution)性质带来的挑战,它引入了一种适应查询向量分布的注意力感知向量搜索算法
  • ClusterKV (2024)
    • 在语义聚类的粒度上选择 Token ,克服了诸如 Quest 等页面级检索方法内部碎片化的问题
    • 在预填充阶段之后, Token 通过 K-means 算法进行聚类
    • Token \(i\) 和 \(j\) 之间的语义相似性通过键向量的余弦相似性度量:\(\mathcal{D}(i, j) = 1 - \frac{\langle k_i, k_j \rangle}{|k_i| \cdot |k_j|}\)
    • 语义聚类由其质心 \(\mu_1, \mu_2, …, \mu_C \in \mathcal{R}^d\) 表示
    • 在每个解码步骤,基于查询 Token \(q\) 和质心 \(\mu_i\) 的注意力权重(即 \(q \mu_i^T\))的排名来选择聚类
  • MagicPIG (2024)
    • 利用局部敏感哈希(Locality Sensitive Hashing, LSH)采样来高效近似注意力计算
    • 它使用 LSH 将相似的查询向量和键向量映射到相同的哈希桶中,并将存储和部分计算卸载到 CPU,以解决 KV 缓存的瓶颈
    • 它还引入了 Oracle Top-K 采样作为比暴力 Top-K 更好的策略

Bidirectional Sparse Attention (双向稀疏注意力)

  • 双向稀疏注意力建立在编码器风格的架构上,使用静态模式或块级稀疏性来加速注意力计算
  • 块稀疏性在双向稀疏注意力中被广泛使用
    • BigBird (2020)
      • 使用块级随机注意力,作为缩短 Token 之间间接路径的桥梁
    • Longformer (2020)
      • 使用静态的全局-局部混合注意力
      • Longformer 也依赖于块级稀疏性,并带有额外的全局和随机链接,以促进结构化计算和内存高效的并行性
  • 基于聚类的方法也用于双向稀疏注意力
    • Reformer (2020)
      • 使用局部敏感哈希(LSH)将相似的 Token 分配到同一个桶中
    • Routing Transformer (2020)
      • 在每一层执行在线 \(k\)-means 聚类
    • ClusterFormer (2022)
      • 引入了一个与下游目标共同训练的可微分聚类模块
      • 这些方法通过分组相关 Token 来减少计算,同时通过学习的适应性来保持性能

采用高效注意力机制的预训练大语言模型

采用统一高效注意力机制的预训练模型

  • 尽管早期的线性注意力探索通常局限于小规模模型,但近期的进展已证明其能够成功扩展到数十亿参数的规模,使其成为标准 Transformer 的一种可行且高效的替代方案
  • 这些模型完全基于线性注意力或其架构等价物,如状态空间模型和循环神经网络,即使在大规模下也能保持其标志性的推理效率
  • 基于 RWKV 的模型
    • RWKV 项目代表了一项持续且有影响力的努力,旨在创建一个可扩展的循环神经网络架构,它结合了 Transformer 的可并行化训练与传统 RNN 的高效推理 (2023)
    • 例如,EAGLE 系列引入了矩阵值状态以增加容量,而后续迭代如 Finch (RWKV-6) (2024) 和 Goose (RWKV-7) (2025) 则结合了动态循环和更具表达力的状态演化机制,以支持更复杂、数据依赖的状态转换
  • 基于 Mamba 的模型
    • Mamba 架构的成功及其数据依赖的选择机制,已引发了主要研究实验室的一波采用和扩展浪潮
    • Falcon Mamba (2024)
      • 基于纯 Mamba 架构,在一系列通用语言基准测试中展现出与领先 Transformer 模型相竞争的性能,验证了该架构在此类任务上的可行性,同时保持了其标志性的恒定时间推理
    • Codestral Mamba (2024)
      • 基于 Mamba-2 架构,进一步证明了该范式的潜力
      • 虽然专门用于代码生成,但它在相关基准测试中取得了最先进的结果,并支持 256K 个 Token 的上下文长度,展示了 SSM 方法在复杂结构化领域内的可扩展性和有效性
  • 基于稀疏注意力的模型
    • MiniCPM-4 (2025) 引入了一种两阶段稀疏注意力机制,根据语义相似性为每个查询 Token 动态选择相关的键值块
    • MiniCPM-4 利用 InfLLM-v2(一种块稀疏注意力变体)来替代标准注意力机制
    • 此外,一种轻量级的 LogSumExp 近似实现了高效的 top-k 选择,使得该方法能够扩展到极长序列
    • 这些技术共同使 MiniCPM-4 能够在细粒度上下文感知能力与可控的内存和计算需求之间取得平衡,使其成为长上下文建模的有力候选者

采用混合高效注意力机制的预训练模型

  • 随着对高效长上下文建模和多样化计算范式需求的增长,最近的研究广泛探索了混合注意力机制
  • 此类策略结合了全局和局部注意力组件,通常交错使用专门设计的层,以平衡计算成本和性能
  • 候选模型架构示意图如图 4 所示
  • 稀疏混合模型
    • GPT-3 (2020) 通过交错使用稠密注意力和局部带状稀疏注意力层,集成了一种混合注意力机制,其灵感来自 Sparse Transformer (2019)
    • 稠密注意力提供全上下文建模,而稀疏层则采用固定或跨步模式来减少被关注的 Token 数量
    • 这种设计使得 GPT-3 能够在 2048 个 Token 的固定上下文窗口内高效扩展到大规模模型,平衡了建模能力和计算效率
  • 线性-全注意力混合模型
    • Jamba (2024) 和 MiniMax-01 (2025) 结合了线性注意力和全注意力层,以在吞吐量和表达力之间实现高效的权衡
    • MiniMax-01 在大多数层中使用 Lightning Attention,并每八层插入一次基于 Softmax 的全注意力
    • Jamba 采用了相似的比例,在每八层的 Mamba 块中插入一个 Transformer 层
    • 两者都通过限制计算密集的全注意力的使用,实现了更快的解码和改善的长序列性能
  • 局部-全注意力混合模型
    • Gemma 3 (2025)、Command A (2025) 和 LLaMA-4-Maverick (2025) 在局部和全局注意力层之间交替使用,其共享的设计理念是稀疏地使用全局层(例如,每 4-6 层一次)以提升效率
    • 虽然局部层采用滑动窗口模式,但关键区别在于位置编码策略
    • Gemma 3 调制 RoPE 的基础频率——为局部层分配 10K,为全局层分配 1M(以更好地捕获长距离依赖关系)
    • Command A 和 LLaMA-4-Maverick 混合了基于 RoPE 的局部层与完全省略位置嵌入的全注意力层,从而实现了更强的长序列性能
  • 先进混合模型
    • Character.AI (2024) 将滑动窗口的局部注意力与每六层应用一次的稀疏全局注意力层交错使用
    • 特别是,它们在多个非相邻层之间复用全局注意力层的键值表示
    • 这种 KV 共享机制能以减少的内存和延迟开销实现高效的长上下文处理
  • YOCO (2024) 和 Phi-4-mini-flash (2025) 采用了一种双解码器架构,将预填充阶段和生成阶段分离开来
    • 自解码器在预填充和生成中都使用 RetNet 和滑动窗口注意力等线性注意力机制,而交叉解码器仅在生成期间激活
    • 整个过程中使用单层全局 KV 缓存,实现了线性时间的预填充和高效的解码,同时 GPU 内存消耗最小
  • In summary,这些最新进展强调了混合注意力机制以在不同计算约束和序列长度下实现平衡性能的趋势。每种架构都独特地贡献了关于如何有效结合局部细节管理与全局上下文整合的见解,从而为未来注意力机制的发展提供了有价值的框架

Outlook

  • 本综述全面概述了高效注意力机制,重点关注其算法基础、实际实现以及在大规模预训练语言模型中的集成
  • 通过将线性和稀疏注意力分类为明确定义的范式,论文识别了实现可扩展性、计算效率和长上下文能力的关键设计原则
  • 论文还分析了这些机制在最先进模型中如何部署,无论是作为独立架构还是作为平衡局部和全局计算的混合设计的一部分
  • 展望未来,论文强调几个预计将塑造该领域未来研究的关键方向:
  • 对混合模型的架构理解
    • 虽然先前关于线性注意力的工作主要集中于独立的线性架构,但混合模型通常通过结合现成的线性注意力模块与稠密或局部组件来构建
      • 但更强的线性主干是否直接转化为改进的混合性能尚不清楚
    • 未来的工作应将混合模型作为一个独特的架构类别来研究,试图理解它们的组成、相互作用效应和优化动态
  • 无损稀疏注意力与扩展上下文
    • 稀疏注意力仍然受到精度和计算增益之间权衡的挑战。完全训练的稀疏模型通常性能不如稠密模型,而训练后稀疏近似则由于缺乏端到端训练而面临限制
    • 一个主要的研究前沿在于开发能够保持稠密注意力的表达力和精度,同时扩展到更长上下文的稀疏注意力机制
    • Besides,稀疏预算与上下文长度之间的关系尚不明确,固定的 top-k 方案在序列更长时可能会退化,这需要更具适应性的策略
  • 对稀疏和混合注意力的机制性洞见
    • 实证研究反复证明,混合注意力模型可以用更少的注意力计算匹配甚至超越稠密模型,但其有效性的根本原因仍未得到充分探索
    • 此外,研究在合成基准测试中表现良好的稀疏模式是否适用于现实世界任务,以及描述基于稀疏性的泛化极限,尤为重要
  • 随着基于注意力的模型不断发展,论文预计架构创新、理论洞见和硬件感知设计之间将进一步融合
  • 作者希望本综述能为未来高效、高性能语言建模系统的研究奠定坚实的基础
1…515253…61
Joe Zhou

Joe Zhou

Stay Hungry. Stay Foolish.

608 posts
49 tags
GitHub E-Mail
© 2026 Joe Zhou
Powered by Hexo
|
Theme — NexT.Gemini v5.1.4