扩散模型中的 PNDM 采样
November 7, 2023
PNDM 采样算法是一个发表于 ICLR 2022 的工作,最早公开的时间是 2022 年的 2 月份。作者是来自浙江大学的 Luping Liu (DPM-Solver 的作者是清华的 Cheng Lu,别搞混了 hh)。
这篇论文一开始看的时候感觉好多公式,还挺复杂的,但是当理解 DDIM 之后就会发现,这篇论文其实 DDIM 非常简单的一个扩展,论文中的理论部分其实并没有特多新的东西,不过比较惊喜的是,效果却非常好。
关于 DDIM 采样算法的推导 可以查看上一篇 博客 的介绍。
另一方面,DDIM 2020年 10月就公开了,这么简单的扩展居然到了 2022年 2 月份才出现,可见 DDIM 的超前程度,orz。
Background #
我们知道扩散模型的采样过程可以视作一个梯度下降的过程,其中梯度可以用 DDPM 训练出来的 UNet 进行求解,下降的过程就是往 data density 越高的地方走,因此可以实现一个 generation 的效果。
另一方面,Song Yang 大神在 Score-Based Generative Modeling through Stochastic Differential Equations 中证明了这个采样过程可以用随机微分方程(SDE)或者常微分方程(ODE)进行建模。
这意味着,当我们写出扩散模型的微分方程后,我们就可以用发展了几十年的求解微分方程的数值方法对其进行求解。虽然我已经把本科《数值分析》课学的全还给老师了,但是从直觉上求解常微分方程是更容易的(事实也是如此),因此后续用于加速扩散模型采样的算法大部分都是基于 ODE 来做的,并且他们算法的设计大量参考了诸如 forward Euler method, Runge-Kutta method, linear multi-step method, predictor–corrector method 等数值算法。
首先,让我们来回忆一下微分方程
《不要紧张…》
微分方程就是一个包含了一个函数 $f$ 和其各阶导数的 $f'$, $f''$ … $f^n$ 的一个方程。例如 $f'(x)= xf(x)$
如果微分方程里包含随机变量,那么它就是随机微分方程,如果没有,那它就是常微分方程。随机微分方程比较复杂,非数学系应该都不太学,不过常微分方程高数里是有的。在扩散模型的加速采样中,我们大部分情况都只考虑常微分方程,也就是 ODE。
特别地,在扩散模型里,我们只需要考虑ODE一个更简单的形式,也就是 linear differential equation,在这里微分方程里的所有函数/导函数,都是线性组合的,不会出现诸如 $f'(x) = e^{f(x)}$ 这种非线性的情况。具体地,这类 ODE 可以写成下面这样:
$$ a_0(x) y+a_1(x) y^{\prime}+a_2(x) y^{\prime \prime}+\cdots+a_n(x) y^{(n)}+b(x)=0 $$
对于扩散模型来说,从 DDPM 的角度来看,我们是一个通过定义一个固定的加噪过程,然后通过这个固定加噪过程去推导去噪过程,这个加噪过程可以描述为:
$$ x_t = \sqrt{1-\beta_t} x_{t-1} + \sqrt{\beta_t} \epsilon $$
如果把加噪的时间步变成无穷大,那么这就变成了一个连续的加噪过程,并且从上述加噪过程,我们也可以看出数据 x 随时间 t 的变化(也就是导数)应该满足:
- 与当前的数据 x 有关
- 要再加上一个高斯噪声
这用公式写下来就是,
$$ \frac{dx}{dt} = f(x,t) + g(t) \epsilon $$
其中 $\epsilon$ 在微分方程里可以写成 $dw/dt$,其中w 是一个 standard Wiener process (a.k.a., Brownian motion)。因此我们就得到了扩散模型的 SDE 形式
$$ dx = f(x,t)dt + g(t) dw $$
类似的,如果我们把噪声项去掉,我们就可以得到 ODE 形式,可以证明这两种形式的 reverse SDE/ODE 具有相同的边缘分布 $q(x_t|x_0)$,回忆 DDIM 中我们提到 DDPM 的训练只有边缘分布有关,只要保证边缘分布一直,则采样过程可以复用用一个预训练模型。
$$ dx = f(x,t)dt $$
PNDM #
好了,逐渐超纲.. 说这么多,其实只是想灌输一个概念,就是扩散模型的采样过程可以用 ODE 的数值方法求解,那么最简单的 ODE 数值方法是什么呢?
现在有请欧拉同志 (此处应有掌声 👏)
欧拉法特别简单,其实就是一个梯度下降法,利用当前步的 x 和微分方程估计当前步的梯度,然后做一步梯度下降
$$ x_{t+1} = x_{t} + h \frac{dx}{dt}_{t=t} $$
其中 $h$ 是步长。提问为什么一定能估计出梯度呢?其实上面也提到了扩散模型的 ODE 方程,方程是线性的,并且还是一阶的,我们自然可以解出导数。
事实上,不管是复杂的进阶算法,例如Runge-Kutta, linear multi-step,还是欧拉法,他们都共享了一个统一的更新过程,即
$$ x_{t+1} = x_{t} + h f(x_{t}, t) $$
不同算法的主要区别就在于这个“梯度”的估计方法$f(x_{t}, t)$,PNDM 的作者把更新过程称作《transfer part》,梯度计算过程则成为《gradient part》。
在理解这个之后,PNDM 就非常简单了,首先 DDIM 可以视作求解一个 ODE 的离散形式(参见原论文的章节 3.1),它的 transfer part 可以写成(原论文的公式 11)
其中 $\epsilon_t$ 就是 gradient part。
PNDM 的改进就是把这个 gradient part 替换成了进阶算法使用的 gradient part。如 Runge-Kutta, linear multi-step。
算法总结 #
如下所示(原论文 3.4章)
- 公式 12是 linear multi-step method 的 gradient part,其中 $e_t'$ 对应 $\epsilon_t$ ,$\phi$ 对应上面的 transfer part。
- 算法 1 是 DDIM 的算法,算法 2 是 PNDM 的算法。
- 因为公式 12的 gradient part 需要连续 4 个 step 的 $e_t$,因此一开始不能用这个 gradient part,所以作者在前三步用了 Runge-Kutta 的 gradient step(不过使用 Runge-Kutta不是必须的,任何不需要连续 4 个 step的 gradient step 都可以)。
性能 #
PNDM 的结果是 S-PNDM 和 F-PNDM 那几行。FID 越低越好,应该不需要解释了吧,牛逼就完事了。
3.1 章的推导 #
- 公式 8 到公式 9 利用了 $(a-b)(a+b) = a^2 - b^2$。
- 公式 9 到那个极限,只需要把 $\delta = 0$ 代进去就有了。
后记 #
关于 transfer part
标准的 transfer part 应该是利用原论文的公式 10 得到的$dx/dt$,使用欧拉法的话,是通过$x_{t-\delta} = x_t + h dx/dt$ 进行求解,但如果我们把步长$h$ 设置为1,推导出来这个公式正好就是公式 9,或者说公式 11。
换言之,DDIM 其实就是扩散模型 ODE 形式的步长为 1 的欧法求解方法。
不确定是否严谨,因为这个 ODE 在这篇论文里本身是由 DDIM 导出来的,可能还有其他 ODE 的形式。
关于 gradient part
严格来说,gradient part 部分应该是梯度,但是 $\epsilon_\theta$ 给出的并不是完全是梯度,那为什么它可以视作 gradient part 呢? 关于这一点,作者在 Property 3.1 下面给的解释是
That if $\epsilon_\theta$ is precise, the result of $x_{t-\delta}$ also precise, which means that $\epsilon_\theta$ can determine the direction of the denoising process to generate the final results. Therefore, such a choice also satisfies the definition of a gradient part. Now, we have our gradient part $\epsilon_\theta$ and transfer part $\phi$.
因为 $\epsilon_\theta$ 和 $dx/dt$ 是线性相关的,如下:
$$ \frac{d x}{d t}=-\bar{\alpha}^{\prime}(t)\left(\frac{x(t)}{2 \bar{\alpha}(t)}-\frac{\epsilon_\theta(x(t), t)}{2 \bar{\alpha}(t) \sqrt{1-\bar{\alpha}(t)}}\right) $$
所以对于 linear multi-step 的 gradient part 来说,对 $dx/dt$ 的任何操作的,等价于对 $\epsilon_\theta$ 操作。
对于 Runge-Kutta 法的来说,我们本质上是利用一个离 $x_{t+\delta}$ 更近的 $x$ 去估计梯度,那么我们可以把下面的写法改一下,改成 $x_1 = .., x_2 = .., x3=..$,每一步都利用之前的 x 就是,这样就绕过了每一步 $f$ 是算梯度的问题,因为本质上 $f$ 是为了求 x 服务的,如果我们可以直接求 x,那就不用管 f 怎么来的了。