整体说明
- 不记录梯度的方法包括:
torch.no_grad()、torch.set_grad_enabled(False)、tensor.detach()、tensor.requires_grad = False等- 注:除了不记录梯度,这些方法还会释放计算图占用的内存,显著降低内存开销
- 特殊注意 :
model.eval()仅影响特定层(如Dropout、BatchNorm)的行为,不会禁用梯度计算 ,必须配合with torch.no_grad():使用才能完全关闭梯度 - 各种模式选择建议
- 若要临时停止梯度计算,推荐使用
with torch.no_grad()上下文管理器或者torch.set_grad_enabled() - 若想永久性地停止某个张量的梯度计算,可使用
detach()方法或者直接设置requires_grad=False - 对大型模型进行微调时,在模型层面设置
requires_grad能有效节省内存 - 模型推理阶段,要同时使用
model.eval()和with torch.no_grad()
- 若要临时停止梯度计算,推荐使用
方法一:使用 with torch.no_grad() 上下文管理器
with torch.no_grad()管理器能够暂停所有计算图的构建,进而显著降低内存的使用量并加快计算速度1
2
3
4
5
6import torch
x = torch.tensor([1.0], requires_grad=True)
with torch.no_grad():
y = x * 2
print(y.requires_grad) # 输出 False
方法二:使用 @torch.no_grad() 作为装饰器
装饰整个函数,使其在执行期间禁用梯度计算
1
2
3
def inference(model, input_data):
return model(input_data)注:
@torch.no_grad()装饰器和with torch.no_grad()上下文管理器的效果是一样的,一个针对方法,一个针对上下文
方法三:使用 detach() 方法
- 运用
detach()方法可以创建一个新的张量,这个新张量和计算图没有关联1
2
3
4x = torch.tensor([1.0], requires_grad=True)
y = x * 2
z = y.detach() # z 和计算图无关
print(z.requires_grad) # 输出 False
方法四:使用 torch.set_grad_enabled() 实现全局开关
- 借助这个全局开关,能够控制整个代码块是否进行梯度计算
1
2
3
4
5x = torch.tensor([1.0], requires_grad=True)
torch.set_grad_enabled(False)
y = x * 2
torch.set_grad_enabled(True)
print(y.requires_grad) # 输出 False
方法五:在张量层面设置 requires_grad=False
你可以在创建张量时或者之后,把
requires_grad属性设置为False,以此来阻止梯度的计算1
2
3
4
5
6
7x = torch.tensor([1.0], requires_grad=False) # 创建时设置
y = x * 2
print(y.requires_grad) # 输出 False
# 或者之后设置
x = torch.tensor([1.0], requires_grad=True)
x.requires_grad = False特别示例,也可以对模型参数直接操作:对于预训练模型进行微调时,你可以冻结部分层,只对特定层计算梯度
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15import torch.nn as nn
model = nn.Sequential(
nn.Linear(10, 10),
nn.ReLU(),
nn.Linear(10, 1)
)
# 冻结所有参数
for param in model.parameters():
param.requires_grad = False
# 只训练最后一层
for param in model[2].parameters():
param.requires_grad = True
模型推理时的特殊场景
特别地,推理时,常常用
model.eval()和with torch.no_grad()或@torch.no_grad()结合在模型推理阶段,同时使用这两个方法能够有效减少内存占用并提高计算效率
1
2
3model.eval() # 关闭 Dropout 和 BatchNorm 等训练特有的层
with torch.no_grad():
outputs = model(inputs)在仅推理的场景中,更常用
torch.inference_mode()来替代torch.no_grad()(要求 PyTorch 版本 >= 1.9),以获得更好的性能(跳过一些推理阶段非必要的检查),详情见:PyTorch——torch.no_grad的用法torch.inference_mode()可以用作上下文管理器或者装饰器- 推理场景优先推荐使用
torch.inference_mode(),但torch.inference_mode()仅适用于推理场景,其他场景不可乱用 torch.inference_mode()不能像torch.no_grad()那样嵌套使用
附录:PyTorch中禁用梯度计算的方法对比
- 整体对比
方法 是否不记录梯度 是否释放计算图内存 作用范围 使用场景 with torch.no_grad():是 是 代码块 推理阶段(如模型预测)、不需要梯度的计算(如验证集评估)。 torch.set_grad_enabled(False)是 是 全局(直到恢复为True) 临时关闭整个代码段的梯度计算,例如批量推理。 tensor.detach()是 是(对新张量) 单个张量 从计算图中分离张量,例如生成对抗网络(GAN)中的生成器输出。 tensor.requires_grad = False是 是(设置后) 单个张量或模型参数 冻结预训练模型的部分层,只训练特定参数。 model.eval()否 否 模型(影响Dropout/BatchNorm),使用 model.train()可恢复推理阶段,关闭训练特有的层(如Dropout),但仍会记录梯度(需配合 no_grad()使用)。
附录:前向过程不涉及梯度计算,为什么需要关闭梯度?
- 在调用
loss.backward()之前,PyTorch 不会计算梯度,但是模型的前向过程会构建计算图,这也会消耗额外的内存
附录:非常简单的代码示例
- 简单的代码示例:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22import torch
# 方法1: torch.no_grad()
with torch.no_grad():
y = x * 2 # y 不记录梯度,不构建计算图
# 方法2: set_grad_enabled
torch.set_grad_enabled(False)
y = x * 2 # 全局禁用梯度
torch.set_grad_enabled(True)
# 方法3: detach()
y = x.detach() * 2 # y 是脱离计算图的新张量
# 方法4: requires_grad = False
x.requires_grad = False
y = x * 2 # x 不再需要梯度,y 也不记录
# 方法5: model.eval()(需配合 no_grad())
model.eval()
with torch.no_grad():
outputs = model(inputs) # 完全禁用梯度