PyTorch——FSDP的使用


整体说明

  • FSDP(Fully Sharded Data Parallel)是 PyTorch 推出的分布式训练技术,专为大规模模型训练设计
  • FSDP 通过分片模型参数、梯度和优化器状态到多个 GPU,显著降低单卡内存占用,支持训练远超单卡内存容量的大型模型(如千亿参数级模型)
  • 与传统的 DDP(Distributed Data Parallel)相比
    • DDP 会在每个 GPU 上保存完整的模型参数、梯度和优化器状态(数据并行)
    • FSDP 仅在每个 GPU 上保存部分分片,大幅提升内存效率(模型并行)
  • FSDP 的核心特性包括
    • 全分片机制 :模型参数、梯度、优化器状态均被分片存储在不同 GPU,仅在计算时临时聚合所需分片
    • 自动通信优化 :通过重叠计算与通信、按需聚合分片,减少分布式训练的通信开销
    • 混合精度支持 :原生支持 FP16/BF16 混合精度训练,进一步节省内存
    • 灵活的包装策略 :可通过 wrap_policy 指定需要分片的子模块,支持部分模块分片、部分模块复制(如小模型组件)
    • 与模型并行结合 :可与 Tensor Parallel 等模型并行技术结合,支持超大规模模型训练
  • FSDP 的工作流程
    • 1)初始化 :将模型划分为多个子模块,每个子模块的参数被分片到不同 GPU
    • 2)前向传播 :当需要某子模块参数时,FSDP 自动聚合所需分片到当前 GPU,计算完成后释放临时聚合的参数
    • 3)反向传播 :梯度按相同策略分片存储,避免全量梯度占用内存
    • 4)参数更新 :优化器仅更新本地保存的分片参数,通过通信同步确保全局一致性
  • 注意事项:
    • 模型保存/加载时需使用 FSDP 提供的 state_dict() 方法,避免直接保存原始模型参数(分片后参数分布在不同进程)
    • 可结合 activation checkpointing 进一步节省内存

