Skip to content

Commit

Permalink
add nsfw image censoring
Browse files Browse the repository at this point in the history
activatable via config, uses CompVis/stable-diffusion-safety-checker
  • Loading branch information
mashb1t committed Nov 15, 2023
1 parent 943098f commit 52ae346
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 2 deletions.
9 changes: 7 additions & 2 deletions modules/async_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def worker():
import fooocus_extras.ip_adapter as ip_adapter
import fooocus_extras.face_crop

from modules.censor import censor_batch
from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion
from modules.private_logger import log
from modules.expansion import safe_str
Expand All @@ -50,12 +51,16 @@ def progressbar(number, text):
print(f'[Fooocus] {text}')
outputs.append(['preview', (number, text, None)])

def yield_result(imgs, do_not_show_finished_images=False):
def yield_result(imgs, do_not_show_finished_images=False, progressbar_index=13):
global global_results

if not isinstance(imgs, list):
imgs = [imgs]

if modules.config.default_black_out_nsfw:
progressbar(progressbar_index, 'Checking for NSFW content ...')
imgs = censor_batch(imgs)

global_results = global_results + imgs

if do_not_show_finished_images:
Expand Down Expand Up @@ -711,7 +716,7 @@ def callback(step, x0, x, total_steps, y):
d.append((f'LoRA [{n}] weight', w))
log(x, d, single_line_number=3)

yield_result(imgs, do_not_show_finished_images=len(tasks) == 1)
yield_result(imgs, do_not_show_finished_images=len(tasks) == 1, progressbar_index=int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps))))
except fcbh.model_management.InterruptProcessingException as e:
if shared.last_stop == 'skip':
print('User skipped')
Expand Down
54 changes: 54 additions & 0 deletions modules/censor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# modified version of https://github.com/AUTOMATIC1111/stable-diffusion-webui-nsfw-censor/blob/master/scripts/censor.py

import numpy as np
import torch
import modules.core as core

from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
from PIL import Image

safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = None
safety_checker = None


def numpy_to_pil(image):
image = (image * 255).round().astype("uint8")

#pil_image = Image.fromarray(image, 'RGB')
pil_image = Image.fromarray(image)

return pil_image


# check and replace nsfw content
def check_safety(x_image):
global safety_feature_extractor, safety_checker

if safety_feature_extractor is None:
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)

safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)

return x_checked_image, has_nsfw_concept


def censor_single(x):
x_checked_image, has_nsfw_concept = check_safety(x)

# replace image with black pixels, keep dimensions
# workaround due to different numpy / pytorch image matrix format
if has_nsfw_concept[0]:
imageshape = x_checked_image.shape
x_checked_image = np.zeros((imageshape[0], imageshape[1], 3), dtype = np.uint8)

return x_checked_image


def censor_batch(images):
images = [censor_single(image) for image in images]

return images
5 changes: 5 additions & 0 deletions modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,11 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_
default_value=-1,
validator=lambda x: isinstance(x, int)
)
default_black_out_nsfw = get_config_item_or_set_default(
key='default_black_out_nsfw',
default_value=False,
validator=lambda x: isinstance(x, bool)
)


def add_ratio(x):
Expand Down

0 comments on commit 52ae346

Please sign in to comment.