Jiahong 的个人博客

凡事预则立,不预则废


  • Home

  • Tags

  • Archives

  • Navigation

  • Search

PyTorch——torch.no_grad的用法


整体说明

  • 在 PyTorch 中,torch.no_grad()可用作装饰器 @torch.no_grad() 或上下文管理器 with torch.no_grad()(两者形式不同,但作用相同),用于禁用梯度计算
  • 如果 PyTorch 版本 >= 1.9,可以考虑使用 torch.inference_mode() 来替代 torch.no_grad(),以获得更好的性能

torch.no_grad()的作用

  • torch.no_grad() 的主要作用是临时关闭自动求导机制(autograd)。在被装饰的函数或代码块中,所有涉及张量的操作都不会构建计算图(computation graph),从而节省内存和计算资源:
    • 自动求导机制 :PyTorch 默认会记录张量操作的历史信息(即计算图),以便支持反向传播(backward())来计算梯度
    • 关闭梯度计算 :在推理阶段或其他不需要梯度的场景下,关闭自动求导可以减少内存占用,提高运行效率

使用场景

模型推理(Inference)

  • 在推理阶段,我们只需要前向传播(forward pass),而不需要计算梯度。因此,可以使用 @torch.no_grad() 来优化性能
    1
    2
    3
    4
    5
    6
    7
    8
    9
    @torch.no_grad()
    def evaluate_model(model, test_loader):
    model.eval() # 设置模型为评估模式,改回训练模式可以调用 model.train()
    total_loss = 0
    for data, target in test_loader:
    output = model(data)
    loss = loss_function(output, target)
    total_loss += loss.item()
    return total_loss

更新模型参数时不计算梯度

  • 在某些情况下,我们需要手动更新模型参数(例如权重剪枝、量化等),但不希望这些操作影响梯度计算
    1
    2
    3
    4
    @torch.no_grad()
    def update_weights(model):
    for param in model.parameters():
    param.add_(1.0) # 在参数上加 1,不会记录到计算图中

计算评估指标时不计算梯度

  • 当计算评估指标(如准确率、F1 分数等)时,不需要梯度计算
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    @torch.no_grad()
    def compute_accuracy(model, data_loader):
    correct = 0
    total = 0
    for inputs, labels in data_loader:
    outputs = model(inputs)
    _, predicted = torch.max(outputs, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
    return correct / total

附录:装饰器和上下文管理器的示例

作为装饰器

  • 装饰整个函数,使其在执行期间禁用梯度计算
    1
    2
    3
    @torch.no_grad()
    def inference(model, input_data):
    return model(input_data)

作为上下文管理器

  • 仅在特定代码块中禁用梯度计算
    1
    2
    3
    def inference(model, input_data):
    with torch.no_grad():
    return model(input_data)

附录:推理场景 torch.inference_mode() 的使用

  • 从 PyTorch 1.9 开始,引入了 torch.inference_mode(),它是 torch.no_grad() 的更高效替代品,专门用于推理阶段。与 torch.no_grad() 相比:
    • 性能更高 :torch.inference_mode() 会跳过一些额外的检查,进一步提升性能
    • 不可嵌套 :torch.inference_mode() 不能像 torch.no_grad() 那样嵌套使用
    • 推荐使用 :如果只用于推理,建议优先使用 torch.inference_mode()
  • 示例:
    1
    2
    3
    4
    @torch.inference_mode()
    def evaluate_model(model, test_loader):
    model.eval()
    ...

PyTorch——关闭梯度的方法


整体说明

  • 不记录梯度的方法包括:torch.no_grad()、torch.set_grad_enabled(False)、tensor.detach()、tensor.requires_grad = False 等
    • 注:除了不记录梯度,这些方法还会释放计算图占用的内存,显著降低内存开销
  • 特殊注意 :model.eval() 仅影响特定层(如Dropout、BatchNorm)的行为,不会禁用梯度计算 ,必须配合 with torch.no_grad(): 使用才能完全关闭梯度
  • 各种模式选择建议
    • 若要临时停止梯度计算,推荐使用with torch.no_grad()上下文管理器或者torch.set_grad_enabled()
    • 若想永久性地停止某个张量的梯度计算,可使用detach()方法或者直接设置requires_grad=False
    • 对大型模型进行微调时,在模型层面设置requires_grad能有效节省内存
    • 模型推理阶段,要同时使用model.eval()和with torch.no_grad()

方法一:使用 with torch.no_grad() 上下文管理器

  • with torch.no_grad() 管理器能够暂停所有计算图的构建,进而显著降低内存的使用量并加快计算速度
    1
    2
    3
    4
    5
    6
    import torch

    x = torch.tensor([1.0], requires_grad=True)
    with torch.no_grad():
    y = x * 2
    print(y.requires_grad) # 输出 False

方法二:使用 @torch.no_grad() 作为装饰器

  • 装饰整个函数,使其在执行期间禁用梯度计算

    1
    2
    3
    @torch.no_grad()
    def inference(model, input_data):
    return model(input_data)
  • 注:@torch.no_grad() 装饰器和 with torch.no_grad() 上下文管理器的效果是一样的,一个针对方法,一个针对上下文


方法三:使用 detach() 方法

  • 运用detach()方法可以创建一个新的张量,这个新张量和计算图没有关联
    1
    2
    3
    4
    x = torch.tensor([1.0], requires_grad=True)
    y = x * 2
    z = y.detach() # z 和计算图无关
    print(z.requires_grad) # 输出 False

方法四:使用 torch.set_grad_enabled() 实现全局开关

  • 借助这个全局开关,能够控制整个代码块是否进行梯度计算
    1
    2
    3
    4
    5
    x = torch.tensor([1.0], requires_grad=True)
    torch.set_grad_enabled(False)
    y = x * 2
    torch.set_grad_enabled(True)
    print(y.requires_grad) # 输出 False

方法五:在张量层面设置 requires_grad=False

  • 你可以在创建张量时或者之后,把requires_grad属性设置为False,以此来阻止梯度的计算

    1
    2
    3
    4
    5
    6
    7
    x = torch.tensor([1.0], requires_grad=False)  # 创建时设置
    y = x * 2
    print(y.requires_grad) # 输出 False

    # 或者之后设置
    x = torch.tensor([1.0], requires_grad=True)
    x.requires_grad = False
  • 特别示例,也可以对模型参数直接操作:对于预训练模型进行微调时,你可以冻结部分层,只对特定层计算梯度

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    import torch.nn as nn

    model = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 1)
    )

    # 冻结所有参数
    for param in model.parameters():
    param.requires_grad = False

    # 只训练最后一层
    for param in model[2].parameters():
    param.requires_grad = True

模型推理时的特殊场景

  • 特别地,推理时,常常用 model.eval() 和 with torch.no_grad() 或 @torch.no_grad() 结合

  • 在模型推理阶段,同时使用这两个方法能够有效减少内存占用并提高计算效率

    1
    2
    3
    model.eval()  # 关闭 Dropout 和 BatchNorm 等训练特有的层
    with torch.no_grad():
    outputs = model(inputs)
  • 在仅推理的场景中,更常用 torch.inference_mode() 来替代 torch.no_grad()(要求 PyTorch 版本 >= 1.9),以获得更好的性能(跳过一些推理阶段非必要的检查),详情见:PyTorch——torch.no_grad的用法

    • torch.inference_mode() 可以用作上下文管理器或者装饰器
    • 推理场景优先推荐使用 torch.inference_mode(),但 torch.inference_mode() 仅适用于推理场景,其他场景不可乱用
    • torch.inference_mode() 不能像 torch.no_grad() 那样嵌套使用

附录:PyTorch中禁用梯度计算的方法对比

  • 整体对比
    方法 是否不记录梯度 是否释放计算图内存 作用范围 使用场景
    with torch.no_grad(): 是 是 代码块 推理阶段(如模型预测)、不需要梯度的计算(如验证集评估)。
    torch.set_grad_enabled(False) 是 是 全局(直到恢复为True) 临时关闭整个代码段的梯度计算,例如批量推理。
    tensor.detach() 是 是(对新张量) 单个张量 从计算图中分离张量,例如生成对抗网络(GAN)中的生成器输出。
    tensor.requires_grad = False 是 是(设置后) 单个张量或模型参数 冻结预训练模型的部分层,只训练特定参数。
    model.eval() 否 否 模型(影响Dropout/BatchNorm),使用model.train()可恢复 推理阶段,关闭训练特有的层(如Dropout),但仍会记录梯度(需配合no_grad()使用)。

附录:前向过程不涉及梯度计算,为什么需要关闭梯度?

  • 在调用 loss.backward() 之前,PyTorch 不会计算梯度,但是模型的前向过程会构建计算图,这也会消耗额外的内存

附录:非常简单的代码示例

  • 简单的代码示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    import torch

    # 方法1: torch.no_grad()
    with torch.no_grad():
    y = x * 2 # y 不记录梯度,不构建计算图

    # 方法2: set_grad_enabled
    torch.set_grad_enabled(False)
    y = x * 2 # 全局禁用梯度
    torch.set_grad_enabled(True)

    # 方法3: detach()
    y = x.detach() * 2 # y 是脱离计算图的新张量

    # 方法4: requires_grad = False
    x.requires_grad = False
    y = x * 2 # x 不再需要梯度,y 也不记录

    # 方法5: model.eval()(需配合 no_grad())
    model.eval()
    with torch.no_grad():
    outputs = model(inputs) # 完全禁用梯度

PyTorch——torch.distributed.rpc的使用


整体说明

  • torch.distributed.rpc 库是 PyTorch 用于 RPC 通信的包,初始化 RPC 环境的核心函数是 init_rpc
  • 在 PyTorch 中
    • torch.distributed.init_process_group 用于初始化进程组环境, 数据并行 和 集体通信 ,比如同步梯度
    • torch.distributed.rpc.init_rpc 则用于初始化 RPC 通信环境, 模型并行 和 远程过程调用 ,比如在另一台机器上执行一个函数
  • 在很多复杂的分布式训练场景中(如全分片数据并行+流水线并行),torch.distributed 和 torch.distributed.rpc 这两个模块甚至会同时被使用 ,以发挥它们各自的优势

回顾:torch.distributed.init_process_group 初始化进程组

  • torch.distributed.init_process_group 函数是 PyTorch Distributed (c10d) 包的初始化入口

    • 注:c10d 是 “Caffe2 TenSoR Distributed” 的缩写,c10d 代表了 PyTorch 中负责分布式通信的核心底层模块,提供了跨进程、跨节点的数据传输和同步机制;实现了诸如 ProcessGroup(进程组)、Backend(通信后端,如 NCCL、Gloo 等)等核心组件,是 torch.distributed 高层 API 的底层支撑
  • torch.distributed.init_process_group 函数会启动一个进程组,用于执行集体通信

  • torch.distributed.init_process_group 启动的进程组的通信范式是集体通信 :

    • 所有进程都必须参与同一个操作,并且需要等待所有进程都完成该操作后才能继续
    • 常见的操作包括:
      • all_reduce:所有进程的Tensor进行汇总(如求和)后,将结果返回给所有进程。(用于梯度同步)
      • broadcast:将一个进程的Tensor广播给所有其他进程
      • barrier:同步所有进程,确保大家都到达同一个点
      • scatter, gather, all_gather 等
  • 特点是训练代码在每个进程(通常是每个GPU)上几乎是完全相同的

    • 每个进程都做同样的事情:读一部分数据、前向传播、反向传播、然后通过集体通信同步梯度
    • 示例如下:
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      # 示例:在每个进程上初始化进程组
      import torch.distributed as dist

      def setup(rank, world_size):
      dist.init_process_group(
      backend='nccl', # 或 'gloo', 'mpi'
      init_method='env://',
      rank=rank,
      world_size=world_size
      )
      # 创建模型、DDP包装、定义采样器等...
      torch.cuda.set_device(rank)

      # 假设使用 spawn 启动多个进程
      # 每个进程都会运行 setup 函数和后续相同的训练循环
  • 常用于:

    • 数据并行训练:(最经典的用法)
      • 在 DistributedDataParallel (DDP) 中,每个 GPU 上的模型副本计算完梯度后,使用 all_reduce 来对所有梯度进行求和平均,确保每个模型副本都得到相同的更新后的梯度
    • 分布式数据加载:使用 DistributedSampler 确保每个进程读取数据集中不重复的部分
    • 其他需要进程间紧密同步的计算任务

