With the rapid development of generative AI models, the risk of misuse-generated content has become critical. In this post, I will discuss four main approaches to preventing unwanted content in foundation models and then dive into the implementation of the Safety Checker, the most popular and practical approach used in real-world GenAI applications such as Stable Diffusion or Midjourney.

Approaches to Prevent Unwanted Content in Generative Models

Preventing unwanted content generated by foundation models, such as Stable Diffusion, involves several strategies. Below, we explore four main approaches:

Pre-training

This approach involves filtering out unwanted content and retraining the model from scratch. While effective, it is prohibitively expensive. For instance, as noted in this blog post, training a Stable Diffusion-level model from scratch with MosaicML costs approximately $50,000. This estimate excludes data preprocessing costs, making it economically unfeasible for frequent updates.

Major updates or retraining may occur in response to significant ethical concerns, but they are not suitable for addressing frequent issues like user-generated reports. For example, after the incident involving fake explicit images of Taylor Swift, other celebrities requested similar protections as reported in this article. Or another example, a recent statement signed by more than 13,000 creatives from around the world, including famous actors, singers and authors, warning artificial intelligence (AI) companies that the unlicensed use of their work to train generative AI models is a “major, unjust threat” to their livelihoods. Addressing these cases through pre-training would be highly impractical.

The cost of training a Stable-Diffusion level model from scratch.

Post-training

In this approach, inappropriate content is detected and censored in the generated output by a safety checker deployed by the model’s developer. Post-training is more economical compared to pre-training, requiring minimal effort or changes to the model’s development pipeline. For instance, a safety checker—discussed in the next section—relies on embedding similarity matching (e.g., using CLIP or multimodal embeddings). These systems can be updated quickly in response to user requests. This method is particularly effective for closed-source models like OpenAI’s or MidJourney’s, where users access the model only via API, therefore, should have higher priorities to be focused on in addressing user’s requests.

However, for open-source models like Stable Diffusion, where users have access to model parameters and source code, this approach can be bypassed with a few lines of code as shown in the figure below, making it ineffective. Even closed-source models are vulnerable to recent black-box jailbreak techniques, which exploit the transferability of adversarial examples between surrogate and target models (e.g., between open-source models like Stable Diffusion and closed-source models like Dall-E).

Bypassing the NSFW detector of Stable Diffusion.

Fine-tuning

The goal of fine-tuning is to adjust the model’s parameters to “unlearn” its ability to generate unwanted content. Conceptually, the model retains the capability to generate such content, but access to this capability becomes hidden, making it difficult for public users to exploit. From a research perspective, this approach is robust, reliable, and presents numerous opportunities for exploration. Practically, fine-tuning is also efficient, requiring only a few thousand iterations—a process that typically completes within a few hours—without needing additional data (e.g., our methods :D).

Self-protection with Unlearnable Invisible Masks

Unlike the previous three approaches, which are developer-centric, this method is user-centric. The idea is users add an unlearnable invisible mask to their personal data before publishing it publicly, e.g., through a default setting in camera app (entrepreneur idea alert :D). Even if the data is collected by a model, the mask prevents the model from learning the personal data, thereby hindering it from generating related content. This approach has garnered attention from the research community, with notable works like MIST, Anti-Dreambooth, FT-Shield, or Meta-Cloak. It could be a million dollar idea if it works, but unfortunately, to the best of my knowledge, even the SOTA methods are still be bypassed by transformation-based or auto-encoder based jailbreak techniques.

Summary

For open-source models like Stable Diffusion, the fine-tuning approach offers the most promise and economic viability. On the other hand, for closed-source models like OpenAI or MidJourney, post-training methods are the most efficient and should take precedence.

In the next section, we delve deeper into the mechanics of the Safety Checker. The implementation for this module can be found in the diffusers library (version 0.32.0 as the time of writing).

What the heck is Safety Checker?

The StableDiffusionSafetyChecker module is designed to detect NSFW (Not Safe For Work) or inappropriate content in generated images. In the case a NSFW content is detected, the safety action will be taken, i.e., the output images will be censored and a warning message will be returned.

Below is the implementation of this module in the diffusers library for readers’ convenience

import numpy as np
import torch
import torch.nn as nn
from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel

from ...utils import logging


logger = logging.get_logger(__name__)


def cosine_distance(image_embeds, text_embeds):
    normalized_image_embeds = nn.functional.normalize(image_embeds)
    normalized_text_embeds = nn.functional.normalize(text_embeds)
    return torch.mm(normalized_image_embeds, normalized_text_embeds.t())


