# 变分推断

TIP

这一节里有的公式似乎没有区分积分和求和,那就不区分了吧反正意思对了就行 2333。

EM 算法中有:

logp(x)=ELBO(q,x)+KL(q(z)p(zx))(1)\log p(x) = \text{ELBO}(q,x) + \text{KL}(q(z) \| p(z \mid x)) \tag{1}

在 E 步中,我们需要找到一个 q(z)=p(zx)q(z) = p(z | x),从而使得 logp(x)=ELBO(q,x)\log p(x) = \text{ELBO}(q,x)。这里假设 p(zx)p(z | x) 是可以计算的,但这个假设有可能是不成立的,后验可能是 intractable 的。


这里解释一下为啥 intractable。由贝叶斯定理:

p(zx)=p(xz)p(z)p(x)p(z \mid x) = \frac{p(x \mid z)p(z)}{p(x)}

也就是说这里需要算输入数据的分布 p(x)p(x)

p(x)=zp(xz)p(z)dzp(x)=\int_z p(x \mid z)p(z) dz

而在很多情况下这是没法算的(不然还要 GAN 这种调参调到死的玩意儿干啥)。


因此有了变分推断(Variational Inference),也称变分贝叶斯(Variational Bayesian),变分推断可以看作 EM 算法的扩展版,主要处理不能精确求出 p(zx)p(z \mid x) 的情况。

变分推断的思想是,在 E 步中,寻找一个简单分布 q(z)q^\ast(z) 来近似 p(zx)p(z \mid x)

q(z)=argminq(z)QKL(q(z)p(zx))(2)q^\ast(z) = \arg \min_{q(z) \in Q} \text{KL}(q(z) \| p(z \mid x)) \tag{2}

其中 QQ 为候选的概率分布族。当 KL(q(z)p(zx))\text{KL}(q(z) \| p(z \mid x)) 无限接近于 0 时,q(z)q^\ast(z)p(zx)p(z \mid x) 就无限接近。但如刚刚所说,p(zx)p(z \mid x) 难以直接计算,因此我们不能直接优化这个 KL 散度。

结合式 (1) 和式 (2),有:

q(z)=argminq(z)Q(logp(x)ELBO(q,x))=argmaxq(z)QELBO(q,x)\begin{aligned} q^\ast(z) &= \arg \min_{q(z) \in Q} (\log p(x) - \text{ELBO}(q,x)) \\ &= \arg \max_{q(z) \in Q} \text{ELBO}(q,x) \end{aligned}

所以最小化 KL 散度被转化为了最大化 ELBO。这里 ELBO 是一个以函数 qq 为自变量的函数,即泛函。

这里的公式都没有写参数 θ\theta,因为变分推断中的参数都是随机变量,可以直接算进隐变量 zz 里。

# ELBO

上一节中算出 KL(q(z)p(zx))\text{KL}(q(z) \| p(z \mid x)) 为:

KL(q(z)p(zx))=zq(z)logp(zx)q(z)=Eq(z)[logq(z)]Eq(z)[logp(zx)]=Eq(z)[logq(z)]Eq(z)[logp(x,z)p(x)]=Eq(z)[logq(z)]Eq(z)[logp(x,z)]+logp(x)\begin{aligned} \text{KL}(q(z) \| p(z \mid x)) &= - \sum_z q(z) \log \frac{p(z \mid x)}{q(z)} \\ &= \mathbb{E}_{q(z)}[\log q(z)] - \mathbb{E}_{q(z)}[\log p(z \mid x)] \\ &= \mathbb{E}_{q(z)}[\log q(z)] - \mathbb{E}_{q(z)}[\log \frac{p(x,z)}{p(x)}] \\ &= \mathbb{E}_{q(z)}[\log q(z)] - \mathbb{E}_{q(z)}[\log p(x,z)] + \log p(x) \end{aligned}

则 ELBO 为:

