Jiahong 的个人博客

凡事预则立,不预则废


  • Home

  • Tags

  • Archives

  • Navigation

  • Search

PyTorch——使用问题记录


PyTorch和torchvision版本不兼容

  • 问题描述:
    RuntimeError: Couldn't load custom C++ ops. This can happen if your PyTorch and torchvision versions are incompatible, or if you had errors while compiling torchvision from source. For further information on the compatible versions, check https://github.com/pytorch/vision#installation for the compatibility matrix. Please check your PyTorch version with torch.version and your torchvision version with torchvision.version and verify if they are compatible, and if not please reinstall torchvision so that it matches your PyTorch install.

  • 解决方案:

    1
    pip install torch torchvision --upgrade

PyTorch——关闭梯度的方法


整体说明

  • 不记录梯度的方法包括:torch.no_grad()、torch.set_grad_enabled(False)、tensor.detach()、tensor.requires_grad = False 等
    • 注:除了不记录梯度,这些方法还会释放计算图占用的内存,显著降低内存开销
  • 特殊注意 :model.eval() 仅影响特定层(如Dropout、BatchNorm)的行为,不会禁用梯度计算 ,必须配合 with torch.no_grad(): 使用才能完全关闭梯度
  • 各种模式选择建议
    • 若要临时停止梯度计算,推荐使用with torch.no_grad()上下文管理器或者torch.set_grad_enabled()
    • 若想永久性地停止某个张量的梯度计算,可使用detach()方法或者直接设置requires_grad=False
    • 对大型模型进行微调时,在模型层面设置requires_grad能有效节省内存
    • 模型推理阶段,要同时使用model.eval()和with torch.no_grad()

方法一:使用 with torch.no_grad() 上下文管理器

  • with torch.no_grad() 管理器能够暂停所有计算图的构建,进而显著降低内存的使用量并加快计算速度
    1
    2
    3
    4
    5
    6
    import torch

    x = torch.tensor([1.0], requires_grad=True)
    with torch.no_grad():
    y = x * 2
    print(y.requires_grad) # 输出 False

方法二:使用 @torch.no_grad() 作为装饰器

  • 装饰整个函数,使其在执行期间禁用梯度计算

    1
    2
    3
    @torch.no_grad()
    def inference(model, input_data):
    return model(input_data)
  • 注:@torch.no_grad() 装饰器和 with torch.no_grad() 上下文管理器的效果是一样的,一个针对方法,一个针对上下文


方法三:使用 detach() 方法

  • 运用detach()方法可以创建一个新的张量,这个新张量和计算图没有关联
    1
    2
    3
    4
    x = torch.tensor([1.0], requires_grad=True)
    y = x * 2
    z = y.detach() # z 和计算图无关
    print(z.requires_grad) # 输出 False

方法四:使用 torch.set_grad_enabled() 实现全局开关

  • 借助这个全局开关,能够控制整个代码块是否进行梯度计算
    1
    2
    3
    4
    5
    x = torch.tensor([1.0], requires_grad=True)
    torch.set_grad_enabled(False)
    y = x * 2
    torch.set_grad_enabled(True)
    print(y.requires_grad) # 输出 False

方法五:在张量层面设置 requires_grad=False

  • 你可以在创建张量时或者之后,把requires_grad属性设置为False,以此来阻止梯度的计算

    1
    2
    3
    4
    5
    6
    7
    x = torch.tensor([1.0], requires_grad=False)  # 创建时设置
    y = x * 2
    print(y.requires_grad) # 输出 False

    # 或者之后设置
    x = torch.tensor([1.0], requires_grad=True)
    x.requires_grad = False
  • 特别示例,也可以对模型参数直接操作:对于预训练模型进行微调时,你可以冻结部分层,只对特定层计算梯度

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    import torch.nn as nn

    model = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 1)
    )

    # 冻结所有参数
    for param in model.parameters():
    param.requires_grad = False

    # 只训练最后一层
    for param in model[2].parameters():
    param.requires_grad = True

模型推理时的特殊场景

  • 特别地,推理时,常常用 model.eval() 和 with torch.no_grad() 或 @torch.no_grad() 结合

  • 在模型推理阶段,同时使用这两个方法能够有效减少内存占用并提高计算效率

    1
    2
    3
    model.eval()  # 关闭 Dropout 和 BatchNorm 等训练特有的层
    with torch.no_grad():
    outputs = model(inputs)
  • 在仅推理的场景中,更常用 torch.inference_mode() 来替代 torch.no_grad()(要求 PyTorch 版本 >= 1.9),以获得更好的性能(跳过一些推理阶段非必要的检查),详情见:PyTorch——torch.no_grad的用法

    • torch.inference_mode() 可以用作上下文管理器或者装饰器
    • 推理场景优先推荐使用 torch.inference_mode(),但 torch.inference_mode() 仅适用于推理场景,其他场景不可乱用
    • torch.inference_mode() 不能像 torch.no_grad() 那样嵌套使用

附录:PyTorch中禁用梯度计算的方法对比

  • 整体对比
    方法 是否不记录梯度 是否释放计算图内存 作用范围 使用场景
    with torch.no_grad(): 是 是 代码块 推理阶段(如模型预测)、不需要梯度的计算(如验证集评估)。
    torch.set_grad_enabled(False) 是 是 全局(直到恢复为True) 临时关闭整个代码段的梯度计算,例如批量推理。
    tensor.detach() 是 是(对新张量) 单个张量 从计算图中分离张量,例如生成对抗网络(GAN)中的生成器输出。
    tensor.requires_grad = False 是 是(设置后) 单个张量或模型参数 冻结预训练模型的部分层,只训练特定参数。
    model.eval() 否 否 模型(影响Dropout/BatchNorm),使用model.train()可恢复 推理阶段,关闭训练特有的层(如Dropout),但仍会记录梯度(需配合no_grad()使用)。

附录:前向过程不涉及梯度计算,为什么需要关闭梯度?

  • 在调用 loss.backward() 之前,PyTorch 不会计算梯度,但是模型的前向过程会构建计算图,这也会消耗额外的内存

附录:非常简单的代码示例

  • 简单的代码示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    import torch

    # 方法1: torch.no_grad()
    with torch.no_grad():
    y = x * 2 # y 不记录梯度,不构建计算图

    # 方法2: set_grad_enabled
    torch.set_grad_enabled(False)
    y = x * 2 # 全局禁用梯度
    torch.set_grad_enabled(True)

    # 方法3: detach()
    y = x.detach() * 2 # y 是脱离计算图的新张量

    # 方法4: requires_grad = False
    x.requires_grad = False
    y = x * 2 # x 不再需要梯度,y 也不记录

    # 方法5: model.eval()(需配合 no_grad())
    model.eval()
    with torch.no_grad():
    outputs = model(inputs) # 完全禁用梯度

PyTorch——叶子张量


