This post is a note for myself to compare the implementations of diffusion models in HuggingFace’s Diffusers and CompVis’s Stable Diffusion. I quite often need to switch between these two implementations, so I want to keep track of the differences between them.

The source code of two libraries can be found at:

Basic Functions

Below are the basic functions of a standard diffusion model pipeline, including:

  • Loading components such as tokenizer, scheduler, vae, unet.
  • Converting images to latent space.
  • Forward and backward diffusion process.
  • Calculating loss.

Note that the code snippets below just refer to specific functions and not meant to be a complete script. Read comments in the code to understand the context.

Diffusers

taken from train_text_to_image.py in here

# Import the necessary modules
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    DiffusionPipeline,
    DPMSolverMultistepScheduler,
    StableDiffusionPipeline,
    UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler

# load components of the model
# Load tokenizer
if args.tokenizer_name:
    tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
elif args.pretrained_model_name_or_path:
    tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")

# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = CLIPTextModel.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)

# Inside the training loop
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
latents = latents * vae.config.scaling_factor

# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]

# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)

# Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

CompVis’s Stable Diffusion

In CompVis library, the training parameters are packed in config yaml files in the configs folder, and the training script is in train.py. The training method uses a Trainer class which is a wrapper of PyTorch Lightning’s Trainer class (refer to here).

Lightning Trainer

The Lightning Trainer does much more than just “training”. Under the hood, it handles all loop details for you, some examples include:

  1. Automatically enabling/disabling grads
  2. Running the training, validation and test dataloaders
  3. Calling the Callbacks at the appropriate times
  4. Putting batches and computations on the correct devices

Here’s the pseudocode for what the trainer does under the hood (showing the train loop only)

# enable grads
torch.set_grad_enabled(True)

losses = []
for batch in train_dataloader:
    # calls hooks like this one
    on_train_batch_start()

    # train step
    loss = training_step(batch)

    # clear gradients
    optimizer.zero_grad()

    # backward
    loss.backward()

    # update parameters
    optimizer.step()

    losses.append(loss)

In the config file, we can find the paths to the components of the model, such as the VAE, UNet, and scheduler. For example, in configs/latent-diffusion/celebahq-ldm-vq-4.yaml, these models are defined in the target field with the corresponding paths and training parameters.

model:
  base_learning_rate: 2.0e-06
  target: ldm.models.diffusion.ddpm.LatentDiffusion
  params:
    linear_start: 0.0015
    linear_end: 0.0195
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: image
    image_size: 64
    channels: 3
    monitor: val/loss_simple_ema

    unet_config:
      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
      params:
        image_size: 64
        in_channels: 3
        out_channels: 3
        model_channels: 224
        attention_resolutions:
        # note: this isn\t actually the resolution but
        # the downsampling factor, i.e. this corresnponds to
        # attention on spatial resolution 8,16,32, as the
        # spatial reolution of the latents is 64 for f4
        - 8
        - 4
        - 2
        num_res_blocks: 2
        channel_mult:
        - 1
        - 2
        - 3
        - 4
        num_head_channels: 32
    first_stage_config:
      target: ldm.models.autoencoder.VQModelInterface
      params:
        embed_dim: 3
        n_embed: 8192
        ckpt_path: models/first_stage_models/vq-f4/model.ckpt
        ddconfig:
          double_z: false
          z_channels: 3
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity
    cond_stage_config: __is_unconditional__
data:
  target: main.DataModuleFromConfig
  params:
    batch_size: 48
    num_workers: 5
    wrap: false
    train:
      target: taming.data.faceshq.CelebAHQTrain
      params:
        size: 256
    validation:
      target: taming.data.faceshq.CelebAHQValidation
      params:
        size: 256


lightning:
  callbacks:
    image_logger:
      target: main.ImageLogger
      params:
        batch_frequency: 5000
        max_images: 8
        increase_log_steps: False

  trainer:
    benchmark: True

How to train the model?

IMO, Lightning is difficult to read and understand. I found this post in Reddit, saying that the path of just simple training loop function (it’s suck)

Trainer.fit() -> Trainer._fit_impl() -> Trainer._run() -> Trainer._run_stage() -> Trainer._run_train() -> FitLoop.run() -> FitLoop.advance() -> TrainingEpochLoop.run() -> TrainingEpochLoop.advance() -> TrainingBatchLoop.run() -> TrainingBatchLoop.advance() -> OptimizerLoop.run() -> OptimizerLoop.advance() -> OptimizerLoop._run_optimization() -> OptimizerLoop._make_closure() -> OptimizerLoop._make_step_fn()

The training procedure is hidden in the class LatentDiffusion in ldm/models/diffusion/ddpm.py, function training_step (refer to this line). More specifically, the forward pass as follows:


# convert images to latent space
encoder_posterior = self.encode_first_stage(x)
z = self.get_first_stage_encoding(encoder_posterior).detach()

# get conditioning
c = self.get_learned_conditioning(cond_key)

# random timestep
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()

# add noise
noise = default(noise, lambda: torch.randn_like(x_start))

# forward diffusion
# x_start is the input clean image
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)

# apply model, backward diffusion
model_output = self.apply_model(x_noisy, t, cond)

# choose type of target, there are two types of output of the model, image or noise
# in the default setting of Latent Diffusion Models, the output is the epsilon rather than the image
if self.parameterization == "x0":
    target = x_start
elif self.parameterization == "eps":
    target = noise
else:
    raise NotImplementedError()

# calculate loss
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})

logvar_t = self.logvar[t].to(self.device)
loss = loss_simple / torch.exp(logvar_t) + logvar_t

loss = self.l_simple_weight * loss.mean()

loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
loss += (self.original_elbo_weight * loss_vlb)

Example

In the following, I will provide a simple code using CompVis’s Stable Diffusion for Textual Inversion, which is already implemented in HuggingFace’s Diffusers here.

The full script including data can be found here https://github.com/tuananhbui89/diffusion_demo/tree/main/textual_inversion

It is a worth noting that in the original Textual Inversion, the final goal is to obtain a special token (e.g., sks dog) that serves two purposes: (1) it is associated to the visual representation of personal data, and (2) it is in text form so that users can easily use it to generate new images. To do that, in the original implementation, the original embedding matrix is replaced by a new one, however, in my implementation, I skip this step and directly optimize the embedding vector.

def train_inverse(model, sampler, train_data_dir, devices, args):
    """
    Given a model and a set of reference images, learn an embedding vector that will generate an image similar to the reference images.

    Args:
        model: the model to be trained
        sampler: the sampler to be used for sampling
        train_data_dir: the reference images to be used for training
        args: the arguments for training

    Returns:
        emb: the learned embedding vector
    """

    # create a textual embedding variable to optimize
    prompt = f'a photo of {args.concept}'
    emb = model.get_learned_conditioning([prompt])
    org_emb = emb.clone()
    emb = Variable(emb, requires_grad=True).to(devices[0])

    # create an optimizer to optimize the prompt
    opt = torch.optim.Adam([emb], lr=args.lr)

    # Dataset and DataLoaders creation:
    train_dataset = PreprocessImage(
        data_root=train_data_dir,
        size=args.resolution,
        repeats=args.repeats,
        center_crop=args.center_crop,
        set="train",
    )
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
    )    
    
    fixed_start_code = torch.randn((1, 4, 64, 64)).to(devices[0])

    # create a lambda function for cleaner use of sampling code (only denoising till time step t)
    quick_sample_till_t = lambda cond, s, code, t: sample_model(model, sampler,
                                                                cond, args.image_size, args.image_size, args.ddim_steps, s, args.ddim_eta,
                                                                start_code=code, till_T=t, verbose=False)
    
    # create a function to decode and save the image
    def decode_and_save_image(model, z, path):
        x = model.decode_first_stage(z)
        x = torch.clamp((x + 1.0)/2.0, min=0.0, max=1.0)
        x = rearrange(x, 'b c h w -> b h w c')
        image = Image.fromarray((x[0].cpu().numpy()*255).astype(np.uint8))
        plt.imshow(image)
        plt.xticks([])
        plt.yticks([])
        plt.savefig(path)
        plt.close()
    
    os.makedirs('evaluation_folder', exist_ok=True)
    os.makedirs('evaluation_folder/textual_inversion', exist_ok=True)
    os.makedirs(f'evaluation_folder/textual_inversion/{args.concept}', exist_ok=True)
    os.makedirs(f'{args.models_path}/embedding_textual_inversion', exist_ok=True)

    # train the embedding
    for epoch in range(args.epochs):
        for i, batch in enumerate(train_dataloader):
            opt.zero_grad()
            model.zero_grad()
            model.train()

            # Convert images to latent space
            batch_images = batch['pixel_values'].to(devices[0])
            encoder_posterior = model.encode_first_stage(batch_images)
            batch_z = model.get_first_stage_encoding(encoder_posterior).detach()

            # get conditioning - SKIP because in this case, it is the trainable embedding vector
            cond = torch.repeat_interleave(emb, batch_z.shape[0], dim=0)

            # random timestep
            t_enc = torch.randint(0, args.ddim_steps, (1,), device=devices[0]).long()

            # time step from 1000 to 0 (0 being good)
            og_num = round((int(t_enc)/args.ddim_steps)*1000)
            og_num_lim = round((int(t_enc+1)/args.ddim_steps)*1000)

            t_enc_ddpm = torch.randint(og_num, og_num_lim, (batch_z.shape[0],), device=devices[0])

            # add noise
            noise = torch.randn_like(batch_z) * args.noise_scale

            # forward diffusion
            x_noisy = model.q_sample(x_start=batch_z, t=t_enc_ddpm, noise=noise)

            # backward diffusion
            model_output = model.apply_model(x_noisy, t_enc_ddpm, cond)

            # calculate loss
            # in the default setting of Latent Diffusion Models, the output is the epsilon rather than the image
            loss = torch.nn.functional.mse_loss(model_output, noise)

            # optimize
            loss.backward()
            opt.step()

            if i % 100 == 0:
                print(f'Epoch: {epoch}, Batch: {i}, Loss: {loss.item()}')
        
        # inference with the learned embedding
        with torch.no_grad():
            model.eval()

            z_r_till_T = quick_sample_till_t(emb.to(devices[0]), args.start_guidance, fixed_start_code, int(args.ddim_steps))
            decode_and_save_image(model, z_r_till_T, path=f'evaluation_folder/textual_inversion/{args.concept}/gen_{epoch}.png')

            z_r_till_T = quick_sample_till_t(org_emb.to(devices[0]), args.start_guidance, fixed_start_code, int(args.ddim_steps))
            decode_and_save_image(model, z_r_till_T, path=f'evaluation_folder/textual_inversion/{args.concept}/gen_{epoch}_original.png')

            torch.save(emb, f'{args.models_path}/embedding_textual_inversion/emb_{args.concept}_{epoch}.pt')

    return emb.detach()