diff --git a/src/openeo_aggregator/backend.py b/src/openeo_aggregator/backend.py index d9ebb2bb..d815c074 100644 --- a/src/openeo_aggregator/backend.py +++ b/src/openeo_aggregator/backend.py @@ -107,17 +107,15 @@ def _get_process_registry(self) -> ProcessRegistry: # TODO: fail instead of warn? _log.warning(f"Failed to get processes from {con.id}", exc_info=True) - # TODO: not only check process name, but also parameters and return type? - # TODO: return union of processes instead of intersection? - intersection = None + # TODO #4: combined set of processes: union, intersection or something else? + # TODO #4: not only check process name, but also parameters and return type? + combined_processes = {} for bid, backend_processes in processes_per_backend.items(): - if intersection is None: - intersection = backend_processes - else: - intersection = {k: v for (k, v) in intersection.items() if k in backend_processes} + # Combine by taking union (with higher preference for earlier backends) + combined_processes = {**backend_processes, **combined_processes} process_registry = ProcessRegistry() - for pid, spec in intersection.items(): + for pid, spec in combined_processes.items(): process_registry.add_spec(spec=spec) return process_registry diff --git a/tests/test_backend.py b/tests/test_backend.py index 76f77ef3..f5af6bd0 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -141,6 +141,29 @@ def test_get_process_registry(self, multi_backend_connection, backend1, backend2 catalog = AggregatorCollectionCatalog(backends=multi_backend_connection) processing = AggregatorProcessing(backends=multi_backend_connection, catalog=catalog) registry = processing.get_process_registry(api_version="1.0.0") - assert registry.get_specs() == [ + assert sorted(registry.get_specs(), key=lambda p: p["id"]) == [ + {"id": "add", "parameters": [{"name": "x"}, {"name": "y"}]}, {"id": "mean", "parameters": [{"name": "data"}]}, + {"id": "multiply", "parameters": [{"name": "x"}, {"name": "y"}]}, + ] + + def test_get_process_registry_parameter_differences( + self, multi_backend_connection, backend1, backend2, + requests_mock + ): + requests_mock.get(backend1 + "/processes", json={"processes": [ + {"id": "add", "parameters": [{"name": "x"}, {"name": "y"}]}, + {"id": "mean", "parameters": [{"name": "array"}]}, + ]}) + requests_mock.get(backend2 + "/processes", json={"processes": [ + {"id": "multiply", "parameters": [{"name": "x"}, {"name": "y"}]}, + {"id": "mean", "parameters": [{"name": "values"}]}, + ]}) + catalog = AggregatorCollectionCatalog(backends=multi_backend_connection) + processing = AggregatorProcessing(backends=multi_backend_connection, catalog=catalog) + registry = processing.get_process_registry(api_version="1.0.0") + assert sorted(registry.get_specs(), key=lambda p: p["id"]) == [ + {"id": "add", "parameters": [{"name": "x"}, {"name": "y"}]}, + {"id": "mean", "parameters": [{"name": "array"}]}, + {"id": "multiply", "parameters": [{"name": "x"}, {"name": "y"}]}, ] diff --git a/tests/test_views.py b/tests/test_views.py index f1de997f..2027a82b 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -97,7 +97,11 @@ def test_processes_basic(self, api100, requests_mock, backend1, backend2): ]}) res = api100.get("/processes").assert_status_code(200).json assert res == { - "processes": [{"id": "mean", "parameters": [{"name": "data"}]}], + "processes": [ + {"id": "multiply", "parameters": [{"name": "x"}, {"name": "y"}]}, + {"id": "mean", "parameters": [{"name": "data"}]}, + {"id": "add", "parameters": [{"name": "x"}, {"name": "y"}]}, + ], "links": [], }