From 7d693699d712cd42dddd8c842690de7f4a5e6e36 Mon Sep 17 00:00:00 2001 From: Stefaan Lippens Date: Fri, 20 Sep 2024 13:42:02 +0200 Subject: [PATCH] Issue #150/#155 extract split_at_multiple logic from DeepGraphSplitter --- .../partitionedjobs/crossbackend.py | 90 +++++++++++-------- tests/partitionedjobs/test_crossbackend.py | 65 ++++++++++++-- 2 files changed, 111 insertions(+), 44 deletions(-) diff --git a/src/openeo_aggregator/partitionedjobs/crossbackend.py b/src/openeo_aggregator/partitionedjobs/crossbackend.py index 667ed59..2be30c6 100644 --- a/src/openeo_aggregator/partitionedjobs/crossbackend.py +++ b/src/openeo_aggregator/partitionedjobs/crossbackend.py @@ -96,7 +96,7 @@ def _default_get_replacement(node_id: str, node: dict, subgraph_id: SubGraphId) } -class _SubGraphData(NamedTuple): +class _PGSplitSubGraph(NamedTuple): """Container for result of ProcessGraphSplitterInterface.split""" split_node: NodeId @@ -109,7 +109,7 @@ class _PGSplitResult(NamedTuple): primary_node_ids: Set[NodeId] primary_backend_id: BackendId - secondary_graphs: List[_SubGraphData] + secondary_graphs: List[_PGSplitSubGraph] class ProcessGraphSplitterInterface(metaclass=abc.ABCMeta): @@ -161,7 +161,7 @@ def split(self, process_graph: FlatPG) -> _PGSplitResult: primary_has_load_collection = False primary_graph_node_ids = set() - secondary_graphs: List[_SubGraphData] = [] + secondary_graphs: List[_PGSplitSubGraph] = [] for node_id, node in process_graph.items(): if node["process_id"] == "load_collection": bid = backend_per_collection[node["arguments"]["id"]] @@ -169,7 +169,7 @@ def split(self, process_graph: FlatPG) -> _PGSplitResult: primary_graph_node_ids.add(node_id) primary_has_load_collection = True else: - secondary_graphs.append(_SubGraphData(split_node=node_id, node_ids={node_id}, backend_id=bid)) + secondary_graphs.append(_PGSplitSubGraph(split_node=node_id, node_ids={node_id}, backend_id=bid)) else: primary_graph_node_ids.add(node_id) @@ -579,7 +579,7 @@ def from_edges( def node(self, node_id: NodeId) -> _GVNode: if node_id not in self._graph: - raise GraphSplitException(f"Invalid node id {node_id}.") + raise GraphSplitException(f"Invalid node id {node_id!r}.") return self._graph[node_id] def iter_nodes(self) -> Iterator[Tuple[NodeId, _GVNode]]: @@ -712,7 +712,7 @@ def get_flow_weights(node_id: NodeId) -> Dict[NodeId, fractions.Fraction]: def split_at(self, split_node_id: NodeId) -> Tuple[_GraphViewer, _GraphViewer]: """ Split graph at given node id (must be articulation point), - creating two new graphs, containing original nodes and adaptation of the split node. + creating two new graph viewers, containing original nodes and adaptation of the split node. :return: two _GraphViewer objects: the upstream subgraph and the downstream subgraph """ @@ -729,7 +729,7 @@ def next_nodes(node_id: NodeId) -> Iterable[NodeId]: up_node_ids = set(self._walk(seeds=[split_node_id], next_nodes=next_nodes)) if split_node.flows_to.intersection(up_node_ids): - raise GraphSplitException(f"Graph can not be split at {split_node_id}: not an articulation point.") + raise GraphSplitException(f"Graph can not be split at {split_node_id!r}: not an articulation point.") up_graph = {n: self.node(n) for n in up_node_ids} # Replacement of original split node: no `flows_to` links @@ -810,6 +810,26 @@ def produce_split_locations(self, limit: int = 20) -> Iterator[List[NodeId]]: # All nodes can be handled as is, no need to split yield [] + def split_at_multiple(self, split_nodes: List[NodeId]) -> Dict[Union[NodeId, None], _GraphViewer]: + """ + Split the graph viewer at multiple nodes in the order as provided. + Each split produces an upstream and downstream graph, + the downstream graph is used for the next split, + so the split nodes should be ordered as such. + + Returns dictionary with: + - key: split node_ids or None for the final downstream graph + - value: corresponding sub graph viewers as values. + """ + result = {} + graph_to_split = self + for split_node_id in split_nodes: + up, down = graph_to_split.split_at(split_node_id=split_node_id) + result[split_node_id] = up + graph_to_split = down + result[None] = graph_to_split + return result + class DeepGraphSplitter(ProcessGraphSplitterInterface): """ @@ -820,6 +840,16 @@ def __init__(self, supporting_backends: SupportingBackendsMapper, primary_backen self._supporting_backends_mapper = supporting_backends self._primary_backend = primary_backend + def _pick_backend(self, backend_candidates: Union[frozenset[BackendId], None]) -> BackendId: + if backend_candidates is None: + if self._primary_backend: + return self._primary_backend + else: + raise GraphSplitException("DeepGraphSplitter._pick_backend: No backend candidates.") + else: + # TODO: better backend selection mechanism + return sorted(backend_candidates)[0] + def split(self, process_graph: FlatPG) -> _PGSplitResult: graph = _GraphViewer.from_flat_graph( flat_graph=process_graph, supporting_backends=self._supporting_backends_mapper @@ -828,36 +858,26 @@ def split(self, process_graph: FlatPG) -> _PGSplitResult: for split_nodes in graph.produce_split_locations(): _log.debug(f"DeepGraphSplitter.split: evaluating split nodes: {split_nodes=}") - secondary_graphs: List[_SubGraphData] = [] - graph_to_split = graph - for split_node_id in split_nodes: - up, down = graph_to_split.split_at(split_node_id=split_node_id) - # Use upstream graph as secondary graph - node_ids = set(nid for nid, _ in up.iter_nodes()) - backend_candidates = up.get_backend_candidates_for_node_set(node_ids) - # TODO: better backend selection? - # TODO handle case where backend_candidates is None? - backend_id = sorted(backend_candidates)[0] - _log.debug( - f"DeepGraphSplitter.split: secondary graph: from {split_node_id=}: {backend_id=} {node_ids=}" - ) - secondary_graphs.append( - _SubGraphData( - split_node=split_node_id, - node_ids=node_ids, - backend_id=backend_id, - ) - ) + split_views = graph.split_at_multiple(split_nodes=split_nodes) - # Prepare for next split (if any) - graph_to_split = down + # Extract nodes and backend ids for each subgraph + subgraph_node_ids = {k: set(n for n, _ in v.iter_nodes()) for k, v in split_views.items()} + subgraph_backend_ids = { + k: self._pick_backend(backend_candidates=v.get_backend_candidates_for_node_set(subgraph_node_ids[k])) + for k, v in split_views.items() + } + _log.debug(f"DeepGraphSplitter.split: {subgraph_node_ids=} {subgraph_backend_ids=}") + + # Handle primary graph + split_views.pop(None) + primary_node_ids = subgraph_node_ids.pop(None) + primary_backend_id = subgraph_backend_ids.pop(None) - # Remaining graph is primary graph - primary_graph = graph_to_split - primary_node_ids = set(n for n, _ in primary_graph.iter_nodes()) - backend_candidates = primary_graph.get_backend_candidates_for_node_set(primary_node_ids) - primary_backend_id = sorted(backend_candidates)[0] - _log.debug(f"DeepGraphSplitter.split: primary graph: {primary_backend_id=} {primary_node_ids=}") + # Handle secondary graphs + secondary_graphs = [ + _PGSplitSubGraph(split_node=k, node_ids=subgraph_node_ids[k], backend_id=subgraph_backend_ids[k]) + for k in split_views.keys() + ] if self._primary_backend is None or primary_backend_id == self._primary_backend: _log.debug(f"DeepGraphSplitter.split: current split matches constraints") diff --git a/tests/partitionedjobs/test_crossbackend.py b/tests/partitionedjobs/test_crossbackend.py index 2060e17..0be8664 100644 --- a/tests/partitionedjobs/test_crossbackend.py +++ b/tests/partitionedjobs/test_crossbackend.py @@ -22,7 +22,7 @@ _GraphViewer, _GVNode, _PGSplitResult, - _SubGraphData, + _PGSplitSubGraph, run_partitioned_job, ) @@ -786,6 +786,13 @@ def test_split_at_basic(self): def test_split_at_complex(self): graph = _GraphViewer.from_edges( + # a + # / \ + # b c X + # \ / \ | + # d e f Y + # \ / + # g [("a", "b"), ("a", "c"), ("b", "d"), ("c", "d"), ("c", "e"), ("e", "g"), ("f", "g"), ("X", "Y")] ) up, down = graph.split_at("e") @@ -830,6 +837,44 @@ def test_split_at_non_articulation_point(self): ("c", _GVNode()), ] + def test_split_at_multiple_empty(self): + graph = _GraphViewer.from_edges([("a", "b")]) + result = graph.split_at_multiple([]) + assert {n: sorted(g.iter_nodes()) for (n, g) in result.items()} == { + None: [("a", _GVNode(flows_to="b")), ("b", _GVNode(depends_on="a"))], + } + + def test_split_at_multiple_single(self): + graph = _GraphViewer.from_edges([("a", "b"), ("b", "c")]) + result = graph.split_at_multiple(["b"]) + assert {n: sorted(g.iter_nodes()) for (n, g) in result.items()} == { + "b": [("a", _GVNode(flows_to="b")), ("b", _GVNode(depends_on="a"))], + None: [("b", _GVNode(flows_to="c")), ("c", _GVNode(depends_on="b"))], + } + + def test_split_at_multiple_basic(self): + graph = _GraphViewer.from_edges( + [("a", "b"), ("b", "c"), ("c", "d")], + supporting_backends_mapper=supporting_backends_from_node_id_dict({"a": "A"}), + ) + result = graph.split_at_multiple(["b", "c"]) + assert {n: sorted(g.iter_nodes()) for (n, g) in result.items()} == { + "b": [("a", _GVNode(flows_to="b", backend_candidates="A")), ("b", _GVNode(depends_on="a"))], + "c": [("b", _GVNode(flows_to="c")), ("c", _GVNode(depends_on="b"))], + None: [("c", _GVNode(flows_to="d")), ("d", _GVNode(depends_on="c"))], + } + + def test_split_at_multiple_invalid(self): + """Split nodes should be in downstream order""" + graph = _GraphViewer.from_edges( + [("a", "b"), ("b", "c"), ("c", "d")], + ) + # Downstream order: works + _ = graph.split_at_multiple(["b", "c"]) + # Upstream order: fails + with pytest.raises(GraphSplitException, match="Invalid node id 'b'"): + _ = graph.split_at_multiple(["c", "b"]) + def test_produce_split_locations_simple(self): """Simple produce_split_locations use case: no need for splits""" flat = { @@ -956,7 +1001,7 @@ def test_simple_split(self): primary_node_ids={"lc1", "lc2", "merge"}, primary_backend_id="b2", secondary_graphs=[ - _SubGraphData( + _PGSplitSubGraph( split_node="lc1", node_ids={"lc1"}, backend_id="b1", @@ -995,7 +1040,7 @@ def test_simple_deep_split(self): assert result == _PGSplitResult( primary_node_ids={"lc2", "temporal2", "bands1", "merge"}, primary_backend_id="b2", - secondary_graphs=[_SubGraphData(split_node="bands1", node_ids={"lc1", "bands1"}, backend_id="b1")], + secondary_graphs=[_PGSplitSubGraph(split_node="bands1", node_ids={"lc1", "bands1"}, backend_id="b1")], ) def test_shallow_triple_split(self): @@ -1026,8 +1071,8 @@ def test_shallow_triple_split(self): primary_node_ids={"lc1", "lc2", "lc3", "merge1", "merge2"}, primary_backend_id="b2", secondary_graphs=[ - _SubGraphData(split_node="lc1", node_ids={"lc1"}, backend_id="b1"), - _SubGraphData(split_node="lc3", node_ids={"lc3"}, backend_id="b3"), + _PGSplitSubGraph(split_node="lc1", node_ids={"lc1"}, backend_id="b1"), + _PGSplitSubGraph(split_node="lc3", node_ids={"lc3"}, backend_id="b3"), ], ) @@ -1067,16 +1112,18 @@ def test_triple_split(self): primary_node_ids={"merge2", "merge1", "lc3", "spatial3"}, primary_backend_id="b3", secondary_graphs=[ - _SubGraphData(split_node="bands1", node_ids={"bands1", "lc1"}, backend_id="b1"), - _SubGraphData(split_node="merge1", node_ids={"bands1", "merge1", "temporal2", "lc2"}, backend_id="b2"), + _PGSplitSubGraph(split_node="bands1", node_ids={"bands1", "lc1"}, backend_id="b1"), + _PGSplitSubGraph( + split_node="merge1", node_ids={"bands1", "merge1", "temporal2", "lc2"}, backend_id="b2" + ), ], ) @pytest.mark.parametrize( ["primary_backend", "secondary_graph"], [ - ("b1", _SubGraphData(split_node="lc2", node_ids={"lc2"}, backend_id="b2")), - ("b2", _SubGraphData(split_node="lc1", node_ids={"lc1"}, backend_id="b1")), + ("b1", _PGSplitSubGraph(split_node="lc2", node_ids={"lc2"}, backend_id="b2")), + ("b2", _PGSplitSubGraph(split_node="lc1", node_ids={"lc1"}, backend_id="b1")), ], ) def test_split_with_primary_backend(self, primary_backend, secondary_graph):