Jiahong 的个人博客

凡事预则立,不预则废


  • Home

  • Tags

  • Archives

  • Navigation

  • Search

PyTorch——分布式编程之子通信组


整体介绍

  • 子通信组允许在分布式环境中灵活地划分进程,实现更精细的通信控制
  • torch.distributed.new_group 是 PyTorch 分布式训练中用于创建子通信组的函数
  • 使用场景包括:
    • 部分进程通信:当需要让部分进程单独通信(如模型并行中不同层的参数同步)
    • 灵活分组:动态划分进程组,适应复杂的分布式策略(如混合数据并行+模型并行)
  • 使用步骤包括:
    • 1)初始化全局进程组:先通过 init_process_group 初始化全局通信环境
    • 2)创建子组:调用 new_group 划分进程
    • 3)子组内通信:使用返回的 ProcessGroup 对象进行通信操作
  • 子通信组使用的核心注意事项
    • 进程一致性:所有进程必须调用 new_group ,即使不加入子组(此时可传入 ranks 不包含自身,或后续不使用返回的组对象)
    • 后端兼容性:子组的 backend 需与全局后端兼容(如 GPU 通信推荐 nccl)

new_group 函数定义

  • new_group 函数形式说明:

    1
    2
    3
    4
    5
    torch.distributed.new_group(
    ranks=None, # 参与新组的进程编号列表
    timeout=datetime.timedelta(seconds=1800), # 超时时间
    backend=None # 通信后端,默认为全局后端
    )
  • new_group 核心参数说明如下文

    • ranks(可选,列表/元组):
      • 指定加入新组的进程编号(全局进程编号,非局部编号)
      • 若为 None,则默认包含所有进程(等价于全局组)
      • 例如:ranks=[0,1,2] 表示仅 0、1、2 号进程加入新组
    • timeout**(可选,datetime.timedelta):
      • 组内通信的超时时间,超时未完成会抛出异常
      • 默认为 30 分钟(1800 秒)
    • backend**(可选,字符串):
      • 指定该组使用的通信后端(如 nccl、gloo 等)
      • 若为 None,则继承全局初始化的后端(init_process_group 中指定的 backend)
  • new_group 返回值

    • 返回一个 ProcessGroup 对象 ,用于后续子组内的通信操作(如 allreduce、broadcast 等)

附录:为什么所有进程都要调用子通信组初始化函数

  • 在 PyTorch 分布式的最佳实践中,即使不加入子进程组的 rank(如例子中的 rank=3),也必须调用 dist.new_group(ranks=[0, 1], ...)
    • 这是由分布式通信的一致性要求决定的
  • 必须调用的核心原因是:避免死锁
    • PyTorch 分布式通信的底层实现要求所有进程必须参与子组的创建过程 ,无论是否加入该子组
    • 若部分进程调用 new_group 而其他进程不调用,会导致进程间同步失衡,触发分布式死锁(所有进程会阻塞等待未调用的进程)
    • 即使某进程明确不加入子组(不在 ranks 列表中),也需要通过调用 new_group 完成“知晓该子组存在”的协议同步
  • 不加入子组的进程如何处理返回的 ProcessGroup 对象?
    • 对于不加入子组的进程(如 rank=3),调用 new_group 后会返回一个有效的 ProcessGroup 对象,但该进程不属于该组
    • 此时的最佳实践是:保留该对象但不使用它进行通信(或在通信前先判断是否属于子组)
    • 可通过 dist.get_rank(group=subgroup) 检查:若返回 -1,说明当前进程不属于该子组,应跳过子组内的通信操作

子通信组示例代码

  • 代码示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    import torch
    import torch.distributed as dist
    from datetime import timedelta

    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()

    # 所有进程(无论是否加入子组)必须调用 `dist.new_group`,否则会导致分布式死锁
    subgroup = dist.new_group(ranks=[0, 1], timeout=timedelta(seconds=30))

    # 检查当前进程是否属于子组,非子组成员进程可通过 `dist.get_rank(group=subgroup) != -1` 判断身份,避免无效通信
    is_in_subgroup = dist.get_rank(group=subgroup) != -1

    if is_in_subgroup:
    # 子组成员执行通信操作
    tensor = torch.tensor([rank], device="cuda")
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=subgroup)
    print(f"Rank {rank}(子组成员):通信结果 = {tensor.item()}")
    else:
    # 非子组成员跳过通信,或执行其他逻辑
    print(f"Rank {rank}(非子组成员):不参与子组通信")

    dist.destroy_process_group()

PyTorch——分布式训练Debug笔记


整体说明

  • 为了方便在分布式训练中查看代码信息,需要一些日志打印

优雅打印对象的函数

  • 示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    import torch
    from typing import Any
    import os
    import inspect

    def get_current_location() -> str:
    """
    获取当前执行位置的「文件绝对路径」和「所在函数名称」
    :return: (file_path, func_name)
    - file_path: 当前文件的绝对路径
    - func_name: 所在函数名称(模块级代码返回 "<module>",匿名函数返回 "<lambda>")
    """
    # 获取调用栈:index=1 对应「调用当前函数的位置」(即目标位置)
    try:
    # 栈帧结构:inspect.stack()[index] -> FrameInfo 对象
    frame_info = inspect.stack()[1]
    frame = frame_info.frame # 提取栈帧

    # 1. 获取文件绝对路径
    file_path = os.path.abspath(frame.f_code.co_filename)

    # 2. 获取函数名称
    func_name = frame.f_code.co_name

    # 特殊处理:模块级代码(无函数包裹)的函数名显示为 "<module>"
    # (inspect 默认返回 "<module>",无需额外处理)

    return f"python file path:{file_path} #function_name:{func_name}"

    finally:
    # 手动清理栈帧引用,避免内存泄漏(关键!)
    del frame_info
    del frame

    def print_obj_info(obj: Any, indent: int = 0) -> None:
    """
    打印对象的详细信息,包括类型、大小/长度、关键属性及嵌套对象信息
    :param obj: 待打印的对象
    :param indent: 缩进级别(用于嵌套结构格式化)
    """
    # 缩进格式化
    prefix = " " * indent
    type_name = type(obj).__name__

    # 基础信息:类型 + 核心属性
    base_info = f"{prefix}[{type_name}] "

    # 1. 列表类型(含嵌套列表)
    if isinstance(obj, list):
    base_info += f"len={len(obj)}"
    print(base_info)
    # 递归打印前3个元素(避免超长输出),超过则提示
    for i, item in enumerate(obj[:2]):
    print(f"{prefix} - 索引{i}:", end=" ")
    print_obj_info(item, indent + 2)
    if len(obj) > 2:
    print(f"{prefix} - ... 还有{len(obj)-2}个元素")

    # 2. 字典类型
    elif isinstance(obj, dict):
    base_info += f"len={len(obj)}, keys={list(obj.keys())}"
    print(base_info)
    # 递归打印每个value
    for k, v in obj.items():
    print(f"{prefix} - key='{k}':", end=" ")
    print_obj_info(v, indent + 2)

    # 3. PyTorch Tensor类型
    elif isinstance(obj, torch.Tensor):
    base_info += f"shape={tuple(obj.shape)}, dtype={obj.dtype}, device={obj.device}"
    print(base_info)

    # 4. 其他普通类型(数字、字符串、布尔等)
    else:
    # 补充长度信息(字符串)和值信息
    if hasattr(obj, "__len__") and not isinstance(obj, (int, float, bool)):
    obj_str = f"{obj}"
    log_obj_str = obj_str[:100]
    base_info += f"len={len(obj_str)}, value={log_obj_str}" + (f", 还有{len(obj_str)-100} 个字符" if len(obj_str) > 100 else "")
    else:
    base_info += f"value={obj}"
    print(base_info)


    def test_function():
    # 测试数据
    test_obj = {
    "int_val": 42,
    "str_val": "hello" * 30,
    "tensor1": torch.randn(3, 4),
    "nested_list": [1, torch.tensor([2,3]), [4.5, "6"]],
    "bool_val": True
    }

    # 调用函数

    print("="*30)
    print(f"call_location={get_current_location()}")
    print_obj_info(test_obj)
    print("="*60)

    if __name__ == "__main__":
    test_function()

    # ==============================
    # call_location=python file path:/path_to_log_helper.py #function_name:test_function
    # [dict] len=5, keys=['int_val', 'str_val', 'tensor1', 'nested_list', 'bool_val']
    # - key='int_val': [int] value=42
    # - key='str_val': [str] len=150, value=hellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohello, 还有50 个字符
    # - key='tensor1': [Tensor] shape=(3, 4), dtype=torch.float32, device=cpu
    # - key='nested_list': [list] len=3
    # - 索引0: [int] value=1
    # - 索引1: [Tensor] shape=(2,), dtype=torch.int64, device=cpu
    # - ... 还有1个元素
    # - key='bool_val': [bool] value=True
    # ============================================================

反向解析并打印某未知函数

  • 有时候调用的函数是经过多次封装得到的,比如 Megatron-LM 项目中存在大量的封装代码
  • 打印函数信息示例
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    import functools
    import inspect
    def print_function_info(function):
    print("=" * 30 + "Inspecting function")

    # 如果是 partial 封装过的对象,需要特殊逻辑 function.func 取出真实的函数
    if isinstance(function, functools.partial):
    print("[Type]: functools.partial")
    print(f"[Original Function]: {function.func.__name__}")
    print(f"[Preset Args]: {function.args}")
    print(f"[Preset Keywords]: {function.keywords}")
    real_func = function.func
    else:
    print(f"[Type]: {type(function)}")
    if hasattr(function, '__name__'):
    print(f"[Name]: {function.__name__}")
    real_func = function

    print("-" * 20)
    try:
    source = inspect.getsource(real_func) # 根据真实函数取出其源代码
    print("[Source Code]:")
    print(source)
    except Exception as e:
    print(f"[Source Code]: Unable to retrieve source. ({e})")
    print("=" * 30)
    print_function_info(my_diy_function)

参数值检查

  • 参数值输出示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    # 关闭梯度确保不影响梯度值
    def print_tensor(model_head)
    with torch.no_grad():
    weights = model_head.weight.detach().cpu()
    weights_flat = weights.view(-1)
    num_params = min(1000, weights_flat.numel())
    first_1000_params = weights_flat[:num_params].tolist()
    print("="*30 + "print init_param")
    print(f"Shape of weights: {weights.shape}")
    print(f"First {num_params} parameters of weights:\n{first_1000_params}")

随机种子查看和设置

  • 随机种子涉及到 shuffle,模型参数初始化等操作,如果要对齐两个配置相同的模型,种子也需要对齐

torch Seed 打印

  • torch Seed 打印代码:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    # 打印 torch 的随机种子情况
    def print_torch_seeds():
    print("=" * 30 + "PyTorch Random Seeds Status")
    print("=" * 30)
    cpu_seed = torch.initial_seed()
    print(f"[CPU] Seed: {cpu_seed}")

    if torch.cuda.is_available():
    try:
    gpu_seed = torch.cuda.initial_seed()
    current_device = torch.cuda.current_device()
    device_name = torch.cuda.get_device_name(current_device)

    print(f"[GPU] Seed: {gpu_seed}")
    print(f" Device: {current_device} ({device_name})")
    except Exception as e:
    print(f"[GPU] Error getting seed: {e}")
    else:
    print("[GPU] CUDA is not available.")

    print("=" * 30)
    print_torch_seeds()

torch Seed 设置

  • 全局 torch Seed 设置代码:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    import torch

    # 固定CPU种子
    torch.manual_seed(42)

    # 固定所有GPU的种子(单GPU/多GPU通用)
    if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42) # 替代 torch.cuda.manual_seed(42)(单GPU)

    # GPU上生成随机排列
    perm = torch.randperm(10, device="cuda") # 注意:需要指定 "cuda" 才会在 GPU 上执行
    print("GPU随机排列:", perm) # 每次运行结果一致
    print("draw a random number:", torch.rand()) # 每次运行结果一致
  • 使用独立的 torch 生成器(独立管理自己的随机生成器):

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    import torch

    # 创建独立的生成器并设置种子
    generator = torch.Generator()
    generator.manual_seed(42)

    # 生成随机排列时指定生成器
    perm1 = torch.randperm(10, generator=generator)
    perm2 = torch.randperm(10, generator=generator)

    print("独立生成器-第一次:", perm1) # tensor([2, 7, 3, 1, 0, 9, 4, 5, 8, 6])
    print("独立生成器-第二次:", perm2) # tensor([2, 0, 7, 9, 8, 4, 3, 6, 1, 5])

    # 重置生成器种子,结果重复
    generator.manual_seed(42)
    perm3 = torch.randperm(10, generator=generator)
    print("重置生成器后:", perm3) # tensor([2, 7, 3, 1, 0, 9, 4, 5, 8, 6])(和perm1一致)
    • 说明:torch.Generator 是 PyTorch 中统一的随机数生成器(RNG)核心对象,几乎所有 PyTorch 内置的随机操作都支持通过 generator 参数指定该生成器

PyTorch——分布式编程框架总结

  • 参考链接:
    • 简单入门可参考:「指北」PyTorch分布式训练 - Will Lee的文章 - 知乎

整体说明

  • PyTorch 原生支持了 DP(DataParallel)和DDP(DistributedDataParallel)是常用的数据并行分布式训练工具
  • PyTorch 原生还支持了 FSDP 作为模型并行的分布式训练工具,通过分片模型参数、梯度和优化器状态到多个 GPU,显著降低单卡内存占用
  • HuggingFace Accelerate 是一个轻量级库,专为简化 PyTorch 模型在各种硬件配置上的训练和推理而设计,支持选择 DeepSpeed 和 FSDP 等
  • HuggingFace 的 Trainer 是 transformers 库中一个核心且功能强大的类,它为 PyTorch 模型提供了完整的训练和评估循环,极大地简化了训练过程,让用户可以专注于模型、数据集和训练参数的配置,而无需手动编写复杂的训练代码
    • Trainer 比 Accelerate 更高一级,把循环等也封装了,进需要用户配置参数数据集等即可
  • 其他相关的分布式封装框架有 Horovod、Ray 等
    • Horovod 是 Uber 开源的跨平台的分布式训练工具,名字来自于俄国传统民间舞蹈,舞者手牵手围成一个圈跳舞,与 Horovod 设备之间的通信模式很像
    • Ray 是更高层级的分布式训练框架(利用其他框架),目标是融合数据处理、模型训练、超参数调优和模型服务等各个阶段
  • 注意:篇幅有限,本文主要是记录一些简单的使用示例和说明,更详尽的使用细节需要去官网查看
  • PyTorch 分布式系统的启动方式见:/Notes/PyTorch/PyTorch——分布式程序启动方式汇总

DataParallel(DP) 使用示例

  • DP 是单进程多线程模式,简单易用,适合单机多 GPU 场景
  • DP 的每次前向过程都会进行一次从 GPU 间的参数复制(效率较慢)
  • 参考博客(有比较清晰的图片):Training Neural Nets on Larger Batches: Practical Tips for 1-GPU, Multi-GPU & Distributed setups
  • DP 的使用示例如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    import torch
    import torch.nn as nn
    from torch.utils.data import DataLoader, Dataset

    class DiyModel(nn.Module):
    def __init__(self):
    super().__init__()
    self.fc = nn.Linear(10, 2)
    def forward(self, x):
    return self.fc(x)

    class DiyDataset(Dataset):
    def __len__(self):
    return 1000
    def __getitem__(self, idx):
    return torch.randn(10), torch.randint(0, 2, (1,)).item()

    # 第一步:数据准备(与常规方式一致)
    dataset = DiyDataset()
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

    # 第二步:初始化模型、损失函数、优化器
    model = DiyModel()
    # 将模型放到DP中(自动分发到多个GPU),核心步骤,相对普通训练方式,仅需要修改这里
    model = nn.DataParallel(model) # 注意:DP 仅增加这一步即可
    model = model.cuda() # 或 .to('cuda')

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    # 第三步:训练循环
    for epoch in range(3):
    for inputs, labels in dataloader:
    inputs = inputs.cuda()
    labels = labels.cuda()

    # 注:model 是 DP封装过的模型,所以能执行 DP 的个性化操作
    # 在这里 model(inputs) 会执行四个流程:
    # * 数据分发
    # * 模型复制(主线程GPU到其他线程)
    # * 并行前向推理
    # * 输出汇总四个流程
    outputs = model(inputs)
    loss = criterion(outputs, labels)

    optimizer.zero_grad()

    # 注:loss 本质是从 Model 出来的(loss 的梯度是传到 outputs 上再进一步计算的,而 outputs 是模型的输出,执行梯度计算时能考虑分布式),所以也能执行 DP 的个性化操作
    # 这里 loss.backward 执行四个流程:
    # * 各GPU损失梯度计算:根据各GPU的outputs与总output的关系来计算各自的损失梯度(不同GPU不一样)
    # * 损失梯度分发(将各自的梯度分发到各自 GPU 上
    # * 并行后向推理(计算 GPU 自身的局部梯度)
    # * 梯度汇总
    loss.backward()
    optimizer.step() # 仅更新主线程GPU上的模型
    print(f"Epoch {epoch}, Loss: {loss.item()}")

nn.DataParallel(model) 发生了什么?

  • TLDR:model = nn.DataParallel(model) 的核心作用是通过包装模型实现多 GPU 数据并行 :
    • 执行这一行后,模型对象已经变了,需要通过 model.module.fc 才能访问原模型属性(访问 model.fc 会出错)
    • DP 中,仅需要这一步即可实现数据并行 ,后续的代码都无需修改,自动适配了;但代码后面做了很多数据分发和梯度合并的工作
    • 这一行后,程序会自动管理设备分配、模型复制、数据拆分与合并、梯度汇总与参数同步
  • 1)初始化:确定并行设备与包装模型,nn.DataParallel 实例化时会完成以下核心操作:
    • 检测可用GPU :默认情况下,DataParallel 会自动检测当前可见的GPU(通过 CUDA_VISIBLE_DEVICES 环境变量控制),并将其ID列表存储在 device_ids 属性中(默认值为 range(torch.cuda.device_count()))
    • 指定主GPU :主GPU(device[0])是默认的“主导设备”,负责汇总计算结果、更新参数,并作为数据/模型的初始落脚点。若未指定 device_ids,则第一个可见GPU(通常是 cuda:0)会被设为主GPU
    • 包装原始模型 :原始模型会被存入 DataParallel 的 module 属性中(这也是后续访问原模型属性需通过 model.module 的原因),同时 DataParallel 会接管模型的 forward 和 backward 逻辑
  • 2)模型的移动与复制:DataParallel 会自动处理模型在设备间的分布:
    • 主GPU加载模型 :若原始模型在CPU上,DataParallel 会先将其移动到主GPU(device_ids[0]);若模型已在某个GPU上,会检查是否与主GPU一致,不一致则移动
    • 副本同步到其他GPU :在首次执行前向传播时,DataParallel 会将主GPU上的模型参数复制到 device_ids 中的其他GPU,确保所有GPU上的模型初始参数完全一致
  • 3)前向传播:数据拆分与并行计算,当调用 output = model(input) 时,DataParallel 会按以下流程处理:
    • 数据校验与准备 :检查输入数据是否在主GPU上(若不在,会自动移到主GPU)。输入可以是Tensor、列表、字典等结构,只要包含需要拆分的批量数据(如 batch_size 维度的Tensor)
    • 数据拆分(Split) :沿批量维度(默认是第0维,即 batch_size 所在维度)将输入数据均匀拆分到 device_ids 中的所有GPU。例如,若 batch_size=8 且使用2个GPU,则每个GPU会收到 batch_size=4 的子数据
    • 并行计算 :每个GPU上的模型副本会独立处理分配到的子数据,执行各自的 forward 计算,得到子输出
    • 结果收集与合并(Gather) :所有GPU的子输出会被发送回主GPU,然后按拆分的逆过程合并(如拼接),最终形成与单GPU计算格式一致的输出(例如,将2个 batch_size=4 的子输出拼接为 batch_size=8 的完整输出)
  • 4)反向传播:梯度汇总与参数更新
    当调用 loss.backward() 时,梯度计算与参数更新流程如下:
    • 各GPU独立计算梯度 :每个GPU会基于自己处理的子数据和子输出,独立计算模型参数的梯度(存储在各自GPU的模型副本中)
    • 梯度汇总到主GPU :DataParallel 会自动将所有GPU上的梯度求和(sum)并汇总到主GPU的模型中(即 model.module 的参数梯度)
    • 主GPU更新参数 :优化器(如 optimizer.step())仅作用于主GPU上的模型参数,完成参数更新
    • 参数同步到其他GPU :主GPU更新后的参数会被自动广播(broadcast)到其他GPU的模型副本中,确保所有GPU的模型参数保持一致,为下一轮计算做准备
  • 5)特殊细节与注意事项
    • 模型属性访问 :由于原始模型被包装在 DataParallel 的 module 属性中,访问原模型的层或参数时需通过 model.module(例如,model.module.fc 而非 model.fc)。若仅使用单GPU,DataParallel 仍会包装模型,因此建议始终通过 model.module 访问原模型
    • 单GPU场景 :若只有1个可见GPU,DataParallel 不会进行拆分计算(本质上是单GPU运行),但仍会包装模型,此时 model.module 与原始模型等价
    • 数据类型与设备兼容 :输入数据必须与主GPU设备兼容(例如,若主GPU是 cuda:0,输入数据需为 cuda:0 上的Tensor),否则会触发设备不匹配错误
    • 局限性 :DataParallel 是单进程多线程模式,受Python GIL限制,多GPU效率可能不如多进程的 DistributedDataParallel(DDP),且不支持跨节点并行

