Jiahong 的个人博客

凡事预则立,不预则废


  • Home

  • Tags

  • Archives

  • Navigation

  • Search

PyTorch——Random相关状态管理


整体说明

  • PyTorch 中,包含很多随机操作,比如
    • 可以使用 torch.rand() 等函数获取随机数
    • 可以使用 torch.nn.functional.dropout() 实现随机 drop 一些神经元
    • 可以使用 tensor.random_() 等函数随机初始化参数
  • 这些涉及随机数/采样的方法均受限于一个随机状态管理

torch Seed 打印

  • torch Seed 打印代码:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    # 打印 torch 的随机种子情况
    def print_torch_seeds():
    print("=" * 30 + "PyTorch Random Seeds Status")
    print("=" * 30)
    cpu_seed = torch.initial_seed()
    print(f"[CPU] Seed: {cpu_seed}")

    if torch.cuda.is_available():
    try:
    gpu_seed = torch.cuda.initial_seed()
    current_device = torch.cuda.current_device()
    device_name = torch.cuda.get_device_name(current_device)

    print(f"[GPU] Seed: {gpu_seed}")
    print(f" Device: {current_device} ({device_name})")
    except Exception as e:
    print(f"[GPU] Error getting seed: {e}")
    else:
    print("[GPU] CUDA is not available.")

    print("=" * 30)
    print_torch_seeds()

torch Seed 设置

  • 全局 torch Seed 设置代码:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    import torch

    # 固定CPU种子
    torch.manual_seed(42)

    # 固定所有GPU的种子(单GPU/多GPU通用)
    if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42) # 替代 torch.cuda.manual_seed(42)(单GPU)

    # GPU上生成随机排列
    perm = torch.randperm(10, device="cuda") # 注意:需要指定 "cuda" 才会在 GPU 上执行
    print("GPU随机排列:", perm) # 每次运行结果一致
    print("draw a random number:", torch.rand()) # 每次运行结果一致
  • 使用独立的 torch 生成器(独立管理自己的随机生成器):

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    import torch

    # 创建独立的生成器并设置种子
    generator = torch.Generator()
    generator.manual_seed(42)
    # generator = torch.Generator().manual_seed(42) # 等价实现

    # 生成随机排列时指定生成器
    perm1 = torch.randperm(10, generator=generator)
    perm2 = torch.randperm(10, generator=generator)

    print("独立生成器-第一次:", perm1) # tensor([2, 7, 3, 1, 0, 9, 4, 5, 8, 6])
    print("独立生成器-第二次:", perm2) # tensor([2, 0, 7, 9, 8, 4, 3, 6, 1, 5])

    # 重置生成器种子,结果重复
    generator.manual_seed(42)
    perm3 = torch.randperm(10, generator=generator)
    print("重置生成器后:", perm3) # tensor([2, 7, 3, 1, 0, 9, 4, 5, 8, 6])(和perm1一致)
    • 说明:torch.Generator 是 PyTorch 中统一的随机数生成器(RNG)核心对象,几乎所有 PyTorch 内置的随机操作都支持通过 generator 参数指定该生成器

附录:torch.Generator 详细说明

  • torch.Generator 是 PyTorch 中统一的随机数生成器(RNG)核心对象 ,几乎所有 PyTorch 内置的随机操作都支持通过 generator 参数指定该生成器,仅极少数场景不支持(或无需支持)
  • torch.Generator 的核心作用是隔离随机状态 :任何依赖 PyTorch 内置随机数生成的操作,只要支持 generator 参数,就能通过该生成器控制随机行为;无 generator 参数的操作,要么不依赖随机数,要么复用全局生成器(CPU/CUDA)
  • 所有需要随机逻辑的场景均支持 torch.Generator 的随机操作(全场景)
  • 注:无随机逻辑的操作本身无随机行为,因此不需要(也无法)指定 generator:
    • 张量基础操作:torch.ones()、torch.zeros()、torch.arange()、torch.cat()、torch.matmul() 等
    • 数学运算:torch.sin()、torch.exp()、torch.mean()、torch.argmax() 等
    • 索引/切片:x[:, 0]、x.index_select() 等
    • 设备/类型转换:x.to('cuda')、x.float() 等

随机逻辑的场景示例

  • 所有操作均可通过 generator 参数指定自定义 torch.Generator,实现随机状态隔离
  • 基础随机数生成
    函数/方法 用途 示例
    torch.rand() 均匀分布随机数 torch.rand(3, generator=g)
    torch.randn() 标准正态分布随机数 torch.randn(2, 4, generator=g)
    torch.randint() 整数随机数 torch.randint(0, 10, (3,), generator=g)
    torch.randperm() 随机排列 torch.randperm(5, generator=g)
    torch.rand_like() 按形状生成均匀随机数 torch.rand_like(torch.ones(2), generator=g)
    torch.randn_like() 按形状生成正态随机数 torch.randn_like(torch.ones(2), generator=g)
    torch.normal() 自定义均值/方差正态分布 torch.normal(0, 1, (3,), generator=g)
    torch.poisson() 泊松分布随机数 torch.poisson(torch.ones(3), generator=g)
    torch.exponential() 指数分布随机数 torch.exponential(1.0, (3,), generator=g)
    torch.cauchy() 柯西分布随机数 torch.cauchy(0, 1, (3,), generator=g)
    torch.log_normal() 对数正态分布随机数 torch.log_normal(0, 1, (3,), generator=g)
    torch.multinomial() 多项分布采样 torch.multinomial(torch.ones(5), 3, generator=g)
    torch.bernoulli() 伯努利分布(0/1) torch.bernoulli(torch.ones(3)*0.5, generator=g)
    • 注:指定参数 generator 时,前面的参数也需要指定(Python 本身的规则)
  • 张量随机初始化
    函数/方法 用途 示例
    tensor.random_() 原地随机初始化(整数) tensor.random_(generator=g)
    tensor.uniform_() 原地均匀分布初始化 tensor.uniform_(0, 1, generator=g)
    tensor.normal_() 原地正态分布初始化 tensor.normal_(0, 1, generator=g)
    tensor.cauchy_() 原地柯西分布初始化 tensor.cauchy_(0, 1, generator=g)
  • 随机采样/变换(数据增强等)
    函数/方法 用途 示例
    torch.utils.data.RandomSampler 数据集随机采样 RandomSampler(dataset, generator=g)
    torch.nn.functional.dropout() Dropout层随机失活 F.dropout(x, p=0.5, generator=g)
    torch.nn.functional.dropout2d() 2D Dropout F.dropout2d(x, p=0.5, generator=g)
    torch.nn.functional.dropout3d() 3D Dropout F.dropout3d(x, p=0.5, generator=g)
    torchvision.transforms 中的随机变换 图像随机增强(如RandomCrop) transforms.RandomCrop(32, generator=g)(需torchvision)
    torch.distributions 分布采样 概率分布采样(如Normal、Uniform) dist = Normal(0, 1); dist.sample((3,), generator=g)

特殊说明:随机场景但不支持 torch.Generator 的场景

  • 有随机逻辑但不支持自定义 generator 的场景;依赖随机数,但 PyTorch 未开放 generator 参数,只能复用全局生成器(CPU/CUDA):
    操作 原因 替代方案
    torch.shuffle() 底层绑定全局生成器 用 torch.randperm(generator=g) 手动实现洗牌
    torch.nn.Dropout 模块(默认) 模块初始化时未绑定生成器 改用 F.dropout(generator=g) 或自定义模块绑定生成器
    部分第三方库的随机操作(如某些数据增强) 未适配 generator 参数 替换为 PyTorch 原生实现或手动设置全局种子
    torch.multiprocessing 多进程随机 进程间生成器隔离限制 每个进程内重新初始化 generator
  • 实践思考:
    • 1)凡是生成随机数的 PyTorch 原生函数,优先检查是否有 generator 参数,有则建议使用(隔离随机状态)
    • 2)对不支持 generator 的随机操作,要么手动实现(如用 randperm 替代 shuffle),要么临时设置全局种子并尽快恢复
    • 3)CUDA 场景务必创建对应设备的 generator,避免跨设备混用导致随机状态混乱

最佳实践:torch.Generator 隔离随机状态

  • 多个 torch.Generator 隔离随机状态示例
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    import torch

    # 创建两个独立生成器
    g1 = torch.Generator().manual_seed(42)
    g2 = torch.Generator().manual_seed(42)

    # 用g1生成随机数(消耗g1的状态)
    a = torch.rand(2, generator=g1)
    b = torch.rand(2, generator=g1)

    # 用g2生成随机数(g2状态未被消耗,结果和g1初始一致)
    c = torch.rand(2, generator=g2)

    print("a (g1第一次):", a) # tensor([0.8823, 0.9150])
    print("b (g1第二次):", b) # tensor([0.3829, 0.9593])
    print("c (g2第一次):", c) # tensor([0.8823, 0.9150])(和a一致)

附录:GPU 下的 torch.Generator

  • torch.Generator 必须与操作的设备(CPU/CUDA)对齐 ,否则会导致隐式设备拷贝、性能损耗,甚至随机状态混乱
    • CPU 操作时使用 CPU 生成器
    • CUDA 操作时使用对应 CUDA 设备的生成器
    • 核心目的:避免隐式跨设备拷贝,保证随机状态的隔离性和可复现性
  • 所有支持 CUDA 的随机操作(如 torch.rand(3, device='cuda', generator=g)),需指定与生成器同设备的 generator
  • CUDA 生成器的随机状态与 CPU 生成器完全隔离,互不干扰
生成器的设备属性
  • torch.Generator 可通过 device 参数绑定具体设备,默认是 CPU:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    import torch

    # CPU生成器(默认)
    g_cpu = torch.Generator() # 等价于 torch.Generator(device="cpu")
    print("CPU生成器设备:", g_cpu.device) # 输出:cpu

    # CUDA生成器(需显式指定)
    if torch.cuda.is_available():
    g_cuda = torch.Generator(device="cuda:0") # 绑定cuda:0
    print("CUDA生成器设备:", g_cuda.device) # 输出:cuda:0
对齐 vs 不对齐的示例对比
  • 正确:生成器设备 等于 操作 device(推荐)

    1
    2
    3
    4
    5
    6
    7
    8
    if torch.cuda.is_available():
    # 创建cuda:0的生成器
    g_cuda = torch.Generator(device="cuda:0").manual_seed(42)
    # 操作指定device=cuda:0,与生成器对齐
    perm = torch.randperm(10, device="cuda:0", generator=g_cuda)

    print("结果设备:", perm.device) # cuda:0
    print("无隐式拷贝,效率最高")
  • 错误:生成器设备 不等于 操作 device(性能坑)

    1
    2
    3
    4
    5
    6
    7
    if torch.cuda.is_available():
    # 生成器是cuda:0,但操作指定 device=cpu
    g_cuda = torch.Generator(device="cuda:0").manual_seed(42)
    perm = torch.randperm(10, device="cpu", generator=g_cuda)

    print("结果设备:", perm.device) # cpu
    print("隐式拷贝:GPU生成随机数 拷贝到CPU(额外开销)")
  • 更隐蔽的错误:CUDA操作用CPU生成器

    1
    2
    3
    4
    5
    6
    7
    if torch.cuda.is_available():
    # 生成器是CPU,操作指定device=cuda
    g_cpu = torch.Generator().manual_seed(42)
    perm = torch.randperm(10, device="cuda", generator=g_cpu)

    print("结果设备:", perm.device) # cuda:0
    print("隐式拷贝:CPU生成随机数 拷贝到GPU(额外开销)")
