From a1c61e1a3e930f5b8bf9c24bdc718f59d5c198f2 Mon Sep 17 00:00:00 2001 From: Niklas Freund Date: Thu, 7 Nov 2024 10:12:33 +0100 Subject: [PATCH] feat: Implement filter_sensor_uids_by_type --- raillabel_providerkit/_util/_filters.py | 27 +++++++++ .../_util/test_filters.py | 59 +++++++++++++++++++ 2 files changed, 86 insertions(+) create mode 100644 raillabel_providerkit/_util/_filters.py create mode 100644 tests/test_raillabel_providerkit/_util/test_filters.py diff --git a/raillabel_providerkit/_util/_filters.py b/raillabel_providerkit/_util/_filters.py new file mode 100644 index 0000000..5b4a4b2 --- /dev/null +++ b/raillabel_providerkit/_util/_filters.py @@ -0,0 +1,27 @@ +# Copyright DB Netz AG and contributors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import raillabel + + +def filter_sensor_uids_by_type( + sensors: list[raillabel.format.Sensor], sensor_type: raillabel.format.SensorType +) -> set[str]: + """Get the uids of all given sensors matching the given SensorType. + + Parameters + ---------- + sensors : list[raillabel.format.Sensor] + The sensors to filter. + sensor_type : raillabel.format.SensorType + The SensorType to filter by. + + Returns + ------- + set[str] + The list of uids of matching sensors. + + """ + return {sensor.uid for sensor in sensors if sensor.type == sensor_type} diff --git a/tests/test_raillabel_providerkit/_util/test_filters.py b/tests/test_raillabel_providerkit/_util/test_filters.py new file mode 100644 index 0000000..b9bc5ab --- /dev/null +++ b/tests/test_raillabel_providerkit/_util/test_filters.py @@ -0,0 +1,59 @@ +# Copyright DB Netz AG and contributors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import raillabel + +from raillabel_providerkit._util._filters import filter_sensor_uids_by_type + + +@pytest.fixture +def sensor_types() -> list[raillabel.format.SensorType]: + return [sensor_type for sensor_type in raillabel.format.SensorType] + + +def test_filter_sensor_uids_by_type__empty(sensor_types): + sensors = [] + for sensor_type in sensor_types: + assert len(filter_sensor_uids_by_type(sensors, sensor_type)) == 0 + + +def test_filter_sensor_uids_by_type__exactly_one_match(sensor_types): + # Create a list of sensors where each sensor type occurs exactly once + sensors = [] + for i in range(len(sensor_types)): + sensors.append(raillabel.format.Sensor(uid=f"test_{i}", type=sensor_types[i])) + + # Ensure the filter works for each sensor type + for sensor_type in sensor_types: + results = filter_sensor_uids_by_type(sensors, sensor_type) + assert len(results) == 1 + # Assert the result is of correct type + matches = 0 + for sensor in sensors: + if sensor.uid in results: + assert sensor.type == sensor_type + matches += 1 + assert matches == len(results) + + +def test_filter_sensor_uids_by_type__multiple_matches(sensor_types): + # Create a list of sensors where each sensor type occurs three times + sensors = [] + i = 0 + for sensor_type in sensor_types: + for j in range(3): + sensors.append(raillabel.format.Sensor(uid=f"test_{i}", type=sensor_type)) + i += 1 + + # Ensure the filter works for each sensor type + for sensor_type in sensor_types: + results = filter_sensor_uids_by_type(sensors, sensor_type) + assert len(results) == 3 + # Assert the results are of correct type + matches = 0 + for sensor in sensors: + if sensor.uid in results: + assert sensor.type == sensor_type + matches += 1 + assert matches == len(results)