ELBO(q,x)=logp(x)KL(q(z)p(zx))=Eq(z)[logp(x,z)]Eq(z)[logq(z)]\begin{aligned} \text{ELBO}(q,x) &= \log p(x) - \text{KL}(q(z) \| p(z \mid x)) \\ &= \mathbb{E}_{q(z)}[\log p(x,z)] - \mathbb{E}_{q(z)}[\log q(z)] \end{aligned}

EM 算法中:

θt+1=argmaxθzpθt(zx)logpθ(x,z)=argmaxθEpθt(zx)[logpθ(x,z)]\begin{aligned} \theta_{t+1} &= \arg \max_\theta \sum_z p_{\theta_t}(z \mid x) \log p_\theta(x,z) \\ &= \arg \max_\theta \mathbb{E}_{p_{\theta_t}(z \mid x)}[\log p_\theta(x,z)] \end{aligned}

可以看到变分推断中的 ELBO 相比 EM 算法中的 ELBO 大概多了 Eq(z)[logq(z)]- \mathbb{E}_{q(z)}[\log q(z)] 这一项。这是因为 EM 算法中 q(z)q(z) 是常数项,而在变分推断中并不是。

btw,Eq(z)[logq(z)]- \mathbb{E}_{q(z)}[\log q(z)] 就是 q(z)q(z) 的熵,可以表示为 H[q(z)]H[q(z)]

# 平均场分布族

候选分布族 QQ 的复杂性决定了优化问题的复杂性。我们选 QQ 的时候可以选我们知道到的、简单的、最好是独立同分布的概率分布。通常会选平均场(mean-field)分布族,即 zz 可以分拆为多组相互独立的变量,概率密度 q(z)q(z) 可以分解为:

q(z)=m=1Mqm(zm)q(z) = \prod_{m=1}^M q_m(z_m)

其中 zmz_m 是隐变量的子集,可以是单变量,也可以是一组多元变量。

那么 ELBO(q,x)\text{ELBO}(q,x) 可以写为:

ELBO(q,x)=q(z)logp(x,z)q(z)dz=q(z)(logp(x,z)logq(z))dz=m=1Mqm(zm)logp(x,z)dzpart1m=1Mqm(zm)m=1Mlogqm(zm)dzpart2\begin{aligned} \text{ELBO}(q,x) &= \int q(z) \log \frac{p(x,z)}{q(z)} dz \\ &= \int q(z) (\log p(x,z) - \log q(z)) dz \\ &= \underbrace{\int \prod_{m=1}^M q_m(z_m) \log p(x,z) dz}_{\text{part1}} - \underbrace{\int \prod_{m=1}^M q_m(z_m) \sum_{m=1}^M \log q_m(z_m) dz}_{\text{part2}} \\ \end{aligned}

对于 part1:

part1=z1 ⁣zMm=1Mqm(zm)logp(x,z)dz1dzM\text{part1} = \int_{z_1} \dots \int_{z_M} \prod_{m=1}^M q_m(z_m) \log p(x,z) dz_1 \dots dz_M

如果只对隐变量的某个子集 zjz_j 感兴趣:

part1=qj(zj)( ⁣zmjmjMqm(zm)logp(x,z)mjMdzm)dzj=qj(zj)( ⁣zmjlogp(x,z)mjMqm(zm)dzm)dzj=qj(zj)(mjMqm(zm)logp(x,z)dzm)dzj\begin{aligned} \text{part1} &= \int q_j(z_j) \left ( \int \dots \int_{z_{m \ne j}} \prod_{m \ne j}^M q_m(z_m) \log p(x,z) \prod_{m \ne j}^M dz_m \right ) dz_j \\ &= \int q_j(z_j) \left ( \int \dots \int_{z_{m \ne j}} \log p(x,z) \prod_{m \ne j}^M q_m(z_m) dz_m \right ) dz_j \\ &= \int q_j(z_j) \left ( \int \prod_{m \ne j}^M q_m(z_m) \log p(x,z) dz_m \right ) dz_j \end{aligned}

再令:

logp(x,zj)=mjqm(zm)logp(x,z)dzm(3)\log \overline{p}(x, z_j) = \int \prod_{m \ne j} q_m(z_m) \log p(x,z) dz_m \tag{3}

p(x,zj)\overline{p}(x, z_j) 可以看作一个关于 zjz_j 的未归一化的分布。

最终有:

part1=qj(zj)logp(x,zj)dzj\text{part1} = \int q_j(z_j) \log \overline{p}(x, z_j) dz_j

对于 part2:

part2=z1 ⁣zMm=1Mqm(zm)m=1Mlogqm(zm)dz1dzM=z1 ⁣zM[logq1(z1)+logqM(zM)](q1(z1)qM(zM))dz1dzM=z1q1(z1)logq1(z1)dz1++zMqM(zM)logqM(zM)dzM=m=1M(zmqm(zm)logqm(zm)dzm)\begin{aligned} \text{part2} &= \int_{z_1} \dots \int_{z_M} \prod_{m=1}^M q_m(z_m) \sum_{m=1}^M \log q_m(z_m) dz_1 \dots dz_M \\ &= \int_{z_1} \dots \int_{z_M} [\log q_1(z_1) + \dots \log q_M(z_M)] (q_1(z_1) \dots q_M(z_M)) dz_1 \dots dz_M \\ &= \int_{z_1} q_1(z_1) \log q_1(z_1) dz_1 + \dots + \int_{z_M} q_M(z_M) \log q_M(z_M) dz_M \\ &= \sum_{m=1}^M \left( \int_{z_m} q_m(z_m) \log q_m(z_m) dz_m \right ) \end{aligned}

同样,如果只对隐变量的某个子集 zjz_j 感兴趣:

part2=qj(zj)logqj(zj)dzj+const\text{part2} = \int q_j(z_j) \log q_j(z_j) dz_j + \text{const}

其中,const 为一个常数,即所有与 qj(zj)q_j(z_j) 无关的项。

现在 ELBO(q,x)\text{ELBO}(q,x) 可以写为:

ELBO(q,x)=part1part2=qj(zj)logp(x,zj)dzjqj(zj)logqj(zj)dzj+const=qj(zj)logp(x,zj)qj(zj)dzj+constKL(qj(zj)p(x,zj))+const\begin{aligned} \text{ELBO}(q,x) &= \text{part1} - \text{part2}\\ &= \int q_j(z_j) \log \overline{p}(x, z_j) dz_j - \int q_j(z_j) \log q_j(z_j) dz_j + \text{const} \\ &= \int q_j(z_j) \log \frac{\overline{p}(x, z_j)}{q_j(z_j)} dz_j + \text{const} \\[1em] & \rarr - \text{KL} (q_j(z_j) \| \overline{p}(x, z_j)) + \text{const} \end{aligned}

# 坐标上升法

也就是说,如果我们固定除了 zjz_j 以外的其他隐变量 zjz_{-j} 不变,那么 ELBO 可以被看做一个负 KL 散度加上一个常数。因此最小化 KL 散度 KL(qj(zj)p(x,zj))\text{KL} (q_j(z_j) \| \overline{p}(x, z_j)) 就等于最大化 ELBO。

因此最优的 qj(zj)q_j^*(z_j) 正比于对数联合概率密度 logp(x,z)\log p(x,z) 的期望的指数(由式 (3) 推出):

qj(zj)=p(x,zj)exp(Eqj[logp(x,z)])q_j^*(z_j) = \overline{p}(x, z_j) \propto \exp(\mathbb{E}_{q_{-j}}[\log p(x,z)])

可以用**坐标上升法(Coordinate Ascent Variational Inference,CAVI)**来迭代优化每个 qj(zj)q_j^*(z_j)(同时会假设其他隐变量固定不变)。坐标上升法流程为:

cavi

# 其他

我们通常会选择一些比较简单的分布 q(z)q(z) 来近似 p(zx)p(z \mid x)。但当 p(zx)p(z \mid x) 比较复杂时,往往很难用简单的 q(z)q(z) 去近似。这时可以用神经网络的强大拟合能力来近似 p(zx)p(z \mid x),这种思想被应用在了变分自编码器中

# 参考