Anti-Dreambooth
About the paper
- Published at ICCV 2023
- Affiliation: VinAI Research
- Main idea: Generate invisible perturbation to add to the personal images before uploading to the internet. To prevent adversary from using the uploaded images to train a personalized model (specific to Dreambooth method) to generate harmful images (e.g., nude images) of the person.
I gave a talk about “Exploring Controllability of Conditioned Diffusion Models” covering the “Dreambooth” paper at the event of Postdoc-NeT-AI program. Slide can be found here. Therefore, in this blog post, I would like to focus on the implementation side of the paper.
How to implement
How to preprocess the data
It is worth noting that in adversarial examples, the perturbation is added to the post-processed image while in this project the perturbation should be added to the pre-processed image and robust to the pre-processing step. However, in the current implementation, the perturbation is added to the post-processed image.
- It first call the
load_data
function that read PIL image and apply some transformations (e.g., resize, crop, normalize) and return a tensor (shape = [N, H, W, C], channel last format???). Ref: Line 360
def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor:
image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
images = [image_transforms(Image.open(i).convert("RGB")) for i in list(Path(data_dir).iterdir())]
images = torch.stack(images)
return images
- It then call the
DreamBoothDatasetFromTensor
class with input argumentinstance_images_tensor
which is the tensor returned from theload_data
function. In this class, when__getitem__
function is called, the data will be associated with a corresponding textual prompt. There is no transformation applied in this class. Ref: Line 31
def __getitem__(self, index):
example = {}
instance_image = self.instance_images_tensor[index % self.num_instance_images]
example["instance_images"] = instance_image
example["instance_prompt_ids"] = self.tokenizer(
self.instance_prompt,
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images])
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
example["class_prompt_ids"] = self.tokenizer(
self.class_prompt,
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
return example
The overall pipeline is as the snippet code below:
- Load the data from the
instance_data_dir_for_adversarial
directory. Output is a tensor of shape [N, H, W, C]? should it be [N, C, H, W]?# Load data from the instance data directory # output: tensor of shape [N, H, W, C]? should it be [N, C, H, W]? perturbed_data = load_data( args.instance_data_dir_for_adversarial, size=args.resolution, center_crop=args.center_crop, )
- Clone the current model to avoid the in-place operation
- Train the model with the clean data
- Learn the perturbation with the updated model f_sur. Input is the entire data tensor not just a batch! Output is the new perturbed data tensor.
- Restore the model and train with perturbed data
f = [unet, text_encoder]
for i in range(args.max_train_steps):
# Clone the current model to avoid the in-place operation
f_sur = copy.deepcopy(f)
# Train the model with the clean data
f_sur = train_one_epoch(
args,
f_sur,
tokenizer,
noise_scheduler,
vae,
clean_data,
args.max_f_train_steps,
)
# Learn the perturbation with the updated model f_sur
perturbed_data = pgd_attack(
args,
f_sur,
tokenizer,
noise_scheduler,
vae,
perturbed_data,
original_data,
target_latent_tensor,
args.max_adv_train_steps,
)
# Restore the model and train with perturbed data
f = train_one_epoch(
args,
f,
tokenizer,
noise_scheduler,
vae,
perturbed_data,
args.max_f_train_steps,
)
Inside the train_one_epoch
function, the DreamBoothDatasetFromTensor
class is called to associate the data (i.e., perturbed data) with the corresponding textual prompt.
train_dataset = DreamBoothDatasetFromTensor(
data_tensor,
args.instance_prompt,
tokenizer,
args.class_data_dir,
args.class_prompt,
args.resolution,
args.center_crop,
)
At the end, the perturbed image is saved in the instance_data_dir_for_adversarial
directory.
Some notes:
- In the original Dreambooth project, the data is loaded from the
DataLoader
class and is shuffled. However, in the Anti-Dreambooth project, the data is loaded from theDreamBoothDatasetFromTensor
class and is not shuffled. Ref: Line 1061 - The reason for the above modification is that the author want to change on the fly the perturbed data after each epoch, which will be harder in control if using
DataLoader
class.
Difference in prompting process between “Textual Inversion” and “Dreambooth” projects
In Dreambooth, there is an argument instance_prompt
which is used as a neural prompt to associate with the given images. For example, the default value is a photo of sks dog
, where sks
is the unique identifier to specify the learned concept. The instance_prompt
is then tokenized by the tokenizer and the token ids are used to specify the position in the embedding matrix to be updated (corresponding to the specific token).
# In the DreamBoothDataset class
if self.encoder_hidden_states is not None:
example["instance_prompt_ids"] = self.encoder_hidden_states
else:
text_inputs = tokenize_prompt(
self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
)
example["instance_prompt_ids"] = text_inputs.input_ids
example["instance_attention_mask"] = text_inputs.attention_mask
So the difference between the two projects is that:
- In Dreambooth, only one neural prompt is used, while in Textual Inversion, there is a list of neural prompts
- In Textual Inversion, it is important to specify the
placeholder_token
to reuse the same token in other prompts, while in Dreambooth, the identifier (i.e.,sks
) is used to specify the position in the embedding matrix to be updated (corresponding to the specific token). In inferencce, a prompt with the same identifier will be used to generate images, for example,a photo of sks dog in the beach
. So to me, the whole prompt in Dreambooth is like a placeholder token in Textual Inversion. However, in this case, how the output looks like if we use a prompt that not contains the wholeinstance_prompt
? For example,a sks dog walking on the beach
?
PGD attack
The PGD attack is implemented in the pgd_attack
function. The input is the perturbed data tensor and the output is the new perturbed data tensor.
Some notes:
- weight type is
torch.bfloat16
instead oftorch.float32
- unet, vae, text_encoder are in
train
mode (because they were set intrain_one_epoch
function) - Learn for the entire data tensor not just a batch
- The whole process is quite similar to the standard PGD attack without the random initialization.
# Create a copy of data and set requires_grad to True
perturbed_images = data_tensor.detach().clone()
perturbed_images.requires_grad_(True)
# Repeat the input_ids to match the batch size
input_ids = tokenizer(
args.instance_prompt,
truncation=True,
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
).input_ids.repeat(len(data_tensor), 1)
# Loop over the number of steps
for step in range(num_steps):
# Reset requires_grad to True because it was set to False in the last step
perturbed_images.requires_grad = True
latents = vae.encode(perturbed_images.to(device, dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(input_ids.to(device))[0]
# Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
unet.zero_grad()
text_encoder.zero_grad()
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# target-shift loss
if target_tensor is not None:
xtm1_pred = torch.cat(
[
noise_scheduler.step(
model_pred[idx : idx + 1],
timesteps[idx : idx + 1],
noisy_latents[idx : idx + 1],
).prev_sample
for idx in range(len(model_pred))
]
)
xtm1_target = noise_scheduler.add_noise(target_tensor, noise, timesteps - 1)
loss = loss - F.mse_loss(xtm1_pred, xtm1_target)
loss.backward()
alpha = args.pgd_alpha
eps = args.pgd_eps
# Project to valid range
adv_images = perturbed_images + alpha * perturbed_images.grad.sign()
eta = torch.clamp(adv_images - original_images, min=-eps, max=+eps)
perturbed_images = torch.clamp(original_images + eta, min=-1, max=+1).detach_()
print(f"PGD loss - step {step}, loss: {loss.detach().item()}")
return perturbed_images
Enjoy Reading This Article?
Here are some more articles you might like to read next: