扩散模型蒸馏之一致性模型

December 21, 2023
扩散模型

Latent Consistency Model(LCM)是近期非常流行的扩散模型蒸馏方法。本文将介绍其背后的基础模型 Consistency Model 以及 LCM 的原理。

图片取自 Latent Consistency Model 论文

🤗 写在前面 #

扩散模型的蒸馏研究事实上在去年就已经开始了,当时的 On Distillation of Guided Diffusion Models 还拿了CVPR 2023 的 Best Paper Nomination。 后续其实也有蛮多蒸馏工作出现,例如:

  1. SnapFusion: Text-to-Image Diffusion Model on Mobile Devices within Two Seconds
    • 【没开源】
  2. BOOT: Data-free Distillation of Denoising Diffusion Models with Bootstrapping
    • 【coming soon半年了】
  3. InstaFlow: One Step is Enough for High-Quality Diffusion-Based Text-to-Image Generation
    • 【严格来说,这个和 LCM 算同期,一开始没有 code,后面开源了模型,没有训练 code】

不过遗憾的是,这些工作都没有获得广泛的关注,直到 Latent Consistency Model 出现。事实上,LCM 刚出现的时候我还在吐槽,怎么他就比了 Guided-Distill 一个方法。 后面转念一想,前面那么多工作其实都没有开源,甚至 Guided-Distill 也是作者自己尝试复现的。因此那些工作没有受到关注也就可以理解了。

总的来说,LCM 可以说是第一个完全开源的扩散模型蒸馏方法(据我所知),而且因为训练速度非常快,可以在 64 batch size上训几百个 step 就可以看出显著效果, 所以它瞬间就席卷了 Stable Diffusion 社区。

Consistency Model (CM) #

在了解 LCM 之前,我们可以先了解一下他的基底模型,也就是 Consistency Model(CM)的原理。 事实上,在了解完 CM 之后,我们就会发现 LCM 其实非常自然,基本上可以看作是 CM 在 text-to-image 的直接扩展。

那么,consistency model 是什么呢?简单来说,

  1. 它是 song yang 博士提出的一类新的生成模型。
  2. 它是扩散模型的一种,但与 DDPM 不同。
  3. 它可以通过蒸馏的方式从 DDPM 的模型提取而来,也可以从零开始训练。
  4. 它具有快速采样的特点,因此可以用来蒸馏 DDPM 训练得到的模型,例如 Stable Diffusion 模型,从而实现加速的目的。

理解 Motivation #

在阅读 consistency model 的论文的时候,我们一上来会看到这么一个定义(右上角的 teaser)

这个定义对于对 score matching 那一套不熟悉的人,可能会感觉有点绕。用扩散模型的概念来说,PF-ODE 对应的是 DDIM 这类确定性的采样过程。 而 consistency model 就是要求我们的扩散模型在一个采样路径上的每一个 noisy sample 预测的 x0 都是一样的。

可能有人要问,我自己一开始也问了,DDPM 的训练目标不就是这个吗,随机采样一个 timestep,然后预测 x0,希望和 gt 的 x0 一致。

答案当时是不一样的,不然 consistency model 还提出来干什么?关键点在于 consistency model 要求训练出来的模型 对于一个采样路径上的每个点都保持 consistency,即预测结果保持一致。

🙋‍♂️ 提问:那么怎么理解这个一致性呢,它为什么能帮助提升采样速度?

回忆 DDPM 的训练目标,是下面所展示的一个 KL 散度:

$$ \mathbb{E}_q[\underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right) \| p\left(\mathbf{x}_T\right)\right)}_{L_T}+\sum_{t>1} \underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\right)}_{L_{t-1}} \underbrace{-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}_{L_0}] $$

它包含 T 项,对应 DDPM 一个完整的去噪路径,但DDPM 实际训练的时候并不是采样一个完整的路径,然后算这个 loss 训练的,它是随机采样的时刻,然后只算这个时刻的 loss 部分,然后通过蒙特卡洛估计去算这个 loss 的期望。

😯 那么很明显,这里有一个问题,DDPM 采样的时候不同时刻对应的 x0 是不一样的,所以这里的 loss 并不能很好的估计出这个 KL 散度的期望。 这么训出来的模型,对于同一个采样路径上的 noisy sample 是没有 “映射到同一个 x0” 的约束的。

没有这个约束的话,那么我们的采样过程可能就会打架,第10 个 step 可能预测的是一个 x0,下一步预测的 x0 就变了,这导致采样路径很不稳定,做了很多无用功。与之相反,consistency model 因为有了一致性,所以可以步子迈的大一些,采样速度也就快了。

那么为什么 DDPM 步子不能太大呢?因为 DDPM 没有一致性,第一步步子迈的大了,虽然你更接近第一步对应的 x0 了,但从第一步的终点出发,并不一定是往第一步的 x0 走,所以得一小步一小步走,利用“函数的平滑性“,相近的点预测的 x0 应该是比较接近的。

🤖《此处应该有一个配图》

原理 🌧️ 细节 #

理解了consistency model 的 motivation 之后,下一步就是如何实现这个约束了。对于此,consistency model 支持两种训练方式:

  1. 从训练好的模型蒸馏得到
  2. 从头开始训练

对于前者,我们的训练 loss 就是约束一个采样路径上的相邻点的输出保持一致。实现起来很简单,采样一个 noisy sample, $x_t$ 然后可以预测一个 $x_0$,接着通过一个 Diffusion ODE Solver(比如 DDIM),得到$x_{t-1}$,再预测一个$\hat{x}_0$,然后就可以算 loss 了, $loss=(x_0- \hat{x}_0)^2$

Consistency Model 的数学定义

Consistency model 就是一个函数(用神经网络拟合),它可以用来求解 PF ODE(即扩散模型的采样过程)。对于一个采样过程里的所有样本点, ${x_t}_{t\in[\epsilon, T]}$,Song Yang 把 Consistency function 定义为:

$$ \boldsymbol{f}:\left(\mathbf{x}_t, t\right) \mapsto \mathbf{x}_\epsilon $$
这个定义包含两个含义
  1. 一个是所有输入的输出都一致,即 “its outputs are consistent for arbitrary pairs …”
  2. 所有输入都映射到了采样过程的初始位置 $x_\epsilon$。
Alt text

Boundary Condition

对于上述定义,我们知道 consistency model 有一个边界条件:

$$ \boldsymbol{f}\left(\mathbf{x}_\epsilon, t\right) = \mathbf{x}_\epsilon $$

对于这个边界条件,如果 consistency model 是个神经网络的话,我们可以直接通过合理定义 consistency function 的参数化方法进行约束。如右图公式 4和公式 5。

在这里,作者采用的是公式 5 的方式。

Tips: 不过实际实验下来,直接用朴素的参数方法也没什么影响,即 $t=\epsilon$ 也用网络的输出,而不是像公式 4 一样,直接置为输入。

Alt text

EMA & Target Network

回忆,consistency model 蒸馏的训练方法,我们需要预测两个 $x_0$,朴素的做法就是用 teacher model 初始化 student model,然后算 loss,不断优化即可。

事实上,我们的优化目标可以看作一种 bootstrapping 方法,类似强化学习里的 time difference 算法,第二步预测的 x0 可以看作是对 x0 的更准确的估计,因此我们希望第一步的 x0 和第二步的 x0 保持一致。

Alt text

另一方面,如果我们第一步第二步都用同一个网络的话会产生累计误差。怎么理解?因为我们是用同一个模型输出 GT(第二步的 x0) 和 Prediction(第一步的 x0),如果优化一次学偏了,那么我们的 GT 就会被带偏,这样累计下来就会越来越偏。

因此,作者在这里使用两个不同的模型(都用 teacher model 初始化),分别预测 GT 和 Prediction,其中预测 GT 的用 EMA 的方式更新,预测 prediction 的网络正常用梯度下降更新。

用强化学习的术语来说,预测 GT 的网络称为 target network(因为它预测是 target),预测 prediction 的网络称为 online network(因为它是 online更新的,使用 loss,target network 的 EMA更新是 offline )。

对应公式的话,$\theta$ 就是 online network,另一个 $\theta -$ 是 target network。

一些不那么重要的点 #

证明 #

凭什么上面提到的蒸馏方法训练出来的模型能够满足 consistency model 的定义?

作者在论文里给出了证明。

从零开始训练 Consistency Model #

TODO

Latent Consistency Model (LCM) #

理解了consistency model 之后,将其应用到 stable diffusion 只需要解决一个问题,怎么处理 classifier free guidance,其他都是一模一样的。

实现细节 #

Classifier free guidance #

在 LCM 中,作者的做法就是将 guidance scale 作为一个参数加入到 Student model 里,具体实现是调制进 timestep embedding。然后 target network 预测的 GT 改成加了 classifier free guidance 的 model out。

训练的时候,guidance scale 就是随机取的。

Skip Timestep #

朴素的 consistency model 可能是直接蒸馏 1000 步的模型,通过 1000 步的采样序列,但事实上 consistency model 并没有要求一定要这样,因此作者在这里使用更高阶的 sampler 采样了 50 步的序列,用于训练。 实现了加速训练的效果。

LCM-LoRA #

LCM-LoRA 就是用 LoRA 训练 UNet 网络,作者发现这么训出来的 LoRA 具有一定的迁移能力,套到别的模型也能产生加速效果,并且可以和其他 LoRA 组合。

除此之外,可能作者发现即使训练的时候把 guidance scale 喂给网络,但是 inference 的时候改变 guidance scale 也没啥用,然后 LCM-LoRA 就没有这个 condition 了。

另外,EMA作者发现没啥用也去掉了。

讨论和 Note #

  1. LCM 训练非常快,64 batch size,每训练 10 个 step 都可以看到显著变化,训练 100 个 step,4 个 step 的图已经能看了。
  2. LCM-LoRA 学习率要大一点(比如1e-4),太小训了没变化。
  3. 数据集的分辨率和 teacher model 一致,不一致可能训不出来。