DistributedDataParallel(DDP)使用示例

  • 更详细的使用说明见:官网说明文档
  • DDP是多进程模式,支持单机/多机多GPU,效率更高,是PyTorch推荐的分布式训练方式

第一步:编写训练脚本

  • 注:这一步只是定义脚本,这个脚本不能通过简单的 python 命令启动(需通过torch.distributed.launch启动)
  • DDP 的示例脚本如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    import os

    import torch
    import torch.nn as nn
    import torch.distributed as dist
    from torch.utils.data import DataLoader, Dataset
    from torch.utils.data.distributed import DistributedSampler # 用于DDP的数据采样

    class DiyModel(nn.Module):
    def __init__(self):
    super().__init__()
    self.fc = nn.Linear(10, 2)

    def forward(self, x):
    return self.fc(x)

    class DiyDataset(Dataset):
    def __len__(self):
    return 1000

    def __getitem__(self, idx):
    return torch.randn(10), torch.randint(0, 2, (1,)).item()

    # 初始化分布式环境
    dist.init_process_group(backend='gloo') # 多GPU推荐用'nccl'后端(NVIDIA专门优化过),CPU使用'gloo'或'mpi',亲测Mac上'gloo'可直接使用
    rank = dist.get_rank() # 全局进程编号(0,1,2...),也可以用 env_rank = int(os.environ["RANK"])
    world_size = dist.get_world_size() # 也可以使用 world_size = int(os.environ["WORLD_SIZE"])
    local_rank = int(os.environ["LOCAL_RANK"]) # 当前进程的local_rank

    # 定义设备并设置
    device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
    torch.cuda.set_device(device) # 为当前进程分配 GPU

    # 数据准备
    dataset = DiyDataset()
    sampler = DistributedSampler(dataset) # 重点:sampler确保各进程数据不重叠
    dataloader = DataLoader(dataset, batch_size=32, sampler=sampler) # sampler传入时,shuffle参数不生效

    # 模型初始化(每个进程单独初始化,再用DDP包装)
    model = DiyModel().to(device)
    model = torch.nn.parallel.DistributedDataParallel( # 核心步骤
    model,
    device_ids=[local_rank],
    output_device=local_rank
    )

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    # 训练循环
    for epoch in range(3):
    sampler.set_epoch(epoch) # 每个epoch打乱数据,避免每个epoch数据顺序一致(epoch内部顺序不随机时容易导致模型学到错误的样本顺序规律)
    for inputs, labels in dataloader:
    # 使用to(device)将数据移动到指定设备
    inputs = inputs.to(device)
    labels = labels.to(device)

    # 前向过程,会同时追踪需要的同步的数据
    outputs = model(inputs)
    loss = criterion(outputs, labels)

    optimizer.zero_grad()

    # 后向过程
    # 注:DDP 在初始化时会为模型参数注册特殊的钩子(hook),这些钩子会在 backward() 过程中自动触发
    # 当本地梯度计算完成后,钩子会启动 All-Reduce 操作,将所有进程的同一份参数的梯度进行汇总(默认求平均)
    # 完成上述流程后,每个进程上的同一份参数会拥有相同的梯度值,存储在 `grad` 属性中
    loss.backward()
    optimizer.step() # 普通的 optimizer 更新每个进程自己的参数

    # 只在主进程(rank=0)打印信息
    if rank == 0:
    print(f"Epoch {epoch}, Loss: {loss.item()}")

    dist.destroy_process_group() # 销毁进程组
附录:DDP 的 DP 间求平均的细节
  • DDP 中在 DP 间累积梯度后,做了平均,具体实现参见 github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/reducer.cpp

    1
    2
    3
    4
    5
    // 取值与 DP_size 有关(注意: 这里的 size 就是 DDP 中的 world_size,也就是 DP_size)
    div_factor_ = process_group_->getSize();
    ...
    // 做除法
    bucket_view.div_(div_factor_);
  • 问题:为什么 DDP 中,loss.backward() 要对梯度求平均而不是求和?

    • 理解:本质应该是一样的,求和求平均都可以,目前实现本质也是先求和再做除法
  • 特别说明:如果是想做 Token 粒度的平均(每个样本的可学习 Token 数不一致),需要多维护一个 Token 数量的变量并执行一次 all_reduce 通信

    • 当然,为了实现与不做 DP 完全一致的效果,这里其实是应该对 Token 也做聚合,再做除法才行的

第二步:启动脚本(通过 torch.distributed.launch 命令)

  • 假设脚本名为ddp_demo.py

  • 单机4卡,使用4个GPU启动:

    1
    python -m torch.distributed.launch --nproc_per_node=4 ddp_demo.py
    • --nproc_per_node:指定单机的GPU数量
  • 多机多卡,2个机器,各有4个GPU启动:

    1
    2
    3
    4
    5
    # 机器1
    python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr="localhost" --master_port=29500 DDP_demo.py

    # 机器2
    python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr="localhost" --master_port=29500 DDP_demo.py
    • --nnodes:总节点数
    • --node_rank:当前节点编号
    • --master_addr:指定master节点的IP地址,默认值是”localhost”,但是显示给出来更明确
    • --master_port:指定master节点的端口号,默认值是 29500,但是显示给出来更明确

补充:torchrun 命令启动脚本(结论是在 Mac 系统下得到的,暂时没有在 Linux 上尝试)

  • torchrun 启动单机多卡(与 torch.distributed.launch 相同):

    1
    2
    3
    4
    5
    # 回顾 `torch.distributed.launch` 的启动方式
    python -m torch.distributed.launch --nproc_per_node=4 ddp_demo.py

    # torchrun 的启动方式:替换 `python -m torch.distributed.launch` 为 `torchrun` 即可
    torchrun --nproc_per_node=4 ddp_demo.py
  • torchrun 启动多机多卡(相对 torch.distributed.launch而言,torchrun 会自动推断部分参数(如节点数、进程数))

    • (待确认:Mac 系统下失败)torchrun 无需手动指定 node_rank,只需在所有节点上指定相同的 --rdzv_id(任务ID)和 --rdzv_endpoint(主节点地址):

      1
      2
      # 所有节点统一执行相同命令(不需要区分不同节点使用不同命令,会自动分配node_rank)
      torchrun --nproc_per_node=4 --nnodes=2 --rdzv_id=123 --rdzv_backend=c10d --rdzv_endpoint="localhost:29500" DDP_demo.py
      • --rdzv_id:任务唯一ID(任意整数)
      • rdzv_backend:后端(默认c10d)
      • rdzv_endpoint:主节点IP:端口
      • torchrun还可以用 --max_restarts 等参数指定最大重启次数
    • (Mac 系统下成功)在 Mac 系统下,上面的命令执行会出现问题(上面命令在两个窗口分别打开)

      1
      2
      torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr="localhost" --master_port=29500 DDP_test.py
      torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr="localhost" --master_port=29500 DDP_test.py
      • 亲测上面的代码可以成功
  • 更多启动详情见:/Notes/PyTorch/PyTorch——分布式程序启动方式汇总


HuggingFace Accelerate 介绍及使用示例

  • 详情见:/Notes/PyTorch/PyTorch——HF-Accelerate使用总结

HuggingFace Trainer 介绍及使用示例

  • 详情见:/Notes/NLP/NLP——HF-Trainer使用总结

PyTorch——分布式程序启动方式汇总


整体介绍

  • torch.distributed.dist.init_process_group 是 PyTorch 分布式 包中用于初始化进程组的核心函数,它在分布式训练中负责协调多个进程之间的通信
  • 本文重点讲解 torch.distributed.dist.init_process_group 函数的使用

dist.init_process_group 函数使用注意事项

  • 必须在所有进程中调用该函数,且参数需保持一致(除 rank 外)
  • 初始化后需调用 dist.destroy_process_group() 进行清理,否则在复杂的程序中,容易出现资源泄露问题
  • 实际使用中推荐通过 torch.distributed.launch 或 torchrun 工具启动,他们会自动设置环境变量
    • 使用 torch.multiprocessing.spawn 启动则需要自己管理参数或环境变量
  • 不同后端有不同的适用场景:GPU 集群优先用 nccl,CPU 集群用 gloo(Mac 上实验使用 gloo)
  • 确保所有进程能够访问到 init_method 指定的地址或文件
  • 初始化时的 IP 问题,有两个特点:
    • MASTER_ADDR 必须是 rank=0 的机器所在的 IP(torchrun --node_rank=0的机器所在的 IP),该进程负责作为 master 完成交流和初始化操作
    • 在 dist.init_process_group 执行后所有进程的地位是等价的,都可以作为 master (虚拟的含义:处理主要事务)
      • 比如,dist.broadcast(tensor, src) 中的 src 可以被指定为任意值
      • 亲测,在单机启动后,想用哪个进程作为 master(处理主要事务)都可以
    • 一般建议使用 rank=0 的进程作为 master 处理主要事务

函数原型及其参数讲解

  • dist.init_process_group 函数原型
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    torch.distributed.init_process_group(
    backend: Optional[str] = None,
    init_method: Optional[str] = None,
    timeout: Optional[timedelta] = None,
    world_size: int = -1,
    rank: int = -1,
    store: Optional[Store] = None,
    group_name: str = "",
    pg_options: Optional[Any] = None,
    device_id: Optional[torch.device] = None,
    )

dist.init_process_group 主要参数说明

  • 1)backend(必填)

    • 指定通信后端,决定了进程间通信使用的底层协议
    • 可选值:'nccl'(推荐 GPU 通信)、'gloo'(CPU 和 GPU 均可)、'mpi'(需 MPI 库支持)
    • 注意:'nccl' 仅支持 GPU 且性能最优,'gloo' 对 CPU 支持更好
  • 2)init_method(可选)

    • 指定进程组的初始化方式,常用方式如下面
    • 'env://'(推荐):从环境变量读取配置(需设置 MASTER_ADDR、MASTER_PORT 等)
    • 'tcp://ip:port':指定主节点的 IP 和端口
    • 'file:///path':通过共享文件系统初始化(需所有进程可访问该路径)
    • init_method 参数的默认值为 None,当使用默认值时,其行为取决于是否设置了环境变量 TORCH_DISTRIBUTED_INIT_METHOD:
      • 如果设置了 TORCH_DISTRIBUTED_INIT_METHOD 环境变量
        • 函数会自动使用该环境变量的值作为初始化方法(等价于显式传入 init_method=环境变量值)
        • 例如,若环境变量设置为 tcp://127.0.0.1:23456,则会以该 TCP 地址进行初始化
      • 如果未设置 TORCH_DISTRIBUTED_INIT_METHOD 环境变量
        • 此时会触发 默认初始化逻辑 ,函数会尝试从环境变量中读取分布式配置(等价于 init_method='env://')
        • 此时要求必须设置以下环境变量才能正常初始化:
          • MASTER_ADDR:主节点的 IP 地址
          • MASTER_PORT:主节点的端口号(需所有进程可访问)
          • WORLD_SIZE:总进程数(可选,部分启动工具会自动设置)
          • RANK:当前进程的全局编号(0 为主进程,可选,部分启动工具会自动设置)
        • 如果这些环境变量未正确设置,会抛出类似 RuntimeError: Expected env:// init method but no MASTER_ADDR or MASTER_PORT found 的错误
          • 但无需担心:使用 torchrun(或 python -m torch.distributed.launch)启动时,若不指定 --master_addr 和 --master_port,则 torchrun 会默认 环境变量设置为 MASTER_ADDR:MASTER_PORT=127.0.0.1:29500
  • 3)world_size(可选)

    • 总进程数,即参与分布式训练的进程总数
    • 使用 'env://' 时可由环境变量 WORLD_SIZE 指定
  • 4)rank(可选)

    • 当前进程的编号(0 到 world_size-1),0 通常为主进程
    • 使用 'env://' 时可由环境变量 RANK 指定
  • 5)timeout

    • 通信超时时间,默认为 30 分钟(1800 秒)
    • 对于耗时较长的操作可适当增大
  • 6)store

    • 与 init_method 参数互斥,不指明 store 时,由 init_method 参数传入的方式决定初始化什么类型的 Store
    • init_method
    • 详情见附录
  • 7)pg_options

    • pg_options 参数用于配置进程组的特定选项,它是一个可选参数,允许用户为不同的后端指定特定的配置选项
    • 该参数的类型通常是Optional[Any],具体的选项内容取决于所使用的通信后端,主要用于:
      • 配置后端特定的优化参数
      • 设置通信超时时间
      • 调整内存使用策略
      • 配置网络相关参数等
    • 举例:
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      import torch.distributed as dist

      # 示例:为NCCL后端配置特定选项
      pg_options = {
      'timeout': 1800, # 30分钟超时
      'init_method_timeout': 300, # 初始化超时
      }

      dist.init_process_group(
      backend="nccl",
      world_size=4,
      rank=0,
      pg_options=pg_options
      )
  • 8)device_id

    • device_id 参数用于将进程”绑定”到单个特定设备,从而实现后端特定优化,是一个 torch.device 类型的可选参数,主要用于 GPU 训练场景
    • NCCL 后端的特殊效果:在 NCCL 后端下,device_id 参数有两个重要影响
      • 1)立即形成通信器 :通信器会立即形成(直接调用ncclCommInit*而非延迟初始化)
      • 2)内存占用优化 :每个进程会在指定的GPU上占用显存,而不是在第一个可访问的GPU上[citation:7]
    • 示例:
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      import torch
      import torch.distributed as dist

      # 指定当前进程使用的GPU设备
      device_id = torch.device(f"cuda:{local_rank}")

      dist.init_process_group(
      backend="nccl",
      world_size=world_size,
      rank=rank,
      device_id=device_id
      )

示例:单节点多进程(使用 torch.multiprocessing)

  • 下面的代码启动单节点时,无需配置环境变量,也不需要特殊的启动命令,使用 python 命令即可
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    import torch
    import torch.distributed as dist
    import torch.multiprocessing as mp

    def init_process(rank, world_size):
    # 设置环境变量
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # 初始化进程组
    dist.init_process_group(
    backend='nccl', # 使用NCCL后端(GPU)
    init_method='env://',
    world_size=world_size,
    rank=rank
    )

    # 后续分布式操作...
    print(f"Process {rank} initialized")

    # 销毁进程组
    dist.destroy_process_group()

    if __name__ == "__main__":
    world_size = 4 # 4个进程
    mp.spawn(init_process, args=(world_size,), nprocs=world_size, join=True)

示例:多节点分布式(通过环境变量配置)

  • 在每个节点上运行的脚本中:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    import torch.distributed as dist
    import os

    # 环境变量通常由分布式启动工具设置
    # 如 torch.distributed.launch 或 torchrun
    dist.init_process_group(backend='nccl')

    # 可通过以下方式获取当前进程信息
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    local_rank = int(os.environ.get('LOCAL_RANK', 0)) # 节点内的进程编号

    print(f"Rank {rank}/{world_size}, Local rank {local_rank}")
  • 上面的代码运行多节点时需要提前指定环境变量(不同机器不同),或通过 torchrun 等命令启动


使用时的一些常规规范写法

  • 分布式启动后,常用到下面三个参数:

    1
    2
    3
    rank: 当前进程在的全局进程编号
    word_size: 分布式系统总进程数
    local_rank: 当前进程在当前节点上的本地进程编号
  • (标准用法)当使用 torchrun 命令(或 torch.distributed.launch 命令时):

    • 可通过环境变量获取这三个参数(此时是默认设置的)

      1
      2
      3
      rank = int(os.environ["RANK"])
      world_size = int(os.environ["WORLD_SIZE"])
      local_rank = int(os.environ["LOCAL_RANK"])
    • 此时启动环境,可以仅指定后端即可

      1
      dist.init_process_group(backend='nccl')
    • 启动命令需要在不同的机器上执行以下,且在命令中传入当前机器对应的参数(自动转化成环境变量),比如:

      1
      2
      3
      4
      5
      # 机器1
      python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr="localhost" --master_port=29500 DDP_demo.py

      # 机器2
      python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr="localhost" --master_port=29500 DDP_demo.py
      • 易错点,要非常小心:
        • python -m torch.distributed.launch(或 torchrun) 的启动参数必须在脚本名 DDP_demo.py 之前
        • DDP_demo.py 之后的参数都是传递给 DDP_demo.py 的,不再是 python -m torch.distributed.launch(或 torchrun)的参数
        • 不建议使用环境变量的方式定义 --master_port 等参数(包括 MASTER_PORT=12368 torchrun xx.py 等方式也不建议)
          • 因为 torchrun 会自动覆盖环境变量 MASTER_PORT(即使没有显示传入 --master_port 参数也会用默认值 27500 覆盖)
  • (使用较少)当使用 torch.multiprocessing.spawn 启动多进程时,需要自己主动管理环境变量,或通过参数传入

  • 无论使用哪种方式使用,均可使用下面的代码获取参数

    1
    2
    3
    rank = dist.get_rank() 
    world_size = dist.get_world_size()
    local_rank = int(os.environ["LOCAL_RANK"]) # 若未配置则需要使用参数显示传入

通过 torch.multiprocessing.spawn 启动完整示例

  • 在代码里面使用 torch.multiprocessing.spawn 函数启动多进程(不使用任何环境变量,整体管理较为复杂)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    import torch.distributed as dist
    import torch.multiprocessing as mp

    def train_fn(local_rank, args, world_size):
    rank = args.start_rank + local_rank # 全局进程编号 = 起始rank + 本地进程编号
    print(rank) # 输出全局进程编号
    # 初始化分布式进程组
    dist.init_process_group(
    backend="nccl",
    init_method=f"tcp://{args.master_addr}:{args.master_port}", # 多机多卡需要指定服务器地址,不能写死成 local
    world_size=world_size,
    rank=rank, # 全局进程编号 = 起始rank + 本地进程编号
    )
    # 绑定本地GPU
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)
    # ...

    # 对于多机多卡,为了复用同一套代码(方便管理),使用传入参数的方式启动代码,方便传参实现不同 node 机器的启动
    parser = argparse.ArgumentParser()
    parser.add_argument("--master_addr", type=str, default="localhost", help="主节点IP")
    parser.add_argument("--master_port", type=str, default="12355", help="主节点端口")
    parser.add_argument("--world_size", type=int, required=True, help="总进程数")
    parser.add_argument("--start_rank", type=int, required=True, help="当前机器进程的起始全局rank")
    parser.add_argument("--num_gpus", type=int, default=torch.cuda.device_count(), help="当前机器的GPU数量")
    args = parser.parse_args()
    print(args)
    # 启动当前机器的所有进程(num_gpus个),每台机器启动自己的进程数即可
    mp.spawn(
    train_fn,
    args=(args, args.world_size), # 传递给train_fn的参数
    nprocs=args.num_gpus, # 进程数=当前机器的GPU数
    join=True # 等待所有子进程完成
    )
  • 此时启动命令可以仅使用普通的 python 命令(后面的参数可选)

    1
    python demo.py --xxx xx
    • 每个节点通过传入不同参数即可指定分布式进程数量,允许不同节点不同数量(start_rank 数和启动时的 num_gpus 参数控制)