整体说明

  • tensor.is_leaf=True 的张量被称为叶子张量(Leaf Tensor),也称为叶子变量(Leaf Variable),部分博客或书籍也称为叶子节点
  • 叶子张量分两种类型,即以下两种情况下的张量 tensor.is_leaf 返回 True:
    • 类型一:requires_grad 为 False 的张量都是叶子张量(由于 requires_grad 为 False,所以不会存储梯度)
    • 类型二:requires_grad 为 True 的张量,如果是由用户创建的 ,而不是通过其他张量运算得到的,那么它是叶子张量;
      • 通过其他张量运算得到的,都是非叶子张量
  • 特别说明:
    • detach() 函数可以将节点从计算图中剥离,使其成为叶子节点,此时 requires_grad 为 False,同时 tensor.is_leaf 会变为 True 了
    • 从 CPU 定义好后移动到 GPU 时产生的张量,或者从 GPU 定义后,挪到 CPU 上的向量,也都是非叶子张量,tensor.is_leaf 返回 False
      • 注意:不论是 GPU 还是 CPU,叶子张量的判定不变,只有用户定义的张量是叶子张量,挪动以后得都不是叶子张量(除非同时修改其
    • 只有叶子张量的 requires_grad 属性可以被修改,非叶子张量的 requires_grad 属性是不能被修改的
      • 理解:非叶子张量都是派生出来的,且 requires_grad=True 的张量,叶子张量计算梯度时依赖性非叶子张量的梯度,不能随便修改 requires_grad 属性
    • 叶子张量的 grad_fn=None (包括 requires_grad=True 和 requires_grad=False 的都是)
      • 理解:tensor.grad_fn 属性指向/存储生成 tensor 张量的计算操作(如加法、减法、乘法等),叶子张量要么是直接由用户定义出来的(此时 requires_grad=True),要么是不需要计算梯度的(此时 requires_grad=False)
    • 非叶子张量的梯度不会被保存(因为不需要使用)

叶子张量的特性

  • 在反向传播过程中,只有叶子张量的梯度会被保留并存储在张量的grad属性中
    • 用于后续优化器更新参数等操作
    • 比如:神经网络层中的权值 w 的张量均为叶子节点,反向传播 backward() 就是为了求它们的梯度,进而更新权值
  • 非叶子张量的梯度在反向传播使用完后通常会被清除 ,以节省内存
    • 比如:计算中中间的一些派生阶段就是非叶子张量

叶子张量相关示例

  • 各种叶子张量判断示例
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    import torch as torch

    a = torch.randn(3,3)
    print(a.is_leaf, a.requires_grad) # True False
    a.requires_grad = True
    print(a.is_leaf, a.requires_grad) # True True
    b = a.cuda()
    print(b.is_leaf, b.requires_grad) # False True
    c = b.detach()
    print(c.is_leaf, c.requires_grad) # True False
    d = a + 2
    print(d.is_leaf, d.requires_grad) # False True

    e = torch.randn(3,3, device="cuda", requires_grad=True)
    print(e.is_leaf, e.requires_grad) # True True
    f = e.to("cpu")
    print(f.is_leaf, f.requires_grad) # False True

附录:如何打印非叶子张量的梯度?

  • 在 PyTorch 里,当开启自动求导功能时,中间变量(非叶子张量)的梯度默认不会被保存,目的是节省内存
    • 只有叶子节点(比如直接创建的张量)的梯度会被保留
    • 在任意时刻时,非中间节点的梯度(grad属性)是都是 None
  • 要获取中间变量的梯度,有两种方法:
    • 运用 retain_grad() 方法保留梯度
    • 借助钩子(hook)来捕获梯度(可以打印或者赋值给全局变量)
  • 打印非节点张量的代码示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    import torch

    x = torch.tensor(2.0, requires_grad=True)
    y = x**2
    z = y**2

    # 方法一:使用 retain_grad() 显式保留梯度
    y.retain_grad()

    z.backward(retain_graph=True) # 若不使用 retain_graph=True,计算图会在 backward 被清空,则后续想要调用 backward() 前需要重新构造计算图
    print("方法一:", y.grad) # 输出: 16.0

    # 方法二:使用钩子 hook 捕捉梯度
    gradient_list = []
    def save_gradient(grad):
    gradient_list.append(grad)

    ## 如果 前面的 z.backward(retain_graph=True) 不使用 retain_graph=True,则 backward() 会清空计算图,这里就需要重新构造计算图,目前使用的是 z.backward(retain_graph=True) ,不需要下面的两句
    # y = x**2
    # z = y**2

    hook = y.register_hook(save_gradient)
    z.backward()
    hook.remove() # 重点:即时移除钩子,防止不必要的内存占用

    print("方法二:", gradient_list[0]) # 输出: 16.0

PyTorch——torch.distributed.rpc的使用


整体说明

  • torch.distributed.rpc 库是 PyTorch 用于 RPC 通信的包,初始化 RPC 环境的核心函数是 init_rpc
  • 在 PyTorch 中
    • torch.distributed.init_process_group 用于初始化进程组环境, 数据并行 和 集体通信 ,比如同步梯度
    • torch.distributed.rpc.init_rpc 则用于初始化 RPC 通信环境, 模型并行 和 远程过程调用 ,比如在另一台机器上执行一个函数
  • 在很多复杂的分布式训练场景中(如全分片数据并行+流水线并行),torch.distributed 和 torch.distributed.rpc 这两个模块甚至会同时被使用 ,以发挥它们各自的优势

回顾:torch.distributed.init_process_group 初始化进程组

  • torch.distributed.init_process_group 函数是 PyTorch Distributed (c10d) 包的初始化入口

    • 注:c10d 是 “Caffe2 TenSoR Distributed” 的缩写,c10d 代表了 PyTorch 中负责分布式通信的核心底层模块,提供了跨进程、跨节点的数据传输和同步机制;实现了诸如 ProcessGroup(进程组)、Backend(通信后端,如 NCCL、Gloo 等)等核心组件,是 torch.distributed 高层 API 的底层支撑
  • torch.distributed.init_process_group 函数会启动一个进程组,用于执行集体通信

  • torch.distributed.init_process_group 启动的进程组的通信范式是集体通信 :

    • 所有进程都必须参与同一个操作,并且需要等待所有进程都完成该操作后才能继续
    • 常见的操作包括:
      • all_reduce:所有进程的Tensor进行汇总(如求和)后,将结果返回给所有进程。(用于梯度同步)
      • broadcast:将一个进程的Tensor广播给所有其他进程
      • barrier:同步所有进程,确保大家都到达同一个点
      • scatter, gather, all_gather 等
  • 特点是训练代码在每个进程(通常是每个GPU)上几乎是完全相同的

    • 每个进程都做同样的事情:读一部分数据、前向传播、反向传播、然后通过集体通信同步梯度
    • 示例如下:
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      # 示例:在每个进程上初始化进程组
      import torch.distributed as dist

      def setup(rank, world_size):
      dist.init_process_group(
      backend='nccl', # 或 'gloo', 'mpi'
      init_method='env://',
      rank=rank,
      world_size=world_size
      )
      # 创建模型、DDP包装、定义采样器等...
      torch.cuda.set_device(rank)

      # 假设使用 spawn 启动多个进程
      # 每个进程都会运行 setup 函数和后续相同的训练循环
  • 常用于:

    • 数据并行训练:(最经典的用法)
      • 在 DistributedDataParallel (DDP) 中,每个 GPU 上的模型副本计算完梯度后,使用 all_reduce 来对所有梯度进行求和平均,确保每个模型副本都得到相同的更新后的梯度
    • 分布式数据加载:使用 DistributedSampler 确保每个进程读取数据集中不重复的部分
    • 其他需要进程间紧密同步的计算任务

前置说明:RPC 框架建立在进程组之上

  • torch.distributed.rpc.init_rpc 启动的 RPC 建立在进程组之上:

    • RPC 框架底层依赖于 torch.distributed 的进程组来进行初始的握手和协调
    • 当调用 rpc.init_rpc 时,它会在内部检查是否已经存在一个初始化好的进程组
      • 如果存在,它就复用这个进程组
      • 如果不存在,它会*隐式地调用 dist.init_process_group * ,使用默认的 gloo 后端(除非你通过 rpc_backend_options 指定)来创建一个进程组
  • 在大多数情况下,只需要调用 rpc.init_rpc 即可,它会把底层需要的进程组也初始化好

    • 如果需要对进程组的后端(如使用nccl)或初始化方法进行更精细的控制,可以选择 显式地先调用 dist.init_process_group,然后再调用rpc.init_rpc ,代码示例如下:
      1
      2
      3
      4
      5
      6
      # 显式控制进程组后端的示例
      def setup(rank, world_size):
      # 1. 首先,用 NCCL 后端初始化进程组(为了更好的GPU通信性能)
      dist.init_process_group(backend='nccl', ...)
      # 2. 然后,初始化RPC框架。此时RPC检测到已有进程组,不会再次初始化
      rpc.init_rpc(...)
  • 亲测:init_rpc 使用的 master_ip:master_port,必须和 init_process_group 使用的一样,否则无法启动

    • 暂未看到官方明确说明(待补充),猜测原因是因为 RPC 框架依赖已经初始化好的进程组(若没有初始化进程组,init_rpc 会自己新建进程组)
    • 详细示例见附录
  • RPC 的 world_size 可以比 已初始化的进程组的多(虽然有依赖,但两者是相对独立的体系)【具体深层逻辑待梳理】

    • 即 init_rpc 的 world_size 可以比 init_process_group 的大
    • 具体实现时,部分 rank 进程不调用 init_process_group 即可实现

torch.distributed.rpc.init_rpc 初始化 PyTorch RPC 框架

  • torch.distributed.rpc.init_rpc 函数是 PyTorch RPC 框架 的初始化入口

  • torch.distributed.rpc.init_rpc 函数核心目的是启用远程过程调用 ,允许一个进程调用另一个进程(可能在不同的机器上)上定义的函数或方法

  • torch.distributed.rpc.init_rpc 初始化的 PyTorch RPC 框架 通信范式是点对点通信

    • 一个进程(工作者)可以主动请求另一个进程执行某个任务,而不需要所有进程都参与
    • 核心 API 包括:
      • rpc_sync():同步调用,调用方会等待远程函数执行完毕并返回结果
      • rpc_async():异步调用,调用方立即返回一个 Future 对象,未来再从中获取结果
      • remote():异步地在远程工作者上创建一个对象(例如一个模块),并返回一个指向该远程对象的引用(RRef)
  • 常用于:

    • 模型并行:将一个大模型的不同部分放在不同的机器或 GPU 上
      • 前向传播时,数据需要从一个机器流到下一个机器
      • RPC 可以管理这种跨机器的计算流
    • 参数服务器(PS):
      • 一个或多个进程作为参数服务器存储模型参数,其他工作者通过 RPC 向参数服务器拉取参数或推送梯度
    • 强化学习:
      • 多个环境模拟器(Actor)通过RPC将经验数据发送给一个中心 learner 进行训练
    • 分布式流水线并行(待补充):
      • 与 torch.distributed.pipeline.sync.Pipe 结合使用,RPC 负责处理不同设备间层的通信
  • 特点是每个进程的角色和执行的代码可能完全不同 ,一个简单的代码示例如下:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    # start shell: torchrun --nproc_per_node=2 torch_distributed_rpc_demo.py
    import os
    import torch
    import torch.distributed as dist
    import torch.distributed.rpc as rpc

    # 正常定义函数,可用于远程调用(在 worker0 发送命令从 worker1 执行函数)
    def worker_function(a, b):
    worker_info = rpc.get_worker_info()
    print(f"Worker {worker_info.name} received a call with a={a}, b={b}")
    return a + b

    def main():
    # 获取环境变量
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])

    # 1. 初始化进程组和 RPC
    dist.init_process_group(backend='gloo', rank=rank, world_size=world_size) # 建议提前初始化好进程组,如没有这行,下面的函数也会自动初始化一个 进程组
    rpc.init_rpc(
    name=f"worker_{rank}", # 若不唯一,会报错 RuntimeError: RPC worker name worker is not unique. Workers 1 and 0 share the same name
    rank=rank,
    world_size=world_size
    )

    print(f"Worker {rank} is ready.")

    if rank == 0:
    # 2. worker_0 发起 RPC 调用
    print("--- Worker 0 sending RPC calls ---")

    # 同步调用
    result_sync = rpc.rpc_sync(
    to=f"worker_1", # 真正的执行过程发生在 worker1 上
    func=worker_function,
    args=(10, 20)
    )
    print(f"Worker 0 received sync result from Worker 1: {result_sync}")

    # 异步调用
    future_async = rpc.rpc_async(
    to=f"worker_1",
    func=worker_function,
    args=(30, 40)
    )
    print("Worker 0 is doing other work while waiting for async result...")
    result_async = future_async.wait()
    print(f"Worker 0 received async result from Worker 1: {result_async}")

    else:
    # 3. worker_1 只是等待并处理 RPC 请求
    print("--- Worker 1 is waiting for RPC calls ---")
    # 由于 worker_0 正在调用,这里不需要做任何事,RPC 框架会自动处理

    rpc.shutdown() # 默认等待所有 RPC 调用完成后再关闭
    dist.destroy_process_group() # 注:若没有显示调用 dist.init_process_group,也不能调用这一行,否则会报错

    if __name__ == "__main__":
    main()
    • 上述代码的启动脚本(单机双卡):torchrun --nproc_per_node=2 torch_distributed_rpc_demo.py

rpc.init_rpc 函数 详细用法

  • PyTorch 的 torch.distributed.rpc.init_rpc 函数原型

    1
    2
    3
    4
    5
    6
    7
    torch.distributed.rpc.init_rpc(
    name,
    backend=None,
    rank=-1,
    world_size=None,
    rpc_backend_options=None
    )
  • 参数说明如下:

  • name(必填)

    • 当前工作进程的全局唯一名称 ,不能重复,重复会报错,后续会作为交互对象的指定
    • 名称只能包含数字、字母、下划线、冒号和/或短划线,且必须少于128个字符
    • 例如 "trainer0", "parameter_server1"
  • backend(选填)

    • 指定使用的RPC 后端实现
    • 可选值来自 torch.distributed.rpc.BackendType 枚举
      • BackendType.TENSORPIPE,更现代且推荐的 RPC 后端
      • BackendType.PROCESS_GROUP,不常用,对应的是 PyTorch 的后端,如 NCCL 和 Gloo
  • rank(必填)

    • 当前进程在组中的全局唯一 ID
    • 通常是一个整数,例如 0, 1, 2…
    • 必须与 world_size 一起正确设置
  • world_size(必填)

    • 参与作业的进程总数
    • 所有进程的 world_size 必须相同
  • rpc_backend_options(选填)

    • 传递给底层 RPC 代理的高级选项 ,类型必须与 backend 参数选定的值严格对齐,否则会报错
      • backend=BackendType.TENSORPIPE 对应 rpc_backend_options 类别为:TensorPipeRpcBackendOptions
      • backend=BackendType.PROCESS_GROUP 对应 rpc_backend_options 类别为:ProcessGroupRpcBackendOptions
    • 可常用于设置 RPC 超时、初始化方法等
      • 若未指定,默认超时为60秒,并使用 init_method="env://"(这意味着需要设置 MASTER_ADDR 和 MASTER_PORT 环境变量)
  • 使用示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    import torch.distributed.rpc as rpc
    from torch.distributed.rpc import TensorPipeRpcBackendOptions

    # 配置 TensorPipe 后端选项
    options = TensorPipeRpcBackendOptions(
    num_worker_threads=8, # 8 个工作线程,指定处理 RPC 请求的工作线程数,影响并发处理能力,默认值是 CPU 数量
    rpc_timeout=10000, # RPC 超时时间设置为 10 秒
    init_method="tcp://127.0.0.1:29500", # 初始化地址,格式和 init_process_group 的 init_method 参数一致
    device_maps={ "worker1": {0: 1}, "worker2": {1: 0} } # 本地 GPU 0 映射到 worker1 的 GPU 1;本地 GPU 1 映射到 worker2 的 GPU 0
    # devices=?, # 默认使用所有可用设备
    security_token="my_rpc_token" # str 类型安全令牌,用于验证节点身份的安全令牌,多机环境下防止未授权节点接入
    )

    # 初始化 RPC,传入配置
    rpc.init_rpc(
    name="worker0", # 当前节点名称
    backend=rpc.BackendType.TENSORPIPE,
    rpc_backend_options=options,
    rank=0, # 全局 rank
    world_size=2 # 总节点数
    )

TensorPipeRpcBackendOptions 的初始化细节

  • TensorPipeRpcBackendOptions 的函数原型为:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    def __init__(
    self,
    *, # Python 中可变位置参数的用法,仅使用一个 * 表示后续的参数必须按照 关键字参数形式调用
    num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS,
    rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC,
    init_method: str = rpc_contants.DEFAULT_INIT_METHOD,
    device_maps: Optional[Dict[str, Dict[DeviceType, DeviceType]]] = None,
    devices: Optional[List[DeviceType]] = None,
    _transports: Optional[List] = None,
    _channels: Optional[List] = None,
    )
  • 除了前文示例指定的参数外,其他参数介绍如下:

  • devices 参数

    • 指定当前节点上可用于 RPC 通信的设备(主要是 GPU),限制 RPC 操作只能使用列表中声明的设备
    • 默认值是 None(表示允许使用所有可见设备)
    • 当设置为非空列表时(如 [0, 1] 或 [torch.device('cuda:0'), torch.device('cuda:1')]),TensorPipe 只会使用列表中的设备进行数据传输,忽略其他设备
    • 主要用于多 GPU 场景下的资源隔离,例如限制 RPC 通信仅使用特定 GPU,避免占用训练用 GPU 资源
  • _transports 参数(以下划线开头,是私有参数,谨慎使用)

    • 指定 TensorPipe 可使用的传输协议(底层物理通信方式),用于跨节点或跨进程的数据传输
    • (推荐)默认值是自动检测并选择最优传输方式(通常为 ["tcp", "ibv"],即 TCP 和 InfiniBand)
    • 支持的传输协议:
      • "tcp":基于 TCP/IP 的网络传输,兼容性好,适用于普通网络环境
      • "ibv":基于 InfiniBand 的传输,低延迟、高带宽,适用于高性能计算集群
      • "shm":表示优先使用共享内存作为传输介质
      • "uv":基于 libuv 的本地传输(用于同一节点内的进程通信,通常自动启用)
    • 列表顺序表示优先级,TensorPipe 会优先尝试前面的传输方式
    • 仅在需要强制指定传输协议时使用(如集群仅支持 InfiniBand 时设置 ["ibv"]),默认自动选择即可满足多数场景
  • _channels 参数(以下划线开头,是私有参数,谨慎使用)

    • 指定 TensorPipe 可使用的通道类型(数据在设备间的内存传输方式),用于同一节点内的设备间通信(如 CPU-GPU、GPU-GPU)
    • (推荐)默认是自动检测并选择最优通道(通常为 ["cuda_ipc", "cuda_basic", "shm"])
    • 支持的通道类型有:
      • "cuda_ipc":GPU 进程间通信(同一节点内不同进程的 GPU 直接通信,效率最高)
      • "cuda_basic":通过 CPU 中转的 GPU 通信(当 cuda_ipc 不可用时降级使用)
      • "shm":共享内存通信(同一节点内的 CPU 进程间通信),这里强调通信通道,_transports 的 "shm" 强调通信介质
      • "basic":通过系统内存拷贝的通信(通用 fallback 方式)
    • 列表顺序表示优先级,优先使用高效的通道(如 cuda_ipc 优于 cuda_basic)
    • 主要用于调试或特殊硬件环境下的兼容性调整,默认配置已针对性能优化

初始化常见错误排查

  • 初始化失败:检查 MASTER_ADDR 和 MASTER_PORT 是否设置正确且所有机器可达,防火墙是否放行了指定端口,以及 rank 和 world_size 是否配置正确
  • 超时错误:增加 rpc_backend_options 中的 rpc_timeout 值
  • 版本不匹配:确保所有参与计算的 PyTorch 版本一致

RPC 框架的交互模式

  • PyTorch 的 torch.distributed.rpc 模块为 分布式模型训练 提供了核心的远程过程调用(RPC)和远程引用(RRef)机制
  • 以下是 RPC 相关核心方法:
  • rpc_sync()(同步方法,阻塞)
    • 在远程 worker 上同步调用函数
    • 返回函数调用的直接结果
  • rpc_async()(异步方法,非阻塞)
    • 在远程 worker 上异步调用函数
    • 返回一个 Future 对象,可通过 future.wait() 获取结果
  • remote()(非阻塞)
    • 异步在远程 worker 上创建对象
    • 返回一个 RRef (远程引用) 指向创建的对象
  • RRef.to_here()(阻塞)
    • 将 RRef 引用的值从所有者复制到本地节点
    • 返回引用的对象本身

同步远程调用 (rpc_sync)

  • rpc_sync 会阻塞调用者,直到远程函数执行完成并返回结果

  • 函数原型:

    1
    2
    3
    4
    5
    6
    7
    8
    # 在目标 worker (例如 to="worker1") 上同步运行函数 func 并获取结果
    result = torch.distributed.rpc.rpc_sync(
    to, # (str or WorkerInfo or int) 目标 worker 的名称、rank 或 WorkerInfo
    func, # (Callable) 待执行函数,是一个可调用函数(如 Python 函数、内置运算符或 TorchScript 函数)
    args=None, # (tuple) 传给函数的可变位置参数
    kwargs=None, # (dict) 传给函数的可变关键字参数
    timeout=-1.0 # (float, optional) RPC 超时时间(秒),-1.0 表示使用默认值(timeout=UNSET_RPC_TIMEOUT, 此时走初始化逻辑,或_set_rpc_timeout 设置的值),0 表示无限超时
    )
  • 使用示例:

    1
    2
    3
    # 在 worker 0 上:
    ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
    # 在 worker 1 上需要定义相应的函数或直接使用内置函数, worker 1 执行函数体,并返回结果

异步远程调用 (rpc_async)

  • rpc_async 会立即返回一个 Future 对象,你可以在之后需要时等待并获取结果,适用于非阻塞调用

  • 函数原型:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    # 异步调用远程函数,返回 Future 对象
    future = torch.distributed.rpc.rpc_async(
    to, # (str or WorkerInfo or int) 目标 worker 的名称、rank 或 WorkerInfo
    func, # (Callable) 待执行函数,是一个可调用函数(如 Python 函数、内置运算符或 TorchScript 函数)
    args=None, # (tuple) 传给函数的可变位置参数
    kwargs=None, # (dict) 传给函数的可变关键字参数
    timeout=-1.0 # (float, optional) RPC 超时时间(秒),-1.0 表示使用默认值(timeout=UNSET_RPC_TIMEOUT, 此时走初始化逻辑,或_set_rpc_timeout 设置的值),0 表示无限超时
    # 注:_set_rpc_timeout 是 PyTorch 分布式 RPC 框架中用于动态修改全局默认超时时间的函数,通常在 init_rpc(RPC 环境初始化)之后调用
    # 注:异步操作不会阻塞,这里的 timeout 是从此刻开始,远程操作本身返回的时间,如果有错也是在 future.wait() 时抛出
    )
    # ... 执行其他操作 ...
    result = future.wait() # 等待并获取结果,这里也可以设置 future 从此刻开始的等待 timeout, 设置后与 rpc_async 中的 timeout 时间两者时间取小(注意两者起始时间不同)
  • 使用示例:

    1
    2
    3
    fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3))
    fut2 = rpc.rpc_async("worker1", min, args=(1, 2))
    result = fut1.wait() + fut2.wait()

远程对象创建与引用 (remote 和 RRef)

  • remote 方法用于在远程 worker 上异步创建对象 ,并立即返回一个指向该对象的 RRef (Remote Reference)

  • RRef 的 to_here() 方法可以阻塞地将远程对象的值复制到本地

  • remote 函数原型

    1
    2
    3
    4
    5
    6
    7
    8
    9
    # 在远程 worker 上创建对象并返回 RRef,持有 func 函数的返回值
    rref = torch.distributed.rpc.remote(
    to, # (str or WorkerInfo or int) 目标 worker 的名称、rank 或 WorkerInfo,目标 worker(将成为所有者)
    func, # (Callable) 用于创建对象的函数,是一个可调用函数(如 Python 函数、内置运算符或 TorchScript 函数)
    args=None, # (tuple) 传给函数的可变位置参数
    kwargs=None, # (dict) 传给函数的可变关键字参数
    timeout=-1.0 # (float, optional) RPC 超时时间(秒),-1.0 表示使用默认值(timeout=UNSET_RPC_TIMEOUT, 此时走初始化逻辑,或_set_rpc_timeout 设置的值),0 表示无限超时
    # 注:异步操作不会阻塞,这里的 timeout 是从此刻开始,远程操作本身返回的时间,如果有错也是在 rref.to_here() 时抛出
    )
  • RRef 对象的一些使用方式:

    • 下面的两种调用相同
      1
      2
      3
      4
      5
      6
      7
      8
      9
      rref = rpc.remote("worker1", MyClass, args=(arg1, arg2))

      # 方式一:直接调用远程对象的函数(func_name 必须是 rref 持有对象的函数)
      rref.rpc_async().func_name(*args, **kwargs)

      # 方式二:原生的 rpc 实现过程,较为复杂,效果和方式一完全相同
      def run(rref, func_name, args, kwargs):
      return getattr(rref.local_value(), func_name)(*args, **kwargs)
      rpc.rpc_async(rref.owner(), run, args=(rref, func_name, args, kwargs))
  • RRef 的更多使用详情见附录

shutdown 函数的使用

  • 在所有 RPC 工作完成后、进程退出前,必须显式调用 torch.distributed.rpc.shutdown ();
  • 推荐使用 graceful=True(默认) 以等待所有未完成 RPC 结束并避免死锁与 SIGABRT 风险
  • 函数原型
    1
    2
    3
    4
    5
    6
    7
    # torch.distributed.rpc.shutdown
    def shutdown(
    graceful=True, # (bool),默认值为 True,表示是否优雅的结束
    # 如果 True,则 1) 等待所有针对 UserRRef 的系统消息处理完毕并删除这些引用;2) 阻塞直到本地与远程所有 RPC 进程均调用此方法,并等待所有未完成工作完成;
    # 当 graceful=False 时:不等待未完成的 RPC 任务和其他进程,直接强制关闭本地 RPC 系统。可能导致未完成任务失败或资源泄漏,仅建议在紧急退出场景使用
    timeout=DEFAULT_SHUTDOWN_TIMEOUT # 限制 “优雅等待” 的最大时长:
    )

重要注意事项

  • 较早的 PyTorch 版本(如 1.4 之前)API 可能不能使用
  • 分布式作业中的每个进程都必须成功调用 init_rpc,并在最后调用 rpc.shutdown() 来优雅地释放资源
  • RPC 框架在处理张量时,通常视为张量位于 CPU 上进行操作,以避免设备不匹配的错误
    • 如果需要在 GPU 上操作,可能需要在函数内部显式地将张量移动到 GPU
  • 网络问题、函数执行失败或超时都可能导致 RPC 调用失败,建议添加适当的异常处理
  • RPC 调用会有通信开销
    • 应尽量减少小规模的频繁调用,考虑批量操作或传输更大尺寸的数据以减少开销

附录:init_rpc 的初始化 IP 和端口要与 init_process_group 一致

  • 亲测:init_rpc 使用的 master_ip:master_port,必须和 init_process_group 使用的一样,否则无法启动
    • 暂未看到官方明确说明(待补充),猜测原因是因为 RPC 框架依赖已经初始化好的进程组(若没有初始化进程组,init_rpc 会自己新建进程组)
  • 示例如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    # start shell: torchrun --nproc_per_node=2 torch_distributed_rpc_demo.py
    import os
    import torch.distributed as dist
    import torch.distributed.rpc as rpc
    from torch.distributed.rpc import TensorPipeRpcBackendOptions

    def worker_function(a, b):
    worker_info = rpc.get_worker_info()
    print(f"Worker {worker_info.name} received a call with a={a}, b={b}")
    return a + b

    def main():
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    master_addr = os.environ.get('MASTER_ADDR', '未设置')
    master_port = int(os.environ.get('MASTER_PORT', '未设置'))
    print(f"master addr: {master_addr}:{master_port}")

    dist.init_process_group(backend='gloo', rank=rank, world_size=2,init_method=f"tcp://{master_addr}:{master_port}") # 提前初始化进程组

    rpc_backend_options = TensorPipeRpcBackendOptions(
    init_method=f"tcp://{master_addr}:{master_port}", # 与进程组初始化(init_process_group)时相同 IP和端口,可以正常初始化,端口不一致时会一直卡在这里不动
    num_worker_threads=2,
    rpc_timeout=10
    )

    # rpc_backend_options = None
    print(f"rank {rank} start init_rpc..., dist.is_initialized={dist.is_initialized()}")
    rpc.init_rpc(
    name=f"worker_{rank}",
    rank=rank,
    world_size=world_size,
    rpc_backend_options=rpc_backend_options
    )
    print(f"rank {rank} have done init_rpc")

    print(f"Worker {rank} is ready.")

    if rank == 0:
    result_sync = rpc.rpc_sync(
    to=f"worker_1", # 真正的执行过程发生在 worker1 上
    func=worker_function,
    args=(10, 20)
    )
    print(f"Worker 0 received sync result from Worker 1: {result_sync}")
    else:
    print("--- Worker 1 is waiting for RPC calls ---")

    rpc.shutdown() # 等待所有 RPC 完成后自动关闭
    dist.destroy_process_group() # 注:若没有显示调用 dist.init_process_group,也不用这一行

    if __name__ == "__main__":
    main()

附录:RPC 的 world_size 可以比 已初始化的进程组的多

  • RPC 的 world_size 可以比 已初始化的进程组的多(虽然有依赖,但两者是相对独立的体系)【具体深层逻辑待梳理】
    • 即 init_rpc 的 world_size 可以比 init_process_group 的大
    • 具体实现时,部分 rank 进程不调用 init_process_group 即可实现
  • 示例如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    # start shell: torchrun --nproc_per_node=4 torch_distributed_rpc_demo.py
    import os
    import torch.distributed as dist
    import torch.distributed.rpc as rpc
    from torch.distributed.rpc import TensorPipeRpcBackendOptions

    def worker_function(a, b):
    worker_info = rpc.get_worker_info()
    print(f"Worker {worker_info.name} received a call with a={a}, b={b}")
    return a + b

    def main():
    # 获取环境变量
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    master_addr = os.environ.get('MASTER_ADDR', '未设置')
    master_port = int(os.environ.get('MASTER_PORT', '未设置'))
    print(f"master addr: {master_addr}:{master_port}")

    if rank < 2: # 仅 rank=0,1 的节点初始化进程组
    print(f"rank {rank} start init_process_group...")
    dist.init_process_group(
    backend='gloo',
    rank=rank,
    world_size=2, # world_size 设置为2,甚至仅初始化一个也可以,rank<2 改成 rank<1 即可
    init_method=f"tcp://{master_addr}:{master_port}"
    ) # 提前初始化好进程组
    print(f"rank {rank} have done init_process_group")
    print(dist)

    rpc_backend_options = TensorPipeRpcBackendOptions(
    init_method=f"tcp://{master_addr}:{master_port}", # 与进程组初始化(init_process_group)时相同 IP和端口,可以正常初始化,端口不一致时会一直卡在这里不动
    num_worker_threads=2,
    rpc_timeout=10
    )

    print(f"rank {rank} start init_rpc..., dist.is_initialized={dist.is_initialized()}")
    rpc.init_rpc(
    name=f"worker_{rank}",
    rank=rank,
    world_size=world_size, # world_size 由传入的确定,实际上可能是 4
    rpc_backend_options=rpc_backend_options
    )
    print(f"rank {rank} have done init_rpc")

    print(f"Worker {rank} is ready.")

    if rank == 0:
    print("--- Worker 0 sending RPC calls ---")

    # 同步调用
    result_sync = rpc.rpc_sync(
    to=f"worker_1", # 真正的执行过程发生在 worker1 上
    func=worker_function,
    args=(10, 20)
    )
    print(f"Worker 0 received sync result from Worker 1: {result_sync}")

    # 异步调用
    future_async = rpc.rpc_async(
    to=f"worker_1",
    func=worker_function,
    args=(30, 40)
    )
    print("Worker 0 is doing other work while waiting for async result...")
    result_async = future_async.wait()
    print(f"Worker 0 received async result from Worker 1: {result_async}")

    elif rank == 3:
    print("--- Worker 3 sending RPC calls ---")

    # 同步调用
    result_sync = rpc.rpc_sync(
    to=f"worker_2", # 真正的执行过程发生在 worker2 上
    func=worker_function,
    args=(10, 20)
    )
    print(f"Worker 3 received sync result from Worker 2: {result_sync}")

    else:
    print("--- Worker 1 is waiting for RPC calls ---")

    rpc.shutdown() # 等待所有 RPC 完成后自动关闭
    if dist.is_initialized(): # 必须判断,否则可能报错,因为部分进程为初始化 init_process_group
    dist.destroy_process_group() # 注:若没有显示调用 dist.init_process_group,不用这一行

    if __name__ == "__main__":
    main()

附录:RRef 的使用方式细节

  • 参考自:PyTorch 官方文档-RRef
  • Warning:当前使用 CUDA 张量时,暂不支持 RRef,需要把张量从 GPU 转移到 CPU 才可以
  • RRef(Remote REFerence,远程引用)是指向远程工作节点上某一类型(如张量 Tensor)值的引用
    • 该句柄会确保被引用的远程值在其所属节点上保持有效
    • 在多机训练场景中,RRef 可用于持有对其他工作节点上 nn.Module(神经网络模块)的引用,并在训练过程中调用相应函数来获取或修改这些模块的参数
    • 更多细节可参考:Remote Reference Protocol
  • 注:RRef 也是一个 Python 类(torch.distributed.rpc.RRef),但有时泛指所有远程引用,而不是某一个类

torch.distributed.rpc.PyRRef 类型总体说明

  • 类型定义原型:

    1
    2
    3
    4
    5
    6
    7
    # __pybind11_builtins.pybind11_object 表明 PyRRef 是一个通过 pybind11 技术从 C++ 创建的 Python 类
    class PyRRef(__pybind11_builtins.pybind11_object):
    # ...

    # RRef 的定义
    class RRef(PyRRef, Generic[T]):
    pass
  • PyRRef 类用于封装对远程工作节点上某一类型值的引用

    • 此句柄会确保被引用的远程值在对应工作节点上保持有效
    • 当满足以下任一条件时,UserRRef(用户侧远程引用)将被销毁:
      • 1)在应用代码和本地 RRef 上下文环境中,均不存在对该 UserRRef 的引用;
      • 2)应用已调用 graceful shutdown 操作
    • 调用已销毁 RRef 的方法会导致未定义行为
      • RRef 实现仅提供“尽力而为”的错误检测机制(best-effort error detection),因此在调用 rpc.shutdown()(RPC 关闭函数)后,应用不应再使用 UserRRef
  • 警告:RRef 只能由 RPC 模块进行序列化和反序列化操作

    • 若不通过 RPC 模块(如使用 Python pickle 模块、PyTorch 的 save()/load() 函数、JIT 的 save()/load() 函数等)对 RRef 进行序列化或反序列化,将会引发错误

PyRRef 类型初始化参数(不常用,一般不自己定义 RRef)

  • value(object 类型):需用此 RRef 封装的值
  • type_hint(Type 类型,可选):传递给 TorchScript 编译器的 Python 类型提示,用于指定 value 的类型

PyRRef 类型使用示例

  • 以下示例为简化说明,省略了 RPC 初始化和关闭相关代码。有关这些操作的详细内容,请参考 RPC 文档

  • 使用 rpc.remote() 创建 RRef

    1
    2
    3
    4
    5
    6
    7
    import torch
    import torch.distributed.rpc as rpc

    # 在远程工作节点“worker1”上执行 torch.add 操作,并创建指向结果的 RRef
    rref = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) # Debug 时看到的可能是 torch._C._distributed_rpc.PyRRef 类型,实际和 torch.distributed.rpc.PyRRef 是同一个类的不同路径
    # 从 RRef 中获取值的副本
    x = rref.to_here()
  • 基于本地对象创建 RRef

    1
    2
    3
    4
    5
    6
    7
    import torch
    from torch.distributed.rpc import RRef

    # 创建本地张量
    x = torch.zeros(2, 2)
    # 用 RRef 封装本地张量,生成 RRef 对象
    rref = RRef(x)
  • 与其他工作节点共享 RRef

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    # 在工作节点“worker0”和“worker1”上均执行以下函数定义
    def f(rref):
    # 获取 RRef 指向的值并加 1
    return rref.to_here() + 1

    # 在工作节点“worker0”上执行以下代码
    import torch
    import torch.distributed.rpc as rpc
    from torch.distributed.rpc import RRef

    # 基于本地零矩阵创建 RRef
    rref = RRef(torch.zeros(2, 2))
    # 通过 RPC 调用,将 RRef 共享给工作节点“worker1”,引用计数会自动更新
    rpc.rpc_sync("worker1", f, args=(rref,))

PyRRef.backward 方法

  • backward 方法原型

    1
    backward(self: torch._C._distributed_rpc.PyRRef, dist_autograd_ctx_id: int = -1, retain_graph: bool = False) -> None
  • backward 方法以 RRef 作为反向传播的根节点,执行反向传播过程

  • 若提供了 dist_autograd_ctx_id(分布式自动求导上下文 ID),则会使用该上下文 ID,从 RRef 的所属节点开始执行分布式反向传播

    • 这种情况下,应使用 get_gradients() 函数获取梯度
  • 若 dist_autograd_ctx_id 为 None,则默认当前为本地自动求导图,仅执行本地反向传播

    • 在本地反向传播场景中,调用此 API 的节点必须是该 RRef 的所属节点,且 RRef 所指向的值需为标量张量(scalar Tensor)
  • PyRRef.backward 方法参数

    • dist_autograd_ctx_id(int 类型,可选):用于获取梯度的分布式自动求导上下文 ID,默认值为 -1
    • retain_graph(bool 类型,可选):若设为 False,则计算梯度所用的计算图会被释放
      • 几乎不需要设为 True
      • 一般仅在需要多次执行反向传播时,才需将其设为 True,默认值为 False
  • PyRRef.backward 方法示例

    1
    2
    3
    4
    5
    6
    import torch.distributed.autograd as dist_autograd

    # 创建分布式自动求导上下文,并获取上下文 ID
    with dist_autograd.context() as context_id:
    # 以当前 RRef 为根节点,执行分布式反向传播
    rref.backward(context_id)

PyRRef 一些简单方法说明

  • confirmed_by_owner 方法:

    1
    confirmed_by_owner(self: torch._C._distributed_rpc.PyRRef)  bool
    • 返回该 RRef 是否已得到其所属节点的确认(真正拥有对象的节点为所属节点,对象时存储到所属节点上的)
      • OwnerRRef(所属节点侧远程引用)始终返回 True;
      • UserRRef(用户侧远程引用)仅在所属节点知晓该 UserRRef 存在时,才返回 True
  • is_owner 方法:

    1
    is_owner(self: torch._C._distributed_rpc.PyRRef) -> bool
    • 返回当前节点是否为该 RRef 的所属节点(对象时存储到所属节点上的)
  • local_value 方法:

    1
    local_value(self: torch._C._distributed_rpc.PyRRef) -> object
    • 若当前节点是该 RRef 的所属节点 ,则返回对本地值的引用;否则,抛出异常
  • owner 方法:

    1
    owner(self: torch._C._distributed_rpc.PyRRef) -> torch._C._distributed_rpc.WorkerInfo
    • 返回该 RRef 所属节点的工作节点信息(WorkerInfo 对象)
  • owner_name 方法:

    1
    owner_name(self: torch._C._distributed_rpc.PyRRef) -> str
    • 返回该 RRef 所属节点的工作节点名称(字符串形式)

PyRRef.remote 方法

  • remote 方法原型:

    1
    remote(self: torch._C._distributed_rpc.PyRRef, timeout: float = -1.0) -> object
  • remote 方法创建一个辅助代理(helper proxy)

    • 该代理可以用于发起 remote 调用,以 RRef 的所属节点为目标节点,在该 RRef 所指向的对象上执行函数
  • 更具体地说,rref.remote().func_name(*args, **kwargs) 等价于以下代码:

    1
    2
    3
    4
    5
    6
    def run(rref, func_name, args, kwargs):
    # 获取 RRef 本地值,并调用其指定方法
    return getattr(rref.local_value(), func_name)(*args, **kwargs)

    # 向 RRef 所属节点发起 remote 调用,执行 run 函数
    rpc.remote(rref.owner(), run, args=(rref, func_name, args, kwargs))
  • remote 方法参数:

    • timeout(float 类型,可选):rref.remote() 的超时时间(单位:秒)
      • 若在该时间内未成功创建 RRef,则下次尝试使用该 RRef(如调用 to_here() 方法)时,会引发超时错误
      • 若未提供该参数,则使用默认的 RPC 超时时间
  • remote 方法示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    from torch.distributed import rpc

    # 在远程节点“worker1”上执行 torch.add,创建指向结果的 RRef
    rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
    # 调用 RRef 指向对象的 size() 方法,获取结果(远程执行,本地获取)
    rref.remote().size().to_here() # 返回结果:torch.Size([2, 2])
    # 注:返回类型
    # rref.remote(): <torch.distributed.rpc.rref_proxy.RRefProxy object at 0x11fa6e470>
    # rref.remote().size(): UserRRef(RRefId = GloballyUniqueId(created_on=0, local_id=4), ForkId = GloballyUniqueId(created_on=0, local_id=5))
    # rref.remote().size().to_here(): torch.Size([2])

    # 调用 RRef 指向对象的 view() 方法,重塑张量形状并获取结果
    rref.remote().view(1, 4).to_here() # 返回结果:tensor([[1., 1., 1., 1.]])

