Skip to content

Commit

Permalink
feat(perception): add prediction evaluation
Browse files Browse the repository at this point in the history
Signed-off-by: Hayato Mizushima <[email protected]>
  • Loading branch information
hayato-m126 committed Nov 5, 2024
1 parent f310093 commit 5547782
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
25 changes: 16 additions & 9 deletions driving_log_replayer_v2/scripts/perception_evaluator_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

from autoware_perception_msgs.msg import DetectedObject
from autoware_perception_msgs.msg import DetectedObjects
from autoware_perception_msgs.msg import PredictedObject
from autoware_perception_msgs.msg import PredictedObjects
from autoware_perception_msgs.msg import Shape as MsgShape
from autoware_perception_msgs.msg import TrackedObject
from autoware_perception_msgs.msg import TrackedObjects
Expand Down Expand Up @@ -108,7 +110,7 @@ def __init__(
self.__evaluator = PerceptionEvaluationManager(evaluation_config=evaluation_config)
self.__sub_perception = self.create_subscription(
self.__msg_type,
"/perception/object_recognition/" + self.__topic_ns + "/objects",
"/perception/object_recognition/" + self.__topic_ns + "objects",
self.perception_cb,
1,
)
Expand All @@ -124,12 +126,17 @@ def check_evaluation_task(self) -> bool:
if self.__evaluation_task in ["detection", "fp_validation"]:
self.__frame_id_str = "base_link"
self.__msg_type = DetectedObjects
self.__topic_ns = "detection"
self.__topic_ns = "detection/"
return True
if self.__evaluation_task == "tracking":
self.__frame_id_str = "map"
self.__msg_type = TrackedObjects
self.__topic_ns = "tracking"
self.__topic_ns = "tracking/"
return True
if self.__evaluation_task == "prediction":
self.__frame_id_str = "map"
self.__msg_type = PredictedObjects
self.__topic_ns = "" # prediction name space is ""
return True
self.get_logger().error(f"Unexpected evaluation task: {self.__evaluation_task}")
return False
Expand Down Expand Up @@ -172,7 +179,7 @@ def write_metrics(self) -> None:
def list_dynamic_object_from_ros_msg(
self,
unix_time: int,
objects: list[DetectedObject] | list[TrackedObject],
objects: list[DetectedObject] | list[TrackedObject] | list[PredictedObject],
) -> list[DynamicObject] | str:
# return str(error_msg) when footprint points are invalid
estimated_objects: list[DynamicObject] = []
Expand All @@ -191,7 +198,7 @@ def list_dynamic_object_from_ros_msg(
)

uuid = None
if isinstance(perception_object, TrackedObject):
if isinstance(perception_object, TrackedObject | PredictedObject):
uuid = eval_conversions.uuid_from_ros_msg(perception_object.object_id.uuid)

shape_type = ShapeType.BOUNDING_BOX
Expand Down Expand Up @@ -230,12 +237,12 @@ def list_dynamic_object_from_ros_msg(
estimated_objects.append(estimated_object)
return estimated_objects

def perception_cb(self, msg: DetectedObjects | TrackedObjects) -> None:
def perception_cb(self, msg: DetectedObjects | TrackedObjects | PredictedObjects) -> None:
map_to_baselink = self.lookup_transform(msg.header.stamp)
# DetectedObjectとTrackedObjectで違う型ではあるが、estimated_objectを作る上で使用している項目は共通で保持しているので同じ関数で処理できる
# Although there are multiple msg types to receive, the items used to create the estimated_object are held in common, so they can be processed by the same function.
unix_time: int = eval_conversions.unix_time_from_ros_msg(msg.header)
# Tracking objectはtimestampがズレていることがあるのでGTの補間を行う
if isinstance(msg, TrackedObjects):
# Tracking object以降はtimestampがズレていることがあるのでGTの補間を行う
if isinstance(msg, TrackedObjects | PredictedObjects):
interpolation: bool = True
else:
interpolation = False
Expand Down
2 changes: 1 addition & 1 deletion sample/perception/scenario.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Evaluation:
Distance: 50.0- # [m] null [Do not filter by distance] or lower_limit-(upper_limit) [Upper limit can be omitted. If omitted value is 1.7976931348623157e+308]
PerceptionEvaluationConfig:
evaluation_config_dict:
evaluation_task: detection # detection or tracking. Evaluate the objects specified here
evaluation_task: detection # detection, tracking, or prediction. Evaluate the objects specified here
target_labels: [car, bicycle, pedestrian, motorbike] # evaluation label
max_x_position: 102.4 # Maximum x position of object to be evaluated
max_y_position: 102.4 # Maximum y position of object to be evaluated
Expand Down

0 comments on commit 5547782

Please sign in to comment.