NLP——GShard

注:本文包含 AI 辅助创作

Paper Summary

  • 评价:
    • 本文是 2020 年的文章,是非常重要的 Scaling 方向的作品,对模型规模的扩展意义重大
    • 核心思路:
      • 条件计算(Conditional Computation):采用了稀疏门控混合专家(Sparsely-Gated MoE)架构
      • 自动分片(Automatic Sharding):通过轻量级标注 API(replicate/split/shard)和 XLA 编译器扩展,实现张量的自动分布式划分
      • SPMD 编译优化:提出单程序多数据(SPMD)编译范式(替代传统多程序多数据(MPMD)模式),为所有设备生成单一程序,编译时间与设备数量无关
    • Google 出品,必属精品
  • GShard 是一个由一组轻量级标注接口(API)和 XLA 编译器扩展组成的模块
    • Gshard 提供了一种简洁的方式,只需对现有模型代码进行少量修改,就能实现多种并行计算模式
    • Gshard 解决了神经网络扩展过程中的 计算成本、编程便捷性以及在并行设备上的高效实现等问题
  • 作者使用 GShard 将带有稀疏门控混合专家层的多语言神经机器翻译 Transformer 模型扩展到了 600B 参数以上(应该是当时第一次有这么大的模型)
  • 论文的实验表明:
    • 这样一个巨型模型能够在 2048 个 TPU v3 加速器上高效训练 4 天,在 100 种语言到英语的翻译任务中,取得了远超现有技术水平的翻译质量

Introduction and Discussion

  • 在各类机器学习问题中,神经网络扩展都带来了显著的质量提升
    • 在计算机视觉领域,提升模型容量使得多种计算机视觉架构在图像分类和检测精度上有了更好的表现
    • 在自然语言处理领域,Transformer 模型的扩展在语言理解任务(2018, 2019)、跨语言下游迁移任务(2018, 2019)以及(大规模)多语言神经机器翻译任务(2019)中都实现了稳定的性能提升
  • 这种普遍趋势促使近期研究开始深入探究影响扩展成功的关键因素(2017, 2019, 2020),其中包括过往研究发现的训练数据量、模型大小和计算资源利用率等
  • 尽管研究发现最终模型质量与数据量、计算资源和模型大小呈幂律关系(2017, 2020),但更大模型带来的显著质量提升也伴随着各种实际挑战
  • 训练效率便是其中至关重要的挑战之一,论文将其定义为:为获得优于现有最佳系统的模型质量所消耗的计算资源训练时间 ,而这一指标往往被忽视
  • 图1:
    • 多语言翻译质量(与双语基线模型相比的平均 \(\delta\) BLEU 值)随着混合专家(MoE)模型规模增长到 600B 参数而不断提升,而端到端训练成本(以 TPU v3 核心年为单位)仅呈亚线性增长
    • 将模型规模从 37.5B 参数扩大到 600B 参数(16 倍),计算成本仅从 6 个核心年增加到 22 个核心年(3.6倍)
    • 达到最佳翻译质量的 600B 参数模型,使用 2048 个 TPU v3 核心训练了 4 天,总成本为 22 个 TPU v3 核心年
    • 相比之下,训练所有 100 个双语基线模型则需要 29 个 TPU v3 核心年
    • 论文训练的最佳质量密集型单一 Transformer 模型(2.3B 参数), \(\delta\) BLEU 值为 6.1,该模型使用 GPipe(2019)在 2048 个 TPU v3 核心上训练了 6 周,总成本达 235.5 个 TPU v3 核心年

Practical Challenges for Scaling(扩展面临的实际挑战)

  • 本节将列举在训练超大规模模型时面临的主要实际挑战,这类模型的规模远超单个加速器内存(如 GPU 或 TPU)的容量限制
  • 特定架构的模型并行支持问题 :在 TensorFlow(2016)和 PyTorch(2017)等常用深度学习框架中,缺乏对高效模型并行算法的支持
    • 框架虽支持基于图分割的简单模型并行,但由于网络的顺序依赖性和基于梯度的优化方式,这种并行方式会导致设备利用率严重不足
    • 为了高效扩展现有模型,用户通常需要投入大量工程工作,例如将模型代码迁移到专门的框架中(2018, 2019)
  • 计算成本与模型规模的超线性扩展问题 :通过增加网络深度或宽度来直接扩展模型规模(2018, 2019),通常会导致训练步长时间至少呈线性增长
    • 为解决这一问题,通常需要通过在多个设备间分割层权重和计算过程来实现模型并行,但这会带来网络通信开销和设备利用率不足的问题
    • 设备利用率不足源于神经网络底层计算任务分配不均衡以及存在顺序依赖关系
    • 这种计算成本与模型规模之间的超线性关系,无法通过简单增加设备数量来解决,使得训练超大规模模型变得不切实际
  • 巨型模型表示的基础设施扩展性问题 :对于分布在数千个设备上的超大规模模型,简单的图表示可能会成为深度学习框架及其优化编译器的瓶颈
    • 例如,通过操作间分割(inter-op partitioning)增加 D 倍的层数,或通过操作内分割(intra-op partitioning)在 D 个设备上增加模型维度,都可能导致图节点数量大幅增加
    • 设备间的通信通道可能会进一步使图大小增加(例如,分割聚合或转置操作时)
    • 对于超大规模模型而言,图大小的这种增长会导致图构建和编译时间达到无法实现的程度
  • 实现分片策略的复杂工作 :将模型高效地分割到多个设备上运行具有挑战性,因为这需要协调设备间的通信
    • 在图级分割方面,需要复杂的算法(2019, 2018)来减少因不同设备上分配的图分割部分之间存在顺序依赖而引入的开销
    • 在算子级并行方面,不同分割后的算子会有不同的通信模式,这取决于算子的语义,例如是否需要累积部分结果或重新排列数据分片
    • 根据论文的经验,由于 TensorFlow 等框架包含大量具有特殊语义的算子,在模型中手动处理这些问题需要耗费大量精力
    • 在所有情况下,实现模型分片对研究人员和工程师来说都是一项负担,因为改变模型架构就需要修改底层的设备通信逻辑,从而产生连锁反应

Design Principles for Efficient Training at Scale

  • 论文展示了如何通过构建一个拥有 600B 参数、带有稀疏门控混合专家层(Sparsely-Gated Mixture-of-Experts layers)的 sequence-to-sequence Transformer 模型,来克服上述挑战
    • 该模型的计算成本呈亚线性增长,编译时间为 \(O(1)\)
    • 论文在 2048 个 TPU v3 设备上,针对多语言机器翻译任务训练该模型 4 天,最终实现了单个非集成模型在 100 种语言到英语翻译任务中远超现有技术的翻译质量
    • 论文对不同规模的模型进行了实验,结果发现,随着模型规模增大,翻译质量不断提升,而训练总耗时(wall-time)相对于模型规模仅呈亚线性增长,如图1所示
  • 为构建如此庞大的模型,论文做出了以下关键设计选择
    • 亚线性扩展(Sub-linear Scaling) :首先,模型架构的设计应确保计算和通信需求相对于模型容量呈亚线性增长
      • 条件计算(2015, 2019, 2020, 2020)通过在每个输入样本的基础上激活一个子网络,使论文能够兼顾训练和推理效率
      • 通过在基于循环神经网络(RNN)的机器翻译和语言模型中添加位置感知稀疏门控混合专家层(Position-wise Sparsely Gated Mixture-of-Experts, MoE)(2019),可以在实现最先进性能的同时,使计算成本呈亚线性增长
      • 因此,论文将在第2节中详细介绍如何用混合专家层扩展 Transformer 架构
    • 抽象的力量(The Power of Abstraction) :其次,模型描述应与分片实现和优化相分离
      • 这种关注点分离使模型开发人员能够专注于网络架构,并灵活地改变分片策略,而底层系统则负责执行语义保持转换并实现高效的并行执行
      • 为此,论文提出了 GShard 模块,用户只需在模型中对少数关键张量标注分片策略即可
      • 该模块包含一组简单的标注接口,以及一个用于自动并行化的 XLA(2019)编译器扩展
      • 模型开发人员可以像在一个拥有超大内存和计算能力的单一设备上编写模型,编译器会根据标注信息和自身的启发式算法,自动为目标设备分割计算任务
      • 论文将在3.2节中提供更多标注示例
    • 可扩展编译器(Scalable Compilers) :第三,包括计算表示和编译在内的系统基础设施,必须能够支持数千个设备的并行执行
      • 例如,图2 展示了在 4 个设备上分割点积运算(用颜色编码)的两种不同方式
      • 需要注意的是,图2a 中常用的 MPMD(多程序多数据,Multiple Program Multiple Data)方法在扩展性方面面临挑战
        • 因为图中的节点数量会随着设备数量的增加而线性增长
      • 相反,论文开发了一种用于 SPMD(单程序多数据,Single Program Multiple Data)转换的编译器技术
        • 该技术生成一个可在所有设备上运行的单一程序,使得编译时间与设备数量无关,保持恒定,如图2b所示
        • 论文将在3.3节中详细讨论 SPMD 框架
  • 论文其余部分的结构如下:
    • 第2节详细描述带有稀疏门控混合专家层的 Transformer 架构;
    • 第3节介绍论文开发的 GShard 模块;
    • 第4节展示混合专家模型在100个语言对的多语言机器翻译任务中的应用;
    • 第5节对实现的性能和内存使用情况进行评估;
    • 第6节讨论相关工作
  • 图2:
    • 在4个设备上对点积算子( \([M, K] \times [K, N] = [M, N]\) )进行 MPMD 分片与论文提出的 SPMD 分片的对比
    • 在该示例中,两个操作数均沿收缩维度 K 进行分片,每个设备计算本地结果后,通过 AllReduce 操作进行全局合并
    • MPMD 分片会为每个设备生成独立的算子,限制了其扩展性;
    • 而 SPMD 分片则生成一个可在所有设备上运行的程序
    • 需要注意的是,使用论文提出的 SPMD 分片时,编译时间与所使用的设备数量无关

Model

