Skip to content

Commit

Permalink
feat: refactor validate_rail_side for Issue
Browse files Browse the repository at this point in the history
  • Loading branch information
nalquas committed Dec 2, 2024
1 parent 1dce0d4 commit fb3e597
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
IncludeSensorTypeFilter,
)

from raillabel_providerkit.validation import Issue, IssueIdentifiers, IssueType

def validate_rail_side(scene: raillabel.Scene) -> list[str]:

def validate_rail_side(scene: raillabel.Scene) -> list[Issue]:
"""Validate whether all tracks have <= one left and right rail, and that they have correct order.
Parameters
Expand All @@ -30,7 +32,7 @@ def validate_rail_side(scene: raillabel.Scene) -> list[str]:
errors present.
"""
errors: list[str] = []
errors = []

camera_uids = list(scene.filter([IncludeSensorTypeFilter(["camera"])]).sensors.keys())

Expand All @@ -47,11 +49,11 @@ def validate_rail_side(scene: raillabel.Scene) -> list[str]:
counts_per_track = _count_rails_per_track_in_frame(frame)

for object_uid, (left_count, right_count) in counts_per_track.items():
context = {
"frame_uid": frame_uid,
"object_uid": object_uid,
"camera_uid": camera_uid,
}
context = IssueIdentifiers(
frame=frame_uid,
sensor=camera_uid,
object=object_uid,
)

count_errors = _check_rail_counts(context, left_count, right_count)
exactly_one_left_and_right_rail_exist = count_errors != []
Expand All @@ -64,33 +66,37 @@ def validate_rail_side(scene: raillabel.Scene) -> list[str]:
if left_rail is None or right_rail is None:
continue

errors.extend(
_check_rails_for_swap_or_intersection(left_rail, right_rail, frame_uid)
)
errors.extend(_check_rails_for_swap_or_intersection(left_rail, right_rail, context))

return errors


def _check_rail_counts(context: dict, left_count: int, right_count: int) -> list[str]:
def _check_rail_counts(context: IssueIdentifiers, left_count: int, right_count: int) -> list[Issue]:
errors = []
if left_count > 1:
errors.append(
f"In sensor {context['camera_uid']} frame {context['frame_uid']}, the track with"
f" object_uid {context['object_uid']} has more than one ({left_count}) left rail."
Issue(
type=IssueType.RAIL_SIDE,
reason=f"This track has {left_count} left rails.",
identifiers=context,
)
)
if right_count > 1:
errors.append(
f"In sensor {context['camera_uid']} frame {context['frame_uid']}, the track with"
f" object_uid {context['object_uid']} has more than one ({right_count}) right rail."
Issue(
type=IssueType.RAIL_SIDE,
reason=f"This track has {right_count} right rails.",
identifiers=context,
)
)
return errors


def _check_rails_for_swap_or_intersection(
left_rail: raillabel.format.Poly2d,
right_rail: raillabel.format.Poly2d,
frame_uid: str | int = "unknown",
) -> list[str]:
context: IssueIdentifiers,
) -> list[Issue]:
if left_rail.object_id != right_rail.object_id:
return []

Expand All @@ -103,23 +109,22 @@ def _check_rails_for_swap_or_intersection(
if left_x is None or right_x is None:
return []

object_uid = left_rail.object_id
sensor_uid = left_rail.sensor_id if left_rail.sensor_id is not None else "unknown"

if left_x >= right_x:
return [
f"In sensor {sensor_uid} frame {frame_uid}, the track with"
f" object_uid {object_uid} has its rails swapped."
f" At the maximum common y={max_common_y}, the left rail has x={left_x}"
f" while the right rail has x={right_x}."
Issue(
type=IssueType.RAIL_SIDE,
reason="The left and right rails of this track are swapped.",
identifiers=context,
)
]

intersect_interval = _find_intersect_interval(left_rail, right_rail)
if intersect_interval is not None:
if _polylines_are_intersecting(left_rail, right_rail):
return [
f"In sensor {sensor_uid} frame {frame_uid}, the track with"
f" object_uid {object_uid} intersects with itself."
f" The left and right rail intersect in y interval {intersect_interval}."
Issue(
type=IssueType.RAIL_SIDE,
reason="The left and right rails of this track intersect.",
identifiers=context,
)
]

return []
Expand Down Expand Up @@ -162,9 +167,9 @@ def _filter_for_poly2ds(
]


def _find_intersect_interval(
def _polylines_are_intersecting(
line1: raillabel.format.Poly2d, line2: raillabel.format.Poly2d
) -> tuple[float, float] | None:
) -> bool:
"""If the two polylines intersect anywhere, return the y interval where they intersect."""
y_values_with_points_in_either_polyline: list[float] = sorted(
_get_y_of_all_points_of_poly2d(line1).union(_get_y_of_all_points_of_poly2d(line2))
Expand All @@ -181,18 +186,18 @@ def _find_intersect_interval(
continue

if x1 == x2:
return (y, y)
return True

new_order = x1 < x2

order_has_flipped = order is not None and new_order != order and last_y is not None
if order_has_flipped:
return (last_y, y) # type: ignore # noqa: PGH003
return True

order = new_order
last_y = y

return None
return False


def _find_max_y(poly2d: raillabel.format.Poly2d) -> float:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright DB InfraGO AG and contributors
# SPDX-License-Identifier: Apache-2.0

from uuid import UUID

import pytest
from raillabel.format import Poly2d, Point2d
from raillabel.scene_builder import SceneBuilder
Expand All @@ -9,6 +11,7 @@
validate_rail_side,
_count_rails_per_track_in_frame,
)
from raillabel_providerkit.validation import Issue, IssueIdentifiers, IssueType


def test_count_rails_per_track_in_frame__empty(empty_frame):
Expand Down Expand Up @@ -155,6 +158,7 @@ def test_validate_rail_side__no_errors(ignore_uuid):


def test_validate_rail_side__rail_sides_switched(ignore_uuid):
SENSOR_ID = "rgb_center"
scene = (
SceneBuilder.empty()
.add_annotation(
Expand All @@ -169,7 +173,7 @@ def test_validate_rail_side__rail_sides_switched(ignore_uuid):
sensor_id="IGNORE_THIS",
),
object_name="track_0001",
sensor_id="rgb_center",
sensor_id=SENSOR_ID,
)
.add_annotation(
annotation=Poly2d(
Expand All @@ -183,16 +187,25 @@ def test_validate_rail_side__rail_sides_switched(ignore_uuid):
sensor_id="IGNORE_THIS",
),
object_name="track_0001",
sensor_id="rgb_center",
sensor_id=SENSOR_ID,
)
.result
)

actual = validate_rail_side(scene)
assert len(actual) == 1
assert actual == [
Issue(
type=IssueType.RAIL_SIDE,
reason="The left and right rails of this track are swapped.",
identifiers=IssueIdentifiers(
frame=1, sensor=SENSOR_ID, object=UUID("5c59aad4-0000-4000-0000-000000000000")
),
)
]


def test_validate_rail_side__rail_sides_intersect_at_top(ignore_uuid):
SENSOR_ID = "rgb_center"
scene = (
SceneBuilder.empty()
.add_annotation(
Expand All @@ -209,7 +222,7 @@ def test_validate_rail_side__rail_sides_intersect_at_top(ignore_uuid):
sensor_id="IGNORE_THIS",
),
object_name="track_0001",
sensor_id="rgb_center",
sensor_id=SENSOR_ID,
)
.add_annotation(
annotation=Poly2d(
Expand All @@ -225,13 +238,21 @@ def test_validate_rail_side__rail_sides_intersect_at_top(ignore_uuid):
sensor_id="IGNORE_THIS",
),
object_name="track_0001",
sensor_id="rgb_center",
sensor_id=SENSOR_ID,
)
.result
)

actual = validate_rail_side(scene)
assert len(actual) == 1
assert actual == [
Issue(
type=IssueType.RAIL_SIDE,
reason="The left and right rails of this track intersect.",
identifiers=IssueIdentifiers(
frame=1, sensor=SENSOR_ID, object=UUID("5c59aad4-0000-4000-0000-000000000000")
),
)
]


def test_validate_rail_side__rail_sides_correct_with_early_end_of_one_side(ignore_uuid):
Expand Down Expand Up @@ -276,6 +297,7 @@ def test_validate_rail_side__rail_sides_correct_with_early_end_of_one_side(ignor


def test_validate_rail_side__two_left_rails(ignore_uuid):
SENSOR_ID = "rgb_center"
scene = (
SceneBuilder.empty()
.add_annotation(
Expand All @@ -290,7 +312,7 @@ def test_validate_rail_side__two_left_rails(ignore_uuid):
sensor_id="IGNORE_THIS",
),
object_name="track_0001",
sensor_id="rgb_center",
sensor_id=SENSOR_ID,
)
.add_annotation(
annotation=Poly2d(
Expand All @@ -304,16 +326,25 @@ def test_validate_rail_side__two_left_rails(ignore_uuid):
sensor_id="IGNORE_THIS",
),
object_name="track_0001",
sensor_id="rgb_center",
sensor_id=SENSOR_ID,
)
.result
)

actual = validate_rail_side(scene)
assert len(actual) == 1
assert actual == [
Issue(
type=IssueType.RAIL_SIDE,
reason="This track has 2 left rails.",
identifiers=IssueIdentifiers(
frame=1, sensor=SENSOR_ID, object=UUID("5c59aad4-0000-4000-0000-000000000000")
),
)
]


def test_validate_rail_side__two_right_rails(ignore_uuid):
SENSOR_ID = "rgb_center"
scene = (
SceneBuilder.empty()
.add_annotation(
Expand All @@ -325,10 +356,10 @@ def test_validate_rail_side__two_right_rails(ignore_uuid):
closed=False,
attributes={"railSide": "rightRail"},
object_id=ignore_uuid,
sensor_id="IGNORE_THIS",
sensor_id=SENSOR_ID,
),
object_name="track_0001",
sensor_id="rgb_center",
sensor_id=SENSOR_ID,
)
.add_annotation(
annotation=Poly2d(
Expand All @@ -339,16 +370,24 @@ def test_validate_rail_side__two_right_rails(ignore_uuid):
closed=False,
attributes={"railSide": "rightRail"},
object_id=ignore_uuid,
sensor_id="IGNORE_THIS",
sensor_id=SENSOR_ID,
),
object_name="track_0001",
sensor_id="rgb_center",
sensor_id=SENSOR_ID,
)
.result
)

actual = validate_rail_side(scene)
assert len(actual) == 1
assert actual == [
Issue(
type=IssueType.RAIL_SIDE,
reason="This track has 2 right rails.",
identifiers=IssueIdentifiers(
frame=1, sensor=SENSOR_ID, object=UUID("5c59aad4-0000-4000-0000-000000000000")
),
)
]


def test_validate_rail_side__two_sensors_with_two_right_rails_each(ignore_uuid):
Expand Down Expand Up @@ -460,4 +499,4 @@ def test_validate_rail_side__two_sensors_with_one_right_rail_each(ignore_uuid):


if __name__ == "__main__":
pytest.main([__file__, "--disable-pytest-warnings", "--cache-clear", "-v"])
pytest.main([__file__, "-vv"])

0 comments on commit fb3e597

Please sign in to comment.