前置说明:RPC 框架建立在进程组之上

  • torch.distributed.rpc.init_rpc 启动的 RPC 建立在进程组之上:

    • RPC 框架底层依赖于 torch.distributed 的进程组来进行初始的握手和协调
    • 当调用 rpc.init_rpc 时,它会在内部检查是否已经存在一个初始化好的进程组
      • 如果存在,它就复用这个进程组
      • 如果不存在,它会*隐式地调用 dist.init_process_group * ,使用默认的 gloo 后端(除非你通过 rpc_backend_options 指定)来创建一个进程组
  • 在大多数情况下,只需要调用 rpc.init_rpc 即可,它会把底层需要的进程组也初始化好

    • 如果需要对进程组的后端(如使用nccl)或初始化方法进行更精细的控制,可以选择 显式地先调用 dist.init_process_group,然后再调用rpc.init_rpc ,代码示例如下:
      1
      2
      3
      4
      5
      6
      # 显式控制进程组后端的示例
      def setup(rank, world_size):
      # 1. 首先,用 NCCL 后端初始化进程组(为了更好的GPU通信性能)
      dist.init_process_group(backend='nccl', ...)
      # 2. 然后,初始化RPC框架。此时RPC检测到已有进程组,不会再次初始化
      rpc.init_rpc(...)
  • 亲测:init_rpc 使用的 master_ip:master_port,必须和 init_process_group 使用的一样,否则无法启动

    • 暂未看到官方明确说明(待补充),猜测原因是因为 RPC 框架依赖已经初始化好的进程组(若没有初始化进程组,init_rpc 会自己新建进程组)
    • 详细示例见附录
  • RPC 的 world_size 可以比 已初始化的进程组的多(虽然有依赖,但两者是相对独立的体系)【具体深层逻辑待梳理】

    • 即 init_rpc 的 world_size 可以比 init_process_group 的大
    • 具体实现时,部分 rank 进程不调用 init_process_group 即可实现

torch.distributed.rpc.init_rpc 初始化 PyTorch RPC 框架

  • torch.distributed.rpc.init_rpc 函数是 PyTorch RPC 框架 的初始化入口

  • torch.distributed.rpc.init_rpc 函数核心目的是启用远程过程调用 ,允许一个进程调用另一个进程(可能在不同的机器上)上定义的函数或方法

  • torch.distributed.rpc.init_rpc 初始化的 PyTorch RPC 框架 通信范式是点对点通信

    • 一个进程(工作者)可以主动请求另一个进程执行某个任务,而不需要所有进程都参与
    • 核心 API 包括:
      • rpc_sync():同步调用,调用方会等待远程函数执行完毕并返回结果
      • rpc_async():异步调用,调用方立即返回一个 Future 对象,未来再从中获取结果
      • remote():异步地在远程工作者上创建一个对象(例如一个模块),并返回一个指向该远程对象的引用(RRef)
  • 常用于:

    • 模型并行:将一个大模型的不同部分放在不同的机器或 GPU 上
      • 前向传播时,数据需要从一个机器流到下一个机器
      • RPC 可以管理这种跨机器的计算流
    • 参数服务器(PS):
      • 一个或多个进程作为参数服务器存储模型参数,其他工作者通过 RPC 向参数服务器拉取参数或推送梯度
    • 强化学习:
      • 多个环境模拟器(Actor)通过RPC将经验数据发送给一个中心 learner 进行训练
    • 分布式流水线并行(待补充):
      • 与 torch.distributed.pipeline.sync.Pipe 结合使用,RPC 负责处理不同设备间层的通信
  • 特点是每个进程的角色和执行的代码可能完全不同 ,一个简单的代码示例如下:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    # start shell: torchrun --nproc_per_node=2 torch_distributed_rpc_demo.py
    import os
    import torch
    import torch.distributed as dist
    import torch.distributed.rpc as rpc

    # 正常定义函数,可用于远程调用(在 worker0 发送命令从 worker1 执行函数)
    def worker_function(a, b):
    worker_info = rpc.get_worker_info()
    print(f"Worker {worker_info.name} received a call with a={a}, b={b}")
    return a + b

    def main():
    # 获取环境变量
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])

    # 1. 初始化进程组和 RPC
    dist.init_process_group(backend='gloo', rank=rank, world_size=world_size) # 建议提前初始化好进程组,如没有这行,下面的函数也会自动初始化一个 进程组
    rpc.init_rpc(
    name=f"worker_{rank}", # 若不唯一,会报错 RuntimeError: RPC worker name worker is not unique. Workers 1 and 0 share the same name
    rank=rank,
    world_size=world_size
    )

    print(f"Worker {rank} is ready.")

    if rank == 0:
    # 2. worker_0 发起 RPC 调用
    print("--- Worker 0 sending RPC calls ---")

    # 同步调用
    result_sync = rpc.rpc_sync(
    to=f"worker_1", # 真正的执行过程发生在 worker1 上
    func=worker_function,
    args=(10, 20)
    )
    print(f"Worker 0 received sync result from Worker 1: {result_sync}")

    # 异步调用
    future_async = rpc.rpc_async(
    to=f"worker_1",
    func=worker_function,
    args=(30, 40)
    )
    print("Worker 0 is doing other work while waiting for async result...")
    result_async = future_async.wait()
    print(f"Worker 0 received async result from Worker 1: {result_async}")

    else:
    # 3. worker_1 只是等待并处理 RPC 请求
    print("--- Worker 1 is waiting for RPC calls ---")
    # 由于 worker_0 正在调用,这里不需要做任何事,RPC 框架会自动处理

    rpc.shutdown() # 默认等待所有 RPC 调用完成后再关闭
    dist.destroy_process_group() # 注:若没有显示调用 dist.init_process_group,也不能调用这一行,否则会报错

    if __name__ == "__main__":
    main()
    • 上述代码的启动脚本(单机双卡):torchrun --nproc_per_node=2 torch_distributed_rpc_demo.py

rpc.init_rpc 函数 详细用法

  • PyTorch 的 torch.distributed.rpc.init_rpc 函数原型

    1
    2
    3
    4
    5
    6
    7
    torch.distributed.rpc.init_rpc(
    name,
    backend=None,
    rank=-1,
    world_size=None,
    rpc_backend_options=None
    )
  • 参数说明如下:

  • name(必填)

    • 当前工作进程的全局唯一名称 ,不能重复,重复会报错,后续会作为交互对象的指定
    • 名称只能包含数字、字母、下划线、冒号和/或短划线,且必须少于128个字符
    • 例如 "trainer0", "parameter_server1"
  • backend(选填)

    • 指定使用的RPC 后端实现
    • 可选值来自 torch.distributed.rpc.BackendType 枚举
      • BackendType.TENSORPIPE,更现代且推荐的 RPC 后端
      • BackendType.PROCESS_GROUP,不常用,对应的是 PyTorch 的后端,如 NCCL 和 Gloo
  • rank(必填)

    • 当前进程在组中的全局唯一 ID
    • 通常是一个整数,例如 0, 1, 2…
    • 必须与 world_size 一起正确设置
  • world_size(必填)

    • 参与作业的进程总数
    • 所有进程的 world_size 必须相同
  • rpc_backend_options(选填)

    • 传递给底层 RPC 代理的高级选项 ,类型必须与 backend 参数选定的值严格对齐,否则会报错
      • backend=BackendType.TENSORPIPE 对应 rpc_backend_options 类别为:TensorPipeRpcBackendOptions
      • backend=BackendType.PROCESS_GROUP 对应 rpc_backend_options 类别为:ProcessGroupRpcBackendOptions
    • 可常用于设置 RPC 超时、初始化方法等
      • 若未指定,默认超时为60秒,并使用 init_method="env://"(这意味着需要设置 MASTER_ADDR 和 MASTER_PORT 环境变量)
  • 使用示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    import torch.distributed.rpc as rpc
    from torch.distributed.rpc import TensorPipeRpcBackendOptions

    # 配置 TensorPipe 后端选项
    options = TensorPipeRpcBackendOptions(
    num_worker_threads=8, # 8 个工作线程,指定处理 RPC 请求的工作线程数,影响并发处理能力,默认值是 CPU 数量
    rpc_timeout=10000, # RPC 超时时间设置为 10 秒
    init_method="tcp://127.0.0.1:29500", # 初始化地址,格式和 init_process_group 的 init_method 参数一致
    device_maps={ "worker1": {0: 1}, "worker2": {1: 0} } # 本地 GPU 0 映射到 worker1 的 GPU 1;本地 GPU 1 映射到 worker2 的 GPU 0
    # devices=?, # 默认使用所有可用设备
    security_token="my_rpc_token" # str 类型安全令牌,用于验证节点身份的安全令牌,多机环境下防止未授权节点接入
    )

    # 初始化 RPC,传入配置
    rpc.init_rpc(
    name="worker0", # 当前节点名称
    backend=rpc.BackendType.TENSORPIPE,
    rpc_backend_options=options,
    rank=0, # 全局 rank
    world_size=2 # 总节点数
    )

TensorPipeRpcBackendOptions 的初始化细节

  • TensorPipeRpcBackendOptions 的函数原型为:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    def __init__(
    self,
    *, # Python 中可变位置参数的用法,仅使用一个 * 表示后续的参数必须按照 关键字参数形式调用
    num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS,
    rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC,
    init_method: str = rpc_contants.DEFAULT_INIT_METHOD,
    device_maps: Optional[Dict[str, Dict[DeviceType, DeviceType]]] = None,
    devices: Optional[List[DeviceType]] = None,
    _transports: Optional[List] = None,
    _channels: Optional[List] = None,
    )
  • 除了前文示例指定的参数外,其他参数介绍如下:

  • devices 参数

    • 指定当前节点上可用于 RPC 通信的设备(主要是 GPU),限制 RPC 操作只能使用列表中声明的设备
    • 默认值是 None(表示允许使用所有可见设备)
    • 当设置为非空列表时(如 [0, 1] 或 [torch.device('cuda:0'), torch.device('cuda:1')]),TensorPipe 只会使用列表中的设备进行数据传输,忽略其他设备
    • 主要用于多 GPU 场景下的资源隔离,例如限制 RPC 通信仅使用特定 GPU,避免占用训练用 GPU 资源
  • _transports 参数(以下划线开头,是私有参数,谨慎使用)

    • 指定 TensorPipe 可使用的传输协议(底层物理通信方式),用于跨节点或跨进程的数据传输
    • (推荐)默认值是自动检测并选择最优传输方式(通常为 ["tcp", "ibv"],即 TCP 和 InfiniBand)
    • 支持的传输协议:
      • "tcp":基于 TCP/IP 的网络传输,兼容性好,适用于普通网络环境
      • "ibv":基于 InfiniBand 的传输,低延迟、高带宽,适用于高性能计算集群
      • "shm":表示优先使用共享内存作为传输介质
      • "uv":基于 libuv 的本地传输(用于同一节点内的进程通信,通常自动启用)
    • 列表顺序表示优先级,TensorPipe 会优先尝试前面的传输方式
    • 仅在需要强制指定传输协议时使用(如集群仅支持 InfiniBand 时设置 ["ibv"]),默认自动选择即可满足多数场景
  • _channels 参数(以下划线开头,是私有参数,谨慎使用)

    • 指定 TensorPipe 可使用的通道类型(数据在设备间的内存传输方式),用于同一节点内的设备间通信(如 CPU-GPU、GPU-GPU)
    • (推荐)默认是自动检测并选择最优通道(通常为 ["cuda_ipc", "cuda_basic", "shm"])
    • 支持的通道类型有:
      • "cuda_ipc":GPU 进程间通信(同一节点内不同进程的 GPU 直接通信,效率最高)
      • "cuda_basic":通过 CPU 中转的 GPU 通信(当 cuda_ipc 不可用时降级使用)
      • "shm":共享内存通信(同一节点内的 CPU 进程间通信),这里强调通信通道,_transports 的 "shm" 强调通信介质
      • "basic":通过系统内存拷贝的通信(通用 fallback 方式)
    • 列表顺序表示优先级,优先使用高效的通道(如 cuda_ipc 优于 cuda_basic)
    • 主要用于调试或特殊硬件环境下的兼容性调整,默认配置已针对性能优化

