PyTorch——各种常用函数总结

PyTorch使用笔记,函数持续更新


torch.max()

  • 注:torch.mintorch.max用法类似

单参数(取全局最大值)

  • 用法

    1
    torch.max(input) -> Tensor
    • input: 一个Tensor的对象
    • return: 返回input变量中的最大值

多参数(按维度取最大值)

  • 用法

    1
    torch.max(input, dim, keepdim=False, out=None) -> tuple[Tensor, Tensor]
    • input: 一个Tensor的对象
    • 返回是一个包含 values 和 indices 的对象,其中 values 是最大值,indices 是其索引
    • dim: 指明维度,生成的结果中,indices 用于替换第 0 维度(这对应 gather() 的检索方式)
    • keepdim: 是否保持输出张量的维度与输入张量一致(默认值为 False)
      • 如果为 True,输出张量在指定维度上的大小为 1;
      • 如果为 False,输出张量将减少一个维度
  • keepdim=False示例(包含索引使用示例):

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    import torch
    x = torch.tensor([[3, 2, 1],
    [4, 5, 6]])
    # 获取每行的最大值
    values, indices = torch.max(x, dim=1) # 默认在最后一个维度,即 dim=1 上操作,按行取最大值(生成的index是指定dim=1维度的索引)
    print("values:\n", values)
    print("indices:\n", indices)
    # 用 indices 从 x 中检索出 values,方法一:gather 方法
    retrieved_values = torch.gather(x, dim=1, index=indices.unsqueeze(1)) # 用 indices 替换 dim=1 索引维度即可抽取到对应的值
    print("retrieved_values:\n", retrieved_values)
    print("retrieved_values.squeeze():\n", retrieved_values.squeeze())
    print("---")
    # 用 indices 从 x 中检索出 values,方法二:高级索引
    retrieved_values = x[torch.arange(2), indices]
    print("retrieved_values:\n", retrieved_values)

    # values:
    # tensor([3, 6])
    # indices:
    # tensor([0, 2])
    # retrieved_values:
    # tensor([[3],
    # [6]])
    # retrieved_values.squeeze():
    # tensor([3, 6])
    # ---
    # retrieved_values:
    # tensor([3, 6])
  • keepdim=True示例(包含索引使用示例):

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    import torch
    x = torch.tensor([[3, 2, 1],
    [4, 5, 6]])
    # 获取每行的最大值,保留维度
    values, indices = torch.max(x, dim=1, keepdim=True) # 默认在最后一个维度,即 dim=1 上操作,按行取最大值(生成的index是指定dim=1维度的索引)
    print("values:\n", values)
    print("indices:\n", indices)
    # 用 indices 从 x 中检索出 values,方法一:gather 方法
    retrieved_values = torch.gather(x, dim=1, index=indices) # 用 indices 替换 dim=1 索引维度即可抽取到对应的值
    print("retrieved_values:\n", retrieved_values)
    print("---")
    # 用 indices 从 x 中检索出 values,方法二:高级索引
    retrieved_values = x[torch.arange(2), indices.squeeze()]
    print("retrieved_values:\n", retrieved_values)
    print("retrieved_values.unsqueeze(dim=1):\n", retrieved_values.unsqueeze(dim=1))

    # values:
    # tensor([[3],
    # [6]])
    # indices:
    # tensor([[0],
    # [2]])
    # retrieved_values:
    # tensor([[3],
    # [6]])
    # ---
    # retrieved_values:
    # tensor([3, 6])
    # retrieved_values.unsqueeze(dim=1):
    # tensor([[3],
    # [6]])

torch.backward() & torch.no_grad()

  • 用法:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    import torch
    x = torch.ones(2, 2, requires_grad=True)
    m = torch.ones(2, 2, requires_grad=True)
    with torch.no_grad(): # 这两句合起来
    z = x + 1 # 等价于z = x.detach() + 1
    y = z * 2 + m
    out = y.mean()
    out.backward()
    print(x.grad) # output None,因为梯度在z = x+1处断开了
    print(m.grad) # output tensor([[0.2500, 0.2500],
    # [0.2500, 0.2500]])
    print(y.grad) # output None,因为y只是中间操作节点,不是叶子Tensor变量(leaf Tensor)
  • 以上代码中,x不会被计算梯度,因为z = x + 1处梯度断开了,等价于z = x.detach() + 1

  • 注意,如果调用torch.backward()时,没有任何可以计算的梯度,会报错,如下面的代码:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    import torch
    x = torch.ones(2, 2, requires_grad=True)
    m = torch.ones(2, 2, requires_grad=True)
    with torch.no_grad(): # 这两句合起来
    z = x + 1 # 等价于z = x.detach() + 1
    y = z * 2 + m.detach() # 相对上面的改动点,仅这里
    out = y.mean()
    out.backward()
    print(x.grad) # output None,因为梯度在z = x+1处断开了
    print(m.grad) # output tensor([[0.2500, 0.2500],
    # [0.2500, 0.2500]])
    print(y.grad) # output None,因为y只是中间操作节点,不是叶子Tensor变量(leaf Tensor)
    • 仅修改了m.detach()out.backward()就开始报错RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn,因为没有任何梯度可以计算了
    • 其他说明:如果仅修改m = torch.ones(2, 2, requires_grad=True)m = torch.ones(2, 2, requires_grad=False),也一样会导师没有任何梯度可以计算了,执行out.backward()也会报相同错误

torch.nn.Module

  • 直接打印torch.nn.Module类对象print(model),默认情况下这个方法会调用model.__str__(),进一步调用model.__repr__(),这个函数会给出一个包含所有被定义为类属性的层的可读表示,但不会包括那些被定义为简单属性或方法的对象,这个方法可以用来查看模型包含哪些网络层
  • 举例如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    import torch
    import torch.nn as nn
    class MLPModel(nn.Module):
    def __init__(self):
    super(MLPModel, self).__init__()
    self.fc1 = nn.Linear(2, 2) # 会被打印
    self.fc2 = nn.Linear(2, 1) # 会被打印
    self.activation = torch.relu # 不会被打印

    def forward(self, x):
    x = self.activation(self.fc1(x))
    x = self.fc2(x)
    return x
    model = MLPModel()
    print(model) # 等价于print(model.__str__())和print(model.__repr__())

    # MLPModel(
    # (fc1): Linear(in_features=2, out_features=2, bias=True)
    # (fc2): Linear(in_features=2, out_features=1, bias=True)
    # )

tensor.item()

  • 用于抽取一个tensor对象中的单一数值(注意,只能是单一数值,否则会报错)
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    import torch
    # 单元素张量
    tensor = torch.tensor([5]) # 形状为 (1,)
    value = tensor.item() # 提取出 Python 的 int 值
    print(value) # 输出: 5
    print(type(value)) # 输出: <class 'int'>

    tensor = torch.tensor([1, 2, 3])
    value = tensor.item() # 报错: ValueError: only one element tensors can be converted to Python scalars

    # 强化学习中,常常用于抽取动作,方便和环境交互
    action = action_dist.sample() # 采样一个动作,action 是一个单元素张量
    action_value = action.item() # 提取出动作的标量值

torch.autograd.grad(outputs,inputs,grad_outputs)

  • 功能:outputsinputs求导
  • 关键点:
    • grad_outputs取默认值时,要求outputs必须是标量张量(一维)
    • outputs不是标量张量时,要求grad_outputsoutputs维度一致,指明outputs中每个值对最终梯度的权重,且此时不可省略该参数(否则会报错:“RuntimeError: grad can be implicitly created only for scalar outputs”)
  • Demo展示:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    import torch
    # 创建需要梯度的张量
    x = torch.tensor([2.0, 3.0], requires_grad=True)

    # 定义一个多维张量 y = [x[0]**2, x[1]**3]
    y = torch.stack([x[0]**2+x[1]**2, x[0]*3+x[1]**2]) # y 是一个形状为 (2,) 的张量

    # 计算 y 对 x 的梯度
    # 需要指定 grad_outputs 作为 y 的每个元素的权重(默认情况下,grad_outputs 是全 1 的张量)
    grads = torch.autograd.grad(outputs=y, inputs=x,grad_outputs=torch.ones_like(y))

    print("x:", x)
    print("y:", y)
    print("Gradients of y with respect to x:", grads)

    ### output:
    # x: tensor([2., 3.], requires_grad=True)
    # y: tensor([13., 15.], grad_fn=<StackBackward0>)
    # Gradients of y with respect to x: (tensor([ 7., 12.]),)
    # 其中 7 = 2 * x[0] + 3; 12 = 2 * x[1] + 2 * x[1]