Sparse scaling of the Transformer architecture

  • Transformer架构(2017)已被广泛应用于自然语言处理领域,成为许多 sequence-to-sequence 任务(如机器翻译)的事实标准
  • Transformer 包含两个计算模块,即 Encoder 和 Decoder,两者均通过堆叠多个 Transformer 层实现
  • Transformer Encoder 层由两个连续的层组成,即自注意力层(self-attention layer)之后紧跟一个位置感知前馈层(position-wise feed-forward layer)
  • Decoder 则额外增加了第三个交叉注意力层(cross-attention layer),该层会关注 Encoder 的输出
  • 论文通过条件计算(conditional computation)对 Transformer 进行稀疏扩展,在 Encoder 和 Decoder 中,每隔一个前馈层就用一个位置感知混合专家层(Position-wise Mixture of Experts, MoE)(2019)替代,该混合专家层采用了一种改进的 top-2 门控机制(图3)
  • 论文通过改变 Transformer 层的数量和每个混合专家层中专家的数量来调整模型容量
  • 每个训练样本由一对子词 Token(subword tokens)序列组成
  • 在训练和推理过程中,每个 Token 都会激活混合专家 Transformer 的一个子网络
  • 子网络的大小大致与每个混合专家层中的专家数量无关,这使得计算成本能够像前一节所述的那样呈亚线性增长
  • 3.1节将进一步分析计算复杂度,5节将分析训练性能

Position-wise Mixture-of-Experts Layer

  • 论文模型中使用的混合专家(MoE)层基于文献(2019)的设计,并在稀疏门控函数和辅助损失函数方面进行了改进
  • Transformer 的混合专家层由 E 个前馈网络(FFN)组成,其计算过程如下:

$$ g_{s,e} = \text{GATE}(x_s) \quad \tag{1} $$

$$ \text{FFN}_e(x_s) = w o_e \cdot \text{ReLU}(w i_e \cdot x_s) \quad \tag{2} $$

$$ y_s = \sum_{e=1}^{E} g_{s,e} \cdot \text{FFN}_e(x_s) $$

  • 其中:
    • \(x_s\) 是混合专家层的输入 Token
    • \(w i_e\) 和 \(w o_e\) 分别是前馈层(即一个专家)的输入投影矩阵和输出投影矩阵
    • 向量 \(g_{s,e}\) 由门控网络计算得到
    • \(g_{s,e}\) 中每个专家对应一个非负值,其中大部分值为零,这意味着该 Token 不会被分配给对应的专家
      • 每个 Token 仅被分配给极少数专家,论文设定每个 Token 最多被分配给两个专家
    • \(g_{s,e}\) 中对应的值非零,表示该专家对最终网络输出的贡献程度
    • 每个专家 \(\text{FFN}_e\) 都采用含 ReLU 激活函数(2010)的两层全连接网络对 \(x_s\) 进行处理
    • 混合专家层的输出 \(y_s\) 是所有被选中专家输出的加权平均值
  • 门控函数 \(\text{GATE}(\cdot)\) 是混合专家层的核心,它采用 softmax 激活函数来表示每个专家在处理输入 Token 时的权重,即表示某个专家处理该输入 Token 的适合程度
  • 此外,门控函数必须满足以下两个目标:
    • 负载均衡(Balanced load) :对于给定的 Token ,混合专家层最好能稀疏地激活专家
      • 一种简单的方案是根据 softmax 概率分布选择前k个专家,但研究表明这种方法会导致训练中的负载不均衡问题(2019):训练过程中看到的大多数 Token 会被分配给少数几个专家,导致这几个(繁忙的)专家的输入缓存变得极大,而其他专家则处于未充分训练的状态,从而减慢训练速度
        • 问题:分配给少数专家也不一定会导致训练速度变慢吧,除非有 EP (专家并行)?
      • 同时,许多其他专家根本无法得到充分训练
      • 因此,门控函数需要设计得更合理,以实现所有专家间处理负载的更均匀分配
    • 大规模下的效率(Efficiency at scale) :如果门控函数采用顺序执行方式,实现负载均衡相对容易
      • 但对于输入批次中所有 N 个 TokenE 个专家 ,仅门控函数的计算成本就至少为 \(O(N \cdot E)\)
      • 在论文的研究中,N 达到数百万量级,E达到数千量级,若门控函数采用顺序实现,会导致大部分计算资源在大部分时间处于空闲状态
      • 因此,论文需要一种高效的并行门控函数实现方式,以充分利用多个设备的计算能力
  • 为满足上述要求,论文在门控函数 \(\text{GATE}(\cdot)\) 中设计了以下机制(细节如算法1所示):
    • 专家容量限制(Expert capacity) :为确保负载均衡,论文强制每个专家处理的 Token 数量不超过某个统一的阈值,论文将该阈值定义为专家容量
      • 假设一个训练批次中的 Token 总数为 N,每个 Token 最多被分配给两个专家,那么专家容量设置为 \(\frac{2N}{E}\)
      • 门控函数 \(\text{GATE}(\cdot)\) 会为每个专家维护一个计数器 \(c_e\) ,用于记录分配给该专家的 Token 数量
      • 当一个 Token 选中的两个专家都已超过其容量限制时,该 Token 被视为溢出 Token ,此时 \(g_{s,e}\) 退化为零向量
      • 这类 Token 的表示会通过残差连接传递到下一层
        • 问题:这里是说专家容量限制也传递到下一层?
    • 本地组分配(Local group dispatching)
      • 门控函数 \(\text{GATE}(\cdot)\) 将训练批次中的所有 Token 均匀划分为 G 个组,即每个组包含 \(S = \frac{N}{G}\) 个 Token ,所有组独立并行处理
      • 每个组被分配到每个专家的部分容量为 \(\frac{2N}{G \cdot E}\)
      • 每个组会确保分配给每个专家的 Token 数量不超过该部分容量
      • 通过这种方式,既能保证专家容量限制得到遵守,又能实现整体负载均衡
    • 辅助损失(Auxiliary loss) :门控函数不应总是选择相同的几个专家,否则会导致仅少数专家出现容量溢出,而其余专家利用率不足(这一点至关重要)
      • 借鉴文献(2019)的方法,论文定义了一个辅助损失项来强制执行这一约束
      • 该辅助损失项以一个常数系数 k 添加到模型的总损失函数中
      • 算法1第13行中辅助损失项的具体形式基于以下考虑:
        • \(\frac{c_e}{S}\) 表示分配给每个专家的输入 Token 比例,作者希望最小化 \(\frac{c_e}{S}\) 的均方值
        • 但由于 \(c_e\) 是通过 top-2 操作得到的,不具有可微性,因此论文使用每个专家的平均门控值作为可微近似,并用 \(m_e \cdot \frac{c_e}{S}\) 替代 \((\frac{c_e}{S})^2\) ,这样就可以通过梯度下降法对其进行优化
    • 随机路由(Random routing)
      • 直观上,由于 \(y_s\) 是所选专家输出的加权平均值,如果第二个专家的权重非常小,我们可以直接忽略该专家 ,以节省整体专家容量
      • 因此,除了遵守专家容量约束外,门控函数 \(\text{GATE}(\cdot)\) 还会以与第二个最佳专家权重 \(g_2\) 成比例的概率,将 Token 分配给该专家

