Skip to content

Commit

Permalink
fix: Refactor validate_rail_side for raillabel 4
Browse files Browse the repository at this point in the history
  • Loading branch information
nalquas committed Nov 25, 2024
1 parent da7a317 commit 59d8fec
Showing 1 changed file with 14 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from __future__ import annotations

from uuid import UUID

import numpy as np
import raillabel
from raillabel.filter import (
Expand Down Expand Up @@ -31,7 +33,7 @@ def validate_rail_side(scene: raillabel.Scene) -> list[str]:
errors: list[str] = []

# Get a list of camera uids
cameras = list(scene.filter([IncludeSensorTypeFilter("camera")]).sensors.keys())
cameras = list(scene.filter([IncludeSensorTypeFilter(["camera"])]).sensors.keys())

# Check per camera
for camera in cameras:
Expand Down Expand Up @@ -123,9 +125,9 @@ def _check_rails_for_swap(
return None


def _count_rails_per_track_in_frame(frame: raillabel.format.Frame) -> dict[str, tuple[int, int]]:
def _count_rails_per_track_in_frame(frame: raillabel.format.Frame) -> dict[UUID, tuple[int, int]]:
# For each track, the left and right rail counts are stored as a list (left, right)
counts: dict[str, list[int, int]] = {}
counts: dict[UUID, list[int]] = {}

# For each track, count the left and right rails
unfiltered_annotations = list(frame.annotations.values())
Expand All @@ -148,11 +150,14 @@ def _count_rails_per_track_in_frame(frame: raillabel.format.Frame) -> dict[str,
continue

# Return results
return {key: tuple(value) for key, value in counts.items()}
return {
object_id: (object_counts[0], object_counts[1])
for object_id, object_counts in counts.items()
}


def _filter_for_poly2ds(
unfiltered_annotations: list[type[raillabel.format._ObjectAnnotation]],
unfiltered_annotations: list,
) -> list[raillabel.format.Poly2d]:
return [
annotation
Expand Down Expand Up @@ -271,9 +276,12 @@ def _find_x_by_y(y: float, poly2d: raillabel.format.Poly2d) -> float | None:


def _get_track_from_frame(
frame: raillabel.format.Frame, object_uid: str, rail_side: str
frame: raillabel.format.Frame, object_uid: UUID, rail_side: str
) -> raillabel.format.Poly2d | None:
for annotation in frame.annotations.values():
if not isinstance(annotation, raillabel.format.Poly2d):
continue

if annotation.object_id != object_uid:
continue

Expand Down

0 comments on commit 59d8fec

Please sign in to comment.