Jiahong 的个人博客

凡事预则立,不预则废


  • Home

  • Tags

  • Archives

  • Navigation

  • Search

NLP——VeRL框架使用总结

  • 参考链接:
    • 源码地址:github.com/volcengine/verl
    • 官方教程文档:https://verl.readthedocs.io/
      • 官方配置链接:verl.readthedocs.io/en/latest/examples/config.html
      • 官方调优链接:verl.readthedocs.io/en/latest/perf/perf_tuning.html
      • 官方设备调优链接:verl.readthedocs.io/en/latest/perf/device_tuning.html
      • 其他官方示例文档:
        • https://github.com/volcengine/verl/blob/main/examples/tuning/7b/qwen2-7b_grpo_2_h800_fsdp_vllm.sh
        • https://github.com/volcengine/verl/blob/main/examples/tuning/14b/qwen2-14b_grpo_4_h800_fsdp_vllm.sh
        • https://github.com/volcengine/verl/blob/main/examples/grpo_trainer/run_qwen2-7b.sh
    • DeepWiki 解读:deepwiki.com/volcengine/verl
    • 官方公开讲座(青稞社区):verl 源码解读与 HybridFlow 编程范式讲解
    • 字节跳动Seed官方解读:最高提升20倍吞吐量!豆包大模型团队发布全新 RLHF 框架,现已开源!
    • 相关论文:HybridFlow: A Flexible and Efficient RLHF Framework, EuroSys 2025, HKU & ByteDance
      • 论文解读:HybridFlow / veRL 原文浅析 - Chayenne Zhao的文章 - 知乎
    • 其他解读:
      • 基于 Ray 的分离式架构:veRL、OpenRLHF 工程设计 - 杨远航的文章 - 知乎
      • verl:一个集SFT与RL于一体的灵活大模型post-training框架 (快速入门) - Cyril-KI的文章 - 知乎
      • [AI Infra] VeRL 框架入门&代码带读 - 不关岳岳的事的文章 - 知乎
      • 跟着 verl 代码学习 GRPO 算法流程 - 想当大侠的文章 - 知乎
      • 跟着 verl 代码学习 PPO 算法流程 - 想当大侠的文章 - 知乎
      • 从零开始的verl框架解析 - Nasusu的文章 - 知乎
    • verl 参数速览 - Chayenne Zhao的文章 - 知乎
    • 不错的系列文章:
      • RLHF Infra — Verl 学习(一):Overview - swtheking的文章 - 知乎
      • RLHF Infra — Verl 学习(二):Initialization - swtheking的文章 - 知乎
      • RLHF Infra — Verl 学习(三):Sample Generation - swtheking的文章 - 知乎
      • RLHF Infra — Verl 学习(四): Train Data Organize & Reward Model - swtheking的文章 - 知乎
      • RLHF Infra — Verl 学习(五): Review Verl - swtheking的文章 - 知乎
      • RLHF Infra — Verl 学习(六)Fully Async Policy Trainer - swtheking的文章 - 知乎

环境安装

  • 参考链接:verl.readthedocs.io/en/latest/start/install
  • 建议使用 docker 镜像安装方式,亲测本地直接安装坑很多,且安装后还会陆陆续续出现错误
  • 特别注意:官方镜像加载后还需要执行本地安装 pip3 install --no-deps -e .
    • 不执行这一步会提示 verl 库找不到
    • 建议将代码拉到本地 host 机器,然后用镜像挂载 host 路径
  • 注:官方镜像可能缺失一些依赖包,比如我就遇到缺少 vllm 库,遇到这种情况直接安装即可
    • 最新测试过可用的镜像为:verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.0-fa2.7.4,仅需要自己安装一个 vllm 即可,还有个较小的包按需要安装
      1
      2
      3
      4
      5
      sudo docker create --gpus all --net=host --shm-size="10g" --cap-add=SYS_ADMIN -v ../verl:/workspace/verl -v ~/llm:/workspace/llm --name verl verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.0-fa2.7.4 sleep infinity
      sudo docker start verl
      sudo docker exec -it verl bash
      cd verl && pip3 install --no-deps -e .
      sudo docker stop verl

模型训练

  • Quick Start 可参考:verl.readthedocs.io/en/latest/start/quickstart
  • 多节点启动:verl.readthedocs.io/en/latest/start/multinode

源码阅读

verl 库的目标

  • 将原始问题建模为一个有向图 DataFlow 问题
  • 统一实现,让算法开发者仅需要考虑自身的代码优化即可

数据流的流向过程

  • 原始论文的图片

SPMD 的初始化

  • 在 RayPPOTrainer.init_workers() 内找到相关流程
  • 对每个 资源池分别初始化(for resource_pool, class_dict in self.resource_pool_to_cls.items():)
  • 每个资源池进行如下操作(self.ray_worker_group_cls)
    • 进一步地,执行函数 self._init_with_resource_pool
    • for 循环依次处理每个 GPU(每个 GPU 启动一个进程),每个进程配置好对应的分别是环境变量
    • 每个 GPU 对应一个 worker

数据的分发是如何实现的

  • 每个 Worker 的函数都会接受来自上游的数据,处理数据并输出
    • 注意传入每个 Worker 的数据已经是分布式处理过的,仅仅是 1/WORLD_SIZE,这里的数据分发是使用 @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) 定义的
    • @register 是一个注解,用于实现数据分发过程和收集过程,dispatch_modee=Dispatch.DP_COMPUTE_PROTO 会对应的绑定两个函数(分别负责分发和收集)

每个 Worker 的大致工作流程(Multi Controller 逻辑核心)

  • 具体函数:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    def func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, blocking):
    class Functor:
    def __call__(this, *args, **kwargs):
    args, kwargs = dispatch_fn(self, *args, **kwargs)
    padding_count = kwargs.pop(_padding_size_key, 0)
    output = execute_fn(method_name, *args, **kwargs)
    if blocking:
    output = ray.get(output)
    output = collect_fn(self, output)
    if padding_count > 0:
    if isinstance(output, DataProto):
    indices = [i for i in range(len(output))][:-padding_count]
    output = output.select_idxs(indices)
    elif isinstance(output, list):
    output = output[:-padding_count]
    return output

    # use class type to pass the method_name to get a better observability
    return type(method_name, (Functor,), {})()
  • 核心函数名为 func_generator,这个函数会接受5个参数 method_name,dispatch_fn,collect_fn,execute_fn, blocking

    • dispatch_fn 负责 dispatch 参数
    • execute_fn 负责根据 dispatch 后的参数调用 method_name 函数(使用 getattr 方法实现)
    • blocking 决定在这里是否等待 execute_fn 执行完成
    • collect_fn 负责收集 collect_fn 函数返回的分组结果
  • 注:原始代码中的 type(method_name, (Functor,), {})() 表示:动态创建一个名为 method_name、继承自 Functor 且无自定义属性的类,然后实例化该类,最终得到一个继承自 Functor 的实例对象


verl 编程接口

数据集修改(最简单)

  • 保证数据集符合 verl 的格式即可,verl 要求数据是 .parquet 格式,且包含下面 5 列

    • prompt:是一个 message list,每个 message 是 {"role":"...", "content": "..."} 的格式
    • data_source:数据来源,比如 gsm8k 来自 openai/gsm8k
    • ability:数据分类,比如 gsm8k 属于 math 类
    • reward_model:是一个字典,比如 {'ground_truth': '72', 'style':'rule'} 说明使用规则型 reward 模型
    • extro_info:是一个字典,作为额外的信息在训练中使用,可以包含一些自定义信息,比如 PPO 官方示例中的 gsm8k 数据处理就是将 prompt 的 answer 放进去了,完整格式为:{'answer': '...', 'index': 0, 'question': '[原始问题]', 'split': 'train'}
      • 注:extro_info 的 [原始问题] 比 prompt 的 content 少一些模板内容
  • 注:支持 VLM 时,需要 images 和 videos 这样的字段

  • 注:建议使用 pandas 加载数据后多看:

    1
    2
    import pandas as pd
    df = pd.read_parquet(file_path)
  • 数据处理的参考模板见:examples/data_preprocess/ 目录下,比如 gsm8k 数据集的处理文件是 examples/data_preprocess/gsm8k.py

  • 特别地:还可以自定义数据类,通过参数将定义类的 Python 文件路径和类名传入并注册到 verl 中即可,详情见:verl 源码解读与 HybridFlow 编程范式讲解:40:06

自定义 Reward

  • reward fuction 的参数定义:

    1
    2
    3
    4
    5
    custom_reward_function:
    path: null # 指定源码路径
    name: compute_score # 指定函数
    reward_model:
    reward_manager: naive # 指定 reward_manager 类 NaiveRewardManager
  • 可以通过参数传入,示例如下:

    1
    2
    3
    --custom_reward_function.path=./examples/reward_fn/custom_reward_fn.py \
    --custom_reward_function.name=compute_score \
    --reward_model.reward_manager=naive
  • 函数定义可参考 NaiveRewardManager 类的定义

自定义损失函数

  • 全局搜索找到 .backward() 函数调用的地方,这里就是损失定义的地方
    • 在这里可以修改函数 compute_policy_loss
    • 也可以添加其他损失项,比如 交叉熵损失
  • verl 的损失函数定义方式和 llama_factory 的模板类有点相似,是通过将 loss 注册到 POLICY_LOSS_REGISTRY: dict[str, PolicyLossFn] = {} 中实现的
  • 可以通过修改 POLICY_LOSS_REGISTRY: dict[str, PolicyLossFn] = {} 所在文件增加自己的损失函数

修改整个训练逻辑(最复杂)

  • 核心是修改 fit 函数
  • DAPO 的实现类 RayDAPOTrainer 就是继承 RayPPOTrainer 后实现的
    1
    2
    3
    4
    5
    6
    7
    8
    class RayDAPOTrainer(RayPPOTrainer):
    """
    Note that this trainer runs on the driver process on a single CPU/GPU node.
    """

    def fit(self):
    # ...
    # DAPO 的 fit 实现

模型融合

  • 训练完成模型是按照 GPU,以分片的形式存储的,所以需要进行模型融合

    1
    2
    3
    4
    python3 -m verl.model_merger merge \
    --backend fsdp \
    --local_dir checkpoints/verl_examples/gsm8k/global_step_410/actor \
    --target_dir checkpoints/verl_examples/gsm8k/global_step_410/actor/huggingface
    • 将模型路径替换为目标路径
    • 融合结果会存储到 target_dir 下,即 huggingface 目录下,执行后会存贮 model.safetensors 文件和 tokenizer.json 文件
  • 注:模型融合不一定需要安装 verl 的所有的依赖,某些情况下,安装所有 verl 依赖是昂贵的,容易出错,建议简单安装(参考:官方安装说明)

    1
    2
    3
    4
    5
    # 安装底层框架依赖
    USE_MEGATRON=0 bash scripts/install_vllm_sglang_mcore.sh # 仅使用 FSDP,不适用 Megatron(Megatron 安装容易出错)

    # 安装 verl
    pip install --no-deps -e . # 不安装依赖,在使用模型融合命令时遇到缺失的再安装,否则安装依赖容易出错

使用 verl 进行模型评估

  • 评估分成生成回答和评估结果两个部分

  • 生成回答

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    python3 -m verl.trainer.main_generation \
    trainer.nnodes=1 \
    trainer.n_gpus_per_node=2 \
    data.path=/path/to/test.parquet \
    data.prompt_key=prompt \
    data.batch_size=1024 \
    data.n_samples=8 \
    data.output_path=/path/to/output.parquet \
    model.path=/path/to/model \
    rollout.temperature=0.6 \
    rollout.top_p=0.95
    • 注意:这里会为每个 Prompt 生成 8 个样本
    • 路径替换为目标模型和目标输出文件名(注意:输出必须到文件名)
  • 评估结果

    1
    2
    3
    4
    5
    6
    python3 -m recipe.r1.main_eval \
    data.path=/path/to/output.parquet \
    data.prompt_key=prompt \
    data.response_key=responses \
    custom_reward_function.path=./recipe/r1/reward_score.py \
    custom_reward_function.name=reward_func
    • 注意:原始的 ./recipe/r1/reward_score.py 文件中不含有 gsm8k 数据集,只需要在数学类型中加入 “openai/gsm8k” 即可
    • 执行该命令可能需要安装 math-verify 包,执行 pip install math-verify 即可
  • 亲测:对 Qwen2.5-0.5B-Instruct 模型在 gsm8k 上训练,从 step=30 到 step=410 (batch_size=256, epoch=15),测试集上的准确率从 0.45 提升至 0.53 左右


附录:如何传入多个数据集?

  • 传入下面的参数?
    1
    2
    3
    4
    5
    train_files="['$train_data_path1','$train_data_path2']"
    test_files="['$valid_data_path1','$valid_data_path2']"

    data.train_files="$train_files" \
    data.val_files="$test_files" \

附录:其他注意事项或技巧

  • 控制保留的 ckpt 数量
    1
    2
    trainer.max_actor_ckpt_to_keep=10
    trainer.max_critic_ckpt_to_keep=10

附录:错误记录

HTTPRequestEntityTooLarge 错误

  • 问题详情:HTTPRequestEntityTooLarge: Request Entity Too Large
  • 原因:Ray 打包文件上传时上传了太多东西,导致实体过大,需要在 verl/trainer/runtime_env.yaml 中增加需要移出的文件 至 excludes
    • 一般都是 *.safetensors 相关的文件导致
  • 详情参考:github.com/volcengine/verl/issues/696

NCCL 错误

  • 表现是单机多卡没错误,多机多卡就会出现错误,错误信息为:

    1
    torch.distributed.DistBackendError: NCCL error in: /pytorch/torch/csrc/distributed/c10d/NCCLUtils.hpp:268, unhandled system error (run with NCCL_DEBUG=INFO for details), NCCL version 2.21.5
  • 一般是 NCCL 相关的环境变量配置有问题,需要检查一下,被修改过后成功运行的参数包括

    1
    2
    3
    4
    NCCL_SOCKET_IFNAME
    NCCL_SOCKET_IFNAME
    NCCL_IB_DISABLE
    NCCL_NET_GDR_LEVEL
  • 注:分布式训练中经常遇到 NCCL 相关的错误,下面是 NCCL 相关的官方错误说明:docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html


附录:特殊参数说明和记录

  • log_prob_micro_batch_size_per_gpu:表示 ref 或 rollout(actor) 一次前向推理时的真实 样本数

    from https://verl.readthedocs.io/en/latest/examples/config.html#actor-rollout-reference-policy
    The batch size for one forward pass in the computation of ref_log_prob. The value represent the local num per gpu.

    • actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu 表示 ref_log_prob 的配置
    • actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu 表示 log_prob 的配置
    • 注:log_prob 的计算是一个前向过程,但 batch_size 较大时显存会比较大,所以进一步进行拆分
  • 注:更多 batch_size 相关介绍:
    • 参考链接:聊聊verl中的batch_size
  • mini_batch,ppo_mini_batch_size(mini_batch_size) :一个 mini_batch 表示一次 PPO 参数更新
  • micro_batch,ppo_micro_batch_size_per_gpu :一次前向/反向过程的批次大小,多个 micro_batch 会累加梯度,直到足够一次 mini_batch 再更新一次模型

NLP——vLLM使用相关笔记

  • 参考链接:
    • GitHub 地址:github.com/vllm-project/vllm
    • 文档地址:Welcome to vLLM
    • 中文文档地址:vLLM 中文站

vLLM 采样参数:SamplingParams

  • SamplingParams 是控制模型“如何生成”的核心对象,常常包含下面的几个参数
    • n : 每个输入提示生成的输出序列数量(默认为 1)
    • best_of : 从生成的一组序列中选择最好的 k 个(用于集束搜索等)
    • temperature : 采样温度,控制随机性;0 表示贪心采样(确定性),值越高越随机
    • top_p : 核采样概率阈值,控制候选词的累积概率
    • top_k : 仅从概率最高的 k 个 token 中采样
    • max_tokens : 每个输出序列生成的最大 token 数
    • stop : 停止生成的字符串列表(遇到这些词即停止)
    • ignore_eos : 是否忽略结束符(EOS),强制生成直到达到最大长度
  • 更多详细参数见附录

输入格式:Prompts

  • vLLM 支持两种形式的输入,可以在同一个 batch 中混合使用:
    • 直接传入字符串,例如 "Hello, world"
      • vLLM 会自动调用内置 Tokenizer 进行编码
    • 传入已经编码好的 Token ID 列表
      • 这在需要自定义 Tokenizer 或复用已编码数据时非常有用
  • 还可以在一个列表中混合输入以上两种输入

vLLM 使用示例

  • 本文将通过三个维度的代码示例来展示 vLLM 的核心能力:
    • 高层同步接口 (LLM) :最常用的离线批量推理方式
    • 高层异步接口 (AsyncLLM) :适用于构建高并发服务的异步流式处理
    • 底层引擎接口 (LLMEngine) :展示如何手动控制调度循环 (Step-level control)

离线批量推理:LLM

  • 这是最简单的用法,适用于处理数据集
  • LLM 类封装了引擎的初始化和调度循环:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    from vllm import LLM, SamplingParams

    # # 初始化 LLM
    # tensor_parallel_size: 使用的 GPU 数量
    # gpu_memory_utilization: 显存占用比例 (0.0 - 1.0)
    llm = LLM(
    model="path_to_model",
    tensor_parallel_size=1,
    gpu_memory_utilization=0.9
    )

    # # 定义采样参数
    # sampling_params_greedy = SamplingParams(temperature=0, max_tokens=10) # 贪心采样策略
    sampling_params_creative = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=50)

    prompts = [
    "Hello, my name is", # 索引 0
    "The capital of France is" # 索引 1
    ]

    # # 执行批量解码 (Batch Decoding)
    # generate 函数是同步阻塞的,直到所有请求完成
    outputs = llm.generate(prompts, sampling_params_creative)

    # # 处理输出结果
    for i, output in enumerate(outputs):
    prompt = output.prompt
    # output.outputs 是一个列表,包含 'n' 个生成的序列 (这里 n=1,每个 Prompt 仅生成一个)
    generated_text = output.outputs[0].text

    print(f"--- Sample {i+1} ---")
    print(f"Prompt: {prompt!r}")
    print(f"Generated: {generated_text!r}")
    print(f"Finish Reason: {output.outputs[0].finish_reason}") # e.g., 'stop', 'length'

异步流式推理:AsyncLLM

  • AsyncLLM 是 LLM 的异步版本,基于 AsyncLLMEngine 构建
  • AsyncLLM 允许你在 Python 的 asyncio 循环中非阻塞地提交请求并获取结果,非常适合搭建 API 服务
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    import asyncio
    from vllm import AsyncLLM, SamplingParams
    from vllm.engine.arg_utils import AsyncEngineArgs

    async def run_async_inference(): # 使用 async 关键字定义一个协程函数
    # 1. 配置引擎参数
    # AsyncEngineArgs 允许更细粒度地控制引擎行为,如 max_num_seqs (最大并发序列数)
    engine_args = AsyncEngineArgs(
    model="path_to_model",
    tensor_parallel_size=1,
    disable_log_requests=True
    )

    # 2. 初始化异步 LLM
    # AsyncLLM 内部维护了一个后台循环来处理请求
    llm = AsyncLLM.from_engine_args(engine_args)

    # 3. 定义采样参数
    sampling_params = SamplingParams(temperature=0.7, max_tokens=20)

    # 4. 定义异步生成任务
    # request_id 是必须的,用于在引擎内部追踪请求,需保证唯一性
    async def generate_stream(request_id, prompt):
    results_generator = llm.generate(
    prompt,
    sampling_params,
    request_id=request_id
    )

    # 异步迭代生成结果 (Streaming)
    final_output = None
    async for request_output in results_generator:
    # 这里可以实现流式推送到前端
    final_output = request_output

    return final_output

    # 5. 模拟并发请求 (多样本解码)
    # 同时发送文本提示和 Token 提示
    tasks = [
    generate_stream("req_001", "To be or not to be,"),
    generate_stream("req_002", "The capital of France is") # TokensPrompt
    ]

    # 等待所有任务完成
    results = await asyncio.gather(*tasks)

    for res in results:
    print(f"Request ID: {res.request_id}")
    print(f"Output: {res.outputs[0].text}")

    # 运行异步主函数
    if __name__ == "__main__":
    asyncio.run(run_async_inference())

底层引擎手动调度:LLMEngine

  • LLMEngine 是 vLLM 最底层的核心
    • 通常用户不需要直接操作它,除非你需要极度定制化的调度逻辑(例如自定义 Web Server 或特殊的强化学习循环)
  • LLM 类本质上就是在这个类外面包了一层 while 循环
  • 这个示例展示了 vLLM 内部是如何通过 step() 函数一步步完成推理的
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    from vllm import LLMEngine, SamplingParams, RequestOutput
    from vllm.engine.arg_utils import EngineArgs
    from vllm.utils import random_uuid

    def run_core_engine_loop():
    # 1. 初始化引擎参数与实例
    engine_args = EngineArgs(model="path_to_model")
    engine = LLMEngine.from_engine_args(engine_args)

    sampling_params = SamplingParams(temperature=0, max_tokens=10)

    # 2. 手动添加请求 (Add Requests)
    # 必须手动管理 request_id
    engine.add_request( # 注意:add_request 函数不会启动推理,需要等待 step 函数来执行
    request_id="req_text",
    prompt="Artificial Intelligence is",
    sampling_params=sampling_params
    )

    engine.add_request(
    request_id="req_text",
    prompt="The capital of France is",
    sampling_params=sampling_params
    )

    # 3. 手动执行调度循环 (The Step Loop)
    # 只要引擎中还有未完成的请求,就继续循环
    while engine.has_unfinished_requests():
    # step() 执行一次推理迭代:
    # 1. 调度器决定哪些请求进入 GPU 计算
    # 2. 执行模型的前向传播 (Model Forward)
    # 3. 采样下一个 Token
    # 4. 更新 KV Cache
    request_outputs: list[RequestOutput] = engine.step() # 注意 step 是一次仅采样一个 Token!streaming 也是借助 step 函数实现的;平时不需要 step 函数是因为封装到底层了

    # 打印当前步的中间结果 (Streaming 效果)
    for output in request_outputs:
    if output.finished:
    print(f"[{output.request_id}] Finished: {output.outputs[0].text}")
    else:
    # 仅打印当前生成的最新 token(简化展示)
    # 实际 output.outputs[0].text 包含完整的累积文本
    pass

    # 运行
    if __name__ == "__main__":
    run_core_engine_loop()

vLLM 部署及参数说明

部署命令

  • 推荐的 Linux 启动命令(可根据实际情况修改路径和显卡数量):

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    vllm serve /data/models/Llama-3-8B-Instruct \
    --served-model-name llama3-8b \
    --host 0.0.0.0 \
    --port 8000 \
    --dtype auto \
    --tensor-parallel-size 1 \
    --pipeline-parallel-size 1 \
    --gpu-memory-utilization 0.90 \
    --swap-space 4 \
    --max-model-len 8192 \
    --max-num-seqs 256 \
    --max-num-batched-tokens 8192 \
    --trust-remote-code \
    --enable-chunked-prefill \
    --disable-custom-all-reduce \
    --quantization awq \
    --enforce-eager \
    --api-key "sk-your-secure-password"
  • TLDR:参数配置建议:

    • 如果追求极致吞吐量(Throughput) :增大 --max-num-batched-tokens
      • 这允许一次性处理更多数据,但可能会导致生成过程中的停顿感(因为大批量的预填充会抢占计算资源)
    • 如果追求低延迟和流畅度(Latency) :建议保持适中的 --max-num-batched-tokens,并 开启 --enable-chunked-prefill
      • 这样可以将大的预填充任务打散,避免计算尖峰,确保正在生成的对话不会卡顿
    • 显存限制 :请注意,max-num-batched-tokens 的大小直接影响 KV Cache 的瞬时显存需求
      • 如果该值过大而显存不足,可能会触发 OOM 或强制调度器减少并发序列数(--max-num-seqs)

参数详细解析

  • vllm serve <path_to_model>
    • 这是 vLLM 的启动入口命令
    • 后面的 path_to_model 路径是模型在本地文件系统中的绝对路径(也可以是 Hugging Face 的模型 ID)
  • --served-model-name <model_name>
    • 指定服务对外显示的名称,建议使用类似 “llama-8b” 等类似名称标注
    • 当客户端调用 OpenAI 兼容 API 时,model 字段需要匹配这个名字
    • 如果不设置,默认使用模型路径作为名字
  • --host 0.0.0.0
    • 指定服务绑定的 IP 地址
    • 0.0.0.0 表示允许来自任何网络接口的连接(对外网开放);如果仅限本地访问,可设置为 127.0.0.1
  • --port 8000
    • 指定服务监听的端口号
  • --dtype auto
    • 指定模型权重的加载精度
    • 设置为 auto 时,vLLM 会根据配置文件(config.json)自动检测(通常是 float16 或 bfloat16)
    • 也可以强制指定为 float16、bfloat16 或 float32
  • --tensor-parallel-size 1 (TP)
    • 张量并行度,即把一个模型的层拆分到几张显卡上并行计算
    • 通常设置为单机内的 GPU 数量
  • --pipeline-parallel-size 1 (PP)
    • 流水线并行度,即把模型的不同层分配到不同的显卡上
    • 通常用于模型过大,单卡显存塞不下且 TP 无法解决时
    • 注:模型部署不建议开 PP,使用 TP 即可,一般情况下 PP 保持为 1
  • --gpu-memory-utilization 0.90
    • GPU 显存使用率上限,注意:这是影响并发能力的核心参数
    • vLLM 会预先占用这部分比例的显存(此处为 90%)
      • 其中一部分用于加载模型权重,剩余的所有空间都会被预分配给 KV Cache(键值缓存)
    • 如果设得太高容易 OOM(显存溢出),设得太低则浪费显存,导致并发量上不去
  • --swap-space 4 (新增重要参数)
    • CPU 交换空间大小(单位:GiB)
    • 当 GPU 显存不足以存放 KV Cache 时,vLLM 会将部分 KV Block 换出到 CPU 内存中
    • 设置此参数可以防止在请求突发高峰时发生 OOM 崩溃
  • --max-model-len 8192
    • 模型的最大上下文长度(输入+输出)
      • 如果未指定,vLLM 会尝试从模型配置中读取
    • 显式指定可以限制显存占用,避免处理过长的序列导致崩溃
  • --max-num-seqs 256
    • 最大并发序列数,即同一时刻 vLLM 能处理的请求数量(Batch Size)
    • 这个值越高,吞吐量越大,但每个请求的延迟可能会增加
  • --max-num-batched-tokens 8192
    • 一次迭代(iteration)中处理的最大 Token 总数
    • 这包括了 Prefill(预填充)阶段和 Decode(解码)阶段的所有 Token
    • 通常默认为 max(max_model_len, 2048),建议根据卡的性能灵活配置,以最大化效率
  • --trust-remote-code
    • 允许执行模型仓库中的自定义 Python 代码
    • 对于某些非标准架构的模型(如 ChatGLM、Qwen 的早期版本等),必须开启此选项才能正确加载模型架构
  • --enable-chunked-prefill
    • 为了解决长 Prompt 导致的“队头阻塞”问题(即一个超长 Prompt 占满计算资源,导致短请求延迟增加),引入了分块预填充机制
    • 开启分块预填充,这是一个优化参数,允许将长 Prompt 的 Prefill 阶段拆分成多个小块,与 Decode 阶段混合调度
    • 这可以显著降低长文本输入时的首字延迟(TTFT) ,因为允许解码(Decode)任务和预填充(Prefill)任务更平滑地交错执行,显著降低了其他并发请求的 Inter-Token Latency(ITL,Token 间延迟),使生成过程更加流畅
  • --max-num-partial-prefills
    • 并发预填充数,当启用了分块预填充(Chunked Prefill)后,这个参数变得非常重要
    • 限制了在同一时刻,有多少个请求可以处于“部分预填充”状态,
    • 默认为 1:意味着在任何给定的迭代中,调度器最多只允许 1 个请求进行部分预填充计算(与其他正在解码的请求并行),这有助于防止过多的上下文切换开销,同时保证显存管理的稳定性
  • --long-prefill-token-threshold
    • 长预填充阈值,这是一个辅助参数,用于配合分块预填充使用
    • 定义了多少 Token 数量的 Prompt 被视为“长请求”
    • 当 Prompt 长度超过此阈值时,vLLM 才会考虑对其应用特殊的调度策略或分块逻辑。默认值为 0,意味着所有请求都遵循统一的规则
  • --disable-custom-all-reduce
    • 禁用 vLLM 自定义的 All-Reduce 内核
    • 通常在某些 GPU 架构不支持或驱动不兼容导致 NCCL 通信错误时使用
    • 如果硬件环境标准,通常不需要加这个,但在排查多卡通信问题时很有用
  • --quantization awq (新增重要参数)
    • 指定量化格式
    • 如果模型是量化版本(如 AWQ, GPTQ, SqueezeLLM),必须指定此参数
    • 例如加载 Llama-3-8B-AWQ 时,需设置为 awq
    • 如果是非量化模型,请去掉此参数
  • --enforce-eager (新增重要参数)
    • 强制使用 PyTorch 的 Eager 模式,禁用 CUDA Graph
    • 虽然 CUDA Graph 能加速小 Batch 的推理,但在某些特定显卡或驱动版本上可能会导致显存分配错误或死锁
    • 开启此项有助于调试和提高稳定性
  • --api-key "sk-your-secure-password" (新增重要参数)
    • 设置访问 API 的密钥
    • 在生产环境中,为了防止未授权访问,配置 API Key 是必须的安全措施
    • 客户端请求头需携带 Authorization: Bearer sk-your-secure-password

附录:如果是量化模型,不添加 --quantization 参数 会怎样?

  • 以 W8A8-QuaRot(Weight 8-bit / Activation 8-bit,使用了 QuaRot 旋转算法进行离群值抑制)量化模型 为例,如果在启动 vLLM 时不指定 --quantization 参数,通常会发生以下三种情况之一(具体取决于模型的 config.json 配置和 vLLM 的版本)
  • 注:QuaRot 是一种算法技术,它生成的模型最终通常以 FP8 (E4M3/E5M2) 或 Int8 的格式存储
情况1:直接报错并无法启动(最常见的情况)
  • 这是最可能发生的结果
  • vLLM 启动时会读取模型的 config.json
    • 如果该配置文件中包含 quantization_config 字段(例如标记为 fp8、compressed-tensors 或自定义格式),但 vLLM 在默认模式下无法自动匹配到合适的 Kernel(内核),或者检测到硬件不支持该量化格式(例如在非 Hopper 架构显卡上加载 FP8),程序会直接抛出 ValueError 或 RuntimeError
  • 终端会打印类似 ValueError: Unknown quantization method... 或 RuntimeError: Shape mismatch... 的错误日志,服务启动失败
情况2:加载成功但输出乱码,Garbage Output
  • 这种情况比较危险,因为它看起来“跑起来了”,但完全不可用
  • 如果模型的 config.json 中缺失 了量化相关的元数据,或者 vLLM 错误地将其识别为标准模型,它会尝试以默认精度(通常是 float16 或 bfloat16)来解释权重数据
  • 数据曲解:原本是 8-bit 的整数或 FP8 数据,被当成了 16-bit 的浮点数读取
  • 模型可以接受输入,但吐出来的全是乱码、重复符号或毫无逻辑的字符
情况3:自动识别成功(理想情况,但有前提)
  • 在较新的 vLLM 版本中,如果模型打包规范(例如使用 llm-compressor 或 AutoGPTQ 正确导出),config.json 中会有明确的 quantization_config 字段
  • config.json 里的 quantization 字段(如 fp8 或 compressed-tensors)被当前版本的 vLLM 原生支持
  • 如果显卡支持该精度(例如 w8a8 的 QuaRot 通常对应 FP8 ,这通常需要 NVIDIA Ada Lovelace (RTX 4090) 或 Hopper (H100) 架构的 GPU)
  • 此时即使你不写 --quantization,vLLM 也会根据配置文件自动启用对应的量化内核,服务正常运行
推荐做法
  • 先看模型文件夹下的 config.json,寻找 quantization_config 字段

  • 如果格式是 FP8(常见于 QuaRot 转换的模型):

    1
    --quantization fp8 --kv-cache-dtype fp8
    • 注意:支持 fp8 通常需要 H100/L40/RTX4090 等新显卡
  • 如果格式是 Compressed-Tensors / Neural Magic 格式:

    • vLLM 通常能自动识别,但如果报错,可能需要指定:
      1
      --quantization compressed-tensors