初始化常见错误排查

  • 初始化失败:检查 MASTER_ADDR 和 MASTER_PORT 是否设置正确且所有机器可达,防火墙是否放行了指定端口,以及 rank 和 world_size 是否配置正确
  • 超时错误:增加 rpc_backend_options 中的 rpc_timeout 值
  • 版本不匹配:确保所有参与计算的 PyTorch 版本一致

RPC 框架的交互模式

  • PyTorch 的 torch.distributed.rpc 模块为 分布式模型训练 提供了核心的远程过程调用(RPC)和远程引用(RRef)机制
  • 以下是 RPC 相关核心方法:
  • rpc_sync()(同步方法,阻塞)
    • 在远程 worker 上同步调用函数
    • 返回函数调用的直接结果
  • rpc_async()(异步方法,非阻塞)
    • 在远程 worker 上异步调用函数
    • 返回一个 Future 对象,可通过 future.wait() 获取结果
  • remote()(非阻塞)
    • 异步在远程 worker 上创建对象
    • 返回一个 RRef (远程引用) 指向创建的对象
  • RRef.to_here()(阻塞)
    • 将 RRef 引用的值从所有者复制到本地节点
    • 返回引用的对象本身

同步远程调用 (rpc_sync)

  • rpc_sync 会阻塞调用者,直到远程函数执行完成并返回结果

  • 函数原型:

    1
    2
    3
    4
    5
    6
    7
    8
    # 在目标 worker (例如 to="worker1") 上同步运行函数 func 并获取结果
    result = torch.distributed.rpc.rpc_sync(
    to, # (str or WorkerInfo or int) 目标 worker 的名称、rank 或 WorkerInfo
    func, # (Callable) 待执行函数,是一个可调用函数(如 Python 函数、内置运算符或 TorchScript 函数)
    args=None, # (tuple) 传给函数的可变位置参数
    kwargs=None, # (dict) 传给函数的可变关键字参数
    timeout=-1.0 # (float, optional) RPC 超时时间(秒),-1.0 表示使用默认值(timeout=UNSET_RPC_TIMEOUT, 此时走初始化逻辑,或_set_rpc_timeout 设置的值),0 表示无限超时
    )
  • 使用示例:

    1
    2
    3
    # 在 worker 0 上:
    ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
    # 在 worker 1 上需要定义相应的函数或直接使用内置函数, worker 1 执行函数体,并返回结果

异步远程调用 (rpc_async)

  • rpc_async 会立即返回一个 Future 对象,你可以在之后需要时等待并获取结果,适用于非阻塞调用

  • 函数原型:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    # 异步调用远程函数,返回 Future 对象
    future = torch.distributed.rpc.rpc_async(
    to, # (str or WorkerInfo or int) 目标 worker 的名称、rank 或 WorkerInfo
    func, # (Callable) 待执行函数,是一个可调用函数(如 Python 函数、内置运算符或 TorchScript 函数)
    args=None, # (tuple) 传给函数的可变位置参数
    kwargs=None, # (dict) 传给函数的可变关键字参数
    timeout=-1.0 # (float, optional) RPC 超时时间(秒),-1.0 表示使用默认值(timeout=UNSET_RPC_TIMEOUT, 此时走初始化逻辑,或_set_rpc_timeout 设置的值),0 表示无限超时
    # 注:_set_rpc_timeout 是 PyTorch 分布式 RPC 框架中用于动态修改全局默认超时时间的函数,通常在 init_rpc(RPC 环境初始化)之后调用
    # 注:异步操作不会阻塞,这里的 timeout 是从此刻开始,远程操作本身返回的时间,如果有错也是在 future.wait() 时抛出
    )
    # ... 执行其他操作 ...
    result = future.wait() # 等待并获取结果,这里也可以设置 future 从此刻开始的等待 timeout, 设置后与 rpc_async 中的 timeout 时间两者时间取小(注意两者起始时间不同)
  • 使用示例:

    1
    2
    3
    fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3))
    fut2 = rpc.rpc_async("worker1", min, args=(1, 2))
    result = fut1.wait() + fut2.wait()

远程对象创建与引用 (remote 和 RRef)

  • remote 方法用于在远程 worker 上异步创建对象 ,并立即返回一个指向该对象的 RRef (Remote Reference)

  • RRef 的 to_here() 方法可以阻塞地将远程对象的值复制到本地

  • remote 函数原型

    1
    2
    3
    4
    5
    6
    7
    8
    9
    # 在远程 worker 上创建对象并返回 RRef,持有 func 函数的返回值
    rref = torch.distributed.rpc.remote(
    to, # (str or WorkerInfo or int) 目标 worker 的名称、rank 或 WorkerInfo,目标 worker(将成为所有者)
    func, # (Callable) 用于创建对象的函数,是一个可调用函数(如 Python 函数、内置运算符或 TorchScript 函数)
    args=None, # (tuple) 传给函数的可变位置参数
    kwargs=None, # (dict) 传给函数的可变关键字参数
    timeout=-1.0 # (float, optional) RPC 超时时间(秒),-1.0 表示使用默认值(timeout=UNSET_RPC_TIMEOUT, 此时走初始化逻辑,或_set_rpc_timeout 设置的值),0 表示无限超时
    # 注:异步操作不会阻塞,这里的 timeout 是从此刻开始,远程操作本身返回的时间,如果有错也是在 rref.to_here() 时抛出
    )
  • RRef 对象的一些使用方式:

    • 下面的两种调用相同
      1
      2
      3
      4
      5
      6
      7
      8
      9
      rref = rpc.remote("worker1", MyClass, args=(arg1, arg2))

      # 方式一:直接调用远程对象的函数(func_name 必须是 rref 持有对象的函数)
      rref.rpc_async().func_name(*args, **kwargs)

      # 方式二:原生的 rpc 实现过程,较为复杂,效果和方式一完全相同
      def run(rref, func_name, args, kwargs):
      return getattr(rref.local_value(), func_name)(*args, **kwargs)
      rpc.rpc_async(rref.owner(), run, args=(rref, func_name, args, kwargs))
  • RRef 的更多使用详情见附录

shutdown 函数的使用

  • 在所有 RPC 工作完成后、进程退出前,必须显式调用 torch.distributed.rpc.shutdown ();
  • 推荐使用 graceful=True(默认) 以等待所有未完成 RPC 结束并避免死锁与 SIGABRT 风险
  • 函数原型
    1
    2
    3
    4
    5
    6
    7
    # torch.distributed.rpc.shutdown
    def shutdown(
    graceful=True, # (bool),默认值为 True,表示是否优雅的结束
    # 如果 True,则 1) 等待所有针对 UserRRef 的系统消息处理完毕并删除这些引用;2) 阻塞直到本地与远程所有 RPC 进程均调用此方法,并等待所有未完成工作完成;
    # 当 graceful=False 时:不等待未完成的 RPC 任务和其他进程,直接强制关闭本地 RPC 系统。可能导致未完成任务失败或资源泄漏,仅建议在紧急退出场景使用
    timeout=DEFAULT_SHUTDOWN_TIMEOUT # 限制 “优雅等待” 的最大时长:
    )

重要注意事项

  • 较早的 PyTorch 版本(如 1.4 之前)API 可能不能使用
  • 分布式作业中的每个进程都必须成功调用 init_rpc,并在最后调用 rpc.shutdown() 来优雅地释放资源
  • RPC 框架在处理张量时,通常视为张量位于 CPU 上进行操作,以避免设备不匹配的错误
    • 如果需要在 GPU 上操作,可能需要在函数内部显式地将张量移动到 GPU
  • 网络问题、函数执行失败或超时都可能导致 RPC 调用失败,建议添加适当的异常处理
  • RPC 调用会有通信开销
    • 应尽量减少小规模的频繁调用,考虑批量操作或传输更大尺寸的数据以减少开销

附录:init_rpc 的初始化 IP 和端口要与 init_process_group 一致

  • 亲测:init_rpc 使用的 master_ip:master_port,必须和 init_process_group 使用的一样,否则无法启动
    • 暂未看到官方明确说明(待补充),猜测原因是因为 RPC 框架依赖已经初始化好的进程组(若没有初始化进程组,init_rpc 会自己新建进程组)
  • 示例如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    # start shell: torchrun --nproc_per_node=2 torch_distributed_rpc_demo.py
    import os
    import torch.distributed as dist
    import torch.distributed.rpc as rpc
    from torch.distributed.rpc import TensorPipeRpcBackendOptions

    def worker_function(a, b):
    worker_info = rpc.get_worker_info()
    print(f"Worker {worker_info.name} received a call with a={a}, b={b}")
    return a + b

    def main():
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    master_addr = os.environ.get('MASTER_ADDR', '未设置')
    master_port = int(os.environ.get('MASTER_PORT', '未设置'))
    print(f"master addr: {master_addr}:{master_port}")

    dist.init_process_group(backend='gloo', rank=rank, world_size=2,init_method=f"tcp://{master_addr}:{master_port}") # 提前初始化进程组

    rpc_backend_options = TensorPipeRpcBackendOptions(
    init_method=f"tcp://{master_addr}:{master_port}", # 与进程组初始化(init_process_group)时相同 IP和端口,可以正常初始化,端口不一致时会一直卡在这里不动
    num_worker_threads=2,
    rpc_timeout=10
    )

    # rpc_backend_options = None
    print(f"rank {rank} start init_rpc..., dist.is_initialized={dist.is_initialized()}")
    rpc.init_rpc(
    name=f"worker_{rank}",
    rank=rank,
    world_size=world_size,
    rpc_backend_options=rpc_backend_options
    )
    print(f"rank {rank} have done init_rpc")

    print(f"Worker {rank} is ready.")

    if rank == 0:
    result_sync = rpc.rpc_sync(
    to=f"worker_1", # 真正的执行过程发生在 worker1 上
    func=worker_function,
    args=(10, 20)
    )
    print(f"Worker 0 received sync result from Worker 1: {result_sync}")
    else:
    print("--- Worker 1 is waiting for RPC calls ---")

    rpc.shutdown() # 等待所有 RPC 完成后自动关闭
    dist.destroy_process_group() # 注:若没有显示调用 dist.init_process_group,也不用这一行

    if __name__ == "__main__":
    main()