为什么必须对齐?
  • 随机数生成器的硬件绑定 :
    • CPU 生成器依赖 CPU 的随机数算法
    • CUDA 生成器依赖 GPU 的 cuRAND 库,直接在 GPU 显存生成随机数;
      • 跨设备使用时,PyTorch 会先在生成器设备生成随机数,再通过 PCIe 总线拷贝到操作指定的设备,产生额外耗时
  • 随机状态的隔离性(容易因为误用而出错) :
    • CUDA生成器的随机状态(get_state())和 CPU 生成器完全隔离,若跨设备使用,会导致“生成器状态和操作设备不匹配”,破坏随机种子的可复现性:
      1
      2
      3
      4
      5
      6
      7
      if torch.cuda.is_available():
      g_cuda = torch.Generator(device="cuda").manual_seed(42)
      # 第一次:跨设备使用(cuda生成器 到 cpu操作)
      perm1 = torch.randperm(10, device="cpu", generator=g_cuda)
      # 第二次:直接用cuda生成器 到 cuda操作
      perm2 = torch.randperm(10, device="cuda", generator=g_cuda)
      # perm2的结果不等于“重新seed后cuda操作的结果”(状态已被跨设备操作消耗)
torch.Generator 的最佳实践
  • 创建生成器时显式指定设备 :

    • 不要依赖默认的CPU生成器,GPU场景务必创建 device="cuda" 的生成器
  • 封装成函数,强制对齐 :

    1
    2
    3
    4
    5
    6
    7
    8
    9
    def get_generator(device: str = "cpu", seed: int = 42):
    g = torch.Generator(device=device)
    g.manual_seed(seed)
    return g

    # 使用
    if torch.cuda.is_available():
    g = get_generator(device="cuda:0", seed=42)
    perm = torch.randperm(10, device="cuda:0", generator=g)
  • 多GPU场景:每个GPU对应独立生成器 :

    1
    2
    3
    4
    5
    6
    7
    if torch.cuda.is_available() and torch.cuda.device_count() > 1:
    # 为cuda:0和cuda:1分别创建生成器
    g0 = torch.Generator(device="cuda:0").manual_seed(42)
    g1 = torch.Generator(device="cuda:1").manual_seed(100)

    perm0 = torch.randperm(10, device="cuda:0", generator=g0)
    perm1 = torch.randperm(10, device="cuda:1", generator=g1)

PyTorch——Tensor的内存布局Layout


整体说明

  • 在 PyTorch 里,张量的 layout 属性主要用于表明内存的组织形式(tensor.layout 属性可查看张量当前的布局类型)
  • 张量的存储主要分为稀疏布局(Sparse Layout)和稠密布局(Dense Layout)两种
    • 稠密布局适合进行常规的张量运算
    • 稀疏布局在处理大规模稀疏数据时,能够显著减少内存占用和计算量
  • 在使用稀疏张量进行计算时,需要注意:
    • 并非所有的 PyTorch 操作都支持稀疏张量,部分操作在执行前可能需要先将稀疏张量转换为稠密张量

稠密布局(torch.strided)

  • torch.strided 是标准的多维数组布局,采用连续的内存存储方式
  • torch.strided 在每一个维度都具备步长(stride)属性,借助该属性可以计算出内存中的偏移量
  • 在 PyTorch 1.13 及后续版本中,很多张量创建函数(如 torch.ones 等)的参数的默认值是都是 torch.strided
  • 稠密布局示例:
    1
    2
    x = torch.tensor([[1, 2, 3], [4, 5, 6]])
    print(x.layout) # 输出:torch.strided

稀疏 CSR 布局(torch.sparse_csr)

  • torch.sparse_csr 布局适用于存储稀疏矩阵,能有效节省内存
  • torch.sparse_csr 运用压缩稀疏行(Compressed Sparse Row)格式,借助三个张量来表示:
    • crow_indices:记录每行在 col_indices 和 values 中的起始位置
    • col_indices:存储非零元素所在的列索引
    • values:存放非零元素的值
  • 稀疏 CSR 布局示例:
    1
    2
    3
    4
    5
    crow_indices = torch.tensor([0, 2, 3])
    col_indices = torch.tensor([0, 1, 2])
    values = torch.tensor([1, 2, 3])
    sparse_tensor = torch.sparse_csr_tensor(crow_indices, col_indices, values, (2, 3))
    print(sparse_tensor.layout) # 输出:torch.sparse_csr

其他稀疏布局

  • PyTorch 还支持多种稀疏布局,以满足不同的访问和计算需求
  • 稀疏 COO(torch.sparse_coo)布局
  • 稀疏 CSC(torch.sparse_csc)布局

布局转换方法示例

  • 可以使用以下方法在不同布局之间进行转换:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    # 从稠密布局转换为稀疏 COO 布局
    dense_tensor = torch.tensor([[0, 1], [2, 0]])
    sparse_coo = dense_tensor.to_sparse() # 默认为 COO 格式

    # 从稀疏 COO 布局转换回稠密布局
    dense_tensor = sparse_coo.to_dense()

    # 稀疏布局之间的转换
    sparse_csr = sparse_coo.to_sparse_csr()

PyTorch——Module中的Parameter和Buffer


Parameter 和 Buffer 的整体理解

  • 在 PyTorch 的 nn.Module 里,Parameter 和 Buffer 都是张量类型(是两种类型的张量)
  • 两者的关键区别有:
    • 可训练性 :
      • Parameter 可训练 ,通常用于模型的权重和偏置,会在反向传播时被优化器更新
      • Buffer 不可训练 ,通常用于存储需要在训练或推理过程中保留,但不需要梯度更新的值(如 BatchNorm 的统计信息)
    • 注册方式 :
      • Parameter 显示定义字段:通过 nn.Parameter() 初始化,或通过 nn.Linear() 等类初始化
      • Buffer 一般是隐式定义:通过 register_buffer 或 BN 层等隐式自动定义
    • 访问方式:
      • Parameter 作为可训练的张量,会被自动添加到模型的 parameters() 迭代器中
      • Buffer 是不可训练的张量,不会被添加到 parameters() 中,也不会被优化器更新
  • 两者的共同点有:
    • 两者都会被保存在模型的 state_dict 中,因此在保存/加载模型时都会被保留
    • 当调用 model.to(device) 时,Parameter 和 Buffer 都会被移动到指定设备
  • 最佳实践:
    • Parameter用于
      • 定义模型权重、偏置等需要学习的参数
    • Buffer用于
      • 非训练状态的统计量(如 BatchNorm 的均值/方差)
      • 固定的预训练权重或常量张量
      • 以及其他需要与模型一起保存 ,但不需要梯度的中间结果

Parameter 和 Buffer 的定义方式

  • 需要注意的关键经验和知识点:
    • 定义位置 :建议将 Parameter 和 Buffer 都定义在 __init__ 中(虽然可以动态定义到 __init__ 之外)
    • 同名覆盖规则 :
      • Buffer 对象之间可以互相覆盖
      • Parameter 对象之间可以互相覆盖
      • Parameter 对象可以覆盖 Buffer 对象
      • Buffer 对象不可以覆盖 Parameter 对象
    • Parameter 或 Buffer 为 None 时 ,仅仅是一个声明 ,不会被 parameters()、buffers()或state_dict()等包含(注:named_xx() 和 xx() 数量一样,也不会包含)
    • 冻结参数 :属性为 requires_grad=False 的参数不会被更新(但可以被 parameters() 返回,也可以被加入优化器,此时有优化器状态,但是没有梯度,也不会被更新)
    • 初始化类型要求 :
      • Parameter 对象一定要用 nn.Parameter 对象初始化
      • Buffer 对象可以用 Tensor 对象初始化
    • Parameter 和 Buffer 更新规则 :
      • in-place update :
        • 可使用 model.x.data += 2 或 model.x.data.fill_(2.0) 的方式修改 Buffer 或 Parameter 的值,实现 in-place update
        • 此时针对 Parameter,不需要重新初始化 优化器
      • 替换张量数据:
        • 当使用类似 model.x.data = torch.tensor(2.0) 的方式修改,或重新注册新的参数对象时,此时会替换整个 Buffer 或 Parameter 对象
        • 此时针对 Parameter,需要重新初始化 优化器 ,否则优化器无法识别到被修改后的参数的张量
    • 优化器更新规则 :
      • 对于被重新赋值data张量的参数,需要重新初始化 优化器 ,否则优化器无法识别到被修改后的参数的张量
      • 对于新增加的参数 ,必须重新初始化优化器 ,以保证优化器能够优化到新的参数
      • 如果一个参数没有被优化器追踪(被追踪的参数在 optimizer.param_groups() 中),该参数不会被更新(即使在 loss.backward() 阶段已经计算了梯度,参数也不会被更新)
      • 更多补充见附录
  • Parameter 和 Buffer 定义代码和测试(详细阅读注释和输出部分)
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    266
    267
    268
    269
    270
    271
    272
    273
    274
    275
    276
    277
    278
    279
    280
    281
    282
    283
    284
    285
    import torch
    import torch.nn as nn
    import math

    class RegistrationDemo(nn.Module):
    def __init__(self):
    super().__init__()

    # 1. 直接定义Parameter
    self.direct_param = nn.Parameter(torch.randn(3, 3))

    # 2. 使用register_parameter()方法
    custom_tensor = torch.ones(2, 2) * 0.5
    self.register_parameter('explicit_param', nn.Parameter(custom_tensor)) # 必须是 nn.Parameter 对象或 None,如果是其他对象会出错

    # 3. 通过nn.Linear隐式注册Parameter
    self.linear = nn.Linear(8 * 32 * 32, 4) # 输入维度=8×32×32

    # 4. 通过nn.Conv2d隐式注册Parameter
    self.conv = nn.Conv2d(3, 8, kernel_size=3, padding=1)

    # 5. 通过nn.BatchNorm2d隐式注册Parameter和Buffer
    self.bn = nn.BatchNorm2d(8) # weight(Parameter)和bias(Parameter), running_mean和running_var(Buffer)

    # 6. 自定义Buffer
    self.register_buffer('custom_buffer', torch.tensor([math.pi])) # 必须是 Tensor 对象或 None,如果是其他对象会出错

    # 7. 使用不同初始化方式的Parameter
    self.xavier_init = nn.Parameter(torch.zeros(5, 5))
    nn.init.xavier_uniform_(self.xavier_init)

    self.kaiming_init = nn.Parameter(torch.zeros(5, 5))
    nn.init.kaiming_normal_(self.kaiming_init)

    # 8. 可学习的标量Parameter
    self.scalar_param = nn.Parameter(torch.tensor(0.1))

    # 9. 冻结的Parameter (requires_grad=False)
    self.frozen_param = nn.Parameter(torch.tensor(0.1), requires_grad=False)

    # 10. 预声明动态注册的参数名(重要!避免state_dict问题,注意,不推荐这么使用,但部分情况下可以用来做高阶的模型设计)
    self.register_parameter('dynamic_param', None) # 仅声明参数名,但该参数不会被加入 state_dict,直到被初始化为实际的 Parameter 对象
    self.register_buffer('dynamic_buffer', None) # 仅声明缓冲区名,但该参数不会被加入 state_dict,直到被初始化为实际的 Buffer 对象

    # 11. 测试通过register重复注册对象
    self.register_parameter('test_multiple_param', nn.Parameter(torch.tensor(1.0))) # 定义参数 test_multiple_param
    self.register_parameter('test_multiple_param', nn.Parameter(torch.tensor(2.0))) # 重新修改参数对象,值为2.0(注意不是简单的修改值)
    self.test_multiple_param = nn.Parameter(torch.tensor(3.0)) # 重新修改参数对象,值为3.0(注意不是简单的修改值)
    self.test_multiple_param = nn.Parameter(torch.tensor(4.0)) # 重新修改参数对象,值为4.0(注意不是简单的修改值)
    self.register_parameter('test_multiple_param', nn.Parameter(torch.tensor(5.0))) # 重新修改参数对象,值为5.0(注意不是简单的修改值)

    # 11(a)特别地,buffer的注册和参数类似,但 Parameter 可以覆盖 Buffer 对象,但 Buffer 不可覆盖 Parameter 对象
    self.register_buffer('test_multiple_buffer', nn.Parameter(torch.tensor(1.0)))
    self.register_buffer('test_multiple_buffer', nn.Parameter(torch.tensor(2.0)))
    self.test_multiple_buffer = nn.Parameter(torch.tensor(3.0)) # 重新修改 Buffer 对象为 Parameter 对象,值为3.0(注意不是简单的修改值)
    self.test_multiple_buffer = nn.Parameter(torch.tensor(4.0)) # 重新修改 Parameter 对象,值为4.0(注意不是简单的修改值)
    # self.register_buffer('test_multiple_buffer', nn.Parameter(torch.tensor(5.0))) # 这行会报错 KeyError: "attribute 'test_multiple_buffer' already exists":Buffer 不允许使用 register_buffer 覆盖 Parameter对象

    # 11(b) Buffer 不可 覆盖 Parameter 对象的再次验证
    self.register_buffer('test_multiple_buffer_param', nn.Parameter(torch.tensor(1.0)))
    self.test_multiple_buffer_param = torch.tensor(1.5) # 仍然是 Buffer 对象,值变成1.5
    self.test_multiple_buffer_param += 10 # 仍然是 Buffer 对象,值变成11.5,且是 原地修改(in-place update)
    # self.register_parameter('test_multiple_buffer_param', nn.Parameter(torch.tensor(2.0))) # 这行会报错,KeyError: "attribute 'test_multiple_buffer_param' already exists",不允许 Buffer 和 Parameter 互相覆盖
    self.test_multiple_buffer_param = nn.Parameter(torch.tensor(3.0)) # 通过重写将 Buffer修改为 Parameter 对象
    self.register_parameter('test_multiple_buffer_param', nn.Parameter(torch.tensor(4.0))) # 这行不会报错,因为test_multiple_buffer_param已经是 Parameter 对象了,可以被参数重写

    # 11(c) Buffer 不可 覆盖 Parameter 对象的再次验证(即使是通过 register_parameter 注册的参数也不能覆盖)
    self.register_parameter('test_multiple_param_buffer', nn.Parameter(torch.tensor(1.0)))
    # self.register_buffer('test_multiple_param_buffer', nn.Parameter(torch.tensor(2.0))) # 这行会报错,KeyError: "attribute 'test_multiple_param_buffer' already exists",不允许 Buffer 和 Parameter 互相覆盖


    def init_dynamic_params(self, input_size):
    """在__init__外动态注册Parameter和Buffer(注意,不推荐在 __init__ 外定义 Parameter 和 Buffer ,但部分情况下可以用来做高阶的模型设计)"""
    # 动态注册Parameter(注册__init__中声明过的参数)
    self.register_parameter('dynamic_param', nn.Parameter(torch.randn(input_size)))

    # 动态注册Buffer(注册__init__中声明过的参数)
    self.register_buffer('dynamic_buffer', torch.arange(5).float())

    # 注册__init__中未声明的参数,也可以被正常更新,但需要保证定义执行此语句后再执行 forward 操作
    self.register_parameter('dynamic_param2', nn.Parameter(torch.randn(1)))

    def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    x = x * self.scalar_param
    x = x * self.frozen_param

    if self.dynamic_param is not None: # 不论是否提前初始化,均没有问题,因为 __init__ 中已经声明
    x = x + self.dynamic_param.mean()

    if self.dynamic_param2 is not None: # 如果在调用 forward 前未定义 dynamic_param2,会报错,建议像是 dynamic_param 一样在 __init__ 函数中进行声明
    x = x + self.dynamic_param2

    x = x.view(x.size(0), -1) # 展平为 (batch_size, 8*32*32)
    x = self.linear(x)
    return x

    def print_registration_info(model):
    print("\n=== 模型整体结构 ===")
    print(model)

    print("\n=== Parameters ===")
    for name, param in model.named_parameters():
    print(f"参数名称: {name}, 形状: {param.shape}, 类型: {type(param)}, 是否可训练: {param.requires_grad}")

    print("\n=== Buffers ===")
    for name, buffer in model.named_buffers():
    print(f"缓冲区名称: {name}, 形状: {buffer.shape}, 类型: {type(buffer)}")

    print("\n=== state_dict ===")
    for key in model.state_dict().keys():
    print(f"Key: {key}")

    if __name__ == "__main__":
    # 创建模型但不初始化动态参数
    model = RegistrationDemo()

    # 打印注册信息(动态参数尚未初始化)
    print("初始化前的参数状态:")
    print_registration_info(model)

    # 初始化动态参数
    model.init_dynamic_params(input_size=4)

    # 打印注册信息(动态参数已初始化)
    print("\n初始化后的参数状态:")
    print_registration_info(model)

    # 测试前向传播
    x = torch.randn(2, 3, 32, 32)
    output = model(x)

    # 验证动态参数是否参与训练
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    loss = output.sum()
    loss.backward()

    print()

    print("训练前 dynamic_param:", model.dynamic_param.data)
    print("训练前 dynamic_param2:", model.dynamic_param2.data)
    print("训练前 scalar_param:", model.scalar_param.data)
    print("训练前 frozen_param:", model.frozen_param.data)

    optimizer.step() # 更新参数,注意:这里只更新需要更新的参数,不参与训练的参数不会被更新(即使他们被加入了优化器中)

    print("训练后 dynamic_param:", model.dynamic_param.data)
    print("训练前 dynamic_param2:", model.dynamic_param2.data)
    print("训练后 scalar_param:", model.scalar_param.data)
    print("训练后 frozen_param(未发生改变):", model.frozen_param.data)

    print()
    print("test_multiple_param:", model.test_multiple_param)
    print("test_multiple_buffer:", model.test_multiple_buffer)

    # 初始化前的参数状态:
    #
    # === 模型整体结构 ===
    # RegistrationDemo(
    # (linear): Linear(in_features=8192, out_features=4, bias=True)
    # (conv): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    # (bn): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    # )
    #
    # === Parameters ===
    # 参数名称: direct_param, 形状: torch.Size([3, 3]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: explicit_param, 形状: torch.Size([2, 2]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: xavier_init, 形状: torch.Size([5, 5]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: kaiming_init, 形状: torch.Size([5, 5]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: scalar_param, 形状: torch.Size([]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: frozen_param, 形状: torch.Size([]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: False
    # 参数名称: test_multiple_param, 形状: torch.Size([]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: test_multiple_buffer, 形状: torch.Size([]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: test_multiple_buffer_param, 形状: torch.Size([]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: test_multiple_param_buffer, 形状: torch.Size([]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: linear.weight, 形状: torch.Size([4, 8192]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: linear.bias, 形状: torch.Size([4]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: conv.weight, 形状: torch.Size([8, 3, 3, 3]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: conv.bias, 形状: torch.Size([8]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: bn.weight, 形状: torch.Size([8]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: bn.bias, 形状: torch.Size([8]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    #
    # === Buffers ===
    # 缓冲区名称: custom_buffer, 形状: torch.Size([1]), 类型: <class 'torch.Tensor'>
    # 缓冲区名称: bn.running_mean, 形状: torch.Size([8]), 类型: <class 'torch.Tensor'>
    # 缓冲区名称: bn.running_var, 形状: torch.Size([8]), 类型: <class 'torch.Tensor'>
    # 缓冲区名称: bn.num_batches_tracked, 形状: torch.Size([]), 类型: <class 'torch.Tensor'>
    #
    # === state_dict ===
    # Key: direct_param
    # Key: explicit_param
    # Key: xavier_init
    # Key: kaiming_init
    # Key: scalar_param
    # Key: frozen_param
    # Key: test_multiple_param
    # Key: test_multiple_buffer
    # Key: test_multiple_buffer_param
    # Key: test_multiple_param_buffer
    # Key: custom_buffer
    # Key: linear.weight
    # Key: linear.bias
    # Key: conv.weight
    # Key: conv.bias
    # Key: bn.weight
    # Key: bn.bias
    # Key: bn.running_mean
    # Key: bn.running_var
    # Key: bn.num_batches_tracked
    #
    # 初始化后的参数状态:
    #
    # === 模型整体结构 ===
    # RegistrationDemo(
    # (linear): Linear(in_features=8192, out_features=4, bias=True)
    # (conv): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    # (bn): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    # )
    #
    # === Parameters ===
    # 参数名称: direct_param, 形状: torch.Size([3, 3]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: explicit_param, 形状: torch.Size([2, 2]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: xavier_init, 形状: torch.Size([5, 5]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: kaiming_init, 形状: torch.Size([5, 5]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: scalar_param, 形状: torch.Size([]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: frozen_param, 形状: torch.Size([]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: False
    # 参数名称: dynamic_param, 形状: torch.Size([4]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: test_multiple_param, 形状: torch.Size([]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: test_multiple_buffer, 形状: torch.Size([]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: test_multiple_buffer_param, 形状: torch.Size([]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: test_multiple_param_buffer, 形状: torch.Size([]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: dynamic_param2, 形状: torch.Size([1]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: linear.weight, 形状: torch.Size([4, 8192]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: linear.bias, 形状: torch.Size([4]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: conv.weight, 形状: torch.Size([8, 3, 3, 3]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: conv.bias, 形状: torch.Size([8]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: bn.weight, 形状: torch.Size([8]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    # 参数名称: bn.bias, 形状: torch.Size([8]), 类型: <class 'torch.nn.parameter.Parameter'>, 是否可训练: True
    #
    # === Buffers ===
    # 缓冲区名称: custom_buffer, 形状: torch.Size([1]), 类型: <class 'torch.Tensor'>
    # 缓冲区名称: dynamic_buffer, 形状: torch.Size([5]), 类型: <class 'torch.Tensor'>
    # 缓冲区名称: bn.running_mean, 形状: torch.Size([8]), 类型: <class 'torch.Tensor'>
    # 缓冲区名称: bn.running_var, 形状: torch.Size([8]), 类型: <class 'torch.Tensor'>
    # 缓冲区名称: bn.num_batches_tracked, 形状: torch.Size([]), 类型: <class 'torch.Tensor'>
    #
    # === state_dict ===
    # Key: direct_param
    # Key: explicit_param
    # Key: xavier_init
    # Key: kaiming_init
    # Key: scalar_param
    # Key: frozen_param
    # Key: dynamic_param
    # Key: test_multiple_param
    # Key: test_multiple_buffer
    # Key: test_multiple_buffer_param
    # Key: test_multiple_param_buffer
    # Key: dynamic_param2
    # Key: custom_buffer
    # Key: dynamic_buffer
    # Key: linear.weight
    # Key: linear.bias
    # Key: conv.weight
    # Key: conv.bias
    # Key: bn.weight
    # Key: bn.bias
    # Key: bn.running_mean
    # Key: bn.running_var
    # Key: bn.num_batches_tracked
    #
    # 训练前 dynamic_param: tensor([-1.2286, 0.4382, 2.0483, 0.1235])
    # 训练前 dynamic_param2: tensor([-0.3353])
    # 训练前 scalar_param: tensor(0.1000)
    # 训练前 frozen_param: tensor(0.1000)
    # 训练后 dynamic_param: tensor([-1.2186, 0.4482, 2.0583, 0.1335])
    # 训练前 dynamic_param2: tensor([-0.3253])
    # 训练后 scalar_param: tensor(0.0900)
    # 训练后 frozen_param(未发生改变): tensor(0.1000)
    #
    # test_multiple_param: Parameter containing:
    # tensor(5., requires_grad=True)
    # test_multiple_buffer: Parameter containing:
    # tensor(4., requires_grad=True)