FSDP 单机多卡使用示例

  • FSDP 单机多卡示例代码

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import Dataset, DataLoader
    import torch.distributed as dist
    from torch.nn.parallel import DistributedDataParallel as DDP
    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
    from torch.distributed.fsdp.fully_sharded_data_parallel import (
    CPUOffload,
    BackwardPrefetch,
    ) # 用括号的导入方式跟普通方式效果完全相同,区别在于括号的方式看起来更容易管理和注释某一行
    from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
    import torch.multiprocessing as mp
    from torch.utils.data.distributed import DistributedSampler

    # 1. 定义数据集
    class DiyDataset(Dataset):
    def __len__(self):
    return 1000

    def __getitem__(self, idx):
    x = torch.randn(10) # 输入特征
    y = torch.randint(0, 2, (1,)).item() # 分类标签(二分类)
    return x, y

    # 2. 定义模型(包含多个子模块,便于后续使用 FSDP 分片)
    class SubModule(nn.Module):
    def __init__(self, input_dim, output_dim):
    super().__init__()
    self.fc = nn.Linear(input_dim, output_dim)
    self.relu = nn.ReLU()

    def forward(self, x):
    return self.relu(self.fc(x))

    class MainModel(nn.Module):
    def __init__(self):
    super().__init__()
    self.layer1 = SubModule(10, 20)
    self.layer2 = SubModule(20, 30)
    self.layer3 = SubModule(30, 2) # 输出维度为2(二分类)

    def forward(self, x):
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    return x

    # 3. 分布式训练主函数
    def train(rank, world_size):
    # 初始化分布式环境 通过`dist.init_process_group`配置 NCCL 后端(GPU 推荐),并绑定进程到对应 GPU
    dist.init_process_group(
    backend='nccl', # GPU 推荐使用 nccl 后端,与 DataParallel 类似
    init_method='tcp://localhost:12355', # init_method 是必填参数(或通过环境变量指定),用于初始化分布式进程组,确保单机内的多个进程(每个进程对应一个 GPU)能够相互发现并建立通信,localhost表示只有本地一台机器
    rank=rank,
    world_size=world_size
    )
    torch.cuda.set_device(rank) # 绑定当前进程到对应的GPU

    # 4. 配置 FSDP 参数
    # 定义自动包装策略:当子模块参数数量超过 1e6 时进行分片
    auto_wrap_policy = size_based_auto_wrap_policy(min_num_params=1e6)

    # 初始化模型并应用 FSDP
    model = MainModel().cuda(rank)
    model = FSDP(
    model,
    auto_wrap_policy=auto_wrap_policy, # 自动递归包装子模块,避免手动指定 `wrap_module`,简化大模型配置 `size_based_auto_wrap_policy` 函数生成的对象,表示按参数大小进行分片,前面定义的是 >1e6 的子模块被分片
    cpu_offload=CPUOffload(offload_params=False), # False(默认值)表示仅将梯度卸载到 CPU;True 表示将梯度和参数都卸载到 CPU;前向传播时加载,反向传播后卸载
    backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # 反向传播前预取参数,控制 反向传播预取
    sharding_strategy=FSDP.ShardingStrategy.FULL_SHARD, # 全分片策略(`FULL_SHARD`策略),即参数、梯度、优化器状态均分片(内存使用最小的策略);还可以使用 `SHARD_GRAD_OP`(平衡内存和性能)
    device_id=rank, # 当前设备 ID
    )

    # 5. 数据加载(分布式采样)
    dataset = DiyDataset()
    sampler = DistributedSampler(dataset, shuffle=True) # 确保各卡数据不重复,类似 DDP 的做法
    dataloader = DataLoader(
    dataset,
    batch_size=32,
    sampler=sampler,
    num_workers=2
    )

    # 6. 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss().cuda(rank)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    # 7. 训练循环
    model.train()
    for epoch in range(3):
    sampler.set_epoch(epoch) # 每个 epoch 打乱数据
    total_loss = 0.0
    for x, y in dataloader:
    x = x.cuda(rank)
    y = y.cuda(rank)

    optimizer.zero_grad() # 清零本地梯度分片
    outputs = model(x) # 前向传播(FSDP 自动聚合所需参数分片)
    loss = criterion(outputs, y) # 损失计算,loss 与 outputs 绑定,从而与模型绑定,故而可以使用 FSDP 的特性
    loss.backward() # 反向传播:触发梯度计算与分片梯度同步(核心行),类似操作可以参考 DataParallel 的实现方式
    optimizer.step() # 用同步后的梯度更新本地参数分片

    total_loss += loss.item()

    # 仅在主进程(rank=0)打印日志
    if rank == 0:
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")

    # 清理分布式环境
    dist.destroy_process_group()

    # 8. 启动多进程训练
    def main():
    world_size = 2 # 使用 2 个 GPU,注意这里是单机模式,可以直接写死各种配置
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)

    if __name__ == "__main__":
    main()
    • 反向传播预取说明:
      • FSDP 中,模型参数被分片存储在不同 GPU 上
      • 在反向传播时,计算某层梯度可能需要其他 GPU 上的参数分片(如计算梯度需要完整的参数信息)
      • backward_prefetch 决定了何时提前获取这些所需的参数分片,从而让数据传输(通信)与计算重叠进行,减少空闲时间
  • 运行上述单机多卡代码,正常使用 python 命令即可,不需要特殊命令启动:

    1
    python fsdp_example.py