PyRRef.rpc_async 方法

  • rpc_async 方法原型:

    1
    rpc_async(self: torch._C._distributed_rpc.PyRRef, timeout: float = -1.0) -> object
  • rpc_async 方法创建一个辅助代理

    • 该代理可发起 rpc_async 调用,以 RRef 的所属节点为目标节点,在该 RRef 所指向的对象上执行函数
  • 更具体地说,rref.rpc_async().func_name(*args, **kwargs) 等价于以下代码:

    1
    2
    3
    4
    5
    6
    def run(rref, func_name, args, kwargs):
    # 获取 RRef 本地值,并调用其指定方法
    return getattr(rref.local_value(), func_name)(*args, **kwargs)

    # 向 RRef 所属节点发起异步 RPC 调用,执行 run 函数
    rpc.rpc_async(rref.owner(), run, args=(rref, func_name, args, kwargs))
  • rpc_async 方法参数

    • timeout(float 类型,可选):rref.rpc_async() 的超时时间(单位:秒)
      • 若在该时间内调用未完成,则会引发超时异常
      • 若未提供该参数,则使用默认的 RPC 超时时间
  • rpc_async 方法示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    from torch.distributed import rpc

    # 在远程节点“worker1”上执行 torch.add,创建指向结果的 RRef
    rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
    # 异步调用 size() 方法,等待结果返回
    rref.rpc_async().size().wait() # 返回结果:torch.Size([2, 2])
    # 注:返回类型
    # rref.rpc_async(): <torch.distributed.rpc.rref_proxy.RRefProxy object at 0x11bcfa3e0>
    # rref.rpc_async().size(): <torch.jit.Future object at 0x11bcc65c0>
    # rref.rpc_async().size().wait(): torch.Size([2])

    # 异步调用 view() 方法,重塑张量形状并等待结果返回
    rref.rpc_async().view(1, 4).wait() # 返回结果:tensor([[1., 1., 1., 1.]])

PyRRef.rpc_sync 方法

  • rpc_sync 方法原型:

    1
    rpc_sync(self: torch._C._distributed_rpc.PyRRef, timeout: float = -1.0) -> object
  • rpc_sync 方法创建一个辅助代理

    • 该代理可发起 rpc_sync 调用,以 RRef 的所属节点为目标节点,在该 RRef 所指向的对象上执行函数
  • 更具体地说,rref.rpc_sync().func_name(*args, **kwargs) 等价于以下代码:

    1
    2
    3
    4
    5
    6
    def run(rref, func_name, args, kwargs):
    # 获取 RRef 本地值,并调用其指定方法
    return getattr(rref.local_value(), func_name)(*args, **kwargs)

    # 向 RRef 所属节点发起同步 RPC 调用,执行 run 函数
    rpc.rpc_sync(rref.owner(), run, args=(rref, func_name, args, kwargs))
  • rpc_sync 方法参数:

    • timeout(float 类型,可选):rref.rpc_sync() 的超时时间(单位:秒)
      • 若在该时间内调用未完成,则会引发超时异常
      • 若未提供该参数,则使用默认的 RPC 超时时间
  • rpc_sync 方法示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    from torch.distributed import rpc

    # 在远程节点“worker1”上执行 torch.add,创建指向结果的 RRef
    rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
    # 同步调用 size() 方法,获取张量尺寸
    rref.rpc_sync().size() # 返回结果:torch.Size([2, 2])
    # 注:返回类型:
    # rref.rpc_sync(): <torch.distributed.rpc.rref_proxy.RRefProxy object at 0x11bcfa440>
    # rref.rpc_sync().size(): torch.Size([2])

    # 同步调用 view() 方法,重塑张量形状
    rref.rpc_sync().view(1, 4) # 返回结果:tensor([[1., 1., 1., 1.]])

PyRRef.to_here 方法

  • to_here 方法原型:

    1
    to_here(self: torch._C._distributed_rpc.PyRRef, timeout: float = -1.0) -> object
  • to_here 方法是阻塞式调用,将 RRef 指向的值从其所属节点复制到本地节点,并返回该值

    • 若当前节点是该 RRef 的所属节点,则直接返回对本地值的引用
  • to_here 方法参数:

    • timeout(float 类型,可选):to_here 方法的超时时间(单位:秒)
      • 若在该时间内调用未完成,则会引发超时异常
      • 若未提供该参数,则使用默认的 RPC 超时时间(60 秒)

附录:torch.distributed.rpc.PyRRef 和 torch._C._distributed_rpc.PyRRef 的区别

  • 这两个是同一个类的不同引用路径,但有重要的使用区别,两者本质关系为

    1
    2
    3
    4
    # 推荐使用的方式;稳定,版本间兼容性好;可能包含额外的封装、错误检查、文档
    torch.distributed.rpc.PyRRef # 公共API接口,官方支持,文档完整
    # 不推荐直接使用;内部实现细节;可能在PyTorch版本更新时发生变化;直接的C++绑定,可能缺少一些Python层的便利功能
    torch._C._distributed_rpc.PyRRef # 底层C++实现,以`_`开头表示私有
  • 实际上 torch.distributed.rpc.PyRRef 就是对 torch._C._distributed_rpc.PyRRef 的封装或直接引用

  • 验证两者的等价关系

    1
    2
    3
    4
    5
    import torch.distributed.rpc as rpc
    import torch._C._distributed_rpc as _rpc

    # 通常这两个是相同的对象
    print(rpc.PyRRef is _rpc.PyRRef) # 可能输出 True

为什么会看到内部 API torch._C._distributed_rpc.PyRRef?

  • 调试时的堆栈跟踪可能显示内部路径
  • 某些高级用法或源码分析
  • 类型检查时可能遇到

PyTorch——激活函数调用的最佳实践


整体说明

  • 在PyTorch中,调用激活函数有几种常见方法:

    • 方法一 :通过 torch.relu() 直接调用激活函数【几乎不会使用,常用torch.nn.functional.relu()方式替代】
      • 使用时不需要实例化对象,适合用于简单的函数式调用
      • 它接受一个张量作为输入,并对这个张量应用ReLU操作
    • 方法二 :使用 torch.nn.functional.relu()(通常简写为 F)直接调用激活函数,一种函数式的方式来应用ReLU激活函数
      • 使用时不需要实例化对象,适合用于简单的函数式调用
      • 和torch.relu()类似,但它提供了更多的灵活性,比如你可以通过参数控制是否进行原地操作(inplace)等
      • 从技术角度来看,使用 torch.relu 和 F.relu 最终调用了相同的底层实现
    • 方法三 :通过 torch.nn.ReLU() 创建对象后调用
      • 这是ReLU作为一个层(layer)的形式出现,属于torch.nn模块。当你需要将ReLU作为一个网络的一部分时使用
      • 在使用前需要先实例化一个ReLU对象,然后可以像其他层一样调用这个对象。这种方式更适合于构建神经网络模型的架构中,因为它遵循了面向对象的设计理念,可以方便地集成到模型定义中
  • 总结来说,如果你只是想应用ReLU而不考虑网络结构,可以直接使用torch.relu()或torch.nn.functional.relu()。若你在构建一个复杂的神经网络并且希望以层的形式组织你的激活函数,则推荐使用torch.nn.ReLU()。对于更细粒度的控制需求,如执行原地操作来节省内存,torch.nn.functional.relu()是更好的选择


使用 torch.nn.functional 调用激活函数

  • 设计目的 : torch.nn.functional 提供了一系列函数式接口,适用于直接对张量执行操作,比如激活函数、池化等。这种方式非常适合用于需要灵活地应用不同操作的场景
  • 典型使用场景 : 当你需要在模型外部或者在自定义的前向传播逻辑中灵活地应用某些操作时,F 模块下的函数是非常有用的。例如,在定义一个自定义的 forward 方法时,你可以直接对输入张量调用 F.relu()
    1
    2
    3
    4
    5
    import torch
    import torch.nn.functional as F

    input_tensor = torch.randn(2, 3)
    output_tensor = F.relu(input_tensor)

使用 torch.nn.ReLU 调用激活函数

  • 设计目的 : 尽管 torch 下也有可以直接调用的激活函数(如 torch.relu),但这种做法并不常见
  • 典型使用场景 : 实际上,对于激活函数这类操作,更推荐使用 torch.nn.functional 或者对应的 nn 模块中的类(例如 torch.nn.ReLU)。这是因为它们提供了更清晰的设计模式,并且与PyTorch的整体设计理念更加一致
    1
    2
    3
    4
    5
    6
    import torch
    import torch.nn as nn

    relu_layer = nn.ReLU()
    input_tensor = torch.randn(2, 3)
    output_tensor = relu_layer(input_tensor)

总结

  • F 模块提供了一种更为灵活的方式来调用激活函数和其他层操作,因为它允许你直接将这些操作应用于任何张量,而不需要先将其包装在一个模块中
  • 使用 torch.nn.functional 或者对应的 nn 模块中的类来调用激活函数有助于保持代码的一致性和提高代码的可读性,这是由于PyTorch社区普遍采用这样的编码风格
  • 因此,尽管从技术角度来看,使用 torch.relu 和 F.relu 最终调用了相同的底层实现,但为了遵循最佳实践和保持代码的一致性,推荐使用 F.relu 或者 nn.ReLU 来应用ReLU激活函数。这样做不仅使得代码更具可读性,也更容易维护

PyTorch——计算机视觉torchvision

PyTorch中有个torchvision包,里面包含着很多计算机视觉相关的数据集(datasets),模型(models)和图像处理的库(transforms)等
本文主要介绍数据集中(ImageFolder)类和图像处理库(transforms)的用法


PyTorch预先实现的Dataset

  • ImageFolder

    1
    from torchvision.datasets import ImageFolder
  • COCO

    1
    from torchvision.datasets import coco
  • MNIST

    1
    from torchvision.datasets import mnist
  • LSUN

    1
    from torchvision.datasets import lsun
  • CIFAR10

    1
    from torchvision.datasets import CIFAR10

ImageFolder

  • ImageFolder假设所有的文件按照文件夹保存,每个文件夹下面存储统一类别的文件,文件夹名字为类名

  • 构造函数

    1
    ImageFolder(root, transform=None, target_transform=None, loader=default_loader)
    • root:在root指定的路径下寻找图片,root下面的每个子文件夹就是一个类别,每个子文件夹下面的所有文件作为当前类别的数据
    • transform:对PIL Image进行的转换操作,transform的输入是使用loader读取图片的返回对象
      • PIL是 Python Imaging Library 的简称,是Python平台上图像处理的标准库
    • target_transform:对label的转换, 默认会自动编码
      • 默认编码为从0开始的数字,如果我们自己将文件夹命名为从0开头的数字,那么将按照我们的意愿命名,否则命名顺序不确定
      • 测试证明,如果文件夹下面是root/cat/, root/dog/两个文件夹,则自动编码为{‘cat’: 0, ‘dog’: 1}
      • class_to_idx属性存储着文件夹名字和类别编码的映射关系,dict
      • classes属性存储着所有类别,list
    • loader:从硬盘读取图片的函数
      • 不同的图像读取应该用不同的loader
      • 默认读取为RGB格式的PIL Image对象
      • 下面是默认的loader
        1
        2
        3
        4
        5
        6
        def default_loader(path):
        from torchvision import get_image_backend
        if get_image_backend() == 'accimage':
        return accimage_loader(path)
        else:
        return pil_loader(path)

