PyTorch——compile函数的理解和使用


整体说明

  • torch.compile() 函数是 PyTorch 2.0 引入的一个重要功能,用于对模型进行编译优化,以提升训练和推理性能
    • 将 PyTorch 模型从“解释型”的逐行执行模式,转变为“编译型”的、一次性优化的执行模式
    • 当模型被编译后,它在后续的推理或训练中会运行得更快,并且使用的内存可能更少
  • torch.compile() 的核心作用是通过对模型计算图进行一系列优化(如算子融合、常量折叠、内存优化等),生成更高效的代码,从而加速模型的执行
    • 这个过程是自动的,并且大部分时间是无侵入性的,不需要修改模型的代码
  • torch.compile() 效果提升:
    • 因模型结构和硬件而异,通常对于有大量小操作(如 Transformer 模型)或对 GPU 算力要求高的模型效果更显著
    • 在大规模训练和推理场景中效果显著
  • 首次运行编译后的模型会有一定的启动延迟,用于图优化和代码生成(compile 函数是惰性执行的),但后续重复调用会更快
  • 编译过程可能会增加模型的内存占用,需根据实际情况调整

调用 torch.compile(model) 后会发生什么?

  • 具体来说,调用 torch.compile(model) 会发生以下过程:
  • 1)捕获模型计算图(Graph Capturing) :分析模型的前向传播逻辑,记录张量的操作序列和依赖关系,构建计算图表示
    • 当你第一次调用 torch.compile 编译后的模型时,它会追踪模型的前向传播过程
    • 这就像在记录模型中的每一步操作,比如矩阵乘法、卷积、激活函数等
    • PyTorch 会创建一个计算图(computational graph) ,这个图代表了模型从输入到输出的所有计算路径
    • 这个过程是惰性的,即只在第一次实际运行模型时发生(所以第一次调用比较慢,后面会比较快)
    • 计算图被捕获后,编译器会对计算图进行一系列优化(接下面)
  • 2)图优化(Graph Optimization) :包含一系列优化,例如:
    • 算子融合(Operator Fusion) : 将多个连续的小算子合并为一个大算子(如将卷积+批归一化+激活函数融合),减少 kernel 调用次数和内存读写
    • 常量折叠(Constant Folding) :计算图中固定不变的常量表达式会被预先计算,避免重复计算
    • 死代码消除(Dead Code Elimination) : 移除计算图中未被使用的节点或操作,减少不必要计算
    • 内存优化(Memory Optimization) : 优化张量的内存分配和复用,重新安排计算顺序,以减少中间结果所需的内存,从而更好地利用缓存
    • 优化后的计算图会用来生成高效的、针对特定硬件(如 GPU)的可执行代码(具体见下节)
  • 3)代码生成
    • 编译器将高级的 PyTorch 操作转换成底层的、更接近硬件指令的特定后端的代码(如 CPU 上的 C++/OpenMP 代码,GPU 上的 CUDA 代码)
      • 例如,它可能会将 PyTorch 的张量操作转换成 CUDA 内核
    • 支持多种后端(如 inductoraot_eager 等),默认使用 inductor 后端(也叫做 TorchInductor 后端),它能生成高效的 GPU/CPU 代码
  • 4)返回编译后的模型
    • torch.compile(model) 返回一个经过包装的模型对象,其接口与原模型一致(可直接调用 forward 方法或进行训练),但内部执行的是优化后的代码

使用示例

  • 使用非常简单,仅需一行代码
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    import torch
    import torch.nn as nn

    class DiyModel(nn.Module):
    def __init__(self):
    super().__init__()
    self.conv = nn.Conv2d(3, 32, kernel_size=3)
    self.relu = nn.ReLU()

    def forward(self, x):
    return self.relu(self.conv(x))

    model = DiyModel()
    # 编译模型
    compiled_model = torch.compile(model)

    # 使用编译后的模型(接口与原模型万全一致)
    x = torch.randn(1, 3, 224, 224)
    output = compiled_model(x)