通过 torch.distributed.launch 命令启动

  • 假设脚本名为ddp_demo.py

  • 单机4卡,使用4个GPU启动:

    1
    python -m torch.distributed.launch --nproc_per_node=4 ddp_demo.py
    • --nproc_per_node:指定单机的GPU数量
  • 多机多卡,2个机器,各有4个GPU启动:

    1
    2
    3
    4
    5
    # 机器1
    python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr="localhost" --master_port=29500 DDP_demo.py

    # 机器2
    python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr="localhost" --master_port=29500 DDP_demo.py
    • --nnodes:总节点数
    • --node_rank:当前节点编号
    • --master_addr:指定master节点的IP地址,默认值是”localhost”,但是显示给出来更明确
    • --master_port:指定master节点的端口号,默认值是 29500,但是显示给出来更明确

通过 torchrun 命令启动

  • 以下结论是在 Mac 系统下得到的,暂时没有在 Linux 上尝试

  • torchrun 和 torch.distributed.launch 都是 PyTorch 中用于启动分布式训练的工具

  • torchrun 是 PyTorch 1.10 后推出的新一代工具,旨在替代 torch.distributed.launch

  • 两者的核心区别:

    特性 torch.distributed.launch torchrun
    推出时间 较早版本(已逐步废弃) PyTorch 1.10+ 推出(推荐使用)
    进程管理 依赖用户手动管理进程(如指定 --node_rank 等) 自动管理进程,支持弹性训练(节点故障后自动恢复)
    配置方式 大部分参数需通过命令行传入 支持从环境变量、命令行、配置文件读取参数
    容错能力 无弹性训练支持,进程崩溃后需手动重启 支持弹性训练(--max_restarts 等参数)
    日志管理 日志输出较为基础 提供更规范的日志管理,区分不同进程的输出
  • torchrun 启动单机多卡(与 torch.distributed.launch 相同):

    1
    2
    3
    4
    5
    # 回顾 `torch.distributed.launch` 的启动方式
    python -m torch.distributed.launch --nproc_per_node=4 ddp_demo.py

    # torchrun 的启动方式:替换 `python -m torch.distributed.launch` 为 `torchrun` 即可
    torchrun --nproc_per_node=4 ddp_demo.py
  • torchrun 启动多机多卡(相对 torch.distributed.launch而言,torchrun 会自动推断部分参数(如节点数、进程数))

    • 回顾 torch.distributed.launch 需要手动指定总节点数(--nnodes)和当前节点序号(--node_rank)

      1
      2
      3
      4
      5
      # 机器1
      python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr="localhost" --master_port=29500 DDP_demo.py

      # 机器2
      python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr="localhost" --master_port=29500 DDP_demo.py
    • (待确认:Mac 系统下失败)torchrun 无需手动指定 node_rank,只需在所有节点上指定相同的 --rdzv_id(任务ID)和 --rdzv_endpoint(主节点地址):

      1
      2
      # 所有节点统一执行相同命令(不需要区分不同节点使用不同命令,会自动分配node_rank)
      torchrun --nproc_per_node=4 --nnodes=2 --rdzv_id=123 --rdzv_backend=c10d --rdzv_endpoint="localhost:29500" DDP_demo.py
      • --rdzv_id:任务唯一ID(任意整数)
      • rdzv_backend:后端(默认c10d)
      • rdzv_endpoint:主节点IP:端口
      • torchrun还可以用 --max_restarts 等参数指定最大重启次数
    • (Mac 系统下成功)在 Mac 系统下,上面的命令执行会出现问题(上面命令在两个窗口分别打开)

      1
      2
      torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr="localhost" --master_port=29500 DDP_test.py
      torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr="localhost" --master_port=29500 DDP_test.py
      • 亲测上面的代码可以成功
  • 特别地,部分介绍文档会说 torch.distributed.launch 和 torchrun 对代码的使用有一些不同,主要是初始化不同:

    • 初始化方式有下面两种:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      # 方式一:
      rank = int(os.environ["RANK"])
      world_size = int(os.environ["WORLD_SIZE"])
      local_rank = int(os.environ["LOCAL_RANK"])
      dist.init_process_group(backend='gloo', world_size=world_size, rank=rank)

      # 方式二:
      dist.init_process_group(backend='gloo')
      rank = dist.get_rank()
      world_size = dist.get_world_size()
      local_rank = int(os.environ["LOCAL_RANK"])
    • 特别说明:部分文档介绍说 torchrun 只能用方式二, torch.distributed.launch 只能用方式一

    • 亲测:

      • torch.distributed.launch 可用方式一和方式二初始化两种方式均可用,具体使用命令为:

        1
        2
        3
        4
        5
        # 机器1
        python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr="localhost" --master_port=29500 DDP_demo.py

        # 机器2
        python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr="localhost" --master_port=29500 DDP_demo.py
      • torchrun 可用方式二定义,在使用下面的命令时,方式一方式二均可:

        1
        torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr="localhost" --master_port=29500 DDP_test.py
  • 使用建议:

    • 优先使用 torchrun :它是 PyTorch 官方推荐的新一代工具,简化了分布式配置,支持弹性训练,且兼容未来的功能更新
    • 避免使用 torch.distributed.launch :该工具已逐步被废弃,不再添加新功能,仅为兼容性保留

附录:python -m torch.distributed.launch 命令具体在做什么?

  • 本节讲述 python -m torch.distributed.launch,实际上 torchrun 是 python -m torch.distributed.launch 的升级版本,做的事情差不多,但还多了些功能

启动多个进程

  • 根据 --nproc_per_node 参数指定的数量,为每个 GPU 启动一个独立的进程

    • 例如,如果 --nproc_per_node=4,则会在当前节点启动 4 个进程,每个进程绑定到一个 GPU 上
  • 每个进程会独立执行训练脚本(如 train.py),并通过 dist.init_process_group 初始化分布式环境

    • dist.init_process_group 会让当前进程于其他进程简历通信
  • 代码中,通过下面的命令执行即可使用进程自己的 GPU(实现 GPU 的分配):

    1
    2
    3
    4
    import os
    import torch
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
  • 注:这也是在使用普通的 python 命令启动时,需要在代码里自己手动启动多进程的原因

设置环境变量(分布式环境所需的)

  • python -m torch.distributed.launch 命令会设置分布式训练所需的环境变量(如 WORLD_SIZE、RANK、LOCAL_RANK、MASTER_ADDR、MASTER_PORT 等)
  • 每个进程的环境变量不同,RANK、LOCAL_RANK 等环境变量均是有差异的

代码内部在做什么?(非启动命令的工作)

  • 代码内部通过 dist.init_process_group 初始化分布式通信后端(如 nccl 或 gloo),确保所有进程能够协同工作
    • backend:通信后端(如 nccl 用于 GPU,gloo 用于 CPU)
    • init_method:初始化方法(如 tcp:// 指定 master 地址和端口)
  • 代码示例 :
    1
    dist.init_process_group(backend='nccl', init_method='env://')

附录:init_process_group 的 store 参数详细用法说明

  • 在 torch.distributed.init_process_group() 函数中,store 参数是一个可选参数,用于指定分布式进程间通信所使用的键值存储后端
  • 该参数允许用户显式创建和配置存储实例,而不是依赖默认的自动创建机制
  • store 参数的作用:进程间协调,store 参数指定的存储系统用于:
    • 存储分布式训练过程中的元数据
    • 协调各个进程的初始化过程
    • 实现进程间的同步和通信
    • 管理进程组的状态信息

支持的 Store 类型

  • PyTorch 分布式包支持三种主要的键值存储类型,这些 Store 的核心功能一致(同步元数据),但实现方式不同,选择时需根据分布式环境的网络、存储和调度方式决定
  • 下面是最常见的三种 Store 的使用示例
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    import torch.distributed as dist

    # 第一类:TCPStore
    # # server store 持有数据,client store 可以连接到 server store(by TCP)访问数据,
    store = dist.TCPStore(
    host_name="localhost", # 服务器(master)地址,所有参与分布式训练的进程必须能通过此地址访问到 master
    port=12345, # 需确保该端口在主机上未被占用,且所有进程可访问此端口
    world_size=4, # 指明 store users 数,默认为 None(表示没有固定的 store users 数量),注:一般与分布式环境的 world_size保持一致
    is_master=True, # 是否为master节点,只有主进程会启动 TCP 服务器,其他进程(非主进程)会作为客户端连接到主进程
    timeout=300, # 超时时间(秒)
    wait_for_workers=True, # (bool, optional), 默认为 True,是否等待所有 workers 连接到 store 服务器,只有当 world_size 为固定值时生效
    )
    # # 第二类:FileStore
    # store = dist.FileStore("/tmp/distributed_store", world_size=4)

    # # 第三类:HashStore(通常用于单机多进程)
    # store = dist.HashStore()

    # 使用自定义store初始化进程组
    dist.init_process_group(
    backend="nccl",
    world_size=4,
    rank=0,
    store=store
    )

使用 init_method 指定初始化 Store 类型

  • 在 PyTorch 的 dist.init_process_group 中,不指明 store 参数时,init_method 参数决定了分布式进程初始化时使用的 Store 类型
  • 不同的 init_method 对应不同的 Store 实现,用于在进程间同步元数据(如进程编号、通信地址等)
  • 不同初始化方法如下:
    • init_method='tcp://master_ip:port':(对应 TCPStore)
      • 基于 TCP 协议的集中式 Store,需要指定一个主节点(master)的 IP 和端口
      • 主进程会在指定地址创建 TCP 监听,其他进程通过该地址连接主进程,完成元数据交换
      • 适用于大多数分布式场景(单机多卡、多机多卡),无需依赖外部服务
    • init_method='file:///path/to/shared_file'(对应 FileStore)
      • 基于共享文件系统的 Store,所有进程通过读写同一个共享文件同步元数据
      • 要求所有进程可访问同一个共享文件系统(如 NFS、本地文件系统,单机多卡场景常用)
      • 无需网络通信,但依赖文件系统的可靠性和性能
    • init_method='env://'(由环境变量指定的 Store(通常是 TCPStore))
      • 不直接指定 Store 类型,而是通过环境变量(如 MASTER_ADDR、MASTER_PORT、WORLD_SIZE、RANK 等)配置初始化信息
      • 本质上仍会创建 TCPStore,但参数由环境变量而非函数参数传入
      • 常用于容器化环境(如 Kubernetes)或需要动态配置的场景
    • 第三方分布式框架集成(如 Slurm、MPI)(对应 SlurmStore、MPIStore 等(取决于框架))
      • 当使用 Slurm 或 MPI 启动分布式任务时,PyTorch 可自动检测并使用对应框架的 Store
      • 例如,init_method='slurm://' 会使用 SlurmStore,通过 Slurm 的环境变量和接口同步元数据,无需手动指定主节点
  • 总结核心差异总结
    init_method 类型 Store 实现 依赖条件 适用场景
    tcp://master_ip:port TCPStore 网络连通性 通用分布式场景(单机/多机)
    file:///path FileStore 共享文件系统 单机多卡或共享存储的多机场景
    env:// 通常为 TCPStore 环境变量配置 容器化、动态配置场景
    第三方框架(如 Slurm) 框架专属 Store 对应调度框架环境 集群管理系统(Slurm/MPI)

使用 PrefixStore 进行多进程组管理

  • 使用 PrefixStore 可以为不同的进程组创建隔离的命名空间 :
    1
    2
    3
    4
    5
    6
    7
    8
    9
    # 创建基础store
    base_store = dist.TCPStore("localhost", 12345, world_size, is_master=True)

    # 为不同进程组创建前缀store
    pg1_store = dist.PrefixStore("group1_", base_store)
    pg2_store = dist.PrefixStore("group2_", base_store)

    # 使用不同的store初始化不同的进程组
    dist.init_process_group(backend="nccl", store=pg1_store, ...)

注意事项

  • 当同时指定了环境变量和 store 参数时,store 参数会优先使用:

    1
    2
    3
    4
    5
    6
    # 即使设置了环境变量,也会使用显式指定的store
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    custom_store = dist.TCPStore("192.168.1.100", 29500, world_size, is_master)
    dist.init_process_group(backend="nccl", store=custom_store) # 使用custom_store
  • 和 init_method 参数是互斥的,只能指定其中一个:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    # 方法1:使用init_method
    dist.init_process_group(
    backend="nccl",
    init_method="tcp://192.168.1.100:29500"
    )

    # 方法2:使用store参数
    store = dist.TCPStore("192.168.1.100", 29500, world_size, is_master)
    dist.init_process_group(
    backend="nccl",
    store=store
    )
  • 可根据环境动态配置不同 store

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    def create_store(backend, world_size, rank):
    if backend == "nccl":
    # GPU训练使用TCPStore
    return dist.TCPStore(
    host_name=os.environ.get('MASTER_ADDR', 'localhost'),
    port=int(os.environ.get('MASTER_PORT', '29500')),
    world_size=world_size,
    is_master=(rank == 0)
    )
    else:
    # CPU训练可以使用FileStore
    return dist.FileStore("/tmp/dist_store", world_size)

附录:init_process_group 的 device_id 参数详细用法说明

  • device_id 参数用于将进程”绑定”到单个特定设备,从而实现后端特定优化,是一个 torch.device 类型的可选参数,主要用于 GPU 训练场景

  • device_id 参数在 NCCL 后端下有特殊效果:在 NCCL 后端下,device_id 参数有两个重要影响

    • 1)立即形成通信器 :通信器会立即形成(直接调用ncclCommInit*而非延迟初始化)
    • 2)内存占用优化 :每个进程会在指定的 GPU 上占用显存,而不是在第一个可访问的 GPU 上
    • If you want to know NCCL initialization error early, you can also use this field
  • 使用 gloo 后端时,不需要设置 device_id

  • 基本用法

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    import torch
    import torch.distributed as dist

    # 指定当前进程使用的GPU设备
    device_id = torch.device(f"cuda:{local_rank}")

    dist.init_process_group(
    backend="nccl",
    world_size=world_size,
    rank=rank,
    device_id=device_id # 不指定 device_id 默认会在可访问的第一个 GPU 上占用内存?
    )
  • 多 GPU 训练场景

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    import torch
    import torch.distributed as dist
    import os

    def setup_distributed():
    # 获取本地rank
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    rank = int(os.environ.get("RANK", 0))

    # 设置当前进程的GPU设备
    torch.cuda.set_device(local_rank) # 根据不同 rank 获取不同的设备
    device_id = torch.device(f"cuda:{local_rank}")

    # 初始化进程组,绑定到特定GPU
    dist.init_process_group(
    backend="nccl",
    world_size=world_size,
    rank=rank,
    device_id=device_id # 绑定到特定设备
    )

    return device_id

附录:init_process_group 中 NCCL 的延迟初始化

  • NCCL 是 NVIDIA 集合通信库(NVIDIA Collective Communications Library)

    • 提供多 GPU / 多节点通信原语(如 all-reduce、broadcast、all-gather、reduce-scatter 等)
    • 针对 PCIe/NVLink 与 NVIDIA 网络优化,用于加速深度学习分布式训练
    • 开源地址:github.com/NVIDIA/nccl
  • 延迟初始化(Lazy Initialization)是指 NCCL 通信器不在 init_process_group() 调用时立即创建,而是推迟到 第一次实际需要进行集合通信操作时才创建

    • 注:若 init_process_group() 函数指定了 device_id 参数,则 NCCL 会立即初始化到当前设备上
  • 立即初始化(指定 device_id 时)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    device_id = torch.device(f"cuda:{local_rank}")
    dist.init_process_group(backend="nccl", device_id=device_id)

    # 在这一行执行时,NCCL立即:
    # 1. 调用 ncclCommInit* 系列函数
    # 2. 创建通信器对象
    # 3. 分配通信缓冲区
    # 4. 建立进程间的通信通道
    # 5. 进行通信拓扑优化
  • 延迟初始化(不指定 device_id 时)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    dist.init_process_group(backend="nccl")  # 没有device_id

    # 在这一行执行时,NCCL只是:
    # 1. 记录进程组的元信息
    # 2. 设置必要的环境变量
    # 3. 但不创建实际的通信器

    # 真正的初始化发生在第一次集合通信时:
    tensor = torch.randn(10).cuda()
    dist.all_reduce(tensor) # 在这里才真正初始化 NCCL 通信器!自动识别需要绑定的 GPU 并初始化,这里会很慢,因为要初始化 NCCL

    dist.all_reduce(tensor) # 这里就很快了,这次不需要初始化 NCCL

    # # 注:以下操作都会触发NCCL初始化:
    # dist.all_reduce(tensor) # 全规约
    # dist.all_gather([tensor]) # 全收集
    # dist.reduce(tensor, dst=0) # 规约
    # dist.broadcast(tensor, src=0) # 广播
    # dist.all_to_all([tensor], [tensor]) # 全到全
    • 延迟初始化可以允许灵活、动态、自动地选择需要的 GPU,但可能会造成一些误解
    • 注意:一旦 NCCL 初始化以后,就绑定了 GPU 了,再切换 GPU 可能会出现错误
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      def multi_device_lazy_init_issues():
      dist.init_process_group(backend="nccl")

      # 问题场景:进程需要使用多个GPU
      tensor_gpu0 = torch.randn(10).cuda(0)
      tensor_gpu1 = torch.randn(10).cuda(1)

      # 第一次通信在GPU 0上
      dist.all_reduce(tensor_gpu0) # NCCL在GPU 0上初始化

      # 尝试在GPU 1上通信
      try:
      dist.all_reduce(tensor_gpu1) # 可能失败或性能差
      except Exception as e:
      print(f"Multi-device issue: {e}")
  • 目前尚无代码可直接检测 NCCL 通信器是否已创建,只能通过第一次通信时间来大致判断

  • 生产环境推荐立即初始化,方便代码阅读;开发时可以使用延迟初始化以获得更快的启动

NLP——Muon

  • 参考链接:
    • 原始博客(Muon 最早出自该博客):kellerjordan.github.io/posts/muon, Muon: An optimizer for hidden layers in neural networks, 20241208
    • 苏神的解读:Muon优化器赏析:从向量到矩阵的本质跨越,注:苏神博客中写的方案与原始博客方案不完全相同,增加了一些技巧
    • Muon 的改进论文:Muon is Scalable for LLM Training, 20250224, Moonshot AI

整体讨论

  • 在 AdamW 已经大行其道的今天(24年底),已经很少有人在优化器上下功夫了,Muon (MomentUm Orthogonalized by Newton-Schulz) 就是其中一个不可多得的优秀方法
  • Muon 最早由 Keller Jordan 2024年12月8日 在其博客 Muon: An optimizer for hidden layers in neural networks 中发表,后在业内引起了广泛讨论

Muon is Scalable for LLM Training 论文内容

Paper Summary

  • 基于矩阵正交化(Matrix Orthogonalization)的 Muon 优化器(2024)在小规模语言模型训练中表现出色,但其在大规模模型上的可扩展性尚未得到验证
  • 论文发现了两种关键技术可以扩展 Muon:
    • (1)加入权重衰减(weight decay)
    • (2)精心调整每个参数的更新尺度
  • 这些技术改进让 Muon 能够直接用于大规模训练,而无需调整超参数
  • 扩展定律实验表明,在计算最优训练条件下,Muon 的计算效率比 AdamW 高约 2 倍
  • 基于这些改进,论文推出了 Moonlight,这是一个使用 Muon 训练、包含 3B/16B 参数的混合专家模型(Mixture-of-Expert, MoE),训练数据量为 5.7T tokens
  • 论文的模型改进了当前的帕累托前沿(Pareto frontier) ,在更少的训练浮点运算(FLOPs)下实现了更好的性能
  • 论文开源了分布式 Muon 实现,该实现内存最优且通信高效(特别地:还发布了预训练、指令微调及中间检查点)

