Jiahong 的个人博客

凡事预则立,不预则废


  • Home

  • Tags

  • Archives

  • Navigation

  • Search

PyTorch——Accelerate使用总结

  • 参考链接:
    • 官方文档:docs.pytorch.org/docs/stable/fsdp.html

整体说明

  • HuggingFace Accelerate 是一个轻量级库(accelerate),专为简化 PyTorch 模型在各种硬件配置上的训练和推理而设计
  • 它能自动处理分布式训练、混合精度训练等复杂设置,让开发者无需深入了解底层硬件细节,就能轻松将模型部署到单 GPU、多 GPU、TPU 甚至 CPU 集群等环境中(专注于模型逻辑和训练流程即可)
  • 安装简便,仅需一行代码:pip install accelerate
  • Accelerate 的核心功能包括下面几个
    • 自动识别可用硬件(GPU、TPU 等),并根据硬件情况优化训练配置
    • 无缝支持数据并行、模型并行等分布式训练模式 ,适配多 GPU 或集群环境
    • 可选择 FSDP 或 DeepSpeed 等底层框架,仅需简单修改启动命令即可
    • 支持 FP16、BF16 等混合精度训练 ,在减少显存占用的同时,保证模型训练精度
    • 只需对原有 PyTorch 代码进行少量修改 ,即可实现硬件加速和分布式训练
  • 一般来说仅需要两行代码改动,其他都有命令行进行配置
    • accelerator.prepare() :核心函数,用于包装模型、优化器、数据加载器等组件,自动适配分布式和混合精度设置
    • accelerator.backward() :替代传统的 loss.backward(),在分布式环境中确保梯度正确同步

HuggingFace Accelerate 使用代码示例

  • 以下是一个使用 Accelerate 进行模型训练的基础示例,训练代码(train.py)如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import Dataset, DataLoader
    from accelerate import Accelerator # 导入 Accelerator

    class DiyModel(nn.Module):
    def __init__(self):
    super().__init__()
    self.fc = nn.Linear(10, 2) # 二分类任务
    def forward(self, x):
    return self.fc(x)

    class DiyDataset(Dataset):
    def __len__(self):
    return 1000
    def __getitem__(self, idx):
    x = torch.randn(10)
    y = torch.randint(0, 2, (1,)).item() # 随机标签(0 或 1)
    return x, y

    # 混合精度配置:可通过 `Accelerator(mixed_precision="fp16")` 启用 FP16 混合精度训练,减少显存占用
    accelerator = Accelerator() # 核心代码,初始化 Accelerator

    # 以下所有定义都不涉及使用 Accelerator
    model = DiyModel()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    dataset = DiyDataset()
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

    # 重点:使用 accelerator 包装多个组件(注:这一行会自动处理分布式和混合精度)
    model, optimizer, dataloader, criterion = accelerator.prepare(
    model, optimizer, dataloader, criterion
    )

    # 特别注意的一点不同是:训练循环时,使用 accelerator.backward()
    model.train()
    for epoch in range(3):
    total_loss = 0.0
    for x, y in dataloader:
    optimizer.zero_grad()
    outputs = model(x)
    loss = criterion(outputs, y)
    accelerator.backward(loss) # 替代 loss.backward()
    optimizer.step()
    total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")

Accelerate 启动训练

  • 可简单通过 Accelerate 命令行工具配置训练环境并启动:

    1
    accelerate launch --num_processes=2 train.py  # 使用 2 个进程(如 2 个 GPU)
    • --num_processes:指定进程数(通常等于 GPU 数量)
    • 若使用单 GPU 或 CPU,可直接运行 python train.py,Accelerator 会自动适配环境
  • 更多启动命令参见下文


accelerate launch 命令详细说明

  • accelerate launch 是 HuggingFace Accelerate 库的核心命令,用于启动分布式训练脚本,它能自动处理多卡、多机等复杂分布式环境的配置

  • accelerate launch 的核心作用是:

    • 1)初始化分布式环境(进程组、通信后端等)
    • 2)根据参数自动选择分布式策略(数据并行/FSDP/DeepSpeed 等)
    • 3)将环境配置传递给训练脚本中的 Accelerator 实例,使其能正确处理模型、数据的分布式适配
  • accelerate launch 命令基本语法

    1
    accelerate launch [启动参数] your_script.py [脚本参数]
    • [启动参数]:控制分布式训练的配置(如使用的 GPU 数量、分布式策略等)
    • [脚本参数]:传递给你的训练脚本(your_script.py)的自定义参数(如 --epochs 10、--batch_size 32 等)

硬件与进程配置参数

  • --num_processes N:指定总进程数(通常等于参与训练的 GPU 总数)
    • 例如:--num_processes 4 表示使用 4 个 GPU
  • --num_machines N:指定机器数量(多机分布式训练时使用),默认值为 1(单机器)
  • --machine_rank N:当使用多机时,指定当前机器的序号(从 0 开始)
    • 例如:主节点用 --machine_rank 0,从节点用 --machine_rank 1
  • --main_process_ip IP地址:多机训练时,主节点的 IP 地址(供从节点连接)
  • --main_process_port 端口号:主节点的通信端口(默认 29500,需确保端口未被占用)

分布式策略选择参数

  • Accelerate 支持多种分布式策略,通过参数指定:
  • 默认策略(自动选择) :不指定任何策略时,Accelerate 会根据硬件自动选择当前硬件下的最佳策略:
    • 单卡:直接使用单进程训练
    • 多卡:默认使用 PyTorch 的 nn.DataParallel 或 DistributedDataParallel(数据并行)
  • --use_fsdp:
    • 启用 FSDP(完全分片数据并行) ,适合超大规模模型(需 PyTorch ≥ 1.11);
    • 常用搭配参数如下:
      • --fsdp_fully_shard:完全分片参数、梯度和优化器状态(最大程度节省内存)
      • --fsdp_transformer_layer_cls_to_wrap "类名":指定 Transformer 层的类名(如 GPT2 的 GPT2Layer、BERT 的 BertLayer),用于自动分片模型层
      • --fsdp_sharding_strategy 策略:分片策略,可选 FULL_SHARD(完全分片)、SHARD_GRAD_OP(梯度和优化器分片)等
  • --use_deepspeed:
    • 启用 DeepSpeed 分布式框架,支持 ZeRO 优化、混合精度等(需提前安装 deepspeed)
    • 通常需配合 DeepSpeed 配置文件使用,通过 --deepspeed 配置文件路径 指定

混合精度训练参数

  • --mixed_precision [mode]:指定混合精度策略
  • 可选模式 [mode] 为:
    • no:不使用混合精度(默认)
    • fp16:使用 FP16 混合精度
    • bf16:使用 BF16 混合精度(需 GPU 支持,如 A100、RTX 3090 等)
    • fp8:使用 FP8 混合精度(需 PyTorch ≥ 2.0 且 GPU 支持)

其他实用参数

  • --config_file 配置文件路径:通过 YAML 配置文件指定所有参数(推荐复杂场景使用),无需在命令行逐个输入
  • --debug:启用调试模式,打印详细的分布式初始化日志,便于排查问题
  • --gradient_accumulation_steps N:指定梯度累积步数(等价于在代码中设置,但通过命令行更灵活)

附录:在启动命令中使用配置文件用法(推荐)

  • 对于复杂配置(如 FSDP/DeepSpeed 细节),建议使用 YAML 配置文件,步骤如下:

  • 第一步:生成默认配置文件 :

    1
    accelerate launch --config_file accelerate_config.yaml --generate_config
    • --generate_config 表示运行后会交互式提问,自动生成配置文件
    • 注:也可以自己编辑 accelerate_config.yaml 文件
  • 第二步:示例配置文件(FSDP 场景) ::

    1
    2
    3
    4
    5
    6
    7
    8
    compute_environment: LOCAL_MACHINE  # 本地机器环境
    distributed_type: FSDP # 使用 FSDP 策略
    fsdp_config:
    fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP # 自动包装 Transformer 层
    fsdp_transformer_layer_cls_to_wrap: "GPT2Layer" # 模型层类名
    fsdp_sharding_strategy: FULL_SHARD # 完全分片
    mixed_precision: fp16 # 启用 FP16 混合精度
    num_processes: 4 # 4 个进程(4 卡)
  • 第三步:使用配置文件启动 :

    1
    accelerate launch --config_file accelerate_config.yaml your_script.py --epochs 10

附录:启动命令的常见场景示例

  • 示例一:单机器多卡基础数据并行示例

    1
    accelerate launch --num_processes 4 train.py --batch_size 32
  • 示例二:启用 FSDP 训练大模型

    1
    2
    3
    4
    5
    6
    7
    accelerate launch \
    --num_processes 4 \
    --use_fsdp \
    --fsdp_fully_shard \
    --fsdp_transformer_layer_cls_to_wrap "BertLayer" \
    --mixed_precision bf16 \
    train_bert.py
  • 示例三:启用 DeepSpeed 与 ZeRO-3 优化(其中 ds_config.json 为 DeepSpeed 配置文件,定义 ZeRO 阶段、梯度裁剪等)

    1
    2
    3
    4
    5
    accelerate launch \
    --num_processes 8 \
    --use_deepspeed \
    --deepspeed ds_config.json \
    train.py

PyTorch——compile函数的理解和使用


整体说明

  • torch.compile() 函数是 PyTorch 2.0 引入的一个重要功能,用于对模型进行编译优化,以提升训练和推理性能
    • 将 PyTorch 模型从“解释型”的逐行执行模式,转变为“编译型”的、一次性优化的执行模式
    • 当模型被编译后,它在后续的推理或训练中会运行得更快,并且使用的内存可能更少
  • torch.compile() 的核心作用是通过对模型计算图进行一系列优化(如算子融合、常量折叠、内存优化等),生成更高效的代码,从而加速模型的执行
    • 这个过程是自动的,并且大部分时间是无侵入性的,不需要修改模型的代码
  • torch.compile() 效果提升:
    • 因模型结构和硬件而异,通常对于有大量小操作(如 Transformer 模型)或对 GPU 算力要求高的模型效果更显著
    • 在大规模训练和推理场景中效果显著
  • 首次运行编译后的模型会有一定的启动延迟,用于图优化和代码生成(compile 函数是惰性执行的),但后续重复调用会更快
  • 编译过程可能会增加模型的内存占用,需根据实际情况调整

调用 torch.compile(model) 后会发生什么?

  • 具体来说,调用 torch.compile(model) 会发生以下过程:
  • 1)捕获模型计算图(Graph Capturing) :分析模型的前向传播逻辑,记录张量的操作序列和依赖关系,构建计算图表示
    • 当你第一次调用 torch.compile 编译后的模型时,它会追踪模型的前向传播过程
    • 这就像在记录模型中的每一步操作,比如矩阵乘法、卷积、激活函数等
    • PyTorch 会创建一个计算图(computational graph) ,这个图代表了模型从输入到输出的所有计算路径
    • 这个过程是惰性的,即只在第一次实际运行模型时发生(所以第一次调用比较慢,后面会比较快)
    • 计算图被捕获后,编译器会对计算图进行一系列优化(接下面)
  • 2)图优化(Graph Optimization) :包含一系列优化,例如:
    • 算子融合(Operator Fusion) : 将多个连续的小算子合并为一个大算子(如将卷积+批归一化+激活函数融合),减少 kernel 调用次数和内存读写
    • 常量折叠(Constant Folding) :计算图中固定不变的常量表达式会被预先计算,避免重复计算
    • 死代码消除(Dead Code Elimination) : 移除计算图中未被使用的节点或操作,减少不必要计算
    • 内存优化(Memory Optimization) : 优化张量的内存分配和复用,重新安排计算顺序,以减少中间结果所需的内存,从而更好地利用缓存
    • 优化后的计算图会用来生成高效的、针对特定硬件(如 GPU)的可执行代码(具体见下节)
  • 3)代码生成 :
    • 编译器将高级的 PyTorch 操作转换成底层的、更接近硬件指令的特定后端的代码(如 CPU 上的 C++/OpenMP 代码,GPU 上的 CUDA 代码)
      • 例如,它可能会将 PyTorch 的张量操作转换成 CUDA 内核
    • 支持多种后端(如 inductor、aot_eager 等),默认使用 inductor 后端(也叫做 TorchInductor 后端),它能生成高效的 GPU/CPU 代码
  • 4)返回编译后的模型 :
    • torch.compile(model) 返回一个经过包装的模型对象,其接口与原模型一致(可直接调用 forward 方法或进行训练),但内部执行的是优化后的代码

