From 6af7a92243c8f33fcae5ad0a44baf5ec9d2bfac7 Mon Sep 17 00:00:00 2001 From: ktro2828 Date: Fri, 29 Dec 2023 15:04:18 +0900 Subject: [PATCH] feat: add `MatchingPolicy` Signed-off-by: ktro2828 --- .../perception_eval/common/label/__init__.py | 2 + .../perception_eval/common/label/utils.py | 15 +++ .../perception_eval/common/schema.py | 17 +++ .../config/perception_evaluation_config.py | 3 + .../manager/perception_evaluation_manager.py | 8 +- .../perception_eval/matching/__init__.py | 2 + .../matching/matching_policy.py | 122 ++++++++++++++++++ .../matching/object_matching.py | 5 + .../result/perception/perception_result.py | 97 +++----------- .../test/perception_fp_validation_lsim.py | 4 +- perception_eval/test/perception_lsim.py | 4 +- 11 files changed, 196 insertions(+), 83 deletions(-) create mode 100644 perception_eval/perception_eval/matching/matching_policy.py diff --git a/perception_eval/perception_eval/common/label/__init__.py b/perception_eval/perception_eval/common/label/__init__.py index 7c46f87a..eb9b2f54 100644 --- a/perception_eval/perception_eval/common/label/__init__.py +++ b/perception_eval/perception_eval/common/label/__init__.py @@ -4,6 +4,7 @@ from .types import LabelType from .types import SemanticLabel from .types import TrafficLightLabel +from .utils import is_same_label from .utils import set_target_lists __all__ = ( @@ -13,5 +14,6 @@ "LabelType", "SemanticLabel", "TrafficLightLabel", + "is_same_label", "set_target_lists", ) diff --git a/perception_eval/perception_eval/common/label/utils.py b/perception_eval/perception_eval/common/label/utils.py index 7f58d6ec..e03f9c1e 100644 --- a/perception_eval/perception_eval/common/label/utils.py +++ b/perception_eval/perception_eval/common/label/utils.py @@ -5,6 +5,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from perception_eval.object import ObjectType + from .converter import LabelConverter from .types import LabelType @@ -32,3 +34,16 @@ def set_target_lists( if target_labels is None or len(target_labels) == 0: return [label for label in label_converter.label_type] return [label_converter.convert_name(name) for name in target_labels] + + +def is_same_label(object1: ObjectType, object2: ObjectType) -> bool: + """Return `True`, if both objects has the same label. + + Args: + object1 (ObjectType): An object. + object2 (ObjectType): An object. + + Returns: + bool: `True` if the label is same. + """ + return object1.semantic_label == object2.semantic_label diff --git a/perception_eval/perception_eval/common/schema.py b/perception_eval/perception_eval/common/schema.py index d43c4fe6..c2c22980 100644 --- a/perception_eval/perception_eval/common/schema.py +++ b/perception_eval/perception_eval/common/schema.py @@ -17,10 +17,14 @@ from enum import Enum import logging from typing import Dict +from typing import TYPE_CHECKING from typing import Union from perception_eval.common.evaluation_task import EvaluationTask +if TYPE_CHECKING: + from perception_eval.object import ObjectType + class FrameID(Enum): # 3D @@ -91,6 +95,19 @@ def from_task(cls, task: Union[str, EvaluationTask]) -> FrameID: raise ValueError(f"Unexpected task: {task}") +def is_same_frame_id(object1: ObjectType, object2: ObjectType) -> bool: + """Returns `True` if the both objects has same frame id. + + Args: + object1 (ObjectType): An object. + object2 (ObjectType): An object. + + Returns: + bool: `True` if the frame id is same. + """ + return object1.frame_id == object2.frame_id + + class Visibility(Enum): """Visibility status class. diff --git a/perception_eval/perception_eval/config/perception_evaluation_config.py b/perception_eval/perception_eval/config/perception_evaluation_config.py index d6ea1688..7a15e7b5 100644 --- a/perception_eval/perception_eval/config/perception_evaluation_config.py +++ b/perception_eval/perception_eval/config/perception_evaluation_config.py @@ -19,6 +19,7 @@ from typing import Tuple from typing import Union +from perception_eval.matching import MatchingPolicy from perception_eval.metrics import MetricsScoreConfig from .evaluation_config_base import EvaluationConfigBase @@ -84,6 +85,8 @@ def __init__( load_raw_data=load_raw_data, ) + self.matching_policy = MatchingPolicy.from_dict(config_dict) + self.metrics_config = MetricsScoreConfig(self.metrics_param) def _extract_params(self, cfg: Dict[str, Any]) -> Tuple[PerceptionFilterParam, PerceptionMetricsParam]: diff --git a/perception_eval/perception_eval/manager/perception_evaluation_manager.py b/perception_eval/perception_eval/manager/perception_evaluation_manager.py index 14cc76fe..e114d35f 100644 --- a/perception_eval/perception_eval/manager/perception_evaluation_manager.py +++ b/perception_eval/perception_eval/manager/perception_evaluation_manager.py @@ -35,6 +35,7 @@ from perception_eval.common.label import LabelType from perception_eval.config import PerceptionEvaluationConfig from perception_eval.dataset import FrameGroundTruth + from perception_eval.matching import MatchingPolicy from perception_eval.object import ObjectType from perception_eval.result import DynamicObjectWithPerceptionResult from perception_eval.result import PerceptionFrameConfig @@ -66,6 +67,10 @@ def __init__(self, config: PerceptionEvaluationConfig) -> None: def target_labels(self) -> List[LabelType]: return self.config.target_labels + @property + def matching_policy(self) -> MatchingPolicy: + return self.config.matching_policy + @property def metrics_config(self): return self.config.metrics_config @@ -170,8 +175,7 @@ def _filter_objects( estimated_objects=estimated_objects, ground_truth_objects=frame_ground_truth.objects, target_labels=self.target_labels, - # allow_matching_unknown=self.label_param["allow_matching_unknown"], TODO - # matchable_thresholds=self.filtering_params["max_matchable_radii"], + matching_policy=self.matching_policy, ) if self.filter_param.target_uuids is not None: diff --git a/perception_eval/perception_eval/matching/__init__.py b/perception_eval/perception_eval/matching/__init__.py index 0044d8d6..6eec5a6b 100644 --- a/perception_eval/perception_eval/matching/__init__.py +++ b/perception_eval/perception_eval/matching/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .matching_policy import MatchingPolicy from .object_matching import CenterDistanceMatching from .object_matching import IOU2dMatching from .object_matching import IOU3dMatching @@ -20,6 +21,7 @@ from .object_matching import PlaneDistanceMatching __all__ = ( + "MatchingPolicy", "CenterDistanceMatching", "IOU2dMatching", "IOU3dMatching", diff --git a/perception_eval/perception_eval/matching/matching_policy.py b/perception_eval/perception_eval/matching/matching_policy.py new file mode 100644 index 00000000..1616912a --- /dev/null +++ b/perception_eval/perception_eval/matching/matching_policy.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from enum import Enum +from numbers import Number +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union + +from perception_eval.common.label import is_same_label +from perception_eval.common.schema import is_same_frame_id +from perception_eval.common.threshold import get_label_threshold + +from .object_matching import CenterDistanceMatching +from .object_matching import IOU2dMatching +from .object_matching import IOU3dMatching +from .object_matching import MatchingMode +from .object_matching import PlaneDistanceMatching + +if TYPE_CHECKING: + from perception_eval.common.label import LabelType + from perception_eval.object import ObjectType + + from .object_matching import MatchingMethod + + +class MatchingLabelPolicy(Enum): + STRICT = "STRICT" + ALLOW_UNKNOWN = "ALLOW_UNKNOWN" + ALLOW_ANY = "ALLOW_ANY" + + @classmethod + def from_str(cls, name: str) -> MatchingLabelPolicy: + name = name.upper() + assert name in cls.__members__, f"{name} is not enum member" + return cls.__members__[name] + + def is_matchable(self, estimation: ObjectType, ground_truth: ObjectType) -> bool: + if ground_truth.semantic_label.is_fp() or self == MatchingLabelPolicy.ALLOW_ANY: + return True + elif self == MatchingLabelPolicy.ALLOW_UNKNOWN: + return is_same_label(estimation, ground_truth) or estimation.semantic_label.is_unknown() + else: # STRICT + return is_same_label(estimation, ground_truth) + + +class MatchingPolicy: + def __init__( + self, + matching_mode: Optional[Union[str, MatchingMode]] = None, + label_policy: Optional[Union[str, MatchingLabelPolicy]] = None, + matchable_thresholds: Optional[List[Number]] = None, + ) -> None: + if matching_mode is None: + self.matching_mode = MatchingMode.CENTERDISTANCE + elif isinstance(matching_mode, str): + self.matching_mode = MatchingMode.from_str(matching_mode) + else: + self.matching_mode = matching_mode + + self.matching_module, self.maximize = self.get_matching_module(self.matching_mode) + + if label_policy is None: + self.label_policy = MatchingLabelPolicy.STRICT + elif isinstance(label_policy, str): + self.label_policy = MatchingLabelPolicy.from_str(label_policy) + else: + self.label_policy = label_policy + + self.matchable_thresholds = matchable_thresholds + + @classmethod + def from_dict(cls, cfg: Dict[str, Any]) -> MatchingPolicy: + matching_mode = cfg.get("matching_mode") + label_policy = cfg.get("matching_label_policy") + matchable_thresholds = cfg.get("matchable_thresholds") + return cls(matching_mode=matching_mode, label_policy=label_policy, matchable_thresholds=matchable_thresholds) + + @staticmethod + def get_matching_module(matching_mode: MatchingMode) -> Tuple[Callable, bool]: + if matching_mode == MatchingMode.CENTERDISTANCE: + matching_method_module: CenterDistanceMatching = CenterDistanceMatching + maximize: bool = False + elif matching_mode == MatchingMode.PLANEDISTANCE: + matching_method_module: PlaneDistanceMatching = PlaneDistanceMatching + maximize: bool = False + elif matching_mode == MatchingMode.IOU2D: + matching_method_module: IOU2dMatching = IOU2dMatching + maximize: bool = True + elif matching_mode == MatchingMode.IOU3D: + matching_method_module: IOU3dMatching = IOU3dMatching + maximize: bool = True + else: + raise ValueError(f"Unsupported matching mode: {matching_mode}") + + return matching_method_module, maximize + + def is_matchable(self, estimation: ObjectType, ground_truth: ObjectType) -> bool: + return self.label_policy.is_matchable(estimation, ground_truth) and is_same_frame_id(estimation, ground_truth) + + def get_matching_score( + self, + estimation: ObjectType, + ground_truth: ObjectType, + target_labels: List[LabelType], + ) -> Optional[float]: + threshold: Optional[float] = get_label_threshold( + ground_truth.semantic_label, + target_labels, + self.matchable_thresholds, + ) + + matching_method: MatchingMethod = self.matching_module(estimation, ground_truth) + + if threshold is None or (threshold is not None and matching_method.is_better_than(threshold)): + return matching_method.value + else: + return None diff --git a/perception_eval/perception_eval/matching/object_matching.py b/perception_eval/perception_eval/matching/object_matching.py index fa3f7a7d..fd274473 100644 --- a/perception_eval/perception_eval/matching/object_matching.py +++ b/perception_eval/perception_eval/matching/object_matching.py @@ -53,6 +53,11 @@ class MatchingMode(Enum): def __str__(self) -> str: return self.value + @classmethod + def from_str(cls, name: str) -> MatchingMode: + assert name in cls.__members__, f"{name} is not enum member" + return cls.__members__[name] + class MatchingMethod(ABC): """A base class for matching method class. diff --git a/perception_eval/perception_eval/result/perception/perception_result.py b/perception_eval/perception_eval/result/perception/perception_result.py index 9e1440e5..e4790b4a 100644 --- a/perception_eval/perception_eval/result/perception/perception_result.py +++ b/perception_eval/perception_eval/result/perception/perception_result.py @@ -14,16 +14,16 @@ from __future__ import annotations -from typing import Callable from typing import List from typing import Optional from typing import Tuple +from typing import TYPE_CHECKING import numpy as np from perception_eval.common.evaluation_task import EvaluationTask -from perception_eval.common.label import LabelType +from perception_eval.common.label import is_same_label +from perception_eval.common.schema import is_same_frame_id from perception_eval.common.status import MatchingStatus -from perception_eval.common.threshold import get_label_threshold from perception_eval.matching import CenterDistanceMatching from perception_eval.matching import IOU2dMatching from perception_eval.matching import IOU3dMatching @@ -34,7 +34,11 @@ from perception_eval.object import distance_objects_bev from perception_eval.object import DynamicObject from perception_eval.object import DynamicObject2D -from perception_eval.object import ObjectType + +if TYPE_CHECKING: + from perception_eval.common.label import LabelType + from perception_eval.matching import MatchingPolicy + from perception_eval.object import ObjectType class DynamicObjectWithPerceptionResult: @@ -258,10 +262,7 @@ def is_label_correct(self) -> bool: Returns: bool: Whether label is correct """ - if self.ground_truth_object: - return self.estimated_object.semantic_label == self.ground_truth_object.semantic_label - else: - return False + return is_same_label(self.estimated_object, self.ground_truth_object) if self.ground_truth_object else False def get_object_results( @@ -269,9 +270,7 @@ def get_object_results( estimated_objects: List[ObjectType], ground_truth_objects: List[ObjectType], target_labels: Optional[List[LabelType]] = None, - allow_matching_unknown: bool = True, - matching_mode: MatchingMode = MatchingMode.CENTERDISTANCE, - matchable_thresholds: Optional[List[float]] = None, + matching_policy: MatchingPolicy = MatchingPolicy(), ) -> List[DynamicObjectWithPerceptionResult]: """Returns list of DynamicObjectWithPerceptionResult. @@ -306,15 +305,7 @@ def get_object_results( if evaluation_task == EvaluationTask.CLASSIFICATION2D: return _get_object_results_with_id(estimated_objects, ground_truth_objects) - matching_method_module, maximize = _get_matching_module(matching_mode) - score_table = _get_score_table( - estimated_objects, - ground_truth_objects, - allow_matching_unknown, - matching_method_module, - target_labels, - matchable_thresholds, - ) + score_table = _get_score_table(estimated_objects, ground_truth_objects, target_labels, matching_policy) # assign correspond GT to estimated objects object_results: List[DynamicObjectWithPerceptionResult] = [] @@ -327,7 +318,7 @@ def get_object_results( est_idx, gt_idx = ( np.unravel_index(np.nanargmax(score_table), score_table.shape) - if maximize + if matching_policy.maximize else np.unravel_index(np.nanargmin(score_table), score_table.shape) ) @@ -377,8 +368,8 @@ def _get_object_results_with_id( ) if ( est_object.uuid == gt_object.uuid - and est_object.semantic_label == gt_object.semantic_label - and est_object.frame_id == gt_object.frame_id + and is_same_label(est_object, gt_object) + and is_same_frame_id(est_object, gt_object) ): object_results.append( DynamicObjectWithPerceptionResult( @@ -418,41 +409,11 @@ def _get_fp_object_results( return object_results -def _get_matching_module(matching_mode: MatchingMode) -> Tuple[Callable, bool]: - """Returns the matching function and boolean flag whether choose maximum value or not. - - Args: - matching_mode (MatchingMode): MatchingMode instance. - - Returns: - matching_method_module (Callable): MatchingMethod instance. - maximize (bool): Whether much bigger is better. - """ - if matching_mode == MatchingMode.CENTERDISTANCE: - matching_method_module: CenterDistanceMatching = CenterDistanceMatching - maximize: bool = False - elif matching_mode == MatchingMode.PLANEDISTANCE: - matching_method_module: PlaneDistanceMatching = PlaneDistanceMatching - maximize: bool = False - elif matching_mode == MatchingMode.IOU2D: - matching_method_module: IOU2dMatching = IOU2dMatching - maximize: bool = True - elif matching_mode == MatchingMode.IOU3D: - matching_method_module: IOU3dMatching = IOU3dMatching - maximize: bool = True - else: - raise ValueError(f"Unsupported matching mode: {matching_mode}") - - return matching_method_module, maximize - - def _get_score_table( estimated_objects: List[ObjectType], ground_truth_objects: List[ObjectType], - allow_matching_unknown: bool, - matching_method_module: Callable, - target_labels: Optional[List[LabelType]], - matchable_thresholds: Optional[List[float]], + target_labels: List[LabelType], + matching_policy: MatchingPolicy, ) -> np.ndarray: """Returns score table, in shape (num_estimation, num_ground_truth). @@ -471,27 +432,9 @@ def _get_score_table( num_row: int = len(estimated_objects) num_col: int = len(ground_truth_objects) score_table = np.full((num_row, num_col), np.nan) - for i, est_obj in enumerate(estimated_objects): - for j, gt_obj in enumerate(ground_truth_objects): - if gt_obj.semantic_label.is_fp(): - is_label_ok = True - elif allow_matching_unknown: - is_label_ok = est_obj.semantic_label == gt_obj.semantic_label or est_obj.semantic_label.is_unknown() - else: - is_label_ok = est_obj.semantic_label == gt_obj.semantic_label - - is_same_frame_id = est_obj.frame_id == gt_obj.frame_id - - if is_label_ok and is_same_frame_id: - threshold: Optional[float] = get_label_threshold( - gt_obj.semantic_label, target_labels, matchable_thresholds - ) - - matching_method: MatchingMethod = matching_method_module( - estimated_object=est_obj, ground_truth_object=gt_obj - ) - - if threshold is None or (threshold is not None and matching_method.is_better_than(threshold)): - score_table[i, j] = matching_method.value + for i, estimation in enumerate(estimated_objects): + for j, ground_truth in enumerate(ground_truth_objects): + if matching_policy.is_matchable(estimation, ground_truth): + score_table[i, j] = matching_policy.get_matching_score(estimation, ground_truth, target_labels) return score_table diff --git a/perception_eval/test/perception_fp_validation_lsim.py b/perception_eval/test/perception_fp_validation_lsim.py index 1e0a7bd8..817be1f3 100644 --- a/perception_eval/test/perception_fp_validation_lsim.py +++ b/perception_eval/test/perception_fp_validation_lsim.py @@ -41,10 +41,10 @@ def __init__(self, dataset_paths: List[int], result_root_directory: str) -> None "target_labels": ["car", "bicycle", "pedestrian", "motorbike"], "max_x_position": 102.4, "max_y_position": 102.4, - "max_matchable_radii": [5.0, 3.0, 3.0, 3.0], + "matching_label_policy": "strict", + "matchable_thresholds": [5.0, 3.0, 3.0, 3.0], "merge_similar_labels": False, "label_prefix": "autoware", - "allow_matching_unknown": True, } evaluation_config = PerceptionEvaluationConfig( diff --git a/perception_eval/test/perception_lsim.py b/perception_eval/test/perception_lsim.py index 2a987567..43d5e55d 100644 --- a/perception_eval/test/perception_lsim.py +++ b/perception_eval/test/perception_lsim.py @@ -70,11 +70,11 @@ def __init__( "iou_2d_thresholds": [0.5, 0.5, 0.5, 0.5], # = [[0.5, 0.5, 0.5, 0.5]] "iou_3d_thresholds": [0.5], # = [[0.5, 0.5, 0.5, 0.5]] "min_point_numbers": [0, 0, 0, 0], - "max_matchable_radii": 5.0, # = [5.0, 5.0, 5.0, 5.0] + "matching_label_policy": "strict", + "matchable_thresholds": [5.0, 5.0, 5.0, 5.0], # label parameters "label_prefix": "autoware", "merge_similar_labels": False, - "allow_matching_unknown": True, } evaluation_config = PerceptionEvaluationConfig(