What is Safety Checker in Stable Diffusion
With the rapid development of generative AI models, the risk of misuse-generated content has become critical. In this post, I will discuss four main approaches to preventing unwanted content in foundation models and then dive into the implementation of the Safety Checker, the most popular and practical approach used in real-world GenAI applications such as Stable Diffusion or Midjourney.
Approaches to Prevent Unwanted Content in Generative Models
Preventing unwanted content generated by foundation models, such as Stable Diffusion, involves several strategies. Below, we explore four main approaches:
Pre-training
This approach involves filtering out unwanted content and retraining the model from scratch. While effective, it is prohibitively expensive. For instance, as noted in this blog post, training a Stable Diffusion-level model from scratch with MosaicML costs approximately $50,000. This estimate excludes data preprocessing costs, making it economically unfeasible for frequent updates.
Major updates or retraining may occur in response to significant ethical concerns, but they are not suitable for addressing frequent issues like user-generated reports. For example, after the incident involving fake explicit images of Taylor Swift, other celebrities requested similar protections as reported in this article. Or another example, a recent statement signed by more than 13,000 creatives from around the world, including famous actors, singers and authors, warning artificial intelligence (AI) companies that the unlicensed use of their work to train generative AI models is a “major, unjust threat” to their livelihoods. Addressing these cases through pre-training would be highly impractical.
Post-training
In this approach, inappropriate content is detected and censored in the generated output by a safety checker deployed by the model’s developer. Post-training is more economical compared to pre-training, requiring minimal effort or changes to the model’s development pipeline. For instance, a safety checker—discussed in the next section—relies on embedding similarity matching (e.g., using CLIP or multimodal embeddings). These systems can be updated quickly in response to user requests. This method is particularly effective for closed-source models like OpenAI’s or MidJourney’s, where users access the model only via API, therefore, should have higher priorities to be focused on in addressing user’s requests.
However, for open-source models like Stable Diffusion, where users have access to model parameters and source code, this approach can be bypassed with a few lines of code as shown in the figure below, making it ineffective. Even closed-source models are vulnerable to recent black-box jailbreak techniques, which exploit the transferability of adversarial examples between surrogate and target models (e.g., between open-source models like Stable Diffusion and closed-source models like Dall-E).
Fine-tuning
The goal of fine-tuning is to adjust the model’s parameters to “unlearn” its ability to generate unwanted content. Conceptually, the model retains the capability to generate such content, but access to this capability becomes hidden, making it difficult for public users to exploit. From a research perspective, this approach is robust, reliable, and presents numerous opportunities for exploration. Practically, fine-tuning is also efficient, requiring only a few thousand iterations—a process that typically completes within a few hours—without needing additional data (e.g., our methods :D).
Self-protection with Unlearnable Invisible Masks
Unlike the previous three approaches, which are developer-centric, this method is user-centric. The idea is users add an unlearnable invisible mask to their personal data before publishing it publicly, e.g., through a default setting in camera app (entrepreneur idea alert :D). Even if the data is collected by a model, the mask prevents the model from learning the personal data, thereby hindering it from generating related content. This approach has garnered attention from the research community, with notable works like MIST, Anti-Dreambooth, FT-Shield, or Meta-Cloak. It could be a million dollar idea if it works, but unfortunately, to the best of my knowledge, even the SOTA methods are still be bypassed by transformation-based or auto-encoder based jailbreak techniques.
Summary
For open-source models like Stable Diffusion, the fine-tuning approach offers the most promise and economic viability. On the other hand, for closed-source models like OpenAI or MidJourney, post-training methods are the most efficient and should take precedence.
In the next section, we delve deeper into the mechanics of the Safety Checker. The implementation for this module can be found in the diffusers library (version 0.32.0 as the time of writing).
What the heck is Safety Checker?
The StableDiffusionSafetyChecker
module is designed to detect NSFW (Not Safe For Work) or inappropriate content in generated images. In the case a NSFW content is detected, the safety action will be taken, i.e., the output images will be censored and a warning message will be returned.
Below is the implementation of this module in the diffusers library for readers’ convenience
import numpy as np
import torch
import torch.nn as nn
from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
from ...utils import logging
logger = logging.get_logger(__name__)
def cosine_distance(image_embeds, text_embeds):
normalized_image_embeds = nn.functional.normalize(image_embeds)
normalized_text_embeds = nn.functional.normalize(text_embeds)
return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
class StableDiffusionSafetyChecker(PreTrainedModel):
config_class = CLIPConfig
main_input_name = "clip_input"
_no_split_modules = ["CLIPEncoderLayer"]
def __init__(self, config: CLIPConfig):
super().__init__(config)
self.vision_model = CLIPVisionModel(config.vision_config)
self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
@torch.no_grad()
def forward(self, clip_input, images):
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy()
cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
result = []
batch_size = image_embeds.shape[0]
for i in range(batch_size):
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
# increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of filtering benign images
adjustment = 0.0
for concept_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concept_idx]
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["special_scores"][concept_idx] > 0:
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
adjustment = 0.01
for concept_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concept_idx]
concept_threshold = self.concept_embeds_weights[concept_idx].item()
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["concept_scores"][concept_idx] > 0:
result_img["bad_concepts"].append(concept_idx)
result.append(result_img)
has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
if has_nsfw_concept:
if torch.is_tensor(images) or torch.is_tensor(images[0]):
images[idx] = torch.zeros_like(images[idx]) # black image
else:
images[idx] = np.zeros(images[idx].shape) # black image
if any(has_nsfw_concepts):
logger.warning(
"Potential NSFW content was detected in one or more images. A black image will be returned instead."
" Try again with a different prompt and/or seed."
)
return images, has_nsfw_concepts
and how this module is used in the Stable Diffusion pipeline, e.g., in this file pipelines/stable_diffusion/pipeline_stable_diffusion.py
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
How to use the Safety Checker in the ldm
library
in the ldm
library, the safety_checker
is first loaded from pre-trained checkpoint which is converted from the diffusers
library, and then passed to the StableDiffusionPipeline
class,
e.g., in this file https://github.com/CompVis/stable-diffusion/blob/main/scripts/txt2img.py
from diffusers import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
# load safety model
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
def check_safety(x_image):
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
assert x_checked_image.shape[0] == len(has_nsfw_concept)
for i in range(len(has_nsfw_concept)):
if has_nsfw_concept[i]:
x_checked_image[i] = load_replacement(x_checked_image[i])
return x_checked_image, has_nsfw_concept
...
x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)
...
Key Components
- CLIP Vision Model: Uses a pretrained CLIP vision encoder to extract features from images
- Visual Projection Layer: Projects the CLIP features into a specific embedding space
- Two sets of learned concept embeddings:
-
concept_embeds
: 17 different concepts (presumably NSFW concepts) -
special_care_embeds
: 3 special concepts that require extra attention - These embeddings are preloaded and marked as non-trainable (
requires_grad=False
).
-
- Corresponding weights for both types of embeddings:
-
concept_embeds_weights
: 17 weights for the 17 concepts -
special_care_embeds_weights
: 3 weights for the 3 special concepts - These weights determine how strictly the model filters specific concepts.
-
The Detection Process
Step 1: Input Processing
- The input clip_
input is processed by the vision model to extract pooled features (
pooled_output`). - The pooled features are projected into a lower-dimensional embedding space using
visual_projection
.
pooled_output = self.vision_model(clip_input)[1]
image_embeds = self.visual_projection(pooled_output)
Step 2: Cosine Similarity Calculation
- Cosine distances between the image embeddings and:
- Special Care Embeddings (special_cos_dist).
- General Concept Embeddings (cos_dist).
- These distances help determine how closely the input image matches undesirable concepts.
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
cos_dist = cosine_distance(image_embeds, self.concept_embeds)
Step 3: Threshold Detection
- Compares similarities against learned thresholds (embed_weights)
- Uses an adjustment factor (0.01) if special care concepts are detected
- Marks images as NSFW if they exceed thresholds
round(concept_cos - concept_threshold + adjustment, 3) > 0
Step 4: Safety Action
- If NSFW content is detected, replaces the image with a black image (zeros)
if has_nsfw_concept:
images[idx] = torch.zeros_like(images[idx])
How to get the self.concept_embeds
?
The concept_embeds
are initialized as parameters in the model but their actual values are loaded from a pre-trained checkpoint. The code shown only has placeholder initialization:
self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
The actual concept embeddings come from the safety checker model checkpoint that ships with Stable Diffusion. You can find this model on the HuggingFace Hub as CompVis/stable-diffusion-safety-checker.
To get the actual values, you would need to load the safety checker model (defined in convert_from_ckpt.py
) and then access the embeddings:
from diffusers import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
concept_embeds = safety_checker.concept_embeds.cpu().numpy()
special_care_embeds = safety_checker.special_care_embeds.cpu().numpy()
How to bypass the Safety Checker?
After reading the code, it is very simple to bypass the Safety Checker by simply set has_nsfw_concepts
always to False
.
has_nsfw_concepts = [False] * len(images)
Conclusion
In this post, we have discussed the Safety Checker in Stable Diffusion, its implementation, and how to bypass it. We have also provided a detailed explanation of the detection process and the key components of the Safety Checker. I believe that understanding this module in greater detail can lead to some interesting research ideas, for example, how to quickly update the Safety Checker to detect new NSFW concepts or how to bypass the Safety Checker without modifying the code or can we use this module as a surrogate for other NSFW detection models.
Thank you for reading!
Enjoy Reading This Article?
Here are some more articles you might like to read next: