PyTorch——torch.distributed.rpc的使用


整体说明

  • torch.distributed.rpc 库是 PyTorch 用于 RPC 通信的包,初始化 RPC 环境的核心函数是 init_rpc
  • 在 PyTorch 中
    • torch.distributed.init_process_group 用于初始化进程组环境, 数据并行集体通信 ,比如同步梯度
    • torch.distributed.rpc.init_rpc 则用于初始化 RPC 通信环境, 模型并行远程过程调用 ,比如在另一台机器上执行一个函数
  • 在很多复杂的分布式训练场景中(如全分片数据并行+流水线并行),torch.distributedtorch.distributed.rpc 这两个模块甚至会同时被使用 ,以发挥它们各自的优势

回顾:torch.distributed.init_process_group 初始化进程组

  • torch.distributed.init_process_group 函数是 PyTorch Distributed (c10d) 包的初始化入口

    • 注:c10d 是 “Caffe2 TenSoR Distributed” 的缩写,c10d 代表了 PyTorch 中负责分布式通信的核心底层模块,提供了跨进程、跨节点的数据传输和同步机制;实现了诸如 ProcessGroup(进程组)、Backend(通信后端,如 NCCL、Gloo 等)等核心组件,是 torch.distributed 高层 API 的底层支撑
  • torch.distributed.init_process_group 函数会启动一个进程组,用于执行集体通信

  • torch.distributed.init_process_group 启动的进程组的通信范式是集体通信

    • 所有进程都必须参与同一个操作,并且需要等待所有进程都完成该操作后才能继续
    • 常见的操作包括:
      • all_reduce:所有进程的Tensor进行汇总(如求和)后,将结果返回给所有进程。(用于梯度同步)
      • broadcast:将一个进程的Tensor广播给所有其他进程
      • barrier:同步所有进程,确保大家都到达同一个点
      • scatter, gather, all_gather
  • 特点是训练代码在每个进程(通常是每个GPU)上几乎是完全相同

    • 每个进程都做同样的事情:读一部分数据、前向传播、反向传播、然后通过集体通信同步梯度
    • 示例如下:
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      # 示例:在每个进程上初始化进程组
      import torch.distributed as dist

      def setup(rank, world_size):
      dist.init_process_group(
      backend='nccl', # 或 'gloo', 'mpi'
      init_method='env://',
      rank=rank,
      world_size=world_size
      )
      # 创建模型、DDP包装、定义采样器等...
      torch.cuda.set_device(rank)

      # 假设使用 spawn 启动多个进程
      # 每个进程都会运行 setup 函数和后续相同的训练循环
  • 常用于:

    • 数据并行训练:(最经典的用法)
      • 在 DistributedDataParallel (DDP) 中,每个 GPU 上的模型副本计算完梯度后,使用 all_reduce 来对所有梯度进行求和平均,确保每个模型副本都得到相同的更新后的梯度
    • 分布式数据加载:使用 DistributedSampler 确保每个进程读取数据集中不重复的部分
    • 其他需要进程间紧密同步的计算任务

前置说明:RPC 框架建立在进程组之上

  • torch.distributed.rpc.init_rpc 启动的 RPC 建立在进程组之上:

    • RPC 框架底层依赖于 torch.distributed 的进程组来进行初始的握手和协调
    • 当调用 rpc.init_rpc 时,它会在内部检查是否已经存在一个初始化好的进程组
      • 如果存在,它就复用这个进程组
      • 如果不存在,它会*隐式地调用 dist.init_process_group * ,使用默认的 gloo 后端(除非你通过 rpc_backend_options 指定)来创建一个进程组
  • 在大多数情况下,只需要调用 rpc.init_rpc 即可,它会把底层需要的进程组也初始化好

    • 如果需要对进程组的后端(如使用nccl)或初始化方法进行更精细的控制,可以选择 显式地先调用 dist.init_process_group,然后再调用rpc.init_rpc ,代码示例如下:
      1
      2
      3
      4
      5
      6
      # 显式控制进程组后端的示例
      def setup(rank, world_size):
      # 1. 首先,用 NCCL 后端初始化进程组(为了更好的GPU通信性能)
      dist.init_process_group(backend='nccl', ...)
      # 2. 然后,初始化RPC框架。此时RPC检测到已有进程组,不会再次初始化
      rpc.init_rpc(...)
  • 亲测:init_rpc 使用的 master_ip:master_port,必须和 init_process_group 使用的一样,否则无法启动

    • 暂未看到官方明确说明(待补充),猜测原因是因为 RPC 框架依赖已经初始化好的进程组(若没有初始化进程组,init_rpc 会自己新建进程组)
    • 详细示例见附录
  • RPC 的 world_size 可以比 已初始化的进程组的多(虽然有依赖,但两者是相对独立的体系)【具体深层逻辑待梳理】

    • init_rpcworld_size 可以比 init_process_group 的大
    • 具体实现时,部分 rank 进程不调用 init_process_group 即可实现