附录:loss.backward() 和 optimizer.step() 的工作流程

  • loss.backward() 负责计算梯度并存储到参数的 .grad 中(若 requires_grad = False 则不会计算梯度)
  • optimizer.step() 负责根据梯度更新参数(.grad为 None时不更新)
  • 如果在 loss.backward() 之前执行 requires_grad = False 可保证 .grad 为 None,参数不会更新
  • 如果在 loss.backward() 和 optimizer.step() 中间执行 requires_grad = False,参数会更新这一次,下次不会更新
    • 梯度计算发生在 loss.backward() 阶段 :此时参数的 requires_grad 为 True,梯度已被计算并存储在 param.grad 中
    • 优化器只检查 .grad 是否为 None :修改 requires_grad = False 不会清除已计算的梯度,因此优化器仍会使用已有的梯度更新参数
    • 后续迭代中参数被忽略:一旦 requires_grad = False,后续的 loss.backward() 将不再计算该参数的梯度,优化器也会跳过它

PyTorch——PyTorch中的高级索引


整体说明

  • PyTorch 的高级索引操作允许以非常灵活和强大的方式选择和修改张量中的元素
  • 高级索引包括整数索引、切片(slicing)、布尔索引和整数数组索引等
  • PyTorch 中的索引使用(包括基础索引和高级索引)和 Numpy 中基本一致
  • 可用于选择元素,也可以用于修改元素
  • 高级索引一般不共享存储区(普通索引一般共享存储区)
    • 普通索引一般可以通过修改 Tensor 的偏移量(offset)、步长(stride)或形状实现,不需要修改存储区的数据(使用共享存储区可以节省内存和处理速度)
    • 高级索引则一般都是不规则的变化 ,需要修改存储区,故而不使用共享存储区
    • 这也是高级索引与切片的最大差别
  • 检索维度匹配要求 :多个索引数组的维度必须能够广播成一致的形状,否则报错
  • 高级索引的判定方式:
    • 在 PyTorch 中,当索引对象是一个非元组序列对象、一个 Tensor(数据类型为整数或布尔,在 NumPy 中为 ndarray),或一个至少包含一个序列对象或 Tensor(数据类型为整数或布尔,在 NumPy 中为 ndarray)的元组时,会触发高级索引判定

基础索引回顾

  • 在深入高级索引之前,我们先快速回顾一下基础索引:
    • 单个整数索引 : 选取特定位置的单个元素
    • 切片 : 选取连续的子范围。例如 a[start:end:step]
    • 省略号 (...) : 表示选择所有剩余的维度
      • 这在张量维度较多时非常有用,可以避免写出冗长的 :
  • 基础索引的 Demo:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    import torch
    a = torch.arange(27).reshape(3, 3, 3)
    print("原始张量 a:\n", a)
    # 输出:
    # tensor([[[ 0, 1, 2],
    # [ 3, 4, 5],
    # [ 6, 7, 8]],
    #
    # [[ 9, 10, 11],
    # [12, 13, 14],
    # [15, 16, 17]],
    #
    # [[18, 19, 20],
    # [21, 22, 23],
    # [24, 25, 26]]])

    print("a[0, 1, 2]:", a[0, 1, 2]) # 输出:tensor(5)(第一个维度第0个,第二个维度第1个,第三个维度第2个)
    print("a[1, :, 0]:\n", a[1, :, 0]) # 输出:tensor([ 9, 12, 15])(第一个维度第1个,第二个维度所有,第三个维度第0个)
    print("a[..., 1]:\n", a[..., 1]) # 选取所有维度,最后一个维度第1个,等价于 a[:,:,1]
    # 输出:
    # tensor([[ 1, 4, 7],
    # [10, 13, 16],
    # [19, 22, 25]])

高级索引-整数数组索引 (Integer Array Indexing)

  • 当使用 torch.Tensor (类型为 torch.long 或 torch.int) 或 Python 列表作为索引时,这被称为整数数组索引(高级索引的一种)

    • 第一步:广播 : 如果有多个整数数组索引,并且它们的形状不同,PyTorch 会尝试对它们进行广播(broadcasting),广播为相同形状
    • 第二步:返回结果 : 整数数组索引的返回张量的形状由广播后的索引张量的形状决定 ,当部分维度没有被索引时,默认保留该维度的所有值
  • 示例 1 : 在一个维度上使用整数数组索引

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    import torch

    a = torch.arange(12).reshape(3, 4)
    print("原始张量 a:\n", a)
    # 输出:
    # tensor([[ 0, 1, 2, 3],
    # [ 4, 5, 6, 7],
    # [ 8, 9, 10, 11]])

    # 选取第0行和第2行
    indices = torch.tensor([0, 2])
    result = a[indices]
    print("\na[torch.tensor([0, 2])]:\n", result)
    # 输出:
    # tensor([[ 0, 1, 2, 3],
    # [ 8, 9, 10, 11]])

    # 选取第1列和第3列
    result = a[:, torch.tensor([1, 3])]
    print("\na[:, torch.tensor([1, 3])]:\n", result)
    # 输出:
    # tensor([[ 1, 3],
    # [ 5, 7],
    # [ 9, 11]])
  • 示例 2: 在多个维度上使用整数数组索引(“坐标”式选择)

    • 注:当你在多个维度上同时使用整数数组索引时,它们会被解释为坐标对(这与 NumPy 的行为非常相似)
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      import torch

      a = torch.arange(12).reshape(3, 4)
      print("原始张量 a:\n", a)

      # 选取 (0, 1), (1, 3), (2, 0) 这些位置的元素
      row_indices = torch.tensor([0, 1, 2])
      col_indices = torch.tensor([1, 3, 0])
      result = a[row_indices, col_indices]
      print("\na[torch.tensor([0, 1, 2]), torch.tensor([1, 3, 0])]:", result)
      # 解释:
      # 结果的第一个元素是 a[0, 1]
      # 结果的第二个元素是 a[1, 3]
      # 结果的第三个元素是 a[2, 0]
      # 输出: tensor([ 1, 7, 8])

      # 索引张量形状不一致时会进行广播
      a = torch.arange(27).reshape(3, 3, 3)
      print("\n原始张量 a (3x3x3):\n", a)

      # 选取多个坐标,例如:
      # 维度0取索引0和2
      # 维度1取索引1和2
      # 维度2取索引0和1
      idx0 = torch.tensor([[0], [2]]) # shape (2, 1)
      idx1 = torch.tensor([1, 2]) # shape (2,)
      idx2 = torch.tensor([0, 1]) # shape (2,)

      # PyTorch 会尝试广播这些索引
      # idx0: [[0], [2]] -> 广播为 [[0, 0], [2, 2]] (因为 idx1 和 idx2 的大小是2)
      # idx1: [1, 2] -> 广播为 [[1, 2], [1, 2]]
      # idx2: [0, 1] -> 广播为 [[0, 1], [0, 1]]

      # 最终会选择以下坐标的元素:
      # (0, 1, 0), (0, 2, 1)
      # (2, 1, 0), (2, 2, 1)
      result = a[idx0, idx1, idx2]
      print("\na[idx0, idx1, idx2]:\n", result)
      # 输出:
      # tensor([[ 3, 5], # a[0,1,0], a[0,2,1]
      # [21, 23]]) # a[2,1,0], a[2,2,1]

布尔索引 (Boolean Indexing)

  • 当使用一个布尔张量作为索引时,PyTorch 会选择布尔张量中值为 True 的所有元素
    • 形状要求 :布尔张量的形状必须与被索引张量的一个或多个维度匹配
    • 返回结果 :布尔张量索引的结果张量通常是 1 维的张量,包含所有满足条件(True)的元素
    • 特别说明:布尔张量通常被称为“掩码”
  • 布尔张量的 Demo:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    import torch

    a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    print("原始张量 a:\n", a)

    # 选取所有大于 5 的元素
    mask = a > 5
    print("\n布尔掩码 (a > 5):\n", mask)
    # 输出:
    # tensor([[False, False, False],
    # [False, False, True],
    # [ True, True, True]])

    result = a[mask]
    print("\na[a > 5]:", result)
    # 输出: tensor([6, 7, 8, 9])
    # 注意以上输出是 1 维的

混合索引操作:a[index1, :, index2] 和 a[index1, index2, :]

  • 特别说明广播机制:当索引数组的维度不匹配时,PyTorch 会尝试运用广播规则来使它们的维度变得兼容,即使混合索引也适用,详情如下:

    1
    2
    3
    4
    5
    # 行索引是[0, 2],列索引对所有行都是0
    rows = torch.tensor([0, 2]) # 形状为(2,)
    cols = torch.tensor([0]) # 形状为(1,)
    result = a[rows, :, cols] # 广播后形状变为(2, 3, 1)
    print(result.shape) # 输出: torch.Size([2, 3, 1])
  • 关于索引后的形状:

    • 若高级索引没被隔离,则正常覆盖高级索引所在的维度区域
    • 若高级索引被隔离,则对齐维度后,高级索引需要合并,并放到最前面(注意:不是第一个高级索引所在的位置,而是从第 0 维开始的最前面几维)
  • 这种索引方式是在多个维度上混合使用高级索引和切片操作

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    import torch
    a = torch.arange(27).reshape(3, 3, 3)
    print(a)
    # 输出:
    # tensor([[[ 0, 1, 2],
    # [ 3, 4, 5],
    # [ 6, 7, 8]],
    # [[ 9, 10, 11],
    # [12, 13, 14],
    # [15, 16, 17]],
    # [[18, 19, 20],
    # [21, 22, 23],
    # [24, 25, 26]]])

    index1 = torch.tensor([[0, 2]]) # 使用 [[0,2]] 也可以
    index2 = torch.tensor([[0]]) # 1)使用 [[0]] 也可以;2)[0] 也可以得到相同结果,因为广播后结果是一样的
    result = a[index1, index2, :]
    print(result.shape) # 输出: torch.Size([1, 2, 3])
    print(result)
    # 输出:
    # tensor([[[ 0, 1, 2],
    # [18, 19, 20]]])
    # 理解:index1和index2广播以后是shape=[1,2],结果中,第0,1维的[3,3]会被[1,2]替换

    index1 = torch.tensor([[0, 2]]) # 使用 [[0,2]] 也可以
    index2 = torch.tensor([[0]]) # 1)使用 [[0]] 也可以;2)[0] 也可以得到相同结果,因为广播后结果是一样的
    result = a[index1, :, index2]
    print(result.shape) # 输出: torch.Size([1, 2, 3])
    print(result)
    # 输出:
    # tensor([[[ 0, 3, 6],
    # [18, 21, 24]]])
    # 理解:相当于先调用 a.transpose_(2,1) 对齐索引维度,然后再调用 a[index1,index2,:],因为 高级索引需要合并到一起去广播并放到最前面
    # # 也可以用permute替代transpose,但permute不能inplace
    # 测试下面的代码替换 result = a[index1, :, index2] 后与上述输出一致
    # a.transpose_(2,1)
    # result = a[index1, index2, :]


    index1 = torch.tensor([0, 2])
    index2 = torch.tensor([0, 2])
    result = a[index1, :, index2]
    print(result.shape) # 输出: torch.Size([2, 3])
    print(result)
    # 输出:
    # tensor([[[ 0, 3, 6],
    # [20, 23, 26]]])
    # tensor([[ 0, 3, 6],
    # [20, 23, 26]])
    # 理解1:index1和index2广播以后是shape=[2],结果中,第0,1维的[3,3]会被[2]替换
    # 理解2:相当于先调用 a.transpose_(2,1) 对齐索引维度,然后再调用 a[index1,index2,:],因为 高级索引需要合并到一起去广播并放到最前面
    # 说明:已经测试下面的代码替换 result = a[index1, :, index2] 后与上述输出一致
    # a.transpose_(2,1)
    # result = a[index1, index2, :]
  • 高级索引被隔离的补充实验 Demo:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    import torch
    a = torch.arange(3*4*5*6).reshape(3, 4, 5, 6)
    # print(a)

    index1 = torch.tensor([[0, 2]])
    index2 = torch.tensor([[0, 2]])
    result = a[:, index1, :, index2]
    print(result.shape) # 输出:torch.Size([1, 2, 3, 3])
    print(result)
    # 输出:
    # tensor([[[[ 0, 6, 12, 18, 24],
    # [120, 126, 132, 138, 144],
    # [240, 246, 252, 258, 264]],
    #
    # [[ 62, 68, 74, 80, 86],
    # [182, 188, 194, 200, 206],
    # [302, 308, 314, 320, 326]]]])

    index1 = torch.tensor([[0, 2]])
    index2 = torch.tensor([[0, 2]])
    a = a.permute(1,3,0,2)
    result = a[index1,index2,:,:]
    print(result.shape) # 输出:torch.Size([1, 2, 3, 3])
    print(result)
    # 输出:
    # tensor([[[[ 0, 6, 12, 18, 24],
    # [120, 126, 132, 138, 144],
    # [240, 246, 252, 258, 264]],
    #
    # [[ 62, 68, 74, 80, 86],
    # [182, 188, 194, 200, 206],
    # [302, 308, 314, 320, 326]]]])

    # 理解1:高级索引index1和index2被基础索引隔断了,index1和index2广播后为shape=[1,2],然后高级索引的维度会替换所在维度的[3,3]后放到最前面
    # 理解2:相当于先执行 a = a.permute(1,3,0,2)(注意permute操作不能inplace),然后再执行 result = a[index1,index2,:,:]
    # 测试说明:从上面的代码输出可以看到,result = a[:, index1, :, index2] 替换为下面语句后,结果一致:
    # a = a.permute(1,3,0,2)
    # result = a[index1,index2,:,:]

附录:高级索引的一些常见用法

批量数据选择

  • 在处理批量数据时,可以利用高级索引为每个样本选择不同的元素
    1
    2
    3
    4
    5
    6
    7
    8
    batch_size = 4
    features = 10
    data = torch.randn(batch_size, features)

    # 为每个样本选择不同的特征
    indices = torch.tensor([2, 5, 1, 8]) # 为4个样本分别选择第2、5、1、8个特征
    selected = data[torch.arange(batch_size), indices]
    print(selected.shape) # 输出: torch.Size([4])

并行更新

  • 借助高级索引,能够并行地更新张量中的多个位置
    1
    2
    3
    4
    5
    6
    7
    x = torch.zeros(5, 5)
    rows = torch.tensor([0, 1, 2])
    cols = torch.tensor([1, 2, 3])
    values = torch.tensor([10, 20, 30])

    x[rows, cols] = values # 将(0,1)、(1,2)、(2,3)位置的值分别更新为10、20、30
    print(x)

附录:高级索引出发条件总结

  • 在 PyTorch 中,当索引对象是一个非元组序列对象、一个 Tensor(数据类型为整数或布尔,在 NumPy 中为 ndarray),或一个至少包含一个序列对象或 Tensor(数据类型为整数或布尔,在 NumPy 中为 ndarray)的元组时,会触发高级索引判定

  • 以下是会触发高级索引判定的情况

  • 索引对象为单个高维数组或张量 :索引对象不是一个元组序列,而是一个高维数组或者张量,其中布尔型比整数型更常见

    • 举例:

      1
      2
      x = torch.arange(24)
      y = x.reshape(6, 4)[x.reshape(6, 4) > 10] # y.shape = torch.Size([13]),以为数组
    • 这里 torch.arange(24) > 10 是布尔型张量,会触发高级索引

  • 索引对象为整数型数组或张量组成的元组序列 :索引对象是一个元组序列,并且元组序列完全由整数型高维数组或者整数型张量组成

    • 举例:

      1
      2
      3
      4
      x = torch.randn(3, 4)
      rows = torch.tensor([0, 2])
      cols = torch.tensor([1, 3])
      y = x[(rows, cols)] # y.shape = torch.Size([2])
    • 其中 (rows, cols) 构成的元组序列就是由整数型张量组成,会触发高级索引

  • 索引对象为列表序列组成的元组序列 :索引对象是一个元组序列,并且元组序列完全由列表序列组成

    • 举例:

      1
      2
      x = torch.randn(3, 4)
      y = x[([0, 2], [1, 3])] # y.shape = torch.Size([2])
    • 这里 ([0, 2], [1, 3]) 是由 列表序列组成的元组 ,会触发高级索引

  • 索引对象为混合组成的元组序列(包含数组或张量与序列对象) :索引对象是一个元组序列,元组序列不仅包含高维整数型数组或者高维整数型张量,还包括序列对象

    • 举例:

      1
      2
      3
      x = torch.randn(3, 4)
      rows = torch.tensor([0, 2])
      y = x[(rows, [1, 3])] # y.shape = torch.Size([2])
    • 此元组序列中既有整数型张量rows,又有列表[1, 3],会触发高级索引

  • 索引对象为混合组成的元组序列(包含数组或张量与整数标量) :索引对象是一个元组序列,元组序列不仅包含高维整数型数组或者高维整数型张量,还包括整数标量

    • 举例:

      1
      2
      3
      x = torch.randn(3, 4)
      rows = torch.tensor([0, 2])
      y = x[(rows, 2)] # y.shape = torch.Size([2])
    • 元组序列中包含整数型张量 rows 和整数标量2,会触发高级索引

  • 索引对象为混合组成的元组序列(包含数组、张量、标量和序列对象) :索引对象是一个元组序列,元组序列包含高维整数型数组或者高维整数型张量、整数标量和序列对象

    • 举例:

      1
      2
      3
      x = torch.randn(3, 4, 5)
      rows = torch.tensor([0, 2])
      y = x[(rows, 2, [1, 3])] # y.shape = torch.Size([2])
    • 这种情况同样会触发高级索引

PyTorch——compile函数的理解和使用


整体说明

  • torch.compile() 函数是 PyTorch 2.0 引入的一个重要功能,用于对模型进行编译优化,以提升训练和推理性能
    • 将 PyTorch 模型从“解释型”的逐行执行模式,转变为“编译型”的、一次性优化的执行模式
    • 当模型被编译后,它在后续的推理或训练中会运行得更快,并且使用的内存可能更少
  • torch.compile() 的核心作用是通过对模型计算图进行一系列优化(如算子融合、常量折叠、内存优化等),生成更高效的代码,从而加速模型的执行
    • 这个过程是自动的,并且大部分时间是无侵入性的,不需要修改模型的代码
  • torch.compile() 效果提升:
    • 因模型结构和硬件而异,通常对于有大量小操作(如 Transformer 模型)或对 GPU 算力要求高的模型效果更显著
    • 在大规模训练和推理场景中效果显著
  • 首次运行编译后的模型会有一定的启动延迟,用于图优化和代码生成(compile 函数是惰性执行的),但后续重复调用会更快
  • 编译过程可能会增加模型的内存占用,需根据实际情况调整

调用 torch.compile(model) 后会发生什么?

  • 具体来说,调用 torch.compile(model) 会发生以下过程:
  • 1)捕获模型计算图(Graph Capturing) :分析模型的前向传播逻辑,记录张量的操作序列和依赖关系,构建计算图表示
    • 当你第一次调用 torch.compile 编译后的模型时,它会追踪模型的前向传播过程
    • 这就像在记录模型中的每一步操作,比如矩阵乘法、卷积、激活函数等
    • PyTorch 会创建一个计算图(computational graph) ,这个图代表了模型从输入到输出的所有计算路径
    • 这个过程是惰性的,即只在第一次实际运行模型时发生(所以第一次调用比较慢,后面会比较快)
    • 计算图被捕获后,编译器会对计算图进行一系列优化(接下面)
  • 2)图优化(Graph Optimization) :包含一系列优化,例如:
    • 算子融合(Operator Fusion) : 将多个连续的小算子合并为一个大算子(如将卷积+批归一化+激活函数融合),减少 kernel 调用次数和内存读写
    • 常量折叠(Constant Folding) :计算图中固定不变的常量表达式会被预先计算,避免重复计算
    • 死代码消除(Dead Code Elimination) : 移除计算图中未被使用的节点或操作,减少不必要计算
    • 内存优化(Memory Optimization) : 优化张量的内存分配和复用,重新安排计算顺序,以减少中间结果所需的内存,从而更好地利用缓存
    • 优化后的计算图会用来生成高效的、针对特定硬件(如 GPU)的可执行代码(具体见下节)
  • 3)代码生成 :
    • 编译器将高级的 PyTorch 操作转换成底层的、更接近硬件指令的特定后端的代码(如 CPU 上的 C++/OpenMP 代码,GPU 上的 CUDA 代码)
      • 例如,它可能会将 PyTorch 的张量操作转换成 CUDA 内核
    • 支持多种后端(如 inductor、aot_eager 等),默认使用 inductor 后端(也叫做 TorchInductor 后端),它能生成高效的 GPU/CPU 代码
  • 4)返回编译后的模型 :
    • torch.compile(model) 返回一个经过包装的模型对象,其接口与原模型一致(可直接调用 forward 方法或进行训练),但内部执行的是优化后的代码

使用示例

  • 使用非常简单,仅需一行代码
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    import torch
    import torch.nn as nn

    class DiyModel(nn.Module):
    def __init__(self):
    super().__init__()
    self.conv = nn.Conv2d(3, 32, kernel_size=3)
    self.relu = nn.ReLU()

    def forward(self, x):
    return self.relu(self.conv(x))

    model = DiyModel()
    # 编译模型
    compiled_model = torch.compile(model)

    # 使用编译后的模型(接口与原模型万全一致)
    x = torch.randn(1, 3, 224, 224)
    output = compiled_model(x)

PyTorch——einsum函数使用


整体说明:

  • torch.einsum 是 PyTorch 中一个非常强大且灵活的函数,用于执行基于爱因斯坦求和约定(Einstein summation convention)的张量运算
    • 通过这种约定,你可以简洁地表示复杂的多维数组操作,如矩阵乘法、转置、点积等,而不需要显式地编写循环
  • torch.einsum 提供了极大的灵活性来处理各种复杂的张量运算,但需要注意的是,不恰当的使用可能导致性能下降,因为它可能会隐藏潜在的优化机会(比如矩阵乘法建议直接调用矩阵乘法)
  • torch.einsum的本质就是爱因斯坦求和约定(一种爱因斯坦发表论文中提到的表达式省略写法),是矩阵乘法的一种表示
    • 比如 C = torch.einsum("m,d->nd", A,B)表示矩阵 C[m,d] = A[n]*B[d],这是个很常用的省略写法

基本用法

  • torch.einsum 的基本语法如下:

    1
    torch.einsum(equation, *operands)
    • equation:一个字符串,指定了输入张量的下标标签以及输出张量的计算方式
    • operands:可变数量的张量参数,它们将根据 equation 进行运算
  • 在 equation 参数中:

    • 每个输入张量的维度由字母标记,不同张量之间使用逗号分隔(亲测字符串中间可以随便加空格,不影响最终结果)
      • 相同字母只能表示相同维度Size,但是相同维度Size不要求必须相同字母
      • 下标数量和输入矩阵维度数量一定要对齐
    • 输出的维度是在箭头 -> 右侧指定,如果没有指定输出,则自动推断(注意,需要自动推断的场景是没有 -> 的场景,有->的场景不需要推断,->后面为空时表示输出是一个标量)
    • 推断思路是:重复下标都去掉,不重复下边按照顺序保留
      • ij,jk == ij,jk->ik
      • j,j == j,j->
      • ijd,jk == ijd,jk->idk
      • idi,jk == idi,jk->djk
  • 举例:给定两个二维张量 A 和 B,要进行矩阵乘法并求和可以写作:

    1
    torch.einsum('ij,jk->ik', A, B) # 等价于 torch.einsum('ij,jk', A, B)
    • 'ij' 表示张量 A 的维度,
    • 'jk' 表示张量 B 的维度,
    • 'ik' 表示输出张量的维度
    • 重复的下标(在这个例子中的 j)意味着沿着这些维度进行乘积和求和(后面会详细讲解)
  • 后面会有详细讲解


爱因斯坦求和约定讲解

  • 以 看图学 AI:einsum 爱因斯坦求和约定到底是怎么回事? 中的一个例子为例,torch.einsum('ij,jk->ik', A, B) 求解过程相当于下面的图片展示的形式:

  • 爱因斯坦求和约定的基本理解:

    • 对于任意的表达式,结果都等价于一个多重循环乘积(可能包含求和)的过程:
      • 输入:函数参数,字符串 -> 左边表示矩阵的输入下标
      • 输出:返回值,字符串 -> 右边表示矩阵的输出下标
      • 右边的下标有时候可以省略,此时需要自动推断
  • C = torch.einsum("ij,jk->ijk", A,B) 结果相当于下面的函数(性能上并不想当,因为下面是一种速度较慢的函数)

    1
    2
    3
    4
    for i in range(A[0]):
    for j in range(A[1]): # 也可以用 range(B[0])
    for k in range(B[1]):
    C[i][j][k] = A[i][j] * B[j][k]
  • C = torch.einsum("ij,jk->ik", A,B) 结果相当于下面的函数

    1
    2
    3
    4
    for i in range(A[0]):
    for j in range(A[1]): # 也可以用 range(B[0])
    for k in range(B[1]):
    C[i][k] += A[i][j] * B[j][k] # C[i][k]从0开始累加
  • C = torch.einsum("ii->i", A) 结果相当于下面的函数

    1
    2
    for i in range(A[0]): # 也可用A[1]
    C[i] = A[i][i]
  • 实际上,所有的表达式都可以表达成同一个形式,只要初始化结果的各个元素为0 ,然后统一使用加法即可


附录:一些简单示例

  • 矩阵乘法 :

    1
    2
    3
    A = torch.randn(3, 4)
    B = torch.randn(4, 5)
    C = torch.einsum('ij,jk->ik', A, B) # 等价于 torch.mm(A, B),C[i][k] += A[i][j] * B[j][k]
  • 向量内积 :

    1
    2
    3
    u = torch.randn(3)
    v = torch.randn(3)
    C = torch.einsum('i,i->', u, v) # 等价于 torch.dot(u, v),C += A[i] * B[i]
  • 张量转置 :

    1
    2
    A = torch.randn(2, 3, 4)
    C = torch.einsum('ijk->kji', A) # 转置张量 permute(2,1,0),C[k][j][i] += A[i][j][k]
  • 批量矩阵乘法 :

    1
    2
    3
    A = torch.randn(3, 2, 5)
    B = torch.randn(3, 5, 4)
    C = torch.einsum('bij,bjk->bik', A, B) # 对每一批次做矩阵乘法, C[b][i][k] += A[b][i][k] * B[b][k][k]
  • 乘法+求和 :

    1
    2
    3
    A = torch.randn(3, 4)
    B = torch.randn(4, 5)
    C = torch.einsum('ij,jk->', A, B) # 输出为一个标量 C += A[i][j] * B[j][k]

附录:高阶用法之省略号

  • 省略号容易误解,不建议使用!

附录:einops 包

  • 除了爱因斯坦求和(包含在 torch 包中)外,还有许多相似的爱因斯坦操作(包含在 einops 包中)

  • einops 包中的函数支持跨框架的数据格式,不仅仅是 PyTorch,比如 NumPy,TensorFlow 等

  • 最常见的有 rearrange 函数,可用于做下面的工作:

    • 对张量做更高阶的 reshape 或 view 操作
    • 对张量做 permute 或 transpose 操作
  • 特别地,爱因斯坦操作还有些对应的网络层,如 rearrange 函数对应的 einops.layers.torch.Rearrange 类可以像 nn.Linear 或 nn.ReLU 一样加入到 nn.Sequential() 中作为一个网络模块使用

  • rearrange 函数的简单示例

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    import torch
    import einops

    A = torch.randn(6,2,3,4)
    B = einops.rearrange(A, 'b c d e -> b (c d e)') # 等价于 A.reshape(6, -1)
    print(B.shape)
    # 输出:torch.Size([6, 24])

    A = torch.randn(6,2,3,4)
    B = A.reshape(6, -1)
    print(B.shape)
    # 输出:torch.Size([6, 24])

    A = torch.randn(6,2,3,4)
    B = einops.rearrange(A, 'b c d e -> b e d c') # 等价于 A.permute(0, 3, 2, 1)
    print(B.shape)
    # 输出:torch.Size([6, 4, 3, 2])

    A = torch.randn(6,2,3,4)
    B = A.permute(0, 3, 2, 1)
    print(B.shape)
    # 输出:torch.Size([6, 4, 3, 2])

    # 更复杂的用法
    A = torch.randn(2, 3, 9, 8)
    B = einops.rearrange(A, 'b c (d1 d2) (e1 e2) -> b c (d1 e1) d2 e2', d1=3, e1=2)
    print(B.shape)
    # 输出:torch.Size([2, 3, 6, 3, 4])
  • reduce 函数的简单示例(可选的 reduction 参数有 'sum','max','min','mean' 等):

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    import torch
    import einops

    A = torch.randn(6,2,3,4)
    B = einops.reduce(A, 'b c d e ->b c d', reduction='sum') # 等价于 A.sum(-1)
    print(B.shape)
    # 输出:torch.Size([6, 2, 3])

    A = torch.randn(6,2,3,4)
    B = einops.reduce(A, 'b c d e ->b c', reduction='sum') # 等价于 A.sum(-1).sum(-1)
    print(B.shape)
    # 输出:torch.Size([6, 2])

    A = torch.randn(6,2,3,4)
    B = einops.reduce(A, 'b c d e ->b c 1 1', reduction='mean') # 等价于 A.sum(-1).sum(-1)
    print(B.shape)
    # 输出:torch.Size([6, 2, 1, 1])

PyTorch——gather函数使用

  • 参考链接:
    • 一篇浅显易懂的博客:PyTorch中的高级索引方法——gather详解

gather函数形式

  • 包含 torch.gather 和 tensor.gather 两种形式,基本思路等价,他们的函数签名如下

    1
    2
    torch.gather(input, dim, index, *, sparse_grad=False, out=None)
    tensor.gather(dim, index, *, sparse_grad=False, out=None)
    • 注:在 PyTorch 的函数签名中,* 是 Python 语法中的一个特殊标记,用于表示强制关键字参数(keyword-only arguments)。这意味着在 * 之后的所有参数(如 sparse_grad 和 out)必须通过关键字(即使用参数名)来传递,而不能通过位置参数传递
  • 参数解释

    • input (Tensor): 输入张量
    • dim (int): 沿着哪个维度进行收集
    • index (LongTensor): 索引张量,包含要收集的元素的索引
    • sparse_grad (bool, 可选): 如果为True,梯度将是稀疏张量
    • out (Tensor, 可选): 输出张量(若该值不为 None,则会将返回值存储到 out 引用中,此时 out 和 返回值 是同一个对象)
  • 特别说明:gather 和普通的矩阵索引操作一样,操作支持反向传播


基本原理

  • 对于 3D 张量,gather操作可以表示为:

    1
    2
    3
    out[i][j][k] = input[index[i][j][k]][j][k]  # dim=0
    out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
    out[i][j][k] = input[i][j][index[i][j][k]] # dim=2
  • 特别说明(记忆方法)

    • index[i][j][k] 用于索引 out[i][j][k]
    • dim=0 :index 指定行号(第0维) ,列号(第1维)和通道(第2维)保持不变
    • dim=1 :index 指定列号(第1维) ,行号(第0维)和通道(第2维)保持不变
    • dim=2 :index 指定通道(第2维) ,行号(第0维)和列号(第1维)保持不变
    • 特别地:输出形状 = index 形状