附录:SamplingParams 参数项详解

  • vLLM 的 SamplingParams 参数很多,覆盖了多个方面:
    • 从基础生成控制(长度、终止)
    • 采样策略(随机性、候选集)
    • 重复控制(惩罚)
    • 输出格式(detokenize、 Special Token )
    • 高级自定义(logits 处理器、结构化输出)的全维度参数
  • 这些参数既兼容 OpenAI API 规范,又扩展了 beam search、结构化输出、不良词过滤等特有功能
  • 一些简单的常用理解:
    • 追求确定性可以配置:temperature=0 + top_k=1;
      • 问题:temperature=0 其实就已经是贪心采样了,但是我们一般还是会使用 top_k=1 进一步明确 贪心采样
    • 追求多样性可以配置:temperature=0.7 + top_p=0.9;
      • 理解:temperature=0.7 + top_p=0.9 是很常用的参数
    • 避免重复:presence_penalty=0.5 + frequency_penalty=0.3;
      • presence_penalty 惩罚是否出现过
      • frequency_penalty 惩罚出现频次

SamplingParams 源码配置

  • 以下源码参考自:github.com/vllm-project
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    class SamplingParams(
    PydanticMsgspecMixin,
    msgspec.Struct,
    omit_defaults=True, # type: ignore[call-arg]
    # required for @cached_property.
    dict=True,
    ): # type: ignore[call-arg]
    """Sampling parameters for text generation.

    Overall, we follow the sampling parameters from the OpenAI text completion
    API (https://platform.openai.com/docs/api-reference/completions/create).
    In addition, we support beam search, which is not supported by OpenAI.
    """

    n: int = 1
    """Number of outputs to return for the given prompt request.

    NOTE:
    `AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs
    are generated and streamed cumulatively per request. To see all `n`
    outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY`
    in `SamplingParams`."""
    presence_penalty: float = 0.0
    """Penalizes new tokens based on whether they appear in the generated text
    so far. Values > 0 encourage the model to use new tokens, while values < 0
    encourage the model to repeat tokens."""
    frequency_penalty: float = 0.0
    """Penalizes new tokens based on their frequency in the generated text so
    far. Values > 0 encourage the model to use new tokens, while values < 0
    encourage the model to repeat tokens."""
    repetition_penalty: float = 1.0
    """Penalizes new tokens based on whether they appear in the prompt and the
    generated text so far. Values > 1 encourage the model to use new tokens,
    while values < 1 encourage the model to repeat tokens."""
    temperature: float = 1.0
    """Controls the randomness of the sampling. Lower values make the model
    more deterministic, while higher values make the model more random. Zero
    means greedy sampling."""
    top_p: float = 1.0
    """Controls the cumulative probability of the top tokens to consider. Must
    be in (0, 1]. Set to 1 to consider all tokens."""
    top_k: int = 0
    """Controls the number of top tokens to consider. Set to 0 (or -1) to
    consider all tokens."""
    min_p: float = 0.0
    """Represents the minimum probability for a token to be considered,
    relative to the probability of the most likely token. Must be in [0, 1].
    Set to 0 to disable this."""
    seed: int | None = None
    """Random seed to use for the generation."""
    stop: str | list[str] | None = None
    """String(s) that stop the generation when they are generated. The returned
    output will not contain the stop strings."""
    stop_token_ids: list[int] | None = None
    """Token IDs that stop the generation when they are generated. The returned
    output will contain the stop tokens unless the stop tokens are special
    tokens."""
    ignore_eos: bool = False
    """Whether to ignore the EOS token and continue generating
    tokens after the EOS token is generated."""
    max_tokens: int | None = 16
    """Maximum number of tokens to generate per output sequence."""
    min_tokens: int = 0
    """Minimum number of tokens to generate per output sequence before EOS or
    `stop_token_ids` can be generated"""
    logprobs: int | None = None
    """Number of log probabilities to return per output token. When set to
    `None`, no probability is returned. If set to a non-`None` value, the
    result includes the log probabilities of the specified number of most
    likely tokens, as well as the chosen tokens. Note that the implementation
    follows the OpenAI API: The API will always return the log probability of
    the sampled token, so there may be up to `logprobs+1` elements in the
    response. When set to -1, return all `vocab_size` log probabilities."""
    prompt_logprobs: int | None = None
    """Number of log probabilities to return per prompt token.
    When set to -1, return all `vocab_size` log probabilities."""
    flat_logprobs: bool = False
    """Whether to return logprobs in flatten format (i.e. FlatLogprob)
    for better performance.
    NOTE: GC costs of FlatLogprobs is significantly smaller than
    list[dict[int, Logprob]]. After enabled, PromptLogprobs and
    SampleLogprobs would populated as FlatLogprobs."""
    # NOTE: This parameter is only exposed at the engine level for now.
    # It is not exposed in the OpenAI API server, as the OpenAI API does
    # not support returning only a list of token IDs.
    detokenize: bool = True
    """Whether to detokenize the output."""
    skip_special_tokens: bool = True
    """Whether to skip special tokens in the output."""
    spaces_between_special_tokens: bool = True
    """Whether to add spaces between special tokens in the output."""
    # `list[LogitsProcessor] | None` type. We use Any here because
    # `list[LogitsProcessor] | None` type is not supported by msgspec.
    logits_processors: Any | None = None
    """Functions that modify logits based on previously generated tokens, and
    optionally prompt tokens as a first argument."""
    include_stop_str_in_output: bool = False
    """Whether to include the stop strings in output text."""
    truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None
    """If set to -1, will use the truncation size supported by the model. If
    set to an integer k, will use only the last k tokens from the prompt
    (i.e., left truncation). If set to `None`, truncation is disabled."""
    output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE

    # The below fields are not supposed to be used as an input.
    # They are set in post_init.
    output_text_buffer_length: int = 0
    _all_stop_token_ids: set[int] = msgspec.field(default_factory=set)

    # Fields used to construct logits processors
    structured_outputs: StructuredOutputsParams | None = None
    """Parameters for configuring structured outputs."""
    logit_bias: dict[int, float] | None = None
    """If provided, the engine will construct a logits processor that applies
    these logit biases."""
    allowed_token_ids: list[int] | None = None
    """If provided, the engine will construct a logits processor which only
    retains scores for the given token ids."""
    extra_args: dict[str, Any] | None = None
    """Arbitrary additional args, that can be used by custom sampling
    implementations, plugins, etc. Not used by any in-tree sampling
    implementations."""

    # Fields used for bad words
    bad_words: list[str] | None = None
    """Words that are not allowed to be generated. More precisely, only the
    last token of a corresponding token sequence is not allowed when the next
    generated token can complete the sequence."""
    _bad_words_token_ids: list[list[int]] | None = None

    skip_reading_prefix_cache: bool | None = None

基础参数说明

  • n: int = 1:
    • 为单个 Prompt 请求返回的生成结果数量
    • vLLM 默认一个个输出结果,当 n > 1 时,所有 n 个结果会按请求累积流式一个个返回;
      • 问题:这里的流式,不是通常意义上的流式,而是针对 Response n 粒度的流式?
    • 若希望仅在生成完成后一次性获取所有 n 个结果,需将 output_kind 设置为 RequestOutputKind.FINAL_ONLY
  • max_tokens: int | None = 16:
    • 每个输出序列允许生成的最大 token 数量
    • 若设为 None,需确保模型有明确的终止条件(如 EOS 或 stop 词),否则可能无限生成
  • min_tokens: int = 0
    • 每个输出序列在生成 EOS(结束符)或 stop_token_ids 之前必须生成的最小 token 数
    • 作用 :避免生成过短的结果,例如设置 min_tokens=5 时,即使模型提前触发终止条件 ,也会继续生成直到达到 5 个 token
  • ignore_eos: bool = False
    • 是否忽略 EOS token,强制模型在生成 EOS 后继续生成
    • 适用于需要生成超长文本、绕过模型默认终止逻辑的场景(如生成完整文档而非单句)

采样策略参数

  • 采样策略参数参数控制模型生成 token 时的随机性和候选范围,是最常用的参数,当不做采样时,no_sample
  • temperature: float = 1.0
    • 控制采样的随机性,本质是对 logits(token 概率对数)进行缩放(注意:是在 Softmax 前进行缩放)
    • temperature = 0:贪心采样(Greedy Sampling),直接选择概率最高的 token,结果完全确定;
    • 0 < temperature < 1:降低随机性(提高确定性),结果更聚焦、确定(如 0.7 是平衡随机性和确定性的常用值);
    • temperature > 1:提高随机性,结果更发散、创意性更强,但可能出现无意义内容
    • temperature 越小越容易出现重复现象
    • 注意 :当 temperature=0 时,top_p/top_k 等参数会失效(贪心采样无需候选集)
      • vLLM 中没有 do_sample 参数 参照了 HF Transformer 相似的思路,但是实现方式不同,通过 temperature 隐晦地实现了是否贪心采样的控制
      • temperature=0 强制 do_sample=False(贪心采样,只选概率最高的 token);
      • temperature>0 等价 do_sample=True(启用随机采样,按概率分布选 token)
  • top_p: float = 1.0
    • 核采样(Nucleus Sampling),控制待选 token 的累积概率阈值,取值范围 (0, 1]
    • 将所有 token 按概率从高到低排序,累加概率直到达到 top_p,仅从这些 token 中采样
      • top_p=0.9 时,仅选择累计概率前 90% 的 token 作为候选;
      • top_p=1.0 时,包含所有 token(等同于不限制)
    • 相比 top_k 更灵活,能自适应调整候选集大小(高概率 token 少则候选集小,反之则大)
  • top_k: int = 0
    • 限制采样的候选 token 数量,仅从概率最高的 top_k 个 token 中选择
      • top_k=0(或 -1):不限制,包含所有 token;
      • top_k=50:仅从概率前 50 的 token 中采样
    • 对比 top_p :top_k 是固定数量限制,top_p 是概率累积限制,通常两者二选一使用
      • 两者组合时:先按 top_k 筛选,再按 top_p 过滤
  • min_p: float = 0.0
    • 基于最高概率 token 的相对概率阈值,筛选候选 token,取值范围 [0, 1]
    • 设本次采样遇到的最高概率 token 的概率为 P_max(注意:是个随分布变化的值),仅保留概率 \(\ge\) min_p * P_max 的 token
      • min_p=0.1 且 P_max=0.5 时,仅保留概率 \(\ge\) 0.05 的 token;
      • min_p=0 时禁用该规则
    • 优势 :相比 top_k/top_p,能避免极端情况下的不合理筛选(如 top_k 可能漏掉低概率但有意义的 token,top_p 可能包含过多低概率 token)
  • seed: int | None = None
    • 生成随机数的种子,用于复现生成结果
    • 设置固定 seed 后,相同 Prompt 和参数下,模型会生成完全相同的结果(解决采样随机性导致的不可复现问题)

重复/惩罚类参数

  • 用于控制模型生成时的重复率,避免生成冗余、重复的文本
  • presence_penalty: float = 0.0
    • 基于 token 是否“出现过”的惩罚,与出现次数无关
      • 正值(如 0.5):惩罚已出现的 token,鼓励生成新内容;
      • 负值(如 -0.5):奖励已出现的 token,鼓励重复;
      • 0:无惩罚/奖励
    • 适用场景 :需要避免模型重复提及相同实体(如人名、地名)的场景
  • frequency_penalty: float = 0.0
    • 基于 token 出现“频率”的惩罚,出现次数越多,惩罚越重
      • 正值:抑制高频 token,减少重复;
      • 负值:强化高频 token,增加重复;
      • 0:无惩罚/奖励
    • 区别于 presence_penalty :前者是“有无”惩罚,后者是“多少”惩罚,例如重复 3 次的 token 会比重复 1 次的 token 受到更重的频率惩罚
  • repetition_penalty: float = 1.0
    • 基于 prompt 和已生成文本中 token 出现的惩罚,核心是调整 token 的概率
      • 取值 > 1:惩罚重复 token(概率 = 原概率 / repetition_penalty),鼓励新内容;
      • 取值 < 1:奖励重复 token(概率 = 原概率 * repetition_penalty),鼓励重复;
      • 1:无惩罚/奖励
    • 覆盖范围(特别注意) :同时作用于 prompt 和生成文本中的 token,是更通用的重复控制参数
      • 理解:这里的含义是在 prompt 中的 Token 也会当做是否重复的判断依据进行累计

终止条件参数

  • 控制模型何时停止生成,避免无限制输出
  • stop: str | list[str] | None = None
    • 触发生成终止的字符串(单个或列表),返回的结果中默认不包含这些停止字符串
    • stop=["\n", "###"] 时,模型生成到换行符或 ### 时立即停止
  • stop_token_ids: list[int] | None = None
    • 触发生成终止的 token ID 列表(底层 token-level 的终止条件)
    • 返回结果中会包含 stop_token_ids 对应的 stop token(Special Token 服从本规则)
      • 如果是 Special Token,可能是不会在输出结果中的,有自己的规则
      • 与 stop(字符串级)互补,分别用于指定字符串或者 Token
  • include_stop_str_in_output: bool = False
    • 是否将 stop 参数指定的停止字符串包含在输出文本中
      • 注意:这里只影响 stop,与 stop_token_ids 无关,stop_token_ids 不受此参数影响
    • 若设为 True,停止字符串会出现在最终输出里
  • 理解终止条件参数,vLLM 的 SamplingParams 内部会维护一个参数:_all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
    • _all_stop_token_ids 存储所有终止 token ID
      • 包括 stop_token_ids 转换后的 ID、EOS token 等
      • 这个参数无需用户手动设置,由 post_init 自动初始化

日志概率(logprobs)参数

  • 用于获取 token 生成的概率信息,适用于需要分析模型决策过程的场景(如评估生成可靠性)
  • logprobs: int | None = None
    • 每个输出 token 返回的最高概率 token 的数量(包含选中的 token)
      • logprobs=None:不返回概率;
      • logprobs=k(\(k \in \mathbb{Z}^+\)):返回概率最高的 k 个 token 的 log 概率(实际返回 k+1 个,因为包含选中的 token);
        • 理解:这里选中的 Token 不一定是概率最高的, 所以被选中的一定会返回
      • logprobs=-1:返回全词表(vocab_size 维度)所有 token 的 log 概率
    • Following OpenAI API :始终返回选中 token 的 log 概率
  • prompt_logprobs: int | None = None
    • 每个 Prompt token 返回的最高概率 token 的数量
      • 取值规则同 logprobs,-1 表示返回全词表概率
    • 问题:prompt 为什么也会对应概率?
      • prompt_logprobs 是专门针对输入的 prompt 部分(而非生成的 completion 部分)返回的每个 token 的对数概率信息
      • logprobs 则通常指生成部分的对数概率
  • flat_logprobs: bool = False
    • 是否展平返回 logprobs,优化性能
    • 优势 :FlatLogprob 的 GC(垃圾回收)成本远低于 list[dict[int, Logprob]] 格式,适合高并发场景;
    • 启用后 PromptLogprobs 和 SampleLogprobs 均会以 FlatLogprob 格式返回

输出格式与处理参数

  • 控制生成结果的格式、是否过滤 Special Token 等
  • detokenize: bool = True
    • 是否将生成的 token ID 转换为文本
    • 注意 :仅在引擎层暴露,OpenAI API 不支持仅返回 token ID,默认开启,得到的就是文本而不是 Token ID
  • skip_special_tokens: bool = True
    • 是否在输出中跳过 Special Token (如 、、 等)
    • 注意默认是 True(跳过),避免输出包含无意义的特殊标记
  • spaces_between_special_tokens: bool = True
    • 是否在 Special Token 之间添加空格
    • 优化 Special Token 的可读性,例如 <|endoftext|><|user|> 会变成 <|endoftext|> <|user|>
    • 理解:为什么这里默认是 True,目前我们几乎不用,但确从不需要设置?猜测如下(待确定):
      • Hugging Face Tokenizer 基类的 通用默认值 是 True;
      • LLaMA/Qwen 等模型的 专属默认值 是 False(通过代码硬编码覆盖了通用默认值)
  • output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
    • 输出类型,控制流式返回的方式:
      • output_kind=RequestOutputKind.CUMULATIVE(默认):累积式输出(如第 1 次返回第 1 个 token,第 2 次返回前 2 个 token,依此类推);
      • output_kind=RequestOutputKind.FINAL_ONLY:仅在生成完成后返回最终完整结果(此时不是异步生成了)
      • output_kind=RequestOutputKind.DELTA:仅返回增量
    • 问题:这个参数的使用待测试确认
  • output_text_buffer_length: int = 0
    • 内部参数,存储输出文本缓冲区长度,无需用户设置,由 post_init 初始化

Prompt 处理参数

  • truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None
    • Prompt 的左截断规则(仅保留最后 k 个 token):
      • -1:使用模型支持的最大截断长度;
      • 正整数 k:仅保留 Prompt 最后 k 个 token;
      • None:禁用截断
    • 常用参数,适配模型的上下文窗口限制,避免 Prompt 过长导致超出模型容量
  • skip_reading_prefix_cache: bool | None = None
    • 是否跳过读取前缀缓存(prefix cache),用于优化 Prompt 处理性能,通常无需用户手动设置
  • logits_processors: Any | None = None
    • 修改 logits 的自定义处理器列表(函数),可基于已生成的 token(或 Prompt token)调整 token 概率
    • 因 msgspec 不支持 list[LogitsProcessor] | None,故用 Any 替代;适用于自定义生成逻辑(如强制生成特定 token、限制生成内容)
    • 问题:待确认这个参数
  • structured_outputs: StructuredOutputsParams | None = None
    • 结构化输出参数,用于控制模型生成符合特定格式的内容(如 JSON、XML)
    • 需要结构化结果的场景(如数据提取、API 调用返回)
  • logit_bias: dict[int, float] | None = None
    • token 级别的概率偏置,键为 token ID,值为偏置值
    • 调整指定 token 的生成概率(正值提高概率,负值降低概率),例如 logit_bias={123: 5.0} 会大幅提高 ID 为 123 的 token 被选中的概率
    • 问题:待尝试这个参数
  • allowed_token_ids: list[int] | None = None
    • 允许生成的 token ID 列表,后续生成时,会仅保留这些 token 的概率,其余 token 概率置 0
    • 严格限制生成内容的范围(如仅允许生成数字、特定词汇)
    • 问题:待尝试这个参数
  • extra_args: dict[str, Any] | None = None
    • 自定义额外参数,供第三方插件、自定义采样逻辑使用,vLLM 内置采样逻辑不使用该参数

不良词过滤参数

  • bad_words: list[str] | None = None
    • 禁止生成的词汇列表,核心逻辑是:当生成的 token 即将完成某个 bad word 的 token 序列时,禁止生成该序列的最后一个 token
    • 比如 bad_words=["暴力"] 时,模型会避免生成“暴力”这个词(通过阻止其最后一个 token 的生成),直接停止
    • 问题:待测试这个参数
  • _bad_words_token_ids: list[list[int]] | None = None
    • 内部参数,存储 bad_words 转换后的 token ID 序列,无需用户设置,由 post_init 初始化
    • 问题:待测试这个参数

NLP——技术报告解读-DeepSeek-R1

注:本文包含 AI 辅助创作

  • 参考链接:
    • 原始论文:DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning, arXiv 20250122 & 20260104, DeepSeek-AI
      • 20260104 补充了更多细节【本文还未补充,待后续有时间再更新】
    • 补充文档:(DeepSeek-R1-Supplements)Supplementary Information for: DeepSeek-R1 Incentivizes Reasoning in LLMs via Reinforcement Learning, DeepSeek-AI
      • 补充材料阅读笔记见另一篇
    • 中文完整版: 梁文锋Nature论文的同行评审和团队回应- 上
    • rebuttal过程

Paper Summary

  • 评价:
    • 划时代的一篇文章,25 年春节前后 DeepSeek 给大家带来的冲击是巨大的,众多社区一起复现 DeepSeek-R1 的 Aha Moment 的空前盛况
    • 本文及其附录都是非常值得深入阅读的文章
  • 论文介绍了 DeepSeek 的第一代推理模型 DeepSeek-R1-Zero 和 DeepSeek-R1
  • DeepSeek-R1-Zero 是一个通过大规模 RL 训练、无需 SFT 作为初步步骤的模型,展现出卓越的推理能力
    • 通过强化学习,DeepSeek-R1-Zero 自然地涌现出许多强大而有趣的推理行为
    • 但 DeepSeek-R1-Zero 也面临可读性差和语言混合等挑战
  • 为了解决这些问题并进一步提升推理性能,论文引入了 DeepSeek-R1,它在强化学习之前引入了多阶段训练和冷启动数据
    • DeepSeek-R1 在推理任务上的性能与 OpenAI-o1-1217 相当
    • 为了支持研究社区,论文开源了 DeepSeek-R1-Zero、DeepSeek-R1,以及六个基于 Qwen 和 Llama 从 DeepSeek-R1 蒸馏得到的稠密模型(1.5B、7B、8B、14B、32B、70B)

Introduction and Discussion

  • LLM 经历了快速的迭代和演进,逐步缩小了与人工通用智能(Artificial General Intelligence, AGI)之间的差距
  • 训练后阶段(post-training)已成为完整训练流程中的一个重要组成部分
    • 它被证明能够提升推理任务的准确性、与社会价值观对齐并适应用户偏好,同时相对于预训练所需计算资源相对较少
  • 在推理能力的背景下,OpenAI 的 o1 (OpenAI) 系列模型首次通过增加 CoT 推理过程的长度引入了推理时扩展(inference-time scaling)
    • 这种方法在数学、编程和科学推理等各种推理任务中取得了显著改进
    • 但有效的测试时扩展(test-time scaling)的挑战仍然是研究社区的一个开放性问题
  • 之前的几项工作探索了各种方法,包括基于过程的奖励模型(process-based reward models)(2022; 2023)、强化学习 (2024) 以及蒙特卡洛树搜索(Monte Carlo Tree Search)和束搜索(Beam Search)等搜索算法 (2024; 2024; 2024)
    • 然而,这些方法都没有达到与 OpenAI 的 o1 系列模型相媲美的通用推理性能
  • 在论文中,论文迈出了第一步,使用纯 RL 来改进语言模型的推理能力
    • 论文的目标是探索大语言模型在没有任何监督数据的情况下发展推理能力的潜力,重点关注它们通过纯强化学习过程进行的自我进化
    • 具体来说,论文使用 DeepSeek-V3-Base 作为基础模型,并采用 GRPO (2024) 作为强化学习框架来提高模型在推理中的性能
    • 在训练过程中,DeepSeek-R1-Zero 自然地涌现出许多强大而有趣的推理行为
    • 经过数千个强化学习步骤后,DeepSeek-R1-Zero 在推理基准测试中表现出卓越的性能
      • 例如,在 AIME 2024 上的 pass@1 分数从 15.6% 增加到 71.0%,并且通过多数投票(majority voting),分数进一步提高到 86.7%,与 OpenAI-o1-0912 的性能相匹配
  • 然而,DeepSeek-R1-Zero 遇到了可读性差和语言混合等挑战
    • 为了解决这些问题并进一步提升推理性能,论文引入了 DeepSeek-R1,它结合了少量冷启动数据(cold-start data)和多阶段训练流程
    • 具体来说
      • 论文首先收集数千个冷启动数据来微调 DeepSeek-V3-Base 模型
      • 随后,论文像 DeepSeek-R1-Zero 一样执行面向推理的强化学习
    • 当强化学习过程接近收敛时,论文通过对强化学习检查点(checkpoint)进行拒绝采样(rejection sampling)来创建新的监督微调数据,并结合来自 DeepSeek-V3 在写作、事实问答(factual QA)和自我认知(self-cognition)等领域的有监督数据,然后重新训练 DeepSeek-V3-Base 模型
    • 在使用新数据微调后,该检查点会经历额外的强化学习过程,考虑所有场景的 Prompts
    • 经过这些步骤,论文获得了一个称为 DeepSeek-R1 的检查点,其性能与 OpenAI-o1-1217 相当
  • 论文进一步探索了从 DeepSeek-R1 到更小稠密模型(dense models)的蒸馏(distillation)
    • 使用 Qwen2.5-32B (Qwen) 作为基础模型,直接从 DeepSeek-R1 进行蒸馏优于在其上应用强化学习
    • 这表明由更大基础模型发现的推理模式对于提高推理能力至关重要
  • 论文开源了蒸馏后的 Qwen 和 Llama (2024) 系列
    • 值得注意的是,论文蒸馏的 14B 模型大幅优于最先进的开源模型 QwQ-32B-Preview (Qwen),并且蒸馏的 32B 和 70B 模型在稠密模型的推理基准测试中创造了新纪录
  • 补充:来自辅助材料的说明
    • DeepSeek-V3-Base 指基础模型
    • DeepSeek-V3 指经过指令微调的模型
    • DeepSeek-R1 与 DeepSeek-R1-Zero 均在 DeepSeek-V3-Base 的基础上训练而成
      • 且 DeepSeek-R1 还利用了 DeepSeek-V3 监督微调数据中的非推理类数据

Approach

Overview

  • 先前的工作严重依赖大量的监督数据来提升模型性能
  • 在本研究中,论文证明了即使不使用 SFT 作为冷启动,通过大规模 RL 也能显著提升推理能力
    • 此外,加入少量冷启动数据可以进一步提升性能
  • 在接下来的小节中,论文将介绍:
    • (1) DeepSeek-R1-Zero,它直接在基础模型(DeepSeek-V3-Base)上应用强化学习,不使用任何监督微调数据;
    • (2) DeepSeek-R1,它从一个经过数千个长 CoT 示例微调过的检查点开始应用强化学习;
    • (3) 将 DeepSeek-R1 的推理能力蒸馏到小型稠密模型中

DeepSeek-R1-Zero: Reinforcement Learning on the Base Model

  • 强化学习在推理任务中已展现出显著的有效性,这在论文先前的工作 (2024; 2023) 中得到了证明
    • 但这些工作严重依赖监督数据,而收集这些数据非常耗时
  • 在本节中,论文探索了 LLM 在没有任何监督数据的情况下发展推理能力的潜力,重点关注其通过纯强化学习过程进行的自我进化
  • 论文首先简要概述论文的强化学习算法,随后展示一些令人兴奋的结果,并希望这能为研究社区提供有价值的见解
Reinforcement Learning Algorithm
Group Relative Policy Optimization, GRPO
  • 为了节省强化学习的训练成本,论文采用了组相对策略优化(GRPO)(2024)
  • 该方法省去了通常与策略模型大小相同的评论家模型,转而从组分数中估计基线
  • 具体来说,对于每个问题 \(q\),GRPO 从旧策略 \(\pi_{\theta_{old} }\) 中采样一组输出 \(\{o_{1},o_{2},\cdots,o_{G}\}\),然后通过最大化以下目标来优化策略模型 \(\pi_{\theta}\):
    $$
    \mathcal{J}_{GRPO}(\theta)=\mathbb{E}_{[q\sim P(Q),\{o_{i}\}_{i=1}^{ G}\sim\pi_{\theta_{old} }(O|q)]} \frac{1}{G}\sum_{i=1}^{G}\left(\min\left(\frac{\pi_{\theta}(o_{i}|q)}{\pi_{\theta_{old} }(o_{i}|q)}A_{i},\text{clip}\left(\frac{\pi_{\theta}(o_{i}|q)}{\pi_{\theta_{old} }(o_{i}|q)},1-\varepsilon,1+\varepsilon\right)A_{i}\right)-\beta\mathbb{D}_{KL}\left(\pi_{\theta}||\pi_{ref}\right)\right),
    $$
    • 其中 \(\varepsilon\) 和 \(\beta\) 是超参数,\(A_{i}\) 是优势函数,使用与组内每个输出对应的一组奖励 \(\{r_{1},r_{2},\ldots,r_{G}\}\) 计算得出:
      $$
      A_{i}=\frac{r_{i}-\text{mean}(\{r_{1},r_{2},\cdots,r_{G}\})}{\text{std}(\{r_{1},r_{2},\cdots,r_{G}\})}.
      $$
  • KL 散度项 \(\mathbb{D}_{KL}\left(\pi_{\theta}||\pi_{ref}\right)\) 定义为:
    $$
    \mathbb{D}_{KL}\left(\pi_{\theta}||\pi_{ref}\right)=\frac{\pi_{ ref}(o_{i}|q)}{\pi_{\theta}(o_{i}|q)}-\log\frac{\pi_{ref}(o_{i}|q)}{\pi_{ \theta}(o_{i}|q)}-1.
    $$
Reward Modeling
  • 奖励是训练信号的来源,它决定了强化学习的优化方向
  • 为了训练 DeepSeek-R1-Zero,论文采用了一个基于规则的奖励系统,主要包括两种类型的奖励:
    • 准确性奖励 (Accuracy rewards) :准确性奖励模型评估响应是否正确
      • 例如,对于具有确定性结果的数学问题,模型需要以指定格式(例如,在方框内)提供最终答案,从而能够基于规则可靠地验证正确性
      • 类似地,对于 LeetCode 问题,可以使用编译器根据预定义的测试用例生成反馈
    • 格式奖励 (Format rewards) :除了准确性奖励模型,论文还采用了一个格式奖励模型,强制模型将其思维过程放在 <think> 和 </think> 标签之间
  • 在开发 DeepSeek-R1-Zero 时,论文没有使用基于结果的或基于过程的神经奖励模型(neural reward model)
    • 因为论文发现神经奖励模型在大规模强化学习过程中可能遭受奖励破解(reward hacking)问题 ,并且重新训练奖励模型需要额外的训练资源,并使整个训练流程复杂化
Training Template
  • 为了训练 DeepSeek-R1-Zero,论文首先设计了一个简单的模板,引导基础模型遵循论文指定的指令
  • 如表 1 所示,该模板要求 DeepSeek-R1-Zero 先生成一个推理过程,然后是最终答案
  • 论文有意将约束限制在这种结构格式上,避免任何特定于内容的偏见——例如强制进行反思性推理或推广特定的问题解决策略,以确保论文能够准确观察模型在强化学习过程中的自然进展
Performance, Self-evolution Process and Aha Moment of DeepSeek-R1-Zero
Performance of DeepSeek-R1-Zero
  • 图 2 描绘了 DeepSeek-R1-Zero 在 AIME 2024 基准测试上的性能随强化学习训练过程的变化轨迹
  • 如图所示,随着强化学习训练的进行,DeepSeek-R1-Zero 表现出稳定且一致的性能提升
    • AIME 2024 的平均 pass@1 分数显著增加,从最初的 15.6% 跃升至令人印象深刻的 71.0%,达到了与 OpenAI-o1-0912 相当的性能水平
    • 这一显著改进凸显了论文的强化学习算法在随时间优化模型性能方面的有效性
    • 注:图 2 中 cons@k 是多数投票的结果(cons 表示 consensus,即共识):详情见 NLP——技术报告解读-DeepSeek-R1-Supplements
  • 表 2 提供了 DeepSeek-R1-Zero 与 OpenAI 的 o1-0912 模型在各种推理相关基准测试上的比较分析
    • 研究结果表明,强化学习使 DeepSeek-R1-Zero 能够在不需要任何监督微调数据的情况下获得强大的推理能力
    • 这是一个值得注意的成就,因为它强调了模型仅通过强化学习就能有效学习和泛化的能力
  • 此外,通过应用多数投票(majority voting),可以进一步增强 DeepSeek-R1-Zero 的性能
    • 例如,在 AIME 基准测试上使用多数投票时,DeepSeek-R1-Zero 的性能从 71.0% 提升至 86.7%(图 2 中 cons@16 的结果),从而超过了 OpenAI-o1-0912 的性能
    • DeepSeek-R1-Zero 在有和没有多数投票的情况下都能实现如此有竞争力的性能,这突显了其强大的基础能力及其在推理任务中进一步发展的潜力
Self-evolution Process of DeepSeek-R1-Zero
  • DeepSeek-R1-Zero 的自我进化过程是一个迷人的演示(demonstration),展示了强化学习如何驱动模型自主提高其推理能力
    • 通过直接从基础模型启动强化学习,我们可以在不受监督微调阶段影响的情况下密切监控模型的进展
    • 这种方法清晰地展示了模型随时间演变的过程,特别是在其处理复杂推理任务的能力方面
  • 如图 3 所示,DeepSeek-R1-Zero 的思考时间在整个训练过程中持续改善
    • 这种改进不是外部调整的结果,而是模型内部的内在发展
    • DeepSeek-R1-Zero 自然地获得了通过利用延长的测试时间计算来解决日益复杂的推理任务的能力
    • 这种计算范围从生成数百到数千个推理 Token,使模型能够更深入地探索和完善其思维过程
  • 这种自我进化最显著的方面之一是随着测试时间计算的增加而出现的复杂行为
    • 诸如反思(模型重新审视和重新评估其先前步骤)以及探索替代性问题解决方法等行为自发产生
    • 这些行为不是显式编程的,而是模型与强化学习环境交互的结果
    • 这种自发的发展显著增强了 DeepSeek-R1-Zero 的推理能力,使其能够更高效、更准确地应对更具挑战性的任务
Aha Moment of DeepSeek-R1-Zero
  • 在 DeepSeek-R1-Zero 的训练过程中观察到一个特别有趣的现象是 “顿悟时刻”(aha moment) 的出现
  • 如表 3 所示,这个时刻发生在模型的一个中间版本中
    • 在此阶段,DeepSeek-R1-Zero 学会了通过重新评估其初始方法为问题分配更多的思考时间
    • 这种行为不仅证明了模型不断增长的推理能力,也是强化学习如何导致意外和复杂结果的一个引人入胜的例子
  • 这个时刻不仅是模型的“顿悟时刻”,对观察其行为的研究人员来说也是如此
    • 它强调了强化学习的力量和美感:论文不是明确地教导模型如何解决问题,而是简单地提供正确的激励,它就会自主地发展出高级的问题解决策略
    • “顿悟时刻”有力地提醒论文强化学习在人工智能系统中解锁新智能水平的潜力,为未来更自主和自适应的模型铺平道路
  • 个人理解:后面的一些文章逐步分析并证明,一些顿悟时刻实际上并不是一个突然发生的过程,而是逐步发生的,只是在特定任务上看起来像是突然发生一样
DeepSeek-R1-Zero 的缺点 (Drawback of DeepSeek-R1-Zero)**
  • 尽管 DeepSeek-R1-Zero 表现出强大的推理能力并自主发展出意想不到的强大推理行为,但它也面临几个问题
    • DeepSeek-R1-Zero 存在可读性差和语言混合等挑战
  • 为了使推理过程更具可读性并与开放社区分享,论文探索了 DeepSeek-R1,这是一种利用强化学习和对人类友好的冷启动数据的方法

DeepSeek-R1: Reinforcement Learning with Cold Start

  • 受 DeepSeek-R1-Zero 有希望的结果的启发,两个自然的问题出现了:
    • 1)通过加入少量高质量数据作为冷启动,能否进一步提高推理性能或加速收敛?
    • 2)论文如何训练一个用户友好的模型,不仅能产生清晰连贯的思维链(CoT),还能展现出强大的通用能力?
  • 为了解决这些问题,论文设计了一个训练 DeepSeek-R1 的流程
    • 该流程包括四个阶段,概述如下文所示
  • 补充:来自其他博主制作的非常好的 DeepSeek-R1 训练过程:
    • 注意:根据 DeepSeek-V3 辅助材料给出的结论,下图中存在问题(已补充),DeepSeek-R1 和 DeepSeek-R1-Zero 均是从 DeepSeek-V3-Base 训练而来,图中给的是 DeepSeek-V3 (这是 DeepSeek-V3-Base 的微调版本);部分训练数据(监督微调数据中的非推理类数据)确实来源于 DeepSeek-V3
Cold Start
  • 与 DeepSeek-R1-Zero 不同,为了防止从基础模型开始强化学习训练时早期不稳定的冷启动阶段,对于 DeepSeek-R1,论文构建并收集了少量长思维链(CoT)数据来微调模型,微调后的模型作为初始的强化学习 Actor
  • 为了收集此类数据,论文探索了几种方法:
    • 使用带有长 CoT 示例的少样本提示(few-shot prompting),直接提示模型生成带有反思和验证的详细答案,以可读格式收集 DeepSeek-R1-Zero 的输出,以及通过人工标注员的后处理来细化结果
  • 在这项工作中,论文收集了数千个冷启动数据来微调 DeepSeek-V3-Base,作为强化学习的起点
  • 与 DeepSeek-R1-Zero 相比,冷启动数据的优势包括:
    • 可读性 (Readability) :DeepSeek-R1-Zero 的一个关键限制是其内容通常不适合阅读
      • 响应可能混合多种语言或缺乏用于向用户突出显示答案的 markdown 格式
      • 在为 DeepSeek-R1 创建冷启动数据时,论文设计了一种可读的模式 ,包括在每个响应末尾进行总结 ,并过滤掉对读者不友好的响应
      • 在这里,论文将输出格式定义为 \(|\)special_token\(|\)\(<\)reasoning_process\(>\)\(|\)special_token\(|\)\(<\)summary\(>\),其中推理过程是针对查询的 CoT,而总结(summary)用于总结推理结果
    • 潜力 (Potential) :通过利用人类先验知识精心设计冷启动数据的模式,论文观察到相对于 DeepSeek-R1-Zero 更好的性能
      • 作者相信,对于推理模型来说,迭代训练(iterative training)是一种更好的方式
Reasoning-oriented Reinforcement Learning
  • 在基于冷启动数据对 DeepSeek-V3-Base 进行微调之后,论文应用了与 DeepSeek-R1-Zero 相同的大规模强化学习训练过程
  • 此阶段侧重于增强模型的推理能力 ,特别是在编码、数学、科学和逻辑推理等推理密集型任务中,这些任务涉及具有明确解决方案的明确定义的问题
  • 在训练过程中,论文观察到 CoT 经常出现语言混合 ,特别是当强化学习提示涉及多种语言时
    • 为了缓解语言混合问题,论文在强化学习训练期间引入了语言一致性奖励(language consistency reward),该奖励计算为 CoT 中目标语言单词的比例
    • 尽管消融实验表明这种对齐会导致模型性能略有下降 ,但这种奖励符合人类偏好,使其更具可读性
    • 问题:CoT 不需要让人可以阅读吧?
  • 论文通过直接求和将推理任务的准确性和语言一致性奖励结合起来,形成最终奖励
  • 论文在微调后的模型上应用强化学习训练,直到其在推理任务上达到收敛
Rejection Sampling and Supervised Fine-Tuning
  • 当面向推理的强化学习收敛时,论文利用得到的检查点来为后续轮次收集 SFT 数据
  • 与主要关注推理的初始冷启动数据不同,此阶段合并了来自其他领域的数据 ,以增强模型在写作、角色扮演和其他通用任务中的能力
  • 具体来说,论文按照下述方式生成数据并微调模型(分推理数据和非推理数据)
Reasoning data
  • 论文策划(curate)推理提示词(prompts),并通过从上述强化学习训练的检查点执行拒绝采样(rejection sampling)来生成推理轨迹(trajectories)
  • 在前一阶段,论文只包含了可以使用基于规则的奖励进行评估的数据
  • 在此阶段,论文通过合并额外的数据来扩展数据集,其中一些数据使用生成式奖励模型,通过将真实值(ground-truth)和模型预测输入到 DeepSeek-V3 中进行判断
  • 由于模型输出有时混乱且难以阅读,论文过滤掉了具有混合语言、长段落和代码块的思维链
  • 对于每个提示,论文采样多个响应,并仅保留正确的响应
  • 总共,论文收集了约 60 万个与推理相关的训练样本
Non-Reasoning data
  • 对于非推理数据,例如写作、事实问答(factual QA)、自我认知(self-cognition)和翻译,论文采用 DeepSeek-V3 的流程并重用部分 DeepSeek-V3 的监督微调数据集
  • 对于某些非推理任务,论文调用 DeepSeek-V3,通过在回答问题时进行提示来生成潜在的思维链
  • 但对于更简单的查询 ,例如“你好”,论文不提供 CoT 响应
  • 最终,论文总共收集了约 20 万个与推理无关的训练样本
  • 论文使用上述策划的大约 80 万个样本的数据集 ,对 DeepSeek-V3-Base 进行了两个轮次(epochs)的微调
    • 注:这个 2 个 epochs 的设定和 大约 100W 样本的微调量级,已经成为后来一些文章的标配
Reinforcement Learning for all Scenarios
  • 为了进一步使模型与人类偏好对齐,论文实施了第二轮强化学习阶段,旨在提高模型的有用性(helpfulness)和无害性(harmlessness) ,同时完善其推理能力
  • 论文结合使用奖励信号和多样化的提示词分布(diverse prompt distributions)来训练模型
    • 对于推理数据:
      • 论文遵循 DeepSeek-R1-Zero 中概述的方法论,该方法利用基于规则的奖励来指导数学、代码和逻辑推理领域的学习过程
    • 对于通用数据
      • 论文用奖励模型来捕捉复杂和细微场景中的人类偏好
        • 论文基于 DeepSeek-V3 的流程,并采用了类似的偏好对(preference pairs)和训练提示词分布
      • 对于有用性,论文只关注最终总结(summary),确保评估强调响应对于用户的实用性和相关性,同时最小化对底层推理过程的干扰
      • 对于无害性,论文评估模型的整个响应,包括推理过程和总结,以识别和减轻生成过程中可能出现的任何潜在风险、偏见或有害内容
      • 最终,奖励信号和多样化数据分布的整合使论文能够训练出一个在推理方面表现出色,同时优先考虑有用性和无害性的模型

Distillation: Empower Small Models with Reasoning Capability

  • 为了让更高效的小型模型具备像 DeepSeek-R1 一样的推理能力,论文直接使用 DeepSeek-R1 策划的 80 万个样本(详见 章节2.3.3)对开源模型如 Qwen 和 Llama 进行了微调
  • 论文的研究结果表明,这种简单的蒸馏方法显著增强了小型模型的推理能力
  • 论文这里使用的基础模型是 Qwen2.5-Math-1.5B, Qwen2.5-Math-7B, Qwen2.5-14B, Qwen2.5-32B, Llama-3.1-8B 和 Llama-3.3-70B-Instruct
    • 论文选择 Llama-3.3 是因为其推理能力略优于 Llama-3.1
    • 问题:为什么 70B 量级的模型,选择的不是 Qwen2.5-72B-Instruct ?
  • 对于蒸馏模型 ,论文仅应用了 SFT ,没有包含 RL 阶段 ,尽管加入强化学习可以大幅提升模型性能
    • 论文这里的主要目标是证明蒸馏技术的有效性 ,将强化学习阶段的探索留给更广泛的研究社区

Experiment

Benchmarks

  • 标准测试基准方面:
    • 论文在 MMLU (2020)、MMLU-Redux (2024)、MMLU-Pro (2024)、C-Eval (2023)、CMMLU (2023)、IFFval (2023)、FRAMES (2024)、GPQA Diamond (2023)、SimpleQA (OpenAI)、C-SimpleQA (2024)、SWE-Bench Verified (OpenAI, 2024d)、Aider、LiveCodeBench (2024) (2024-08 - 2025-01)、Codeforces、中国高中数学奥林匹克竞赛(Chinese National High School Mathematics Olympiad, CNMO 2024) 和美国数学邀请赛 2024(American Invitational Mathematics Examination 2024, AIME 2024)(MAA, 2024) 上评估模型
  • 除了标准基准测试外,论文还使用 LLM 作为评判者,在开放式生成任务上评估论文的模型
    • 论文遵循 AlpacaEval 2.0 (2024) 和 Arena-Hard (2024) 的原始配置,它们利用 GPT-4-Turbo-1106 作为配对比较的评判者
    • 在这里,论文仅将最终摘要提供给评估,以避免长度偏差
    • 对于蒸馏模型,论文报告了在 AIME 2024、MATH-500、GPQA Diamond、Codeforces 和 LiveCodeBench 上的代表性结果

Evaluation Prompts

  • 遵循 DeepSeek-V3 的设置,使用 simple-evals 框架提供的提示来评估标准基准测试,如 MMLU、DROP、GPQA Diamond 和 SimpleQA
  • 对于 MMLU-Redux,论文在零样本(zero-shot)设置中采用 Zero-Eval 提示格式 (2024)
  • 对于 MMLU-Pro、C-Eval 和 CLUE-WSC,由于原始提示是少样本(few-shot)的,论文略微修改了提示以适应零样本设置
    • 少样本中的思维链(CoT)可能会损害 DeepSeek-R1 的性能
    • 问题:这里怎么理解?
  • 其他数据集遵循其创建者提供的默认提示的原始评估协议
  • 对于代码和数学基准测试
    • HumanEval-Mul 数据集涵盖了八种主流编程语言(Python、Java、C++、C#、JavaScript、TypeScript、PHP 和 Bash)
    • 使用 CoT 格式评估模型在 LiveCodeBench 上的性能,数据收集时间为 2024 年 8 月至 2025 年 1 月
    • 使用 10 场 Div.2 比赛的题目以及专家精心设计的测试用例来评估 Codeforces 数据集,然后计算预期评分和参赛者百分比
    • 通过无代理框架(agentless framework)(2024) 获得 SWE-Bench 验证结果
    • 使用“diff”格式衡量 AIDER 相关基准测试
  • DeepSeek-R1 的输出在每个基准测试中最多限制为 32,768 个 Token

Baselines

  • 论文对多个强基线模型进行了全面评估,包括 DeepSeek-V3、Claude-Sonnet-3.5-1022、GPT-4o-0513、OpenAI-o1-mini 和 OpenAI-o1-1217
  • 由于在中国大陆访问 OpenAI-o1-1217 API 具有挑战性,论文根据官方报告报告其性能
  • 对于蒸馏模型,论文还比较了开源模型 QwQ-32B-Preview (Qwen)

Evaluation Setup

  • 论文将模型的最大生成长度设置为 32,768 个 Token
  • 论文发现使用贪婪解码(greedy decoding)来评估长输出推理模型会导致更高的重复率和不同检查点之间的显著变异性
    • 因此,论文默认使用 pass@\(k\) 评估 (2021),并使用非零温度(non-zero temperature)报告 pass@1
    • 理解:零温度表示贪婪解码
  • 具体来说,论文使用采样温度 0.6 和 top-\(p\) 值 0.95 为每个问题生成 \(k\) 个回复(通常在 4 到 64 之间,取决于测试集大小)。然后 pass@1 计算为
    $$
    \text{pass@}1=\frac{1}{k}\sum_{i=1}^{k}p_{i}
    $$
    • 其中 \(p_{i}\) 表示第 \(i\) 个回复的正确性
    • 这种方法提供了更可靠的性能估计
  • 对于 AIME 2024,论文还使用 64 个样本报告了共识(consensus),即多数投票(majority vote)结果 (2022),记为 cons@64

3.1 DeepSeek-R1 评估 (DeepSeek-R1 Evaluation)

  • 评估结果如表 4 所示:
  • 对于面向教育的知识基准测试,如 MMLU、MMLU-Pro 和 GPQA Diamond,DeepSeek-R1 相较于 DeepSeek-V3 展现出更优越的性能
    • 这一改进主要归功于通过大规模强化学习在 STEM 相关问题上准确率的显著提升
  • DeepSeek-R1 在 FRAMES(一个依赖长上下文的问答任务)上表现卓越,展示了其强大的文档分析能力
    • 这凸显了推理模型在 AI 驱动的搜索和数据分析任务中的潜力
  • 在事实性基准测试 SimpleQA 上,DeepSeek-R1 的表现优于 DeepSeek-V3,证明了其处理基于事实的查询的能力
    • OpenAI-o1 在该基准测试上超越 GPT-4o 也观察到了类似的趋势
    • 由于在安全强化学习(safety RL)后倾向于拒绝回答某些查询 ,DeepSeek-R1 在中文 SimpleQA 基准测试上的表现不如 DeepSeek-V3
    • 若没有安全强化学习,DeepSeek-R1 的准确率可以超过 70%
  • DeepSeek-R1 在 IF-Eval(一个旨在评估模型遵循格式指令能力的基准测试)上也取得了令人印象深刻的结果
    • 这些改进可以归因于在 SFT 和强化学习训练的最后阶段包含了遵循指令的数据
  • 在 AlpacaEval2.0 和 ArenaHard 上观察到了卓越的性能,表明 DeepSeek-R1 在写作任务和开放域问答方面的优势
    • 其显著超越 DeepSeek-V3 的表现凸显了大规模强化学习的泛化益处,它不仅提升了推理能力,还提高了跨不同领域的性能
    • 特别地,DeepSeek-R1 生成的摘要长度简洁,在 ArenaHard 上平均为 689 个 Token,在 AlpacaEval 2.0 上平均为 2218 个字符
      • 这表明 DeepSeek-R1 在基于 GPT 的评估中避免了引入长度偏差,进一步巩固了其在多项任务中的鲁棒性
  • 在数学任务和编码算法任务(如 LiveCodeBench 和 Codeforces)上,DeepSeek-R1 表现出与 OpenAI-o1-1217 相当的性能,大幅超越其他模型
    • 专注于推理的模型在这些基准测试中占据主导地位
  • 特别地,在面向工程的编码任务上 ,OpenAI-o1-1217 在 Aider 上优于 DeepSeek-R1 ,但在 SWE Verified 上取得了相当的性能
    • 作者认为 DeepSeek-R1 的工程性能将在下一个版本中得到改善,因为目前相关的强化学习训练数据量仍然非常有限

3.2 蒸馏模型评估 (Distilled Model Evaluation)

  • 如表 5 所示
    • 仅通过蒸馏 DeepSeek-R1 的输出,高效的 DeepSeek-R1-7B(即 DeepSeek-R1-Distill-Qwen-7B,下文类似缩写)就能全面超越如 GPT-4o-0513 这样的非推理模型
    • DeepSeek-R1-14B 在所有评估指标上均超越了 QwQ-32B-Preview,而 DeepSeek-R1-32B 和 DeepSeek-R1-70B 在大多数基准测试上显著超过了 o1-mini
    • 这些结果展示了蒸馏的强大潜力
  • 论文发现对这些蒸馏模型应用强化学习能带来显著的进一步增益
    • 作者认为这值得进一步探索,因此在此仅展示简单 SFT 蒸馏模型的结果

Discussion

Distillation v.s. Reinforcement Learning

  • 在 3.2 节中,我们可以看到通过蒸馏 DeepSeek-R1,小模型能够取得令人印象深刻的结果
    • 但还有一个问题悬而未决:模型能否不通过蒸馏,而是通过论文讨论的大规模强化学习训练达到相当的性能?
  • 为了回答这个问题,论文在 Qwen-32B-Base 上使用数学、代码和 STEM 数据进行了大规模强化学习训练,训练超过 10K 步,得到了 DeepSeek-R1-Zero-Qwen-32B
  • 实验结果如表 6 所示,表明 32B 基础模型在经过大规模强化学习训练后,性能与 QwQ-32B-Preview 相当
    • 但从 DeepSeek-R1 蒸馏得到的 DeepSeek-R1-Distill-Qwen-32B 在所有基准测试上的表现均显著优于 DeepSeek-R1-Zero-Qwen-32B
  • 因此,我们可以得出两个结论:
    • 首先,将更强大的模型蒸馏到较小的模型中能产生优异的结果 ,而依赖论文提到的大规模强化学习的小模型需要巨大的计算能力 ,甚至可能无法达到蒸馏的性能
    • 其次,虽然蒸馏策略既经济又有效 ,但要突破智能的边界可能仍然需要更强大的基础模型和更大规模的强化学习

Unsuccessful Attempts

  • 在开发 DeepSeek-R1 的早期阶段,论文也遇到了一些失败和挫折
    • 论文在此分享论文的失败经验以提供见解,但这并不意味着这些方法无法开发出有效的推理模型
Process Reward Model, PRM
  • PRM 是一种合理的方法,可以指导模型采用更好的方法来解决推理任务 (2023; 2022;);但在实践中,PRM 有三个主要局限性可能阻碍其最终成功
    • 第一,在通用推理中明确定义细粒度的步骤具有挑战性
    • 第二,判断当前中间步骤是否正确是一项艰巨的任务
      • 使用模型进行自动标注可能无法产生令人满意的结果,而手动标注不利于扩大规模
    • 第三,一旦引入基于模型的 PRM,就不可避免地会导致奖励黑客攻击(reward hacking)(2022)
      • 并且重新训练奖励模型需要额外的训练资源,并使整个训练流程复杂化
  • 总之,虽然 PRM 在重排模型生成的 top-N 响应或辅助引导式搜索 (2024) 方面表现出良好的能力,但在论文实验的大规模强化学习过程中,与其引入的额外计算开销相比,其优势有限
    • 问题:如何理解 PRM 在重排模型生成的 top-N 响应或辅助引导式搜索 方面表现出良好的能力?
Monte Carlo Tree Search,MCTS
  • 受 AlphaGo (2017a) 和 AlphaZero (2017b) 的启发,论文探索了使用蒙特卡洛树搜索(MCTS)来增强测试时计算的可扩展性
  • MCTS 方法涉及将答案分解成更小的部分,以便模型能够系统地探索解决方案空间
  • 为了促进这一点,论文提示模型生成多个标签,这些标签对应于搜索所需的特定推理步骤
  • 对于训练:
    • 首先使用收集的提示,通过由预训练价值模型引导的 MCTS 来寻找答案
      • 理解:这里的价值模型决定了每次选择哪些节点进行扩展
    • 随后使用产生的 问题-答案 对来训练行动者模型和价值模型,并迭代地改进这个过程
  • 但这种方法在扩大训练规模时遇到了几个挑战
    • 首先,与搜索空间相对明确的象棋不同,Token 生成呈现出一个指数级更大的搜索空间
      • 为了解决这个问题,论文为每个节点设置了最大扩展限制,但这可能导致模型陷入局部最优
    • 其次,价值模型直接影响生成的质量,因为它指导搜索过程的每一步
      • 训练一个细粒度的价值模型本身就很困难,这使得模型难以迭代改进
      • 虽然 AlphaGo 的核心成功依赖于训练一个价值模型来逐步提高其性能,但由于 Token 生成的复杂性,这一原则在论文的设置中难以复制
  • 总之,虽然 MCTS 在与预训练价值模型配对时可以在推理过程中提高性能,但通过自我搜索迭代地提升模型性能仍然是一个重大挑战

Conclusion, Limitations, and Future Work

  • 在本工作中,论文分享了通过 RL 来增强模型推理能力的探索历程
    • DeepSeek-R1-Zero 代表了一种不依赖冷启动数据(cold-start data)的纯强化学习方法,在各种任务上均取得了强劲的性能
    • DeepSeek-R1 更加强大,它利用了冷启动数据以及迭代式的强化学习微调
    • 最终,DeepSeek-R1 在一系列任务上达到了与 OpenAI-o1-1217 相当的性能
  • 论文进一步探索了将推理能力蒸馏(distillation)到小型稠密模型(small dense models)中的方法
    • 论文使用 DeepSeek-R1 作为教师模型(teacher model)来生成 80 万条训练样本,并对多个小型稠密模型进行了微调
    • 结果令人鼓舞:
      • DeepSeek-R1-Distill-Qwen-1.5B 在数学基准测试中超越了 GPT-4o 和 Claude-3.5-Sonnet,在 AIME 上达到了 28.9%,在 MATH 上达到了 83.9%
      • 其他稠密模型也取得了令人印象深刻的结果,显著超越了基于相同底层检查点(underlying checkpoints)的其他指令微调模型(instruction-tuned models)
  • 未来,论文计划在以下几个方向为 DeepSeek-R1 投入研究
    • 通用能力 (General Capability):
      • 目前,DeepSeek-R1 在函数调用(function calling)、多轮对话(multi-turn)、复杂角色扮演(complex role-playing)和 JSON 输出等任务上的能力尚不及 DeepSeek-V3
      • 接下来,论文计划探索如何利用长思维链(long Chain-of-Thought, CoT)来增强这些领域的任务
    • 语言混合 (Language Mixing):
      • DeepSeek-R1 目前针对中文和英文进行了优化,这可能导致在处理其他语言的查询时出现语言混合问题
      • 例如,即使用户查询使用的不是英文或中文,DeepSeek-R1 也可能使用英文进行推理和回复
      • 论文旨在未来的更新中解决这一局限性
    • 提示工程 (Prompting Engineering):
      • 在评估 DeepSeek-R1 时,论文观察到它对 Prompts 很敏感
      • 少样本提示(few-shot prompting)consistently 会降低其性能
      • 因此,论文建议用户在使用零样本(zero-shot)设置时直接描述问题并指定输出格式 ,以获得最佳结果
    • 软件工程任务 (Software Engineering Tasks):
      • 由于评估时间较长,影响了强化学习过程的效率,大规模强化学习尚未广泛应用于软件工程任务
      • 因此,DeepSeek-R1 在软件工程基准测试中并未显示出相对于 DeepSeek-V3 的巨大改进
      • 未来的版本将通过对软件工程数据实施拒绝采样(rejection sampling),或在强化学习过程中引入异步评估(asynchronous evaluations)来提高效率,从而解决这一问题

NLP——Model-Growth-Initialization


整体说明

  • 模型增长初始化(Model Growth Initialization,MGI)是一种让大模型在不从头开始训练的前提下迅速“长大”并具备良好初始性能的技术
  • Model Growth Initialization 的核心思想是:先训练一个小模型,然后在深度或宽度上扩展成更大的模型(如增加层数、宽度或专家数量),并把小模型已学到的知识完整复用到大模型里 ,从而显著节省训练成本、提升收敛速度
  • TLDR:Model Growth Initialization 就是“把小模型当预制件,复制粘贴成大模型,再微调”,用最小的算力让大模型站在小模型的肩膀上起跑
  • 一个通俗的比喻:可以把模型训练比作盖楼:
    • 传统做法 :平地起高楼,从地基开始逐层盖(随机初始化)
    • MGI 做法 :先盖一栋“小楼”并装修完毕,然后把整栋小楼连同装修一起复制+堆叠 ,瞬间变成一栋“大厦”,再只对新增部分做微调

Background

  • 直接训练千亿参数的大模型成本极高(如 Llama-3 需消耗 770 万 GPU 小时),而 Model Growth Initialization 通过复用小模型的知识,显著降低计算成本
    • 例如,使用已训练的 7B 模型初始化 30B 模型时,可减少约 50% 的训练时间

结构扩展策略

  • 模型扩展通常分为三类:
    • 深度扩展 :增加 Transformer 层的数量
    • 宽度扩展 :增加神经元数量、头数或 FFN 维度(如 Net2Net 通过复制神经元并均分权重)
    • 混合扩展 :同时调整深度和宽度(在扩展层数的同时增加隐藏层维度)

参数初始化原则

  • 扩展后的模型需保持与原模型行为一致 ,避免训练震荡
  • Model Growth Initialization 的初始化原则是确保模型在结构扩展(如增加层数、宽度或专家数量)后,既能继承原小模型的知识,又能保持训练稳定性(避免梯度爆炸/消失或性能骤降),同时为新参数提供合理的学习起点
  • 这些原则的核心逻辑是:在“继承知识”与“学习新能力”之间找平衡
    • 通过功能保留和身份映射确保模型初始稳定
    • 通过部分保留和跨层传递实现精准知识迁移
    • 通过优化器状态一致保证训练连续性

Function-Preserving Initialization, FPI 原则

  • 让扩展后的大模型在初始状态下,对任意输入的输出与原小模型完全一致(或近似一致),实现“无损知识迁移”
  • 避免扩展后模型“忘记”原有的能力,为后续训练提供稳定起点
  • 通过精确的参数复制与调整,确保扩展后的模型计算逻辑与原模型等效
  • 示例:
    • 宽度扩展示例 :若原模型某层有2个神经元(h1, h2),输出为 y = w1*h1 + w2*h2,扩展到3个神经元时,新增神经元h3的权重复制h2的参数,同时将w2拆分为w2/2和w2/2,使新输出 y = w1*h1 + (w2/2)*h2 + (w2/2)*h3 与原输出完全一致
    • 深度扩展示例 :新增 Transformer 层时,将其参数初始化为“恒等映射”(如自注意力的输出投影矩阵设为单位矩阵,偏置设为 0),确保新增层对输入不做任何修改,等效于原模型的计算流程(实际上就是后面要介绍的 IMI 方法)

Identity Mapping Initialization, IMI 方法

  • 让新增的层/参数在初始状态下“不干扰”模型原有计算,仅在训练中逐步学习新功能
  • 防止新增结构破坏原模型的优化状态(如损失函数突增),降低训练震荡风险
  • 将新增组件的参数初始化为“中性值”,使其对模型输出的影响为零或极小
  • 示例:
    • 新增 Transformer 层 :将多头注意力的输出权重矩阵初始化为单位矩阵(确保输入=输出),前馈网络(FFN)的中间层权重设为 0(使 FFN 等效于“跳过连接”)
    • MoE 模型新增专家 :将新专家的输入/输出投影权重初始化为0,使其在初始阶段不参与计算,仅通过训练逐步被激活
  • 注:“自注意力输出投影矩阵设为单位矩阵、偏置设为 0” 的做法,本身是 Identity Mapping Initialization 的具体操作,因为它直接让该组件成为 “恒等变换”
    • 但当这种操作被用于确保 “扩展后的整体模型与原模型功能一致” 时,它成为实现 FPI 的手段之一
  • 注:FPI(是一个原则) 保证新旧模型输出完全一样,而 IMI(是一种方法) 只保证新增部分是恒等映射,不保证整体输出不变
    • IMI 只保证新增部分是恒等映射,原始参数往往可能也会被修改

Optimizer State Consistency

  • 确保扩展后的模型优化器(如Adam)状态与原模型兼容,避免训练进程中断
  • 使扩展后的模型训练能“无缝衔接”原训练过程,减少重新收敛的时间
  • 复用原模型的优化器参数(如动量、二阶矩估计),并对新增参数初始化合理的优化器状态
    • 对复用的参数,直接继承原优化器的动量值,确保梯度更新方向与原训练一致;
    • 对新增参数,将优化器的动量初始化为 0(或小值),避免其初期干扰整体更新节奏

附录:一些典型方法与技术细节

Net2Net:开创性的结构扩展框架

  • 深度扩展(Net2DeeperNet) :直接复制Transformer层(如将L层模型扩展为2L层),确保每层输入输出形状一致
  • 宽度扩展(Net2WiderNet) :新增神经元的权重复制相邻神经元,并调整输出权重使总和不变。例如,原输出为y = e*h1 + f*h2,扩展后变为y = e*h1 + (f/2)*h2 + (f/2)*h3

Stacking Your Transformers:深度堆叠优化

  • G_stack操作符 :通过堆叠多个小模型(如 7B 到 70B),使大模型在 194B tokens 即可收敛到传统 300B tokens 的损失,速度提升 54.6%
  • 增长规划(Growth Schedule) :分阶段扩展模型,例如先训练 16B 模型,再逐步扩展至 101B,同时调整学习率和优化器状态

MoE 模型的扩展

  • Mixtral-8x7B :从 Mistral-7B 初始化,直接复用其 FFN 层作为专家,并通过微调不同任务的 FFN 生成多样化专家
  • 参数共享与隔离 :专家层共享底层编码器参数,但各自保留独立的前馈网络,平衡效率与多样性

初始化策略的精细化设计

  • 部分保留初始化(Partial Preservation Init) :保留原模型部分层的参数,随机初始化新增层
  • 交叉层知识传递(AKI) :不仅考虑当前层参数,还结合下一层参数进行初始化

附录:MGI 中的 FPI 原则详细介绍

  • 功能保留初始化(Function-Preserving Initialization,FPI) 特指通过精确的参数复制与调整,确保扩展后的大模型在初始状态下对任意输入的输出与原小模型完全一致(或高度近似),从而实现“无损知识迁移”
  • FPI 是模型增长初始化(Model Growth Initialization,MGI)中的关键要求

核心机制与示例

  • 宽度扩展 :当某层神经元数量从 2 个扩展到 3 个时,新增神经元的权重复制原模型中邻近神经元的参数,并通过权重拆分(如将原权重w拆分为w/2和w/2),使新输出与原输出完全一致。例如:
    • 原模型:\( y = w_1 h_1 + w_2 h_2 \)
    • 扩展后模型:\( y = w_1 h_1 + (w_2/2) h_2 + (w_2/2) h_3 \)
  • 深度扩展 :在新增 Transformer 层时,将其参数初始化为“恒等映射”,确保新增层对输入不做任何修改,等效于原模型的计算流程
    • 如自注意力的输出投影矩阵设为单位矩阵可实现恒等映射
  • MoE 模型扩展
    • 新增专家的输入/输出投影权重初始化为 0,使其在初始阶段不参与计算,仅通过训练逐步被激活,避免干扰原模型的优化状态

为什么 MGI 中需要 FPI?

  • 降低训练震荡风险 :通过确保扩展后的模型初始输出与原模型一致,避免损失函数突增或优化器状态中断
  • 加速大模型收敛 :复用小模型的知识,使大模型在训练初期即可继承成熟的特征提取能力,减少从头学习的时间成本

NLP——SEAL

注:本文包含 AI 辅助创作

  • 参考链接:
    • 原始论文:(SEAL)Self-Adapting Language Models, arXiv 20250612, MIT
    • 主页:jyopari.github.io/posts/seal

Paper Summary

  • 背景 & 问题:LLM 虽然强大,但缺乏动态调整其权重以应对新任务、知识或示例的机制
  • 论文提出了 自适应大语言模型(SEAL, Self-Adapting LLMs) 框架,通过生成自身的微调数据和更新指令,使 LLM 能够自我调整
    • 给定新输入时,模型会 produces a self-edit
      • a generation,可能以不同方式重组信息、指定优化超参数,或调用工具进行数据增强和基于梯度的更新

        Given a new input, the model produces a self-edit—a generation that may restructure the information in different ways, specify optimization hyperparameters, or invoke tools for data augmentation and gradient-based updates

    • 通过 SFT ,这些 self-edit 会带来持久的权重更新,从而实现长期适应(lasting adaptation)
    • 为了训练模型生成有效的 self-edit,论文使用强化学习循环(loop) ,将更新后模型在下游任务中的表现作为奖励信号
    • 与依赖独立适应模块或辅助网络的现有方法不同,SEAL 直接利用模型的生成能力参数化并控制其自身的适应过程
  • 在知识整合和 Few-shot 泛化的实验中,SEAL 展现了语言模型在新数据下实现自我导向适应的潜力

Introduction and Discussion

  • 在大规模文本语料库上预训练的 LLM 在语言理解和生成方面表现出卓越能力(2020; 2023; 2024; 2025)
    • 但将这些强大模型适配到特定任务(2020)、整合新信息(2020)或掌握新推理技能(2025)仍然具有挑战性,主要由于任务特定数据的稀缺性
  • 论文探讨了一个有趣的假设:LLM 能否通过转换或生成自身的训练数据和学习过程来实现自我适应?(can an LLM self-adapt by transforming or generating its own training data and learning procedure?)
  • 以人类学生学习为例,学生通常通过整理笔记来备考,这些笔记是对原始内容的重新解读和增强 :这种将外部知识转化为更易理解形式的能力是人类学习的普遍特征
  • 但当前 LLM 的训练和部署方式与人类学习形成鲜明对比:面对新任务时,LLM 只能通过微调或上下文学习(ICL, In-Context Learning)(2022; 2024; 2023)直接利用原始数据
    • 这些数据可能并非最优格式(或数量),且现有方法无法让模型开发定制化的数据转换和学习策略
  • 为实现语言模型的高效适应,论文提出赋予 LLM 生成自身训练数据和微调指令的能力
    • 具体而言,论文引入了一种强化学习算法,训练 LLM 生成 self-edit(即指定数据和优化超参数的自然语言指令(如图1 所示)
    • 论文将此类模型称为 自适应大语言模型(Self-Adapting LLMs,SEAL)
  • 论文在两个应用中评估 SEAL
    • 首先,测试其在整合新事实知识任务中的表现:模型通过生成合成数据而非直接微调原文
      • 在无上下文版本的 SQuAD(2016)问答任务中,准确率从 33.5% 提升至 47.0%
      • 值得注意的是,SEAL 自生成的数据表现优于 GPT-4.1 生成的合成数据
    • 其次,论文在简化版 ARC-AGI 基准(2025)上评估 Few-shot 学习能力,模型通过调用工具自主选择数据增强和优化超参数(如学习率、训练轮次等)
      • 实验表明,SEAL 自动配置工具的能力显著优于标准上下文学习和未经强化学习训练的 self-edit 方法
    • 这些结果共同证明,SEAL 是实现语言模型自我适应的通用框架

Methods

  • 论文提出了 自适应性大语言模型(Self-Adapting LLMs, SEAL) ,这是一种框架,使语言模型能够通过生成自己的合成数据和优化参数(即“自我编辑”)来改进自身
    • 模型通过 Token 生成直接产生这些自我编辑,生成过程以模型上下文中的数据为基础
    • 自我编辑的生成通过 RL 进行训练,模型的奖励信号来源于应用自我编辑后模型在目标任务上的性能提升
    • 因此,SEAL 可以视为一种包含两个嵌套循环的算法:
      • 外部强化学习循环优化自我编辑的生成
      • 内部更新循环利用生成的自我编辑通过梯度下降更新模型
  • 论文的方法可以看作是一种元学习(meta-learning)实例,其目标是学习如何生成有效的自我编辑

General Framework

  • 设 \(\theta\) 表示语言模型 \(\text{LM}_{\theta}\) 的参数,SEAL 在单个任务实例 \((C, \tau)\) 上运行,其中 \(C\) 是包含任务相关信息的上下文,\(\tau\) 是用于评估模型适应能力的下游任务
    • 在知识整合任务中,\(C\) 是需要整合到模型内部知识中的段落,\(\tau\) 是关于该段落的一组问题和答案;
    • 在 Few-shot 学习中,\(C\) 包含新任务的少量示例,\(\tau\) 是查询输入和真实输出
  • 给定 \(C\),模型生成一个自我编辑(SE),其形式因领域而异(见原论文 3.2 节),并通过 SFT 更新参数:
    $$\theta’ \leftarrow \text{SFT}(\theta, \text{SE})$$
    • 问题:SE 是什么,如何理解 SE?
    • 回答:SE 是一些指令,给定一个新输入(如一段文本或少量示例),模型会生成一组数据或优化参数(即 self-edit),这些输出 SE 会被用于更新模型自身参数,从而提升模型
  • 论文使用强化学习优化自我编辑的生成过程:模型采取动作(生成 SE),根据 \(\text{LM}_{\theta’}\) 在 \(\tau\) 上的表现获得奖励 \(r\),并更新其策略以最大化期望奖励:
    $$
    \mathcal{L}_{\text{RL} }(\theta_t) := -\mathbb{E}_{(C,\tau)\sim\mathcal{D} } \left[ \mathbb{E}_{\text{SE}\sim\text{LM}_{\theta_t}(:C)} \left[ r(\text{SE}, \tau, \theta_t) \right] \right]. \tag{1}
    $$
    • 与标准强化学习设置不同,论文的奖励取决于模型参数 \(\theta\)(因为 \(\theta\) 会更新为 \(\theta’\) 并随后被评估)
    • 因此,强化学习的状态必须包含策略的参数,即 \((C, \theta)\),尽管策略的观察仅限于 \(C\)(将 \(\theta\) 直接放入上下文不可行)
    • 这意味着从旧模型 \(\theta_{\text{old} }\) 收集的(状态、动作、奖励)三元组可能与当前模型 \(\theta_{\text{current} }\) 不匹配
    • 为此,论文采用同策略(on-policy)方法,即自我编辑从当前模型中采样,并且奖励也基于当前模型计算
  • 论文尝试了多种同策略方法,如组相对策略优化(GRPO)和近端策略优化(PPO),但发现训练不稳定
    • 最终,论文采用了 ReST\(^{EM}\)(2023),这是一种基于过滤行为克隆的简化方法,也称为“拒绝采样 + SFT”
    • ReST\(^{EM}\) 可以视为一种期望最大化(EM)过程:
      • E 步 :从当前策略中采样候选输出;
      • M 步 :仅对获得正奖励的样本进行监督微调
  • 这种方法在二元奖励下优化了目标函数(1)的近似:
    $$
    r(\text{SE}, \tau, \theta_t) = \begin{cases}
    1 & \text{If on } \tau \text{ adaptation using SE improves } \text{LM}_{\theta_t} \text{‘s performance}, \\
    0 & \text{Otherwise}.
    \end{cases} \tag{2}
    $$
    • 具体来说,在优化(1)时,论文需要计算梯度 \(\nabla_{\theta_t} \mathcal{L}_{\text{RL} }\)
    • 但由于奖励项 \(r(\text{SE}, \tau, \theta_t)\) 依赖于 \(\theta_t\) 且不可微,论文将其视为固定值
    • 在这种近似下,对于包含 \(N\) 个上下文和每个上下文 \(M\) 个采样自我编辑的小批量,蒙特卡洛估计为:
      $$
      \nabla_{\theta_t} \mathcal{L}_{\text{RL} } \approx -\frac{1}{NM} \sum_{i=1}^{N} \sum_{j=1}^{M} r_{ij} \nabla_{\theta_t} \log p_{\theta_t}(\text{SE}_{ij} \mid C_i), \tag{3}
      $$
      • 其中 \(p_{\theta_t}\) 表示模型的自回归分布
      • \(y_s^{(i,j)}\) 是自我编辑 \(\text{SE}_{ij}\) 的第 \(s\) 个 Token
      • 由于 \(r=0\) 的序列可以忽略,因此 ReST\(^{EM}\) 通过简单的“对优质自我编辑进行 SFT”优化了目标(1)
  • SEAL 的训练流程总结在算法1 中
  • 此外,尽管论文使用单一模型同时生成自我编辑和学习这些编辑,但也可以将这两个角色解耦
    • 在这种“教师-学生”框架中,学生模型通过教师模型提出的编辑进行更新,而教师模型则通过强化学习训练,以生成能最大化学生改进的编辑

Domain Instantiations

  • 论文在两个领域实现了 SEAL 框架:知识整合(Knowledge Incorporation)和 Few-shot 学习。这两个领域展示了模型适应的两种互补形式:
    • 1)将新信息整合到模型权重中,使其无需依赖上下文即可回忆(通过无上下文 SQuAD 评估);
    • 2)在少量示例后泛化到新任务(通过 ARC 评估)
Knowledge Incorporation
  • 论文的目标是高效地将段落中的信息整合到模型权重中
  • 一种有前景的方法是让语言模型生成基于段落的内容,然后对原始段落和生成内容进行微调(2022; 2023; 2024)
  • 论文采用典型格式:段落的逻辑推论
    • 这种方法将上下文 \(C\) 转换为一组推论 \(\text{SE} = \{s_1, s_2, \ldots, s_n\}\),并通过 Prompt 模型“列出从内容中衍生的几条推论”实现
    • 输出可能包括原始段落的推论、逻辑结果或重述
    • 在 附录C 中,论文还探索了其他 Prompt ,如“以不同方式重写段落”或“以问答格式重写”,并发现无论基础 Prompt 如何,论文的方法都能带来相似的性能提升
  • 这些自我生成的语句构成 SFT 的训练数据:论文对每个序列 \(s_i\) 计算标准因果语言建模损失,并更新模型参数,得到 \(\theta’\)
    • 由于每次更新的数据量较小且总更新次数较多,论文使用低秩适配器(LoRA)进行高效轻量级微调
    • 最后,适应后的模型 \(\text{LM}_{\theta’}\) 在任务 \(\tau\) 上评估。这一过程如图2 所示
  • 在强化学习训练中,适应模型在 \(\tau\) 上的准确率定义了奖励 \(r\),驱动外部强化学习优化。这训练模型以最适合微调的方式重组段落
Few-Shot Learning
  • 抽象与推理语料库(Abstraction and Reasoning Corpus,ARC)是一个测试抽象推理和从极少示例中泛化的基准
    • 每个任务包括少量输入-输出示例和一个需要预测正确输出的保留测试输入
  • 论文采用测试时训练(TTT)协议,其中少量示例的增强用于执行基于梯度的适应。论文训练 SEAL 学习这些决策,而不是依赖手动调整的启发式方法。为此,论文定义了一组工具 ,每个工具是预定义的函数,用于转换数据或指定训练参数,包括:
    • 数据增强(Data augmentations) :旋转、翻转、反射、转置、调整大小操作(如更改网格分辨率)以及链式或重复变换
    • 优化参数(Optimization parameters) :学习率、训练周期数以及损失是否计算所有 Token 或仅输出 Token
  • 模型通过任务的少量示例 Prompt 生成自我编辑(即指定调用哪些工具及如何配置的指令),如图3 所示
    • 然后通过基于 LoRA 的微调应用这些自我编辑来适应模型
    • 适应后的模型在保留测试输入上评估,结果决定训练自我编辑生成策略的奖励

Results

  • 本节论文通过两个适应领域实证评估 SEAL: Few-shot 学习和知识整合。完整的训练、超参数和评估细节见 附录A 和 附录B

Few-Shot Learning

  • 论文使用 Llama-3.2-1B-Instruct(2023)进行实验,这是一个未针对 ARC 进行预训练的小型开源模型
    • 由于大多数 ARC 任务对未经预训练的模型具有挑战性,论文从 ARC 训练集和评估集中分别筛选了 11 个和 8 个任务,确保它们在最优 TTT 配置下可解
  • 模型通过 ReST\(^{EM}\) 训练,每个训练任务采样 15 个自我编辑
    • 每个自我编辑独立生成 15 个更新模型,随后在对应的保留测试示例上评估
    • 论文仅强化那些导致模型正确适应的自我编辑(即适应后能正确输出测试输入的模型)
  • 训练后,论文通过为每个保留评估任务生成 5 个自我编辑并独立应用它们来评估模型
    • 然后报告导致正确输出的自我编辑百分比,反映学习到的自我编辑生成策略的质量
  • 论文对比了以下基线:
    • 1)ICL(上下文学习) :Llama-3.2-1B-Instruct 直接通过少量示例 Prompt 并查询测试输入
    • 2)TTT + 自我编辑(无 RL 预训练) :模型使用少量示例和合成增强进行测试时训练,但未通过 RL 优化增强或训练配置
    • 3)Oracle TTT :模型使用人工设计的最优配置进行测试时训练,作为方法的上限
  • 结果如表1 所示
    • SEAL 显著提高了适应成功率:72.5% vs. 20%(无 RL 训练的自我编辑)和 0%(无适应),但性能仍低于 Oracle TTT,表明还有改进空间

Knowledge Incorporation

  • 论文在 Quen2.5-7B 上实验,整合来自 SQuAD 段落的新事实内容
  • 论文进行了两轮 ReST\(^{EM}\) 训练,每轮包含 50 个上下文(详见 附录B)。论文对比了以下基线方法:
    • 1)基础模型 :未经适应的预训练模型,无法访问段落
    • 2)仅训练段落 :模型直接对段落进行标准语言建模损失微调,无合成数据
    • 3)训练段落 + 合成数据 :模型训练段落及自我生成的推论,与 SEAL 相同但无 RL 训练
    • 4)训练段落 + GPT-4.1 合成数据 :模型训练段落及通过 OpenAI API 从 GPT-4.1 生成的推论
  • 表2 报告了两种设置下的平均无上下文 SQuAD 准确率:单段落(\(n=1\))和持续预训练(CPT, \(n=200\))
    • 在单段落设置中,直接微调段落仅比冻结基础模型略优(33.5% vs. 32.7%),表明原始数据不足
    • 使用 GPT-4.1 生成的合成数据将准确率提升至 46.3%,比仅段落基线高 12.8 个百分点
    • 使用 Quen-2.5-7B 生成的合成数据达到 39.7%,提高 6.2 个百分点
    • 经过强化学习后,SEAL 进一步将准确率提升至 47.0% ,显著优于 GPT-4.1 的合成数据,尽管模型规模小得多
  • 在 CPT 设置中,模型在一次持续预训练中整合了 \(n=200\) 个段落的信息,随后在所有 974 个对应问题上评估
  • SEAL 再次超越所有基线,达到 43.8% 准确率。尽管绝对性能低于单段落设置(可能由于梯度干扰增加),但相对改进一致,表明 SEAL 的编辑策略能够泛化到原始 RL 设置之外
  • 图4 跟踪了每轮 RL 迭代后的准确率。两轮迭代后 SEAL 即超越 GPT-4.1 数据;后续迭代收益递减,表明策略快速收敛为将段落提炼为易学习的原子事实(定性示例见图5)。所有结果均使用调优超参数(见附录B)

Limitations

Catastrophic Forgetting

  • 论文提出 self-edit 语言模型(SEAL)的一个关键动机是实现持续学习(Continual Learning)的终极目标——让模型能够随着时间的推移不断整合新信息,无论是通过与环境的主动交互还是通过标准训练
  • 虽然之前的实验评估了 SEAL 在独立编辑场景下的适应能力,但更雄心勃勃的目标是支持连续的编辑序列 :模型能否在保留已有知识的同时,反复适应新信息?
  • 这一问题直接关联到灾难性遗忘(2014, 2015)的挑战,即新更新会破坏过去的学习成果
    • 当前的训练设置并未显式优化知识保留,但论文旨在建立一个基线,评估SEAL在没有专门机制的情况下处理连续 self-edit 的能力
    • 为了测试这一点,论文在知识整合领域模拟了一个持续学习场景:模型接收一系列测试段落,每个段落触发一次新的 self-edit
    • 每次更新后,论文重新评估模型在所有已见任务上的表现,以衡量其知识保留能力
  • 如图6所示,随着编辑次数的增加,模型在早期任务上的表现逐渐下降,这表明SEAL仍然容易受到灾难性遗忘的影响
    • 尽管如此,它能够在多次更新后避免完全崩溃,这表明未来仍有改进空间
    • 未来的工作可以通过奖励塑形(2020, 2024)来增强这一能力,例如惩罚对早期任务的回归,或整合持续学习策略,如零空间约束编辑(2025)或表示叠加(2019)

Computational overhead

  • TTT(Test-Time Training)奖励循环的计算成本显著高于其他用于 LLM 的 RL 方法
  • 例如,基于人类偏好的奖励信号通常只需要一次模型前向传播,而基于验证解的奖励可能仅依赖简单的模式匹配(如正则表达式)
  • 相比之下,论文的方法需要对整个模型进行微调和评估以计算奖励——每次 self-edit 评估大约需要30-45秒,带来了显著的开销(详见附录B.5)

Context-dependent evaluation

  • 当前的实例化假设每个上下文都配有一个明确的下游任务:Few-shot 示例附带一个保留的查询对,每个段落捆绑了参考问答
  • 这种耦合简化了奖励计算,但阻碍了 SEAL 的 RL 训练扩展到未标注语料库
  • 一个潜在的解决方案是让模型不仅生成 self-edit,还为每个段落生成自己的评估问题(例如草拟问答项或合成测试用例),同时保留原始内容在上下文中
  • 这些模型编写的查询可以提供强化学习所需的即时监督,从而将适用性扩展到缺乏外部问答集的通用训练领域

Related Work

  • 合成数据生成(Synthetic Data Generation) :
    • 合成数据在训练中的应用日益广泛,从大规模预训练数据集(2023; 2024; 2024)到任务特定的数据增强(2023; 2024)和指令微调集(2023; 2023)
    • Yang 等人(2025)通过基于图的 Prompt 生成合成数据
    • SEAL 在此基础上,利用强化学习训练生成策略,直接最大化合成数据在梯度更新中的下游效用,而非依赖手动调整的静态启发式方法
  • 知识更新(Knowledge Updating) :
    • 近期研究尝试通过权重更新修改或注入事实知识
    • 部分方法直接定位与特定事实相关的参数(2022; 2022; 2023)
    • 另一些则利用上下文信息生成额外的微调数据(2024; 2024; 2025; 2025)
    • 论文采用后者,参考 Akyurek 等人(2024)提出的逻辑蕴涵生成和 Lampinen 等人(2025)展示的蕴涵微调优于上下文学习的结果
    • SEAL 通过强化学习训练模型生成更优的微调数据,进一步扩展了这些方法
    • Park 等人(2025)表明,直接生成问答对(QA)的 Prompt 优于蕴涵式 Prompt
    • 由于 SEAL 框架对 self-edit 数据的格式无关,它同样可以训练生成 QA 对或其他输出格式
  • 测试时训练(TTT, Test-Time Training) :
    • 测试时训练基于输入临时调整模型权重(2020; 2022; 2024)。Akyurek 等人(2025)表明,TTT 与上下文学习结合可在 Few-shot 设置中超越标准 ICL
    • SEAL 在内部优化中整合了 TTT,利用其高效性执行多次更新,并奖励带来最大性能提升的数据生成策略
  • LLM 的强化学习(Reinforcement Learning for LLMs) :
    • 强化学习在改进 LLM 行为中发挥核心作用,最初通过 RLHF(2022)实现
    • 近期研究利用可验证奖励直接优化任务成功率(2022; 2024; 2025)
    • SEAL 将强化学习应用于优化 self-edit 数据的生成,而非最终答案或推理轨迹的修订
  • 元学习与自修改系统(Meta-Learning and Self-Modifying Systems) :
    • SEAL 通过外部优化循环学习适应策略(即如何生成有效的 self-edit),体现了元学习原则(2001; 2017; 2025),其目标是学习如何高效地从任务上下文中学习
    • 元学习同样已应用于强化学习领域在该领域中,模型通过元目标进行训练,以快速适应新任务
    • 这类工作的一个自然延伸是自指网络(self-referential networks),即模型自行修改自身参数(1992; 2022)
    • 在大型语言模型领域,近期的研究已将元学习原则应用于改进 LLM 的适应性[2024;2023]
    • 值得注意的是,Hu等人(2023)训练了一个较小的模型,使其在对语料库进行微调时输出特定于标记的权重,以解决与我们类似的知识整合任务
    • 然而,SEAL 通过利用模型现有的生成能力来参数化更新,从而在跨领域场景中展现出更强的通用性
  • 自我改进(Self-Improvement) :
    • 近期研究涵盖自我改进或自训练的多种方法
    • RLAIF(2022; 2024)和自我奖励语言模型(2024; 2025)利用模型自身提供奖励信号,基于判断输出比生成更容易的观察(2025)
    • 其他工作通过多数投票或模型置信度作为强化学习奖励,在无真实标签的情况下提升数学任务性能(2023; 2024; 2025; 2025)
    • 然而,这些方法受限于模型的当前评估能力和自一致性,相比之下,SEAL 通过与外部数据的交互实现自我改进,为更具扩展性的路径提供了可能

Discussion and Conclusion

  • Villalobos等人(2024)预测,到2028年,前沿 LLM 将完成对所有公开人类生成文本的训练
  • 作者认为,这一迫近的“数据墙”将迫使人们采用合成数据增强 ,一旦网络规模的语料库耗尽,进展将取决于模型自主生成高效用训练信号的能力
  • 自然的下一步是元训练一个专用的 SEAL 合成数据生成模型,生成新的预训练语料库,使未来模型能够在无需额外人类文本的情况下扩展并实现更高的数据效率
  • 我们可以设想一个未来场景:大语言模型能够消化新数据(如学术论文),并利用上下文中的数据和已有知识生成大量解释和推论
    • 这种自我表达和自我优化的迭代循环可能使模型即使在缺乏外部监督的情况下,也能在罕见或代表性不足的主题上持续改进
  • 此外,尽管现代推理模型通常通过强化学习生成思维链(CoT)轨迹,但 SEAL 可以提供一种互补机制,让模型学习何时以及如何更新自身权重
    • 这两种方法可以协同作用:模型可能在推理过程中选择更新权重以引导当前轨迹,或在完成推理后将关键见解提炼到参数中——通过内部化学习提升未来推理能力
  • 这种持续优化的循环对于构建代理系统(Agentic Systems)也很有前景——这些模型在长期交互中运行,并动态适应不断变化的目标
    • 代理模型必须逐步获取和保留知识
    • 论文的方法通过支持结构化的自我修改来实现这种行为:在交互后,代理可以合成一个 self-edit,触发权重更新
    • 这可能使代理能够随时间发展,使其行为与先验经验对齐,并减少对重复监督的依赖
  • SEAL 证明了大语言模型在预训练后不必保持静态:通过学习生成自己的合成 self-edit 数据并通过轻量级权重更新应用这些数据,它们可以自主整合新知识并适应新任务
  • 展望未来,作者计划将 SEAL 框架扩展到预训练、持续学习和代理模型中,最终使语言模型能够在数据受限的世界中自我学习和扩展

附录A 实验细节:Few-shot 学习(Experimental Details: Few-shot Learning)

  • 模型与设置 :论文使用Llama-3.2-1B-Instruct(2024)作为基础语言模型。由于该模型未针对ARC进行专门训练,其解决ARC任务的能力有限。为了进行受控评估,论文从训练集和评估集中筛选了一小部分可通过最优TTT超参数解决的 ARC 问题
    • 训练集 :从训练集中选择了11个ARC任务作为RL优化的环境
    • 评估集 :从评估集中选择了8个不同的ARC问题用于衡量泛化性能。这8个问题明确筛选为适合TTT的场景
  • 这些集合的选择是为了隔离 self-edit 学习的效果,而非通用ARC能力
  • 训练流程 :论文在11个训练问题上使用单轮强化学习训练SEAL。对于每个问题,模型生成15个 self-edit 配置,包括:
    • 数据增强策略 :控制是否包含基本示例、基于尺寸的增强、链式推理和重复输入
    • 优化配置 :指定学习率、训练轮数和损失函数类型(例如是否对所有 Token 或仅最终输出 Token 计算损失)
  • 每个配置通过测试时训练(TTT)评估,并分配二元奖励:如果适应后的模型生成正确解则为1,否则为0
  • 评估细节 :对于8个保留的评估任务,模型生成5个 self-edit 配置,共40个配置。成功率衡量为适应后生成正确输出的配置百分比

附录B 实验细节:知识整合(Experimental Details: Knowledge Incorporation)

  • 模型与设置 :论文在知识整合实验中使用Quen-2.5-7B基础模型(2025)。论文将SQuAD数据集v1.1(2016)重新用于无上下文段落的问答任务
  • RL训练流程 :论文运行2轮 ReST\(^{EM}\) 训练(2025)
    • 每轮从 SQuAD 训练集中抽取 50 个 上下文-问题-答案 三元组。对于每个上下文,采样 5 个 self-edit 生成(温度=1),并通过 3 个随机种子评估每个生成
  • 合成数据生成与微调细节 :在所有模型中,论文通过 Prompt 生成段落的推论来生成合成数据
  • 评估细节 :论文在 SQuAD 评估集的 200 个段落子集上评估,共974个问题
    • 使用 GPT-4.1(2025)通过 OpenAI API 进行自动评分
  • 计算资源 :所有实验在 2×H100 或 2×H200 上运行
    • 使用 DeepSpeed ZeRO-3(2020)进行 ReST\(^{EM}\) 训练的 SFT,使用 vLLM(2023)进行高效推理

附录C Prompting

  • 近期研究表明,强化学习基线和结果对 Prompt 高度敏感。论文在知识整合设置中测试了 4 种额外的 self-edit Prompt;五种 Prompt 如下:
    • 1)推论(Implications)
    • 2)长推论(Implications-long)
    • 3)超长推论(Implications-very-long)
    • 4)重写(Rewrite)
    • 5)自问答(Self-QA)
  • 结果显示,尽管通过 Prompt 生成长响应可以提高性能,但以这些 Prompt 为基础的 RL 训练能带来更大的改进
    • 在所有情况下,ReST\(^{EM}\)将性能提升了约 6 到 11 个百分点

NLP——Not-Just-Scaling-Laws

注:本文包含 AI 辅助创作

  • 参考链接:
    • 原始论文:Not-Just-Scaling Laws: Towards a Better Understanding of the Downstream Impact of Language Model Design Decisions, arXiv 20250525, CMU
    • 代码和数据开源:https://anonymous.4open.science/r/llm-pretraining-behaviours-FE80/

Paper Summary

  • 本论文首次对开源语言模型在不同任务上的性能进行了系统分析,将其性能与架构和数据等联系起来
    • 但本文的分析存在一些不太严谨的地方,因为各个公司的实现方式可能是不一样的,比如基建或者各种配方的使用等
  • 一般来说:语言模型能力的提升通常归因于模型规模或训练数据的增加
  • 但在某些情况下:
    • 使用精选数据训练的小模型或采用不同架构决策的模型 可以超越在更多 token 上训练的更大模型
    • 引出问题:这是什么原因造成的呢?
  • 为了量化这些设计选择的影响,论文对 92 个不同规模的开源预训练模型进行了元分析 (meta-analyze),这些模型包括:
    • SOTA 开源权重模型
    • 性能较差的模型(less performant models)
    • 采用非常规设计决策的模型(less conventional design decisions)
  • 论文发现:
    • 通过纳入除模型大小和训练 token 数量之外的特征 ,论文预测下游任务性能的能力相对提高了 3-28%
      • 注:这个提升是与仅使用规模特征相比
    • 对模型设计决策的分析揭示了对数据构成的见解
      • 例如代码占比在 15-25% 时语言任务和代码任务之间的权衡,以及网络数据对真实性 (truthfulness) 的负面影响
  • 论文的框架为更系统地研究模型开发选择如何塑造最终能力奠定了基础

Introduction and Discussion

  • 语言模型训练的效果关键取决于预训练 (pretraining) 期间所做的决策
    • 例如,扩展数据 (scaling up data) 的有效性取决于其构成
    • 理解:即使处理了一万亿个 token,如果它们全部由单词 “the” 组成,那也是无效的
  • 研究发现,语言模型的性能可以通过 Scaling Laws (2020, 第 2 节) 进行相当准确的预测
    • 即基于模型参数数量和训练所用 token 数量对模型性能进行外推
  • 但仅基于这两个方面的扩展定律并不总能解释下游任务性能 (2024; 2024)
  • 研究界在理解训练决策如何影响下游性能方面已经取得了进展,特别是在数据构成方面。例如,对照研究 (controlled studies) 表明
    • 在代码数据上训练可以提高在某些推理基准测试上的性能 (2024; 2024);
    • 数据的元特征 (meta-features),如年龄和毒性过滤器 (toxicity filters) 的使用,会影响许多问答 (QA) 任务的性能 (2024);
    • 多语言数据的平衡会影响英语和其他语言的性能 (2023; 2025)
  • 这些工作揭示了宝贵的见解,但它们往往只关注改变训练方案 (training recipe) 的单个方面,而保持其他方面不变
    • 尽管严谨,但这在计算和开发时间上成本高昂
  • 论文转而提出一个问题:论文能否利用过去来自开源语言模型的发现来检验训练决策如何共同影响下游性能?
  • 为此,论文首先对来自不同系列的 92 个基础预训练 LM 的模型架构和数据相关的特征进行了 编目 (catalog) (章节3)
    • 由此产生的模型特征数据库涵盖了 2019 年至 2024 年间发布的大多数主要的、原始的开源权重 Decoder-only 预训练模型
  • 然后,论文开发了方法来 预测 这些模型在广泛基准测试上的性能
    • 预测依据既包括传统的扩展因素,也包括架构决策和数据构成 (章节4)
  • 具体来说,论文训练回归模型 (regression models)
    • 模型输入:提取的特征
    • 模型输出:预测基准测试结果
  • 进一步使用模型可解释性技术 (model interpretability techniques) 来识别在做出这些预测时最显著的特征
  • 论文在预测 12 个流行 LLM 基准测试的性能上评估了这种方法,并证明决定模型性能的 不仅仅是扩展 (not just scaling)
    • 在所有基准测试上,包含所有特征的回归器 (regressor) 的性能都优于仅基于扩展模型特征的回归器 (章节5.1)
  • 论文对特征重要性 (feature importance) 的分析揭示了数据领域 (data domains) 对任务性能的潜在影响,再次证实了经验性结果,例如预训练中使用代码的最佳比例 (章节5.2)
  • 此外,论文发现从模型生成文本中提取的特征(例如问题相关词的频率或类似网络文本的比例),有助于预测各种基准测试的性能
    • 这表明模型的生成模式可以反映其预训练数据中的潜在偏差 (underlying biases) ,进而影响下游性能
  • 通过记录整个社区训练的开源模型并提取见解,论文为模型开发者提供了一个实用的资源,以从集体经验中学习
  • 论文在 (章节8) 中讨论了这一点以及未来的工作

Scaling Laws

Definition

  • 论文在此将扩展定律定义为语言模型系列的参数数量 \(N\) 和 token 数量 \(D\) 与收敛时期望的语言建模损失 \(L(N,D)\) 之间的关系
  • 重要的是,这些定律通常是在保持所有其他因素不变的情况下进行研究的:
    • 保持相同的模型架构、训练数据和模型参数
  • 最初,Kaplan 等 (2020) 表明,在广泛的基于 Transformer 的模型中,这种关系可以表示为幂律 (power law):
    $$L(N,D)=\left(\left(\frac{N_{c} }{N}\right)^{\frac{\alpha_{N} }{\theta_{D} } }+\frac{ D_{c} }{D}\right)^{\alpha_{D} }$$
  • 后来,Hoffmann 等 (2022a) 提出了一个类似的定律,其拟合的系数不同,但同样基于幂律
  • 但扩展定律并非绝对,其确切函数形式和拟合系数可能取决于架构类型、规模范围 (Pearce and Song, 2024) 或其他考虑因素,如推理成本 (inference costs)
  • 更多讨论请参见 (章节7.2)

Maybe it’s Not Just Scaling?

  • 参数数量和训练 token 数量真的是准确预测模型下游性能所需的全部吗?直觉上答案是否定的
  • 模型训练涉及许多设计决策,所有这些都可能对模型性能产生影响
  • 模型架构细节 (Model Architecture Details)
    • 虽然大多数现代语言模型都遵循 Transformer 架构,但存在一些细节差异
      • 例如,层归一化 (layer normalization) 的种类 (2019) 和位置 (2020),以及位置编码 (positional encoding) 的类型 (2021; 2022) 都会对模型性能产生显著差异
    • 先前的工作,例如 Gu and Dao (2023),已经凭经验证明,在保持所有其他因素相等的情况下,做出更好架构决策的模型 (2023a) 优于做出更差决策的模型 (2017)
  • 数据构成 (Data Composition)
    • 数据构成和质量在模型的最终质量中起着重要作用
      • 例如,过去的工作表明,训练一定数量的代码可以提高英语推理任务的性能 (2023)
    • 同样,有工作表明,筛选“教育”内容可以实现更高效的学习,并在基于知识的问答任务上获得更高的性能 (2023)
  • Task Setting
    • 最后,所有上述因素与模型性能的衡量方式之间存在相互作用
    • 虽然先前关于扩展定律的工作主要测量损失值,但下游用户通常关心的是任务性能,而不是预训练数据集上的验证损失 (validation loss)
    • 尽管对于许多任务来说,两者之间通常存在相关性,但某些任务可能更难仅从模型的损失来预测 (2024)
    • 此外,某些任务表现出异常的扩展行为,例如反向扩展 (inverse scaling) 或 U 型扩展 (U-shaped scaling) (2023; 2023; 2024),或者仅仅是更不可预测的性能 (2024)
  • 论文提问:论文能否通过设计一套新的、不仅仅依赖于基于扩展的因素的“定律”来更有效地预测 LLM 的性能?

Building a Database of Publicly-Available Language Models

  • 为了解决论文的研究问题,论文构建了一个包含 11M 到 110B 参数的公开可用语言模型的数据库(包括嵌入参数),仅限于不同的 Decoder-only 基础预训练模型
    • 注:不同是指训练数据和架构的独特组合。在去重数据集上训练的模型被单独计数,但具有不同课程/初始化的变体不计入
  • 本节描述了论文的纳入标准、模型特征化以及评估方法

Data Collection

  • 为了确保论文的分析是一致的,论文应用了以下标准:
  • Pretrained-only:
    • 仅包含从头开始预训练的基础模型,排除了微调变体、合并模型以及经过额外后训练的模型
  • Architecture:
    • 仅包含基于 Transformer 的 Decoder-only 模型以保持一致性
    • 排除了 MoE 或其他架构
  • Publicly available information:
    • 仅包含具有公开可用元数据的模型,这些元数据通过配置文件或论文记录
    • 特别是,纳入需要总参数数量和训练所用总 token 数量这两个信息
    • 模型和模型系列的完整列表可在附录A 中找到

Characterizing Models and Data

  • 论文通过每个模型所做的架构选择以及其预训练数据的选择来表示每个模型
  • 形式上,令 \(\mathcal{A}\) 为与模型架构相关的特征集合, \(\mathcal{D}\) 为与模型预训练数据集相关的特征集合
  • 对于每个任务 \(T\) ,作者希望用预测值 \(\widehat{s_{T} }\) 来近似模型 \(M\) 的真实得分 \(s_{T}\) :
    $$\widehat{s_{T} }(M)=f_{\theta}([\mathcal{A}_{M};\mathcal{D}_{M}]).$$
  • 当 \(\mathcal{A}=\{\#\text{params}\}\) ,\(\mathcal{D}=\{\#\text{tokens}\}\) ,且 \(f_{\theta}\) 是幂律时,这就简化为典型的扩展定律
  • 论文总共记录了 92 个开放模型,涵盖模型特征、高层数据集特征以及从该模型的无上下文生成中派生出的特征等维度
  • 有关完整的特征集和定义,请参见附录B
Features from Model Documentation
  • 论文首先通过阅读源论文/博客(如有,请参见附录A了解原始引用)以及在 Hugging Face Hub (2020) 上列出的数据来收集每个模型的信息
  • Architectural Features:
    • 这些特征捕获了决定模型结构的设计决策
    • 例如,总参数(包括嵌入参数)、Transformer层数、嵌入和前馈维度,以及细节,例如使用的层归一化类型或注意力变体
  • Data Features:
    • 这些特征总结了预训练数据的组成
    • 代表性示例包括训练所用的总 token 数以及来自图2中定义的各个领域的 token 百分比细分,以及英语 token 的比例
    • 论文的预训练数据领域源自开放预训练数据集(2020; 2024)中常见的子领域
    • 论文使用顶级领域(网络、代码、书籍、参考、学术),因为这往往是论文中描述数据组成的粒度
Exploring Data Composition via Generation
  • 尽管许多模型记录了一些数据组成细节,但相对较少的模型发布了其完整的预训练语料库,导致论文研究中许多模型的这些值缺失
  • 为了填补这些空白,论文探索了一种替代方法:
    • 分析模型在无提示情况下生成的文本,以估计其训练数据的特征
    • 论文假设模型的生成风格和内容反映了其训练数据的分布
    • 对于每个模型,论文使用温度 \(T=0.8\) 和 top-p \(p=0.9\) 的核采样生成 5k-10k 个无上下文的生成文本(每个生成文本最多 256 个 token)
    • 然后,论文使用标准 NLP 工具和基于 LM 的分类器从这些生成文本中提取语言学和领域特征
    • 论文在附录E和F中验证了这种方法
  • 论文还提取了 low-level 语言特征,例如每句词数(words per sentence)、成分树深度(constituency tree depth)和依存长度(dependency length)
    • 论文的验证分析(附录G)表明:
      • 领域层级特征与实际预训练数据构成具有较强的相关性
        • 例如,网页内容相关性:\(r = 0.916\),\(p = 7.55 \times 10^{-12}\)),
      • low-level 风格特征的相关性较弱
    • 然而,所有特征的整体 Model-level 相关性表现强劲(通常 \(r > 0.8\))
    • 这一结果支持我们将“自由生成内容”(free-generations)用作预训练数据构成的替代指标(proxies),同时也说明不能用自由生成特征替代预训练特征
  • 注:一些关键术语补充说明:
    • constituency tree depth(成分树深度) :句法分析中的核心概念,指“成分树”(用于表示句子句法结构的树形图,如主谓、动宾等成分的层级关系)从根节点到最深叶节点的路径长度,反映句子句法结构的复杂程度
    • dependency length(依存长度) :依存句法分析中的指标,指句子中两个存在依存关系的词语(如中心词与修饰词)在文本序列中的距离,常用来衡量句子结构的线性复杂度

Evaluation Datasets and Metrics

  • 为评估设计选择对推理能力的影响,我们在 Open LLM 排行榜(2024)的数据集上对模型进行了测评,这些数据集涵盖了推理能力的多个不同维度(见表 1)
    • 其中,部分模型的测评结果直接从该排行榜获取;
    • 对于未列入该排行榜的模型,我们使用 Eleuther LM 评估工具包(2023),在完全相同的设置下开展测评
    • 此外,若某项任务或子任务存在多个版本,我们会对所有版本均进行测评,并通过求平均值得到该任务的整体得分
    • 有关评估数据集及测评设置的完整列表,详见附录 C
  • 对于评估数据集 \(T\) ,其中第 \(i\) 个样本是 \(y_{i}\) ,模型为 \(M\) ,论文如下定义 \(s_{T}(M)\) :
  • 准确率(Accuracy) :对于大多数任务,论文使用未归一化的精确匹配准确率
    $$ s_{T,\text{acc} }=\frac{1}{|T|}\sum_{i=1}^{|T|}\mathbb{I}\{y_{i}=\hat{y}_{i}\} $$
    • 对于 Humaneval,论文使用 pass@1,但为方便起见,将其与准确率任务归为一组
  • Brier分数(Brier score) 对于较小模型难以达到非零准确率的任务,论文遵循(2023)的做法,使用多类 Brier 分数作为多项选择任务的替代连续指标(1950) (注意:对于 Brier 分数来说,值越低越好,多类 Brier 分数范围在 0-2 之间)
    • 对于一个有 \(K\) 个类别的任务,令 \(p_{ik}\) 为样本 \(i\) 上类别 \(k\) 的预测概率。则
      $$ s_{T,BS}=\frac{1}{|T|}\sum_{i=1}^{|T|}\sum_{k=1}^{K}(p_{ik}-\mathbb{I}\{y_{i}=k\})^{2} $$

异质性(Heterogeneity) in Task-specific Scaling

  • 在加入其他因素之前,作者检查了所选任务之间沿 \(N\) 和 \(D\) 扩展的差异
  • 论文为每个任务拟合了一个(2020)风格的定律
  • 如图3所 示,论文看到不同的任务在遵循扩展趋势的程度以及它们各自的扩展轮廓上可能表现出显著差异
  • 例如,TruthfulQA 似乎表现出 U 形扩展,而 Humaneval 有更多的“异常值”模型
  • 任务的 \(R^{2}\) 值完整列表可在附录D 中找到

Predictive Modeling

  • 接下来,给定论文的数据库,论文拟合一个回归器来尝试预测性能
  • 在传统的扩展定律中,回归器是基于幂律拟合的
  • 然而,论文现在要处理大量特征,其中一些可能无法通过简单的参数形式很好地捕捉
  • 因此,论文遵循先前关于性能预测的工作(2020; 2021),利用基于 XGBoost (2016)的树形回归器
    • 论文还试验了LightGBM (2017),发现其性能相似。结果见附录K
  • 对于每个评估基准,训练一个模型 ,以基于架构特征 \(\mathcal{A}\) 和数据特征 \(\mathcal{D}\) 来预测该任务上的性能指标
    • 理解:每个评估基准都有一个单独的 XGB 模型
  • 对于每个任务设置,由于模型数量相对较少,论文执行 3 折交叉验证,并在每折的训练集上进行嵌套内部交叉验证
    • 内部交叉验证在一小组超参数上进行网格搜索,允许模型随任务略有变化。更多细节请参见附录I
  • 评估(Evaluation) 为了评估预测器,论文使用所有模型和折迭的平均绝对误差(Mean Absolute Error)
    • 对于一个有 \(N\) 个模型被评估的任务
      $$ \text{MAE}_{T}=\frac{1}{|T|}\sum_{i=1}^{N}|s_{T}(M_{i})-\widehat{s_{T} }(M_{i})|$$
    • 论文将扩展定律预测器以及全特征预测器相互比较,同时也与中位数基线(median baseline)(它只是为该折迭测试集中的每个模型预测训练集中模型的中位数得分)和对数线性基线(log-linear baseline)(它将一个对数线性函数拟合到参数数量和 token 数量)进行比较
  • 迭代特征选择(Iterative Feature Selection) 由于完整的特征集非常大,论文根据哪个特征能最大程度地减少 MAE(在 5 个随机种子上平均),从完整集合中贪心地顺序选择特征
    • 不断添加特征,直到观察到的减少量不再至少为 \(1\times 10^{-4}\)
    • 论文开始时仅使用两个扩展定律特征,并将其称为扩展定律(scaling-laws) 模型,尽管它不具有传统幂律的形式
      • 注:不具有传统幂律形式的解释:由于论文使用基于树的预测器来适应多样化的特征类型(包括非数值型特征),论文的方法优先考虑在观察到的界限(10M-100B 参数,50B-3T token)内进行插值,而不是外推(探索其他预测方法仍然是未来的工作)
    • 然后,通过合并额外的架构或数据特征,我们可以直接量化这些额外特征带来的增量预测能力
      • 论文将具有该组特征的模型称为全特征(all-features) 模型
      • 在所有情况下,论文使用相同的超参数网格、相同的随机种子和分割来运行模型
  • 显著性检验(Significance Testing) 由于基线之间的相对差异很小,论文在多个种子(50个)上测试两个预测器
    • 然后,论文对每个种子的总体 MAE 值运行配对 \(t\) 检验,并使用错误发现率(1995)对跨任务的多重比较进行校正

Results

Predictor Performance

  • 加入与规模无关的特征能持续提升基准测试性能
    • 论文发现,在传统的扩展定律特征之外加入额外特征,能在多个基准测试上显著提升预测准确度,如表 2 所示
    • 在所有评估案例中,全特征预测器均优于仅使用扩展定律的预测器,相对误差减少的幅度大约从 3%(MathQA)到 28%(Lambada)不等
    • 值得注意的是,在语言建模和常识推理任务中观察到了最强的改进效果
  • 某些任务更强烈地依赖于非规模特征
    • 这种改进模式表明,架构和训练数据特征,对于预测与特定数据“类型”更紧密相关的某些任务,其表现可能更具信息量
    • 在代码生成任务(13% 的改进)和基于自然语言的推理任务(例如 Lambada,28% 的改进)上都观察到了巨大的改进
    • 即使是领域较窄的任务,如数学推理(GSM8k,+16%)或知识密集型评估(MMLU,+11-14%),也看到了一致但更温和的增强
    • 然而,使用 Brier 分数的基准测试显示出较小的改进(约 3-6%)
      • 这可能是因为 Brier 分数本身对模型性能中的涌现效应敏感性较低,特定任务的选择限制了改进空间,或者是这两个因素共同作用的结果

What Features Does Task Performance Depend On?

  • 为了理解影响任务性能的因素,作者检查了 Shapley (1953) (SHAP) 值,这些值显示了特征值如何影响预测
    • 注:SHAP Value(SHapley Additive exPlanations,沙普利可加解释)是一种基于合作博弈论的模型解释方法,核心目标是量化机器学习模型中每个特征对预测结果的贡献程度 ,让复杂模型(如随机森林、神经网络)的决策过程变得可解释
    • “SHAP Value”:每个特征对这个 “收益” 的贡献值,即该特征让预测结果偏离基准值的程度(正值表示推动预测值升高 ,负值表示推动预测值降低)
  • Arc Challenge、HumanEval、Winogrande 和 TruthfulQA 的结果如图 4 所示,其余基准测试的结果见附录 L
  • 少量代码大有裨益,但过多则对自然语言推理(NLI)有害
    • 预训练数据中代码的百分比是一个关键的非规模特征
    • 较高的代码组成有益于 Humaneval 性能,但对包括 Arc Challenge、Hellaswag、Winogrande 和 Lambada 在内的自然语言推理任务产生负面影响
    • 如图 5 所示,代码比例超过 20-25% 的模型在 Humaneval 上显示出增益,但在语言基准测试上受到惩罚
    • 15-25% 的中等代码比例似乎能平衡这些相互竞争的需求
  • 其他数据领域显示出任务特定的效应
    • 从自由生成特征中,论文观察到最近使用合成数据训练的模型(Phi (2023)、SmoILM (2024))生成了更多疑问词,这表明训练数据中包含问答内容
    • 类似参考书或包含大量问题的生成内容与 Arc Challenge 和 Winogrande 的更好性能相关,而类似网络文本的生成内容则与更差的 TruthfulQA 性能相关(图 4)
  • 非规模的架构决策影响较小
    • 大多数高影响力特征是与数据相关的或与规模相关的架构特征(例如,维度)
    • 在某些情况下,层归一化的类型和位置嵌入都被认为具有显著影响

Validating Performance Predictions with Confirmatory Experiments(验证性实验)

  • 为了验证元分析的发现,论文还使用 Dolma 数据集上训练的 460M 参数模型进行了验证性的预训练实验
  • 论文旨在验证两个关于数据分布的发现:
    • (1) 当仅考虑自然语言推理时,约 8% 的代码比例是最优的 ,但在平衡代码和自然语言时,15-25% 可能是最佳比例;
    • (2) TruthfulQA 性能随着网络数据比例的增加而降低
      • 理解:网络数据的虚假信息多
  • 由于这是一个小规模模型,准确度差异可能不显著,论文将相关数据集转换为基于损失的评估
  • 由于计算限制,论文将每个检查点训练 10B 个 token,但使用按 100B token 运行规模调整的余弦学习率调度
  • 详细信息和精确的损失图见附录 M
  • 总体而言,在图 6 中,论文发现验证性运行在很大程度上验证了论文的元分析预测
    • 唯一不足:尽管准确度的趋势符合预期,但 TruthfulQA 的基于边际的损失在 50% 网络数据检查点上略低于 30% 检查点
  • 这提供了初步证据,表明论文的分析方法可以用于先验地智能预测语言模型训练设计决策

Related Work

Empirical Data Composition Results

  • 先前的研究已经探讨了预训练中代码数据的作用 (2023; 2024) 以及领域消融 (2024)
  • 数据过滤可以在单纯扩展规模的基础上进一步提升性能 (2023; 2024)
  • 论文的结果表明,代码数据在中等比例(最佳比例 15–25%)下能增强自然语言推理能力,这修正了先前 25% 的估计 (2024)
  • 论文通过汇集现有模型的见解来识别有前景的测试方向,从而对实证消融研究进行了补充

Observational and Task-Specific Scaling Law Fitting(理解:观测性与任务特定的扩展定律拟合)

  • 任务特定的扩展定律研究表明,参数分配会影响机器翻译的结果 (2021),而多任务处理对英语-目标语言对有益 (2023)
  • 关于下游任务的研究强调了预训练数据与下游数据之间对齐的重要性 (2021; 2024)
  • 各种研究探讨了数据重复 (2024)、多领域数据 (2024) 以及稀疏性 (2023)、精度 (2024) 和推理成本 (2022a) 等因素,而一些研究发现训练超参数具有稳定性 (DeepSeek-2024a)
  • Ruan 等 (2024) 也使用开源模型的观测数据来预测任务性能,但他们是根据模型在其他任务上的表现来预测某一任务的性能
  • 论文在识别性能的通用自然语言能力和编码能力两个方面得出了类似的结果,但论文的动机是将这些能力追溯到预训练决策

Pretraining Data Selection

  • 领域混合在预训练中已被研究,其他工作将其表述为回归问题 (2024; 2025) 或在训练过程中使用代理模型来选择领域权重 (2023; 2023; 2024b; 2025)
  • 相比之下,论文回顾性地分析了领域构成和训练决策如何影响跨任务的性能,这是在训练期间为单个模型优化数据权重的补充视角

Tracing Capabilities to Data

  • 特定的语言模型能力已被关联到预训练数据中的模式
  • 数值推理和句法规则学习的性能取决于训练数据中数字术语的频率 (2020; 2021)
  • Ruis 等 (2024) 发现,对推理有影响的数据分散在众多文档中,并且与程序性内容相关
  • 类似地,Chen 等 (2024) 观察到 “并行结构”与上下文学习能力密切相关
    • 问题:这里的并行结构是什么?
  • 论文目前关注更广泛的数据领域,但论文的框架可以通过更细粒度的任务或更精细的数据特征进行扩展

Future Work

  • 展望未来,有几个明确的方向
    • 首先,论文的数据库(章节3)可以随着新模型和基准测试的发布而进一步扩展,论文将发布代码和数据以帮助推动社区进行更系统的数据记录工作
    • 其次,作者希望论文的工作将有助于发现在更受控环境中测试的假设
      • 现有模型交织了许多设计决策,而进一步仅涉及单一变化轴的受控预训练实验可以进一步阐明每个特征的影响
    • 最后,在论文的研究中,绝大多数预训练模型专注于密集 Transformer 架构,而混合专家 (2024a; DeepSeek-2024b) 和状态空间模型 (Gu and Dao, 2023) 等替代架构也引起了显著的研究兴趣
      • 如何恰当地对这些更多样化的模型架构进行特征化,并在性能预测中使用这些信息,是一个有趣的挑战,可能会揭示更多的见解
  • 尽管预训练数据分析和选择迄今为止主要集中于实证发现,但通过大规模实证研究更好地理解训练如何影响模型能力,也可以促进可解释性实验和对学习表征的可能干预,其中受控的变化轴提供了案例研究

Limitations

  • 论文当前的工作有几个局限性,可以在未来的工作中改进
  • 第一,尽管论文记录了许多开源模型,但论文的样本量仍然有限,特别是对于较大(>50B)参数的模型
    • 这限制了论文得出关于大型模型扩展行为的稳健结论的能力
    • 而且论文拥有的模型在参数数量、数据大小和数据分布上并不均匀,某些规模范围和数据分布被过度代表
      • 哪些模型被开源也可能存在选择效应,并且在不同的时间段,流行的架构决策或数据构成可能存在时间效应
  • 第二,论文的方法论也带来了一些局限性
    • 因为论文没有系统地训练所有论文自己的模型(尽管论文在附录 A 中有一些自己的模型),所以论文的分析本质上是观察性的
    • 虽然我们可以观察到设计选择与性能之间的有趣关系,但要做出因果断言需要实验验证
    • 此外,虽然基于树的回归器能有效捕捉复杂的特征交互,但它们限制了论文外推超出数据集中所见模型大小(参数和 token 数量)范围的能力
  • 第三,论文注意到论文工作的范围也有局限性
    • 论文专注于 Decoder-only 的基预训练密集 Transformer 模型,这排除了重要的架构变体,例如混合专家模型、非基于 Transformer 的架构以及经过后训练的模型
    • 此外,论文主要检查英语模型,因为论文在这项工作中不关注多语言性
    • 论文的特征集虽然广泛,但可能仍未捕捉到模型设计和训练的所有相关细节,特别是目前的优化细节
  • 这些局限性为未来的工作指明了方向:
    • 扩展数据库以包含更多样化的模型类型和语言覆盖范围;
    • 开发更具针对性的函数形式,以便在输入异构特征集的同时实现更好的外推;
    • 使用新的预训练模型进行有针对性的实验,以验证特定设计选择的影响

Ethical Considerations

  • 在这项工作中,论文专注于理解模型为何在标准基准测试上表现良好,但并未关注其他重要的考量因素,例如安全性或社会偏见
  • 而且论文的分析侧重于英语模型和基准测试
    • 这一局限性反映但也可能强化了该领域现有的对英语的偏向,可能导致对其他语言有效架构的开发投入不足

附录 A:List of all models

  • All models are listed in Table 3.

附录 B:List of all architectural and data features

B.1 Architectural Features

  • (注意,本部分的特征是从官方文档(例如 Hugging Face 的模型/数据卡片或原始论文)中收集的)
  • 总参数量 (Total parameters) :模型中的参数总数(包括嵌入参数)
    • 注意,论文仅包含 Decoder-only 的密集模型
  • 维度 (Dimension) :嵌入维度
  • 头数 (Num heads) :注意力头的数量
  • MLP 比率 (MLP ratio) :\(\frac{\text{FFN dimension} }{\text{Embedding dimension} }\) 的比率
  • 位置嵌入 (Positional Embeddings) :位置嵌入的类型
    • 这可以是非参数的(正弦或固定嵌入)、学习的(仅作为每个位置的向量学习)、rope (rope 嵌入) 或 alibi(技术上不是嵌入,但因其功能目的而包含在此)
  • 层归一化 (LayerNorm) :应用的层归一化类型
    • 这可以是非参数的(仅基于算术的归一化)、参数的(类似,但有一些可学习的参数,如扩展/偏置)和 RMSNorm(参数版本的简化版)
  • 注意力变体 (Attention variant) :使用的注意力的大致类型
    • 这可以是 full(普通注意力)、local(每个 Token 位置仅关注其周围的位置)、mqa(多查询注意力)或 gqa(分组查询注意力)
  • 偏置 (Biases) :模型的某些部分是否存在偏置项
    • 可以是 none(无偏置)、attn only(仅在注意力层中)、ln only(仅在层归一化中)
  • 块类型 (Block type) :变压器块是否完全并行计算
    • Sequential 表示不并行,而 parallel 表示在注意力或 FFN 层中存在某种并行性
  • 激活函数 (Activation) :使用的激活函数
    • 可以是 relu、gelu/gelu 变体、silu 或 swiglu
  • 序列长度 (Sequence length) :序列长度
  • 批次实例数 (Batch instances) :预训练期间使用的批次大小

B.2 Data Features

  • (注意,本部分的特征是从官方文档(例如 Hugging Face 的模型/数据卡片或原始论文)中收集的)
  • 总 Token 数 (B) (Total tokens (B)) :预训练期间使用的 Token 总数,以十亿计(转换为对数尺度)
  • 预训练数据中网络数据百分比 (% Web in Pretraining) :来自通用网络来源的预训练数据百分比
  • 预训练数据中代码百分比 (% Code in Pretraining) :由代码组成的预训练数据百分比
  • 预训练数据中书籍百分比 (% Books in Pretraining) :来自书籍的预训练数据百分比
  • 预训练数据中参考文献百分比 (% Reference in Pretraining) :来自参考文献来源的预训练数据百分比
  • 预训练数据中学术内容百分比 (% Academic in Pretraining) :来自学术来源的预训练数据百分比
  • 预训练数据中英文百分比 (% English in Pretraining) :预训练数据中英文文本的百分比

B.3 Freegen-derived Features

  • 这些特征源自模型的生成文本
  • 对于每个模型,提取 5-10k 个生成文本,并聚合以下指标(通过均值和标准差)
    • 但二元组熵、教育分类器分数和领域分类是例外,因为它们是在所有生成文本上计算一次的
  • 论文使用 Stanza (2020) 在按语言对生成文本进行分类后生成基于解析的特征
    • 论文仅将 stanza 支持的语言包含在解析特征所基于的最终生成文本集中
B.3.1 生成长度和基本统计量 (Generation Length & Basic Statistics)
  • 平均字符长度 (Mean Character Length) :每个生成文本的平均字符数(上限为 2048)
  • 平均生成 Token 数 (Mean Tokens Generated) :每个生成文本的平均 Token 数
  • 平均句子数 (Mean Sentences) :每个生成文本的平均句子数
  • 平均词数 (Mean Words) :每个生成文本的平均词数
  • 平均每句词数 (Mean Words per Sentence) :每个句子的平均词数
B.3.2 Constituency Parse Features
  • 最深解析树平均深度 (Mean Depth of Deepest Parse Tree) :每个生成文本的平均最大选区树深度
  • 解析树平均深度 (Mean Depth of Parse Trees) :所有句子/短语的平均选区树深度
  • 词平均深度 (Mean Word Depth) :选区树内词的平均深度
  • 词深度变异平均 (Mean Word Depth Variation) :跨句子/短语的词深度标准差的平均值
B.3.3 Dependency Parse Features
  • 依存头距离 90% 分位数平均值 (Mean 90th-Percentile Dependency Head Distances) :对于每个生成文本,计算词与其依存头之间的线性距离的 90% 分位数,然后对这些值取平均
  • 最大依存头距离平均值 (Mean Maximum Dependency Head Distances) :每个生成文本中任何词到其依存头的最大距离的平均值
  • 依存头距离中位数平均值 (Mean Median Dependency Head Distances) :每个生成文本的依存头距离中位数的平均值
  • 最大依存根距离平均值 (Mean Maximum Dependency Root Distances) :每个生成文本中任何词到句子根的最大距离的平均值
  • 平均依存根距离平均值 (Mean Mean Dependency Root Distances) :每个生成文本中词到句子根的平均距离的平均值
  • 依存根距离中位数平均值 (Mean Median Dependency Root Distances) :每个生成文本中词到句子根的距离中位数的平均值
B.3.4 Domain Classification Features
  • 生成学术类文本百分比 (% Generated Academic-like Text) :被分类为学术类的生成文本百分比
  • 生成书籍类文本百分比 (% Generated Books-like Text) :被分类为书籍类的生成文本百分比
  • 生成代码类文本百分比 (% Generated Code-like Text) :被分类为代码类的生成文本百分比
  • 生成参考类文本百分比 (% Generated Reference-like Text) :被分类为参考类的生成文本百分比
  • 生成专业文本百分比 (% Generated Specialized Text) :被分类为专业类(例如,乐谱、象棋 PGN、生物医学数据)的生成文本百分比
  • 生成网络类文本百分比 (% Generated Web-like Text) :被分类为网络类的生成文本百分比
B.3.5 Classifier and Language Metrics
  • 教育分类器分数平均值 (Mean Educational Classifier Score) :教育分类器给出的平均分数
  • 生成英文文本百分比 (% Generated English Text) :生成的英文文本的平均百分比

B.3.6 Lexical Diversity and Entropy Metrics

  • 平均二元组熵 (Mean Bigram Entropy) :跨生成文本计算二元组的平均熵
  • 型符比 (Type-Token Ratio) :唯一 Token 数与总 Token 数的平均比率
  • 唯一 Token 数 (Unique Tokens) :每个生成文本的平均唯一 Token 数
B.3.7 Lexical and Stylistic Features
  • 实词-功能词比率 (Content-Function Ratio) :实词(名词、动词、形容词、副词)与功能词的比率
  • 疑问词比率 (Question Words Ratio) :每 10 万个词中疑问相关词(例如 how, what, why, when, where, who, which, whose)的比率
  • 祈使词比率 (Imperative Words Ratio) :每 10 万个词中祈使词(例如 do, make, consider, take, use, ensure, check, build, apply, run, create, find, go, try, turn, start, stop, put, keep, leave, get, move)的比率
  • 连词比率 (Conjunctions Ratio) :每 10 万个词中连词(例如 and, but, or, so, because, although, however, therefore, yet)的比率
  • 指令词比率 (Instruction Words Ratio) :每 10 万个词中指令导向短语(例如 “Question:”, “Answer:”, “Instruction:”, “User:”, “Assistant:”, “Q:”, “A:”)的比率
  • 数字比率 (Numbers Ratio) :生成文本中数字 Token 的比率

附录 C:List of all evaluations and settings

  • 尽管论文理想情况下会评估模型和任务的全部组合,但论文发现由于一些模型与 LM Evaluation Harness 不兼容以及计算限制,论文无法在每個数据集上评估所有 92 个模型
  • 论文在表 4 中列出了论文目前每个基准测试的评估数量,并将在数据库中继续补充评估结果

附录 D:Task Deviations from Kaplan-style Scaling Laws

  • 表 5 记录了针对每个模型性能拟合幂律分布所得到的决定系数(\(R^2\) value)

附录 E:Free-generation Domain Classification

  • 论文使用 GPT-4o-mini 将模型生成文本分类到顶级领域
  • 论文发现这种多阶段提示(清单 1,清单 2)在 Dolma 按领域采样的样本上具有合理的精确度 (2024),因此使用它对自由生成文本进行分类

附录 F:Domain Classifier Validation

  • 为了验证基于 4o-mini 的分类器的可靠性,论文请论文的一位作者根据附录 E 中使用的相同标注标准,对来自三个预训练数据集(the Pile, the SmoILM corpus, 和 RefinedWeb)的 300 个选定样本进行标注
  • 被模型或人类标注者标注为“unknown”或“incoherent”的样本被排除,因为这些样本不包含在领域混合的计算中
  • 过滤后,论文分析了 258 个文本样本,发现人类标注者和模型的绝对一致率为 85.8%,Cohen’s \(\kappa\) 为 0.746,表明人类分类和模型分类之间具有高度一致性

附录 G:Free-generation Validation

  • 为了验证论文的自由生成方法作为预训练数据组成的代理,论文分析了模型的自由生成特征与其预训练数据之间的相关性
  • 对于在三个开放预训练数据集(the Pile, the SmoILM corpus, 和 Refinedweb)上训练的模型,论文比较了它们的自由生成特征与相同标注器和基于 LM 的分类器(附录 E)在预训练语料库随机采样的 100 万文档子集上产生的特征
  • 由于成本原因,对于领域分类,每个语料库使用了 100 万中的 5000 个示例
    • 这 100 万文档是通过水库采样均匀采样的
  • 此外,论文计算了两个整体模型层面的相关性,它们衡量了每个模型的完整生成配置文件与其训练数据的匹配程度:
    • 1)领域层面相关性 (Domain-level correlations): 对于每个领域类别(web, code, academic, books, reference),论文计算了模型文档记录的预训练数据中该领域的百分比与被分类到该类别的自由生成文本百分比之间的相关性
    • 2)特征层面相关性 (Feature-level correlations): 对于语言特征(连词比率、祈使动词比率、疑问词比率、数字比率、字符长度),论文评估了自由生成中的每个特征与预训练数据中相应特征的相关性
  • 这些模型层面的相关性明显高于单个特征的相关性,表明模型比任何单一维度都更好地保留了其训练数据的整体“指纹”
  • 表6 和 表7

附录 H:Free-generation Examples

  • Llama-2-70B,网页风格(web-like):

    • 给定斯蒂芬妮在与布雷克谢尔顿一起出席的活动上无法认出自己的第一首歌曲你创作和/或录制的第一首歌是什么?如果你再也没听过它,直到在参加一个大型乡村音乐活动时与数十万人进行Facetime视频通话,你会有什么感觉?这正是格温斯蒂芬妮本周在乡村电台研讨会活动上被授予荣誉时发生的事情。在纳什维尔,可以想象,当你拥有像她那样疯狂的职业生涯时,可能很难记住这些年所有的里程碑,但试想一下还要当场为所有人做到这一点。嗯,结果发现,双重才华的多莉帕顿原本要介绍格温是“一个打破壁垒的艺术家”,当时她引用了斯蒂芬妮第一首“乡村”歌曲的歌词……‘3000年’。外交辞令的“好声音”教练布雷克谢尔顿在她当时展示的视频中插话,问道“还有其他完成的工作吗……

    • phi-2,书籍风格(books-like): 太阳在傍晚的天空中明亮地闪烁,将温暖的光线投射在温馨的小厨房里,一对勤奋的父母正在那里准备一顿美味的饭菜。空气中充满了期待,他们开始了一次烹饪冒险,心中充满了爱和喜悦。他们卷起袖子,灵巧地挥舞着各种锋利的刀具,每把刀都是他们技能和精确度的延伸。就像艺术家拿着画笔一样,他们精巧地在食材中穿梭,与刀具共舞,如同舞伴。刀片的移动速度之快违背了日常动作的物理规律,优雅而有目的地在空中滑行。……

    • Qwen2-72B,参考(指令)风格(reference (instruct)-like): 儿童:家庭票(最多2名成人和2名儿童)23.00美元 (1)布莱克先生想和他的妻子以及两个儿子(他们都超过9岁)一起看电影,他应该支付多少?__ A.23美元 B.46美元 C.48美元 D.49美元 (2)布朗先生想和他5岁的双胞胎儿子一起看电影,他应该支付多少__。 A.18美元 B.23美元 C.28.5美元 D.32美元 (3)丽莎在这部电影中看不到雪。__ A.正确 B.错误 (4).简是一位数学老师,她有两个女儿(一个12岁,妹妹8岁)。她们想看电影。她应该支付多少?__ A.28.5美元 B.27美元 C.33美元 D.26.5美元 (5)迈克十六岁。今天是星期六,他想和他40岁的父母一起看电影。他们将为门票支付____。 A.22美元

    • OPT-6.7B,网页风格(web-like): 我不认为这是一个问题。这是一个陈述。我说养只狗吧,世界有狗更美好。我对自己说了完全一样的话。我不知道你为什么被踩了。因为reddit是一群讨厌狗的人

    • Codegen-16B-mono,代码风格(code-like):

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      from typing import Optional
      from fastapi import FastAPI
      from pydantic import BaseModel

      app = FastAPI()

      class Item(BaseModel):
      name: str
      price: float
      is_offer: Optional[bool] = None

      @app.get("/")
      def read_root():
      return {"Hello": "World_from_ FastAPI"}

      @app.get("/items/{item_id}")
      def read_item(item_id: int, q: Optional[str] = None):
      return {"item_id": item_id, "q": q}

      @app.put("/items/{item_id}")
      def update_item(item_id: int, item: Item):
      return {"item_name": item.name, "item_id": item_id}
    • Pythia-410M,学术风格(专利)(academic-like (patent)):

      • 本发明涉及用于从发电站的废物流中去除固体废物材料的新的和改进的方法和设备,其中该固体废物材料计划在发电厂完工后进行处置。先前已经进行了各种尝试来从工厂的废物流中去除固体废物。这是真的,例如,对于污泥的排放,这些污泥通常被收集到一个污泥罐中,并在一个通常连接到工厂的敞开排水出口的清洗罐中被洗出工厂。这项现有技术由美国专利号3,623,579进行了讨论,该专利授予了G.R.Clark并描述了一种通过在罐中絮凝和絮凝并搅动固体以打破固体颗粒之间的键合来处理废物流以去除固体废物的方法。此外,美国专利号4,016,823描述了一种装置,该装置描述了一种方法,其中液体污水从废物流中和从污水处理厂中被去除,在那里要被去除的固体废物将被处理以生产用于沐浴浴缸或肥皂的氨净化水,并且其中来自废水处理厂的污水被去除到污水处理厂,在那里这些污水与水混合或作为肥料处理。…

附录 I:Appendix I XGBoost Settings

  • 对于内部网格搜索,树的最大深度在[2,3,5]中,学习率在[0.01,0.1,0.3]中,树的数量在[50,100]中

附录 J:Selected Features by Task

  • 在表8中,论文展示了每个基准测试所选的特征

附录 K:LightGBM Results

  • 表2的 LightGBM 版本可以在表9中找到
  • 注意:未对 LGBM 进行显著性检验,因此这反映了一次运行的结果,尽管对两个预测器都仍在附录I的相同值上进行了超参数搜索
    • Brier 分数扩展 ×100 以便比较
    • 这里的两个预测器都使用 LGBM

附录 L:SHAP Plots for remaining benchmarks

  • 剩余基准测试的SHAP图可以在图7-图15中找到。请注意,对于Brier分数任务(ANLI,XNLI,MathQA,LogiQA2),分数越低越好

附录 M:Details on confirmatory pretraining runs

M.1 训练(Training)

  • 对于论文的验证性实验,论文使用 Megatron-Deepspeed 库从头开始训练了 460M 参数的 Llama-2 架构模型

  • 论文将训练 token 数量上限设为 10B,同时使用设置为 100B token 长度的余弦学习率调度(意味着每个检查点大约完成了“完整”预训练运行的 10%)

  • 训练在每个检查点一个节点上进行,使用 8 个 H100 GPU

    • 每个检查点大约需要 6 小时来训练
  • 对于论文的数据混合,论文使用 Dolma v1 数据集的子集构建了各种混合

    • 在网页与其他的实验中,论文固定了所有其他数据源的相对百分比,同时改变网页的百分比
  • 训练配置如下:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    training:
    num_layers: 14
    num_attention_heads: 12
    seq_length: 2048
    num_kv_heads: 12
    hidden_size: 1536
    ffn_hidden_size: 4128
    tune_steps: 1000
    lr: 0.00015
    min_lr: 1.0e-5
    weight_decay: 1e-2
    grad_clip: 1.0
    lr_warmup_steps: 100
    save_interval: 2000
    eval_interval: 2000
    train_epochs: 1
    tp: 1
    micro_batch_size: 16
    global_batch_size: 512
    seed: 42
  • 除了数据混合之外,所有实验都使用相同的超参数以确保公平比较

M.2 评估(Evaluation)

  • 为了评估不同数据混合对模型性能的影响,论文在以下任务上评估了论文的模型:
    • 1)自然语言推理(Natural language inference): Lambada, winogrande, arc challenge
    • 2)代码生成(Code generation): Humaneval
    • 3)数学(Math): GSM8K
    • 4)事实性(Factuality): TruthfulQA
  • 注意,由于时间限制,论文没有选择完整的评估集
  • 由于LM eval harness没有为所有任务实现困惑度/基于损失的评估,论文手动将多项选择任务转换为基于损失的指标,并在计算所有任务的损失时屏蔽提示或问题

M.3 转换为基于损失的指标(Conversion to Loss-Based Metrics)

  • 为了确保跨不同任务和模型的一致评估,论文将各种基准测试数据集转换为基于损失的指标
  • 这种方法允许在模型之间进行更直接的比较,并更清晰地解释改进
  • 以下是论文为每种数据集类型实现损失计算的方式:
  • 多项选择任务(ARC Challenge, Winogrande, HellaSwag, TruthfulQA): 对于这些数据集,论文计算了两个主要的基于损失的指标:
    • 平均损失(Average Loss): 论文计算了正确答案的负归一化对数概率。对于每个问题,论文将输入格式化为“问题+答案选项”,然后为每个选项计算按token长度归一化的序列对数概率。正确答案的负对数概率被用作损失
    • 基于边际的损失(Margin-based Loss): 特别是对于 TruthfulQA,论文计算了真实答案和非真实答案之间的边际。这被计算为最佳真实答案的对数概率与最佳非真实答案的对数概率之差的负值。损失越低表示区分真实和非真实信息的能力越好
  • 生成任务(GSM8K, HumanEval, Lambda): 对于生成任务,论文计算:
    • 回答损失(Answer Loss): 论文计算 solution Token 上的交叉熵损失
      • 注:对 Lambada 任务,仅使用最后一个 word
    • 所有对数概率均被序列长度归一化

M.4 Full Results

  • 表 10:代码与自然语言混合数据
  • 表 11:网络数据与其他数据混合的精确损失值

NLP——Reinforcement-Pre-Training

注:本文包含 AI 辅助创作

  • 参考链接:
    • 原始论文:(RPT)Reinforcement Pre-Training, arXiv 20250609, Microsoft Research & PKU & THU

Paper Summary

  • 整体内容总结
    • 论文提出了一种新颖的,用于预训练大语言模型的新范式,强化预训练(Reinforcement Pre-Training, RPT)
    • 通过将下一词预测任务重构为可验证的推理任务,并应用基于正确性的强化学习奖励,RPT 使大语言模型能够在预训练期间利用扩展计算构建更强的基础推理能力
    • 实验表明:RPT 提升了下一词预测的准确性 ,在数学 和通用推理基准的 Zero-Shot 设置 中表现出色 ,并为后续强化学习微调 提供了更好的起点
    • RPT 通过从根本上重新思考预训练目标本身,为开发更强大、更通用的智能大语言模型提供了新的方向
  • 论文将 Next-Token Prediction(NTP)任务重新定义为一种通过强化学习训练的推理任务,模型在正确预测给定上下文的 Next Token 时会获得可验证的奖励
  • RPT 提供了一种可扩展的方法,能够利用海量文本数据实现 通用强化学习(general-purpose RL) ,而 无需依赖特定领域的标注答案
    • 通过激励 Next Token 推理能力,RPT 显著提升了语言模型在预测 Next Token 时的准确性
  • RPT 为后续的强化微调提供了强大的预训练基础
  • 实验结果表明,随着训练计算量的增加, Next-Token Prediction 的准确性持续提升
    • 作者认为,这些结果证明了 RPT 是一种有效且有前景的规模化范式,能够推动语言模型预训练的进步

Introduction and Discussion

  • LLM 在各种任务中展现出了卓越的能力,这主要得益于基于海量文本语料的 Next-Token Prediction 目标的可扩展性
    • 这种自监督范式已被证明是一种高效的通用预训练方法
  • RL 已成为一种强大的技术,可用于微调大语言模型,使其与人类偏好对齐或增强特定技能,例如复杂推理 (2022, 2023, 2024)
  • 但当前强化学习在大语言模型训练中的应用面临可扩展性和通用性挑战
    • RLHF 在对齐任务中表现有效,但其依赖昂贵的人类偏好数据,且学习到的奖励模型容易受到奖励破解(reward hacking)的影响,从而限制了可扩展性
    • RLVR 是一种基于可验证奖励的强化学习(Reinforcement Learning with Verifiable Rewards, RLVR)方法
      • RLVR 利用客观的、基于规则的奖励(通常来自问答对),这种方法能够缓解奖励破解问题
      • 但 RLVR 通常受限于带有可验证答案的标注数据的稀缺性,因此其 应用范围仅限于特定领域的微调 ,而 非通用预训练
  • 在这项工作中,作者提出了 强化预训练(RPT) ,这是一种新颖的范式,旨在弥合可扩展的自监督预训练与强化学习能力之间的差距
    • RPT 将 Next-Token Prediction 任务重新定义为一种 Next Token 推理过程
    • 对于预训练语料中的任何给定上下文,模型会被激励在预测 Next Token 之前对其进行推理
    • 模型会根据其预测结果与语料中真实 Next Token 的匹配程度,获得一种可验证的内在奖励
    • 这种方法将通常用于 Next-Token Prediction 的海量未标注文本数据,转化为一个适用于通用强化学习的庞大数据集,而无需依赖外部标注或特定领域的奖励函数
  • 这种方法具有以下几个关键优势:
    • 1)可扩展性与通用性 :RPT 利用了与标准 Next-Token Prediction 相同的海量未标注文本数据,将其转化为通用强化学习的大规模数据集,无需外部标注
    • 2)减少奖励破解风险 :通过使用直接的、基于规则的奖励信号(即预测 Next Token 的正确性),RPT 从根本上降低了复杂学习奖励模型中常见的奖励破解风险
    • 3)促进深度理解与泛化 :通过显式激励 Next Token 推理模式,RPT 鼓励模型深入理解上下文,而非简单地记忆 Next Token
      • 模型学会探索和验证关于“为何某个 Token 应该出现”的假设,从而构建更鲁棒的表示
    • 4)推理过程的计算分配 :预训练中的内部推理过程使模型能够为每个预测步骤分配更多的“思考”或计算资源,类似于在训练时为每个 Token 应用推理时扩展,从而直接提升 Next-Token Prediction 的准确性
  • 论文的实验表明:
    • RPT 显著提升了 Next-Token Prediction 的准确性
    • RPT 为后续的强化微调提供了更鲁棒的预训练基础,从而在最终任务中表现更优
    • 扩展曲线显示,在 RPT 框架下,增加训练计算量能够持续提升 Next-Token Prediction 的准确性 ,这表明 RPT 是一种可持续的规模化策略
    • 这些结果证明了强化预训练是一种有效且有前景的新范式,能够推动大语言模型预训练的进步
  • 论文的主要贡献如下:
    • 提出了一种新的规模化范式: 强化预训练(RPT) ,将 Next-Token Prediction 重新定义为通过强化学习训练的推理任务,并利用预训练语料直接生成的内在可验证奖励
    • RPT 提供了一种可扩展且通用的强化学习预训练方法,通过基于规则的奖励减少奖励破解风险,并通过激励 Next Token 推理模式(而非机械记忆(rote memorization))促进泛化能力
    • RPT 显著提升了 Next-Token Prediction 的准确性,并展现出良好的扩展特性,即性能随着训练计算量的增加而持续提升
    • RPT 为后续的强化微调提供了更强的预训练基础,并提升了在多种下游任务中的 Zero-Shot 性能

Preliminary

Next-Token Prediction, NTP

  • Next-Token Prediction 是现代大语言模型的基本训练目标 (2022)
  • 给定训练语料中的输入序列 \(x_{0}\cdots x_{T}\),模型的训练目标是最大化以下目标函数:
    $$
    \mathcal{J}_{\text{NTP} }(\theta)=\sum_{t=1}^{T}\log P(x_{t}\mid x_{0},x_{1},\ldots,x_{t-1};\theta),
    $$
    • 其中 \(\theta\) 表示语言模型的参数

Reinforcement Learning with Verifiable Rewards, RLVR

  • RLVR 利用强化学习目标来增强具有可验证答案的特定技能 (2023)
  • RLVR 需要一个标注的问答对数据集 \(\mathcal{D}=\{(q,a)\}\)
  • 对于特定的问答对 \((q,a)\in \mathcal{D}\),大语言模型 \(\pi_{\theta}\) 会生成一个响应 \(o\sim \pi_{\theta}(\cdot \mid q)\)
    • 然后使用一个确定性的验证器 \(\mathcal{V}\) 计算可验证奖励 \(r=\mathcal{V}(o,a)\),模型的训练目标是最大化期望奖励:
      $$
      \mathcal{J}_{\text{RLVR} }(\theta)=\mathbb{E}_{(q,a)\sim \mathcal{D},o\sim \pi_{\theta}(\cdot|q)}\left[r(o,a)\right].
      $$

Reinforcement Pre-Training

Pre-Training Task: Next-Token Reasoning

  • 论文提出了 Next Token 推理任务用于语言建模
  • 给定训练语料库中的输入序列 \(x_0 \cdots x_T\),对于每个位置 \(t \in \{1, \ldots, T\}\),前缀 \(x_{ < t}\) 被视为上下文,而真实的 Next Token 是 \(x_t\)
  • 在 Next Token 推理任务中,模型 \(\pi_\theta\) 需要在生成对 Next Token 的预测 \(y_t\) 之前生成一个思维链(chain-of-thought)推理序列,记为 \(c_t\)
  • 模型的整体响应为 \(o_t = (c_t, y_t)\),其中 \(o_t \sim \pi_\theta(\cdot \mid x_{ < t})\)
  • 如图 2 所示, Next Token 推理的长思维链过程可能涉及多种推理模式,例如头脑风暴、自我批判和自我修正
  • Next Token 推理任务将预训练语料库重构为大量的推理问题集合,使预训练从学习表面的 Token-level 关联转向理解其背后的隐藏知识,并使强化学习的扩展成为可能

Pre-Training with Reinforcement Learning

  • RPT 通过在线策略强化学习训练 LLM 执行 Next Token 推理,如图 3 所示
  • 对于上下文 \(x_{ < t}\),RPT 提示语言模型 \(\pi_\theta\) 生成 \(G\) 个响应(思维轨迹),记为 \(\{o^i_t\}_{i=1}^G\)
  • 每个响应 \(o^i_t = (c^i_t, y^i_t)\) 包含一个思维链推理序列 \(c^i_t\) 和一个最终的预测序列 \(y^i_t\)
  • 为了验证 \(y^i_t\) 的正确性,论文引入了前缀匹配奖励(prefix matching reward),该奖励支持验证跨越多 Token 或涉及词汇表外 Token 的预测
  • 符号定义如下:
    • \(x_{\geq t}\) 表示真实补全序列,其字节(byte)序列表示为 \(\overline{x}_{\geq t}\)
      • 问题:为什么不是单个 token? 论文每次仅预估下一个 token 吧?到底是每次生成单个 token 还是多个 token?
      • 猜测:这里是表示每次推理时,可以只看一个 token,也可以看多个 token
    • \(y^i_t\) 表示预测的序列,其字节(byte)序列表示为 \(\overline{y}^i_t\)
    • \(\overline{y}^i_t\) 的字节长度记为 \(l\)
    • 真实补全序列中 Token 的累积字节长度定义为有效边界,记为 \(\mathcal{L}_{gt}\)
      • 理解:这里的有效边界是一个整数集合,表示有效的长度值的集合
  • 形式上,对于上下文 \(x_{ < t}\) 的第 \(i\) 个输出,奖励 \(r^i_t\) 定义为:
    $$
    r^i_t = \begin{cases}
    1 & \text{if } \overline{y}^i_t = \overline{x}_{\geq t}[1:l] \text{ and } l \in \mathcal{L}_{gt} \\
    0 & \text{otherwise}
    \end{cases},
    $$
    • 如果预测的字节序列是真实补全字节序列的精确前缀且其长度 \(l\) 匹配任何有效 Token 边界 ,则奖励为 1
    • \(\overline{y}^i_t\) 表示预测的字节序列
    • \(\overline{x}_{\geq t}\) 表示真实补全的字节序列
  • 令 \(\mathcal{D}\) 为所有 \(\{x_{ < t}\}_{t=1}^T\) 的集合,模型训练的目标是最大化期望奖励:
    $$
    \mathcal{J}_{\text{RPT} }(\theta) = \mathbb{E}_{(x_{ < t}, x_{\geq t}) \sim \mathcal{D}, \{o^i_t\}_{i=1}^G \sim \pi_\theta(\cdot|x_{ < t})} \left[r^i_t \right].
    $$

Pre-Training Setup

  • 论文使用 OmniMATH 数据集(2024)进行强化预训练。OmniMATH 包含 4,428 个竞赛级数学问题及其解答,数据来自 AoPS Wiki 和 AoPS 论坛等官方网站
  • 由于许多 Token 即使无需推理也容易预测,论文在强化预训练前进行了 Token-level 数据过滤
    • 论文使用 Deepseek-R1-Distill-Queen-1.5B 作为小型代理模型,计算每个 Token 在前 16 个候选 Token 上的代理模型熵
    • 通过应用熵阈值,论文过滤掉低熵位置,优先训练那些需要更多计算努力预测的挑战性 Token(注:高熵的 token 是较难预测的)
  • 在所有实验中,论文以 Deepseek-R1-Distill-Queen-14B(2025)作为基础模型
    • R1-Distill-Queen-14B 因其基本的推理能力而成为强化学习的良好起点
  • 论文使用 verl 库(2025)实现训练框架,并使用 vllm 进行推理
  • 论文采用 GRPO 算法(2025),具体超参数详见附录 B
  • 训练时,论文采用 8k 的训练长度,学习率为 \(1 \times 10^{-6}\),KL 惩罚为零,批次大小为 256 个问题,每个问题采样 \(G=8\) 个响应,在 rollout 过程中使用温度为 0.8
  • 从每个响应中,论文直接提取最后一个 \(\backslash\)boxed{ } 内的完整序列作为模型对 Next Token 的预测
  • 从第 500 步开始 ,论文使用动态采样以提高训练效率(2025),主实验的总训练步数为 1,000
    • 补充:这里的动态采样是 DAPO 中的动态采样技术,把奖励全为 0 或者全为 1 的 Prompt/样本 丢弃掉
  • 提示模板及其变体在附录 D 中讨论

Evaluation of Pretrained Models

  • 模型预训练完成后,我们可以直接在下游任务上进行 Next-Token Prediction 和强化微调
  • 论文通过以下设置展示强化预训练如何提升大语言模型的语言建模能力和推理能力
  • 语言建模(Language Modeling)
    • 基于 Next Token 推理目标,论文的模型可以自然地用于语言建模
    • 论文报告 Next-Token Prediction 准确率,以评估 RPT 的语言建模性能和扩展性
  • 下游任务的强化微调(Reinforcement Fine-Tuning on Downstream Tasks)
    • 论文以预训练后微调的方式对 RPT 模型进行持续的强化微调
    • 由于 RPT 将预训练过程与强化学习对齐,预训练与后续强化微调之间的目标差距被最小化
    • 论文评估强化预训练过程是否进一步提升了最终任务的性能

Experiments

Language Modeling

  • 论文在 OmniMATH 的 200 个验证集样本上评估语言建模性能
  • 根据第 3.3 节描述的基于熵的数据过滤策略,论文根据难度对验证集中的 Token 位置进行分类
    • 论文使用 R1-Distill-Queen-14B 计算每个 Token 位置的熵,并根据熵是否超过阈值 0.5、1.0 和 1.5 将位置划分为简单、中等和困难三类
    • 为了比较,论文报告了 R1-Distill-Queen-14B 在两种评估方式下的性能:
      • (1) 标准 Next-Token Prediction ,选择概率最高的 Token ;
      • (2) Next Token 推理,生成思维链后再进行最终预测
    • 论文还包含了 Qwen2.5-14B 的结果,因为它是 R1-Distill-Queen-14B 的基础模型
  • 如表 1 所示,RPT-14B 在所有难度级别上的 Next-Token Prediction 准确率均高于 R1-Distill-Queen-14B
    • 值得注意的是,它的性能与显著更大的模型 R1-Distill-Queen-32B 相当(图 4)
  • 这些结果表明,强化预训练能有效捕捉 Token 生成背后的复杂推理信号,并具有提升大语言模型语言建模能力的强大潜力

Scaling Properties of Reinforcement Pre-Training

  • 本节论文研究强化预训练的扩展性
  • 自然语言语料库上的 Next Token 预训练损失在模型大小、训练 Token 数量和训练计算量方面通常遵循幂律衰减(2020, 2022)
  • 论文使用以下幂律形式建模训练计算量 \(C\) 与性能的关系:
    $$
    P(C) = \frac{A}{C^\alpha} + P^*, \tag{5}
    $$
    • 其中 \(P(C)\) 表示验证集上的 Next-Token Prediction 准确率,\(P^*\)、\(\alpha\) 和 \(A\) 是待估计的参数
  • 论文在不同训练步数(100、200、400、800、1000 和 1200)下评估 RPT 的 Next-Token Prediction 准确率,并将其转换为相应的训练计算量
  • 为了评估数据难度的影响,论文考虑了基于熵阈值 0.5(简单)、1.0(中等)和 1.5(困难)过滤的验证集分割
    • 更高的阈值对应更具挑战性的输入
    • 对于每个难度级别,论文根据公式 (5) 拟合结果,并使用决定系数 \(R^2\) 衡量拟合优度(Goodness of fit)
    • 理解:按照 15W 词表算:
      • 熵为 0.5 对应这最大的概率差不多是 0.970;
      • 熵为 1.0 对应这最大的概率差不多是 0.936;
      • 熵为 1.5 对应这最大的概率差不多是 0.901;
      • 更多可视化详情见附录
  • 如图 5 所示,随着训练计算量的增加,RPT 的 Next-Token Prediction 准确率持续提升
  • 所有难度级别的高 \(R^2\) 值表明拟合曲线能准确捕捉性能趋势(理解:说明在不同难度上,均能很好的拟合到公式 (5) 上)

Reinforcement Fine-Tuning with RPT

  • 为了研究 RPT 模型是否能通过 RLVR 更有效地微调,论文从 Skywork-OR1(2025)随机采样带有可验证答案的问题进行进一步训练
    • 论文使用 256 个样本进行训练,200 个样本进行测试
    • 遵循 Skywork-OR1 的数据过滤流程(2025),论文使用 R1-Distill-Queen-32B 识别训练中的挑战性实例
    • 论文将训练批次大小和 PPO 小批次大小均设为 64,并训练模型 15 个周期
    • 评估时,验证的最大 Token 数设为 32,000,温度为 0.6
  • 如表 2 所示:
    • 强化预训练模型在使用 RLVR 进一步训练时达到了更高的上限
    • 当使用相同的 Next-Token Prediction 目标持续训练相同数据时,模型的推理能力显著下降,而后续的 RLVR 仅带来缓慢的性能提升
      • 理解:可以观察到,直接进行普通 NPT 的持续预训练(目标和 RPT 相同)会导致推理能力大幅下降;猜测这里是训练太多次,发生了过拟合了!
        • 引申问题:这里的 普通 NPT 和 RPT 训练轮次是相同的吗?
  • 这些结果表明,在数据有限的情况下,强化预训练可以快速将从 Next Token 推理中学到的强化推理模式迁移到最终任务中

Zero-Shot Performance on End Tasks

  • 论文评估了 RPT-14B 在下游任务上的 Zero-Shot 性能
  • 作为比较,论文评估了 R1-Distill-Queen-14B 和 R1-Distill-Queen-32B 的 Next-Token Prediction 性能,以及 RPT-14B 与 R1-Distill-Queen-14B 的推理性能
  • 论文的评估涉及两个广泛认可的基准:
    • MMLU-Pro(2020),一个综合性多任务理解基准,评估大语言模型在不同领域的表现;
    • SuperGPQA(2025),一个涵盖 285 个学科的研究生级推理问题的大规模基准
  • 在推理设置下,论文将最大 Token 数设为 12,288,温度为 0.8
  • 遵循先前工作(2024, 2025),论文使用多项选择题格式进行评估并报告准确率
  • 如表 3 所示
    • RPT-14B 在所有基准上均优于 R1-Distill-Queen-14B(无论是标准 Next-Token Prediction 还是作为推理模型评估)
    • RPT-14B 还超越了显著更大的 R1-Distill-Queen-32B(在 Next-Token Prediction 模式下),在 SuperGPQA 上提升了 7 分,在 MMLU-Pro 上提升了约 22 分
  • 每个基准的详细分科结果见附录 C

Next-Token Reasoning Pattern Analysis

  • 论文分析了 Next Token 推理与显式问题解决(explicit problem solving)在推理模式上的差异
    • 根据先前研究(2024, 2025),论文统计了模型响应中包含推理关键词(如“break down”、“alternatively”)的比例
  • 论文的分析比较了两种模型在 OmniMATH 数据集上的思维过程:
    • R1-Distill-Queen-14B 用于问题解决
    • RPT-14B 用于 Next Token 推理
    • 对每个模型,采样 200 个响应
  • 论文将推理模式分为六类:
    • 转换(切换策略)、反思(自我检查)、分解(分解问题)、假设(提出并验证假设)、发散思维(探索可能性)和演绎(逻辑推理)

      transition (switching strategies), reflection (self-checking), breakdown (decomposing the problem), hypothesis (proposing and verifying assumptions), divergent thinking (exploring possibilities), and deduction (logical inference).

  • 如图 6 所示,RPT-14B 的 Next Token 推理过程与 R1-Distill-Queen-14B 的问题解决过程显著不同
    • RPT-14B 相对 R1-Distill-Queen-14B: 假设模式的使用量增加了 161.8%
    • RPT-14B 相对 R1-Distill-Queen-14B:演绎模式的使用量增加了 26.2%
    • 问题解决过程(R1-Distill-Queen-14B)更依赖分解模式(breakdown) ,这表明 Next Token 推理引发的推理过程在性质上与结构化问题解决不同
  • 表 4 展示了一个推理模式的例子
    • 该例子揭示了模型参与的是一个深思熟虑的过程,而非简单的模式匹配
    • 它分析了更广泛的语义上下文(“calculating vector magnitude”),识别关键短语(“go over some…”),然后进行头脑风暴并权衡多个可能的延续
    • 这涉及假设生成(“the next part is likely going to be…”)、替代方案考虑(“Alternatively, it could be…”)以及对结构线索(“markdown with headers”)甚至细粒度 Token-level 细节(“could have a space”)的反思
    • 这种多方面的推理,既包含高级语义理解,又包含低级文本特征,展示了模型通过推理探索推断 Next Token 的努力,与 RPT 培养超越表面关联的更深层次理解的目标一致
  • 更多例子见附录 F

Related Work

Scaling Paradigms of Large Language Models

  • LLM 的进步主要由两个扩展维度驱动:
    • 训练时计算(training-time compute)(2022a):通过大幅增加模型参数和训练数据,以下一词预测(next-token prediction)作为预训练任务
    • 测试时计算(test-time compute)(2025a):测试时扩展(2024)通过延长推理计算时间提升大语言模型的推理能力
  • RPT 独特地整合了上述原则,超越现有扩展范式,将每一词预测任务重构为推理任务

Reinforcement Learning for Large Language Models

  • 强化学习在大语言模型的后训练阶段发挥了关键作用
    • RLHF(2022)通过人类偏好数据微调预训练语言模型以提升对齐性
    • 除对齐外,大规模强化学习还被用于增强语言模型的推理能力(2025)
  • 最相关的工作(2025)鼓励语言模型为下一词预测生成有帮助的推理过程
    • 基于帮助性的奖励容易被生成的推理中重复目标词所“破解” ,这种捷径可能损害模型性能
      • 问题:如何理解这里所谓的奖励破解问题?
        • 参考:Quiet-STaR: Language Models Can Teach Themselves to Think Before Speaking
    • 相比之下,论文使用下一词预测的正确性作为基于规则的奖励信号,以最小化奖励破解风险

Future Work

  • RPT 的初步探索仍存在一些局限性
    • 论文的实验主要基于 14B 参数的模型,没有在更大的模型进行测试
    • 虽然 RPT 方法设计为通用,但当前预训练语料库 主要由数学文档组成;
      • 未来工作将探索其在更广泛的通用领域文本上的有效性
    • RPT 训练是从一个具备基础推理能力的模型初始化(R1-Distill-Qwen-14B)的;
      • 后续可以研究从标准基础语言模型开始的 RPT 训练
      • 这将为 RPT 基础性影响提供进一步分析和结论
  • 未来工作可从以下方向推进:
    • 扩展训练语料库的规模和领域覆盖范围,利用大规模通用互联网数据进行强化预训练
    • 增加训练计算资源以突破性能边界
    • 建立强化预训练的扩展定律(scaling laws),指导大语言模型的扩展
    • 探索将混合思维(hybrid thinking)(2025)与 RPT 结合,通过自适应触发下一词推理实现细粒度的适应性思考

附录 A Design Choices of Reward

  • 除了第 3 节描述的基于前缀匹配的奖励机制外,论文还研究了其他几种奖励函数变体以评估其对强化预训练的影响
    • 变体一:首词匹配(first-token matching)
      • 在此设置中,奖励仅反映模型预测 \( y_t^i \) 的首词是否与真实下一词 \( x_t \) 匹配,忽略预测中首词之后的所有词
    • 变体二:探索了“密集奖励”(dense reward)方案:
      • 正确预测的下一词(即 \( y_t^i[0] = x_t \))获得满分奖励(如 1);
      • 对于错误预测(\( y_t^i[0] \neq x_t \)),奖励为一个较小的正值 ,具体为语言模型生成该错误词的概率 \( P(y_t^i[0] \mid x_{ < t}; \theta) \)
        • 问题:为什么是错误词的概率?岂不是错误词的概率越大,奖励越大,应该是正确词的概率吧
      • 这提供了比二元奖励更密集的反馈信号
    • 变体三:条件性应用密集奖励结构:
      • 仅当给定前缀 \( x_{ < t} \) 的 \( G \) 次采样中至少有一次正确预测下一词时 ,才使用密集奖励;
      • 若所有 \( G \) 次采样均错误,则应用其他奖励方案(如零奖励或统一的小惩罚)
  • 实验表明,这些替代奖励与前缀匹配奖励相比,性能相当
    • 表明强化预训练框架对这些奖励信号的修改具有较强的鲁棒性 ,其核心优势可能对这些特定选择不敏感,至少在测试的变体范围内如此

附录 B Hyperparameters Used for Reinforcement Pre-Training

  • 表 5 展示了第 4 节中强化预训练的详细超参数
  • 论文遵循精确策略强化学习(2025)的设置,将熵损失系数设为 0

附录 C Detailed Results on End Tasks

  • 表 6 和表 7 展示了通用终端任务基准的详细分类性能
  • RPT-14B 模型在大多数类别中表现优于 R1-Distill-Qwen-14B 和 R1-Distill-Qwen-32B

附录 D Impact of Prompt Templates

  • 论文探索了不同提示模板对初始下一词推理性能的影响
  • 表 10 展示了七种模板变体,这些模板使用不同指令片段,并以不同形式包装上下文
  • 如表 8 所示,清晰的 Prompt 能很大程度提升初始表现的准确性
    • 第 4 节的强化预训练实验使用了“v0”模板,其他模板变体的优化留待未来工作

附录 E Keywords for Reasoning Pattern Analysis

  • 表 9 列出了第 4.5 节中用于推理模式分析的模式组和关键词

附录 F Case Studies

  • 表 11 展示了 RPT-14B 在下一词推理任务中的三个案例,包括模型对数学问题和文本上下文的推理过程
  • 这些案例揭示了模型如何通过多角度思考生成最终预测

附录:概率和熵的关系图

  • 关键词:entropy curve;熵和概率;概率和熵;曲线图;

  • 假定只有一个 token 的值较大,其他 token 概率相同,此时的熵和最大概率的关系是如何的?

  • 可视化最大概率和熵的关系的代码

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    import numpy as np
    import matplotlib.pyplot as plt

    # 设置中文显示
    plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC"]
    plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题

    def calculate_entropy(p, n=150000):
    remaining_p = (1 - p) / (n - 1)

    probabilities = np.full(n, remaining_p)
    probabilities[0] = p

    entropy = -np.sum(probabilities * np.log(probabilities))

    return entropy

    p_values = np.linspace(0.0001, 0.9999, 9999)

    entropy_values = [calculate_entropy(p) for p in p_values]

    plt.figure(figsize=(10, 6))
    plt.plot(p_values, entropy_values, 'b-', linewidth=2)

    max_entropy_idx = np.argmax(entropy_values)
    plt.scatter(p_values[max_entropy_idx], entropy_values[max_entropy_idx], color='red', s=50, zorder=5)
    plt.annotate(f'最大熵: p={p_values[max_entropy_idx]:.2f}, H={entropy_values[max_entropy_idx]:.4f}',
    xy=(p_values[max_entropy_idx], entropy_values[max_entropy_idx]),
    xytext=(p_values[max_entropy_idx]+0.1, entropy_values[max_entropy_idx]+0.2),
    arrowprops=dict(facecolor='black', shrink=0.05, width=1.5, headwidth=8))

    # print(entropy_values)
    points = [0.5, 1.0, 1.5]
    for point in points:
    for index, entropy in enumerate(entropy_values):
    if entropy <= point:
    print((p_values[index], entropy_values[index]))
    plt.scatter(p_values[index], entropy_values[index], color='red', s=50, zorder=5)
    break

    plt.title('概率值p与熵的关系图 (15W个候选值)')
    plt.xlabel('概率值p (第一个候选值的概率)')
    plt.ylabel('熵 (nats)')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.xlim(-0.05, 1.05)
    plt.ylim(0, max(entropy_values) * 1.1)

    plt.tight_layout()
    plt.show()

    # (0.9695, 0.4999866540719247)
    # (0.9361, 0.9991510505356568)
    # (0.9012, 1.4999746052311926)
  • 示意图:

NLP——将传统强化学习Trick用到LLM中的思考

  • 参考链接:
    • 英文链接:Old-School Deep RL Tricks for Modern LLM Training
      • 本文主要翻译该英文链接的内容,并包含自己的一些思考
    • 中文解读博客:炒冷饭:把祖传RL的tirck塞进LLM

整体说明

  • 本文讨论如何将深度强化学习技术移植到 RLHF/RLAIF 和使用工具的 LLM 智能体中
    • 具体设计到的技术包括 \(n\) 步回报、TD(\(\lambda\))、不确定性、安全性以及其他
  • 本文收集了著名的深度强化学习技术(前 LLM 时代),并将其适配到现代 LLM 的训练/推理中
  • 以下方法并非唯一途径(可将其视为实用的起点),可以根据自己的技术栈进行优化

为何要将深度强化学习理念引入 LLM?

  • 生成过程就是一条轨迹:隐藏状态 \(h_t\) 和动作 \(y_t\) (Token 或工具调用)
  • 奖励可能是稀疏的(仅在序列末尾有一个分数)或 Dense 的(规则、自我批判、任务进度)
  • 这正是经典深度强化学习所针对的场景:长 horizon 信用分配、带噪声的目标以及安全约束

用于长文本的多步回报、TD(\(\lambda\)) 和 GAE

  • 设 \(h_t\) 为 Token \(y_{1:t-1}\) 之后的解码器状态
  • 训练一个价值头 \(V_\phi(h)\) 用于预测未来奖励
  • \(n\)步回报
    $$
    G_t^{(n)}=\sum_{i=0}^{n-1}\gamma^i r_{t+i}+\gamma^n V_\phi(h_{t+n})
    $$
  • TD 更新
    $$
    V_\phi(h_t)\leftarrow V_\phi(h_t)+\alpha\big(G_t^{(n)}-V_\phi(h_t)\big)
    $$
  • TD(\(\lambda\))/GAE(Actor-critic)
    $$
    \hat A_t=\sum_{l=0}^{\infty}(\gamma\lambda)^l\delta_{t+l}
    $$
    • 其中 \(\delta_t=r_t+\gamma V_\phi(h_{t+1})-V_\phi(h_t)\)
  • 带 KL 控制的策略梯度
    $$
    \nabla_\theta J \approx \mathbb{E}\left[\sum_t \hat A_t \nabla_\theta \log \pi_\theta(y_t|h_t)\right] -\beta\nabla_\theta \mathrm{KL}\big(\pi_\theta|\pi_{\text{ref} }\big)
    $$

应用场景

  • 仅带有 end-of-sequence scores 的长文本生成/摘要(Long-form generation/summarization):
    • 通过 GAE 在 Token 间传播信用(propagate credit via GAE across tokens.)
  • 多轮对话助手(Multi-turn assistants):
    • 将每一轮视为一个步骤;
    • 设置 \(\gamma<1\) 以减少不必要的轮次(理解:对于最后一步给奖励的场景,\(\gamma<1\) 相当于鼓励缩短决策轮次)
  • 工具/代码智能体(Tool/Code agents):
    • 每次工具调用或单元测试结果作为一个步骤;
    • \(n\) 步回报可快速利用中间反馈
  • RAG/问答:
    • 将检索质量/格式检查作为 Dense 奖励,以解决稀疏信用问题(sparse-credit problems.)

Off-policy multi-step without explosions

  • 混合日志数据和新轨迹是常见做法(需结合修正进行多步学习)
  • V-trace(适用于分布式采样,具有稳定性) :
    $$
    \rho_t=\min\left(\bar\rho,\frac{\pi_\theta(y_t|h_t)}{\mu(y_t|h_t)}\right),\quad c_t=\min\left(\bar c,\frac{\pi_\theta(y_t|h_t)}{\mu(y_t|h_t)}\right)
    $$
    • 使用 \(c_t\) 截断 eligibility 并构建修正后的优势
  • Retrace(\(\lambda\))
    $$
    c_t=\lambda \min\left(1,\rho_t\right)
    $$
  • Tree-Backup(\(\lambda\)) :
    • 期望回溯(无重要性权重),方差更低,偏差较小
  • 典型截断值: \(\bar\rho\in[1,2]\) , \(\bar c=1\)

附录:V-trace 介绍

  • 原始论文 IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures, ICML 2018, Google
    • IMPALA:Importance Weighted Actor-Learner Architectures
  • V-trace 是一种用于强化学习的 Off-policy 修正方法,主要用于解决分布式强化学习中 Actor 和 Learner 策略不一致的问题
  • 问题引入:
    • 在分布式强化学习框架 IMPALA 中,Actor 负责与环境交互生成经验轨迹,Learner 则根据这些轨迹来更新策略
    • 由于 Learner 的更新速度通常比 Actor 快,导致用于生成轨迹的策略滞后于 Learner 的当前策略,使得学习过程变成了 Off-policy 学习
    • V-trace 就是为了校正这种策略差异带来的不良影响而设计的
  • V-trace 核心思想 :
    • V-trace 通过引入重要性采样(Importance Sampling)来修正 Off-policy 学习中的偏差
    • V-trace 对传统的策略梯度公式进行了改进,通过裁剪重要性系数来稳定方差,使得训练更加稳定
    • 具体来说,V-trace 在计算值函数更新时,使用了经过裁剪的重要性采样比率,限制了重要性系数的最大值,以防止因策略分布差异过大导致的估计偏差过大
  • V-trace 数学公式 :
    • 在 V-trace 算法中, \(n\) 步下的目标价值函数可以表示为:
      $$ V_s = V(x_s) + \sum_{t=s}^{s+n-1} \gamma^{t-s}(\prod_{i=s}^{t-1} c_i )\delta_t V$$
      • \(\delta_t V\) 的表达式为:
        $$ \delta_t V = \rho_t(r_t + \gamma V(x_{t+1}) - V(x_t))$$
        • 其中:
          $$
          \rho_t = min(\bar{\rho},\frac{\pi(a_t|x_t)}{\mu(a_t|x_t)}) \\
          c_i = min(\bar{c},\frac{\pi(a_i|x_i)}{\mu(a_i|x_i)}) \\
          $$
      • \(\pi\) 是目标策略(Learner 的策略)
      • \(\mu\) 是行为策略(Actor 的策略)
      • \(\bar{\rho}\) 是重要性系数的裁剪阈值
      • \(c_i\) 是用于控制模型收敛速度的参数
  • V-trace 的作用 :
    • V-trace 使得 IMPALA 能够在高吞吐量的情况下,通过结合解耦的 Acting 和 Learning 过程,实现相当稳定的学习
    • V-trace 有效地解决了分布式强化学习中因策略时间差导致的训练不稳定问题,提高了算法的效率和鲁棒性

应用场景

  • 带大量历史日志的 RLHF/RLAIF:减少行为-目标不匹配导致的偏差
  • Distributed/asynchronous sampling(IMPALA-style):学习器和执行者不同步
    • 分布式/异步采样场景中
  • 离线+在线混合训练:安全复用旧数据,同时保持稳定性

Uncertainty and risk: optimize for reliability, not just averages

  • 奖励模型不确定性(集成/贝叶斯头) (Reward-model uncertainty (ensembles/Bayesian head))
    • 设 \(\hat r\) 为均值, \(\sigma^2\) 为方差
    • 对高风险样本进行降权:
      $$
      \tilde A_t=\frac{\hat A_t}{1+\alpha \sigma_t}
      $$
  • 条件风险价值(CVaR,聚焦尾部) (tail-focus)
    $$
    \mathrm{CVaR}_\alpha(R)=\mathbb{E}[R \mid R \le q_\alpha]
    $$
    • 通过奖励的分位数回归(quantile regression)实现,并在更新时对低分位数值进行掩码
  • 偏好可靠性 (Preference reliability)
    • 通过 Bradley-Terry 置信度对成对偏好进行加权,以减少噪声标签导致的过拟合
  • 参数不确定性 (Parameter uncertainty)
    • 策略/价值上的 Laplace-LoRA 或小型集成可提供每个状态的方差,用于控制步长或触发“重新生成与重新评分”

附录:条件风险价值(CVaR)

  • 条件风险价值(Conditional Value at Risk,CVaR),常被称为 Expected Shortfall
  • CVaR 含义是:
    • 在给定置信水平 \(\alpha\)(通常是一个较小的值,如 0.05 或 0.1)的情况下,所有风险超过 \(\alpha\) 的情况对应损失的期望值
    • 简单来说,CVaR 关注的是 “最坏情况中的平均损失”
      • 例如,当 \(\alpha=0.05\) 时,CVaR 表示在所有可能发生的结果中,损失最大的 5% 的情况的平均损失
  • CVaR 常被用于风险敏感型决策,尤其适用于需要规避极端不利结果的场景

应用场景

  • 安全性/合规性要求高的领域(医疗、金融、教育):减少罕见的灾难性失败
  • 噪声或主观的人类反馈:不确定性加权更新可稳定学习
  • 领域偏移/检索质量波动:检测分布外(OOD)情况并重新路由至重新评估

Safety as constraints, not just penalties

  • 设 \(c(h_t,y_t)\) 为安全成本(毒性、个人身份信息(PII)、事实风险)。通过拉格朗日方法进行优化:
    $$
    \max_\theta\ \mathbb{E}[R] - \lambda\big(\mathbb{E}[c]-\tau\big) \ -\ \beta\mathrm{KL}\left(\pi_\theta|\pi_{\text{ref} }\right)
    $$
  • 通过对偶上升更新 \(\lambda\)
  • 训练一个成本价值头 \(C_\psi(h)\)
  • 推理时添加屏蔽(分类器/正则表达式/规则)以过滤不安全 Token,或路由至拒绝模板
  • 训练时的约束与解码时的屏蔽结合使用效果最佳

应用场景

  • 企业/公共部门(Enterprise/public-sector):严格的个人身份信息(PII)/合规控制
  • 开放式对话/内容(Open-ended chat/content):减少毒性/偏见
  • 高事实性要求(High factuality requirements):将幻觉信号视为成本

Robustness to distribution shift and prompt attacks

  • 分布鲁棒优化(distributionally robust optimization,DRO):在训练提示分布 \(\hat P\) 周围的 \(f\)-divergence ball \(\mathcal{Q}\) 内最大化最坏情况下的奖励
    $$
    \max_\theta \ \min_{Q\in \mathcal{Q} } \ \mathbb{E}_{x\sim Q}\big[R(\pi_\theta; x)\big]
    $$
    • 问题:待补充理解
  • 实用方案(Practical recipe):
    • 对抗性重加权和对抗性提示生成
    • 添加红队测试循环和领域随机化(检索噪声、工具延迟、系统提示)

附录:DRO

  • 在强化学习中,分布鲁棒优化(Distributionally Robust Optimization, DRO) 是一种旨在提高策略对环境分布不确定性鲁棒性的方法
  • DRO 结合了分布鲁棒优化的核心思想与强化学习的框架,解决了传统 RL 方法在环境分布未知(真实的状态转移概率 \( p(s’|s,a) \) 未知)、存在扰动或偏移时性能下降的问题
  • DRO 通过建模环境分布的不确定性集合,在最坏情况下优化策略性能,从而提高策略对分布扰动、估计误差或偏移的鲁棒性
  • DRO 是解决 RL 中“分布不确定性”问题的重要框架,尤其适用于非平稳环境、安全关键场景或模型误差较大的场景,但需在鲁棒性与计算效率、性能之间进行精细权衡
DRO 背景:强化学习中的分布不确定性
  • 强化学习的核心是在马尔可夫决策过程(MDP)中学习最优策略,而 MDP 的关键组件(如状态转移分布、奖励分布、初始状态分布)往往存在不确定性:
    • 状态转移分布 :真实的状态转移概率 \( p(s’|s,a) \) 通常未知,只能通过有限样本估计,可能存在偏差;
    • 奖励分布 :奖励函数可能受噪声影响,或随环境动态变化;
    • 分布偏移 :实际部署时,环境分布可能与训练时不同(如领域自适应、非平稳环境)
  • 传统 RL 方法(如 Q-learning、PPO 等)通常假设可以通过采样准确估计真实分布,或直接使用经验分布,但这种假设在分布不确定时会导致策略不稳定、泛化能力差
DRO 的基本思想
  • DRO 的核心是 “在不确定性中求稳健” :
    • 当无法获知真实分布时,按如下步骤进行:
      • 第一步:定义一个包含真实分布的 “不确定性集合”(uncertainty set) (即所有可能的候选分布构成的集合)
      • 第二步:在这个集合中 针对最坏情况(worst-case)优化目标函数
      • 最终得到的解对集合内的所有分布都能保持较好性能,从而提高鲁棒性
  • 简单来说:传统优化是 “针对真实分布求最优”,而 DRO 是“针对最坏可能的分布求最优”
DRO:RL 问题形式化
  • 在 RL 中,DRO 的目标是学习一个策略,使其在不确定性集合内的所有可能分布下都能最大化(或保证)累积奖励
  • 以 MDP 为例,其核心形式化为:
    $$
    \max_{\pi} \min_{p \in \mathcal{U}} V^{\pi}(p)
    $$
  • 其中:
    • \( \pi \) 是待优化的策略;
    • \( \mathcal{U} \) 是不确定性集合(包含真实分布的候选分布集合);
    • \( V^{\pi}(p) \) 是策略 \( \pi \) 在分布 \( p \) 下的价值函数(累积奖励期望)
  • 目标是找到一个策略 \( \pi \),使得在不确定性集合 \( \mathcal{U} \) 中最坏的分布 \( p \) 下,价值函数 \( V^{\pi}(p) \) 尽可能大
DRO 中不确定性集合如何构造?
  • 不确定性集合 \( \mathcal{U} \) 的定义是DRO的核心,它直接决定了鲁棒性的范围和程度。构造方式通常基于统计距离(衡量分布间的差异),常见的包括:
    统计距离 定义与应用场景 特点
    KL散度(KL-divergence) \( \mathcal{U} = \{ p \mid \text{KL}(p \parallel \hat{p}) \leq \epsilon \} \),其中 \( \hat{p} \) 是经验分布,\( \epsilon \) 是不确定性预算。 适用于高维分布,计算相对简单,但不对称(\(KL(p||q) \neq KL(q||p)\))
    Wasserstein距离 \( \mathcal{U} = \{ p \mid W_c(p, \hat{p}) \leq \epsilon \} \),基于最优传输理论,衡量分布间的“运输成本”。 对异常值更稳健,适用于连续分布,但高维下计算复杂
    总变差距离 \( \mathcal{U} = \{ p \mid \text{TV}(p, \hat{p}) \leq \epsilon \} \),衡量分布最大差异。 直观但严格,导致集合较小,解可能过于保守
  • 不确定性集合的大小由参数 \( \epsilon \) 控制:
    • \( \epsilon \) 越大,集合包含的分布越多(鲁棒性越强,但可能过于保守,牺牲性能);
    • \( \epsilon \) 越小,集合越接近经验分布(性能可能更好,但鲁棒性下降)
DRO 在强化学习中的典型应用
  • Model-based RL中的鲁棒策略学习
    • 在 Model-based RL 中,若环境模型(如状态转移函数)存在误差,DRO 可通过构造模型分布的不确定性集合,优化最坏情况下的策略,避免模型误差导致的性能崩溃
  • Model-free RL中的分布偏移处理
    • Model-free 方法依赖采样数据,若采样分布与真实分布存在偏移(如探索不足、噪声干扰),DRO 可通过对采样分布的不确定性建模,使策略对偏移更稳健
  • 安全强化学习(Safe RL)
    • 在需要满足安全约束(如避免碰撞、能耗上限)的场景中,DRO 可确保策略在约束分布的最坏情况下仍不违反安全条件
DRO 相关的挑战与权衡
  • 计算复杂度 :DRO 需求解 min-max 优化问题(先最小化最坏分布,再最大化策略性能),比传统 RL 的单目标优化更复杂,尤其在高维状态/动作空间中难以高效求解
  • 鲁棒性与保守性的权衡 :不确定性集合过大可能导致策略过于保守(为了稳健牺牲了最优性能);过小则可能无法覆盖真实分布,失去鲁棒性
  • 不确定性集合的合理性 :如何基于有限数据构造“既包含真实分布,又不过大”的集合,是 DRO 的核心难点(通常依赖统计理论或领域知识)

应用场景

  • 公共 LLM 端点:抵御越狱/攻击提示
  • RAG 系统:证据质量和风格存在差异
  • 跨领域泛化:训练-部署偏移(training–serving shift)

Model-based flavor: value-guided decoding

  • 采样时使用价值头作为短 horizon 前瞻:
    $$
    \log \tilde \pi(y_t|h_t) = \log \pi_\theta(y_t|h_t) + \eta\big( V_\phi(h_{t+1}) - V_\phi(h_t) \big)
    $$
    • 公式理解:朝前面看一步,看看下一步 \(V_\phi(h_{t+1})\) 相对当前这步 \(V_\phi(h_{t})\) 带来的价值提升,提升越大的动作 \(y_{t}\),则赋予更大的采样概率(直接将增益加到原始对数概率上)
    • 问题:是不是在于原始策略输出概率增加前或后,增加一个归一化步骤会更好?
  • 对于工具智能体,展开 \(n\) 步工具计划并引导至 \(V_\phi\)
  • 这可在不重新训练整个策略的情况下,推动生成过程向高下游价值方向发展

应用场景

  • 代码/测试驱动生成:偏向通过测试/完成子任务
  • 冗长推理/约束写作:解码时更严格地遵循目标
  • 低重训练预算:无需完整强化学习周期即可获得“轻量级规划”收益

Offline and conservative RL from logs

  • IQL/AWAC 风格的优势加权更新
    $$
    \pi_\theta \leftarrow \arg\max_\theta \ \mathbb{E}\left[\exp\left(\frac{A_\beta(s,a)}{\lambda}\right)\log \pi_\theta(a|s)\right]
    $$
  • CQL 风格的抑制 :
    • 通过压低未见过的文本/动作上的 \(Q\) 值,抑制过度乐观的分布外(OOD)动作
    • 在偏好空间中,添加行为正则化器 \(\log \pi_\beta(y|h)\) 以保持策略接近日志行为
  • 通过双重鲁棒异策略估计器进行评估:
    • 建模奖励 \(\hat r\) ,并将重要性权重与控制变量结合使用

应用场景

  • 日志丰富,在线数据有限:在风险探索前充分利用历史数据的价值
  • 高风险领域:先进行保守改进,再逐步扩展
  • 新领域冷启动:初始时保持接近 \(\pi_\beta\)

Exploration and diversity (without going off the rails, 不偏离正轨)

  • 熵/温度:SAC 风格的熵奖励或受控采样温度
  • 内在动机:基于 \(h_t\) 的分歧/RND,以鼓励新颖的语义或新工具路径
    *
  • 多样性正则化器:去重;与提示的互信息

应用场景

  • 创意写作/广告/教育内容:在安全范围内实现风格/结构多样性
  • 工具链发现:找到可靠的新序列
  • 覆盖导向的评估:扩大提示集群覆盖范围

Hierarchy and skills: plan–act–verify

  • 慢规划器输出子目标(工具计划、大纲),快执行器实现子目标
  • 训练类选项策略,或通过模仿/离线强化学习预训练技能库,并通过高层控制器调用技能

应用场景

  • 多工具/多步骤工作流(检索 -> 规划 -> 执行 -> 验证)
  • 可分解的大型任务(数据 Pipelines、无人机调度(UAV scheduling)、城市分析(urban analytics))
  • 跨任务/领域的技能复用

Common pitfalls(陷阱)

  • Only end-of-sequence reward + weak value -> unstable advantages. Densify or invest in \(V\).
    • 理解:仅序列最后奖励+奖励模型比较弱,会造成不稳定的优势,需 Dense 化奖励或优化 \(V\)
  • Off-policy drift without correction -> biased updates. Use V-trace/Retrace.
    • 理解:无修正的 Off-policy 偏移是有偏更新,需使用 V-trace/Retrace 修正
  • Single deterministic reward model -> brittle. Prefer ensembles/quantiles.
    • 单一确定性奖励模型是脆弱的(理解:不稳定,方差大),建议优先选择集成/分位数奖励
  • Safety only at decoding time -> model still learns unsafe regions. Train with constraints too.
    • 仅在解码时考虑安全意味着模型仍会学习不安全区域,建议训练时也需添加约束

Closing Remarks

  • 以上这些是将“传统技巧”切实转化为 LLM 实践的方法
  • 如果只选择一种技术组合,可从 PPO + GAE + V-trace + Lagrangian safety 开始,然后添加不确定性感知加权(uncertainty-aware weighting)和价值引导解码(value-guided decoding)
  • 作者很想了解反例和更好的方案,欢迎调整并分享在你的系统中有效的(或无效的)方法

NLP——MHA2MLA(Partial-RoPE)

注:本文包含 AI 辅助创作

  • 参考链接:
    • 原始论文:Towards Economical Inference: Enabling DeepSeek’s Multi-Head Latent Attention in Any Transformer-based LLMs, Fudan & ECNU & Shanghai Al Lab
    • GitHub 开源链接:github.com/JT-Ushio/MHA2MLA

Paper Summary

  • 整体总结:
    • MHA2MLA 包含贡献感知的部分 RoPE 移除和 SVD 驱动的低秩投影
      • 论文实现了 KV 缓存的近乎无损压缩
    • 本文创新:通过针对性的参数重用和数据高效的微调来实现大语言模型架构迁移的可行性
  • 问题提出:
    • 基于 MHA 及 MHA 的变体(GQA 等)的标准 LLM 在成本上存在显著劣势
  • DeepSeek 的解法:MLA
    • Multi-head Latent Attention(MLA,多头潜在注意力)是 DeepSeek 提出的一种创新架构
    • MLA 通过将键值(Key-Value, KV)缓存显著压缩为一个潜在向量,从而实现高效且经济的推理
  • 本文核心解决的现实问题及挑战:
    • 使训练良好的大语言模型(例如 Llama)能够快速适配 MLA 而无需从头预训练,既具有重要意义又充满挑战
  • 论文提出了首个数据高效的精调方法,用于从 MHA 迁移至 MLA(MHA2MLA),该方法包含两个关键组件:
    • 对于 部分旋转位置编码(partial-RoPE) ,论文从对注意力分数贡献较小的查询和键的维度中移除 RoPE
    • 对于 低秩近似(low-rank approximation) ,论文基于预训练的键和值参数引入联合奇异值分解(SVD)近似
  • MHA2MLA 仅需使用一小部分(3% 到 6%)数据即可恢复性能,显著降低了推理成本,同时可与 KV 缓存量化等压缩技术无缝集成
    • 举例:Llama2-7B 的 KV 缓存大小减少了 92.19%,而在 LongBench 上的性能仅下降 0.5%

Introduction and Discussion

  • LLM 的快速发展显著加速了通往通用人工智能(Artificial General Intelligence, AGI)的进程,模型能力随参数数量增加而呈现可预测的扩展 (2020)
    • 然而,这些收益伴随着高昂的代价:训练的计算需求不断攀升,推理吞吐量下降,导致巨大的能源消耗和碳排放 (2019)
  • 随着下游任务日益复杂,长上下文处理和高计算量的推理已成为大语言模型应用的核心 (2024)
    • 一个关键瓶颈在于 MHA (2017) 机制固有的键值(Key-Value, KV)缓存的内存占用,其随序列长度和模型大小线性增长
    • 为缓解此问题,研究者探索了诸如分组查询注意力(Grouped-Query Attention, GQA)(2023) 和多查询注意力(Multi-Query Attention, MQA)(2019) 等变体
    • 然而,这些方法不仅减少了 KV 缓存大小,也减少了注意力中的参数数量,导致性能下降
  • DeepSeek 引入了 MLA (2024)
    • MLA 是一种配备低秩键值联合压缩的注意力机制
    • 经验上,MLA 实现了优于 MHA 的性能,同时显著减少了推理期间的 KV 缓存,从而提升了推理效率
  • 一个关键但尚未探索的问题随之产生:原本为 MHA 良好训练的大语言模型能否适配 MLA 以进行推理?
    • MHA 与 MLA 之间固有的架构差异使得零样本迁移不切实际,而从头预训练的惊人成本使得这种转变在技术上具有挑战性且在现有研究中探索不足
  • 为填补这一空白,论文提出了首个精心设计的 MHA2MLA 框架,该框架最大化地复用了预训练 MHA 网络的参数 ,同时将 KV 缓存存储和推理过程与 MLA 的范式对齐(图 1)
  • 论文的框架具有两项关键的技术创新:
    • 部分旋转位置编码(partial rotary position embedding, partial RoPE)
    • 低秩近似(low-rank approximation)
  • MHA2MLA 的主要目标是实现数据高效的性能恢复,即使用最少的精调数据来恢复由架构变更引起的能力下降
  • MLA 的推理加速机制与 RoPE 之间固有的不兼容性 necessitates 架构上的折衷
  • DeepSeek 的解决方案是在有限维度中保留位置编码(PEs)同时压缩其他维度,这需要在 MHA 中策略性地移除 RoPE 维度(将其转换为 NoPE)以实现与 MLA 的对齐
  • 虽然更高的移除比率提升了压缩效率,但也加剧了性能下降,形成了效率与能力之间的权衡
  • 通过系统性地探索 RoPE 移除策略,论文发现基于贡献的维度选择(保留按注意力分数影响排序的前 k 个维度)能最优地平衡这些竞争目标
  • 尽管先前的研究已经探索了从头训练部分 RoPE 大语言模型 (2024; 2021),但论文的工作开创了在大语言模型中进行全 RoPE 到部分 RoPE(Partial RoPE) 转换的数据高效精调方法
  • MLA 通过将键和值投影到低秩潜在表示空间(存储在 KV 缓存中)来减少内存占用
  • MHA2MLA 也可以对剥离了 RoPE 的值和键(NoPE 维度)应用低秩近似
  • 通过对对应于 NoPE 子空间的预训练参数矩阵 \(\boldsymbol{W}_{v}\) 和 \(\boldsymbol{W}_{k}\) 执行奇异值分解(Singular Value Decomposition, SVD),论文将这些组件压缩到一个潜在空间中,同时最大限度地保留原始模型学到的知识
  • 论文的主要贡献是:
    • 论文提出了 MHA2MLA,这是首个参数高效的精调框架,能够仅使用 \(3%\) 到 \(6%\) 的训练数据将预训练的基于 MHA 的大语言模型适配到 MLA 架构,而无需从头训练
    • 论文证明了 MHA2MLA 架构可以与 KV 缓存量化技术集成,以实现更经济的推理(最高减少 96.87%)
    • 论文在四种模型规模(从 135M 到 7B,涵盖 MHA 和 GQA)上进行了实验,并进行了详细的消融研究,为 MHA2MLA 提供了指导和见解

Preliminary

多头注意力机制 (Multi-Head Attention, MHA)

  • 给定一个输入序列 \(\{\boldsymbol{x}_{1},\ldots,\boldsymbol{x}_{l}\} \in \mathbb{R}^{l \times d}\),标准的 MHA (2017) 将每个 token \(\boldsymbol{x}_{i}\) 投影为查询向量 \(\boldsymbol{q}_{i}^{(h)} = \boldsymbol{x}_{i}\boldsymbol{W}_{q}^{(h)}\)、键向量 \(\boldsymbol{k}_{i}^{(h)} = \boldsymbol{x}_{i}\boldsymbol{W}_{k}^{(h)}\) 和值向量 \(\boldsymbol{v}_{i}^{(h)} = \boldsymbol{x}_{i}\boldsymbol{W}_{v}^{(h)}\),其中对于每个头 \(h \in \{1,\ldots,n_{h}\}\),有 \(\boldsymbol{W}_{q}^{(h)}, \boldsymbol{W}_{k}^{(h)}, \boldsymbol{W}_{v}^{(h)} \in \mathbb{R}^{d \times d_{h} }\)。旋转位置编码(Rotary Positional Encoding, RoPE)(2021) 被应用于查询和键(例如,\(\boldsymbol{q}_{i,\text{rope} }^{(h)} = \text{RoPE}(\boldsymbol{q}_{i}^{(h)})\)),随后进行缩放点积注意力计算:
    $$
    \boldsymbol{o}_{i}^{(h)} = \text{Softmax}\left( \boldsymbol{q}_{i,\text{rope} }^{(h)} \boldsymbol{k}_{\leq i,\text{rope} }^{(h)\top} \right) \boldsymbol{v}_{\leq i}^{(h)}, \\
    \text{MHA}(\boldsymbol{x}_{i}) = \left[ \boldsymbol{o}_{i}^{(1)}, \ldots, \boldsymbol{o}_{i}^{(n_{h})} \right] \boldsymbol{W}_{o},
    $$
    • 其中 \(\boldsymbol{W}_{o} \in \mathbb{R}^{(n_{h}d_{h}) \times d}\)
    • \([\cdot,\cdot]\) 表示向量拼接
    • 注:为简化符号,此处忽略了 \(\frac{1}{\sqrt{d} }\) 缩放因子
  • 在自回归推理过程中,MHA 需要存储大小为 \(O(2ln_{h}d_{h})\) 的键值(KV)缓存 \(\{\boldsymbol{k}_{\text{rope} }^{(h)}, \boldsymbol{v}^{(h)}\}_{h=1}^{n_{h} }\),该大小随序列长度 \(l\) 线性增长,造成了内存瓶颈

分组查询注意力 (Grouped-Query Attention, GQA)

  • GQA (2023) 通过在 \(n_{g}\) 个组(\(n_{g} \ll n_{h}\))之间共享键/值来减少 KV 缓存。对于每个头 \(h\),它映射到组 \(g = \lfloor h / n_{g} \rfloor\):
    $$
    \boldsymbol{o}_{i}^{(h)} = \text{Softmax}\left( \boldsymbol{q}_{i,\text{rope} }^{(h)} \boldsymbol{k}_{\leq i,\text{rope} }^{(g)\top} \right) \boldsymbol{v}_{\leq i}^{(g)}, \\
    \text{GQA}(\boldsymbol{x}_{i}) = \begin{bmatrix} \boldsymbol{o}_{i}^{(1)}, \ldots, \boldsymbol{o}_{i}^{(n_{h})} \end{bmatrix} \boldsymbol{W}_{o}. \tag{2}
    $$

多头查询注意力 (Multi-Query Attention, MQA)

  • MQA (2019) 是 GQA 的一个特例,其中 \(n_{g} = 1\),即所有头共享一个全局的键/值
  • 虽然 GQA 和 MQA 方法将 KV 缓存减少到 \(O(2ln_{g}d_{h})\),但由于参数剪枝,它们会导致性能下降

多头潜在注意力 (Multi-Head Latent Attention, MLA)

  • MLA (DeepSeek-AI, 2024) 引入了一种混合架构,将位置编码(PE)与潜在 KV 压缩解耦
  • 对于每个头 \(h\),输入 \(\boldsymbol{x}_{i}\) 被投影为两个互补的分量:
  • 位置感知分量 (Position-Aware Component) :一部分维度保留 PE 以保持位置敏感性:
    $$
    \boldsymbol{q}_{i,\text{rope} }^{(h)}, \boldsymbol{k}_{i,\text{rope} } = \text{RoPE}\left( \boldsymbol{x}_{i}\boldsymbol{W}_{dq}\boldsymbol{W}_{qr}^{(h)}, \boldsymbol{x}_{i}\boldsymbol{W}_{kr} \right),
    $$
    • 其中 \(\boldsymbol{W}_{dq} \in \mathbb{R}^{d \times d_{q} }\),\(\boldsymbol{W}_{qr}^{(h)} \in \mathbb{R}^{d_{q} \times d_{r} }\),\(\boldsymbol{W}_{kr} \in \mathbb{R}^{d \times d_{r} }\) 将查询/键投影到保留 RoPE 的 \(d_{r}\) 维分量中
  • 位置无关分量 (Position-Agnostic Component) :剩余的 \(d_{c}\) 个维度被移除 PE(即 NoPE),并将 \(\boldsymbol{k}_{i,\text{nope} }^{(h)}\) 和 \(\boldsymbol{v}_{i}^{(h)}\) 压缩成一个共享的潜在向量 \(\boldsymbol{c}_{i,kv}^{(h)}\):
    $$
    \boldsymbol{q}_{i,\text{nope} }^{(h)} = \boldsymbol{x}_{i}\boldsymbol{W}_{dq}\boldsymbol{W}_{qc}^{(h)}, \\
    \boldsymbol{c}_{i,kv} = \boldsymbol{x}_{i}\boldsymbol{W}_{dkv}, \\
    \boldsymbol{k}_{i,\text{nope} }^{(h)}, \boldsymbol{v}_{i}^{(h)} = \boldsymbol{c}_{i,kv}\boldsymbol{W}_{uk}^{(h)}, \boldsymbol{c}_{i,kv}\boldsymbol{W}_{uv}^{(h)},
    $$
    • 其中 \(\boldsymbol{W}_{qc}^{(h)} \in \mathbb{R}^{d_{q} \times d_{c} }\),\(\boldsymbol{W}_{dkv} \in \mathbb{R}^{d \times d_{kv} }\),\(\boldsymbol{W}_{uk}^{(h)} \in \mathbb{R}^{d_{kv} \times d_{c} }\),\(\boldsymbol{W}_{uv}^{(h)} \in \mathbb{R}^{d_{kv} \times d_{h} }\)
  • 注意 \(d_{r} + d_{c} = d_{h}\)。MLA 的注意力输出结合了两个分量:
    $$
    \boldsymbol{o}_{i}^{(h)} = \text{Softmax}\left( \boldsymbol{q}_{i,\text{rope} }^{(h)} \boldsymbol{k}_{\leq i,\text{rope} }^{(h)\top} + \boldsymbol{q}_{i,\text{nope} } \boldsymbol{k}_{\leq i,\text{nope} }^{(h)\top} \right) \cdot \boldsymbol{v}_{\leq i}^{(h)} \\
    \text{MLA}(\boldsymbol{x}_{i}) = \begin{bmatrix} \boldsymbol{o}_{i}^{(1)}, \ldots, \boldsymbol{o}_{i}^{(n_{h})} \end{bmatrix} \cdot \boldsymbol{W}_{o}. \tag{3}
    $$
  • 与 MHA 及其变体不同,MLA 存储潜在向量 \(\boldsymbol{c}_{kv}\) 和 \(\boldsymbol{k}_{i,\text{rope} }^{(h)}\)(\(\mathcal{O}(ld_{r} + ld_{kv})\))而不是全秩的 \(\boldsymbol{k}_{i}\), \(\boldsymbol{v}_{i}\)(\(\mathcal{O}(2ln_{h}d_{h})\)),其中 \((d_{r} + d_{kv}) \ll 2n_{h}d_{h}\)
  • 为什么 MLA 需要分离 RoPE 和 NoPE?
    • MLA 在推理过程中对 NoPE 部分引入了矩阵合并技术,有效减少了内存使用
    • 对于点积操作 \(\boldsymbol{q}_{i,\text{nope} }^{(h)} \boldsymbol{k}_{j,\text{nope} }^{(h)\top}\),可以应用以下恒等变换:
      $$
      \boldsymbol{q}_{i,\text{nope} } \boldsymbol{k}_{j,\text{nope} }^{\top} = (\boldsymbol{x}_{i}\boldsymbol{W}_{dq}\boldsymbol{W}_{qc}) (\boldsymbol{c}_{j,kv}\boldsymbol{W}_{uk})^{\top} = \boldsymbol{x}_{i} (\boldsymbol{W}_{dq}\boldsymbol{W}_{qc}\boldsymbol{W}_{uk}^{\top}) \boldsymbol{c}_{j,kv}^{\top}
      $$
      • 注:为简化符号,论文省略了上标 \({}^{(h)}\)。矩阵 \(\boldsymbol{W}_{uv}\) 和 \(\boldsymbol{W}_{o}\) 也可以合并,请参阅 DeepSeek-AI 等人 (2024) 的附录 C
      • 其中 \((\boldsymbol{W}_{dq}\boldsymbol{W}_{qc}\boldsymbol{W}_{uk}^{\top})\) 可以预先合并为单个矩阵,而 \(\boldsymbol{c}_{j,kv}\) 已经存储在 KV 缓存中。对于 RoPE 部分,RoPE() 函数将输入向量乘以旋转矩阵(例如,\(\text{RoPE}(\boldsymbol{q}_{i}) = \boldsymbol{q}_{i}\boldsymbol{R}_{i}\),\(\boldsymbol{R}_{i}\) 的具体形式将在第 3.1 节介绍)
  • 因此,恒等变换变为:
    $$
    \boldsymbol{q}_{i,\text{rope} } \boldsymbol{k}_{j,\text{rope} }^{\top} = (\boldsymbol{x}_{i}\boldsymbol{W}_{dq}\boldsymbol{W}_{qr} \boldsymbol{R}_{i}) (\boldsymbol{x}_{j}\boldsymbol{W}_{kr} \boldsymbol{R}_{j})^{\top} = \boldsymbol{x}_{i} (\boldsymbol{W}_{dq}\boldsymbol{W}_{qr} \boldsymbol{R}_{j-i} \boldsymbol{W}_{kr}^{\top}) \boldsymbol{x}_{j}^{\top}
    $$
  • 由于 \((\boldsymbol{W}_{dq}\boldsymbol{W}_{qr} \boldsymbol{R}_{j-i} \boldsymbol{W}_{kr}^{\top})\) 与相对位置 \(j-i\) 相关,它不能被合并成一个固定矩阵。考虑到 LLM 中的相对距离可能非常长(例如 128K),RoPE 部分更适合使用原始形式进行计算

MHA2MLA

部分旋转位置编码(Partial-RoPE)

  • 为实现从标准 MHA 到 MLA 的迁移,论文提出了部分旋转位置编码微调(partial-RoPE finetuning)策略,该策略从目标比例维度中移除 RoPE,并将其转换为 NoPE
  • 关键的是,尽管先前的工作已经探索了从头开始训练具有部分 RoPE 的 LLM(实现了比全 RoPE 略好的困惑度 (2021; 2024)),但现有方法均未解决如何高效地将预训练的全 RoPE 模型(例如 Llama)适配到部分 RoPE,而无需昂贵的重新训练
  • 论文的工作通过系统评估部分 RoPE 的变体,以确定最数据高效的微调方案来恢复模型在适配后的性能,从而弥补了这一空白
  • RoPE 的数学形式化表示如下:
    • 对于维度为 \(d_h\) 的查询或键向量,RoPE 将向量划分为 \(\frac{d_h}{2}\) 个子空间,每个子空间包含两个连续维度(例如,第 \(k\) 个子空间包含维度 \(2k\) 和 \(2k+1\))
    • 每个子空间以不同的旋转速度(频率)旋转,其中第 \(k\) 个子空间的旋转角为 \(\theta_k = b^{-2k/d_h}\),\(b\) 是预定义的基数(例如,Llama 使用 \(b=10000\))
    • 因此,对查询和键应用 RoPE 变为:
      $$
      \boldsymbol{q}_{i,rope} =\left[\boldsymbol{R}_{i}^{[2k,2k+1]}(\theta_{k})\boldsymbol{q}_{i }^{[2k,2k+1]}\right]_{0\leq k < \frac{d_{h} }{2} }, \\
      \boldsymbol{k}_{i,rope} =\left[\boldsymbol{R}_{i}^{[2k,2k+1]}(\theta_{k})\boldsymbol{k}_{i }^{[2k,2k+1]}\right]_{0\leq k < \frac{d_{h} }{2} }.
      $$
全 RoPE 到部分 RoPE 的策略(Full-RoPE to Partial-RoPE Strategies)
  • 给定保留的旋转子空间数量 \(r\)(\(r=\frac{d_r}{2} \ll\) 总子空间数 \(\frac{d_h}{2}\)),论文提出了四种策略(如图 2 所示)来选择哪些 \(r\) 个子空间保留 RoPE 编码
  • 高频保留(High-Frequency Preservation) 保留 \(r\) 个旋转最快(高频)的子空间:
    $$
    \mathcal{S}_{\text{high} }=\left\{k,|,0\leq k<r\right\}.
    $$
    • 这与 Barbero 等人 (2024) 提出的 p-RoPE 方法一致,他们探索了 \(r\) 占总子空间数 25%、50% 和 75% 的设置,并观察到在从头训练的 LLM 中比全 RoPE 略有优势
  • 低频保留(Low-Frequency Preservation) 保留 \(r\) 个旋转最慢(低频)的子空间:
    $$
    \mathcal{S}_{\text{low} }=\left\{k,\Big{|},\frac{d_{h} }{2}-r\leq k<\frac{d_{h} } {2}\right\}.
    $$
    • 选择该策略作为高频策略的对照实验
  • 均匀采样(Uniform Sampling) 以等间隔选择 \(r\) 个子空间:
    $$
    \mathcal{S}_{\text{uniform} }=\left\{\left.\left\lfloor k\frac{d_{h} }{2r}\right \rfloor,\right|0\leq k<r\right\}
    $$
    • 这通过几何间距平衡了高频和低频分量。实践中,\(2r\) 通常能整除 \(d_h\)。这与 GPT-Neo (2021) 中使用的部分 RoPE 类似
  • 头部感知 2-范数贡献度(Head-wise 2-norm Contribution) Barbero 等人 (2024) 首次提出了 2-范数贡献度来研究这些频率是否被使用以及它们如何发挥作用。该方法基于这样的观察:根据柯西-施瓦茨不等式,第 \(k\) 个频率子空间对注意力对数几率(logits)的影响受相应查询和键分量的 2-范数上界限制,即 \(\left|\left\langle\mathbf{q}_{i}^{[2k,2k+1]},\mathbf{k}_{j}^{[2k,2k+1]}\right \rangle\right|\leq\left|\mathbf{q}_{i}^{[2k,2k+1]}\right|\left|\mathbf{k}_{j }^{[2k,2k+1]}\right|\)。对于每个头 \(h\),论文在长序列上计算 LLM 中每个子空间的平均 2-范数得分 \(^4\)。然后,论文提出按它们的 2-范数得分对所有子空间进行排序,并选择前 \(r\) 个:
    $$
    \mathcal{S}_{2\text{-norm} }=\operatorname*{top}\nolimits_{r} \left(\left|\mathsf{q}^{[2k,2k+1]}_{*}\right|\left|\mathsf{k}^{[2k,2k+1]}_{*}\right|\right).
    $$
    • 这种头部特定的选择自适应地保留了旋转关键的子空间
    • 图 3 可视化了 Llama2-7B 四个头部的 2-范数
  • 论文将在第 4.3 节分析这四种策略的有效性,并在附录 D 中对关键超参数 \(r\) 进行消融研究
  • 对于所有策略,未选择的子空间(\(k \notin \mathcal{S}\))变为 NoPE 维度,从而实现与 MLA 潜在压缩的无缝集成

Low-rank Approximation, Low-rank Approximation

  • 在从全 RoPE 转换为部分 RoPE 后,论文得到了 MLA 中 KV 缓存的第一个分量,表示为:
    $$ \boldsymbol{k}_{i,rope}=\left[\boldsymbol{R}^{[2k,2k+1]}_{i}(\theta_{k})\boldsymbol{k}^{[2k,2k+1]}_{i} \right]_{k\in\mathcal{S} } $$
  • 论文的下一个目标是推导第二个分量
    $$ \boldsymbol{c}_{i,kv} \in \mathbb{R}^{d_{kv} } $$
    • 它作为 \(\boldsymbol{k}_{i,\text{nope} }\) 和 \(\boldsymbol{v}_{i}\) 的低秩表示
  • 给定 MHA 中的键 \(\boldsymbol{k}_{i}=\boldsymbol{x}_{i}\boldsymbol{W}_{k}\) 和值 \(\boldsymbol{v}_{i}=\boldsymbol{x}_{i}\boldsymbol{W}_{v}\),论文首先提取 \(\boldsymbol{W}_{k}\) 中对应于 \(\boldsymbol{k}_{i,\text{nope} }\) 的子空间,即未包含在 \(\mathcal{S}\) 中的维度,得到:
    $$ \boldsymbol{k}_{i,\text{nope} }=\boldsymbol{x}_{i}\boldsymbol{W}_{k,\text{nope} } $$
  • 论文提出了两种基于奇异值分解(SVD)的策略(如图 4 所示)来在实现降秩的同时保留预训练知识:
  • 解耦 SVD(Decoupled SVD, SVD\({}_{\text{split} }\)) 分别将 \(\boldsymbol{W}_{k,\text{nope} }\) 和 \(\boldsymbol{W}_{v}\) 分解为截断 SVD,各分配 \(d_{kv}/2\) 个维度:
    $$
    \boldsymbol{W}_{k,\text{nope} }=\boldsymbol{U}_{k}\boldsymbol{\Sigma}_{k}\boldsymbol{V}^{\top}_{k}, \quad \boldsymbol{W}_{v}=\boldsymbol{U}_{v}\boldsymbol{\Sigma}_{v}\boldsymbol{V}^{\top}_{v},
    $$
    其中 \(\boldsymbol{U}_{k},\boldsymbol{U}_{v},\boldsymbol{V}_{k},\boldsymbol{V}_{v} \in \mathbb{R}^{d_{h} \times \frac{d_{kv} }{2} }\),\(\boldsymbol{\Sigma}_{k},\boldsymbol{\Sigma}_{v} \in \mathbb{R}^{\frac{d_{kv} }{2} \times \frac{d_{kv} }{2} }\)。下投影矩阵 \(\boldsymbol{W}_{d\cdot}\) 和上投影矩阵 \(\boldsymbol{W}_{u\cdot}\) 变为:
    $$
    \boldsymbol{W}_{dk} =\boldsymbol{U}_{k}\boldsymbol{\Sigma}^{1/2}_{k}, \quad \boldsymbol{W}_{uk} =\boldsymbol{\Sigma}^{1/2}_{k}\boldsymbol{V}^{\top}_{k},
    $$
    $$
    \boldsymbol{W}_{dv} =\boldsymbol{U}_{v}\boldsymbol{\Sigma}^{1/2}_{v}, \quad \boldsymbol{W}_{uv} =\boldsymbol{\Sigma}^{1/2}_{v}\boldsymbol{V}^{\top}_{v}.
    $$
    低秩表示 \(\boldsymbol{c}_{i,kv}\) 可以使用 \(\boldsymbol{c}_{i,kv}=[\boldsymbol{x}_{i}\boldsymbol{W}_{dk}, \boldsymbol{x}_{i}\boldsymbol{W}_{dv}]\) 构建
  • 联合 SVD(Joint SVD, SVD\({}_{\text{joint} }\)) 为保留 \(\boldsymbol{K}_{\text{nope} }\) 和 \(\boldsymbol{V}\) 之间的相互作用,论文联合分解拼接后的矩阵:
    $$
    [\boldsymbol{W}_{k,\text{nope} }, \boldsymbol{W}_{v}] = \boldsymbol{U}_{kv}\boldsymbol{\Sigma}_{kv}\boldsymbol{V}^{\top}_{kv},
    $$
    其中 \(\boldsymbol{U}_{kv}, \boldsymbol{V}_{kv} \in \mathbb{R}^{d_{h} \times d_{kv} }\),\(\boldsymbol{\Sigma}_{kv} \in \mathbb{R}^{d_{kv} \times d_{kv} }\)。潜在投影则为:
    $$
    \boldsymbol{W}_{dkv} = \boldsymbol{U}_{kv}\boldsymbol{\Sigma}^{1/2}_{kv},
    $$
    $$
    \boldsymbol{W}_{uk} = \boldsymbol{\Sigma}^{1/2}_{kv}\boldsymbol{V}_{kv}[:, :-d_{v}], \quad \boldsymbol{W}_{uv} = \boldsymbol{\Sigma}^{1/2}_{kv}\boldsymbol{V}_{kv}[:, d_{v}:].
    $$
    这联合优化了键和值的潜在空间,即 \(\boldsymbol{c}_{i,kv} = \boldsymbol{x}_{i}\boldsymbol{W}_{dkv}\),保留了对自回归生成至关重要的跨参数依赖性 \(^5\)。第 4.3 节显示 SVD\({}_{\text{joint} }\) 优于 SVD\({}_{\text{split} }\) ,验证了联合分解能更好地保留预训练知识

Experiment

  • 论文在不同规模(SmoILM-135M/360M/1B7, Llama2-7B)且使用 MHA 或 GQA 预训练的 LLM 上评估了论文的方法
    • 选择 SmoILM 系列是因为其预训练数据和框架都是开源的,这可以最大程度地减少微调数据和过程上的差异
    • 选择 Llama2-7B 是因为它是广泛使用的开源大语言模型之一(但其预训练数据未开源,微调数据可能存在潜在差异)
  • 论文分别使用 MHA2MLA 和 GQA2MLA 来表示架构迁移
    • 两者均采用数据高效的全参数微调(data-efficient full-parameter fine-tuning)
    • 默认使用基于头部的 2-范数贡献度选择(\(\mathcal{S}_{2\text{-norm} }\),\(r=\frac{d_{h} }{16}\))作为部分旋转位置编码(Partial-RoPE)策略
    • 联合奇异值分解(SVD\({}_{\text{joint} }\))作为低秩近似策略
  • 论文的实验旨在回答三个关键问题:
    • 1)MHA2MLA 如何最小化由架构转变引起的准确性下降?
    • 2)MHA2MLA 在 KV 缓存减少比率方面取得了什么成果?
    • 3)MHA2MLA 能否与 KV 缓存量化技术结合以实现复合收益?

Commonsense Reasoning Tasks

Main Results
  • 如表 1 所示,论文的方法在四种模型规模(135M 到 7B)和不同的 KV 缓存压缩比(通过潜在维度 \(d_{kv}\) 控制)下均实现了高效的架构迁移
  • 当比较论文的微调方法与原始大语言模型的性能时
    • 论文观察到四个基础模型的性能仅有微小变化:
      • 135M 模型下降 -0.25%
      • 360M 模型上升 +0.03%
      • 1B7 模型上升 +0.03%
      • 7B 模型上升 +0.37%
    • 这表明微调数据并未显著降低或提高原始模型的性能,为 MHA2MLA 框架提供了一个合适的实验环境
  • 随着 \(d_{kv}\) 减小(例如从 32 到 16 再到 8),KV 缓存减少量增加(即从 -68.75% 到 -81.25% 再到 -87.5%),但通过微调恢复性能损失变得更具挑战性
    • 图 5 显示了 135M(代表 GQA)和 7B(代表 MHA)在不同压缩比下的微调损失曲线
    • 随着压缩比增加,与基线的损失差异变大
    • 论文还观察到损失曲线的波动趋势几乎一致,这表明论文的架构迁移并未显著损害模型的内部知识
  • 更大的模型在迁移到 MLA 架构时经历的性能下降更小
    • 例如,压缩至 18.75% 时,性能下降分别为:
      • 135M 下降 2.41%
      • 360M 下降 2.69%
      • 1B7 下降 1.28%
      • 7B 下降 0.61%
    • 这揭示了 MHA2MLA 的潜在缩放定律 (potential scaling law of MHA2MLA)
  • 最后,从 135M 模型到 7B 模型,微调所需的 token 数量仅占预训练 token 的约 0.3% 到 0.6%,证明了论文方法的数据效率
  • 总体而言,无论是使用 GQA2MLA 还是 MHA2MLA,架构迁移都以极小的成本实现,从而带来高效且经济(economical)的推理
  • 表 1: 使用 MHA2MLA 或 GQA2MLA 的四个大语言模型的常识推理能力
    • 六个基准测试包括 MMLU (2021)、ARC 简单和挑战集 (ARC, 2018)、PIQA (2020)、Hellaswag (HS, 2019)、OpenBookQA (OBQA, 2018)、Winogrande (WG, 2021)

Long Context Tasks

Settings
  • 为评估模型的生成能力,论文采用 LongBench (2024) 作为生成性能的基准
  • 所有模型均使用贪心解码策略进行测试
  • 上下文窗口大小根据模型微调时使用的序列长度确定
  • 使用 HQQ ( 2023) 和 Quanto 以不同精度级别设置缓存,以评估原始模型的性能作为基线
  • 由于论文的方法与 KV 缓存量化兼容,论文还进行了额外实验来评估两种方法结合的效果
Main Results
  • 如表 2 所示,在 LongBench 上,与训练后量化方法相比,MHA2MLA 实现了具有竞争力或更优的效率-准确性曲线
    • 原生的 4 位量化在可比压缩比下仅带来适度的性能下降(-0.2% 到 -0.4%)
    • 原生的 2 位量化实现了 87.5% 的 KV 缓存减少,但出现了严重的性能崩溃(-6.2% 到 -9%)
    • 相比之下
      • MHA2MLA 在达到 87.5% 压缩(\(d_{kv}=16\))时仅造成 3% 的准确性损失
      • 进一步与 4 位量化协同作用,实现了 92.19%/96.87% 的压缩(\(d_{kv}=64/16\)+Int4HQQ),同时将性能下降限制在 -0.5%/-3.2%,优于所有 2 位基线
    • 这突显了 MHA2MLA 的潜在空间设计与数值精度降低是正交的,从而能够实现复合效率增益 (compound efficiency gains) 而不会产生破坏性干扰
  • 表 2: Llama2-7B 和 MHA2MLA 在 LongBench 上的评估结果。粗体表示压缩比大于或等于 Int2 量化,同时性能也高于 Int2
  • 图 5: 不同 KV 缓存存储比率下的微调损失曲线(颜色从浅到深代表 12.5%, 18.75%, 31.25%, 和 100%)

Ablation Study

四种部分旋转位置编码策略:\(\mathcal{S}_{\text{high} }\), \(\mathcal{S}_{\text{low} }\), \(\mathcal{S}_{\text{uniform} }\), \(\mathcal{S}_{\text{2-norm} }\)
  • 表 3 展示了四种将完整旋转位置编码(full-RoPE)转换为部分旋转位置编码(partial-RoPE)的策略结果
    • 当将这四种策略与完整旋转位置编码进行比较时
      • 低频保留策略 \(\mathcal{S}_{\text{low} }\) 遭受了最大的性能损失(135M 减少 -6.49%,1B7 减少 -1.21%)
      • 高频保留策略 \(\mathcal{S}_{\text{high} }\) 的性能下降显著较小(135M 减少 -0.85%,1B7 减少 -0.76%)
      • 强调了高频子空间的重要性
    • \(\mathcal{S}_{\text{uniform} }\) 和 \(\mathcal{S}_{\text{2-norm} }\) 都产生了更好的性能,\(\mathcal{S}_{\text{uniform} }\) 保留了跨频率谱的子空间
    • \(\mathcal{S}_{\text{2-norm} }\) 则根据子空间对注意力分数的贡献来保留子空间
    • 论文选择 \(\mathcal{S}_{\text{2-norm} }\) 作为默认配置,因为被移除的子空间(即 NoPE)更适合(基于 SVD 的)低秩近似
两种基于 SVD 的低秩近似:\(\text{SVD}_\text{split}\), \(\text{SVD}_\text{joint}\)
  • 表 3 中每个组的最后两行比较了两种 SVD 方法的效果
    • 在两个大语言模型上,\(\text{SVD}_\text{joint}\) 方法 consistently 优于 \(\text{SVD}_\text{split}\),在 135M 模型上平均性能提升 0.92%,在 1B7 模型上平均提升 0.74%
    • 这表明 \(\text{SVD}_\text{joint}\) 成为明确的默认选择

Related Work

Efficient Attention Architectures

  • 标准的多头注意力机制(Multi-Head Attention, MHA)(2017) 在上下文长度上具有二次复杂度,这促使了众多效率创新
  • MHA 变体,如多头查询注意力(Multi-Query Attention, MQA)和分组查询注意力(Grouped-Query Attention, GQA)(2023)
    • 通过在不同头之间共享键/值来减少内存开销
    • 但这是以参数剪枝和性能下降为代价的
  • 其他并行的工作,如线性 Transformer (2019; 2020; 2021)、RWKV (2023) 和 Mamba (2023)
    • 用线性循环或状态空间模型替代了 softmax 注意力,但在自回归生成中难以匹配标准注意力的表达能力
  • 多头潜在注意力(Multi-Head Latent Attention, MLA)(2024) 通过将 KV 缓存压缩为低秩潜在向量而无需剪枝注意力参数,从而脱颖而出
  • 论文的工作将 MLA 与主流架构(MHA/GQA)连接起来,通过数据高效的微调实现无缝迁移
  • 许多线性注意力变体放弃了 softmax 查询-键交互(例如,通过核近似),但保留查询-键点积结构(即使是分解形式)的架构仍然与论文的 MHA2MLA 框架兼容

Economical Key-Value Cache(经济的键值缓存)

  • KV 缓存的内存占用已成为长上下文推理的关键瓶颈。最近的进展分为三类:
    • 创新架构方法 ,如 MLA (DeepSeek-2024)、MiniCache (2024a) 和 MLKV (2024),跨层或头共享或压缩 KV 表示
      • 虽然有效,但跨层共享可能混淆不同的注意力模式,可能损害特定任务的性能
      • 只有 MLA 在 DeepSeek 的 LLM 中得到了成功验证
    • 量化技术 ,如 GPTQ (2022)、FlexGen (2023) 和 KIVI (2024b)
      • 以低比特格式(例如 2 比特)存储 KV 缓存,以精度损失为代价实现内存节省
    • 动态剪枝方法
      • A2SF (2024) 和 SnapKV (2024) 从 KV 缓存中剪枝“不太重要”的 Token
        • 但 Token 剪枝可能丢弃关键的长距离依赖
      • 头剪枝(例如 SliceGPT (2024)、Sheared (2024) 和 Simple Pruning (2024))则不可逆地降低了模型容量
  • 论文的 MHA2MLA 方法实现了标准基于 Transformer 的大语言模型向更经济的 MLA 架构的迁移,并已证明其能够与 KV 量化技术集成以实现约 97% 的缓存节省
    • 它在理论上也与其他方法(如剪枝)兼容

NLP——LLM-API调用示例


Qwen

  • Qwen API 申请:获取API Key
  • Qwen API 调用文档:Qwen-API Doc
  • 吐槽:Qwen 的文档和申请链接写的很差,阿里云东西太多,需要翻来翻去找
  • Qwen API 调用示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    def qwen_api():
    import requests

    url = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions"
    headers = {
    "Authorization": "Bearer $API_KEY",
    "Content-Type": "application/json"
    }

    data = {
    "model": "qwen-plus",
    "messages": [
    {"role": "user", "content": "你好,请介绍一下你自己"}
    ],
    "max_tokens": 50,
    "temperature": 0.0, # 贪心采样示例
    "top": 0.2, # 贪心采样示例
    "logprobs": True, # 可以打开 logprobs 看每个 token 的 logprobs,使用 e^logprob 即可得到最终概率
    "top_logprobs": 2,
    }

    response = requests.post(url, headers=headers, json=data)
    print(response.json())

    if __name__ == "__main__":
    qwen_api()

LongCat

  • LongCat 文档:LongCat API开放平台快速开始

  • 文档写的清晰明了,Qwen 应该学习一下

  • LongCat API 调用示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    def longcat_api():
    import requests

    url = "https://api.longcat.chat/openai/v1/chat/completions"
    headers = {
    "Authorization": "Bearer $API_KEY",
    "Content-Type": "application/json"
    }

    data = {
    "model": "LongCat-Flash-Chat",
    "messages": [
    {"role": "user", "content": "你好,请介绍一下自己"}
    ],
    "max_tokens": 1000,
    "temperature": 0.7,
    # "logprobs": True, # 打开这个参数会报错
    }

    response = requests.post(url, headers=headers, json=data)
    print(response.json())

    if __name__ == "__main__":
    longcat_api()
    • 特别强调:目前 LongCat 不支持返回 logprobs 信息
1…456…61
Joe Zhou

Joe Zhou

Stay Hungry. Stay Foolish.

608 posts
49 tags
GitHub E-Mail
© 2026 Joe Zhou
Powered by Hexo
|
Theme — NexT.Gemini v5.1.4