I'll be giving an oral talk at #ICLR2025!
🗓 Session 1C — 🕦 11:30 AM SGT.
Title: Learning to Discretize Denoising Diffusion ODEs.
Come by if you're into #GenerativeAI / #DiffusionModels
I'll be giving an oral talk at #ICLR2025!
🗓 Session 1C — 🕦 11:30 AM SGT.
Title: Learning to Discretize Denoising Diffusion ODEs.
Come by if you're into #GenerativeAI / #DiffusionModels
Many thanks to my collaborators Dung Hoang, @anjiliu.bsky.social, @guyvdb.bsky.social, and @mniepert.bsky.social.
[10/n]
Paper: openreview.net/forum?id=xDr...
Code: github.com/vinhsuhi/LD3...
[9/n] Beyond Image Generation
LD3 can be applied to diffusion models in other domains, such as molecular docking.
[8/n] LD3 is fast
LD3 can be trained on a single GPU in under one hour. For smaller datasets like CIFAR-10, training can be completed in less than 6 minutes.
[7/n]
LD3 significantly improves sample quality.
[6/n]
This surrogate loss is theoretically close to the original distillation objective, leading to better convergence and avoiding underfitting.
[5/n] Soft constraint
A potential problem with the student model is its limited capacity. To address this, we propose a soft surrogate loss, simplifying the student's optimization task.
[4/n] How?
LD3 uses a teacher-student framework:
🔹Teacher: Runs the ODE solver with small step sizes.
🔹Student: Learns optimal discretization to match the teacher's output.
🔹Backpropagates through the ODE solver to refine time steps.
[3/n] Key idea
LD3 optimizes the time discretization for diffusion ODE solvers by minimizing the global truncation error, resulting in higher sample quality with fewer sampling steps.
[2/n]
Diffusion models produce high-quality generations but are computationally expensive due to multi-step sampling. Existing acceleration methods either require costly retraining (distillation) or depend on manually designed time discretization heuristics. LD3 changes that.
🚀 Exciting news! Our paper "Learning to Discretize Diffusion ODEs" has been accepted as an Oral at #ICLR2025! 🎉
[1/n]
We propose LD3, a lightweight framework that learns the optimal time discretization for sampling from pre-trained Diffusion Probabilistic Models (DPMs).