使用示例

  • 使用非常简单,仅需一行代码
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    import torch
    import torch.nn as nn

    class DiyModel(nn.Module):
    def __init__(self):
    super().__init__()
    self.conv = nn.Conv2d(3, 32, kernel_size=3)
    self.relu = nn.ReLU()

    def forward(self, x):
    return self.relu(self.conv(x))

    model = DiyModel()
    # 编译模型
    compiled_model = torch.compile(model)

    # 使用编译后的模型(接口与原模型万全一致)
    x = torch.randn(1, 3, 224, 224)
    output = compiled_model(x)

PyTorch——gather函数使用

  • 参考链接:
    • 一篇浅显易懂的博客:PyTorch中的高级索引方法——gather详解

gather函数形式

  • 包含 torch.gather 和 tensor.gather 两种形式,基本思路等价,他们的函数签名如下

    1
    2
    torch.gather(input, dim, index, *, sparse_grad=False, out=None)
    tensor.gather(dim, index, *, sparse_grad=False, out=None)
    • 注:在 PyTorch 的函数签名中,* 是 Python 语法中的一个特殊标记,用于表示强制关键字参数(keyword-only arguments)。这意味着在 * 之后的所有参数(如 sparse_grad 和 out)必须通过关键字(即使用参数名)来传递,而不能通过位置参数传递
  • 参数解释

    • input (Tensor): 输入张量
    • dim (int): 沿着哪个维度进行收集
    • index (LongTensor): 索引张量,包含要收集的元素的索引
    • sparse_grad (bool, 可选): 如果为True,梯度将是稀疏张量
    • out (Tensor, 可选): 输出张量(若该值不为 None,则会将返回值存储到 out 引用中,此时 out 和 返回值 是同一个对象)
  • 特别说明:gather 和普通的矩阵索引操作一样,操作支持反向传播


基本原理

  • 对于 3D 张量,gather操作可以表示为:

    1
    2
    3
    out[i][j][k] = input[index[i][j][k]][j][k]  # dim=0
    out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
    out[i][j][k] = input[i][j][index[i][j][k]] # dim=2
  • 特别说明(记忆方法)

    • index[i][j][k] 用于索引 out[i][j][k]
    • dim=0 :index 指定行号(第0维) ,列号(第1维)和通道(第2维)保持不变
    • dim=1 :index 指定列号(第1维) ,行号(第0维)和通道(第2维)保持不变
    • dim=2 :index 指定通道(第2维) ,行号(第0维)和列号(第1维)保持不变
    • 特别地:输出形状 = index 形状

代码示例

  • 一维张量示例

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    import torch

    # 基本用法
    input_tensor = torch.tensor([10, 20, 30, 40, 50])
    index_tensor = torch.tensor([0, 2, 4, 1])
    result = torch.gather(input_tensor, dim=0, index=index_tensor)
    print(result) # tensor([10, 30, 50, 20])

    # 使用tensor.gather方法
    result2 = input_tensor.gather(0, index_tensor)
    print(result2) # tensor([10, 30, 50, 20])
  • 二维张量示例

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    # 二维张量
    input_2d = torch.tensor([[1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]])

    # 沿着dim=0收集(按行收集)
    index_2d = torch.tensor([[0, 1, 2],
    [2, 0, 1]])
    result_dim0 = torch.gather(input_2d, dim=0, index=index_2d)
    print("dim=0 结果:")
    print(result_dim0)
    # tensor([[1, 5, 9],
    # [7, 2, 6]])

    # 沿着dim=1收集(按列收集)
    index_2d_col = torch.tensor([[0, 2],
    [1, 0],
    [2, 1]])
    result_dim1 = torch.gather(input_2d, dim=1, index=index_2d_col)
    print("dim=1 结果:")
    print(result_dim1)
    # tensor([[1, 3],
    # [5, 4],
    # [9, 8]])

一些实际应用场景

  • 获取最大值,获取对应标签等实例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    # 场景1: 获取每行的最大值索引对应的值
    scores = torch.tensor([[0.1, 0.8, 0.3],
    [0.6, 0.2, 0.9],
    [0.4, 0.7, 0.1]])

    # 获取每行最大值的索引
    max_indices = torch.argmax(scores, dim=1, keepdim=True)
    print("最大值索引:", max_indices) # tensor([[1], [2], [1]])

    # 使用gather获取最大值
    max_values = torch.gather(scores, dim=1, index=max_indices)
    print("最大值:", max_values) # tensor([[0.8], [0.9], [0.7]])

    # 场景2: 根据标签获取对应的预测概率
    predictions = torch.tensor([[0.2, 0.3, 0.5],
    [0.1, 0.8, 0.1],
    [0.6, 0.2, 0.2]])
    labels = torch.tensor([2, 1, 0]) # 真实标签

    # 获取每个样本对应标签的预测概率
    label_probs = torch.gather(predictions, dim=1, index=labels.unsqueeze(1))
    print("标签概率:", label_probs) # tensor([[0.5], [0.8], [0.6]])

注意事项

  • 索引值必须在 [0, input.size(dim)) 范围内
  • 除了指定的dim维度外,input 和 index 的其他维度大小必须相同
  • 输出张量的形状与 index 张量相同

PyTorch——einsum函数使用


整体说明:

  • torch.einsum 是 PyTorch 中一个非常强大且灵活的函数,用于执行基于爱因斯坦求和约定(Einstein summation convention)的张量运算
    • 通过这种约定,你可以简洁地表示复杂的多维数组操作,如矩阵乘法、转置、点积等,而不需要显式地编写循环
  • torch.einsum 提供了极大的灵活性来处理各种复杂的张量运算,但需要注意的是,不恰当的使用可能导致性能下降,因为它可能会隐藏潜在的优化机会(比如矩阵乘法建议直接调用矩阵乘法)
  • torch.einsum的本质就是爱因斯坦求和约定(一种爱因斯坦发表论文中提到的表达式省略写法),是矩阵乘法的一种表示
    • 比如 C = torch.einsum("m,d->nd", A,B)表示矩阵 C[m,d] = A[n]*B[d],这是个很常用的省略写法

基本用法

  • torch.einsum 的基本语法如下:

    1
    torch.einsum(equation, *operands)
    • equation:一个字符串,指定了输入张量的下标标签以及输出张量的计算方式
    • operands:可变数量的张量参数,它们将根据 equation 进行运算
  • 在 equation 参数中:

    • 每个输入张量的维度由字母标记,不同张量之间使用逗号分隔(亲测字符串中间可以随便加空格,不影响最终结果)
      • 相同字母只能表示相同维度Size,但是相同维度Size不要求必须相同字母
      • 下标数量和输入矩阵维度数量一定要对齐
    • 输出的维度是在箭头 -> 右侧指定,如果没有指定输出,则自动推断(注意,需要自动推断的场景是没有 -> 的场景,有->的场景不需要推断,->后面为空时表示输出是一个标量)
    • 推断思路是:重复下标都去掉,不重复下边按照顺序保留
      • ij,jk == ij,jk->ik
      • j,j == j,j->
      • ijd,jk == ijd,jk->idk
      • idi,jk == idi,jk->djk
  • 举例:给定两个二维张量 A 和 B,要进行矩阵乘法并求和可以写作:

    1
    torch.einsum('ij,jk->ik', A, B) # 等价于 torch.einsum('ij,jk', A, B)
    • 'ij' 表示张量 A 的维度,
    • 'jk' 表示张量 B 的维度,
    • 'ik' 表示输出张量的维度
    • 重复的下标(在这个例子中的 j)意味着沿着这些维度进行乘积和求和(后面会详细讲解)
  • 后面会有详细讲解


爱因斯坦求和约定讲解

  • 以 看图学 AI:einsum 爱因斯坦求和约定到底是怎么回事? 中的一个例子为例,torch.einsum('ij,jk->ik', A, B) 求解过程相当于下面的图片展示的形式:

  • 爱因斯坦求和约定的基本理解:

    • 对于任意的表达式,结果都等价于一个多重循环乘积(可能包含求和)的过程:
      • 输入:函数参数,字符串 -> 左边表示矩阵的输入下标
      • 输出:返回值,字符串 -> 右边表示矩阵的输出下标
      • 右边的下标有时候可以省略,此时需要自动推断
  • C = torch.einsum("ij,jk->ijk", A,B) 结果相当于下面的函数(性能上并不想当,因为下面是一种速度较慢的函数)

    1
    2
    3
    4
    for i in range(A[0]):
    for j in range(A[1]): # 也可以用 range(B[0])
    for k in range(B[1]):
    C[i][j][k] = A[i][j] * B[j][k]
  • C = torch.einsum("ij,jk->ik", A,B) 结果相当于下面的函数

    1
    2
    3
    4
    for i in range(A[0]):
    for j in range(A[1]): # 也可以用 range(B[0])
    for k in range(B[1]):
    C[i][k] += A[i][j] * B[j][k] # C[i][k]从0开始累加
  • C = torch.einsum("ii->i", A) 结果相当于下面的函数

    1
    2
    for i in range(A[0]): # 也可用A[1]
    C[i] = A[i][i]
  • 实际上,所有的表达式都可以表达成同一个形式,只要初始化结果的各个元素为0 ,然后统一使用加法即可


附录:一些简单示例

  • 矩阵乘法 :

    1
    2
    3
    A = torch.randn(3, 4)
    B = torch.randn(4, 5)
    C = torch.einsum('ij,jk->ik', A, B) # 等价于 torch.mm(A, B),C[i][k] += A[i][j] * B[j][k]
  • 向量内积 :

    1
    2
    3
    u = torch.randn(3)
    v = torch.randn(3)
    C = torch.einsum('i,i->', u, v) # 等价于 torch.dot(u, v),C += A[i] * B[i]
  • 张量转置 :

    1
    2
    A = torch.randn(2, 3, 4)
    C = torch.einsum('ijk->kji', A) # 转置张量 permute(2,1,0),C[k][j][i] += A[i][j][k]
  • 批量矩阵乘法 :

    1
    2
    3
    A = torch.randn(3, 2, 5)
    B = torch.randn(3, 5, 4)
    C = torch.einsum('bij,bjk->bik', A, B) # 对每一批次做矩阵乘法, C[b][i][k] += A[b][i][k] * B[b][k][k]
  • 乘法+求和 :

    1
    2
    3
    A = torch.randn(3, 4)
    B = torch.randn(4, 5)
    C = torch.einsum('ij,jk->', A, B) # 输出为一个标量 C += A[i][j] * B[j][k]

附录:高阶用法之省略号

  • 省略号容易误解,不建议使用!

附录:einops 包

  • 除了爱因斯坦求和(包含在 torch 包中)外,还有许多相似的爱因斯坦操作(包含在 einops 包中)

  • einops 包中的函数支持跨框架的数据格式,不仅仅是 PyTorch,比如 NumPy,TensorFlow 等

  • 最常见的有 rearrange 函数,可用于做下面的工作:

    • 对张量做更高阶的 reshape 或 view 操作
    • 对张量做 permute 或 transpose 操作
  • 特别地,爱因斯坦操作还有些对应的网络层,如 rearrange 函数对应的 einops.layers.torch.Rearrange 类可以像 nn.Linear 或 nn.ReLU 一样加入到 nn.Sequential() 中作为一个网络模块使用

  • rearrange 函数的简单示例

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    import torch
    import einops

    A = torch.randn(6,2,3,4)
    B = einops.rearrange(A, 'b c d e -> b (c d e)') # 等价于 A.reshape(6, -1)
    print(B.shape)
    # 输出:torch.Size([6, 24])

    A = torch.randn(6,2,3,4)
    B = A.reshape(6, -1)
    print(B.shape)
    # 输出:torch.Size([6, 24])

    A = torch.randn(6,2,3,4)
    B = einops.rearrange(A, 'b c d e -> b e d c') # 等价于 A.permute(0, 3, 2, 1)
    print(B.shape)
    # 输出:torch.Size([6, 4, 3, 2])

    A = torch.randn(6,2,3,4)
    B = A.permute(0, 3, 2, 1)
    print(B.shape)
    # 输出:torch.Size([6, 4, 3, 2])

    # 更复杂的用法
    A = torch.randn(2, 3, 9, 8)
    B = einops.rearrange(A, 'b c (d1 d2) (e1 e2) -> b c (d1 e1) d2 e2', d1=3, e1=2)
    print(B.shape)
    # 输出:torch.Size([2, 3, 6, 3, 4])
  • reduce 函数的简单示例(可选的 reduction 参数有 'sum','max','min','mean' 等):

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    import torch
    import einops

    A = torch.randn(6,2,3,4)
    B = einops.reduce(A, 'b c d e ->b c d', reduction='sum') # 等价于 A.sum(-1)
    print(B.shape)
    # 输出:torch.Size([6, 2, 3])

    A = torch.randn(6,2,3,4)
    B = einops.reduce(A, 'b c d e ->b c', reduction='sum') # 等价于 A.sum(-1).sum(-1)
    print(B.shape)
    # 输出:torch.Size([6, 2])

    A = torch.randn(6,2,3,4)
    B = einops.reduce(A, 'b c d e ->b c 1 1', reduction='mean') # 等价于 A.sum(-1).sum(-1)
    print(B.shape)
    # 输出:torch.Size([6, 2, 1, 1])

PyTorch——torch.distributed.rpc的使用


整体说明

  • torch.distributed.rpc 库是 PyTorch 用于 RPC 通信的包,初始化 RPC 环境的核心函数是 init_rpc
  • 在 PyTorch 中
    • torch.distributed.init_process_group 用于初始化进程组环境, 数据并行 和 集体通信 ,比如同步梯度
    • torch.distributed.rpc.init_rpc 则用于初始化 RPC 通信环境, 模型并行 和 远程过程调用 ,比如在另一台机器上执行一个函数
  • 在很多复杂的分布式训练场景中(如全分片数据并行+流水线并行),torch.distributed 和 torch.distributed.rpc 这两个模块甚至会同时被使用 ,以发挥它们各自的优势

回顾:torch.distributed.init_process_group 初始化进程组

  • torch.distributed.init_process_group 函数是 PyTorch Distributed (c10d) 包的初始化入口

    • 注:c10d 是 “Caffe2 TenSoR Distributed” 的缩写,c10d 代表了 PyTorch 中负责分布式通信的核心底层模块,提供了跨进程、跨节点的数据传输和同步机制;实现了诸如 ProcessGroup(进程组)、Backend(通信后端,如 NCCL、Gloo 等)等核心组件,是 torch.distributed 高层 API 的底层支撑
  • torch.distributed.init_process_group 函数会启动一个进程组,用于执行集体通信

  • torch.distributed.init_process_group 启动的进程组的通信范式是集体通信 :

    • 所有进程都必须参与同一个操作,并且需要等待所有进程都完成该操作后才能继续
    • 常见的操作包括:
      • all_reduce:所有进程的Tensor进行汇总(如求和)后,将结果返回给所有进程。(用于梯度同步)
      • broadcast:将一个进程的Tensor广播给所有其他进程
      • barrier:同步所有进程,确保大家都到达同一个点
      • scatter, gather, all_gather 等
  • 特点是训练代码在每个进程(通常是每个GPU)上几乎是完全相同的

    • 每个进程都做同样的事情:读一部分数据、前向传播、反向传播、然后通过集体通信同步梯度
    • 示例如下:
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      # 示例:在每个进程上初始化进程组
      import torch.distributed as dist

      def setup(rank, world_size):
      dist.init_process_group(
      backend='nccl', # 或 'gloo', 'mpi'
      init_method='env://',
      rank=rank,
      world_size=world_size
      )
      # 创建模型、DDP包装、定义采样器等...
      torch.cuda.set_device(rank)

      # 假设使用 spawn 启动多个进程
      # 每个进程都会运行 setup 函数和后续相同的训练循环
  • 常用于:

    • 数据并行训练:(最经典的用法)
      • 在 DistributedDataParallel (DDP) 中,每个 GPU 上的模型副本计算完梯度后,使用 all_reduce 来对所有梯度进行求和平均,确保每个模型副本都得到相同的更新后的梯度
    • 分布式数据加载:使用 DistributedSampler 确保每个进程读取数据集中不重复的部分
    • 其他需要进程间紧密同步的计算任务

前置说明:RPC 框架建立在进程组之上

  • torch.distributed.rpc.init_rpc 启动的 RPC 建立在进程组之上:

    • RPC 框架底层依赖于 torch.distributed 的进程组来进行初始的握手和协调
    • 当调用 rpc.init_rpc 时,它会在内部检查是否已经存在一个初始化好的进程组
      • 如果存在,它就复用这个进程组
      • 如果不存在,它会*隐式地调用 dist.init_process_group * ,使用默认的 gloo 后端(除非你通过 rpc_backend_options 指定)来创建一个进程组
  • 在大多数情况下,只需要调用 rpc.init_rpc 即可,它会把底层需要的进程组也初始化好

    • 如果需要对进程组的后端(如使用nccl)或初始化方法进行更精细的控制,可以选择 显式地先调用 dist.init_process_group,然后再调用rpc.init_rpc ,代码示例如下:
      1
      2
      3
      4
      5
      6
      # 显式控制进程组后端的示例
      def setup(rank, world_size):
      # 1. 首先,用 NCCL 后端初始化进程组(为了更好的GPU通信性能)
      dist.init_process_group(backend='nccl', ...)
      # 2. 然后,初始化RPC框架。此时RPC检测到已有进程组,不会再次初始化
      rpc.init_rpc(...)
  • 亲测:init_rpc 使用的 master_ip:master_port,必须和 init_process_group 使用的一样,否则无法启动

    • 暂未看到官方明确说明(待补充),猜测原因是因为 RPC 框架依赖已经初始化好的进程组(若没有初始化进程组,init_rpc 会自己新建进程组)
    • 详细示例见附录
  • RPC 的 world_size 可以比 已初始化的进程组的多(虽然有依赖,但两者是相对独立的体系)【具体深层逻辑待梳理】

    • 即 init_rpc 的 world_size 可以比 init_process_group 的大
    • 具体实现时,部分 rank 进程不调用 init_process_group 即可实现

torch.distributed.rpc.init_rpc 初始化 PyTorch RPC 框架

  • torch.distributed.rpc.init_rpc 函数是 PyTorch RPC 框架 的初始化入口

  • torch.distributed.rpc.init_rpc 函数核心目的是启用远程过程调用 ,允许一个进程调用另一个进程(可能在不同的机器上)上定义的函数或方法

  • torch.distributed.rpc.init_rpc 初始化的 PyTorch RPC 框架 通信范式是点对点通信

    • 一个进程(工作者)可以主动请求另一个进程执行某个任务,而不需要所有进程都参与
    • 核心 API 包括:
      • rpc_sync():同步调用,调用方会等待远程函数执行完毕并返回结果
      • rpc_async():异步调用,调用方立即返回一个 Future 对象,未来再从中获取结果
      • remote():异步地在远程工作者上创建一个对象(例如一个模块),并返回一个指向该远程对象的引用(RRef)
  • 常用于:

    • 模型并行:将一个大模型的不同部分放在不同的机器或 GPU 上
      • 前向传播时,数据需要从一个机器流到下一个机器
      • RPC 可以管理这种跨机器的计算流
    • 参数服务器(PS):
      • 一个或多个进程作为参数服务器存储模型参数,其他工作者通过 RPC 向参数服务器拉取参数或推送梯度
    • 强化学习:
      • 多个环境模拟器(Actor)通过RPC将经验数据发送给一个中心 learner 进行训练
    • 分布式流水线并行(待补充):
      • 与 torch.distributed.pipeline.sync.Pipe 结合使用,RPC 负责处理不同设备间层的通信
  • 特点是每个进程的角色和执行的代码可能完全不同 ,一个简单的代码示例如下:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    # start shell: torchrun --nproc_per_node=2 torch_distributed_rpc_demo.py
    import os
    import torch
    import torch.distributed as dist
    import torch.distributed.rpc as rpc

    # 正常定义函数,可用于远程调用(在 worker0 发送命令从 worker1 执行函数)
    def worker_function(a, b):
    worker_info = rpc.get_worker_info()
    print(f"Worker {worker_info.name} received a call with a={a}, b={b}")
    return a + b

    def main():
    # 获取环境变量
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])

    # 1. 初始化进程组和 RPC
    dist.init_process_group(backend='gloo', rank=rank, world_size=world_size) # 建议提前初始化好进程组,如没有这行,下面的函数也会自动初始化一个 进程组
    rpc.init_rpc(
    name=f"worker_{rank}", # 若不唯一,会报错 RuntimeError: RPC worker name worker is not unique. Workers 1 and 0 share the same name
    rank=rank,
    world_size=world_size
    )

    print(f"Worker {rank} is ready.")

    if rank == 0:
    # 2. worker_0 发起 RPC 调用
    print("--- Worker 0 sending RPC calls ---")

    # 同步调用
    result_sync = rpc.rpc_sync(
    to=f"worker_1", # 真正的执行过程发生在 worker1 上
    func=worker_function,
    args=(10, 20)
    )
    print(f"Worker 0 received sync result from Worker 1: {result_sync}")

    # 异步调用
    future_async = rpc.rpc_async(
    to=f"worker_1",
    func=worker_function,
    args=(30, 40)
    )
    print("Worker 0 is doing other work while waiting for async result...")
    result_async = future_async.wait()
    print(f"Worker 0 received async result from Worker 1: {result_async}")

    else:
    # 3. worker_1 只是等待并处理 RPC 请求
    print("--- Worker 1 is waiting for RPC calls ---")
    # 由于 worker_0 正在调用,这里不需要做任何事,RPC 框架会自动处理

    rpc.shutdown() # 默认等待所有 RPC 调用完成后再关闭
    dist.destroy_process_group() # 注:若没有显示调用 dist.init_process_group,也不能调用这一行,否则会报错

    if __name__ == "__main__":
    main()
    • 上述代码的启动脚本(单机双卡):torchrun --nproc_per_node=2 torch_distributed_rpc_demo.py

rpc.init_rpc 函数 详细用法

  • PyTorch 的 torch.distributed.rpc.init_rpc 函数原型

    1
    2
    3
    4
    5
    6
    7
    torch.distributed.rpc.init_rpc(
    name,
    backend=None,
    rank=-1,
    world_size=None,
    rpc_backend_options=None
    )
  • 参数说明如下:

  • name(必填)

    • 当前工作进程的全局唯一名称 ,不能重复,重复会报错,后续会作为交互对象的指定
    • 名称只能包含数字、字母、下划线、冒号和/或短划线,且必须少于128个字符
    • 例如 "trainer0", "parameter_server1"
  • backend(选填)

    • 指定使用的RPC 后端实现
    • 可选值来自 torch.distributed.rpc.BackendType 枚举
      • BackendType.TENSORPIPE,更现代且推荐的 RPC 后端
      • BackendType.PROCESS_GROUP,不常用,对应的是 PyTorch 的后端,如 NCCL 和 Gloo
  • rank(必填)

    • 当前进程在组中的全局唯一 ID
    • 通常是一个整数,例如 0, 1, 2…
    • 必须与 world_size 一起正确设置
  • world_size(必填)

    • 参与作业的进程总数
    • 所有进程的 world_size 必须相同
  • rpc_backend_options(选填)

    • 传递给底层 RPC 代理的高级选项 ,类型必须与 backend 参数选定的值严格对齐,否则会报错
      • backend=BackendType.TENSORPIPE 对应 rpc_backend_options 类别为:TensorPipeRpcBackendOptions
      • backend=BackendType.PROCESS_GROUP 对应 rpc_backend_options 类别为:ProcessGroupRpcBackendOptions
    • 可常用于设置 RPC 超时、初始化方法等
      • 若未指定,默认超时为60秒,并使用 init_method="env://"(这意味着需要设置 MASTER_ADDR 和 MASTER_PORT 环境变量)
  • 使用示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    import torch.distributed.rpc as rpc
    from torch.distributed.rpc import TensorPipeRpcBackendOptions

    # 配置 TensorPipe 后端选项
    options = TensorPipeRpcBackendOptions(
    num_worker_threads=8, # 8 个工作线程,指定处理 RPC 请求的工作线程数,影响并发处理能力,默认值是 CPU 数量
    rpc_timeout=10000, # RPC 超时时间设置为 10 秒
    init_method="tcp://127.0.0.1:29500", # 初始化地址,格式和 init_process_group 的 init_method 参数一致
    device_maps={ "worker1": {0: 1}, "worker2": {1: 0} } # 本地 GPU 0 映射到 worker1 的 GPU 1;本地 GPU 1 映射到 worker2 的 GPU 0
    # devices=?, # 默认使用所有可用设备
    security_token="my_rpc_token" # str 类型安全令牌,用于验证节点身份的安全令牌,多机环境下防止未授权节点接入
    )

    # 初始化 RPC,传入配置
    rpc.init_rpc(
    name="worker0", # 当前节点名称
    backend=rpc.BackendType.TENSORPIPE,
    rpc_backend_options=options,
    rank=0, # 全局 rank
    world_size=2 # 总节点数
    )

TensorPipeRpcBackendOptions 的初始化细节

  • TensorPipeRpcBackendOptions 的函数原型为:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    def __init__(
    self,
    *, # Python 中可变位置参数的用法,仅使用一个 * 表示后续的参数必须按照 关键字参数形式调用
    num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS,
    rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC,
    init_method: str = rpc_contants.DEFAULT_INIT_METHOD,
    device_maps: Optional[Dict[str, Dict[DeviceType, DeviceType]]] = None,
    devices: Optional[List[DeviceType]] = None,
    _transports: Optional[List] = None,
    _channels: Optional[List] = None,
    )
  • 除了前文示例指定的参数外,其他参数介绍如下:

  • devices 参数

    • 指定当前节点上可用于 RPC 通信的设备(主要是 GPU),限制 RPC 操作只能使用列表中声明的设备
    • 默认值是 None(表示允许使用所有可见设备)
    • 当设置为非空列表时(如 [0, 1] 或 [torch.device('cuda:0'), torch.device('cuda:1')]),TensorPipe 只会使用列表中的设备进行数据传输,忽略其他设备
    • 主要用于多 GPU 场景下的资源隔离,例如限制 RPC 通信仅使用特定 GPU,避免占用训练用 GPU 资源
  • _transports 参数(以下划线开头,是私有参数,谨慎使用)

    • 指定 TensorPipe 可使用的传输协议(底层物理通信方式),用于跨节点或跨进程的数据传输
    • (推荐)默认值是自动检测并选择最优传输方式(通常为 ["tcp", "ibv"],即 TCP 和 InfiniBand)
    • 支持的传输协议:
      • "tcp":基于 TCP/IP 的网络传输,兼容性好,适用于普通网络环境
      • "ibv":基于 InfiniBand 的传输,低延迟、高带宽,适用于高性能计算集群
      • "shm":表示优先使用共享内存作为传输介质
      • "uv":基于 libuv 的本地传输(用于同一节点内的进程通信,通常自动启用)
    • 列表顺序表示优先级,TensorPipe 会优先尝试前面的传输方式
    • 仅在需要强制指定传输协议时使用(如集群仅支持 InfiniBand 时设置 ["ibv"]),默认自动选择即可满足多数场景
  • _channels 参数(以下划线开头,是私有参数,谨慎使用)

    • 指定 TensorPipe 可使用的通道类型(数据在设备间的内存传输方式),用于同一节点内的设备间通信(如 CPU-GPU、GPU-GPU)
    • (推荐)默认是自动检测并选择最优通道(通常为 ["cuda_ipc", "cuda_basic", "shm"])
    • 支持的通道类型有:
      • "cuda_ipc":GPU 进程间通信(同一节点内不同进程的 GPU 直接通信,效率最高)
      • "cuda_basic":通过 CPU 中转的 GPU 通信(当 cuda_ipc 不可用时降级使用)
      • "shm":共享内存通信(同一节点内的 CPU 进程间通信),这里强调通信通道,_transports 的 "shm" 强调通信介质
      • "basic":通过系统内存拷贝的通信(通用 fallback 方式)
    • 列表顺序表示优先级,优先使用高效的通道(如 cuda_ipc 优于 cuda_basic)
    • 主要用于调试或特殊硬件环境下的兼容性调整,默认配置已针对性能优化

初始化常见错误排查

  • 初始化失败:检查 MASTER_ADDR 和 MASTER_PORT 是否设置正确且所有机器可达,防火墙是否放行了指定端口,以及 rank 和 world_size 是否配置正确
  • 超时错误:增加 rpc_backend_options 中的 rpc_timeout 值
  • 版本不匹配:确保所有参与计算的 PyTorch 版本一致

RPC 框架的交互模式

  • PyTorch 的 torch.distributed.rpc 模块为 分布式模型训练 提供了核心的远程过程调用(RPC)和远程引用(RRef)机制
  • 以下是 RPC 相关核心方法:
  • rpc_sync()(同步方法,阻塞)
    • 在远程 worker 上同步调用函数
    • 返回函数调用的直接结果
  • rpc_async()(异步方法,非阻塞)
    • 在远程 worker 上异步调用函数
    • 返回一个 Future 对象,可通过 future.wait() 获取结果
  • remote()(非阻塞)
    • 异步在远程 worker 上创建对象
    • 返回一个 RRef (远程引用) 指向创建的对象
  • RRef.to_here()(阻塞)
    • 将 RRef 引用的值从所有者复制到本地节点
    • 返回引用的对象本身

同步远程调用 (rpc_sync)

  • rpc_sync 会阻塞调用者,直到远程函数执行完成并返回结果

  • 函数原型:

    1
    2
    3
    4
    5
    6
    7
    8
    # 在目标 worker (例如 to="worker1") 上同步运行函数 func 并获取结果
    result = torch.distributed.rpc.rpc_sync(
    to, # (str or WorkerInfo or int) 目标 worker 的名称、rank 或 WorkerInfo
    func, # (Callable) 待执行函数,是一个可调用函数(如 Python 函数、内置运算符或 TorchScript 函数)
    args=None, # (tuple) 传给函数的可变位置参数
    kwargs=None, # (dict) 传给函数的可变关键字参数
    timeout=-1.0 # (float, optional) RPC 超时时间(秒),-1.0 表示使用默认值(timeout=UNSET_RPC_TIMEOUT, 此时走初始化逻辑,或_set_rpc_timeout 设置的值),0 表示无限超时
    )
  • 使用示例:

    1
    2
    3
    # 在 worker 0 上:
    ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
    # 在 worker 1 上需要定义相应的函数或直接使用内置函数, worker 1 执行函数体,并返回结果

异步远程调用 (rpc_async)

  • rpc_async 会立即返回一个 Future 对象,你可以在之后需要时等待并获取结果,适用于非阻塞调用

  • 函数原型:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    # 异步调用远程函数,返回 Future 对象
    future = torch.distributed.rpc.rpc_async(
    to, # (str or WorkerInfo or int) 目标 worker 的名称、rank 或 WorkerInfo
    func, # (Callable) 待执行函数,是一个可调用函数(如 Python 函数、内置运算符或 TorchScript 函数)
    args=None, # (tuple) 传给函数的可变位置参数
    kwargs=None, # (dict) 传给函数的可变关键字参数
    timeout=-1.0 # (float, optional) RPC 超时时间(秒),-1.0 表示使用默认值(timeout=UNSET_RPC_TIMEOUT, 此时走初始化逻辑,或_set_rpc_timeout 设置的值),0 表示无限超时
    # 注:_set_rpc_timeout 是 PyTorch 分布式 RPC 框架中用于动态修改全局默认超时时间的函数,通常在 init_rpc(RPC 环境初始化)之后调用
    # 注:异步操作不会阻塞,这里的 timeout 是从此刻开始,远程操作本身返回的时间,如果有错也是在 future.wait() 时抛出
    )
    # ... 执行其他操作 ...
    result = future.wait() # 等待并获取结果,这里也可以设置 future 从此刻开始的等待 timeout, 设置后与 rpc_async 中的 timeout 时间两者时间取小(注意两者起始时间不同)
  • 使用示例:

    1
    2
    3
    fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3))
    fut2 = rpc.rpc_async("worker1", min, args=(1, 2))
    result = fut1.wait() + fut2.wait()

远程对象创建与引用 (remote 和 RRef)

  • remote 方法用于在远程 worker 上异步创建对象 ,并立即返回一个指向该对象的 RRef (Remote Reference)

  • RRef 的 to_here() 方法可以阻塞地将远程对象的值复制到本地

  • remote 函数原型

    1
    2
    3
    4
    5
    6
    7
    8
    9
    # 在远程 worker 上创建对象并返回 RRef,持有 func 函数的返回值
    rref = torch.distributed.rpc.remote(
    to, # (str or WorkerInfo or int) 目标 worker 的名称、rank 或 WorkerInfo,目标 worker(将成为所有者)
    func, # (Callable) 用于创建对象的函数,是一个可调用函数(如 Python 函数、内置运算符或 TorchScript 函数)
    args=None, # (tuple) 传给函数的可变位置参数
    kwargs=None, # (dict) 传给函数的可变关键字参数
    timeout=-1.0 # (float, optional) RPC 超时时间(秒),-1.0 表示使用默认值(timeout=UNSET_RPC_TIMEOUT, 此时走初始化逻辑,或_set_rpc_timeout 设置的值),0 表示无限超时
    # 注:异步操作不会阻塞,这里的 timeout 是从此刻开始,远程操作本身返回的时间,如果有错也是在 rref.to_here() 时抛出
    )
  • RRef 对象的一些使用方式:

    • 下面的两种调用相同
      1
      2
      3
      4
      5
      6
      7
      8
      9
      rref = rpc.remote("worker1", MyClass, args=(arg1, arg2))

      # 方式一:直接调用远程对象的函数(func_name 必须是 rref 持有对象的函数)
      rref.rpc_async().func_name(*args, **kwargs)

      # 方式二:原生的 rpc 实现过程,较为复杂,效果和方式一完全相同
      def run(rref, func_name, args, kwargs):
      return getattr(rref.local_value(), func_name)(*args, **kwargs)
      rpc.rpc_async(rref.owner(), run, args=(rref, func_name, args, kwargs))
  • RRef 的更多使用详情见附录

shutdown 函数的使用

  • 在所有 RPC 工作完成后、进程退出前,必须显式调用 torch.distributed.rpc.shutdown ();
  • 推荐使用 graceful=True(默认) 以等待所有未完成 RPC 结束并避免死锁与 SIGABRT 风险
  • 函数原型
    1
    2
    3
    4
    5
    6
    7
    # torch.distributed.rpc.shutdown
    def shutdown(
    graceful=True, # (bool),默认值为 True,表示是否优雅的结束
    # 如果 True,则 1) 等待所有针对 UserRRef 的系统消息处理完毕并删除这些引用;2) 阻塞直到本地与远程所有 RPC 进程均调用此方法,并等待所有未完成工作完成;
    # 当 graceful=False 时:不等待未完成的 RPC 任务和其他进程,直接强制关闭本地 RPC 系统。可能导致未完成任务失败或资源泄漏,仅建议在紧急退出场景使用
    timeout=DEFAULT_SHUTDOWN_TIMEOUT # 限制 “优雅等待” 的最大时长:
    )

重要注意事项

  • 较早的 PyTorch 版本(如 1.4 之前)API 可能不能使用
  • 分布式作业中的每个进程都必须成功调用 init_rpc,并在最后调用 rpc.shutdown() 来优雅地释放资源
  • RPC 框架在处理张量时,通常视为张量位于 CPU 上进行操作,以避免设备不匹配的错误
    • 如果需要在 GPU 上操作,可能需要在函数内部显式地将张量移动到 GPU
  • 网络问题、函数执行失败或超时都可能导致 RPC 调用失败,建议添加适当的异常处理
  • RPC 调用会有通信开销
    • 应尽量减少小规模的频繁调用,考虑批量操作或传输更大尺寸的数据以减少开销

附录:init_rpc 的初始化 IP 和端口要与 init_process_group 一致

  • 亲测:init_rpc 使用的 master_ip:master_port,必须和 init_process_group 使用的一样,否则无法启动
    • 暂未看到官方明确说明(待补充),猜测原因是因为 RPC 框架依赖已经初始化好的进程组(若没有初始化进程组,init_rpc 会自己新建进程组)
  • 示例如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    # start shell: torchrun --nproc_per_node=2 torch_distributed_rpc_demo.py
    import os
    import torch.distributed as dist
    import torch.distributed.rpc as rpc
    from torch.distributed.rpc import TensorPipeRpcBackendOptions

    def worker_function(a, b):
    worker_info = rpc.get_worker_info()
    print(f"Worker {worker_info.name} received a call with a={a}, b={b}")
    return a + b

    def main():
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    master_addr = os.environ.get('MASTER_ADDR', '未设置')
    master_port = int(os.environ.get('MASTER_PORT', '未设置'))
    print(f"master addr: {master_addr}:{master_port}")

    dist.init_process_group(backend='gloo', rank=rank, world_size=2,init_method=f"tcp://{master_addr}:{master_port}") # 提前初始化进程组

    rpc_backend_options = TensorPipeRpcBackendOptions(
    init_method=f"tcp://{master_addr}:{master_port}", # 与进程组初始化(init_process_group)时相同 IP和端口,可以正常初始化,端口不一致时会一直卡在这里不动
    num_worker_threads=2,
    rpc_timeout=10
    )

    # rpc_backend_options = None
    print(f"rank {rank} start init_rpc..., dist.is_initialized={dist.is_initialized()}")
    rpc.init_rpc(
    name=f"worker_{rank}",
    rank=rank,
    world_size=world_size,
    rpc_backend_options=rpc_backend_options
    )
    print(f"rank {rank} have done init_rpc")

    print(f"Worker {rank} is ready.")

    if rank == 0:
    result_sync = rpc.rpc_sync(
    to=f"worker_1", # 真正的执行过程发生在 worker1 上
    func=worker_function,
    args=(10, 20)
    )
    print(f"Worker 0 received sync result from Worker 1: {result_sync}")
    else:
    print("--- Worker 1 is waiting for RPC calls ---")

    rpc.shutdown() # 等待所有 RPC 完成后自动关闭
    dist.destroy_process_group() # 注:若没有显示调用 dist.init_process_group,也不用这一行

    if __name__ == "__main__":
    main()

附录:RPC 的 world_size 可以比 已初始化的进程组的多

  • RPC 的 world_size 可以比 已初始化的进程组的多(虽然有依赖,但两者是相对独立的体系)【具体深层逻辑待梳理】
    • 即 init_rpc 的 world_size 可以比 init_process_group 的大
    • 具体实现时,部分 rank 进程不调用 init_process_group 即可实现
  • 示例如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    # start shell: torchrun --nproc_per_node=4 torch_distributed_rpc_demo.py
    import os
    import torch.distributed as dist
    import torch.distributed.rpc as rpc
    from torch.distributed.rpc import TensorPipeRpcBackendOptions

    def worker_function(a, b):
    worker_info = rpc.get_worker_info()
    print(f"Worker {worker_info.name} received a call with a={a}, b={b}")
    return a + b

    def main():
    # 获取环境变量
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    master_addr = os.environ.get('MASTER_ADDR', '未设置')
    master_port = int(os.environ.get('MASTER_PORT', '未设置'))
    print(f"master addr: {master_addr}:{master_port}")

    if rank < 2: # 仅 rank=0,1 的节点初始化进程组
    print(f"rank {rank} start init_process_group...")
    dist.init_process_group(
    backend='gloo',
    rank=rank,
    world_size=2, # world_size 设置为2,甚至仅初始化一个也可以,rank<2 改成 rank<1 即可
    init_method=f"tcp://{master_addr}:{master_port}"
    ) # 提前初始化好进程组
    print(f"rank {rank} have done init_process_group")
    print(dist)

    rpc_backend_options = TensorPipeRpcBackendOptions(
    init_method=f"tcp://{master_addr}:{master_port}", # 与进程组初始化(init_process_group)时相同 IP和端口,可以正常初始化,端口不一致时会一直卡在这里不动
    num_worker_threads=2,
    rpc_timeout=10
    )

    print(f"rank {rank} start init_rpc..., dist.is_initialized={dist.is_initialized()}")
    rpc.init_rpc(
    name=f"worker_{rank}",
    rank=rank,
    world_size=world_size, # world_size 由传入的确定,实际上可能是 4
    rpc_backend_options=rpc_backend_options
    )
    print(f"rank {rank} have done init_rpc")

    print(f"Worker {rank} is ready.")

    if rank == 0:
    print("--- Worker 0 sending RPC calls ---")

    # 同步调用
    result_sync = rpc.rpc_sync(
    to=f"worker_1", # 真正的执行过程发生在 worker1 上
    func=worker_function,
    args=(10, 20)
    )
    print(f"Worker 0 received sync result from Worker 1: {result_sync}")

    # 异步调用
    future_async = rpc.rpc_async(
    to=f"worker_1",
    func=worker_function,
    args=(30, 40)
    )
    print("Worker 0 is doing other work while waiting for async result...")
    result_async = future_async.wait()
    print(f"Worker 0 received async result from Worker 1: {result_async}")

    elif rank == 3:
    print("--- Worker 3 sending RPC calls ---")

    # 同步调用
    result_sync = rpc.rpc_sync(
    to=f"worker_2", # 真正的执行过程发生在 worker2 上
    func=worker_function,
    args=(10, 20)
    )
    print(f"Worker 3 received sync result from Worker 2: {result_sync}")

    else:
    print("--- Worker 1 is waiting for RPC calls ---")

    rpc.shutdown() # 等待所有 RPC 完成后自动关闭
    if dist.is_initialized(): # 必须判断,否则可能报错,因为部分进程为初始化 init_process_group
    dist.destroy_process_group() # 注:若没有显示调用 dist.init_process_group,不用这一行

    if __name__ == "__main__":
    main()

附录:RRef 的使用方式细节

  • 参考自:PyTorch 官方文档-RRef
  • Warning:当前使用 CUDA 张量时,暂不支持 RRef,需要把张量从 GPU 转移到 CPU 才可以
  • RRef(Remote REFerence,远程引用)是指向远程工作节点上某一类型(如张量 Tensor)值的引用
    • 该句柄会确保被引用的远程值在其所属节点上保持有效
    • 在多机训练场景中,RRef 可用于持有对其他工作节点上 nn.Module(神经网络模块)的引用,并在训练过程中调用相应函数来获取或修改这些模块的参数
    • 更多细节可参考:Remote Reference Protocol
  • 注:RRef 也是一个 Python 类(torch.distributed.rpc.RRef),但有时泛指所有远程引用,而不是某一个类

torch.distributed.rpc.PyRRef 类型总体说明

  • 类型定义原型:

    1
    2
    3
    4
    5
    6
    7
    # __pybind11_builtins.pybind11_object 表明 PyRRef 是一个通过 pybind11 技术从 C++ 创建的 Python 类
    class PyRRef(__pybind11_builtins.pybind11_object):
    # ...

    # RRef 的定义
    class RRef(PyRRef, Generic[T]):
    pass
  • PyRRef 类用于封装对远程工作节点上某一类型值的引用

    • 此句柄会确保被引用的远程值在对应工作节点上保持有效
    • 当满足以下任一条件时,UserRRef(用户侧远程引用)将被销毁:
      • 1)在应用代码和本地 RRef 上下文环境中,均不存在对该 UserRRef 的引用;
      • 2)应用已调用 graceful shutdown 操作
    • 调用已销毁 RRef 的方法会导致未定义行为
      • RRef 实现仅提供“尽力而为”的错误检测机制(best-effort error detection),因此在调用 rpc.shutdown()(RPC 关闭函数)后,应用不应再使用 UserRRef
  • 警告:RRef 只能由 RPC 模块进行序列化和反序列化操作

    • 若不通过 RPC 模块(如使用 Python pickle 模块、PyTorch 的 save()/load() 函数、JIT 的 save()/load() 函数等)对 RRef 进行序列化或反序列化,将会引发错误

PyRRef 类型初始化参数(不常用,一般不自己定义 RRef)

  • value(object 类型):需用此 RRef 封装的值
  • type_hint(Type 类型,可选):传递给 TorchScript 编译器的 Python 类型提示,用于指定 value 的类型

PyRRef 类型使用示例

  • 以下示例为简化说明,省略了 RPC 初始化和关闭相关代码。有关这些操作的详细内容,请参考 RPC 文档

  • 使用 rpc.remote() 创建 RRef

    1
    2
    3
    4
    5
    6
    7
    import torch
    import torch.distributed.rpc as rpc

    # 在远程工作节点“worker1”上执行 torch.add 操作,并创建指向结果的 RRef
    rref = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) # Debug 时看到的可能是 torch._C._distributed_rpc.PyRRef 类型,实际和 torch.distributed.rpc.PyRRef 是同一个类的不同路径
    # 从 RRef 中获取值的副本
    x = rref.to_here()
  • 基于本地对象创建 RRef

    1
    2
    3
    4
    5
    6
    7
    import torch
    from torch.distributed.rpc import RRef

    # 创建本地张量
    x = torch.zeros(2, 2)
    # 用 RRef 封装本地张量,生成 RRef 对象
    rref = RRef(x)
  • 与其他工作节点共享 RRef

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    # 在工作节点“worker0”和“worker1”上均执行以下函数定义
    def f(rref):
    # 获取 RRef 指向的值并加 1
    return rref.to_here() + 1

    # 在工作节点“worker0”上执行以下代码
    import torch
    import torch.distributed.rpc as rpc
    from torch.distributed.rpc import RRef

    # 基于本地零矩阵创建 RRef
    rref = RRef(torch.zeros(2, 2))
    # 通过 RPC 调用,将 RRef 共享给工作节点“worker1”,引用计数会自动更新
    rpc.rpc_sync("worker1", f, args=(rref,))

PyRRef.backward 方法

  • backward 方法原型

    1
    backward(self: torch._C._distributed_rpc.PyRRef, dist_autograd_ctx_id: int = -1, retain_graph: bool = False) -> None
  • backward 方法以 RRef 作为反向传播的根节点,执行反向传播过程

  • 若提供了 dist_autograd_ctx_id(分布式自动求导上下文 ID),则会使用该上下文 ID,从 RRef 的所属节点开始执行分布式反向传播

    • 这种情况下,应使用 get_gradients() 函数获取梯度
  • 若 dist_autograd_ctx_id 为 None,则默认当前为本地自动求导图,仅执行本地反向传播

    • 在本地反向传播场景中,调用此 API 的节点必须是该 RRef 的所属节点,且 RRef 所指向的值需为标量张量(scalar Tensor)
  • PyRRef.backward 方法参数

    • dist_autograd_ctx_id(int 类型,可选):用于获取梯度的分布式自动求导上下文 ID,默认值为 -1
    • retain_graph(bool 类型,可选):若设为 False,则计算梯度所用的计算图会被释放
      • 几乎不需要设为 True
      • 一般仅在需要多次执行反向传播时,才需将其设为 True,默认值为 False
  • PyRRef.backward 方法示例

    1
    2
    3
    4
    5
    6
    import torch.distributed.autograd as dist_autograd

    # 创建分布式自动求导上下文,并获取上下文 ID
    with dist_autograd.context() as context_id:
    # 以当前 RRef 为根节点,执行分布式反向传播
    rref.backward(context_id)

PyRRef 一些简单方法说明

  • confirmed_by_owner 方法:

    1
    confirmed_by_owner(self: torch._C._distributed_rpc.PyRRef)  bool
    • 返回该 RRef 是否已得到其所属节点的确认(真正拥有对象的节点为所属节点,对象时存储到所属节点上的)
      • OwnerRRef(所属节点侧远程引用)始终返回 True;
      • UserRRef(用户侧远程引用)仅在所属节点知晓该 UserRRef 存在时,才返回 True
  • is_owner 方法:

    1
    is_owner(self: torch._C._distributed_rpc.PyRRef) -> bool
    • 返回当前节点是否为该 RRef 的所属节点(对象时存储到所属节点上的)
  • local_value 方法:

    1
    local_value(self: torch._C._distributed_rpc.PyRRef) -> object
    • 若当前节点是该 RRef 的所属节点 ,则返回对本地值的引用;否则,抛出异常
  • owner 方法:

    1
    owner(self: torch._C._distributed_rpc.PyRRef) -> torch._C._distributed_rpc.WorkerInfo
    • 返回该 RRef 所属节点的工作节点信息(WorkerInfo 对象)
  • owner_name 方法:

    1
    owner_name(self: torch._C._distributed_rpc.PyRRef) -> str
    • 返回该 RRef 所属节点的工作节点名称(字符串形式)

PyRRef.remote 方法

  • remote 方法原型:

    1
    remote(self: torch._C._distributed_rpc.PyRRef, timeout: float = -1.0) -> object
  • remote 方法创建一个辅助代理(helper proxy)

    • 该代理可以用于发起 remote 调用,以 RRef 的所属节点为目标节点,在该 RRef 所指向的对象上执行函数
  • 更具体地说,rref.remote().func_name(*args, **kwargs) 等价于以下代码:

    1
    2
    3
    4
    5
    6
    def run(rref, func_name, args, kwargs):
    # 获取 RRef 本地值,并调用其指定方法
    return getattr(rref.local_value(), func_name)(*args, **kwargs)

    # 向 RRef 所属节点发起 remote 调用,执行 run 函数
    rpc.remote(rref.owner(), run, args=(rref, func_name, args, kwargs))
  • remote 方法参数:

    • timeout(float 类型,可选):rref.remote() 的超时时间(单位:秒)
      • 若在该时间内未成功创建 RRef,则下次尝试使用该 RRef(如调用 to_here() 方法)时,会引发超时错误
      • 若未提供该参数,则使用默认的 RPC 超时时间
  • remote 方法示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    from torch.distributed import rpc

    # 在远程节点“worker1”上执行 torch.add,创建指向结果的 RRef
    rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
    # 调用 RRef 指向对象的 size() 方法,获取结果(远程执行,本地获取)
    rref.remote().size().to_here() # 返回结果:torch.Size([2, 2])
    # 注:返回类型
    # rref.remote(): <torch.distributed.rpc.rref_proxy.RRefProxy object at 0x11fa6e470>
    # rref.remote().size(): UserRRef(RRefId = GloballyUniqueId(created_on=0, local_id=4), ForkId = GloballyUniqueId(created_on=0, local_id=5))
    # rref.remote().size().to_here(): torch.Size([2])

    # 调用 RRef 指向对象的 view() 方法,重塑张量形状并获取结果
    rref.remote().view(1, 4).to_here() # 返回结果:tensor([[1., 1., 1., 1.]])

PyRRef.rpc_async 方法

  • rpc_async 方法原型:

    1
    rpc_async(self: torch._C._distributed_rpc.PyRRef, timeout: float = -1.0) -> object
  • rpc_async 方法创建一个辅助代理

    • 该代理可发起 rpc_async 调用,以 RRef 的所属节点为目标节点,在该 RRef 所指向的对象上执行函数
  • 更具体地说,rref.rpc_async().func_name(*args, **kwargs) 等价于以下代码:

    1
    2
    3
    4
    5
    6
    def run(rref, func_name, args, kwargs):
    # 获取 RRef 本地值,并调用其指定方法
    return getattr(rref.local_value(), func_name)(*args, **kwargs)

    # 向 RRef 所属节点发起异步 RPC 调用,执行 run 函数
    rpc.rpc_async(rref.owner(), run, args=(rref, func_name, args, kwargs))
  • rpc_async 方法参数

    • timeout(float 类型,可选):rref.rpc_async() 的超时时间(单位:秒)
      • 若在该时间内调用未完成,则会引发超时异常
      • 若未提供该参数,则使用默认的 RPC 超时时间
  • rpc_async 方法示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    from torch.distributed import rpc

    # 在远程节点“worker1”上执行 torch.add,创建指向结果的 RRef
    rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
    # 异步调用 size() 方法,等待结果返回
    rref.rpc_async().size().wait() # 返回结果:torch.Size([2, 2])
    # 注:返回类型
    # rref.rpc_async(): <torch.distributed.rpc.rref_proxy.RRefProxy object at 0x11bcfa3e0>
    # rref.rpc_async().size(): <torch.jit.Future object at 0x11bcc65c0>
    # rref.rpc_async().size().wait(): torch.Size([2])

    # 异步调用 view() 方法,重塑张量形状并等待结果返回
    rref.rpc_async().view(1, 4).wait() # 返回结果:tensor([[1., 1., 1., 1.]])

PyRRef.rpc_sync 方法

  • rpc_sync 方法原型:

    1
    rpc_sync(self: torch._C._distributed_rpc.PyRRef, timeout: float = -1.0) -> object
  • rpc_sync 方法创建一个辅助代理

    • 该代理可发起 rpc_sync 调用,以 RRef 的所属节点为目标节点,在该 RRef 所指向的对象上执行函数
  • 更具体地说,rref.rpc_sync().func_name(*args, **kwargs) 等价于以下代码:

    1
    2
    3
    4
    5
    6
    def run(rref, func_name, args, kwargs):
    # 获取 RRef 本地值,并调用其指定方法
    return getattr(rref.local_value(), func_name)(*args, **kwargs)

    # 向 RRef 所属节点发起同步 RPC 调用,执行 run 函数
    rpc.rpc_sync(rref.owner(), run, args=(rref, func_name, args, kwargs))
  • rpc_sync 方法参数:

    • timeout(float 类型,可选):rref.rpc_sync() 的超时时间(单位:秒)
      • 若在该时间内调用未完成,则会引发超时异常
      • 若未提供该参数,则使用默认的 RPC 超时时间
  • rpc_sync 方法示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    from torch.distributed import rpc

    # 在远程节点“worker1”上执行 torch.add,创建指向结果的 RRef
    rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
    # 同步调用 size() 方法,获取张量尺寸
    rref.rpc_sync().size() # 返回结果:torch.Size([2, 2])
    # 注:返回类型:
    # rref.rpc_sync(): <torch.distributed.rpc.rref_proxy.RRefProxy object at 0x11bcfa440>
    # rref.rpc_sync().size(): torch.Size([2])

    # 同步调用 view() 方法,重塑张量形状
    rref.rpc_sync().view(1, 4) # 返回结果:tensor([[1., 1., 1., 1.]])

PyRRef.to_here 方法

  • to_here 方法原型:

    1
    to_here(self: torch._C._distributed_rpc.PyRRef, timeout: float = -1.0) -> object
  • to_here 方法是阻塞式调用,将 RRef 指向的值从其所属节点复制到本地节点,并返回该值

    • 若当前节点是该 RRef 的所属节点,则直接返回对本地值的引用
  • to_here 方法参数:

    • timeout(float 类型,可选):to_here 方法的超时时间(单位:秒)
      • 若在该时间内调用未完成,则会引发超时异常
      • 若未提供该参数,则使用默认的 RPC 超时时间(60 秒)

附录:torch.distributed.rpc.PyRRef 和 torch._C._distributed_rpc.PyRRef 的区别

  • 这两个是同一个类的不同引用路径,但有重要的使用区别,两者本质关系为

    1
    2
    3
    4
    # 推荐使用的方式;稳定,版本间兼容性好;可能包含额外的封装、错误检查、文档
    torch.distributed.rpc.PyRRef # 公共API接口,官方支持,文档完整
    # 不推荐直接使用;内部实现细节;可能在PyTorch版本更新时发生变化;直接的C++绑定,可能缺少一些Python层的便利功能
    torch._C._distributed_rpc.PyRRef # 底层C++实现,以`_`开头表示私有
  • 实际上 torch.distributed.rpc.PyRRef 就是对 torch._C._distributed_rpc.PyRRef 的封装或直接引用

  • 验证两者的等价关系

    1
    2
    3
    4
    5
    import torch.distributed.rpc as rpc
    import torch._C._distributed_rpc as _rpc

    # 通常这两个是相同的对象
    print(rpc.PyRRef is _rpc.PyRRef) # 可能输出 True

为什么会看到内部 API torch._C._distributed_rpc.PyRRef?

  • 调试时的堆栈跟踪可能显示内部路径
  • 某些高级用法或源码分析
  • 类型检查时可能遇到

PyTorch——torch.no_grad的用法


整体说明

  • 在 PyTorch 中,torch.no_grad()可用作装饰器 @torch.no_grad() 或上下文管理器 with torch.no_grad()(两者形式不同,但作用相同),用于禁用梯度计算
  • 如果 PyTorch 版本 >= 1.9,可以考虑使用 torch.inference_mode() 来替代 torch.no_grad(),以获得更好的性能

torch.no_grad()的作用

  • torch.no_grad() 的主要作用是临时关闭自动求导机制(autograd)。在被装饰的函数或代码块中,所有涉及张量的操作都不会构建计算图(computation graph),从而节省内存和计算资源:
    • 自动求导机制 :PyTorch 默认会记录张量操作的历史信息(即计算图),以便支持反向传播(backward())来计算梯度
    • 关闭梯度计算 :在推理阶段或其他不需要梯度的场景下,关闭自动求导可以减少内存占用,提高运行效率

使用场景

模型推理(Inference)

  • 在推理阶段,我们只需要前向传播(forward pass),而不需要计算梯度。因此,可以使用 @torch.no_grad() 来优化性能
    1
    2
    3
    4
    5
    6
    7
    8
    9
    @torch.no_grad()
    def evaluate_model(model, test_loader):
    model.eval() # 设置模型为评估模式,改回训练模式可以调用 model.train()
    total_loss = 0
    for data, target in test_loader:
    output = model(data)
    loss = loss_function(output, target)
    total_loss += loss.item()
    return total_loss

更新模型参数时不计算梯度

  • 在某些情况下,我们需要手动更新模型参数(例如权重剪枝、量化等),但不希望这些操作影响梯度计算
    1
    2
    3
    4
    @torch.no_grad()
    def update_weights(model):
    for param in model.parameters():
    param.add_(1.0) # 在参数上加 1,不会记录到计算图中

计算评估指标时不计算梯度

  • 当计算评估指标(如准确率、F1 分数等)时,不需要梯度计算
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    @torch.no_grad()
    def compute_accuracy(model, data_loader):
    correct = 0
    total = 0
    for inputs, labels in data_loader:
    outputs = model(inputs)
    _, predicted = torch.max(outputs, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
    return correct / total

附录:装饰器和上下文管理器的示例

作为装饰器

  • 装饰整个函数,使其在执行期间禁用梯度计算
    1
    2
    3
    @torch.no_grad()
    def inference(model, input_data):
    return model(input_data)

作为上下文管理器

  • 仅在特定代码块中禁用梯度计算
    1
    2
    3
    def inference(model, input_data):
    with torch.no_grad():
    return model(input_data)

附录:推理场景 torch.inference_mode() 的使用

  • 从 PyTorch 1.9 开始,引入了 torch.inference_mode(),它是 torch.no_grad() 的更高效替代品,专门用于推理阶段。与 torch.no_grad() 相比:
    • 性能更高 :torch.inference_mode() 会跳过一些额外的检查,进一步提升性能
    • 不可嵌套 :torch.inference_mode() 不能像 torch.no_grad() 那样嵌套使用
    • 推荐使用 :如果只用于推理,建议优先使用 torch.inference_mode()
  • 示例:
    1
    2
    3
    4
    @torch.inference_mode()
    def evaluate_model(model, test_loader):
    model.eval()
    ...

PyTorch——关闭梯度的方法


整体说明

  • 不记录梯度的方法包括:torch.no_grad()、torch.set_grad_enabled(False)、tensor.detach()、tensor.requires_grad = False 等
    • 注:除了不记录梯度,这些方法还会释放计算图占用的内存,显著降低内存开销
  • 特殊注意 :model.eval() 仅影响特定层(如Dropout、BatchNorm)的行为,不会禁用梯度计算 ,必须配合 with torch.no_grad(): 使用才能完全关闭梯度
  • 各种模式选择建议
    • 若要临时停止梯度计算,推荐使用with torch.no_grad()上下文管理器或者torch.set_grad_enabled()
    • 若想永久性地停止某个张量的梯度计算,可使用detach()方法或者直接设置requires_grad=False
    • 对大型模型进行微调时,在模型层面设置requires_grad能有效节省内存
    • 模型推理阶段,要同时使用model.eval()和with torch.no_grad()

方法一:使用 with torch.no_grad() 上下文管理器

  • with torch.no_grad() 管理器能够暂停所有计算图的构建,进而显著降低内存的使用量并加快计算速度
    1
    2
    3
    4
    5
    6
    import torch

    x = torch.tensor([1.0], requires_grad=True)
    with torch.no_grad():
    y = x * 2
    print(y.requires_grad) # 输出 False

方法二:使用 @torch.no_grad() 作为装饰器

  • 装饰整个函数,使其在执行期间禁用梯度计算

    1
    2
    3
    @torch.no_grad()
    def inference(model, input_data):
    return model(input_data)
  • 注:@torch.no_grad() 装饰器和 with torch.no_grad() 上下文管理器的效果是一样的,一个针对方法,一个针对上下文


方法三:使用 detach() 方法

  • 运用detach()方法可以创建一个新的张量,这个新张量和计算图没有关联
    1
    2
    3
    4
    x = torch.tensor([1.0], requires_grad=True)
    y = x * 2
    z = y.detach() # z 和计算图无关
    print(z.requires_grad) # 输出 False

方法四:使用 torch.set_grad_enabled() 实现全局开关

  • 借助这个全局开关,能够控制整个代码块是否进行梯度计算
    1
    2
    3
    4
    5
    x = torch.tensor([1.0], requires_grad=True)
    torch.set_grad_enabled(False)
    y = x * 2
    torch.set_grad_enabled(True)
    print(y.requires_grad) # 输出 False

方法五:在张量层面设置 requires_grad=False

  • 你可以在创建张量时或者之后,把requires_grad属性设置为False,以此来阻止梯度的计算

    1
    2
    3
    4
    5
    6
    7
    x = torch.tensor([1.0], requires_grad=False)  # 创建时设置
    y = x * 2
    print(y.requires_grad) # 输出 False

    # 或者之后设置
    x = torch.tensor([1.0], requires_grad=True)
    x.requires_grad = False
  • 特别示例,也可以对模型参数直接操作:对于预训练模型进行微调时,你可以冻结部分层,只对特定层计算梯度

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    import torch.nn as nn

    model = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 1)
    )

    # 冻结所有参数
    for param in model.parameters():
    param.requires_grad = False

    # 只训练最后一层
    for param in model[2].parameters():
    param.requires_grad = True

模型推理时的特殊场景

  • 特别地,推理时,常常用 model.eval() 和 with torch.no_grad() 或 @torch.no_grad() 结合

  • 在模型推理阶段,同时使用这两个方法能够有效减少内存占用并提高计算效率

    1
    2
    3
    model.eval()  # 关闭 Dropout 和 BatchNorm 等训练特有的层
    with torch.no_grad():
    outputs = model(inputs)
  • 在仅推理的场景中,更常用 torch.inference_mode() 来替代 torch.no_grad()(要求 PyTorch 版本 >= 1.9),以获得更好的性能(跳过一些推理阶段非必要的检查),详情见:PyTorch——torch.no_grad的用法

    • torch.inference_mode() 可以用作上下文管理器或者装饰器
    • 推理场景优先推荐使用 torch.inference_mode(),但 torch.inference_mode() 仅适用于推理场景,其他场景不可乱用
    • torch.inference_mode() 不能像 torch.no_grad() 那样嵌套使用

附录:PyTorch中禁用梯度计算的方法对比

  • 整体对比
    方法 是否不记录梯度 是否释放计算图内存 作用范围 使用场景
    with torch.no_grad(): 是 是 代码块 推理阶段(如模型预测)、不需要梯度的计算(如验证集评估)。
    torch.set_grad_enabled(False) 是 是 全局(直到恢复为True) 临时关闭整个代码段的梯度计算,例如批量推理。
    tensor.detach() 是 是(对新张量) 单个张量 从计算图中分离张量,例如生成对抗网络(GAN)中的生成器输出。
    tensor.requires_grad = False 是 是(设置后) 单个张量或模型参数 冻结预训练模型的部分层,只训练特定参数。
    model.eval() 否 否 模型(影响Dropout/BatchNorm),使用model.train()可恢复 推理阶段,关闭训练特有的层(如Dropout),但仍会记录梯度(需配合no_grad()使用)。

附录:前向过程不涉及梯度计算,为什么需要关闭梯度?

  • 在调用 loss.backward() 之前,PyTorch 不会计算梯度,但是模型的前向过程会构建计算图,这也会消耗额外的内存

附录:非常简单的代码示例

  • 简单的代码示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    import torch

    # 方法1: torch.no_grad()
    with torch.no_grad():
    y = x * 2 # y 不记录梯度,不构建计算图

    # 方法2: set_grad_enabled
    torch.set_grad_enabled(False)
    y = x * 2 # 全局禁用梯度
    torch.set_grad_enabled(True)

    # 方法3: detach()
    y = x.detach() * 2 # y 是脱离计算图的新张量

    # 方法4: requires_grad = False
    x.requires_grad = False
    y = x * 2 # x 不再需要梯度,y 也不记录

    # 方法5: model.eval()(需配合 no_grad())
    model.eval()
    with torch.no_grad():
    outputs = model(inputs) # 完全禁用梯度

PyTorch——叶子张量


