diff --git a/tests/test_raillabel_providerkit/convert/loader_classes/test_loader_understand_ai.py b/tests/test_raillabel_providerkit/convert/loader_classes/test_loader_understand_ai.py index 77ea33a..d3a8e06 100644 --- a/tests/test_raillabel_providerkit/convert/loader_classes/test_loader_understand_ai.py +++ b/tests/test_raillabel_providerkit/convert/loader_classes/test_loader_understand_ai.py @@ -2,6 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 import pytest +import raillabel +import raillabel.load_ +import raillabel.load_.loader_classes +import raillabel.load_.loader_classes.loader_raillabel from raillabel_providerkit.convert.loader_classes.loader_understand_ai import LoaderUnderstandAi @@ -30,5 +34,34 @@ def test_validate_schema__real_life_file__errors(json_data): assert "topic" in actual[0] +def test_load(json_data): + input_data_raillabel = remove_non_parsed_fields(json_data["openlabel_v1_short"]) + input_data_uai = json_data["understand_ai_t4_short"] + + scene_ground_truth = raillabel.load_.loader_classes.loader_raillabel.LoaderRailLabel().load(input_data_raillabel, validate=False) + scene = LoaderUnderstandAi().load(input_data_uai, validate_schema=False) + + scene.metadata = scene_ground_truth.metadata + assert scene == scene_ground_truth + +def remove_non_parsed_fields(raillabel_data: dict) -> dict: + """Return RailLabel file with frame_data and poly3ds removed.""" + + for frame in raillabel_data["openlabel"]["frames"].values(): + + if "frame_data" in frame["frame_properties"]: + del frame["frame_properties"]["frame_data"] + + for object_id, object in list(frame["objects"].items()): + if "poly3d" not in object["object_data"]: + continue + + del object["object_data"]["poly3d"] + if len(object["object_data"]) == 0: + del frame["objects"][object_id] + + return raillabel_data + + if __name__ == "__main__": - pytest.main([__file__, "--disable-pytest-warnings", "--cache-clear"]) + pytest.main([__file__, "-vv"])