Math——KL中得分函数和路径梯度的理解


整体说明

  • 在强化学习(尤其是策略梯度方法)中,当作者想要计算一个关于概率分布的期望的梯度时,通常会遇到两种主要的梯度类型:路径导数得分函数导数

背景:KL 计算梯度时为什么会有两个部分?

  • 假设我们想计算 KL 散度的梯度:
    $$
    \nabla \mathbb{KL}(\pi, \pi_{\text{ref} }) = \nabla \mathbb{E}_{y\sim\pi} \left[ \log \frac{\pi(y)}{\pi_{\text{ref} }(y)} \right]
    $$
    • 这里的问题在于:参数 \(\theta\) 既出现在“采样分布”中(即 \(y \sim \pi_\theta\)),也出现在“被积函数”中(即 \(\log \frac{\pi_\theta(y)}{\pi_{\text{ref} }(y)}\))
  • 这就导致了梯度由两部分组成:
    • 路径导数 :对“被积函数”求导(假设分布固定)
    • 得分函数导数 :对“采样分布”求导(假设函数值固定)

路径导数(Pathwise Derivative)

  • Pathwise Derivative 含义:
    • 路径导数衡量的是:当作者改变策略的参数 \(\theta\) 时,被积函数本身(即 \(\log \frac{\pi(y)}{\pi_{\text{ref} }(y)}\))如何变化
    • 它假设作者采样的样本 \(y\) 是固定的,不考虑采样分布的变化
  • 在 KL 散度中的表现
    • 在论文中,路径导数部分就是直接对 KL 估计值求导:
      $$
      \nabla \overline{\mathrm{KL} } = \nabla \left( \log \frac{\pi(y)}{\pi_{\text{ref} }(y)} \right)
      $$
  • 这是大多数初学者容易想到的方式:定义 loss = KL 估计,然后 .backward(),认为这就是 KL 散度求导的全部
    • 但问题是,这样做忽略了“采样分布也在变”这一事实

得分函数导数(Score Function Derivative)

  • Score Function Derivative 的含义
    • 得分函数导数衡量的是:当作者改变策略的参数 \(\theta\) 时,采样分布 \(\pi\) 的变化对期望值的影响
  • Score Function Derivative 基于以下恒等式:
    $$
    \nabla \mathbb{E}_{y\sim\pi}[f(y)] = \mathbb{E}_{y\sim\pi}[f(y) \nabla \log \pi(y)]
    $$
    • 其中 \(\nabla \log \pi(y)\) 就是得分函数

在 KL 散度中的表现

  • 在 KL 散度中,这部分就是:
    $$
    \overline{\mathrm{KL} } \cdot \nabla \log \pi(y) = \left( \log \frac{\pi(y)}{\pi_{\text{ref} }(y)} \right) \cdot \nabla \log \pi(y)
    $$
    • 这部分反映了:如果作者更大概率采样到某个 \(y\) ,它对 KL 期望的影响
    • 这是很多开源项目缺失的部分

两者结合:完整的 KL 梯度

  • 一个无偏的 KL 梯度估计应该是:
    $$
    \hat{g} = \underbrace{\nabla \overline{\mathrm{KL} } }_{\text{路径导数} } + \underbrace{\overline{\mathrm{KL} } \cdot \nabla \log \pi(y)}_{\text{得分函数导数} }
    $$
    • 路径导数:对被积函数求导
    • 得分函数导数:对采样分布求导

为什么直接对 KL 估计求导是错的?

  • 论文中举了一个极端的例子:对普通估计 \(\overline{\mathrm{KL} }_{\text{vanilla} }\) 求导:
    $$
    \mathbb{E}[\nabla \overline{\mathrm{KL} }_{\text{vanilla} }] = \mathbb{E}[\nabla \log \pi(y)] = 0
    $$
    • 这意味着期望为零,完全没有优化效果
    • 这就是因为路径导数为零,而得分函数导数被忽略了

总结

  • 得分函数导数和路径导数的总结如下:
    类型 含义 在 KL 梯度中的形式
    路径导数 对 loss 函数本身求导 \(\nabla \overline{\mathrm{KL} }\)
    得分函数导数 对采样分布求导 \(\overline{\mathrm{KL} } \cdot \nabla \log \pi(y)\)
    完整梯度 两者之和 \(\nabla \overline{\mathrm{KL} } + \overline{\mathrm{KL} } \cdot \nabla \log \pi(y)\)
  • 如果在实际编码中想实现正确的 KL 梯度,不能只写 kl.backward(),而应该用类似:
    1
    2
    3
    kl = (log_probs - ref_log_probs).mean()
    loss = (kl.detach() * log_probs).mean() # 得分函数部分,路径导数为0,可以不加
    # 或者直接用 squared estimate

附录:得分函数的定义和名字来源

  • 得分函数的定义:在统计学中,对于一个参数化的概率模型 \(p(y; \theta)\)(即论文中的 \(\pi(y)\)),得分函数定义为:
    $$
    s(\theta; y) = \nabla_{\theta} \log p(y; \theta)
    $$
    • 也就是说,它是对数似然函数对参数 \(\theta\) 的一阶导数
  • 为什么叫“得分”?这个名字来源于似然函数的最大化过程
    • 在最大似然估计中,作者希望找到 \(\theta\) 使得 \(p(y|\theta)\) 最大
    • 对数似然 \(\log p(y|\theta)\) 的最大值点满足:
      $$
      \nabla_{\theta} \log p(y|\theta) = 0
      $$
    • 这个方程叫做得分方程
    • 所以,\(\nabla_{\theta} \log p(y|\theta)\) 就是“得分”本身
    • 类比一下:
      • 有一个模型,数据 \(y\) 是证据
      • 每个参数 \(\theta\) 的 “得分” 就是它能让数据出现得多好(的梯度)
      • 得分越高,说明这个参数方向越能解释数据
  • 直观理解:可以把 \(\log p(y|\theta)\) 想象成模型对数据 \(y\) 的“满意度”(对数概率)
    • 如果 \(\theta\) 稍微变化,满意度变化大,说明这个数据点对这个参数敏感
    • 这个敏感度就是“得分”
  • 得分函数有两个非常重要的性质:
    • 1)期望为零:
      $$
      \mathbb{E}_{y \sim p(y|\theta)} [\nabla_{\theta} \log p(y|\theta)] = 0
      $$
      • 这是论文中多次用到的重要性质,也是策略梯度定理的基础,证明详情见本文附录
    • 2)方差是 Fisher 信息量
      $$
      \mathrm{Var}[\nabla_{\theta} \log p(y|\theta)] = \mathcal{I}(\theta)
      $$
      • 其中 \(\mathcal{I}(\theta)\) 是 Fisher 信息矩阵,衡量了参数 \(\theta\) 的估计精度

得分函数在强化学习中的角色

  • 在强化学习和论文的上下文中,下面的式子被叫做得分函数
    $$
    \nabla_{\theta} \log \pi_{\theta}(a|s)
    $$
  • 原因与上述得分函数名字来源完全相同:
    • \(\pi_{\theta}(a|s)\) 是一个概率模型(策略)
    • 它的对数梯度就是统计学意义上的得分函数
  • 这就是为什么策略梯度定理长这样:
    $$
    \nabla_{\theta} J(\theta) = \mathbb{E}_{\pi} [Q(s,a) \nabla_{\theta} \log \pi_{\theta}(a|s)]
    $$
    • \(Q(s,a)\) 是“权重”
    • \(\nabla_{\theta} \log \pi_{\theta}(a|s)\) 是“得分函数”

KL 梯度 为什么包含得分函数

  • 原始论文 On a few pitfalls in KL divergence gradient estimation for RL, 20250611, Meta FAIR Appendix B 中给出了完整的 KL 梯度:
    $$
    \nabla \mathbb{KL} = \mathbb{E}_{y\sim\pi}[\nabla \overline{\mathrm{KL} }] + \mathbb{E}_{y\sim\pi}[\overline{\mathrm{KL} } \cdot \nabla \log \pi(y)]
    $$
    • 第二项中的 \(\nabla \log \pi(y)\) 正是得分函数
  • 它的作用:
    • 当 \(\overline{\mathrm{KL} }\) 很大时(策略偏离参考策略很远)
    • 得分函数会告诉作者应该往哪个方向调整参数 ,使得采样分布变化,从而降低 KL

最后:得分函数与似然函数等的区别

  • 想象我们在玩一个猜数字游戏:
    • 似然函数 :我们猜的数字离正确答案有多近
    • 对数似然 :把距离转换成“得分”
    • 得分函数 :告诉我们应该往大猜还是往小猜,猜多少
  • 在 RL 中:
    • 策略 :我们猜数字的规则
    • 得分函数 :告诉我们应该怎么调整规则才能更可能猜到正确答案

附录:得分函数期望为 0 的性质证明

  • 证明目标:
    $$
    \mathbb{E}_{a \sim \pi_\theta(\cdot|s)} \left[ \nabla_\theta \log \pi_\theta(a|s) \right] = 0
    $$

预备知识

  • 对于给定的状态 \(s\),策略 \(\pi_\theta(a|s)\) 是一个关于动作 \(a\) 的概率分布(离散或连续),满足:
    $$
    \sum_a \pi_\theta(a|s) = 1
    $$

  • $$
    \int \pi_\theta(a|s) da = 1
    $$
  • 接下来对 \(\theta\) 求梯度,希望证明 score function 的期望为零

证明

  • 对于离散动作空间:
    $$
    \sum_a \pi_\theta(a|s) = 1
    $$
  • 两边对 \(\theta\) 求梯度(假设可交换求和与求导):
    $$
    \nabla_\theta \sum_a \pi_\theta(a|s) = \nabla_\theta (1) = 0
    $$
  • 即:
    $$
    \sum_a \nabla_\theta \pi_\theta(a|s) = 0
    $$
  • 现在,利用对数导数技巧:
    $$
    \nabla_\theta \pi_\theta(a|s) = \pi_\theta(a|s) \cdot \nabla_\theta \log \pi_\theta(a|s)
    $$
  • 代入上式:
    $$
    \sum_a \pi_\theta(a|s) \cdot \nabla_\theta \log \pi_\theta(a|s) = 0
    $$
  • 这正是期望的定义:
    $$
    \mathbb{E}_{a \sim \pi_\theta(\cdot|s)} \left[ \nabla_\theta \log \pi_\theta(a|s) \right] = 0
    $$
  • 注:连续情况证明同上

得分函数期望为 0 的直观理解和策略梯度法中的作用

  • 这个性质本质上是因为概率分布的归一化条件:所有概率之和为 1,梯度必须保持这个约束,所以 score function 的期望为零
  • 在策略梯度定理的推导中,这个性质被用来证明减去任意与动作无关的基线函数不会引入偏差
    $$
    \mathbb{E} \left[ \nabla_\theta \log \pi_\theta(a|s) \cdot b(s) \right] = \mathbb{E}_{s} \left[ b(s) \cdot \mathbb{E}_{a \sim \pi_\theta(\cdot|s)} [ \nabla_\theta \log \pi_\theta(a|s) ] \right] = 0
    $$