transfroms详解

  • 包导入

    1
    from torchvision.transforms import transforms
  • transforms包中包含着很多封装好的transform操作

    • transforms.Scale(size):将数据变成制定的维度
    • transforms.ToTensor():将数据封装成PyTorch的Tensor类
    • transforms.Normalize(mean, std): 将数据标准话,具体标准化的参数可指定
  • 可将多个操作组合到一起,同时传入 ImageFolder 等对数据进行同时操作,每个操作被封装成一个类

    1
    2
    3
    4
    simple_transform = transforms.Compose([transforms.Resize((224,224))
    ,transforms.ToTensor()
    ,transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    train = ImageFolder('dogsandcats/train/',simple_transform)
  • torchvision.transforms.transforms包下的操作类都是基于torchvision.transforms.functional下的函数实现的

    • 导入torchvision.transforms.functional的方式
      1
      from torchvision.transforms import functional

PyTorch——GPU管理

本文主要介绍PyTorch相关的GPU管理


GPU管理整体说明

  • PyTorch的GPU管理包含以下接口:
    • cuda设备可用性判断:torch.cuda.is_available()
    • 设备对象创建和获取:torch.device('cuda')
    • 张量切换设备:tensor.to(gpu_device)
    • 模型切换设备:model.to(gpu_device)
    • 模型参数设备打印:print(f"参数所在设备: {param.device}") for param in model.parameters(): # model.parameters()
    • cuda设备可用数量查看:torch.cuda.device_count()
    • 当前环境cuda设备管理:torch.cuda.device(1)
    • 当前环境cuda设备号查看:torch.cuda.current_device()

Torch的默认指定GPU设定

  • PyTorch 会默认选择一个 GPU 作为当前设备(通常是编号为 0 的 GPU),原因如下:
    • 单 GPU 环境 :如果系统中只有一块 GPU,PyTorch 会默认使用它(编号为 0)
    • 多 GPU 环境 :如果有多个 GPU,PyTorch 仍然会默认使用编号为 0 的 GPU,除非显式地指定使用其他 GPU
    • 简化开发 :默认设备机制使得开发者在不显式指定设备的情况下,代码仍然可以正常运行
  • 默认 GPU 的设定是为了方便开发者,避免每次都需要手动指定设备

GPU管理先关概念

设备编号(Device Index)

  • PyTorch 使用从 0 开始的整数编号来标识 GPU
  • 例如,如果有 4 块 GPU,它们的编号分别是 0、1、2、3
  • 可以通过 torch.cuda.device_count() 获取当前可用的 GPU 数量

当前设备(Current Device)

  • PyTorch 会跟踪当前正在使用的 GPU,称为「当前设备」
  • 默认情况下,当前设备是编号为 0 的 GPU
  • 可以通过 torch.cuda.current_device() 获取当前设备的编号

切换设备简单示例

  • 可以使用 torch.cuda.set_device(device_id) 来切换当前设备

    1
    2
    torch.cuda.set_device(1)  # 切换到编号为 1 的 GPU
    # 接下来的操作会默认在指定GPU上
  • 也可以通过 torch.device 来指定设备:

    1
    2
    device = torch.device("cuda:1")  # 使用编号为 1 的 GPU
    tensor = torch.tensor([1, 2, 3]).to(device)

GPU高阶管理

  • 如果没有显式指定设备,PyTorch 会将张量和模型放在默认 GPU(通常是 cuda:0)上

  • 例如:

    1
    x = torch.tensor([1, 2, 3]).cuda()  # 默认放在 cuda:0 上
  • 可以通过环境变量 CUDA_VISIBLE_DEVICES 控制哪些 GPU 对 PyTorch 可见。例如:

    1
    export CUDA_VISIBLE_DEVICES=1,2  # 只让 GPU 1 和 2 对 PyTorch 可见

    在这种情况下,PyTorch 会将可见的 GPU 重新编号为 0 和 1

代码示例

  • 使用示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    ## 设备管理
    # 创建一个 CPU 设备对象
    cpu_device = torch.device('cpu')
    # 创建一个 GPU 设备对象,如果有多个 GPU,可以指定编号,如 'cuda:1'
    gpu_device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    ## 检查GPU是否可用
    # if torch.cuda.is_available()

    ## 获取GPU的数量
    num_gpus = torch.cuda.device_count()

    ## 张量切换设备
    # 将张量移动到 GPU 设备
    tensor = tensor.to(gpu_device)
    # 或者使用 .cuda() 方法
    tensor = tensor.cuda() if torch.cuda.is_available() else tensor
    # 将张量移动到 CPU 设备
    tensor = tensor.cpu()
    ## 检查张量所在设备
    print(f"张量所在设备: {tensor.device}")

    ## 模型切换设备(跟张量操作一致)
    model = torch.nn.Linear(10, 10)
    # 将模型移动到 GPU 设备
    model = model.to(gpu_device)
    # 或者使用 .cuda() 方法
    model = model.cuda() if torch.cuda.is_available() else model
    # 将模型移动到 CPU 设备
    model = model.cpu()
    ## 检查参数所在设备
    for param in model.parameters(): # model.parameters()
    print(f"参数所在设备: {param.device}")

    ## 当前环境的多设备管理
    # 指定GPU设备方案一
    if torch.cuda.is_available() and num_gpus > 1:
    with torch.cuda.device(1): # 指定使用第二个 GPU
    tensor = torch.randn(3, 3).cuda()
    print(f"张量所在设备: {tensor.device}")
    # 指定GPU设备方案二
    if torch.cuda.is_available() and num_gpus > 1:
    torch.cuda.set_device(1) # 指定使用第二个 GPU
    # 输出设备号
    if torch.cuda.is_available():
    current_device = torch.cuda.current_device()
    print(f"当前使用的 GPU 编号: {current_device}")

    ## 更清晰的设备管理方式
    # 将 tensor 移动到 cuda:0
    gpu_device0 = torch.device('cuda:0')
    tensor0 = tensor.to(gpu_device0)

    # 将 tensor 移动到 cuda:1
    gpu_device1 = torch.device('cuda:1')
    tensor1 = tensor.to(gpu_device1)

PyTorch——模型存储与加载


整体说明

  • 在 PyTorch 中,训练完成后,可以将模型保存到磁盘上持久化存储,模型保存方式有:
    • 保存/加载整个模型(包括模型结构和参数)
    • 保存/加载模型参数(状态字典)(推荐方式,包含参数和其他状态信息)
    • 保存/加载检查点(用于断点续训)
  • 对于跨渠道的保存和加载,可以存储为 TorchScript 格式,这是一种专为生产环境优化的模型序列化格式
    • TorchScript 格式能够将 PyTorch 模型转换为一种可序列化、可优化的中间表示形式,便于在不同环境(包括 C++ 部署)中运行
  • 保存模型时 .pt 和 .pth 两种存储格式之间没有本质区别,主要区别在于使用习惯:
    • 早期 PyTorch 文档和示例中更常用 .pth 扩展名
    • 后来随着 PyTorch 版本中,官方示例逐渐开始使用 .pt,逐渐成为更推荐的格式
    • 实际上两者完全等价,亲测:.pt 和 .pth 直接修改后缀就能混用

存储模型和加载模型示例

  • 有两种主要的方法来保存模型:保存整个模型或仅保存模型的状态字典(推荐)

整个模型存储(不常用)

  • 保存整个模型并加载整个模型

    1
    2
    3
    4
    5
    # 保存整个模型
    torch.save(model, 'simple_model.pt') # 保存

    # 加载整个模型
    model_loaded = torch.load('simple_model.pt') # 加载
    • 注意:这种方法要求模型类定义必须可用,读不到类名会出错
  • 特别提示(容易出错):容易导致模型代码定义和加载存储模型不一致情况

    • 加载规则:实际上只要模型的类名相同即可加载(按照类名匹配的),模型结构可以定义不一致
    • 模型加载后模型实际对象结构与存储真实结构一致,会丢失当前代码定义的模型所有结构(包括类属性也不存在)
    • 理论上仅仅需要定义一个类名即可正常加载,但是使用时是按照存储模型的结构来使用的

仅存储模型参数(常用)

  • 保存模型的状态字典 并加载:

    1
    2
    3
    4
    5
    6
    # 保存模型参数
    torch.save(model.state_dict(), 'simple_model_state_dict.pth')

    # 创建模型示例并加载模型参数
    model_loaded = SimpleModel() # 实例化模型
    model_loaded.load_state_dict(torch.load('simple_model_state_dict.pth'))
    • 待加载模型参数和定义的模型参数不同时会直接报错
    • 状态字典方法更灵活,允许你只保存模型参数,不包括模型结构,这在部署时特别有用

常见保存和加载方式的整体示例

  • 整体示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    import torch
    import torch.nn as nn
    import torch.optim as optim

    class DiyModel(nn.Module):
    def __init__(self):
    super(DiyModel, self).__init__()
    self.fc1 = nn.Linear(10, 20)
    self.fc2 = nn.Linear(20, 2)
    self.relu = nn.ReLU()

    def forward(self, x):
    x = self.fc1(x)
    x = self.relu(x)
    x = self.fc2(x)
    return x

    model = DiyModel()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    # 1. 保存和加载整个模型(包括结构和参数)
    # 保存
    torch.save(model, 'entire_model.pt')

    # 加载
    loaded_entire_model = torch.load('entire_model.pt')
    loaded_entire_model.eval() # 设置为评估模式,方便后续的 Serving
    print("加载整个模型成功")

    # 2. 仅保存和加载模型参数(推荐方式)
    # 保存
    torch.save(model.state_dict(), 'model_parameters.pt')

    # 加载
    loaded_model = DiyModel() # 需要先创建模型实例
    loaded_model.load_state_dict(torch.load('model_parameters.pt'))
    loaded_model.eval() # 设置为评估模式
    print("加载模型参数成功")

    # 3. 保存和加载检查点(用于断点续训)
    # 保存:包含模型参数、优化器状态、 epoch 等信息
    checkpoint = {
    'epoch': 5,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': 0.23,
    # 可以添加其他需要保存的信息
    }
    torch.save(checkpoint, 'training_checkpoint.pt')

    # 加载检查点
    loaded_checkpoint = torch.load('training_checkpoint.pt')

    # 恢复模型和优化器状态
    restored_model = DiyModel()
    restored_optimizer = optim.SGD(restored_model.parameters(), lr=0.001, momentum=0.9)

    restored_model.load_state_dict(loaded_checkpoint['model_state_dict'])
    restored_optimizer.load_state_dict(loaded_checkpoint['optimizer_state_dict'])
    epoch = loaded_checkpoint['epoch']
    loss = loaded_checkpoint['loss']

    restored_model.train() # 恢复训练时设置为训练模式
    print(f"加载检查点成功: epoch {epoch}, loss {loss}")

    # 4. 跨设备保存和加载(跨设备
    # 在 GPU 上保存,在 CPU 上加载
    if torch.cuda.is_available():
    model.cuda()
    torch.save(model.state_dict(), 'model_gpu.pt') # 同时也会保存一些 GPU 相关信息

    # 在CPU上加载GPU保存的模型
    cpu_model = DiyModel()
    # 特别需要注意的点:map_location指明设备,是必须的参数(实际上,为了兼容,建议所有的加载都加上 `map_location=device` 参数)
    cpu_model.load_state_dict(torch.load('model_gpu.pt', map_location=torch.device('cpu')))
    print("在 CPU 上加载 GPU 保存的模型成功")

TorchScript 格式

  • TorchScript 格式能够跨平台存储和加载(可在 C++ 环境中运行,无需 Python 依赖)
  • TorchScript 格式在生产环境中更常用
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    import torch
    import torch.nn as nn

    class DiyModel(nn.Module):
    def __init__(self):
    super(DiyModel, self).__init__()
    self.fc1 = nn.Linear(10, 20)
    self.fc2 = nn.Linear(20, 2)
    self.relu = nn.ReLU()

    def forward(self, x):
    x = self.fc1(x)
    x = self.relu(x)
    x = self.fc2(x)
    return x

    model = DiyModel() # 创建模型

    # 1. 将模型转换为TorchScript格式(注:这一步是必要的)
    # 方法1:跟踪(tracing) - 适用于无控制流的模型(不常用,不推荐,容易出错)
    example_input = torch.randn(1, 10)
    traced_script_module = torch.jit.trace(model, example_input)
    # tracing 方法引入一个输入数据来执行流程,同时基于执行流程生成模型静态图
    # 如果存在控制流,会只保留example_input下会遇到的控制流
    # 以后遇到其他数据也会都走这个控制流,从而导致错误发生
    # 所以仅适用于无控制流的模型,不推荐使用

    # 方法2:脚本(scripting) - 适用于有控制流的模型和无控制流的模型(常用,推荐,兼容性好)
    scripted_script_module = torch.jit.script(model) # 解析 Python 代码结构生成静态图,能完整保留模型结构
    # 特别说明:scripting形式和tracing保存的模型性能上差异不大

    # 2. 保存TorchScript模型
    traced_script_module.save("traced_model.pt")
    scripted_script_module.save("scripted_model.pt")

    # 3. 加载TorchScript模型
    loaded_traced_model = torch.jit.load("traced_model.pt")
    loaded_scripted_model = torch.jit.load("scripted_model.pt")

    # 4. 使用加载的模型进行推理
    loaded_traced_model.eval()
    loaded_scripted_model.eval()
    with torch.no_grad():
    input_data = torch.randn(1, 10)
    output1 = loaded_traced_model(input_data)
    output2 = loaded_scripted_model(input_data)

使用 tracing 方法保存含控制流模型

  • 下面是使用 tracing 方法保存含控制流模型出错的示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    import torch
    import torch.nn as nn

    class ControlFlowModel(nn.Module):
    def __init__(self):
    super().__init__()
    self.fc1 = nn.Linear(1, 1)
    self.fc2 = nn.Linear(1, 1)

    def forward(self, x):
    if x > 0: # 控制流语句:根据输入值选择不同分支
    output = self.fc1(x) # 分支1
    else:
    output = self.fc2(x) # 分支2
    return output

    model = ControlFlowModel()
    model.fc1.weight.data = torch.tensor([[2.0]]) # 分支1:输出 = 2*x
    model.fc1.bias.data = torch.tensor([0.0])
    model.fc2.weight.data = torch.tensor([[3.0]]) # 分支2:输出 = 3*x
    model.fc2.bias.data = torch.tensor([0.0])

    # 1. 使用 tracing 方法转换模型(示例输入为正数,触发分支1)
    example_input = torch.tensor([1.0]) # 正数:走fc1分支
    traced_model = torch.jit.trace(model, example_input)

    # 2. 测试不同输入的推理结果
    test_inputs = [
    torch.tensor([2.0]), # 正数(与示例输入同分支)
    torch.tensor([-1.0]) # 负数(与示例输入不同分支)
    ]

    print("Original Model Output:")
    for x in test_inputs:
    print(f"Input {x.item()}: Output {model(x).item()}")

    print("\nTracing Model Output:")
    for x in test_inputs:
    print(f"Input {x.item()}: Output {traced_model(x).item()}")
    # 输入 -1.0 时错误地输出了 -2.0,实际上应该输出 -3.0

    # Original Model Output:
    # Input 2.0: Output 4.0
    # Input -1.0: Output -3.0
    #
    # Tracing Model Output:
    # Input 2.0: Output 4.0
    # Input -1.0: Output -2.0

大模型的存储和加载

  • 大模型一般以 Safetensors 格式存储(Hugging Face 的默认存储形式),许多 CV 和 NLP 的开源模型都是这个格式
  • 超大规模模型还可以分片保存
  • 大模型存储和加载的示例:
    1
    # 待补充

附录:加载模型后的动作

  • 切换模式 :无论哪种保存方式,在加载模型后,如果模型包含 Batch Normalization、Dropout 等层,都必须调用 model.eval() 来确保这些层在推理时行为正确
    • 因为这样会关闭 Dropout 和 Batch Normalization 等层的行为变化
  • 禁用梯度 :调用模型 Serving 前建议禁用梯度
    • 加速并节省内存 :禁用梯度计算可以减少运行时的内存占用,因为在前向过程中不需要存储用于反向传播的信息,由于也不需要准备一些梯度计算的步骤,可小幅提升 Serving 速度
    • 增加代码可读性 :使用 torch.no_grad() 强调了当前操作的目的(即,仅执行前向计算而不更新模型权重),这提高了代码的可读性和意图的一致性(这对于维护和理解代码非常有帮助)
  • torch.no_grad()的实践Demo:
    1
    2
    3
    4
    5
    model.eval()  # 设置模型为评估模式,不会禁用梯度
    with torch.no_grad(): # 在评估过程中禁用梯度计算
    output = model(X_test)
    loss = criterion(output, Y_test)
    print(f'Test Loss: {loss.item():.4f}')

Shell——进程查找


ps

  • 应用场景:当使用命令sh run.sh启动一个进程后,想要删除,却不知道进程号
  • 查找步骤:
    • 首先使用ps aux | grep run.sh列出进程
    • 杀死进程kill -9 [PID]

Spark——DataFrame读取Array类型


spark 从 DataFrame 中读取 Array 类型的列

  • 代码示例
    1
    2
    3
    4
    dataFrame.rdd.map(row => {
    val vectorCol = row.getAs[Seq[Double]]("VectorCol")
    vectorCol.toArray
    }).collect().foreach(println)
1…555657…61
Joe Zhou

Joe Zhou

Stay Hungry. Stay Foolish.

608 posts
49 tags
GitHub E-Mail
© 2026 Joe Zhou
Powered by Hexo
|
Theme — NexT.Gemini v5.1.4