CA——多任务学习总结

多任务学习,multi-task learning

多任务学习的结构

ESMM及扩展

MMoE及相关变体

多任务学习的损失函数权重设置

基于人工经验的人工优化

  • 一般思想是对齐损失函数均值,并按照业务偏好有所倾斜,如果愿意花时间尝试,往往能拿到不错的结果

基于贝叶斯推论的权重优化

  • 基于不确定性的权重设置方法(Uncertainty Weighting)
  • 基本公式
    $$
    L(\mathbf{W}, \sigma_1, \sigma_2,…,\sigma_K) = \sum_{k=1}^K \frac{1}{2\sigma^2}L(\mathbf{W}) + \log \sigma^2
    $$
    • 上述公式可通过推导得出Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics
      • 可以推导,无论是回归问题还是分类问题(也可以是分类和回归问题的混合),都可以按照上面的方法设置损失函数(分类问题证明中会使用到一个近似值,不是严格推导)
      • 推导是在假设
    • \(\sigma\)是可学习的参数,初始设置固定值,然后使用梯度更新学习即可
    • 使用简单,实际使用时效果也确实不错,建议人工调参也可以在先使用该方案拿到权重量级后继续

帕累托最优权重优化

损失函数优化

损失函数归一化

可用于解决由于不同任务损失函数量级差异带来的问题

普通版本

$$
L_{norm} = \frac{L_k(\mathbf{W})}{L_0(\mathbf{W_0})}
$$

  • 使用各个任务自己的第一次输出的损失函数作为基础损失函数,其中\(L_0(\mathbf{W_0}\)为第一次计算loss的到的损失函数
滑动平均版本

$$
L_{base} = \alpha L_{base} + (1-\alpha) L_k \\
L_{norm} = \frac{L_k(\mathbf{W})}{L_{base}} \\
$$

  • 使用滑动平均来记录基础损失函数,该方案可进一步减少由于初始化损失误差过大带来的问题

梯度归一化(GradNorm)

  • 原始论文:GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks
  • 核心思想是对各任务的损失函数进行加权(\(w_i\))求和得到更新共享参数的损失函数
  • “GradNorm”这个名字的由来是因为权重\(w_i\)是与梯度2范数的期望等有关的?
  • 公式中\(w_i\)是各个任务损失函数对共享参数损失函数的权重,该权重初始值为1,在训练过程中逐步更新,每一步最后都保持该权重加和为\(T\)(\(T\)为任务数量,即保证权重均值为1)
  • 问题:文中没有明确各个任务各自的参数如何更新,猜测各自更新即可

最佳实践

  • 一般情况下,根据业务特点,尽量使用类似于ESMM结构
  • 权重设置尝试次序:
    • 对损失函数进行归一化(梯度归一化好像效果一般?)
    • 权重设置时,先使用不确定性权重(Uncertainty Weighting)跑一版,得到基线
    • 在Uncertainty Weighting的基础上,人工可以根据业务需要微调,可以偏向于需要的任务
    • 帕累托最优实现复杂,且不一定有收益
      • 复杂体现在:需要在每个batch上重新求解最优化问题,得到当前的loss权重(用上一个batch的梯度求解这一个batch的最优权重)