整体说明
- PyTorch 的高级索引操作允许以非常灵活和强大的方式选择和修改张量中的元素
- 高级索引包括整数索引、切片(slicing)、布尔索引和整数数组索引等
- PyTorch 中的索引使用(包括基础索引和高级索引)和 Numpy 中基本一致
- 可用于选择元素,也可以用于修改元素
- 高级索引一般不共享存储区(普通索引一般共享存储区)
- 普通索引一般可以通过修改 Tensor 的偏移量(offset)、步长(stride)或形状实现,不需要修改存储区的数据(使用共享存储区可以节省内存和处理速度)
- 高级索引则一般都是不规则的变化 ,需要修改存储区,故而不使用共享存储区
- 这也是高级索引与切片的最大差别
- 检索维度匹配要求 :多个索引数组的维度必须能够广播成一致的形状,否则报错
- 高级索引的判定方式:
- 在 PyTorch 中,当索引对象是一个非元组序列对象、一个
Tensor(数据类型为整数或布尔,在 NumPy 中为ndarray),或一个至少包含一个序列对象或Tensor(数据类型为整数或布尔,在 NumPy 中为ndarray)的元组时,会触发高级索引判定
- 在 PyTorch 中,当索引对象是一个非元组序列对象、一个
基础索引回顾
- 在深入高级索引之前,我们先快速回顾一下基础索引:
- 单个整数索引 : 选取特定位置的单个元素
- 切片 : 选取连续的子范围。例如
a[start:end:step] - 省略号 (
...) : 表示选择所有剩余的维度- 这在张量维度较多时非常有用,可以避免写出冗长的
:
- 这在张量维度较多时非常有用,可以避免写出冗长的
- 基础索引的 Demo:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23import torch
a = torch.arange(27).reshape(3, 3, 3)
print("原始张量 a:\n", a)
# 输出:
# tensor([[[ 0, 1, 2],
# [ 3, 4, 5],
# [ 6, 7, 8]],
#
# [[ 9, 10, 11],
# [12, 13, 14],
# [15, 16, 17]],
#
# [[18, 19, 20],
# [21, 22, 23],
# [24, 25, 26]]])
print("a[0, 1, 2]:", a[0, 1, 2]) # 输出:tensor(5)(第一个维度第0个,第二个维度第1个,第三个维度第2个)
print("a[1, :, 0]:\n", a[1, :, 0]) # 输出:tensor([ 9, 12, 15])(第一个维度第1个,第二个维度所有,第三个维度第0个)
print("a[..., 1]:\n", a[..., 1]) # 选取所有维度,最后一个维度第1个,等价于 a[:,:,1]
# 输出:
# tensor([[ 1, 4, 7],
# [10, 13, 16],
# [19, 22, 25]])
高级索引-整数数组索引 (Integer Array Indexing)
当使用
torch.Tensor(类型为torch.long或torch.int) 或 Python 列表作为索引时,这被称为整数数组索引(高级索引的一种)- 第一步:广播 : 如果有多个整数数组索引,并且它们的形状不同,PyTorch 会尝试对它们进行广播(broadcasting),广播为相同形状
- 第二步:返回结果 : 整数数组索引的返回张量的形状由广播后的索引张量的形状决定 ,当部分维度没有被索引时,默认保留该维度的所有值
示例 1 : 在一个维度上使用整数数组索引
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24import torch
a = torch.arange(12).reshape(3, 4)
print("原始张量 a:\n", a)
# 输出:
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])
# 选取第0行和第2行
indices = torch.tensor([0, 2])
result = a[indices]
print("\na[torch.tensor([0, 2])]:\n", result)
# 输出:
# tensor([[ 0, 1, 2, 3],
# [ 8, 9, 10, 11]])
# 选取第1列和第3列
result = a[:, torch.tensor([1, 3])]
print("\na[:, torch.tensor([1, 3])]:\n", result)
# 输出:
# tensor([[ 1, 3],
# [ 5, 7],
# [ 9, 11]])示例 2: 在多个维度上使用整数数组索引(“坐标”式选择)
- 注:当你在多个维度上同时使用整数数组索引时,它们会被解释为坐标对(这与 NumPy 的行为非常相似)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41import torch
a = torch.arange(12).reshape(3, 4)
print("原始张量 a:\n", a)
# 选取 (0, 1), (1, 3), (2, 0) 这些位置的元素
row_indices = torch.tensor([0, 1, 2])
col_indices = torch.tensor([1, 3, 0])
result = a[row_indices, col_indices]
print("\na[torch.tensor([0, 1, 2]), torch.tensor([1, 3, 0])]:", result)
# 解释:
# 结果的第一个元素是 a[0, 1]
# 结果的第二个元素是 a[1, 3]
# 结果的第三个元素是 a[2, 0]
# 输出: tensor([ 1, 7, 8])
# 索引张量形状不一致时会进行广播
a = torch.arange(27).reshape(3, 3, 3)
print("\n原始张量 a (3x3x3):\n", a)
# 选取多个坐标,例如:
# 维度0取索引0和2
# 维度1取索引1和2
# 维度2取索引0和1
idx0 = torch.tensor([[0], [2]]) # shape (2, 1)
idx1 = torch.tensor([1, 2]) # shape (2,)
idx2 = torch.tensor([0, 1]) # shape (2,)
# PyTorch 会尝试广播这些索引
# idx0: [[0], [2]] -> 广播为 [[0, 0], [2, 2]] (因为 idx1 和 idx2 的大小是2)
# idx1: [1, 2] -> 广播为 [[1, 2], [1, 2]]
# idx2: [0, 1] -> 广播为 [[0, 1], [0, 1]]
# 最终会选择以下坐标的元素:
# (0, 1, 0), (0, 2, 1)
# (2, 1, 0), (2, 2, 1)
result = a[idx0, idx1, idx2]
print("\na[idx0, idx1, idx2]:\n", result)
# 输出:
# tensor([[ 3, 5], # a[0,1,0], a[0,2,1]
# [21, 23]]) # a[2,1,0], a[2,2,1]
- 注:当你在多个维度上同时使用整数数组索引时,它们会被解释为坐标对(这与 NumPy 的行为非常相似)
布尔索引 (Boolean Indexing)
- 当使用一个布尔张量作为索引时,PyTorch 会选择布尔张量中值为
True的所有元素- 形状要求 :布尔张量的形状必须与被索引张量的一个或多个维度匹配
- 返回结果 :布尔张量索引的结果张量通常是 1 维的张量,包含所有满足条件(
True)的元素 - 特别说明:布尔张量通常被称为“掩码”
- 布尔张量的 Demo:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17import torch
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print("原始张量 a:\n", a)
# 选取所有大于 5 的元素
mask = a > 5
print("\n布尔掩码 (a > 5):\n", mask)
# 输出:
# tensor([[False, False, False],
# [False, False, True],
# [ True, True, True]])
result = a[mask]
print("\na[a > 5]:", result)
# 输出: tensor([6, 7, 8, 9])
# 注意以上输出是 1 维的
混合索引操作:a[index1, :, index2] 和 a[index1, index2, :]
特别说明广播机制:当索引数组的维度不匹配时,PyTorch 会尝试运用广播规则来使它们的维度变得兼容,即使混合索引也适用,详情如下:
1
2
3
4
5# 行索引是[0, 2],列索引对所有行都是0
rows = torch.tensor([0, 2]) # 形状为(2,)
cols = torch.tensor([0]) # 形状为(1,)
result = a[rows, :, cols] # 广播后形状变为(2, 3, 1)
print(result.shape) # 输出: torch.Size([2, 3, 1])关于索引后的形状:
- 若高级索引没被隔离,则正常覆盖高级索引所在的维度区域
- 若高级索引被隔离,则对齐维度后,高级索引需要合并,并放到最前面(注意:不是第一个高级索引所在的位置,而是从第 0 维开始的最前面几维)
这种索引方式是在多个维度上混合使用高级索引和切片操作
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54import torch
a = torch.arange(27).reshape(3, 3, 3)
print(a)
# 输出:
# tensor([[[ 0, 1, 2],
# [ 3, 4, 5],
# [ 6, 7, 8]],
# [[ 9, 10, 11],
# [12, 13, 14],
# [15, 16, 17]],
# [[18, 19, 20],
# [21, 22, 23],
# [24, 25, 26]]])
index1 = torch.tensor([[0, 2]]) # 使用 [[0,2]] 也可以
index2 = torch.tensor([[0]]) # 1)使用 [[0]] 也可以;2)[0] 也可以得到相同结果,因为广播后结果是一样的
result = a[index1, index2, :]
print(result.shape) # 输出: torch.Size([1, 2, 3])
print(result)
# 输出:
# tensor([[[ 0, 1, 2],
# [18, 19, 20]]])
# 理解:index1和index2广播以后是shape=[1,2],结果中,第0,1维的[3,3]会被[1,2]替换
index1 = torch.tensor([[0, 2]]) # 使用 [[0,2]] 也可以
index2 = torch.tensor([[0]]) # 1)使用 [[0]] 也可以;2)[0] 也可以得到相同结果,因为广播后结果是一样的
result = a[index1, :, index2]
print(result.shape) # 输出: torch.Size([1, 2, 3])
print(result)
# 输出:
# tensor([[[ 0, 3, 6],
# [18, 21, 24]]])
# 理解:相当于先调用 a.transpose_(2,1) 对齐索引维度,然后再调用 a[index1,index2,:],因为 高级索引需要合并到一起去广播并放到最前面
# # 也可以用permute替代transpose,但permute不能inplace
# 测试下面的代码替换 result = a[index1, :, index2] 后与上述输出一致
# a.transpose_(2,1)
# result = a[index1, index2, :]
index1 = torch.tensor([0, 2])
index2 = torch.tensor([0, 2])
result = a[index1, :, index2]
print(result.shape) # 输出: torch.Size([2, 3])
print(result)
# 输出:
# tensor([[[ 0, 3, 6],
# [20, 23, 26]]])
# tensor([[ 0, 3, 6],
# [20, 23, 26]])
# 理解1:index1和index2广播以后是shape=[2],结果中,第0,1维的[3,3]会被[2]替换
# 理解2:相当于先调用 a.transpose_(2,1) 对齐索引维度,然后再调用 a[index1,index2,:],因为 高级索引需要合并到一起去广播并放到最前面
# 说明:已经测试下面的代码替换 result = a[index1, :, index2] 后与上述输出一致
# a.transpose_(2,1)
# result = a[index1, index2, :]高级索引被隔离的补充实验 Demo:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38import torch
a = torch.arange(3*4*5*6).reshape(3, 4, 5, 6)
# print(a)
index1 = torch.tensor([[0, 2]])
index2 = torch.tensor([[0, 2]])
result = a[:, index1, :, index2]
print(result.shape) # 输出:torch.Size([1, 2, 3, 3])
print(result)
# 输出:
# tensor([[[[ 0, 6, 12, 18, 24],
# [120, 126, 132, 138, 144],
# [240, 246, 252, 258, 264]],
#
# [[ 62, 68, 74, 80, 86],
# [182, 188, 194, 200, 206],
# [302, 308, 314, 320, 326]]]])
index1 = torch.tensor([[0, 2]])
index2 = torch.tensor([[0, 2]])
a = a.permute(1,3,0,2)
result = a[index1,index2,:,:]
print(result.shape) # 输出:torch.Size([1, 2, 3, 3])
print(result)
# 输出:
# tensor([[[[ 0, 6, 12, 18, 24],
# [120, 126, 132, 138, 144],
# [240, 246, 252, 258, 264]],
#
# [[ 62, 68, 74, 80, 86],
# [182, 188, 194, 200, 206],
# [302, 308, 314, 320, 326]]]])
# 理解1:高级索引index1和index2被基础索引隔断了,index1和index2广播后为shape=[1,2],然后高级索引的维度会替换所在维度的[3,3]后放到最前面
# 理解2:相当于先执行 a = a.permute(1,3,0,2)(注意permute操作不能inplace),然后再执行 result = a[index1,index2,:,:]
# 测试说明:从上面的代码输出可以看到,result = a[:, index1, :, index2] 替换为下面语句后,结果一致:
# a = a.permute(1,3,0,2)
# result = a[index1,index2,:,:]
附录:高级索引的一些常见用法
批量数据选择
- 在处理批量数据时,可以利用高级索引为每个样本选择不同的元素
1
2
3
4
5
6
7
8batch_size = 4
features = 10
data = torch.randn(batch_size, features)
# 为每个样本选择不同的特征
indices = torch.tensor([2, 5, 1, 8]) # 为4个样本分别选择第2、5、1、8个特征
selected = data[torch.arange(batch_size), indices]
print(selected.shape) # 输出: torch.Size([4])
并行更新
- 借助高级索引,能够并行地更新张量中的多个位置
1
2
3
4
5
6
7x = torch.zeros(5, 5)
rows = torch.tensor([0, 1, 2])
cols = torch.tensor([1, 2, 3])
values = torch.tensor([10, 20, 30])
x[rows, cols] = values # 将(0,1)、(1,2)、(2,3)位置的值分别更新为10、20、30
print(x)
附录:高级索引出发条件总结
在 PyTorch 中,当索引对象是一个非元组序列对象、一个
Tensor(数据类型为整数或布尔,在 NumPy 中为ndarray),或一个至少包含一个序列对象或Tensor(数据类型为整数或布尔,在 NumPy 中为ndarray)的元组时,会触发高级索引判定以下是会触发高级索引判定的情况
索引对象为单个高维数组或张量 :索引对象不是一个元组序列,而是一个高维数组或者张量,其中布尔型比整数型更常见
举例:
1
2x = torch.arange(24)
y = x.reshape(6, 4)[x.reshape(6, 4) > 10] # y.shape = torch.Size([13]),以为数组这里
torch.arange(24) > 10是布尔型张量,会触发高级索引
索引对象为整数型数组或张量组成的元组序列 :索引对象是一个元组序列,并且元组序列完全由整数型高维数组或者整数型张量组成
举例:
1
2
3
4x = torch.randn(3, 4)
rows = torch.tensor([0, 2])
cols = torch.tensor([1, 3])
y = x[(rows, cols)] # y.shape = torch.Size([2])其中
(rows, cols)构成的元组序列就是由整数型张量组成,会触发高级索引
索引对象为列表序列组成的元组序列 :索引对象是一个元组序列,并且元组序列完全由列表序列组成
举例:
1
2x = torch.randn(3, 4)
y = x[([0, 2], [1, 3])] # y.shape = torch.Size([2])这里
([0, 2], [1, 3])是由 列表序列组成的元组 ,会触发高级索引
索引对象为混合组成的元组序列(包含数组或张量与序列对象) :索引对象是一个元组序列,元组序列不仅包含高维整数型数组或者高维整数型张量,还包括序列对象
举例:
1
2
3x = torch.randn(3, 4)
rows = torch.tensor([0, 2])
y = x[(rows, [1, 3])] # y.shape = torch.Size([2])此元组序列中既有整数型张量
rows,又有列表[1, 3],会触发高级索引
索引对象为混合组成的元组序列(包含数组或张量与整数标量) :索引对象是一个元组序列,元组序列不仅包含高维整数型数组或者高维整数型张量,还包括整数标量
举例:
1
2
3x = torch.randn(3, 4)
rows = torch.tensor([0, 2])
y = x[(rows, 2)] # y.shape = torch.Size([2])元组序列中包含整数型张量
rows和整数标量2,会触发高级索引
索引对象为混合组成的元组序列(包含数组、张量、标量和序列对象) :索引对象是一个元组序列,元组序列包含高维整数型数组或者高维整数型张量、整数标量和序列对象
举例:
1
2
3x = torch.randn(3, 4, 5)
rows = torch.tensor([0, 2])
y = x[(rows, 2, [1, 3])] # y.shape = torch.Size([2])这种情况同样会触发高级索引