附录:RPC 的 world_size 可以比 已初始化的进程组的多

  • RPC 的 world_size 可以比 已初始化的进程组的多(虽然有依赖,但两者是相对独立的体系)【具体深层逻辑待梳理】
    • 即 init_rpc 的 world_size 可以比 init_process_group 的大
    • 具体实现时,部分 rank 进程不调用 init_process_group 即可实现
  • 示例如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    # start shell: torchrun --nproc_per_node=4 torch_distributed_rpc_demo.py
    import os
    import torch.distributed as dist
    import torch.distributed.rpc as rpc
    from torch.distributed.rpc import TensorPipeRpcBackendOptions

    def worker_function(a, b):
    worker_info = rpc.get_worker_info()
    print(f"Worker {worker_info.name} received a call with a={a}, b={b}")
    return a + b

    def main():
    # 获取环境变量
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    master_addr = os.environ.get('MASTER_ADDR', '未设置')
    master_port = int(os.environ.get('MASTER_PORT', '未设置'))
    print(f"master addr: {master_addr}:{master_port}")

    if rank < 2: # 仅 rank=0,1 的节点初始化进程组
    print(f"rank {rank} start init_process_group...")
    dist.init_process_group(
    backend='gloo',
    rank=rank,
    world_size=2, # world_size 设置为2,甚至仅初始化一个也可以,rank<2 改成 rank<1 即可
    init_method=f"tcp://{master_addr}:{master_port}"
    ) # 提前初始化好进程组
    print(f"rank {rank} have done init_process_group")
    print(dist)

    rpc_backend_options = TensorPipeRpcBackendOptions(
    init_method=f"tcp://{master_addr}:{master_port}", # 与进程组初始化(init_process_group)时相同 IP和端口,可以正常初始化,端口不一致时会一直卡在这里不动
    num_worker_threads=2,
    rpc_timeout=10
    )

    print(f"rank {rank} start init_rpc..., dist.is_initialized={dist.is_initialized()}")
    rpc.init_rpc(
    name=f"worker_{rank}",
    rank=rank,
    world_size=world_size, # world_size 由传入的确定,实际上可能是 4
    rpc_backend_options=rpc_backend_options
    )
    print(f"rank {rank} have done init_rpc")

    print(f"Worker {rank} is ready.")

    if rank == 0:
    print("--- Worker 0 sending RPC calls ---")

    # 同步调用
    result_sync = rpc.rpc_sync(
    to=f"worker_1", # 真正的执行过程发生在 worker1 上
    func=worker_function,
    args=(10, 20)
    )
    print(f"Worker 0 received sync result from Worker 1: {result_sync}")

    # 异步调用
    future_async = rpc.rpc_async(
    to=f"worker_1",
    func=worker_function,
    args=(30, 40)
    )
    print("Worker 0 is doing other work while waiting for async result...")
    result_async = future_async.wait()
    print(f"Worker 0 received async result from Worker 1: {result_async}")

    elif rank == 3:
    print("--- Worker 3 sending RPC calls ---")

    # 同步调用
    result_sync = rpc.rpc_sync(
    to=f"worker_2", # 真正的执行过程发生在 worker2 上
    func=worker_function,
    args=(10, 20)
    )
    print(f"Worker 3 received sync result from Worker 2: {result_sync}")

    else:
    print("--- Worker 1 is waiting for RPC calls ---")

    rpc.shutdown() # 等待所有 RPC 完成后自动关闭
    if dist.is_initialized(): # 必须判断,否则可能报错,因为部分进程为初始化 init_process_group
    dist.destroy_process_group() # 注:若没有显示调用 dist.init_process_group,不用这一行

    if __name__ == "__main__":
    main()

附录:RRef 的使用方式细节

  • 参考自:PyTorch 官方文档-RRef
  • Warning:当前使用 CUDA 张量时,暂不支持 RRef,需要把张量从 GPU 转移到 CPU 才可以
  • RRef(Remote REFerence,远程引用)是指向远程工作节点上某一类型(如张量 Tensor)值的引用
    • 该句柄会确保被引用的远程值在其所属节点上保持有效
    • 在多机训练场景中,RRef 可用于持有对其他工作节点上 nn.Module(神经网络模块)的引用,并在训练过程中调用相应函数来获取或修改这些模块的参数
    • 更多细节可参考:Remote Reference Protocol
  • 注:RRef 也是一个 Python 类(torch.distributed.rpc.RRef),但有时泛指所有远程引用,而不是某一个类

torch.distributed.rpc.PyRRef 类型总体说明

  • 类型定义原型:

    1
    2
    3
    4
    5
    6
    7
    # __pybind11_builtins.pybind11_object 表明 PyRRef 是一个通过 pybind11 技术从 C++ 创建的 Python 类
    class PyRRef(__pybind11_builtins.pybind11_object):
    # ...

    # RRef 的定义
    class RRef(PyRRef, Generic[T]):
    pass
  • PyRRef 类用于封装对远程工作节点上某一类型值的引用

    • 此句柄会确保被引用的远程值在对应工作节点上保持有效
    • 当满足以下任一条件时,UserRRef(用户侧远程引用)将被销毁:
      • 1)在应用代码和本地 RRef 上下文环境中,均不存在对该 UserRRef 的引用;
      • 2)应用已调用 graceful shutdown 操作
    • 调用已销毁 RRef 的方法会导致未定义行为
      • RRef 实现仅提供“尽力而为”的错误检测机制(best-effort error detection),因此在调用 rpc.shutdown()(RPC 关闭函数)后,应用不应再使用 UserRRef
  • 警告:RRef 只能由 RPC 模块进行序列化和反序列化操作

    • 若不通过 RPC 模块(如使用 Python pickle 模块、PyTorch 的 save()/load() 函数、JIT 的 save()/load() 函数等)对 RRef 进行序列化或反序列化,将会引发错误

PyRRef 类型初始化参数(不常用,一般不自己定义 RRef)

  • value(object 类型):需用此 RRef 封装的值
  • type_hint(Type 类型,可选):传递给 TorchScript 编译器的 Python 类型提示,用于指定 value 的类型

PyRRef 类型使用示例

  • 以下示例为简化说明,省略了 RPC 初始化和关闭相关代码。有关这些操作的详细内容,请参考 RPC 文档

  • 使用 rpc.remote() 创建 RRef

    1
    2
    3
    4
    5
    6
    7
    import torch
    import torch.distributed.rpc as rpc

    # 在远程工作节点“worker1”上执行 torch.add 操作,并创建指向结果的 RRef
    rref = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) # Debug 时看到的可能是 torch._C._distributed_rpc.PyRRef 类型,实际和 torch.distributed.rpc.PyRRef 是同一个类的不同路径
    # 从 RRef 中获取值的副本
    x = rref.to_here()
  • 基于本地对象创建 RRef

    1
    2
    3
    4
    5
    6
    7
    import torch
    from torch.distributed.rpc import RRef

    # 创建本地张量
    x = torch.zeros(2, 2)
    # 用 RRef 封装本地张量,生成 RRef 对象
    rref = RRef(x)
  • 与其他工作节点共享 RRef

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    # 在工作节点“worker0”和“worker1”上均执行以下函数定义
    def f(rref):
    # 获取 RRef 指向的值并加 1
    return rref.to_here() + 1

    # 在工作节点“worker0”上执行以下代码
    import torch
    import torch.distributed.rpc as rpc
    from torch.distributed.rpc import RRef

    # 基于本地零矩阵创建 RRef
    rref = RRef(torch.zeros(2, 2))
    # 通过 RPC 调用,将 RRef 共享给工作节点“worker1”,引用计数会自动更新
    rpc.rpc_sync("worker1", f, args=(rref,))

PyRRef.backward 方法

  • backward 方法原型

    1
    backward(self: torch._C._distributed_rpc.PyRRef, dist_autograd_ctx_id: int = -1, retain_graph: bool = False) -> None
  • backward 方法以 RRef 作为反向传播的根节点,执行反向传播过程

  • 若提供了 dist_autograd_ctx_id(分布式自动求导上下文 ID),则会使用该上下文 ID,从 RRef 的所属节点开始执行分布式反向传播

    • 这种情况下,应使用 get_gradients() 函数获取梯度
  • 若 dist_autograd_ctx_id 为 None,则默认当前为本地自动求导图,仅执行本地反向传播

    • 在本地反向传播场景中,调用此 API 的节点必须是该 RRef 的所属节点,且 RRef 所指向的值需为标量张量(scalar Tensor)
  • PyRRef.backward 方法参数

    • dist_autograd_ctx_id(int 类型,可选):用于获取梯度的分布式自动求导上下文 ID,默认值为 -1
    • retain_graph(bool 类型,可选):若设为 False,则计算梯度所用的计算图会被释放
      • 几乎不需要设为 True
      • 一般仅在需要多次执行反向传播时,才需将其设为 True,默认值为 False
  • PyRRef.backward 方法示例

    1
    2
    3
    4
    5
    6
    import torch.distributed.autograd as dist_autograd

    # 创建分布式自动求导上下文,并获取上下文 ID
    with dist_autograd.context() as context_id:
    # 以当前 RRef 为根节点,执行分布式反向传播
    rref.backward(context_id)

PyRRef 一些简单方法说明

  • confirmed_by_owner 方法:

    1
    confirmed_by_owner(self: torch._C._distributed_rpc.PyRRef)  bool
    • 返回该 RRef 是否已得到其所属节点的确认(真正拥有对象的节点为所属节点,对象时存储到所属节点上的)
      • OwnerRRef(所属节点侧远程引用)始终返回 True;
      • UserRRef(用户侧远程引用)仅在所属节点知晓该 UserRRef 存在时,才返回 True
  • is_owner 方法:

    1
    is_owner(self: torch._C._distributed_rpc.PyRRef) -> bool
    • 返回当前节点是否为该 RRef 的所属节点(对象时存储到所属节点上的)
  • local_value 方法:

    1
    local_value(self: torch._C._distributed_rpc.PyRRef) -> object
    • 若当前节点是该 RRef 的所属节点 ,则返回对本地值的引用;否则,抛出异常
  • owner 方法:

    1
    owner(self: torch._C._distributed_rpc.PyRRef) -> torch._C._distributed_rpc.WorkerInfo
    • 返回该 RRef 所属节点的工作节点信息(WorkerInfo 对象)
  • owner_name 方法:

    1
    owner_name(self: torch._C._distributed_rpc.PyRRef) -> str
    • 返回该 RRef 所属节点的工作节点名称(字符串形式)

PyRRef.remote 方法

  • remote 方法原型:

    1
    remote(self: torch._C._distributed_rpc.PyRRef, timeout: float = -1.0) -> object
  • remote 方法创建一个辅助代理(helper proxy)

    • 该代理可以用于发起 remote 调用,以 RRef 的所属节点为目标节点,在该 RRef 所指向的对象上执行函数
  • 更具体地说,rref.remote().func_name(*args, **kwargs) 等价于以下代码:

    1
    2
    3
    4
    5
    6
    def run(rref, func_name, args, kwargs):
    # 获取 RRef 本地值,并调用其指定方法
    return getattr(rref.local_value(), func_name)(*args, **kwargs)

    # 向 RRef 所属节点发起 remote 调用,执行 run 函数
    rpc.remote(rref.owner(), run, args=(rref, func_name, args, kwargs))
  • remote 方法参数:

    • timeout(float 类型,可选):rref.remote() 的超时时间(单位:秒)
      • 若在该时间内未成功创建 RRef,则下次尝试使用该 RRef(如调用 to_here() 方法)时,会引发超时错误
      • 若未提供该参数,则使用默认的 RPC 超时时间
  • remote 方法示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    from torch.distributed import rpc

    # 在远程节点“worker1”上执行 torch.add,创建指向结果的 RRef
    rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
    # 调用 RRef 指向对象的 size() 方法,获取结果(远程执行,本地获取)
    rref.remote().size().to_here() # 返回结果:torch.Size([2, 2])
    # 注:返回类型
    # rref.remote(): <torch.distributed.rpc.rref_proxy.RRefProxy object at 0x11fa6e470>
    # rref.remote().size(): UserRRef(RRefId = GloballyUniqueId(created_on=0, local_id=4), ForkId = GloballyUniqueId(created_on=0, local_id=5))
    # rref.remote().size().to_here(): torch.Size([2])

    # 调用 RRef 指向对象的 view() 方法,重塑张量形状并获取结果
    rref.remote().view(1, 4).to_here() # 返回结果:tensor([[1., 1., 1., 1.]])

PyRRef.rpc_async 方法

  • rpc_async 方法原型:

    1
    rpc_async(self: torch._C._distributed_rpc.PyRRef, timeout: float = -1.0) -> object
  • rpc_async 方法创建一个辅助代理

    • 该代理可发起 rpc_async 调用,以 RRef 的所属节点为目标节点,在该 RRef 所指向的对象上执行函数
  • 更具体地说,rref.rpc_async().func_name(*args, **kwargs) 等价于以下代码:

    1
    2
    3
    4
    5
    6
    def run(rref, func_name, args, kwargs):
    # 获取 RRef 本地值,并调用其指定方法
    return getattr(rref.local_value(), func_name)(*args, **kwargs)

    # 向 RRef 所属节点发起异步 RPC 调用,执行 run 函数
    rpc.rpc_async(rref.owner(), run, args=(rref, func_name, args, kwargs))
  • rpc_async 方法参数

    • timeout(float 类型,可选):rref.rpc_async() 的超时时间(单位:秒)
      • 若在该时间内调用未完成,则会引发超时异常
      • 若未提供该参数,则使用默认的 RPC 超时时间
  • rpc_async 方法示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    from torch.distributed import rpc

    # 在远程节点“worker1”上执行 torch.add,创建指向结果的 RRef
    rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
    # 异步调用 size() 方法,等待结果返回
    rref.rpc_async().size().wait() # 返回结果:torch.Size([2, 2])
    # 注:返回类型
    # rref.rpc_async(): <torch.distributed.rpc.rref_proxy.RRefProxy object at 0x11bcfa3e0>
    # rref.rpc_async().size(): <torch.jit.Future object at 0x11bcc65c0>
    # rref.rpc_async().size().wait(): torch.Size([2])

    # 异步调用 view() 方法,重塑张量形状并等待结果返回
    rref.rpc_async().view(1, 4).wait() # 返回结果:tensor([[1., 1., 1., 1.]])

PyRRef.rpc_sync 方法

  • rpc_sync 方法原型:

    1
    rpc_sync(self: torch._C._distributed_rpc.PyRRef, timeout: float = -1.0) -> object
  • rpc_sync 方法创建一个辅助代理

    • 该代理可发起 rpc_sync 调用,以 RRef 的所属节点为目标节点,在该 RRef 所指向的对象上执行函数
  • 更具体地说,rref.rpc_sync().func_name(*args, **kwargs) 等价于以下代码:

    1
    2
    3
    4
    5
    6
    def run(rref, func_name, args, kwargs):
    # 获取 RRef 本地值,并调用其指定方法
    return getattr(rref.local_value(), func_name)(*args, **kwargs)

    # 向 RRef 所属节点发起同步 RPC 调用,执行 run 函数
    rpc.rpc_sync(rref.owner(), run, args=(rref, func_name, args, kwargs))
  • rpc_sync 方法参数:

    • timeout(float 类型,可选):rref.rpc_sync() 的超时时间(单位:秒)
      • 若在该时间内调用未完成,则会引发超时异常
      • 若未提供该参数,则使用默认的 RPC 超时时间
  • rpc_sync 方法示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    from torch.distributed import rpc

    # 在远程节点“worker1”上执行 torch.add,创建指向结果的 RRef
    rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
    # 同步调用 size() 方法,获取张量尺寸
    rref.rpc_sync().size() # 返回结果:torch.Size([2, 2])
    # 注:返回类型:
    # rref.rpc_sync(): <torch.distributed.rpc.rref_proxy.RRefProxy object at 0x11bcfa440>
    # rref.rpc_sync().size(): torch.Size([2])

    # 同步调用 view() 方法,重塑张量形状
    rref.rpc_sync().view(1, 4) # 返回结果:tensor([[1., 1., 1., 1.]])

PyRRef.to_here 方法

  • to_here 方法原型:

    1
    to_here(self: torch._C._distributed_rpc.PyRRef, timeout: float = -1.0) -> object
  • to_here 方法是阻塞式调用,将 RRef 指向的值从其所属节点复制到本地节点,并返回该值

    • 若当前节点是该 RRef 的所属节点,则直接返回对本地值的引用
  • to_here 方法参数:

    • timeout(float 类型,可选):to_here 方法的超时时间(单位:秒)
      • 若在该时间内调用未完成,则会引发超时异常
      • 若未提供该参数,则使用默认的 RPC 超时时间(60 秒)

附录:torch.distributed.rpc.PyRRef 和 torch._C._distributed_rpc.PyRRef 的区别

  • 这两个是同一个类的不同引用路径,但有重要的使用区别,两者本质关系为

    1
    2
    3
    4
    # 推荐使用的方式;稳定,版本间兼容性好;可能包含额外的封装、错误检查、文档
    torch.distributed.rpc.PyRRef # 公共API接口,官方支持,文档完整
    # 不推荐直接使用;内部实现细节;可能在PyTorch版本更新时发生变化;直接的C++绑定,可能缺少一些Python层的便利功能
    torch._C._distributed_rpc.PyRRef # 底层C++实现,以`_`开头表示私有
  • 实际上 torch.distributed.rpc.PyRRef 就是对 torch._C._distributed_rpc.PyRRef 的封装或直接引用

  • 验证两者的等价关系

    1
    2
    3
    4
    5
    import torch.distributed.rpc as rpc
    import torch._C._distributed_rpc as _rpc

    # 通常这两个是相同的对象
    print(rpc.PyRRef is _rpc.PyRRef) # 可能输出 True

为什么会看到内部 API torch._C._distributed_rpc.PyRRef?

  • 调试时的堆栈跟踪可能显示内部路径
  • 某些高级用法或源码分析
  • 类型检查时可能遇到

PyTorch——叶子张量


整体说明

  • tensor.is_leaf=True 的张量被称为叶子张量(Leaf Tensor),也称为叶子变量(Leaf Variable),部分博客或书籍也称为叶子节点
  • 叶子张量分两种类型,即以下两种情况下的张量 tensor.is_leaf 返回 True:
    • 类型一:requires_grad 为 False 的张量都是叶子张量(由于 requires_grad 为 False,所以不会存储梯度)
    • 类型二:requires_grad 为 True 的张量,如果是由用户创建的 ,而不是通过其他张量运算得到的,那么它是叶子张量;
      • 通过其他张量运算得到的,都是非叶子张量
  • 特别说明:
    • detach() 函数可以将节点从计算图中剥离,使其成为叶子节点,此时 requires_grad 为 False,同时 tensor.is_leaf 会变为 True 了
    • 从 CPU 定义好后移动到 GPU 时产生的张量,或者从 GPU 定义后,挪到 CPU 上的向量,也都是非叶子张量,tensor.is_leaf 返回 False
      • 注意:不论是 GPU 还是 CPU,叶子张量的判定不变,只有用户定义的张量是叶子张量,挪动以后得都不是叶子张量(除非同时修改其
    • 只有叶子张量的 requires_grad 属性可以被修改,非叶子张量的 requires_grad 属性是不能被修改的
      • 理解:非叶子张量都是派生出来的,且 requires_grad=True 的张量,叶子张量计算梯度时依赖性非叶子张量的梯度,不能随便修改 requires_grad 属性
    • 叶子张量的 grad_fn=None (包括 requires_grad=True 和 requires_grad=False 的都是)
      • 理解:tensor.grad_fn 属性指向/存储生成 tensor 张量的计算操作(如加法、减法、乘法等),叶子张量要么是直接由用户定义出来的(此时 requires_grad=True),要么是不需要计算梯度的(此时 requires_grad=False)
    • 非叶子张量的梯度不会被保存(因为不需要使用)

叶子张量的特性

  • 在反向传播过程中,只有叶子张量的梯度会被保留并存储在张量的grad属性中
    • 用于后续优化器更新参数等操作
    • 比如:神经网络层中的权值 w 的张量均为叶子节点,反向传播 backward() 就是为了求它们的梯度,进而更新权值
  • 非叶子张量的梯度在反向传播使用完后通常会被清除 ,以节省内存
    • 比如:计算中中间的一些派生阶段就是非叶子张量

叶子张量相关示例

  • 各种叶子张量判断示例
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    import torch as torch

    a = torch.randn(3,3)
    print(a.is_leaf, a.requires_grad) # True False
    a.requires_grad = True
    print(a.is_leaf, a.requires_grad) # True True
    b = a.cuda()
    print(b.is_leaf, b.requires_grad) # False True
    c = b.detach()
    print(c.is_leaf, c.requires_grad) # True False
    d = a + 2
    print(d.is_leaf, d.requires_grad) # False True

    e = torch.randn(3,3, device="cuda", requires_grad=True)
    print(e.is_leaf, e.requires_grad) # True True
    f = e.to("cpu")
    print(f.is_leaf, f.requires_grad) # False True

附录:如何打印非叶子张量的梯度?

  • 在 PyTorch 里,当开启自动求导功能时,中间变量(非叶子张量)的梯度默认不会被保存,目的是节省内存
    • 只有叶子节点(比如直接创建的张量)的梯度会被保留
    • 在任意时刻时,非中间节点的梯度(grad属性)是都是 None
  • 要获取中间变量的梯度,有两种方法:
    • 运用 retain_grad() 方法保留梯度
    • 借助钩子(hook)来捕获梯度(可以打印或者赋值给全局变量)
  • 打印非节点张量的代码示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    import torch

    x = torch.tensor(2.0, requires_grad=True)
    y = x**2
    z = y**2

    # 方法一:使用 retain_grad() 显式保留梯度
    y.retain_grad()

    z.backward(retain_graph=True) # 若不使用 retain_graph=True,计算图会在 backward 被清空,则后续想要调用 backward() 前需要重新构造计算图
    print("方法一:", y.grad) # 输出: 16.0

    # 方法二:使用钩子 hook 捕捉梯度
    gradient_list = []
    def save_gradient(grad):
    gradient_list.append(grad)

    ## 如果 前面的 z.backward(retain_graph=True) 不使用 retain_graph=True,则 backward() 会清空计算图,这里就需要重新构造计算图,目前使用的是 z.backward(retain_graph=True) ,不需要下面的两句
    # y = x**2
    # z = y**2

    hook = y.register_hook(save_gradient)
    z.backward()
    hook.remove() # 重点:即时移除钩子,防止不必要的内存占用

    print("方法二:", gradient_list[0]) # 输出: 16.0

PyTorch——模型存储与加载


整体说明

  • 在 PyTorch 中,训练完成后,可以将模型保存到磁盘上持久化存储,模型保存方式有:
    • 保存/加载整个模型(包括模型结构和参数)
    • 保存/加载模型参数(状态字典)(推荐方式,包含参数和其他状态信息)
    • 保存/加载检查点(用于断点续训)
  • 对于跨渠道的保存和加载,可以存储为 TorchScript 格式,这是一种专为生产环境优化的模型序列化格式
    • TorchScript 格式能够将 PyTorch 模型转换为一种可序列化、可优化的中间表示形式,便于在不同环境(包括 C++ 部署)中运行
  • 保存模型时 .pt 和 .pth 两种存储格式之间没有本质区别,主要区别在于使用习惯:
    • 早期 PyTorch 文档和示例中更常用 .pth 扩展名
    • 后来随着 PyTorch 版本中,官方示例逐渐开始使用 .pt,逐渐成为更推荐的格式
    • 实际上两者完全等价,亲测:.pt 和 .pth 直接修改后缀就能混用

存储模型和加载模型示例

  • 有两种主要的方法来保存模型:保存整个模型或仅保存模型的状态字典(推荐)

整个模型存储(不常用)

  • 保存整个模型并加载整个模型

    1
    2
    3
    4
    5
    # 保存整个模型
    torch.save(model, 'simple_model.pt') # 保存

    # 加载整个模型
    model_loaded = torch.load('simple_model.pt') # 加载
    • 注意:这种方法要求模型类定义必须可用,读不到类名会出错
  • 特别提示(容易出错):容易导致模型代码定义和加载存储模型不一致情况

    • 加载规则:实际上只要模型的类名相同即可加载(按照类名匹配的),模型结构可以定义不一致
    • 模型加载后模型实际对象结构与存储真实结构一致,会丢失当前代码定义的模型所有结构(包括类属性也不存在)
    • 理论上仅仅需要定义一个类名即可正常加载,但是使用时是按照存储模型的结构来使用的

仅存储模型参数(常用)

  • 保存模型的状态字典 并加载:

    1
    2
    3
    4
    5
    6
    # 保存模型参数
    torch.save(model.state_dict(), 'simple_model_state_dict.pth')

    # 创建模型示例并加载模型参数
    model_loaded = SimpleModel() # 实例化模型
    model_loaded.load_state_dict(torch.load('simple_model_state_dict.pth'))
    • 待加载模型参数和定义的模型参数不同时会直接报错
    • 状态字典方法更灵活,允许你只保存模型参数,不包括模型结构,这在部署时特别有用

常见保存和加载方式的整体示例

  • 整体示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    import torch
    import torch.nn as nn
    import torch.optim as optim

    class DiyModel(nn.Module):
    def __init__(self):
    super(DiyModel, self).__init__()
    self.fc1 = nn.Linear(10, 20)
    self.fc2 = nn.Linear(20, 2)
    self.relu = nn.ReLU()

    def forward(self, x):
    x = self.fc1(x)
    x = self.relu(x)
    x = self.fc2(x)
    return x

    model = DiyModel()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    # 1. 保存和加载整个模型(包括结构和参数)
    # 保存
    torch.save(model, 'entire_model.pt')

    # 加载
    loaded_entire_model = torch.load('entire_model.pt')
    loaded_entire_model.eval() # 设置为评估模式,方便后续的 Serving
    print("加载整个模型成功")

    # 2. 仅保存和加载模型参数(推荐方式)
    # 保存
    torch.save(model.state_dict(), 'model_parameters.pt')

    # 加载
    loaded_model = DiyModel() # 需要先创建模型实例
    loaded_model.load_state_dict(torch.load('model_parameters.pt'))
    loaded_model.eval() # 设置为评估模式
    print("加载模型参数成功")

    # 3. 保存和加载检查点(用于断点续训)
    # 保存:包含模型参数、优化器状态、 epoch 等信息
    checkpoint = {
    'epoch': 5,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': 0.23,
    # 可以添加其他需要保存的信息
    }
    torch.save(checkpoint, 'training_checkpoint.pt')

    # 加载检查点
    loaded_checkpoint = torch.load('training_checkpoint.pt')

    # 恢复模型和优化器状态
    restored_model = DiyModel()
    restored_optimizer = optim.SGD(restored_model.parameters(), lr=0.001, momentum=0.9)

    restored_model.load_state_dict(loaded_checkpoint['model_state_dict'])
    restored_optimizer.load_state_dict(loaded_checkpoint['optimizer_state_dict'])
    epoch = loaded_checkpoint['epoch']
    loss = loaded_checkpoint['loss']

    restored_model.train() # 恢复训练时设置为训练模式
    print(f"加载检查点成功: epoch {epoch}, loss {loss}")

    # 4. 跨设备保存和加载(跨设备
    # 在 GPU 上保存,在 CPU 上加载
    if torch.cuda.is_available():
    model.cuda()
    torch.save(model.state_dict(), 'model_gpu.pt') # 同时也会保存一些 GPU 相关信息

    # 在CPU上加载GPU保存的模型
    cpu_model = DiyModel()
    # 特别需要注意的点:map_location指明设备,是必须的参数(实际上,为了兼容,建议所有的加载都加上 `map_location=device` 参数)
    cpu_model.load_state_dict(torch.load('model_gpu.pt', map_location=torch.device('cpu')))
    print("在 CPU 上加载 GPU 保存的模型成功")

TorchScript 格式

  • TorchScript 格式能够跨平台存储和加载(可在 C++ 环境中运行,无需 Python 依赖)
  • TorchScript 格式在生产环境中更常用
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    import torch
    import torch.nn as nn

    class DiyModel(nn.Module):
    def __init__(self):
    super(DiyModel, self).__init__()
    self.fc1 = nn.Linear(10, 20)
    self.fc2 = nn.Linear(20, 2)
    self.relu = nn.ReLU()

    def forward(self, x):
    x = self.fc1(x)
    x = self.relu(x)
    x = self.fc2(x)
    return x

    model = DiyModel() # 创建模型

    # 1. 将模型转换为TorchScript格式(注:这一步是必要的)
    # 方法1:跟踪(tracing) - 适用于无控制流的模型(不常用,不推荐,容易出错)
    example_input = torch.randn(1, 10)
    traced_script_module = torch.jit.trace(model, example_input)
    # tracing 方法引入一个输入数据来执行流程,同时基于执行流程生成模型静态图
    # 如果存在控制流,会只保留example_input下会遇到的控制流
    # 以后遇到其他数据也会都走这个控制流,从而导致错误发生
    # 所以仅适用于无控制流的模型,不推荐使用

    # 方法2:脚本(scripting) - 适用于有控制流的模型和无控制流的模型(常用,推荐,兼容性好)
    scripted_script_module = torch.jit.script(model) # 解析 Python 代码结构生成静态图,能完整保留模型结构
    # 特别说明:scripting形式和tracing保存的模型性能上差异不大

    # 2. 保存TorchScript模型
    traced_script_module.save("traced_model.pt")
    scripted_script_module.save("scripted_model.pt")

    # 3. 加载TorchScript模型
    loaded_traced_model = torch.jit.load("traced_model.pt")
    loaded_scripted_model = torch.jit.load("scripted_model.pt")

    # 4. 使用加载的模型进行推理
    loaded_traced_model.eval()
    loaded_scripted_model.eval()
    with torch.no_grad():
    input_data = torch.randn(1, 10)
    output1 = loaded_traced_model(input_data)
    output2 = loaded_scripted_model(input_data)

使用 tracing 方法保存含控制流模型

  • 下面是使用 tracing 方法保存含控制流模型出错的示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    import torch
    import torch.nn as nn

    class ControlFlowModel(nn.Module):
    def __init__(self):
    super().__init__()
    self.fc1 = nn.Linear(1, 1)
    self.fc2 = nn.Linear(1, 1)

    def forward(self, x):
    if x > 0: # 控制流语句:根据输入值选择不同分支
    output = self.fc1(x) # 分支1
    else:
    output = self.fc2(x) # 分支2
    return output

    model = ControlFlowModel()
    model.fc1.weight.data = torch.tensor([[2.0]]) # 分支1:输出 = 2*x
    model.fc1.bias.data = torch.tensor([0.0])
    model.fc2.weight.data = torch.tensor([[3.0]]) # 分支2:输出 = 3*x
    model.fc2.bias.data = torch.tensor([0.0])

    # 1. 使用 tracing 方法转换模型(示例输入为正数,触发分支1)
    example_input = torch.tensor([1.0]) # 正数:走fc1分支
    traced_model = torch.jit.trace(model, example_input)

    # 2. 测试不同输入的推理结果
    test_inputs = [
    torch.tensor([2.0]), # 正数(与示例输入同分支)
    torch.tensor([-1.0]) # 负数(与示例输入不同分支)
    ]

    print("Original Model Output:")
    for x in test_inputs:
    print(f"Input {x.item()}: Output {model(x).item()}")

    print("\nTracing Model Output:")
    for x in test_inputs:
    print(f"Input {x.item()}: Output {traced_model(x).item()}")
    # 输入 -1.0 时错误地输出了 -2.0,实际上应该输出 -3.0

    # Original Model Output:
    # Input 2.0: Output 4.0
    # Input -1.0: Output -3.0
    #
    # Tracing Model Output:
    # Input 2.0: Output 4.0
    # Input -1.0: Output -2.0

大模型的存储和加载

  • 大模型一般以 Safetensors 格式存储(Hugging Face 的默认存储形式),许多 CV 和 NLP 的开源模型都是这个格式
  • 超大规模模型还可以分片保存
  • 大模型存储和加载的示例:
    1
    # 待补充

附录:加载模型后的动作

  • 切换模式 :无论哪种保存方式,在加载模型后,如果模型包含 Batch Normalization、Dropout 等层,都必须调用 model.eval() 来确保这些层在推理时行为正确
    • 因为这样会关闭 Dropout 和 Batch Normalization 等层的行为变化
  • 禁用梯度 :调用模型 Serving 前建议禁用梯度
    • 加速并节省内存 :禁用梯度计算可以减少运行时的内存占用,因为在前向过程中不需要存储用于反向传播的信息,由于也不需要准备一些梯度计算的步骤,可小幅提升 Serving 速度
    • 增加代码可读性 :使用 torch.no_grad() 强调了当前操作的目的(即,仅执行前向计算而不更新模型权重),这提高了代码的可读性和意图的一致性(这对于维护和理解代码非常有帮助)
  • torch.no_grad()的实践Demo:
    1
    2
    3
    4
    5
    model.eval()  # 设置模型为评估模式,不会禁用梯度
    with torch.no_grad(): # 在评估过程中禁用梯度计算
    output = model(X_test)
    loss = criterion(output, Y_test)
    print(f'Test Loss: {loss.item():.4f}')

Python——Ray-多节点集群启动


整体说明

  • Ray 集群的启动有多种方法,本文简单总结这些方法
  • 核心概念补充:
    • 头节点(Head Node) :集群的主节点,负责管理整个集群的资源、任务调度和元数据存储
    • 工作节点(Worker Node) :通过连接头节点加入集群,提供计算资源(CPU/GPU/内存等)

配置文件启动

  • 设置配置文件 cluster.yaml

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    cluster_name: my-ray-cluster
    max_workers: 2

    # 头节点配置
    head_node:
    InstanceType: m5.large
    ImageId: ami-0abcdef1234567890 # 选择包含Ray的镜像

    # 工作节点配置
    worker_nodes:
    InstanceType: m5.xlarge

    # 启动命令(可选,自定义节点初始化)
    setup_commands:
    - pip install pandas # 安装依赖
  • 主节点启动 ray up cluster.yaml

  • 其他节点加入集群 ray attach cluster.yaml

  • 停止容器 ray down cluster.yaml


Kubernetes 部署

  • 待补充

Python 程序内启动

  • 主节点使用:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    import ray

    # 启动本地集群(自动创建头节点和工作进程)
    ray.init(
    num_cpus=4, # 模拟4核CPU
    num_gpus=1, # 模拟1块GPU
    port=6380, # 手动指定通信端口
    dashboard_host="0.0.0.0" # 允许外部访问Dashboard
    )

    # 验证集群
    print("集群节点数:", len(ray.nodes())) # 输出1(单节点模拟)

    # 关闭集群
    ray.shutdown()
  • 其他节点加入使用

    1
    2
    3
    4
    5
    6
    # 头节点启动Ray
    import ray
    ray.init()

    # 工作节点连接到头节点
    ray.init(address='头节点地址:6380')

使用 ray start 命令启动

  • ray start 是启动 Ray 集群节点的常用命令
  • 可分别分别使用 ray start 用于初始化头节点(Head Node)和工作节点(Worker Node)

启动头节点

  • 头节点是集群的入口,必须先启动

  • 基本命令格式:

    1
    ray start --head [其他可选参数]
  • 关键参数包括:

    参数 说明 示例
    --head 声明当前节点为头节点(必选) -
    --port 指定 Ray 内部通信端口(默认 6379,若被占用会自动切换) --port=6380
    --dashboard-host 允许外部访问 Ray dashboard 的主机地址(默认仅本地访问) --dashboard-host=0.0.0.0
    --dashboard-port Dashboard 端口(默认 8265) --dashboard-port=8266
    --num-cpus 手动指定该节点可用的 CPU 核心数(默认自动检测) --num-cpus=16
    --num-gpus 手动指定该节点可用的 GPU 数量(默认自动检测) --num-gpus=2
    --memory 限制节点可用内存(单位:字节,如 1000000000 表示 1GB) --memory=8000000000
    --object-store-memory 对象存储的内存上限(默认总内存的 30%) --object-store-memory=2000000000
    --block 启动后阻塞终端(不后台运行,便于调试) -
    --log-dir 指定日志目录(默认 ~/raylogs) --log-dir=/path/to/logs
  • 示例:启动一个允许外部访问 Dashboard、指定 CPU/GPU 资源的头节点:

    1
    2
    3
    4
    5
    6
    ray start --head \
    --port=6379 \
    --dashboard-host=0.0.0.0 \
    --dashboard-port=8265 \
    --num-cpus=12 \
    --num-gpus=1

启动工作节点

  • 工作节点需通过头节点的地址加入集群,命令格式:

    1
    ray start --address=<IP>:<port> [其他可选参数]
  • 关键参数说明:

    • --address:头节点的地址(必填,格式为 <IP>:<port>,即头节点启动时输出的地址)
    • 其他参数(如 --num-cpus、--num-gpus 等)与头节点相同,用于限制工作节点的资源
  • 示例:连接到 IP 为 192.168.1.100、端口为 6379 的头节点,同时指定工作节点的资源:

    1
    2
    3
    ray start --address='192.168.1.100:6379' \
    --num-cpus=8 \
    --num-gpus=0

使用 ray 命令验证集群状态

  • 可在任意节点执行下面脚本查看节点列表

    1
    ray status
    • 输出会显示集群中的所有节点及资源使用情况
  • 通过头节点的 http://<头节点IP>:8265(或其他指定端口) 查看集群监控、任务状态等

使用 ray 命令停止集群

  • 停止单个节点(包括工作节点或头节点):

    1
    ray stop
  • 特别注意:若头节点停止,整个集群会自动解散

附录:ray start 命令的其他高级配置

  • 可通过 --runtime-env 指定环境配置文件(如依赖安装、环境变量等):

    1
    ray start --head --runtime-env=runtime_env.yaml
  • 可通过 --redis-password 设置密码,防止未授权节点加入:

    1
    2
    3
    4
    5
    # 头节点
    ray start --head --redis-password='mysecret'

    # 工作节点
    ray start --address=<IP:port> --redis-password='mysecret'
  • 可通过 --log-dir 指定日志目录(默认 ~/raylogs):

    1
    ray start --head --log-dir=/path/to/logs

附录:通过 ray job submit 向已经启动的 Ray 集群提交任务

  • ray job submit 命令用于将作业提交到 Ray 集群

基本语法

  • 用法说明:

    1
    ray job submit [options] -- <entrypoint> [<entrypoint_args>]
    • [options] 是命令的可选参数
    • -- 之后的 <entrypoint> 是要执行的入口点脚本或命令
      • 可以是 -- python my_script.py 或 bash my_shell.sh
    • <entrypoint_args> 是传递给 <entrypoint> 的参数

常用参数

  • --address:指定 Ray 集群的地址,通常是集群头节点的地址和端口,如http://127.0.0.1:8265
  • --working-dir:指定作业的工作目录
    • 该目录下的文件会被同步到集群节点上,默认为当前目录
    • 启动脚本会有一些文件依赖,这里上传所有文件可以保证本地能访问的文件,集群也能访问
  • --runtime-env:用于指定作业的运行时环境,可以是一个 JSON 格式的字符串或 YAML 文件路径
    • 例如,可以通过该参数指定需要安装的 Python 包,如--runtime-env='{"pip": ["requests"]}'
  • --no-wait:提交作业后不等待作业完成,立即返回
    • 如果不指定该参数,命令会等待作业完成,并输出作业的日志和结果
  • --submission-id:指定作业的提交 ID
    • 如果不指定,Ray 会自动生成一个唯一的 ID

用法示例

  • 假设要提交一个 Python 脚本 my_script.py 到 Ray 集群,指定集群地址为 http://127.0.0.1:8265,工作目录为当前目录:

    1
    RAY_ADDRESS='http://127.0.0.1:8265' ray job submit --working-dir. -- python my_script.py
  • 假设提交一个 Python 脚本 train.py,并传递参数 --epochs 10 --batch-size 32,同时指定运行时环境需要安装 torch 和 numpy 包:

    1
    ray job submit --address=http://your-ray-cluster-address:8265 --runtime-env='{"pip": ["torch", "numpy"]}' -- python train.py --epochs 10 --batch-size 32

其他高级配置

  • py_modules :指定需要导入的自定义 Python 模块路径,支持将本地模块添加到 Python 路径(sys.path),示例如下:

    1
    runtime_env={"py_modules": ["./my_utils"]}  # 同步 my_utils 模块并添加到路径
  • env :指定预定义的环境名称(如 Ray 集群中已配置的共享环境),避免重复配置,示例如下:

    1
    runtime_env={"env": "shared-training-env"}  # 使用集群中预定义的环境
  • 特别说明:runtime_env 还支持其他非官方扩展,比如 verl 库中就为英伟达显卡配置了 nsight 工具的参数传入

    1
    2
    ./verl/trainer/main_ppo.py
    runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote()

NLP——LLM内存优化技术总结

本文主要介绍内存优化相关技术


LLM内存优化技术总结

  • LLM(大型语言模型)在内存优化方面采用了多种技术,常见的方法包括

梯度检查点(Gradient Checkpointing)

  • 在前向传播时只保存部分激活值,其余的在反向传播时重新计算,主要用在预训练阶段
  • 特点:显著减少内存占用,但会增加计算量

混合精度训练(Mixed Precision Training)

  • 使用16位浮点数(FP16)代替32位浮点数(FP32)进行计算和存储,通过自动混合精度(AMP)工具实现,关键部分(如梯度更新)仍使用FP32以保证数值稳定性
  • 特点:减少内存使用并提升计算速度,但对显卡有要求,适用于支持FP16的硬件(如NVIDIA Tensor Core GPU)

模型并行(Model Parallelism)

  • 将模型的不同层分配到多个设备上,减少单个设备的内存负担
  • 特点:支持更大模型的训练,但增加了通信开销

数据并行(Data Parallelism)

  • 将数据批次分配到多个设备上,每个设备拥有完整的模型副本
  • 特点:通过增加设备数量来分摊内存压力

梯度累积(Gradient Accumulation)

  • 在多个小批次上累积梯度后再更新模型参数,多次累计梯度后一次更新参数,能够用小内存实现大梯度更新参数
  • 特点:减少单次内存需求,支持更大的批次训练

参数卸载(Parameter Offloading)

  • 将部分模型参数存储在CPU或磁盘上,需要时再加载到GPU,可通过框架(如DeepSpeed的ZeRO-Offload)自动管理参数的加载和卸载
  • 特点:减少GPU内存占用,但可能增加I/O开销

稀疏注意力机制(Sparse Attention)

  • 只计算输入序列中部分位置的注意力权重,通过设计稀疏模式(如局部窗口、随机采样)减少计算量
    • 稀疏注意力机制(Sparse Attention)是一种优化 Transformer 模型中注意力计算的技术,旨在减少计算复杂度和内存占用。它通过限制每个输入位置只与部分其他位置进行注意力计算,而不是与所有位置进行全连接计算,从而实现高效的计算和内存管理
  • 特点:降低内存和计算复杂度、支持更长序列等,对效果是有损的

量化(Quantization)

  • 将模型参数从高精度(如FP32)转换为低精度(如INT8),比如 QLoRA 等可以做到更低的量化,大幅降低模型微调和推理的内存
  • 特点:减少内存占用和计算量,对效果可能有损

知识蒸馏(Knowledge Distillation)

  • 用大模型训练小模型,使其性能接近大模型,小模型模仿大模型的输出
  • 特点:减少内存和计算资源需求,可加速模型的部署和推理

内存高效优化器(Memory-Efficient Optimizers)

  • 使用如Adafactor等优化器,减少存储优化状态的内存
    • AdaFactor是一种优化算法,旨在减少在训练深度学习模型时的内存占用,同时保持或提高模型性能。它是由Google的研究人员提出的一种自适应学习率优化器,其特点是显存成本(Memory Cost)是次线性的(Sublinear),意味着随着参数数量的增长,所需的额外内存不会线性增长
  • 特点:降低训练时的内存使用

分层训练(Layer-wise Training)

  • 逐层训练模型,每次只加载当前层的参数和梯度
  • 特点:减少内存需求,但可能影响模型性能

内存池(Memory Pooling)

  • 预先分配并复用内存块,减少频繁分配和释放的开销
  • 特点:提高内存使用效率

模型剪枝(Model Pruning)

  • 通过删减网络结构,移除不重要的神经元或连接,实现模型压缩
  • 特点:减少模型大小和内存占用,但这种方式不常用

低秩分解(Low-Rank Factorization)

  • 将大矩阵分解为多个小矩阵,常见的方式就是LoRA相关的技术,也可以和量化结合,如QLoRA等
  • 特点:减少内存和计算需求,适用于各种微调场景

vLLM推理框架

  • vLLM 是一个专注于高效推理的框架,要用于推理阶段 ,通过 PagedAttention、连续批处理、量化等技术优化内存和计算效率,显著提升吞吐量和响应速度
  • PagedAttention
    • 这是 vLLM 的核心技术,灵感来源于操作系统的虚拟内存分页机制。它通过分页管理注意力机制中的键值(KV)缓存,显著减少了内存浪费,并支持动态调整缓存大小,从而提高了吞吐量和内存利用率
  • 连续批处理(Continuous Batching)
    • vLLM 支持将多个请求批量处理,通过共享计算资源(如 KV 缓存)来减少重复计算,从而提高吞吐量
  • 量化技术
    • 支持多种量化方法(如 GPTQ、AWQ、FP8 KV Cache 等),通过降低模型参数的精度来减少内存占用和计算开销
  • 张量并行(Tensor Parallelism)
    • 支持将模型分布到多个 GPU 上运行,通过并行计算加速推理过程
  • 推测解码(Speculative Decoding) ,也称为 投机采样
    • 使用较小的模型预测词元,再用大模型验证结果,从而加速文本生成
  • Flash Attention
    • 优化 Transformer 模型的注意力计算,减少计算复杂度和内存占用
  • OpenAI 兼容 API
    • 提供与 OpenAI API 兼容的接口,便于集成到现有应用中
  • 多 LoRA 支持
    • 支持多 LoRA(低秩适应)模型,允许在同一框架下运行多个微调模型

ZeRO显存优化技术

  • ZeRO(Zero Redundancy Optimizer)是一种用于训练阶段的显存优化技术,主要用于训练阶段,通过分片存储、通信优化和混合精度训练等技术减少显存占用,支持更大规模的模型训练
  • ZERO 技术最初是微软在 2020 年的论文 ZeRO: Memory Optimization Towards Training Trillion Parameter Models 中被提出的,详细阐述了 ZERO 的三个阶段(ZERO-1、ZERO-2、ZERO-3)及其内存优化原理
  • ZERO 技术也是 DeepSpeed 框架的核心创新之一(注:DeepSpeed 是微软开发的一个用于大规模深度学习训练的优化库)
  • 分片存储(Sharding)
    • ZeRO 将模型参数、梯度和优化器状态分片存储到多个 GPU 上,从而减少单个 GPU 的内存占用。分为三个阶段:
      • ZeRO Stage 1(ZeRO-1) :仅分片优化器状态
      • ZeRO Stage 2(ZeRO-2) :分片优化器状态和梯度
      • ZeRO Stage 3(ZeRO-3) :分片优化器状态、梯度和模型参数
  • 通信优化
    • ZeRO 通过优化 GPU 间的通信(如 All-Reduce 和 Reduce-Scatter 等GPU通信操作),减少分布式训练中的通信开销
  • 混合精度训练
    • 支持 FP16 和 FP8 等低精度训练,减少显存占用并加速计算
  • 重计算(Gradient Checkpointing)
    • 在前向传播时只保存部分激活值,反向传播时重新计算其余部分,从而减少显存占用
  • 负载均衡
    • 在 MoE(Mixture of Experts)模型中,通过优化路由策略和负载分配,避免专家模型之间的负载不均衡
  • 后来微软 DeepSpeed 团队继续对 ZERO 技术进行扩展,退出了 ZeRO-Offload 和 ZeRO-Infinity 等高级技术:
    • ZeRO-Offload 可以将已划分的优化器状态和梯度卸载到 CPU 内存中
    • ZeRO-Infinity 是 ZeRO-3 的扩展,它可以利用 CPU 和 NVMe 内存来进一步扩展 GPU 的内存,支持训练更大型的模型

附录:大模型推理中的模型量化技术总结

  • 大模型推理中,模型量化旨在减少模型的存储和计算需求,同时尽量保持模型的性能

GPTQ(Gradient-based Post-training Quantization)

  • TLDR:基于梯度的后训练量化方法
  • 基本原理:
    • 在模型训练完成后,对模型权重进行量化
    • 通过优化目标函数来最小化量化误差,利用梯度调整量化时的权重误差,使量化后模型与未量化模型的表现尽可能接近
    • 采用误差反馈机制,将量化误差传播到后续层进行补偿,减少累积误差对模型输出的影响
  • 特点:
    • 适用于 8-bit 或更低的量化需求,尤其对大语言模型量化效果好
    • 不需要额外的训练数据,精度损失相对较小,特别适合复杂模型
    • 针对 GPU 使用进行了优化,在 GPU 推理时性能较好,能将权重动态去量化为 float16,提高性能的同时保持低内存占用

AWQ(Activation-aware Quantization)

  • TLDR:关注激活值的量化方法,量化过程中考虑激活值分布对模型性能的影响
  • 基本原理:
    • 分析激活值的分布特性,对激活值进行适应性处理
    • 采用非均匀量化,针对不同的激活值范围选择不同的量化尺度
  • 特点:
    • 精度较高,通过对激活值分布的考虑,能更好地保留模型的性能
    • 计算复杂度较大,因为需要分析激活值分布并进行非均匀量化操作

GGUF(Generalized Global Uniform Quantization Framework)

  • TLDR:一种通用的全局统一量化框架,用于处理大规模神经网络
  • 基本原理:
    • 通常采用全局统一量化策略,假设模型的所有层或某一类参数具有相似的分布,对整个模型的权重或激活值采用相同的量化参数
    • 采用均匀量化,将所有数值线性地映射到一个均匀的范围,并引入缩放因子,在推理阶段重定标量化后的数值,避免数值溢出或精度过低
  • 特点:
    • 简单高效,适用于资源受限的部署场景,如普通 CPU 环境
    • 兼容 Windows 和 Linux 操作系统
    • 提供从 2-bit 到 8-bit 的多级量化选项
    • 由于采用统一量化策略,可能导致某些模型层的精度损失

DL——TensorBoard的使用


整体说明

  • TensorBoard 是 TensorFlow 提供的可视化工具,能帮助理解、调试和优化深度学习模型
  • 安装 TensorBoard
    1
    pip install tensorboard

启动 TensorBoard

  • 可以使用命令行工具执行下面的命令从一个指定目录启动 TensorBoard:

    1
    tensorboard --logdir=path/to/logs --port=6006 --host=0.0.0.0
  • 参数解释:

    • --logdir:这个参数用于指定 TensorFlow 事件文件所在的目录
      • TensorBoard 会对该目录进行监控,一旦有新的事件文件生成,它就会实时更新可视化内容
      • 可以只指定一个目录,也可以通过逗号分隔的方式指定多个目录,或者使用通配符来匹配多个目录
      • 一个目录下可以有多个子目录,TensorBoard 会同时显示,可通过页面选择勾选目标子文件夹
    • --port:此参数用于设置 TensorBoard 服务监听的端口,默认使用 6006 端口
      • 如果你想在同一台机器上同时运行多个 TensorBoard 实例,可以为它们指定不同的端口后分别启动
    • --host:通过这个参数可以设置 TensorBoard 服务监听的 IP 地址
      • 默认是 localhost(即 127.0.0.1),此时仅能在本机上访问
      • 若你想让其他机器能够访问当前机器上的 TensorBoard,可将其设置为 0.0.0.0

TensorBoard 展示细节

  • TensorBoard 启动之后,通过浏览器即可访问
  • 最常用的 TensorBoard 页面包含 Scalars、Graphs、Histograms 等,点击这些选项卡可以查看不同类型的可视化数据

TensorBoard 文件的生成

PyTorch TensorBoard 示例

  • 使用 PyTorch 生成 TensorBoard 文件的示例
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    import torch
    from torch.utils.tensorboard import SummaryWriter

    # 创建 SummaryWriter 对象,指定日志保存目录
    writer = SummaryWriter('runs/pytorch_demo')

    # 模拟训练过程
    for epoch in range(100):
    # 模拟损失值(通常是训练过程中计算得到的)
    loss = 0.9 ** epoch

    # 记录损失值到 TensorBoard
    writer.add_scalar('Loss/train', loss, epoch)

    # 模拟模型权重
    weights = torch.randn(10) * (0.95 ** epoch)

    # 记录直方图到 TensorBoard
    writer.add_histogram('Weights', weights, epoch)

    # 关闭 writer
    writer.close()

PyTorch TensorBoard 示例

  • 使用 TensorFlow 生成 TensorBoard 文件的示例
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    import tensorflow as tf
    import numpy as np

    # 创建 SummaryWriter 对象,指定日志保存目录
    writer = tf.summary.create_file_writer('runs/tensorflow_demo')

    # 模拟训练过程
    for epoch in range(100):
    # 模拟损失值
    loss = 0.9 ** epoch

    # 记录损失值到 TensorBoard
    with writer.as_default():
    tf.summary.scalar('Loss/train', loss, step=epoch)

    # 模拟模型权重
    weights = np.random.randn(10) * (0.95 ** epoch)

    # 记录直方图到 TensorBoard
    with writer.as_default():
    tf.summary.histogram('Weights', weights, step=epoch)

一些说明

  • TensorBoard 可以同时显示多个项目的数据,如上述示例就在同一个目录 runs 下分别创建了 PyTorch 和 TensorFlow 的文件夹
  • 训练过程中数据不会自动刷新,可以随时刷新浏览器查看实时更新的数据

DL——Teacher-Forcing方法

本文主要介绍Transformer和Attention相关内容


整体总结

  • 教师强制(Teacher Forcing) 是一种在训练序列生成模型(包括循环神经网络 RNN、长短期记忆网络 LSTM 等)时使用的方法
  • 其核心思想是在训练过程中强制模型使用真实的目标序列作为输入 ,而非模型自身的预测结果,从而解决序列生成任务中可能出现的误差累积问题
  • 大模型的 SFT 方法就是一种 Teacher Forcing 方法,属于一种 Token-level 的行为克隆

Teacher Forcing 的基本原理

  • 在序列生成任务(如机器翻译、文本生成、语音识别等)中,模型需要根据历史输入和已生成的序列来预测下一个输出
  • 传统训练方式下,若直接使用模型前一步的预测结果作为下一步的输入,一旦某一步预测错误,后续预测可能会因误差累积而“偏离轨道”,导致训练不稳定
  • 教师强制的做法 :在每一步训练中,强制使用真实的目标序列(而非模型上一步的预测值)作为下一步的输入
    • 例如:在机器翻译中,当生成第二个词时,不使用模型预测的第一个词,而是直接使用参考译文中的第一个词,以此类推

Teacher Forcing 的具体流程(以LSTM为例)

  • 假设我们有一个序列生成任务,目标序列为 \( y_1, y_2, y_3, \dots, y_T \),模型输入为 \( x_1, x_2, \dots, x_T \),则训练过程如下:
    • 第一步 :输入 \( x_1 \),模型预测 \( \hat{y}_1 \),与真实值 \( y_1 \) 计算损失并更新参数
    • 第二步 :不使用 \( \hat{y}_1 \),而是将真实值 \( y_1 \) 作为输入,结合 \( x_2 \),模型预测 \( \hat{y}_2 \),计算损失并更新参数
    • 后续步骤 :重复上述过程,每一步都用真实的 \( y_{t-1} \) 作为当前步的部分输入,直至生成 \( \hat{y}_T \)。

Teacher Forcing 的优缺点分析

优点

  • 训练更稳定 :避免因早期预测错误导致的误差累积,模型更容易收敛
  • 加速收敛 :真实目标序列提供了更准确的监督信号,减少了训练迭代次数
  • 降低训练难度 :尤其适合复杂序列任务(如长文本生成),避免模型“发散”

缺点

  • 训练与推理偏差 :推理时(如实际生成文本)无法获取真实目标序列,需依赖模型自身预测,可能导致“暴露偏差(Exposure Bias)”(即训练时的输入分布与推理时不一致)
  • 缺乏抗噪能力 :模型可能过度依赖真实标签,对预测误差的鲁棒性较差

其他相关训练方法的对比

  • 教师强制 :始终使用真实标签作为输入
  • 为解决教师强制的“暴露偏差”问题,有人提出了 Scheduled Sampling(计划采样) 方法:
    • Scheduled Sampling :在训练初期以高概率使用真实标签,随着训练推进,逐渐增加使用模型预测值的概率,使模型逐步适应推理时的输入分布
    • Scheduled Sampling通过平衡“教师指导”和“自主预测”,减少训练与推理的差异,提升模型泛化能力

代码示例

  • PyTorch实现简单教师强制训练
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    import torch
    import torch.nn as nn
    import torch.optim as optim

    class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
    super(LSTMModel, self).__init__()
    self.lstm = nn.LSTM(input_size, hidden_size)
    self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
    lstm_out, hidden = self.lstm(x, hidden)
    output = self.fc(lstm_out)
    return output, hidden

    def train_with_teacher_forcing(model, input_seq, target_seq, criterion, optimizer):
    model.train()
    hidden = model.init_hidden()
    optimizer.zero_grad()
    loss = 0

    for t in range(target_seq.size(0)):
    output, hidden = model(input_seq[t].unsqueeze(0), hidden)
    input_seq[t+1] = target_seq[t] # 下一时间步的输入使用真实目标值(Teacher Forcing 方法的核心代码)
    loss += criterion(output, target_seq[t].unsqueeze(0))

    loss.backward()
    optimizer.step()
    return loss.item()

Python——Ray-使用笔记


远程调用时传入的函数指针必须是远程函数

  • 在 Ray 中不支持直接传入 local 函数指针作为远程函数的执行对象,需通过 Ray 装饰器(@ray.remote)将函数注册为远程可执行,再通过 函数名.remote() 调用(本质是基于函数标识而非指针传递)
  • 总结:
    • 不推荐将普通函数作为参数传递给 Ray 远程函数
    • 推荐使用 @ray.remote 装饰器或在远程函数内部定义逻辑
    • 注意:一些代码在单机环境下可能碰巧能运行,但不具有可移植性和可靠性(这一点需要注意 Ray 本地调试通过可能也无法分布式运行)

错误示例(未注册本地函数)

  • 若 add 未被 @ray.remote 注册,它只是一个本地函数 ,无法在 Ray 分布式环境中执行,直接传递给远程函数(如 execute_func)会报错

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    import ray

    ray.init(ignore_reinit_error=True)

    # 未注册的本地函数
    def add(a, b):
    return a + b

    # 已注册的远程函数
    @ray.remote
    def execute_func(func, x, y):
    # 此处调用本地函数会失败,因为 func 在远程节点无定义
    # # 远程节点的工作进程无法导入本地主模块的 add_local 函数,也无法序列化传递普通函数,可能会直接抛出 SerializationError
    # # 单进程/单节点下调用指针函数可以执行,但是分布式情况下,local_func 无法被序列化,会出错
    return func(x, y) # 报错:NameError,PicklingError 或 SerializationError

    # 调用会抛出异常
    try:
    result = ray.get(execute_func.remote(add, 4, 6))
    except Exception as e:
    print("错误:", e) # 提示无法序列化或找不到函数

    ray.shutdown()
  • 核心原因:Ray 远程函数执行依赖序列化传输和集群节点间代码同步

    • 未注册的本地函数无法被序列化为集群可识别的任务,且远程节点没有该函数的定义,会导致执行失败

正确示例(远程函数调用)

  • Ray 的远程函数依赖集群调度,通过 @ray.remote 显式注册后使用远程调用函数调用

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    import ray

    ray.init(ignore_reinit_error=True)

    # 定义远程函数(会注册到 Ray 集群)
    @ray.remote
    def add(a, b):
    return a + b

    # 远程函数,可接收其他远程函数的调用结果
    @ray.remote
    def execute_func(func, x, y):
    # 这里 func 是远程函数标识,通过 .remote() 触发执行
    result = ray.get(func.remote(x, y)) # 使用远程调用的方式调用函数指针,实现调用远程函数,正确!
    # result = func(x, y) # remote 函数无法被直接调用,错误!
    # result = add(x,y) # remote 函数无法被直接调用,错误!
    # result = add_local(x, y) # add_local 当做 local 函数调用(注意:不再是指针传入),正确!
    return result

    # # 不使用 remote 直接调用 远程函数,错误
    # result1 = add(2, 3)

    # 使用remote直接调用远程函数,正确
    result1 = ray.get(add.remote(2, 3))
    print("直接调用结果:", result1) # 输出:5

    # 间接通过另一个远程函数调用(模拟"传递函数逻辑")
    result2 = ray.get(execute_func.remote(add, 2, 3))
    print("间接调用结果:", result2) # 输出:10

    ray.shutdown()
  • Ray 的远程函数依赖集群调度,需通过 @ray.remote 显式注册,无法像本地代码那样传递函数指针(内存地址在分布式环境中无效)

  • 若需在远程函数中复用其他函数逻辑,直接传递已注册的远程函数名(如示例中的 add),再通过 func.remote() 调用即可

1…242526…66
Joe Zhou

Joe Zhou

Stay Hungry. Stay Foolish.

659 posts
53 tags
GitHub E-Mail
© 2026 Joe Zhou
Powered by Hexo
|
Theme — NexT.Gemini v5.1.4