torch.distributed.rpc.init_rpc 初始化 PyTorch RPC 框架

  • torch.distributed.rpc.init_rpc 函数是 PyTorch RPC 框架 的初始化入口

  • torch.distributed.rpc.init_rpc 函数核心目的是启用远程过程调用 ,允许一个进程调用另一个进程(可能在不同的机器上)上定义的函数或方法

  • torch.distributed.rpc.init_rpc 初始化的 PyTorch RPC 框架 通信范式是点对点通信

    • 一个进程(工作者)可以主动请求另一个进程执行某个任务,而不需要所有进程都参与
    • 核心 API 包括:
      • rpc_sync():同步调用,调用方会等待远程函数执行完毕并返回结果
      • rpc_async():异步调用,调用方立即返回一个 Future 对象,未来再从中获取结果
      • remote():异步地在远程工作者上创建一个对象(例如一个模块),并返回一个指向该远程对象的引用(RRef)
  • 常用于:

    • 模型并行:将一个大模型的不同部分放在不同的机器或 GPU 上
      • 前向传播时,数据需要从一个机器流到下一个机器
      • RPC 可以管理这种跨机器的计算流
    • 参数服务器(PS):
      • 一个或多个进程作为参数服务器存储模型参数,其他工作者通过 RPC 向参数服务器拉取参数或推送梯度
    • 强化学习:
      • 多个环境模拟器(Actor)通过RPC将经验数据发送给一个中心 learner 进行训练
    • 分布式流水线并行(待补充):
      • torch.distributed.pipeline.sync.Pipe 结合使用,RPC 负责处理不同设备间层的通信
  • 特点是每个进程的角色和执行的代码可能完全不同 ,一个简单的代码示例如下:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    # start shell: torchrun --nproc_per_node=2 torch_distributed_rpc_demo.py
    import os
    import torch
    import torch.distributed as dist
    import torch.distributed.rpc as rpc

    # 正常定义函数,可用于远程调用(在 worker0 发送命令从 worker1 执行函数)
    def worker_function(a, b):
    worker_info = rpc.get_worker_info()
    print(f"Worker {worker_info.name} received a call with a={a}, b={b}")
    return a + b

    def main():
    # 获取环境变量
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])

    # 1. 初始化进程组和 RPC
    dist.init_process_group(backend='gloo', rank=rank, world_size=world_size) # 建议提前初始化好进程组,如没有这行,下面的函数也会自动初始化一个 进程组
    rpc.init_rpc(
    name=f"worker_{rank}", # 若不唯一,会报错 RuntimeError: RPC worker name worker is not unique. Workers 1 and 0 share the same name
    rank=rank,
    world_size=world_size
    )

    print(f"Worker {rank} is ready.")

    if rank == 0:
    # 2. worker_0 发起 RPC 调用
    print("--- Worker 0 sending RPC calls ---")

    # 同步调用
    result_sync = rpc.rpc_sync(
    to=f"worker_1", # 真正的执行过程发生在 worker1 上
    func=worker_function,
    args=(10, 20)
    )
    print(f"Worker 0 received sync result from Worker 1: {result_sync}")

    # 异步调用
    future_async = rpc.rpc_async(
    to=f"worker_1",
    func=worker_function,
    args=(30, 40)
    )
    print("Worker 0 is doing other work while waiting for async result...")
    result_async = future_async.wait()
    print(f"Worker 0 received async result from Worker 1: {result_async}")

    else:
    # 3. worker_1 只是等待并处理 RPC 请求
    print("--- Worker 1 is waiting for RPC calls ---")
    # 由于 worker_0 正在调用,这里不需要做任何事,RPC 框架会自动处理

    rpc.shutdown() # 默认等待所有 RPC 调用完成后再关闭
    dist.destroy_process_group() # 注:若没有显示调用 dist.init_process_group,也不能调用这一行,否则会报错

    if __name__ == "__main__":
    main()
    • 上述代码的启动脚本(单机双卡):torchrun --nproc_per_node=2 torch_distributed_rpc_demo.py

rpc.init_rpc 函数 详细用法

  • PyTorch 的 torch.distributed.rpc.init_rpc 函数原型

    1
    2
    3
    4
    5
    6
    7
    torch.distributed.rpc.init_rpc(
    name,
    backend=None,
    rank=-1,
    world_size=None,
    rpc_backend_options=None
    )
  • 参数说明如下:

  • name(必填)

    • 当前工作进程的全局唯一名称 ,不能重复,重复会报错,后续会作为交互对象的指定
    • 名称只能包含数字、字母、下划线、冒号和/或短划线,且必须少于128个字符
    • 例如 "trainer0", "parameter_server1"
  • backend(选填)

    • 指定使用的RPC 后端实现
    • 可选值来自 torch.distributed.rpc.BackendType 枚举
      • BackendType.TENSORPIPE,更现代且推荐的 RPC 后端
      • BackendType.PROCESS_GROUP,不常用,对应的是 PyTorch 的后端,如 NCCL 和 Gloo
  • rank(必填)

    • 当前进程在组中的全局唯一 ID
    • 通常是一个整数,例如 0, 1, 2
    • 必须与 world_size 一起正确设置
  • world_size(必填)

    • 参与作业的进程总数
    • 所有进程的 world_size 必须相同
  • rpc_backend_options(选填)

    • 传递给底层 RPC 代理的高级选项 ,类型必须与 backend 参数选定的值严格对齐,否则会报错
      • backend=BackendType.TENSORPIPE 对应 rpc_backend_options 类别为:TensorPipeRpcBackendOptions
      • backend=BackendType.PROCESS_GROUP 对应 rpc_backend_options 类别为:ProcessGroupRpcBackendOptions
    • 可常用于设置 RPC 超时初始化方法
      • 若未指定,默认超时为60秒,并使用 init_method="env://"(这意味着需要设置 MASTER_ADDRMASTER_PORT 环境变量)
  • 使用示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    import torch.distributed.rpc as rpc
    from torch.distributed.rpc import TensorPipeRpcBackendOptions

    # 配置 TensorPipe 后端选项
    options = TensorPipeRpcBackendOptions(
    num_worker_threads=8, # 8 个工作线程,指定处理 RPC 请求的工作线程数,影响并发处理能力,默认值是 CPU 数量
    rpc_timeout=10000, # RPC 超时时间设置为 10 秒
    init_method="tcp://127.0.0.1:29500", # 初始化地址,格式和 init_process_group 的 init_method 参数一致
    device_maps={ "worker1": {0: 1}, "worker2": {1: 0} } # 本地 GPU 0 映射到 worker1 的 GPU 1;本地 GPU 1 映射到 worker2 的 GPU 0
    # devices=?, # 默认使用所有可用设备
    security_token="my_rpc_token" # str 类型安全令牌,用于验证节点身份的安全令牌,多机环境下防止未授权节点接入
    )

    # 初始化 RPC,传入配置
    rpc.init_rpc(
    name="worker0", # 当前节点名称
    backend=rpc.BackendType.TENSORPIPE,
    rpc_backend_options=options,
    rank=0, # 全局 rank
    world_size=2 # 总节点数
    )