class StableDiffusionSafetyChecker(PreTrainedModel):
    config_class = CLIPConfig
    main_input_name = "clip_input"

    _no_split_modules = ["CLIPEncoderLayer"]

    def __init__(self, config: CLIPConfig):
        super().__init__(config)

        self.vision_model = CLIPVisionModel(config.vision_config)
        self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)

        self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
        self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)

        self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
        self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)

    @torch.no_grad()
    def forward(self, clip_input, images):
        pooled_output = self.vision_model(clip_input)[1]  # pooled_output
        image_embeds = self.visual_projection(pooled_output)

        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
        special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy()
        cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()

        result = []
        batch_size = image_embeds.shape[0]
        for i in range(batch_size):
            result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}

            # increase this value to create a stronger `nfsw` filter
            # at the cost of increasing the possibility of filtering benign images
            adjustment = 0.0

            for concept_idx in range(len(special_cos_dist[0])):
                concept_cos = special_cos_dist[i][concept_idx]
                concept_threshold = self.special_care_embeds_weights[concept_idx].item()
                result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
                if result_img["special_scores"][concept_idx] > 0:
                    result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
                    adjustment = 0.01

            for concept_idx in range(len(cos_dist[0])):
                concept_cos = cos_dist[i][concept_idx]
                concept_threshold = self.concept_embeds_weights[concept_idx].item()
                result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
                if result_img["concept_scores"][concept_idx] > 0:
                    result_img["bad_concepts"].append(concept_idx)

            result.append(result_img)

        has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]

        for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
            if has_nsfw_concept:
                if torch.is_tensor(images) or torch.is_tensor(images[0]):
                    images[idx] = torch.zeros_like(images[idx])  # black image
                else:
                    images[idx] = np.zeros(images[idx].shape)  # black image

        if any(has_nsfw_concepts):
            logger.warning(
                "Potential NSFW content was detected in one or more images. A black image will be returned instead."
                " Try again with a different prompt and/or seed."
            )

        return images, has_nsfw_concepts

and how this module is used in the Stable Diffusion pipeline, e.g., in this file pipelines/stable_diffusion/pipeline_stable_diffusion.py

def run_safety_checker(self, image, device, dtype):
    if self.safety_checker is None:
        has_nsfw_concept = None
    else:
        if torch.is_tensor(image):
            feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
        else:
            feature_extractor_input = self.image_processor.numpy_to_pil(image)
        safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
        image, has_nsfw_concept = self.safety_checker(
            images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
        )
    return image, has_nsfw_concept

How to use the Safety Checker in the ldm library

in the ldm library, the safety_checker is first loaded from pre-trained checkpoint which is converted from the diffusers library, and then passed to the StableDiffusionPipeline class, e.g., in this file https://github.com/CompVis/stable-diffusion/blob/main/scripts/txt2img.py

from diffusers import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor

# load safety model
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)

def check_safety(x_image):
    safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
    x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
    assert x_checked_image.shape[0] == len(has_nsfw_concept)
    for i in range(len(has_nsfw_concept)):
        if has_nsfw_concept[i]:
            x_checked_image[i] = load_replacement(x_checked_image[i])
    return x_checked_image, has_nsfw_concept

...
x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)
...

Key Components

  • CLIP Vision Model: Uses a pretrained CLIP vision encoder to extract features from images
  • Visual Projection Layer: Projects the CLIP features into a specific embedding space
  • Two sets of learned concept embeddings:
    • concept_embeds: 17 different concepts (presumably NSFW concepts)
    • special_care_embeds: 3 special concepts that require extra attention
    • These embeddings are preloaded and marked as non-trainable (requires_grad=False).
  • Corresponding weights for both types of embeddings:
    • concept_embeds_weights: 17 weights for the 17 concepts
    • special_care_embeds_weights: 3 weights for the 3 special concepts
    • These weights determine how strictly the model filters specific concepts.

The Detection Process

Step 1: Input Processing

  • The input clip_input is processed by the vision model to extract pooled features (pooled_output`).
  • The pooled features are projected into a lower-dimensional embedding space using visual_projection.
pooled_output = self.vision_model(clip_input)[1]
image_embeds = self.visual_projection(pooled_output)

Step 2: Cosine Similarity Calculation

  • Cosine distances between the image embeddings and:
    • Special Care Embeddings (special_cos_dist).
    • General Concept Embeddings (cos_dist).
  • These distances help determine how closely the input image matches undesirable concepts.
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
cos_dist = cosine_distance(image_embeds, self.concept_embeds)

Step 3: Threshold Detection

  • Compares similarities against learned thresholds (embed_weights)
  • Uses an adjustment factor (0.01) if special care concepts are detected
  • Marks images as NSFW if they exceed thresholds
round(concept_cos - concept_threshold + adjustment, 3) > 0

Step 4: Safety Action

  • If NSFW content is detected, replaces the image with a black image (zeros)
if has_nsfw_concept:
    images[idx] = torch.zeros_like(images[idx])

How to get the self.concept_embeds?

The concept_embeds are initialized as parameters in the model but their actual values are loaded from a pre-trained checkpoint. The code shown only has placeholder initialization:

self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)

self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)

The actual concept embeddings come from the safety checker model checkpoint that ships with Stable Diffusion. You can find this model on the HuggingFace Hub as CompVis/stable-diffusion-safety-checker.

To get the actual values, you would need to load the safety checker model (defined in convert_from_ckpt.py) and then access the embeddings:

from diffusers import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor

safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")

concept_embeds = safety_checker.concept_embeds.cpu().numpy()
special_care_embeds = safety_checker.special_care_embeds.cpu().numpy()

How to bypass the Safety Checker?

After reading the code, it is very simple to bypass the Safety Checker by simply set has_nsfw_concepts always to False.

has_nsfw_concepts = [False] * len(images)

Conclusion

In this post, we have discussed the Safety Checker in Stable Diffusion, its implementation, and how to bypass it. We have also provided a detailed explanation of the detection process and the key components of the Safety Checker. I believe that understanding this module in greater detail can lead to some interesting research ideas, for example, how to quickly update the Safety Checker to detect new NSFW concepts or how to bypass the Safety Checker without modifying the code or can we use this module as a surrogate for other NSFW detection models.

Thank you for reading!