diff --git a/raillabel_providerkit/validation/validate_rail_side/validate_rail_side.py b/raillabel_providerkit/validation/validate_rail_side/validate_rail_side.py index 94f4a09..e791bd1 100644 --- a/raillabel_providerkit/validation/validate_rail_side/validate_rail_side.py +++ b/raillabel_providerkit/validation/validate_rail_side/validate_rail_side.py @@ -43,9 +43,6 @@ def validate_rail_side(scene: raillabel.Scene) -> list[str]: # Count rails per track counts_per_track = _count_rails_per_track_in_frame(frame) - # Find rail x limits per track - track_limits_per_track = _get_track_limits_in_frame(frame) - # Add errors if there is more than one left or right rail for object_uid, (left_count, right_count) in counts_per_track.items(): if left_count > 1: @@ -59,21 +56,66 @@ def validate_rail_side(scene: raillabel.Scene) -> list[str]: f" object_uid {object_uid} has more than one ({right_count}) right rail." ) - # If left and right rails exist, check if the track has its rails swapped - if left_count >= 1 and right_count >= 1: - # Add errors if any track has its rails swapped - (max_x_of_left, min_x_of_right) = track_limits_per_track[object_uid] - if max_x_of_left > min_x_of_right: - errors.append( - f"In sensor {camera} frame {frame_uid}, the track with" - f" object_uid {object_uid} has its rails swapped." - f" The right-most left rail has x={max_x_of_left} while" - f" the left-most right rail has x={min_x_of_right}." - ) + # If exactly one left and right rail exists, check if the track has its rails swapped + # or intersects with itself + if left_count == 1 == right_count: + # Get the two annotations in question + left_rail: raillabel.format.Poly2d | None = _get_track_from_frame( + frame, object_uid, "leftRail" + ) + right_rail: raillabel.format.Poly2d | None = _get_track_from_frame( + frame, object_uid, "rightRail" + ) + if left_rail is None or right_rail is None: + continue + + swap_error: str | None = _check_rails_for_swap(left_rail, right_rail, frame_uid) + if swap_error is not None: + errors.append(swap_error) return errors +def _check_rails_for_swap( + left_rail: raillabel.format.Poly2d, + right_rail: raillabel.format.Poly2d, + frame_uid: str | int = "unknown", +) -> str | None: + # Ensure the rails belong to the same track + if left_rail.object.uid != right_rail.object.uid: + return None + + max_common_y = _find_max_common_y(left_rail, right_rail) + if max_common_y is None: + return None + + left_x = _find_x_by_y(max_common_y, left_rail) + right_x = _find_x_by_y(max_common_y, right_rail) + if left_x is None or right_x is None: + return None + + object_uid = left_rail.object.uid + sensor_uid = left_rail.sensor.uid if left_rail.sensor is not None else "unknown" + + if left_x >= right_x: + return ( + f"In sensor {sensor_uid} frame {frame_uid}, the track with" + f" object_uid {object_uid} has its rails swapped." + f" At the maximum common y={max_common_y}, the left rail has x={left_x}" + f" while the right rail has x={right_x}." + ) + + intersect_interval = _find_intersect_interval(left_rail, right_rail) + if intersect_interval is not None: + return ( + f"In sensor {sensor_uid} frame {frame_uid}, the track with" + f" object_uid {object_uid} intersects with itself." + f" The left and right rail intersect in y interval {intersect_interval}." + ) + + return None + + def _count_rails_per_track_in_frame(frame: raillabel.format.Frame) -> dict[str, tuple[int, int]]: # For each track, the left and right rail counts are stored as a tuple (left, right) counts: dict[str, tuple[int, int]] = {} @@ -105,39 +147,6 @@ def _count_rails_per_track_in_frame(frame: raillabel.format.Frame) -> dict[str, return counts -def _get_track_limits_in_frame(frame: raillabel.format.Frame) -> dict[str, tuple[float, float]]: - # For each track, the largest x of any left rail and the smallest x of any right rail is stored - # as a tuple (max_x_of_left, min_x_of_right) - track_limits: dict[str, tuple[float, float]] = {} - - for object_uid, unfiltered_annotations in frame.object_data.items(): - # Ensure we work only on Poly2d annotations - poly2ds: list[raillabel.format.Poly2d] = _filter_for_poly2ds( - list(unfiltered_annotations.values()) - ) - - # Get the largest x of any left rail and the smallest x of any right rail - max_x_of_left: float = float("-inf") - min_x_of_right: float = float("inf") - for poly2d in poly2ds: - rail_x_values: list[float] = [point.x for point in poly2d.points] - match poly2d.attributes["railSide"]: - case "leftRail": - max_x_of_rail_points: float = np.max(rail_x_values) - max_x_of_left = max(max_x_of_rail_points, max_x_of_left) - case "rightRail": - min_x_of_rail_points: float = np.min(rail_x_values) - min_x_of_right = min(min_x_of_rail_points, min_x_of_right) - case _: - # NOTE: This is ignored because it is covered by validate_onthology - continue - - # Store the calculated limits of current track - track_limits[object_uid] = (max_x_of_left, min_x_of_right) - - return track_limits - - def _filter_for_poly2ds( unfiltered_annotations: list[type[raillabel.format._ObjectAnnotation]], ) -> list[raillabel.format.Poly2d]: @@ -146,3 +155,163 @@ def _filter_for_poly2ds( for annotation in unfiltered_annotations if isinstance(annotation, raillabel.format.Poly2d) ] + + +def _find_intersect_interval( + line1: raillabel.format.Poly2d, line2: raillabel.format.Poly2d +) -> tuple[float, float] | None: + # If the two polylines intersect anywhere, return the y interval where they intersect. + + # Get all y values where either polyline has points + y_values: list[float] = sorted( + _get_y_of_all_points_of_poly2d(line1).union(_get_y_of_all_points_of_poly2d(line2)) + ) + + order: bool | None = None + last_y: float | None = None + for y in y_values: + x1 = _find_x_by_y(y, line1) + x2 = _find_x_by_y(y, line2) + + if x1 is None or x2 is None: + order = None + continue + + if x1 == x2: + return (y, y) + + new_order = x1 < x2 + + if order is not None and new_order != order and last_y is not None: + # The order has flipped. There is an intersection between previous and current y + return (last_y, y) + + order = new_order + last_y = y + + return None + + +def _find_max_y(poly2d: raillabel.format.Poly2d) -> float: + return np.max([point.y for point in poly2d.points]) + + +def _find_max_common_y( + line1: raillabel.format.Poly2d, line2: raillabel.format.Poly2d +) -> float | None: + if len(line1.points) == 0 or len(line2.points) == 0: + # One of the lines is empty + return None + + max_y_of_line1: float = _find_max_y(line1) + if _y_in_poly2d(max_y_of_line1, line2): + # The highest y is the bottom of line 1 + return max_y_of_line1 + + max_y_of_line2: float = _find_max_y(line2) + if _y_in_poly2d(max_y_of_line2, line1): + # The highest y is the bottom of line 2 + return max_y_of_line2 + + # There is no y overlap + return None + + +def _find_x_by_y(y: float, poly2d: raillabel.format.Poly2d) -> float | None: + """Find the x value of the first point where the polyline passes through y. + + Parameters + ---------- + y : float + The y value to check. + poly2d : raillabel.format.Poly2d + The Poly2D whose points will be checked against. + + Returns + ------- + float | None + The x value of a point (x,y) that poly2d passes through, + or None if poly2d doesn't go through y. + + """ + # 1. Find the first two points between which y is located + points = poly2d.points + p1: raillabel.format.Point2d | None = None + p2: raillabel.format.Point2d | None = None + for i in range(len(points) - 1): + current = points[i] + next_ = points[i + 1] + if (current.y >= y >= next_.y) or (current.y <= y <= next_.y): + p1 = current + p2 = next_ + break + + # 2. Abort if no valid points have been found + if not (p1 and p2): + return None + + # 3. Return early if p1=p2 (to avoid division by zero) + if p1.x == p2.x: + return p1.x + + # 4. Calculate m and n for the line g(x)=mx+n connecting p1 and p2 + m = (p2.y - p1.y) / (p2.x - p1.x) + n = p1.y - (m * p1.x) + + # 5. Return early if m is 0, as that means p2.y=p1.y, which implies p2.y=p1.y=y + if m == 0: + return p1.x + + # 6. Calculate the x we were searching for and return it + return (y - n) / m + + +def _get_track_from_frame( + frame: raillabel.format.Frame, object_uid: str, rail_side: str +) -> raillabel.format.Poly2d | None: + if object_uid not in frame.object_data: + return None + + for annotation in frame.object_data[object_uid].values(): + if not isinstance(annotation, raillabel.format.Poly2d): + continue + + if "railSide" not in annotation.attributes: + continue + + if annotation.attributes["railSide"] == rail_side: + return annotation + + return None + + +def _get_y_of_all_points_of_poly2d(poly2d: raillabel.format.Poly2d) -> set[float]: + y_values: set[float] = set() + for point in poly2d.points: + y_values.add(point.y) + return y_values + + +def _y_in_poly2d(y: float, poly2d: raillabel.format.Poly2d) -> bool: + """Check whether the polyline created by the given Poly2d passes through the given y value. + + Parameters + ---------- + y : float + The y value to check. + poly2d : raillabel.format.Poly2d + The Poly2D whose points will be checked against. + + Returns + ------- + bool + Does the Poly2d pass through the given y value? + + """ + # For every point (except the last), check if the y is between them + for i in range(len(poly2d.points) - 1): + current = poly2d.points[i] + next_ = poly2d.points[i + 1] + if (current.y >= y >= next_.y) or (current.y <= y <= next_.y): + return True + return False diff --git a/tests/test_raillabel_providerkit/validation/validate_rail_side/test_validate_rail_side.py b/tests/test_raillabel_providerkit/validation/validate_rail_side/test_validate_rail_side.py index 06f5e09..eb053be 100644 --- a/tests/test_raillabel_providerkit/validation/validate_rail_side/test_validate_rail_side.py +++ b/tests/test_raillabel_providerkit/validation/validate_rail_side/test_validate_rail_side.py @@ -7,8 +7,6 @@ from raillabel_providerkit.validation.validate_rail_side.validate_rail_side import ( validate_rail_side, _count_rails_per_track_in_frame, - _get_track_limits_in_frame, - _filter_for_poly2ds, ) @@ -140,29 +138,20 @@ def test_count_rails_per_track_in_frame__many_rails_for_two_tracks( assert results[object2.uid] == (LEFT_COUNT, RIGHT_COUNT) -def test_get_track_limits_in_frame__empty(empty_frame): - frame = empty_frame - results = _get_track_limits_in_frame(frame) - assert len(results) == 0 - - -def test_get_track_limits_in_frame__one_track_two_rails( - empty_frame, example_camera_1, example_track_1 -): - frame = empty_frame - sensor = example_camera_1 +def test_validate_rail_side__no_errors(empty_scene, empty_frame, example_camera_1, example_track_1): + scene = empty_scene object = example_track_1 - - MAX_X_OF_LEFT = 42 - MIN_X_OF_RIGHT = 73 - + scene.objects[object.uid] = object + sensor = example_camera_1 + scene.sensors[sensor.uid] = sensor + frame = empty_frame frame.annotations["325b1f55-a2ef-475f-a780-13e1a9e823c3"] = raillabel.format.Poly2d( uid="325b1f55-a2ef-475f-a780-13e1a9e823c3", object=object, sensor=sensor, points=[ raillabel.format.Point2d(0, 0), - raillabel.format.Point2d(MAX_X_OF_LEFT, 1), + raillabel.format.Point2d(0, 1), ], closed=False, attributes={"railSide": "leftRail"}, @@ -172,20 +161,21 @@ def test_get_track_limits_in_frame__one_track_two_rails( object=object, sensor=sensor, points=[ - raillabel.format.Point2d(1000, 0), - raillabel.format.Point2d(MIN_X_OF_RIGHT, 1), + raillabel.format.Point2d(1, 0), + raillabel.format.Point2d(1, 1), ], closed=False, attributes={"railSide": "rightRail"}, ) + scene.frames[frame.uid] = frame - results = _get_track_limits_in_frame(frame) - assert len(results) == 1 - assert object.uid in results.keys() - assert results[object.uid] == (MAX_X_OF_LEFT, MIN_X_OF_RIGHT) + actual = validate_rail_side(scene) + assert len(actual) == 0 -def test_validate_rail_side__no_errors(empty_scene, empty_frame, example_camera_1, example_track_1): +def test_validate_rail_side__rail_sides_switched( + empty_scene, empty_frame, example_camera_1, example_track_1 +): scene = empty_scene object = example_track_1 scene.objects[object.uid] = object @@ -201,7 +191,7 @@ def test_validate_rail_side__no_errors(empty_scene, empty_frame, example_camera_ raillabel.format.Point2d(0, 1), ], closed=False, - attributes={"railSide": "leftRail"}, + attributes={"railSide": "rightRail"}, ) frame.annotations["be7d136a-8364-4fbd-b098-6f4a21205d22"] = raillabel.format.Poly2d( uid="be7d136a-8364-4fbd-b098-6f4a21205d22", @@ -212,15 +202,15 @@ def test_validate_rail_side__no_errors(empty_scene, empty_frame, example_camera_ raillabel.format.Point2d(1, 1), ], closed=False, - attributes={"railSide": "rightRail"}, + attributes={"railSide": "leftRail"}, ) scene.frames[frame.uid] = frame actual = validate_rail_side(scene) - assert len(actual) == 0 + assert len(actual) == 1 -def test_validate_rail_side__rail_sides_switched( +def test_validate_rail_side__rail_sides_intersect_at_top( empty_scene, empty_frame, example_camera_1, example_track_1 ): scene = empty_scene @@ -234,22 +224,26 @@ def test_validate_rail_side__rail_sides_switched( object=object, sensor=sensor, points=[ - raillabel.format.Point2d(0, 0), - raillabel.format.Point2d(0, 1), + raillabel.format.Point2d(20, 0), + raillabel.format.Point2d(20, 10), + raillabel.format.Point2d(10, 20), + raillabel.format.Point2d(10, 100), ], closed=False, - attributes={"railSide": "rightRail"}, + attributes={"railSide": "leftRail"}, ) frame.annotations["be7d136a-8364-4fbd-b098-6f4a21205d22"] = raillabel.format.Poly2d( uid="be7d136a-8364-4fbd-b098-6f4a21205d22", object=object, sensor=sensor, points=[ - raillabel.format.Point2d(1, 0), - raillabel.format.Point2d(1, 1), + raillabel.format.Point2d(10, 0), + raillabel.format.Point2d(10, 10), + raillabel.format.Point2d(20, 20), + raillabel.format.Point2d(20, 100), ], closed=False, - attributes={"railSide": "leftRail"}, + attributes={"railSide": "rightRail"}, ) scene.frames[frame.uid] = frame @@ -257,6 +251,46 @@ def test_validate_rail_side__rail_sides_switched( assert len(actual) == 1 +def test_validate_rail_side__rail_sides_correct_with_early_end_of_one_side( + empty_scene, empty_frame, example_camera_1, example_track_1 +): + scene = empty_scene + object = example_track_1 + scene.objects[object.uid] = object + sensor = example_camera_1 + scene.sensors[sensor.uid] = sensor + frame = empty_frame + frame.annotations["325b1f55-a2ef-475f-a780-13e1a9e823c3"] = raillabel.format.Poly2d( + uid="325b1f55-a2ef-475f-a780-13e1a9e823c3", + object=object, + sensor=sensor, + points=[ + raillabel.format.Point2d(70, 0), + raillabel.format.Point2d(30, 20), + raillabel.format.Point2d(15, 40), + raillabel.format.Point2d(10, 50), + raillabel.format.Point2d(10, 100), + ], + closed=False, + attributes={"railSide": "leftRail"}, + ) + frame.annotations["be7d136a-8364-4fbd-b098-6f4a21205d22"] = raillabel.format.Poly2d( + uid="be7d136a-8364-4fbd-b098-6f4a21205d22", + object=object, + sensor=sensor, + points=[ + raillabel.format.Point2d(20, 50), + raillabel.format.Point2d(20, 100), + ], + closed=False, + attributes={"railSide": "rightRail"}, + ) + scene.frames[frame.uid] = frame + + actual = validate_rail_side(scene) + assert len(actual) == 0 + + def test_validate_rail_side__two_left_rails( empty_scene, empty_frame, example_camera_1, example_track_1 ):