TensorPipeRpcBackendOptions 的初始化细节

  • TensorPipeRpcBackendOptions 的函数原型为:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    def __init__(
    self,
    *, # Python 中可变位置参数的用法,仅使用一个 * 表示后续的参数必须按照 关键字参数形式调用
    num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS,
    rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC,
    init_method: str = rpc_contants.DEFAULT_INIT_METHOD,
    device_maps: Optional[Dict[str, Dict[DeviceType, DeviceType]]] = None,
    devices: Optional[List[DeviceType]] = None,
    _transports: Optional[List] = None,
    _channels: Optional[List] = None,
    )
  • 除了前文示例指定的参数外,其他参数介绍如下:

  • devices 参数

    • 指定当前节点上可用于 RPC 通信的设备(主要是 GPU),限制 RPC 操作只能使用列表中声明的设备
    • 默认值是 None(表示允许使用所有可见设备)
    • 当设置为非空列表时(如 [0, 1][torch.device('cuda:0'), torch.device('cuda:1')]),TensorPipe 只会使用列表中的设备进行数据传输,忽略其他设备
    • 主要用于多 GPU 场景下的资源隔离,例如限制 RPC 通信仅使用特定 GPU,避免占用训练用 GPU 资源
  • _transports 参数(以下划线开头,是私有参数,谨慎使用)

    • 指定 TensorPipe 可使用的传输协议(底层物理通信方式),用于跨节点或跨进程的数据传输
    • (推荐)默认值是自动检测并选择最优传输方式(通常为 ["tcp", "ibv"],即 TCP 和 InfiniBand)
    • 支持的传输协议:
      • "tcp":基于 TCP/IP 的网络传输,兼容性好,适用于普通网络环境
      • "ibv":基于 InfiniBand 的传输,低延迟、高带宽,适用于高性能计算集群
      • "shm":表示优先使用共享内存作为传输介质
      • "uv":基于 libuv 的本地传输(用于同一节点内的进程通信,通常自动启用)
    • 列表顺序表示优先级,TensorPipe 会优先尝试前面的传输方式
    • 仅在需要强制指定传输协议时使用(如集群仅支持 InfiniBand 时设置 ["ibv"]),默认自动选择即可满足多数场景
  • _channels 参数(以下划线开头,是私有参数,谨慎使用)

    • 指定 TensorPipe 可使用的通道类型(数据在设备间的内存传输方式),用于同一节点内的设备间通信(如 CPU-GPU、GPU-GPU)
    • (推荐)默认是自动检测并选择最优通道(通常为 ["cuda_ipc", "cuda_basic", "shm"]
    • 支持的通道类型有:
      • "cuda_ipc":GPU 进程间通信(同一节点内不同进程的 GPU 直接通信,效率最高)
      • "cuda_basic":通过 CPU 中转的 GPU 通信(当 cuda_ipc 不可用时降级使用)
      • "shm":共享内存通信(同一节点内的 CPU 进程间通信),这里强调通信通道,_transports"shm" 强调通信介质
      • "basic":通过系统内存拷贝的通信(通用 fallback 方式)
    • 列表顺序表示优先级,优先使用高效的通道(如 cuda_ipc 优于 cuda_basic
    • 主要用于调试或特殊硬件环境下的兼容性调整,默认配置已针对性能优化

初始化常见错误排查

  • 初始化失败:检查 MASTER_ADDRMASTER_PORT 是否设置正确且所有机器可达,防火墙是否放行了指定端口,以及 rankworld_size 是否配置正确
  • 超时错误:增加 rpc_backend_options 中的 rpc_timeout
  • 版本不匹配:确保所有参与计算的 PyTorch 版本一致

RPC 框架的交互模式

  • PyTorch 的 torch.distributed.rpc 模块为 分布式模型训练 提供了核心的远程过程调用(RPC)和远程引用(RRef)机制
  • 以下是 RPC 相关核心方法:
  • rpc_sync()(同步方法,阻塞)
    • 在远程 worker 上同步调用函数
    • 返回函数调用的直接结果
  • rpc_async()(异步方法,非阻塞)
    • 在远程 worker 上异步调用函数
    • 返回一个 Future 对象,可通过 future.wait() 获取结果
  • remote()(非阻塞)
    • 异步在远程 worker 上创建对象
    • 返回一个 RRef (远程引用) 指向创建的对象
  • RRef.to_here()(阻塞)
    • RRef 引用的值从所有者复制到本地节点
    • 返回引用的对象本身

同步远程调用 (rpc_sync)

  • rpc_sync阻塞调用者,直到远程函数执行完成并返回结果

  • 函数原型:

    1
    2
    3
    4
    5
    6
    7
    8
    # 在目标 worker (例如 to="worker1") 上同步运行函数 func 并获取结果
    result = torch.distributed.rpc.rpc_sync(
    to, # (str or WorkerInfo or int) 目标 worker 的名称、rank 或 WorkerInfo
    func, # (Callable) 待执行函数,是一个可调用函数(如 Python 函数、内置运算符或 TorchScript 函数)
    args=None, # (tuple) 传给函数的可变位置参数
    kwargs=None, # (dict) 传给函数的可变关键字参数
    timeout=-1.0 # (float, optional) RPC 超时时间(秒),-1.0 表示使用默认值(timeout=UNSET_RPC_TIMEOUT, 此时走初始化逻辑,或_set_rpc_timeout 设置的值),0 表示无限超时
    )
  • 使用示例:

    1
    2
    3
    # 在 worker 0 上:
    ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
    # 在 worker 1 上需要定义相应的函数或直接使用内置函数, worker 1 执行函数体,并返回结果

异步远程调用 (rpc_async)

  • rpc_async立即返回一个 Future 对象,你可以在之后需要时等待并获取结果,适用于非阻塞调用

  • 函数原型:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    # 异步调用远程函数,返回 Future 对象
    future = torch.distributed.rpc.rpc_async(
    to, # (str or WorkerInfo or int) 目标 worker 的名称、rank 或 WorkerInfo
    func, # (Callable) 待执行函数,是一个可调用函数(如 Python 函数、内置运算符或 TorchScript 函数)
    args=None, # (tuple) 传给函数的可变位置参数
    kwargs=None, # (dict) 传给函数的可变关键字参数
    timeout=-1.0 # (float, optional) RPC 超时时间(秒),-1.0 表示使用默认值(timeout=UNSET_RPC_TIMEOUT, 此时走初始化逻辑,或_set_rpc_timeout 设置的值),0 表示无限超时
    # 注:_set_rpc_timeout 是 PyTorch 分布式 RPC 框架中用于动态修改全局默认超时时间的函数,通常在 init_rpc(RPC 环境初始化)之后调用
    # 注:异步操作不会阻塞,这里的 timeout 是从此刻开始,远程操作本身返回的时间,如果有错也是在 future.wait() 时抛出
    )
    # ... 执行其他操作 ...
    result = future.wait() # 等待并获取结果,这里也可以设置 future 从此刻开始的等待 timeout, 设置后与 rpc_async 中的 timeout 时间两者时间取小(注意两者起始时间不同)
  • 使用示例:

    1
    2
    3
    fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3))
    fut2 = rpc.rpc_async("worker1", min, args=(1, 2))
    result = fut1.wait() + fut2.wait()

远程对象创建与引用 (remoteRRef)

  • remote 方法用于在远程 worker 上异步创建对象 ,并立即返回一个指向该对象的 RRef (Remote Reference)

  • RRefto_here() 方法可以阻塞地将远程对象的值复制到本地

  • remote 函数原型

    1
    2
    3
    4
    5
    6
    7
    8
    9
    # 在远程 worker 上创建对象并返回 RRef,持有 func 函数的返回值
    rref = torch.distributed.rpc.remote(
    to, # (str or WorkerInfo or int) 目标 worker 的名称、rank 或 WorkerInfo,目标 worker(将成为所有者)
    func, # (Callable) 用于创建对象的函数,是一个可调用函数(如 Python 函数、内置运算符或 TorchScript 函数)
    args=None, # (tuple) 传给函数的可变位置参数
    kwargs=None, # (dict) 传给函数的可变关键字参数
    timeout=-1.0 # (float, optional) RPC 超时时间(秒),-1.0 表示使用默认值(timeout=UNSET_RPC_TIMEOUT, 此时走初始化逻辑,或_set_rpc_timeout 设置的值),0 表示无限超时
    # 注:异步操作不会阻塞,这里的 timeout 是从此刻开始,远程操作本身返回的时间,如果有错也是在 rref.to_here() 时抛出
    )
  • RRef 对象的一些使用方式:

    • 下面的两种调用相同
      1
      2
      3
      4
      5
      6
      7
      8
      9
      rref = rpc.remote("worker1", MyClass, args=(arg1, arg2))

      # 方式一:直接调用远程对象的函数(func_name 必须是 rref 持有对象的函数)
      rref.rpc_async().func_name(*args, **kwargs)

      # 方式二:原生的 rpc 实现过程,较为复杂,效果和方式一完全相同
      def run(rref, func_name, args, kwargs):
      return getattr(rref.local_value(), func_name)(*args, **kwargs)
      rpc.rpc_async(rref.owner(), run, args=(rref, func_name, args, kwargs))
  • RRef 的更多使用详情见附录

shutdown 函数的使用

  • 在所有 RPC 工作完成后、进程退出前,必须显式调用 torch.distributed.rpc.shutdown ();
  • 推荐使用 graceful=True(默认) 以等待所有未完成 RPC 结束并避免死锁与 SIGABRT 风险
  • 函数原型
    1
    2
    3
    4
    5
    6
    7
    # torch.distributed.rpc.shutdown
    def shutdown(
    graceful=True, # (bool),默认值为 True,表示是否优雅的结束
    # 如果 True,则 1) 等待所有针对 UserRRef 的系统消息处理完毕并删除这些引用;2) 阻塞直到本地与远程所有 RPC 进程均调用此方法,并等待所有未完成工作完成;
    # 当 graceful=False 时:不等待未完成的 RPC 任务和其他进程,直接强制关闭本地 RPC 系统。可能导致未完成任务失败或资源泄漏,仅建议在紧急退出场景使用
    timeout=DEFAULT_SHUTDOWN_TIMEOUT # 限制 “优雅等待” 的最大时长:
    )

