Erasing Concepts from Diffusion Models
About the paper
- Published at ICCV 2023
- Affiliations: Northeastern University and MIT.
- Motivation: Remove specific concepts from diffusion models weights. The concept can be a specific style (i.e., nudity, Van Gogh style, etc.) or a specific object (i.e., car, dog, etc.) while preserving capability on other concepts.
- Project page: https://erasing.baulab.info/
Approach
The central optimization problem is to reduce the probability of generating an image
where
It can be interpreted as: if the concept
Because of the Bayes’ rule, the likelihood of the concept
Therefore, the above equation can be rewritten when taking the derivative w.r.t.
Because in the diffusion model, each step has been approximated to a Gaussian distribution, therefore, the gradient of the log-likelihood is computed as follows:
where
where
where
Instead of recursively sampling the noise
where
The Final Objective Function
However, in the paper, instead of using the above loss function, the author proposed to use the following loss function:
where
The difference between the two loss functions is that the first loss function is computed based on the unconditional noise
Interpretation of the loss function: By minimizing the above loss function, we try to force the conditional noise
Note: In the above objective function,
How to implement
Link to the original implementation: https://github.com/rohitgandikota/erasing
The minimal code of this project is as follows:
def train_esd():
# choose parameters to train based on train_method,
# e.g., 'noxattn', 'selfattn', 'xattn', 'full'
parameters = []
for name, param in model.model.diffusion_model.named_parameters():
# train all layers except x-attns and time_embed layers
if train_method == 'noxattn':
if name.startswith('out.') or 'attn2' in name or 'time_embed' in name:
pass
else:
print(name)
parameters.append(param)
# and so on for other train_methods
# set model to train mode
model.train()
# create a lambda function for cleaner use of sampling code (only denoising till time step t)
quick_sample_till_t = lambda conditioning, scale, start_code, t: sample_model(model, sampler, conditioning, image_size, image_size, ddim_steps, scale, ddim_eta, start_code=start_code, till_T=t, verbose=False)
# set optimizer to learn only the parameters that we want
opt = torch.optim.Adam(parameters, lr=lr)
# train loop
for i in range():
# sample concept from the list of concepts
word = random.sample(words,1)[0]
# get text embeddings for unconditional and conditional prompts
# What are the differences between positive and negative prompts?
emb_0 = model.get_learned_conditioning(['']) # unconditional
emb_p = model.get_learned_conditioning([word]) # positive
emb_n = model.get_learned_conditioning([f'{word}']) # negative
# clear gradients
opt.zero_grad()
# get the time embedding with DDIM approach
t_enc_ddpm = torch.randint()
with torch.no_grad():
# generate an image with the concept from ESD model
z = quick_sample_till_t(emb_p, start_guidance, start_code, int(t_enc))
# get conditional and unconditional scores from frozen model at time step t and image z generated above
e_0 = model_orig.apply_model(z, t_enc_ddpm, emb_0)
e_p = model_orig.apply_model(z, t_enc_ddpm, emb_p)
# get negative scores from the ESD model
e_n = model.apply_model(z, t_enc_ddpm, emb_d)
# Stop gradient of the unconditional and positive scores
e_0 = e_0.detach()
e_p = e_p.detach()
# compute the loss function
loss = (e_n - e_0 + negative_guidance * (e_p - e_0)) ** 2
# update the model to erase the concept
loss.backward()
opt.step()
Something to note
- In the above objective function,
is the image from the training set at time step . However, as mentioned in the paper “We exploit the model’s knowledge of the concept to synthesize training samples, thereby eliminating the need for data collection”. Therefore, in the implementation, is the image generated by the fine-tuned model at time step . - When generating the image
from the fine-tuned model, the authors usedemb_p
(embedding with conditional image) instead ofemb_0
(embedding with unconditional image). So instead of .
So the loss function in the implementation is as follows:
where
However, this approach might lead to a problem that because sample
Enjoy Reading This Article?
Here are some more articles you might like to read next: