diff --git a/aaaaaa/ui.py b/aaaaaa/ui.py index f637d79..697dd36 100644 --- a/aaaaaa/ui.py +++ b/aaaaaa/ui.py @@ -97,6 +97,11 @@ def on_ad_model_update(model: str): visible=True, placeholder="Comma separated class names to detect, ex: 'person,cat'. default: COCO 80 classes", ) + if "yolo" in model.lower(): + return gr.update( + visible=True, + placeholder="Comma separated class numbers to detect or separated class names, ex: '0,1' for first 2 classes, or 'head, hip", + ) return gr.update(visible=False, placeholder="") @@ -203,7 +208,7 @@ def one_ui_group(n: int, is_img2img: bool, webui_info: WebuiInfo): w.ad_model_classes = gr.Textbox( label="ADetailer detector classes" + suffix(n), value="", - visible=False, + visible=True, elem_id=eid("ad_model_classes"), ) diff --git a/adetailer/ultralytics.py b/adetailer/ultralytics.py index 7c7a1a7..a70185d 100644 --- a/adetailer/ultralytics.py +++ b/adetailer/ultralytics.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING import cv2 +import numpy as np from PIL import Image from torchvision.transforms.functional import to_pil_image @@ -25,9 +26,30 @@ def ultralytics_predict( from ultralytics import YOLO model = YOLO(model_path) - apply_classes(model, model_path, classes) + class_indices = [] + if classes: + parsed = [c.strip() for c in classes.split(",") if c.strip()] + for c in parsed: + if c.isdigit(): + class_indices.append(int(c)) + elif c in model.names.values(): + # Find the index for the class name + for idx, name in model.names.items(): + if name == c: + class_indices.append(idx) + break + pred = model(image, conf=confidence, device=device) + if class_indices and len(pred[0].boxes) > 0: + cls = pred[0].boxes.cls.cpu().numpy() + mask = np.isin(cls, class_indices) + + # Apply mask to boxes + pred[0].boxes.data = pred[0].boxes.data[mask] + if pred[0].masks is not None: + pred[0].masks.data = pred[0].masks.data[mask] + bboxes = pred[0].boxes.xyxy.cpu().numpy() if bboxes.size == 0: return PredictOutput() @@ -50,11 +72,27 @@ def ultralytics_predict( def apply_classes(model: YOLO | YOLOWorld, model_path: str | Path, classes: str): - if not classes or "-world" not in Path(model_path).stem: + if not classes: return + parsed = [c.strip() for c in classes.split(",") if c.strip()] - if parsed: - model.set_classes(parsed) + if not parsed: + return + + try: + class_indices = [] + for c in parsed: + if c.isdigit(): + class_indices.append(int(c)) + elif c in model.names.values(): + for idx, name in model.names.items(): + if name == c: + class_indices.append(idx) + break + + model.classes = class_indices + except Exception as e: + print(f"Error setting classes: {e}") def mask_to_pil(masks: torch.Tensor, shape: tuple[int, int]) -> list[Image.Image]: