NLP——TRL库的使用

本文主要介绍 TRL 库的使用


整体说明

  • TRL(Transformer Reinforcement Learning)是 huggingface 中的一个完整的库,用于微调和对齐大型语言模型,可用于优化 Transformer 语言和扩散模型
  • 这个库支持 SFT、PPO、DPO 等模型微调、对齐流程
  • TRL 库目前被很多开源框架依赖,是 LLM 领域的标准基础框架
  • TRL 支持了很多开源的微调方法,而且还在持续更新,详情见:huggingface.co/docs/trl 的 API 部分
  • TRL 集成了很多底层框架
    • 待补充

安装 TRL 库

  • 通过 pip 安装 TRL 库:

    1
    pip install trl
  • 也可以通过 Git 克隆并直接从源代码安装:

    1
    2
    3
    git clone https://github.com/huggingface/trl.git
    cd trl
    pip install .
    • 一些未发布功能和修复在最新版本里面,此时需要通过上述源码方式安装
    • 安装方式说明:
      • pip install . 会安装并复制文件到默认目录(标准稳定版 pip 包安装)
      • pip install -e . 则会创建链接到当前目录,同时对当前的目录会立即生效到包上(--editable

SFT 示例

  • 参考链接:huggingface.co/docs/trl/sft_trainer
  • SFTTrainer用于在自定义数据集上进行监督微调。示例代码如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    from datasets import load_dataset
    from trl import SFTConfig, SFTTrainer
    from transformers import AutoModelForCausalLM, AutoTokenizer

    # 定义基础模型
    model = AutoModelForCausalLM.from_pretrained("模型名称")
    tokenizer = AutoTokenizer.from_pretrained("模型名称")

    # 加载训练数据集
    dataset = load_dataset("trl-lib/Capybara", split="train")
    # 配置训练参数
    training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
    # 初始化SFTTrainer
    trainer = SFTTrainer(
    model=model, # 基础模型
    tokenizer=tokenizer, # 对应的tokenizer
    args=training_args,
    train_dataset=dataset,
    )
    # 开始训练
    trainer.train()

奖励模型(Reward Model)训练示例

  • RewardTrainer用于训练奖励模型,该模型可以评估文本生成的质量。示例代码如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    from trl import RewardConfig, RewardTrainer
    from transformers import AutoModelForSequenceClassification, AutoTokenizer
    from datasets import load_dataset

    # 加载预训练的模型和分词器
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
    model = AutoModelForSequenceClassification.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", num_labels=1)
    # 加载适合于奖励模型的数据集
    dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
    # 配置训练参数
    training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
    # 初始化RewardTrainer
    trainer = RewardTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    args=training_args
    )
    # 开始训练
    trainer.train()

PPO 训练示例

  • PPOTrainer用于基于近端策略优化算法对语言模型进行微调。示例代码如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    import torch
    from transformers import AutoTokenizer
    from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
    from trl.core import respond_to_batch

    # 加载预训练的模型和分词器
    model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
    model_ref = create_reference_model(model)
    tokenizer = AutoTokenizer.from_pretrained('gpt2')

    # 初始化PPO训练器
    ppo_config = PPOConfig(batch_size=1)

    # 编码一个查询
    query_txt = "This morning I went to the "
    query_tensor = tokenizer.encode(query_txt, return_tensors="pt")

    # 获取模型响应
    response_tensor = respond_to_batch(model, query_tensor)

    # 创建PPO训练器
    ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)

DPO 训练示例

  • DPOTrainer用于根据人类偏好直接优化语言模型。示例代码如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from trl import DPOTrainer
    import datasets

    # 加载模型和分词器
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
    model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
    # 加载数据集
    dataset = datasets.load_dataset("trl-lib/ultrafeedback_binarized", split="train")

    # 初始化DPOTrainer
    trainer = DPOTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    )

    # 开始训练
    trainer.train()

GRPO 训练示例