Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing-su committed Mar 16, 2024
2 parents 3f1d1b9 + bbd774e commit 643242a
Show file tree
Hide file tree
Showing 15 changed files with 360 additions and 53 deletions.
1 change: 0 additions & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ on:
jobs:
lint:
runs-on: ubuntu-latest
if: github.repository == 'Bing-su/adetailer' || github.repository == ''

steps:
- uses: actions/checkout@v4
Expand Down
6 changes: 5 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@ repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-added-large-files
args: [--maxkb=100]
- id: check-merge-conflict
- id: check-case-conflict
- id: check-ast
- id: trailing-whitespace
args: [--markdown-linebreak-ext=md]
- id: end-of-file-fixer
- id: mixed-line-ending

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.2
rev: v0.3.1
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## 2024-03-16

- YOLO World v2, YOLO9 지원가능한 버전으로 ultralytics 업데이트
- inpaint full res인 경우 인페인트 모드에서 동작하게 변경
- inpaint full res가 아닌 경우, 사용자가 입력한 마스크와 교차점이 있는 마스크만 선택하여 사용함

## 2024-03-01

- v24.3.0
Expand Down
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,11 @@ ADetailer works in three simple steps.
1. Create an image.
2. Detect object with a detection model and create a mask image.
3. Inpaint using the image from 1 and the mask from 2.

## Development

ADetailer is developed and tested using the stable-diffusion 1.5 model, for the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) repository only.

## License

ADetailer is a derivative work that uses two AGPL-licensed works (stable-diffusion-webui, ultralytics) and is therefore distributed under the AGPL license.
2 changes: 1 addition & 1 deletion adetailer/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "24.3.0"
__version__ = "24.3.1"
32 changes: 22 additions & 10 deletions adetailer/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,28 @@
from functools import cached_property, partial
from typing import Any, Literal, NamedTuple, Optional

from pydantic import (
BaseModel,
Extra,
NonNegativeFloat,
NonNegativeInt,
PositiveInt,
confloat,
conint,
validator,
)
try:
from pydantic.v1 import (
BaseModel,
Extra,
NonNegativeFloat,
NonNegativeInt,
PositiveInt,
confloat,
conint,
validator,
)
except ImportError:
from pydantic import (
BaseModel,
Extra,
NonNegativeFloat,
NonNegativeInt,
PositiveInt,
confloat,
conint,
validator,
)


@dataclass
Expand Down
25 changes: 17 additions & 8 deletions adetailer/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from collections import OrderedDict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from typing import Any, Optional

from huggingface_hub import hf_hub_download
from PIL import Image, ImageDraw
from rich import print
from torchvision.transforms.functional import to_pil_image

REPO_ID = "Bingsu/adetailer"
_download_failed = False
Expand Down Expand Up @@ -44,7 +45,7 @@ def scan_model_dir(path_: str | Path) -> list[Path]:

def get_models(
model_dir: str | Path, extra_dir: str | Path = "", huggingface: bool = True
) -> OrderedDict[str, str | None]:
) -> OrderedDict[str, str]:
model_paths = [*scan_model_dir(model_dir), *scan_model_dir(extra_dir)]

models = OrderedDict()
Expand All @@ -56,17 +57,17 @@ def get_models(
"hand_yolov8n.pt": hf_download("hand_yolov8n.pt"),
"person_yolov8n-seg.pt": hf_download("person_yolov8n-seg.pt"),
"person_yolov8s-seg.pt": hf_download("person_yolov8s-seg.pt"),
"yolov8x-world.pt": hf_download(
"yolov8x-world.pt", repo_id="Bingsu/yolo-world-mirror"
"yolov8x-worldv2.pt": hf_download(
"yolov8x-worldv2.pt", repo_id="Bingsu/yolo-world-mirror"
),
}
)
models.update(
{
"mediapipe_face_full": None,
"mediapipe_face_short": None,
"mediapipe_face_mesh": None,
"mediapipe_face_mesh_eyes_only": None,
"mediapipe_face_full": "mediapipe_face_full",
"mediapipe_face_short": "mediapipe_face_short",
"mediapipe_face_mesh": "mediapipe_face_mesh",
"mediapipe_face_mesh_eyes_only": "mediapipe_face_mesh_eyes_only",
}
)

Expand Down Expand Up @@ -133,3 +134,11 @@ def create_bbox_from_mask(
if bbox is not None:
bboxes.append(list(bbox))
return bboxes


def ensure_pil_image(image: Any, mode: str = "RGB") -> Image.Image:
if not isinstance(image, Image.Image):
image = to_pil_image(image)
if image.mode != mode:
image = image.convert(mode)
return image
18 changes: 13 additions & 5 deletions adetailer/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from enum import IntEnum
from functools import partial, reduce
from math import dist
from typing import Any

import cv2
import numpy as np
from PIL import Image, ImageChops

from adetailer.args import MASK_MERGE_INVERT
from adetailer.common import PredictOutput
from adetailer.common import PredictOutput, ensure_pil_image


class SortBy(IntEnum):
Expand Down Expand Up @@ -83,12 +84,19 @@ def offset(img: Image.Image, x: int = 0, y: int = 0) -> Image.Image:
return ImageChops.offset(img, x, -y)


def is_all_black(img: Image.Image) -> bool:
arr = np.array(img)
return cv2.countNonZero(arr) == 0
def is_all_black(img: Image.Image | np.ndarray) -> bool:
if isinstance(img, Image.Image):
img = np.array(img)
return cv2.countNonZero(img) == 0


def has_intersection(im1: Any, im2: Any) -> bool:
arr1 = np.array(ensure_pil_image(im1, "L"))
arr2 = np.array(ensure_pil_image(im2, "L"))
return not is_all_black(cv2.bitwise_and(arr1, arr2))


def bbox_area(bbox: list[float]):
def bbox_area(bbox: list[float]) -> float:
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])


Expand Down
4 changes: 4 additions & 0 deletions controlnet_ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
get_cn_models,
)

from .restore import CNHijackRestore, cn_allow_script_control