整体说明

  • tensor.is_leaf=True 的张量被称为叶子张量(Leaf Tensor),也称为叶子变量(Leaf Variable),部分博客或书籍也称为叶子节点
  • 叶子张量分两种类型,即以下两种情况下的张量 tensor.is_leaf 返回 True:
    • 类型一:requires_grad 为 False 的张量都是叶子张量(由于 requires_grad 为 False,所以不会存储梯度)
    • 类型二:requires_grad 为 True 的张量,如果是由用户创建的 ,而不是通过其他张量运算得到的,那么它是叶子张量;
      • 通过其他张量运算得到的,都是非叶子张量
  • 特别说明:
    • detach() 函数可以将节点从计算图中剥离,使其成为叶子节点,此时 requires_grad 为 False,同时 tensor.is_leaf 会变为 True 了
    • 从 CPU 定义好后移动到 GPU 时产生的张量,或者从 GPU 定义后,挪到 CPU 上的向量,也都是非叶子张量,tensor.is_leaf 返回 False
      • 注意:不论是 GPU 还是 CPU,叶子张量的判定不变,只有用户定义的张量是叶子张量,挪动以后得都不是叶子张量(除非同时修改其
    • 只有叶子张量的 requires_grad 属性可以被修改,非叶子张量的 requires_grad 属性是不能被修改的
      • 理解:非叶子张量都是派生出来的,且 requires_grad=True 的张量,叶子张量计算梯度时依赖性非叶子张量的梯度,不能随便修改 requires_grad 属性
    • 叶子张量的 grad_fn=None (包括 requires_grad=True 和 requires_grad=False 的都是)
      • 理解:tensor.grad_fn 属性指向/存储生成 tensor 张量的计算操作(如加法、减法、乘法等),叶子张量要么是直接由用户定义出来的(此时 requires_grad=True),要么是不需要计算梯度的(此时 requires_grad=False)
    • 非叶子张量的梯度不会被保存(因为不需要使用)

叶子张量的特性

  • 在反向传播过程中,只有叶子张量的梯度会被保留并存储在张量的grad属性中
    • 用于后续优化器更新参数等操作
    • 比如:神经网络层中的权值 w 的张量均为叶子节点,反向传播 backward() 就是为了求它们的梯度,进而更新权值
  • 非叶子张量的梯度在反向传播使用完后通常会被清除 ,以节省内存
    • 比如:计算中中间的一些派生阶段就是非叶子张量

叶子张量相关示例

  • 各种叶子张量判断示例
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    import torch as torch

    a = torch.randn(3,3)
    print(a.is_leaf, a.requires_grad) # True False
    a.requires_grad = True
    print(a.is_leaf, a.requires_grad) # True True
    b = a.cuda()
    print(b.is_leaf, b.requires_grad) # False True
    c = b.detach()
    print(c.is_leaf, c.requires_grad) # True False
    d = a + 2
    print(d.is_leaf, d.requires_grad) # False True

    e = torch.randn(3,3, device="cuda", requires_grad=True)
    print(e.is_leaf, e.requires_grad) # True True
    f = e.to("cpu")
    print(f.is_leaf, f.requires_grad) # False True

附录:如何打印非叶子张量的梯度?

  • 在 PyTorch 里,当开启自动求导功能时,中间变量(非叶子张量)的梯度默认不会被保存,目的是节省内存
    • 只有叶子节点(比如直接创建的张量)的梯度会被保留
    • 在任意时刻时,非中间节点的梯度(grad属性)是都是 None
  • 要获取中间变量的梯度,有两种方法:
    • 运用 retain_grad() 方法保留梯度
    • 借助钩子(hook)来捕获梯度(可以打印或者赋值给全局变量)
  • 打印非节点张量的代码示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    import torch

    x = torch.tensor(2.0, requires_grad=True)
    y = x**2
    z = y**2

    # 方法一:使用 retain_grad() 显式保留梯度
    y.retain_grad()

    z.backward(retain_graph=True) # 若不使用 retain_graph=True,计算图会在 backward 被清空,则后续想要调用 backward() 前需要重新构造计算图
    print("方法一:", y.grad) # 输出: 16.0

    # 方法二:使用钩子 hook 捕捉梯度
    gradient_list = []
    def save_gradient(grad):
    gradient_list.append(grad)

    ## 如果 前面的 z.backward(retain_graph=True) 不使用 retain_graph=True,则 backward() 会清空计算图,这里就需要重新构造计算图,目前使用的是 z.backward(retain_graph=True) ,不需要下面的两句
    # y = x**2
    # z = y**2

    hook = y.register_hook(save_gradient)
    z.backward()
    hook.remove() # 重点:即时移除钩子,防止不必要的内存占用

    print("方法二:", gradient_list[0]) # 输出: 16.0

PyTorch——模型存储与加载


整体说明

  • 在 PyTorch 中,训练完成后,可以将模型保存到磁盘上持久化存储,模型保存方式有:
    • 保存/加载整个模型(包括模型结构和参数)
    • 保存/加载模型参数(状态字典)(推荐方式,包含参数和其他状态信息)
    • 保存/加载检查点(用于断点续训)
  • 对于跨渠道的保存和加载,可以存储为 TorchScript 格式,这是一种专为生产环境优化的模型序列化格式
    • TorchScript 格式能够将 PyTorch 模型转换为一种可序列化、可优化的中间表示形式,便于在不同环境(包括 C++ 部署)中运行
  • 保存模型时 .pt 和 .pth 两种存储格式之间没有本质区别,主要区别在于使用习惯:
    • 早期 PyTorch 文档和示例中更常用 .pth 扩展名
    • 后来随着 PyTorch 版本中,官方示例逐渐开始使用 .pt,逐渐成为更推荐的格式
    • 实际上两者完全等价,亲测:.pt 和 .pth 直接修改后缀就能混用

存储模型和加载模型示例

  • 有两种主要的方法来保存模型:保存整个模型或仅保存模型的状态字典(推荐)

整个模型存储(不常用)

  • 保存整个模型并加载整个模型

    1
    2
    3
    4
    5
    # 保存整个模型
    torch.save(model, 'simple_model.pt') # 保存

    # 加载整个模型
    model_loaded = torch.load('simple_model.pt') # 加载
    • 注意:这种方法要求模型类定义必须可用,读不到类名会出错
  • 特别提示(容易出错):容易导致模型代码定义和加载存储模型不一致情况

    • 加载规则:实际上只要模型的类名相同即可加载(按照类名匹配的),模型结构可以定义不一致
    • 模型加载后模型实际对象结构与存储真实结构一致,会丢失当前代码定义的模型所有结构(包括类属性也不存在)
    • 理论上仅仅需要定义一个类名即可正常加载,但是使用时是按照存储模型的结构来使用的

仅存储模型参数(常用)

  • 保存模型的状态字典 并加载:

    1
    2
    3
    4
    5
    6
    # 保存模型参数
    torch.save(model.state_dict(), 'simple_model_state_dict.pth')

    # 创建模型示例并加载模型参数
    model_loaded = SimpleModel() # 实例化模型
    model_loaded.load_state_dict(torch.load('simple_model_state_dict.pth'))
    • 待加载模型参数和定义的模型参数不同时会直接报错
    • 状态字典方法更灵活,允许你只保存模型参数,不包括模型结构,这在部署时特别有用

常见保存和加载方式的整体示例

  • 整体示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    import torch
    import torch.nn as nn
    import torch.optim as optim

    class DiyModel(nn.Module):
    def __init__(self):
    super(DiyModel, self).__init__()
    self.fc1 = nn.Linear(10, 20)
    self.fc2 = nn.Linear(20, 2)
    self.relu = nn.ReLU()

    def forward(self, x):
    x = self.fc1(x)
    x = self.relu(x)
    x = self.fc2(x)
    return x

    model = DiyModel()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    # 1. 保存和加载整个模型(包括结构和参数)
    # 保存
    torch.save(model, 'entire_model.pt')

    # 加载
    loaded_entire_model = torch.load('entire_model.pt')
    loaded_entire_model.eval() # 设置为评估模式,方便后续的 Serving
    print("加载整个模型成功")

    # 2. 仅保存和加载模型参数(推荐方式)
    # 保存
    torch.save(model.state_dict(), 'model_parameters.pt')

    # 加载
    loaded_model = DiyModel() # 需要先创建模型实例
    loaded_model.load_state_dict(torch.load('model_parameters.pt'))
    loaded_model.eval() # 设置为评估模式
    print("加载模型参数成功")

    # 3. 保存和加载检查点(用于断点续训)
    # 保存:包含模型参数、优化器状态、 epoch 等信息
    checkpoint = {
    'epoch': 5,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': 0.23,
    # 可以添加其他需要保存的信息
    }
    torch.save(checkpoint, 'training_checkpoint.pt')

    # 加载检查点
    loaded_checkpoint = torch.load('training_checkpoint.pt')

    # 恢复模型和优化器状态
    restored_model = DiyModel()
    restored_optimizer = optim.SGD(restored_model.parameters(), lr=0.001, momentum=0.9)

    restored_model.load_state_dict(loaded_checkpoint['model_state_dict'])
    restored_optimizer.load_state_dict(loaded_checkpoint['optimizer_state_dict'])
    epoch = loaded_checkpoint['epoch']
    loss = loaded_checkpoint['loss']

    restored_model.train() # 恢复训练时设置为训练模式
    print(f"加载检查点成功: epoch {epoch}, loss {loss}")

    # 4. 跨设备保存和加载(跨设备
    # 在 GPU 上保存,在 CPU 上加载
    if torch.cuda.is_available():
    model.cuda()
    torch.save(model.state_dict(), 'model_gpu.pt') # 同时也会保存一些 GPU 相关信息

    # 在CPU上加载GPU保存的模型
    cpu_model = DiyModel()
    # 特别需要注意的点:map_location指明设备,是必须的参数(实际上,为了兼容,建议所有的加载都加上 `map_location=device` 参数)
    cpu_model.load_state_dict(torch.load('model_gpu.pt', map_location=torch.device('cpu')))
    print("在 CPU 上加载 GPU 保存的模型成功")

TorchScript 格式

  • TorchScript 格式能够跨平台存储和加载(可在 C++ 环境中运行,无需 Python 依赖)
  • TorchScript 格式在生产环境中更常用
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    import torch
    import torch.nn as nn

    class DiyModel(nn.Module):
    def __init__(self):
    super(DiyModel, self).__init__()
    self.fc1 = nn.Linear(10, 20)
    self.fc2 = nn.Linear(20, 2)
    self.relu = nn.ReLU()

    def forward(self, x):
    x = self.fc1(x)
    x = self.relu(x)
    x = self.fc2(x)
    return x

    model = DiyModel() # 创建模型

    # 1. 将模型转换为TorchScript格式(注:这一步是必要的)
    # 方法1:跟踪(tracing) - 适用于无控制流的模型(不常用,不推荐,容易出错)
    example_input = torch.randn(1, 10)
    traced_script_module = torch.jit.trace(model, example_input)
    # tracing 方法引入一个输入数据来执行流程,同时基于执行流程生成模型静态图
    # 如果存在控制流,会只保留example_input下会遇到的控制流
    # 以后遇到其他数据也会都走这个控制流,从而导致错误发生
    # 所以仅适用于无控制流的模型,不推荐使用

    # 方法2:脚本(scripting) - 适用于有控制流的模型和无控制流的模型(常用,推荐,兼容性好)
    scripted_script_module = torch.jit.script(model) # 解析 Python 代码结构生成静态图,能完整保留模型结构
    # 特别说明:scripting形式和tracing保存的模型性能上差异不大

    # 2. 保存TorchScript模型
    traced_script_module.save("traced_model.pt")
    scripted_script_module.save("scripted_model.pt")

    # 3. 加载TorchScript模型
    loaded_traced_model = torch.jit.load("traced_model.pt")
    loaded_scripted_model = torch.jit.load("scripted_model.pt")

    # 4. 使用加载的模型进行推理
    loaded_traced_model.eval()
    loaded_scripted_model.eval()
    with torch.no_grad():
    input_data = torch.randn(1, 10)
    output1 = loaded_traced_model(input_data)
    output2 = loaded_scripted_model(input_data)

使用 tracing 方法保存含控制流模型

  • 下面是使用 tracing 方法保存含控制流模型出错的示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    import torch
    import torch.nn as nn

    class ControlFlowModel(nn.Module):
    def __init__(self):
    super().__init__()
    self.fc1 = nn.Linear(1, 1)
    self.fc2 = nn.Linear(1, 1)

    def forward(self, x):
    if x > 0: # 控制流语句:根据输入值选择不同分支
    output = self.fc1(x) # 分支1
    else:
    output = self.fc2(x) # 分支2
    return output

    model = ControlFlowModel()
    model.fc1.weight.data = torch.tensor([[2.0]]) # 分支1:输出 = 2*x
    model.fc1.bias.data = torch.tensor([0.0])
    model.fc2.weight.data = torch.tensor([[3.0]]) # 分支2:输出 = 3*x
    model.fc2.bias.data = torch.tensor([0.0])

    # 1. 使用 tracing 方法转换模型(示例输入为正数,触发分支1)
    example_input = torch.tensor([1.0]) # 正数:走fc1分支
    traced_model = torch.jit.trace(model, example_input)

    # 2. 测试不同输入的推理结果
    test_inputs = [
    torch.tensor([2.0]), # 正数(与示例输入同分支)
    torch.tensor([-1.0]) # 负数(与示例输入不同分支)
    ]

    print("Original Model Output:")
    for x in test_inputs:
    print(f"Input {x.item()}: Output {model(x).item()}")

    print("\nTracing Model Output:")
    for x in test_inputs:
    print(f"Input {x.item()}: Output {traced_model(x).item()}")
    # 输入 -1.0 时错误地输出了 -2.0,实际上应该输出 -3.0

    # Original Model Output:
    # Input 2.0: Output 4.0
    # Input -1.0: Output -3.0
    #
    # Tracing Model Output:
    # Input 2.0: Output 4.0
    # Input -1.0: Output -2.0

大模型的存储和加载

  • 大模型一般以 Safetensors 格式存储(Hugging Face 的默认存储形式),许多 CV 和 NLP 的开源模型都是这个格式
  • 超大规模模型还可以分片保存
  • 大模型存储和加载的示例:
    1
    # 待补充

附录:加载模型后的动作

  • 切换模式 :无论哪种保存方式,在加载模型后,如果模型包含 Batch Normalization、Dropout 等层,都必须调用 model.eval() 来确保这些层在推理时行为正确
    • 因为这样会关闭 Dropout 和 Batch Normalization 等层的行为变化
  • 禁用梯度 :调用模型 Serving 前建议禁用梯度
    • 加速并节省内存 :禁用梯度计算可以减少运行时的内存占用,因为在前向过程中不需要存储用于反向传播的信息,由于也不需要准备一些梯度计算的步骤,可小幅提升 Serving 速度
    • 增加代码可读性 :使用 torch.no_grad() 强调了当前操作的目的(即,仅执行前向计算而不更新模型权重),这提高了代码的可读性和意图的一致性(这对于维护和理解代码非常有帮助)
  • torch.no_grad()的实践Demo:
    1
    2
    3
    4
    5
    model.eval()  # 设置模型为评估模式,不会禁用梯度
    with torch.no_grad(): # 在评估过程中禁用梯度计算
    output = model(X_test)
    loss = criterion(output, Y_test)
    print(f'Test Loss: {loss.item():.4f}')

Python——Ray-多节点集群启动


整体说明

  • Ray 集群的启动有多种方法,本文简单总结这些方法
  • 核心概念补充:
    • 头节点(Head Node) :集群的主节点,负责管理整个集群的资源、任务调度和元数据存储
    • 工作节点(Worker Node) :通过连接头节点加入集群,提供计算资源(CPU/GPU/内存等)

配置文件启动

  • 设置配置文件 cluster.yaml

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    cluster_name: my-ray-cluster
    max_workers: 2

    # 头节点配置
    head_node:
    InstanceType: m5.large
    ImageId: ami-0abcdef1234567890 # 选择包含Ray的镜像

    # 工作节点配置
    worker_nodes:
    InstanceType: m5.xlarge

    # 启动命令(可选,自定义节点初始化)
    setup_commands:
    - pip install pandas # 安装依赖
  • 主节点启动 ray up cluster.yaml

  • 其他节点加入集群 ray attach cluster.yaml

  • 停止容器 ray down cluster.yaml


Kubernetes 部署

  • 待补充

Python 程序内启动

  • 主节点使用:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    import ray

    # 启动本地集群(自动创建头节点和工作进程)
    ray.init(
    num_cpus=4, # 模拟4核CPU
    num_gpus=1, # 模拟1块GPU
    port=6380, # 手动指定通信端口
    dashboard_host="0.0.0.0" # 允许外部访问Dashboard
    )

    # 验证集群
    print("集群节点数:", len(ray.nodes())) # 输出1(单节点模拟)

    # 关闭集群
    ray.shutdown()
  • 其他节点加入使用

    1
    2
    3
    4
    5
    6
    # 头节点启动Ray
    import ray
    ray.init()

    # 工作节点连接到头节点
    ray.init(address='头节点地址:6380')

使用 ray start 命令启动

  • ray start 是启动 Ray 集群节点的常用命令
  • 可分别分别使用 ray start 用于初始化头节点(Head Node)和工作节点(Worker Node)

启动头节点

  • 头节点是集群的入口,必须先启动

  • 基本命令格式:

    1
    ray start --head [其他可选参数]
  • 关键参数包括:

    参数 说明 示例
    --head 声明当前节点为头节点(必选) -
    --port 指定 Ray 内部通信端口(默认 6379,若被占用会自动切换) --port=6380
    --dashboard-host 允许外部访问 Ray dashboard 的主机地址(默认仅本地访问) --dashboard-host=0.0.0.0
    --dashboard-port Dashboard 端口(默认 8265) --dashboard-port=8266
    --num-cpus 手动指定该节点可用的 CPU 核心数(默认自动检测) --num-cpus=16
    --num-gpus 手动指定该节点可用的 GPU 数量(默认自动检测) --num-gpus=2
    --memory 限制节点可用内存(单位:字节,如 1000000000 表示 1GB) --memory=8000000000
    --object-store-memory 对象存储的内存上限(默认总内存的 30%) --object-store-memory=2000000000
    --block 启动后阻塞终端(不后台运行,便于调试) -
    --log-dir 指定日志目录(默认 ~/raylogs) --log-dir=/path/to/logs
  • 示例:启动一个允许外部访问 Dashboard、指定 CPU/GPU 资源的头节点:

    1
    2
    3
    4
    5
    6
    ray start --head \
    --port=6379 \
    --dashboard-host=0.0.0.0 \
    --dashboard-port=8265 \
    --num-cpus=12 \
    --num-gpus=1

启动工作节点

  • 工作节点需通过头节点的地址加入集群,命令格式:

    1
    ray start --address=<IP>:<port> [其他可选参数]
  • 关键参数说明:

    • --address:头节点的地址(必填,格式为 <IP>:<port>,即头节点启动时输出的地址)
    • 其他参数(如 --num-cpus、--num-gpus 等)与头节点相同,用于限制工作节点的资源
  • 示例:连接到 IP 为 192.168.1.100、端口为 6379 的头节点,同时指定工作节点的资源:

    1
    2
    3
    ray start --address='192.168.1.100:6379' \
    --num-cpus=8 \
    --num-gpus=0

使用 ray 命令验证集群状态

  • 可在任意节点执行下面脚本查看节点列表

    1
    ray status
    • 输出会显示集群中的所有节点及资源使用情况
  • 通过头节点的 http://<头节点IP>:8265(或其他指定端口) 查看集群监控、任务状态等

使用 ray 命令停止集群

  • 停止单个节点(包括工作节点或头节点):

    1
    ray stop
  • 特别注意:若头节点停止,整个集群会自动解散

附录:ray start 命令的其他高级配置

  • 可通过 --runtime-env 指定环境配置文件(如依赖安装、环境变量等):

    1
    ray start --head --runtime-env=runtime_env.yaml
  • 可通过 --redis-password 设置密码,防止未授权节点加入:

    1
    2
    3
    4
    5
    # 头节点
    ray start --head --redis-password='mysecret'

    # 工作节点
    ray start --address=<IP:port> --redis-password='mysecret'
  • 可通过 --log-dir 指定日志目录(默认 ~/raylogs):

    1
    ray start --head --log-dir=/path/to/logs

附录:通过 ray job submit 向已经启动的 Ray 集群提交任务

  • ray job submit 命令用于将作业提交到 Ray 集群

基本语法

  • 用法说明:

    1
    ray job submit [options] -- <entrypoint> [<entrypoint_args>]
    • [options] 是命令的可选参数
    • -- 之后的 <entrypoint> 是要执行的入口点脚本或命令
      • 可以是 -- python my_script.py 或 bash my_shell.sh
    • <entrypoint_args> 是传递给 <entrypoint> 的参数

常用参数

  • --address:指定 Ray 集群的地址,通常是集群头节点的地址和端口,如http://127.0.0.1:8265
  • --working-dir:指定作业的工作目录
    • 该目录下的文件会被同步到集群节点上,默认为当前目录
    • 启动脚本会有一些文件依赖,这里上传所有文件可以保证本地能访问的文件,集群也能访问
  • --runtime-env:用于指定作业的运行时环境,可以是一个 JSON 格式的字符串或 YAML 文件路径
    • 例如,可以通过该参数指定需要安装的 Python 包,如--runtime-env='{"pip": ["requests"]}'
  • --no-wait:提交作业后不等待作业完成,立即返回
    • 如果不指定该参数,命令会等待作业完成,并输出作业的日志和结果
  • --submission-id:指定作业的提交 ID
    • 如果不指定,Ray 会自动生成一个唯一的 ID

用法示例

  • 假设要提交一个 Python 脚本 my_script.py 到 Ray 集群,指定集群地址为 http://127.0.0.1:8265,工作目录为当前目录:

    1
    RAY_ADDRESS='http://127.0.0.1:8265' ray job submit --working-dir. -- python my_script.py
  • 假设提交一个 Python 脚本 train.py,并传递参数 --epochs 10 --batch-size 32,同时指定运行时环境需要安装 torch 和 numpy 包:

    1
    ray job submit --address=http://your-ray-cluster-address:8265 --runtime-env='{"pip": ["torch", "numpy"]}' -- python train.py --epochs 10 --batch-size 32

其他高级配置

  • py_modules :指定需要导入的自定义 Python 模块路径,支持将本地模块添加到 Python 路径(sys.path),示例如下:

    1
    runtime_env={"py_modules": ["./my_utils"]}  # 同步 my_utils 模块并添加到路径
  • env :指定预定义的环境名称(如 Ray 集群中已配置的共享环境),避免重复配置,示例如下:

    1
    runtime_env={"env": "shared-training-env"}  # 使用集群中预定义的环境
  • 特别说明:runtime_env 还支持其他非官方扩展,比如 verl 库中就为英伟达显卡配置了 nsight 工具的参数传入

    1
    2
    ./verl/trainer/main_ppo.py
    runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote()
1…212223…63
Joe Zhou

Joe Zhou

Stay Hungry. Stay Foolish.

630 posts
53 tags
GitHub E-Mail
© 2026 Joe Zhou
Powered by Hexo
|
Theme — NexT.Gemini v5.1.4