PyTorch——关于Variable类和Tensor类的类型判断


问题描述

requires_grad=True

等价于requires_grad=a, a为任意非0整数,不能为浮点数
浮点数会报错: TypeError: integer argument expected, got float

  • 测试代码

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    import torch
    from torch.autograd import Variable

    tensor = torch.ones(1)
    variable = Variable(tensor, requires_grad=True)
    print(tensor)
    print(variable)
    print("type1: ", type(tensor), type(variable))
    print(tensor.data)
    print(variable.data)
    print("type2: ", type(tensor.data), type(variable.data))
    print(tensor.data.numpy())
    print(variable.data.numpy())
    print("type3: ", type(tensor.data.numpy()), type(variable.data.numpy()))
    print(tensor.numpy())
    print(variable.numpy())
    print("type4: ", type(tensor.numpy()), type(variable.numpy()))

    # Output:
    tensor([1.])
    tensor([1.], requires_grad=True)
    ('type1: ', <class 'torch.Tensor'>, <class 'torch.Tensor'>)
    tensor([1.])
    tensor([1.])
    ('type2: ', <class 'torch.Tensor'>, <class 'torch.Tensor'>)
    [1.]
    [1.]
    ('type3: ', <type 'numpy.ndarray'>, <type 'numpy.ndarray'>)
    [1.]
    Traceback (most recent call last):
    File "/home/jiahong/JupyterWorkspace/test.py", line 16, in <module>
    print(variable.numpy())
    RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.
  • 从上面的测试用例可以看出:

    • VariableTensor在判断类型时都是torch.Tensor
      • type(tensor) == type(variable) == torch.Tensor
    • 几乎所有操作都相同
      • tensor.data == variable.data
      • tensor.data.numpy() == varible.data.numpy()
    • 直接输出变量结果不相同
      • tensor输出时没有requires_grad=True
      • variable输出时有requires_grad=True
    • variable不能直接调用函数variable.numpy(),会报异常
      • 异常描述为: 当前Variable变量要求requires grad,也就是requires_grad属性为真时,变量不能直接使用

requires_grad=False

等价于requires_grad=0
不等价于requires_grad=None, None会报错: TypeError: an integer is required

  • 测试代码:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    import torch
    from torch.autograd import Variable

    tensor = torch.ones(1)
    variable = Variable(tensor, requires_grad=False)
    print(tensor)
    print(variable)
    print("type1: ", type(tensor), type(variable))
    print(tensor.data)
    print(variable.data)
    print("type2: ", type(tensor.data), type(variable.data))
    print(tensor.data.numpy())
    print(variable.data.numpy())
    print("type3: ", type(tensor.data.numpy()), type(variable.data.numpy()))
    print(tensor.numpy())
    print(variable.numpy())
    print("type4: ", type(tensor.numpy()), type(variable.numpy()))

    # Output:
    tensor([1.])
    tensor([1.])
    ('type1: ', <class 'torch.Tensor'>, <class 'torch.Tensor'>)
    tensor([1.])
    tensor([1.])
    ('type2: ', <class 'torch.Tensor'>, <class 'torch.Tensor'>)
    [1.]
    [1.]
    ('type3: ', <type 'numpy.ndarray'>, <type 'numpy.ndarray'>)
    [1.]
    [1.]
    ('type4: ', <type 'numpy.ndarray'>, <type 'numpy.ndarray'>)
  • 从上面的测试用例可以看出:

    • variable变量的requires_grad=False时,variable完全退化为tensor
      • 直接输出变量时没有requires_grad=False属性
      • 可以直接使用variable.numpy()函数

Variable的三种等价定义

下面三种定义的Variable类型变量varible等价

  • requires_grad=False

    1
    variable = Variable(tensor, requires_grad=False)
  • 没有requires_grad参数

    1
    variable = Variable(tensor)
  • requires_grad=True,然后variable = variable.detach()

    1
    2
    variable = Variable(tensor, requires_grad=True)
    variable = variable.detach()
  • 上面三种定义都等价于原始的tensor

    • 这里的等价并未经过详细测试,但是至少以下方面等价:
      • 自身类型相同type, 类型为torch.Tensor
      • 可以调用属性.data,类型为torch.Tensor
      • 可以调用.grad,只不过都为None
      • 直接输出对象完全相同,都不包含requires_grad=True属性
      • 可以调用相同的函数.numpy(), 类型为numpy.ndarray
      • 可以调用相同的函数.data.numpy(), 类型为numpy.ndarray