整体说明
tensor.is_leaf=True的张量被称为叶子张量(Leaf Tensor),也称为叶子变量(Leaf Variable),部分博客或书籍也称为叶子节点- 叶子张量分两种类型,即以下两种情况下的张量
tensor.is_leaf返回True:- 类型一:
requires_grad为False的张量都是叶子张量(由于requires_grad为False,所以不会存储梯度) - 类型二:
requires_grad为True的张量,如果是由用户创建的 ,而不是通过其他张量运算得到的,那么它是叶子张量;- 通过其他张量运算得到的,都是非叶子张量
- 类型一:
- 特别说明:
detach()函数可以将节点从计算图中剥离,使其成为叶子节点,此时requires_grad为False,同时tensor.is_leaf会变为True了- 从 CPU 定义好后移动到 GPU 时产生的张量,或者从 GPU 定义后,挪到 CPU 上的向量,也都是非叶子张量,
tensor.is_leaf返回False- 注意:不论是 GPU 还是 CPU,叶子张量的判定不变,只有用户定义的张量是叶子张量,挪动以后得都不是叶子张量(除非同时修改其
- 只有叶子张量的
requires_grad属性可以被修改,非叶子张量的requires_grad属性是不能被修改的- 理解:非叶子张量都是派生出来的,且
requires_grad=True的张量,叶子张量计算梯度时依赖性非叶子张量的梯度,不能随便修改requires_grad属性
- 理解:非叶子张量都是派生出来的,且
- 叶子张量的
grad_fn=None(包括requires_grad=True和requires_grad=False的都是)- 理解:
tensor.grad_fn属性指向/存储生成tensor张量的计算操作(如加法、减法、乘法等),叶子张量要么是直接由用户定义出来的(此时requires_grad=True),要么是不需要计算梯度的(此时requires_grad=False)
- 理解:
- 非叶子张量的梯度不会被保存(因为不需要使用)
叶子张量的特性
- 在反向传播过程中,只有叶子张量的梯度会被保留并存储在张量的
grad属性中- 用于后续优化器更新参数等操作
- 比如:神经网络层中的权值
w的张量均为叶子节点,反向传播backward()就是为了求它们的梯度,进而更新权值
- 非叶子张量的梯度在反向传播使用完后通常会被清除 ,以节省内存
- 比如:计算中中间的一些派生阶段就是非叶子张量
叶子张量相关示例
- 各种叶子张量判断示例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17import torch as torch
a = torch.randn(3,3)
print(a.is_leaf, a.requires_grad) # True False
a.requires_grad = True
print(a.is_leaf, a.requires_grad) # True True
b = a.cuda()
print(b.is_leaf, b.requires_grad) # False True
c = b.detach()
print(c.is_leaf, c.requires_grad) # True False
d = a + 2
print(d.is_leaf, d.requires_grad) # False True
e = torch.randn(3,3, device="cuda", requires_grad=True)
print(e.is_leaf, e.requires_grad) # True True
f = e.to("cpu")
print(f.is_leaf, f.requires_grad) # False True
附录:如何打印非叶子张量的梯度?
- 在 PyTorch 里,当开启自动求导功能时,中间变量(非叶子张量)的梯度默认不会被保存,目的是节省内存
- 只有叶子节点(比如直接创建的张量)的梯度会被保留
- 在任意时刻时,非中间节点的梯度(
grad属性)是都是None
- 要获取中间变量的梯度,有两种方法:
- 运用
retain_grad()方法保留梯度 - 借助钩子(hook)来捕获梯度(可以打印或者赋值给全局变量)
- 运用
- 打印非节点张量的代码示例:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26import torch
x = torch.tensor(2.0, requires_grad=True)
y = x**2
z = y**2
# 方法一:使用 retain_grad() 显式保留梯度
y.retain_grad()
z.backward(retain_graph=True) # 若不使用 retain_graph=True,计算图会在 backward 被清空,则后续想要调用 backward() 前需要重新构造计算图
print("方法一:", y.grad) # 输出: 16.0
# 方法二:使用钩子 hook 捕捉梯度
gradient_list = []
def save_gradient(grad):
gradient_list.append(grad)
## 如果 前面的 z.backward(retain_graph=True) 不使用 retain_graph=True,则 backward() 会清空计算图,这里就需要重新构造计算图,目前使用的是 z.backward(retain_graph=True) ,不需要下面的两句
# y = x**2
# z = y**2
hook = y.register_hook(save_gradient)
z.backward()
hook.remove() # 重点:即时移除钩子,防止不必要的内存占用
print("方法二:", gradient_list[0]) # 输出: 16.0