DL——Teacher-Forcing方法

本文主要介绍Transformer和Attention相关内容


整体总结

  • 教师强制(Teacher Forcing) 是一种在训练序列生成模型(包括循环神经网络 RNN、长短期记忆网络 LSTM 等)时使用的方法
  • 其核心思想是在训练过程中强制模型使用真实的目标序列作为输入 ,而非模型自身的预测结果,从而解决序列生成任务中可能出现的误差累积问题
  • 大模型的 SFT 方法就是一种 Teacher Forcing 方法,属于一种 Token-level 的行为克隆

Teacher Forcing 的基本原理

  • 在序列生成任务(如机器翻译、文本生成、语音识别等)中,模型需要根据历史输入和已生成的序列来预测下一个输出
  • 传统训练方式下,若直接使用模型前一步的预测结果作为下一步的输入,一旦某一步预测错误,后续预测可能会因误差累积而“偏离轨道”,导致训练不稳定
  • 教师强制的做法 :在每一步训练中,强制使用真实的目标序列(而非模型上一步的预测值)作为下一步的输入
    • 例如:在机器翻译中,当生成第二个词时,不使用模型预测的第一个词,而是直接使用参考译文中的第一个词,以此类推

Teacher Forcing 的具体流程(以LSTM为例)

  • 假设我们有一个序列生成任务,目标序列为 \( y_1, y_2, y_3, \dots, y_T \),模型输入为 \( x_1, x_2, \dots, x_T \),则训练过程如下:
    • 第一步 :输入 \( x_1 \),模型预测 \( \hat{y}_1 \),与真实值 \( y_1 \) 计算损失并更新参数
    • 第二步不使用 \( \hat{y}_1 \),而是将真实值 \( y_1 \) 作为输入,结合 \( x_2 \),模型预测 \( \hat{y}_2 \),计算损失并更新参数
    • 后续步骤 :重复上述过程,每一步都用真实的 \( y_{t-1} \) 作为当前步的部分输入,直至生成 \( \hat{y}_T \)。

Teacher Forcing 的优缺点分析

优点

  • 训练更稳定 :避免因早期预测错误导致的误差累积,模型更容易收敛
  • 加速收敛 :真实目标序列提供了更准确的监督信号,减少了训练迭代次数
  • 降低训练难度 :尤其适合复杂序列任务(如长文本生成),避免模型“发散”

缺点

  • 训练与推理偏差 :推理时(如实际生成文本)无法获取真实目标序列,需依赖模型自身预测,可能导致“暴露偏差(Exposure Bias)”(即训练时的输入分布与推理时不一致)
  • 缺乏抗噪能力 :模型可能过度依赖真实标签,对预测误差的鲁棒性较差

其他相关训练方法的对比

  • 教师强制 :始终使用真实标签作为输入
  • 为解决教师强制的“暴露偏差”问题,有人提出了 Scheduled Sampling(计划采样) 方法:
    • Scheduled Sampling :在训练初期以高概率使用真实标签,随着训练推进,逐渐增加使用模型预测值的概率,使模型逐步适应推理时的输入分布
    • Scheduled Sampling通过平衡“教师指导”和“自主预测”,减少训练与推理的差异,提升模型泛化能力

代码示例

  • PyTorch实现简单教师强制训练
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    import torch
    import torch.nn as nn
    import torch.optim as optim

    class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
    super(LSTMModel, self).__init__()
    self.lstm = nn.LSTM(input_size, hidden_size)
    self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
    lstm_out, hidden = self.lstm(x, hidden)
    output = self.fc(lstm_out)
    return output, hidden

    def train_with_teacher_forcing(model, input_seq, target_seq, criterion, optimizer):
    model.train()
    hidden = model.init_hidden()
    optimizer.zero_grad()
    loss = 0

    for t in range(target_seq.size(0)):
    output, hidden = model(input_seq[t].unsqueeze(0), hidden)
    input_seq[t+1] = target_seq[t] # 下一时间步的输入使用真实目标值(Teacher Forcing 方法的核心代码)
    loss += criterion(output, target_seq[t].unsqueeze(0))

    loss.backward()
    optimizer.step()
    return loss.item()