diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 32753a30a..512fd5fbf 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -840,9 +840,10 @@ def from_lmm( if lmm == LMM.PALIGEMMA: assert isinstance(result, str) - xyxy, class_id, class_name = from_paligemma(result, **kwargs) + xyxy, class_id, class_name, mask = from_paligemma(result, **kwargs) data = {CLASS_NAME_DATA_FIELD: class_name} - return cls(xyxy=xyxy, class_id=class_id, data=data) + mask = mask if mask is not None else None + return cls(xyxy=xyxy, class_id=class_id, mask=mask, data=data) if lmm == LMM.FLORENCE_2: assert isinstance(result, dict) diff --git a/supervision/detection/lmm.py b/supervision/detection/lmm.py index 7879902f3..65929ee04 100644 --- a/supervision/detection/lmm.py +++ b/supervision/detection/lmm.py @@ -69,25 +69,72 @@ def validate_lmm_parameters( def from_paligemma( result: str, resolution_wh: Tuple[int, int], classes: Optional[List[str]] = None -) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]: +) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray, Optional[np.ndarray]]: + """ + Parse results from Paligemma model which can contain object detection and segmentation. + + Args: + result (str): Model output string containing loc and optional seg tokens + resolution_wh (Tuple[int, int]): Target resolution (width, height) + classes (Optional[List[str]]): List of class names to filter results + + Returns: + Tuple containing: + - xyxy (np.ndarray): Bounding box coordinates + - class_id (Optional[np.ndarray]): Class IDs if classes provided + - class_name (np.ndarray): Class names + - mask (Optional[np.ndarray]): Segmentation masks if available + """ w, h = resolution_wh - pattern = re.compile( - r"(?) ([\w\s\-]+)" - ) - matches = pattern.findall(result) - matches = np.array(matches) if matches else np.empty((0, 5)) - xyxy, class_name = matches[:, [1, 0, 3, 2]], matches[:, 4] - xyxy = xyxy.astype(int) / 1024 * np.array([w, h, w, h]) - class_name = np.char.strip(class_name.astype(str)) - class_id = None + segmentation_pattern = re.compile( + r"\s*" + + "".join(r"" for _ in range(16)) + + r"\s+([\w\s\-]+)" + ) - if classes is not None: - mask = np.array([name in classes for name in class_name]).astype(bool) - xyxy, class_name = xyxy[mask], class_name[mask] - class_id = np.array([classes.index(name) for name in class_name]) + detection_pattern = re.compile( + r"(?) ([\w\s\-]+)" + ) - return xyxy, class_id, class_name + segmentation_matches = segmentation_pattern.findall(result) + if segmentation_matches: + matches = np.array(segmentation_matches) + xyxy = matches[:, [1, 0, 3, 2]].astype(int) / 1024 * np.array([w, h, w, h]) + class_name = np.char.strip(matches[:, -1].astype(str)) + class_id = None + + seg_tokens = matches[:, 4:-1].astype(int) + masks = [] + for tokens in seg_tokens: + mask = np.zeros((h, w), dtype=bool) + masks.append(mask) + masks = np.array(masks) + + if classes is not None: + mask = np.array([name in classes for name in class_name]).astype(bool) + xyxy = xyxy[mask] + class_name = class_name[mask] + masks = masks[mask] + class_id = np.array([classes.index(name) for name in class_name]) + + return xyxy, class_id, class_name, masks + + detection_matches = detection_pattern.findall(result) + if detection_matches: + matches = np.array(detection_matches) + xyxy = matches[:, [1, 0, 3, 2]].astype(int) / 1024 * np.array([w, h, w, h]) + class_name = np.char.strip(matches[:, 4].astype(str)) + class_id = None + + if classes is not None: + mask = np.array([name in classes for name in class_name]).astype(bool) + xyxy, class_name = xyxy[mask], class_name[mask] + class_id = np.array([classes.index(name) for name in class_name]) + + return xyxy, class_id, class_name, None + + return np.empty((0, 4)), None, np.array([]), None def from_florence_2(