重参数化解决的问题
- 问题 :假设用NN建模一个分布,比如正太分布可以表达为 \(\mathcal{N}(\mu_\theta,\sigma_\theta)\),此时如果直接从NN建模的分布中采样,由于采样动作是离散的,那么这个采样结果不包含NN分布的梯度信息的,NN反向传播时无法传播回去,也无法实现对参数 \(\theta\) 的更新
- 重参数化技巧 :通过一些技巧设计采样方式,使得采样过程可导,让采样结果包含NN分布的梯度信息(即实现既可按照NN分布采样 ,又可回传梯度信息)
重参数化的基本思想
- 不能梯度回传的本质原因是因为采样过程是一种选择动作,这种选择动作本身没有梯度信息,把采样过程挪到计算图之外
- 用形式来表示,将 \(z \sim f(\theta)\) 构建为形如 \(z = g(\theta, \epsilon), \epsilon \sim p\) 的形式(其中p是与参数无关的某个分布,比如高斯分布)
连续变量分布采样的重参数化
- 以正太分布为例,原始NN分布采样形式:
$$ z \sim \mathcal{N}(\mu_\theta,\sigma_\theta) $$ - 重参数技巧采样:
$$
\begin{align}
\epsilon \sim \mathcal{N}(0,1) \\
z = \mu_\theta + \sigma_\theta \cdot \epsilon
\end{align}
$$
离散变量分布采样的重参数化
以下内容主要参考自重参数化技巧(Gumbel-Softmax)以及其中的回复讨论
原版 softmax(原始问题):
1 | logits = model(x) |
- 采到的 r 都是整数 ID,后面可以用 r 去查 embedding table。缺点是采样这一步把计算图弄断了
Gumbel-Max Trick:
1 | def sample_gumbel(shape, eps=1e-20, tens_type=torch.FloatTensor): |
- 采到的 r 都是整数 ID,后面可以用 r 去查 embedding table,计算图连起来了,但 argmax 仍不可导
- 为什么一定要用sample_gumbel分布而不是其他分布?
- 因为只有使用gumbel分布采样才能保证与原始softmax后的多项式分布采样完全等价,即 argmax(logits + Gumbel随机变量)与多项式分布采样严格等价 ,相关证明见:漫谈重参数:从正态分布到Gumbel Softmax
- Gumbel分布的具体定义是什么?
- 一般Gumbel分布的PDF和CDF:
$$
\begin{align}
\text{PDF}: \quad f(x;\mu,\beta) = e^{-(z+e^{-z})},\quad z=\frac{x-\mu}{\beta} \\
\text{CDF}: \quad F(x;\mu,\beta) = e^{-e^{-z}}, \quad z=\frac{x-\mu}{\beta}
\end{align}
$$- \(\mu\) 是位置参数(location parameter)
- \(\beta\) 是尺度参数(scale parameter)
- 标准Gumbel分布中, \(\mu=0,\ \beta=1\),此时有 \(z=x\)
$$
\begin{align}
\text{PDF}: \quad f(x;\mu,\beta) = e^{-(x+e^{-x})} \\
\text{CDF}: \quad F(x;\mu,\beta) = e^{-e^{-x}}
\end{align}
$$
- 一般Gumbel分布的PDF和CDF:
- 在这个场景中,我们使用标准Gumbel分布即可
- 采样标准Gumbel分布时,可以直接使用逆变换采样(Inverse Transform Sampling) :
- 先按照均匀分布采样: \(u = \mathcal{U}(0,1)\)
- 对Gumbel分布原始CDF取逆Gumbel分布采样结果: \(z = -ln(-ln(u))\)
Gumbel-Softmax Trick:
1 | logits = model(x) |
- 采到的 r 都是概率分布,后面可以用 r 把 embedding table 里的各个条目加权平均混合起来,假装是一个单词拿去用。虽然计算图可导了,但是训练和推断不一致!训练时模型见到的都是各个 word embedding 的混合,而非独立的 word embedding!推断时则使用的是独立的 word embedding!
Gumbel-Softmax Trick + Straight-Though Estimator:
1 | logits = model(x) |
- 采到的 r 都是整数 ID,后面可以用 r 去查 embedding table
- 前向传播使用 r_hard 获得独立的单词,反向传播使用 r(即 softmax 的结果)的梯度。一切都很完美
- Straight-Through Estimator 的意思是说,如果你遇到某一层不可导,你就当它的梯度是 identity,直接把梯度漏下去,即假定当前层的梯度为1
- 实际上此时正向传播和反向传播面对的公式也不一样
- 正向传播时得到的是
r_hard - 反向传播时,由于
(r_hard - r).detach()使得梯度为0,所以回传的实际是r的反向梯度
- 正向传播时得到的是
argmax动作的梯度回传
- argmax操作的形式:
$$
\begin{align}
i^* &= \mathop{\arg\max}_i (\vec{x}) \\
\text{where} \quad \vec{x}=&(x_1, x_2, \cdots, x_n), \quad x_i = f(\theta)_i
\end{align}
$$- 注:以上argmax的写法不严谨,严谨的是 \(i^* = \mathop{argmax}_i x_i, \ x_i \in \vec{x}\)
- 近似形式:
$$
\begin{equation}
\mathop{\arg\max}_i (\vec{x}) \approx \sum_{i=1}^n i\times \text{softmax}(\vec{x})_i
\end{equation}
$$ - argmax本质也可以看做一种离散采样,只是没有随机性,该采样选择使得目标值最大的离散变量
- 详情见:函数光滑化杂谈:不可导函数的可导逼近