重要注意事项

  • 较早的 PyTorch 版本(如 1.4 之前)API 可能不能使用
  • 分布式作业中的每个进程都必须成功调用 init_rpc,并在最后调用 rpc.shutdown() 来优雅地释放资源
  • RPC 框架在处理张量时,通常视为张量位于 CPU 上进行操作,以避免设备不匹配的错误
    • 如果需要在 GPU 上操作,可能需要在函数内部显式地将张量移动到 GPU
  • 网络问题、函数执行失败或超时都可能导致 RPC 调用失败,建议添加适当的异常处理
  • RPC 调用会有通信开销
    • 应尽量减少小规模的频繁调用,考虑批量操作或传输更大尺寸的数据以减少开销

附录:init_rpc 的初始化 IP 和端口要与 init_process_group 一致

  • 亲测:init_rpc 使用的 master_ip:master_port,必须和 init_process_group 使用的一样,否则无法启动
    • 暂未看到官方明确说明(待补充),猜测原因是因为 RPC 框架依赖已经初始化好的进程组(若没有初始化进程组,init_rpc 会自己新建进程组)
  • 示例如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    # start shell: torchrun --nproc_per_node=2 torch_distributed_rpc_demo.py
    import os
    import torch.distributed as dist
    import torch.distributed.rpc as rpc
    from torch.distributed.rpc import TensorPipeRpcBackendOptions

    def worker_function(a, b):
    worker_info = rpc.get_worker_info()
    print(f"Worker {worker_info.name} received a call with a={a}, b={b}")
    return a + b

    def main():
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    master_addr = os.environ.get('MASTER_ADDR', '未设置')
    master_port = int(os.environ.get('MASTER_PORT', '未设置'))
    print(f"master addr: {master_addr}:{master_port}")

    dist.init_process_group(backend='gloo', rank=rank, world_size=2,init_method=f"tcp://{master_addr}:{master_port}") # 提前初始化进程组

    rpc_backend_options = TensorPipeRpcBackendOptions(
    init_method=f"tcp://{master_addr}:{master_port}", # 与进程组初始化(init_process_group)时相同 IP和端口,可以正常初始化,端口不一致时会一直卡在这里不动
    num_worker_threads=2,
    rpc_timeout=10
    )

    # rpc_backend_options = None
    print(f"rank {rank} start init_rpc..., dist.is_initialized={dist.is_initialized()}")
    rpc.init_rpc(
    name=f"worker_{rank}",
    rank=rank,
    world_size=world_size,
    rpc_backend_options=rpc_backend_options
    )
    print(f"rank {rank} have done init_rpc")

    print(f"Worker {rank} is ready.")

    if rank == 0:
    result_sync = rpc.rpc_sync(
    to=f"worker_1", # 真正的执行过程发生在 worker1 上
    func=worker_function,
    args=(10, 20)
    )
    print(f"Worker 0 received sync result from Worker 1: {result_sync}")
    else:
    print("--- Worker 1 is waiting for RPC calls ---")

    rpc.shutdown() # 等待所有 RPC 完成后自动关闭
    dist.destroy_process_group() # 注:若没有显示调用 dist.init_process_group,也不用这一行

    if __name__ == "__main__":
    main()

附录:RPC 的 world_size 可以比 已初始化的进程组的多

  • RPC 的 world_size 可以比 已初始化的进程组的多(虽然有依赖,但两者是相对独立的体系)【具体深层逻辑待梳理】
    • init_rpcworld_size 可以比 init_process_group 的大
    • 具体实现时,部分 rank 进程不调用 init_process_group 即可实现
  • 示例如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    # start shell: torchrun --nproc_per_node=4 torch_distributed_rpc_demo.py
    import os
    import torch.distributed as dist
    import torch.distributed.rpc as rpc
    from torch.distributed.rpc import TensorPipeRpcBackendOptions

    def worker_function(a, b):
    worker_info = rpc.get_worker_info()
    print(f"Worker {worker_info.name} received a call with a={a}, b={b}")
    return a + b

    def main():
    # 获取环境变量
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    master_addr = os.environ.get('MASTER_ADDR', '未设置')
    master_port = int(os.environ.get('MASTER_PORT', '未设置'))
    print(f"master addr: {master_addr}:{master_port}")

    if rank < 2: # 仅 rank=0,1 的节点初始化进程组
    print(f"rank {rank} start init_process_group...")
    dist.init_process_group(
    backend='gloo',
    rank=rank,
    world_size=2, # world_size 设置为2,甚至仅初始化一个也可以,rank<2 改成 rank<1 即可
    init_method=f"tcp://{master_addr}:{master_port}"
    ) # 提前初始化好进程组
    print(f"rank {rank} have done init_process_group")
    print(dist)

    rpc_backend_options = TensorPipeRpcBackendOptions(
    init_method=f"tcp://{master_addr}:{master_port}", # 与进程组初始化(init_process_group)时相同 IP和端口,可以正常初始化,端口不一致时会一直卡在这里不动
    num_worker_threads=2,
    rpc_timeout=10
    )

    print(f"rank {rank} start init_rpc..., dist.is_initialized={dist.is_initialized()}")
    rpc.init_rpc(
    name=f"worker_{rank}",
    rank=rank,
    world_size=world_size, # world_size 由传入的确定,实际上可能是 4
    rpc_backend_options=rpc_backend_options
    )
    print(f"rank {rank} have done init_rpc")

    print(f"Worker {rank} is ready.")

    if rank == 0:
    print("--- Worker 0 sending RPC calls ---")

    # 同步调用
    result_sync = rpc.rpc_sync(
    to=f"worker_1", # 真正的执行过程发生在 worker1 上
    func=worker_function,
    args=(10, 20)
    )
    print(f"Worker 0 received sync result from Worker 1: {result_sync}")

    # 异步调用
    future_async = rpc.rpc_async(
    to=f"worker_1",
    func=worker_function,
    args=(30, 40)
    )
    print("Worker 0 is doing other work while waiting for async result...")
    result_async = future_async.wait()
    print(f"Worker 0 received async result from Worker 1: {result_async}")

    elif rank == 3:
    print("--- Worker 3 sending RPC calls ---")

    # 同步调用
    result_sync = rpc.rpc_sync(
    to=f"worker_2", # 真正的执行过程发生在 worker2 上
    func=worker_function,
    args=(10, 20)
    )
    print(f"Worker 3 received sync result from Worker 2: {result_sync}")

    else:
    print("--- Worker 1 is waiting for RPC calls ---")

    rpc.shutdown() # 等待所有 RPC 完成后自动关闭
    if dist.is_initialized(): # 必须判断,否则可能报错,因为部分进程为初始化 init_process_group
    dist.destroy_process_group() # 注:若没有显示调用 dist.init_process_group,不用这一行

    if __name__ == "__main__":
    main()

附录:RRef 的使用方式细节

  • 参考自:PyTorch 官方文档-RRef
  • Warning:当前使用 CUDA 张量时,暂不支持 RRef,需要把张量从 GPU 转移到 CPU 才可以
  • RRef(Remote REFerence,远程引用)是指向远程工作节点上某一类型(如张量 Tensor)值的引用
    • 该句柄会确保被引用的远程值在其所属节点上保持有效
    • 在多机训练场景中,RRef 可用于持有对其他工作节点上 nn.Module(神经网络模块)的引用,并在训练过程中调用相应函数来获取或修改这些模块的参数
    • 更多细节可参考:Remote Reference Protocol
  • 注:RRef 也是一个 Python 类(torch.distributed.rpc.RRef),但有时泛指所有远程引用,而不是某一个类

torch.distributed.rpc.PyRRef 类型总体说明

  • 类型定义原型:

    1
    2
    3
    4
    5
    6
    7
    # __pybind11_builtins.pybind11_object 表明 PyRRef 是一个通过 pybind11 技术从 C++ 创建的 Python 类
    class PyRRef(__pybind11_builtins.pybind11_object):
    # ...

    # RRef 的定义
    class RRef(PyRRef, Generic[T]):
    pass
  • PyRRef 类用于封装对远程工作节点上某一类型值的引用

    • 此句柄会确保被引用的远程值在对应工作节点上保持有效
    • 当满足以下任一条件时,UserRRef(用户侧远程引用)将被销毁:
      • 1)在应用代码和本地 RRef 上下文环境中,均不存在对该 UserRRef 的引用;
      • 2)应用已调用 graceful shutdown 操作
    • 调用已销毁 RRef 的方法会导致未定义行为
      • RRef 实现仅提供“尽力而为”的错误检测机制(best-effort error detection),因此在调用 rpc.shutdown()(RPC 关闭函数)后,应用不应再使用 UserRRef
  • 警告:RRef 只能由 RPC 模块进行序列化和反序列化操作

    • 若不通过 RPC 模块(如使用 Python pickle 模块、PyTorch 的 save()/load() 函数、JIT 的 save()/load() 函数等)对 RRef 进行序列化或反序列化,将会引发错误

PyRRef 类型初始化参数(不常用,一般不自己定义 RRef)

  • value(object 类型):需用此 RRef 封装的值
  • type_hint(Type 类型,可选):传递给 TorchScript 编译器的 Python 类型提示,用于指定 value 的类型

PyRRef 类型使用示例

  • 以下示例为简化说明,省略了 RPC 初始化和关闭相关代码。有关这些操作的详细内容,请参考 RPC 文档

  • 使用 rpc.remote() 创建 RRef

    1
    2
    3
    4
    5
    6
    7
    import torch
    import torch.distributed.rpc as rpc

    # 在远程工作节点“worker1”上执行 torch.add 操作,并创建指向结果的 RRef
    rref = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) # Debug 时看到的可能是 torch._C._distributed_rpc.PyRRef 类型,实际和 torch.distributed.rpc.PyRRef 是同一个类的不同路径
    # 从 RRef 中获取值的副本
    x = rref.to_here()
  • 基于本地对象创建 RRef

    1
    2
    3
    4
    5
    6
    7
    import torch
    from torch.distributed.rpc import RRef

    # 创建本地张量
    x = torch.zeros(2, 2)
    # 用 RRef 封装本地张量,生成 RRef 对象
    rref = RRef(x)
  • 与其他工作节点共享 RRef

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    # 在工作节点“worker0”和“worker1”上均执行以下函数定义
    def f(rref):
    # 获取 RRef 指向的值并加 1
    return rref.to_here() + 1

    # 在工作节点“worker0”上执行以下代码
    import torch
    import torch.distributed.rpc as rpc
    from torch.distributed.rpc import RRef

    # 基于本地零矩阵创建 RRef
    rref = RRef(torch.zeros(2, 2))
    # 通过 RPC 调用,将 RRef 共享给工作节点“worker1”,引用计数会自动更新
    rpc.rpc_sync("worker1", f, args=(rref,))

