From db9bbed15d4db80fd546046199b5c8ba83dff901 Mon Sep 17 00:00:00 2001 From: Panchovix Date: Sun, 1 Dec 2024 22:07:20 -0300 Subject: [PATCH 1/2] Let load classes for YOLO models Actually, when using models from example from https://github.com/aperveyev/booru_yolo/tree/main/models, it uses all classes when it founds it. This PR lets the user set a class index or number to use that class instead of every class. For now it prints all classes found, but only uses the one that the user entered. --- aaaaaa/ui.py | 7 +++++- adetailer/ultralytics.py | 48 +++++++++++++++++++++++++++++++++++----- 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/aaaaaa/ui.py b/aaaaaa/ui.py index f637d79..77a65fd 100644 --- a/aaaaaa/ui.py +++ b/aaaaaa/ui.py @@ -97,6 +97,11 @@ def on_ad_model_update(model: str): visible=True, placeholder="Comma separated class names to detect, ex: 'person,cat'. default: COCO 80 classes", ) + elif "yolo" in model.lower(): + return gr.update( + visible=True, + placeholder="Comma separated class numbers to detect or separated class names, ex: '0,1' for first 2 classes, or 'head, hip", + ) return gr.update(visible=False, placeholder="") @@ -203,7 +208,7 @@ def one_ui_group(n: int, is_img2img: bool, webui_info: WebuiInfo): w.ad_model_classes = gr.Textbox( label="ADetailer detector classes" + suffix(n), value="", - visible=False, + visible=True, elem_id=eid("ad_model_classes"), ) diff --git a/adetailer/ultralytics.py b/adetailer/ultralytics.py index 7c7a1a7..dff3dab 100644 --- a/adetailer/ultralytics.py +++ b/adetailer/ultralytics.py @@ -9,7 +9,7 @@ from adetailer import PredictOutput from adetailer.common import create_mask_from_bbox - +import numpy as np if TYPE_CHECKING: import torch from ultralytics import YOLO, YOLOWorld @@ -23,10 +23,30 @@ def ultralytics_predict( classes: str = "", ) -> PredictOutput[float]: from ultralytics import YOLO - model = YOLO(model_path) - apply_classes(model, model_path, classes) + class_indices = [] + if classes: + parsed = [c.strip() for c in classes.split(",") if c.strip()] + for c in parsed: + if c.isdigit(): + class_indices.append(int(c)) + elif c in model.names.values(): + # Find the index for the class name + for idx, name in model.names.items(): + if name == c: + class_indices.append(idx) + break + pred = model(image, conf=confidence, device=device) + + if class_indices and len(pred[0].boxes) > 0: + cls = pred[0].boxes.cls.cpu().numpy() + mask = np.isin(cls, class_indices) + + # Apply mask to boxes + pred[0].boxes.data = pred[0].boxes.data[mask] + if pred[0].masks is not None: + pred[0].masks.data = pred[0].masks.data[mask] bboxes = pred[0].boxes.xyxy.cpu().numpy() if bboxes.size == 0: @@ -50,11 +70,27 @@ def ultralytics_predict( def apply_classes(model: YOLO | YOLOWorld, model_path: str | Path, classes: str): - if not classes or "-world" not in Path(model_path).stem: + if not classes: return + parsed = [c.strip() for c in classes.split(",") if c.strip()] - if parsed: - model.set_classes(parsed) + if not parsed: + return + + try: + class_indices = [] + for c in parsed: + if c.isdigit(): + class_indices.append(int(c)) + elif c in model.names.values(): + for idx, name in model.names.items(): + if name == c: + class_indices.append(idx) + break + + model.classes = class_indices + except Exception as e: + print(f"Error setting classes: {e}") def mask_to_pil(masks: torch.Tensor, shape: tuple[int, int]) -> list[Image.Image]: From 72b68025e1a7bd9bac424cee556e9258a1733136 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Dec 2024 01:07:49 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- aaaaaa/ui.py | 2 +- adetailer/ultralytics.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/aaaaaa/ui.py b/aaaaaa/ui.py index 77a65fd..697dd36 100644 --- a/aaaaaa/ui.py +++ b/aaaaaa/ui.py @@ -97,7 +97,7 @@ def on_ad_model_update(model: str): visible=True, placeholder="Comma separated class names to detect, ex: 'person,cat'. default: COCO 80 classes", ) - elif "yolo" in model.lower(): + if "yolo" in model.lower(): return gr.update( visible=True, placeholder="Comma separated class numbers to detect or separated class names, ex: '0,1' for first 2 classes, or 'head, hip", diff --git a/adetailer/ultralytics.py b/adetailer/ultralytics.py index dff3dab..a70185d 100644 --- a/adetailer/ultralytics.py +++ b/adetailer/ultralytics.py @@ -4,12 +4,13 @@ from typing import TYPE_CHECKING import cv2 +import numpy as np from PIL import Image from torchvision.transforms.functional import to_pil_image from adetailer import PredictOutput from adetailer.common import create_mask_from_bbox -import numpy as np + if TYPE_CHECKING: import torch from ultralytics import YOLO, YOLOWorld @@ -23,6 +24,7 @@ def ultralytics_predict( classes: str = "", ) -> PredictOutput[float]: from ultralytics import YOLO + model = YOLO(model_path) class_indices = [] if classes: @@ -38,11 +40,11 @@ def ultralytics_predict( break pred = model(image, conf=confidence, device=device) - + if class_indices and len(pred[0].boxes) > 0: cls = pred[0].boxes.cls.cpu().numpy() mask = np.isin(cls, class_indices) - + # Apply mask to boxes pred[0].boxes.data = pred[0].boxes.data[mask] if pred[0].masks is not None: @@ -72,11 +74,11 @@ def ultralytics_predict( def apply_classes(model: YOLO | YOLOWorld, model_path: str | Path, classes: str): if not classes: return - + parsed = [c.strip() for c in classes.split(",") if c.strip()] if not parsed: return - + try: class_indices = [] for c in parsed: @@ -87,7 +89,7 @@ def apply_classes(model: YOLO | YOLOWorld, model_path: str | Path, classes: str) if name == c: class_indices.append(idx) break - + model.classes = class_indices except Exception as e: print(f"Error setting classes: {e}")