Comparing Implementations of Diffusion Models - HuggingFace Diffusers vs. CompVis Stable Diffusion
This post is a note for myself to compare the implementations of diffusion models in HuggingFace’s Diffusers and CompVis’s Stable Diffusion. I quite often need to switch between these two implementations, so I want to keep track of the differences between them.
The source code of two libraries can be found at:
- HuggingFace Diffusers: https://github.com/huggingface/diffusers
- CompVis’s Stable Diffusion: https://github.com/CompVis/stable-diffusion and CompVis’s LDM: https://github.com/CompVis/latent-diffusion
Basic Functions
Below are the basic functions of a standard diffusion model pipeline, including:
- Loading components such as tokenizer, scheduler, vae, unet.
- Converting images to latent space.
- Forward and backward diffusion process.
- Calculating loss.
Note that the code snippets below just refer to specific functions and not meant to be a complete script. Read comments in the code to understand the context.
Diffusers
taken from train_text_to_image.py
in here
# Import the necessary modules
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
# load components of the model
# Load tokenizer
if args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
elif args.pretrained_model_name_or_path:
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
# Inside the training loop
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
# Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
CompVis’s Stable Diffusion
In CompVis library, the training parameters are packed in config yaml
files in the configs
folder, and the training script is in train.py
.
The training method uses a Trainer
class which is a wrapper of PyTorch Lightning’s Trainer
class (refer to here).
Lightning Trainer
The Lightning Trainer does much more than just “training”. Under the hood, it handles all loop details for you, some examples include:
- Automatically enabling/disabling grads
- Running the training, validation and test dataloaders
- Calling the Callbacks at the appropriate times
- Putting batches and computations on the correct devices
Here’s the pseudocode for what the trainer does under the hood (showing the train loop only)
# enable grads
torch.set_grad_enabled(True)
losses = []
for batch in train_dataloader:
# calls hooks like this one
on_train_batch_start()
# train step
loss = training_step(batch)
# clear gradients
optimizer.zero_grad()
# backward
loss.backward()
# update parameters
optimizer.step()
losses.append(loss)
In the config file, we can find the paths to the components of the model, such as the VAE, UNet, and scheduler. For example, in configs/latent-diffusion/celebahq-ldm-vq-4.yaml
, these models are defined in the target
field with the corresponding paths and training parameters.
model:
base_learning_rate: 2.0e-06
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.0015
linear_end: 0.0195
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
image_size: 64
channels: 3
monitor: val/loss_simple_ema
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 64
in_channels: 3
out_channels: 3
model_channels: 224
attention_resolutions:
# note: this isn\t actually the resolution but
# the downsampling factor, i.e. this corresnponds to
# attention on spatial resolution 8,16,32, as the
# spatial reolution of the latents is 64 for f4
- 8
- 4
- 2
num_res_blocks: 2
channel_mult:
- 1
- 2
- 3
- 4
num_head_channels: 32
first_stage_config:
target: ldm.models.autoencoder.VQModelInterface
params:
embed_dim: 3
n_embed: 8192
ckpt_path: models/first_stage_models/vq-f4/model.ckpt
ddconfig:
double_z: false
z_channels: 3
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config: __is_unconditional__
data:
target: main.DataModuleFromConfig
params:
batch_size: 48
num_workers: 5
wrap: false
train:
target: taming.data.faceshq.CelebAHQTrain
params:
size: 256
validation:
target: taming.data.faceshq.CelebAHQValidation
params:
size: 256
lightning:
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 5000
max_images: 8
increase_log_steps: False
trainer:
benchmark: True
How to train the model?
IMO, Lightning is difficult to read and understand. I found this post in Reddit, saying that the path of just simple training loop
function (it’s suck)
Trainer.fit() -> Trainer._fit_impl() -> Trainer._run() -> Trainer._run_stage() -> Trainer._run_train() -> FitLoop.run() -> FitLoop.advance() -> TrainingEpochLoop.run() -> TrainingEpochLoop.advance() -> TrainingBatchLoop.run() -> TrainingBatchLoop.advance() -> OptimizerLoop.run() -> OptimizerLoop.advance() -> OptimizerLoop._run_optimization() -> OptimizerLoop._make_closure() -> OptimizerLoop._make_step_fn()
The training procedure is hidden in the class LatentDiffusion
in ldm/models/diffusion/ddpm.py
, function training_step
(refer to this line).
More specifically, the forward pass as follows:
# convert images to latent space
encoder_posterior = self.encode_first_stage(x)
z = self.get_first_stage_encoding(encoder_posterior).detach()
# get conditioning
c = self.get_learned_conditioning(cond_key)
# random timestep
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
# add noise
noise = default(noise, lambda: torch.randn_like(x_start))
# forward diffusion
# x_start is the input clean image
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
# apply model, backward diffusion
model_output = self.apply_model(x_noisy, t, cond)
# choose type of target, there are two types of output of the model, image or noise
# in the default setting of Latent Diffusion Models, the output is the epsilon rather than the image
if self.parameterization == "x0":
target = x_start
elif self.parameterization == "eps":
target = noise
else:
raise NotImplementedError()
# calculate loss
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
logvar_t = self.logvar[t].to(self.device)
loss = loss_simple / torch.exp(logvar_t) + logvar_t
loss = self.l_simple_weight * loss.mean()
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
loss += (self.original_elbo_weight * loss_vlb)
Example
In the following, I will provide a simple code using CompVis’s Stable Diffusion for Textual Inversion, which is already implemented in HuggingFace’s Diffusers here.
The full script including data can be found here https://github.com/tuananhbui89/diffusion_demo/tree/main/textual_inversion
It is a worth noting that in the original Textual Inversion, the final goal is to obtain a special token (e.g., sks dog
) that serves two purposes: (1) it is associated to the visual representation of personal data, and (2) it is in text form so that users can easily use it to generate new images. To do that, in the original implementation, the original embedding matrix is replaced by a new one, however, in my implementation, I skip this step and directly optimize the embedding vector.
def train_inverse(model, sampler, train_data_dir, devices, args):
"""
Given a model and a set of reference images, learn an embedding vector that will generate an image similar to the reference images.
Args:
model: the model to be trained
sampler: the sampler to be used for sampling
train_data_dir: the reference images to be used for training
args: the arguments for training
Returns:
emb: the learned embedding vector
"""
# create a textual embedding variable to optimize
prompt = f'a photo of {args.concept}'
emb = model.get_learned_conditioning([prompt])
org_emb = emb.clone()
emb = Variable(emb, requires_grad=True).to(devices[0])
# create an optimizer to optimize the prompt
opt = torch.optim.Adam([emb], lr=args.lr)
# Dataset and DataLoaders creation:
train_dataset = PreprocessImage(
data_root=train_data_dir,
size=args.resolution,
repeats=args.repeats,
center_crop=args.center_crop,
set="train",
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
)
fixed_start_code = torch.randn((1, 4, 64, 64)).to(devices[0])
# create a lambda function for cleaner use of sampling code (only denoising till time step t)
quick_sample_till_t = lambda cond, s, code, t: sample_model(model, sampler,
cond, args.image_size, args.image_size, args.ddim_steps, s, args.ddim_eta,
start_code=code, till_T=t, verbose=False)
# create a function to decode and save the image
def decode_and_save_image(model, z, path):
x = model.decode_first_stage(z)
x = torch.clamp((x + 1.0)/2.0, min=0.0, max=1.0)
x = rearrange(x, 'b c h w -> b h w c')
image = Image.fromarray((x[0].cpu().numpy()*255).astype(np.uint8))
plt.imshow(image)
plt.xticks([])
plt.yticks([])
plt.savefig(path)
plt.close()
os.makedirs('evaluation_folder', exist_ok=True)
os.makedirs('evaluation_folder/textual_inversion', exist_ok=True)
os.makedirs(f'evaluation_folder/textual_inversion/{args.concept}', exist_ok=True)
os.makedirs(f'{args.models_path}/embedding_textual_inversion', exist_ok=True)
# train the embedding
for epoch in range(args.epochs):
for i, batch in enumerate(train_dataloader):
opt.zero_grad()
model.zero_grad()
model.train()
# Convert images to latent space
batch_images = batch['pixel_values'].to(devices[0])
encoder_posterior = model.encode_first_stage(batch_images)
batch_z = model.get_first_stage_encoding(encoder_posterior).detach()
# get conditioning - SKIP because in this case, it is the trainable embedding vector
cond = torch.repeat_interleave(emb, batch_z.shape[0], dim=0)
# random timestep
t_enc = torch.randint(0, args.ddim_steps, (1,), device=devices[0]).long()
# time step from 1000 to 0 (0 being good)
og_num = round((int(t_enc)/args.ddim_steps)*1000)
og_num_lim = round((int(t_enc+1)/args.ddim_steps)*1000)
t_enc_ddpm = torch.randint(og_num, og_num_lim, (batch_z.shape[0],), device=devices[0])
# add noise
noise = torch.randn_like(batch_z) * args.noise_scale
# forward diffusion
x_noisy = model.q_sample(x_start=batch_z, t=t_enc_ddpm, noise=noise)
# backward diffusion
model_output = model.apply_model(x_noisy, t_enc_ddpm, cond)
# calculate loss
# in the default setting of Latent Diffusion Models, the output is the epsilon rather than the image
loss = torch.nn.functional.mse_loss(model_output, noise)
# optimize
loss.backward()
opt.step()
if i % 100 == 0:
print(f'Epoch: {epoch}, Batch: {i}, Loss: {loss.item()}')
# inference with the learned embedding
with torch.no_grad():
model.eval()
z_r_till_T = quick_sample_till_t(emb.to(devices[0]), args.start_guidance, fixed_start_code, int(args.ddim_steps))
decode_and_save_image(model, z_r_till_T, path=f'evaluation_folder/textual_inversion/{args.concept}/gen_{epoch}.png')
z_r_till_T = quick_sample_till_t(org_emb.to(devices[0]), args.start_guidance, fixed_start_code, int(args.ddim_steps))
decode_and_save_image(model, z_r_till_T, path=f'evaluation_folder/textual_inversion/{args.concept}/gen_{epoch}_original.png')
torch.save(emb, f'{args.models_path}/embedding_textual_inversion/emb_{args.concept}_{epoch}.pt')
return emb.detach()
Enjoy Reading This Article?
Here are some more articles you might like to read next: