整体说明
torch.compile()函数是 PyTorch 2.0 引入的一个重要功能,用于对模型进行编译优化,以提升训练和推理性能- 将 PyTorch 模型从“解释型”的逐行执行模式,转变为“编译型”的、一次性优化的执行模式
- 当模型被编译后,它在后续的推理或训练中会运行得更快,并且使用的内存可能更少
torch.compile()的核心作用是通过对模型计算图进行一系列优化(如算子融合、常量折叠、内存优化等),生成更高效的代码,从而加速模型的执行- 这个过程是自动的,并且大部分时间是无侵入性的,不需要修改模型的代码
torch.compile()效果提升:- 因模型结构和硬件而异,通常对于有大量小操作(如 Transformer 模型)或对 GPU 算力要求高的模型效果更显著
- 在大规模训练和推理场景中效果显著
- 首次运行编译后的模型会有一定的启动延迟,用于图优化和代码生成(
compile函数是惰性执行的),但后续重复调用会更快 - 编译过程可能会增加模型的内存占用,需根据实际情况调整
调用 torch.compile(model) 后会发生什么?
- 具体来说,调用
torch.compile(model)会发生以下过程: - 1)捕获模型计算图(Graph Capturing) :分析模型的前向传播逻辑,记录张量的操作序列和依赖关系,构建计算图表示
- 当你第一次调用
torch.compile编译后的模型时,它会追踪模型的前向传播过程 - 这就像在记录模型中的每一步操作,比如矩阵乘法、卷积、激活函数等
- PyTorch 会创建一个计算图(computational graph) ,这个图代表了模型从输入到输出的所有计算路径
- 这个过程是惰性的,即只在第一次实际运行模型时发生(所以第一次调用比较慢,后面会比较快)
- 计算图被捕获后,编译器会对计算图进行一系列优化(接下面)
- 当你第一次调用
- 2)图优化(Graph Optimization) :包含一系列优化,例如:
- 算子融合(Operator Fusion) : 将多个连续的小算子合并为一个大算子(如将卷积+批归一化+激活函数融合),减少 kernel 调用次数和内存读写
- 常量折叠(Constant Folding) :计算图中固定不变的常量表达式会被预先计算,避免重复计算
- 死代码消除(Dead Code Elimination) : 移除计算图中未被使用的节点或操作,减少不必要计算
- 内存优化(Memory Optimization) : 优化张量的内存分配和复用,重新安排计算顺序,以减少中间结果所需的内存,从而更好地利用缓存
- 优化后的计算图会用来生成高效的、针对特定硬件(如 GPU)的可执行代码(具体见下节)
- 3)代码生成 :
- 编译器将高级的 PyTorch 操作转换成底层的、更接近硬件指令的特定后端的代码(如 CPU 上的 C++/OpenMP 代码,GPU 上的 CUDA 代码)
- 例如,它可能会将 PyTorch 的张量操作转换成 CUDA 内核
- 支持多种后端(如
inductor、aot_eager等),默认使用inductor后端(也叫做 TorchInductor 后端),它能生成高效的 GPU/CPU 代码
- 编译器将高级的 PyTorch 操作转换成底层的、更接近硬件指令的特定后端的代码(如 CPU 上的 C++/OpenMP 代码,GPU 上的 CUDA 代码)
- 4)返回编译后的模型 :
torch.compile(model)返回一个经过包装的模型对象,其接口与原模型一致(可直接调用forward方法或进行训练),但内部执行的是优化后的代码
使用示例
- 使用非常简单,仅需一行代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19import torch
import torch.nn as nn
class DiyModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 32, kernel_size=3)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.conv(x))
model = DiyModel()
# 编译模型
compiled_model = torch.compile(model)
# 使用编译后的模型(接口与原模型万全一致)
x = torch.randn(1, 3, 224, 224)
output = compiled_model(x)