diff --git a/pyproject.toml b/pyproject.toml index d89728e..49e3b29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,8 @@ dependencies = [ "jsonschema>=4.4.0", "fastjsonschema>=2.16.2", "raillabel>=3.1.0, <4.0.0", - "pyyaml>=6.0.0" + "pyyaml>=6.0.0", + "numpy>=1.24.4", ] [project.urls] diff --git a/raillabel_providerkit/_util/_filters.py b/raillabel_providerkit/_util/_filters.py new file mode 100644 index 0000000..5b4a4b2 --- /dev/null +++ b/raillabel_providerkit/_util/_filters.py @@ -0,0 +1,27 @@ +# Copyright DB Netz AG and contributors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import raillabel + + +def filter_sensor_uids_by_type( + sensors: list[raillabel.format.Sensor], sensor_type: raillabel.format.SensorType +) -> set[str]: + """Get the uids of all given sensors matching the given SensorType. + + Parameters + ---------- + sensors : list[raillabel.format.Sensor] + The sensors to filter. + sensor_type : raillabel.format.SensorType + The SensorType to filter by. + + Returns + ------- + set[str] + The list of uids of matching sensors. + + """ + return {sensor.uid for sensor in sensors if sensor.type == sensor_type} diff --git a/raillabel_providerkit/validation/validate_rail_side/__init__.py b/raillabel_providerkit/validation/validate_rail_side/__init__.py new file mode 100644 index 0000000..ea61820 --- /dev/null +++ b/raillabel_providerkit/validation/validate_rail_side/__init__.py @@ -0,0 +1,3 @@ +# Copyright DB Netz AG and contributors +# SPDX-License-Identifier: Apache-2.0 +"""Package for validating a scene for rail side errors.""" diff --git a/raillabel_providerkit/validation/validate_rail_side/validate_rail_side.py b/raillabel_providerkit/validation/validate_rail_side/validate_rail_side.py new file mode 100644 index 0000000..d01bd76 --- /dev/null +++ b/raillabel_providerkit/validation/validate_rail_side/validate_rail_side.py @@ -0,0 +1,317 @@ +# Copyright DB Netz AG and contributors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import numpy as np +import raillabel + +from raillabel_providerkit._util._filters import filter_sensor_uids_by_type + + +def validate_rail_side(scene: raillabel.Scene) -> list[str]: + """Validate whether all tracks have <= one left and right rail, and that they have correct order. + + Parameters + ---------- + scene : raillabel.Scene + Scene, that should be validated. + + Returns + ------- + list[str] + list of all rail side errors in the scene. If an empty list is returned, then there are no + errors present. + + """ + errors: list[str] = [] + + # Get a list of camera uids + cameras = filter_sensor_uids_by_type( + list(scene.sensors.values()), raillabel.format.SensorType.CAMERA + ) + + # Check per camera + for camera in cameras: + # Filter scene for track annotations in the selected camera sensor + filtered_scene = raillabel.filter( + scene, include_object_types=["track"], include_sensors=[camera] + ) + + # Check per frame + for frame_uid, frame in filtered_scene.frames.items(): + # Count rails per track + counts_per_track = _count_rails_per_track_in_frame(frame) + + # Add errors if there is more than one left or right rail + for object_uid, (left_count, right_count) in counts_per_track.items(): + if left_count > 1: + errors.append( + f"In sensor {camera} frame {frame_uid}, the track with" + f" object_uid {object_uid} has more than one ({left_count}) left rail." + ) + if right_count > 1: + errors.append( + f"In sensor {camera} frame {frame_uid}, the track with" + f" object_uid {object_uid} has more than one ({right_count}) right rail." + ) + + # If exactly one left and right rail exists, check if the track has its rails swapped + # or intersects with itself + if left_count == 1 == right_count: + # Get the two annotations in question + left_rail: raillabel.format.Poly2d | None = _get_track_from_frame( + frame, object_uid, "leftRail" + ) + right_rail: raillabel.format.Poly2d | None = _get_track_from_frame( + frame, object_uid, "rightRail" + ) + if left_rail is None or right_rail is None: + continue + + swap_error: str | None = _check_rails_for_swap(left_rail, right_rail, frame_uid) + if swap_error is not None: + errors.append(swap_error) + + return errors + + +def _check_rails_for_swap( + left_rail: raillabel.format.Poly2d, + right_rail: raillabel.format.Poly2d, + frame_uid: str | int = "unknown", +) -> str | None: + # Ensure the rails belong to the same track + if left_rail.object.uid != right_rail.object.uid: + return None + + max_common_y = _find_max_common_y(left_rail, right_rail) + if max_common_y is None: + return None + + left_x = _find_x_by_y(max_common_y, left_rail) + right_x = _find_x_by_y(max_common_y, right_rail) + if left_x is None or right_x is None: + return None + + object_uid = left_rail.object.uid + sensor_uid = left_rail.sensor.uid if left_rail.sensor is not None else "unknown" + + if left_x >= right_x: + return ( + f"In sensor {sensor_uid} frame {frame_uid}, the track with" + f" object_uid {object_uid} has its rails swapped." + f" At the maximum common y={max_common_y}, the left rail has x={left_x}" + f" while the right rail has x={right_x}." + ) + + intersect_interval = _find_intersect_interval(left_rail, right_rail) + if intersect_interval is not None: + return ( + f"In sensor {sensor_uid} frame {frame_uid}, the track with" + f" object_uid {object_uid} intersects with itself." + f" The left and right rail intersect in y interval {intersect_interval}." + ) + + return None + + +def _count_rails_per_track_in_frame(frame: raillabel.format.Frame) -> dict[str, tuple[int, int]]: + # For each track, the left and right rail counts are stored as a tuple (left, right) + counts: dict[str, tuple[int, int]] = {} + + # For each track, count the left and right rails + for object_uid, unfiltered_annotations in frame.object_data.items(): + # Ensure we work only on Poly2d annotations + poly2ds: list[raillabel.format.Poly2d] = _filter_for_poly2ds( + list(unfiltered_annotations.values()) + ) + + # Count left and right rails + left_count: int = 0 + right_count: int = 0 + for poly2d in poly2ds: + rail_side = poly2d.attributes["railSide"] + if rail_side == "leftRail": + left_count += 1 + elif rail_side == "rightRail": + right_count += 1 + else: + # NOTE: This is ignored because it is covered by validate_onthology + continue + + # Store counts of current track + counts[object_uid] = (left_count, right_count) + + # Return results + return counts + + +def _filter_for_poly2ds( + unfiltered_annotations: list[type[raillabel.format._ObjectAnnotation]], +) -> list[raillabel.format.Poly2d]: + return [ + annotation + for annotation in unfiltered_annotations + if isinstance(annotation, raillabel.format.Poly2d) + ] + + +def _find_intersect_interval( + line1: raillabel.format.Poly2d, line2: raillabel.format.Poly2d +) -> tuple[float, float] | None: + # If the two polylines intersect anywhere, return the y interval where they intersect. + + # Get all y values where either polyline has points + y_values: list[float] = sorted( + _get_y_of_all_points_of_poly2d(line1).union(_get_y_of_all_points_of_poly2d(line2)) + ) + + order: bool | None = None + last_y: float | None = None + for y in y_values: + x1 = _find_x_by_y(y, line1) + x2 = _find_x_by_y(y, line2) + + if x1 is None or x2 is None: + order = None + continue + + if x1 == x2: + return (y, y) + + new_order = x1 < x2 + + if order is not None and new_order != order and last_y is not None: + # The order has flipped. There is an intersection between previous and current y + return (last_y, y) + + order = new_order + last_y = y + + return None + + +def _find_max_y(poly2d: raillabel.format.Poly2d) -> float: + return np.max([point.y for point in poly2d.points]) + + +def _find_max_common_y( + line1: raillabel.format.Poly2d, line2: raillabel.format.Poly2d +) -> float | None: + if len(line1.points) == 0 or len(line2.points) == 0: + # One of the lines is empty + return None + + max_y_of_line1: float = _find_max_y(line1) + if _y_in_poly2d(max_y_of_line1, line2): + # The highest y is the bottom of line 1 + return max_y_of_line1 + + max_y_of_line2: float = _find_max_y(line2) + if _y_in_poly2d(max_y_of_line2, line1): + # The highest y is the bottom of line 2 + return max_y_of_line2 + + # There is no y overlap + return None + + +def _find_x_by_y(y: float, poly2d: raillabel.format.Poly2d) -> float | None: + """Find the x value of the first point where the polyline passes through y. + + Parameters + ---------- + y : float + The y value to check. + poly2d : raillabel.format.Poly2d + The Poly2D whose points will be checked against. + + Returns + ------- + float | None + The x value of a point (x,y) that poly2d passes through, + or None if poly2d doesn't go through y. + + """ + # 1. Find the first two points between which y is located + points = poly2d.points + p1: raillabel.format.Point2d | None = None + p2: raillabel.format.Point2d | None = None + for i in range(len(points) - 1): + current = points[i] + next_ = points[i + 1] + if (current.y >= y >= next_.y) or (current.y <= y <= next_.y): + p1 = current + p2 = next_ + break + + # 2. Abort if no valid points have been found + if not (p1 and p2): + return None + + # 3. Return early if p1=p2 (to avoid division by zero) + if p1.x == p2.x: + return p1.x + + # 4. Calculate m and n for the line g(x)=mx+n connecting p1 and p2 + m = (p2.y - p1.y) / (p2.x - p1.x) + n = p1.y - (m * p1.x) + + # 5. Return early if m is 0, as that means p2.y=p1.y, which implies p2.y=p1.y=y + if m == 0: + return p1.x + + # 6. Calculate the x we were searching for and return it + return (y - n) / m + + +def _get_track_from_frame( + frame: raillabel.format.Frame, object_uid: str, rail_side: str +) -> raillabel.format.Poly2d | None: + if object_uid not in frame.object_data: + return None + + for annotation in frame.object_data[object_uid].values(): + if not isinstance(annotation, raillabel.format.Poly2d): + continue + + if "railSide" not in annotation.attributes: + continue + + if annotation.attributes["railSide"] == rail_side: + return annotation + + return None + + +def _get_y_of_all_points_of_poly2d(poly2d: raillabel.format.Poly2d) -> set[float]: + y_values: set[float] = set() + for point in poly2d.points: + y_values.add(point.y) + return y_values + + +def _y_in_poly2d(y: float, poly2d: raillabel.format.Poly2d) -> bool: + """Check whether the polyline created by the given Poly2d passes through the given y value. + + Parameters + ---------- + y : float + The y value to check. + poly2d : raillabel.format.Poly2d + The Poly2D whose points will be checked against. + + Returns + ------- + bool + Does the Poly2d pass through the given y value? + + """ + # For every point (except the last), check if the y is between them + for i in range(len(poly2d.points) - 1): + current = poly2d.points[i] + next_ = poly2d.points[i + 1] + if (current.y >= y >= next_.y) or (current.y <= y <= next_.y): + return True + return False diff --git a/tests/test_raillabel_providerkit/_util/test_filters.py b/tests/test_raillabel_providerkit/_util/test_filters.py new file mode 100644 index 0000000..c6c812b --- /dev/null +++ b/tests/test_raillabel_providerkit/_util/test_filters.py @@ -0,0 +1,61 @@ +# Copyright DB Netz AG and contributors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest +import raillabel + +from raillabel_providerkit._util._filters import filter_sensor_uids_by_type + + +@pytest.fixture +def sensor_types() -> list[raillabel.format.SensorType]: + return [sensor_type for sensor_type in raillabel.format.SensorType] + + +def test_filter_sensor_uids_by_type__empty(sensor_types): + sensors = [] + for sensor_type in sensor_types: + assert len(filter_sensor_uids_by_type(sensors, sensor_type)) == 0 + + +def test_filter_sensor_uids_by_type__exactly_one_match(sensor_types): + # Create a list of sensors where each sensor type occurs exactly once + sensors = [] + for i in range(len(sensor_types)): + sensors.append(raillabel.format.Sensor(uid=f"test_{i}", type=sensor_types[i])) + + # Ensure the filter works for each sensor type + for sensor_type in sensor_types: + results = filter_sensor_uids_by_type(sensors, sensor_type) + assert len(results) == 1 + # Assert the result is of correct type + matches = 0 + for sensor in sensors: + if sensor.uid in results: + assert sensor.type == sensor_type + matches += 1 + assert matches == len(results) + + +def test_filter_sensor_uids_by_type__multiple_matches(sensor_types): + # Create a list of sensors where each sensor type occurs three times + sensors = [] + i = 0 + for sensor_type in sensor_types: + for j in range(3): + sensors.append(raillabel.format.Sensor(uid=f"test_{i}", type=sensor_type)) + i += 1 + + # Ensure the filter works for each sensor type + for sensor_type in sensor_types: + results = filter_sensor_uids_by_type(sensors, sensor_type) + assert len(results) == 3 + # Assert the results are of correct type + matches = 0 + for sensor in sensors: + if sensor.uid in results: + assert sensor.type == sensor_type + matches += 1 + assert matches == len(results) diff --git a/tests/test_raillabel_providerkit/validation/validate_rail_side/test_validate_rail_side.py b/tests/test_raillabel_providerkit/validation/validate_rail_side/test_validate_rail_side.py new file mode 100644 index 0000000..eb053be --- /dev/null +++ b/tests/test_raillabel_providerkit/validation/validate_rail_side/test_validate_rail_side.py @@ -0,0 +1,469 @@ +# Copyright DB Netz AG and contributors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import raillabel + +from raillabel_providerkit.validation.validate_rail_side.validate_rail_side import ( + validate_rail_side, + _count_rails_per_track_in_frame, +) + + +@pytest.fixture +def example_camera_1() -> raillabel.format.Sensor: + return raillabel.format.Sensor( + uid="rgb_center", + type=raillabel.format.SensorType.CAMERA, + ) + + +@pytest.fixture +def example_camera_2() -> raillabel.format.Sensor: + return raillabel.format.Sensor( + uid="ir_center", + type=raillabel.format.SensorType.CAMERA, + ) + + +@pytest.fixture +def example_track_1() -> raillabel.format.Object: + return raillabel.format.Object( + uid="a1082ef9-555b-4b69-a888-7da531d8a2eb", name="track0001", type="track" + ) + + +@pytest.fixture +def example_track_2() -> raillabel.format.Object: + return raillabel.format.Object( + uid="6e92e7af-3bc8-4225-b538-16d19e3f8aa7", name="track0002", type="track" + ) + + +def test_count_rails_per_track_in_frame__empty(empty_frame): + frame = empty_frame + results = _count_rails_per_track_in_frame(frame) + assert len(results) == 0 + + +def test_count_rails_per_track_in_frame__many_rails_for_one_track( + empty_frame, example_camera_1, example_track_1 +): + frame = empty_frame + sensor = example_camera_1 + object = example_track_1 + + LEFT_COUNT = 32 + RIGHT_COUNT = 42 + + for i in range(LEFT_COUNT): + uid = f"test_left_{i}" + frame.annotations[uid] = raillabel.format.Poly2d( + uid=uid, + object=object, + sensor=sensor, + points=[ + raillabel.format.Point2d(0, 0), + raillabel.format.Point2d(0, 1), + ], + closed=False, + attributes={"railSide": "leftRail"}, + ) + + for i in range(RIGHT_COUNT): + uid = f"test_right_{i}" + frame.annotations[uid] = raillabel.format.Poly2d( + uid=uid, + object=object, + sensor=sensor, + points=[ + raillabel.format.Point2d(1, 0), + raillabel.format.Point2d(1, 1), + ], + closed=False, + attributes={"railSide": "rightRail"}, + ) + + results = _count_rails_per_track_in_frame(frame) + assert len(results) == 1 + assert object.uid in results.keys() + assert results[object.uid] == (LEFT_COUNT, RIGHT_COUNT) + + +def test_count_rails_per_track_in_frame__many_rails_for_two_tracks( + empty_frame, example_camera_1, example_track_1, example_track_2 +): + frame = empty_frame + sensor = example_camera_1 + object1 = example_track_1 + object2 = example_track_2 + + LEFT_COUNT = 32 + RIGHT_COUNT = 42 + + for object in [object1, object2]: + for i in range(LEFT_COUNT): + uid = f"test_left_{i}_object_{object.uid}" + frame.annotations[uid] = raillabel.format.Poly2d( + uid=uid, + object=object, + sensor=sensor, + points=[ + raillabel.format.Point2d(0, 0), + raillabel.format.Point2d(0, 1), + ], + closed=False, + attributes={"railSide": "leftRail"}, + ) + + for i in range(RIGHT_COUNT): + uid = f"test_right_{i}_object_{object.uid}" + frame.annotations[uid] = raillabel.format.Poly2d( + uid=uid, + object=object, + sensor=sensor, + points=[ + raillabel.format.Point2d(1, 0), + raillabel.format.Point2d(1, 1), + ], + closed=False, + attributes={"railSide": "rightRail"}, + ) + + results = _count_rails_per_track_in_frame(frame) + assert len(results) == 2 + assert object1.uid in results.keys() + assert object2.uid in results.keys() + assert results[object1.uid] == (LEFT_COUNT, RIGHT_COUNT) + assert results[object2.uid] == (LEFT_COUNT, RIGHT_COUNT) + + +def test_validate_rail_side__no_errors(empty_scene, empty_frame, example_camera_1, example_track_1): + scene = empty_scene + object = example_track_1 + scene.objects[object.uid] = object + sensor = example_camera_1 + scene.sensors[sensor.uid] = sensor + frame = empty_frame + frame.annotations["325b1f55-a2ef-475f-a780-13e1a9e823c3"] = raillabel.format.Poly2d( + uid="325b1f55-a2ef-475f-a780-13e1a9e823c3", + object=object, + sensor=sensor, + points=[ + raillabel.format.Point2d(0, 0), + raillabel.format.Point2d(0, 1), + ], + closed=False, + attributes={"railSide": "leftRail"}, + ) + frame.annotations["be7d136a-8364-4fbd-b098-6f4a21205d22"] = raillabel.format.Poly2d( + uid="be7d136a-8364-4fbd-b098-6f4a21205d22", + object=object, + sensor=sensor, + points=[ + raillabel.format.Point2d(1, 0), + raillabel.format.Point2d(1, 1), + ], + closed=False, + attributes={"railSide": "rightRail"}, + ) + scene.frames[frame.uid] = frame + + actual = validate_rail_side(scene) + assert len(actual) == 0 + + +def test_validate_rail_side__rail_sides_switched( + empty_scene, empty_frame, example_camera_1, example_track_1 +): + scene = empty_scene + object = example_track_1 + scene.objects[object.uid] = object + sensor = example_camera_1 + scene.sensors[sensor.uid] = sensor + frame = empty_frame + frame.annotations["325b1f55-a2ef-475f-a780-13e1a9e823c3"] = raillabel.format.Poly2d( + uid="325b1f55-a2ef-475f-a780-13e1a9e823c3", + object=object, + sensor=sensor, + points=[ + raillabel.format.Point2d(0, 0), + raillabel.format.Point2d(0, 1), + ], + closed=False, + attributes={"railSide": "rightRail"}, + ) + frame.annotations["be7d136a-8364-4fbd-b098-6f4a21205d22"] = raillabel.format.Poly2d( + uid="be7d136a-8364-4fbd-b098-6f4a21205d22", + object=object, + sensor=sensor, + points=[ + raillabel.format.Point2d(1, 0), + raillabel.format.Point2d(1, 1), + ], + closed=False, + attributes={"railSide": "leftRail"}, + ) + scene.frames[frame.uid] = frame + + actual = validate_rail_side(scene) + assert len(actual) == 1 + + +def test_validate_rail_side__rail_sides_intersect_at_top( + empty_scene, empty_frame, example_camera_1, example_track_1 +): + scene = empty_scene + object = example_track_1 + scene.objects[object.uid] = object + sensor = example_camera_1 + scene.sensors[sensor.uid] = sensor + frame = empty_frame + frame.annotations["325b1f55-a2ef-475f-a780-13e1a9e823c3"] = raillabel.format.Poly2d( + uid="325b1f55-a2ef-475f-a780-13e1a9e823c3", + object=object, + sensor=sensor, + points=[ + raillabel.format.Point2d(20, 0), + raillabel.format.Point2d(20, 10), + raillabel.format.Point2d(10, 20), + raillabel.format.Point2d(10, 100), + ], + closed=False, + attributes={"railSide": "leftRail"}, + ) + frame.annotations["be7d136a-8364-4fbd-b098-6f4a21205d22"] = raillabel.format.Poly2d( + uid="be7d136a-8364-4fbd-b098-6f4a21205d22", + object=object, + sensor=sensor, + points=[ + raillabel.format.Point2d(10, 0), + raillabel.format.Point2d(10, 10), + raillabel.format.Point2d(20, 20), + raillabel.format.Point2d(20, 100), + ], + closed=False, + attributes={"railSide": "rightRail"}, + ) + scene.frames[frame.uid] = frame + + actual = validate_rail_side(scene) + assert len(actual) == 1 + + +def test_validate_rail_side__rail_sides_correct_with_early_end_of_one_side( + empty_scene, empty_frame, example_camera_1, example_track_1 +): + scene = empty_scene + object = example_track_1 + scene.objects[object.uid] = object + sensor = example_camera_1 + scene.sensors[sensor.uid] = sensor + frame = empty_frame + frame.annotations["325b1f55-a2ef-475f-a780-13e1a9e823c3"] = raillabel.format.Poly2d( + uid="325b1f55-a2ef-475f-a780-13e1a9e823c3", + object=object, + sensor=sensor, + points=[ + raillabel.format.Point2d(70, 0), + raillabel.format.Point2d(30, 20), + raillabel.format.Point2d(15, 40), + raillabel.format.Point2d(10, 50), + raillabel.format.Point2d(10, 100), + ], + closed=False, + attributes={"railSide": "leftRail"}, + ) + frame.annotations["be7d136a-8364-4fbd-b098-6f4a21205d22"] = raillabel.format.Poly2d( + uid="be7d136a-8364-4fbd-b098-6f4a21205d22", + object=object, + sensor=sensor, + points=[ + raillabel.format.Point2d(20, 50), + raillabel.format.Point2d(20, 100), + ], + closed=False, + attributes={"railSide": "rightRail"}, + ) + scene.frames[frame.uid] = frame + + actual = validate_rail_side(scene) + assert len(actual) == 0 + + +def test_validate_rail_side__two_left_rails( + empty_scene, empty_frame, example_camera_1, example_track_1 +): + scene = empty_scene + object = example_track_1 + scene.objects[object.uid] = object + sensor = example_camera_1 + scene.sensors[sensor.uid] = sensor + frame = empty_frame + frame.annotations["325b1f55-a2ef-475f-a780-13e1a9e823c3"] = raillabel.format.Poly2d( + uid="325b1f55-a2ef-475f-a780-13e1a9e823c3", + object=object, + sensor=sensor, + points=[ + raillabel.format.Point2d(0, 0), + raillabel.format.Point2d(0, 1), + ], + closed=False, + attributes={"railSide": "leftRail"}, + ) + frame.annotations["be7d136a-8364-4fbd-b098-6f4a21205d22"] = raillabel.format.Poly2d( + uid="be7d136a-8364-4fbd-b098-6f4a21205d22", + object=object, + sensor=sensor, + points=[ + raillabel.format.Point2d(1, 0), + raillabel.format.Point2d(1, 1), + ], + closed=False, + attributes={"railSide": "leftRail"}, + ) + scene.frames[frame.uid] = frame + + actual = validate_rail_side(scene) + assert len(actual) == 1 + + +def test_validate_rail_side__two_right_rails( + empty_scene, empty_frame, example_camera_1, example_track_1 +): + scene = empty_scene + object = example_track_1 + scene.objects[object.uid] = object + sensor = example_camera_1 + scene.sensors[sensor.uid] = sensor + frame = empty_frame + frame.annotations["325b1f55-a2ef-475f-a780-13e1a9e823c3"] = raillabel.format.Poly2d( + uid="325b1f55-a2ef-475f-a780-13e1a9e823c3", + object=object, + sensor=sensor, + points=[ + raillabel.format.Point2d(0, 0), + raillabel.format.Point2d(0, 1), + ], + closed=False, + attributes={"railSide": "rightRail"}, + ) + frame.annotations["be7d136a-8364-4fbd-b098-6f4a21205d22"] = raillabel.format.Poly2d( + uid="be7d136a-8364-4fbd-b098-6f4a21205d22", + object=object, + sensor=sensor, + points=[ + raillabel.format.Point2d(1, 0), + raillabel.format.Point2d(1, 1), + ], + closed=False, + attributes={"railSide": "rightRail"}, + ) + scene.frames[frame.uid] = frame + + actual = validate_rail_side(scene) + assert len(actual) == 1 + + +def test_validate_rail_side__two_sensors_with_two_right_rails_each( + empty_scene, empty_frame, example_camera_1, example_camera_2, example_track_1 +): + scene = empty_scene + object = example_track_1 + scene.objects[object.uid] = object + sensor1 = example_camera_1 + sensor2 = example_camera_2 + for sensor in [sensor1, sensor2]: + scene.sensors[sensor.uid] = sensor + frame = empty_frame + frame.annotations["325b1f55-a2ef-475f-a780-13e1a9e823c3"] = raillabel.format.Poly2d( + uid="325b1f55-a2ef-475f-a780-13e1a9e823c3", + object=object, + sensor=sensor1, + points=[ + raillabel.format.Point2d(0, 0), + raillabel.format.Point2d(0, 1), + ], + closed=False, + attributes={"railSide": "rightRail"}, + ) + frame.annotations["be7d136a-8364-4fbd-b098-6f4a21205d22"] = raillabel.format.Poly2d( + uid="be7d136a-8364-4fbd-b098-6f4a21205d22", + object=object, + sensor=sensor1, + points=[ + raillabel.format.Point2d(1, 0), + raillabel.format.Point2d(1, 1), + ], + closed=False, + attributes={"railSide": "rightRail"}, + ) + frame.annotations["f6db5b28-bdcd-437f-bf39-c044bb516de8"] = raillabel.format.Poly2d( + uid="f6db5b28-bdcd-437f-bf39-c044bb516de8", + object=object, + sensor=sensor2, + points=[ + raillabel.format.Point2d(0, 0), + raillabel.format.Point2d(0, 1), + ], + closed=False, + attributes={"railSide": "rightRail"}, + ) + frame.annotations["89f8cf2c-1dc9-4956-9661-f1054ff069f9"] = raillabel.format.Poly2d( + uid="89f8cf2c-1dc9-4956-9661-f1054ff069f9", + object=object, + sensor=sensor2, + points=[ + raillabel.format.Point2d(1, 0), + raillabel.format.Point2d(1, 1), + ], + closed=False, + attributes={"railSide": "rightRail"}, + ) + scene.frames[frame.uid] = frame + + actual = validate_rail_side(scene) + assert len(actual) == 2 + + +def test_validate_rail_side__two_sensors_with_one_right_rail_each( + empty_scene, empty_frame, example_camera_1, example_camera_2, example_track_1 +): + scene = empty_scene + object = example_track_1 + scene.objects[object.uid] = object + sensor1 = example_camera_1 + sensor2 = example_camera_2 + for sensor in [sensor1, sensor2]: + scene.sensors[sensor.uid] = sensor + frame = empty_frame + frame.annotations["325b1f55-a2ef-475f-a780-13e1a9e823c3"] = raillabel.format.Poly2d( + uid="325b1f55-a2ef-475f-a780-13e1a9e823c3", + object=object, + sensor=sensor1, + points=[ + raillabel.format.Point2d(0, 0), + raillabel.format.Point2d(0, 1), + ], + closed=False, + attributes={"railSide": "rightRail"}, + ) + frame.annotations["f6db5b28-bdcd-437f-bf39-c044bb516de8"] = raillabel.format.Poly2d( + uid="f6db5b28-bdcd-437f-bf39-c044bb516de8", + object=object, + sensor=sensor2, + points=[ + raillabel.format.Point2d(0, 0), + raillabel.format.Point2d(0, 1), + ], + closed=False, + attributes={"railSide": "rightRail"}, + ) + scene.frames[frame.uid] = frame + + actual = validate_rail_side(scene) + assert len(actual) == 0 + + +if __name__ == "__main__": + pytest.main([__file__, "--disable-pytest-warnings", "--cache-clear", "-v"])