diff --git a/perception_eval/perception_eval/manager/perception_evaluation_manager.py b/perception_eval/perception_eval/manager/perception_evaluation_manager.py index 0c9dcd92..14cc76fe 100644 --- a/perception_eval/perception_eval/manager/perception_evaluation_manager.py +++ b/perception_eval/perception_eval/manager/perception_evaluation_manager.py @@ -79,8 +79,8 @@ def add_frame_result( unix_time: int, ground_truth_now_frame: FrameGroundTruth, estimated_objects: List[ObjectType], - critical_ground_truth_objects: Optional[List[ObjectType]] = None, frame_config: Optional[PerceptionFrameConfig] = None, + critical_ground_truth_objects: Optional[List[ObjectType]] = None, ) -> PerceptionFrameResult: """Get perception result at current frame. @@ -111,7 +111,7 @@ def add_frame_result( critical_ground_truth_objects = ground_truth_now_frame.objects.copy() if frame_config is None: - frame_config = PerceptionFrameConfig(self.config) + frame_config = PerceptionFrameConfig.from_eval_cfg(self.config) result = PerceptionFrameResult( unix_time=unix_time, diff --git a/perception_eval/perception_eval/result/perception/perception_frame_config.py b/perception_eval/perception_eval/result/perception/perception_frame_config.py index 5b5ea011..ad229281 100644 --- a/perception_eval/perception_eval/result/perception/perception_frame_config.py +++ b/perception_eval/perception_eval/result/perception/perception_frame_config.py @@ -14,16 +14,17 @@ from __future__ import annotations -from typing import Any -from typing import Dict from typing import List from typing import Optional from typing import TYPE_CHECKING +from typing import Union from perception_eval.common.label import set_target_lists from perception_eval.common.threshold import check_thresholds +from perception_eval.config.params import PerceptionFilterParam if TYPE_CHECKING: + from perception_eval.common.label import LabelType from perception_eval.config import PerceptionEvaluationConfig @@ -52,15 +53,15 @@ class PerceptionFrameConfig: def __init__( self, evaluator_config: PerceptionEvaluationConfig, - target_labels: List[str], - ignore_attributes: Optional[List[str]] = None, + target_labels: List[Union[str, LabelType]], max_x_position_list: Optional[List[float]] = None, max_y_position_list: Optional[List[float]] = None, - max_distance_list: Optional[List[float]] = None, min_distance_list: Optional[List[float]] = None, + max_distance_list: Optional[List[float]] = None, min_point_numbers: Optional[List[int]] = None, confidence_threshold_list: Optional[List[float]] = None, target_uuids: Optional[List[str]] = None, + ignore_attributes: Optional[List[str]] = None, thresholds: Optional[List[float]] = None, ) -> None: """ @@ -82,56 +83,42 @@ def __init__( target_uuids (Optional[List[str]]): The list of target uuid. Defaults to None. """ self.evaluation_task = evaluator_config.evaluation_task - - self.target_labels = set_target_lists(target_labels, evaluator_config.label_converter) - self.ignore_attributes = ignore_attributes - - num_elements: int = len(self.target_labels) - if max_x_position_list and max_y_position_list: - self.max_x_position_list: List[float] = check_thresholds(max_x_position_list, num_elements) - self.max_y_position_list: List[float] = check_thresholds(max_y_position_list, num_elements) - self.max_distance_list = None - self.min_distance_list = None - elif max_distance_list and min_distance_list: - self.max_distance_list: List[float] = check_thresholds(max_distance_list, num_elements) - self.min_distance_list: List[float] = check_thresholds(min_distance_list, num_elements) - self.max_x_position_list = None - self.max_y_position_list = None - elif evaluator_config.evaluation_task.is_2d(): - self.max_x_position_list = None - self.max_y_position_list = None - self.max_distance_list = None - self.min_distance_list = None + if all([isinstance(label, str) for label in target_labels]): + self.target_labels = set_target_lists(target_labels, evaluator_config.label_converter) else: - raise RuntimeError("Either max x/y position or max/min distance should be specified") - - if min_point_numbers is None: - self.min_point_numbers = None - else: - self.min_point_numbers: List[int] = check_thresholds(min_point_numbers, num_elements) - - if confidence_threshold_list is None: - self.confidence_threshold_list = None - else: - self.confidence_threshold_list: List[float] = check_thresholds(confidence_threshold_list, num_elements) - - self.target_uuids: Optional[List[str]] = target_uuids - - self.filtering_params: Dict[str, Any] = { - "target_labels": self.target_labels, - "ignore_attributes": self.ignore_attributes, - "max_x_position_list": self.max_x_position_list, - "max_y_position_list": self.max_y_position_list, - "max_distance_list": self.max_distance_list, - "min_distance_list": self.min_distance_list, - "min_point_numbers": self.min_point_numbers, - "confidence_threshold_list": self.confidence_threshold_list, - "target_uuids": self.target_uuids, - } + self.target_labels = target_labels + + self.filter_param = PerceptionFilterParam( + evaluation_task=self.evaluation_task, + target_labels=self.target_labels, + max_x_position_list=max_x_position_list, + max_y_position_list=max_y_position_list, + min_distance_list=min_distance_list, + max_distance_list=max_distance_list, + min_point_numbers=min_point_numbers, + confidence_threshold_list=confidence_threshold_list, + target_uuids=target_uuids, + ignore_attributes=ignore_attributes, + ) num_elements: int = len(self.target_labels) - if thresholds is None: self.thresholds = None else: self.thresholds = check_thresholds(thresholds, num_elements) + + @classmethod + def from_eval_cfg(cls, eval_cfg: PerceptionEvaluationConfig) -> PerceptionFrameConfig: + return cls( + evaluator_config=eval_cfg, + target_labels=eval_cfg.target_labels, + max_x_position_list=eval_cfg.filter_param.max_x_position_list, + max_y_position_list=eval_cfg.filter_param.max_y_position_list, + min_distance_list=eval_cfg.filter_param.min_distance_list, + max_distance_list=eval_cfg.filter_param.max_distance_list, + min_point_numbers=eval_cfg.filter_param.min_point_numbers, + confidence_threshold_list=eval_cfg.filter_param.confidence_threshold_list, + target_uuids=eval_cfg.filter_param.target_uuids, + ignore_attributes=eval_cfg.filter_param.ignore_attributes, + thresholds=eval_cfg.metrics_param.plane_distance_thresholds, + ) diff --git a/perception_eval/perception_eval/result/perception/perception_pass_fail_result.py b/perception_eval/perception_eval/result/perception/perception_pass_fail_result.py index f19bad5b..e05bd828 100644 --- a/perception_eval/perception_eval/result/perception/perception_pass_fail_result.py +++ b/perception_eval/perception_eval/result/perception/perception_pass_fail_result.py @@ -84,7 +84,7 @@ def evaluate( objects=critical_ground_truth_objects, is_gt=True, ego2map=self.ego2map, - **self.frame_config.filtering_params, + **self.frame_config.filter_param.as_dict(), ) self.tp_object_results, self.fp_object_results = self.__get_positive_object_results( object_results=object_results, diff --git a/perception_eval/test/perception_fp_validation_lsim.py b/perception_eval/test/perception_fp_validation_lsim.py index 781a5673..6bddbd46 100644 --- a/perception_eval/test/perception_fp_validation_lsim.py +++ b/perception_eval/test/perception_fp_validation_lsim.py @@ -82,8 +82,8 @@ def callback(self, unix_time: int, estimated_objects: List[ObjectType]) -> None: unix_time=unix_time, ground_truth_now_frame=ground_truth_now_frame, estimated_objects=estimated_objects, - critical_ground_truth_objects=critical_ground_truth_objects, frame_config=frame_config, + critical_ground_truth_objects=critical_ground_truth_objects, ) self.display(frame_result) diff --git a/perception_eval/test/perception_lsim.py b/perception_eval/test/perception_lsim.py index c59e9b52..63dcd874 100644 --- a/perception_eval/test/perception_lsim.py +++ b/perception_eval/test/perception_lsim.py @@ -123,8 +123,8 @@ def callback( unix_time=unix_time, ground_truth_now_frame=ground_truth_now_frame, estimated_objects=estimated_objects, - critical_ground_truth_objects=critical_ground_truth_objects, frame_config=frame_config, + critical_ground_truth_objects=critical_ground_truth_objects, ) self.visualize(frame_result) diff --git a/perception_eval/test/perception_lsim2d.py b/perception_eval/test/perception_lsim2d.py index ddeb5c2f..9398eb2c 100644 --- a/perception_eval/test/perception_lsim2d.py +++ b/perception_eval/test/perception_lsim2d.py @@ -129,8 +129,8 @@ def callback( unix_time=unix_time, ground_truth_now_frame=ground_truth_now_frame, estimated_objects=estimated_objects, - critical_ground_truth_objects=critical_ground_truth_objects, frame_config=frame_config, + critical_ground_truth_objects=critical_ground_truth_objects, ) self.visualize(frame_result)