PyRRef.backward 方法

  • backward 方法原型

    1
    backward(self: torch._C._distributed_rpc.PyRRef, dist_autograd_ctx_id: int = -1, retain_graph: bool = False) -> None
  • backward 方法以 RRef 作为反向传播的根节点,执行反向传播过程

  • 若提供了 dist_autograd_ctx_id(分布式自动求导上下文 ID),则会使用该上下文 ID,从 RRef 的所属节点开始执行分布式反向传播

    • 这种情况下,应使用 get_gradients() 函数获取梯度
  • dist_autograd_ctx_idNone,则默认当前为本地自动求导图,仅执行本地反向传播

    • 在本地反向传播场景中,调用此 API 的节点必须是该 RRef 的所属节点,且 RRef 所指向的值需为标量张量(scalar Tensor)
  • PyRRef.backward 方法参数

    • dist_autograd_ctx_idint 类型,可选):用于获取梯度的分布式自动求导上下文 ID,默认值为 -1
    • retain_graphbool 类型,可选):若设为 False,则计算梯度所用的计算图会被释放
      • 几乎不需要设为 True
      • 一般仅在需要多次执行反向传播时,才需将其设为 True,默认值为 False
  • PyRRef.backward 方法示例

    1
    2
    3
    4
    5
    6
    import torch.distributed.autograd as dist_autograd

    # 创建分布式自动求导上下文,并获取上下文 ID
    with dist_autograd.context() as context_id:
    # 以当前 RRef 为根节点,执行分布式反向传播
    rref.backward(context_id)

PyRRef 一些简单方法说明

  • confirmed_by_owner 方法:

    1
    confirmed_by_owner(self: torch._C._distributed_rpc.PyRRef)  bool
    • 返回该 RRef 是否已得到其所属节点的确认(真正拥有对象的节点为所属节点,对象时存储到所属节点上的)
      • OwnerRRef(所属节点侧远程引用)始终返回 True
      • UserRRef(用户侧远程引用)仅在所属节点知晓该 UserRRef 存在时,才返回 True
  • is_owner 方法:

    1
    is_owner(self: torch._C._distributed_rpc.PyRRef) -> bool
    • 返回当前节点是否为该 RRef 的所属节点(对象时存储到所属节点上的)
  • local_value 方法:

    1
    local_value(self: torch._C._distributed_rpc.PyRRef) -> object
    • 当前节点是该 RRef 的所属节点 ,则返回对本地值的引用;否则,抛出异常
  • owner 方法:

    1
    owner(self: torch._C._distributed_rpc.PyRRef) -> torch._C._distributed_rpc.WorkerInfo
    • 返回该 RRef 所属节点的工作节点信息(WorkerInfo 对象)
  • owner_name 方法:

    1
    owner_name(self: torch._C._distributed_rpc.PyRRef) -> str
    • 返回该 RRef 所属节点的工作节点名称(字符串形式)

PyRRef.remote 方法

  • remote 方法原型:

    1
    remote(self: torch._C._distributed_rpc.PyRRef, timeout: float = -1.0) -> object
  • remote 方法创建一个辅助代理(helper proxy)

    • 该代理可以用于发起 remote 调用,以 RRef 的所属节点为目标节点,在该 RRef 所指向的对象上执行函数
  • 更具体地说,rref.remote().func_name(*args, **kwargs) 等价于以下代码:

    1
    2
    3
    4
    5
    6
    def run(rref, func_name, args, kwargs):
    # 获取 RRef 本地值,并调用其指定方法
    return getattr(rref.local_value(), func_name)(*args, **kwargs)

    # 向 RRef 所属节点发起 remote 调用,执行 run 函数
    rpc.remote(rref.owner(), run, args=(rref, func_name, args, kwargs))
  • remote 方法参数:

    • timeoutfloat 类型,可选):rref.remote() 的超时时间(单位:秒)
      • 若在该时间内未成功创建 RRef,则下次尝试使用该 RRef(如调用 to_here() 方法)时,会引发超时错误
      • 若未提供该参数,则使用默认的 RPC 超时时间
  • remote 方法示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    from torch.distributed import rpc

    # 在远程节点“worker1”上执行 torch.add,创建指向结果的 RRef
    rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
    # 调用 RRef 指向对象的 size() 方法,获取结果(远程执行,本地获取)
    rref.remote().size().to_here() # 返回结果:torch.Size([2, 2])
    # 注:返回类型
    # rref.remote(): <torch.distributed.rpc.rref_proxy.RRefProxy object at 0x11fa6e470>
    # rref.remote().size(): UserRRef(RRefId = GloballyUniqueId(created_on=0, local_id=4), ForkId = GloballyUniqueId(created_on=0, local_id=5))
    # rref.remote().size().to_here(): torch.Size([2])

    # 调用 RRef 指向对象的 view() 方法,重塑张量形状并获取结果
    rref.remote().view(1, 4).to_here() # 返回结果:tensor([[1., 1., 1., 1.]])

