GGDM: Learnable Diffusion Model Sampler
November 8, 2023
这篇是 Google 发表在 ICLR 2022 的一篇论文,它讨论了如何构造一个可学习的 Sampler 用于提升采样质量。
文章主要解决三个问题:
- 如何选取待优化的参数?
- 损失函数怎么定义?
- 怎么优化?
GGDM #
1. 如何选取待优化的参数?
主要是各个分布的均值方差,具体不看了,这个 sampler 后面也没什么人用了。
2. 损失函数怎么定义?
使用 Kernel Inception Distance (KID),不能用 pixel wise 的 loss,会让结果变差。
3. 怎么优化?
优化的要必须过完整个 inference process,因此显存问题必须解决,作者在这里使用了 JAX 的 Gradient Rematerialization,应该就是 Gradient Checkpointing 吧。除此之外,inference process 过程分布采样的操作可以用重参数化解决。