torch.cat、stack、hstack、vstack

  • torch.cattorch.stack两者都是用来拼接tensor的函数,主要区别是使用 torch.cat 在现有维度上拼接,使用 torch.stack 在新维度上拼接

    • 维度变化 :
      • torch.cat 不新增维度,只在现有维度上拼接
      • torch.stack 会新增一个维度,并在该维度上拼接,torch.stack的工作需要分成两步,第一步是增加维度(比如(3,)经过dim=0会变成(1,3)),第二步是将该维度拼接
    • 形状要求 :
      • torch.cat 要求非拼接维度上的形状相同
      • torch.stack 要求所有张量的形状完全一致
  • 使用特别说明 :一般使用 torch.cattorch.stack就够了,torch.hstacktorch.vstack基本可以被torch.cattorch.stack替代,所以不常用

  • torch.cat Demo 展示:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    import torch

    a = torch.tensor([[1, 2], [3, 4]])
    b = torch.tensor([[5, 6], [7, 8]])

    # 沿第0维拼接
    c = torch.cat((a, b), dim=0)
    print(c)
    # 输出:
    # tensor([[1, 2],
    # [3, 4],
    # [5, 6],
    # [7, 8]])

    # 沿第1维拼接
    d = torch.cat((a, b), dim=1)
    print(d)
    # 输出:
    # tensor([[1, 2, 5, 6],
    # [3, 4, 7, 8]])
  • torch.stack Demo 展示:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    import torch

    a = torch.tensor([1, 2, 3])
    b = torch.tensor([4, 5, 6])

    # 在第0维新增维度并拼接: (3,) -> (1,3) -> (2,3)
    c = torch.stack((a, b), dim=0)
    print(c)
    # 输出:
    # tensor([[1, 2, 3],
    # [4, 5, 6]])

    # 在第1维新增维度并拼接: (3,) -> (3,1) -> (3,2)
    d = torch.stack((a, b), dim=1)
    print(d)
    # 输出:
    # tensor([[1, 4],
    # [2, 5],
    # [3, 6]])
    • stack的动作分两步,第一步是在指定维度增加一维,比如:
      • 上面的式子(3,)经过dim=0后变成(1,3),进一步地两个(1,3)堆叠变成(2,3)
      • 上面的式子(3,)经过dim=1后变成(3,1),进一步地两个(3,1)堆叠变成(3,2)
  • torch.hstacktorch.vstack函数:

    • torch.hstack:在水平方向(dim=1)处拼接张量,不新增维度
      • 对于 1 维张量会特殊处理,直接将 1 维张量变长(此时相当于在 dim=0处拼接)
    • torch.vstack:在垂直方向(第 0 维)拼接张量,除了原始输入为 1 维时,不新增维度
      • 对于 1 维张量会特殊处理,拼接后变成 2 维(此时相当于先在dim=0增加一维再按照这个维度拼接)
    • 不同函数在输入1维和5维向量时的Demo:
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      import torch
      tensors_1d = [torch.tensor([1, 3, 4]), torch.tensor([1, 3, 4])]
      tensors_5d = [torch.randn(1, 3, 4, 5, 6), torch.randn(1, 3, 4, 5, 6)]
      # 使用 torch.stack
      stacked_1d = torch.stack(tensors_1d)
      stacked_5d = torch.stack(tensors_5d, dim=2)
      # 使用 torch.hstack
      hstacked_1d = torch.hstack(tensors_1d)
      hstacked_5d = torch.hstack(tensors_5d)
      # 使用 torch.vstack
      vstacked_1d = torch.vstack(tensors_1d)
      vstacked_5d = torch.vstack(tensors_5d)
      # 使用 torch.cat
      cat_1d = torch.cat(tensors_1d)
      cat_5d = torch.cat(tensors_5d, dim=2)

      print("tensors_1d shape:", tensors_1d[0].shape)
      print("tensors_5d shape:", tensors_5d[0].shape)
      print("torch.stack 1D shape:", stacked_1d.shape)
      print("torch.stack 5d shape(dim=2):", stacked_5d.shape)
      print("torch.hstack 1D shape:", hstacked_1d.shape)
      print("torch.hstack 5d shape:", hstacked_5d.shape)
      print("torch.vstack 1D shape:", vstacked_1d.shape)
      print("torch.vstack 5d shape:", vstacked_5d.shape)
      print("torch.cat 1D shape:", cat_1d.shape)
      print("torch.cat 5d shape(dim=2):", cat_5d.shape)

      # tensors_1d shape: torch.Size([3])
      # tensors_5d shape: torch.Size([1, 3, 4, 5, 6])
      # torch.stack 1D shape: torch.Size([2, 3])
      # torch.stack 5d shape(dim=2): torch.Size([1, 3, 2, 4, 5, 6])
      # torch.hstack 1D shape: torch.Size([6])
      # torch.hstack 5d shape: torch.Size([1, 6, 4, 5, 6])
      # torch.vstack 1D shape: torch.Size([2, 3])
      # torch.vstack 5d shape: torch.Size([2, 3, 4, 5, 6])
      # torch.cat 1D shape: torch.Size([6])
      # torch.cat 5d shape(dim=2): torch.Size([1, 3, 8, 5, 6])

tensor.repeat()函数

  • 基本用法

    1
    tensor.repeat(*sizes)
    • *sizes:一个整数序列,表示每个维度上重复的次数
    • 返回值:一个新的张量,其形状是原张量形状的每个维度乘以对应的重复次数
  • 注意事项

    • 新维度的顺序repeat 会按照输入的顺序对每个维度进行扩展
    • 内存共享repeat 不会复制数据,而是通过视图(view)来实现重复。这意味着返回的张量与原始张量共享底层数据
    • 维度扩展规则 :如果输入的 sizes 长度大于张量的维度,则会在前面补充新的维度,比如(2,)的原始张量调用repeat(2,3)后,会先扩展成(1,2)的张量,再执行repeat(2,3),但是不建议补充这个功能
  • 函数调用Demo:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    import torch
    x = torch.tensor([1, 2])
    print("Original tensor:", x)

    # 沿第 0 维重复 3 次
    result = x.repeat(3)
    print("Repeated tensor(3):", result)
    # 添加一个新维度,并在两个维度上重复
    result = x.repeat(2, 3)
    print("Repeated tensor shape(2,3):", result.shape)
    print("Repeated tensor(2,3):\n", result)

    x = torch.tensor([[1, 2, 3], [4, 5, 6]])
    print("Original tensor:\n", x)

    # 在第 0 维重复 2 次,第 1 维重复 3 次
    result = x.repeat(2, 3)
    print("Repeated tensor:\n", result)

    # Original tensor: tensor([1, 2])
    # Repeated tensor(3): tensor([1, 2, 1, 2, 1, 2])
    # Repeated tensor shape(2,3): torch.Size([2, 6])
    # Repeated tensor(2,3):
    # tensor([[1, 2, 1, 2, 1, 2],
    # [1, 2, 1, 2, 1, 2]])
    # Original tensor:
    # tensor([[1, 2, 3],
    # [4, 5, 6]])
    # Repeated tensor:
    # tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3],
    # [4, 5, 6, 4, 5, 6, 4, 5, 6],
    # [1, 2, 3, 1, 2, 3, 1, 2, 3],
    # [4, 5, 6, 4, 5, 6, 4, 5, 6]])

parameters() & named_parameters()

  • model.parameters():返回模型中所有参数,不返回参数名称
  • model.named_parameters():返回模型中所有参数,同时返回参数名称
  • Demo
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    # 定义一个简单的全连接神经网络
    class SimpleNN(nn.Module):
    def __init__(self):
    super(SimpleNN, self).__init__()
    self.fc1 = nn.Linear(2, 2) # 输入层到隐藏层
    self.fc2 = nn.Linear(2, 1) # 隐藏层到输出层
    for param in self.fc2.parameters(): # 可指定设置某些参数不参与梯度更新(Frozen参数)
    param.requires_grad = False

    def forward(self, x):
    x = torch.relu(self.fc1(x)) # ReLU 激活函数
    x = self.fc2(x)
    return x

    # 创建网络实例
    model = SimpleNN()

    # 打印模型结构
    print(model)

    print("\nmodel.named_parameters():")
    # 遍历模型的所有参数并打印那些 requires_grad=True 的参数
    for name, param in model.named_parameters():
    if param.requires_grad:
    print(f"Parameter {name} requires gradient: {param}")
    else:
    print(f"Parameter {name} don't require gradient: {param}")

    print("\nmodel.parameters():")
    # 遍历模型的所有参数并打印那些 requires_grad=True 的参数
    for param in model.parameters():
    print(param.requires_grad)
    # SimpleNN(
    # (fc1): Linear(in_features=2, out_features=2, bias=True)
    # (fc2): Linear(in_features=2, out_features=1, bias=True)
    # )
    #
    # model.named_parameters():
    # Parameter fc1.weight requires gradient: Parameter containing:
    # tensor([[-0.6769, 0.7008],
    # [-0.6705, 0.2912]], requires_grad=True)
    # Parameter fc1.bias requires gradient: Parameter containing:
    # tensor([-0.6091, -0.2306], requires_grad=True)
    # Parameter fc2.weight don't require gradient: Parameter containing:
    # tensor([[ 0.5131, -0.4101]])
    # Parameter fc2.bias don't require gradient: Parameter containing:
    # tensor([-0.0788])
    #
    # model.parameters():
    # True
    # True
    # False
    # False

torch.empty()创建变量

  • torch.empty() 是 PyTorch 中用于创建一个新的张量(tensor)的函数,该张量不会被初始化,即它的元素值是未定义的。这意味着张量中的元素可能是任意值,取决于分配给张量的内存块之前的状态
  • 具体用途:
    • 快速创建张量 :当你需要一个具有特定形状的张量,但不关心初始值时,可以使用 torch.empty() 来快速创建它。这对于只需要分配内存而不必初始化数据的情况非常有用
    • 占位符张量 :在某些情况下,你可能想要创建一个张量作为占位符,稍后将用实际的数据填充它。例如,在实现算法或构建计算图时,你可能提前知道所需的空间大小,但暂时没有具体数值
    • 性能优化 :由于不需要初始化张量中的数据,torch.empty() 可以比 torch.zeros()torch.ones() 更快地分配内存,因为它跳过了设置默认值的过程
    • 显存管理 :在GPU上操作时,有时会利用 torch.empty() 来预分配显存,避免后续操作中可能出现的内存碎片化问题

torch.linspace() & torch.arange()

  • torch.linspacetorch.arange 都是 PyTorch 中用于生成数值序列的函数,但它们在生成数值的方式和使用场景上有一些关键的不同
    • torch.linspace : 返回一个一维张量,包含从 startend(包括两端点)之间等间距分布的 steps 个点
    • torch.arange : 返回一个一维张量,包含从 start 开始到 end 结束(不包括 end),以步长 step 增长的连续整数或浮点数
  • Demo
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    import torch

    # 使用 linspace
    linspace_tensor = torch.linspace(0, 2, steps=3)
    print("Linspace:", linspace_tensor)

    # 使用 arange
    arange_tensor = torch.arange(0, 10, step=2)
    print("Arange:", arange_tensor)

    # Linspace: tensor([0.0000, 0.6667, 1.3333, 2.0000])
    # Arange: tensor([0, 2, 4, 6, 8])

tensor.is_contiguous()

  • tensor.is_contiguous() 用于检查一个张量(Tensor)是否是连续的(contiguous)。在PyTorch中,张量的存储方式可以分为连续和非连续。一个张量被认为是连续的,如果它的元素在内存中是按行优先顺序(row-major order)依次存放的,也就是说,在内存中这些元素是连续排列的
  • 连续排列对于需要高效访问或操作张量数据的操作非常重要,因为许多底层实现(例如,CUDNN库中的函数)要求输入的数据必须是连续存储的。如果一个张量不是连续的,即使它包含相同的数据,其内存布局也可能导致性能下降或者某些操作无法执行
  • 当调用 tensor.is_contiguous() 方法时,如果返回值为 True,则表示该张量是连续的;如果返回 False,则表示张量当前不是以连续的方式存储的。如果你需要将一个非连续的张量转换为连续的,可以使用 .contiguous() 方法来创建一个新的、与原张量具有相同数据但存储方式为连续的张量副本
  • 示例代码如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    import torch

    x = torch.tensor([[1, 2, 3], [4, 5, 6]])
    print(x.is_contiguous()) # 输出: True

    # 假设进行了转置操作,通常会导致张量变为非连续
    x_t = x.t()
    print(x_t.is_contiguous()) # 输出: False

    # 使用 contiguous 方法将非连续张量转换为连续
    x_t_contig = x_t.contiguous()
    print(x_t_contig.is_contiguous()) # 输出: True

x.squeeze(dim) vs x.unsqueeze(dim)

  • unsqueeze:在指定维度添加一个维度,且该长度为 1
  • squeeze:去掉指定长度为 1 的维度
    • 如果指定的不是长度为 1 的维度,则不做任何修改
    • 如果不指定任何维度,则将所有长度为 1 的维度都去掉;
    • 如果不指定任何维度,且没有长度为 1 的维度,则不做任何修改
  • unsqueeze() 有等价形式,如对于二维的张量:
    • tensor.unsqueeze(dim=0) 等价于 tensor[None]tensor[None,:,:]
    • tensor.unsqueeze(dim=1) 等价于 tensor[:,None]tensor[:,None,:]
    • tensor.unsqueeze(dim=0) 等价于 tensor[:,:,None]
  • unsqueeze() 代码示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    import torch
    import numpy as np
    x = torch.tensor([[3, 2, 1],
    [4, 5, 6]], dtype=torch.float32)

    y = x[None] # 等价于 y = x[None,:,:] 和 y = x.unsqueeze(dim=0)
    print(y)
    print(y.shape)

    y = x[:,None] # 等价于 y = x[:,None,:] 和 y = x.unsqueeze(dim=1)
    print(y)
    print(y.shape)

    y = x[:,:,None] # 和 y = x.unsqueeze(dim=2)
    print(y)
    print(y.shape)

    # tensor([[[3., 2., 1.],
    # [4., 5., 6.]]])
    # torch.Size([1, 2, 3])
    # tensor([[[3., 2., 1.]],
    # [[4., 5., 6.]]])
    # torch.Size([2, 1, 3])
    # tensor([[[3.],
    # [2.],
    # [1.]],
    # [[4.],
    # [5.],
    # [6.]]])
    # torch.Size([2, 3, 1])

torch.flatten()torch.unflatten()

  • torch.flatten() 基本用法:

    1
    torch.flatten(input, start_dim=0, end_dim=-1)
    • input:需要被展平的输入张量
    • start_dim:从哪个维度开始展平,默认值为0(即从第0维开始)
    • end_dim:到哪个维度结束展平,默认值为-1(即到最后一维结束)
  • torch.unflatten() 基本用法

    1
    torch.unflatten(input, dim, unflattened_size)
    • input:需要被处理的输入张量
    • dim:需要被拆分的维度索引
    • unflattened_size:指定拆分后的维度大小,可以是元组或列表
  • 示例
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    import torch

    x = torch.tensor([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]])
    print(x.shape) # 输出: torch.Size([1, 2, 2, 2])

    # 展平整个张量
    y = torch.flatten(x)
    print(y.shape) # 输出: torch.Size([8])
    print(y) # 输出: tensor([1, 2, 3, 4, 5, 6, 7, 8])

    # 展平张量的中间两个维度
    z = torch.flatten(x, start_dim=1, end_dim=2)
    print(z.shape) # 输出: torch.Size([1,4,2])

    # 使用 unflatten 拆分维度
    w = z.unflatten(dim=1, sizes=(2,2))
    print(w) # 输出: torch.Size([1, 2, 2, 2])

tensor.numpy() vs torch.from_numpy()

  • torch.from_numpy(ndarray):将 NumPy 数组转换为张量,两者共享内存 ,修改NumPy变量会改变Tensor变量,反之亦然

  • tensor.numpy():将张量转换为 NumPy 数组(仅限 CPU 张量),两者共享内存 ,修改NumPy变量会改变Tensor变量,反之亦然

  • 代码示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    import torch
    import numpy as np

    numpy_array = np.array([[1.0, 2, 3], [4, 5, 6]])
    tensor_from_numpy = torch.from_numpy(numpy_array)
    numpy_array[0, 0] = 100
    print("\n修改后的 NumPy 数组:\n", numpy_array.dtype,numpy_array)
    print("PyTorch 张量也会同步变化:\n", tensor_from_numpy)

    tensor = torch.tensor([[7, 8, 9], [10, 11, 12]], dtype=torch.float32)
    numpy_from_tensor = tensor.numpy()
    tensor[0, 0] = 77
    print("\n修改后的 PyTorch 张量:\n", tensor)
    print("NumPy 数组也会同步变化:\n", numpy_from_tensor)


    numpy_array = np.array([[1, 2, 3], [4, 5, 6]])
    tensor_from_numpy = torch.Tensor(numpy_array)
    numpy_array[0, 0] = 100
    print("\n修改后的 NumPy 数组:\n", numpy_array)
    print("PyTorch 张量不会同步变化:\n", tensor_from_numpy)

    # 修改后的 NumPy 数组:
    # float64 [[100. 2. 3.]
    # [ 4. 5. 6.]]
    # PyTorch 张量也会同步变化:
    # tensor([[100., 2., 3.],
    # [ 4., 5., 6.]], dtype=torch.float64)
    #
    # 修改后的 PyTorch 张量:
    # tensor([[77., 8., 9.],
    # [10., 11., 12.]])
    # NumPy 数组也会同步变化:
    # [[77. 8. 9.]
    # [10. 11. 12.]]
    #
    # 修改后的 NumPy 数组:
    # [[100 2 3]
    # [ 4 5 6]]
    # PyTorch 张量不会同步变化:
    # tensor([[1., 2., 3.],
    # [4., 5., 6.]])
  • 特别注意(其他共享内存相关函数)

    • 使用 torch.Tensor(numpy_array) 得到的变量只有在 numpy_array 的类型为 np.float32 时共享,否则不共享内存的(注意:torch.Tensor() 是无法指定数据类型的,默认类型就是 torch.float32
    • 使用 torch.tensor(numpy_array) 得到的变量总是不共享内存,都是副本

inplace参数的使用

  • 在PyTorch中,inplace参数允许你指定操作是否直接在输入张量上进行而不需要额外的内存分配来存储结果。这对于减少内存使用特别有用,但需要注意的是,这也会覆盖原始数据。下面通过具体的例子说明如何使用inplace参数,以torch.nn.functional.relu为例:

示例代码

  • 代码示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    import torch
    import torch.nn.functional as F

    # 创建一个随机张量
    input_tensor = torch.randn(2, 3)
    print("Original tensor:")
    print(input_tensor)

    # 不使用inplace模式(默认)
    output_tensor = F.relu(input_tensor)
    print("\nOutput tensor (not inplace):")
    print(output_tensor)
    print("\nOriginal tensor after not inplace operation:")
    print(input_tensor) # 原始张量未被修改

    # 使用inplace模式
    F.relu(input_tensor, inplace=True)
    print("\nOriginal tensor after inplace operation:")
    print(input_tensor) # 原始张量被修改,应用了ReLU

    # Original tensor:
    # tensor([[-0.7269, -0.2997, 0.1654],
    # [ 1.2219, 1.2698, -0.5245]])
    #
    # Output tensor (not inplace):
    # tensor([[0.0000, 0.0000, 0.1654],
    # [1.2219, 1.2698, 0.0000]])
    #
    # Original tensor after not inplace operation:
    # tensor([[-0.7269, -0.2997, 0.1654],
    # [ 1.2219, 1.2698, -0.5245]])
    #
    # Original tensor after inplace operation:
    # tensor([[0.0000, 0.0000, 0.1654],
    # [1.2219, 1.2698, 0.0000]])
    • 输出解释
      • 不使用inplace:当你调用F.relu(input_tensor)时,默认情况下不会修改原始的input_tensor,而是返回一个新的张量output_tensor作为ReLU操作的结果
      • 使用inplace=True:当你指定inplace=True,如F.relu(input_tensor, inplace=True),此时ReLU操作会直接作用于input_tensor本身,即它将所有负值设置为0,并且这个修改是直接在原来的张量上进行的,不会创建新的张量。因此,在执行完这个操作后,原始的input_tensor已经被修改

关于inplace的其他说明

  • 使用inplace操作可以节省一些内存,因为不需要为输出分配新的空间
    • 它会直接修改原张量的数据,在需要保留原数据的情况下要谨慎使用
    • 并不是所有的函数都支持inplace参数,具体的支持情况可以查阅相关函数的文档说明
  • 这种方式同样适用于其他许多具有inplace选项的操作,比如激活函数、归一化等,理解并正确使用inplace可以帮助更有效地管理和优化你的模型训练过程
  • PyTorch 中,常常用以 _ 结尾的函数实现 inplace 操作

torch.device() 和 tensor.to()函数

  • 设备获取:torch.device('cuda')用于设备获取
  • tensor迁移:tensor = tensor.to(device)
  • 模型迁移:model = model.to(device)
    • 会将模型的所有可学习参数以及模型中的缓冲区移动到指定的 device 设备上
    • 所有可学习参数 :也就是模型的权重和偏置,模型的可学习参数由 model.parameters() 对象来表示(这里面包含模型的所有nn.Parameter对象)。这些参数一般是在模型的 __init__ 方法中定义的
    • 模型中的缓冲区 :即不需要反向传播更新的张量,例如 BatchNorm 层中的运行统计量,详细来说,例如 torch.nn.BatchNorm2d 层中的 running_meanrunning_var

tensor.view() 和 tensor.reshape()

  • 在 PyTorch 中,tensor.view()tensor.reshape() 都可以用来改变张量的形状,但它们之间有一些关键的区别
  • tensor.view() :
    • 它要求数据在内存中是连续存储的。如果 tensor 不是连续的(例如,经过 transpose、permute 或其他操作后),使用 .view() 会抛出错误。这时你需要先调用 .contiguous() 来确保 tensor 在内存中是连续的,然后才能使用 .view()
    • .view() 的性能可能更好,因为它直接改变了对原始数据的视图而没有复制数据
  • tensor.reshape() :
    • tensor.reshape() 等价于 tensor.contiguous().view()
    • tensor.reshape()可以在张量不是连续的情况下工作,因为它会在必要时创建张量的一个副本
      • 如果张量是连续的,不会创建副本
    • .reshape() 是自 PyTorch 0.4.0 版本引入的一个函数,旨在提供一种更一致的方式来处理形状变换。无论张量是否连续,它都可以工作
  • 一句话总结:
    • 如果确定张量是连续的 ,并且希望明确表达避免不必要的数据复制提高性能 ,那么可以选择使用 .view()
    • 然而,如果不确定张量是否连续或者你不关心额外的数据复制.reshape() 提供了一个更为方便和灵活的选择
      • 在大多数情况下,为了代码的健壮性和易读性,推荐使用 .reshape()
      • 在连续存储下, .reshape().view() 是等价的,无需担心性能

tensor.topk()函数

  • torch.topk() 用于获取张量中指定维度上的前 k 个最大/最小值及其索引,常用于分类任务中获取 top-k 的预测结果,其函数定义如下:

    1
    torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)
  • 参数说明:

    • input : 输入的张量
    • k : 需要返回的最大(或最小)值的数量
    • dim : 沿着哪个维度进行操作。如果未指定,则默认在最后一个维度上操作
    • largest : 如果为 True,则返回最大的 k 个值;如果为 False,则返回最小的 k 个值。默认为 True
    • sorted : 如果为 True,则返回的结果会按照大小排序;如果为 False,则返回的结果顺序不确定。默认为 True
    • out : 可选的输出元组,用于存储结果
  • 返回值

    • values : 前 k 个最大(或最小)的值
    • indices : 这些值在输入张量中的索引
  • 用法示例

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    import torch
    x = torch.tensor([[3, 2, 1],
    [4, 5, 6]])
    # 获取每行的前 2 个最大值,下面的式子等价于 values, indices = torch.topk(x, k=2, dim=1)
    values, indices = torch.topk(x, k=2) # 默认在最后一个维度,即 dim=1 上操作,按行取最大值(生成的index是指定dim=1维度的索引)
    # 用 indices 从 x 中检索出 values
    retrieved_values = torch.gather(x, dim=1, index=indices) # 用 indices 替换 dim=1 索引维度即可抽取到对应的值
    print("values:\n", values)
    print("indices:\n", indices)
    print("retrieved_values:\n", retrieved_values)

    print("---")
    # 获取每列的前 1 个最大值
    values, indices = torch.topk(x, k=1, dim=0) # 在 dim=0 上操作,生成的index是指定dim=0维度的索引
    # 用 indices 从 x 中检索出 values
    retrieved_values = torch.gather(x, dim=0, index=indices)
    print("values:\n", values)
    print("indices:\n", indices)
    print("retrieved_values:\n", retrieved_values)

    # values:
    # tensor([[3, 2],
    # [6, 5]])
    # indices:
    # tensor([[0, 1],
    # [2, 1]])
    # retrieved_values:
    # tensor([[3, 2],
    # [6, 5]])
    # ---
    # values:
    # tensor([[4, 5, 6]])
    # indices:
    # tensor([[1, 1, 1]])
    # retrieved_values:
    # tensor([[4, 5, 6]])

nn.Module.register_buffer

  • 在PyTorch中,self.register_buffer 是一个用于向模块(nn.Module)注册持久化缓冲区(buffer)的方法。它的主要作用是告诉PyTorch某些张量是模型的一部分,但不属于可训练参数(即不需要梯度更新),但这些张量在模型保存或加载时需要被包含进来
  • 非可训练参数:通过 register_buffer 注册的张量不会被优化器更新(不像 nn.Parameter
    • 注:在BatchNorm中,running_meanrunning_var 是统计量,需要跟踪但不参与梯度计算,也属于这一类
  • 持久化:注册的buffer会被包含在 model.state_dict() 中,因此当调用 torch.savetorch.load 时,它们会被自动保存和加载
  • 设备移动:当调用 model.to(device) 时,这些buffer会自动移动到对应的设备(如CPU/GPU),与模型的其他参数一致
  • 获取值:被包含在 model.state_dict()
  • 常见用法:一些固定常量参数或者统计量
  • 使用示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    class MyModel(nn.Module):
    def __init__(self):
    super().__init__()
    self.weight = nn.Parameter(torch.randn(3, 3))
    # 注册一个buffer(非可训练但需要持久化)
    self.register_buffer('buffer_a', torch.zeros(3))
    model = MyModel()
    print(model.state_dict())
    # 输出包含:weight 和 buffer_a

PyTorch中的对数运算

  • 以 e 为底:x.log() == torch.log(x)
  • 以 2 为底:x.log2() == torch.log2(x)
  • 以 10 为底:x.log10() == torch.log10(x)
  • 以其他数字为底的通过对数换底公式实现:
    $$ \log_5 x = \frac{\ln x}{\ln 5} $$
    • 实现代码
      1
      2
      3
      4
      5
      import torch

      def log5(x):
      """计算以5为底的对数,输入为Tensor"""
      return torch.log(x) / torch.log(torch.tensor(5.0))

torch.cliptorch.clamp

  • 在 PyTorch 里,torch.cliptorch.clamp 功能完全一样,都能用于将输入张量的数值限制在指定的区间内,从英文含义上看,两者分别是截取和限制的含义,都差不多,但理论上 clip 会更符合原本这个函数“裁剪”的含义
  • torch.clamp 是 PyTorch 从早期就存在的传统函数,一直被广泛运用
  • torch.clip 是在版本 1.7.0 时新增的函数,目的是和 NumPy 的 np.clip 保持 API 一致,方便用户从 NumPy 迁移过来
  • 推荐使用 torch.clip,方便阅读,但为了和更早的 torch 版本兼容,还有很多人使用 clamp

torch.allclose() 函数

  • 用于比较两个张量是否在给定的误差范围内“几乎相等”

  • 基本语法

    1
    torch.allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False)
  • 特别注意:

  • otherinput 形状必须相同

  • equal_nan:若为True,则NaN值被视为相等;若为False,则NaN会导致比较失败

  • 若所有元素的差异都在容差范围内,则返回True;否则返回False。对于每个元素,比较条件为:

    1
    |input - other| ≤ atol + rtol * |other|
  • 容易混淆的相关方法的比较(以下两个方法都是逐个元素比较,返回逐元素结果的)

    • torch.isclose()逐元素比较两个张量是否接近,返回与输入形状相同的布尔张量(注意不是只返回一个值,是逐元素比较结果)

      1
      torch.isclose(a, b, rtol=1e-5, atol=1e-8, equal_nan=False)
    • torch.eq()严格逐元素相等比较(不考虑容差),返回与输入形状相同的布尔张量


torch.chuck 函数用法

  • torch.chunk 是 PyTorch 中用于将张量按照指定维度拆分成多个子张量的函数,返回一个包含拆分后子张量的元组

  • torch.chunk 拆分后的数据与原张量共享内存(浅拷贝),修改子张量会影响原张量

  • 函数定义为:

    1
    torch.chunk(input, chunks, dim=0)
    • input:待拆分的输入张量(torch.Tensor
    • chunks:拆分的数量(int)。需注意:若输入张量在 dim 维度的大小不能被 chunks 整除,最后一个子张量的大小会略小(其余子张量大小相等)
    • dim:指定拆分的维度(int,默认值为 0)
    • 返回:一个元组(tuple),包含 chunks 个子张量(或最后一个子张量略小)
  • torch.chunktorch.split 的区别:

    • torch.chunk 按“数量”拆分(chunks 参数),子张量大小尽可能平均;
    • torch.split 按“指定大小”拆分(如 split_size_or_sections=2 表示每个子张量大小为 2)

torch.chunk 的一些示例

  • 基本用法(1D 张量)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    import torch

    # 能整除的情况
    x = torch.tensor([1, 2, 3, 4, 5, 6])
    chunks = torch.chunk(x, chunks=3, dim=0) # 沿第0维拆分成3份
    print(chunks)
    # 输出:(tensor([1, 2]), tensor([3, 4]), tensor([5, 6]))

    # 不能整除的情况
    x = torch.tensor([1, 2, 3, 4, 5])
    chunks = torch.chunk(x, chunks=2, dim=0) # 5不能被2整除,最后一个子张量多1个元素
    print(chunks)
    # 输出:(tensor([1, 2]), tensor([3, 4, 5]))
  • 高维张量拆分(2D 张量)示例

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    x = torch.arange(12).reshape(3, 4)  # 形状为 (3, 4) 的矩阵
    print("原始张量:\n", x)
    # 原始张量:
    # tensor([[ 0, 1, 2, 3],
    # [ 4, 5, 6, 7],
    # [ 8, 9, 10, 11]])

    # 沿 dim=0(行维度)拆分成2份
    chunks_dim0 = torch.chunk(x, chunks=2, dim=0)
    print("沿行拆分:\n", chunks_dim0)
    # 沿行拆分:
    # (tensor([[0, 1, 2, 3],
    # [4, 5, 6, 7]]),
    # tensor([[ 8, 9, 10, 11]]))

    # 沿 dim=1(列维度)拆分成2份
    chunks_dim1 = torch.chunk(x, chunks=2, dim=1)
    print("沿列拆分:\n", chunks_dim1)
    # 沿列拆分:
    # (tensor([[0, 1],
    # [4, 5],
    # [8, 9]]),
    # tensor([[ 2, 3],
    # [ 6, 7],
    # [10, 11]]))

torch.where 函数和 tensor.where 函数

  • torch.where() 是 PyTorch 中用于基于条件对张量元素进行选择性替换的方法,类似于“三目运算符”的向量版,语法如下:

    1
    torch.where(condition, x, y)
    • condition:布尔型张量(与原张量同形状),用于判断每个元素是否满足条件
    • x:当 conditionTrue 时,保留或使用 x 的值(可与原张量同形状,或为标量)
    • y:当 conditionFalse 时,使用 y 的值(可与原张量同形状,或为标量)
    • 返回值:一个新张量,每个元素根据 conditionxy 中取值
  • 注:x.where(condition, y) 是实例方法,等价于全局函数 torch.where(condition, x, y)

tensor.where(condition, y)(实例方法)

  • 张量对象的实例方法,语法为:

    1
    result = tensor.where(condition, y)
    • condition(必选):布尔型张量(torch.BoolTensor),形状必须与 tensor 相同(或可广播为相同形状)
      • 用于判断每个元素是否满足条件,决定最终取值来源
    • y(必选): 张量(与 tensor 同数据类型)或标量(如 intfloat
      • 若为张量,形状必须与 tensor 相同(或可广播为相同形状);若为标量,会自动广播到 tensor 的形状
        • conditionFalse 时,使用 y 的值(或对应位置的元素)替换 tensor 中的元素

torch.where(condition, x, y)(全局函数)

  • PyTorch 的全局函数,语法为:

    1
    result = torch.where(condition, x, y)
    • condition(必选):布尔型张量(torch.BoolTensor),形状必须与 xy 相同(或可广播为相同形状)
      • 用于判断每个元素是否满足条件,决定从 xy 中取值
    • x(必选):张量(与 y 同数据类型)或标量
      • 若为张量,形状必须与 condition 相同(或可广播为相同形状);若为标量,会自动广播到对应形状
      • conditionTrue 时,使用 x 的值(或对应位置的元素)
    • y(必选):张量(与 x 同数据类型)或标量
      • 若为张量,形状必须与 condition 相同(或可广播为相同形状);若为标量,会自动广播到对应形状
      • conditionFalse 时,使用 y 的值(或对应位置的元素)

示例

  • 代码示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    import torch

    tensor = torch.tensor([1, 5, 3, 8, 2, 7])
    condition = tensor % 2 == 0
    y = torch.tensor([-1, -1, -1, -1, -1, -1])

    # 方法1:tensor.where(condition, y) —— 实例方法
    result_instance = tensor.where(condition, y)

    # 等价方法:torch.where(condition, tensor, y) —— 函数
    result_function = torch.where(condition, tensor, y)

    print("tensor.where() 结果:", result_instance) # tensor([-1, -1, -1, 8, 2, -1])
    print("torch.where() 结果:", result_function) # tensor([-1, -1, -1, 8, 2, -1])

torch.size() 函数 和 tensor.size(n)

tensor.size() 函数

  • 返回张量的完整形状 ,描述张量在每个维度上的元素个数
  • 返回值 torch.Size 可直接当作 tuple 使用(支持索引、len() 等),例如 len(t2.size()) 会返回张量的维度数(2)
  • 等价写法:tensor.shape(属性,功能与 size() 完全一致,更简洁),例如 t2.shapet2.size() 结果相同

tensor.size(n) 函数

  • 返回张量第 n 维的元素个数(维度索引从 0 开始)

  • 用法:

    1
    tensor.size(dim)  # dim:指定维度的索引(0 表示第1维,1 表示第2维,以此类推)
    • tensor.size(1) 等价于 tensor.size()[1]tensor.shape[1]
  • 维度索引从 0 开始:

    • size(0) 是第 1 维
    • size(1) 是第 2 维
    • size(-1) 表示最后一维(常用技巧)
  • 若指定的维度索引超出张量的维度范围 会抛出 IndexError

    • 例如 1D 张量用 size(1) 会抛异常

对比 torch.size() 函数 和 tensor.size(n)

  • size():快速查看张量整体形状(例如确认输入数据是否符合模型要求)
  • size(1):常用于提取矩阵的列数、文本张量的词向量维度等(例如 batch_size, seq_len = tensor.size(0), tensor.size(1)
  • 使用示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    import torch

    # 2D 张量(矩阵)整体输出
    t2 = torch.tensor([[1, 2], [3, 4], [5, 6]])
    print(t2.size()) # 输出: torch.Size([3, 2]),2维,3行2列

    # 1D 张量:没有第1维(索引1超出范围),会报错
    t1 = torch.tensor([1, 2, 3])
    # print(t1.size(1)) # 报错:IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

    # 2D 张量:第1维(索引1)是列数
    print(t2.size(1)) # 输出: 2,对应 shape [3, 2] 的第2个元素(列数)
    print(t2.size()[1]) # 等价写法,输出: 2

torch.roll() 函数

  • torch.roll 用于循环移位(滚动)张量元素 ,支持沿指定维度对张量元素进行循环平移,移位后超出边界的元素会从另一侧补回(类似“循环队列”的逻辑)

  • 函数签名为:

    1
    torch.roll(input, shifts, dims=None) -> Tensor
    • input:Tensor,输入张量(任意维度) |
    • shifts:int / 序列,移位步数:
      • 正数:沿维度从后向前移(向右/向下);
      • 负数:沿维度从前向后移(向左/向上);
      • 序列形式(如 (2, -1)):需与 dims 一一对应,为每个维度指定独立移位步数
    • dims:int / 序列 / None,移位维度:
      • None(默认):先将张量展平为 1D 再移位,最后恢复原形状;
      • 单个 int:仅沿该维度移位;
      • 序列形式(如 (0, 1)):需与 shifts 长度一致,对多个维度依次移位
    • 返回值:与 input 形状、数据类型完全相同的新张量(原张量不改变,除非用 in-place 版本 torch.roll_

使用示例

  • 1D 张量示例

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    import torch

    x = torch.tensor([0, 1, 2, 3, 4])

    # 向右移位 2 步(正数=向后移)
    roll1 = torch.roll(x, shifts=2)
    print(roll1) # tensor([3, 4, 0, 1, 2])

    # 向左移位 1 步(负数=向前移,等价于 shifts=len(x)-1)
    roll2 = torch.roll(x, shifts=-1)
    print(roll2) # tensor([1, 2, 3, 4, 0])
  • 2D 张量示例(注:dim=0(行方向,上下移)、dim=1(列方向,左右移))

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    # 仅单维度移位示例
    x = torch.tensor([[1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]])

    # 沿 dim=0(行)向下移 1 步(最后一行绕回开头)
    roll_row = torch.roll(x, shifts=1, dims=0)
    print("行移位:")
    print(roll_row)
    # tensor([[7, 8, 9],
    # [1, 2, 3],
    # [4, 5, 6]])

    # 沿 dim=1(列)向左移 1 步(第一列绕回末尾)
    roll_col = torch.roll(x, shifts=-1, dims=1)
    print("列移位:")
    print(roll_col)
    # tensor([[2, 3, 1],
    # [5, 6, 4],
    # [8, 9, 7]])

    # 多维度同时移位:需保证 `shifts` 和 `dims` 长度一致,分别对应每个维度的移位步数
    x = torch.tensor([[1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]])

    # 沿 dim=0 向下移 1 步,沿 dim=1 向右移 2 步
    roll_multi = torch.roll(x, shifts=(1, 2), dims=(0, 1))
    print(roll_multi)
    # 步骤解析:
    # 1. dim=0 移位后:[[7,8,9],[1,2,3],[4,5,6]]
    # 2. 再对 dim=1 移位 2 步:[[8,9,7],[2,3,1],[5,6,4]]
    # 最终结果:
    # tensor([[8, 9, 7],
    # [2, 3, 1],
    # [5, 6, 4]])
  • dims=None(展平后移位,再恢复原来 shape):默认将张量展平为 1D 移位,再恢复原形状,慎用(可能导致元素顺序混乱)

    1
    2
    3
    4
    5
    x = torch.tensor([[1, 2], [3, 4]])

    # 展平为 [1,2,3,4],移位 1 步 -> [4,1,2,3],再恢复为 (2,2)
    roll_flat = torch.roll(x, shifts=1, dims=None)
    print(roll_flat) # tensor([[4, 1], [2, 3]])

使用注意事项

  • 移位步数可以超界(但不建议) :步数会自动对维度长度取模(如长度为 5 的维度,移位 7 步等价于移位 7%5=2 步)

    1
    2
    x = torch.tensor([0,1,2,3,4])
    print(torch.roll(x, shifts=7)) # 等价于 shifts=2 -> tensor([3,4,0,1,2])
  • 原张量不改变torch.roll 是“非 in-place”操作,如需修改原张量,使用 torch.roll_(末尾加下划线)

    1
    2
    3
    x = torch.tensor([0,1,2])
    x.roll_(shifts=1) # 原张量被修改
    print(x) # tensor([2, 0, 1])
  • 梯度传播 :支持自动求导(梯度会跟随移位逻辑反向传播),可用于神经网络层中


torch.quantile() 函数

  • torch.quantile 是用于计算张量分位数的函数,支持多维张量、指定维度计算、线性插值等功能,适用于统计分析、异常值检测等场景

  • torch.quantile 支持多维张量、批量分位数、多种插值模式,核心是通过 dim 控制计算维度、q 指定分位数、interpolation 调整插值逻辑

  • torch.quantile 函数签名

    1
    2
    3
    4
    5
    6
    7
    8
    torch.quantile(
    input: Tensor,
    q: Union[float, Tensor],
    dim: Optional[Union[int, Tuple[int, ...]]] = None,
    keepdim: bool = False,
    interpolation: str = 'linear',
    out: Optional[Tensor] = None
    ) -> Tensor
    • inputTensor
      • 输入张量(支持任意维度,如 1D/2D/3D 等)
    • qfloatTensor
      • 分位数(范围:[0, 1]
      • 可以是单个值(如 0.5 表示中位数)或张量(批量计算多个分位数)
    • dimintTuple[int, ...]
      • 计算分位数的维度(可选)。默认 None 表示对整个张量展平后计算
    • keepdimbool
      • 是否保留计算维度(默认 False,即压缩维度;True 则维度数不变)
    • interpolationstr
      • 分位数插值方式(默认 'linear'),支持 5 种模式(下文详细说明)
        • 'linear'(默认),线性插值:q = i + f 时,结果 = (1-f)*x[i] + f*x[i+1](最常用)
        • 'lower',取下界:结果 = x[floor(i + f)](即小于等于目标位置的最大元素)
        • 'higher',取上界:结果 = x[ceil(i + f)](即大于等于目标位置的最小元素)
        • 'nearest',取最近邻:结果 = x[round(i + f)](四舍五入到最近索引)
        • 'midpoint',中点插值:结果 = (x[floor(i + f)] + x[ceil(i + f)]) / 2(上下界平均)

使用示例

  • 1D 张量(展平计算),计算单个分位数(中位数)和多个分位数:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    import torch

    x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])

    # 1. 计算中位数(0.5 分位数)
    median = torch.quantile(x, q=0.5)
    print("中位数:", median) # 输出:tensor(3.)

    # 2. 计算多个分位数(0.25、0.5、0.75 四分位数)
    q_list = torch.tensor([0.25, 0.5, 0.75])
    quantiles = torch.quantile(x, q=q_list)
    print("四分位数:", quantiles) # 输出:tensor([2., 3., 4.])
  • 多维张量(指定维度计算):对 2D 张量的行/列计算分位数,控制 dimkeepdim

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32)

    # 1. 按列(dim=0)计算中位数,保留维度(keepdim=True)
    col_median = torch.quantile(x, q=0.5, dim=0, keepdim=True)
    print("按列中位数(保留维度):\n", col_median)
    # 输出:tensor([[4., 5., 6.]])(形状:[1, 3])

    # 2. 按行(dim=1)计算多个分位数(0.25、0.75),不保留维度
    row_quantiles = torch.quantile(x, q=[0.25, 0.75], dim=1)
    print("按行四分位数(不保留维度):\n", row_quantiles)
    # 输出:tensor([[2., 5., 8.], [3., 6., 9.]])(形状:[2, 3])

注意事项

  • input 需为浮点型(float32/float64),整数型张量会自动转换为浮点型
  • q 必须在 [0, 1] 内,否则会抛出 ValueError
  • torch.quantile 的参数(如 interpolation)与 numpy.quantile 基本一致,可无缝迁移

torch.nn.functional.pad()(简称 F.pad()

  • F.pad() 是用于张量填充的核心函数,支持任意维度的对称/非对称填充、多种填充模式(常数、反射、复制等),

  • F.pad() 函数签名

    1
    2
    3
    4
    5
    6
    7
    8
    9
    import torch.nn.functional as F

    F.pad(
    input: torch.Tensor,
    pad: Sequence[int],
    mode: str = 'constant',
    value: float = 0.0,
    **kwargs
    ) -> torch.Tensor
    • inputtorch.Tensor
      • 输入张量(支持任意维度:1D/2D/3D/4D/5D 等)
    • padSequence[int]
      • 填充尺寸配置(关键参数),格式为 (left0, right0, left1, right1, left2, right2, ...),对应张量从最后一维到第一维的左右填充数
      • left0, right0 表示 dim=-1 的维度(最后一维)的左边和右边分别 pad 数量配置
      • 对于 N 维张量,pad 需包含 2*N 个整数(若某维度无需填充,填 0)
      • 每个维度的填充格式为 (left_pad, right_pad),即 左边填充数,右边填充数
      • 示例:
        • 1D 张量(shape [L]):pad=(left, right) ,填充后 shape [left + L + right]
        • 2D 张量(shape [H, W]):pad=(left_w, right_w, left_h, right_h),填充后 shape [left_h + H + right_h, left_w + W + right_w]
        • 3D 张量(shape [D, H, W]):pad=(left_w, right_w, left_h, right_h, left_d, right_d),填充后 shape [left_d + D + right_d, left_h + H + right_h, left_w + W + right_w]
    • modestr
      • 填充模式(默认 'constant'),支持 6 种模式
        • 'constant'(默认):用固定值 value 填充(如 0 填充)
        • 'reflect',反射填充:以张量边缘为对称轴反射(不包含边缘本身);注意,此时张量边缘是对称轴
        • 'replicate',复制填充:用张量边缘的元素填充(重复边缘值)
        • 'circular',循环填充:将张量视为循环结构,用对侧的元素填充(环绕填充)
        • 'edge',等价于 'replicate'(兼容旧版本 PyTorch)
        • 'symmetric',对称填充:以张量边缘为对称轴反射(包含边缘本身)
    • valuefloat
      • mode='constant' 时有效,指定填充的常数(默认 0)
    • **kwargs
      • 额外参数(如 reflect/replicate 模式下的 padding_mode 兼容参数,较少用)

不同模式效果对比

  • 以 2D 张量 [[1,2],[3,4]] 为例,左/右/上/下各填充 1 个元素,对比核心模式:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    import torch
    import torch.nn.functional as F

    x = torch.tensor([[1,2],[3,4]], dtype=torch.float32)
    pad = (1,1,1,1) # (left_w=1, right_w=1, left_h=1, right_h=1)

    # 不同模式填充结果简单对比
    print("constant (value=0):")
    print(F.pad(x, pad, mode='constant', value=0))
    # 输出:
    # tensor([[0., 0., 0., 0.],
    # [0., 1., 2., 0.],
    # [0., 3., 4., 0.],
    # [0., 0., 0., 0.]])

    print("\nreflect:")
    print(F.pad(x, pad, mode='reflect'))
    # 输出(反射不包含边缘):
    # tensor([[4., 3., 4., 3.],
    # [2., 1., 2., 1.],
    # [4., 3., 4., 3.],
    # [2., 1., 2., 1.]])

    print("\nreplicate:")
    print(F.pad(x, pad, mode='replicate'))
    # 输出(复制边缘):
    # tensor([[1., 1., 2., 2.],
    # [1., 1., 2., 2.],
    # [3., 3., 4., 4.],
    # [3., 3., 4., 4.]])

注意事项

  • pad 维度顺序 :必须是「最后一维 到 第一维」的左右填充,容易混淆(如 2D 张量先 W 后 H),填错会导致尺寸异常
  • F.pad 是可微分操作,填充的常数部分梯度为 0,反射/复制部分梯度会反向传播到原张量边缘元素
  • 'constant'/'replicate' 模式效率最高,'reflect'/'symmetric' 稍慢,'circular' 因循环逻辑效率较低(大张量建议提前优化)
  • nn.ZeroPad2d 等层的区别nn.ZeroPad2d/nn.ReflectionPad2d 是封装好的层(仅支持特定维度),F.pad 更灵活(支持任意维度和模式),功能完全覆盖前者

torch.unbind 函数

  • torch.unbind 是 PyTorch 中用于拆分张量维度的核心函数,作用是将一个张量沿着指定维度(dim)“解绑”(拆分)为多个独立的张量,返回这些张量的元组

  • 简单理解:假设有一个 shape 为 (batch_size, seq_len, hidden_dim) 的张量

    • 沿着 dim=0(batch 维度)unbind 后,会得到 batch_size 个 shape 为 (seq_len, hidden_dim) 的张量;
    • 沿着 dim=1(seq_len 维度)unbind 后,会得到 seq_len 个 shape 为 (batch_size, hidden_dim) 的张量
  • 函数签名与核心参数

    1
    torch.unbind(input, dim=0) -> tuple[Tensor, ...]
    • input:待拆分的输入张量(任意维度)
    • dim:指定拆分的维度(默认 0,即第一个维度)
    • 返回值:拆分后的张量组成的元组,元组长度 = 原张量在 dim 维度上的大小

tensor.flip(dims) 函数

  • tensor.flip(dims) 用于沿着指定维度翻转张量(reverse the order of elements),仅改变指定维度上元素的顺序,不改变张量的形状、数据类型和设备

    • dims 是整数或整数元组,指定要翻转的维度(0=第1维,1=第2维,以此类推)
    • 沿 dims 维度反转元素顺序,非指定维度保持不变
    • 张量的 shapedtypedevicerequires_grad 均不改变
    • 支持自动微分(梯度会正确传播)
  • 注:flip 还有个实现是参数可以是 可变位置参数(即展开的列表/元组)

    1
    2
    3
    4
    @overload
    def flip(self, dims: _size) -> Tensor: ...
    @overload
    def flip(self, *dims: _int) -> Tensor: ... # 可变位置参数
  • 使用示例

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    import torch

    # 1D 整数张量
    a = torch.tensor([1, 2, 3, 4])
    a_flip = a.flip(dims=[0]) # 或 a.flip(0),沿维度0翻转(唯一维度),等价于 `tensor[::-1]`
    print("1D flip结果:", a_flip) # 输出: tensor([4, 3, 2, 1])
    print("形状不变:", a_flip.shape == a.shape) # 输出: True

    # 2D 浮点数张量,可指定沿行(维度0)、列(维度1)或同时沿行列翻转
    b = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])

    b_flip_0 = b.flip(dims=[0]) # 沿行翻转(上下颠倒) 等价于 b.flip(0)
    print("沿行翻转:\n", b_flip_0)
    # 输出:
    # tensor([[5., 6.],
    # [3., 4.],
    # [1., 2.]])

    b_flip_1 = b.flip(dims=[1]) # 沿列翻转(左右颠倒)
    print("沿列翻转:\n", b_flip_1)
    # 输出:
    # tensor([[2., 1.],
    # [4., 3.],
    # [6., 5.]])

    b_flip_01 = b.flip(dims=[0, 1]) # 同时沿行和列翻转 等价于 b.flip(0,1)
    print("沿行列翻转:\n", b_flip_01)
    # 输出:
    # tensor([[6., 5.],
    # [4., 3.],
    # [2., 1.]])

    # 3D 张量(模拟2个2×2的图像),可指定任意维度翻转(如batch维度、高度维度、宽度维度)
    c = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) # shape: (2, 2, 2)

    c_flip_2 = c.flip(dims=[2]) # 沿宽度维度(第3维)翻转
    print("3D沿宽度翻转:\n", c_flip_2)
    # 输出:
    # tensor([[[2, 1],
    # [4, 3]],
    # [[6, 5],
    # [8, 7]]])

    c_flip_0 = c.flip(dims=[0]) # 沿batch维度翻转
    print("3D沿batch翻转:\n", c_flip_0)
    # 输出:
    # tensor([[[5, 6],
    # [7, 8]],
    # [[1, 2],
    # [3, 4]]])

torch.nn.utils.rnn.pad_sequence 函数

  • torch.nn.utils.rnn.pad_sequence 是 PyTorch 中处理变长序列(Variable-length Sequences) 的工具,用于将一批长度不同的张量(序列)填充到相同长度,以便批量输入 RNN、Transformer 等模型

  • 核心功能是将一个 张量列表(每个张量对应一条变长序列)填充为一个 统一长度的二维/高维张量 ,填充值默认是 0(可自定义);其填充规则为:

    • 短序列在 末尾(右侧) 补零(默认,可通过 batch_first 调整返回值维度顺序);
    • 最终长度由列表中 最长序列的长度 决定
  • 函数签名

    1
    2
    3
    4
    5
    torch.nn.utils.rnn.pad_sequence(
    sequences, # 变长序列的列表
    batch_first=False, # 输出张量的维度顺序:是否为 (batch_size, seq_len, ...)
    padding_value=0.0 # 填充值(默认0)
    ) -> Tensor
    • sequencesList[Tensor]
      • 必须是张量列表 ,每个张量的形状需满足:
        • 1D 张量:(seq_len_i,)(如单特征序列);
        • 2D 张量:(seq_len_i, feature_dim)(如带特征维度的序列);
        • 高维张量:(seq_len_i, d1, d2, ...)(如序列+多特征维度);
        • 注意 :所有张量的除了 第一个 维度(要求第一个维度为待对齐的序列长度)外,其余维度必须一致
    • batch_firstbool
      • 控制输出张量的维度顺序:
        • False(默认):输出形状 (max_seq_len, batch_size, ...)(适配 PyTorch RNN 层默认输入格式);
        • True:输出形状 (batch_size, max_seq_len, ...)(更直观,适合自定义模型)
    • padding_valuefloat
      • 虽然写的是 float,但可以填充整数,默认 0 或 0.0,(类型需与输入张量 dtype 兼容)

