DL——梯度检查点技术


整体总结

  • 梯度检查点(Gradient Checkpointing)技术广泛应用于当前大规模深度学习模型训练中,能有效降低显存的使用
  • 梯度检查点技术也叫做 重计算(re-materialization)技术
  • Training Deep Nets with Sublinear Memory Cost, MIT, 2016这篇文章的核心贡献是:
    • 提出并系统化了一种 “用计算换内存”的通用方法(在自动微分领域也被称为梯度检查点技术)
    • 通过分段和重计算,巧妙地将深度学习训练过程中的内存瓶颈从 \(O(n)\) 降低到 \(O(\sqrt{n})\),甚至理论上的 \(O(\log n)\),极大地拓展了论文使用现有硬件能够训练的模型规模和深度

核心问题:为什么训练深度网络如此消耗内存?

  • 在训练神经网络时,标准的训练流程包含两个步骤:
    • 前向传播 (Forward Pass)
      • 输入数据从网络的第一层开始,逐层计算,直到最后一层得到输出
      • 在这个过程中,每一层的输出(称为“激活”或“特征图”)都需要被保存下来
    • 反向传播 (Backward Pass)
      • 计算输出的损失(预测值与真实值的差距),然后从最后一层开始,反向逐层计算梯度
      • 为了计算某一层参数的梯度,通常需要用到该层在前向传播时产生的激活值
  • 问题出现:
    • 为了进行反向传播,必须在内存中保留网络中每一层的激活值
    • 如果一个网络有 \(n\) 层,那么内存消耗就大致与 \(n\) 成正比 ,即内存成本为 \(O(n)\)
    • 对于现在动辄成百上千层的深度模型(如 ResNet),这笔内存开销会迅速占满顶配 GPU 的几十 GB 显存,从而限制了探索更深、更复杂模型的能力

传统内存优化方法(治标不治本)

  • 论文首先提到了一些已有的内存优化技术,这些技术主要通过分析计算图 (Computation Graph) 来实现
    • 原地操作 (In-place Operation)
      • 如果一个操作的输入值在后续计算中不再需要,那么其输出可以直接覆盖输入的内存空间
      • 例如,y = relu(x),如果 x 后面用不到了,y 的结果可以直接写在 x 的内存里
    • 内存共享 (Memory Sharing)
      • 分析所有变量的“生命周期”,将生命周期不重叠的变量共享同一块内存
  • 这些方法能将内存占用降低2到3倍,但无法改变内存消耗随网络层数线性增长的趋势,当网络深到一定程度时,内存瓶颈依然存在

梯度检查点:用计算换内存 (Trade Computation for Memory)

  • 既然保存所有中间结果是内存消耗的根源,那么论文提出的核心思想非常直接:
    • 不保存所有中间结果,只保存其中一部分
    • 当反向传播需要用到某个被丢弃的中间结果时,再临时重新计算它
  • 这是一种典型的“用时间换空间”的策略
    • 虽然会增加一些计算量(因为需要重新执行部分前向计算),但可以极大地降低内存峰值

工作原理一:\(O(\sqrt{n})\) 内存成本算法(实用策略)

  • 这是论文提出的主要实用算法
  • \(O(\sqrt{n})\) 内存成本算法原理如下:
    • 1)分段 (Segmenting) :将一个包含 \(n\) 层的网络链条,切分成 \(k\) 个小段(segment)
    • 2)前向传播 :在正常的前向传播过程中,只保存每个分段的最终输出 ,而丢弃每个分段内部的所有中间激活值
    • 3)反向传播
      • 当反向传播进行到第 \(i\) 段时,由于计算该段的梯度需要其内部的激活值(这些值已经被丢弃了),算法会执行一次“局部前向传播”:
        • 利用保存的第 \(i-1\) 段的输出作为输入,重新计算一次第 \(i\) 段的前向传播,以得到所有需要的激活值
      • 计算完梯度后,这些临时重新计算的激活值可以立即被丢弃

关键推导:为什么是 \(O(\sqrt{n})\)?

  • 内存成本分析:假设网络总共有 \(n\) 层,被均匀地切分成 \(k\) 段,那么每一段的长度就是 \(n/k\)层
  • 总内存成本主要由两部分构成:
    • 1)段间内存 (Inter-segment Memory) :用于存储 \(k\) 个分段的输出,以便在反向传播时作为“检查点”(checkpoint)
      • 这部分的成本是 \(O(k)\)
    • 2)段内内存 (Intra-segment Memory) :在对任何一段进行反向传播时,需要临时重新计算并存储该段内部的所有激活值
      • 由于所有段中最大的内存开销决定了峰值,这部分的成本是 \(O(n/k)\)
  • 因此,总的内存成本可以表示为:
    $$\text{Cost}(n, k) = O(k) + O(n/k)$$
  • 为了让总成本最低,论文需要让这两部分达到一个平衡。一个简单的优化方法是让它们的量级相等(实际上等价于导数为 0 的推导):
    $$k \approx \frac{n}{k} \implies k^2 \approx n \implies k = \sqrt{n}$$
  • 当选择 \(k=\sqrt{n}\) 时,总内存成本为:
    $$\text{Cost} = O(\sqrt{n}) + O(n/\sqrt{n}) = O(\sqrt{n}) + O(\sqrt{n}) = O(\sqrt{n})$$
  • 至此,我们就成功地将内存成本从线性 \(O(n)\) 降到了亚线性 \(O(\sqrt{n})\)
  • 作为代价,整个训练过程大约需要额外进行一次完整的前向传播计算(因为每个分段都被重新计算了一次),使得训练时间增加了约 30%
    • 问题:30% 的计算时间怎么来的呢?
    • 回答:反向传播的时间复杂度大约是前向传播的 2~3 倍,折合计算以后大致能算出这个数字(增加了一次前向传播计算),详情见附录

工作原理二:\(O(\log n)\) 内存成本算法(理论极限)

  • 论文进一步展示,通过递归 (Recursion) 的方式,可以实现更低的内存成本
  • 我们可以把一个分段本身看作一个“超级操作符”
  • 对这个超级操作符内部的计算,作者同样可以再次应用分段和重计算的策略

关键推导:\(O(\log n)\) 的递推关系

  • 让 \(g(n)\) 表示训练一个 \(n\) 层网络所需的内存
  • 假设将这个网络分成 \(k+1\) 个子问题,每个子问题的规模是 \(n/(k+1)\)
  • 为了连接这些子问题,需要存储 \(k\) 个中间结果
  • 那么,\(g(n)\) 可以表示为递推公式:
    $$g(n) = k + g\left(\frac{n}{k+1}\right)$$
    • 这是一个典型的对数关系
    • \(k\) 是存储这 \(k\) 个结果的成本
    • \(g\left(\frac{n}{k+1}\right)\) 是解决其中一个子问题所需的成本
    • 注:这里使用 \(k+1\) 或 \(k\) 不影响最终结果
  • 通过解这个递推公式,我们可以得到:
    $$g(n) = k \cdot \log_{k+1}(n)$$
  • 作为一个特例,如果每次只将问题一分为二,即只存储一个中间结果(\(k=1\)),那么递推关系变为 \(g(n) = 1 + g(n/2)\),解得:
    $$g(n) = \log_2(n)$$
    • 注:简单理解一下,展开 \(g(n)\) 后,大致共有 \(\log_2(n)\) 个 1
  • 这揭示了一个终极的理论可能性:
    • 训练一个 \(n\) 层网络的内存成本可以降低到 \(O(\log n)\)
    • 不过,这种方法的计算开销会大得多(需要 \(O(\log n)\) 次额外的前向传播,因此在实践中不如 \(O(\sqrt{n})\) 策略常用

实验效果与结论

  • 作者通过在深度残差网络 (ResNet) 和长短期记忆网络 (LSTM) 上的实验,验证了该方法的有效性
    • 对于一个 1000层 的 ResNet,标准优化方法需要 48GB 显存,而使用亚线性算法后仅需 7GB
    • 在 LSTM 上,该方法同样带来了 超过4倍 的内存节省
    • 代价是训练速度降低了大约30%,这对于能够训练原本无法训练的模型来说,是一个非常值得的交换

Transformer 中的梯度检查点

  • Transformer中的梯度检查点(Gradient Checkpointing)与上述论文中的基本原理上是完全相同的 ,但其应用方式和带来的收益上,针对Transformer 的结构有更多特点
  • 无论是用于CNN、RNN还是Transformer,梯度检查点的核心思想始终是:
    • 目标 :打破模型训练时内存消耗与网络深度(层数)之间的线性关系
    • 方法 :在前向传播时,不再保存所有中间层的激活值,而是只保存少数几个关键节点(检查点)
    • 代价 :在反向传播时,当需要用到被丢弃的激活值时,就从最近的一个检查点开始,重新进行一小段前向计算来恢复它们
    • 权衡 :本质上都是“用计算换内存”的策略

针对 Transformer 结构的应用说明

  • Transformer的独特结构使得梯度检查点的应用非常直接,且效果尤其显著
  • 1. 应用位置非常明确
    • 一个标准的 Transformer 模型是由一个个完全相同的 Transformer Block 堆叠而成的
      • 每个块通常包含一个多头自注意力(Multi-Head Self-Attention)层和一个前馈神经网络(Feed-Forward Network, FFN)层
    • 最自然、最常见的应用方式就是 将每一个Transformer块作为一个分段(Segment)
      • 前向传播时
        • 当数据流经第 \(i\) 个Transformer块时,只保留送入这个块的输入(也就是第 \(i-1\) 块的输出)
        • 在块内部计算过程中产生的所有中间结果,例如注意力分数矩阵(Attention Scores)、注意力权重(Attention Weights)、FFN层的中间激活等,计算完毕后立即被丢弃
      • 反向传播时
        • 当反向传播回第 \(i\) 个块时,算法会利用之前保存的输入,重新执行一次该块的前向计算,从而得到计算梯度所必需的那些中间结果
  • 2. 带来的收益为何对 Transformer 尤其显著
    • 梯度检查点能极大缓解Transformer在两个维度上的内存压力:
    • 深度(层数 \(L\))
      • 现代的大型语言模型(如GPT、LLaMA)可以有几十甚至上百层
      • 如果没有梯度检查点,内存消耗会随着层数 \(L\) 线性暴增。梯度检查点将这个成本从 \(O(L)\) 降到了 \(O(\sqrt{L})\) ,使得训练极深的Transformer 成为可能
    • 序列长度(Sequence Length \(S\)) :这是Transformer最独特的内存瓶颈
      • 注意力矩阵的二次方开销
        • 自注意力机制的核心是计算一个注意力分数矩阵,其大小为 (序列长度 x 序列长度)
        • 这意味着内存开销与序列长度成二次方关系 ,即 \(O(S^2)\)
        • 当序列很长时(例如4096、8192甚至更长),这个矩阵会变得异常巨大
      • 梯度检查点的作用
        • 梯度检查点不能改变单次注意力计算需要 \(O(S^2)\) 内存峰值的事实
        • 但它能确保不必同时在内存中保留每一层的这个巨大矩阵
        • 在没有检查点的情况下,内存中需要为 \(L\) 个注意力矩阵的激活值(或其相关值)分配空间
        • 有了检查点,在任何时候,只需要为当前正在重计算的那一个块的注意力矩阵分配内存
        • 这极大地降低了总体的内存占用

附录:梯度检查点技术增加了多少训练成本?

  • 论文中 “训练时间增加了约30%” 这个数字主要是一个经验性的测量结果 ,来源于作者在特定硬件上进行的基准测试,并且这个结果也与理论上的计算开销分析相符

实验测量结果(经验来源)

  • 作者在论文的第5.4节 (Impact on Training Speed)图7 专门讨论了这个问题
    • 测试方法
      • 作者在单个 Titan X GPU 上对不同的内存分配策略进行了速度基准测试
      • 测量了在 ResNet 和 LSTM 两种模型上,处理一个批次(batch)数据所需的实际运行时间(秒)
    • 对比对象
      • 比较了采用标准内存优化(论文中称为 “sharing”)的策略和采用亚线性内存成本(”sublinear plan”)策略的速度
    • 测量结论
      • 实验结果图表(图7)直观地显示,”sublinear plan” 的时间成本曲线始终在 “sharing” 曲线之上
      • 论文在图7的说明文字和正文中明确指出,使用亚线性内存方案会带来“大约30%的额外运行时成本”

理论计算分析(理论支撑)

  • 训练一个批次的主要计算量如下:
    • 标准训练流程 :包含一次完整的前向传播(Forward Pass)和一次完整的反向传播(Backward Pass):
      $$ 总计算量 \approx 1F + 1B $$
    • 亚线性方案的流程 :它在反向传播过程中需要重新计算前向传播,因此:
      $$ 总计算量 \approx 1F + (1F_recompute + 1B) = 2F + 1B $$
  • 论文中提到,通常一次反向传播的计算量大约是前向传播的两倍(\(B \approx 2F\))
    • 实际上是大约 2~3 倍的样子
  • 基于上述这个假设,我们可以估算增加的计算开销:
    • 标准流程计算量
      $$T_{standard} = T_F + T_B \approx T_F + 2T_F = 3T_F$$
    • 亚线性方案计算量
      $$T_{sublinear} = 2T_F + T_B \approx 2T_F + 2T_F = 4T_F$$
  • 增加的运行时间百分比约为:
    $$\frac{T_{sublinear} - T_{standard}}{T_{standard}} = \frac{4T_F - 3T_F}{3T_F} = \frac{1T_F}{3T_F} = \frac{1}{3} \approx 33.3%$$
  • 这个理论计算出的 33.3% 与实验中测量到的 约30% 基本吻合(考虑到反向传播大致是前向传播的 2~3 倍,那就基本符合了)