PyRRef.rpc_async 方法

  • rpc_async 方法原型:

    1
    rpc_async(self: torch._C._distributed_rpc.PyRRef, timeout: float = -1.0) -> object
  • rpc_async 方法创建一个辅助代理

    • 该代理可发起 rpc_async 调用,以 RRef 的所属节点为目标节点,在该 RRef 所指向的对象上执行函数
  • 更具体地说,rref.rpc_async().func_name(*args, **kwargs) 等价于以下代码:

    1
    2
    3
    4
    5
    6
    def run(rref, func_name, args, kwargs):
    # 获取 RRef 本地值,并调用其指定方法
    return getattr(rref.local_value(), func_name)(*args, **kwargs)

    # 向 RRef 所属节点发起异步 RPC 调用,执行 run 函数
    rpc.rpc_async(rref.owner(), run, args=(rref, func_name, args, kwargs))
  • rpc_async 方法参数

    • timeoutfloat 类型,可选):rref.rpc_async() 的超时时间(单位:秒)
      • 若在该时间内调用未完成,则会引发超时异常
      • 若未提供该参数,则使用默认的 RPC 超时时间
  • rpc_async 方法示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    from torch.distributed import rpc

    # 在远程节点“worker1”上执行 torch.add,创建指向结果的 RRef
    rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
    # 异步调用 size() 方法,等待结果返回
    rref.rpc_async().size().wait() # 返回结果:torch.Size([2, 2])
    # 注:返回类型
    # rref.rpc_async(): <torch.distributed.rpc.rref_proxy.RRefProxy object at 0x11bcfa3e0>
    # rref.rpc_async().size(): <torch.jit.Future object at 0x11bcc65c0>
    # rref.rpc_async().size().wait(): torch.Size([2])

    # 异步调用 view() 方法,重塑张量形状并等待结果返回
    rref.rpc_async().view(1, 4).wait() # 返回结果:tensor([[1., 1., 1., 1.]])

PyRRef.rpc_sync 方法

  • rpc_sync 方法原型:

    1
    rpc_sync(self: torch._C._distributed_rpc.PyRRef, timeout: float = -1.0) -> object
  • rpc_sync 方法创建一个辅助代理

    • 该代理可发起 rpc_sync 调用,以 RRef 的所属节点为目标节点,在该 RRef 所指向的对象上执行函数
  • 更具体地说,rref.rpc_sync().func_name(*args, **kwargs) 等价于以下代码:

    1
    2
    3
    4
    5
    6
    def run(rref, func_name, args, kwargs):
    # 获取 RRef 本地值,并调用其指定方法
    return getattr(rref.local_value(), func_name)(*args, **kwargs)

    # 向 RRef 所属节点发起同步 RPC 调用,执行 run 函数
    rpc.rpc_sync(rref.owner(), run, args=(rref, func_name, args, kwargs))
  • rpc_sync 方法参数:

    • timeoutfloat 类型,可选):rref.rpc_sync() 的超时时间(单位:秒)
      • 若在该时间内调用未完成,则会引发超时异常
      • 若未提供该参数,则使用默认的 RPC 超时时间
  • rpc_sync 方法示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    from torch.distributed import rpc

    # 在远程节点“worker1”上执行 torch.add,创建指向结果的 RRef
    rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
    # 同步调用 size() 方法,获取张量尺寸
    rref.rpc_sync().size() # 返回结果:torch.Size([2, 2])
    # 注:返回类型:
    # rref.rpc_sync(): <torch.distributed.rpc.rref_proxy.RRefProxy object at 0x11bcfa440>
    # rref.rpc_sync().size(): torch.Size([2])

    # 同步调用 view() 方法,重塑张量形状
    rref.rpc_sync().view(1, 4) # 返回结果:tensor([[1., 1., 1., 1.]])

PyRRef.to_here 方法

  • to_here 方法原型:

    1
    to_here(self: torch._C._distributed_rpc.PyRRef, timeout: float = -1.0) -> object
  • to_here 方法是阻塞式调用,将 RRef 指向的值从其所属节点复制到本地节点,并返回该值

    • 若当前节点是该 RRef 的所属节点,则直接返回对本地值的引用
  • to_here 方法参数:

    • timeoutfloat 类型,可选):to_here 方法的超时时间(单位:秒)
      • 若在该时间内调用未完成,则会引发超时异常
      • 若未提供该参数,则使用默认的 RPC 超时时间(60 秒)

附录:torch.distributed.rpc.PyRRef 和 torch._C._distributed_rpc.PyRRef 的区别

  • 这两个是同一个类的不同引用路径,但有重要的使用区别,两者本质关系为

    1
    2
    3
    4
    # 推荐使用的方式;稳定,版本间兼容性好;可能包含额外的封装、错误检查、文档
    torch.distributed.rpc.PyRRef # 公共API接口,官方支持,文档完整
    # 不推荐直接使用;内部实现细节;可能在PyTorch版本更新时发生变化;直接的C++绑定,可能缺少一些Python层的便利功能
    torch._C._distributed_rpc.PyRRef # 底层C++实现,以`_`开头表示私有
  • 实际上 torch.distributed.rpc.PyRRef 就是对 torch._C._distributed_rpc.PyRRef 的封装或直接引用

  • 验证两者的等价关系

    1
    2
    3
    4
    5
    import torch.distributed.rpc as rpc
    import torch._C._distributed_rpc as _rpc

    # 通常这两个是相同的对象
    print(rpc.PyRRef is _rpc.PyRRef) # 可能输出 True

为什么会看到内部 API torch._C._distributed_rpc.PyRRef?

  • 调试时的堆栈跟踪可能显示内部路径
  • 某些高级用法或源码分析
  • 类型检查时可能遇到