From 52ae346c9dfedaf0df61a6f9d376ec9c5856e497 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Wed, 15 Nov 2023 22:00:28 +0100 Subject: [PATCH 1/9] add nsfw image censoring activatable via config, uses CompVis/stable-diffusion-safety-checker --- modules/async_worker.py | 9 +++++-- modules/censor.py | 54 +++++++++++++++++++++++++++++++++++++++++ modules/config.py | 5 ++++ 3 files changed, 66 insertions(+), 2 deletions(-) create mode 100644 modules/censor.py diff --git a/modules/async_worker.py b/modules/async_worker.py index a6807547c..362cddde6 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -30,6 +30,7 @@ def worker(): import fooocus_extras.ip_adapter as ip_adapter import fooocus_extras.face_crop + from modules.censor import censor_batch from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion from modules.private_logger import log from modules.expansion import safe_str @@ -50,12 +51,16 @@ def progressbar(number, text): print(f'[Fooocus] {text}') outputs.append(['preview', (number, text, None)]) - def yield_result(imgs, do_not_show_finished_images=False): + def yield_result(imgs, do_not_show_finished_images=False, progressbar_index=13): global global_results if not isinstance(imgs, list): imgs = [imgs] + if modules.config.default_black_out_nsfw: + progressbar(progressbar_index, 'Checking for NSFW content ...') + imgs = censor_batch(imgs) + global_results = global_results + imgs if do_not_show_finished_images: @@ -711,7 +716,7 @@ def callback(step, x0, x, total_steps, y): d.append((f'LoRA [{n}] weight', w)) log(x, d, single_line_number=3) - yield_result(imgs, do_not_show_finished_images=len(tasks) == 1) + yield_result(imgs, do_not_show_finished_images=len(tasks) == 1, progressbar_index=int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps)))) except fcbh.model_management.InterruptProcessingException as e: if shared.last_stop == 'skip': print('User skipped') diff --git a/modules/censor.py b/modules/censor.py new file mode 100644 index 000000000..fac6db09a --- /dev/null +++ b/modules/censor.py @@ -0,0 +1,54 @@ +# modified version of https://github.com/AUTOMATIC1111/stable-diffusion-webui-nsfw-censor/blob/master/scripts/censor.py + +import numpy as np +import torch +import modules.core as core + +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from transformers import AutoFeatureExtractor +from PIL import Image + +safety_model_id = "CompVis/stable-diffusion-safety-checker" +safety_feature_extractor = None +safety_checker = None + + +def numpy_to_pil(image): + image = (image * 255).round().astype("uint8") + + #pil_image = Image.fromarray(image, 'RGB') + pil_image = Image.fromarray(image) + + return pil_image + + +# check and replace nsfw content +def check_safety(x_image): + global safety_feature_extractor, safety_checker + + if safety_feature_extractor is None: + safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) + safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + + safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") + x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) + + return x_checked_image, has_nsfw_concept + + +def censor_single(x): + x_checked_image, has_nsfw_concept = check_safety(x) + + # replace image with black pixels, keep dimensions + # workaround due to different numpy / pytorch image matrix format + if has_nsfw_concept[0]: + imageshape = x_checked_image.shape + x_checked_image = np.zeros((imageshape[0], imageshape[1], 3), dtype = np.uint8) + + return x_checked_image + + +def censor_batch(images): + images = [censor_single(image) for image in images] + + return images diff --git a/modules/config.py b/modules/config.py index 7216a26d8..ac81ef0b0 100644 --- a/modules/config.py +++ b/modules/config.py @@ -268,6 +268,11 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_ default_value=-1, validator=lambda x: isinstance(x, int) ) +default_black_out_nsfw = get_config_item_or_set_default( + key='default_black_out_nsfw', + default_value=False, + validator=lambda x: isinstance(x, bool) +) def add_ratio(x): From 3dc69a5293b045cd560467baef388836c7b66dd6 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Sat, 18 Nov 2023 12:14:33 +0100 Subject: [PATCH 2/9] fix progressbar call for nsfw output --- modules/async_worker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index 280895543..c57a12ae4 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -56,12 +56,12 @@ def progressbar(async_task, number, text): print(f'[Fooocus] {text}') async_task.yields.append(['preview', (number, text, None)]) - def yield_result(async_task, imgs, do_not_show_finished_images=False): + def yield_result(async_task, imgs, do_not_show_finished_images=False, progressbar_index=13): if not isinstance(imgs, list): imgs = [imgs] if modules.config.default_black_out_nsfw: - progressbar(progressbar_index, 'Checking for NSFW content ...') + progressbar(async_task, progressbar_index, 'Checking for NSFW content ...') imgs = censor_batch(imgs) async_task.results = async_task.results + imgs @@ -725,7 +725,7 @@ def callback(step, x0, x, total_steps, y): d.append((f'LoRA [{n}] weight', w)) log(x, d, single_line_number=3) - yield_result(async_task, imgs, do_not_show_finished_images=len(tasks) == 1, progressbar_index=int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps)))) + yield_result(async_task, imgs, do_not_show_finished_images=len(tasks) == 1, progressbar_index=int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps))) except fcbh.model_management.InterruptProcessingException as e: if shared.last_stop == 'skip': print('User skipped') From cdaaec3e71866c0ceedcd368c2ca329a223681a0 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Sat, 16 Dec 2023 18:54:46 +0100 Subject: [PATCH 3/9] use config to set cache dir for safety checker --- .gitignore | 1 + models/safety_checker_models/put_safety_checker_models_here | 0 modules/censor.py | 5 +++-- modules/config.py | 1 + 4 files changed, 5 insertions(+), 2 deletions(-) create mode 100644 models/safety_checker_models/put_safety_checker_models_here diff --git a/.gitignore b/.gitignore index 05ce1df87..da9cf9743 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ config.txt config_modification_tutorial.txt user_path_config.txt user_path_config-deprecated.txt +/models/safety_checker_models /modules/*.png /repositories /venv diff --git a/models/safety_checker_models/put_safety_checker_models_here b/models/safety_checker_models/put_safety_checker_models_here new file mode 100644 index 000000000..e69de29bb diff --git a/modules/censor.py b/modules/censor.py index fac6db09a..e2352218c 100644 --- a/modules/censor.py +++ b/modules/censor.py @@ -7,6 +7,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import AutoFeatureExtractor from PIL import Image +import modules.config safety_model_id = "CompVis/stable-diffusion-safety-checker" safety_feature_extractor = None @@ -27,8 +28,8 @@ def check_safety(x_image): global safety_feature_extractor, safety_checker if safety_feature_extractor is None: - safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) - safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models) + safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models) safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) diff --git a/modules/config.py b/modules/config.py index ec0b49c38..c6515bd99 100644 --- a/modules/config.py +++ b/modules/config.py @@ -126,6 +126,7 @@ def get_dir_or_set_default(key, default_value): path_clip_vision = get_dir_or_set_default('path_clip_vision', '../models/clip_vision/') path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../models/prompt_expansion/fooocus_expansion') path_outputs = get_dir_or_set_default('path_outputs', '../outputs/') +path_safety_checker_models = get_dir_or_set_default('path_safety_checker_models', '../models/safety_checker_models/') def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False): From e97008d4fb5f3e74439f78620a9d6eee60e192f2 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Sat, 16 Dec 2023 22:53:47 +0100 Subject: [PATCH 4/9] add checkbox black_out_nsfw makes both enabling via config and checkbox possible, where config overrides the checkbox value --- modules/advanced_parameters.py | 8 ++++---- modules/async_worker.py | 2 +- webui.py | 12 ++++++++++-- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/modules/advanced_parameters.py b/modules/advanced_parameters.py index ea04db6c6..7c4025260 100644 --- a/modules/advanced_parameters.py +++ b/modules/advanced_parameters.py @@ -1,15 +1,15 @@ -disable_preview, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg, sampler_name, \ +disable_preview, black_out_nsfw, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg, sampler_name, \ scheduler_name, generate_image_grid, overwrite_step, overwrite_switch, overwrite_width, overwrite_height, \ overwrite_vary_strength, overwrite_upscale_strength, \ mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint, \ debugging_cn_preprocessor, skipping_cn_preprocessor, controlnet_softness, canny_low_threshold, canny_high_threshold, \ refiner_swap_method, \ freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2, \ - debugging_inpaint_preprocessor, inpaint_disable_initial_latent, inpaint_engine, inpaint_strength, inpaint_respective_field = [None] * 32 + debugging_inpaint_preprocessor, inpaint_disable_initial_latent, inpaint_engine, inpaint_strength, inpaint_respective_field = [None] * 33 def set_all_advanced_parameters(*args): - global disable_preview, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg, sampler_name, \ + global disable_preview, black_out_nsfw, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg, sampler_name, \ scheduler_name, generate_image_grid, overwrite_step, overwrite_switch, overwrite_width, overwrite_height, \ overwrite_vary_strength, overwrite_upscale_strength, \ mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint, \ @@ -18,7 +18,7 @@ def set_all_advanced_parameters(*args): freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2, \ debugging_inpaint_preprocessor, inpaint_disable_initial_latent, inpaint_engine, inpaint_strength, inpaint_respective_field - disable_preview, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg, sampler_name, \ + disable_preview, black_out_nsfw, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg, sampler_name, \ scheduler_name, generate_image_grid, overwrite_step, overwrite_switch, overwrite_width, overwrite_height, \ overwrite_vary_strength, overwrite_upscale_strength, \ mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint, \ diff --git a/modules/async_worker.py b/modules/async_worker.py index 02ae5d032..ada35ecc9 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -60,7 +60,7 @@ def yield_result(async_task, imgs, do_not_show_finished_images=False, progressba if not isinstance(imgs, list): imgs = [imgs] - if modules.config.default_black_out_nsfw: + if modules.config.default_black_out_nsfw or advanced_parameters.black_out_nsfw: progressbar(async_task, progressbar_index, 'Checking for NSFW content ...') imgs = censor_batch(imgs) diff --git a/webui.py b/webui.py index 00b1e44e4..18d55c55c 100644 --- a/webui.py +++ b/webui.py @@ -373,9 +373,17 @@ def refresh_seed(r, seed_string): overwrite_upscale_strength = gr.Slider(label='Forced Overwrite of Denoising Strength of "Upscale"', minimum=-1, maximum=1.0, step=0.001, value=-1, info='Set as negative number to disable. For developer debugging.') - disable_preview = gr.Checkbox(label='Disable Preview', value=False, + disable_preview = gr.Checkbox(label='Disable Preview', value=modules.config.default_black_out_nsfw, + interactive=not modules.config.default_black_out_nsfw, info='Disable preview during generation.') + black_out_nsfw = gr.Checkbox(label='Black Out NSFW', value=modules.config.default_black_out_nsfw, + interactive=not modules.config.default_black_out_nsfw, + info='Use black image if NSFW is detected.') + + black_out_nsfw.change(lambda x: gr.update(value=x, interactive=not x), + inputs=black_out_nsfw, outputs=disable_preview, queue=False, show_progress=False) + with gr.Tab(label='Control'): debugging_cn_preprocessor = gr.Checkbox(label='Debug Preprocessors', value=False, info='See the results from preprocessors.') @@ -426,7 +434,7 @@ def refresh_seed(r, seed_string): freeu_s2 = gr.Slider(label='S2', minimum=0, maximum=4, step=0.01, value=0.95) freeu_ctrls = [freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2] - adps = [disable_preview, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg, sampler_name, + adps = [disable_preview, black_out_nsfw, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg, sampler_name, scheduler_name, generate_image_grid, overwrite_step, overwrite_switch, overwrite_width, overwrite_height, overwrite_vary_strength, overwrite_upscale_strength, mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint, From a4bc5c7d630c7ec4f5a54c69e2c01ee4b5891916 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Wed, 17 Jan 2024 22:40:52 +0100 Subject: [PATCH 5/9] fix: add missing diffusers package --- requirements_versions.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements_versions.txt b/requirements_versions.txt index b2111c1f5..5e9e85d6e 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -16,3 +16,4 @@ opencv-contrib-python==4.8.0.74 httpx==0.24.1 onnxruntime==1.16.3 timm==0.9.2 +diffusers==0.25.1 From f338d5fc1611ec973657fe6a73763ffee79f6763 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Fri, 17 May 2024 23:56:02 +0200 Subject: [PATCH 6/9] feat: extract safety checker, remove dependency to diffusers --- .../stable_diffusion/safety_checker.py | 126 ++++++++++++++++++ modules/censor.py | 11 +- requirements_versions.txt | 1 - 3 files changed, 129 insertions(+), 9 deletions(-) create mode 100644 extras/diffusers/pipelines/stable_diffusion/safety_checker.py diff --git a/extras/diffusers/pipelines/stable_diffusion/safety_checker.py b/extras/diffusers/pipelines/stable_diffusion/safety_checker.py new file mode 100644 index 000000000..ea38bf038 --- /dev/null +++ b/extras/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -0,0 +1,126 @@ +# from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py + +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import torch.nn as nn +from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +def cosine_distance(image_embeds, text_embeds): + normalized_image_embeds = nn.functional.normalize(image_embeds) + normalized_text_embeds = nn.functional.normalize(text_embeds) + return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) + + +class StableDiffusionSafetyChecker(PreTrainedModel): + config_class = CLIPConfig + main_input_name = "clip_input" + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModel(config.vision_config) + self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) + + self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) + self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) + + self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) + self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) + + @torch.no_grad() + def forward(self, clip_input, images): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() + cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() + + result = [] + batch_size = image_embeds.shape[0] + for i in range(batch_size): + result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} + + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + for concept_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concept_idx] + concept_threshold = self.special_care_embeds_weights[concept_idx].item() + result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concept_idx] > 0: + result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) + adjustment = 0.01 + + for concept_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concept_idx] + concept_threshold = self.concept_embeds_weights[concept_idx].item() + result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concept_idx] > 0: + result_img["bad_concepts"].append(concept_idx) + + result.append(result_img) + + has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] + + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if torch.is_tensor(images) or torch.is_tensor(images[0]): + images[idx] = torch.zeros_like(images[idx]) # black image + else: + images[idx] = np.zeros(images[idx].shape) # black image + + if any(has_nsfw_concepts): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned instead." + " Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + @torch.no_grad() + def forward_onnx(self, clip_input: torch.Tensor, images: torch.Tensor): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = cosine_distance(image_embeds, self.concept_embeds) + + # increase this value to create a stronger `nsfw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment + # special_scores = special_scores.round(decimals=3) + special_care = torch.any(special_scores > 0, dim=1) + special_adjustment = special_care * 0.01 + special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) + + concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment + # concept_scores = concept_scores.round(decimals=3) + has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) + + images[has_nsfw_concepts] = 0.0 # black image + + return images, has_nsfw_concepts diff --git a/modules/censor.py b/modules/censor.py index e2352218c..ca47693ac 100644 --- a/modules/censor.py +++ b/modules/censor.py @@ -1,10 +1,7 @@ # modified version of https://github.com/AUTOMATIC1111/stable-diffusion-webui-nsfw-censor/blob/master/scripts/censor.py - import numpy as np -import torch -import modules.core as core -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from extras.diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import AutoFeatureExtractor from PIL import Image import modules.config @@ -16,8 +13,6 @@ def numpy_to_pil(image): image = (image * 255).round().astype("uint8") - - #pil_image = Image.fromarray(image, 'RGB') pil_image = Image.fromarray(image) return pil_image @@ -27,7 +22,7 @@ def numpy_to_pil(image): def check_safety(x_image): global safety_feature_extractor, safety_checker - if safety_feature_extractor is None: + if safety_feature_extractor is None or safety_checker is None: safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models) safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models) @@ -52,4 +47,4 @@ def censor_single(x): def censor_batch(images): images = [censor_single(image) for image in images] - return images + return images \ No newline at end of file diff --git a/requirements_versions.txt b/requirements_versions.txt index 5e9e85d6e..b2111c1f5 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -16,4 +16,3 @@ opencv-contrib-python==4.8.0.74 httpx==0.24.1 onnxruntime==1.16.3 timm==0.9.2 -diffusers==0.25.1 From 0f78f8d8cc8800a910225bda9f2898c4b580ff77 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Fri, 17 May 2024 23:56:55 +0200 Subject: [PATCH 7/9] feat: make code compatible again after merge with main --- language/en.json | 2 ++ modules/async_worker.py | 32 +++++++++++++++++++++----------- modules/config.py | 10 +++++----- webui.py | 18 ++++++++++-------- 4 files changed, 38 insertions(+), 24 deletions(-) diff --git a/language/en.json b/language/en.json index fefc79c47..d420a6ab4 100644 --- a/language/en.json +++ b/language/en.json @@ -54,6 +54,8 @@ "Disable seed increment": "Disable seed increment", "Disable automatic seed increment when image number is > 1.": "Disable automatic seed increment when image number is > 1.", "Read wildcards in order": "Read wildcards in order", + "Black Out NSFW": "Black Out NSFW", + "Use black image if NSFW is detected.": "Use black image if NSFW is detected.", "\ud83d\udcda History Log": "\uD83D\uDCDA History Log", "Image Style": "Image Style", "Fooocus V2": "Fooocus V2", diff --git a/modules/async_worker.py b/modules/async_worker.py index 73cceadef..0d95725c2 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -43,7 +43,7 @@ def worker(): import fooocus_version import args_manager - from modules.censor import censor_batch + from modules.censor import censor_batch, censor_single from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion, apply_arrays from modules.private_logger import log from extras.expansion import safe_str @@ -69,11 +69,11 @@ def progressbar(async_task, number, text): print(f'[Fooocus] {text}') async_task.yields.append(['preview', (number, text, None)]) - def yield_result(async_task, imgs, do_not_show_finished_images=False, progressbar_index=13): + def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_finished_images=False, progressbar_index=13): if not isinstance(imgs, list): imgs = [imgs] - if modules.config.default_black_out_nsfw or advanced_parameters.black_out_nsfw: + if censor and (modules.config.default_black_out_nsfw or black_out_nsfw): progressbar(async_task, progressbar_index, 'Checking for NSFW content ...') imgs = censor_batch(imgs) @@ -165,6 +165,7 @@ def handler(async_task): disable_preview = args.pop() disable_intermediate_results = args.pop() disable_seed_increment = args.pop() + black_out_nsfw = args.pop() adm_scaler_positive = args.pop() adm_scaler_negative = args.pop() adm_scaler_end = args.pop() @@ -577,8 +578,11 @@ def handler(async_task): if direct_return: d = [('Upscale (Fast)', 'upscale_fast', '2x')] + if modules.config.default_black_out_nsfw or black_out_nsfw: + progressbar(async_task, 100, 'Checking for NSFW content ...') + uov_input_image = censor_single(uov_input_image) uov_input_image_path = log(uov_input_image, d, output_format=output_format) - yield_result(async_task, uov_input_image_path, do_not_show_finished_images=True) + yield_result(async_task, uov_input_image_path, black_out_nsfw, False, do_not_show_finished_images=True) return tiled = True @@ -642,8 +646,7 @@ def handler(async_task): ) if debugging_inpaint_preprocessor: - yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), - do_not_show_finished_images=True) + yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), black_out_nsfw, do_not_show_finished_images=True) return progressbar(async_task, 13, 'VAE Inpaint encoding ...') @@ -706,7 +709,7 @@ def handler(async_task): cn_img = HWC3(cn_img) task[0] = core.numpy_to_pytorch(cn_img) if debugging_cn_preprocessor: - yield_result(async_task, cn_img, do_not_show_finished_images=True) + yield_result(async_task, cn_img, black_out_nsfw, do_not_show_finished_images=True) return for task in cn_tasks[flags.cn_cpds]: cn_img, cn_stop, cn_weight = task @@ -718,7 +721,7 @@ def handler(async_task): cn_img = HWC3(cn_img) task[0] = core.numpy_to_pytorch(cn_img) if debugging_cn_preprocessor: - yield_result(async_task, cn_img, do_not_show_finished_images=True) + yield_result(async_task, cn_img, black_out_nsfw, do_not_show_finished_images=True) return for task in cn_tasks[flags.cn_ip]: cn_img, cn_stop, cn_weight = task @@ -729,7 +732,7 @@ def handler(async_task): task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_path) if debugging_cn_preprocessor: - yield_result(async_task, cn_img, do_not_show_finished_images=True) + yield_result(async_task, cn_img, black_out_nsfw, do_not_show_finished_images=True) return for task in cn_tasks[flags.cn_ip_face]: cn_img, cn_stop, cn_weight = task @@ -743,7 +746,7 @@ def handler(async_task): task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_face_path) if debugging_cn_preprocessor: - yield_result(async_task, cn_img, do_not_show_finished_images=True) + yield_result(async_task, cn_img, black_out_nsfw, do_not_show_finished_images=True) return all_ip_tasks = cn_tasks[flags.cn_ip] + cn_tasks[flags.cn_ip_face] @@ -843,6 +846,12 @@ def callback(step, x0, x, total_steps, y): imgs = [inpaint_worker.current_task.post_process(x) for x in imgs] img_paths = [] + + if modules.config.default_black_out_nsfw or black_out_nsfw: + progressbar(async_task, int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps)), + 'Checking for NSFW content ...') + imgs = censor_batch(imgs) + for x in imgs: d = [('Prompt', 'prompt', task['log_positive_prompt']), ('Negative Prompt', 'negative_prompt', task['log_negative_prompt']), @@ -892,7 +901,8 @@ def callback(step, x0, x, total_steps, y): d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images)) d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version)) img_paths.append(log(x, d, metadata_parser, output_format)) - yield_result(async_task, img_paths, do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results, progressbar_index=int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps))) + + yield_result(async_task, img_paths, black_out_nsfw, False, do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results) except ldm_patched.modules.model_management.InterruptProcessingException as e: if async_task.last_stop == 'skip': print('User skipped') diff --git a/modules/config.py b/modules/config.py index 2db23dbfd..5a18e9635 100644 --- a/modules/config.py +++ b/modules/config.py @@ -451,6 +451,11 @@ def init_temp_path(path: str | None, default_path: str) -> str: ], validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x) ) +default_black_out_nsfw = get_config_item_or_set_default( + key='default_black_out_nsfw', + default_value=False, + validator=lambda x: isinstance(x, bool) +) default_save_metadata_to_images = get_config_item_or_set_default( key='default_save_metadata_to_images', default_value=False, @@ -466,11 +471,6 @@ def init_temp_path(path: str | None, default_path: str) -> str: default_value='', validator=lambda x: isinstance(x, str) ) -default_black_out_nsfw = get_config_item_or_set_default( - key='default_black_out_nsfw', - default_value=False, - validator=lambda x: isinstance(x, bool) -) example_inpaint_prompts = [[x] for x in example_inpaint_prompts] diff --git a/webui.py b/webui.py index 29eed6068..ab6ad0913 100644 --- a/webui.py +++ b/webui.py @@ -445,6 +445,15 @@ def update_history_link(): value=False) read_wildcards_in_order = gr.Checkbox(label="Read wildcards in order", value=False) + black_out_nsfw = gr.Checkbox(label='Black Out NSFW', + value=modules.config.default_black_out_nsfw, + interactive=not modules.config.default_black_out_nsfw, + info='Use black image if NSFW is detected.') + + black_out_nsfw.change(lambda x: gr.update(value=x, interactive=not x), + inputs=black_out_nsfw, outputs=disable_preview, queue=False, + show_progress=False) + if not args_manager.args.disable_metadata: save_metadata_to_images = gr.Checkbox(label='Save Metadata to Images', value=modules.config.default_save_metadata_to_images, info='Adds parameters to generated images allowing manual regeneration.') @@ -455,13 +464,6 @@ def update_history_link(): save_metadata_to_images.change(lambda x: gr.update(visible=x), inputs=[save_metadata_to_images], outputs=[metadata_scheme], queue=False, show_progress=False) - black_out_nsfw = gr.Checkbox(label='Black Out NSFW', value=modules.config.default_black_out_nsfw, - interactive=not modules.config.default_black_out_nsfw, - info='Use black image if NSFW is detected.') - - black_out_nsfw.change(lambda x: gr.update(value=x, interactive=not x), - inputs=black_out_nsfw, outputs=disable_preview, queue=False, show_progress=False) - with gr.Tab(label='Control'): debugging_cn_preprocessor = gr.Checkbox(label='Debug Preprocessors', value=False, info='See the results from preprocessors.') @@ -640,7 +642,7 @@ def inpaint_mode_change(mode): ctrls += [input_image_checkbox, current_tab] ctrls += [uov_method, uov_input_image] ctrls += [outpaint_selections, inpaint_input_image, inpaint_additional_prompt, inpaint_mask_image] - ctrls += [disable_preview, disable_intermediate_results, disable_seed_increment] + ctrls += [disable_preview, disable_intermediate_results, disable_seed_increment, black_out_nsfw] ctrls += [adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg] ctrls += [sampler_name, scheduler_name] ctrls += [overwrite_step, overwrite_switch, overwrite_width, overwrite_height, overwrite_vary_strength] From 7568b72d9b8dd285238bb0c680b84d6a80fe7932 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Sat, 18 May 2024 01:59:15 +0200 Subject: [PATCH 8/9] feat: move censor to extras, optimize safety checker file handling --- .gitignore | 1 - {modules => extras}/censor.py | 16 +- extras/safety_checker/configs/config.json | 171 ++++++++++++++++++ .../configs/preprocessor_config.json | 20 ++ .../models}/safety_checker.py | 0 modules/async_worker.py | 2 +- modules/config.py | 8 + 7 files changed, 211 insertions(+), 7 deletions(-) rename {modules => extras}/censor.py (65%) create mode 100644 extras/safety_checker/configs/config.json create mode 100644 extras/safety_checker/configs/preprocessor_config.json rename extras/{diffusers/pipelines/stable_diffusion => safety_checker/models}/safety_checker.py (100%) diff --git a/.gitignore b/.gitignore index e423ef81a..859149866 100644 --- a/.gitignore +++ b/.gitignore @@ -18,7 +18,6 @@ config.txt config_modification_tutorial.txt user_path_config.txt user_path_config-deprecated.txt -/models/safety_checker_models /modules/*.png /repositories /fooocus_env diff --git a/modules/censor.py b/extras/censor.py similarity index 65% rename from modules/censor.py rename to extras/censor.py index ca47693ac..2047db246 100644 --- a/modules/censor.py +++ b/extras/censor.py @@ -1,12 +1,16 @@ # modified version of https://github.com/AUTOMATIC1111/stable-diffusion-webui-nsfw-censor/blob/master/scripts/censor.py import numpy as np +import os -from extras.diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from transformers import AutoFeatureExtractor +from extras.safety_checker.models.safety_checker import StableDiffusionSafetyChecker +from transformers import CLIPFeatureExtractor, CLIPConfig from PIL import Image import modules.config -safety_model_id = "CompVis/stable-diffusion-safety-checker" +safety_checker_repo_root = os.path.join(os.path.dirname(__file__), 'safety_checker') +config_path = os.path.join(safety_checker_repo_root, "configs", "config.json") +preprocessor_config_path = os.path.join(safety_checker_repo_root, "configs", "preprocessor_config.json") + safety_feature_extractor = None safety_checker = None @@ -23,8 +27,10 @@ def check_safety(x_image): global safety_feature_extractor, safety_checker if safety_feature_extractor is None or safety_checker is None: - safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models) - safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models) + safety_checker_model = modules.config.downloading_safety_checker_model() + safety_feature_extractor = CLIPFeatureExtractor.from_json_file(preprocessor_config_path) + clip_config = CLIPConfig.from_json_file(config_path) + safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_checker_model, config=clip_config) safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) diff --git a/extras/safety_checker/configs/config.json b/extras/safety_checker/configs/config.json new file mode 100644 index 000000000..aa454d222 --- /dev/null +++ b/extras/safety_checker/configs/config.json @@ -0,0 +1,171 @@ +{ + "_name_or_path": "clip-vit-large-patch14/", + "architectures": [ + "SafetyChecker" + ], + "initializer_factor": 1.0, + "logit_scale_init_value": 2.6592, + "model_type": "clip", + "projection_dim": 768, + "text_config": { + "_name_or_path": "", + "add_cross_attention": false, + "architectures": null, + "attention_dropout": 0.0, + "bad_words_ids": null, + "bos_token_id": 0, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "diversity_penalty": 0.0, + "do_sample": false, + "dropout": 0.0, + "early_stopping": false, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": 2, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "hidden_act": "quick_gelu", + "hidden_size": 768, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 3072, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "layer_norm_eps": 1e-05, + "length_penalty": 1.0, + "max_length": 20, + "max_position_embeddings": 77, + "min_length": 0, + "model_type": "clip_text_model", + "no_repeat_ngram_size": 0, + "num_attention_heads": 12, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 12, + "num_return_sequences": 1, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": 1, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "sep_token_id": null, + "task_specific_params": null, + "temperature": 1.0, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": null, + "torchscript": false, + "transformers_version": "4.21.0.dev0", + "typical_p": 1.0, + "use_bfloat16": false, + "vocab_size": 49408 + }, + "text_config_dict": { + "hidden_size": 768, + "intermediate_size": 3072, + "num_attention_heads": 12, + "num_hidden_layers": 12 + }, + "torch_dtype": "float32", + "transformers_version": null, + "vision_config": { + "_name_or_path": "", + "add_cross_attention": false, + "architectures": null, + "attention_dropout": 0.0, + "bad_words_ids": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "diversity_penalty": 0.0, + "do_sample": false, + "dropout": 0.0, + "early_stopping": false, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "hidden_act": "quick_gelu", + "hidden_size": 1024, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "image_size": 224, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 4096, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "layer_norm_eps": 1e-05, + "length_penalty": 1.0, + "max_length": 20, + "min_length": 0, + "model_type": "clip_vision_model", + "no_repeat_ngram_size": 0, + "num_attention_heads": 16, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 24, + "num_return_sequences": 1, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "patch_size": 14, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "sep_token_id": null, + "task_specific_params": null, + "temperature": 1.0, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": null, + "torchscript": false, + "transformers_version": "4.21.0.dev0", + "typical_p": 1.0, + "use_bfloat16": false + }, + "vision_config_dict": { + "hidden_size": 1024, + "intermediate_size": 4096, + "num_attention_heads": 16, + "num_hidden_layers": 24, + "patch_size": 14 + } +} diff --git a/extras/safety_checker/configs/preprocessor_config.json b/extras/safety_checker/configs/preprocessor_config.json new file mode 100644 index 000000000..5294955ff --- /dev/null +++ b/extras/safety_checker/configs/preprocessor_config.json @@ -0,0 +1,20 @@ +{ + "crop_size": 224, + "do_center_crop": true, + "do_convert_rgb": true, + "do_normalize": true, + "do_resize": true, + "feature_extractor_type": "CLIPFeatureExtractor", + "image_mean": [ + 0.48145466, + 0.4578275, + 0.40821073 + ], + "image_std": [ + 0.26862954, + 0.26130258, + 0.27577711 + ], + "resample": 3, + "size": 224 +} diff --git a/extras/diffusers/pipelines/stable_diffusion/safety_checker.py b/extras/safety_checker/models/safety_checker.py similarity index 100% rename from extras/diffusers/pipelines/stable_diffusion/safety_checker.py rename to extras/safety_checker/models/safety_checker.py diff --git a/modules/async_worker.py b/modules/async_worker.py index 0d95725c2..fa10ff8ad 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -43,7 +43,7 @@ def worker(): import fooocus_version import args_manager - from modules.censor import censor_batch, censor_single + from extras.censor import censor_batch, censor_single from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion, apply_arrays from modules.private_logger import log from extras.expansion import safe_str diff --git a/modules/config.py b/modules/config.py index 5a18e9635..8b2772427 100644 --- a/modules/config.py +++ b/modules/config.py @@ -685,5 +685,13 @@ def downloading_upscale_model(): ) return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin') +def downloading_safety_checker_model(): + load_file_from_url( + url='https://huggingface.co/mashb1t/misc/resolve/main/stable-diffusion-safety-checker.bin', + model_dir=path_safety_checker_models, + file_name='stable-diffusion-safety-checker.bin' + ) + return os.path.join(path_safety_checker_models, 'stable-diffusion-safety-checker.bin') + update_files() From 49795fe0306149106c43cb75bd1eb8afc45badce Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Sat, 18 May 2024 15:37:58 +0200 Subject: [PATCH 9/9] refactor: rename folder safety_checker_models to safety_checker --- .../put_safety_checker_models_here | 0 modules/config.py | 6 +++--- 2 files changed, 3 insertions(+), 3 deletions(-) rename models/{safety_checker_models => safety_checker}/put_safety_checker_models_here (100%) diff --git a/models/safety_checker_models/put_safety_checker_models_here b/models/safety_checker/put_safety_checker_models_here similarity index 100% rename from models/safety_checker_models/put_safety_checker_models_here rename to models/safety_checker/put_safety_checker_models_here diff --git a/modules/config.py b/modules/config.py index 8b2772427..73e33e4a0 100644 --- a/modules/config.py +++ b/modules/config.py @@ -195,7 +195,7 @@ def get_dir_or_set_default(key, default_value, as_array=False, make_directory=Fa path_clip_vision = get_dir_or_set_default('path_clip_vision', '../models/clip_vision/') path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../models/prompt_expansion/fooocus_expansion') path_wildcards = get_dir_or_set_default('path_wildcards', '../wildcards/') -path_safety_checker_models = get_dir_or_set_default('path_safety_checker_models', '../models/safety_checker_models/') +path_safety_checker = get_dir_or_set_default('path_safety_checker', '../models/safety_checker/') path_outputs = get_path_output() @@ -688,10 +688,10 @@ def downloading_upscale_model(): def downloading_safety_checker_model(): load_file_from_url( url='https://huggingface.co/mashb1t/misc/resolve/main/stable-diffusion-safety-checker.bin', - model_dir=path_safety_checker_models, + model_dir=path_safety_checker, file_name='stable-diffusion-safety-checker.bin' ) - return os.path.join(path_safety_checker_models, 'stable-diffusion-safety-checker.bin') + return os.path.join(path_safety_checker, 'stable-diffusion-safety-checker.bin') update_files()