PyTorch——gather函数使用


gather函数形式

  • 包含 torch.gathertensor.gather 两种形式,基本思路等价,他们的函数签名如下

    1
    2
    torch.gather(input, dim, index, *, sparse_grad=False, out=None)
    tensor.gather(dim, index, *, sparse_grad=False, out=None)
    • 注:在 PyTorch 的函数签名中,* 是 Python 语法中的一个特殊标记,用于表示强制关键字参数(keyword-only arguments)。这意味着在 * 之后的所有参数(如 sparse_gradout)必须通过关键字(即使用参数名)来传递,而不能通过位置参数传递
  • 参数解释

    • input (Tensor): 输入张量
    • dim (int): 沿着哪个维度进行收集
    • index (LongTensor): 索引张量,包含要收集的元素的索引
    • sparse_grad (bool, 可选): 如果为True,梯度将是稀疏张量
    • out (Tensor, 可选): 输出张量(若该值不为 None,则会将返回值存储到 out 引用中,此时 out 和 返回值 是同一个对象
  • 特别说明:gather 和普通的矩阵索引操作一样,操作支持反向传播


基本原理

  • 对于 3D 张量,gather操作可以表示为:

    1
    2
    3
    out[i][j][k] = input[index[i][j][k]][j][k]  # dim=0
    out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
    out[i][j][k] = input[i][j][index[i][j][k]] # dim=2
  • 特别说明(记忆方法)

    • index[i][j][k] 用于索引 out[i][j][k]
    • dim=0 :index 指定行号(第0维) ,列号(第1维)和通道(第2维)保持不变
    • dim=1 :index 指定列号(第1维) ,行号(第0维)和通道(第2维)保持不变
    • dim=2 :index 指定通道(第2维) ,行号(第0维)和列号(第1维)保持不变
    • 特别地:输出形状 = index 形状

代码示例

  • 一维张量示例

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    import torch

    # 基本用法
    input_tensor = torch.tensor([10, 20, 30, 40, 50])
    index_tensor = torch.tensor([0, 2, 4, 1])
    result = torch.gather(input_tensor, dim=0, index=index_tensor)
    print(result) # tensor([10, 30, 50, 20])

    # 使用tensor.gather方法
    result2 = input_tensor.gather(0, index_tensor)
    print(result2) # tensor([10, 30, 50, 20])
  • 二维张量示例

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    # 二维张量
    input_2d = torch.tensor([[1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]])

    # 沿着dim=0收集(按行收集)
    index_2d = torch.tensor([[0, 1, 2],
    [2, 0, 1]])
    result_dim0 = torch.gather(input_2d, dim=0, index=index_2d)
    print("dim=0 结果:")
    print(result_dim0)
    # tensor([[1, 5, 9],
    # [7, 2, 6]])

    # 沿着dim=1收集(按列收集)
    index_2d_col = torch.tensor([[0, 2],
    [1, 0],
    [2, 1]])
    result_dim1 = torch.gather(input_2d, dim=1, index=index_2d_col)
    print("dim=1 结果:")
    print(result_dim1)
    # tensor([[1, 3],
    # [5, 4],
    # [9, 8]])

一些实际应用场景

  • 获取最大值,获取对应标签等实例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    # 场景1: 获取每行的最大值索引对应的值
    scores = torch.tensor([[0.1, 0.8, 0.3],
    [0.6, 0.2, 0.9],
    [0.4, 0.7, 0.1]])

    # 获取每行最大值的索引
    max_indices = torch.argmax(scores, dim=1, keepdim=True)
    print("最大值索引:", max_indices) # tensor([[1], [2], [1]])

    # 使用gather获取最大值
    max_values = torch.gather(scores, dim=1, index=max_indices)
    print("最大值:", max_values) # tensor([[0.8], [0.9], [0.7]])

    # 场景2: 根据标签获取对应的预测概率
    predictions = torch.tensor([[0.2, 0.3, 0.5],
    [0.1, 0.8, 0.1],
    [0.6, 0.2, 0.2]])
    labels = torch.tensor([2, 1, 0]) # 真实标签

    # 获取每个样本对应标签的预测概率
    label_probs = torch.gather(predictions, dim=1, index=labels.unsqueeze(1))
    print("标签概率:", label_probs) # tensor([[0.5], [0.8], [0.6]])

注意事项

  • 索引值必须在 [0, input.size(dim)) 范围内
  • 除了指定的dim维度外,inputindex 的其他维度大小必须相同
  • 输出张量的形状与 index 张量相同