__all__ = [
"ControlNetExt",
"CNHijackRestore",
"cn_allow_script_control",
"controlnet_exists",
"controlnet_type",
"get_cn_models",
Expand Down
2 changes: 1 addition & 1 deletion install.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def run_pip(*args):
def install():
deps = [
# requirements
("ultralytics", "8.1.18", None),
("ultralytics", "8.1.29", None),
("mediapipe", "0.10.9", None),
("rich", "13.0.0", None),
# mediapipe
Expand Down
26 changes: 25 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,41 @@
name = "adetailer"
description = "An object detection and auto-mask extension for stable diffusion webui."
authors = [{ name = "dowon", email = "[email protected]" }]
requires-python = ">=3.8,<3.12"
requires-python = ">=3.8,<3.13"
readme = "README.md"
license = { text = "AGPL-3.0" }
dependencies = [
"ultralytics>=8.1",
"mediapipe>=10",
"pydantic<3",
"rich>=13",
"huggingface_hub",
]
keywords = [
"stable-diffusion",
"stable-diffusion-webui",
"adetailer",
"ultralytics",
]
dynamic = ["version"]

[project.urls]
repository = "https://github.com/Bing-su/adetailer"

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.version]
path = "adetailer/__version__.py"

[tool.isort]
profile = "black"
known_first_party = ["launch", "modules"]

[tool.ruff]
target-version = "py38"

[tool.ruff.lint]
select = [
"A",
Expand Down
56 changes: 33 additions & 23 deletions scripts/!adetailer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

import gradio as gr
import torch
from PIL import Image
from PIL import Image, ImageChops
from rich import print
from torchvision.transforms.functional import to_pil_image

import modules
from adetailer import (
Expand All @@ -27,25 +26,25 @@
ultralytics_predict,
)
from adetailer.args import ALL_ARGS, BBOX_SORTBY, ADetailerArgs, SkipImg2ImgOrig
from adetailer.common import PredictOutput
from adetailer.common import PredictOutput, ensure_pil_image
from adetailer.mask import (
filter_by_ratio,
filter_k_largest,
has_intersection,
is_all_black,
mask_preprocess,
sort_bboxes,
)
from adetailer.traceback import rich_traceback
from adetailer.ui import WebuiInfo, adui, ordinal, suffix
from controlnet_ext import (
CNHijackRestore,
ControlNetExt,
cn_allow_script_control,
controlnet_exists,
controlnet_type,
get_cn_models,
)
from controlnet_ext.restore import (
CNHijackRestore,
cn_allow_script_control,
)
from modules import images, paths, safe, script_callbacks, scripts, shared
from modules.devices import NansException
from modules.processing import (
Expand Down Expand Up @@ -565,27 +564,24 @@ def sort_bboxes(self, pred: PredictOutput) -> PredictOutput:
sortby_idx = BBOX_SORTBY.index(sortby)
return sort_bboxes(pred, sortby_idx)

def pred_preprocessing(self, pred: PredictOutput, args: ADetailerArgs):
def pred_preprocessing(self, p, pred: PredictOutput, args: ADetailerArgs):
pred = filter_by_ratio(
pred, low=args.ad_mask_min_ratio, high=args.ad_mask_max_ratio
)
pred = filter_k_largest(pred, k=args.ad_mask_k_largest)
pred = self.sort_bboxes(pred)
return mask_preprocess(
masks = mask_preprocess(
pred.masks,
kernel=args.ad_dilate_erode,
x_offset=args.ad_x_offset,
y_offset=args.ad_y_offset,
merge_invert=args.ad_mask_merge_invert,
)

@staticmethod
def ensure_rgb_image(image: Any):
if not isinstance(image, Image.Image):
image = to_pil_image(image)
if image.mode != "RGB":
image = image.convert("RGB")
return image
if self.is_img2img_inpaint(p) and not self.is_inpaint_only_masked(p):
invert = p.inpainting_mask_invert
image_mask = ensure_pil_image(p.image_mask, mode="L")
masks = self.inpaint_mask_filter(image_mask, masks, invert)
return masks

@staticmethod
def i2i_prompts_replace(
Expand Down Expand Up @@ -637,16 +633,30 @@ def get_each_tap_seed(seed: int, i: int):

@staticmethod
def is_img2img_inpaint(p) -> bool:
return hasattr(p, "image_mask") and bool(p.image_mask)
return hasattr(p, "image_mask") and p.image_mask is not None

@staticmethod
def is_inpaint_only_masked(p) -> bool:
return hasattr(p, "inpaint_full_res") and p.inpaint_full_res

@staticmethod
def inpaint_mask_filter(
img2img_mask: Image.Image, ad_mask: list[Image.Image], invert: int = 0
) -> list[Image.Image]:
if invert:
img2img_mask = ImageChops.invert(img2img_mask)
return [mask for mask in ad_mask if has_intersection(img2img_mask, mask)]

@rich_traceback
def process(self, p, *args_):
if getattr(p, "_ad_disabled", False):
return

if self.is_img2img_inpaint(p):
if self.is_img2img_inpaint(p) and is_all_black(p.image_mask):
p._ad_disabled = True
msg = "[-] ADetailer: img2img inpainting detected. adetailer disabled."
msg = (
"[-] ADetailer: img2img inpainting with no mask -- adetailer disabled."
)
print(msg)
return

Expand Down Expand Up @@ -700,7 +710,7 @@ def _postprocess_image_inner(
with change_torch_load():
pred = predictor(ad_model, pp.image, args.ad_confidence, **kwargs)

masks = self.pred_preprocessing(pred, args)
masks = self.pred_preprocessing(p, pred, args)
shared.state.assign_current_image(pred.preview)

if not masks:
Expand All @@ -726,7 +736,7 @@ def _postprocess_image_inner(
p2 = copy(i2i)
for j in range(steps):
p2.image_mask = masks[j]
p2.init_images[0] = self.ensure_rgb_image(p2.init_images[0])
p2.init_images[0] = ensure_pil_image(p2.init_images[0], "RGB")
self.i2i_prompts_replace(p2, ad_prompts, ad_negatives, j)

if re.match(r"^\s*\[SKIP\]\s*$", p2.prompt):
Expand Down Expand Up @@ -760,7 +770,7 @@ def postprocess_image(self, p, pp, *args_):
return

pp.image = self.get_i2i_init_image(p, pp)
pp.image = self.ensure_rgb_image(pp.image)
pp.image = ensure_pil_image(pp.image, "RGB")
init_image = copy(pp.image)
arg_list = self.get_args(p, *args_)
params_txt_content = Path(paths.data_path, "params.txt").read_text("utf-8")
Expand Down
Loading

0 comments on commit 643242a

Please sign in to comment.