-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add nsfw image censoring via config and checkbox (#958)
* add nsfw image censoring activatable via config, uses CompVis/stable-diffusion-safety-checker * fix progressbar call for nsfw output * use config to set cache dir for safety checker * add checkbox black_out_nsfw makes both enabling via config and checkbox possible, where config overrides the checkbox value * fix: add missing diffusers package * feat: extract safety checker, remove dependency to diffusers * feat: make code compatible again after merge with main * feat: move censor to extras, optimize safety checker file handling * refactor: rename folder safety_checker_models to safety_checker
- Loading branch information
Showing
9 changed files
with
424 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# modified version of https://github.com/AUTOMATIC1111/stable-diffusion-webui-nsfw-censor/blob/master/scripts/censor.py | ||
import numpy as np | ||
import os | ||
|
||
from extras.safety_checker.models.safety_checker import StableDiffusionSafetyChecker | ||
from transformers import CLIPFeatureExtractor, CLIPConfig | ||
from PIL import Image | ||
import modules.config | ||
|
||
safety_checker_repo_root = os.path.join(os.path.dirname(__file__), 'safety_checker') | ||
config_path = os.path.join(safety_checker_repo_root, "configs", "config.json") | ||
preprocessor_config_path = os.path.join(safety_checker_repo_root, "configs", "preprocessor_config.json") | ||
|
||
safety_feature_extractor = None | ||
safety_checker = None | ||
|
||
|
||
def numpy_to_pil(image): | ||
image = (image * 255).round().astype("uint8") | ||
pil_image = Image.fromarray(image) | ||
|
||
return pil_image | ||
|
||
|
||
# check and replace nsfw content | ||
def check_safety(x_image): | ||
global safety_feature_extractor, safety_checker | ||
|
||
if safety_feature_extractor is None or safety_checker is None: | ||
safety_checker_model = modules.config.downloading_safety_checker_model() | ||
safety_feature_extractor = CLIPFeatureExtractor.from_json_file(preprocessor_config_path) | ||
clip_config = CLIPConfig.from_json_file(config_path) | ||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_checker_model, config=clip_config) | ||
|
||
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") | ||
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) | ||
|
||
return x_checked_image, has_nsfw_concept | ||
|
||
|
||
def censor_single(x): | ||
x_checked_image, has_nsfw_concept = check_safety(x) | ||
|
||
# replace image with black pixels, keep dimensions | ||
# workaround due to different numpy / pytorch image matrix format | ||
if has_nsfw_concept[0]: | ||
imageshape = x_checked_image.shape | ||
x_checked_image = np.zeros((imageshape[0], imageshape[1], 3), dtype = np.uint8) | ||
|
||
return x_checked_image | ||
|
||
|
||
def censor_batch(images): | ||
images = [censor_single(image) for image in images] | ||
|
||
return images |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
{ | ||
"_name_or_path": "clip-vit-large-patch14/", | ||
"architectures": [ | ||
"SafetyChecker" | ||
], | ||
"initializer_factor": 1.0, | ||
"logit_scale_init_value": 2.6592, | ||
"model_type": "clip", | ||
"projection_dim": 768, | ||
"text_config": { | ||
"_name_or_path": "", | ||
"add_cross_attention": false, | ||
"architectures": null, | ||
"attention_dropout": 0.0, | ||
"bad_words_ids": null, | ||
"bos_token_id": 0, | ||
"chunk_size_feed_forward": 0, | ||
"cross_attention_hidden_size": null, | ||
"decoder_start_token_id": null, | ||
"diversity_penalty": 0.0, | ||
"do_sample": false, | ||
"dropout": 0.0, | ||
"early_stopping": false, | ||
"encoder_no_repeat_ngram_size": 0, | ||
"eos_token_id": 2, | ||
"exponential_decay_length_penalty": null, | ||
"finetuning_task": null, | ||
"forced_bos_token_id": null, | ||
"forced_eos_token_id": null, | ||
"hidden_act": "quick_gelu", | ||
"hidden_size": 768, | ||
"id2label": { | ||
"0": "LABEL_0", | ||
"1": "LABEL_1" | ||
}, | ||
"initializer_factor": 1.0, | ||
"initializer_range": 0.02, | ||
"intermediate_size": 3072, | ||
"is_decoder": false, | ||
"is_encoder_decoder": false, | ||
"label2id": { | ||
"LABEL_0": 0, | ||
"LABEL_1": 1 | ||
}, | ||
"layer_norm_eps": 1e-05, | ||
"length_penalty": 1.0, | ||
"max_length": 20, | ||
"max_position_embeddings": 77, | ||
"min_length": 0, | ||
"model_type": "clip_text_model", | ||
"no_repeat_ngram_size": 0, | ||
"num_attention_heads": 12, | ||
"num_beam_groups": 1, | ||
"num_beams": 1, | ||
"num_hidden_layers": 12, | ||
"num_return_sequences": 1, | ||
"output_attentions": false, | ||
"output_hidden_states": false, | ||
"output_scores": false, | ||
"pad_token_id": 1, | ||
"prefix": null, | ||
"problem_type": null, | ||
"pruned_heads": {}, | ||
"remove_invalid_values": false, | ||
"repetition_penalty": 1.0, | ||
"return_dict": true, | ||
"return_dict_in_generate": false, | ||
"sep_token_id": null, | ||
"task_specific_params": null, | ||
"temperature": 1.0, | ||
"tie_encoder_decoder": false, | ||
"tie_word_embeddings": true, | ||
"tokenizer_class": null, | ||
"top_k": 50, | ||
"top_p": 1.0, | ||
"torch_dtype": null, | ||
"torchscript": false, | ||
"transformers_version": "4.21.0.dev0", | ||
"typical_p": 1.0, | ||
"use_bfloat16": false, | ||
"vocab_size": 49408 | ||
}, | ||
"text_config_dict": { | ||
"hidden_size": 768, | ||
"intermediate_size": 3072, | ||
"num_attention_heads": 12, | ||
"num_hidden_layers": 12 | ||
}, | ||
"torch_dtype": "float32", | ||
"transformers_version": null, | ||
"vision_config": { | ||
"_name_or_path": "", | ||
"add_cross_attention": false, | ||
"architectures": null, | ||
"attention_dropout": 0.0, | ||
"bad_words_ids": null, | ||
"bos_token_id": null, | ||
"chunk_size_feed_forward": 0, | ||
"cross_attention_hidden_size": null, | ||
"decoder_start_token_id": null, | ||
"diversity_penalty": 0.0, | ||
"do_sample": false, | ||
"dropout": 0.0, | ||
"early_stopping": false, | ||
"encoder_no_repeat_ngram_size": 0, | ||
"eos_token_id": null, | ||
"exponential_decay_length_penalty": null, | ||
"finetuning_task": null, | ||
"forced_bos_token_id": null, | ||
"forced_eos_token_id": null, | ||
"hidden_act": "quick_gelu", | ||
"hidden_size": 1024, | ||
"id2label": { | ||
"0": "LABEL_0", | ||
"1": "LABEL_1" | ||
}, | ||
"image_size": 224, | ||
"initializer_factor": 1.0, | ||
"initializer_range": 0.02, | ||
"intermediate_size": 4096, | ||
"is_decoder": false, | ||
"is_encoder_decoder": false, | ||
"label2id": { | ||
"LABEL_0": 0, | ||
"LABEL_1": 1 | ||
}, | ||
"layer_norm_eps": 1e-05, | ||
"length_penalty": 1.0, | ||
"max_length": 20, | ||
"min_length": 0, | ||
"model_type": "clip_vision_model", | ||
"no_repeat_ngram_size": 0, | ||
"num_attention_heads": 16, | ||
"num_beam_groups": 1, | ||
"num_beams": 1, | ||
"num_hidden_layers": 24, | ||
"num_return_sequences": 1, | ||
"output_attentions": false, | ||
"output_hidden_states": false, | ||
"output_scores": false, | ||
"pad_token_id": null, | ||
"patch_size": 14, | ||
"prefix": null, | ||
"problem_type": null, | ||
"pruned_heads": {}, | ||
"remove_invalid_values": false, | ||
"repetition_penalty": 1.0, | ||
"return_dict": true, | ||
"return_dict_in_generate": false, | ||
"sep_token_id": null, | ||
"task_specific_params": null, | ||
"temperature": 1.0, | ||
"tie_encoder_decoder": false, | ||
"tie_word_embeddings": true, | ||
"tokenizer_class": null, | ||
"top_k": 50, | ||
"top_p": 1.0, | ||
"torch_dtype": null, | ||
"torchscript": false, | ||
"transformers_version": "4.21.0.dev0", | ||
"typical_p": 1.0, | ||
"use_bfloat16": false | ||
}, | ||
"vision_config_dict": { | ||
"hidden_size": 1024, | ||
"intermediate_size": 4096, | ||
"num_attention_heads": 16, | ||
"num_hidden_layers": 24, | ||
"patch_size": 14 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
{ | ||
"crop_size": 224, | ||
"do_center_crop": true, | ||
"do_convert_rgb": true, | ||
"do_normalize": true, | ||
"do_resize": true, | ||
"feature_extractor_type": "CLIPFeatureExtractor", | ||
"image_mean": [ | ||
0.48145466, | ||
0.4578275, | ||
0.40821073 | ||
], | ||
"image_std": [ | ||
0.26862954, | ||
0.26130258, | ||
0.27577711 | ||
], | ||
"resample": 3, | ||
"size": 224 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py | ||
|
||
# Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel | ||
from transformers.utils import logging | ||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
def cosine_distance(image_embeds, text_embeds): | ||
normalized_image_embeds = nn.functional.normalize(image_embeds) | ||
normalized_text_embeds = nn.functional.normalize(text_embeds) | ||
return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) | ||
|
||
|
||
class StableDiffusionSafetyChecker(PreTrainedModel): | ||
config_class = CLIPConfig | ||
main_input_name = "clip_input" | ||
|
||
_no_split_modules = ["CLIPEncoderLayer"] | ||
|
||
def __init__(self, config: CLIPConfig): | ||
super().__init__(config) | ||
|
||
self.vision_model = CLIPVisionModel(config.vision_config) | ||
self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) | ||
|
||
self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) | ||
self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) | ||
|
||
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) | ||
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) | ||
|
||
@torch.no_grad() | ||
def forward(self, clip_input, images): | ||
pooled_output = self.vision_model(clip_input)[1] # pooled_output | ||
image_embeds = self.visual_projection(pooled_output) | ||
|
||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 | ||
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() | ||
cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() | ||
|
||
result = [] | ||
batch_size = image_embeds.shape[0] | ||
for i in range(batch_size): | ||
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} | ||
|
||
# increase this value to create a stronger `nfsw` filter | ||
# at the cost of increasing the possibility of filtering benign images | ||
adjustment = 0.0 | ||
|
||
for concept_idx in range(len(special_cos_dist[0])): | ||
concept_cos = special_cos_dist[i][concept_idx] | ||
concept_threshold = self.special_care_embeds_weights[concept_idx].item() | ||
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) | ||
if result_img["special_scores"][concept_idx] > 0: | ||
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) | ||
adjustment = 0.01 | ||
|
||
for concept_idx in range(len(cos_dist[0])): | ||
concept_cos = cos_dist[i][concept_idx] | ||
concept_threshold = self.concept_embeds_weights[concept_idx].item() | ||
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) | ||
if result_img["concept_scores"][concept_idx] > 0: | ||
result_img["bad_concepts"].append(concept_idx) | ||
|
||
result.append(result_img) | ||
|
||
has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] | ||
|
||
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): | ||
if has_nsfw_concept: | ||
if torch.is_tensor(images) or torch.is_tensor(images[0]): | ||
images[idx] = torch.zeros_like(images[idx]) # black image | ||
else: | ||
images[idx] = np.zeros(images[idx].shape) # black image | ||
|
||
if any(has_nsfw_concepts): | ||
logger.warning( | ||
"Potential NSFW content was detected in one or more images. A black image will be returned instead." | ||
" Try again with a different prompt and/or seed." | ||
) | ||
|
||
return images, has_nsfw_concepts | ||
|
||
@torch.no_grad() | ||
def forward_onnx(self, clip_input: torch.Tensor, images: torch.Tensor): | ||
pooled_output = self.vision_model(clip_input)[1] # pooled_output | ||
image_embeds = self.visual_projection(pooled_output) | ||
|
||
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) | ||
cos_dist = cosine_distance(image_embeds, self.concept_embeds) | ||
|
||
# increase this value to create a stronger `nsfw` filter | ||
# at the cost of increasing the possibility of filtering benign images | ||
adjustment = 0.0 | ||
|
||
special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment | ||
# special_scores = special_scores.round(decimals=3) | ||
special_care = torch.any(special_scores > 0, dim=1) | ||
special_adjustment = special_care * 0.01 | ||
special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) | ||
|
||
concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment | ||
# concept_scores = concept_scores.round(decimals=3) | ||
has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) | ||
|
||
images[has_nsfw_concepts] = 0.0 # black image | ||
|
||
return images, has_nsfw_concepts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.