本文主要介绍Transformer和Attention相关内容
整体总结
- 教师强制(Teacher Forcing) 是一种在训练序列生成模型(包括循环神经网络 RNN、长短期记忆网络 LSTM 等)时使用的方法
- 其核心思想是在训练过程中强制模型使用真实的目标序列作为输入 ,而非模型自身的预测结果,从而解决序列生成任务中可能出现的误差累积问题
- 大模型的 SFT 方法就是一种 Teacher Forcing 方法,属于一种 Token-level 的行为克隆
Teacher Forcing 的基本原理
- 在序列生成任务(如机器翻译、文本生成、语音识别等)中,模型需要根据历史输入和已生成的序列来预测下一个输出
- 传统训练方式下,若直接使用模型前一步的预测结果作为下一步的输入,一旦某一步预测错误,后续预测可能会因误差累积而“偏离轨道”,导致训练不稳定
- 教师强制的做法 :在每一步训练中,强制使用真实的目标序列(而非模型上一步的预测值)作为下一步的输入
- 例如:在机器翻译中,当生成第二个词时,不使用模型预测的第一个词,而是直接使用参考译文中的第一个词,以此类推
Teacher Forcing 的具体流程(以LSTM为例)
- 假设我们有一个序列生成任务,目标序列为 \( y_1, y_2, y_3, \dots, y_T \),模型输入为 \( x_1, x_2, \dots, x_T \),则训练过程如下:
- 第一步 :输入 \( x_1 \),模型预测 \( \hat{y}_1 \),与真实值 \( y_1 \) 计算损失并更新参数
- 第二步 :不使用 \( \hat{y}_1 \),而是将真实值 \( y_1 \) 作为输入,结合 \( x_2 \),模型预测 \( \hat{y}_2 \),计算损失并更新参数
- 后续步骤 :重复上述过程,每一步都用真实的 \( y_{t-1} \) 作为当前步的部分输入,直至生成 \( \hat{y}_T \)。
Teacher Forcing 的优缺点分析
优点
- 训练更稳定 :避免因早期预测错误导致的误差累积,模型更容易收敛
- 加速收敛 :真实目标序列提供了更准确的监督信号,减少了训练迭代次数
- 降低训练难度 :尤其适合复杂序列任务(如长文本生成),避免模型“发散”
缺点
- 训练与推理偏差 :推理时(如实际生成文本)无法获取真实目标序列,需依赖模型自身预测,可能导致“暴露偏差(Exposure Bias)”(即训练时的输入分布与推理时不一致)
- 缺乏抗噪能力 :模型可能过度依赖真实标签,对预测误差的鲁棒性较差
其他相关训练方法的对比
- 教师强制 :始终使用真实标签作为输入
- 为解决教师强制的“暴露偏差”问题,有人提出了 Scheduled Sampling(计划采样) 方法:
- Scheduled Sampling :在训练初期以高概率使用真实标签,随着训练推进,逐渐增加使用模型预测值的概率,使模型逐步适应推理时的输入分布
- Scheduled Sampling通过平衡“教师指导”和“自主预测”,减少训练与推理的差异,提升模型泛化能力
代码示例
- PyTorch实现简单教师强制训练
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29import torch
import torch.nn as nn
import torch.optim as optim
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(LSTMModel, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x, hidden):
lstm_out, hidden = self.lstm(x, hidden)
output = self.fc(lstm_out)
return output, hidden
def train_with_teacher_forcing(model, input_seq, target_seq, criterion, optimizer):
model.train()
hidden = model.init_hidden()
optimizer.zero_grad()
loss = 0
for t in range(target_seq.size(0)):
output, hidden = model(input_seq[t].unsqueeze(0), hidden)
input_seq[t+1] = target_seq[t] # 下一时间步的输入使用真实目标值(Teacher Forcing 方法的核心代码)
loss += criterion(output, target_seq[t].unsqueeze(0))
loss.backward()
optimizer.step()
return loss.item()