PyTorch——torch.no_grad的用法


整体说明

  • 在 PyTorch 中,torch.no_grad()可用作装饰器 @torch.no_grad() 或上下文管理器 with torch.no_grad()(两者形式不同,但作用相同),用于禁用梯度计算
  • 如果 PyTorch 版本 >= 1.9,可以考虑使用 torch.inference_mode() 来替代 torch.no_grad(),以获得更好的性能

torch.no_grad()的作用

  • torch.no_grad() 的主要作用是临时关闭自动求导机制(autograd)。在被装饰的函数或代码块中,所有涉及张量的操作都不会构建计算图(computation graph),从而节省内存和计算资源:
    • 自动求导机制 :PyTorch 默认会记录张量操作的历史信息(即计算图),以便支持反向传播(backward())来计算梯度
    • 关闭梯度计算 :在推理阶段或其他不需要梯度的场景下,关闭自动求导可以减少内存占用,提高运行效率

使用场景

模型推理(Inference)

  • 在推理阶段,我们只需要前向传播(forward pass),而不需要计算梯度。因此,可以使用 @torch.no_grad() 来优化性能
    1
    2
    3
    4
    5
    6
    7
    8
    9
    @torch.no_grad()
    def evaluate_model(model, test_loader):
    model.eval() # 设置模型为评估模式,改回训练模式可以调用 model.train()
    total_loss = 0
    for data, target in test_loader:
    output = model(data)
    loss = loss_function(output, target)
    total_loss += loss.item()
    return total_loss

更新模型参数时不计算梯度

  • 在某些情况下,我们需要手动更新模型参数(例如权重剪枝、量化等),但不希望这些操作影响梯度计算
    1
    2
    3
    4
    @torch.no_grad()
    def update_weights(model):
    for param in model.parameters():
    param.add_(1.0) # 在参数上加 1,不会记录到计算图中

计算评估指标时不计算梯度

  • 当计算评估指标(如准确率、F1 分数等)时,不需要梯度计算
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    @torch.no_grad()
    def compute_accuracy(model, data_loader):
    correct = 0
    total = 0
    for inputs, labels in data_loader:
    outputs = model(inputs)
    _, predicted = torch.max(outputs, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
    return correct / total

附录:装饰器和上下文管理器的示例

作为装饰器

  • 装饰整个函数,使其在执行期间禁用梯度计算
    1
    2
    3
    @torch.no_grad()
    def inference(model, input_data):
    return model(input_data)

作为上下文管理器

  • 仅在特定代码块中禁用梯度计算
    1
    2
    3
    def inference(model, input_data):
    with torch.no_grad():
    return model(input_data)

附录:推理场景 torch.inference_mode() 的使用

  • 从 PyTorch 1.9 开始,引入了 torch.inference_mode(),它是 torch.no_grad() 的更高效替代品,专门用于推理阶段。与 torch.no_grad() 相比:
    • 性能更高torch.inference_mode() 会跳过一些额外的检查,进一步提升性能
    • 不可嵌套torch.inference_mode() 不能像 torch.no_grad() 那样嵌套使用
    • 推荐使用 :如果只用于推理,建议优先使用 torch.inference_mode()
  • 示例:
    1
    2
    3
    4
    @torch.inference_mode()
    def evaluate_model(model, test_loader):
    model.eval()
    ...