About the paper

  • Published at ICCV 2023
  • Affiliations: Northeastern University and MIT.
  • Motivation: Remove specific concepts from diffusion models weights. The concept can be a specific style (i.e., nudity, Van Gogh style, etc.) or a specific object (i.e., car, dog, etc.) while preserving capability on other concepts.
Examples of erasing nudity, Van Gogh style or an objects from a Stable Diffusion model (Image source: Gandikota et al. (2023)).

Approach

The central optimization problem is to reduce the probability of generating an image x according to the likelihood that is described by the concept, scaled by a power factor η.

Pθ(x)Pθ(x)Pθ(cx)η

where Pθ(x) is the distribution generated by the original model θ and Pθ(cx) is the probability of the concept c given the image x. The power factor η controls the strength of the concept erasure. A larger η means a stronger erasure. θ is the parameters of the model after unlearning the concept c.

It can be interpreted as: if the concept c is present in the image x in which Pθ(cx) is high, then the likelihood of the image x under the new model Pθ(x) will be reduced. While if the concept c is not present in the image x in which Pθ(cx) is low, then the likelihood of the image x under the new model Pθ(x) will be increased.

Because of the Bayes’ rule, the likelihood of the concept c given the image x can be rewritten as follows:

Pθ(cx)=Pθ(xc)Pθ(c)Pθ(x)

Therefore, the above equation can be rewritten when taking the derivative w.r.t. x as follows:

xlogPθ(x)xlogPθ(x)ηxlogPθ(cx) xlogPθ(x)xlogPθ(x)η(xlogPθ(xc)+xlogPθ(c)xlogPθ(x)) xlogPθ(x)xlogPθ(x)η(xlogPθ(xc)xlogPθ(x))

Because in the diffusion model, each step has been approximated to a Gaussian distribution, therefore, the gradient of the log-likelihood is computed as follows:

xlogPθ(x)=1σ2(xμ)

where μ is the mean of the diffusion model, σ is the standard deviation of the diffusion model, and c is the concept. Based on the repameterization trick, the gradient of the log-likelihood is correlated with the noise ϵ at each step as follows (linking between DDPM and the score-based matching approaches):

ϵθ(xt,t)ϵθ(xt,t)η(ϵθ(xt,c,t)ϵθ(xt,t))

where ϵθ(xt,t) is the noise at step t of the diffusion model after unlearning the concept c. Finally, to fine-tune the diffusion model from pretrained model θ to new cleaned model θ, the authors proposed to minimize the following loss function:

L(θ)=Ex0D[t=0T1ϵθ(xt,t)ϵθ(xt,t)+η(ϵθ(xt,c,t)ϵθ(xt,t))2]

where x0 is the input image sampled from data distribution D, T is the number of steps of the diffusion model.

Instead of recursively sampling the noise ϵθ(xt,t) at every step, we can sample the time step tU(0,T1) and then sample the noise ϵθ(xt,t) at that time step. Therefore, the loss function can be rewritten as follows:

L(θ)=Ex0D[ϵθ(xt,t)ϵθ(xt,t)+η(ϵθ(xt,c,t)ϵθ(xt,t))2]

where tU(0,T1).

The Final Objective Function

However, in the paper, instead of using the above loss function, the author proposed to use the following loss function:

L(θ)=Ex0D[ϵθ(xt,c,t)ϵθ(xt,t)+η(ϵθ(xt,c,t)ϵθ(xt,t))2]

where tU(0,T1).

The difference between the two loss functions is that the first loss function is computed based on the unconditional noise ϵθ(xt,t) at the time step t while the second loss function is computed based on the noise ϵθ(xt,c,t) at the time step t conditioned on the concept c.

Interpretation of the loss function: By minimizing the above loss function, we try to force the conditional noise ϵθ(xt,c,t) to be close to the unconditional noise ϵθ(xt,t) of the original model. Because the noise ϵθ(xt,t) is the signal to guide the diffusion model to generate the image xt1 (recall the denoising equation xt1=1αt(xt1αt1α¯tϵθ(xt,t))+σtz), therefore, by forcing the conditional noise ϵθ(xt,c,t) to be close to the unconditional noise ϵθ(xt,t), we try to force the diffusion model to generate the image xt1 close to the image generated without the concept c.

