Math——KL散度的近似估计

  • 参考链接:
    • PPO 一作的博客:Approximating KL Divergence, 2020,在博客中解释了作者使用一些近似方法,本文主要参考该博客的内容,有一些自己的总结和思考

KL散度的定义

  • KL散度定义为:
    $$
    D_\text{KL}(q||p) = \sum_x q(x) \log \frac{q(x)}{p(x)} = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right]
    $$
    • 很多地方也常常写为 \(D_\text{KL}(q,p)\) 或 \(KL[q,p]\)
    • 注意:这里使用的 \(p,q\) 顺序可能和常规的文章可能不同

一些假设

  • 作者假设可以计算任意 \(x\) 的概率(或概率密度) \(p(x)\) 和 \(q(x)\),但无法解析地计算关于 \(x\) 的和(即期望)
  • 现实场景中,无法解析计算的原因是有:
    • 精确计算需要过多计算资源或内存
    • 不存在闭式表达式
    • 为了简化代码,仅存储对数概率而非整个分布。如果KL散度仅用作诊断(如强化学习中常见的情况),这是一个合理的选择

好的估计量应该是怎样的?

  • 好的估计量应具有无偏性(期望相同)和低方差
  • 比如对于来自 \(q\) 的样本, \(\log \frac{q(x)}{p(x)}\) 就是一个无偏估计,但它的方差很高
    • 问题:为什么说这个式子方差高?
    • 回答:一个直观的理解是 因为对于一半样本它是负值,而KL散度总是正值,后续会通过实验验证
    • 注意 KL 散度的积分权重和分子是相同的(这是由其含义和非负性决定的),若对换分子分母,得到的是 KL 的负数值

一些前置定义

  • 定义 一个 比例 \(r\)
    $$r = \frac{p(x)}{q(x)}$$
    • 特别注意 :这里的定义与 KL 散度中括号内的分子分母相反,对应到原始 KL 散度中,值为:
      $$
      \begin{align}
      D_\text{KL}(q||p) &= \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] \\
      &= \mathbb{E}_{x \sim q} \left[ - \log \frac{p(x)}{q(x)} \right] \\
      &= \mathbb{E}_{x \sim q} \left[ - \log r \right]
      \end{align}
      $$

朴素估计量(\(k_1\))

  • 朴素估计量(\(k_1\)) 表达式:
    $$
    k_1 = -\log r = \log \frac{q(x)}{p(x)}
    $$
  • 无偏 : \(\mathbb{E}[k_1] = D_\text{KL}(q||p)\)
  • 高方差 :因 \(\log r\) 在 \(r>1\) 时为负, \(r<1\) 时为正,导致样本间波动大
  • 适用场景 :理论分析或对无偏性要求严格的场景,但实际应用中可能因高方差不稳定

二次估计量(\(k_2\))

  • 二次估计量(\(k_2\)) 表达式:
    $$
    k_2 = \frac{1}{2} (\log r)^2
    $$
  • 有偏(低偏差) : \(\mathbb{E}[k_2] \approx D_\text{KL}(q||p) + O(\theta^3)\),当 \(p \approx q\) 时偏差极小(如实验中的0.2%)
  • 低方差 :因平方项强制为正,减少了样本间的波动
  • 适用场景 : \(p\) 和 \(q\) 接近时的高效估计,适合作为诊断指标(如强化学习中的策略评估)
    • 理解:这里的诊断主要是指仅仅作为判断条件,而不是作为损失函数的主要优化目标

为什么说估计量 \(k_2\) 具有低偏差?

  • 它的期望是一个 \(f\)-散度。 \(f\)-散度定义为:
    $$D_f(p,q) = \mathbb{E}_{x \sim q} \left[ f\left( \frac{p(x)}{q(x)} \right) \right]$$
    • 其中 \(f\) 是凸函数
    • KL散度和其他许多著名的概率距离都是 \(f\)-散度, KL 散度中 \(f(\cdot) = -\log (\cdot)\)(\(\log(\cdot)\) 是凹函数,凹函数取负号就是凸函数)
  • 当 \(q\) 接近 \(p\) 时,所有具有可微 \(f\) 的 \(f\)-散度在二阶近似下都类似于KL散度。具体来说,对于参数化分布 \(p_\theta\) :
    $$
    D_f(p_0, p_\theta) = \frac{f’’(1)}{2} \theta^T F \theta + O(\theta^3)
    $$
    • 其中 \(F\) 是 \(p_\theta\) 在 \(p_\theta = p_0\) 处评估的Fisher信息矩阵
    • \(\mathbb{E}_q[k_2] = \mathbb{E}_q \left[ \frac{1}{2}(\log r)^2 \right]\) 对应 \(f(x) = \frac{1}{2}(\log x)^2\) 的 \(f\)-散度,而 \(D_\text{KL}(q||p)\) 对应 \(f(x) = -\log x\)。容易验证两者都有 \(f’’(1)=1\),因此对于 \(p \approx q\),两者看起来像相同的二次距离函数
  • \(k_2\) 的取值总是大于等于0,可以通过求导证明:当 \(x>0\) 时,有 \(x - \log x - 1 \ge 0\) 恒成立,最小值在 \(x=1,y=0\)处,其函数图像如下:

Bregman散度估计量(\(k_3\))

  • Bregman散度估计量(\(k_3\)) 表达式:
    $$
    k_3 = (r - 1) - \log r
    $$
  • 无偏 : \(\mathbb{E}[k_3] = D_\text{KL}(q||p)\)
  • 最低方差 :结合了 \(r-1\) 的线性项与 \(\log r\) 的校正,进一步降低波动
  • 适用场景 :对无偏性和低方差同时要求的场景(如精确的梯度估计或敏感的参数优化)

Bregman散度估计量(\(k_3\)) 是怎么设计出来的?

  • 我们的目标:是找到一个无偏低方差的KL散度估计量
  • 降低方差的通用方法是使用控制变量:取 \(k_1\) 并加上一个期望为零但与 \(k_1\) 负相关的量
  • 幸运的是,作者发现 \(\frac{p(x)}{q(x)} - 1 = r - 1\) 是一个期望为0的量),于是,对于任意 \(\lambda\),下面的表达式都是 \(D_\text{KL}(q||p)\) 的无偏估计量:
    $$-\log r + \lambda(r - 1)$$
    • 注:作者可以通过计算最小化这个估计量的方差来求解 \(\lambda\)。但不幸的是,作者得到的表达式依赖于 \(p\) 和 \(q\),并且难以解析计算
  • 所以,作者使用更简单的策略选择一个好的 \(\lambda\)
    • 作者注意到由于 \(\log\) 是凹函数,有 \(\log(x) \leq x - 1\)
    • 因此,如果作者设 \(\lambda=1\),上述表达式保证为正。它测量了 \(\log(x)\) 与其切线之间的垂直距离
    • 于是作者提出了估计量 \(k_3 = (r - 1) - \log r\)

更多扩展和思考

  • 这种通过观察凸函数与其切平面之间的差异来测量距离的思想出现在许多地方。它被称为Bregman散度 ,具有许多优美的性质
  • 可以推广上述思想,为任何 \(f\)-散度得到一个良好的、总是正的估计量
  • 另一个KL散度是 \(KL[p,q]\)(注意这里 \(p\) 和 \(q\) 交换了位置 ,与 \(D_\text{KL}(q||p)\) 不同)
  • 由于 \(f\) 是凸函数,且 \(\mathbb{E}_q[r] = 1\),以下表达式是 \(f\)-散度的估计量:
    $$f(r) - f’(1)(r - 1)$$
  • 这个量总是正的,因为它是 \(f\) 在 \(r=1\) 处与切线的距离,而凸函数位于其切线上方。现在 \(KL[p,q]\) 对应 \(f(x) = x \log x\),它有 \(f’(1) = 1\),于是有估计量 \(r \log r - (r - 1)\)
  • 总结一下,作者提出以下估计量(对于样本 \(x \sim q\),且 \(r = \frac{p(x)}{q(x)}\)):
    $$
    \begin{align}
    D_\text{KL}(p||q): &\quad r \log r - (r - 1)\\
    D_\text{KL}(q||p): &\quad (r - 1) - \log r
    \end{align}
    $$
    • 注意上面的 KL 散度先后顺序,KL 散度不是对称的

Experiments

  • 现在让作者比较三个 \(D_\text{KL}(q||p)\) 估计量的偏差和方差
  • 定义 \(q = \mathcal{N}(0,1)\), \(p = \mathcal{N}(0.1,1)\)(此时真实的KL散度为0.005)
    统计量/真实值 \(k_1\) \(k_2\) \(k_3\)
    偏差(期望与真实值差) 0 0.002 0
    标准差 20 1.42 1.42
    • 注意到 \(k_2\) 的偏差在这里极低:仅为0.2%
    • 以上标准差是使用 KL 单独作归一化后的(除以 KL 散度)
  • 定义 \(p = \mathcal{N}(1,1)\) (此时真实KL散度为0.5)
    统计量/真实值 \(k_1\) \(k_2\) \(k_3\)
    偏差(期望与真实值差) 0 0.25 0
    标准差 2 1.73 1.7
    • 这里 \(k_2\) 的偏差大得多
    • \(k_3\) 在保持无偏的同时甚至比 \(k_2\) 具有更低的标准差,因此它似乎是一个严格更好的估计量
  • 作者给出的上述实验的代码:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    import torch.distributions as dis
    p = dis.Normal(loc=0, scale=1)
    q = dis.Normal(loc=0.1, scale=1)
    x = q.sample(sample_shape=(10_000_000,))
    truekl = dis.kl_divergence(p, q)
    print("true", truekl)
    logr = p.log_prob(x) - q.log_prob(x)
    k1 = -logr
    k2 = logr ** 2 / 2
    k3 = (logr.exp() - 1) - logr
    for k in (k1, k2, k3):
    print((k.mean() - truekl) / truekl, k.std() / truekl)

一些总结和思考

  • 三种估计量的对比如下:
    估计量 无偏性 方差 偏差(当 \(p \neq q\)) 计算复杂度 适用场景
    \(k_1\) 无偏 0 理论分析,高精度需求
    \(k_2\) 有偏(低) 中低 随 \(KL\) 增大而增加 \(p \approx q\) 时的一些近似判别
    \(k_3\) 无偏 最低 0 中(需算 \(r\)) 高精度需求(如优化算法)
  • 方差排序 : \(k_3 < k_2 < k_1\)
  • 偏差权衡
    • 若 \(p \approx q\), \(k_2\) 的偏差可忽略,且计算简单
    • 若 \(KL\) 较大(如 \(p,q\) 差异显著), \(k_3\) 是唯一同时满足无偏和低方差的选项
  • 实践建议
    • 使用 \(k_3\) 作为默认选择(尤其对敏感任务)
    • 在快速迭代或 \(p \approx q\) 时,可用 \(k_2\) 作为轻量替代

不同估计值的函数图像

  • 函数图像如下:
  • 生成代码如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    import numpy as np
    import matplotlib.pyplot as plt

    # Define the functions
    def k1(r):
    return -np.log(r)

    def k2(r):
    return 0.5 * (np.log(r))**2

    def k3(r):
    return (r - 1) - np.log(r)

    x_max = 10

    # Generate r values (avoid r=0 for log)
    r = np.linspace(0.01, x_max, 500)

    # Calculate function values
    k1_vals = k1(r)
    k2_vals = k2(r)
    k3_vals = k3(r)

    # Create plot
    plt.figure(figsize=(10, 6))
    plt.plot(r, k1_vals, label='k1(r) = -log(r)', linewidth=2)
    plt.plot(r, k2_vals, label='k2(r) = 0.5*(log(r))^2', linewidth=2)
    plt.plot(r, k3_vals, label='k3(r) = (r-1) - log(r)', linewidth=2)

    # Add special points and lines
    plt.axvline(1, color='gray', linestyle='--', alpha=0.5)
    plt.plot(1, 0, 'ro') # All functions equal 0 at r=1

    # Plot formatting
    plt.title('KL Divergence Estimators as Functions of r=p(x)/q(x)', fontsize=14)
    plt.xlabel('r = p(x)/q(x)', fontsize=12)
    plt.ylabel('Estimator Value', fontsize=12)
    plt.legend(fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.xlim(0, x_max)
    plt.ylim(-3, 10)

    # Show plot
    plt.tight_layout()
    plt.show()

附录:凹函数和凸函数简单介绍

凸函数(Convex Function)

  • 若函数 \( f \) 的定义域为某个凸集(如区间),且对定义域内任意两点 \( x_1, x_2 \) 和任意 \( \lambda \in [0, 1] \),满足:
    $$
    f(\lambda x_1 + (1-\lambda) x_2) \leq \lambda f(x_1) + (1-\lambda) f(x_2)
    $$
  • 则称 \( f \) 为凸函数
  • 直观理解:函数图像上任意两点间的线段始终位于函数图像上方(或重合),形如“碗状”或“线性”
  • 举例: \( f(x) = x^2 \)、\( f(x) = e^x \))

凹函数(Concave Function)

  • 若函数 \( f \) 的定义域为凸集,且对任意两点 \( x_1, x_2 \) 和 \( \lambda \in [0, 1] \),满足:
    $$
    f(\lambda x_1 + (1-\lambda) x_2) \geq \lambda f(x_1) + (1-\lambda) f(x_2)
    $$
  • 则称 \( f \) 为凹函数
  • 直观理解:函数图像上任意两点间的线段始终位于函数图像下方(或重合),形如“拱形”或“线性”
  • 举例:\( f(x) = -x^2 \)、\( f(x) = \ln x \)(定义域 \( x > 0 \))

补充说明

  1. 线性函数既是凸的也是凹的(因不等式取等号)
  2. 凹凸性反转 :若 \( f \) 是凸函数,则 \( -f \) 是凹函数,反之亦然
  3. 严格凸/凹 :当不等式在 \( x_1 \neq x_2 \) 且 \( \lambda \in (0,1) \) 时严格成立(如 \( < \) 或 \( > \)),则称函数为严格凸或严格凹