From 95e3c307b73a5a1f6698bf40ed24b76f04c9b1b2 Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Thu, 24 Oct 2024 15:25:02 -0700 Subject: [PATCH] nx-cugraph: faster `shortest_path` For larger graphs, nearly all the time was spent creating the dict of lists of paths. I couldn't find a better way to create these, nor could I find an approach to compute more in PLC or cupy. So, the solution in this PR is to avoid computing until needed! This now returns a `Mapping` instead of a `dict`. Will anybody care or notice that the return type isn't strictly a dict? This is currently only for unweighted bfs. We should also do weighted sssp paths. Also, this currently recurses. If we like the approach in this PR, we should update computing the paths on demand to not recurse. --- .../algorithms/shortest_paths/unweighted.py | 82 +++++++++++++------ 1 file changed, 57 insertions(+), 25 deletions(-) diff --git a/python/nx-cugraph/nx_cugraph/algorithms/shortest_paths/unweighted.py b/python/nx-cugraph/nx_cugraph/algorithms/shortest_paths/unweighted.py index e9c515632ca..d6bf1dd4a30 100644 --- a/python/nx-cugraph/nx_cugraph/algorithms/shortest_paths/unweighted.py +++ b/python/nx-cugraph/nx_cugraph/algorithms/shortest_paths/unweighted.py @@ -10,6 +10,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import collections import itertools import cupy as cp @@ -19,7 +20,7 @@ from nx_cugraph import _nxver from nx_cugraph.convert import _to_graph -from nx_cugraph.utils import _groupby, index_dtype, networkx_algorithm +from nx_cugraph.utils import index_dtype, networkx_algorithm __all__ = [ "bidirectional_shortest_path", @@ -179,35 +180,66 @@ def _bfs( elif not reverse_path: paths.reverse() else: - if return_type == "path": - distances = distances[mask] - groups = _groupby(distances, [predecessors[mask], node_ids]) - - # `pred_node_iter` does the equivalent as these nested for loops: - # for length in range(1, len(groups)): - # preds, nodes = groups[length] - # for pred, node in zip(preds.tolist(), nodes.tolist()): - if G.key_to_id is None: - pred_node_iter = concat( - zip(*(x.tolist() for x in groups[length])) - for length in range(1, len(groups)) - ) - else: - pred_node_iter = concat( - zip(*(G._nodeiter_to_iter(x.tolist()) for x in groups[length])) - for length in range(1, len(groups)) - ) - # Consider making utility functions for creating paths - paths = {source: [source]} + key_iter = node_ids.tolist() + pred_iter = predecessors[mask].tolist() + if G.key_to_id is not None: + key_iter = G._nodeiter_to_iter(key_iter) + pred_iter = G._nodeiter_to_iter(pred_iter) + key_to_pred = dict(zip(key_iter, pred_iter)) + key_to_pred[source] = None if reverse_path: - for pred, node in pred_node_iter: - paths[node] = [node, *paths[pred]] + paths = ReversePathMapping({source: [source]}, key_to_pred) else: - for pred, node in pred_node_iter: - paths[node] = [*paths[pred], node] + paths = PathMapping({source: [source]}, key_to_pred) if return_type == "path": return paths if return_type == "length": return lengths # return_type == "length-path" return lengths, paths + + +class PathMapping(collections.abc.Mapping): + """Compute path for nodes as needed using predecessors. + + The path for each node contains itself at the beginning of tha path. + """ + + def __init__(self, data, key_to_pred): + self._data = data + self._key_to_pred = key_to_pred + + def __getitem__(self, key): + if key not in self._data: + val = self._data[key] = [*self[self._key_to_pred[key]], key] + return val + return self._data[key] + + def __iter__(self): + return iter(self._key_to_pred) + + def __len__(self): + return len(self._key_to_pred) + + +class ReversePathMapping(collections.abc.Mapping): + """Compute path for nodes as needed using predecessors. + + The path for each node contains itself at the end of tha path. + """ + + def __init__(self, data, key_to_pred): + self._data = data + self._key_to_pred = key_to_pred + + def __getitem__(self, key): + if key not in self._data: + val = self._data[key] = [key, *self[self._key_to_pred[key]]] + return val + return self._data[key] + + def __iter__(self): + return iter(self._key_to_pred) + + def __len__(self): + return len(self._key_to_pred)