PyTorch——叶子张量


整体说明

  • tensor.is_leaf=True 的张量被称为叶子张量(Leaf Tensor),也称为叶子变量(Leaf Variable),部分博客或书籍也称为叶子节点
  • 叶子张量分两种类型,即以下两种情况下的张量 tensor.is_leaf 返回 True
    • 类型一:requires_gradFalse 的张量都是叶子张量(由于 requires_gradFalse,所以不会存储梯度)
    • 类型二:requires_gradTrue 的张量,如果是由用户创建的 ,而不是通过其他张量运算得到的,那么它是叶子张量
      • 通过其他张量运算得到的,都是非叶子张量
  • 特别说明:
    • detach() 函数可以将节点从计算图中剥离,使其成为叶子节点,此时 requires_gradFalse,同时 tensor.is_leaf 会变为 True
    • 从 CPU 定义好后移动到 GPU 时产生的张量,或者从 GPU 定义后,挪到 CPU 上的向量,也都是非叶子张量,tensor.is_leaf 返回 False
      • 注意:不论是 GPU 还是 CPU,叶子张量的判定不变,只有用户定义的张量是叶子张量,挪动以后得都不是叶子张量(除非同时修改其
    • 只有叶子张量的 requires_grad 属性可以被修改,非叶子张量的 requires_grad 属性是不能被修改的
      • 理解:非叶子张量都是派生出来的,且 requires_grad=True 的张量,叶子张量计算梯度时依赖性非叶子张量的梯度,不能随便修改 requires_grad 属性
    • 叶子张量的 grad_fn=None (包括 requires_grad=Truerequires_grad=False 的都是)
      • 理解:tensor.grad_fn 属性指向/存储生成 tensor 张量的计算操作(如加法、减法、乘法等),叶子张量要么是直接由用户定义出来的(此时 requires_grad=True),要么是不需要计算梯度的(此时 requires_grad=False
    • 非叶子张量的梯度不会被保存(因为不需要使用)

叶子张量的特性

  • 在反向传播过程中,只有叶子张量的梯度会被保留并存储在张量的grad属性中
    • 用于后续优化器更新参数等操作
    • 比如:神经网络层中的权值 w 的张量均为叶子节点,反向传播 backward() 就是为了求它们的梯度,进而更新权值
  • 非叶子张量梯度反向传播使用完后通常会被清除 ,以节省内存
    • 比如:计算中中间的一些派生阶段就是非叶子张量

叶子张量相关示例

  • 各种叶子张量判断示例
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    import torch as torch

    a = torch.randn(3,3)
    print(a.is_leaf, a.requires_grad) # True False
    a.requires_grad = True
    print(a.is_leaf, a.requires_grad) # True True
    b = a.cuda()
    print(b.is_leaf, b.requires_grad) # False True
    c = b.detach()
    print(c.is_leaf, c.requires_grad) # True False
    d = a + 2
    print(d.is_leaf, d.requires_grad) # False True

    e = torch.randn(3,3, device="cuda", requires_grad=True)
    print(e.is_leaf, e.requires_grad) # True True
    f = e.to("cpu")
    print(f.is_leaf, f.requires_grad) # False True

附录:如何打印非叶子张量的梯度?

  • 在 PyTorch 里,当开启自动求导功能时,中间变量(非叶子张量)的梯度默认不会被保存,目的是节省内存
    • 只有叶子节点(比如直接创建的张量)的梯度会被保留
    • 在任意时刻时,非中间节点的梯度(grad属性)是都是 None
  • 要获取中间变量的梯度,有两种方法:
    • 运用 retain_grad() 方法保留梯度
    • 借助钩子(hook)来捕获梯度(可以打印或者赋值给全局变量)
  • 打印非节点张量的代码示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    import torch

    x = torch.tensor(2.0, requires_grad=True)
    y = x**2
    z = y**2

    # 方法一:使用 retain_grad() 显式保留梯度
    y.retain_grad()

    z.backward(retain_graph=True) # 若不使用 retain_graph=True,计算图会在 backward 被清空,则后续想要调用 backward() 前需要重新构造计算图
    print("方法一:", y.grad) # 输出: 16.0

    # 方法二:使用钩子 hook 捕捉梯度
    gradient_list = []
    def save_gradient(grad):
    gradient_list.append(grad)

    ## 如果 前面的 z.backward(retain_graph=True) 不使用 retain_graph=True,则 backward() 会清空计算图,这里就需要重新构造计算图,目前使用的是 z.backward(retain_graph=True) ,不需要下面的两句
    # y = x**2
    # z = y**2

    hook = y.register_hook(save_gradient)
    z.backward()
    hook.remove() # 重点:即时移除钩子,防止不必要的内存占用

    print("方法二:", gradient_list[0]) # 输出: 16.0