Diffusion Models from Scratch

Goal: Build a DDPM (Denoising Diffusion Probabilistic Model) from scratch in PyTorch. Understand why diffusion works, how the forward/reverse process work, and implement the training loop.

Based on: Denoising Diffusion Probabilistic Models (Ho et al., 2020) and gmongaras/Diffusion_models_from_scratch

Prerequisites: PyTorch, neural networks, understand what a VAE is (tutorial 17 helps)


The Core Intuition

Diffusion models learn to reverse a corruption process. If you can learn to denoise, you can generate images by starting with noise and progressively denoising it.

Two processes:

  • Forward process: Gradually add noise to real image until it’s pure Gaussian noise (known, fixed)
  • Reverse process: Neural network learns to undo this corruption step-by-step (learned)

The key insight: predicting noise is easier than predicting the image directly.


The Forward Process (Adding Noise)

The forward process applies a Markov chain of noise additions:

import torch
import numpy as np
 
def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    """Linear beta schedule as in original DDPM paper."""
    return torch.linspace(start, end, timesteps)
 
def get_alphas(betas):
    """Compute alphas = 1 - betas, cumulative product."""
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
    return alphas, alphas_cumprod, alphas_cumprod_prev
 
# Parameters
T = 1000  # timesteps
betas = linear_beta_schedule(T)
alphas, alphas_cumprod, alphas_cumprod_prev = get_alphas(betas)
 
def q_sample(x_start, t, noise):
    """
    Forward diffusion: add noise to image at timestep t.
    x_t = sqrt(alphas_cumprod[t]) * x_0 + sqrt(1 - alphas_cumprod[t]) * noise
    """
    sqrt_alphas_cumprod_t = torch.sqrt(alphas_cumprod[t])
    sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1 - alphas_cumprod[t])
    
    return sqrt_alphas_cumprod_t[:, None, None, None] * x_start + \
           sqrt_one_minus_alphas_cumprod_t[:, None, None, None] * noise

Key point: You can sample x_t for ANY timestep t directly — no need to iterate. This is the reparameterization trick.


The Loss Function

DDPM predicts the noise ε_θ. The loss is:

def p_losses(denoise_model, x_start, t, noise):
    """
    Training: predict the noise that was added.
    loss = MSE(epsilon, epsilon_theta(x_t, t))
    """
    x_t = q_sample(x_start, t, noise)
    predicted_noise = denoise_model(x_t, t)
    
    return torch.nn.functional.mse_loss(predicted_noise, noise)

Why predict noise instead of image? The reverse process is conditioned on x_t, and noise prediction gives a well-defined scaling (both are Gaussian).


The Reverse Process (Denoising)

@torch.no_grad()
def p_sample(model, x_t, t):
    """
    Single denoising step.
    x_{t-1} = (1/sqrt(alphas[t])) * (x_t - betas[t]/sqrt(1-alphas_cumprod[t]) * epsilon_theta(x_t, t)) 
               + sqrt(betas[t]) * z
    """
    betas_t = betas[t]
    sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1 - alphas_cumprod[t])
    
    # Predict noise
    predicted_noise = model(x_t, t)
    
    # Compute x_{t-1}
    model_mean = (1 / torch.sqrt(alphas[t])) * \
                 (x_t - betas_t / sqrt_one_minus_alphas_cumprod_t * predicted_noise)
    
    # Add noise (except for t=0)
    posterior_variance = betas_t
    noise = torch.randn_like(x_t) if t > 0 else 0
    
    return model_mean + torch.sqrt(posterior_variance) * noise
 
@torch.no_grad()
def p_sample_loop(model, shape):
    """Generate images by starting from noise and denoising step-by-step."""
    x = torch.randn(shape, device=betas.device)
    
    for t in reversed(range(T)):
        x = p_sample(model, x, t)
    
    return x

The UNet Backbone

The noise predictor is typically a U-Net with time embeddings:

import torch.nn as nn
 
class SinusoidalPositionEmbeddings(nn.Module):
    """Time embedding: maps timestep to feature space."""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    
    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        embeddings = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = t[:, None] * embeddings[None, :]
        embeddings = torch.cat([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
        return embeddings
 
class ResidualBlock(nn.Module):
    """Block with time conditioning."""
    def __init__(self, in_ch, out_ch, time_emb_dim, groups=8):
        super().__init__()
        self.conv = nn.Sequential(
            nn.GroupNorm(groups, in_ch),
            nn.SiLU(),
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(groups, out_ch),
            nn.SiLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
        )
        self.time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, out_ch))
        
        if in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, 1)
        else:
            self.shortcut = nn.Identity()
    
    def forward(self, x, t_emb):
        h = self.conv(x)
        h = h + self.time_mlp(t_emb)[:, :, None, None]
        return h + self.shortcut(x)

Training Loop

def train_diffusion():
    model = UNet(dim=64, time_dim=128)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    dataset = YourImageDataset()  # e.g., CIFAR-10
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
    
    for epoch in range(100):
        for batch, (images, _) in enumerate(dataloader):
            images = images.to(device)
            batch_size = images.shape[0]
            
            # Sample random timesteps
            t = torch.randint(0, T, (batch_size,), device=device).long()
            
            # Sample noise
            noise = torch.randn_like(images)
            
            # Compute loss
            loss = p_losses(model, images, t, noise)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if batch % 100 == 0:
                print(f"Epoch {epoch} | Batch {batch} | Loss: {loss.item():.4f}")
        
        # Sample every epoch
        if epoch % 10 == 0:
            samples = p_sample_loop(model, (4, 3, 32, 32))
            save_image(samples, f'sample_{epoch}.png', nrow=2)

Key Insights

  1. Predict noise, not image: Scaling matters. Noise prediction is mathematically cleaner because the variance of the forward process is fixed.

  2. The reparameterization trick: You can jump to any timestep directly. This makes training efficient — you don’t need to simulate the entire chain.

  3. Betas schedule matters: Linear (original DDPM) vs cosine (improved). The schedule affects sample quality and convergence speed.

  4. Connection to score matching: DDPM is equivalent to score matching with Langevin dynamics. The reverse process learns the score function ∇log p(x_t).


DDPM vs DDIM (Speed)

DDIM (Song et al. 2021) enables 10-50x faster sampling by using non-Markovian reverse process:

@torch.no_grad()
def ddim_sample(model, x_t, t, prev_t, eta=0.0):
    """
    DDIM sampling: deterministic (eta=0) or stochastic (eta=1).
    """
    alphas_cumprod_t = alphas_cumprod[t]
    alphas_cumprod_prev_t = alphas_cumprod[prev_t]
    
    predicted_noise = model(x_t, t)
    
    # Direction pointing to x_t
    pred_x_0 = (x_t - torch.sqrt(1 - alphas_cumprod_t) * predicted_noise) \
               / torch.sqrt(alphas_cumprod_t)
    
    # Variance
    sigma_t = eta * torch.sqrt((1 - alphas_cumprod_prev_t) / (1 - alphas_cumprod_t) * \
                               (1 - alphas_cumprod_t / alphas_cumprod_prev_t))
    
    # Previous sample
    pred_x_t_prev = torch.sqrt(alphas_cumprod_prev_t) * pred_x_0 + \
                    torch.sqrt(1 - alphas_cumprod_prev_t - sigma_t**2) * predicted_noise
    
    return pred_x_t_prev + sigma_t * torch.randn_like(x_t) * (eta > 0)

Exercises

  1. Implement cosine beta schedule and compare convergence with linear
  2. Add attention to the UNet (see Attention is All You Need tutorial)
  3. Condition on class labels using CFG (Classifier-Free Guidance)
  4. Implement DDIM sampling and verify 10x speedup with minimal quality loss

See Also