PyTorch——关闭梯度的方法


整体说明

  • 不记录梯度的方法包括:torch.no_grad()torch.set_grad_enabled(False)tensor.detach()tensor.requires_grad = False
    • 注:除了不记录梯度,这些方法还会释放计算图占用的内存,显著降低内存开销
  • 特殊注意model.eval() 仅影响特定层(如Dropout、BatchNorm)的行为,不会禁用梯度计算 ,必须配合 with torch.no_grad(): 使用才能完全关闭梯度
  • 各种模式选择建议
    • 若要临时停止梯度计算,推荐使用with torch.no_grad()上下文管理器或者torch.set_grad_enabled()
    • 若想永久性地停止某个张量的梯度计算,可使用detach()方法或者直接设置requires_grad=False
    • 对大型模型进行微调时,在模型层面设置requires_grad能有效节省内存
    • 模型推理阶段,要同时使用model.eval()with torch.no_grad()

方法一:使用 with torch.no_grad() 上下文管理器

  • with torch.no_grad() 管理器能够暂停所有计算图的构建,进而显著降低内存的使用量并加快计算速度
    1
    2
    3
    4
    5
    6
    import torch

    x = torch.tensor([1.0], requires_grad=True)
    with torch.no_grad():
    y = x * 2
    print(y.requires_grad) # 输出 False

方法二:使用 @torch.no_grad() 作为装饰器

  • 装饰整个函数,使其在执行期间禁用梯度计算

    1
    2
    3
    @torch.no_grad()
    def inference(model, input_data):
    return model(input_data)
  • 注:@torch.no_grad() 装饰器和 with torch.no_grad() 上下文管理器的效果是一样的,一个针对方法,一个针对上下文


方法三:使用 detach() 方法

  • 运用detach()方法可以创建一个新的张量,这个新张量和计算图没有关联
    1
    2
    3
    4
    x = torch.tensor([1.0], requires_grad=True)
    y = x * 2
    z = y.detach() # z 和计算图无关
    print(z.requires_grad) # 输出 False

方法四:使用 torch.set_grad_enabled() 实现全局开关

  • 借助这个全局开关,能够控制整个代码块是否进行梯度计算
    1
    2
    3
    4
    5
    x = torch.tensor([1.0], requires_grad=True)
    torch.set_grad_enabled(False)
    y = x * 2
    torch.set_grad_enabled(True)
    print(y.requires_grad) # 输出 False

方法五:在张量层面设置 requires_grad=False

  • 你可以在创建张量时或者之后,把requires_grad属性设置为False,以此来阻止梯度的计算

    1
    2
    3
    4
    5
    6
    7
    x = torch.tensor([1.0], requires_grad=False)  # 创建时设置
    y = x * 2
    print(y.requires_grad) # 输出 False

    # 或者之后设置
    x = torch.tensor([1.0], requires_grad=True)
    x.requires_grad = False
  • 特别示例,也可以对模型参数直接操作:对于预训练模型进行微调时,你可以冻结部分层,只对特定层计算梯度

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    import torch.nn as nn

    model = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 1)
    )

    # 冻结所有参数
    for param in model.parameters():
    param.requires_grad = False

    # 只训练最后一层
    for param in model[2].parameters():
    param.requires_grad = True

模型推理时的特殊场景

  • 特别地,推理时,常常用 model.eval()with torch.no_grad()@torch.no_grad() 结合

  • 在模型推理阶段,同时使用这两个方法能够有效减少内存占用并提高计算效率

    1
    2
    3
    model.eval()  # 关闭 Dropout 和 BatchNorm 等训练特有的层
    with torch.no_grad():
    outputs = model(inputs)
  • 在仅推理的场景中,更常用 torch.inference_mode() 来替代 torch.no_grad()(要求 PyTorch 版本 >= 1.9),以获得更好的性能(跳过一些推理阶段非必要的检查),详情见:PyTorch——torch.no_grad的用法

    • torch.inference_mode() 可以用作上下文管理器或者装饰器
    • 推理场景优先推荐使用 torch.inference_mode(),但 torch.inference_mode() 仅适用于推理场景,其他场景不可乱用
    • torch.inference_mode() 不能像 torch.no_grad() 那样嵌套使用

附录:PyTorch中禁用梯度计算的方法对比

  • 整体对比
    方法 是否不记录梯度 是否释放计算图内存 作用范围 使用场景
    with torch.no_grad(): 代码块 推理阶段(如模型预测)、不需要梯度的计算(如验证集评估)。
    torch.set_grad_enabled(False) 全局(直到恢复为True) 临时关闭整个代码段的梯度计算,例如批量推理。
    tensor.detach() 是(对新张量) 单个张量 从计算图中分离张量,例如生成对抗网络(GAN)中的生成器输出。
    tensor.requires_grad = False 是(设置后) 单个张量或模型参数 冻结预训练模型的部分层,只训练特定参数。
    model.eval() 模型(影响Dropout/BatchNorm),使用model.train()可恢复 推理阶段,关闭训练特有的层(如Dropout),但仍会记录梯度(需配合no_grad()使用)。

附录:前向过程不涉及梯度计算,为什么需要关闭梯度?

  • 在调用 loss.backward() 之前,PyTorch 不会计算梯度,但是模型的前向过程会构建计算图,这也会消耗额外的内存

附录:非常简单的代码示例

  • 简单的代码示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    import torch

    # 方法1: torch.no_grad()
    with torch.no_grad():
    y = x * 2 # y 不记录梯度,不构建计算图

    # 方法2: set_grad_enabled
    torch.set_grad_enabled(False)
    y = x * 2 # 全局禁用梯度
    torch.set_grad_enabled(True)

    # 方法3: detach()
    y = x.detach() * 2 # y 是脱离计算图的新张量

    # 方法4: requires_grad = False
    x.requires_grad = False
    y = x * 2 # x 不再需要梯度,y 也不记录

    # 方法5: model.eval()(需配合 no_grad())
    model.eval()
    with torch.no_grad():
    outputs = model(inputs) # 完全禁用梯度