PyTorch——分布式编程之子通信组


整体介绍

  • 子通信组允许在分布式环境中灵活地划分进程,实现更精细的通信控制
  • torch.distributed.new_group 是 PyTorch 分布式训练中用于创建子通信组的函数
  • 使用场景包括:
    • 部分进程通信:当需要让部分进程单独通信(如模型并行中不同层的参数同步)
    • 灵活分组:动态划分进程组,适应复杂的分布式策略(如混合数据并行+模型并行)
  • 使用步骤包括:
    • 1)初始化全局进程组:先通过 init_process_group 初始化全局通信环境
    • 2)创建子组:调用 new_group 划分进程
    • 3)子组内通信:使用返回的 ProcessGroup 对象进行通信操作
  • 子通信组使用的核心注意事项
    • 进程一致性:所有进程必须调用 new_group ,即使不加入子组(此时可传入 ranks 不包含自身,或后续不使用返回的组对象)
    • 后端兼容性:子组的 backend 需与全局后端兼容(如 GPU 通信推荐 nccl

new_group 函数定义

  • new_group 函数形式说明:

    1
    2
    3
    4
    5
    torch.distributed.new_group(
    ranks=None, # 参与新组的进程编号列表
    timeout=datetime.timedelta(seconds=1800), # 超时时间
    backend=None # 通信后端,默认为全局后端
    )
  • new_group 核心参数说明如下文

    • ranks(可选,列表/元组):
      • 指定加入新组的进程编号(全局进程编号,非局部编号)
      • 若为 None,则默认包含所有进程(等价于全局组)
      • 例如:ranks=[0,1,2] 表示仅 0、1、2 号进程加入新组
    • timeout**(可选,datetime.timedelta):
      • 组内通信的超时时间,超时未完成会抛出异常
      • 默认为 30 分钟(1800 秒)
    • backend**(可选,字符串):
      • 指定该组使用的通信后端(如 ncclgloo 等)
      • 若为 None,则继承全局初始化的后端(init_process_group 中指定的 backend)
  • new_group 返回值

    • 返回一个 ProcessGroup 对象 ,用于后续子组内的通信操作(如 allreducebroadcast 等)

附录:为什么所有进程都要调用子通信组初始化函数

  • 在 PyTorch 分布式的最佳实践中,即使不加入子进程组的 rank(如例子中的 rank=3),也必须调用 dist.new_group(ranks=[0, 1], ...)
    • 这是由分布式通信的一致性要求决定的
  • 必须调用的核心原因是:避免死锁
    • PyTorch 分布式通信的底层实现要求所有进程必须参与子组的创建过程 ,无论是否加入该子组
    • 若部分进程调用 new_group 而其他进程不调用,会导致进程间同步失衡,触发分布式死锁(所有进程会阻塞等待未调用的进程)
    • 即使某进程明确不加入子组(不在 ranks 列表中),也需要通过调用 new_group 完成“知晓该子组存在”的协议同步
  • 不加入子组的进程如何处理返回的 ProcessGroup 对象?
    • 对于不加入子组的进程(如 rank=3),调用 new_group 后会返回一个有效的 ProcessGroup 对象,但该进程不属于该组
    • 此时的最佳实践是:保留该对象但不使用它进行通信(或在通信前先判断是否属于子组)
    • 可通过 dist.get_rank(group=subgroup) 检查:若返回 -1,说明当前进程不属于该子组,应跳过子组内的通信操作

子通信组示例代码

  • 代码示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    import torch
    import torch.distributed as dist
    from datetime import timedelta

    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()

    # 所有进程(无论是否加入子组)必须调用 `dist.new_group`,否则会导致分布式死锁
    subgroup = dist.new_group(ranks=[0, 1], timeout=timedelta(seconds=30))

    # 检查当前进程是否属于子组,非子组成员进程可通过 `dist.get_rank(group=subgroup) != -1` 判断身份,避免无效通信
    is_in_subgroup = dist.get_rank(group=subgroup) != -1

    if is_in_subgroup:
    # 子组成员执行通信操作
    tensor = torch.tensor([rank], device="cuda")
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=subgroup)
    print(f"Rank {rank}(子组成员):通信结果 = {tensor.item()}")
    else:
    # 非子组成员跳过通信,或执行其他逻辑
    print(f"Rank {rank}(非子组成员):不参与子组通信")

    dist.destroy_process_group()