代码示例

  • 一维张量示例

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    import torch

    # 基本用法
    input_tensor = torch.tensor([10, 20, 30, 40, 50])
    index_tensor = torch.tensor([0, 2, 4, 1])
    result = torch.gather(input_tensor, dim=0, index=index_tensor)
    print(result) # tensor([10, 30, 50, 20])

    # 使用tensor.gather方法
    result2 = input_tensor.gather(0, index_tensor)
    print(result2) # tensor([10, 30, 50, 20])
  • 二维张量示例

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    # 二维张量
    input_2d = torch.tensor([[1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]])

    # 沿着dim=0收集(按行收集)
    index_2d = torch.tensor([[0, 1, 2],
    [2, 0, 1]])
    result_dim0 = torch.gather(input_2d, dim=0, index=index_2d)
    print("dim=0 结果:")
    print(result_dim0)
    # tensor([[1, 5, 9],
    # [7, 2, 6]])

    # 沿着dim=1收集(按列收集)
    index_2d_col = torch.tensor([[0, 2],
    [1, 0],
    [2, 1]])
    result_dim1 = torch.gather(input_2d, dim=1, index=index_2d_col)
    print("dim=1 结果:")
    print(result_dim1)
    # tensor([[1, 3],
    # [5, 4],
    # [9, 8]])

一些实际应用场景

  • 获取最大值,获取对应标签等实例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    # 场景1: 获取每行的最大值索引对应的值
    scores = torch.tensor([[0.1, 0.8, 0.3],
    [0.6, 0.2, 0.9],
    [0.4, 0.7, 0.1]])

    # 获取每行最大值的索引
    max_indices = torch.argmax(scores, dim=1, keepdim=True)
    print("最大值索引:", max_indices) # tensor([[1], [2], [1]])

    # 使用gather获取最大值
    max_values = torch.gather(scores, dim=1, index=max_indices)
    print("最大值:", max_values) # tensor([[0.8], [0.9], [0.7]])

    # 场景2: 根据标签获取对应的预测概率
    predictions = torch.tensor([[0.2, 0.3, 0.5],
    [0.1, 0.8, 0.1],
    [0.6, 0.2, 0.2]])
    labels = torch.tensor([2, 1, 0]) # 真实标签

    # 获取每个样本对应标签的预测概率
    label_probs = torch.gather(predictions, dim=1, index=labels.unsqueeze(1))
    print("标签概率:", label_probs) # tensor([[0.5], [0.8], [0.6]])

注意事项

  • 索引值必须在 [0, input.size(dim)) 范围内
  • 除了指定的dim维度外,input 和 index 的其他维度大小必须相同
  • 输出张量的形状与 index 张量相同

PyTorch——backward函数详细解析

本文主要介绍PyTorch中backward函数和grad的各种用法


梯度的定义

  • \(y\) 对 \(x\) 的梯度可以理解为: 当 \(x\) 增加1的时候, \(y\) 值的增加量
  • 如果 \(x\) 是矢量(矩阵或者向量等),那么计算时也需要看成是多个标量的组合来计算,算出来的值表示的也是 \(x\) 当前维度的值增加1的时候, \(y\) 值的增加量

backward基础用法

  • tensorflow是先建立好图,在前向过程中可以选择执行图的某个部分(每次前向可以执行图的不同部分,前提是,图里必须包含了所有可能情况)
  • pytorch是每次前向过程都会重新建立一个图,反向(backward)的时候会释放,每次的图可以不一样, 所以在Pytorch中可以随时使用if, while等语句
    • tensorflow中使用if, while就得在传入数据前(构建图时)告诉图需要构建哪些逻辑,然后才能传入数据运行
    • PyTorch中由于不用在传入数据前先定义图(图和数据一起到达,图构建的同时开始计算数据?)
  • backward操作会计算梯度并将梯度直接加到变量的梯度上,所以为了保证梯度准确性,需要使用optimizer.zero_grad()清空梯度

计算标量对标量的梯度

  • 结构图如下所示

  • 上面图的代码构建如下

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    import torch
    from torch.autograd import Variable

    w1 = Variable(torch.Tensor([2]),requires_grad=True)
    w2 = Variable(torch.Tensor([3]),requires_grad=True)
    w3 = Variable(torch.Tensor([5]),requires_grad=True)
    x = w1 + w2
    y = w2*w3
    z = x+y
    z.backward()
    print(w1.grad)
    print(w2.grad)
    print(w3.grad)
    print(x.grad)
    print(y.grad)

    # output:
    tensor([1.])
    tensor([6.])
    tensor([3.])
    None
    None
    • 从图中的推导可知,梯度符合预期
    • \(x, y\) 不是叶节点,没有梯度存储下来,注意可以理解为梯度计算了,只是没有存储下来,PyTorch中梯度是一层层计算的

计算标量对矢量的梯度

  • 修改上面的构建为

    • 增加变量 \(s = z.mean\),然后直接求取 \(s\) 的梯度
  • 结构图如下:

  • 代码如下:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    import torch
    from torch.autograd import Variable

    w1 = Variable(torch.ones(2,2)*2,requires_grad=True)
    w2 = Variable(torch.ones(2,2)*3,requires_grad=True)
    w3 = Variable(torch.ones(2,2)*5,requires_grad=True)
    x = w1 + w2
    y = w2*w3
    z = x+y
    # z.backward()
    s = z.mean()
    s.backward()
    print(w1.grad)
    print(w2.grad)
    print(w3.grad)
    print(x.grad)
    print(y.grad)
    # output:
    tensor([[0.2500, 0.2500],
    [0.2500, 0.2500]])
    tensor([[1.5000, 1.5000],
    [1.5000, 1.5000]])
    tensor([[0.7500, 0.7500],
    [0.7500, 0.7500]])
    None
    None
    • 显然推导结果符合代码输出预期
    • 梯度的维度与原始自变量的维度相同,每个元素都有自己对应的梯度,表示当当前元素增加1的时候, 因变量值的增加量

计算矢量对矢量的梯度

  • 还以上面的结构图为例

  • 直接求中间节点 \(z\) 关于自变量的梯度

  • 代码如下

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    import torch
    from torch.autograd import Variable

    w1 = Variable(torch.ones(2,2)*2, requires_grad=True)
    w2 = Variable(torch.ones(2,2)*3, requires_grad=True)
    w3 = Variable(torch.ones(2,2)*5, requires_grad=True)
    x = w1 + w2
    y = w2*w3
    z = x+y
    z_w1_grad = torch.autograd.grad(outputs=z, inputs=w1, grad_outputs=torch.ones_like(z))
    print(z_w1_grad)
    • 在因变量是矢量时,grad_outputs参数不能为空,标量时可以为空(grad_outputs为空时和grad_outputs维度为1时等价)
    • grad_outputs的维度必须和outputs参数的维度兼容

关于autograd.grad函数

  • 参考博客: https://blog.csdn.net/qq_36556893/article/details/91982925
grad_outputs参数详解
  • 在因变量是矢量时,grad_outputs参数不能为空,标量时可以为空(grad_outputs为空时和grad_outputs维度为1时等价)
  • grad_outputs的维度必须和outputs参数的维度兼容
    [待更新]

backward 对计算图的清空

  • 在 PyTorch 中,backward() 方法在计算反向梯度后,不仅不会保存中间节点(非叶子张量)的梯度,还会释放对应的计算图

    • 所以,多次调用 backward() 方法是会报错的 RuntimeError: Trying to backward through the graph a second time
  • 如果想要多次调用 backward() 方法,需要使用 backward(retain_graph=True)

  • 多次调用 backward() 的错误示例:

    1
    2
    3
    4
    5
    6
    7
    8
    import torch

    x = torch.tensor(2.0, requires_grad=True)
    y = x**3
    z = y + 1

    z.backward() # 第一次调用,正常运行
    z.backward() # 报错:RuntimeError: Trying to backward through the graph a second time...
  • 多次调用 backward() 的正确示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    import torch

    x = torch.tensor(2.0, requires_grad=True)
    y = x**3
    z = y + 1

    z.backward(retain_graph=True) # 第一次调用,保留计算图
    print(x.grad) # 输出:tensor(12.)

    z.backward() # 第二次调用,此时可以不保留计算图
    print(x.grad) # 输出:tensor(24.),梯度累加了

自定义算子的前向后向实现

  • PyTorch 中,自定义算子适用于一些 PyTorch 内置函数无法满足的场景:
    • 实现个性化运算算子:当需要使用 PyTorch 未内置的数学运算(如特殊激活函数、自定义卷积等),且需支持自动求导时
    • 优化计算效率:对特定操作手动编写前向 / 反向逻辑,可能比内置函数更高效(如简化冗余计算)
    • 整合外部库:可用于将 C++、CUDA 或其他语言实现的算法接入 PyTorch,并支持自动求导
  • 自定义算子是 Function 类需继承 torch.autograd.Function,并实现两个静态方法 :
    • forward():定义前向传播逻辑,输入为张量,输出为计算结果
    • backward():定义反向传播梯度计算逻辑,输入为上游梯度,输出为各输入的梯度
  • PyTorch 自定义实现代码 Demo:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    import torch

    class SquareSumFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, y):
    """
    :param ctx: 可用于保存需要在反向传播中使用的张量,torch.autograd.function.BackwardCFunction 的实例
    :param x: 输入张量
    :param y: 输入张量
    :return: 输出张量
    """
    ctx.save_for_backward(x, y) # ctx 用于存储反向传播所需的中间变量
    output = x **2 + y **2
    return output

    @staticmethod
    def backward(ctx, grad_output): # backward 输入参数和 forward 输出参数数量一致
    """
    :param ctx: 获取在 forward() 中保存的中间变量
    :param grad_output: 上游梯度,计算到当前算子的梯度后,链式法则乘上这个梯度就行,形式为torch.Tensor,第一次反向操作传入值为常量1
    :return: 返回多个值,依次对齐 forward 过程输入的张量的梯度
    """
    x, y = ctx.saved_tensors # 取出前向保存的变量
    grad_x = grad_output * 2 * x # 计算这个算子的 forward 输出 对x的梯度,然后按照链式法则乘上上游梯度
    grad_y = grad_output * 2 * y
    return grad_x, grad_y # backward 输出 和 forward 输入参数一一对齐

    x = torch.tensor([2.0], requires_grad=True)
    y = torch.tensor([3.0], requires_grad=True)
    z = SquareSumFunction.apply(x, y) # 通过 apply 调用
    z.backward() # 计算梯度
    print(x.grad, y.grad) # 输出:tensor([4.]) tensor([6.]),符合预期
    print(z.grad_fn)
    # 若计算图没有被清空,可以用下面的式子主动调用算子的 backward 函数(注意这里不会自动给 x.grad和y.grad赋值)
    # print(z.grad_fn.apply(1)) # 传入常量1作为上游梯度 `grad_output`,会返回两个梯度值+grad_fn的 tuple (tensor([4.], grad_fn=<MulBackward0>), tensor([6.], grad_fn=<MulBackward0>)),