Note: In the above objective function, xt is the image from the training set D at time step t. However, as mentioned in the paper “We exploit the model’s knowledge of the concept to synthesize training samples, thereby eliminating the need for data collection”. Therefore, in the implementation, xt is the image generated by the fine-tuned model at time step t.

How to implement

Link to the original implementation: https://github.com/rohitgandikota/erasing

The minimal code of this project is as follows:


def train_esd():

  # choose parameters to train based on train_method, 
  # e.g., 'noxattn', 'selfattn', 'xattn', 'full'
  parameters = []
  for name, param in model.model.diffusion_model.named_parameters():
    # train all layers except x-attns and time_embed layers
    if train_method == 'noxattn':
      if name.startswith('out.') or 'attn2' in name or 'time_embed' in name:
        pass
      else:
        print(name)
        parameters.append(param)
    # and so on for other train_methods
  
  # set model to train mode
  model.train()

  # create a lambda function for cleaner use of sampling code (only denoising till time step t)
  quick_sample_till_t = lambda conditioning, scale, start_code, t: sample_model(model, sampler, conditioning, image_size, image_size, ddim_steps, scale, ddim_eta, start_code=start_code, till_T=t, verbose=False)

  # set optimizer to learn only the parameters that we want
  opt = torch.optim.Adam(parameters, lr=lr)

  # train loop
  for i in range():
    # sample concept from the list of concepts
    word = random.sample(words,1)[0]

    # get text embeddings for unconditional and conditional prompts
    # What are the differences between positive and negative prompts?
    emb_0 = model.get_learned_conditioning(['']) # unconditional
    emb_p = model.get_learned_conditioning([word]) # positive
    emb_n = model.get_learned_conditioning([f'{word}']) # negative

    # clear gradients 
    opt.zero_grad()

    # get the time embedding with DDIM approach
    t_enc_ddpm = torch.randint()

    with torch.no_grad():
      # generate an image with the concept from ESD model
      z = quick_sample_till_t(emb_p, start_guidance, start_code, int(t_enc))
      # get conditional and unconditional scores from frozen model at time step t and image z generated above
      e_0 = model_orig.apply_model(z, t_enc_ddpm, emb_0)
      e_p = model_orig.apply_model(z, t_enc_ddpm, emb_p)

    # get negative scores from the ESD model
    e_n = model.apply_model(z, t_enc_ddpm, emb_d)

    # Stop gradient of the unconditional and positive scores
    e_0 = e_0.detach()
    e_p = e_p.detach()

    # compute the loss function 
    loss = (e_n - e_0 + negative_guidance * (e_p - e_0)) ** 2

    # update the model to erase the concept
    loss.backward()
    opt.step()

Something to note

  • In the above objective function, xt is the image from the training set D at time step t. However, as mentioned in the paper “We exploit the model’s knowledge of the concept to synthesize training samples, thereby eliminating the need for data collection”. Therefore, in the implementation, xt is the image generated by the fine-tuned model at time step t.
  • When generating the image xt from the fine-tuned model, the authors used emb_p (embedding with conditional image) instead of emb_0 (embedding with unconditional image). So xtPθ(xtc) instead of xtPθ(xt).

So the loss function in the implementation is as follows:

L(θ)=ExtPθ(.c)[ϵθ(xt,c,t)ϵθ(xt,t)+η(ϵθ(xt,c,t)ϵθ(xt,t))2]

where tU(0,T1).

However, this approach might lead to a problem that because sample z is generated by the fine-tuned model θ which is later used to estimate the score en of the fine-tuned model θ, therefore, when do backpropagation, the gradient of the loss function will be backpropagated twice through the fine-tuned model θ which might lead to unstability.