- 参考链接:
- 一篇浅显易懂的博客:PyTorch中的高级索引方法——gather详解
gather函数形式
包含
torch.gather和tensor.gather两种形式,基本思路等价,他们的函数签名如下1
2torch.gather(input, dim, index, *, sparse_grad=False, out=None)
tensor.gather(dim, index, *, sparse_grad=False, out=None)- 注:在 PyTorch 的函数签名中,
*是 Python 语法中的一个特殊标记,用于表示强制关键字参数(keyword-only arguments)。这意味着在*之后的所有参数(如sparse_grad和out)必须通过关键字(即使用参数名)来传递,而不能通过位置参数传递
- 注:在 PyTorch 的函数签名中,
参数解释
- input (Tensor): 输入张量
- dim (int): 沿着哪个维度进行收集
- index (LongTensor): 索引张量,包含要收集的元素的索引
- sparse_grad (bool, 可选): 如果为True,梯度将是稀疏张量
- out (Tensor, 可选): 输出张量(若该值不为
None,则会将返回值存储到out引用中,此时out和 返回值 是同一个对象)
特别说明:
gather和普通的矩阵索引操作一样,操作支持反向传播
基本原理
对于 3D 张量,gather操作可以表示为:
1
2
3out[i][j][k] = input[index[i][j][k]][j][k] # dim=0
out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
out[i][j][k] = input[i][j][index[i][j][k]] # dim=2特别说明(记忆方法)
index[i][j][k]用于索引out[i][j][k]- dim=0 :index 指定行号(第0维) ,列号(第1维)和通道(第2维)保持不变
- dim=1 :index 指定列号(第1维) ,行号(第0维)和通道(第2维)保持不变
- dim=2 :index 指定通道(第2维) ,行号(第0维)和列号(第1维)保持不变
- 特别地:输出形状 = index 形状
代码示例
一维张量示例
1
2
3
4
5
6
7
8
9
10
11import torch
# 基本用法
input_tensor = torch.tensor([10, 20, 30, 40, 50])
index_tensor = torch.tensor([0, 2, 4, 1])
result = torch.gather(input_tensor, dim=0, index=index_tensor)
print(result) # tensor([10, 30, 50, 20])
# 使用tensor.gather方法
result2 = input_tensor.gather(0, index_tensor)
print(result2) # tensor([10, 30, 50, 20])二维张量示例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24# 二维张量
input_2d = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 沿着dim=0收集(按行收集)
index_2d = torch.tensor([[0, 1, 2],
[2, 0, 1]])
result_dim0 = torch.gather(input_2d, dim=0, index=index_2d)
print("dim=0 结果:")
print(result_dim0)
# tensor([[1, 5, 9],
# [7, 2, 6]])
# 沿着dim=1收集(按列收集)
index_2d_col = torch.tensor([[0, 2],
[1, 0],
[2, 1]])
result_dim1 = torch.gather(input_2d, dim=1, index=index_2d_col)
print("dim=1 结果:")
print(result_dim1)
# tensor([[1, 3],
# [5, 4],
# [9, 8]])
一些实际应用场景
- 获取最大值,获取对应标签等实例:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22# 场景1: 获取每行的最大值索引对应的值
scores = torch.tensor([[0.1, 0.8, 0.3],
[0.6, 0.2, 0.9],
[0.4, 0.7, 0.1]])
# 获取每行最大值的索引
max_indices = torch.argmax(scores, dim=1, keepdim=True)
print("最大值索引:", max_indices) # tensor([[1], [2], [1]])
# 使用gather获取最大值
max_values = torch.gather(scores, dim=1, index=max_indices)
print("最大值:", max_values) # tensor([[0.8], [0.9], [0.7]])
# 场景2: 根据标签获取对应的预测概率
predictions = torch.tensor([[0.2, 0.3, 0.5],
[0.1, 0.8, 0.1],
[0.6, 0.2, 0.2]])
labels = torch.tensor([2, 1, 0]) # 真实标签
# 获取每个样本对应标签的预测概率
label_probs = torch.gather(predictions, dim=1, index=labels.unsqueeze(1))
print("标签概率:", label_probs) # tensor([[0.5], [0.8], [0.6]])
注意事项
- 索引值必须在
[0, input.size(dim))范围内 - 除了指定的dim维度外,
input和index的其他维度大小必须相同 - 输出张量的形状与
index张量相同