(Work in progress I will gradually add more content when having more time:D Please stay tuned :D)

What are Diffusion Models?

Diffusion models are a class of generative models that generate data by progressively denoising a sample from pure noise. They are inspired by non-equilibrium thermodynamics and are based on a forward and reverse diffusion process:

  1. Forward Process (Diffusion Process): A data sample (e.g., an image) is gradually corrupted by adding Gaussian noise over multiple timesteps until it becomes nearly pure noise.
  2. Reverse Process (Denoising Process): A neural network learns to reverse this corruption by gradually removing noise step by step, reconstructing the original data distribution.
Diffusion - How molecules actually move

Analogy: Ink Dissolving in Water Imagine dropping a blob of ink into a glass of water:

  • Forward process (Diffusion Process): Initially, the ink is concentrated in one place (structured data). Over time, it spreads out randomly, blending with the water (adding noise). Eventually, the entire glass becomes a uniformly colored mixture, losing its original structure (complete noise).
  • Reverse process (Denoising Process): If we had a way to perfectly reverse time, we could watch the ink particles retrace their paths, reassembling into the original drop (generating the original data from noise). Diffusion models learn to perform this “reverse process” step by step using machine learning.

Non-Equilibrium Thermodynamics

Thermodynamics studies how energy moves and changes in a system. In equilibrium thermodynamics, systems are in balance—nothing is changing. Non-equilibrium thermodynamics, on the other hand, deals with systems that are constantly evolving, moving between states of disorder and order.

In diffusion models, the forward process (adding noise to data) and the reverse process (removing noise) resemble a non-equilibrium thermodynamic system because they describe an evolving state that moves from order (structured data) to disorder (pure noise) and back to order (reconstructed data).

Brownian Motion

Brownian motion describes the random movement of tiny particles (like pollen grains in water) due to collisions with molecules. This randomness is similar to how noise is added in diffusion models.

Advantages of Diffusion Models

Diffusion models offer several key advantages over traditional generative models like GANs and VAEs:

  1. High-Fidelity Samples: Unlike VAEs and GANs which generate samples in one step, diffusion models create samples gradually by denoising. This step-by-step process allows the model to first establish coarse image structure before refining fine details, resulting in higher quality outputs.

  2. Training Stability: Diffusion models are easier to train compared to GANs as they use a single tractable likelihood loss. They don’t suffer from training instabilities like mode collapse that often plague GANs.

  3. Sample Diversity: Similar to VAEs, diffusion models maximize likelihood which ensures coverage of all modes in the training dataset. This leads to more diverse outputs compared to GANs which can suffer from mode collapse.

  4. Flexible Architecture: The multi-step denoising process enables additional functionalities like inpainting or image-to-image generation by manipulating the input noise, without requiring architectural changes.

  5. Consistent Quality: The gradual denoising process is more robust and consistent compared to GANs where quality can vary significantly between samples.

The main trade-off is generation speed - diffusion models require multiple neural network passes to generate samples, making them slower than single-pass models like GANs and VAEs. However, various sampling optimization techniques have been developed to significantly reduce this computational overhead.

Disadvantages of Diffusion Models

While diffusion models have significant advantages, they also come with some trade-offs:

  • Slow Sampling: The reverse process requires multiple denoising steps, making inference slower compared to GANs.
  • Compute Intensive: Training requires large amounts of data and computational power.
  • Memory Usage: They require storing multiple intermediate noise distributions, making them more memory-intensive.
  • Complex Implementation: The multi-step nature of diffusion models makes them more complex to implement compared to single-step models.

Mathematical Foundation

Diffusion models are built on a deep interplay between differential equations, probability theory, and variational inference. To understand why the model works, we need to trace how these ideas connect: from describing how systems evolve over time, to modeling probability densities, to designing trainable objectives.

Differential Equations: ODEs and SDEs

We start with ordinary differential equations (ODEs), which describe how a system changes deterministically over time based on its current state.

\[\frac{dx}{dt} = f(x, t)\]

where \(x(t)\) is the state of the system - the function we want to solve -and \(t\) is time. \(f(x, t)\) defines how \(x\) changes over time.

This is a useful starting point, but in real-world data generation, we must account for randomness. That brings us to stochastic differential equations (SDEs), which incorporate random fluctuations into the system.

\[dx = f(x, t) dt + g(x, t) dW_t\]

where the drift term \(f(x, t) dt\) captures the deterministic trends, while the diffusion term \(g(x, t) dW_t\) captures the random fluctuations via a Wiener process \(W_t\).

👉 Motivation: Diffusion models inject noise step by step, so SDEs provide the natural language to describe this stochastic corruption process. More specifically, the drift term \(f(x, t)\) is the shift of the mean of the distribution, and the diffusion term \(g(x, t)\) is the spread of the distribution - injecting Gaussian noise.

Forward and Reverse Diffusion Processes

Forward Process (Adding Noise)

The forward diffusion process transform a data sample \(x_0\) into pure noise \(x_T\) over time:

\[dx = f(x, t)dt + g(t) dW_t\]

Intuitively, the drift term \(f(x, t) dt\) shifts the mean of the distribution by a deterministic amount \(f(x,t)\) (i.e., is a function of \(x\) and current time \(t\)) to a zero mean distribution. The diffusion term \(g(t) dW_t\) spreads the distribution by injecting Gaussian noise, increasing the variance of the distribution.

Note that \(g(t)\) is a function of current time \(t\) only, to guarantee that each noised distribution remains Gaussian with a know mean and variance. This make the forward process to be simple and tractable, that we can have exact sampling formula for each specific time step \(t\), i.e., \(q (x_t \mid x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t} x_0, (1-\bar{\alpha}_t)I)\).

Reverse Process (Removing Noise) In order to generate data from pure noise \(x_T\), we need to reverse the diffusion process by Reverse-Time SDE (Anderson 1982).

\[dx = \left[ f(x,t) - \frac{1}{2} g^2(t) \nabla_x \log p_t(x) \right] dt + g(t) d\tilde{W}_t\]

where \(\nabla_x \log p_t(x)\) is the score function, which estimates the structure of data at time \(t\) - how likely different data points are at each step. \(d\tilde{W}_t\) is another Wiener process but in the reverse direction.

👉 Motivation: Since \(f(x,t)\) and \(g(t))\) are known, to reverse noise, we must know the score function. So we train a neural network to approximate the score function \(\nabla_x \log p_t(x)\). This is the core of the diffusion model.

Euler Method for Numerical Integration

To simulate the reverse process, we need a numerical method to integrate the SDE backward through time. The Euler–Maruyama method (a stochastic generalization of the Euler method) provides a simple, first-order approximation.

For the variance-preserving (VP) SDE used in diffusion models, the deterministic part of the reverse process (ignoring noise sampling) can be expressed as an ODE:

\[\frac{dx}{dt} = f(x,t) - \frac{1}{2} g^2(t) \nabla_x \log p_t(x)\]

Now, using the Euler method to integrate this ODE backward in time from \(t_{n+1}\) to \(t_n\):

\[x_{t_n} = x_{t_{n+1}} + (t_{n+1} - t_{n}) \frac{dx}{dt} \bigg|_{x=x_{t_{n+1}}, t=t_{n+1}}\]

Substituing the ODE into the Euler method, we get:

\[x_{t_n} = x_{t_{n+1}} + (t_{n+1} - t_{n}) \left[ f(x_{t_{n+1}}, t_{n+1}) - \frac{1}{2} g^2(t_{n+1}) \nabla_x \log p_{t_{n+1}}(x_{t_{n+1}}) \right]\]

This stepwise update forms the foundation of diffusion sampling algorithms like DDPM and DDIM, where the score term is replaced by the neural network’s prediction.

Implementation of the Euler method

In the context of consistency models, the update rule is derived from the Probability Flow ODE (PF-ODE):

\[\frac{dx}{dt} = -t s_{\phi}(x,t)\]

where \(s_{\phi}(x,t)\) is a learned score function related to the model output \(f_{\theta}(x,t)\). From the definition of the consistency function (please refer to the Consistency Models section), we have:

\[s_{\phi}(x,t) = \frac{x - f_{\theta}(x,t)}{t^2}\]

Substituing this into the PF-ODE Euler update:

\[\hat{x}_{t_n} = x_{t_{n+1}} - (t_n - t_{n+1}) t_{n+1} s_{\phi}(x_{t_{n+1}}, t_{n+1})\]

We get: \(\hat{x}_{t_n} = x_{t_{n+1}} - (t_n - t_{n+1}) \frac{x - f_{\theta}(x,t)}{t}\)

This formula defines a single-step explicit Euler integration: it can predict the next step \(x_{t_n}\) from the current step \(x_{t_{n+1}}\).

Python implementation:

@th.no_grad()
def euler_solver(samples, t, next_t, x0):
    x = samples
    if teacher_model is None:
        denoiser = x0
    else:
        denoiser = teacher_denoise_fn(x, t) # f_{\theta}(x,t) - consistency model output
    d = (x - denoiser) / append_dims(t, dims)
    samples = x + d * append_dims(next_t - t, dims)

    return samples

👉 Interpretation:

  • d is the difference between the current step and the next step.
  • next_t - t is the time step size (negative for reverse time integration).
  • The Euler method uses only the slope at the start of the interval, making it simple but potentially less accurate for large steps.

Heun Method (Improved Euler / Predictor–Corrector)

The Heun method is a second-order numerical scheme that improves on Euler by estimating the derivative twice — at the beginning and at the end of the time interval — and then averaging them. This yields better stability and smaller integration error, especially for stiff or nonlinear dynamics like those in diffusion models.

The update consists of two stages:

  • Predictor (Euler step):
\[\tilde{x}_{t_n} = x_{t_{n+1}} - (t_n - t_{n+1}) d_{t_{n+1}}\]

where \(d_{t_{n+1}} = \frac{x_{t_{n+1}} - f_{\theta}(x_{t_{n+1}}, t_{n+1})}{t_{n+1}}\).

  • Corrector (Second-order correction): Evaluate a new derivative \(d_{t_n}\) at the predicted point \(\tilde{x}_{t_n}\), then update the final point:
\[\hat{x}_{t_n} = x_{t_{n+1}} + \frac{1}{2}(t_n - t_{n+1}) (d_{t_{n+1}} + d_{t_n})\]
@th.no_grad()
def heun_solver(samples, t, next_t, x0):
    x = samples
    if teacher_model is None:
        denoiser = x0
    else:
        denoiser = teacher_denoise_fn(x, t) # f_{\theta}(x,t) - consistency model output

    
    # IMPORTANT - Euler method
    d = (x - denoiser) / append_dims(t, dims)
    samples = x + d * append_dims(next_t - t, dims)
    if teacher_model is None:
        denoiser = x0
    else:
        denoiser = teacher_denoise_fn(samples, next_t) # f_{\theta}(x,t) - consistency model output

    next_d = (samples - denoiser) / append_dims(next_t, dims)
    samples = x + (d + next_d) * append_dims((next_t - t) / 2, dims)

    return samples

Fokker-Planck Equation: From Trajectories to Distributions

SDEs describe how the distribution of a system changes over time, but what about the distribution of data as noise accumulates? The Fokker-Planck equation bridges the gap between trajectories and distributions, explaining how noise pushes data distributions \(p_t(x)\) toward isotropic Gaussian distributions.

\[\frac{\partial p_t(x)}{\partial t} = -\nabla_x \cdot (f(x,t) p_t(x)) + \frac{1}{2} \nabla_x \cdot (g(t)^2 \nabla_x p_t(x))\]

where \(p_t(x)\) is the distribution of the data at time \(t\).

The first term \(-\nabla_x \cdot (f(x,t) p_t(x))\) describes the change of the probability density \(p_t(x)\) with the drift term \(f(x,t)\) (as the velocity of that mass). The divergence operator \(\nabla_x \cdot\) measures how much the mass is spreading out (positive divergence) or converging/concentrating (negative divergence) at any given point \(x\). The whole term \(- \nabla_x \cdot (f(x,t) p_t(x))\) describes a rate of change of the probability density \(p_t(x)\) at any given point \(x\), where the positive value means the mass is flowing away from \(x\), causing \(p_t(x)\) to decrease (hence the negative sign), and vice versa.

The second term \(\frac{1}{2} \nabla_x \cdot (g(t)^2 \nabla_x p_t(x))\) presents the spreading and smoothing effect of the probability density \(p_t(x)\) over time due to the influence of the random fluctuations. More specifically, \(g(t)\) controls the magnitude of the random noise. The \(\nabla_x p_t(x)\) term describes the steepness or slope of \(p_t(x)\) at any given point \(x\). The whole term \(\frac{1}{2} \nabla_x \cdot (g(t)^2 \nabla_x p_t(x))\) describes a rate of change of the probability density \(p_t(x)\), i.e., the larger the gradient, the more the density rising sharply around \(x\). Similar to the first term, \(\nabla_x \cdot\) measures how much the mass is spreading out (positive divergence) or converging/concentrating (negative divergence) at any given point \(x\) with two differences:

  • It contains the random fluctuations term \(g(t)^2\) instead of the drift term \(f(x,t)\), introducing randomness into the system.
  • It proportional to the gradient \(\nabla_x p_t(x)\), meaning that the slope/sharp region is more affected/spread out than the flat region (which has smaller gradient \(\nabla_x p_t(x)\)).

Score Matching and Denoising

Since the reverse SDE depends on the score function \(\nabla_x \log p_t(x)\) (which is intractable), we need to design a training objective to approximate this function. Vincent et al. proposed denoising score matching (Vincent 2011) to approximate by training a neural network \(s_{\theta}(x_t, t)\) to approximate the conditional score function \(\nabla_x \log q(x_t \mid x_0, \epsilon)\) (where \(q\) is a tractable forward process, \(dx = f(x, t)dt + g(t)dW_t\)) assuming that \(q(x_t \mid x_0) \approx p_t(x_t)\) at time \(t\) (which is a reasonable assumption).

This objective aims to minimize the difference:

\[\mathbb{E}_{p(x_0), \epsilon \sim \mathcal{N}(0, I)} \left[ \left\| s_{\theta}(x_t, t) - \nabla_x \log p_t(x_t \mid x_0, \epsilon) \right\|^2 \right]\]

Variational Perspective and KL Minimization

Another way to frame diffusion models is to consider them as a variational inference problem. The forward process \(q(x_{0:T})\) is a know noising chain of distributions, and the reverse process \(p_{\theta}(x_{0:T})\) is learned. Therefore, we can use the variational lower bound (ELBO) to train the model.

\[\mathbb{E}_{q(x_{0:T})} \left[ D_{KL} \left( q(x_{t-1} \mid x_t, x_0) \parallel p_{\theta}(x_{t-1} \mid x_t) \right) \right]\]

ELBO

Evidence lower bound (ELBO) is a key concept in variational inference, which is used in VAEs to approximate the log-likelihood of the data.

Let \(X\) and \(Z\) be random variables, jointly distributed with distribution \(p_\theta\). For example, \(p_\theta(X)\) is the marginal distribution of \(X\), and \(p_\theta(Z \mid X)\) is the conditional distribution of \(Z\) given \(X\). Then, for a sample \(x \sim p_{\text{data}}\), and any distribution \(q_\phi\), the ELBO is defined as

\[L(\phi, \theta; x) := \mathbb{E}_{z\sim q_\phi(\cdot|x)} \left[\ln \frac{p_\theta(x,z)}{q_\phi(z|x)}\right].\]

The ELBO can equivalently be written as

\[\begin{aligned} L(\phi, \theta; x) &= \mathbb{E}_{z\sim q_\phi(\cdot|x)}[\ln p_\theta(x,z)] + H[q_\phi(z \mid x)] \\ &= \ln p_\theta(x) - D_{KL}(q_\phi(z \mid x) || p_\theta(z \mid x)). \end{aligned}\]

In the first line, \(H[q_\phi(z \mid x)]\) is the entropy of \(q_\phi\), which relates the ELBO to the Helmholtz free energy. In the second line, \(\ln p_\theta(x)\) is called the evidence for \(x\), and \(D_{KL}(q_\phi(z \mid x) \mid\mid p_\theta(z \mid x))\) is the Kullback-Leibler divergence between \(q_\phi\) and \(p_\theta\). Since the Kullback-Leibler divergence is non-negative, \(L(\phi, \theta; x)\) forms a lower bound on the evidence (ELBO inequality)

\[\ln p_\theta(x) \geq \mathbb{E}_{z\sim q_\phi(\cdot|x)}\left[\ln \frac{p_\theta(x,z)}{q_\phi(z|x)}\right].\]

Deep-dive topics about VAE might including:

Tweedie’s formula

Finally, Tweedie’s formula gives a neat probabilistic justification for the denoising score matching objective:

\[\mathbb{E} [ x_0 \mid x_t] = x_t + \sigma_t^2 s_{\theta}(x_t, t)\]

where \(s_{\theta}(x_t, t)\) is the score function to be learned by the neural network. It shows that the posterior mean of clean data given a noisy data is just the noisy sample plus a correction term proportional to the score function.

In some papers such as ESD (Gandikota et al. 2023), where we need to fine-tune the pretrained model and match the score function of the original and the fine-tuned model, they use Tweedie’s formula to justify the matching term.

Variants of Diffusion Models

The original formulation of diffusion models can be implemented in several ways. Two of the most influential variants are Denoising Diffusion Probabilistic Models (DDPMs) and Denoising Diffusion Implicit Models (DDIMs). Both share the same forward noising process but differ in how they perform the reverse (denoising) process during inference.

DDPM

DDPM (Ho et al. 2020) is the classic diffusion model:

  • The Reverse process is defined as a Markov chain, where each step \(x_t \to x_{t-1}\) involves sampling from a Gaussian distribution conditioned on \(x_t\).
  • Sampling os stochastic, even with the same starting noise \(x_T\), the generated data \(x_0\) is different.
  • While highly effective and stable (compared to GANs), DDPMs require hundreds to thousands of steps to slowly add/remove noise, which makes inference slow.

Read more about DDPM in another blog post here

DDIM

DDIM (Song et al. 2020) builds on DDPM but introduces a non-Markovian reverse process, enabling faster sampling. It also allows us to use the same training process as DDPM, e.g., we can use pretrained DDPM models to generate data.

The sampling process of DDIM is as follows:

\[x_{t-1} = \sqrt{\alpha_{t-1}} \left(\frac{x_t - \sqrt{1-\alpha_t}\epsilon_\theta^{(t)}(x_t)}{\sqrt{\alpha_t}}\right) + \sqrt{1-\alpha_{t-1}-\sigma_t^2} \cdot \epsilon_\theta^{(t)}(x_t) + \sigma_t\epsilon_t\]

where the first term represents the “predicted \(x_0\)”, the second term is the “direction pointing to \(x_t\)”, and the last term is random noise.

By setting \(\sigma_t = 0\) for all \(t\), DDIM becomes a deterministic process given \(x_{t-1}\) and \(x_0\), except for \(t=1\). In other words, the intermediate steps \(x_{T-1}, x_{T-2}, \ldots, x_1\) are deterministic given starting noise \(x_T\).

Read more about DDIM in another blog post here

Score Matching

Score-based generative models (Song and Ermon 2019, Song et al. 2021(published in ICLR 2021, are a family of generative models that learn to estimate the score function — the gradient of the log-density of the data distribution \(s_\theta(x, t) \approx \nabla_x \log p_t(x)\). Intuitively, the score function tells us which direction in the input space increases the likelihood of the data. By learning this function at different noise levels (at different time steps \(t\)), the model learns how to denoise a sample toward realistic data.

Training Objective

The ideal objective of score matching is:

\[\min_{\theta} \mathbb{E}_{p_t(x)} \left[ \| s_\theta(x, t) - \nabla_x \log p_t(x) \|^2 \right]\]

which can be shown to be equivalent to

\[\min_{\theta} \mathbb{E}_{p_t(x)} \left[ \text{trace}(\nabla_x s_{\theta}(x, t)) + \frac{1}{2} \| s_{\theta}(x, t) \|^2 \right]\]

With deep networks and high-dimensional data, it is difficult to obtain the expectation over the data distribution, especially with the \(\text{trace}(\nabla_x s_{\theta}(x, t))\).

One of the popular ways to overcome this is to use Denoising Score Matching (DSM) (Vincent 2011). The idea is first to perturbs the data \(x\) with a known Gaussian noise process \(q(x_t \mid x_0)\), then train the score network \(s_\theta(x_t, t)\) to denoise the perturbed data \(x_t\) back to the original data \(x_0\).

\[\min_{\theta} \mathbb{E}_{q(x_t \mid x), p_t(x)} \left[ \| s_\theta(x_t, t) - \nabla_{x_t} \log q(x_t \mid x) \|^2 \right]\]

For Gaussian perturbations:

\[\nabla_{x_t} \log q(x_t \mid x) = -\frac{x_t - \alpha_t x_0}{\sigma_t^2}\]

Hence, we can train \(s_\theta(x_t, t)\) to match this target directly.

Sampling

In Song and Ermon 2019, they use Langevin dynamics to samples data with the score function \(s_\theta(x, t)\). Given a fixed step size \(\epsilon\), and an inital value \(x_T \sim \mathcal{N}(0, I)\), the Langevin dynamics recursively update the data as follows:

\[x_{t-1} = x_t + \frac{\epsilon}{2} s_{\theta}(x_t, t) + \sqrt{\epsilon} z_t\]

where \(z_t \sim \mathcal{N}(0, I)\).

\(\epsilon\)-prediction

Instead of predicting the score directly, DDPM variants predict the noise \(\epsilon\) added to the data, which is equivalent to predicting the score but more numerically stable.

The data corruption process is:

\[x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon\]

Then the score function relates to the noise predictor as:

\[s_\theta(x_t, t) = -\frac{\epsilon_{\theta}(x_t, t)}{\sqrt{1 - \bar{\alpha}_t}}\]

The training loss becomes:

\[\mathcal{L}_{DDPM}(\theta) = \mathbb{E}_{t, x_0, \epsilon} \left[ \| \epsilon_\theta(x_t, t) - \epsilon \|^2 \right]\]

From DDPM to Score Matching

It can be seen that the key idea of both DDPM and SMLD is to perturb the data with a known noise distribution, and then denoise back to the original data. In the following work (Song et al. 2021), the authors unified the two frameworks into a single framework called Score Matching, with an infinite number of noise scales.

First, the diffusion process can be defined as a SDE:

\[dx = f(x, t) dt + g(t) dW_t\]

where \(f(x, t)\) is the drift term and \(g(t)\) is the diffusion term and \(W_t\) is the Wiener process.

The reverse diffusion process, which denoises the data from \(x_T \sim \mathcal{N}(0, I)\) to \(x_0 \sim p(x)\), can be defined as a reverse-time SDE:

\[dx = [f(x, t) - g(t)^2 \nabla_x \log p_t(x)] dt + g(t) d\bar{W}_t\]

where \(\bar{W}_t\) is the reverse-time Wiener process.

Different diffusion model families have different choices of \(f(x, t)\) and \(g(t)\):

  • Variance Exploding (VE) SDE: \(f(x, t) = 0, g(t) = \sqrt{\frac{d\sigma_t^2}{dt}}\) as in SMLD
  • Variance Preserving (VP) SDE: \(f(x, t) = -\frac{1}{2}\beta_t x, g(t) = \sqrt{\beta_t}\) as in DDPM

Sampling with the Euler-Maruyama method

Once trained, the score network defines the reverse-time dynamics that transform Gaussian noise \(x_T \sim \mathcal{N}(0, I)\) to the data sample \(x_0 \sim p(x)\). We can simulate the reverse SDE numerically using the Euler-Maruyama method:

\[x_{t-1} = x_t + (t_n - t_{n+1}) \frac{dx}{dt} \big| _{x=x_t, t = t_n}\] \[x_{t-1} = x_t + (t_n - t_{n+1}) \left[ f(x_t, t_n) - g(t_n)^2 \nabla_x \log p_{t_n}(x_t) \right] + g(t_n) \sqrt{t_n - t_{n+1}} z_t\]

where \(z_t \sim \mathcal{N}(0, I)\).

This adds both a deterministic drift (using the learned score function) and a stochastic noise term to preserve the sample diversity.

If we remove the noise term (set \(z_t = 0\)), we obtain the Probability Flow ODE, a deterministic trajectory equivalent in distribution to the reverse-time SDE, which is the foundation of DDIM deterministic sampling.

Implementation of Score Matching

The official implementation of Score Matching is https://github.com/yang-song/score_sde_pytorch.

In Diffusers lib, the Score SDV VP and VE schedulers can be found here and here.

Notable functions/methods in the official implementation (PyTorch version):

  def discretize(self, x, t):
    """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.

    Useful for reverse diffusion sampling and probabiliy flow sampling.
    Defaults to Euler-Maruyama discretization.

    Args:
      x: a torch tensor
      t: a torch float representing the time step (from 0 to `self.T`)

    Returns:
      f, G
    """
    dt = 1 / self.N
    drift, diffusion = self.sde(x, t)
    f = drift * dt
    G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))
    return f, G

  def reverse(self, score_fn, probability_flow=False):
    """Create the reverse-time SDE/ODE.

    Args:
      score_fn: A time-dependent score-based model that takes x and t and returns the score.
      probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
    """
    N = self.N
    T = self.T
    sde_fn = self.sde
    discretize_fn = self.discretize

    # Build the class for reverse-time SDE.
    class RSDE(self.__class__):
      def __init__(self):
        self.N = N
        self.probability_flow = probability_flow

      @property
      def T(self):
        return T

      def sde(self, x, t):
        """Create the drift and diffusion functions for the reverse SDE/ODE."""
        drift, diffusion = sde_fn(x, t)
        score = score_fn(x, t)
        drift = drift - diffusion[:, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.)
        # Set the diffusion function to zero for ODEs.
        diffusion = 0. if self.probability_flow else diffusion
        return drift, diffusion

      def discretize(self, x, t):
        """Create discretized iteration rules for the reverse diffusion sampler."""
        f, G = discretize_fn(x, t)
        rev_f = f - G[:, None, None, None] ** 2 * score_fn(x, t) * (0.5 if self.probability_flow else 1.)
        rev_G = torch.zeros_like(G) if self.probability_flow else G
        return rev_f, rev_G

    return RSDE()

