Skip to content

Commit

Permalink
feat: update tlr new interface (#291)
Browse files Browse the repository at this point in the history
  • Loading branch information
hayato-m126 authored May 24, 2024
2 parents 7c04a66 + 209a998 commit 7609f1a
Show file tree
Hide file tree
Showing 15 changed files with 406 additions and 119 deletions.
2 changes: 2 additions & 0 deletions .driving_log_replayer.cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
"Minoda",
"pyproject",
"CENTERDISTANCE",
"rightdiagonal",
"leftdiagonal",
"nums",
"pydantic",
"Kotaro",
Expand Down
2 changes: 2 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"${env:HOME}/ros_ws/awf/install/tier4_localization_msgs/local/lib/python3.10/dist-packages",
"${env:HOME}/ros_ws/awf/install/tier4_api_msgs/local/lib/python3.10/dist-packages",
"${env:HOME}/ros_ws/awf/install/tier4_perception_msgs/local/lib/python3.10/dist-packages",
"${env:HOME}/ros_ws/awf/install/autoware_perception_msgs/local/lib/python3.10/dist-packages",
"${env:HOME}/ros_ws/awf/install/lanelet2_extension_python/local/lib/python3.10/dist-packages"
],
"python.analysis.extraPaths": [
Expand All @@ -56,6 +57,7 @@
"${env:HOME}/ros_ws/awf/install/tier4_localization_msgs/local/lib/python3.10/dist-packages",
"${env:HOME}/ros_ws/awf/install/tier4_api_msgs/local/lib/python3.10/dist-packages",
"${env:HOME}/ros_ws/awf/install/tier4_perception_msgs/local/lib/python3.10/dist-packages",
"${env:HOME}/ros_ws/awf/install/autoware_perception_msgs/local/lib/python3.10/dist-packages",
"${env:HOME}/ros_ws/awf/install/lanelet2_extension_python/local/lib/python3.10/dist-packages"
],
"files.associations": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def from_str(cls, value: str) -> CriteriaMethod:
"""
name: str = value.upper()
assert name in cls.__members__, "value must be NUM_TP, METRICS_SCORE, or METRICS_SCORE_MAPH"
assert (
name in cls.__members__
), "value must be NUM_TP, LABEL, METRICS_SCORE, or METRICS_SCORE_MAPH"
return cls.__members__[name]


Expand Down
35 changes: 4 additions & 31 deletions driving_log_replayer/driving_log_replayer/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from autoware_adapi_v1_msgs.srv import InitializeLocalization
from autoware_auto_perception_msgs.msg import ObjectClassification
from autoware_auto_perception_msgs.msg import TrafficLight
from builtin_interfaces.msg import Time as Stamp
from geometry_msgs.msg import Point
from geometry_msgs.msg import Pose
Expand Down Expand Up @@ -342,36 +341,10 @@ def get_most_probable_classification(
cls,
array_classification: list[ObjectClassification],
) -> ObjectClassification:
highest_probability = 0.0
highest_classification = None
for classification in array_classification:
if classification.probability >= highest_probability:
highest_probability = classification.probability
highest_classification = classification
return highest_classification

@classmethod
def get_traffic_light_label_str(cls, light: TrafficLight) -> str:
if light.color == TrafficLight.RED:
return "red"
if light.color == TrafficLight.AMBER:
return "yellow"
if light.color == TrafficLight.GREEN:
return "green"
return "unknown"

@classmethod
def get_most_probable_signal(
cls,
lights: list[TrafficLight],
) -> TrafficLight:
highest_probability = 0.0
highest_light = None
for light in lights:
if light.confidence >= highest_probability:
highest_probability = light.confidence
highest_light = light
return highest_light
index: int = array_classification.index(
max(array_classification, key=lambda x: x.probability),
)
return array_classification[index]


def evaluator_main(func: Callable) -> Callable:
Expand Down
19 changes: 15 additions & 4 deletions driving_log_replayer/driving_log_replayer/lanelet2_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,23 @@
from shapely.geometry import Polygon


def road_lanelets_from_file(map_path: str) -> Any:
def load_map(map_path: str) -> lanelet2.core.LaneletMap:
projection = MGRSProjector(lanelet2.io.Origin(0.0, 0.0))
lanelet_map = lanelet2.io.load(map_path, projection)
all_lanelets = query.laneletLayer(lanelet_map)
return lanelet2.io.load(map_path, projection)


def load_all_lanelets(map_path: str) -> Any:
lanelet_map = load_map(map_path)
return query.laneletLayer(lanelet_map)


def road_lanelets_from_file(map_path: str) -> Any:
# return type lanelet2_extension_python._lanelet2_extension_python_boost_python_utility.lanelet::ConstLanelets
return query.roadLanelets(all_lanelets)
return query.roadLanelets(load_all_lanelets(map_path))


def traffic_light_from_file(map_path: str) -> list:
return query.trafficLights(load_all_lanelets(map_path))


def to_shapely_polygon(lanelet: Lanelet) -> Polygon:
Expand Down
5 changes: 1 addition & 4 deletions driving_log_replayer/driving_log_replayer/perception.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,7 @@ def __post_init__(self) -> None:
distance_range=self.condition.Filter.Distance,
)

def set_frame(
self,
frame: PerceptionFrameResult,
) -> dict:
def set_frame(self, frame: PerceptionFrameResult) -> dict:
frame_success = "Fail"
# ret_frame might be filtered frame result or original frame result.
result, ret_frame = self.criteria.get_result(frame)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from builtin_interfaces.msg import Time
from geometry_msgs.msg import Point
from geometry_msgs.msg import Polygon as RosPolygon
from geometry_msgs.msg import Pose
Expand All @@ -36,6 +37,10 @@ def unix_time_from_ros_msg(ros_header: Header) -> int:
return ros_header.stamp.sec * pow(10, 6) + ros_header.stamp.nanosec // 1000


def unix_time_from_ros_timestamp(ros_timestamp: Time) -> int:
return ros_timestamp.sec * pow(10, 6) + ros_timestamp.nanosec // 1000


def position_from_ros_msg(ros_position: Point) -> tuple[int, int, int]:
return (ros_position.x, ros_position.y, ros_position.z)

Expand Down
181 changes: 162 additions & 19 deletions driving_log_replayer/driving_log_replayer/traffic_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,136 @@
# limitations under the License.

from dataclasses import dataclass
import logging
from pathlib import Path
import sys
from typing import Literal

from autoware_perception_msgs.msg import TrafficSignalElement
from perception_eval.evaluation import PerceptionFrameResult
from pydantic import BaseModel
from pydantic import field_validator
import simplejson as json

from driving_log_replayer.criteria import PerceptionCriteria
from driving_log_replayer.perception_eval_conversions import summarize_pass_fail_result
from driving_log_replayer.result import EvaluationItem
from driving_log_replayer.result import ResultBase
from driving_log_replayer.scenario import number
from driving_log_replayer.scenario import Scenario

TRAFFIC_LIGHT_LABEL_MAPPINGS: list[tuple[set, str]] = [
({"green"}, "green"),
({"green", "straight"}, "green_straight"),
({"green", "left"}, "green_left"),
({"green", "right"}, "green_right"),
({"yellow"}, "yellow"),
({"yellow", "straight"}, "yellow_straight"),
({"yellow", "left"}, "yellow_left"),
({"yellow", "right"}, "yellow_right"),
({"yellow", "straight", "left"}, "yellow_straight_left"),
({"yellow", "straight", "right"}, "yellow_straight_right"),
({"red"}, "red"),
({"red", "straight"}, "red_straight"),
({"red", "left"}, "red_left"),
({"red", "right"}, "red_right"),
({"red", "straight", "left"}, "red_straight_left"),
({"red", "straight", "right"}, "red_straight_right"),
({"red", "straight", "left", "right"}, "red_straight_left_right"),
({"red", "right", "diagonal"}, "red_rightdiagonal"),
({"red", "left", "diagonal"}, "red_leftdiagonal"),
]

class Conditions(BaseModel):

def get_traffic_light_label_str(elements: list[TrafficSignalElement]) -> str: # noqa
label_infos = []
for element in elements:
if element.shape == TrafficSignalElement.CIRCLE:
if element.color == TrafficSignalElement.RED:
label_infos.append("red")
elif element.color == TrafficSignalElement.AMBER:
label_infos.append("yellow")
elif element.color == TrafficSignalElement.GREEN:
label_infos.append("green")
continue

if element.shape == TrafficSignalElement.UP_ARROW:
label_infos.append("straight")
elif element.shape == TrafficSignalElement.LEFT_ARROW:
label_infos.append("left")
elif element.shape == TrafficSignalElement.RIGHT_ARROW:
label_infos.append("right")
elif element.shape in (
TrafficSignalElement.UP_LEFT_ARROW,
TrafficSignalElement.DOWN_LEFT_ARROW,
):
label_infos.append("left")
label_infos.append("diagonal")
elif element.shape in (
TrafficSignalElement.UP_RIGHT_ARROW,
TrafficSignalElement.DOWN_RIGHT_ARROW,
):
label_infos.append("right")
label_infos.append("diagonal")

label_infos = set(label_infos)

for info_set, label in TRAFFIC_LIGHT_LABEL_MAPPINGS:
if label_infos == info_set:
return label

return "unknown"


def get_most_probable_element(
elements: list[TrafficSignalElement],
) -> TrafficSignalElement:
index: int = elements.index(max(elements, key=lambda x: x.confidence))
return elements[index]


class Filter(BaseModel):
Distance: tuple[float, float] | None = None
# add filter condition here

@field_validator("Distance", mode="before")
@classmethod
def validate_distance_range(cls, v: str | None) -> tuple[number, number] | None:
if v is None:
return None

err_msg = f"{v} is not valid distance range, expected ordering min-max with min < max."

s_lower, s_upper = v.split("-")
if s_upper == "":
s_upper = sys.float_info.max

lower = float(s_lower)
upper = float(s_upper)

if lower >= upper:
raise ValueError(err_msg)
return (lower, upper)


class Criteria(BaseModel):
PassRate: number
CriteriaMethod: Literal["num_tp", "metrics_score"] | None = None
CriteriaLevel: Literal["perfect", "hard", "normal", "easy"] | number | None = None
CriteriaMethod: (
Literal["num_tp", "label", "metrics_score", "metrics_score_maph"] | list[str] | None
) = None
CriteriaLevel: (
Literal["perfect", "hard", "normal", "easy"] | list[str] | number | list[number] | None
) = None
Filter: Filter


class Conditions(BaseModel):
Criterion: list[Criteria]


class Evaluation(BaseModel):
UseCaseName: Literal["traffic_light"]
UseCaseFormatVersion: Literal["0.2.0", "0.3.0"]
UseCaseFormatVersion: Literal["1.0.0"]
Datasets: list[dict]
Conditions: Conditions
PerceptionEvaluationConfig: dict
Expand All @@ -45,6 +154,36 @@ class TrafficLightScenario(Scenario):
Evaluation: Evaluation


class FailResultHolder:
def __init__(self, save_dir: str) -> None:
self.save_path: str = Path(save_dir, "fail_info.json")
self.buffer = []

def add_frame(self, frame_result: PerceptionFrameResult) -> None:
if frame_result.pass_fail_result.get_fail_object_num() <= 0:
return
info = {"fp": [], "fn": []}
info["timestamp"] = frame_result.frame_ground_truth.unix_time
for fp_result in frame_result.pass_fail_result.fp_object_results:
est_label = fp_result.estimated_object.semantic_label.label.value
gt_label = (
fp_result.ground_truth_object.semantic_label.label.value
if fp_result.ground_truth_object is not None
else None
)
info["fp"].append({"est": est_label, "gt": gt_label})
for fn_object in frame_result.pass_fail_result.fn_objects:
info["fn"].append({"est": None, "gt": fn_object.semantic_label.label.value})

info_str = f"Fail timestamp: {info}"
logging.info(info_str)
self.buffer.append(info)

def save(self) -> None:
with self.save_path.open("w") as f:
json.dump(self.buffer, f, ensure_ascii=False, indent=4)


@dataclass
class Perception(EvaluationItem):
success: bool = True
Expand All @@ -55,11 +194,12 @@ def __post_init__(self) -> None:
self.criteria: PerceptionCriteria = PerceptionCriteria(
methods=self.condition.CriteriaMethod,
levels=self.condition.CriteriaLevel,
distance_range=self.condition.Filter.Distance,
)

def set_frame(self, frame: PerceptionFrameResult) -> dict:
frame_success = "Fail"
result, _ = self.criteria.get_result(frame)
result, ret_frame = self.criteria.get_result(frame)

if result is None:
self.no_gt_no_obj += 1
Expand All @@ -76,28 +216,30 @@ def set_frame(self, frame: PerceptionFrameResult) -> dict:
return {
"PassFail": {
"Result": {"Total": self.success_str(), "Frame": frame_success},
"Info": {
"TP": len(frame.pass_fail_result.tp_object_results),
"FP": len(frame.pass_fail_result.fp_object_results),
"FN": len(frame.pass_fail_result.fn_objects),
},
"Info": summarize_pass_fail_result(ret_frame.pass_fail_result),
},
}


class TrafficLightResult(ResultBase):
def __init__(self, condition: Conditions) -> None:
super().__init__()
self.__perception = Perception(condition=condition)
self.__perception_criterion: list[Perception] = []
for i, criteria in enumerate(condition.Criterion):
self.__perception_criterion.append(
Perception(name=f"criteria{i}", condition=criteria),
)

def update(self) -> None:
summary_str = f"{self.__perception.summary}"
if self.__perception.success:
self._success = True
self._summary = f"Passed: {summary_str}"
else:
self._success = False
self._summary = f"Failed: {summary_str}"
all_summary: list[str] = []
all_success: list[bool] = []
for criterion in self.__perception_criterion:
tmp_success = criterion.success
prefix_str = "Passed: " if tmp_success else "Failed: "
all_summary.append(prefix_str + criterion.summary)
all_success.append(tmp_success)
self._summary = ", ".join(all_summary)
self._success = all(all_success)

def set_frame(
self,
Expand All @@ -110,7 +252,8 @@ def set_frame(
"FrameName": frame.frame_name,
"FrameSkip": skip,
}
self._frame |= self.__perception.set_frame(frame)
for criterion in self.__perception_criterion:
self._frame[criterion.name] = criterion.set_frame(frame)
self.update()

def set_final_metrics(self, final_metrics: dict) -> None:
Expand Down
Loading

0 comments on commit 7609f1a

Please sign in to comment.