用法示例

  • 1D 变长序列(单特征)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    import torch
    from torch.nn.utils.rnn import pad_sequence

    # 生成3条长度不同的1D序列(seq_len分别为2、3、1)
    seq1 = torch.tensor([1, 2]) # shape: (2,)
    seq2 = torch.tensor([3, 4, 5]) # shape: (3,)
    seq3 = torch.tensor([6]) # shape: (1,)
    sequences = [seq1, seq2, seq3]

    # 1. 默认参数(batch_first=False,padding_value=0)
    padded = pad_sequence(sequences)
    print("默认输出 shape:", padded.shape) # (max_seq_len=3, batch_size=3)
    print("默认输出:\n", padded)
    # 输出:
    # tensor([[1, 3, 6],
    # [2, 4, 0],
    # [0, 5, 0]])
  • 2D 变长序列(带特征维度):适用于每条序列的每个元素是一个特征向量(如词嵌入后的序列):

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    # 2D序列:shape (seq_len_i, feature_dim),feature_dim=2(统一)
    seq1 = torch.tensor([[1, 10], [2, 20]]) # (2, 2)
    seq2 = torch.tensor([[3, 30], [4, 40], [5, 50]]) # (3, 2)
    seq3 = torch.tensor([[6, 60]]) # (1, 2)
    sequences = [seq1, seq2, seq3]

    padded = pad_sequence(sequences, batch_first=True)
    print("2D序列填充后 shape:", padded.shape) # (3, 3, 2)(batch_size=3, max_seq_len=3, feature_dim=2)
    print("2D序列填充后:\n", padded)
    # 输出:
    # tensor([[[ 1, 10],
    # [ 2, 20],
    # [ 0, 0]], # 短序列补零(特征维度均补0)
    #
    # [[ 3, 30],
    # [ 4, 40],
    # [ 5, 50]],
    #
    # [[ 6, 60],
    # [ 0, 0],
    # [ 0, 0]]])