Highly Parallel Implementation using GShard(高度并行化)

  • 本节将介绍第2节中模型在张量处理单元(TPU)设备集群上的高效运行实现方案
  • 实现的第一步是将模型用线性代数运算表示,因为论文的软件栈(TensorFlow(2016))和硬件平台(TPU)在这类运算上经过了高度定制和优化
    • 与原始 Transformer 模型类似,将模型的大部分部分用线性代数运算编写非常容易
    • 但由于混合专家层(MoE Layer),尤其是算法1 中提出的 \(\text{GATE}(\cdot)\) 函数具有顺序执行特性,要用线性代数运算表示该层则需要额外付出努力,论文将在3.1节详细介绍相关细节
  • 接下来,论文对线性代数计算进行标注,以体现并行性
  • 通过3.2节中的分片接口(sharding APIs),可以对计算中的每个张量进行标注,指定其在设备集群中是采用复制(replication)还是分布(distribution)方式存储
  • 分片标注能够实现模型描述与高效并行实现的关注点分离,让用户可以灵活地表达各种并行化策略;例如:
    • (1)注意力层(attention layer)通过沿批次维度(batch dimension)分割并将权重复制到所有设备上来实现并行化;
    • (2)由于混合专家层中专家(expert)的规模极大,无法在所有设备上进行复制,因此唯一可行的策略是将专家分片存储到多个设备中
    • 此外,整个模型会在这两种模式(1)-(2)之间切换
    • 通过标注,模型开发人员无需关注系统优化工作,也不必将并行实现细节和底层细节融入模型代码中
  • 最后,编译器基础设施会接收(部分)标注后的线性代数计算,并生成可在数千个设备上高效运行的并行程序
  • 如3.3节所述,编译器会应用单程序多数据(SPMD,Single Program Multiple Data)分片转换来表示每个设备的计算任务,插入必要的跨设备通信操作,处理非规则模式(如非均匀分片),最终生成一个可在所有设备上启动并执行并行计算的单一程序

Positions-wise Mixture-of-Expert Layer Expressed in Linear Algebra(线性代数)

  • 论文的模型实现(算法2)将整个加速器集群视为一个单一设备,并用少量与集群具体配置无关的张量运算来表示其核心数学算法

    • 爱因斯坦求和符号(Einstein summation notation)(1923)(即tf.einsum)是一种能够简洁表示模型的强大工具,论文在实现中大量使用了该符号
    • softmax 门控计算可以通过一个爱因斯坦求和运算(einsum)后紧跟 softmax 函数来轻松表示;
    • 将输入分配给选定专家的操作,可以通过分配掩码(dispatching mask)与输入之间的单次爱因斯坦求和运算来表示;
    • 所有 \(\text{FFN}_e\) 的权重被组合成单个三维张量 \(wi\) 和 \(wo\)
    • \(\text{FFN}_1 \dots \text{FFN}_E\) 的计算则通过三个算子(两个爱因斯坦求和运算和一个relu运算)来表示;
    • 最后,对所有专家输出进行加权平均以得到最终输出的操作,也可以通过另一个爱因斯坦求和运算来表示
  • 算法2中的 Top2Gating 计算了算法1中所有组本地(group-local)门控决策的并集

    • combine_weights是一个四维张量,形状为[G, S, E, C]
    • 其中,combine_weights[g, s, e, c] 非零时,表示组 g 中的输入 Token s被发送到专家 e 的输入缓存中,且位于缓存位置 c
    • 对于特定的 g 和 s,combine_weight 切片中最多包含两个非零值
    • 通过将所有非零值设为 1,可由 combine_weights 生成二进制分配掩码(binary dispatch_mask)
  • 论文需要合理选择组数(G)和专家数量(E),以确保该算法能够在包含D个设备的集群上实现扩展

    • 有必要分析在一个包含 N 个 Token 的训练批次中,该算法在一个训练步骤内的整体计算复杂度(浮点运算总数)
  • 算法2 :位置感知混合专家层的前向传播(Forward pass of the Positions-wise MoE layer),下划线字母(如G和E)表示张量将沿该维度进行分片

    1
    2
    3
    4
    5
    6
    7
    1: gates = softmax(einsum("GSM, ME -> GSE", inputs, wg))
    2: combine_weights, dispatch_mask = Top2Gating(gates)
    3: dispatched_expert_inputs = einsum("GSEC, GSM -> EGCM", dispatch_mask, reshaped_inputs)
    4: h = einsum("EGCM, EMH -> EGCH", dispatched_expert_inputs, wi)
    5: h = relu(h)
    6: expert_outputs = einsum("EGCH, EHM -> GECM", h, wo)
    7: outputs = einsum("GSEC, GECM -> GSM", combine_weights, expert_outputs)
    • 在分析算法2的计算复杂度随设备数量D的扩展情况时,论文做出以下假设:
      • a)每个设备上的 Token 数量 \(\frac{N}{D}=O(1)\) (在实际应用中,为避免设备内存溢出,这一假设通常是必要的),且保持恒定;
      • b) \(G=O(D)\) 、 \(S=O(1)\) ;
      • c) \(H=O(1)\) ;
      • d) \(E=O(D)\) ;
      • e) \(C=O(\frac{2S}{E})=O(\frac{1}{D})\) ,且 \(D< S\) ,D 为正整数
  • 算法2中的浮点运算总数计算如下:
    $$
    \begin{align}
    \text{FLOPS}_{\text{Softmax} } + \text{FLOPS}_{\text{Top2Gating} } + \text{FLOPS}_{\text{Dispatch/Combine} } + \text{FLOPS}_{\text{FFN} } &= \\
    O(G S M E) + O(G S E) + O(G S M E C) + O(G E C H) &= \\
    O(D \cdot 1 \cdot 1 \cdot D) + O(D \cdot 1 \cdot D) + O(D \cdot 1 \cdot 1 \cdot D \cdot \frac{1}{D}) + O(D \cdot D \cdot \frac{1}{D} \cdot 1) &= \\
    O(D^2) + O(D^2) + O(D) + O(D)
    \end{align
    }
    $$

    • 因此,每个设备的浮点运算量为 \(O(D)\)
    • 每个设备上 softmax 的计算复杂度 \(\text{FLOPS}_{\text{Softmax} } / D = O(D)\) 与设备数量呈线性关系,但实际上,由于 \(D< S\) ,该复杂度会被其他项主导,因此可将整体复杂度视为 \(O(1)\) ,满足亚线性扩展的设计要求
    • 5节通过实验验证了这一分析结果
  • 除计算成本外,跨设备通信成本并非恒定,但如5节所述,当设备数量增加时,通信成本仅以 \(O(\sqrt{D})\) 的温和速率增长

GShard Annotation API for Parallel Execution

  • 由于算法1中张量的规模和计算需求极大,论文必须在多个设备上对该算法进行并行化处理
  • 算法2中带下划线的字母展示了对每个张量进行分片的一种直接方案
  • GShard 中的分片接口允许论文对程序中的张量进行标注,选择性地指定其分片方式
  • 这些信息会传递给编译器,以便编译器自动应用转换以实现并行执行
  • 在论文的研究中,使用了 TensorFlow/Lingvo(2019)中的以下接口
    • replicate(tensor):对张量进行标注,使其在各个分片(partition)中复制,并返回标注后的张量。该接口通常用于模型中的非混合专家层(non-MoE layers),以实现权重复制
    • split(tensor, split_dimension, num_partitions):对张量进行标注,使其沿split_dimension维度进行分片,并返回标注后的张量。第i个分片会被放置在第i个设备上,且num_partitions(分片数量)不得超过系统中的设备数量
    • shard(tensor, device_assignment):是split()接口的泛化形式,支持对多个维度进行分片,并指定每个分片的放置位置。附录A.3对该接口进行了更详细的描述
  • 需要注意的是,调用splitshard接口仅会添加标注,不会改变用户程序中张量的逻辑形状。用户仍可使用完整形状的张量进行操作,无需担心非均匀分片等问题
  • GShard 具有良好的通用性,其简单接口可同样应用于所有维度
  • 根据具体应用场景,分片维度可以包括批次维度(数据并行)、特征维度、专家维度,甚至图像模型中的空间维度
  • 此外,由于分片标注是基于每个张量单独进行的,模型的不同部分可以采用不同的分片方式
  • 这种灵活性使论文能够对巨型混合专家层权重进行分片,并在混合专家层和非混合专家层之间切换分片模式,同时也支持论文未涉及的其他应用场景,例如对大型图像进行空间分片(2019)(附录A.4)
  • 通过上述分片接口,可以将算法2中所示的分片策略表示如下:
    • 输入张量沿第一个维度(组维度G)进行分片,门控权重张量(wg)采用复制方式存储
    • 计算得到分配后的专家输入(dispatched expert inputs)后,应用split接口将分片维度从组维度(G)切换为专家维度(E)。其中,D为设备数量
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      1: # 沿组维度(G)对输入进行分片
      2: inputs = split(inputs, 0, D)
      3: # 复制门控权重
      4: wg = replicate(wg)
      5: gates = softmax(einsum("GSM, ME -> GSE", inputs, wg))
      6: combine_weights, dispatch_mask = Top2Gating(gating_logits)
      7: dispatched_expert_inputs = einsum("GSEC, GSM -> EGCM", dispatch_mask, reshaped_inputs)
      8: # 沿专家维度(E)对分配后的输入进行分片
      9: dispatched_expert_inputs = split(dispatched_expert_inputs, 0, D)
      10: h = einsum("EGCM, EMH -> EGCH", dispatched_expert_inputs, wi)
Per-tensor sharding assignment
  • 如上述示例所示,用户无需对程序中的每个张量都进行标注
  • 通常只需对模型中的少数重要算子(如爱因斯坦求和算子)进行标注,编译器会通过自身的启发式算法推断出其余张量的分片方式
    • 注:由于反向传播计算通常由前端框架自动生成,用户无法访问这些张量,因此编译器推断缺失的分片信息也非常重要
    • 例如,由于输入张量沿组维度(G)进行分片,而权重张量采用复制方式存储,编译器会选择沿相同维度(组维度G)对爱因斯坦求和运算的输出进行分片(第5行)
    • 类似地,由于输入分配的爱因斯坦求和运算(第7行)的两个输入均沿组维度(G)进行分片,因此输出的分片方式会被推断为沿组维度(G)进行分片,之后论文通过添加split标注将输出的分片维度切换为专家维度(E)
    • 上述示例中的某些标注(如replicate(wg))也可由编译器自动确定,但建议对计算的初始输入张量和最终输出张量进行标注
  • 目前,编译器采用迭代数据流分析(iterative data-flow analysis),从用户标注的算子开始,将分片信息传播到其相邻算子(操作数和使用者)
    • 该分析通过对齐相邻算子的分片决策,尽量减少重分片(resharding)的需求
    • 虽然还可以采用整数规划或机器学习等其他方法,但改进自动分片分配并非论文的重点,论文将其留作未来的研究工作
Mixing manual and automatic sharding(手动分片与自动分片结合)
  • 在常见情况下,通过分片标注实现自动分片通常已足够,但GShard也支持灵活地将手动分片算子与自动分片算子结合使用
  • 这让用户能够更好地控制算子的分片方式,例如当用户掌握算子语义之外的运行时知识时
  • 例如,XLA 和 TensorFlow 的 Gather 算子定义均未包含输入中不同范围的索引边界信息,但用户可能知道某个特定的 Gather 算子仅在每个分片内对数据进行重排
  • 在这种情况下,用户只需缩小维度大小并执行本地 Gather 操作,即可轻松对该算子进行分片;
  • 否则,编译器需要对索引范围采取保守处理,并添加不必要的通信开销
  • 例如,算法2中使用独热矩阵(one-hot matrix)分配输入的爱因斯坦求和算子(第3行),也可以通过手动分片的 Gather 算子来实现,而模型的其余部分仍采用自动分片方式
  • 以下伪代码展示了该应用场景
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    1: # 输入形状为[G, S, M],split()不改变逻辑形状
    2: input = split(input, 0, num_devices)
    3: # s_indices形状为[E, G, C, 1],值为输入中S维度的索引
    4: s_indices = split(s_indices, 1, num_devices)
    5:
    6: # 开始手动分片
    7: # partitioned_input形状为[G/num_devices, S, M]
    8: partitioned_input = auto_to_manual_spmd_partition(input)
    9: # partitioned_s_indices形状为[E, G/num_devices, C, 1]
    10: partitioned_s_indices = auto_to_manual_spmd_partition(s_indices)
    11: # 在partitioned_input中拼接G维度的索引:在G维度上生成Iota张量
    12: partitioned_gs_indices = concat(
    13: iota([E, G/num_devices, C, 1], 1),
    14: partitioned_s_indices,
    15: 3
    16: )
    17: # partitioned_data形状为[E, G/num_devices, C, M]
    18: partitioned_data = gather(partitioned_input, partitioned_gs_indices)
    19:
    20: # 切换回自动分片
    21: # data形状为[E, G, C, M]
    22: data = manual_to_auto_spmd_partition(partitioned_data)
    23: ...

The XLA SPMD Partitioner for GShard

  • 本节将介绍基于分片标注自动对计算图进行分片的编译器基础设施
  • 分片标注告知编译器每个张量应如何在设备间分布
  • SPMD(单程序多数据,Single Program Multiple Data)分片器(简称“分片器”)是编译器的一个组件,它将计算图转换为可在所有设备上并行执行的单一程序
    • 注:另一种方案是MPMD(多程序多数据,Multiple Program Multiple Data),如图2所示,该方案的扩展性较差
  • 这使得编译时间几乎与分片数量无关,从而支持论文将模型扩展到数千个分片⁴
  • 论文在 XLA 编译器(2019)中实现了该分片器
  • TensorFlow、JAX、PyTorch 和 Julia 等多个前端框架已具备将其图表示转换为 XLA HLO 图的降级逻辑(lowering logic)
  • 与 TensorFlow 等流行前端框架相比,XLA 的算子数量要少得多,这在不影响通用性的前提下降低了实现分片器的负担,因为现有前端的降级过程已实现了强大的表达能力
    • 尽管论文在 XLA 中开发了该基础设施,但论文介绍的技术也可应用于其他机器学习框架的中间表示(如 ONNX(2019)、TVM Relay(2018)、Glow IR(2018))
  • XLA 将计算表示为数据流图(dataflow graph),其中节点表示算子,边表示在算子之间流动的张量
  • 分片器的核心是对每个算子进行处理,根据输入和输出指定的分片方式,将全尺寸算子转换为分片尺寸算子
  • 对计算进行分片时,会引入各种跨设备数据传输模式
  • 为了在大规模场景下最大化性能,定义一组核心通信原语(communication primitive)并针对目标平台对其进行优化至关重要
Communication Primitives,通信原语
  • 由于分片器强制所有设备运行相同的程序,通信模式也具有规律性,XLA 定义了一组集合算子(collective operator)来执行类似 MPI(2009)的通信操作
  • 下面列出了论文在 SPMD 分片器中使用的常见通信原语
    • CollectivePermute :该算子指定一组源-目标设备对(source-destination pairs),源设备的输入数据会发送到对应的目标设备
      • 它主要用于两个场景:改变分片张量在各分片间的设备顺序,以及本节后续讨论的 Halo Exchange
    • AllGather :该算子按照指定顺序将所有参与设备的张量连接起来,用于将分片张量转换为复制张量
    • AllReduce :该算子对所有参与设备的输入执行按元素归约(如求和)操作,用于合并来自不同分片的部分归约中间张量
      • 在 TPU 设备网络中,当分片数量增加时,AllReduce 的成本保持恒定(5.2节)
      • 它也是一种在其他类型网络拓扑中具有高效实现的常用原语(2019)
    • AllToAll :该算子从逻辑上沿某个维度对每个参与设备的输入进行分割,然后将每个数据片段发送到不同的参与设备
      • 每个设备在接收来自其他设备的数据片段后,会将这些片段连接起来生成结果
      • AllToAll 用于将分片张量从一个维度重分片到另一个维度
      • 在 TPU 设备网络中,AllToAll 是实现此类重分片的高效方式,其成本随分片数量的增加呈亚线性增长(5.2节)
Per-Operator SPMD Partitioning
  • 分片器的核心是根据指定的分片方式,对每个算子进行从全尺寸到分片尺寸的转换
  • 虽然某些算子(如按元素算子)的分片支持非常简单,但论文将讨论需要跨分片通信的几种常见情况
  • 一般情况下,存在一些重要的技术挑战,论文将在3.3.3节中介绍
  • 为了使讨论与混合专家模型更相关,本节将重点关注爱因斯坦求和算子(Einsum)的分片,以展示几种通信模式
  • 为简化讨论,此处假设所有张量均采用均匀分片方式,即待分片维度的大小是分片数量的整数倍
Einsum Case Study
  • 爱因斯坦求和算子是实现混合专家模型的最关键算子
  • 在 XLA HLO 中,它们被表示为点积(Dot)运算,其中每个操作数(左操作数 LHS 或右操作数 RHS)包含三种类型的维度:
    • 批次维度(Batch dimensions) :是易并行化维度(embarrassingly parallel dimensions)。所有左操作数、右操作数和输出都必须包含相同的批次维度集合,且输出中的每个元素仅依赖于左操作数和右操作数中对应的批次元素
    • 收缩维度(Contracting dimensions) :仅存在于操作数中。左操作数和右操作数必须包含相同的收缩维度集合,这些维度在输出中会被求和并压缩
    • 非收缩维度(Non-contracting dimensions) :也是并行维度,存在于某个操作数和输出中。左操作数和右操作数分别具有自己的非收缩维度集合,这些维度会被输出继承
  • 分片传播(Sharding propagation)优先选择在左操作数、右操作数和输出的批次维度上采用相同的分片方式,因为这样可以避免任何跨分片通信
  • 但在实际情况中,这并不总是可行的,以下三种情况需要跨分片通信:
    • 1)重分片(Resharding) :在论文构建的混合专家模型中,专家分配逻辑(算法2第3行)需要在爱因斯坦求和运算后切换分片维度。由于使用AllToAll可以高效实现重分片(5.2节),论文首先在本地执行爱因斯坦求和运算,然后将结果重分片到目标维度,如图4a所示
    • 2)累积部分结果(Accumulating partial results) :如果输入沿收缩维度进行分片,本地计算得到的结果只是部分结果,需要使用AllReduce将这些部分结果合并以生成最终结果,如图4b所示
    • 3)循环切片(Slicing in a loop) :在某些场景下,论文还实现了一种类似于Cannon算法(1969)的方法,以限制每个分片上张量的大小。例如,如果两个操作数均沿非收缩维度进行分片,由于操作数的非收缩维度不同,无法直接在本地执行爱因斯坦求和运算。复制其中一个操作数不会导致冗余计算,但需要确保被复制的操作数能够放入设备内存
      • 因此,如果操作数的规模过大,论文会保持两个操作数的分片状态,通过循环迭代计算结果的每个切片,并使用 CollectivePermute 来传输输入切片(图4c)
  • 图4:带有跨设备通信的爱因斯坦求和算子分片示例:
    • (a)分片爱因斯坦求和算子
      • 彩色字母(G和E)表示每个张量的分片维度
      • 分片器决定首先沿G维度执行批次并行爱因斯坦求和运算,然后将结果重分片到E维度。(注:图中省略了S和M维度)
    • (b)在收缩维度上分片的简单爱因斯坦求和算子(矩阵乘法)
      • 每个分片计算部分结果,然后通过AllReduce合并得到完整结果
    • (c)在循环中使用CollectivePermute的爱因斯坦求和算子(矩阵乘法)
      • 通过循环每次计算一个切片,整个过程中不会出现全尺寸张量
