Skip to content

Commit

Permalink
test: re-implement understand ai loader load test
Browse files Browse the repository at this point in the history
  • Loading branch information
tklockau committed Oct 28, 2024
1 parent 8c82d85 commit 73b12cb
Showing 1 changed file with 34 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
import raillabel
import raillabel.load_
import raillabel.load_.loader_classes
import raillabel.load_.loader_classes.loader_raillabel

from raillabel_providerkit.convert.loader_classes.loader_understand_ai import LoaderUnderstandAi

Expand Down Expand Up @@ -30,5 +34,34 @@ def test_validate_schema__real_life_file__errors(json_data):
assert "topic" in actual[0]


def test_load(json_data):
input_data_raillabel = remove_non_parsed_fields(json_data["openlabel_v1_short"])
input_data_uai = json_data["understand_ai_t4_short"]

scene_ground_truth = raillabel.load_.loader_classes.loader_raillabel.LoaderRailLabel().load(input_data_raillabel, validate=False)
scene = LoaderUnderstandAi().load(input_data_uai, validate_schema=False)

scene.metadata = scene_ground_truth.metadata
assert scene == scene_ground_truth

def remove_non_parsed_fields(raillabel_data: dict) -> dict:
"""Return RailLabel file with frame_data and poly3ds removed."""

for frame in raillabel_data["openlabel"]["frames"].values():

if "frame_data" in frame["frame_properties"]:
del frame["frame_properties"]["frame_data"]

for object_id, object in list(frame["objects"].items()):
if "poly3d" not in object["object_data"]:
continue

del object["object_data"]["poly3d"]
if len(object["object_data"]) == 0:
del frame["objects"][object_id]

return raillabel_data


if __name__ == "__main__":
pytest.main([__file__, "--disable-pytest-warnings", "--cache-clear"])
pytest.main([__file__, "-vv"])

0 comments on commit 73b12cb

Please sign in to comment.