torch.randperm 函数

  • torch.randperm 函数 的原型为:

    1
    torch.randperm(n, *, generator=None, out=None, dtype=torch.int64, layout=torch.strided, device=None, requires_grad=False) -> Tensor
  • torch.randperm 函数用于生成 0 到 n-1 的随机排列

    • 返回一个长度为 n、包含不重复整数1D 张量 ,常用于数据打乱(如训练集 shuffle、批次采样)
    • 注意返回的是索引 indices,并不改变按原始数据

torch.argsort 函数

  • torch.argsort 用于返回张量元素排序后索引的函数

    • 核心作用是:不改变原张量,仅返回一个索引张量,该索引对应将原张量按指定规则排序后各元素的原始位置
  • torch.argsort 函数原型为:

    1
    torch.argsort(input, dim=-1, descending=False, stable=False, *, out=None)
    • input:输入张量(任意维度,如 1D、2D、3D 等)
    • dim :指定排序的维度(负号表示倒数维度,如 dim=-1 表示最后一维);默认 -1
    • descending:是否降序排序(True 降序,False 升序);默认 False
    • stable:是否使用稳定排序(相同元素保留原始顺序,仅对 CPU 有效,GPU 不支持);默认 False
    • out:输出张量(可选,用于指定结果存储位置);默认 None
    • 返回值:一个与 input 形状完全相同的整数张量 ,元素为原张量排序后的索引
  • 1D 张量示例:对一维张量排序,返回排序后元素的原始索引(升序/降序)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    import torch

    # 1D 输入张量
    x = torch.tensor([3, 1, 4, 1, 5])

    # 升序排序(默认 descending=False)
    idx_asc = torch.argsort(x)
    print("升序索引:", idx_asc) # 输出:tensor([1, 3, 0, 2, 4])
    # 验证:x[idx_asc] 即为排序后的张量
    print("升序结果:", x[idx_asc]) # 输出:tensor([1, 1, 3, 4, 5])

    # 降序排序(descending=True)
    idx_desc = torch.argsort(x, descending=True)
    print("降序索引:", idx_desc) # 输出:tensor([4, 2, 0, 3, 1])
    print("降序结果:", x[idx_desc]) # 输出:tensor([5, 4, 3, 1, 1])

结合 torch.randperm 函数 和 torch.argsort 函数 实现打乱再恢复

  • 举个直观例子:

    • 原索引:[0,1,2,3,4](对应 x 的位置)
    • randperm(5) 生成 shuffle_idx = [3,0,4,1,2](随机打乱)
    • x_shuffled = x[3], x[0], x[4], x[1], x[2]
    • argsort(shuffle_idx) 计算:
      • shuffle_idx 升序排序为 [0,1,2,3,4],其元素对应在 shuffle_idx 中的位置是 [1,3,4,0,2](存储为 restore_idx
    • 验证:x_shuffled[restore_idx] = x[0],x[1],x[2],x[3],x[4](完全恢复原顺序)
  • 通过 torch.randperm 打乱 + torch.argsort(shuffle_idx) 恢复的核心是 “索引映射的可逆性”

    • 打乱:用随机索引映射到乱序;
    • 恢复:用 argsort 还原该映射,与张量值无关,高效且稳定
  • 1D 张量示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    import torch

    # 1. 生成原张量(长度为 5)
    x = torch.tensor([10, 20, 30, 40, 50]) # 原张量:[10,20,30,40,50]
    n = len(x)

    # 2. 用 randperm 打乱顺序
    shuffle_idx = torch.randperm(n) # 生成 0~4 的随机排列(例:[3,0,4,1,2])
    x_shuffled = x[shuffle_idx] # 打乱后的张量:x[3],x[0],x[4],x[1],x[2] -> [40,10,50,20,30]
    print("打乱索引 shuffle_idx:", shuffle_idx)
    print("打乱后 x_shuffled:", x_shuffled)

    # 3. 用 argsort 恢复顺序(核心:对 shuffle_idx 做 argsort)
    restore_idx = torch.argsort(shuffle_idx) # 恢复索引:[1,3,4,0,2](示例)
    x_restored = x_shuffled[restore_idx] # 恢复原张量
    print("恢复索引 restore_idx:", restore_idx)
    print("恢复后 x_restored:", x_restored)
    print("是否完全恢复:", torch.equal(x, x_restored)) # 输出 True
  • 注意(对 index 做排序的优点):

    • 如果原张量有重复值,直接对 x_shuffledargsort 可能因值排序导致恢复失败,但对 shuffle_idxargsort 不受值影响(仅依赖索引映射)

len(dataloader) 函数

  • len(dataloader) 得到的是 训练数据加载器(DataLoader)的批次数(batch count),即整个训练数据集被划分为多少个批次(batch)
  • 注:len(dataloader) 可以在任意时刻使用,不会影响 dataloader 数据集的 迭代操作
  • 示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    import torch
    from torch.utils.data import DataLoader, TensorDataset

    # 构造简单数据集和dataloader
    dataset = TensorDataset(torch.randn(1000, 10)) # 1000个样本
    train_dataloader = DataLoader(dataset, batch_size=32, drop_last=False)

    # 定义get_batch(基于dataloader迭代器)
    def get_batch(loader):
    loader_iter = iter(loader) # 每次调用都重新初始化迭代器(关键)
    return next(loader_iter)

    # 第一步:调用len()
    print(len(train_dataloader)) # 输出32(仅计算批次数,不碰数据)

    # 第二步:正常get_batch
    batch = get_batch(train_dataloader)
    print(batch[0].shape) # 输出torch.Size([32, 10])(成功获取批次,无任何影响)

tensor.expand()tensor.expand_as() 函数

  • 核心:Expand 是“视图扩展”,不是“数据复制”
  • expand()expand_as() 的核心作用是将张量在维度上进行扩展(广播) ,但它并不会复制新的数据,而是返回原张量的一个“视图(view)”
    • 扩展后的张量和原张量共享内存(修改原张量,扩展后的张量也会变;)
    • 只有在维度大小为 1 的位置才能被扩展

tensor.expand() 用法

  • 基本语法

    1
    tensor.expand(*sizes)
    • *sizes:传入想要扩展后的张量形状(元组/多个整数)
    • 规则:
      • 对于原张量中大小为 1 的维度,可以扩展为任意正整数;
      • 对于原张量中大小大于 1 的维度,必须和原维度大小一致(不能改);
      • 可以用 -1 表示“保持原维度大小不变”;
      • 支持“维度扩展”(比如从 2D 扩展为 3D),只需在前面/后面加 1 再扩展
  • 代码示例

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    import torch

    # 扩展维度(可新增维度)
    x = torch.tensor([1, 2, 3]) # shape: [3]
    x_expand = x.expand(5, 3) # 扩展为 [5, 3]
    print("原张量x:", x)
    print("扩展后x_expand:", x_expand) # 复制 5 行
    print("x_expand形状:", x_expand.shape) # torch.Size([5, 3])

    # 用-1保持原维度
    y = torch.tensor([[1], [2], [3]]) # shape: [3, 1]
    y_expand = y.expand(3, 4) # 扩展第2维为4,shape: [3,4]
    y_expand2 = y.expand(-1, 4) # -1 等价于保持3不变,结果和上面一致
    print("\ny_expand:", y_expand)
    print("y_expand2形状:", y_expand2.shape) # torch.Size([3, 4])

    # 错误示例:非1维度不能修改
    z = torch.tensor([1, 2]) # shape: [2]
    # z_error = z.expand(3, 4) # 原第2维是2,不能扩展为4,会报错

tensor.expand_as() 用法

  • 基本语法

    1
    tensor.expand_as(other_tensor)
    • 作用:等价于 tensor.expand(other_tensor.size()),即把当前张量扩展为和 other_tensor 相同的形状
    • 规则:和 expand() 完全一致,只是不需要手动写形状,直接复用另一个张量的形状
  • 代码示例

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    import torch

    # 原张量
    a = torch.tensor([[1], [2], [3]]) # shape: [3, 1]
    # 目标形状的张量
    b = torch.randn(3, 4) # shape: [3, 4]

    # 扩展a为b的形状
    a_expand = a.expand_as(b)
    print("a_expand形状:", a_expand.shape) # torch.Size([3, 4])
    print("a_expand:", a_expand)
    # 输出:
    # tensor([[1, 1, 1, 1],
    # [2, 2, 2, 2],
    # [3, 3, 3, 3]])

    # 等价写法
    a_expand2 = a.expand(b.size())
    print(torch.equal(a_expand, a_expand2)) # True(两个结果完全一致)

torch.ge 函数

  • torch.ge 是 用于逐元素比较 的基础函数,核心功能是判断第一个输入张量的元素是否大于或等于第二个输入张量的对应元素,最终返回一个与输入同形状的布尔型张量(dtype=torch.bool),元素值为 True/False 表示对应位置的比较结果
  • 完整形式:torch.ge(input, other, *, out=None)
    • input:torch.Tensor(必选),第一个输入张量(比较的左操作数)
    • other: torch.Tensor, 数值型(int/float)(必选), 第二个输入(比较的右操作数),支持广播机制
  • 其他等价名称:torch.greater_equal(与 torch.ge 功能完全一致,可互换使用)
  • 注意输出类型:逐元素执行 input >= other 比较,输出布尔张量(与广播后形状一致),数据类型固定为 torch.bool,每个元素对应 inputother 对应位置的 >= 比较结果