Skip to content

Commit

Permalink
feat: update PerceptionFrameConfig
Browse files Browse the repository at this point in the history
Signed-off-by: ktro2828 <[email protected]>
  • Loading branch information
ktro2828 committed Dec 28, 2023
1 parent a72959c commit 903601c
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def add_frame_result(
unix_time: int,
ground_truth_now_frame: FrameGroundTruth,
estimated_objects: List[ObjectType],
critical_ground_truth_objects: Optional[List[ObjectType]] = None,
frame_config: Optional[PerceptionFrameConfig] = None,
critical_ground_truth_objects: Optional[List[ObjectType]] = None,
) -> PerceptionFrameResult:
"""Get perception result at current frame.
Expand Down Expand Up @@ -111,7 +111,7 @@ def add_frame_result(
critical_ground_truth_objects = ground_truth_now_frame.objects.copy()

if frame_config is None:
frame_config = PerceptionFrameConfig(self.config)
frame_config = PerceptionFrameConfig.from_eval_cfg(self.config)

result = PerceptionFrameResult(
unix_time=unix_time,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@

from __future__ import annotations

from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union

from perception_eval.common.label import set_target_lists
from perception_eval.common.threshold import check_thresholds
from perception_eval.config.params import PerceptionFilterParam

if TYPE_CHECKING:
from perception_eval.common.label import LabelType
from perception_eval.config import PerceptionEvaluationConfig


Expand Down Expand Up @@ -52,15 +53,15 @@ class PerceptionFrameConfig:
def __init__(
self,
evaluator_config: PerceptionEvaluationConfig,
target_labels: List[str],
ignore_attributes: Optional[List[str]] = None,
target_labels: List[Union[str, LabelType]],
max_x_position_list: Optional[List[float]] = None,
max_y_position_list: Optional[List[float]] = None,
max_distance_list: Optional[List[float]] = None,
min_distance_list: Optional[List[float]] = None,
max_distance_list: Optional[List[float]] = None,
min_point_numbers: Optional[List[int]] = None,
confidence_threshold_list: Optional[List[float]] = None,
target_uuids: Optional[List[str]] = None,
ignore_attributes: Optional[List[str]] = None,
thresholds: Optional[List[float]] = None,
) -> None:
"""
Expand All @@ -82,56 +83,42 @@ def __init__(
target_uuids (Optional[List[str]]): The list of target uuid. Defaults to None.
"""
self.evaluation_task = evaluator_config.evaluation_task

self.target_labels = set_target_lists(target_labels, evaluator_config.label_converter)
self.ignore_attributes = ignore_attributes

num_elements: int = len(self.target_labels)
if max_x_position_list and max_y_position_list:
self.max_x_position_list: List[float] = check_thresholds(max_x_position_list, num_elements)
self.max_y_position_list: List[float] = check_thresholds(max_y_position_list, num_elements)
self.max_distance_list = None
self.min_distance_list = None
elif max_distance_list and min_distance_list:
self.max_distance_list: List[float] = check_thresholds(max_distance_list, num_elements)
self.min_distance_list: List[float] = check_thresholds(min_distance_list, num_elements)
self.max_x_position_list = None
self.max_y_position_list = None
elif evaluator_config.evaluation_task.is_2d():
self.max_x_position_list = None
self.max_y_position_list = None
self.max_distance_list = None
self.min_distance_list = None
if all([isinstance(label, str) for label in target_labels]):
self.target_labels = set_target_lists(target_labels, evaluator_config.label_converter)
else:
raise RuntimeError("Either max x/y position or max/min distance should be specified")

if min_point_numbers is None:
self.min_point_numbers = None
else:
self.min_point_numbers: List[int] = check_thresholds(min_point_numbers, num_elements)

if confidence_threshold_list is None:
self.confidence_threshold_list = None
else:
self.confidence_threshold_list: List[float] = check_thresholds(confidence_threshold_list, num_elements)

self.target_uuids: Optional[List[str]] = target_uuids

self.filtering_params: Dict[str, Any] = {
"target_labels": self.target_labels,
"ignore_attributes": self.ignore_attributes,
"max_x_position_list": self.max_x_position_list,
"max_y_position_list": self.max_y_position_list,
"max_distance_list": self.max_distance_list,
"min_distance_list": self.min_distance_list,
"min_point_numbers": self.min_point_numbers,
"confidence_threshold_list": self.confidence_threshold_list,
"target_uuids": self.target_uuids,
}
self.target_labels = target_labels

self.filter_param = PerceptionFilterParam(
evaluation_task=self.evaluation_task,
target_labels=self.target_labels,
max_x_position_list=max_x_position_list,
max_y_position_list=max_y_position_list,
min_distance_list=min_distance_list,
max_distance_list=max_distance_list,
min_point_numbers=min_point_numbers,
confidence_threshold_list=confidence_threshold_list,
target_uuids=target_uuids,
ignore_attributes=ignore_attributes,
)

num_elements: int = len(self.target_labels)

if thresholds is None:
self.thresholds = None
else:
self.thresholds = check_thresholds(thresholds, num_elements)

@classmethod
def from_eval_cfg(cls, eval_cfg: PerceptionEvaluationConfig) -> PerceptionFrameConfig:
return cls(
evaluator_config=eval_cfg,
target_labels=eval_cfg.target_labels,
max_x_position_list=eval_cfg.filter_param.max_x_position_list,
max_y_position_list=eval_cfg.filter_param.max_y_position_list,
min_distance_list=eval_cfg.filter_param.min_distance_list,
max_distance_list=eval_cfg.filter_param.max_distance_list,
min_point_numbers=eval_cfg.filter_param.min_point_numbers,
confidence_threshold_list=eval_cfg.filter_param.confidence_threshold_list,
target_uuids=eval_cfg.filter_param.target_uuids,
ignore_attributes=eval_cfg.filter_param.ignore_attributes,
thresholds=eval_cfg.metrics_param.plane_distance_thresholds,
)
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def evaluate(
objects=critical_ground_truth_objects,
is_gt=True,
ego2map=self.ego2map,
**self.frame_config.filtering_params,
**self.frame_config.filter_param.as_dict(),
)
self.tp_object_results, self.fp_object_results = self.__get_positive_object_results(
object_results=object_results,
Expand Down
2 changes: 1 addition & 1 deletion perception_eval/test/perception_fp_validation_lsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def callback(self, unix_time: int, estimated_objects: List[ObjectType]) -> None:
unix_time=unix_time,
ground_truth_now_frame=ground_truth_now_frame,
estimated_objects=estimated_objects,
critical_ground_truth_objects=critical_ground_truth_objects,
frame_config=frame_config,
critical_ground_truth_objects=critical_ground_truth_objects,
)
self.display(frame_result)

Expand Down
2 changes: 1 addition & 1 deletion perception_eval/test/perception_lsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ def callback(
unix_time=unix_time,
ground_truth_now_frame=ground_truth_now_frame,
estimated_objects=estimated_objects,
critical_ground_truth_objects=critical_ground_truth_objects,
frame_config=frame_config,
critical_ground_truth_objects=critical_ground_truth_objects,
)
self.visualize(frame_result)

Expand Down
2 changes: 1 addition & 1 deletion perception_eval/test/perception_lsim2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def callback(
unix_time=unix_time,
ground_truth_now_frame=ground_truth_now_frame,
estimated_objects=estimated_objects,
critical_ground_truth_objects=critical_ground_truth_objects,
frame_config=frame_config,
critical_ground_truth_objects=critical_ground_truth_objects,
)
self.visualize(frame_result)

Expand Down

0 comments on commit 903601c

Please sign in to comment.