FSDP 多机多卡使用示例

  • FSDP 多机多卡使用示例(其实也可以用多机多卡直接改成单机单卡形式)
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import Dataset, DataLoader
    import torch.distributed as dist
    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
    from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
    from torch.utils.data.distributed import DistributedSampler
    import torch.multiprocessing as mp
    import argparse

    # 数据集和模型定义(正常定义)
    class DiyDataset(Dataset):
    def __len__(self):
    return 10000
    def __getitem__(self, idx):
    return torch.randn(128), torch.randint(0, 10, (1,)).item()

    class SubModule(nn.Module):
    def __init__(self, in_dim, out_dim):
    super().__init__()
    self.fc = nn.Linear(in_dim, out_dim)
    self.bn = nn.BatchNorm1d(out_dim)
    self.relu = nn.ReLU()
    def forward(self, x):
    return self.relu(self.bn(self.fc(x)))

    class MainModel(nn.Module):
    def __init__(self):
    super().__init__()
    self.layer1 = SubModule(128, 256)
    self.layer2 = SubModule(256, 512)
    self.layer3 = SubModule(512, 1024)
    self.final = nn.Linear(1024, 10)
    def forward(self, x):
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    return self.final(x)

    # 子进程函数(每个进程对应一个GPU)
    def train_fn(local_rank, args, world_size):
    rank = args.start_rank + local_rank # 全局进程编号 = 起始rank + 本地进程编号
    print(rank) # 输出全局进程编号
    # 初始化分布式进程组
    dist.init_process_group(
    backend="gloo", # GPU 时使用 `nccl`
    init_method=f"tcp://{args.master_addr}:{args.master_port}", # 多机多卡需要指定服务器地址,不能写死成 local
    world_size=world_size,
    rank=rank, # 全局进程编号 = 起始rank + 本地进程编号
    )

    # 绑定本地GPU
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)

    # 初始化 FSDP 模型
    auto_wrap_policy = size_based_auto_wrap_policy(min_num_params=2e6)
    model = MainModel().to(device)

    model = FSDP( # 整体配置同单进程模式,为了简化,这里不写 cpu_offload 和 backward_prefetch 的配置
    model,
    auto_wrap_policy=auto_wrap_policy, # 自动递归包装子模块,避免手动指定 `wrap_module`,简化大模型配置
    sharding_strategy=FSDP.ShardingStrategy.FULL_SHARD,
    device_id=local_rank,
    )

    # 数据加载和训练逻辑(与之前类似)
    dataset = DiyDataset()
    sampler = DistributedSampler(dataset, world_size=world_size, rank=args.start_rank + local_rank) # 为不同 GPU 分发不同机器,分成 world_size 个并取 第 args.start_rank + rank 份
    dataloader = DataLoader(dataset, batch_size=64, sampler=sampler, num_workers=4)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    model.train()
    for epoch in range(5):
    sampler.set_epoch(epoch)
    total_loss = 0.0
    for x, y in dataloader:
    x, y = x.to(device), y.to(device)
    optimizer.zero_grad()
    outputs = model(x)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()
    total_loss += loss.item()

    if args.start_rank + local_rank == 0: # 主进程打印日志
    print(f"Epoch {epoch+1}, Avg Loss: {total_loss/len(dataloader):.4f}")

    dist.destroy_process_group()

    def main():
    # 对于多机多卡,为了复用同一套代码(方便管理),使用传入参数的方式启动代码,方便传参实现不同 node 机器的启动
    parser = argparse.ArgumentParser()
    parser.add_argument("--master_addr", type=str, default="localhost", help="主节点IP")
    parser.add_argument("--master_port", type=str, default="12355", help="主节点端口")
    parser.add_argument("--world_size", type=int, required=True, help="总进程数")
    parser.add_argument("--start_rank", type=int, required=True, help="当前机器进程的起始全局rank")
    parser.add_argument("--num_gpus", type=int, default=torch.cuda.device_count(), help="当前机器的GPU数量")
    args = parser.parse_args()
    print(args)
    mp.spawn(
    train_fn,
    args=(args, args.world_size), # 传递给train_fn的参数
    nprocs=args.num_gpus, # 进程数=当前机器的GPU数,启动当前机器的所有进程(num_gpus个),每台机器启动自己的进程数即可
    join=True # 等待所有子进程完成
    )

    if __name__ == "__main__":
    main()

关于启动命令项的说明


附录:集成了 FSDP 分布式框架 HuggingFace Accelerate

  • HuggingFace Accelerate 是集成 了 FSDP 技术的工具,属于 accelerate
  • 两者辨析:FSDP 是一种分布式训练技术,Accelerate 是集成该技术的工具
    • FSDP 是由 PyTorch 官方推出的分布式训练策略,核心是通过将模型参数、梯度和优化器状态进行分片存储,显著降低单卡内存占用,支持训练超大规模模型(如千亿参数模型)
      • 属于“分布式训练方法”的范畴
    • HuggingFace Accelerate 是一个简化分布式训练的工具库,其核心功能是屏蔽不同硬件环境(单卡、多卡、GPU/TPU)和分布式策略(如数据并行、模型并行、FSDP 等)的底层细节,让用户用少量代码即可实现分布式训练
      • Accelerate 内部集成了 FSDP ,将其作为一种可选的分布式策略供用户使用
  • Accelerate 对 FSDP 起到封装和简化作用 :使用原生 PyTorch 的 FSDP 时,需要手动编写较多分布式配置代码(如初始化进程组、包装模型、设置通信策略等)
    • Accelerate 通过抽象接口,将 FSDP 的使用简化:
    • 用户只需通过 Accelerator 类,并在配置中指定使用 FSDP,即可自动完成 FSDP 的初始化和模型包装
    • 例如,通过 accelerate launch 命令启动训练时,Accelerate 会根据配置自动选择包括 FSDP 在内的最佳分布式策略,无需用户手动调用 torch.distributedtorch.distributed.fsdp 的底层 API
  • Accelerate 则不局限于 FSDP,还支持数据并行、DeepSpeed 等其他策略,其价值在于统一接口 ,让用户在不同分布式策略之间无缝切换,无需大幅修改代码
  • HuggingFace Accelerate 的使用见:PyTorch——Accelerate使用总结

附录:FSDP 和 FSDP2 的区别

  • FSDP2 是 FSDP 的全新版本
  • 架构上:
    • FSDP1 采用 FlatParameter 设计,将一组张量展平、连接并分块,使得每个设备上的数据推理和重新分片变得复杂
    • FSDP2 使用 DTensor 基础架构,通过在 dim-0 上对每个参数进行分片,每个参数在数据并行工作器之间按 dim-0 进行分块,提供了更简单的分片表示,也使得对单个参数的操作更便捷,还实现了无需通信的分片状态字典
  • 内存管理上:
    • FSDP1 使用 recordStream 机制,导致 GPU 内存使用不够优化且非确定性,有时还需要 CPU 同步
      • recordStream 机制是一种用于通过 CUDA 流(Stream)记录和同步张量生命周期 的内存管理技术,主要用于优化 GPU 显存使用和计算效率,但也存在一定局限性
      • recordStream 跟踪张量在 CUDA 流中的使用时机,确保张量在计算完成后再被释放或重用于其他操作,避免因显存提前回收导致的计算错误
      • FSDP V1 在训练过程中需要频繁在不同设备(或分片)间迁移参数(如 Forward 时 AllGather 完整参数、Backward 后释放显存)
    • FSDP2 借助 DTensor,实现了更低且确定性的 GPU 内存使用,避免了 recordStream 机制,无需 CPU 同步,能更有效地管理内存
  • API设计:上有所增减
  • 在实际测试中,FSDP2 相比 FSDP1实现了更高的MFU(模型浮点利用率),峰值内存降低了 7%,且能保持相同的损失曲线(Llama-7B 模型在 8×H100)
  • FSDP2 不直接支持完整的状态字典,用户需要使用 DTensor 的 API 或更高层次的 API 将包含 DTensor 的分片状态字典重新分片为完整状态字典,而 FSDP1 支持完整的状态字典
    • FSDP1 使用 FlatParameter 将多个小参数张量展平并拼接成一个大张量,然后对这个大张量进行分片
      • 在这种情况下,FSDP1 可以通过 state_dict_config 配置来保存和加载完整的状态字典,其中完整的状态字典中的值是未分片的 PyTorch 张量
    • FSDP2 采用了基于 DTensor 的逐参数维度分片方式,每个参数张量直接按维度进行分片,其参数表示直接与分片状态字典匹配
      • 在 FSDP2 中,调用 model.state_dict() 返回的是一个分片状态字典,且不涉及任何计算或通信,FSDP2 也只支持加载分片状态字典
      • 如果需要完整的状态字典,用户需要使用 DTensor 的相关 API,如 dtensor.full_tensor(),或使用更高层次的 API,如 PyTorch 分布式检查点的分布式状态字典 API,在 FSDP2 之外进行转换