整体说明
- 在 PyTorch 中,
torch.no_grad()可用作装饰器@torch.no_grad()或上下文管理器with torch.no_grad()(两者形式不同,但作用相同),用于禁用梯度计算 - 如果 PyTorch 版本 >= 1.9,可以考虑使用
torch.inference_mode()来替代torch.no_grad(),以获得更好的性能
torch.no_grad()的作用
torch.no_grad()的主要作用是临时关闭自动求导机制(autograd)。在被装饰的函数或代码块中,所有涉及张量的操作都不会构建计算图(computation graph),从而节省内存和计算资源:- 自动求导机制 :PyTorch 默认会记录张量操作的历史信息(即计算图),以便支持反向传播(
backward())来计算梯度 - 关闭梯度计算 :在推理阶段或其他不需要梯度的场景下,关闭自动求导可以减少内存占用,提高运行效率
- 自动求导机制 :PyTorch 默认会记录张量操作的历史信息(即计算图),以便支持反向传播(
使用场景
模型推理(Inference)
- 在推理阶段,我们只需要前向传播(forward pass),而不需要计算梯度。因此,可以使用
@torch.no_grad()来优化性能1
2
3
4
5
6
7
8
9
def evaluate_model(model, test_loader):
model.eval() # 设置模型为评估模式,改回训练模式可以调用 model.train()
total_loss = 0
for data, target in test_loader:
output = model(data)
loss = loss_function(output, target)
total_loss += loss.item()
return total_loss
更新模型参数时不计算梯度
- 在某些情况下,我们需要手动更新模型参数(例如权重剪枝、量化等),但不希望这些操作影响梯度计算
1
2
3
4
def update_weights(model):
for param in model.parameters():
param.add_(1.0) # 在参数上加 1,不会记录到计算图中
计算评估指标时不计算梯度
- 当计算评估指标(如准确率、F1 分数等)时,不需要梯度计算
1
2
3
4
5
6
7
8
9
10
def compute_accuracy(model, data_loader):
correct = 0
total = 0
for inputs, labels in data_loader:
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return correct / total
附录:装饰器和上下文管理器的示例
作为装饰器
- 装饰整个函数,使其在执行期间禁用梯度计算
1
2
3
def inference(model, input_data):
return model(input_data)
作为上下文管理器
- 仅在特定代码块中禁用梯度计算
1
2
3def inference(model, input_data):
with torch.no_grad():
return model(input_data)
附录:推理场景 torch.inference_mode() 的使用
- 从 PyTorch 1.9 开始,引入了
torch.inference_mode(),它是torch.no_grad()的更高效替代品,专门用于推理阶段。与torch.no_grad()相比:- 性能更高 :
torch.inference_mode()会跳过一些额外的检查,进一步提升性能 - 不可嵌套 :
torch.inference_mode()不能像torch.no_grad()那样嵌套使用 - 推荐使用 :如果只用于推理,建议优先使用
torch.inference_mode()
- 性能更高 :
- 示例:
1
2
3
4
def evaluate_model(model, test_loader):
model.eval()
...