From 5eff927c07c7ac0c3f1c93006762fea6978f010a Mon Sep 17 00:00:00 2001 From: unexcellent <> Date: Tue, 29 Oct 2024 18:44:48 +0100 Subject: [PATCH] refactor: remove filter functionality --- raillabel/__init__.py | 1 - raillabel/filter/_filter_classes/__init__.py | 20 - .../filter/_filter_classes/_filter_abc.py | 55 --- .../_filter_classes/_filter_annotation_ids.py | 21 - .../_filter_annotation_types.py | 24 -- .../_filter_classes/_filter_attributes.py | 39 -- .../filter/_filter_classes/_filter_end.py | 21 - .../filter/_filter_classes/_filter_frames.py | 20 - .../_filter_classes/_filter_object_ids.py | 21 - .../_filter_classes/_filter_object_types.py | 21 - .../filter/_filter_classes/_filter_sensors.py | 21 - .../filter/_filter_classes/_filter_start.py | 21 - raillabel/filter/filter.py | 216 ---------- tests/test_raillabel/filter/test_filter.py | 400 ------------------ 14 files changed, 901 deletions(-) delete mode 100644 raillabel/filter/_filter_classes/__init__.py delete mode 100644 raillabel/filter/_filter_classes/_filter_abc.py delete mode 100644 raillabel/filter/_filter_classes/_filter_annotation_ids.py delete mode 100644 raillabel/filter/_filter_classes/_filter_annotation_types.py delete mode 100644 raillabel/filter/_filter_classes/_filter_attributes.py delete mode 100644 raillabel/filter/_filter_classes/_filter_end.py delete mode 100644 raillabel/filter/_filter_classes/_filter_frames.py delete mode 100644 raillabel/filter/_filter_classes/_filter_object_ids.py delete mode 100644 raillabel/filter/_filter_classes/_filter_object_types.py delete mode 100644 raillabel/filter/_filter_classes/_filter_sensors.py delete mode 100644 raillabel/filter/_filter_classes/_filter_start.py delete mode 100644 raillabel/filter/filter.py delete mode 100644 tests/test_raillabel/filter/test_filter.py diff --git a/raillabel/__init__.py b/raillabel/__init__.py index bfcb251..c32f684 100644 --- a/raillabel/__init__.py +++ b/raillabel/__init__.py @@ -6,7 +6,6 @@ from . import format from .exceptions import * -from .filter.filter import filter from .format import Scene from .load.load import load from .save.save import save diff --git a/raillabel/filter/_filter_classes/__init__.py b/raillabel/filter/_filter_classes/__init__.py deleted file mode 100644 index 097c89c..0000000 --- a/raillabel/filter/_filter_classes/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright DB InfraGO AG and contributors -# SPDX-License-Identifier: Apache-2.0 -"""Package containing the loader classes for all supported formats.""" - -from importlib import import_module -from inspect import isclass -from pathlib import Path -from pkgutil import iter_modules - -# iterate through the modules in the current package -package_dir = str(Path(__file__).resolve().parent) -for _, module_name, _ in iter_modules([package_dir]): - # import the module and iterate through its attributes - module = import_module(f"{__name__}.{module_name}") - for attribute_name in dir(module): - attribute = getattr(module, attribute_name) - - if isclass(attribute): - # Add the class to this package's variables - globals()[attribute_name] = attribute diff --git a/raillabel/filter/_filter_classes/_filter_abc.py b/raillabel/filter/_filter_classes/_filter_abc.py deleted file mode 100644 index 3346d37..0000000 --- a/raillabel/filter/_filter_classes/_filter_abc.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright DB InfraGO AG and contributors -# SPDX-License-Identifier: Apache-2.0 - -import typing as t -from abc import ABC, abstractmethod, abstractproperty - -from ...format import Frame, _ObjectAnnotation - - -class _FilterABC(ABC): - """ABC for all filter classes. - - Creating a new filter - --------------------- - To create a new, custom filter create a new class in this dir, that inherits from _FilterABC. - Any class, that inherits from _FilterABC will automatically be loaded by the filter function. - Include the filter arguments (include_[...], exclude_[...], ...) in the PARAMETERS field. These - will be mutually exclusive.Select a level for the filter. The level determines where the filter - is going to be applied (e.g. at the frame level, annotation level, ...). Include the conditions - to pass the filter in the passes_filter() method, which returns True if the filter is passed. - The contents of the filter arguments can optionally be processed by the _process_filter_args(). - """ - - @property - @abstractproperty - def PARAMETERS(self) -> t.List[str]: - raise NotImplementedError - - @property - @abstractproperty - def LEVELS(self) -> t.List[str]: - raise NotImplementedError - - def __init__(self, kwargs) -> None: - set_parameter = None - for param in self.PARAMETERS: - if param in kwargs and param is not None: - if set_parameter is None: - setattr(self, param, self._process_filter_args(kwargs[param])) - set_parameter = param - else: - raise ValueError( - f"{set_parameter} and {param} are mutually exclusive, but were both set." - ) - - else: - setattr(self, param, None) - - @abstractmethod - def passes_filter(self, annotation: t.Union[t.Type[_ObjectAnnotation], Frame]) -> bool: - raise NotImplementedError - - def _process_filter_args(self, filter_args): - """Process filter arguments (optional).""" - return filter_args diff --git a/raillabel/filter/_filter_classes/_filter_annotation_ids.py b/raillabel/filter/_filter_classes/_filter_annotation_ids.py deleted file mode 100644 index b460fce..0000000 --- a/raillabel/filter/_filter_classes/_filter_annotation_ids.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright DB InfraGO AG and contributors -# SPDX-License-Identifier: Apache-2.0 - -import typing as t -from typing import ClassVar - -from ._filter_abc import _FilterABC, _ObjectAnnotation - - -class _FilterAnnotationIds(_FilterABC): - PARAMETERS: ClassVar = ["include_annotation_ids", "exclude_annotation_ids"] - LEVELS: ClassVar = ["annotation"] - - def passes_filter(self, annotation: t.Type[_ObjectAnnotation]) -> bool: - if self.include_annotation_ids is not None: - return annotation.uid in self.include_annotation_ids - - if self.exclude_annotation_ids is not None: - return annotation.uid not in self.exclude_annotation_ids - - return True diff --git a/raillabel/filter/_filter_classes/_filter_annotation_types.py b/raillabel/filter/_filter_classes/_filter_annotation_types.py deleted file mode 100644 index e44d673..0000000 --- a/raillabel/filter/_filter_classes/_filter_annotation_types.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright DB InfraGO AG and contributors -# SPDX-License-Identifier: Apache-2.0 - -import typing as t -from typing import ClassVar - -from ._filter_abc import _FilterABC, _ObjectAnnotation - - -class _FilterAnnotationTypes(_FilterABC): - PARAMETERS: ClassVar = ["include_annotation_types", "exclude_annotation_types"] - LEVELS: ClassVar = ["annotation"] - - def passes_filter(self, annotation: t.Type[_ObjectAnnotation]) -> bool: - if self.include_annotation_types is not None: - return annotation.__class__.__name__.lower() in self.include_annotation_types - - if self.exclude_annotation_types is not None: - return annotation.__class__.__name__.lower() not in self.exclude_annotation_types - - return True - - def _process_filter_args(self, filter_args): - return [arg.lower() for arg in filter_args] diff --git a/raillabel/filter/_filter_classes/_filter_attributes.py b/raillabel/filter/_filter_classes/_filter_attributes.py deleted file mode 100644 index 44a8263..0000000 --- a/raillabel/filter/_filter_classes/_filter_attributes.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright DB InfraGO AG and contributors -# SPDX-License-Identifier: Apache-2.0 - -import typing as t -from typing import ClassVar - -from ._filter_abc import _FilterABC, _ObjectAnnotation - - -class _FilterAttributes(_FilterABC): - PARAMETERS: ClassVar = ["include_attributes", "exclude_attributes"] - LEVELS: ClassVar = ["annotation"] - - def passes_filter(self, annotation: t.Type[_ObjectAnnotation]) -> bool: - if self.include_attributes is not None: - for attribute_id, attribute_val in self.include_attributes.items(): - if attribute_val is None: - if attribute_id not in annotation.attributes: - return False - - elif ( - attribute_id not in annotation.attributes - or annotation.attributes[attribute_id] != attribute_val - ): - return False - - elif self.exclude_attributes is not None: - for attribute_id, attribute_val in self.exclude_attributes.items(): - if attribute_val is None: - if attribute_id in annotation.attributes: - return False - - elif ( - attribute_id in annotation.attributes - and attribute_val == annotation.attributes[attribute_id] - ): - return False - - return True diff --git a/raillabel/filter/_filter_classes/_filter_end.py b/raillabel/filter/_filter_classes/_filter_end.py deleted file mode 100644 index cc97667..0000000 --- a/raillabel/filter/_filter_classes/_filter_end.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright DB InfraGO AG and contributors -# SPDX-License-Identifier: Apache-2.0 - -from decimal import Decimal -from typing import ClassVar - -from ._filter_abc import Frame, _FilterABC - - -class _FilterEnd(_FilterABC): - PARAMETERS: ClassVar = ["end_frame", "end_timestamp"] - LEVELS: ClassVar = ["frame"] - - def passes_filter(self, frame: Frame) -> bool: - if self.end_frame is not None: - return frame.uid <= self.end_frame - - if self.end_timestamp is not None: - return frame.timestamp <= Decimal(self.end_timestamp) - - return True diff --git a/raillabel/filter/_filter_classes/_filter_frames.py b/raillabel/filter/_filter_classes/_filter_frames.py deleted file mode 100644 index b2ee895..0000000 --- a/raillabel/filter/_filter_classes/_filter_frames.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright DB InfraGO AG and contributors -# SPDX-License-Identifier: Apache-2.0 - -from typing import ClassVar - -from ._filter_abc import Frame, _FilterABC - - -class _FilterFrame(_FilterABC): - PARAMETERS: ClassVar = ["include_frames", "exclude_frames"] - LEVELS: ClassVar = ["frame"] - - def passes_filter(self, frame: Frame) -> bool: - if self.include_frames is not None: - return int(frame.uid) in self.include_frames - - if self.exclude_frames is not None: - return int(frame.uid) not in self.exclude_frames - - return True diff --git a/raillabel/filter/_filter_classes/_filter_object_ids.py b/raillabel/filter/_filter_classes/_filter_object_ids.py deleted file mode 100644 index fe85bc4..0000000 --- a/raillabel/filter/_filter_classes/_filter_object_ids.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright DB InfraGO AG and contributors -# SPDX-License-Identifier: Apache-2.0 - -import typing as t -from typing import ClassVar - -from ._filter_abc import _FilterABC, _ObjectAnnotation - - -class _FilterObjectIds(_FilterABC): - PARAMETERS: ClassVar = ["include_object_ids", "exclude_object_ids"] - LEVELS: ClassVar = ["annotation"] - - def passes_filter(self, annotation: t.Type[_ObjectAnnotation]) -> bool: - if self.include_object_ids is not None: - return annotation.object.uid in self.include_object_ids - - if self.exclude_object_ids is not None: - return annotation.object.uid not in self.exclude_object_ids - - return True diff --git a/raillabel/filter/_filter_classes/_filter_object_types.py b/raillabel/filter/_filter_classes/_filter_object_types.py deleted file mode 100644 index 01d708d..0000000 --- a/raillabel/filter/_filter_classes/_filter_object_types.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright DB InfraGO AG and contributors -# SPDX-License-Identifier: Apache-2.0 - -import typing as t -from typing import ClassVar - -from ._filter_abc import _FilterABC, _ObjectAnnotation - - -class _FilterObjectTypes(_FilterABC): - PARAMETERS: ClassVar = ["include_object_types", "exclude_object_types"] - LEVELS: ClassVar = ["annotation"] - - def passes_filter(self, annotation: t.Type[_ObjectAnnotation]) -> bool: - if self.include_object_types is not None: - return annotation.object.type in self.include_object_types - - if self.exclude_object_types is not None: - return annotation.object.type not in self.exclude_object_types - - return True diff --git a/raillabel/filter/_filter_classes/_filter_sensors.py b/raillabel/filter/_filter_classes/_filter_sensors.py deleted file mode 100644 index 4aa65f9..0000000 --- a/raillabel/filter/_filter_classes/_filter_sensors.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright DB InfraGO AG and contributors -# SPDX-License-Identifier: Apache-2.0 - -import typing as t -from typing import ClassVar - -from ._filter_abc import _FilterABC, _ObjectAnnotation - - -class _FilterSensors(_FilterABC): - PARAMETERS: ClassVar = ["include_sensors", "exclude_sensors"] - LEVELS: ClassVar = ["frame_data", "annotation"] - - def passes_filter(self, annotation: t.Type[_ObjectAnnotation]) -> bool: - if self.include_sensors is not None: - return annotation.sensor.uid in self.include_sensors - - if self.exclude_sensors is not None: - return annotation.sensor.uid not in self.exclude_sensors - - return True diff --git a/raillabel/filter/_filter_classes/_filter_start.py b/raillabel/filter/_filter_classes/_filter_start.py deleted file mode 100644 index 7a0bd17..0000000 --- a/raillabel/filter/_filter_classes/_filter_start.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright DB InfraGO AG and contributors -# SPDX-License-Identifier: Apache-2.0 - -from decimal import Decimal -from typing import ClassVar - -from ._filter_abc import Frame, _FilterABC - - -class _FilterStart(_FilterABC): - PARAMETERS: ClassVar = ["start_frame", "start_timestamp"] - LEVELS: ClassVar = ["frame"] - - def passes_filter(self, frame: Frame) -> bool: - if self.start_frame is not None: - return frame.uid >= self.start_frame - - if self.start_timestamp is not None: - return frame.timestamp >= Decimal(self.start_timestamp) - - return True diff --git a/raillabel/filter/filter.py b/raillabel/filter/filter.py deleted file mode 100644 index ef5df9c..0000000 --- a/raillabel/filter/filter.py +++ /dev/null @@ -1,216 +0,0 @@ -# Copyright DB InfraGO AG and contributors -# SPDX-License-Identifier: Apache-2.0 - -import pickle -import typing as t - -from .. import format -from . import _filter_classes - - -def filter(scene: format.Scene, **kwargs) -> format.Scene: - """Return a copy of the scene with the annotations filtered. - - Parameters - ---------- - scene: raillabel.Scene - Scene, which should be copied and filtered. - include_object_types: str or list of str, optional - List of class/type names that should be included in the filtered scene. If set, no - other classes/types will be copied. Mutually exclusive with exclude_object_types. - exclude_object_types: str or list of str, optional - List of class/type names that should be excluded in the filtered scene. If set, all - other classes/types will be copied. Mutually exclusive with include_object_types. - include_annotation_types: str or list of str, optional - List of annotation types (i.e. bboxs, cuboids, poly2ds, seg3ds) that should be included - in the filtered scene. If set, no other annotation types will be copied. Mutually - exclusive with exclude_annotation_types. - exclude_annotation_types: str or list of str, optional - List of annotation types (i.e. bboxs, cuboids, poly2ds, seg3ds) that should be excluded - in the filtered scene. If set, all other annotation types will be copied. Mutually - exclusive with include_annotation_types. - include_annotation_ids: str or list of str, optional - List of annotation UIDs that should be included in the filtered scene. If set, no other - annotation UIDs will be copied. Mutually exclusive with exclude_annotation_ids. - exclude_annotation_ids: str or list of str, optional - List of annotation UIDs that should be excluded in the filtered scene. If set, all - other annotation UIDs will be copied. Mutually exclusive with include_annotation_ids. - include_object_ids: str or list of str, optional - List of object UIDs that should be included in the filtered scene. If set, no other - objects will be copied. Mutually exclusive with exclude_object_ids. - exclude_object_ids: str or list of str, optional - List of object UIDs that should be excluded in the filtered scene. If set, all other - objects will be copied. Mutually exclusive with include_object_ids. - include_sensors: str or list of str - List of sensors that should be included in the filtered scene. If set, no other - sensors will be copied. Mutually exclusive with exclude_sensors. - exclude_sensors: str or list of str, optional - List of sensors that should be excluded in the filtered scene. If set, all other - sensors will be copied. Mutually exclusive with include_sensors. - include_attributes: dict, optional - Dict of attributes that should be included in the filtered scene. Dict keys are the - attribute names, values are the specific values that should be included. If the - value is set so None, all annotations with the attribute are included regardless of - value. Mutually exclusive with exclude_attributes. - exclude_attributes: dict, optional - Dict of attributes that should be excluded in the filtered scene. Dict keys are the - attribute names, values are the specific values that should be excluded. If the value - is set so None, all annotations with the attribute are excluded regardless of value. - Mutually exclusive with include_attributes. - include_frames: int or list of int, optional - List of frame UIDs that should be included in the filtered scene. If set, no other - frames will be copied. Mutually exclusive with exclude_frames. - exclude_frames: int or list of int, optional - List of frame UIDs that should be excluded in the filtered scene. If set, all other - frames will be copied. Mutually exclusive with include_frames. - start_frame: int, optional - Frame at which the filtered scene should start. Mutually exclusive with s - tart_timestamp. - end_frame: int, optional - Frame at which the filtered scene should end (inclusive). Mutually exclusive with - end_timestamp. - start_timestamp: decimal.Decimal, optional - Unix timestamp at which the filtered scene should start (inclusive). Mutually exclusive - with start_frame. - end_timestamp: decimal.Decimal, optional - Unix timestamp at which the filtered scene should end (inclusive). Mutually exclusive - with end_frame. - - Raises - ------ - ValueError - if two mutually exclusive parameters are set. - TypeError - if an unexpected keyword argument has been set. - - """ - filters_by_level = _collect_filter_classes(kwargs) - filtered_scene, used_sensors, used_objects = _filter_scene(_copy(scene), filters_by_level) - filtered_scene = _remove_unused(filtered_scene, used_sensors, used_objects) - - return filtered_scene - - -# --- Prepare filter classes - - -def _collect_filter_classes(kwargs) -> t.Tuple[t.List[t.Type], t.List[str]]: - filters = [] - supported_kwargs = [] - for cls in _filter_classes.__dict__.values(): - if ( - isinstance(cls, type) - and issubclass(cls, _filter_classes._FilterABC) - and cls != _filter_classes._FilterABC - ): - filters.append(cls(kwargs)) - supported_kwargs.extend(cls.PARAMETERS) - - _check_for_unsupported_arg(kwargs, supported_kwargs) - - return _seperate_filters_by_level(filters) - - -def _check_for_unsupported_arg(kwargs: t.List[str], supported_kwargs: t.List[str]): - for arg in kwargs: - if arg not in supported_kwargs: - raise TypeError( - f"filter() got an unexpected keyword argument '{arg}'. Supported keyword " - + f"arguments: {sorted(supported_kwargs)}" - ) - - -def _seperate_filters_by_level(filters: t.List[t.Type]) -> t.Dict[str, t.List[t.Type]]: - all_filter_levels = [level for f in filters for level in f.LEVELS] - - filters_by_level = {level: [] for level in all_filter_levels} - for level in filters_by_level: - for filter_class in filters: - if level in filter_class.LEVELS: - filters_by_level[level].append(filter_class) - - return filters_by_level - - -# --- Filter scene - - -def _filter_scene( - scene: format.Scene, filters_by_level: t.Dict[str, t.List[t.Type]] -) -> t.Tuple[format.Scene, t.Set[str], t.Set[str]]: - used_sensors = set() - used_objects = set() - - for frame_id, frame in list(scene.frames.items()): - if not _passes_filters(frame, filters_by_level["frame"]): - del scene.frames[frame_id] - continue - - for frame_data_id, frame_data in list(frame.frame_data.items()): - if _passes_filters(frame_data, filters_by_level["frame_data"]): - used_sensors.add(frame_data.sensor.uid) - - else: - del scene.frames[frame_id].frame_data[frame_data_id] - - for annotation_id, annotation in list(frame.annotations.items()): - if _passes_filters(annotation, filters_by_level["annotation"]): - used_objects.add(annotation.object.uid) - used_sensors.add(annotation.sensor.uid) - - else: - del scene.frames[frame_id].annotations[annotation_id] - - return scene, used_sensors, used_objects - - -# --- Remove unused - - -def _remove_unused( - scene: format.Scene, used_sensors: t.Set[str], used_objects: t.Set[str] -) -> format.Scene: - scene = _remove_unused_sensors(scene, used_sensors) - scene = _remove_unused_objects(scene, used_objects) - - for frame_id in scene.frames: - scene.frames[frame_id] = _remove_unused_sensor_references( - scene.frames[frame_id], used_sensors - ) - - return scene - - -def _remove_unused_sensors(scene: format.Scene, used_sensors: t.Set[str]) -> format.Scene: - for sensor_id in list(scene.sensors): - if sensor_id not in used_sensors: - del scene.sensors[sensor_id] - - return scene - - -def _remove_unused_objects(scene: format.Scene, used_objects: t.Set[str]) -> format.Scene: - for object_id in list(scene.objects): - if object_id not in used_objects: - del scene.objects[object_id] - - return scene - - -def _remove_unused_sensor_references(frame: format.Frame, used_sensors: t.Set[str]) -> format.Frame: - for sensor_id in list(frame.sensors): - if sensor_id not in used_sensors: - del frame.sensors[sensor_id] - - return frame - - -# --- Helper functions - - -def _passes_filters(data, filters): - return all(f.passes_filter(data) for f in filters) - - -def _copy(object): - return pickle.loads(pickle.dumps(object, -1)) diff --git a/tests/test_raillabel/filter/test_filter.py b/tests/test_raillabel/filter/test_filter.py deleted file mode 100644 index 9259abf..0000000 --- a/tests/test_raillabel/filter/test_filter.py +++ /dev/null @@ -1,400 +0,0 @@ -# Copyright DB InfraGO AG and contributors -# SPDX-License-Identifier: Apache-2.0 - -import os -import sys -from pathlib import Path - -import pytest - -sys.path.insert(1, str(Path(__file__).parent.parent.parent.parent)) - -import raillabel - - -def delete_sensor_from_data(data: dict, sensor_id: str) -> dict: - del data["openlabel"]["streams"][sensor_id] - del data["openlabel"]["coordinate_systems"][sensor_id] - del data["openlabel"]["coordinate_systems"]["base"]["children"][ - data["openlabel"]["coordinate_systems"]["base"]["children"].index(sensor_id) - ] - - for frame_id in data["openlabel"]["frames"]: - if sensor_id not in data["openlabel"]["frames"][frame_id]["frame_properties"]["streams"]: - continue - - del data["openlabel"]["frames"][frame_id]["frame_properties"]["streams"][sensor_id] - - return data - - -def test_filter_unexpected_kwarg(json_paths): - # Loads scene - scene = raillabel.load(json_paths["openlabel_v1_short"]) - - with pytest.raises(TypeError): - raillabel.filter(scene, unsupported_kwarg=[]) - - -def test_mutual_exclusivity(json_paths): - # Loads scene - scene = raillabel.load(json_paths["openlabel_v1_short"]) - - with pytest.raises(ValueError): - raillabel.filter(scene, include_frames=[0], exclude_frames=[1, 2]) - - -def test_filter_frames(json_paths, json_data): - data = json_data["openlabel_v1_short"] - - # Loads scene - scene = raillabel.load(json_paths["openlabel_v1_short"]) - - # Deletes the excluded data - del data["openlabel"]["frames"]["1"] - del data["openlabel"]["objects"]["6fe55546-0dd7-4e40-b6b4-bb7ea3445772"] - data = delete_sensor_from_data(data, "radar") - - # Loads the ground truth filtered data - scene_filtered_ground_truth = raillabel.Scene.fromdict(data) - - # Tests for include filter - scene_filtered = raillabel.filter(scene, include_frames=[0]) - assert scene_filtered == scene_filtered_ground_truth - - # Tests for exclude filter - scene_filtered = raillabel.filter(scene, exclude_frames=[1]) - assert scene_filtered == scene_filtered_ground_truth - - -def test_filter_start(json_paths, json_data): - data = json_data["openlabel_v1_short"] - - # Loads scene - scene = raillabel.load(json_paths["openlabel_v1_short"]) - - # Deletes the excluded data - del data["openlabel"]["frames"]["0"] - del data["openlabel"]["objects"]["b40ba3ad-0327-46ff-9c28-2506cfd6d934"] - data = delete_sensor_from_data(data, "radar") - - # Loads the ground truth filtered data - scene_filtered_ground_truth = raillabel.Scene.fromdict(data) - - # Tests for frame filter - scene_filtered = raillabel.filter(scene, start_frame=1) - assert scene_filtered == scene_filtered_ground_truth - - # Tests for timestamp filter - scene_filtered = raillabel.filter(scene, start_timestamp="1632321743.134150") - assert scene_filtered == scene_filtered_ground_truth - - -def test_filter_end(json_paths, json_data): - data = json_data["openlabel_v1_short"] - - # Loads scene - scene = raillabel.load(json_paths["openlabel_v1_short"]) - - # Deletes the excluded data - del data["openlabel"]["frames"]["1"] - del data["openlabel"]["objects"]["6fe55546-0dd7-4e40-b6b4-bb7ea3445772"] - data = delete_sensor_from_data(data, "radar") - - # Loads the ground truth filtered data - scene_filtered_ground_truth = raillabel.Scene.fromdict(data) - - # Tests for frame filter - scene_filtered = raillabel.filter(scene, end_frame=0) - assert scene_filtered == scene_filtered_ground_truth - - # Tests for timestamp filter - scene_filtered = raillabel.filter(scene, end_timestamp="1632321743.233250") - assert scene_filtered == scene_filtered_ground_truth - - -def test_filter_object_ids(json_paths, json_data): - data = json_data["openlabel_v1_short"] - - # Loads scene - scene = raillabel.load(json_paths["openlabel_v1_short"]) - - # Deletes the excluded data - del data["openlabel"]["objects"]["22dedd49-6dcb-413b-87ef-00ccfb532e98"] - del data["openlabel"]["frames"]["0"]["objects"]["22dedd49-6dcb-413b-87ef-00ccfb532e98"] - del data["openlabel"]["frames"]["1"]["objects"]["22dedd49-6dcb-413b-87ef-00ccfb532e98"] - data = delete_sensor_from_data(data, "radar") - - # Loads the ground truth filtered data - scene_filtered_ground_truth = raillabel.Scene.fromdict(data) - - # Tests for include filter - scene_filtered = raillabel.filter( - scene, - include_object_ids=[ - "6fe55546-0dd7-4e40-b6b4-bb7ea3445772", - "b40ba3ad-0327-46ff-9c28-2506cfd6d934", - ], - ) - - assert scene_filtered == scene_filtered_ground_truth - - # Tests for exclude filter - scene_filtered = raillabel.filter( - scene, exclude_object_ids=["22dedd49-6dcb-413b-87ef-00ccfb532e98"] - ) - assert scene_filtered == scene_filtered_ground_truth - - -def test_filter_object_types(json_paths, json_data): - data = json_data["openlabel_v1_short"] - - # Loads scene - scene = raillabel.load(json_paths["openlabel_v1_short"]) - - # Deletes the excluded data - del data["openlabel"]["objects"]["22dedd49-6dcb-413b-87ef-00ccfb532e98"] - del data["openlabel"]["frames"]["0"]["objects"]["22dedd49-6dcb-413b-87ef-00ccfb532e98"] - del data["openlabel"]["frames"]["1"]["objects"]["22dedd49-6dcb-413b-87ef-00ccfb532e98"] - data = delete_sensor_from_data(data, "radar") - - # Loads the ground truth filtered data - scene_filtered_ground_truth = raillabel.Scene.fromdict(data) - - # Tests for include filter - scene_filtered = raillabel.filter(scene, include_object_types=["person"]) - assert scene_filtered == scene_filtered_ground_truth - - # Tests for exclude filter - scene_filtered = raillabel.filter(scene, exclude_object_types=["train"]) - assert scene_filtered == scene_filtered_ground_truth - - -def test_filter_annotation_ids(json_paths, json_data): - data = json_data["openlabel_v1_short"] - - # Loads scene - scene = raillabel.load(json_paths["openlabel_v1_short"]) - - # Deletes the excluded data - del data["openlabel"]["objects"]["22dedd49-6dcb-413b-87ef-00ccfb532e98"] - del data["openlabel"]["frames"]["0"]["objects"]["22dedd49-6dcb-413b-87ef-00ccfb532e98"] - del data["openlabel"]["frames"]["1"]["objects"]["22dedd49-6dcb-413b-87ef-00ccfb532e98"] - data = delete_sensor_from_data(data, "radar") - - # Loads the ground truth filtered data - scene_filtered_ground_truth = raillabel.Scene.fromdict(data) - - # Tests for include filter - scene_filtered = raillabel.filter( - scene, - include_annotation_ids=[ - "78f0ad89-2750-4a30-9d66-44c9da73a714", - "68b4e02c-40c8-4de0-89ad-bc00ed05a043", - "bebfbae4-61a2-4758-993c-efa846b050a5", - "3f63201c-fb33-4487-aff6-ae0aa5fa976c", - "dc2be700-8ee4-45c4-9256-920b5d55c917", - "c1087f1d-7271-4dee-83ad-519a4e3b78a8", - "50be7fe3-1f43-47ca-b65a-930e6cfacfeb", - "6ba42cbc-484e-4b8d-a022-b23c2bb6643c", - "5f28fa18-8f2a-4a40-a0b6-c0bbedc00f2e", - "e2503c5d-9fe4-4666-b510-ef644c5a766b", - "450ceb81-9778-4e63-bf89-42f3ed9f6747", - ], - ) - - assert scene_filtered == scene_filtered_ground_truth - - # Tests for exclude filter - scene_filtered = raillabel.filter( - scene, - exclude_annotation_ids=[ - "14f58fb0-add7-4ed9-85b3-74615986d854", - "536ac83a-32c8-4fce-8499-ef32716c64a6", - "e53bd5e3-980a-4fa7-a0f9-5a2e59ba663c", - "550df2c3-0e66-483e-bcc6-f3013b7e581b", - "12b21c52-06ea-4269-9805-e7167e7a74ed", - ], - ) - assert scene_filtered == scene_filtered_ground_truth - - -def test_filter_annotation_types(json_paths, json_data): - data = json_data["openlabel_v1_short"] - - # Loads scene - scene = raillabel.load(json_paths["openlabel_v1_short"]) - - # Deletes the excluded data - del data["openlabel"]["objects"]["22dedd49-6dcb-413b-87ef-00ccfb532e98"] - - del data["openlabel"]["frames"]["0"]["objects"]["b40ba3ad-0327-46ff-9c28-2506cfd6d934"][ - "object_data" - ]["cuboid"] - del data["openlabel"]["frames"]["0"]["objects"]["b40ba3ad-0327-46ff-9c28-2506cfd6d934"][ - "object_data" - ]["vec"] - del data["openlabel"]["frames"]["0"]["objects"]["22dedd49-6dcb-413b-87ef-00ccfb532e98"] - - del data["openlabel"]["frames"]["1"]["objects"]["22dedd49-6dcb-413b-87ef-00ccfb532e98"] - data = delete_sensor_from_data(data, "radar") - - # Loads the ground truth filtered data - scene_filtered_ground_truth = raillabel.Scene.fromdict(data) - - # Tests for include filter - scene_filtered = raillabel.filter(scene, include_annotation_types=["bbox", "poly2d", "Num"]) - - assert scene_filtered == scene_filtered_ground_truth - - # Tests for exclude filter - scene_filtered = raillabel.filter(scene, exclude_annotation_types=["cuboid", "Poly3d", "seg3d"]) - assert scene_filtered == scene_filtered_ground_truth - - -def test_filter_sensors(json_paths, json_data): - data = json_data["openlabel_v1_short"] - - # Loads scene - scene = raillabel.load(json_paths["openlabel_v1_short"]) - - # Deletes the excluded data - del data["openlabel"]["objects"]["22dedd49-6dcb-413b-87ef-00ccfb532e98"] - - data = delete_sensor_from_data(data, "lidar") - data = delete_sensor_from_data(data, "radar") - - del data["openlabel"]["frames"]["0"]["frame_properties"]["frame_data"]["num"][-1] - del data["openlabel"]["frames"]["0"]["objects"]["b40ba3ad-0327-46ff-9c28-2506cfd6d934"][ - "object_data" - ]["cuboid"] - del data["openlabel"]["frames"]["0"]["objects"]["b40ba3ad-0327-46ff-9c28-2506cfd6d934"][ - "object_data" - ]["vec"] - del data["openlabel"]["frames"]["0"]["objects"]["22dedd49-6dcb-413b-87ef-00ccfb532e98"] - - del data["openlabel"]["frames"]["1"]["frame_properties"]["frame_data"]["num"][-1] - del data["openlabel"]["frames"]["1"]["objects"]["22dedd49-6dcb-413b-87ef-00ccfb532e98"] - - # Loads the ground truth filtered data - scene_filtered_ground_truth = raillabel.Scene.fromdict(data) - - # Tests for include filter - scene_filtered = raillabel.filter(scene, include_sensors=["rgb_middle", "ir_middle"]) - - assert scene_filtered == scene_filtered_ground_truth - - # Tests for exclude filter - scene_filtered = raillabel.filter(scene, exclude_sensors=["lidar"]) - assert scene_filtered == scene_filtered_ground_truth - - -def test_filter_include_attribute_ids(json_paths, json_data): - data = json_data["openlabel_v1_short"] - - # Loads scene - scene = raillabel.load(json_paths["openlabel_v1_short"]) - - # Deletes the excluded data - del data["openlabel"]["frames"]["0"]["objects"]["22dedd49-6dcb-413b-87ef-00ccfb532e98"] - del data["openlabel"]["frames"]["0"]["objects"]["b40ba3ad-0327-46ff-9c28-2506cfd6d934"][ - "object_data" - ]["bbox"][1] - del data["openlabel"]["frames"]["0"]["objects"]["b40ba3ad-0327-46ff-9c28-2506cfd6d934"][ - "object_data" - ]["poly2d"][1] - del data["openlabel"]["frames"]["0"]["objects"]["b40ba3ad-0327-46ff-9c28-2506cfd6d934"][ - "object_data" - ]["cuboid"][1] - del data["openlabel"]["frames"]["0"]["objects"]["b40ba3ad-0327-46ff-9c28-2506cfd6d934"][ - "object_data" - ]["vec"][1] - - del data["openlabel"]["frames"]["1"]["objects"]["6fe55546-0dd7-4e40-b6b4-bb7ea3445772"][ - "object_data" - ]["bbox"][1] - del data["openlabel"]["frames"]["1"]["objects"]["22dedd49-6dcb-413b-87ef-00ccfb532e98"][ - "object_data" - ]["cuboid"][1] - del data["openlabel"]["frames"]["1"]["objects"]["22dedd49-6dcb-413b-87ef-00ccfb532e98"][ - "object_data" - ]["vec"][1] - - data = delete_sensor_from_data(data, "ir_middle") - data = delete_sensor_from_data(data, "radar") - - # Loads the ground truth filtered data - scene_filtered_ground_truth = raillabel.Scene.fromdict(data) - - # Tests for include filter - scene_filtered = raillabel.filter(scene, include_attributes={"test_text_attr0": None}) - - assert scene_filtered == scene_filtered_ground_truth - - -def test_filter_exclude_attribute_ids(json_paths, json_data): - data = json_data["openlabel_v1_short"] - - # Loads scene - scene = raillabel.load(json_paths["openlabel_v1_short"]) - - # Deletes the excluded data - del data["openlabel"]["frames"]["0"]["objects"]["b40ba3ad-0327-46ff-9c28-2506cfd6d934"][ - "object_data" - ]["bbox"][0] - del data["openlabel"]["frames"]["0"]["objects"]["b40ba3ad-0327-46ff-9c28-2506cfd6d934"][ - "object_data" - ]["poly2d"][0] - del data["openlabel"]["frames"]["0"]["objects"]["b40ba3ad-0327-46ff-9c28-2506cfd6d934"][ - "object_data" - ]["cuboid"][0] - del data["openlabel"]["frames"]["0"]["objects"]["b40ba3ad-0327-46ff-9c28-2506cfd6d934"][ - "object_data" - ]["vec"][0] - - del data["openlabel"]["frames"]["1"]["objects"]["6fe55546-0dd7-4e40-b6b4-bb7ea3445772"][ - "object_data" - ]["bbox"][0] - del data["openlabel"]["frames"]["1"]["objects"]["6fe55546-0dd7-4e40-b6b4-bb7ea3445772"][ - "object_data" - ]["poly2d"][0] - del data["openlabel"]["frames"]["1"]["objects"]["22dedd49-6dcb-413b-87ef-00ccfb532e98"][ - "object_data" - ]["cuboid"][0] - del data["openlabel"]["frames"]["1"]["objects"]["22dedd49-6dcb-413b-87ef-00ccfb532e98"][ - "object_data" - ]["vec"][0] - data = delete_sensor_from_data(data, "radar") - - # Loads the ground truth filtered data - scene_filtered_ground_truth = raillabel.Scene.fromdict(data) - - # Tests for exclude filter - scene_filtered = raillabel.filter(scene, exclude_attributes={"test_text_attr0": None}) - assert scene_filtered == scene_filtered_ground_truth - - -def test_filter_exclude_attribute_values(json_paths, json_data): - data = json_data["openlabel_v1_short"] - - # Loads scene - scene = raillabel.load(json_paths["openlabel_v1_short"]) - - # Deletes the excluded data - del data["openlabel"]["frames"]["0"]["objects"]["b40ba3ad-0327-46ff-9c28-2506cfd6d934"][ - "object_data" - ]["poly2d"][0] - data = delete_sensor_from_data(data, "radar") - - # Loads the ground truth filtered data - scene_filtered_ground_truth = raillabel.Scene.fromdict(data) - - # Tests for exclude filter - scene_filtered = raillabel.filter(scene, exclude_attributes={"test_num_attr0": 2}) - assert scene_filtered == scene_filtered_ground_truth - - -# Executes the test if the file is called -if __name__ == "__main__": - os.system("clear") - pytest.main([__file__, "--disable-pytest-warnings", "--cache-clear"])