GGDM: Learnable Diffusion Model Sampler

November 8, 2023
扩散模型

这篇是 Google 发表在 ICLR 2022 的一篇论文,它讨论了如何构造一个可学习的 Sampler 用于提升采样质量。

文章主要解决三个问题:

  1. 如何选取待优化的参数?
  2. 损失函数怎么定义?
  3. 怎么优化?

GGDM #

1. 如何选取待优化的参数?

主要是各个分布的均值方差,具体不看了,这个 sampler 后面也没什么人用了。

2. 损失函数怎么定义?

使用 Kernel Inception Distance (KID),不能用 pixel wise 的 loss,会让结果变差。

3. 怎么优化?

优化的要必须过完整个 inference process,因此显存问题必须解决,作者在这里使用了 JAX 的 Gradient Rematerialization,应该就是 Gradient Checkpointing 吧。除此之外,inference process 过程分布采样的操作可以用重参数化解决。