Supporting a Complete Set of Operators
  • 为了使 SPMD 分片器能够支持完整的算子集合,且不对张量形状或算子配置施加额外限制,论文解决了若干额外挑战
  • 这些挑战通常涉及分片间的非对称计算或通信模式,而由于单一程序需要对所有分片都具有通用性,这些模式在 SPMD 中尤其难以表达
  • 论文不能简单地根据运行时设备ID在单一程序中创建多个分支,因为这会导致程序规模激增

Static shapes and uneven partitioning

  • XLA要求张量形状是静态的(注:中间表示中有限的动态性对于高效适配加速器通常是必要的)
  • 然而,对计算进行分片时,由于维度大小可能无法被分片数量整除,并非所有分片的输入/输出形状都相同
  • 在这种情况下,会将形状大小向上取整到分片数量的下一个整数倍,填充区域(padded region)中的数据可以是任意值
  • 在计算算子时,为了保证正确性,可能需要在填充区域中填充已知值
    • 例如,对 Reduce-Add 算子进行分片时,需要使用单位元0作为填充值
    • 考虑一个示例:待分片维度大小为 15,无法被 2(分片数量)整除,因此分片1比所需大小多一列
    • 论文生成一个范围为 [0, 8) 的 Iota 算子,加上分片偏移量(由分片 ID×8 计算得到),并与全尺寸偏移量(15)进行比较。根据比较得到的谓词值(predicate value),选择从操作数中取值或从 0 中取值,最终得到掩码操作数(masked operand)

Static operator configurations

  • XLA算子具有静态配置,例如卷积(Convolution)中定义的填充(padding)、步幅(stride)和扩张(dilation)
  • 但不同分片可能不会使用相同的算子配置
    • 例如,对于卷积算子,最左侧的分片会在其左侧应用填充,而最右侧的分片会在其右侧应用填充
    • 在这种情况下,分片器可能会选择让某些分片生成略多于所需的数据,然后切分出无关部分
  • 附录A.4讨论了卷积和类似算子的示例

Halo Exchange

  • 某些算子具有一种通信模式,需要与相邻分片交换部分数据,论文称之为 Halo Exchange
  • 论文使用 CollectivePermute 算子在分片间交换 Halo 数据
  • Halo Exchange 最典型的应用场景是对基于窗口的算子(如卷积、ReduceWindow)进行分片,因为相邻分片可能需要重叠的输入数据(图5a)
  • 在实际应用中,由于窗口配置(扩张、步幅和填充)的复杂使用以及非均匀 Halo 大小,这些算子的 Halo Exchange 通常需要结合适当的填充、切片和掩码操作
    • 论文在附录A.4中描述了各种场景
  • Halo Exchange 的另一个应用场景是改变形状大小的数据格式化算子(data formatting operator)
    • 例如,经过 Slice 或 Pad 算子后,张量的形状会发生变化,分片间的边界也会随之改变
    • 这需要论文重新对齐不同分片上的数据,而这可以通过 Halo Exchange 的形式来处理(图5b)
  • 其他数据格式化算子(尽管在逻辑上不改变形状大小)也可能需要 Halo Exchange ,这主要是由于静态形状约束和非均匀分片
    • 例如,Reverse 算子会反转张量中元素的顺序,但如果对其进行非均匀分片,需要在分片间移动数据,以确保填充在逻辑上位于结果张量的右侧
    • 另一个示例是Reshape算子:
      • 考虑将形状为 [3, 2] 的张量重塑为 [6],其中输入在第一个维度上采用2种非均匀分片方式(分片形状为 [2, 2]),输出也采用2种分片方式(分片形状为[3])
      • 由于非均匀分片,输入中存在填充,但重塑后输出张量不再有填充;因此,需要采用与Slice类似的 Halo Exchange 方式(图5c)

Compiler optimizations

  • SPMD 分片器会创建各种数据格式化算子,以执行切片、填充、连接、掩码和 Halo Exchange 操作
  • 为解决这一问题,论文利用了 XLA 在 TPU 上的融合能力(fusion capabilities),以及针对切片和填充的代码移动优化(code motion optimizations),在很大程度上隐藏了数据格式化的开销
  • 因此,即使在大量使用掩码和填充的卷积网络中,运行时开销通常也可以忽略不计

Massively Multilingual, Massive Machine Translation, M4

  • 待补充

Performance and Memory Consumption

  • 待补充