PyTorch——nn.Parameter类


整体说明

  • nn.Parameter 是 torch.Tensor 的一个子类,其定义方式如下:

    1
    2
    3
    class Parameter(torch.Tensor, metaclass=_ParameterMeta):
    """..."""
    def __init__ ...
  • nn.Parameter是最常用的模型参数类,其他许多高级封装的层(如nn.Linear和nn.Conv2d等)都包含着nn.Parameter对象作为参数


torch.Tensor和nn.Parameter的却别

  • 两者主要区别是是否自动注册为模型参数,具体逻辑见下文:
    特性 nn.Parameter torch.Tensor
    是否自动注册为模型参数 是 否
    是否默认启用梯度计算 是 (requires_grad=True) 否 (requires_grad=False)
    是否被优化器自动更新 是 否(需手动添加到优化器)
    适用场景 可训练参数(如权重、偏置) 中间结果或固定值
  • 在构建神经网络时,通常使用 nn.Parameter 来定义可训练参数,而 torch.Tensor 更适合存储不需要训练的数据

是否自动注册为模型参数

  • nn.Parameter :

    • 当 nn.Parameter 被赋值给 nn.Module 的属性时,它会自动注册为模型的可训练参数
    • 可以通过 model.parameters() 访问这些参数,优化器会自动更新它们
    • 示例:
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      import torch.nn as nn

      class MyModel(nn.Module):
      def __init__(self):
      super(MyModel, self).__init__()
      self.weight = nn.Parameter(torch.randn(2, 2)) # 自动注册为模型参数

      def forward(self, x):
      return x @ self.weight

      model = MyModel()
      for param in model.parameters():
      print(param) # 可以访问到 self.weight

      # Parameter containing:
      # tensor([[-0.1866, 0.6549],
      # [-0.2559, -0.4768]], requires_grad=True)
  • torch.Tensor :

    • 直接使用 torch.Tensor 初始化的张量不会被自动注册为模型参数
    • 无法通过 model.parameters() 访问,优化器也不会更新它
    • 示例:
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      import torch.nn as nn

      class MyModel(nn.Module):
      def __init__(self):
      super(MyModel, self).__init__()
      self.weight = torch.randn(2, 2) # 只是一个普通的张量,不会注册为参数

      def forward(self, x):
      return x @ self.weight

      model = MyModel()
      for param in model.parameters():
      print(param) # 不会输出 self.weight

      # 无任何输出

是否支持自动梯度计算

  • nn.Parameter :

    • nn.Parameter 是 torch.Tensor 的子类,默认启用梯度计算(requires_grad=True)
    • 在反向传播时,PyTorch 会自动计算其梯度
  • torch.Tensor :

    • 默认情况下,torch.Tensor 的 requires_grad=False,不会计算梯度
    • 如果需要计算梯度,必须手动设置 requires_grad=True
    • 示例:
      1
      self.weight = torch.randn(2, 2, requires_grad=True)  # 手动启用梯度计算

优化器是否能更新

  • nn.Parameter :

    • 优化器可以通过 model.parameters() 获取 nn.Parameter 并更新其值
  • torch.Tensor :

    • 普通的 torch.Tensor 不会被优化器识别,除非手动将其添加到优化器的参数列表中
    • 示例:
      1
      optimizer = torch.optim.SGD([self.weight], lr=0.01)  # 手动添加到优化器

使用场景

  • nn.Parameter :

    • 适用于定义模型的可训练参数(如权重、偏置等)
    • 是构建神经网络时的标准做法
  • torch.Tensor :

    • 适用于存储不需要训练的中间结果或固定值(如常量张量)
    • 如果需要训练,必须手动设置 requires_grad=True 并注册到优化器

示例对比

  • 对比代码:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    import torch
    import torch.nn as nn

    # 使用 nn.Parameter
    class ModelWithParameter(nn.Module):
    def __init__(self):
    super(ModelWithParameter, self).__init__()
    self.weight = nn.Parameter(torch.randn(2, 2)) # 自动注册为参数

    def forward(self, x):
    return x @ self.weight

    # 使用 torch.Tensor
    class ModelWithTensor(nn.Module):
    def __init__(self):
    super(ModelWithTensor, self).__init__()
    self.weight = torch.randn(2, 2, requires_grad=True) # 需要手动设置 requires_grad

    def forward(self, x):
    return x @ self.weight

    # 比较
    model1 = ModelWithParameter()
    print(list(model1.parameters())) # 输出: [Parameter containing...]

    model2 = ModelWithTensor()
    print(list(model2.parameters())) # 输出: []

nn.Linear和nn.Parameter的关系

  • nn.Parameter 是 torch.Tensor 的子类,用于表示模型中的可训练参数,当它被赋值给 nn.Module 的属性时,会自动注册为模型参数,参与反向传播和优化

  • nn.Linear 内部包含两个 nn.Parameter,分别用于存储权重和偏置

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    class Linear(Module):
    """...comments..."""
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: Tensor
    def __init__(self, in_features: int, out_features: int, bias: bool = True,
    device=None, dtype=None) -> None:
    factory_kwargs = {'device': device, 'dtype': dtype}
    super(Linear, self).__init__()
    self.in_features = in_features
    self.out_features = out_features
    self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs)) # Parameter定义
    if bias:
    self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) # Parameter定义
    else:
    self.register_parameter('bias', None)
    self.reset_parameters()
  • 对上面代码的其他解读:

    • __constants__ = ['in_features', 'out_features']: __constants__ 是 PyTorch 的一个特殊属性,用于声明哪些属性是常量。这些常量在 TorchScript(PyTorch 的 JIT 编译器)中会被优化,并且不会保存在模型的 state_dict 中
    • in_features: int: 是一个类型注解,声明 in_features 是一个整数类型的类属性。它表示输入特征的数量,即输入向量的维度,这个注解的主要作用是提高代码的可读性和类型检查(例如,使用静态类型检查工具如 mypy)
    • out_features: int: 同样是一个类型注解,声明 out_features 是一个整数类型的类属性。它表示输出特征的数量,即输出向量的维度,这个注解的主要作用是提高代码的可读性和类型检查
    • weight: Tensor: 是一个类型注解,声明 weight 是一个 Tensor 类型的类属性。weight 是线性层的权重矩阵,形状为 (out_features, in_features)。它会在 __init__ 方法中被初始化为一个可学习的参数(通过 torch.nn.Parameter 封装)

torch模型参数都是nn.Parameter类吗?

  • 在PyTorch中,模型参数通常是nn.Parameter类的实例,但并非所有模型中的可学习参数都必须是nn.Parameter

  • 模型中不是nn.Parameter类的一些特例和说明:

    • Buffer(缓冲区) :有些模型需要存储一些状态,但这些状态不是可学习的参数。这些状态可以通过register_buffer方法注册为缓冲区,而不是nn.Parameter。缓冲区不会被优化器更新,但会随模型一起保存和加载

      1
      self.register_buffer('running_mean', torch.zeros(num_features))
    • 非可学习参数 :有些参数虽然是模型的一部分,但不需要通过反向传播进行更新。这些参数可以是普通的torch.Tensor,而不是nn.Parameter

    • 自定义参数 :在某些情况下,开发者可能会手动管理参数,而不是使用nn.Parameter。例如,直接使用torch.Tensor并在需要时手动更新

    • 动态生成的临时变量 :在某些复杂的模型中,可能会有动态生成的参数或临时变量,这些可能不是nn.Parameter

    • 注:子模块中间接包含了nn.Parameter对象 :子模块(如nn.Linear、nn.Conv2d等)中包含的nn.Parameter对象参数会自动被识别并注册为模型参数,开发者可以通过parameters()方法访问这些参数。

PyTorch——torch.no_grad的用法


整体说明

  • 在 PyTorch 中,torch.no_grad()可用作装饰器 @torch.no_grad() 或上下文管理器 with torch.no_grad()(两者形式不同,但作用相同),用于禁用梯度计算
  • 如果 PyTorch 版本 >= 1.9,可以考虑使用 torch.inference_mode() 来替代 torch.no_grad(),以获得更好的性能

torch.no_grad()的作用

  • torch.no_grad() 的主要作用是临时关闭自动求导机制(autograd)。在被装饰的函数或代码块中,所有涉及张量的操作都不会构建计算图(computation graph),从而节省内存和计算资源:
    • 自动求导机制 :PyTorch 默认会记录张量操作的历史信息(即计算图),以便支持反向传播(backward())来计算梯度
    • 关闭梯度计算 :在推理阶段或其他不需要梯度的场景下,关闭自动求导可以减少内存占用,提高运行效率

使用场景

模型推理(Inference)

  • 在推理阶段,我们只需要前向传播(forward pass),而不需要计算梯度。因此,可以使用 @torch.no_grad() 来优化性能
    1
    2
    3
    4
    5
    6
    7
    8
    9
    @torch.no_grad()
    def evaluate_model(model, test_loader):
    model.eval() # 设置模型为评估模式,改回训练模式可以调用 model.train()
    total_loss = 0
    for data, target in test_loader:
    output = model(data)
    loss = loss_function(output, target)
    total_loss += loss.item()
    return total_loss

更新模型参数时不计算梯度

  • 在某些情况下,我们需要手动更新模型参数(例如权重剪枝、量化等),但不希望这些操作影响梯度计算
    1
    2
    3
    4
    @torch.no_grad()
    def update_weights(model):
    for param in model.parameters():
    param.add_(1.0) # 在参数上加 1,不会记录到计算图中

计算评估指标时不计算梯度

  • 当计算评估指标(如准确率、F1 分数等)时,不需要梯度计算
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    @torch.no_grad()
    def compute_accuracy(model, data_loader):
    correct = 0
    total = 0
    for inputs, labels in data_loader:
    outputs = model(inputs)
    _, predicted = torch.max(outputs, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
    return correct / total

附录:装饰器和上下文管理器的示例

作为装饰器

  • 装饰整个函数,使其在执行期间禁用梯度计算
    1
    2
    3
    @torch.no_grad()
    def inference(model, input_data):
    return model(input_data)

作为上下文管理器

  • 仅在特定代码块中禁用梯度计算
    1
    2
    3
    def inference(model, input_data):
    with torch.no_grad():
    return model(input_data)

附录:推理场景 torch.inference_mode() 的使用

  • 从 PyTorch 1.9 开始,引入了 torch.inference_mode(),它是 torch.no_grad() 的更高效替代品,专门用于推理阶段。与 torch.no_grad() 相比:
    • 性能更高 :torch.inference_mode() 会跳过一些额外的检查,进一步提升性能
    • 不可嵌套 :torch.inference_mode() 不能像 torch.no_grad() 那样嵌套使用
    • 推荐使用 :如果只用于推理,建议优先使用 torch.inference_mode()
  • 示例:
    1
    2
    3
    4
    @torch.inference_mode()
    def evaluate_model(model, test_loader):
    model.eval()
    ...
1…545556…61
Joe Zhou

Joe Zhou

Stay Hungry. Stay Foolish.

608 posts
49 tags
GitHub E-Mail
© 2026 Joe Zhou
Powered by Hexo
|
Theme — NexT.Gemini v5.1.4