整体介绍
- 子通信组允许在分布式环境中灵活地划分进程,实现更精细的通信控制
torch.distributed.new_group是 PyTorch 分布式训练中用于创建子通信组的函数- 使用场景包括:
- 部分进程通信:当需要让部分进程单独通信(如模型并行中不同层的参数同步)
- 灵活分组:动态划分进程组,适应复杂的分布式策略(如混合数据并行+模型并行)
- 使用步骤包括:
- 1)初始化全局进程组:先通过
init_process_group初始化全局通信环境 - 2)创建子组:调用
new_group划分进程 - 3)子组内通信:使用返回的
ProcessGroup对象进行通信操作
- 1)初始化全局进程组:先通过
- 子通信组使用的核心注意事项
- 进程一致性:所有进程必须调用
new_group,即使不加入子组(此时可传入ranks不包含自身,或后续不使用返回的组对象) - 后端兼容性:子组的
backend需与全局后端兼容(如 GPU 通信推荐nccl)
- 进程一致性:所有进程必须调用
new_group 函数定义
new_group函数形式说明:1
2
3
4
5torch.distributed.new_group(
ranks=None, # 参与新组的进程编号列表
timeout=datetime.timedelta(seconds=1800), # 超时时间
backend=None # 通信后端,默认为全局后端
)new_group核心参数说明如下文ranks(可选,列表/元组):- 指定加入新组的进程编号(全局进程编号,非局部编号)
- 若为
None,则默认包含所有进程(等价于全局组) - 例如:
ranks=[0,1,2]表示仅 0、1、2 号进程加入新组
timeout**(可选,datetime.timedelta):- 组内通信的超时时间,超时未完成会抛出异常
- 默认为 30 分钟(1800 秒)
backend**(可选,字符串):- 指定该组使用的通信后端(如
nccl、gloo等) - 若为
None,则继承全局初始化的后端(init_process_group中指定的 backend)
- 指定该组使用的通信后端(如
new_group返回值- 返回一个
ProcessGroup对象 ,用于后续子组内的通信操作(如allreduce、broadcast等)
- 返回一个
附录:为什么所有进程都要调用子通信组初始化函数
- 在 PyTorch 分布式的最佳实践中,即使不加入子进程组的 rank(如例子中的 rank=3),也必须调用
dist.new_group(ranks=[0, 1], ...)- 这是由分布式通信的一致性要求决定的
- 必须调用的核心原因是:避免死锁
- PyTorch 分布式通信的底层实现要求所有进程必须参与子组的创建过程 ,无论是否加入该子组
- 若部分进程调用
new_group而其他进程不调用,会导致进程间同步失衡,触发分布式死锁(所有进程会阻塞等待未调用的进程) - 即使某进程明确不加入子组(不在
ranks列表中),也需要通过调用new_group完成“知晓该子组存在”的协议同步
- 不加入子组的进程如何处理返回的
ProcessGroup对象?- 对于不加入子组的进程(如 rank=3),调用
new_group后会返回一个有效的ProcessGroup对象,但该进程不属于该组 - 此时的最佳实践是:保留该对象但不使用它进行通信(或在通信前先判断是否属于子组)
- 可通过
dist.get_rank(group=subgroup)检查:若返回-1,说明当前进程不属于该子组,应跳过子组内的通信操作
- 对于不加入子组的进程(如 rank=3),调用
子通信组示例代码
- 代码示例:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23import torch
import torch.distributed as dist
from datetime import timedelta
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
# 所有进程(无论是否加入子组)必须调用 `dist.new_group`,否则会导致分布式死锁
subgroup = dist.new_group(ranks=[0, 1], timeout=timedelta(seconds=30))
# 检查当前进程是否属于子组,非子组成员进程可通过 `dist.get_rank(group=subgroup) != -1` 判断身份,避免无效通信
is_in_subgroup = dist.get_rank(group=subgroup) != -1
if is_in_subgroup:
# 子组成员执行通信操作
tensor = torch.tensor([rank], device="cuda")
dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=subgroup)
print(f"Rank {rank}(子组成员):通信结果 = {tensor.item()}")
else:
# 非子组成员跳过通信,或执行其他逻辑
print(f"Rank {rank}(非子组成员):不参与子组通信")
dist.destroy_process_group()