Introduction and Discussion

  • LLM (2024;DeepSeek-2024;2024;2024)的快速发展显著推动了通用人工智能的进步
  • 由于扩展定律(2020;2022)的存在,训练强大的 LLM 仍然是一个计算密集且资源需求高的过程
  • 优化器在高效训练 LLM 中扮演着关键角色,其中 Adam(2015)及其变体 AdamW(2019)是大多数大规模训练的标准选择
  • 近期优化算法的发展显示出超越 AdamW 的潜力(2024;2024;2024;2025;2018a;2018b;2024;2022;2024;2025)
  • 其中,K. Jordan 等人(2024)提出了 Muon ,它通过牛顿-舒尔茨迭代(Newton-Schulz iteration)使用正交化梯度动量(orthogonalized gradient momentum)更新矩阵参数
  • Muon 在小规模语言模型训练中的初步实验表现出色,但正如这篇博客(Muon: An optimizer for hidden layers in neural networks, 20241208)所讨论的,仍存在几个关键挑战未解决:
    • (1)如何将基于矩阵正交化的优化器有效扩展到具有数十亿参数、训练数据量达数万亿 tokens 的大模型;
    • (2)如何在分布式环境中计算近似正交化;
    • (3)此类优化器是否能泛化到不同训练阶段,包括预训练和监督微调(Supervised Finetuning, SFT)
  • 在本技术报告中,论文通过系统性研究解决了这些挑战
  • 论文的工作基于 Muon,同时通过分析解决了其在大规模训练场景中的局限性。论文的技术贡献包括:
    • Muon 有效扩展的分析(Analysis for Effective Scaling of Muon) :
      • 通过广泛分析,论文发现权重衰减对 Muon 的可扩展性至关重要
      • 论文提出了对 Muon 参数级更新规则的尺度调整
        • 这些调整使得 Muon 无需超参数调优即可直接使用,并显著提高了训练稳定性
    • 高效的分布式实现(Efficient Distributed Implementation) :
      • 论文开发了基于 ZeRO-1(2020)风格的分布式 Muon 版本,实现了最优内存效率和降低的通信开销,同时保留了算法的数学特性
    • 扩展定律验证(Scaling Law Validation) :
      • 论文进行了扩展定律研究,比较 Muon 与强基线 AdamW,结果显示 Muon 性能更优(图 1a)
      • 根据扩展定律结果,Muon 在仅需约 52% 的训练 FLOPs 时,即可达到与 AdamW 训练模型相当的性能
  • 论文的全面实验表明,Muon 可以有效地替代 AdamW 作为大规模 LLM 训练的实际优化器,在训练效率和模型性能上均带来显著提升
  • 基于这项工作,论文发布了 Moonlight,这是一个使用 Muon 训练的 16B 参数 MoE 模型,同时开源了实现代码和中间训练检查点,以促进 LLM 可扩展优化技术的进一步研究

Methods

Background

  • Muon 优化器 Muon(2024)是一种针对矩阵参数优化的神经网络优化器

  • 在迭代步 \( t \) 时,给定当前权重 \(\mathbf{W}_{t-1}\)、动量 \(\mu\)、学习率 \(\eta_t\) 和目标函数 \(\mathcal{L}_t\),Muon 的更新规则如下:
    $$
    \begin{split}
    \mathbf{M}_t &= \mu\mathbf{M}_{t-1} + \nabla\mathcal{L}_t(\mathbf{W}_{t-1}) \\
    \mathbf{O}_t &= \text{Newton-Schulz}(\mathbf{M}_t)^{\mathrm{i} } \\
    \mathbf{W}_t &= \mathbf{W}_{t-1} - \eta_t\mathbf{O}_t
    \end{split} \tag{1}
    $$

    • 其中,\(\mathbf{M}_t\) 是第 \( t \) 步的梯度动量(初始时 \(\mathbf{M}_0\) 为零矩阵)
    • 在公式1中,Newton-Schulz 迭代过程(2024)用于近似计算 \((\mathbf{M}_t\mathbf{M}_t^{\mathrm{T} })^{-1/2}\mathbf{M}_t\)
    • 设 \(\mathbf{M}_t\) 的奇异值分解(SVD)为 \(\mathbf{U}\boldsymbol{\Sigma}\mathbf{V}^{\mathrm{T} }\),则 \((\mathbf{M}_t\mathbf{M}_t^{\mathrm{T} })^{-1/2}\mathbf{M}_t = \mathbf{U}\mathbf{V}^{\mathrm{T} }\),即将 \(\mathbf{M}_t\) 正交化
    • 直观上,正交化能确保更新矩阵是同构的,避免权重沿少数主导方向学习(2024)
  • Newton-Schulz 迭代的矩阵正交化(Newton-Schulz Iterations for Matrix Orthogonalization) :公式1通过迭代过程计算

    • 初始时,设:
      $$\mathbf{X}_0 = \mathbf{M}_t / |\mathbf{M}_t|_{\mathrm{F} }$$
      • 注:\(|\mathbf{M}_t|_{\mathrm{F}}\) 是 F 范数,在 PyTorch 中的实现为 M.norm(),定义如下:
        $$ |A|_F = \sqrt{\sum_{i=1}^{m} \sum_{j=1}^{n} |a_{ij}|^2} $$
    • 在每步迭代 \( k \) 中,按以下方式更新 \(\mathbf{X}_k\):
      $$
      \mathbf{X}_k = a\mathbf{X}_{k-1} + b(\mathbf{X}_{k-1}\mathbf{X}_{k-1}^{\mathrm{T} }) \mathbf{X}_{k-1} + c(\mathbf{X}_{k-1}\mathbf{X}_{k-1}^{\mathrm{T} })^{2} \mathbf{X}_{k-1} \tag{2}
      $$
      • 其中,\(\mathbf{X}_N\) 是经过 \( N \) 次迭代后的结果,\( a \)、\( b \)、\( c \) 为系数
      • 为确保公式2正确收敛,需调整系数使多项式 \( f(x) = ax + bx^{3} + cx^{5} \) 在 1 附近有固定点
      • 在Jordan等人(2024)的原始设计中,系数设为 \( a=3.4445 \)、\( b=-4.7750 \)、\( c=2.0315 \),以加速小初始奇异值的收敛(论文沿用这一设置)
    • 原始博客 Muon: An optimizer for hidden layers in neural networks, 20241208 中 Newton-Schulz 算法的实现如下:
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      # Pytorch code
      def newtonschulz5(G, steps=5, eps=1e-7):
      assert G.ndim == 2
      a, b, c = (3.4445, -4.7750, 2.0315)
      X = G.bfloat16()
      X /= (X.norm() + eps)
      if G.size(0) > G.size(1):
      X = X.T
      for _ in range(steps):
      A = X @ X.T
      B = b * A + c * A @ A
      X = a * X + B @ X
      if G.size(0) > G.size(1):
      X = X.T
      return X
  • 范数约束下的最速下降法(Steepest Descent Under Norm Constraints)

    • Bernstein等人(2024)提出将深度学习优化过程视为范数约束下的最速下降
      • 注:最速下降法(Steepest Descent Method)和共轭梯度法(Conjugate Gradient Method, CG)类似,都是求解无约束最优化问题的优化方法
    • 从这一视角看,Muon 与 Adam(2015;2019)的区别在于范数约束的不同:
      • Adam 是动态调整的 Max-of-Max 范数约束下的最速下降,而 Muon 提供的是静态 Schatten-\( p \) 范数约束(Franz,2024)
      • 当公式1精确计算时,Muon 的范数约束为谱范数
      • 神经网络权重作为输入空间或隐藏空间的算子,通常(局部)是欧几里得的(Cesista,2024),因此权重的范数约束应为诱导算子范数(或矩阵的谱范数)
        • 为什么神经网络可以看成算子?
      • 从这个意义上说,Muon 的范数约束比 Adam 更合理

Scaling Up Muon

Weight Decay

  • 虽然 Muon 在小规模模型上表现优于 AdamW(2024),但作者发现当扩展到更大模型和更多数据时,性能提升会减弱
  • 作者观察到权重和层输出的 RMS 值持续增长 ,甚至超出 bf16 的高精度范围,可能损害模型性能
  • 为解决这一问题,论文将 AdamW(2019)的标准权重衰减机制引入 Muon:
    $$
    \mathbf{W}_t = \mathbf{W}_{t-1} - \eta_t(\mathbf{O}_t \color{red}{+ \lambda\mathbf{W}_{t-1}}) \tag{3}
    $$
  • 通过实验对比带和不带权重衰减的 Muon,论文训练了一个 800M 参数、100B token 的模型(约为最优训练 token 量的 5 倍)
  • 图2展示了使用 AdamW、原始 Muon(无权重衰减)和带权重衰减的 Muon 的验证损失曲线
  • 原始 Muon 初始收敛更快 ,但部分权重随时间增长过大 ,可能限制模型长期性能
  • 加入权重衰减后,Muon 的表现优于原始 Muon 和 AdamW,在 “over-train regime”(过训练区域)中实现了更低的验证损失
  • 因此,论文将更新规则调整为公式3 ,其中 \(\lambda\) 为权重衰减比率

Consistent Update RMS

  • Adam 和 AdamW(2015;2019)的一个重要特性是其理论更新 RMS 保持在 1 附近
  • 但 Muon 的更新 RMS 随参数形状变化,如下引理所示:
  • 引理1 :对于形状为 \([A,B]\) 的满秩矩阵参数,其理论 Muon 更新 RMS 为 \(\sqrt{1/\max(A,B)}\)
    • 注:\(A\) 和 \(B\) 矩阵的维度,比如 \(A = 4\) 表示 \(4 \times 4\) 大小的矩阵; 而 \(\max(A,B)\) 是一个数字,即 \(A\) 和 \(B\) 中的较大值
    • 引理1 的证明和变量含义可见 附录A
  • 论文监测了训练中 Muon 的更新 RMS,发现其通常接近上述理论值
  • 这种不一致性在扩展模型规模时可能引发问题:
    • 当 \(\max(A,B)\) 过大(如稠密 MLP 矩阵)时,更新过小,限制模型表征能力;
    • 当 \(\max(A,B)\) 过小(如将 GQA(2019)或 MLA(DeepSeek-2024)中的每个 KV 头视为独立参数)时,更新过大,导致训练不稳定
  • 为保持不同形状矩阵间更新 RMS 的一致性,论文提出对每个矩阵的 Muon 更新乘以 \(\sqrt{\max(A,B)}\) 以抵消引理1 的影响
  • 第3.1节的实验证明这一策略对优化有益

Matching Update RMS of AdamW

  • Muon 专为矩阵参数设计 ,实际训练中 AdamW 用于处理非矩阵参数(如 RMSNorm、LM Head 和 Embedding 参数)
    • 作者希望优化器超参数(学习率 \(\eta\)、权重衰减 \(\lambda\))能在矩阵和非矩阵参数间共享
    • 自回归模型中,LM Head 通常也只包含一个矩阵权重参数(d_model x vocab_size维度大小)吧?
  • 论文提出将 Muon 的更新 RMS 调整至与 AdamW 相近的范围
    • 根据经验观察,AdamW 的更新 RMS 通常在 0.2 至 0.4 之间
    • 问题:为什么要将 RMS 调整到 AdamW 相近范围,其他范围不行吗?
      • 回答:是从实验验出来的,附录 A 中有实验说明
  • 论文基于观察,通过以下公式调整将 Muon 的更新 RMS 缩放至该范围:
    $$
    \mathbf{W}_t = \mathbf{W}_{t-1} - \eta_t(\color{blue}{0.2 \cdot} \mathbf{O}_t \color{blue}{\cdot \sqrt{\max(A,B)}} + \color{red}{\lambda\mathbf{W}_{t-1}})
    $$
  • 这一选择的实证验证见附录A
  • 此外,调整后 Muon 可直接复用为AdamW调优的学习率和权重衰减

Other Hyper-parameters

  • Muon还有两个可调超参数:
    • Newton-Schulz 迭代步数 \( N \) :
      • 实验发现,当 \( N=10 \) 时,迭代结果比 \( N=5 \) 更精确,但性能未提升
      • 因此,论文为效率考虑选择 \( N=5 \)
    • 动量 \(\mu\) :
      • 动量调优未带来一致性能提升,故沿用 Jordan等人(2024)的 0.95

Distributed Muon

ZeRO-1 与 Megatron-LM
  • ZeRO-1(2020)技术将昂贵优化器状态(如主权重、动量)分区存储于集群中
  • Megatron-LM(2020)将 ZeRO-1 集成到其并行设计中
  • 基于 Megatron-LM 的并行策略(如张量并行 TP、流水线并行 PP、专家并行 EP 和数据并行 DP),ZeRO-1 的通信负载从全局收集简化为仅需在数据并行组内收集

    Based on Megatron-LM’s sophisticated parallel strategies, e.g. Tensor-Parallel (TP), Pipeline Parallel (PP), Expert Parallel (EP) and Data Parallel (DP), the communication workload of ZeRO-1 can be reduced from gathering all over the distributed world to only gathering over the data parallel group.

Method

  • ZeRO-1 对 AdamW 高效,因其按元素计算更新
  • 但 Muon 需完整梯度矩阵计算更新,故原始 ZeRO-1 不直接适用于 Muon
  • 论文提出基于 ZeRO-1 的分布式 Muon 方案,称为 Distributed Muon,它在DP上分区优化器状态,并引入两项额外操作:
    • 1) DP Gather :将本地 DP 分区的主权重(大小为模型权重的 1/DP)对应的分区梯度收集为完整梯度矩阵
    • 2) 计算完整更新(Calculate Full Update) :对完整梯度矩阵执行 Newton-Schulz 迭代(如2.1节所述),随后丢弃部分更新矩阵,仅保留与本地参数对应的分区
  • Distributed Muon 的实现如算法1所示,新增操作以蓝色标注

Analysis

  • 论文从多角度对比 Distributed Muon 与经典 ZeRO-1 分布式 AdamW(简称 Distributed AdamW):
  • 内存占用(Memory Usage) :Muon 仅需一个动量缓冲区,AdamW 需两个,故 Muon 的额外内存占用为 AdamW 的一半
  • 通信开销(Communication Overhead) :每设备仅需为本地 DP 分区参数 \(\mathbf{p}\) 执行额外 DP 收集,通信成本低于 \(\mathbf{G}\) 的 reduce-scatter 或 \(\mathbf{P}\) 的 all-gather
    • 此外,Muon 的 Newton-Schulz 迭代以 bf16 执行,通信开销比 fp32 降低 50%
    • 总体而言,Distributed Muon 的通信量为 Distributed AdamW 的 1 至 1.25 倍
  • 延迟(Latency) :Distributed Muon 因额外通信和 Newton-Schulz 迭代,端到端延迟高于 Distributed AdamW。但这并非主要问题,因为:
    • (a)Newton-Schulz 仅需约 5 次迭代即可获得良好结果(见2.2节);
    • (b)优化器导致的延迟仅占模型前向-反向传播时间的 1% 至 3%
    • 此外,技术可进一步降低延迟,比如:
      • overlapping gather and computation
      • overlapping optimizer reduce-scatter with parameter gather
  • 在大规模分布式集群中,Distributed Muon 的延迟开销与 AdamW 相当
  • 论文将很快向开源 Megatron-LM 提交实现 Distributed Muon 的 PR

Experiments

Consistent Update RMS

  • 如第 2.2 节所述,论文的目标是让所有矩阵参数的更新 RMS 保持一致,并与 AdamW 的更新 RMS 匹配
  • 论文通过两种方法控制 Muon 的更新 RMS,并与仅保持与 AdamW 一致 RMS 的基线进行比较:
  • Baseline :论文将更新矩阵乘以 \(0.2 \cdot \sqrt{H}\)(\(H\) 为模型隐藏层大小),以保持与 AdamW 一致的更新 RMS。注意,对于大多数矩阵,\(\max(A,B)\) 等于 \(H\)
    $$
    \mathbf{W}_{t} = \mathbf{W}_{t-1} - \eta_{t}(0.2 \cdot \mathbf{O}_{t} \cdot \sqrt{H} + \lambda \mathbf{W}_{t-1})
    $$
  • 更新归一化(Update Norm) :论文直接对通过牛顿-舒尔茨迭代计算的更新进行归一化,使其 RMS 严格等于 0.2:
    $$
    \mathbf{W}_{t} = \mathbf{W}_{t-1} - \eta_{t}(0.2 \cdot \mathbf{O}_{t} / \text{RMS}(\mathbf{O}_{t}) + \lambda \mathbf{W}_{t-1})
    $$
  • 调整学习率(Adjusted LR) :对于每个更新矩阵,论文根据其形状将学习率缩放 \(0.2 \cdot \sqrt{\max(A,B)}\) 倍:
    $$
    \mathbf{W}_{t} = \mathbf{W}_{t-1} - \eta_{t}(0.2 \cdot \mathbf{O}_{t} \cdot \sqrt{\max(A,B)} + \lambda \mathbf{W}_{t-1})
    $$
  • 分析 :论文设计了实验来说明 Muon 更新 RMS 在训练早期的影响,因为在更大规模的模型训练中,异常行为会很快出现
    • 论文使用第 3.2 节描述的 800M 参数小模型进行实验
    • 当矩阵维度差异较大时,更新 RMS 不一致的问题会更加明显,为了突出这一问题,论文略微修改了模型架构
      • 将 Swiglu MLP 替换为标准的两层 MLP,将其矩阵参数的形状从 \([H, 2.6H]\) 改为 \([H, 4H]\)
    • 论文评估了模型的损失,并监测了一些参数的 RMS,特别是注意力查询和 MLP:
      • 注意力查询(形状 \([H, H]\))
      • MLP(形状 \([H, 4H]\))
      • 论文在 20B token 的训练计划中训练了 4B token 后评估模型
    • 从表 1 中,论文观察到以下几点:
      • 1)更新归一化和调整学习率方法均优于基线;
      • 2)对于形状为 \([H, 4H]\) 的 MLP 权重矩阵,更新归一化和调整学习率得到的权重 RMS 大约是基线的两倍
        • 这是因为 \(\sqrt{\max(H, 4H)} / \sqrt{H} = 2\),因此更新归一化和调整学习率的更新 RMS 大约是基线的两倍;
      • 3)对于形状为 \([H, H]\) 的注意力查询权重矩阵,更新归一化仍然对更新进行归一化,而调整学习率则不会,因为 \(\sqrt{\max(H, H)} / \sqrt{H} = 1\)
        • 因此,调整学习率得到的权重 RMS 与基线相似,而更新归一化的权重 RMS 则与其 MLP 类似;
  • 基于这些发现,论文选择调整学习率方法用于后续实验,因为它的计算成本更低

Scaling Law of Muon

  • 为了与 AdamW 进行公平比较,论文在 Llama 架构的一系列密集模型上进行了缩放定律实验
  • 构建一个强大的基线对于优化器研究至关重要,因此论文按照计算最优训练设置(2022)对 AdamW 的超参数进行了网格搜索(网格搜索实验详见附录 B)
  • 模型架构和超参数的细节见表 2
  • 对于 Muon,如第 2.2 节所述,由于论文已将 Muon 的更新 RMS 与 AdamW 匹配,因此直接复用了 AdamW 基线的最优超参数
  • 拟合的缩放定律曲线见图 3,拟合方程详见表 3
  • 如图 1a 所示,在计算最优设置下,Muon 仅需约 52% 的训练 FLOPs 即可达到与 AdamW 相当的性能

