Analytic-DPM 速览
November 8, 2023
Analytic-DPM 是 ICLR 2022 的最佳论文,作者是清华的 Fan Bao。这是一篇理论性很强的论文,因此这里不会给出详细的推导解释,只阐述这篇论文究竟做了一件什么事情。
关于这篇论文,这里我们还拿到了来自 Fan Bao 的 Presentation Slide,大家可以参考阅读一下。
TL & DR #
为了阐述论文到底在做什么,我们需要首先回忆一下,DDPM 的推导过程
- 定义一个加噪过程 $p(x_t|x_{t-1})$,由 $\beta$ 序列决定
- 根据这个加噪过程推导出 $p(x_t|x_0)$ 和 $q(x_{t-1}|x_0, x_t)$,前者直接展开就可以了,后者利用贝叶斯公式推导 $$ q(x_{t-1} \vert x_t, x_0) = q(x_t \vert x_{t-1}, x_0) \frac{ q(x_{t-1} \vert x_0) }{ q(x_t \vert x_0) } $$
- 反向去噪需要用到 $p(x_{t-1}|x_t)$, 我们需要估计这个分布的均值和方差,这个可以用网络来估计。
- 网络训练的话通过 ELBO 训练
- 通过推导可以发现,我们约束均值需要和 $q(x_{t-1}|x_0, x_t)$ 的均值一致,因此可以直接采用其均值,预测未知部分。
- 对于 DDPM 来说,方差不训练,直接采用固定的值,DDPM 测试了取 $p(x_t|x_{t-1})$ 的方差 或者 $q(x_{t-1}|x_0, x_t)$ 的方差,结果差不多。DDIM 则是使用了 0 方差,不过对应均值的表达式也变了。
可以看出,这里面方差的选取是一个仍然可以调整的变量,最优的方差是什么我们还不知道(何为最优呢?对于 Analytic-DPM 来说,选取这个方差我们得到的是 Loss 的最优解)。
Analytic-DPM 这篇论文的贡献就是,他们发现这个最优的方差是可以解析的算出来的。
他们的结论是:
$$
\begin{aligned}
& \boldsymbol{\mu}_n^*\left(\boldsymbol{x}_n\right)=\tilde{\boldsymbol{\mu}}_n\left(\boldsymbol{x}_n, \frac{1}{\sqrt{\bar{\alpha}_n}}\left(\boldsymbol{x}_n+\bar{\beta}_n \nabla_{\boldsymbol{x}_n} \log q_n\left(\boldsymbol{x}_n\right)\right)\right) \\
& \sigma_n^{* 2}=\lambda_n^2+\left(\sqrt{\frac{\bar{\beta}_n}{\alpha_n}}-\sqrt{\bar{\beta}_{n-1}-\lambda_n^2}\right)^2\left(1-\bar{\beta}_n \mathbb{E}_{q_n\left(\boldsymbol{x}_n\right)} \frac{\left\|\nabla_{\boldsymbol{x}_n} \log q_n\left(\boldsymbol{x}_n\right)\right\|^2}{d}\right),
\end{aligned}
$$
课代表总结:最优均值和 DDPM 的一致,保持不变。方差不一样了。
这里用了 Score-based 的表示,实际使用时,需要转换一下变成DDPM的$\epsilon$表示。
- $\lambda_n^2$ 是论文里对方差的一个通用表示,对于 DDPM 直接取 $\bar{\beta_n}$ 即可。
- 公式里的数学期望,可以通过蒙特卡洛估计计算,抽样的分布是$q_n(x_n)$,即$x_n$的数据分布,因此,可以直接用整个数据集的数据,每个样本加噪得到$x_n$即可。