IMPORTANT NOTE: For models trained with VP SDE, the marginal distribution of \(x_t\) given clean data \(x_0\) is Gaussian: \(x_t \sim \mathcal{N}(\sqrt{\bar{\alpha}_t} x_0, \sigma_t^2 I)\) or in general SDE notation: \(x_t = \text{mean}(x_0, t) + \text{std}(t) \dot z\) where \(z \sim \mathcal{N}(0, I)\).

Therefore, when training, we sample

mean, std = sde.marginal_prob(batch, t)
perturbed_data = mean + std * z

class VPSDE(SDE):
  def __init__(self, beta_min=0.1, beta_max=20, N=1000):
    """Construct a Variance Preserving SDE.

    Args:
      beta_min: value of beta(0)
      beta_max: value of beta(1)
      N: number of discretization steps
    """
    super().__init__(N)
    self.beta_0 = beta_min
    self.beta_1 = beta_max
    self.N = N
    self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N)
    self.alphas = 1. - self.discrete_betas
    self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
    self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
    self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)

  @property
  def T(self):
    return 1

  def sde(self, x, t):
    beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
    drift = -0.5 * beta_t[:, None, None, None] * x
    diffusion = torch.sqrt(beta_t)
    return drift, diffusion

  def marginal_prob(self, x, t):
    log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
    mean = torch.exp(log_mean_coeff[:, None, None, None]) * x
    std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
    return mean, std

  def prior_sampling(self, shape):
    return torch.randn(*shape)

  def prior_logp(self, z):
    shape = z.shape
    N = np.prod(shape[1:])
    logps = -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3)) / 2.
    return logps

  def discretize(self, x, t):
    """DDPM discretization."""
    timestep = (t * (self.N - 1) / self.T).long()
    beta = self.discrete_betas.to(x.device)[timestep]
    alpha = self.alphas.to(x.device)[timestep]
    sqrt_beta = torch.sqrt(beta)
    f = torch.sqrt(alpha)[:, None, None, None] * x - x
    G = sqrt_beta
    return f, G

The SDE loss function. IMPORTANT, in the implementation, losses = (score x std + z)**2 is used. Why is that?

It is because of the Gaussian property where \(p(x_t \mid x_0) = \mathcal{N}(\mu_t, \sigma_t^2 I)\).

\[\nabla_x \log p(x_t \mid x_0) = -\frac{1}{\sigma_t^2} (x_t - \mu_t) = - \frac{z}{\sigma_t}\]

Therefore, the final loss function is:

\[\mathcal{L}_{SDE}(\theta) = \mathbb{E}_{t, x_0, z} \left[ \| s_\theta(x_t, t) + \frac{z}{\sigma_t} \|^2 \right]\]

def get_sde_loss_fn(sde, train, reduce_mean=True, continuous=True, likelihood_weighting=True, eps=1e-5):
  """Create a loss function for training with arbirary SDEs.

  Args:
    sde: An `sde_lib.SDE` object that represents the forward SDE.
    train: `True` for training loss and `False` for evaluation loss.
    reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions.
    continuous: `True` indicates that the model is defined to take continuous time steps. Otherwise it requires
      ad-hoc interpolation to take continuous time steps.
    likelihood_weighting: If `True`, weight the mixture of score matching losses
      according to https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended in our paper.
    eps: A `float` number. The smallest time step to sample from.

  Returns:
    A loss function.
  """
  reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)

  def loss_fn(model, batch):
    """Compute the loss function.

    Args:
      model: A score model.
      batch: A mini-batch of training data.

    Returns:
      loss: A scalar that represents the average loss value across the mini-batch.
    """
    score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous)
    t = torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps
    z = torch.randn_like(batch)
    mean, std = sde.marginal_prob(batch, t)
    perturbed_data = mean + std[:, None, None, None] * z
    score = score_fn(perturbed_data, t)

    if not likelihood_weighting:
      losses = torch.square(score * std[:, None, None, None] + z)
      losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1)
    else:
      g2 = sde.sde(torch.zeros_like(batch), t)[1] ** 2
      losses = torch.square(score + z / std[:, None, None, None])
      losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) * g2

    loss = torch.mean(losses)
    return loss

  return loss_fn

The score function



def get_score_fn(sde, model, train=False, continuous=False):
  """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.

  Args:
    sde: An `sde_lib.SDE` object that represents the forward SDE.
    model: A score model.
    train: `True` for training and `False` for evaluation.
    continuous: If `True`, the score-based model is expected to directly take continuous time steps.

  Returns:
    A score function.
  """
  model_fn = get_model_fn(model, train=train)

  if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
    def score_fn(x, t):
      # Scale neural network output by standard deviation and flip sign
      if continuous or isinstance(sde, sde_lib.subVPSDE):
        # For VP-trained models, t=0 corresponds to the lowest noise level
        # The maximum value of time embedding is assumed to 999 for
        # continuously-trained models.
        labels = t * 999
        score = model_fn(x, labels)
        std = sde.marginal_prob(torch.zeros_like(x), t)[1]
      else:
        # For VP-trained models, t=0 corresponds to the lowest noise level
        labels = t * (sde.N - 1)
        score = model_fn(x, labels)
        std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]

      score = -score / std[:, None, None, None]
      return score

  elif isinstance(sde, sde_lib.VESDE):
    def score_fn(x, t):
      if continuous:
        labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
      else:
        # For VE-trained models, t=0 corresponds to the highest noise level
        labels = sde.T - t
        labels *= sde.N - 1
        labels = torch.round(labels).long()

      score = model_fn(x, labels)
      return score

  else:
    raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")

  return score_fn

Implementation of SMLD

From the same repository of SDE Pytorch


def get_smld_loss_fn(vesde, train, reduce_mean=False):
  """Legacy code to reproduce previous results on SMLD(NCSN). Not recommended for new work."""
  assert isinstance(vesde, VESDE), "SMLD training only works for VESDEs."

  # Previous SMLD models assume descending sigmas
  smld_sigma_array = torch.flip(vesde.discrete_sigmas, dims=(0,))
  reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)

  def loss_fn(model, batch):
    model_fn = mutils.get_model_fn(model, train=train)
    labels = torch.randint(0, vesde.N, (batch.shape[0],), device=batch.device)
    sigmas = smld_sigma_array.to(batch.device)[labels]
    noise = torch.randn_like(batch) * sigmas[:, None, None, None]
    perturbed_data = noise + batch
    score = model_fn(perturbed_data, labels)
    target = -noise / (sigmas ** 2)[:, None, None, None]
    losses = torch.square(score - target)
    losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) * sigmas ** 2
    loss = torch.mean(losses)
    return loss

  return loss_fn

Implementation of DDPM


def get_ddpm_loss_fn(vpsde, train, reduce_mean=True):
  """Legacy code to reproduce previous results on DDPM. Not recommended for new work."""
  assert isinstance(vpsde, VPSDE), "DDPM training only works for VPSDEs."

  reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)

  def loss_fn(model, batch):
    model_fn = mutils.get_model_fn(model, train=train)
    labels = torch.randint(0, vpsde.N, (batch.shape[0],), device=batch.device)
    sqrt_alphas_cumprod = vpsde.sqrt_alphas_cumprod.to(batch.device)
    sqrt_1m_alphas_cumprod = vpsde.sqrt_1m_alphas_cumprod.to(batch.device)
    noise = torch.randn_like(batch)
    perturbed_data = sqrt_alphas_cumprod[labels, None, None, None] * batch + \
                     sqrt_1m_alphas_cumprod[labels, None, None, None] * noise
    score = model_fn(perturbed_data, labels)
    losses = torch.square(score - noise)
    losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1)
    loss = torch.mean(losses)
    return loss

  return loss_fn

Flow Matching

Fundammentals Concepts in Flow Matching

Normalizing Flow: A class of generative models that learns a transformation (or “flow”) to map a know prior distribution \(p_0\) to a target distribution \(p_1\) through a family of intermediate marginal distributions \(p_t\), where \(t \in [0, 1]\). A key requirement is that the transformation must be invertible (bijective).

Continuous Normalizing Flow: Uses ordinary differential equation (ODE) to define continuous-time transformations between distributions.

Flow and Velocity Field:

  • The flow \(\psi_t(x)\) describes the trajectory of a point \(x\) over time.
  • The velocity field \(u_t(x)\) specifies the instantaneous direction and speed of movement
  • These are related by the ODE: \(\frac{d}{dt} \psi_t(x) = u_t (\psi_t (x))\)
  • The induced density \(p_t(x)\) evolves according to the continuity equation: \(\frac{\partial p_t(x)}{\partial t} + \nabla_x \cdot \big( u_t(x) \, p_t(x) \big) = 0.\)

This equation shows how a point moves along the flow path: \(x_t \rightarrow x_{t+1} = x_t + dt * u_t(x_t)\) at time \(t\).

Key insight: The velocity field \(u_t(x)\) is the only component neccessary to sample from \(p_t\) by solving the ODE. Therefore, flow matching aims to learn the velocity field \(u_t(x)\).

Derivation of the Flow Matching Objective

Starting Objective: Approximate the velocity field \(u_t(x)\) with the learned velocity field \(v_{\theta}(t, x)\).

\[\mathcal{L}_{FM} (\theta) = \mathbb{E}_{x_t \sim p_t(x)} \left[ \| v_{\theta}(t, x_t) - u_t(x_t) \|^2 \right]\]

Step 1: Expand the squared norm:

\[\mathcal{L}_{FM} (\theta) = \mathbb{E}_{x_t \sim p_t(x)} \left[ \| v_{\theta}(t, x_t) - u_t(x_t) \|^2 \right] = \mathbb{E}_{x_t \sim p_t(x)} \left[ \| v_{\theta}(t, x_t) \|^2 - 2 \langle v_{\theta}(t, x_t), u_t(x_t) \rangle + \| u_t(x_t) \|^2 \right]\]

Step 2: Express the velocity field as a conditional expectation:

\[u_t(x_t) = \int u_t(x_t \mid x_1) \frac{p_t (x_t \mid x_1) q(x_1)}{p_t(x_t)} dx_1\]

Interpretation: The velocity at \(x_t\) is a weighted average of conditional velocities \(u_t(x_t \mid x_1)\) from all possible data points \(x_1\). Point \(x_1\) that are “closer” to \(x_t\) (higher probability \(p_t (x_t \mid x_1)\)) contribute more to the velocity at \(x_t\).

Step 3: Substitute into the cross-term expectation (correlation between \(v_{\theta}(t, x_t)\) and \(u_t(x_t)\))

\[\mathbb{E}_{x_t \sim p_t(x)} \left[ \langle v_{\theta}(t, x_t), u_t(x_t) \rangle \right] = \int p_t(x_t) v_{\theta}(t, x_t) \cdot u_t(x_t) dx_t\]

Substitute \(u_t(x_t)\) to the above equation:

\[= \int \int v_{\theta}(t, x_t) \cdot u_t(x_t \mid x_1) \cdot p_t(x_t \mid x_1) \cdot q(x_1) dx_1 dx_t\] \[= \mathbb{E}_{x_t \sim p_t(x_t \mid x_1), x_1 \sim q(x_1)} \left[ v_{\theta}(t, x_t) \cdot u_t(x_t \mid x_1) \right]\]

Step 4: Rewrite the full objective using conditional expectation:

\[\mathcal{L}_{FM} (\theta) = \mathbb{E}_{x_1 \sim q(x_1), x_t \sim p_t(x_t \mid x_1)} \left[ \| v_{\theta}(t, x_t) \|^2 - 2 \langle v_{\theta}(t, x_t), u_t(x_t \mid x_1) \rangle + \| u_t(x_t) \|^2 \right]\]

Step 5: Add and subtract the term \(\| u_t(x_t \mid x_1) \|^2\)

\[\mathcal{L}_{FM} (\theta) = \mathbb{E}_{x_1 \sim q(x_1), x_t \sim p_t(x_t \mid x_1)} \left[ \| v_{\theta}(t, x_t) - u_t(x_t \mid x_1) \|^2 \right] + \mathbb{E}_{x_1 \sim q(x_1), x_t \sim p_t(x_t \mid x_1)} \left[ \| u_t(x_t) \|^2 - \| u_t(x_t \mid x_1) \|^2\right]\]

Step 6: Drop constant terms that are independent of \(\theta\):

\[\mathcal{L}_{FM} (\theta) = \mathbb{E}_{x_1 \sim q(x_1), x_t \sim p_t(x_t \mid x_1)} \left[ \| v_{\theta}(t, x_t) - u_t(x_t \mid x_1) \|^2 \right]\]

Practical Implementation:

Simply choose the linear interpolation path \(X_t = (1 - t) X_0 + t X_1\), then the velocity field \(u_t(x_t)\) is:

\[u_t(x_t \mid x_1) = \frac{d}{dt} X_t = X_1 - X_0\]

This give us a tractable training objective where we sample:

  • A time step \(t \sim \mathcal{U}(0, 1)\)
  • A data ppont \(x_1 \sim q(x_1)\)
  • A noise sample \(x_0 \sim \mathcal{N}(0, I)\)
  • Construct the interpolated sample \(x_t = (1 - t) x_0 + t x_1\)
  • Train to predict the velocity field \(v_{\theta}(t, x_t) = x_1 - x_0\)

How to sampling from Flow Matching model

The sampling process of Flow Matching model is similar to the diffusion model, where we start from a noise sample \(x_0 \sim \mathcal{N}(0, I)\) and iteratively sample the next step \(x_{t+dt}\) by Euler method:

\[x_{t+dt} = x_t + dt * v_{\theta}(t+dt/2, x_t+dt/2 * v_{\theta}(t, x_t))\]

where \(v_{\theta}(t, x_t)\) is the velocity field predicted by the neural network.

Flow Matching Code Example

A Standalone Flow Matching code - from [4]

import torch 
from torch import nn, Tensor

import matplotlib.pyplot as plt
from sklearn.datasets import make_moons

# Define the Flow
class Flow(nn.Module):
    def __init__(self, dim: int = 2, h: int = 64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim + 1, h), nn.ELU(),
            nn.Linear(h, h), nn.ELU(),
            nn.Linear(h, h), nn.ELU(),
            nn.Linear(h, dim))
    
    def forward(self, t: Tensor, x_t: Tensor) -> Tensor:
        return self.net(torch.cat((t, x_t), -1))
    
    def step(self, x_t: Tensor, t_start: Tensor, t_end: Tensor) -> Tensor:
        t_start = t_start.view(1, 1).expand(x_t.shape[0], 1)
        
        return x_t + (t_end - t_start) * self(t=t_start + (t_end - t_start) / 2, x_t= x_t + self(x_t=x_t, t=t_start) * (t_end - t_start) / 2)

# Training
flow = Flow()

optimizer = torch.optim.Adam(flow.parameters(), 1e-2)
loss_fn = nn.MSELoss()

for _ in range(10000):
    x_1 = Tensor(make_moons(256, noise=0.15)[0])
    x_0 = torch.randn_like(x_1)
    t = torch.rand(len(x_1), 1)
    
    x_t = (1 - t) * x_0 + t * x_1
    dx_t = x_1 - x_0
    
    optimizer.zero_grad()
    loss_fn(flow(t=t, x_t=x_t), dx_t).backward()
    optimizer.step()

# Sampling
x = torch.randn(300, 2)
n_steps = 8
fig, axes = plt.subplots(1, n_steps + 1, figsize=(30, 4), sharex=True, sharey=True)
time_steps = torch.linspace(0, 1.0, n_steps + 1)

axes[0].scatter(x.detach()[:, 0], x.detach()[:, 1], s=10)
axes[0].set_title(f't = {time_steps[0]:.2f}')
axes[0].set_xlim(-3.0, 3.0)
axes[0].set_ylim(-3.0, 3.0)

for i in range(n_steps):
    x = flow.step(x_t=x, t_start=time_steps[i], t_end=time_steps[i + 1])
    axes[i + 1].scatter(x.detach()[:, 0], x.detach()[:, 1], s=10)
    axes[i + 1].set_title(f't = {time_steps[i + 1]:.2f}')

plt.tight_layout()
plt.show()

In the above code, the forward function is for the velocity field \(v_{\theta}(t, x)\), and the step function is to get the next step \(X_{t+dt}\) from the current step \(X_t\) by Euler method

\[X_{t+dt} = X_t + dt * v_{\theta}(t+dt/2, X_t+dt/2 * v_{\theta}(t, X_t))\]

Conditional Flow Matching

(Note that the Conditional in the name of Conditional Flow Matching is meaning the condition \(c\) is given, not the conditional vector field \(u_t(x \mid x_1)\) from previous step)

In conditional flow matching, we incorporate a condition \(c\) into the velocity field \(v_{\theta}(t, x, c)\). In practice, the three inputs \(t, x, c\) are concatenated together as the input to the neural network.

Sampling function:

\[X_{t+dt} = X_t + dt * v_{\theta}(t+dt/2, X_t+dt/2 * v_{\theta}(t, X_t, c), c)\]
def forward(self, t: Tensor, c: Tensor, x_t: Tensor ) -> Tensor:
        return self.net(torch.cat((t, c, x_t), -1))
import torch 
from torch import nn, Tensor

import matplotlib.pyplot as plt
from sklearn.datasets import make_moons

# Define the Flow
class Flow(nn.Module):
    def __init__(self, dim: int = 2, h: int = 64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim + 2, h), nn.ELU(),
            nn.Linear(h, h), nn.ELU(),
            nn.Linear(h, h), nn.ELU(),
            nn.Linear(h, dim))
    
    def forward(self, t: Tensor, c: Tensor, x_t: Tensor ) -> Tensor:
        return self.net(torch.cat((t, c, x_t), -1))
    
    def step(self, x_t: Tensor, t_start: Tensor, t_end: Tensor, c: Tensor) -> Tensor:
        t_start = t_start.view(1, 1).expand(x_t.shape[0], 1)
        
        return x_t + (t_end - t_start) * self(t=t_start + (t_end - t_start) / 2, c = c, x_t= x_t + self(c = c, x_t=x_t, t=t_start) * (t_end - t_start) / 2)

# Training
flow = Flow()

optimizer = torch.optim.Adam(flow.parameters(), 1e-2)
loss_fn = nn.MSELoss()

for _ in range(10000):
    x_1, c = make_moons(256, noise=0.15)
    x_1 = Tensor(x_1)
    c = Tensor(c)
    c = c.view(-1, 1)
    
    x_0 = torch.randn_like(x_1)
    t = torch.rand(len(x_1), 1)
    
    x_t = (1 - t) * x_0 + t * x_1
    dx_t = x_1 - x_0
    
    optimizer.zero_grad()
    loss_fn(flow(t=t, x_t=x_t, c=c), dx_t).backward()

# Sampling
# --- evaluation / visualisation section --------------------------
n_samples = 256                    

sigma = 1.0
x      = torch.randn(n_samples, 2) * sigma     # (n_samples, 2)

# if you just want random labels –– otherwise load real labels here
c_eval = torch.randint(0, 2, (n_samples, 1), dtype=torch.float32)  # (n_samples, 1)

# colours for the scatter (same length as x)
colors  = ['blue' if lbl == 0 else 'orange' for lbl in c_eval.squeeze().tolist()]

# -----------------------------------------------------------------
n_steps      = 100
plot_every   = 20
plot_indices = list(range(0, n_steps + 1, plot_every))
if plot_indices[-1] != n_steps:
    plot_indices.append(n_steps)

fig, axes   = plt.subplots(1, len(plot_indices), figsize=(4 * len(plot_indices), 4),
                           sharex=True, sharey=True)
time_steps  = torch.linspace(0, 1.0, n_steps + 1)

# initial frame
axes[0].scatter(x[:, 0], x[:, 1], s=10, c=colors)
axes[0].set_title(f't = {time_steps[0]:.2f}')
axes[0].set_xlim(-3.0, 3.0)
axes[0].set_ylim(-3.0, 3.0)

