From a8b56674030780c359c3536ec71c6c641e94d8c6 Mon Sep 17 00:00:00 2001 From: unexcellent <> Date: Sat, 16 Nov 2024 18:03:17 +0100 Subject: [PATCH] feat: implement IncludeObjectTypeFilter --- raillabel/filter/__init__.py | 2 ++ raillabel/filter/_filter_abc.py | 4 ++-- .../filter/exclude_annotation_id_filter.py | 6 +++-- .../filter/exclude_annotation_type_filter.py | 6 +++-- raillabel/filter/exclude_object_id_filter.py | 6 +++-- raillabel/filter/filter.py | 17 +++++++------ .../filter/include_annotation_id_filter.py | 6 +++-- .../filter/include_annotation_type_filter.py | 6 +++-- raillabel/filter/include_object_id_filter.py | 6 +++-- .../filter/include_object_type_filter.py | 24 +++++++++++++++++++ tests/filter/test_filter.py | 13 ++++++++++ 11 files changed, 75 insertions(+), 21 deletions(-) create mode 100644 raillabel/filter/include_object_type_filter.py diff --git a/raillabel/filter/__init__.py b/raillabel/filter/__init__.py index 370f3f5..8fb800a 100644 --- a/raillabel/filter/__init__.py +++ b/raillabel/filter/__init__.py @@ -12,6 +12,7 @@ from .include_annotation_type_filter import IncludeAnnotationTypeFilter from .include_frame_id_filter import IncludeFrameIdFilter from .include_object_id_filter import IncludeObjectIdFilter +from .include_object_type_filter import IncludeObjectTypeFilter from .start_time_filter import StartTimeFilter __all__ = [ @@ -26,4 +27,5 @@ "ExcludeAnnotationTypeFilter", "IncludeObjectIdFilter", "ExcludeObjectIdFilter", + "IncludeObjectTypeFilter", ] diff --git a/raillabel/filter/_filter_abc.py b/raillabel/filter/_filter_abc.py index 70582aa..cb7b1b7 100644 --- a/raillabel/filter/_filter_abc.py +++ b/raillabel/filter/_filter_abc.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from uuid import UUID -from raillabel.format import Bbox, Cuboid, Frame, Poly2d, Poly3d, Seg3d +from raillabel.format import Bbox, Cuboid, Frame, Poly2d, Poly3d, Scene, Seg3d class _FilterAbc(ABC): @@ -17,7 +17,7 @@ class _AnnotationLevelFilter(_FilterAbc): @abstractmethod def passes_filter( - self, annotation_id: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d + self, annotation_id: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d, scene: Scene ) -> bool: """Assess if an annotation passes this filter.""" raise NotImplementedError diff --git a/raillabel/filter/exclude_annotation_id_filter.py b/raillabel/filter/exclude_annotation_id_filter.py index 7973e50..a41830f 100644 --- a/raillabel/filter/exclude_annotation_id_filter.py +++ b/raillabel/filter/exclude_annotation_id_filter.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from uuid import UUID -from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Seg3d +from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Scene, Seg3d from ._filter_abc import _AnnotationLevelFilter @@ -17,6 +17,8 @@ class ExcludeAnnotationIdFilter(_AnnotationLevelFilter): annotation_ids: set[UUID] | list[UUID] - def passes_filter(self, annotation_id: UUID, _: Bbox | Cuboid | Poly2d | Poly3d | Seg3d) -> bool: + def passes_filter( + self, annotation_id: UUID, _: Bbox | Cuboid | Poly2d | Poly3d | Seg3d, __: Scene + ) -> bool: """Assess if an annotation passes this filter.""" return annotation_id not in self.annotation_ids diff --git a/raillabel/filter/exclude_annotation_type_filter.py b/raillabel/filter/exclude_annotation_type_filter.py index 89e2357..80180dd 100644 --- a/raillabel/filter/exclude_annotation_type_filter.py +++ b/raillabel/filter/exclude_annotation_type_filter.py @@ -7,7 +7,7 @@ from typing import Literal from uuid import UUID -from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Seg3d +from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Scene, Seg3d from ._filter_abc import _AnnotationLevelFilter @@ -21,7 +21,9 @@ class ExcludeAnnotationTypeFilter(_AnnotationLevelFilter): | list[Literal["bbox", "cuboid", "poly2d", "poly3d", "seg3d"]] ) - def passes_filter(self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d) -> bool: + def passes_filter( + self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d, __: Scene + ) -> bool: """Assess if an annotation passes this filter.""" annotation_type_str = None diff --git a/raillabel/filter/exclude_object_id_filter.py b/raillabel/filter/exclude_object_id_filter.py index 9ed7963..f840d8e 100644 --- a/raillabel/filter/exclude_object_id_filter.py +++ b/raillabel/filter/exclude_object_id_filter.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from uuid import UUID -from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Seg3d +from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Scene, Seg3d from ._filter_abc import _AnnotationLevelFilter @@ -17,6 +17,8 @@ class ExcludeObjectIdFilter(_AnnotationLevelFilter): object_ids: list[UUID] - def passes_filter(self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d) -> bool: + def passes_filter( + self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d, __: Scene + ) -> bool: """Assess if an annotation passes this filter.""" return annotation.object_id not in self.object_ids diff --git a/raillabel/filter/filter.py b/raillabel/filter/filter.py index 174ef23..7fa2c9a 100644 --- a/raillabel/filter/filter.py +++ b/raillabel/filter/filter.py @@ -30,7 +30,7 @@ def filter_(scene: Scene, filters: list[_FilterAbc]) -> Scene: frame_filters, annotation_filters = _separate_filters(filters) filtered_scene = Scene(metadata=deepcopy(scene.metadata)) - filtered_scene.frames = _filter_frames(scene.frames, frame_filters, annotation_filters) + filtered_scene.frames = _filter_frames(scene, frame_filters, annotation_filters) filtered_scene.sensors = _get_used_sensors(scene, filtered_scene) filtered_scene.objects = _get_used_objects(scene, filtered_scene) @@ -53,31 +53,31 @@ def _separate_filters( def _filter_frames( - frames: dict[int, Frame], + scene: Scene, frame_filters: list[_FrameLevelFilter], annotation_filters: list[_AnnotationLevelFilter], ) -> dict[int, Frame]: filtered_frames = {} - for frame_id, frame in frames.items(): + for frame_id, frame in scene.frames.items(): if _frame_passes_all_filters(frame_id, frame, frame_filters): filtered_frames[frame_id] = Frame( timestamp=deepcopy(frame.timestamp), sensors=deepcopy(frame.sensors), frame_data=deepcopy(frame.frame_data), - annotations=_filter_annotations(frame, annotation_filters), + annotations=_filter_annotations(frame, annotation_filters, scene), ) return filtered_frames def _filter_annotations( - frame: Frame, annotation_filters: list[_AnnotationLevelFilter] + frame: Frame, annotation_filters: list[_AnnotationLevelFilter], scene: Scene ) -> dict[UUID, Bbox | Cuboid | Poly2d | Poly3d | Seg3d]: annotations = {} for annotation_id, annotation in frame.annotations.items(): - if _annotation_passes_all_filters(annotation_id, annotation, annotation_filters): + if _annotation_passes_all_filters(annotation_id, annotation, annotation_filters, scene): annotations[annotation_id] = deepcopy(annotation) return annotations @@ -93,8 +93,11 @@ def _annotation_passes_all_filters( annotation_id: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d, annotation_filters: list[_AnnotationLevelFilter], + scene: Scene, ) -> bool: - return all(filter_.passes_filter(annotation_id, annotation) for filter_ in annotation_filters) + return all( + filter_.passes_filter(annotation_id, annotation, scene) for filter_ in annotation_filters + ) def _get_used_sensors( diff --git a/raillabel/filter/include_annotation_id_filter.py b/raillabel/filter/include_annotation_id_filter.py index dc8f311..7cac89b 100644 --- a/raillabel/filter/include_annotation_id_filter.py +++ b/raillabel/filter/include_annotation_id_filter.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from uuid import UUID -from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Seg3d +from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Scene, Seg3d from ._filter_abc import _AnnotationLevelFilter @@ -17,6 +17,8 @@ class IncludeAnnotationIdFilter(_AnnotationLevelFilter): annotation_ids: set[UUID] | list[UUID] - def passes_filter(self, annotation_id: UUID, _: Bbox | Cuboid | Poly2d | Poly3d | Seg3d) -> bool: + def passes_filter( + self, annotation_id: UUID, _: Bbox | Cuboid | Poly2d | Poly3d | Seg3d, __: Scene + ) -> bool: """Assess if an annotation passes this filter.""" return annotation_id in self.annotation_ids diff --git a/raillabel/filter/include_annotation_type_filter.py b/raillabel/filter/include_annotation_type_filter.py index 5e4908e..300be80 100644 --- a/raillabel/filter/include_annotation_type_filter.py +++ b/raillabel/filter/include_annotation_type_filter.py @@ -7,7 +7,7 @@ from typing import Literal from uuid import UUID -from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Seg3d +from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Scene, Seg3d from ._filter_abc import _AnnotationLevelFilter @@ -21,7 +21,9 @@ class IncludeAnnotationTypeFilter(_AnnotationLevelFilter): | list[Literal["bbox", "cuboid", "poly2d", "poly3d", "seg3d"]] ) - def passes_filter(self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d) -> bool: + def passes_filter( + self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d, __: Scene + ) -> bool: """Assess if an annotation passes this filter.""" annotation_type_str = None diff --git a/raillabel/filter/include_object_id_filter.py b/raillabel/filter/include_object_id_filter.py index 96aa7e5..e4bccb9 100644 --- a/raillabel/filter/include_object_id_filter.py +++ b/raillabel/filter/include_object_id_filter.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from uuid import UUID -from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Seg3d +from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Scene, Seg3d from ._filter_abc import _AnnotationLevelFilter @@ -17,6 +17,8 @@ class IncludeObjectIdFilter(_AnnotationLevelFilter): object_ids: list[UUID] - def passes_filter(self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d) -> bool: + def passes_filter( + self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d, __: Scene + ) -> bool: """Assess if an annotation passes this filter.""" return annotation.object_id in self.object_ids diff --git a/raillabel/filter/include_object_type_filter.py b/raillabel/filter/include_object_type_filter.py new file mode 100644 index 0000000..aa1a352 --- /dev/null +++ b/raillabel/filter/include_object_type_filter.py @@ -0,0 +1,24 @@ +# Copyright DB InfraGO AG and contributors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +from uuid import UUID + +from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Scene, Seg3d + +from ._filter_abc import _AnnotationLevelFilter + + +@dataclass +class IncludeObjectTypeFilter(_AnnotationLevelFilter): + """Filter out all annotations in the scene, that do NOT match the type (like 'person').""" + + object_types: list[str] + + def passes_filter( + self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d, scene: Scene + ) -> bool: + """Assess if an annotation passes this filter.""" + return scene.objects[annotation.object_id].type in self.object_types diff --git a/tests/filter/test_filter.py b/tests/filter/test_filter.py index f196dcf..5a20f43 100644 --- a/tests/filter/test_filter.py +++ b/tests/filter/test_filter.py @@ -138,5 +138,18 @@ def test_exclude_object_ids(): assert actual == SceneBuilder.empty().add_bbox(object_name="person_0001").result +def test_include_object_types(): + scene = ( + SceneBuilder.empty() + .add_bbox(object_name="person_0001") + .add_cuboid(object_name="train_0001") + .result + ) + filters = [raillabel.filter.IncludeObjectTypeFilter(["person"])] + + actual = raillabel.filter.filter_(scene, filters) + assert actual == SceneBuilder.empty().add_bbox(object_name="person_0001").result + + if __name__ == "__main__": pytest.main([__file__, "-vv"])