DL——重参数化技巧


重参数化解决的问题

  • 问题 :假设用NN建模一个分布,比如正太分布可以表达为 \(\mathcal{N}(\mu_\theta,\sigma_\theta)\),此时如果直接从NN建模的分布中采样,由于采样动作是离散的,那么这个采样结果不包含NN分布的梯度信息的,NN反向传播时无法传播回去,也无法实现对参数 \(\theta\) 的更新
  • 重参数化技巧 :通过一些技巧设计采样方式,使得采样过程可导,让采样结果包含NN分布的梯度信息(即实现既可按照NN分布采样又可回传梯度信息

重参数化的基本思想

  • 不能梯度回传的本质原因是因为采样过程是一种选择动作,这种选择动作本身没有梯度信息,把采样过程挪到计算图之外
  • 用形式来表示,将 \(z \sim f(\theta)\) 构建为形如 \(z = g(\theta, \epsilon), \epsilon \sim p\) 的形式(其中p是与参数无关的某个分布,比如高斯分布)

连续变量分布采样的重参数化

  • 以正太分布为例,原始NN分布采样形式:
    $$ z \sim \mathcal{N}(\mu_\theta,\sigma_\theta) $$
  • 重参数技巧采样:
    $$
    \begin{align}
    \epsilon \sim \mathcal{N}(0,1) \\
    z = \mu_\theta + \sigma_\theta \cdot \epsilon
    \end{align}
    $$

离散变量分布采样的重参数化

以下内容主要参考自重参数化技巧(Gumbel-Softmax)以及其中的回复讨论

原版 softmax(原始问题):

1
2
3
logits = model(x)
probs = softmax(logits)
r = torch.multinomial(probs, num_samples)
  • 采到的 r 都是整数 ID,后面可以用 r 去查 embedding table。缺点是采样这一步把计算图弄断了

Gumbel-Max Trick:

1
2
3
4
5
6
7
8
def sample_gumbel(shape, eps=1e-20, tens_type=torch.FloatTensor):
"""Sample from Gumbel(0, 1)"""
U = Variable(tens_type(*shape).uniform_(), requires_grad=False)
return -torch.log(-torch.log(U + eps) + eps)

logits = model(x)
g = sample_gumbel(logits.size())
r = torch.argmax(logits + g)
  • 采到的 r 都是整数 ID,后面可以用 r 去查 embedding table,计算图连起来了,但 argmax 仍不可导
  • 为什么一定要用sample_gumbel分布而不是其他分布?
    • 因为只有使用gumbel分布采样才能保证与原始softmax后的多项式分布采样完全等价,即 argmax(logits + Gumbel随机变量)与多项式分布采样严格等价 ,相关证明见:漫谈重参数:从正态分布到Gumbel Softmax
  • Gumbel分布的具体定义是什么?
    • 一般Gumbel分布的PDF和CDF:
      $$
      \begin{align}
      \text{PDF}: \quad f(x;\mu,\beta) = e^{-(z+e^{-z})},\quad z=\frac{x-\mu}{\beta} \\
      \text{CDF}: \quad F(x;\mu,\beta) = e^{-e^{-z}}, \quad z=\frac{x-\mu}{\beta}
      \end{align}
      $$
      • \(\mu\) 是位置参数(location parameter)
      • \(\beta\) 是尺度参数(scale parameter)
    • 标准Gumbel分布中, \(\mu=0,\ \beta=1\),此时有 \(z=x\)
      $$
      \begin{align}
      \text{PDF}: \quad f(x;\mu,\beta) = e^{-(x+e^{-x})} \\
      \text{CDF}: \quad F(x;\mu,\beta) = e^{-e^{-x}}
      \end{align}
      $$
  • 在这个场景中,我们使用标准Gumbel分布即可
  • 采样标准Gumbel分布时,可以直接使用逆变换采样(Inverse Transform Sampling)
    • 先按照均匀分布采样: \(u = \mathcal{U}(0,1)\)
    • 对Gumbel分布原始CDF取逆Gumbel分布采样结果: \(z = -ln(-ln(u))\)

Gumbel-Softmax Trick:

1
2
3
logits = model(x)
g = sample_gumbel(logits.size())
r = F.softmax(logits + g)
  • 采到的 r 都是概率分布,后面可以用 r 把 embedding table 里的各个条目加权平均混合起来,假装是一个单词拿去用。虽然计算图可导了,但是训练和推断不一致!训练时模型见到的都是各个 word embedding 的混合,而非独立的 word embedding!推断时则使用的是独立的 word embedding!

Gumbel-Softmax Trick + Straight-Though Estimator:

1
2
3
4
5
logits = model(x)
g = sample_gumbel(logits.size())
r = F.softmax(logits + g)
r_hard = torch.argmax(r)
r = (r_hard - r).detach() + r
  • 采到的 r 都是整数 ID,后面可以用 r 去查 embedding table
  • 前向传播使用 r_hard 获得独立的单词,反向传播使用 r(即 softmax 的结果)的梯度。一切都很完美
  • Straight-Through Estimator 的意思是说,如果你遇到某一层不可导,你就当它的梯度是 identity,直接把梯度漏下去,即假定当前层的梯度为1
  • 实际上此时正向传播和反向传播面对的公式也不一样
    • 正向传播时得到的是r_hard
    • 反向传播时,由于(r_hard - r).detach()使得梯度为0,所以回传的实际是r的反向梯度

argmax动作的梯度回传

  • argmax操作的形式:
    $$
    \begin{align}
    i^* &= \mathop{\arg\max}_i (\vec{x}) \\
    \text{where} \quad \vec{x}=&(x_1, x_2, \cdots, x_n), \quad x_i = f(\theta)_i
    \end{align}
    $$
    • 注:以上argmax的写法不严谨,严谨的是 \(i^* = \mathop{argmax}_i x_i, \ x_i \in \vec{x}\)
  • 近似形式:
    $$
    \begin{equation}
    \mathop{\arg\max}_i (\vec{x}) \approx \sum_{i=1}^n i\times \text{softmax}(\vec{x})_i
    \end{equation}
    $$
  • argmax本质也可以看做一种离散采样,只是没有随机性,该采样选择使得目标值最大的离散变量
  • 详情见:函数光滑化杂谈:不可导函数的可导逼近