plot_count = 0
with torch.no_grad():                         # no gradients while sampling
    for i in range(n_steps):
        x = flow.step(x_t=x,
                      t_start=time_steps[i],
                      t_end=time_steps[i + 1],
                      c=c_eval)               # 2️⃣ use the same‑sized label tensor
        if (i + 1) in plot_indices:
            plot_count += 1
            axes[plot_count].scatter(x[:, 0], x[:, 1], s=10, c=colors)
            axes[plot_count].set_title(f't = {time_steps[i + 1]:.2f}')
            axes[plot_count].set_xlim(-3.0, 3.0)
            axes[plot_count].set_ylim(-3.0, 3.0)

plt.tight_layout()
plt.show()optimizer.step()

References:

Differences between Score Matching, Diffusion Models and Flow Matching

Summary of Main Differences

All three methods learn generative models by establishing a connection between a simple noise distribution and a complex data distribution, but they differ fundamentally in their formulation and training approach:

Aspect Diffusion Models (DDPM/DDIM) Score Matching (SDE) Flow Matching (CFM)
Core Learning Target Learn to predict noise \(\epsilon_t\) or data \(x_t\) from previous step \(x_{t-1}\) Learn score function \(\nabla_x \log p_t(x)\) Learn velocity field \(u_t(x)\)
Process Type Discrete Markov chain Continuous SDE Continuous ODE
Forward Process Add Gaussian noise step-by-step Stochastic diffusion (SDE with drift + noise) Deterministic interpolation path
Backward Process Reverse Markov chain Reverse SDE ODE integration
Tractability Forward process tractable Forward SDE tractable Conditional paths tractable
Training Paradigm Denoising autoencoder Denoising score matching Conditional flow matching
Sampling Iterative denoising (stochastic or deterministic) SDE/ODE integration ODE integration (typically straight paths)
Path Geometry Curved noising trajectory Stochastic curved paths Straight/optimal transport paths
Key Advantage Simple, well-understood Theoretically grounded, flexible Fast sampling, simple training

Relationship:

  • DDIM can be viewed as a discretization of a probability flow ODE derived from the Score Matching SDE
  • Flow Matching with Gaussian probability paths recovers Score Matching formulations
  • Flow Matching with linear interpolation paths gives the deterministic version similar to DDIM
  • All three can be unified under a common framework of learning to transform distributions

Detailed Differences

1. Time Convention

Understanding time conventions is crucial for comparing these methods:

Flow Matching:

  • Continuous time \(t \in [0, 1]\)
  • \(t = 0\): noise distribution \(p_0(x) = \mathcal{N}(0, I)\)
  • \(t = 1\): data distribution \(p_1(x) = q(x)\)
  • Forward in time moves from noise → data

Diffusion Models (DDPM, DDIM):

  • Discrete time with steps \(t \in \{0, 1, 2, ..., T\}\)
  • \(t = 0\): data distribution \(q(x_0)\)
  • \(t = T\): noise distribution \(\mathcal{N}(0, I)\)
  • Forward in time moves from data → noise (opposite of Flow Matching!)
  • To align with Flow Matching convention, we use \(r = T - t\), so:
    • \(r = 0\) corresponds to noise
    • \(r = T\) corresponds to data

Score Matching (SDE):

  • Continuous time \(t \in [0, T]\) (often \(T = 1\))
  • \(t = 0\): data distribution \(p_0(x) = q(x)\)
  • \(t = T\): noise distribution \(p_T(x) \approx \mathcal{N}(0, \sigma^2 I)\)
  • Forward in time moves from data → noise
  • Using \(r = T - t\) for consistency: \(r = 0\) is noise, \(r = T\) is data

2. Forward Process vs. Probability Paths

Diffusion Models (DDPM, DDIM):

The forward process progressively corrupts data by adding Gaussian noise through a Markov chain:

\[q(x_t \mid x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_t} x_{t-1}, \beta_t I)\]

With reparameterization:

\[x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)\]

where \(\bar{\alpha}_t = \prod_{s=1}^t (1 - \beta_s)\).

  • Discrete steps: Each transition is a single Gaussian convolution
  • Tractable: \(q(x_t \mid x_0)\) has closed form
  • Markovian: Each step depends only on the previous state

Score Matching (SDE):

The forward process is a continuous stochastic differential equation (SDE):

\[dx = f(x, t)dt + g(t)dW_t\]

where:

  • \(f(x, t)\) is the drift coefficient (deterministic component)
  • \(g(t)\) is the diffusion coefficient (stochastic component)
  • \(W_t\) is the Wiener process (Brownian motion)

Common example (Variance Exploding - VE): \(dx = 0 \cdot dt + \sqrt{\frac{d\sigma_t^2}{dt}} dW_t\)

Common example (Variance Preserving - VP): \(dx = -\frac{1}{2}\beta_t x \, dt + \sqrt{\beta_t} dW_t\)

  • Continuous time: Infinitesimal noise additions
  • Stochastic: Includes random Brownian motion
  • Non-Markovian in discrete time but Markovian in continuous time

Flow Matching:

The forward process defines probability paths that interpolate between distributions:

\[p_t(x) = \int p_t(x \mid x_1) q(x_1) dx_1\]

Where the conditional probability path is often chosen as:

\[p_t(x_t \mid x_1) = \mathcal{N}(x_t; \mu_t(x_1), \sigma_t^2(x_1) I)\]

For linear interpolation (Optimal Transport path): \(x_t = (1-t)x_0 + t x_1, \quad x_0 \sim \mathcal{N}(0, I)\)

This gives: \(\mu_t(x_1) = t x_1, \quad \sigma_t = 1 - t\)

  • Deterministic paths (no stochastic component in the ODE)
  • Conditional paths are tractable by design
  • Straight trajectories (shortest path in many metrics)

Connections:

  • The velocity field \(u_t(x)\) in Flow Matching corresponds to the drift term \(f(x, t)\) in Score Matching
  • The conditional probability path \(p_t(x_t \mid x_1)\) in Flow Matching corresponds to the SDE solution initialized at \(x_1\)
  • The marginal probability path \(p_t(x)\) in Flow Matching corresponds to the SDE marginal when initialized from data \(x_0 \sim q(x)\)

Special Cases:

  • Flow Matching with Gaussian probability paths (with appropriate \(\mu_t, \sigma_t\)) recovers the forward SDE from Score Matching
  • Flow Matching with linear interpolation gives the deterministic probability flow ODE, similar to DDIM

3. Training Objective

Diffusion Models (DDPM):

Train a neural network to predict the noise added in the forward process:

\[\mathcal{L}_{DDPM}(\theta) = \mathbb{E}_{t, x_0, \epsilon} \left[ \| \epsilon_\theta(x_t, t) - \epsilon \|^2 \right]\]

where \(x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon\) (\(\epsilon\)-prediction).

Alternative formulation (\(x_0\)-prediction):

\[\mathcal{L}_{DDPM}(\theta) = \mathbb{E}_{t, x_0, \epsilon} \left[ \| \hat{x}_\theta(x_t, t) - x_0 \|^2 \right]\]
  • Target: Noise \(\epsilon\) or clean data \(x_0\)
  • Simple: Direct regression on known quantities
  • Weighted MSE: Can add time-dependent weighting

Score Matching (DSM - Denoising Score Matching):

The reverse SDE requires the score function \(\nabla_x \log p_t(x)\), which is intractable. Vincent (2011) proposed training a network \(s_\theta(x_t, t)\) to approximate the conditional score:

\[\mathcal{L}_{DSM}(\theta) = \mathbb{E}_{t, x_0, x_t \mid x_0} \left[ \| s_\theta(x_t, t) - \nabla_{x_t} \log q(x_t \mid x_0) \|^2 \right]\]

Under the assumption that \(q(x_t \mid x_0) \approx p_t(x_t)\) (reasonable for small noise), this approximates the true score.

For Gaussian perturbations \(q(x_t \mid x_0) = \mathcal{N}(\alpha_t x_0, \sigma_t^2 I)\):

\[\nabla_{x_t} \log q(x_t \mid x_0) = -\frac{x_t - \alpha_t x_0}{\sigma_t^2} = -\frac{\epsilon}{\sigma_t}\]

So the objective becomes:

\[\mathcal{L}_{DSM}(\theta) = \mathbb{E}_{t, x_0, \epsilon} \left[ \left\| s_\theta(\alpha_t x_0 + \sigma_t \epsilon, t) + \frac{\epsilon}{\sigma_t} \right\|^2 \right]\]
  • Target: Score function (gradient of log probability)
  • Theoretical: Grounded in score-based generative modeling theory
  • Flexible: Works with any forward SDE

Flow Matching (CFM):

Train a network to predict the velocity field:

\[\mathcal{L}_{CFM}(\theta) = \mathbb{E}_{t, x_1, x_t \mid x_1} \left[ \| v_\theta(t, x_t) - u_t(x_t \mid x_1) \|^2 \right]\]

For linear interpolation \(x_t = (1-t)x_0 + t x_1\):

\[u_t(x_t \mid x_1) = \frac{d x_t}{dt} = x_1 - x_0\]

So:

\[\mathcal{L}_{CFM}(\theta) = \mathbb{E}_{t, x_0, x_1} \left[ \| v_\theta(t, x_t) - (x_1 - x_0) \|^2 \right]\]
  • Target: Velocity (direction and magnitude of flow)
  • Direct: Straightforward regression on vector field
  • Efficient: Often requires fewer sampling steps

Equivalence:

The training objectives are closely related through reparameterizations:

  • Score to Noise: \(s_\theta(x_t, t) = -\frac{\epsilon_\theta(x_t, t)}{\sigma_t}\)
  • Velocity to Noise: \(v_\theta(t, x_t) = \frac{\alpha_t}{\sigma_t} \epsilon_\theta(x_t, t)\) (approximately, for certain schedulers)
  • Score Matching with Gaussian paths is equivalent to Flow Matching with appropriate probability path parameterization

4. Sampling Process

Diffusion Models:

DDPM (Stochastic Sampling):

Iteratively denoise by reversing the forward Markov chain:

\[x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(x_t, t) \right) + \sigma_t z\]

where \(z \sim \mathcal{N}(0, I)\) and \(\sigma_t\) is the noise variance.

  • Stochastic: Adds noise at each step
  • Many steps: Typically 1000 steps (can be reduced with techniques)

DDIM (Deterministic Sampling):

Use a deterministic update rule:

\[x_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \underbrace{\left( \frac{x_t - \sqrt{1 - \bar{\alpha}_t} \epsilon_\theta(x_t, t)}{\sqrt{\bar{\alpha}_t}} \right)}_{\text{predicted } x_0} + \sqrt{1 - \bar{\alpha}_{t-1}} \epsilon_\theta(x_t, t)\]
  • Deterministic: No added noise
  • Fewer steps: 10-50 steps often sufficient
  • Equivalent to probability flow ODE

Score Matching:

Reverse-time SDE:

\[dx = [f(x, t) - g(t)^2 \nabla_x \log p_t(x)] dt + g(t) d\bar{W}_t\]

where \(\bar{W}_t\) is a reverse-time Brownian motion.

  • Stochastic: Includes diffusion term
  • Continuous: Integrated numerically (Euler-Maruyama, etc.)

Probability Flow ODE (Deterministic):

\[\frac{dx}{dt} = f(x, t) - \frac{1}{2} g(t)^2 \nabla_x \log p_t(x)\]
  • Deterministic: No stochastic component
  • Same marginals as the SDE
  • Flexible solvers: Use any ODE solver (Runge-Kutta, adaptive methods)

Flow Matching:

Integrate the learned velocity field ODE:

\[\frac{dx}{dt} = v_\theta(t, x)\]

Starting from \(x_0 \sim \mathcal{N}(0, I)\) at \(t=0\), integrate to \(t=1\):

References

Noise scheduling

Noise scheduling in diffusion models refers to how noise is gradually added to data in the forward process and how it is removed in the reverse process. The choice of noise schedule significantly impacts the model’s performance, sample quality, and training efficiency.

We follow the DDIM convention, where \(0 < \bar{\alpha}_t < 1, \beta_t = 1 - \bar{\alpha}_t\) and \(\alpha_t = \prod_{i=1}^{t} \bar{\alpha}_i\) is the cumulative noise level at time \(t\), and \(\beta_t\) is the noise level at time \(t\). With this convention, \(x_t = \sqrt(\alpha_t) x_0 + \sqrt(1-\alpha_t) \epsilon\), and \(\alpha_T \approx 0\) when \(t \rightarrow T\) while \(\alpha_0 \approx 1\) when \(t \rightarrow 0\).

Common principles of noise scheduling:

  • Add large amount of noise at \(t\) large while small amount of noise at \(t\) small. \(t=0\) means clean data, \(t=T\) means pure noise.
  • The speed of change (acceleration, or \(\frac{d\beta_t}{dt}\)) should also has some proper speed (but I am not sure :D)

Common noise schedules:

  • Linear: \(\alpha_t = \frac{t}{T}\) or \(\beta_t = \beta_{\min} + (\beta_{\max} - \beta_{\min})\frac{t}{T}\). Issue: early timesteps do not add enough noise, and late timesteps can add too much noise.
  • Cosine: \(\beta_t = \beta_{\min} + 0.5 (\beta_{\max} - \beta_{\min}) ( 1 + \cos(\frac{t}{T} \pi))\). Intuition is that adding more gradually at the start and faster at the end.
  • Exponential: \(\beta_t = \beta_{\max} (\beta_{\min} / \beta_{\max})^{\frac{t}{T}}\)

Guidanced Diffusion

Resources:

Why Guidance?

Guidance is a method to control the generation process so that the ouput is sample from a conditional distribution \(p(x \mid y)\), where \(y\) is a condition - such as a text prompt - rather than a generic \(p(x)\).

Classifier Guidance

In order to get the conditional score function \(\nabla_x \ln p(x \mid y)\), we can use Bayes rule to decompose the score function into an unconditional component and a conditional one:

\[p(x \mid y) = \frac{p(y \mid x) p(x)}{p(y)}\] \[\log p(x \mid y) = \log p(y \mid x) + \log p(x) - \log p(y)\] \[\nabla_x \log p(x \mid y) = \nabla_x \log p(y \mid x) + \nabla_x \log p(x) - \nabla_x \log p(y)\]

where \(\nabla_x \log p(x)\) is the score function of the unconditional model. \(\nabla_x \log p(y) = 0\) since \(p(y)\) is independent of \(x\).

The term \(\nabla_x \log p(y \mid x)\) means the direction pointing to \(y\) given \(x\).

  • In the begining of the inference process, i.e., large \(t\), when \(x_t\) still has a lot of noise, \(\nabla_x \log p(y \mid x)\) is close to \(0\), means that there is no clear information of \(y\).
  • In the later stages, i.e., small \(t\), when \(x_t\) is less noisy and closer to \(x_0\), \(\nabla_x \log p(y \mid x)\) is larger, means that \(x_t\) has more information of \(y\), i.e., larger \(p(y \mid x)\).

How to obtain \(\nabla_x \log p(y \mid x)\)?

\(p(y \mid x)\) means the probability of a condition \(y\) given \(x\). In a simple case, where \(y\) is just a image class, like a cat, the probability \(p(y=\text{cat} \mid x)\) can be simply obtained from a pre-trained classifier.

However, in a more complex case, where \(y\) is a text prompt like a black cat with red eyes and blue fur, a pre-trained classifier is not expressive enough, i.e., it cannot distinguish between \(y_1\) a black cat with red eyes and blue fur vs \(y_2\) a white cat with blue eyes and red fur or mathematically \(p(y_1 \mid x) \neq p(y_2 \mid x)\).

In other words, the quality - diversity of the generated image \(x\) strongly depends on the capability of the conditional model \(p(y \mid x)\). For example:

  • If \(p_\phi(y \mid x)\) is a binary classifier hot dog or not hot dog, then output image \(x \sim p_\theta(x \mid y)\) can be either hot dog or not hot dog only, even \(p_\theta(x)\) was trained from a massive dataset with many more classes rather than just two classes.
  • If you want to generate an image \(x\) from a complex prompt \(y\), you need a powerful model like CLIP as the conditional model \(p_\phi(y \mid x)\).

To balance between the specificity (i.e., high \(p(y \mid x\))) and diversity/quality (i.e., \(p(x \mid y) \approx p(x)\)), we use a guidance scale \(\gamma\) to control the trade-off between the two.

\[\nabla_x \log p_{\textcolor{red}{\gamma}}(x \mid y) = \nabla_x \log p(x) + \gamma \nabla_x \log p(y \mid x)\]

where \(\gamma\) is the guidance scale. A big \(\gamma\) means the model is less creative but more following the condition \(y\).

Classifier-Free Guidance (CFG)

Classifier guidance has two key limitations:

  1. It requires a powerful external classifier \(p_\phi(y \mid x)\).
  2. The generative model \(p_\theta(x)\) may not match the domain of interest (e.g., trained on ImageNet but asked to generate medical images).

Classifier-free guidance solves both by eliminating the explicit classifier. Instead, we train the diffusion model itself in two modes:

  • Unconditional: modeling \(p_\theta(x)\)
  • Conditional: modeling \(p_\theta(x \mid y)\)

This is achieved simply by randomly dropping out the condition \(y\) during training (with some probability, e.g., 10–20%). Rather than truly unconditional generation \(p_\theta(x)\) which can be complicated in implementation, we can replace with a null condition \(\emptyset\) which is an empty string, i.e., \(p_\theta(x) = p_\theta(x \mid \emptyset)\). However, in my intuition, this might implicitly implies the null concept lies in the low-density region of the data manifold. More specifically, without the null condition, the interpolation between \(p_\theta(x \mid y_1)\) and \(p_\theta(x \mid y_2)\) such as \(p_\theta(x \mid (1-t)y_1 + ty_2)\) might be a smooth transition between two concepts \(y_1\) and \(y_2\). However, with the null condition, the interpolation might be a jump from one concept to another.

At inference time, we can combine these two models into a guided score. The derivation is as follows:

We start with Bayes rule for the conditional score that we want to approximate:

\[p(y \mid x) = \frac{p(x \mid y) p(y)}{p(x)}\]

Applying log-likelihood and taking the derivative w.r.t. \(x\):

\[\nabla_x \log p(y \mid x) = \nabla_x \log p(x \mid y) + \nabla_x \log p(y) - \nabla_x \log p(x)\]

Dropping the term \(\nabla_x \log p(y)\) since \(p(y)\) is independent of \(x\):

\[\nabla_x \log p(y \mid x) = \nabla_x \log p(x \mid y) - \nabla_x \log p(x)\]

Replacing this into the Classifier-guidance formula:

\[\nabla_x \log p_{\textcolor{red}{\gamma}}(x \mid y) = \nabla_x \log p(x) + \gamma (\nabla_x \log p(x \mid y) - \nabla_x \log p(x))\]

that is:

\[\nabla_x \log p_{\textcolor{red}{\gamma}}(x \mid y) = (1 - \gamma) \nabla_x \log p(x) + \gamma \nabla_x \log p(x \mid y)\]

where \(\gamma \geq 0\) is the guidance scale.

  • \(\gamma = 0\) → purely unconditional generation \(p(x \mid \emptyset)\)
  • \(\gamma = 1\) → purely conditional generation \(p(x \mid y)\).
  • \(\gamma > 1\) → amplifies the effect of the condition but less creative, trading off diversity for fidelity.
  • Geometrically, CFG interpolates between the unconditional and conditional score vectors, pushing the sample further in the direction that aligns with \(y\).

Why Classifier-Free Guidance works better than Classifier Guidance?
(Extending the intuition from Sander Dieleman) — the key lies in the difference between the gradient from a standard/external classifier \(\phi\) and the gradient from the generative model itself \(\theta\).

  • Classifier Guidance relies on
    \(\nabla_x \log p_{\phi}(y \mid x),\)
    where the classifier is trained independently from the generative process.

  • Classifier-Free Guidance (CFG) instead leverages the implicit classifier inside the generative model:
    \(\nabla_x \log p_{\theta}(x \mid y) + \nabla_x \log p_{\theta}(y) - \nabla_x \log p_{\theta}(x).\)

A well-known phenomenon of discriminative classifiers is shortcut learning (Geirhos et al., 2020): gradient-descent-trained classifiers often find “shortcuts” that optimize the loss but fail to align with human perception. For example, they may overfit to local texture cues rather than global shape, producing gradients that push generation toward features humans do not find semantically meaningful.

By contrast, the generative classifier (implicit in CFG) is trained to model and reconstruct the data distribution itself conditioned on human-meaningful labels/prompts. Its gradients are therefore better aligned with human perceptual semantics.

👉 In short: CFG works better under the human perspective (more semantically aligned generations), whereas Classifier Guidance works better under the classifier perspective (gradients aligned with a discriminative model that may exploit spurious correlations).

When Classifier Guidance is better?
I’ve explored a proposal linking this to Machine Unlearning (although not yet published). The idea is to guide unlearning with the gradient of an external classifier, rather than relying solely on the generative classifier as in the [ESD paper]. This approach can be particularly beneficial in two cases:

  • Unlearning rare concepts: where the generative model assigns very low probability \(p_{\theta}(x \mid y)\), making CFG ineffective.
  • Unlearning ambiguous or multi-expressed concepts: e.g., “nudity” vs “naked”. A discriminative classifier can unify these expressions under a shared semantic decision boundary, while the generative model may treat them as distinct.

Thus, while CFG dominates in general generation tasks, Classifier Guidance can provide unique advantages for targeted unlearning.

Why \(\gamma > 1\) but not between \(0\) and \(1\)?

In typical interpolation scenarios, we expect \(\gamma \in [0,1]\) to balance unconditional and conditional influences. Surprisingly, in CFG, values of \(\gamma > 1\) (e.g., 7 or 7.5) are commonly used — and empirically yield sharper, more faithful generations.

Two sets of samples from OpenAI's GLIDE model, for the prompt 'A stained glass window of a panda eating bamboo.', taken from their paper. Guidance scale 1 (no guidance) on the left, guidance scale 3 on the right. Image source: from Sander Dieleman's blog.
Two sets of samples from OpenAI's GLIDE model, for the prompt '“A cozy living room with a painting of a corgi on the wall above a couch and a round coffee table in front of a couch and a vase of flowers on a coffee table.', taken from their paper. Guidance scale 1 (no guidance) on the left, guidance scale 3 on the right. Image source: from Sander Dieleman's blog.

As demonstrated in Dhariwal & Nichol, 2021, larger \(\gamma\) produces outputs with higher fidelity and stronger alignment to prompts, albeit at the cost of reduced diversity. The intuition is that \(\log p_{\gamma}(y \mid x)\) becomes sharper than \(\log p_{1}(y \mid x)\) when \(\gamma > 1\), effectively amplifying conditional gradients and biasing the generation toward features that match the conditioning signal more strongly.

This explains why CFG practitioners often “turn up the guidance dial” above 1 — it helps the model stay on track with the prompt, even if some creativity/diversity is sacrificed.

Intuition Recap

  • Classifier guidance: adds an external force from a classifier \(p(y \mid x)\).
  • Classifier-free guidance: reuses the generative model itself, trained with and without condition, to simulate that force.
  • Both balance a trade-off between diversity and conditional alignment, controlled by the guidance scale \(\gamma\).

Latent Diffusion

Conditional Diffusion

Control-Net

Image Prompt

Beyond controlling the generation process using text prompt, there is a hot topic in the community to control using image information/layout/prompt - which has a huge potential in applications, e.g., image inpainting, image-to-image generation, etc. In the standard Stable Diffusion, the condition embedding \(c_t\) is just a text embedding \(c_t = E_t(y)\) where \(y\) is the text prompt and \(E_t\) is a pre-trained text encoder such as CLIP. IP-Adapter [1] proposes to use an additional image encoder to extract the image embedding from a reference image \(c_i = E_i(x)\) and then project it into the original condition space. The objective function for IP-Adapter is:

\[\mathcal{L}_{IP} = \mathbb{E}_{z, c, \epsilon, t} \left[ \mid \mid \epsilon - \epsilon_\theta(z_t \mid c_i, c_t, t) \mid \mid_2^2 \right]\]

The cross-attention layer is also modified from the one in Stable Diffusion to include the image embedding \(c_i\) as a condition.

\[\text{Attention}(Q, K_i, K_t, V_i, V_t) = \lambda \text{softmax}\left(\frac{QK_i^T}{\sqrt{d}} + c_i\right)V_i + \text{softmax}\left(\frac{QK_t^T}{\sqrt{d}}\right)V_t\]

where \(Q=z W_Q\), \(K_i = c_i W_K^i\), \(K_t = c_t W_K^t\), \(V_i = c_i W_V^i\), \(V_t = c_t W_V^t\), and \(W_Q\), \(W_K^i\), \(W_K^t\), \(W_V^i\), \(W_V^t\) are the weights of the linear layers. The model becomes the original Stable Diffusion when \(\lambda = 0\).

References:


Diffusion Transformers

The Diffusion Transformers (DiTs) is a class of diffusion models that replace the traditional U-Net convolutional architecture with a Vision Transformer (ViT) architecture as a backbone.

Data Processing in DiT

Similar to Latent Diffusion model, the diffusion process in DiT is on the latent space. Therefore, the first step is using pre-trained convolutional Variational Autoencoder (VAE) as in LDM to convert the spatial input into the latent space (i.e., \(256 \times 256 \times 3\) to \(32 \times 32 \times 4\)).

Patchifying converting the(latent) spatial input into a sequence of \(T\) tokens/patches, each of dimension \(d\), by linearly embedding each patch in the input with a linear layer.

Positional Encoding the standard sinusoidal positional embeddings are added to the token embeddings to provide the model with the positional information.

Beside the visual tokens, the DiT also uses the conditional information such as timestep \(t\) and the textual prompt \(c\) associated with the input image. These information are added to the DiT block through a embedding layer.

The DiT Architecture

There are three types have been studied in the DiT paper including:

In-Context Conditioning (The far right in the above figure) Append the vector embedding of \(t\) and \(c\) in the input sequence, treating them as additional visual tokens. This is similar to the cls tokens in ViT.

Cross-Attention Concatenate the conditional embedding \(t\) and \(c\) into a length-two sequence, separate from the image token sequence. Then modify the cross-attention layer to inject these conditioning information into the visual path.

Adaptive layer norm (adaLN) block Following the widespread success of Adaptive normalization layer in Diffusion with U-Net backbones, DiT also replaces the standard layer norm in transformer blocks with an adaptive layer norm. Rather than directly learn dimension-wise scaling and shift \(\gamma\) and \(\beta\), the adaLN block regresses them from the sum of the embedding of the conditioning information \(t\) and \(c\).

adaLN-Zero block Prior work on ResNets has found that initializing each residual block as the identity function is beneficial. This version uses the same adaptive layer norm as the adaLN block but with zero initialization.

Transformer Decoder After the final DiT block, we need to decode the sequence of image tokens into an output latent noise prediction and output diagonal covariance prediction (two outputs). This can be done by a standard linear layer with output dimension \(p \times p \times 2C\) where \(C\) is the number of channels of the image.

References:


Diffusion Flux

References:

Image Inpainting with Diffusion Models

Training Pipeline

Training data for inpainting is a combination of three components: original image as ground truth, masked image as input, and prompt as condition to provide the context of the missing region.

To ensure the model is robust, a variety of mask shapes and sizes can be used, including Rectangular, Free-form masks, and arbitrary shapes

Loss function for inpainting is a combination of pixel-wise reconstruction loss and perceptual loss (or Style loss). If using GANs, the adversarial loss is used to ensure the inpainted regions are perceptually realistic under the discriminator perspective.

Challenges in Image Inpainting

Semantic and Structural Consistency: A primary challenge for generative models is to fill in missing regions in a way that is not only visually plausible but also semantically and structurally consistent with the rest of the image.

Semantic ambiguity means that the missing region can be filled in multiple ways, e.g., filling a gap in a street scene could be extending a road, adding a pedestrian, or a vehicle. Even when the input prompt is given, the task remains difficult, when concept leaking occurs, i.e., “a black cat on a white background” vs “a white cat on a black background”.

Long-range dependency and global structure: is another significant hurdle. While generative models excel at local details, they can be struggling with the broader context, lighting, and perspective.

Perceptual realism: is another key challenge. Even if the inpainted regions are visually consistent, they may not align with human perception. For example, an inpainting might produce unrealistic shadows or reflections or overly smooth, or having artificial artifacts.

Large missing regions: The size of missing area is directly proportional to the difficulty of the task.

Accelerating Diffusion Models

Consistency Models

The core idea behind Consistency Models (CMs) is elegantly simple yet powerful:

“Points on the same trajectory should map to the same initial point.”

Concept and Mathematical Definition

Formally, consider a solution trajectory \(\{ x_t \}_{t \in [\epsilon, T]}\) of the Probability Flow ODE:

\[\frac{dx}{dt} = \mu(x, t) - \frac{1}{2} \sigma(t)^2 \nabla_x \log p_t(x)\]

We define the consistency function as

\[f: (x_t, t) \mapsto x_{\epsilon}\]

Intuitively, this function maps any point along a diffusion trajectory to its corresponding starting point at time \(\epsilon\).

A valid consistency function must satisfy the self-consistency property:

\[f(x_t, t) = f(x_{t'}, t') \quad \forall t, t' \in [\epsilon, T]\]

That is, any two points on the same trajectory—no matter when they occur—should yield the same mapped output.

The objective of a consistency model \(f_{\theta}\) is to learn this mapping from data while enforcing this self-consistency constraint.

Determinism and Relation to Probability Flow ODE

Unlike the stochastic nature of the diffusion SDE, the Probability Flow ODE is deterministic.
Given a fixed starting point \(x_T\), the trajectory and its corresponding final point \(x_{\epsilon}\) are uniquely determined for all \(t \in [\epsilon, T]\).

Sampling with Consistency Models

Once trained, a consistency model \(f_{\theta}\) can generate samples in a single step:

  1. Sample a random latent point \(x_T \sim \mathcal{N}(0, I)\)
  2. Map it to data space with
    \(x_{\epsilon} = f_{\theta}(x_T, T)\)

This one-step sampling process is deterministic and efficient.
Alternatively, we can perform multi-step sampling by injecting small amounts of noise at each step, introducing stochasticity to improve sample diversity.

Training Consistency Models

There are two main strategies to train consistency models:

  1. Distillation from Pre-Trained Diffusion Models uses knowledge from a pre-trained diffusion model.
  2. Training from Scratch relies on an unbiased estimator of the score function.

In summary, the key step is how to get the two adjacent points of the PF-ODE trajectory, then enforce the consistency function as its definition. In distillation, we leverage the pre-trained score model \(s_{\phi}(x,t)\) to approximate the ground-truth score function \(\nabla_x \log p_t(x)\). In training from scratch, we leverage the following unbiased estimator

\[\nabla_x \log p_t(x) = \mathbb{E} \left[ \frac{x_t - x}{t^2} \mid x_t \right]\]

where \(x \sim \mathcal{D}\) and \(x_t \sim \mathcal{N}(x; t^2 I)\). At the end, we approximate \(x_{t_n} = x + t_n z\) and \(x_{t_{n+1}} = x + t_{n+1} z\) where \(z \sim \mathcal{N}(0, I)\).

Distribution Matching Distillation

This approach leverages an existing diffusion model to generate adjacent points \((\hat{x}_{t_n}, x_{t_{n+1}})\) along a Probability Flow ODE trajectory.

The goal is to enforce:

\[f_{\theta}(\hat{x}_{t_n}, t_n) = f_{\theta}(x_{t_{n+1}}, t_{n+1})\]

so that \(f_{\theta}\) behaves as a true consistency function.

Step-by-step process:

Step 1 — Obtain point \(x_{t_{n+1}}\)
Sample from the SDE transition density: \(x_{t_{n+1}} \sim \mathcal{N}(x; t_{n+1}^2 I), \quad x \sim \mathcal{D}\)

Step 2 — Estimate the adjacent point \(\hat{x}_{t_n}\)
Using an ODE solver \(\Phi(x, t, \phi)\) parameterized by the pre-trained diffusion model:

\[\hat{x}_{t_n} = x_{t_{n+1}} + (t_n - t_{n+1}) \Phi(x_{t_{n+1}}, t_{n+1}, \phi)\]

If the Euler method is used,
\(\Phi(x, t, \phi) = -t\, s_{\phi}(x, t)\)
where \(s_{\phi}(x, t)\) is the score function.

Hence,
\(\hat{x}_{t_n} = x_{t_{n+1}} - (t_n - t_{n+1}) t_{n+1} s_{\phi}(x_{t_{n+1}}, t_{n+1})\)

Step 3 — Define the loss
The consistency loss ensures outputs from adjacent points match:

\[\mathcal{L} = d\big(f_{\theta}(\hat{x}_{t_n}, t_n), f_{\theta}(x_{t_{n+1}}, t_{n+1})\big)\]

where \(d(\cdot,\cdot)\) is a distance metric (e.g., L2 norm).

Step 4 — Update with EMA
Rather than standard gradient descent, the paper proposes an Exponential Moving Average (EMA) update between two model parameters \(\theta\) and \(\theta^-\), as shown in Algorithm 2.

Training from scratch

When no pre-trained diffusion model is available, we can directly estimate the score function using an unbiased estimator:

\[\nabla_x \log p_t(x) = \mathbb{E}\left[\frac{x_t - x}{t^2} \mid x_t\right]\]

where

  • \(x \sim \mathcal{D}\) (data distribution)
  • \(x_t \sim \mathcal{N}(x; t^2 I)\) (noisy version of data)

We can then approximate: \(x_{t_n} = x + t_n z, \quad x_{t_{n+1}} = x + t_{n+1} z, \quad z \sim \mathcal{N}(0, I)\)

This formulation allows consistency models to be trained entirely from noise-perturbed data samples.

Implementation of the Euler method

Implementation of the Euler method (from here). Note that this version implies \(s_{\phi}(x,t) = \frac{x - f_{\theta}(x,t)}{t^2}\) and therefore we don’t need the pre-trained diffusion model to estimate the score function.

Substituing this into the PF-ODE Euler update:

\[\hat{x}_{t_n} = x_{t_{n+1}} - (t_n - t_{n+1}) t_{n+1} s_{\phi}(x_{t_{n+1}}, t_{n+1})\]

We get:

\[\hat{x}_{t_n} = x_{t_{n+1}} - (t_n - t_{n+1}) \frac{x - f_{\theta}(x,t)}{t}\]
@th.no_grad()
def euler_solver(samples, t, next_t, x0):
    x = samples
    if teacher_model is None:
        denoiser = x0
    else:
        denoiser = teacher_denoise_fn(x, t) # f_{\theta}(x,t) - consistency model output
    d = (x - denoiser) / append_dims(t, dims)
    samples = x + d * append_dims(next_t - t, dims)

    return samples

Heun method:

@th.no_grad()
def heun_solver(samples, t, next_t, x0):
    x = samples
    if teacher_model is None:
        denoiser = x0
    else:
        denoiser = teacher_denoise_fn(x, t) # f_{\theta}(x,t) - consistency model output

    
    # IMPORTANT - Euler method
    d = (x - denoiser) / append_dims(t, dims)
    samples = x + d * append_dims(next_t - t, dims)
    if teacher_model is None:
        denoiser = x0
    else:
        denoiser = teacher_denoise_fn(samples, next_t) # f_{\theta}(x,t) - consistency model output

    next_d = (samples - denoiser) / append_dims(next_t, dims)
    samples = x + (d + next_d) * append_dims((next_t - t) / 2, dims)

    return samples

Implementation of Consistency Models

The official implementation is available here. The train loop is defined in the train_util.py file.

  • diffusion.progdist_losses is the loss function for progressive distillation.
  • consistency_losses is the loss function for consistency distillation.
  • target_model
  • teacher_model
  • teacher_diffusion

forward_backward method

    def forward_backward(self, batch, cond):
        self.mp_trainer.zero_grad()
        for i in range(0, batch.shape[0], self.microbatch):
            micro = batch[i : i + self.microbatch].to(dist_util.dev())
            micro_cond = {
                k: v[i : i + self.microbatch].to(dist_util.dev())
                for k, v in cond.items()
            }
            last_batch = (i + self.microbatch) >= batch.shape[0]
            t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())

            ema, num_scales = self.ema_scale_fn(self.global_step)
            if self.training_mode == "progdist":
                if num_scales == self.ema_scale_fn(0)[1]:
                    compute_losses = functools.partial(
                        self.diffusion.progdist_losses,
                        self.ddp_model,
                        micro,
                        num_scales,
                        target_model=self.teacher_model,
                        target_diffusion=self.teacher_diffusion,
                        model_kwargs=micro_cond,
                    )
                else:
                    compute_losses = functools.partial(
                        self.diffusion.progdist_losses,
                        self.ddp_model,
                        micro,
                        num_scales,
                        target_model=self.target_model,
                        target_diffusion=self.diffusion,
                        model_kwargs=micro_cond,
                    )
            elif self.training_mode == "consistency_distillation":
                compute_losses = functools.partial(
                    self.diffusion.consistency_losses,
                    self.ddp_model,
                    micro,
                    num_scales,
                    target_model=self.target_model,
                    teacher_model=self.teacher_model,
                    teacher_diffusion=self.teacher_diffusion,
                    model_kwargs=micro_cond,
                )
            elif self.training_mode == "consistency_training":
                compute_losses = functools.partial(
                    self.diffusion.consistency_losses,
                    self.ddp_model,
                    micro,
                    num_scales,
                    target_model=self.target_model,
                    model_kwargs=micro_cond,
                )
            else:
                raise ValueError(f"Unknown training mode {self.training_mode}")

            if last_batch or not self.use_ddp:
                losses = compute_losses()
            else:
                with self.ddp_model.no_sync():
                    losses = compute_losses()

The consistency_losses.

  • distiller = denoise_fn(x_t, t) and distiller_target = target_denoise_fn(x_t2, t2) are the consistency model output and the target model output, respectively.
  • t2 is the adjacent time step of t.
  • x_t2 is the predicted adjacent point of x_t.

    def consistency_losses(
        self,
        model,
        x_start,
        num_scales,
        model_kwargs=None,
        target_model=None,
        teacher_model=None,
        teacher_diffusion=None,
        noise=None,
    ):
        if model_kwargs is None:
            model_kwargs = {}
        if noise is None:
            noise = th.randn_like(x_start)

        dims = x_start.ndim

        # IMPORTANT - f_{\theta}(x,t) - consistency model output
        def denoise_fn(x, t):
            return self.denoise(model, x, t, **model_kwargs)[1]

        if target_model:

            @th.no_grad()
            def target_denoise_fn(x, t):
                return self.denoise(target_model, x, t, **model_kwargs)[1]

        else:
            raise NotImplementedError("Must have a target model")

        if teacher_model:

            @th.no_grad()
            def teacher_denoise_fn(x, t):
                return teacher_diffusion.denoise(teacher_model, x, t, **model_kwargs)[1]

        @th.no_grad()
        def heun_solver(samples, t, next_t, x0):
            x = samples
            if teacher_model is None:
                denoiser = x0
            else:
                denoiser = teacher_denoise_fn(x, t)

            
            # IMPORTANT - Euler method
            d = (x - denoiser) / append_dims(t, dims)
            samples = x + d * append_dims(next_t - t, dims)
            if teacher_model is None:
                denoiser = x0
            else:
                denoiser = teacher_denoise_fn(samples, next_t)

            next_d = (samples - denoiser) / append_dims(next_t, dims)
            samples = x + (d + next_d) * append_dims((next_t - t) / 2, dims)

            return samples

        @th.no_grad()
        def euler_solver(samples, t, next_t, x0):
            x = samples
            if teacher_model is None:
                denoiser = x0
            else:
                denoiser = teacher_denoise_fn(x, t)
            d = (x - denoiser) / append_dims(t, dims)
            samples = x + d * append_dims(next_t - t, dims)

            return samples

        indices = th.randint(
            0, num_scales - 1, (x_start.shape[0],), device=x_start.device
        )

        t = self.sigma_max ** (1 / self.rho) + indices / (num_scales - 1) * (
            self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)
        )
        t = t**self.rho

        t2 = self.sigma_max ** (1 / self.rho) + (indices + 1) / (num_scales - 1) * (
            self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)
        )
        t2 = t2**self.rho

        x_t = x_start + noise * append_dims(t, dims)

        dropout_state = th.get_rng_state()
        distiller = denoise_fn(x_t, t)

        if teacher_model is None:
            x_t2 = euler_solver(x_t, t, t2, x_start).detach()
        else:
            x_t2 = heun_solver(x_t, t, t2, x_start).detach()

        th.set_rng_state(dropout_state)
        distiller_target = target_denoise_fn(x_t2, t2)
        distiller_target = distiller_target.detach()

        snrs = self.get_snr(t)
        weights = get_weightings(self.weight_schedule, snrs, self.sigma_data)
        if self.loss_norm == "l1":
            diffs = th.abs(distiller - distiller_target)
            loss = mean_flat(diffs) * weights
        elif self.loss_norm == "l2":
            diffs = (distiller - distiller_target) ** 2
            loss = mean_flat(diffs) * weights
        elif self.loss_norm == "l2-32":
            distiller = F.interpolate(distiller, size=32, mode="bilinear")
            distiller_target = F.interpolate(
                distiller_target,
                size=32,
                mode="bilinear",
            )
            diffs = (distiller - distiller_target) ** 2
            loss = mean_flat(diffs) * weights
        elif self.loss_norm == "lpips":
            if x_start.shape[-1] < 256:
                distiller = F.interpolate(distiller, size=224, mode="bilinear")
                distiller_target = F.interpolate(
                    distiller_target, size=224, mode="bilinear"
                )

            loss = (
                self.lpips_loss(
                    (distiller + 1) / 2.0,
                    (distiller_target + 1) / 2.0,
                )
                * weights
            )
        else:
            raise ValueError(f"Unknown loss norm {self.loss_norm}")

        terms = {}
        terms["loss"] = loss

        return terms

Diffusion Distillation

Progressive Distillation

The idea of progressive distillation is to

References:

Rectified Diffusion

References:

Caching in Diffusion Models

This technique takes advantage of the U-Net architecture used in diffusion models, particularly its skip connections, which transfer intermediate features from the encoder to the decoder.

The core idea:
👉 Store intermediate results from step \(t\) (e.g., decoder features) and reuse them at step \(t-1\) instead of recomputing the entire U-Net.

U-Net Refresher

A U-Net has two main components:

  • Encoder (Down Blocks) — progressively downsamples the input to a compact high-level representation.
  • Decoder (Up Blocks) — upsamples the features to reconstruct the image.

Each pair of down and up blocks \(D_i, U_i\) connects through:

  • a main path: \(D_1 \to D_d \to U_d \to U_1\)
  • skip connections: \(D_i \to U_i\)

At each layer, the output combines both paths: \(U_i = \text{Concat}(D_i, U_{i+1})\)

Observation: Feature Reuse Across Timesteps

During denoising, adjacent timesteps produce very similar high-level features: \(U_i^{(t)} \approx U_i^{(t-1)}\)

So instead of recomputing these expensive decoder features every step, we can cache them:

\[F_c^{(t)} \leftarrow U_i^{(t)} \\ U_i^{(t-1)} = \text{Concat}(D_i^{(t-1)}, F_c^{(t)})\]

This simple reuse cuts redundant computation and significantly speeds up inference.

Implementation Overview

I was curious more about the implementation details than the idea. You can find the full implementation in the DeepCache repository. First, we need to modify the Stable Diffusion pipeline, specifically the denoising loop, where cached features are passed to the U-Net.


    # predict the noise residual
    noise_pred, prv_features = self.unet(
        latent_model_input,
        t,
        encoder_hidden_states=prompt_embeds,
        cross_attention_kwargs=cross_attention_kwargs,
        replicate_prv_feature=prv_features,
        quick_replicate= cache_interval>1,
        cache_layer_id=cache_layer_id,
        cache_block_id=cache_block_id,
        return_dict=False,
    )

The U-Net model is defined in the unet_2d_condition.py file, where forward method has been modified to support the caching feature. Note that the caching applied to cross-attention layer only.

if quick_replicate and replicate_prv_feature is not None:
    # Downsampling - nothing change 

    # Middle - No middle 

    # Upsampling
    sample = replicate_prv_feature
    #down_block_res_samples = down_block_res_samples[:-1]
    if cache_block_id == len(self.down_blocks[cache_layer_id].attentions) :
        cache_block_id = 0
        cache_layer_id += 1
    else:
        cache_block_id += 1

    for i, upsample_block in enumerate(self.up_blocks):

        # Skip the blocks that are not the cache layer # IMPORTANT - This is where speed is gained
        if i < len(self.up_blocks) - 1 - cache_layer_id:
            continue

        if i == len(self.up_blocks) - 1 - cache_layer_id:
            trunc_upsample_block = cache_block_id + 1
        else:
            trunc_upsample_block = len(upsample_block.resnets)

        is_final_block = i == len(self.up_blocks) - 1

        res_samples = down_block_res_samples[-trunc_upsample_block:]
        down_block_res_samples = down_block_res_samples[: -trunc_upsample_block]

        # if we have not reached the final block and need to forward the
        # upsample size, we do it here
        if not is_final_block and forward_upsample_size:
            upsample_size = down_block_res_samples[-1].shape[2:]

        if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
            #print(sample.shape, [res_sample.shape for res_sample in res_samples])
            sample, _ = upsample_block(
                hidden_states=sample,
                temb=emb,
                res_hidden_states_tuple=res_samples,
                encoder_hidden_states=encoder_hidden_states,
                cross_attention_kwargs=cross_attention_kwargs,
                upsample_size=upsample_size,
                attention_mask=attention_mask,
                encoder_attention_mask=encoder_attention_mask,
                enter_block_number=cache_block_id if i == len(self.up_blocks) - 1 - cache_layer_id else None,
            )
        else:
            sample = upsample_block(
                hidden_states=sample,
                temb=emb,
                res_hidden_states_tuple=res_samples,
                upsample_size=upsample_size,
                scale=lora_scale,
            )

    prv_f = replicate_prv_feature

References:

Rectified Flows define the forward process as straight paths between the data distribution and a standard normal distribution [2], i.e.,

\[z_t = (1 - t) x_0 + t \epsilon\]

where \(\epsilon\) is a standard normal random variable and \(t\) is the time step in [0, 1].

Multi-modal Diffusion