Pretraining with Muon

  • 模型架构 :为了评估 Muon 在现代模型架构中的表现,论文从头开始预训练了一个基于 deepseek-v3-small 架构(2024)的模型 ,因为该模型性能强大且原始结果可作为参考
    • 论文的预训练模型激活参数为 2.24B,总参数为 15.29B(包含嵌入层时为 3B 激活参数和 16B 总参数)
    • 对架构的微小修改详见附录 C
  • 预训练数据 :预训练数据的细节可参考 Kimi k1.5: Scaling Reinforcement Learning with LLMs, 20250603
    • 预训练的最大上下文长度为 8K
  • 预训练过程 :模型训练分为多个阶段
    • 在阶段 1 和 2 中,论文使用 1e-3 的 Auxfree Bias Update Rate,阶段 3 中为 0.0
      • 问题:这里的 Auxfree Bias Update 是什么?
      • 回答:是在 DeepSeek MoE 训练中使用到的无辅助损失负载均衡技巧(在此之前,常规的负载均衡技巧会使用 辅助负载均衡损失 auxiliary load-balancing loss,Auxfree 表示不需要这个辅助负载均衡项)
      • 注:论文训练的模型架构和 Deepseek-v3-Small 模型一致,这一个 2.4B/16B 参数的 MoE 模型,训练了 1.33T token;
    • 所有权重衰减均设为 0.1
    • 更多训练细节和讨论见附录 D
    • 具体训练流程为:
      • 1)0 到 33B token :在此阶段,学习率在 2k 步内线性增加到 4.2e-4,批量大小保持在 2048 个样本;
      • 2)33B 到 5.2T token :在此阶段,学习率从 4.2e-4 以余弦方式衰减到 4.2e-5
        • 批量大小在 200B token 前保持为 2048,之后增加到 4096;
      • 3)5.2T 到 5.7T token(冷却阶段):在此阶段,学习率在 100 步内增加到 1e-4,随后在 500B token 内线性衰减到 0,批量大小保持为 4096。此阶段使用最高质量的数据,重点关注数学、代码和推理任务
  • 评估基准 :论文的评估涵盖四类主要基准,每类设计用于评估模型的不同能力:
    • 英语语言理解和推理 :MMLU(5-shot)(2021)、MMLU-pro(5-shot)(2024)、BBH(3-shot)(2022)、TriviaQA(5-shot)(2017);
    • 代码生成 :HumanEval(pass@1)(2021)、MBPP(pass@1)(2021);
    • 数学推理 :GSM8K(4-shot)(2021)、MATH(2021)、CMATH(2023);
    • 中文语言理解和推理 :C-Eval(5-shot)(2023)、CMMLU(5-shot)(2024)
  • 性能 :论文将使用 Muon 训练的模型命名为“Moonlight”。论文在 1.2T token 处评估 Moonlight,并与以下同规模公开模型进行比较:
    • Deepseek-v3-Small(2024):一个 2.4B/16B 参数的 MoE 模型,训练了 1.33T token;
    • Moonlight-A :与 Moonlight 训练设置相同,但使用 AdamW 优化器
  • 对于 Moonlight 和 Moonlight-A,论文使用了总预训练 5.7T token 中的 1.2T token 中间检查点,此时学习率尚未衰减到最小值,模型也未进入冷却阶段
  • 如表 4 所示:
    • Moonlight-A(论文的 AdamW 训练基线模型)与同类公开模型相比表现强劲
    • Moonlight 的性能显著优于 Moonlight-A,证明了 Muon 的可扩展性
    • 论文观察到 Muon 在数学和代码相关任务上表现尤为突出,鼓励研究社区进一步研究这一现象
  • 当 Moonlight 完全训练到 5.7T token 后,论文将其与同规模的公开模型进行比较,结果如表 5 所示:
    • LLAMa3-3B(2024):一个 3B 参数的密集模型,训练了 9T token;
    • Qwen2.5-3B(2024):一个 3B 参数的密集模型,训练了 18T token;
    • Deepseek-v2-Lite(2024):一个 2.4B/16B 参数的 MoE 模型,训练了 5.7T token
  • 如表 5 所示,Moonlight 在相同 token 数量下优于同类模型
    • 即使与训练数据量更大的密集模型相比,Moonlight 仍具有竞争力
    • 详细比较见附录 E
  • Moonlight 的性能在 MMLU 和 GSM8k 上与其他知名语言模型进一步对比,如图 1b 和附录 E 图 8 所示
  • 值得注意的是,Moonlight 位于模型性能与训练预算的帕累托前沿,优于许多其他规模的模型

Dynamics of Singular Spectrum(奇异谱)

  • 为了验证 Muon 可以在更多样化的方向上优化权重矩阵的直觉,论文对使用 Muon 和 AdamW 训练的权重矩阵进行了谱分析
  • 对于一个具有奇异值(singular values) \(\sigma = (\sigma_{1}, \sigma_{2}, \cdots, \sigma_{n})\) 的权重矩阵,论文计算其 SVD 熵(2000;2007)如下:
    $$
    H(\sigma) = -\frac{1}{\log n} \sum_{i=1}^{n} \frac{\sigma_{i}^{2} }{\sum_{j=1}^{n} \sigma_{j}^{2} } \log \frac{\sigma_{i}^{2} }{\sum_{j=1}^{n} \sigma_{j}^{2} }
    $$
    • 直观上看,singular values 越平均,SVD 熵越大
  • 如图 4 所示,论文可视化了预训练 1.2T token 过程中不同检查点的权重矩阵的平均 SVD 熵
    • 可以看到,在所有训练检查点和所有权重矩阵组中,Muon 的 SVD 熵均高于 AdamW,这验证了 Muon 可以为权重矩阵提供更多样化的更新谱的直觉
    • 这种差异在专家选择的路由权重中更为显著,表明混合专家模型可以从 Muon 中获益更多
  • 此外,论文在附录 F 中展示了 1.2T token 检查点处各权重矩阵的奇异值分布。论文发现,对于超过 90% 的权重矩阵,Muon 优化时的 SVD 熵高于 AdamW,这为 Muon 在探索多样化优化方向上的卓越能力提供了强有力的实证证据

SFT with Muon

  • 本节论文展示了 Muon 优化器在标准 LLM 训练 SFT 阶段的消融研究
  • 论文的结果表明,Muon 带来的优势在 SFT 阶段仍然存在
    • 具体而言,同时使用 Muon 预训练和 Muon 微调的模型在消融研究中表现最佳
  • 然而,论文也观察到,当 SFT 优化器与预训练优化器不同时,Muon 在 SFT 中并未显示出显著优于 AdamW 的优势
    • 理解:SFT 阶段优化器的选择还与预训练阶段优化器的选择有关?
  • 这表明仍有很大的探索空间,论文将其留待未来研究
Ablation Studies on the Interchangeability of Pretrain and SFT Optimizers(预训练和 SFT 优化器互换性的消融研究)
  • 为了进一步研究 Muon 的潜力,论文使用 Muon 和 AdamW 优化器分别对 Moonlight@1.2T 和 Moonlight-A@1.2T 进行了微调
  • 这些模型在开源的 tulu-3-sft-mixture 数据集(2024)上微调了两个 epoch,数据序列长度为 4k
  • 学习率采用线性衰减计划,从 \(5 \times 10^{-5}\) 逐渐降至 0
  • 结果如表 6 所示,Moonlight@1.2T 的表现优于 Moonlight-A@1.2T

SFT with Muon on public pretrained models

  • 论文进一步将 Muon 应用于公开预训练模型 Qwen2.5-7B 基础模型(2024)的 SFT ,使用了开源的 tulu-3-sft-mixture 数据集(2024)
  • 数据集以 8k 序列长度打包,论文采用了余弦衰减学习率计划,从 \(2 \times 10^{-5}\) 逐渐降至 \(2 \times 10^{-6}\)
  • 结果如表 7 所示: Muon 微调模型的性能与 Adam 微调模型相当
  • 这些结果表明,为了获得最佳性能,在预训练阶段应用 Muon 比在监督微调阶段更有效

Discussions

  • 未来研究有几个可能的方向可以进一步探索和扩展当前的发现
  • 将所有参数纳入 Muon 框架(Incorporating All Parameters into the Muon Framework) :
    • 目前,Muon 优化器与 Adam 优化器结合使用,某些参数仍由 Adam 优化
    • 这种混合方法虽然可行,但仍有改进空间
    • 将所有参数优化完全集成到 Muon 框架中是一个重要的研究方向
  • 将 Muon 扩展到 Schatten 范数(Extending Muon to Schatten Norms) :
    • Muon 优化器可以解释为谱范数下的最陡下降法
    • 鉴于 Schatten 范数的广泛适用性和多功能性,将 Muon 扩展到一般 Schatten 范数是一个有前景的方向
    • 这一扩展可能解锁额外的优化能力,并可能产生优于当前基于谱范数实现的结果
  • 理解和解决预训练与微调的不匹配(Understanding and Solving the Pretraining-Finetuning Mismatch) :
    • 在实践中观察到一个显著现象,使用 AdamW 预训练的模型在使用 Muon 微调时表现不佳,反之亦然
    • 这种优化器不匹配对有效利用大量 AdamW 预训练检查点(训练 Muon)构成了重大障碍 ,因此需要进行严格的理论研究
    • 精确理解其底层机制对于设计稳健有效的解决方案至关重要

Conclusions

  • 在本技术报告中,论文全面研究了 Muon 在 LLM 训练中的可扩展性
  • 通过系统分析和改进,论文成功将 Muon 应用于一个 3B/16B 参数的 MoE 模型,训练了 5.7T token
  • 论文的结果表明,Muon 可以有效地替代 AdamW 作为大规模 LLM 训练的标准优化器,在训练效率和模型性能上均具有显著优势
  • 通过开源论文的实现、Moonlight 模型和中间训练检查点,作者希望促进可扩展优化技术的进一步研究,并加速 LLM 训练方法的开发

附录 A Update RMS

引理 1 的证明

  • 不失一般性,考虑正交矩阵 \( U \in \mathbb{R}^{n \times n} \) 和 \( V \in \mathbb{R}^{m \times m} \),其中 \( n \geq m \geq r \)
  • 论文将证明对于 \( X = U_{[:,:r]} V_{[:r,:]} \)(Muon 的更新具有相同形式),其均方根值为 \( \sqrt{r/mn} \)
    • 注:\( X^{n\times m} = {U_{[:,:r]}}^{n \times r} \cdot {V_{[:r,:]}}^{r \times m} \)
  • 根据矩阵乘法的定义:
    $$ X_{i,j} = \sum_{k=1}^{r} U_{i,k} V_{k,j} $$
    • 仅考虑 \(r\) 之前的值
  • 均方根可以表示为:
    $$
    \begin{align}
    \text{RMS}(X^{n\times m})^2 &= \frac{1}{mn} \sum_{i=1}^{n} \sum_{j=1}^{m} X_{i,j}^2 \\
    &= \frac{1}{mn} \sum_{i=1}^{n} \sum_{j=1}^{m} \sum_{k=1}^{r} U_{i,k}^2 V_{k,j}^2 \\
    &= \frac{1}{mn} \sum_{k=1}^{r} \left( \sum_{i=1}^{n} U_{i,k}^2 \right) \left( \sum_{j=1}^{m} V_{k,j}^2 \right) \\
    &= \frac{1}{mn} \sum_{k=1}^{r} 1 \\
    &= \frac{r}{mn}
    \end{align}
    $$
    • 注:\( U \in \mathbb{R}^{n \times n} \) 是正交矩阵,有 \(\sum_{i=1}^{n} U_{i,k}^2 = 1\)
      • 证明:\(U^\top U = I\),从而 \(\sum_{i=1}^{n} U_{i,k}^2 = 1\) 是一个对角线元素
  • 因此,\( \text{RMS}(X) = \sqrt{r/mn} \)
  • 对于常见的满秩矩阵情况,\( r = m \),此时 \( \text{RMS}(X) = \sqrt{1/n} \)

Muon 与 AdamW 的更新均方根一致性

  • 如 2.2 节所述,作者希望匹配 Muon 和 AdamW 优化器的更新均方根
  • 这一假设通过小规模模型实验得到验证(问题:为什么刚好匹配 AdamW 优化器的均方根更好?)
  • 论文将 Muon 的更新均方根设置为 \([0.05, 0.1, 0.2, 0.4, 0.8]\),并以 AdamW 为基线
  • 表 8 展示了在 2k 步(约 20 亿 token)时的损失和代表性权重矩阵的均方根结果
  • 实验表明,0.2 和 0.4 的均方根设置表现相似且显著优于其他设置
  • 这与论文观察到的 AdamW 更新均方根范围(0.2 至 0.4)一致,因此论文选择将 Muon 的更新均方根控制在 0.2

附录 B AdamW Baseline Scaling Law

  • 为确保实验的公平性和准确性,论文在专有数据集上进行了一系列实验,以确定 AdamW 的最优缩放定律参数
  • 这包括在计算预算(FLOPs,\( C \))约束下,确定最优模型大小(\( N \))、训练 token 数量(\( D \))、学习率(\( \eta \))和批大小(\( B \))(2022;2020)
  • 表 9 展示了论文系统参数搜索的结果
  • 超参数搜索 :为系统性地确定 AdamW 基线的最优缩放定律超参数,论文采用了多阶段搜索协议
    • 首先,根据先前研究的经验准则,选择多个计算预算(FLOPs 级别),并初始化模型大小、学习率和批大小
      • 对于每个固定的 FLOPs 约束 ,论文调整模型大小 \( N \) ,同时反向调整训练 token 数量 \( D \) ,以保持 \( C = 6ND \) ,从而探索模型容量与数据效率之间的权衡
      • 每种配置训练至收敛,并记录验证损失以确定 \( N \) 和 \( D \) 的帕累托最优组合
    • 随后,固定最优的 \( N-D \) 对 ,通过网格搜索优化学习率和批大小 ,确保配置的稳定性和收敛性
    • 为减少局部最优并增强鲁棒性,此迭代过程重复 2-3 次,逐步缩小超参数空间
      • 问题:重复 2-3 次的目的是什么?具体重复了哪些步骤?
  • 图 5 进一步展示了优化过程,描绘了不同 FLOPs 预算下损失随训练 token、学习率和批大小的变化情况
    • 每个碗形曲线代表特定 FLOPs 级别的损失曲面,其全局最小值对应最优超参数配置

附录 C Model Architecture

  • Muon 对模型架构无特定要求,论文采用了与 Deepseek-V3-Small(DeepSeek-2024)相似的模型,因为其作为基线模型具有开放的权重
  • 论文在 Moonlight 模型中进行了几处小修改,具体如下:
  • 多 token 预测(Multi-token Prediction, MTP)
    • MTP 在论文的实验中未显示出对预训练的显著益处
    • 为简化,Moonlight 模型未引入 MTP 层
  • 无偏置更新(Auxfree Bias Update)
    • 在 DeepSeek-V3-Small 中,无偏置更新通过以下公式实现:
      $$ b_i = b_i + u \times \text{sign}(e_i) $$
      • \( u \) 为更新比例
      • \( b_i \) 为第 \( i \) 个专家的偏置
      • \( e_i \) 为专家的违反比例
    • 论文略微修改了更新规则:
      $$ b_i = b_i + u \times (\text{sign}(e_i) - \text{sign}(e).\text{mean}()) $$
      • \( \text{sign}(e).\text{mean}() \) 为所有专家违反比例符号的平均值,以控制偏置的幅度,同时不改变 topk 选择逻辑
  • 门控缩放因子(Gate Scaling Factor)
    • Deepseek-V2-Lite 未使用门控缩放因子,而 Deepseek-V3 使用了 2.5 的缩放因子
    • 论文采用 2.44 的缩放因子以控制与密集模型相似的输出均方根
    • 计算门控缩放因子的代码如图 6 所示

附录 D Training Stability

  • 无损失或梯度范数尖峰 :Moonlight 的训练过程非常平稳,未出现损失或梯度范数尖峰
    • 损失和梯度范数曲线如图 7 所示(Moonlight 为蓝色,AdamW 训练的 Moonlight-A 为红色)
  • 最大注意力对数(Max Attention Logit) :在训练过程中,论文观察到尽管训练损失和梯度范数始终保持稳定,但在某些层的初始训练阶段,最大注意力对数(全局批次中最大的对数值)明显上升,超过阈值 100
    • 值得注意的是,AdamW 在控制这一指标上表现更优
    • 为进一步研究这一现象的影响,论文引入了大注意力对数比例指标,定义为批次中超过 100 的注意力对数比例
      • 如图 7 所示,该比例始终保持在较低水平(约 \( 10^{-4} \)),表明极端大的对数值是稀疏的
      • 此外,随着训练的进行,最大对数值逐渐下降,表明优化动态趋于健康
  • RMSNorm 伽马权重衰减(RMSNorm Gamma Weight Decay) :值得注意的是,对 RMSNorm 伽马参数应用权重衰减对确保训练稳定性至关重要,因为它能有效防止每层输出均方根过高

附录 E Comparison with More Expensive Models

  • 表 10 对比了论文的 Moonlight 模型(使用 Muon 优化)与公开可用的更高计算资源训练的模型,包括 LLama3.1-8B(2024)、Gemma-9B(Gemma 2024)和 Qwen2.5-7B(2024)
  • 图 8 展示了 Moonlight 与同类模型在 GSM8k 性能基准上的对比

附录 F Singular Value Distributions of Weight Matrices

  • 论文通过绘制每个矩阵奇异值的降序排列线图来可视化权重矩阵的奇异值分布,并将其归一化为最大值
  • 如图 9 和图 10 所示,论文发现对于大多数权重矩阵,Muon 优化的奇异值分布比 AdamW 更平坦,进一步验证了 Muon 能提供更多样化的更新谱的假设

NLP——BLEU指标和ROUGE指标


BLEU 和 ROUGE 指标整体说明

  • 在 NLP 领域,BLEU (Bilingual Evaluation Understudy) 和 ROUGE (Recall-Oriented Understudy for Gisting Evaluation) 指标广泛应用于机器翻译和文本摘要任务中
  • 核心思路:BLEU 和 ROUGE 通过比较模型生成的文本与人工参考文本之间的相似性来衡量文本质量
  • 这两个指标通常没有严格的、官方的中文名,在中文学术界和工业界,大家普遍直接使用它们的英文缩写
    • BLEU :通常直接称为 BLEU 值 或 BLEU 分数
      • 注:BLEU 全称 Bilingual Evaluation Understudy 可以直译为“双语评估替补”(Understudy 常翻译为“替补”含义),但这种直译并不常用
    • ROUGE :通常直接称为 ROUGE 值 或 ROUGE 分数
      • 注:ROUGE 全称 Recall-Oriented Understudy for Gisting Evaluation 可以直译为“面向召回的摘要评估替补”(Gisting 翻译为“摘要”),这种直译也很少被使用
  • BLEU :精确率(Precision)导向,常用于机器翻译 ,关注生成文本的忠实度
  • ROUGE :召回率(Recall)导向,常用于文本摘要 ,关注生成文本对参考文本信息的覆盖度
  • 两者都是衡量文本相似性的重要指标,但它们的侧重点和适用场景有所不同。在实际应用中,通常会结合使用多种评估指标来全面评估模型性能

BLEU 指标

  • 音标:/bluː/(类似英文单词 “blue” 的发音)
  • BLEU 主要用于评估机器翻译的质量,侧重于精确率 (Precision)
  • 核心思想:机器翻译的文本与人工翻译的参考文本越相似 ,其质量越高
  • 计算方式: BLEU 的计算涉及以下几个步骤:
  • 第一步:N-gram 精度 (N-gram Precision) :
    • 首先,将候选翻译和参考翻译都进行分词
    • 计算不同长度的 N-gram(例如,unigram (1-gram), bigram (2-gram), trigram (3-gram), 甚至 higher N-grams,通常到 4-gram)
    • 对于每个 N-gram 长度 \(n\),计算修改后的 N-gram 精度 \(p_n\)
      • 注:修改后的精度是为了避免候选翻译中重复词语过多而导致分数虚高,比如全都翻译为 \(the\) 这种常见词语,可能导致 N-gram 精度虚假的被为“100%”
      • 修改动作:通过计算候选翻译中与参考翻译匹配的 N-gram 数量,并将其限制(裁剪)为该 N-gram 在任何一个参考翻译(可能有多个参考翻译,命中任意翻译都算正确)中出现的最大次数
    • 公式为:
      $$p_n = \frac{\sum_{C \in \text{Candidates}} \sum_{\text{n-gram} \in C} \text{Count}_{\text{clip}}(\text{n-gram})}{\sum_{C’ \in \text{Candidates}} \sum_{\text{n-gram}’ \in C’} \text{Count}(\text{n-gram}’)}$$
      • \(\text{Count}_{\text{clip}}(\text{n-gram})\) 是 N-gram 在候选翻译中出现并被限制为在参考翻译中最大出现次数的计数
      • \(\text{Count}(\text{n-gram}’)\) 是 N-gram 在候选翻译中出现的总计数
      • 理解:相当于统计候选翻译的所有 N-gram 数量,看命中参考翻译 N-gram 精确度是多少(在参考翻译中存在则视为准确,否则不准确)
  • 第二步:短句惩罚 (Brevity Penalty, BP) :
    • N-gram 精度 倾向于奖励更短的翻译,因为它的分母更小(比如只翻译一个 “the”,N-gram 精度为100%)
    • 为了惩罚那些过短的翻译(即使它们完美匹配了部分 N-gram),引入了短句惩罚
    • 如果候选翻译的总长度 \(c\) 小于参考翻译中最接近的参考长度 \(r\),则应用惩罚
    • 公式为:
      $$BP = \begin{cases}
      1 & \text{if } c > r \\
      e^{(1 - r/c)} & \text{if } c \le r
      \end{cases}$$
      • 长度比参考文本短越多,惩罚越大
  • 第三步:最终 BLEU 分数 :
    • 将不同 N-gram 长度的修改精度取对数,然后加权平均,再乘以短句惩罚
    • 公式为:
      $$\text{BLEU} = BP \cdot \exp\left(\sum_{n=1}^{N} w_n \log p_n\right)$$
      • 其中,\(n\) 是最大 N-gram 长度(通常为 4),\(w_n\) 是每个 N-gram 长度的权重(通常均匀分配)
      • 实践经验:在实际使用时,一些特定的例子中可能出现高阶匹配不上而导致 BLEU 分数为 0(比如 4-gram 匹配数量为0,这时候看起来取对数的精度会几乎为负无穷),此时可以 NLTK 库中通过 smoothing_function 传入一些平滑策略来解决问题
    • 通常,每个 N-gram 长度的权重是 \(\frac{1}{N}\)(例如,对于 1-gram 到 4-gram,每个权重为 0.25)
      • 该值在 NLTK 库的 sentence_bleu 中是可以通过参数修改的

代码实现 (使用 NLTK 库):

  • 基于 NLTK 库实现 BLUE 值统计
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    from nltk.translate.bleu_score import sentence_bleu # 评估句子
    # from nltk.translate.bleu_score import corpus_bleu # 对于 corpus_bleu,可以评估整个语料库
    from nltk.tokenize import word_tokenize
    import nltk

    nltk.download('punkt_tab') # 若本地不存在该库,这里会从远程下载,'punkt_tab' 数据包包含了 Punkt 分词器(Tokenizer)所需的预训练模型数据

    def calculate_bleu(reference, candidate):
    """
    计算单个候选句子的BLEU分数
    :param reference: 一个参考句子列表,每个参考句子是一个词语列表
    例如:[['this', 'is', 'a', 'test'], ['this', 'is', 'test']]
    :param candidate: 一个候选句子,是一个词语列表
    例如:['this', 'is', 'a', 'test']
    :return: BLEU分数
    """
    # NLTK 的 sentence_bleu 函数期望参考是列表的列表,候选是列表
    # 如果只有一个参考,也要将其包装在列表中:[reference_sentence]
    # 如果有多个参考,则直接传入参考列表

    # 示例:
    # reference = [['The', 'cat', 'sat', 'on', 'the', 'mat'], [...], ...]
    # candidate = ['The', 'cat', 'was', 'on', 'the', 'mat']
    # score = sentence_bleu(reference, candidate)

    # 实际使用时,通常对文本进行分词
    tokenized_reference = [word_tokenize(ref) for ref in reference]
    tokenized_candidate = word_tokenize(candidate) # 分词结果:['The', 'cat', 'was', 'on', 'the', 'mat', '.']

    # NLTK 提供了 weights 参数来控制不同 N-gram 的权重,默认是 (0.25, 0.25, 0.25, 0.25)
    score = sentence_bleu(tokenized_reference, tokenized_candidate)

    # weights = [1,0,0,0] # 为不同 n-gram 设置不同的权重
    # score = sentence_bleu(tokenized_reference, tokenized_candidate, weights=weights)
    return score

    # 示例
    references = [
    "The cat is on the mat.",
    "There is a cat on the mat.",
    "A cat sat on the mat."
    ]
    candidate = "The cat was on the mat."

    bleu_score = calculate_bleu(references, candidate) # 实现:先分词,再评估 n-gram
    print(f"BLEU Score: {bleu_score}")

    # BLEU Score: 0.488923022434901

BLEU 使用说明

* BLEU 倾向于精确率,对于短的(已经有惩罚了)、精确匹配的句子可能给出高分,但可能忽略了语义的完整性或流畅性
* 更多的参考翻译通常会提高 BLEU 分数
* BLEU 在单句评估上指标不太稳定(修改单个单词可能出现非常大的变化),更适合评估整个语料库的平均表现

ROUGE 指标

  • 音标:/ruːʒ/(类似法语单词 “rouge”)
  • ROUGE 主要用于评估文本摘要的质量,侧重于召回率 (Recall)
  • 核心思想:模型生成的摘要包含了多少人工参考摘要中的重要信息
  • ROUGE 有多种变体,最常用的是:
    • ROUGE-N :基于 N-gram 的重叠
      • ROUGE-1 :Unigram(单个词)的召回率
      • ROUGE-2 :Bigram(两个词序列)的召回率
    • ROUGE-L :基于最长公共子序列 (Longest Common Subsequence, LCS) 的召回率
      • 它不要求 N-gram 必须连续,但要求保持相对顺序,更能捕捉句子的结构相似性
    • ROUGE-SU :基于跳跃二元组 (Skip-bigram) 和 unigram 的重叠,允许 N-gram 中间跳过词语
  • 计算方式: 以 ROUGE-N 为例(其他变体类似,但匹配方式不同):
  • 召回率 (Recall) :
    $$R_{\text{N}} = \frac{\sum_{\text{n-gram} \in \text{Ref}} \text{Count}_{\text{match}}(\text{n-gram})}{\sum_{\text{n-gram} \in \text{Ref}} \text{Count}(\text{n-gram})}$$
    • \(\text{Count}_{\text{match}}(\text{n-gram})\) 是在候选摘要和参考摘要中都出现的 N-gram 数量
    • \(\text{Count}(\text{n-gram})\) 是参考摘要中 N-gram 的总数量
  • 精确率 (Precision) :
    $$P_{\text{N}} = \frac{\sum_{\text{n-gram} \in \text{Cand}} \text{Count}_{\text{match}}(\text{n-gram})}{\sum_{\text{n-gram} \in \text{Cand}} \text{Count}(\text{n-gram})}$$
    • \(\text{Count}_{\text{match}}(\text{n-gram})\) 同上
    • \(\text{Count}(\text{n-gram})\) 是候选摘要中 N-gram 的总数量
  • F1 分数 (F1-score) :通常使用 F1 分数来综合召回率和精确率
    $$F_1 = \frac{(1 + \beta^2) \cdot P \cdot R}{\beta^2 \cdot P + R}$$
    • \(\beta\) 通常取 1,表示精确率和召回率同等重要,此时 \(F_1 = \frac{2 \cdot P \cdot R}{P + R}\)
  • ROUGE-L (LCS-based): ROUGE-L 的计算基于最长公共子序列 (LCS) 的长度
    • \(LCS(X, Y)\) 表示序列 \(X\) 和 \(Y\) 的最长公共子序列的长度
    • 召回率:
      $$R_{LCS} = \frac{LCS(\text{candidate}, \text{reference})}{\text{length}(\text{reference})}$$
    • 精确率:
      $$P_{LCS} = \frac{LCS(\text{candidate}, \text{reference})}{\text{length}(\text{candidate})}$$
    • F1 分数:
      $$F_{LCS} = \frac{2 \cdot P_{LCS} \cdot R_{LCS}}{P_{LCS} + R_{LCS}}$$

代码实现 (使用 rouge-score 库)

  • rouge-score 是一个常用的 Python 库,用于计算 ROUGE 分数
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    from rouge_score import rouge_scorer

    def calculate_rouge(reference, candidate):
    """
    计算 ROUGE-1, ROUGE-2, ROUGE-L 分数
    :param reference: 参考摘要字符串
    :param candidate: 候选摘要字符串
    :return: 包含 ROUGE 分数的字典
    """
    # scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL', 'rougeLsum'], use_stemmer=True)
    # 通常使用 rouge1, rouge2, rougeL
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=False)
    # use_stemmer=True 可以对词语进行词干化,可能会影响结果,取决于具体需求,例如 "running" 和 "ran" 可能会被视为匹配

    scores = scorer.score(reference, candidate)

    # scores 返回的是一个字典,例如:
    # {'rouge1': Score(precision=..., recall=..., fmeasure=...),
    # 'rouge2': Score(precision=..., recall=..., fmeasure=...),
    # 'rougeL': Score(precision=..., recall=..., fmeasure=...)}

    # 我们可以提取 F-measure 分数
    result = {}
    for key, score in scores.items():
    result[key] = {
    'precision': score.precision,
    'recall': score.recall,
    'fmeasure': score.fmeasure
    }
    return result

    # 示例
    reference_summary = "The quick brown fox jumps over the lazy dog."
    candidate_summary = "The quick brown fox jumps on the log."

    rouge_scores = calculate_rouge(reference_summary, candidate_summary)

    for metric, scores in rouge_scores.items():
    print(f"{metric.upper()}:")
    print(f" Precision: {scores['precision']:.4f}")
    print(f" Recall: {scores['recall']:.4f}")
    print(f" F-measure: {scores['fmeasure']:.4f}")

    # ROUGE1:
    # Precision: 0.7500
    # Recall: 0.6667
    # F-measure: 0.7059
    # ROUGE2:
    # Precision: 0.5714
    # Recall: 0.5000
    # F-measure: 0.5333
    # ROUGEL:
    # Precision: 0.7500
    # Recall: 0.6667
    # F-measure: 0.7059

ROUGE 使用说明

  • ROUGE 更侧重召回率,因此对于生成式摘要任务更为适用,因为摘要通常比原文短,我们更关心模型是否捕捉到了原文的关键信息
  • ROUGE-L 能够更好地处理词序变化 ,因为它基于最长公共子序列
  • ROUGE-N(特别是 ROUGE-1 和 ROUGE-2)是最常用的 ROUGE 变体

附录:NLTK 库中 BLEU 平滑策略

  • method0(): 无平滑 (No smoothing)
  • method1(): 添加epsilon计数 (Add epsilon counts)
  • method2(): 添加1到分子和分母 (Add 1 to both numerator and denominator)
  • method3(): NIST 几何序列平滑 (NIST geometric sequence smoothing)

NLP——稀疏化Attention之SWA-NSA-MoBA

  • 参考链接:
    • 稀疏化Attention:SWA->NSA->MoBA * 假如给我一只AI的文章 * 知乎

整体说明

  • 稀疏化注意力机制(Sparse Attention)旨在解决传统注意力机制(如 Transformer 中的 Full Attention)在长序列处理时计算复杂度高(\(O(N^2)\))的问题
  • 本文是稀疏化注意力机制从 SWA(Sparse Wave Analysis)到 NSA(Native Sparse Attention)再到 MoBA(Mixture of Block Attention)的改进过程
  • 稀疏注意力机制的设计核心:
    • 通过稀疏表示、动态分层策略和块注意力混合,逐步降低了长序列处理的计算复杂度
    • 针对硬件进行优化,NSA 和 MoBA 均针对现代硬件进行了优化,显著提升了计算速度
    • 关注性能与效率的平衡,在保持模型性能的同时,减少计算资源和内存占用

SWA(Sparse Wave Analysis)

  • SWA 是一种基于稀疏波数分析的方法,最初应用于信号处理领域,用于恢复多模态和色散特性。其核心思想是通过稀疏表示(Sparse Representation)来减少计算量,同时保持对关键信息的捕捉。SWA通过压缩感知技术,利用信号的稀疏性,从有限的数据中恢复出完整的频率-波数表示,从而减少计算复杂度
  • SWA 的核心特点 :
    • 利用信号的稀疏性,减少计算量
    • 适用于多模态和频率分散的场景
    • 通过优化策略(如基底追踪去噪)实现高效恢复
  • SWA 主要针对信号处理领域,未直接应用于大语言模型
  • SWA 缺乏对长上下文建模的针对性优化

NSA(Native Sparse Attention)

  • NSA 由 DeepSeek 提出,是一种针对大语言模型的稀疏注意力机制。NSA通过动态分层稀疏策略,结合粗粒度的 Token 压缩和细粒度的 Token 选择,显著降低了长序列处理的计算复杂度,同时保持了模型性能
  • NSA 的核心改进包括:
    • 动态分层稀疏策略 :通过粗粒度压缩和细粒度选择,兼顾全局上下文和局部信息的精确性
    • 硬件优化 :算法设计与现代硬件对齐,显著提升计算速度
    • 端到端可训练 :支持从预训练到推理的全流程优化,减少计算量
  • 在长文本任务和指令推理中,NSA 的性能优于 Full Attention,且计算速度大幅提升
  • NSA 虽然能效率提升显著,但在某些复杂任务中,稀疏策略可能导致信息丢失

MoBA(Mixture of Block Attention)

  • MoBA 由月之暗面提出,是一种混合块注意力机制,灵感来源于 MoE(Mixture of Experts)结构
  • MoBA 通过将注意力计算限制在最相关的上下文块中,进一步优化了计算效率
  • MoBA 支持全注意力和稀疏注意力的自由切换
  • MoBA 的核心改进:
    • 块注意力混合 :将长序列划分为多个块,仅对最相关的块进行计算,减少冗余计算
    • 灵活切换 :支持全注意力和稀疏注意力的动态切换,适应不同任务需求
    • 高效训练 :MoBA v2在短文本和长文本任务中均表现出色,且训练过程稳定
  • 在长上下文建模中,MoBA 显著降低了计算开销,同时保持了模型性能
  • MoBA 还表现出良好的扩展性和稳定性
  • 缺点是块划分策略需要精细设计,否则可能影响模型对全局上下文的理解

NLP——认识RWKV

前言:RWKV 作为挑战 Transformer 架构的国人开源项目,有前景,本文先简单介绍,有时间回来详细补课


整体说明

  • RWKV,全称 Receptance Weighted Key Value,中文名元始智能 ,是一种语言模型架构(由纯中国团队开发的,开源的语言架构)
  • TLDR:RWKV 结合了 RNN 和 Transformer 的优势
    • 传统 Transformer:计算复杂度随序列长度呈现二次方,且随着序列长度变长显存也一直在增长
    • RWKV 的核心思路:通过线性注意力机制和循环结构实现高效的并行训练与推理,同时保持 RNN 的低显存占用和恒定推理速度,还自然地做到了长度外推
  • 作者是 Bo Peng,知乎主页:PENG Bo
  • 评价:RWKV 作为首个中国纯字眼开源的非 Transformer 架构大模型,凭借高效的计算设计和持续的技术迭代,已在自然语言处理领域占据一席之地(开源社区活跃)
  • 其动态状态演化机制(如 RWKV-7)和多语言能力使其在长文本处理和低显存场景上具有显著优势
  • 期待 RWKV 成为替代 Transformer 架构的下一代语言模型架构

RWKV 核心优势和亮点

  • 线性复杂度 :计算复杂度为 \(O(Td)\)(\(T\) 为序列长度,\(d\) 为特征维度),显著低于Transformer的\(O(T^2)\)
    • 支持处理“无限”上下文长度,尤其适合长文本生成和多轮对话
  • 低资源消耗 :显存占用恒定

RWKV 主要架构版本迭代

  • RWKV-1/2/3 :从2021-2022年开始,逐步发布了前置版本,不是很成熟
  • RWKV-4(2023年):首个成熟版本,通过 Token-shift 技术实现循环与并行训练的结合,性能与同规模 Transformer 相当,论文被 EMNLP 2023 收录
  • RWKV-5/6(2024年):引入矩阵值状态和动态机制,提升长序列处理能力,如 RWKV-6-World-14B 在多语言评测中超越 Llama2 13B
  • RWKV-7(2025年):最新架构,采用动态状态演化(Dynamic State Evolution),超越传统注意力范式,支持持续学习和更复杂的上下文理解。例如,RWKV-7-World-2.9B在MMLU测试中得分54.56%,显著优于前代模型

RWKV 发展的时间线

  • 2020 年,BlinkDL 开始研究 Transformer,发现引入显式 decay 和 Token-shift 两个改进方向
  • 2021 年 8 月,RWKV 架构初版 RWKV-V1 被提交到 RWKV-LM 仓库
  • 2022 年,RWKV-V2 版本首次为 RWKV 实现 RNN 模式;2022年底,发布首个模型
  • 2023 年 6 月,RWKV 正式成立商业公司;2023 年 9 月 20 日,开源项目正式加入 Linux 开源基金会;2023 年 10 月,RWKV-4 架构论文被 EMNLP 2023 收录
  • 2024 年 7 月 19 日,RWKV 开源基金会宣布向全球开放 RWKV-6-World-14B 模型(超过 Llama2 13B);12 月,完成数千万人民币天使轮融资
  • 2025 年 2 月 22 日,参加在上海举办的首届 “RWKV-7 架构与未来趋势” 开发者大会
  • 注:目前团队从 3 人扩展至近 20 人,2024 年获天际资本数千万人民币天使轮融资,用于技术迭代和产品落地

RWKV 当前的缺点

  • 提示词敏感性 :基底模型对提示格式较为敏感,需优化输入顺序以提升生成质量
  • 回顾性任务局限 :在需要回溯前文的任务中表现较弱,需通过提示工程或微调弥补

NLP——Agentic-Design-Patterns-阅读笔记

  • 参考链接:
    • 原始书籍地址:Agentic Design Patterns,20250911
    • 中文版 PDF 地址:智能体设计模式
    • 中文在线阅读地址:智能体设计模式(在线阅读)
    • 英文版 PDF 地址:github.com/sarwarbeing-ai/Agentic_Design_Patterns

简单介绍

  • 本书发布日期是 20250911,作者是 Antonio Gulli

前置讨论


第一章:

NLP——ASearcher

注:本文包含 AI 辅助创作

  • 参考链接:
    • 原始论文:(ASearcher)Beyond Ten Turns: Unlocking Long-Horizon Agentic Search with Large-Scale Asynchronous RL, 20250811 & 20250813 & 20250910, THU, Ant Research
    • GitHub 开源:github.com/inclusionAI/ASearcher

Paper Summary

  • 论文核心:介绍了一个用于大规模 RL 训练搜索智能体的开源项目 ASearcher
    • 论文的贡献包括一个完全异步的智能体 RL 训练系统和一个用于大规模高质量问答对构建的数据合成智能体
    • 实验验证 ASearcher 在不同的模型规模和评估设置下均优于最先进的开源智能体
      • 包括 Qwen2.5-7B、Qwen2.5-14B 的基础模型
      • QWQ-32B(基于提示)的 LLM 智能体
  • 背景信息:LLM-based LLM 集成了外部工具,在处理复杂的、知识密集型任务方面展现出了卓越的能力
    • 在众多工具选择中,搜索工具 (search tools) 在访问海量外部知识方面扮演着关键角色
  • 问题提出:
    • 但开源智能体在实现专家级的(expert-level)搜索智能 (Search Intelligence)方面仍然存在不足
      • 搜索智能 即解决模糊 Query 、生成精确搜索、分析结果并进行彻底探索的能力
    • 现有方法在可扩展性、效率和数据质量方面存在缺陷
      • 例如,现有 Online RL 方法中的小步数限制(例如 \(\leq 10\))限制了复杂策略的学习
  • 论文的主要贡献包括:
    • (1) 可扩展的完全异步 RL 训练 (Scalable fully asynchronous RL training) ,能够在保持高训练效率的同时实现长视野搜索 (long-horizon search)
    • (2) 一个基于提示的 LLM 智能体 (prompt-based LLM agent) ,能够自主合成高质量且具有挑战性的问答对 (QA),创建大规模 QA 数据集
  • 通过 RL 训练,论文基于提示的 QwQ-32B 智能体取得了显著改进,在 xBench 和 GAIA 上的 Avg@4 指标分别提升了 46.7% 和 20.8%
  • 论文的智能体表现出极长的视野搜索能力,在训练期间工具调用次数超过 40 步,输出 Token 数超过 150k
  • 通过简单的智能体设计且无需外部 LLM,ASearcher-Web-QwQ 在 xBench 和 GAIA 上分别取得了 42.1 和 52.8 的 Avg@4 分数,超越了现有的开源 32B 智能体
  • 论文已在 github.com/inclusionAI/ASearcher 开源论文的模型、训练数据和代码

Introduction and Discussion

  • LLM-based LLM 的最新进展表明
    • 通过利用单个或多个外部工具 (2024; 2025;),Agent 在解决复杂的、知识密集型问题方面具有卓越能力
  • 其中,搜索工具 (search tools) 尤为关键,它使智能体能够访问海量外部知识以增强问题解决能力 (2023; 2024; 2025)。然而,专家级地使用搜索需要高级智能
    • 例如,考虑这个问题:“截至 2024 年 12 月 31 日,中国在 2012 年伦敦奥运会上获得的金牌、银牌和铜牌数量分别是多少?”
      • 这个问题看似简单,但实际上具有挑战性,因为网络上存在相互矛盾的答案(例如,“38 金、27 银、22 铜” vs “39 金、31 银、22 铜”)
    • 一个搜索智能体必须从不同来源中筛选噪声和矛盾的答案,识别冲突的根本原因(例如官方报告中因兴奋剂检测不合格而被取消资格),并最终确定正确答案
  • 具有挑战性的现实世界任务要求智能体能够解决输入 Query 中的高度不确定性、生成精确的搜索 Query 、从海量数据中分析和提取关键见解、解决不一致性并进行深入探索
    • 论文将这种高级能力称为 “搜索智能 (Search Intelligence)”
  • 专有智能体和模型已经通过大规模 RL 训练 (2025;) 展现出复杂的搜索行为迹象
    • 但用于开发搜索智能体的开源方法仍然面临显著限制
    • 一系列工作采用强化学习 或监督微调 方法来激励工具使用能力 (2025;)
  • 基于提示的 LLM 智能体 (prompt-based LLM agents) 可以在无需训练的情况下执行大量工具调用 (2025;)
    • 但在实践中,论文发现现有的 Online RL 方法未能激励复杂且有效的搜索策略
    • 论文还发现基于提示的 LLM 智能体可能会因为 LLM 能力不足而失败
      • 例如无法从噪声网页中精确提取关键信息,以及无法验证错误结论
  • 最近一些工作进一步在基于提示的 LLM 智能体基础上,利用 Offline RL 方法来改进这些智能体 (2025;)
    • 但这种 Offline RL 范式在更广泛的领域中被证明表现不如 Online RL (2024; 2021; 2024)
  • 在数学和代码等推理任务中,Online RL 使得模型能够基于正确性反馈迭代优化推理过程,从而演化出复杂行为 (2025;)
    • 这引出了一个关键问题:Online RL 方法如何有效地在开源智能体中解锁搜索智能?
  • 论文识别了两个阻碍搜索智能体有效进行 Online RL 训练的关键障碍:
    • 搜索步数不足限制了复杂策略的学习
      • 现有工作,例如 Search-R1 (2025),人为限制了搜索步数,例如每条轨迹 \(\leq 10\) 步,这阻止了智能体探索更深的搜索路径
      • 但复杂的 Query 通常需要多轮工具调用和多步推理,这在严格的步数限制下无法学习
    • 缺乏大规模、高质量的问答对 (question-answer, QA pairs):
      • 推理任务的 RL 训练需要丰富、具有挑战性且正确的 QA 对 (2025;)
      • 但大多数现有的用于搜索智能体的开源数据集往往过时(例如 HotpotQA)、过于简化或规模太小,无法通过 RL 激发复杂的搜索行为 (2018; 2020; 2025;)
  • 为了应对这些挑战,论文提出了 ASearcher,一个旨在为搜索智能体实现大规模智能体 RL 训练 (large-scale agentic RL training) 的开源项目。论文的贡献包括:
    • 通过完全异步智能体 RL 训练实现长视野搜索
      • 在批生成 RL 训练系统 (2025;) 中设置较大的步数限制时,批次内的长轨迹很容易导致显著的闲置时间,从而减慢整个训练过程
      • 基于 AREaL (2025),论文的完全异步系统通过将轨迹执行与模型更新解耦,避免了长轨迹阻塞训练
      • 这允许放宽步数限制(例如,128 步/轨迹),使得智能体能够在不牺牲训练效率的情况下探索更深的搜索路径
      • 论文的智能体 ASearcher-Web-QwQ 实现了极长的视野搜索,在 RL 训练期间工具调用次数超过 40 步,生成的 Token 数超过 150k
    • 一个可扩展的 QA 合成智能体
      • 论文设计了一个 LLM-based 智能体,能够自主生成需要多轮工具使用的具有挑战性、不确定性和事实依据的 (challenging, uncertain, and grounded) QA 对
      • 从种子问题开始,该智能体通过模糊关键信息或注入外部事实来迭代地模糊查询 (fuzzes queries) 以增加复杂性
      • 每个构建的问题都经过多阶段验证 (multistage validation) 以确保质量和难度
      • 论文从 14k 个种子 QA 中生成了 134k 个高质量样本,其中 25.6k 个需要借助外部工具来解决
  • 使用 ASearcher,论文在两种设置下训练配备搜索引擎和浏览器的智能体:
    • 从基础模型开始进行 RL 训练 (Qwen2.5-7B/14B),以证明论文的训练流程能够激励强大且可泛化的搜索策略;
    • 微调由强大 LRM (QwQ-32B) 驱动的基于提示的智能体 ,以验证论文的训练流程在微调大规模基于提示的 LLM 智能体时的可扩展性
  • 论文在多跳 QA 基准测试和具有挑战性的基准测试上评估论文的智能体
    • 包括 GAIA (2023)、xbench-DeepSearch (2025) 和 Frames (2024)
  • 仅使用本地知识库训练的 ASearcher-Local-7B/14B,在现实的网络搜索中展现出惊人的泛化能力,并在多跳和单跳 QA 任务上达到了最先进的性能
    • 基于 QwQ-32B 构建的 ASearcher-Web-QwQ 在 xBench-DeepSearch 和 GAIA 上分别取得了 42.1 和 52.8 的 Avg@4 分数,超越了一系列开源智能体
    • 在评估 Pass@4 时,ASearcher-Web-QwQ 在 GAIA 和 xBench-DeepSearch 上分别达到了 70.1 和 68.0
    • 通过 RL 训练,ASearcher-Web-QwQ 在 xBench-DeepSearch 和 GAIA 上分别获得了 46.7% 和 20.8% 的提升
  • ASearcher 提出了一个面向基于 LRM 和 LLM 的搜索智能体的大规模开源在线智能体 RL 流程,通过可扩展的训练和高质量的数据解锁了搜索智能
    • 希望论文的发现不仅能推动搜索智能体的发展,也能为面向复杂现实世界任务的 LLM 智能体带来更广泛的创新启发

Limitations of Existing Open-source Approaches

  • 在本节中,论文针对一个来自 GAIA (2023) 的极具挑战性的问题进行了详细的案例研究
  • 具体来说,论文在图 3 中分析了 Search-R1-32B (2025) 和 Search-o1 (QwQ) (2025)
    • 详细的轨迹在附录 A 中提供

Solution Path of the Sample Question

  • 在图 3 中,论文的案例研究针对一个需要找到具有 4 个未知变量 的特定动物的问题
  • 为了识别正确答案,搜索智能体应首先根据条件“以哥本哈根命名的属 (genus named for Copenhagen)”找出所提及的物种
    • 根据该物种维基百科页面上的引文识别正确的 2021 年文章,然后找出两位提及人物的论文
  • 最终,正确答案应通过交叉引用 2021 年的文章和论文来确定;总而言之,这个例子具有挑战性的原因有几个:
    • 高不确定性 (High Uncertainty): 问题涉及多个未知变量,这些变量可能指向许多不同的实体
      • 例如,“2021 年的文章”可能指向 2021 年发表的任何文章,并且只能通过检查 肺泡物种 (alvei species) 维基百科页面中的“多中心、随机、双盲研究 (multicenter, randomized, double-blind study)”来确定
    • 需要精确的信息提取 (Requirement for Exact Information Extraction): 为了找到答案,智能体应列出网页上提到的所有动物并进行跨文档比较
      • 这要求智能体从海量、充满噪声的网页内容中精确提取关键信息,而不是简单地总结网页
    • 误导性答案 (Misleading Answers): 在解决此任务的过程中,可能会出现多个误导性答案(例如“猪 (pigs)”)
      • 智能体应通过检查所有相关网页和文档中的预期答案来严格验证其结论
  • 现有 Online RL 方法未能学习复杂搜索策略 (Existing Online RL Approaches Fail to Learn Complex Search Strategies)
    • 在图 3 中,Search-R1-32B 无法将复杂 Query 分解为单个组成部分,因此只能进行涉及太多未知信息的模糊 Query
      • 该智能体还存在严重的幻觉 (hallucinations),产生了搜索结果不支持结论
      • 最后,它未能解析所有未知信息
    • 这个案例研究表明,现有的 Online RL 方法仅能激励初级的搜索策略
    • 同样值得注意的是,由于在训练期间步数限制被设置为一个较小的值(例如 4 步)该模型仅表现出较短的工具使用视野
  • 基于提示的 LLM 智能体可能因 LLM 能力不足而失败 (Prompt-based LLM Agents Could Fail Due to Insufficient Capability of the LLM)
    • 在图 3 中,Search-o1 (QwQ) 可以通过大量工具调用找到物种名称,以及 2021 年的文章和相关论文
    • 但在试图寻找答案时,Search-o1 (QwQ) 很容易遗漏关键信息,从而得出错误的结论
      • 即使智能体找到了直接指向正确答案的信息,它仍然会被先前错误的结论所误导
      • 最后,该智能体无法验证先前结论的正确性
    • 这个案例研究揭示,尽管一个未在智能体任务上明确训练的开源模型可以执行大量的工具调用 ,但它无法基于检索到的内容和历史上下文进行专家级的推理
  • ASearcher-Web-QwQ (论文端到端 RL 智能体 ASearcher-Web-QwQ 的搜索策略)
    • 如图 3 所示,ASearcher-Web-QwQ 将复杂 Query 分解为精确的 Query
      • 与 Search-o1 (QwQ) 在每次搜索 Query 后访问大量网站不同,ASearcher-Web-QwQ 专注于一次访问一个网站
      • 问题:这样会不会太慢
    • ASearcher-Web-QwQ 总结了网站的所有相关信息
      • 所有候选答案都被列出并由智能体仔细分析
      • 当搜索结果没有直接指向期望目标时,例如,当使用“Olga Tapia Hafnia alvei animal studies”进行搜索以查找与 Olga Tapia 论文相关的动物时,智能体没有获得明确的信息,但能够通过与其他论文建立联系来推断出正确答案
      • 在找到正确答案“小鼠 (Mice)”后,智能体在报告最终答案之前花费了额外的步数来验证先前的结论
    • 总之,ASearcher 成功训练出了一个展现出专家级搜索行为的搜索智能体 :
      • 不确定性感知推理 (Uncertainty-aware reasoning): 智能体详尽地列出并检查所有不确定实体的可能性
      • 精确的关键信息提取 (Precise Key Information Extraction): 智能体能够从海量、充满噪声的网页内容中识别关键信息
      • 跨文档推理 (Cross-document Inference): 智能体能够通过建立多个文档之间的联系来推断关键结论
      • 基于事实的验证 (Grounded Verification): 智能体通过访问或搜索相关材料来验证先前结论的正确性

ASearcher

  • 论文提出了 ASearcher,一个通过大规模 RL 训练来解锁搜索智能(Search Intelligence)的开源项目
  • 如图 3 所示,ASearcher 训练了一个能够通过彻底解决所有不确定性并执行多轮工具调用来解决复杂问题的搜索智能体
  • 在后续的小节中,论文将介绍 ASearcher 中的智能体设计、训练数据及数据合成智能体,以及完全异步的强化学习训练

Agent Design

  • 论文在 ASearcher 中采用了一种简单的智能体设计,如图 2 所示
  • 工具 (Tools).
    • 给定一个用户 Query ,智能体可以使用两个基本工具:一个搜索引擎和一个网络浏览器
      • 搜索引擎:接收文本 Query 作为输入,并返回相关的摘要片段及其对应的 URL
      • 网络浏览器:接收一个 URL 并返回网页的内容
    • 为了有效解决复杂问题,模型应策略性地结合这些工具,并从海量数据中提取关键信息
  • 网页摘要 (Webpage Summarization).
    • 网页可能包含过长的内容,因此论文利用智能体将网页总结成一个简洁的摘要
    • 在训练时,这个摘要过程也会被优化,允许智能体通过强化学习训练来提高摘要能力
  • 使用基础 LLM 和高级 LRM 实例化 ASearcher (Instantiating ASearcher with Base LLMs and Advanced LRMs).
    • 在 ASearcher 框架内,论文研究了两种搜索智能体的具体实例化方式:
      • 一种是使用基础大语言模型(Base LLMs) ,例如 Qwen2.5-7B/14B;
      • 另一种是使用高级大推理模型(Large Reasoning Models, LRMs) ,例如 QwQ-32B
    • 这两种不同类型的实例化在历史管理和提示(Prompting)方面需要不同的设计选择
      • 对于基础 LLM ,论文遵循先前的工作 (2025;),采用仅追加(append-only)风格的提示方式
        • 从一个系统提示(System Prompt)开始,所有由 LLM 生成的响应、搜索结果和网页摘要都被追加到历史记录中
        • 智能体按时间顺序接收完整的历史记录作为输入,并输出一些推理文本和动作。
        • 这种方法确保了推理时的效率
      • 对于LRM ,LRM 本身已经具备了指令跟随能力
        • 论文使用不同的提示来指导 LRM 进行工具选择、摘要和回答
        • 论文还注意到 LRM 通常会生成长响应,有时历史记录会很长
          • 问题:需要确保输入的紧凑性,以保证 LRM 有足够的预算来生成 Token
          • 解法:在历史记录中,丢弃思维过程,而是保留总结后的想法和工具调用
        • 在提示 LRM 时,只将最近 25k 个字符的历史记录作为附加上下文提供给 LRM
          • 这些简单的设计确保了 LRM 的输入最多为 10k 个 Token
          • 问题:25k 不是已经比 10k 大了吗?
  • 端到端强化学习 (End-to-End Reinforcement Learning).
    • 智能体所有由 LLM 生成的响应,包括思维过程、工具调用和摘要,都是以端到端的方式使用强化学习进行训练的

Training Data

  • 论文的训练数据主要有两个来源
    • 开源数据集:仔细筛选,以确保其难度和质量
    • 合成数据:高质量的问答对(Question-Answer pairs, QA pairs),专门设计用于指导智能体学习可泛化的搜索策略
开源数据 (Open-source Data).
  • 论文从 HotpotQA (2018) 和 2WikiMultiHopQA (2020) 的训练集开始,这两个都是多跳问答数据集
  • 论文采用了基于模型的过滤流程
    • 使用 RL 在完整的开源数据集上训练一个模型,再使用训练好的模型为每个问题生成 16 个响应
    • 最后,论文过滤掉满足以下任一标准的问题:
      • 模型在 16 个响应中未能找到一个正确答案
      • 模型达到了 \(\ge\) 50% 的准确率,意味着问题挑战性不足
      • 模型仅用少量搜索轮次(即 \(\le\) 1 轮)就找到了正确答案
  • 这种过滤方法确保论文只保留最具挑战性但又可解决、且需要使用工具的问题
  • 最终,从总共 304k 个问答对中 ,论文保留了 16k 个具有挑战性的样本用于 RL 训练
  • 此外,论文还纳入了一组专为访问特定网页而设计的问答对
  • 特别是,论文加入了 WebWalkerQA (2025) 的一小部分子集,以帮助模型学习如何在嘈杂的真实网络搜索环境中定位答案
Data Synthesis Agent
  • 论文进一步开发了一个数据合成智能体来创建高质量的问答对
  • 如图 4 所示,数据合成智能体从一个种子问题开始,迭代地修改问题以增加复杂性
  • 为了确保合成的问题与可靠来源严格对齐,在问题合成过程中获得的一系列支持事实(supporting facts)被保留下来,并持续更新以进行质量验证
  • 在每一步,给定当前的问题和一个支持事实列表,智能体自动在以下两个关键动作之间进行选择:
    • 动作 1:注入(Injection) 旨在通过插入与问题相关的事实来丰富问题的上下文
      • 智能体首先选择问题中的一个实体,然后从外部来源(如维基百科)获取关于该选定实体的一条相关事实
      • 接着,通过将该事实注入到问题中,提出一个新的问题
      • 这个注入动作增加了问题的复杂性
    • 动作 2:模糊化(Fuzzing) 模糊问题中的某些细节,以增加问题的不确定性水平
      • 例如,“Catskill Mountain Railroad”(Catskill 山铁路)可能被替换为 “a historic mountain railway”(一条有历史意义的铁路)
      • 通过多次对问题进行模糊化处理,问题的不确定性水平和难度都会逐渐增加
  • 为了确保合成问题的高质量并精确评估其难度,论文为评估合成问题加入了一个严格的质量验证(quality verification)阶段:
    • 步骤 1. 基本质量(Basic Quality). 论文使用一个 LLM 来评估每个问题的基本质量
      • 此验证包括检查问题的清晰度,并根据支持事实验证问答对的准确性
      • 此质量控制步骤确保每个问答对都正确地基于可靠来源
    • 步骤 2. 难度测量(Difficulty Measurement). 论文使用一个前沿的 LRM(例如 QwQ-32B)直接为合成问题生成多个答案,而不使用任何外部工具
      • 此验证过程也作为问题难度的衡量标准
    • 步骤 3. 答案唯一性(Answer Uniqueness). 模糊化动作可能会过度放松约束,损害答案的唯一性
      • 为了防止因多个正确答案而产生的歧义,论文评估在难度测量步骤中生成的任何 mismatched answers 是否可以作为替代的有效答案
  • 论文在表 1 中提供了两个说明性示例。从一个简单的问题开始,注入动作用相关的事实细节替换特定的实体
    • 例如,“Michael P. Hein” 被扩展为 “who served as the first County Executive of Ulster County, New York…”
    • 模糊化动作通过泛化精确信息来引入模糊性,例如将确切的年份 “1934” 替换为 “the early 1930s”,或者将 “Catskill Mountain Railroad” 替换为 “a historic mountain railway”
  • 通过迭代的注入和模糊化,数据合成智能体产生出涉及复杂信息和高不确定性的问题,需要大量的搜索和推理才能找到正确答案
    • 在完成问题合成过程后,论文过滤掉那些 LRM 可以不依赖搜索工具直接生成正确答案的问题
    • 由于这些问题仅基于模型的内在知识就能回答,它们对于增强搜索能力几乎没有价值
  • 从 14,107 个种子问题开始,论文对每个问题平均执行了 6.3 次注入和 3.2 次模糊化
    • 从合成池中,论文为每个种子问题最多选择三个高质量的变体
    • 这个筛选过程产生了包含 25,624 个条目的最终数据集,所选问题平均每个包含 4.27 次注入和 2.10 次模糊化

Asynchronous Agentic RL Training

Challenges of Scaling Up Trajectory Length in RL
  • 实验表明复杂任务需要大量的工具调用,因此具有较大轮次限制的 RL 训练对于训练高级搜索智能体是必要的
  • 训练期间轨迹执行时间的方差很大,这可能导致批量生成 RL 系统出现显著的闲置时间
  • 复杂任务需要长轨迹 (Complex Tasks Require Long Trajectories).
    • 智能体任务通常需要大量的 LLM 生成和多次工具调用来解决复杂问题,导致轨迹执行时间延长
    • 如图 6(左)所示,论文在 GAIA (2023)、xBench-Deepsearch (2025) 和 Frames (2024) 上评估了论文经过 RL 训练的 QwQ-32B 智能体,强制智能体使用不同最小轮次数量的工具
    • 结果表明,准确率随着轮次的增加而提高,证实了复杂任务需要更长的轨迹来进行有效的问题解决
  • 轨迹执行时间的高方差 (High Variance in Trajectory Execution Time).
    • 长轨迹也带来了执行时间的显著方差
    • 论文分析了 QwQ 智能体 RL 训练期间的工具调用次数和 Token 生成数量(图 6),观察到最长的轨迹可能比短轨迹多出数十次工具调用和两个数量级以上的 Token
    • 这种差异导致每个轨迹的运行时间高度不可预测,进一步降低了训练效率
  • 智能体 RL 训练的效率问题 (Efficiency Issues of Agentic RL Training).
    • 长时间的执行和高运行时间方差都会降低 RL 训练效率
    • 论文以 one-step-off RL 训练系统 (one-step-off RL training system,也称为 One-Off,来自 DeepCoder,2025) 作为批量生成 RL 系统的代表性例子
      • 参考链接:DeepCoder: A Fully Open-Source 14B Coder at O3-mini Level
    • 在 one-step-off RL 训练中,第 N 步的训练和第 N+1 步的轨迹生成是并发执行的
    • 如图 7 所示,尽管该系统将轨迹 rollout 与模型训练重叠,但批量生成仍然受限于最慢的轨迹(例如轨迹 7),导致 GPU 闲置时间和利用率不足
完全异步 RL 训练 (Fully Asynchronous RL Training).
  • 为了确保高效的智能体 RL 训练,论文采用了完全异步的训练范式
    • 论文的方法在两个不同方面引入了异步
  • 异步轨迹 Rollout (Asynchronous Trajectory Rollouts).
    • 轨迹 rollout 是并行收集的,并且不直接相互干扰
    • 每个轨迹独立地向相应服务器发送工具调用请求,并向 LLM 推理引擎发送 LLM 生成请求
    • 来自不同轨迹的并发请求由服务器自动处理
    • 完全独立的轨迹执行确保了一个轨迹在生成 LLM 响应和等待工具调用响应时不需要等待其他轨迹,从而提高了训练效率
  • 解耦的 Rollout 和训练 (Decoupled Rollout and Training).
    • 除了异步 rollout 之外,轨迹 rollout 和模型更新也是完全解耦的
    • 在图 7 中,论文将论文的完全异步 RL 训练与 one-step-off RL 训练进行了比较,后者在批次内利用异步 rollout
    • 在完全异步 RL 训练中,长轨迹不会阻塞生成,并且可以跨越多个版本,显著减少了 GPU 闲置时间,并在生成过程中实现了近乎完全的 GPU 利用率
    • 在训练侧,一旦收集到足够的轨迹形成一个批次,就会立即启动一个训练步骤
    • 如图 7 所示,训练过程不会等待极长的轨迹 7,而是继续处理轨迹 9

Training Details

  • MDP 公式化 (MDP Formulation). 论文遵循马尔可夫决策过程(Markov Decision Process, MDP)的公式化
    • 形式上,一个 MDP 由元组 \((S,A,T,R)\) 定义
      • \(S\) 代表状态空间,通常包含历史记录、搜索结果和检索到的网页
      • \(A\) 表示动作空间,一个动作包括智能体生成的 Token
        • 一些工具调用可以通过特定的标签从动作中提取,例如 <search> search query </search>
      • \(T(s^{\prime}|s,a)\) 是转移概率:其中 \(s^{\prime}\) 是在状态 \(s\) 应用动作 \(a\) 中的工具调用后的更新状态
    • 在每个时间步,智能体接收一个状态 \(s_{t}\),并根据策略 \(\pi:S\to A\) 生成一个动作 \(a_{t}\)
    • 智能体的目标是最大化回报
      $$ J(\pi)=\mathbb{E}\left[\sum_{t=0}^{\infty}R(s_{t},a_{t})\bigg{|}a_{t}\sim\pi(s_{t})\right]$$
  • GRPO 训练 (GRPO Training). 论文采用 GRPO (2024) 算法来训练搜索智能体
    • 对于每个输入问题 \(x\),生成 \(G\) 个轨迹 \(\tau_{1},\tau_{2},\cdots,\tau_{G}\)
      $$ \tau_{i}=(s^{i}_{0},a^{i}_{0},s^{i}_{1},\cdots,s^{i}_{T_{i} }) $$
    • 为了优化智能体,论文采用以下损失函数:
      $$
      \begin{align}
      \mathcal{J}_{GRPO}(\theta)=\mathbb{E}_{x\sim\mathcal{D}_{\epsilon}\{\tau_{i}\}_{i=1}^{G}\sim\pi_{\theta_{old} }(:\left|x\right\rangle}\left[\frac{ 1}{G}\sum_{i=1}^{G}\frac{1}{\sum_{t=0}^{T_{i}-1}|a^{i}_{t}|}\sum_{t=0}^{T_{i}-1}\sum_{j=1}^{|a^{i}_{t}|}\min\left(\frac{\pi_{\theta}(a^{i}_{t,j}|s_{t},a^{i}_{t,< j})}{\pi_{\theta_{old} }(a^{i}_{t,j}|s_{t},a^{i}_{t,< j})}\hat{A}_{i},\right.\right. \left.\left.\text{clip}\Bigg{(}\frac{\pi_{\theta}(a^{i}_{t,j}|s_{ t},a^{i}_{t,< j})}{\pi_{\theta_{old} }(a^{i}_{t,j}|s_{ t},a^{i}_{t,< j})},1-\epsilon ,1+\epsilon \Bigg{)}\hat{A}_{i}\Bigg{)}\right]\right.
      \end{align} \tag{1}
      $$
      • 其中 \(\epsilon\) 是一个超参数,\(\hat{A}_{i}\) 是第 \(i\) 个轨迹的优势函数(Advantage),基于每个组内所有轨迹的相对奖励计算得出
  • 动态过滤 (Dynamic Filtering). 为了提高训练效率,论文实施了动态过滤,以排除缺乏有意义的训练信号的 Query
    • 具体来说,论文移除所有响应产生相同奖励(导致优势为零)的 Query ,包括智能体已经达到高准确率的 Query 和答案标记错误的 Query
  • 奖励函数 (Reward Function). 对于奖励函数,论文采用稀疏奖励(Sparse-reward)设置,在轨迹完成时计算奖励
    • 若从基础 LLM 开始训练 ,奖励函数通过乘法结合了格式奖励(Format Reward)和 F1 分数
      • 问题:这里的 F1 分数是什么?是工具调用相关 精确率 和 召回率 的衡量吗?
      • 回答:从下文来看,是的
    • 若基于 LRM 的智能体(例如 QwQ)进行微调,论文使用 LLM-as-Judge (2023; 2024) 作为奖励函数,并省略格式奖励,因为这些模型本身就保持了适当的输出格式

Experiments

Experiment Setup

  • 基准测试 (Benchmarks)
    • 论文首先在单跳和多跳问答任务上评估智能体
      • 对于单跳问题,论文使用 Natural Questions (2019)、TriviaQA (2017) 和 PopQA (2022)
      • 对于多跳问题,论文使用 HotpotQA (2018)、2WikiMultiHopQA (2020)、MuSiQue (2022) 和 Bamboogle (2022)
    • 论文进一步在更具挑战性的基准测试上进行了评估,包括 Frames (2024)、GAIA (2023) 和 xBench-DeepSearch (2025) 作为额外的测试集
      • 从 HotpotQA、2WikiMultiHopQA 和 MuSiQue 的验证集中随机抽取 1000 个实例进行评估
      • 对于 Bamboogle、Frames、GAIA 和 xBench-DeepSearch,论文使用其完整的测试集
      • 对于 GAIA,论文使用来自纯文本验证子集 (2025) 的 103 个示例
  • 搜索工具 (Search Tools)
    • 论文在两种设置下评估搜索智能体,每种设置使用不同类型的搜索工具
      • 带有 RAG 的本地知识库 (local knowledge base with RAG)的交互:智能体与本地部署的 RAG 系统交互,从一个 Wikipedia 2018 语料库 (2020) 中检索相关信息
      • 基于网络的搜索和浏览 (web-based search and browsing) 的交互:智能体在交互式网络环境中运行,可以访问搜索引擎和浏览器工具
        • 对于更具挑战性的基准测试 GAIA、xBench-DeepSearch 和 Frames,论文仅在此基于网络的设置下进行评估
  • 基线 (Baselines)
    • 论文考虑与两类基准测试相对应的两组基线
      • 对于多跳和单跳问答基准测试,包括 Search-R1(7B/14B/32B) (2025)、R1-Searcher(7B) (2025)、Search-o1(QwQ-32B) (2025)、DeepResearcher (2025) 和 SimpleDeepSearcher (2025)
        • 还直接提示 Qwen-2.5-7B/32B 在不使用任何工具的情况下生成答案
      • 在更具挑战性的基准测试上,论文与强大的 32B 规模模型进行比较,包括直接使用 QwQ-32B 生成、Search-o1(QwQ-32B) (2025)、Search-R1-32B (2025)、WebThinker-QwQ (2025)、SimpleDeepSearcher-QwQ (2025) 和 WebDancer-32B (2025)
      • 所有基线都使用与论文智能体相同的工具进行评估,以确保公平比较
  • Evaluation metrics
    • 论文采用两个互补的评估指标:F1 分数和 LLM-as-Judge (LasJ)
    • F1 分数在词级别(Word Level)计算,衡量预测答案和参考答案之间的精确率和召回率的调和平均数
    • 对于 LLM-as-Judge,论文提示一个强大的 LLM (Qwen2.5-72B-Instruct) 根据特定任务的指令评估模型输出的正确性
    • 在 GAIA、xBench-DeepSearch 和 Frames 上,论文仅使用 LLM-as-Judge 并报告所有模型的 Avg@4 和 Pass@4 分数
  • ASearcher 的训练细节 (Training Details of ASearcher)
    • 轮次限制:7B 和 14B 模型为 32,ASearcher-Web-QwQ 为 128
    • 批次大小:7B 和 14B 模型为 128,ASearcher-Web-QwQ 为 64
    • 论文整理了两组训练数据,一组用于 7B/14B 训练,另一组用于 QwQ-32B 训练
      • 这两个数据集大小均为 35k 并已开源
      • ASearcher-Web-QwQ 的训练大约需要 7.6k H800 GPU 小时

Main Results

  • 论文在三种评估设置下展示了主要的实验结果:
    • (1) 在标准问答基准测试上使用带有检索增强生成 (RAG) 的本地知识库
    • (2) 在相同基准测试上使用基于网络的搜索和浏览
    • (3) 在更具挑战性的基准测试上使用基于网络的搜索和浏览
  • ASearcher ,实例化为 Qwen2.5-7B、Qwen2.5-14B 和 QwQ-32B,在 F1 和 LasJ 指标上始终优于相同模型规模的现有开源智能体
    • ASearcher-14B 在一系列多跳和单跳问答基准测试上取得了 7B、14B 和 32B 模型中的最佳性能,并且 ASearcher-QwQ 在这些具有挑战性的基准测试上显著优于几个规模相当的有力基线
    • 这些结果突显了 ASearcher 在不同任务和模型规模上的通用性和可扩展性
  • 在标准问答基准测试上使用带有 RAG 的本地知识库 (Local Knowledge Base with RAG on Standard QA Benchmarks)
    • 如表 2 所示,通过强化学习在本地知识库上训练的 ASearcher-Local,在一系列多跳和单跳问答基准测试上,在 7B 和 14B 规模上均取得了最佳性能
      • 在 7B 设置下,ASearcher 的平均 F1 达到 58.0 ,优于 Search-R1-7B (54.3) 和 R1-Searcher-7B (52.2) 等强基线
        • 其 LasJ 分数也达到 61.0 ,显著优于 Search-R1-7B (55.4) 和 R1-Searcher-7B (54.7)
      • 在 14B 规模上,增益更为显著,ASearcher-Local-14B 的 F1 达到 60.0 ,LasJ 达到 65.6 ,甚至超过了更大的 32B 基于检索的基线 Search-R1-32B
  • 在标准问答基准测试上使用基于网络的搜索和浏览 (Web-based Search and Browsing on Standard QA Benchmarks)。
    • 在表 3 中,论文在现实的基于网络的环境中评估智能体
    • 论文以 zero-shot 方式评估完全使用本地知识库训练的模型在网络设置中的表现,以直接检验通过 RL 学习的搜索策略的泛化能力
      • 在所有模型规模上,ASearcher 始终优于强基线
      • ASearcher-Web-14B 取得了最佳性能,平均 F1 为 61.5 ,超过了在此设置下最强的 32B 基线 SimpleDeepSearcher
      • ASearcher-Local-14B 模型在网络设置下测试时表现出强大的泛化能力,在 LasJ 指标上相对于相似或更大规模的所有基线模型均取得了显著增益
      • 这证实了 ASearcher 学习了可迁移到不同信息源的通用搜索策略
  • 在具有挑战性的基准测试上使用基于网络的搜索和浏览 (Web-based Search and Browsing on Challenging Benchmarks)
    • 表 4 显示了在需要高级问题解决能力和搜索策略的具有挑战性的问答任务上的实验结果
      • 这些基准测试专门设计用于评估智能体与真实网络交互并检索超出 LLM 内部知识的最新信息的能力
      • 因此,直接从模型(例如 QwQ-32B)生成答案在所有数据集上表现都很差
    • 论文的智能体 ASearcher-Web-QwQ 在 GAIA (52.8) 和 xBench-DeepSearch (42.1) 上取得了最佳的 Avg@4 分数
      • 优于之前的开源智能体最优水平
    • 这些结果进一步凸显了其在处理长视野规划、现实世界工具使用和开放领域探索方面的优越性
    • 除了 Avg@4,论文还报告了 Pass@4 分数,该分数计算智能体在 4 次试验中找到正确答案的问题比例
      ASearcher-Web-QwQ 在通过率方面也优于最先进的开源智能体
  • RL 训练的效果 (Effect of RL Training)
    • 如图 8 所示,ASearcher-Web-QwQ 在 GAIA、xBench-DeepSearch 和 Frames 上分别获得了 +9.1、+13.4 和 +12.0 的提升
    • 当考虑通过率(即 Pass@4)时,ASearcher-Web-QwQ 也获得了显著增益,尤其是在 xBench-DeepSearch 上提升了 17.0
    • 通过率的显著提升表明论文的训练流程训练智能体学习复杂的搜索策略,以执行精确搜索、提取关键信息并解决冲突信息

Training Dynamics

  • ASearcher-Local-7B/14B 的训练动态 (Training Dynamics of ASearcher-Local-7B/14B)
    • 在图 9 和图 10 中,论文分别绘制了 ASearcher-Local-7B 和 ASearcher-Local-14B 训练过程中生成的 Token 数量、搜索 Query 和网页浏览情况
    • 使用论文的训练方法,在 7B 和 14B 规模上都观察到了生成长度和工具调用次数的增加
      • 搜索 Query 次数扩展到 6 次,高于先前工作 (2025;) 报告的数字
    • 有趣的是,论文发现 7B 模型未能学习有效的网页浏览 ,而 14B 模型可以在训练后期学习访问网页来解决具有挑战性的问题
      • 论文假设 7B 模型在学习网页浏览方面的失败是因为模型容量太小 ,无法在零 RL 训练设置中稳定地学习总结冗长的网页
  • ASearcher-Web-QwQ 的训练动态 (Training Dynamics of ASearcher-Web-QwQ)
    • ASearcher-Web-QwQ 的训练动态如图 6 所示
    • 随着训练的进行,智能体学会执行更多的工具调用,在第 200 步左右达到约 40 次调用,峰值实例甚至达到 70 次调用
    • QwQ-32B 智能体通过训练生成了更多的 Token ,最多超过 150k 个 Token
    • 工具利用率和输出长度的这种扩展趋势突显了完全异步 RL 训练对于复杂现实世界智能体应用的潜力
      • 问题:这跟完全异步 RL 有什么关系?

Related Works

Search Agents

  • 一些工作已经构建了智能体工作流,使 LLM 能够利用外部工具来解决复杂任务
    • 著名的例子包括 Search-o1 (2025) 和 ReAgent (2025)
  • 基于提示的方法虽然对于快速开发有效,但根本上受到底层 LLM 能力的限制,并且无法通过环境反馈可靠地改进
  • 一些工作尝试为 LLM 构建 SFT 轨迹
    • 例如,(2023; 2024) 利用大 LLM 合成检索和推理轨迹来微调较小的模型
  • 最近,一些工作研究强化学习 (RL) 方法来增强 LLM-based 智能体,主要关注多跳问答基准测试,如 HotpotQA 和 2Wiki-Multihop
    • (2025;) 使用多跳问答数据进行 RL 训练,并观察到工具使用次数的增加
    • RAG-R1 (2025) 进一步结合了 SFT 和 RL 来增强搜索策略
  • 最近,研究人员开始关注更具挑战性的任务,通过 Offline RL (2025) 微调由大型推理模型 (LRM) 驱动的复杂基于提示的智能体,在具有真实网络数据的模拟轨迹上进行 SFT (2025;),以及为 RL 训练构建具有挑战性的问答对 (2025)

Synthetic Data for Search Agents

  • 除了依赖大规模人工标注,数据合成也已成为一种可扩展的方法来为搜索智能体准备训练数据
    • 一些方法通过与真实网页交互并使用 LRM 整理数据来生成合成但真实的问答轨迹 (2025;)
    • WebSailor (2025) 通过采样和模糊测试构建结构上具有挑战性的任务
    • WebShaper (2025) 利用集合论技术构建高质量的复杂问答对
  • ASearcher 开发了一个自主的 LLM 智能体来合成具有高不确定性的挑战性问答对,而不依赖复杂的知识图谱
    • ASearcher 中的数据合成智能体和合成训练数据都是完全开源的

附录 A:Full Case Study

  • 在本节中,论文对来自 GAIA (2023) 的一个极具挑战性的问题进行了详细的案例研究
    • 论文在图 11 中分析了 Search-R1-32B (2025) 和 Search-o1 (QwQ) (2025)
  • 示例问题的解决路径 (Solution Path of the Sample Question)
    • 在图 11 中,论文的案例研究针对一个在给定 2 个条件和 4 个未知变量的情况下寻找特定动物的问题进行
    • 为了识别正确答案,搜索智能体应首先根据条件 C1 找出提到的物种 U1 ,识别满足条件 C2 的正确文章 U2 ,然后找出 U3.1 和 U3.2 中列出的论文
    • 最后,正确答案应通过交叉引用文章 U2 和论文 U3.1&U3.2 来确定
    • 总结来说,这个示例具有挑战性主要有以下几个原因:
      • 高不确定性 (High Uncertainty): 问题涉及多个未知变量,这些变量可能指向许多不同的实体
        • 例如,2021 年的文章 U2 可能指向 2021 年发表的任何文章,并且只能在给定条件 C2 和肺泡物种 U1 的情况下确定
      • 对精确信息提取的要求 (Requirement for Exact Information Extraction): 为了找到答案,智能体应列出网页上提到的所有动物并进行跨文档比较
        • 这将要求智能体从海量、嘈杂的网络内容中精确提取关键信息,而不是简单地总结网页
      • 误导性答案 (Misleading Answers): 在解决此任务的过程中,可能会出现多个误导性答案,例如“猪”
        • 智能体应通过检查所有相关网页和文档中的预期答案来严格确认其结论
  • 现有的 Online RL 方法未能学习复杂的搜索策略 (Existing Online RL Approaches Fail to Learn Complex Search Strategies)
    • 在图 11 中,Search-R1-32B 无法将复杂 Query 分解为单独的组成部分,因此只进行了涉及过多未知信息的冗余 Query
      • 该智能体还存在严重的幻觉,产生了搜索结果不支持结论
      • 它未能解析所有未知变量
    • 此案例研究表明,现有的 Online RL 方法仅激励了初级的搜索策略
    • 同样值得注意的是,由于在训练期间轮次限制设置为较小的值(例如 4),模型仅表现出较短的工具使用视野
  • 基于提示的 LLM 智能体可能因 LLM 能力不足而失败 (Prompt-based LLM Agents Could Fail Due to Insufficient Capability of the LLM)
    • 在图 11 中,Search-o1 (QwQ) 可以通过大量的工具调用找到物种名称 U1 ,以及 2021 年的文章 U2 和论文 U3.1&U3.2
      • 但在尝试寻找答案时,Search-o1 (QwQ) 很容易遗漏关键信息
      • 因此,智能体得出了错误的结论
      • 而且,即使智能体找到了直接指向正确答案的信息,它仍然被先前错误的结论所误导
      • 最后,智能体无法验证先前结论的正确性
    • 这个案例研究表明,尽管一个未在智能体任务上明确训练的开源模型可以执行大量的工具调用 ,但它无法基于检索到的内容和历史上下文进行专家级的推理
  • ASearcher-Web-QwQ
    • 论文还分析了论文端到端 RL 智能体 ASearcher-Web-QwQ 的搜索策略
    • 如图 11 所示,ASearcher-Web-QwQ 将复杂 Query 分解为精确且聚焦的 Query
      • 与 Search-o1 (QwQ) 在每次搜索 Query 后访问大量网站不同,ASearcher-Web-QwQ 专注于访问最相关的网站
      • ASearcher-Web-QwQ 总结了网站的所有相关信息
        • 所有候选答案都被智能体列出并仔细分析
      • 当尝试在论文 U3.1&U3.2 中搜索相关事实时,智能体明确引用了关键信息
        • 当搜索结果没有直接指向期望的目标时,例如,当使用“Olga Tapia (U3.2) Hafnia alvei (U1) animal studies”进行搜索以查找与 Olga Tapia 论文相关的动物时,智能体没有得到明确的信息,但能够通过与其他论文 U3.1 建立联系来推断出正确答案
      • 在找到正确答案“Mice”之后,智能体在报告最终答案之前花费了额外的轮次来确认先前的结论
    • 总之,ASearcher 成功训练了一个展现出复杂行为的搜索智能体,这些行为体现了搜索智能:
      • 不确定性感知推理 (Uncertainty-aware reasoning): 智能体详尽地列出并检查所有不确定实体的可能性
      • 精确的关键信息提取 (Precise Key Information Extraction): 智能体能够从海量、嘈杂的网络内容中识别关键信息
      • 跨文档推理 (Cross-document Inference): 智能体能够通过建立多个文档之间的联系来推断关键结论
      • 严格确认 (Rigorous Confirmation): 智能体通过额外的工具调用来验证先前结论的正确性
1…192021…61
Joe Zhou

Joe Zhou

Stay Hungry. Stay Foolish.

608 posts
49 tags
GitHub E-Mail
© 2026 Joe Zhou
Powered by Hexo
|
Theme — NexT.Gemini v5.1.4