From 6c88281b6a58bac05a2c29e5c9ab6b952d3a38c6 Mon Sep 17 00:00:00 2001 From: Rick Ratzel <3039903+rlratzel@users.noreply.github.com> Date: Wed, 13 Mar 2024 09:16:46 -0500 Subject: [PATCH 1/3] Adds nx-cugraph benchmarks for APIs added to prior releases that were never benchmarked (#4228) * Adds benchmark for nx-cugraph `pagerank` with a personalization dict (see results below) * Adds several placeholder benchmarks for other nx-cugraph APIs added to prior releases that were never benchmarked. ![image](https://github.com/rapidsai/cugraph/assets/3039903/4692e2a2-e14a-489d-84f7-772eda6fc316) Authors: - Rick Ratzel (https://github.com/rlratzel) - Ralph Liu (https://github.com/nv-rliu) Approvers: - Erik Welch (https://github.com/eriknw) URL: https://github.com/rapidsai/cugraph/pull/4228 --- .../nx-cugraph/pytest-based/bench_algos.py | 112 ++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/benchmarks/nx-cugraph/pytest-based/bench_algos.py b/benchmarks/nx-cugraph/pytest-based/bench_algos.py index 97eb32e2aaa..3b085a9bfdb 100644 --- a/benchmarks/nx-cugraph/pytest-based/bench_algos.py +++ b/benchmarks/nx-cugraph/pytest-based/bench_algos.py @@ -242,6 +242,28 @@ def get_highest_degree_node(graph_obj): return max(degrees, key=lambda t: t[1])[0] +def build_personalization_dict(pagerank_dict): + """ + Returns a dictionary that can be used as the personalization value for a + call to nx.pagerank(). The pagerank_dict passed in is used as the initial + source of values for each node, and this function simply treats the list of + dict values as two halves (halves A and B) and swaps them so (most if not + all) nodes/keys are assigned a different value from the dictionary. + """ + num_half = len(pagerank_dict) // 2 + A_half_items = list(pagerank_dict.items())[:num_half] + B_half_items = list(pagerank_dict.items())[num_half:] + + # Support an odd number of items by initializing with B_half_items, which + # will always be one bigger if the number of items is odd. This will leave + # the one remainder (in the case of an odd number) unchanged. + pers_dict = dict(B_half_items) + pers_dict.update({A_half_items[i][0]: B_half_items[i][1] for i in range(num_half)}) + pers_dict.update({B_half_items[i][0]: A_half_items[i][1] for i in range(num_half)}) + + return pers_dict + + ################################################################################ # Benchmarks def bench_from_networkx(benchmark, graph_obj): @@ -431,6 +453,26 @@ def bench_pagerank(benchmark, graph_obj, backend_wrapper): assert type(result) is dict +def bench_pagerank_personalized(benchmark, graph_obj, backend_wrapper): + G = get_graph_obj_for_benchmark(graph_obj, backend_wrapper) + + # FIXME: This will run for every combination of inputs, even if the + # graph/dataset does not change. Ideally this is run once per + # graph/dataset. + pagerank_dict = nx.pagerank(G) + personalization_dict = build_personalization_dict(pagerank_dict) + + result = benchmark.pedantic( + target=backend_wrapper(nx.pagerank), + args=(G,), + kwargs={"personalization": personalization_dict}, + rounds=rounds, + iterations=iterations, + warmup_rounds=warmup_rounds, + ) + assert type(result) is dict + + def bench_single_source_shortest_path_length(benchmark, graph_obj, backend_wrapper): G = get_graph_obj_for_benchmark(graph_obj, backend_wrapper) node = get_highest_degree_node(graph_obj) @@ -804,3 +846,73 @@ def bench_weakly_connected_components(benchmark, graph_obj, backend_wrapper): warmup_rounds=warmup_rounds, ) assert type(result) is list + + +@pytest.mark.skip(reason="benchmark not implemented") +def bench_complete_bipartite_graph(benchmark, graph_obj, backend_wrapper): + pass + + +@pytest.mark.skip(reason="benchmark not implemented") +def bench_connected_components(benchmark, graph_obj, backend_wrapper): + pass + + +@pytest.mark.skip(reason="benchmark not implemented") +def bench_is_connected(benchmark, graph_obj, backend_wrapper): + pass + + +@pytest.mark.skip(reason="benchmark not implemented") +def bench_node_connected_component(benchmark, graph_obj, backend_wrapper): + pass + + +@pytest.mark.skip(reason="benchmark not implemented") +def bench_number_connected_components(benchmark, graph_obj, backend_wrapper): + pass + + +@pytest.mark.skip(reason="benchmark not implemented") +def bench_is_isolate(benchmark, graph_obj, backend_wrapper): + pass + + +@pytest.mark.skip(reason="benchmark not implemented") +def bench_isolates(benchmark, graph_obj, backend_wrapper): + pass + + +@pytest.mark.skip(reason="benchmark not implemented") +def bench_number_of_isolates(benchmark, graph_obj, backend_wrapper): + pass + + +@pytest.mark.skip(reason="benchmark not implemented") +def bench_complement(benchmark, graph_obj, backend_wrapper): + pass + + +@pytest.mark.skip(reason="benchmark not implemented") +def bench_reverse(benchmark, graph_obj, backend_wrapper): + pass + + +@pytest.mark.skip(reason="benchmark not implemented") +def bench_is_arborescence(benchmark, graph_obj, backend_wrapper): + pass + + +@pytest.mark.skip(reason="benchmark not implemented") +def bench_is_branching(benchmark, graph_obj, backend_wrapper): + pass + + +@pytest.mark.skip(reason="benchmark not implemented") +def bench_is_forest(benchmark, graph_obj, backend_wrapper): + pass + + +@pytest.mark.skip(reason="benchmark not implemented") +def bench_is_tree(benchmark, graph_obj, backend_wrapper): + pass From 6b28aefcabbe10a44de38bd1b3e54f3c717dd559 Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Wed, 13 Mar 2024 09:18:37 -0500 Subject: [PATCH 2/3] nx-cugraph: add more shortest path algorithms (#4199) This begins by adding more unweighted shortest path algorithms. Next we'll do weighted via `sssp`, then generic. Note that there are some performance improvements that can be made: - add bidirectional search between source and target - for `bidirectional_shortest_path` and `has_path` - alternatively, perform `bfs` from `source` until `target` is reached - run `all_pairs*` in batched groups Authors: - Erik Welch (https://github.com/eriknw) - Ralph Liu (https://github.com/nv-rliu) Approvers: - Brad Rees (https://github.com/BradReesWork) - Rick Ratzel (https://github.com/rlratzel) - Don Acosta (https://github.com/acostadon) URL: https://github.com/rapidsai/cugraph/pull/4199 --- python/nx-cugraph/README.md | 25 +- python/nx-cugraph/_nx_cugraph/__init__.py | 52 +++- python/nx-cugraph/lint.yaml | 4 +- .../nx_cugraph/algorithms/__init__.py | 2 +- .../algorithms/bipartite/__init__.py | 1 - .../nx_cugraph/algorithms/bipartite/basic.py | 31 -- .../algorithms/centrality/eigenvector.py | 9 +- .../nx_cugraph/algorithms/centrality/katz.py | 9 +- .../algorithms/link_analysis/hits_alg.py | 9 +- .../algorithms/link_analysis/pagerank_alg.py | 7 +- .../algorithms/shortest_paths/__init__.py | 4 +- .../algorithms/shortest_paths/generic.py | 165 ++++++++++ .../algorithms/shortest_paths/unweighted.py | 174 ++++++++++- .../algorithms/shortest_paths/weighted.py | 286 ++++++++++++++++++ python/nx-cugraph/nx_cugraph/interface.py | 14 + python/nx-cugraph/nx_cugraph/utils/misc.py | 14 +- python/nx-cugraph/scripts/update_readme.py | 0 17 files changed, 723 insertions(+), 83 deletions(-) delete mode 100644 python/nx-cugraph/nx_cugraph/algorithms/bipartite/basic.py create mode 100644 python/nx-cugraph/nx_cugraph/algorithms/shortest_paths/generic.py create mode 100644 python/nx-cugraph/nx_cugraph/algorithms/shortest_paths/weighted.py mode change 100644 => 100755 python/nx-cugraph/scripts/update_readme.py diff --git a/python/nx-cugraph/README.md b/python/nx-cugraph/README.md index 8201dc34eb2..1bf310c8c88 100644 --- a/python/nx-cugraph/README.md +++ b/python/nx-cugraph/README.md @@ -95,8 +95,6 @@ Below is the list of algorithms that are currently supported in nx-cugraph.
 bipartite
- ├─ basic
- │   └─ is_bipartite
  └─ generators
      └─ complete_bipartite_graph
 centrality
@@ -152,9 +150,26 @@ Below is the list of algorithms that are currently supported in nx-cugraph.
  ├─ overall_reciprocity
  └─ reciprocity
 shortest_paths
- └─ unweighted
-     ├─ single_source_shortest_path_length
-     └─ single_target_shortest_path_length
+ ├─ generic
+ │   ├─ has_path
+ │   ├─ shortest_path
+ │   └─ shortest_path_length
+ ├─ unweighted
+ │   ├─ all_pairs_shortest_path
+ │   ├─ all_pairs_shortest_path_length
+ │   ├─ bidirectional_shortest_path
+ │   ├─ single_source_shortest_path
+ │   ├─ single_source_shortest_path_length
+ │   ├─ single_target_shortest_path
+ │   └─ single_target_shortest_path_length
+ └─ weighted
+     ├─ all_pairs_bellman_ford_path
+     ├─ all_pairs_bellman_ford_path_length
+     ├─ bellman_ford_path
+     ├─ bellman_ford_path_length
+     ├─ single_source_bellman_ford
+     ├─ single_source_bellman_ford_path
+     └─ single_source_bellman_ford_path_length
 traversal
  └─ breadth_first_search
      ├─ bfs_edges
diff --git a/python/nx-cugraph/_nx_cugraph/__init__.py b/python/nx-cugraph/_nx_cugraph/__init__.py
index b2f13d25ff3..bc7f63fcd49 100644
--- a/python/nx-cugraph/_nx_cugraph/__init__.py
+++ b/python/nx-cugraph/_nx_cugraph/__init__.py
@@ -33,15 +33,22 @@
     # "description": "TODO",
     "functions": {
         # BEGIN: functions
+        "all_pairs_bellman_ford_path",
+        "all_pairs_bellman_ford_path_length",
+        "all_pairs_shortest_path",
+        "all_pairs_shortest_path_length",
         "ancestors",
         "average_clustering",
         "barbell_graph",
+        "bellman_ford_path",
+        "bellman_ford_path_length",
         "betweenness_centrality",
         "bfs_edges",
         "bfs_layers",
         "bfs_predecessors",
         "bfs_successors",
         "bfs_tree",
+        "bidirectional_shortest_path",
         "bull_graph",
         "caveman_graph",
         "chvatal_graph",
@@ -70,6 +77,7 @@
         "from_scipy_sparse_array",
         "frucht_graph",
         "generic_bfs_edges",
+        "has_path",
         "heawood_graph",
         "hits",
         "house_graph",
@@ -77,7 +85,6 @@
         "icosahedral_graph",
         "in_degree_centrality",
         "is_arborescence",
-        "is_bipartite",
         "is_branching",
         "is_connected",
         "is_forest",
@@ -110,7 +117,14 @@
         "reciprocity",
         "reverse",
         "sedgewick_maze_graph",
+        "shortest_path",
+        "shortest_path_length",
+        "single_source_bellman_ford",
+        "single_source_bellman_ford_path",
+        "single_source_bellman_ford_path_length",
+        "single_source_shortest_path",
         "single_source_shortest_path_length",
+        "single_target_shortest_path",
         "single_target_shortest_path_length",
         "star_graph",
         "tadpole_graph",
@@ -128,7 +142,11 @@
     },
     "additional_docs": {
         # BEGIN: additional_docs
+        "all_pairs_bellman_ford_path": "Negative cycles are not yet supported. ``NotImplementedError`` will be raised if there are negative edge weights. We plan to support negative edge weights soon. Also, callable ``weight`` argument is not supported.",
+        "all_pairs_bellman_ford_path_length": "Negative cycles are not yet supported. ``NotImplementedError`` will be raised if there are negative edge weights. We plan to support negative edge weights soon. Also, callable ``weight`` argument is not supported.",
         "average_clustering": "Directed graphs and `weight` parameter are not yet supported.",
+        "bellman_ford_path": "Negative cycles are not yet supported. ``NotImplementedError`` will be raised if there are negative edge weights. We plan to support negative edge weights soon. Also, callable ``weight`` argument is not supported.",
+        "bellman_ford_path_length": "Negative cycles are not yet supported. ``NotImplementedError`` will be raised if there are negative edge weights. We plan to support negative edge weights soon. Also, callable ``weight`` argument is not supported.",
         "betweenness_centrality": "`weight` parameter is not yet supported, and RNG with seed may be different.",
         "bfs_edges": "`sort_neighbors` parameter is not yet supported.",
         "bfs_predecessors": "`sort_neighbors` parameter is not yet supported.",
@@ -147,11 +165,28 @@
         "katz_centrality": "`nstart` isn't used (but is checked), and `normalized=False` is not supported.",
         "louvain_communities": "`seed` parameter is currently ignored, and self-loops are not yet supported.",
         "pagerank": "`dangling` parameter is not supported, but it is checked for validity.",
+        "shortest_path": "Negative weights are not yet supported, and method is ununsed.",
+        "shortest_path_length": "Negative weights are not yet supported, and method is ununsed.",
+        "single_source_bellman_ford": "Negative cycles are not yet supported. ``NotImplementedError`` will be raised if there are negative edge weights. We plan to support negative edge weights soon. Also, callable ``weight`` argument is not supported.",
+        "single_source_bellman_ford_path": "Negative cycles are not yet supported. ``NotImplementedError`` will be raised if there are negative edge weights. We plan to support negative edge weights soon. Also, callable ``weight`` argument is not supported.",
+        "single_source_bellman_ford_path_length": "Negative cycles are not yet supported. ``NotImplementedError`` will be raised if there are negative edge weights. We plan to support negative edge weights soon. Also, callable ``weight`` argument is not supported.",
         "transitivity": "Directed graphs are not yet supported.",
         # END: additional_docs
     },
     "additional_parameters": {
         # BEGIN: additional_parameters
+        "all_pairs_bellman_ford_path": {
+            "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.",
+        },
+        "all_pairs_bellman_ford_path_length": {
+            "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.",
+        },
+        "bellman_ford_path": {
+            "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.",
+        },
+        "bellman_ford_path_length": {
+            "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.",
+        },
         "eigenvector_centrality": {
             "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.",
         },
@@ -169,6 +204,21 @@
         "pagerank": {
             "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.",
         },
+        "shortest_path": {
+            "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.",
+        },
+        "shortest_path_length": {
+            "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.",
+        },
+        "single_source_bellman_ford": {
+            "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.",
+        },
+        "single_source_bellman_ford_path": {
+            "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.",
+        },
+        "single_source_bellman_ford_path_length": {
+            "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.",
+        },
         # END: additional_parameters
     },
 }
diff --git a/python/nx-cugraph/lint.yaml b/python/nx-cugraph/lint.yaml
index fdd24861da7..3239fa151d9 100644
--- a/python/nx-cugraph/lint.yaml
+++ b/python/nx-cugraph/lint.yaml
@@ -50,7 +50,7 @@ repos:
       - id: black
       # - id: black-jupyter
   - repo: https://github.com/astral-sh/ruff-pre-commit
-    rev: v0.2.2
+    rev: v0.3.2
     hooks:
       - id: ruff
         args: [--fix-only, --show-fixes]  # --unsafe-fixes]
@@ -77,7 +77,7 @@ repos:
         additional_dependencies: [tomli]
         files: ^(nx_cugraph|docs)/
   - repo: https://github.com/astral-sh/ruff-pre-commit
-    rev: v0.2.2
+    rev: v0.3.2
     hooks:
       - id: ruff
   - repo: https://github.com/pre-commit/pre-commit-hooks
diff --git a/python/nx-cugraph/nx_cugraph/algorithms/__init__.py b/python/nx-cugraph/nx_cugraph/algorithms/__init__.py
index 7aafa85f5b7..b4a10bcf0a1 100644
--- a/python/nx-cugraph/nx_cugraph/algorithms/__init__.py
+++ b/python/nx-cugraph/nx_cugraph/algorithms/__init__.py
@@ -22,7 +22,7 @@
     traversal,
     tree,
 )
-from .bipartite import complete_bipartite_graph, is_bipartite
+from .bipartite import complete_bipartite_graph
 from .centrality import *
 from .cluster import *
 from .components import *
diff --git a/python/nx-cugraph/nx_cugraph/algorithms/bipartite/__init__.py b/python/nx-cugraph/nx_cugraph/algorithms/bipartite/__init__.py
index e028299c675..bfc7f1d4d42 100644
--- a/python/nx-cugraph/nx_cugraph/algorithms/bipartite/__init__.py
+++ b/python/nx-cugraph/nx_cugraph/algorithms/bipartite/__init__.py
@@ -10,5 +10,4 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from .basic import *
 from .generators import *
diff --git a/python/nx-cugraph/nx_cugraph/algorithms/bipartite/basic.py b/python/nx-cugraph/nx_cugraph/algorithms/bipartite/basic.py
deleted file mode 100644
index 46c6b54075b..00000000000
--- a/python/nx-cugraph/nx_cugraph/algorithms/bipartite/basic.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# Copyright (c) 2024, NVIDIA CORPORATION.
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import cupy as cp
-
-from nx_cugraph.algorithms.cluster import _triangles
-from nx_cugraph.convert import _to_graph
-from nx_cugraph.utils import networkx_algorithm
-
-__all__ = [
-    "is_bipartite",
-]
-
-
-@networkx_algorithm(version_added="24.02", _plc="triangle_count")
-def is_bipartite(G):
-    G = _to_graph(G)
-    # Counting triangles may not be the fastest way to do this, but it is simple.
-    node_ids, triangles, is_single_node = _triangles(
-        G, None, symmetrize="union" if G.is_directed() else None
-    )
-    return int(cp.count_nonzero(triangles)) == 0
diff --git a/python/nx-cugraph/nx_cugraph/algorithms/centrality/eigenvector.py b/python/nx-cugraph/nx_cugraph/algorithms/centrality/eigenvector.py
index 65a8633667a..c32b6fbb708 100644
--- a/python/nx-cugraph/nx_cugraph/algorithms/centrality/eigenvector.py
+++ b/python/nx-cugraph/nx_cugraph/algorithms/centrality/eigenvector.py
@@ -36,17 +36,12 @@ def eigenvector_centrality(
     G, max_iter=100, tol=1.0e-6, nstart=None, weight=None, *, dtype=None
 ):
     """`nstart` parameter is not used, but it is checked for validity."""
-    G = _to_graph(G, weight, np.float32)
+    G = _to_graph(G, weight, 1, np.float32)
     if len(G) == 0:
         raise nx.NetworkXPointlessConcept(
             "cannot compute centrality for the null graph"
         )
-    if dtype is not None:
-        dtype = _get_float_dtype(dtype)
-    elif weight in G.edge_values:
-        dtype = _get_float_dtype(G.edge_values[weight].dtype)
-    else:
-        dtype = np.float32
+    dtype = _get_float_dtype(dtype, graph=G, weight=weight)
     if nstart is not None:
         # Check if given nstart is valid even though we don't use it
         nstart = G._dict_to_nodearray(nstart, dtype=dtype)
diff --git a/python/nx-cugraph/nx_cugraph/algorithms/centrality/katz.py b/python/nx-cugraph/nx_cugraph/algorithms/centrality/katz.py
index 4a0684f72ee..1c6ed61703d 100644
--- a/python/nx-cugraph/nx_cugraph/algorithms/centrality/katz.py
+++ b/python/nx-cugraph/nx_cugraph/algorithms/centrality/katz.py
@@ -49,15 +49,10 @@ def katz_centrality(
         # Redundant with the `_can_run` check below when being dispatched by NetworkX,
         # but we raise here in case this funcion is called directly.
         raise NotImplementedError("normalized=False is not supported.")
-    G = _to_graph(G, weight, np.float32)
+    G = _to_graph(G, weight, 1, np.float32)
     if (N := len(G)) == 0:
         return {}
-    if dtype is not None:
-        dtype = _get_float_dtype(dtype)
-    elif weight in G.edge_values:
-        dtype = _get_float_dtype(G.edge_values[weight].dtype)
-    else:
-        dtype = np.float32
+    dtype = _get_float_dtype(dtype, graph=G, weight=weight)
     if nstart is not None:
         # Check if given nstart is valid even though we don't use it
         nstart = G._dict_to_nodearray(nstart, 0, dtype)
diff --git a/python/nx-cugraph/nx_cugraph/algorithms/link_analysis/hits_alg.py b/python/nx-cugraph/nx_cugraph/algorithms/link_analysis/hits_alg.py
index e61a931c069..e529b83ab1a 100644
--- a/python/nx-cugraph/nx_cugraph/algorithms/link_analysis/hits_alg.py
+++ b/python/nx-cugraph/nx_cugraph/algorithms/link_analysis/hits_alg.py
@@ -46,15 +46,10 @@ def hits(
     weight="weight",
     dtype=None,
 ):
-    G = _to_graph(G, weight, np.float32)
+    G = _to_graph(G, weight, 1, np.float32)
     if (N := len(G)) == 0:
         return {}, {}
-    if dtype is not None:
-        dtype = _get_float_dtype(dtype)
-    elif weight in G.edge_values:
-        dtype = _get_float_dtype(G.edge_values[weight].dtype)
-    else:
-        dtype = np.float32
+    dtype = _get_float_dtype(dtype, graph=G, weight=weight)
     if nstart is not None:
         nstart = G._dict_to_nodearray(nstart, 0, dtype)
     if max_iter <= 0:
diff --git a/python/nx-cugraph/nx_cugraph/algorithms/link_analysis/pagerank_alg.py b/python/nx-cugraph/nx_cugraph/algorithms/link_analysis/pagerank_alg.py
index 40224e91d57..41203a2bc22 100644
--- a/python/nx-cugraph/nx_cugraph/algorithms/link_analysis/pagerank_alg.py
+++ b/python/nx-cugraph/nx_cugraph/algorithms/link_analysis/pagerank_alg.py
@@ -48,12 +48,7 @@ def pagerank(
     G = _to_graph(G, weight, 1, np.float32)
     if (N := len(G)) == 0:
         return {}
-    if dtype is not None:
-        dtype = _get_float_dtype(dtype)
-    elif weight in G.edge_values:
-        dtype = _get_float_dtype(G.edge_values[weight].dtype)
-    else:
-        dtype = np.float32
+    dtype = _get_float_dtype(dtype, graph=G, weight=weight)
     if nstart is not None:
         nstart = G._dict_to_nodearray(nstart, 0, dtype=dtype)
         if (total := nstart.sum()) == 0:
diff --git a/python/nx-cugraph/nx_cugraph/algorithms/shortest_paths/__init__.py b/python/nx-cugraph/nx_cugraph/algorithms/shortest_paths/__init__.py
index b7d6b742176..9d87389a98e 100644
--- a/python/nx-cugraph/nx_cugraph/algorithms/shortest_paths/__init__.py
+++ b/python/nx-cugraph/nx_cugraph/algorithms/shortest_paths/__init__.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2023, NVIDIA CORPORATION.
+# Copyright (c) 2023-2024, NVIDIA CORPORATION.
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
@@ -10,4 +10,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from .generic import *
 from .unweighted import *
+from .weighted import *
diff --git a/python/nx-cugraph/nx_cugraph/algorithms/shortest_paths/generic.py b/python/nx-cugraph/nx_cugraph/algorithms/shortest_paths/generic.py
new file mode 100644
index 00000000000..68dbbace93d
--- /dev/null
+++ b/python/nx-cugraph/nx_cugraph/algorithms/shortest_paths/generic.py
@@ -0,0 +1,165 @@
+# Copyright (c) 2024, NVIDIA CORPORATION.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import networkx as nx
+import numpy as np
+
+import nx_cugraph as nxcg
+from nx_cugraph.convert import _to_graph
+from nx_cugraph.utils import _dtype_param, _get_float_dtype, networkx_algorithm
+
+from .unweighted import _bfs
+from .weighted import _sssp
+
+__all__ = [
+    "shortest_path",
+    "shortest_path_length",
+    "has_path",
+]
+
+
+@networkx_algorithm(version_added="24.04", _plc="bfs")
+def has_path(G, source, target):
+    # TODO PERF: make faster in core
+    try:
+        nxcg.bidirectional_shortest_path(G, source, target)
+    except nx.NetworkXNoPath:
+        return False
+    return True
+
+
+@networkx_algorithm(
+    extra_params=_dtype_param, version_added="24.04", _plc={"bfs", "sssp"}
+)
+def shortest_path(
+    G, source=None, target=None, weight=None, method="dijkstra", *, dtype=None
+):
+    """Negative weights are not yet supported, and method is ununsed."""
+    if method not in {"dijkstra", "bellman-ford"}:
+        raise ValueError(f"method not supported: {method}")
+    if weight is None:
+        method = "unweighted"
+    if source is None:
+        if target is None:
+            # All pairs
+            if method == "unweighted":
+                paths = nxcg.all_pairs_shortest_path(G)
+            else:
+                # method == "dijkstra":
+                # method == 'bellman-ford':
+                paths = nxcg.all_pairs_bellman_ford_path(G, weight=weight, dtype=dtype)
+            if nx.__version__[:3] <= "3.4":
+                paths = dict(paths)
+        # To target
+        elif method == "unweighted":
+            paths = nxcg.single_target_shortest_path(G, target)
+        else:
+            # method == "dijkstra":
+            # method == 'bellman-ford':
+            # XXX: it seems weird that `reverse_path=True` is necessary here
+            G = _to_graph(G, weight, 1, np.float32)
+            dtype = _get_float_dtype(dtype, graph=G, weight=weight)
+            paths = _sssp(
+                G, target, weight, return_type="path", dtype=dtype, reverse_path=True
+            )
+    elif target is None:
+        # From source
+        if method == "unweighted":
+            paths = nxcg.single_source_shortest_path(G, source)
+        else:
+            # method == "dijkstra":
+            # method == 'bellman-ford':
+            paths = nxcg.single_source_bellman_ford_path(
+                G, source, weight=weight, dtype=dtype
+            )
+    # From source to target
+    elif method == "unweighted":
+        paths = nxcg.bidirectional_shortest_path(G, source, target)
+    else:
+        # method == "dijkstra":
+        # method == 'bellman-ford':
+        paths = nxcg.bellman_ford_path(G, source, target, weight, dtype=dtype)
+    return paths
+
+
+@shortest_path._can_run
+def _(G, source=None, target=None, weight=None, method="dijkstra", *, dtype=None):
+    return (
+        weight is None
+        or not callable(weight)
+        and not nx.is_negatively_weighted(G, weight=weight)
+    )
+
+
+@networkx_algorithm(
+    extra_params=_dtype_param, version_added="24.04", _plc={"bfs", "sssp"}
+)
+def shortest_path_length(
+    G, source=None, target=None, weight=None, method="dijkstra", *, dtype=None
+):
+    """Negative weights are not yet supported, and method is ununsed."""
+    if method not in {"dijkstra", "bellman-ford"}:
+        raise ValueError(f"method not supported: {method}")
+    if weight is None:
+        method = "unweighted"
+    if source is None:
+        if target is None:
+            # All pairs
+            if method == "unweighted":
+                lengths = nxcg.all_pairs_shortest_path_length(G)
+            else:
+                # method == "dijkstra":
+                # method == 'bellman-ford':
+                lengths = nxcg.all_pairs_bellman_ford_path_length(
+                    G, weight=weight, dtype=dtype
+                )
+        # To target
+        elif method == "unweighted":
+            lengths = nxcg.single_target_shortest_path_length(G, target)
+            if nx.__version__[:3] <= "3.4":
+                lengths = dict(lengths)
+        else:
+            # method == "dijkstra":
+            # method == 'bellman-ford':
+            lengths = nxcg.single_source_bellman_ford_path_length(
+                G, target, weight=weight, dtype=dtype
+            )
+    elif target is None:
+        # From source
+        if method == "unweighted":
+            lengths = nxcg.single_source_shortest_path_length(G, source)
+        else:
+            # method == "dijkstra":
+            # method == 'bellman-ford':
+            lengths = dict(
+                nxcg.single_source_bellman_ford_path_length(
+                    G, source, weight=weight, dtype=dtype
+                )
+            )
+    # From source to target
+    elif method == "unweighted":
+        G = _to_graph(G)
+        lengths = _bfs(G, source, None, "Source", return_type="length", target=target)
+    else:
+        # method == "dijkstra":
+        # method == 'bellman-ford':
+        lengths = nxcg.bellman_ford_path_length(G, source, target, weight, dtype=dtype)
+    return lengths
+
+
+@shortest_path_length._can_run
+def _(G, source=None, target=None, weight=None, method="dijkstra", *, dtype=None):
+    return (
+        weight is None
+        or not callable(weight)
+        and not nx.is_negatively_weighted(G, weight=weight)
+    )
diff --git a/python/nx-cugraph/nx_cugraph/algorithms/shortest_paths/unweighted.py b/python/nx-cugraph/nx_cugraph/algorithms/shortest_paths/unweighted.py
index 2012495953e..714289c5b4b 100644
--- a/python/nx-cugraph/nx_cugraph/algorithms/shortest_paths/unweighted.py
+++ b/python/nx-cugraph/nx_cugraph/algorithms/shortest_paths/unweighted.py
@@ -10,33 +10,127 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import itertools
+
 import cupy as cp
 import networkx as nx
 import numpy as np
 import pylibcugraph as plc
 
 from nx_cugraph.convert import _to_graph
-from nx_cugraph.utils import index_dtype, networkx_algorithm
+from nx_cugraph.utils import _groupby, index_dtype, networkx_algorithm
+
+__all__ = [
+    "bidirectional_shortest_path",
+    "single_source_shortest_path",
+    "single_source_shortest_path_length",
+    "single_target_shortest_path",
+    "single_target_shortest_path_length",
+    "all_pairs_shortest_path",
+    "all_pairs_shortest_path_length",
+]
 
-__all__ = ["single_source_shortest_path_length", "single_target_shortest_path_length"]
+concat = itertools.chain.from_iterable
 
 
 @networkx_algorithm(version_added="23.12", _plc="bfs")
 def single_source_shortest_path_length(G, source, cutoff=None):
-    return _single_shortest_path_length(G, source, cutoff, "Source")
+    G = _to_graph(G)
+    return _bfs(G, source, cutoff, "Source", return_type="length")
 
 
 @networkx_algorithm(version_added="23.12", _plc="bfs")
 def single_target_shortest_path_length(G, target, cutoff=None):
-    return _single_shortest_path_length(G, target, cutoff, "Target")
+    G = _to_graph(G)
+    rv = _bfs(G, target, cutoff, "Target", return_type="length")
+    if nx.__version__[:3] <= "3.4":
+        return iter(rv.items())
+    return rv
+
+
+@networkx_algorithm(version_added="24.04", _plc="bfs")
+def all_pairs_shortest_path_length(G, cutoff=None):
+    # TODO PERF: batched bfs to compute many at once
+    G = _to_graph(G)
+    for n in G:
+        yield (n, _bfs(G, n, cutoff, "Source", return_type="length"))
 
 
-def _single_shortest_path_length(G, source, cutoff, kind):
+@networkx_algorithm(version_added="24.04", _plc="bfs")
+def bidirectional_shortest_path(G, source, target):
+    # TODO PERF: do bidirectional traversal in core
     G = _to_graph(G)
+    if source not in G or target not in G:
+        raise nx.NodeNotFound(f"Either source {source} or target {target} is not in G")
+    return _bfs(G, source, None, "Source", return_type="path", target=target)
+
+
+@networkx_algorithm(version_added="24.04", _plc="bfs")
+def single_source_shortest_path(G, source, cutoff=None):
+    G = _to_graph(G)
+    return _bfs(G, source, cutoff, "Source", return_type="path")
+
+
+@networkx_algorithm(version_added="24.04", _plc="bfs")
+def single_target_shortest_path(G, target, cutoff=None):
+    G = _to_graph(G)
+    return _bfs(G, target, cutoff, "Target", return_type="path", reverse_path=True)
+
+
+@networkx_algorithm(version_added="24.04", _plc="bfs")
+def all_pairs_shortest_path(G, cutoff=None):
+    # TODO PERF: batched bfs to compute many at once
+    G = _to_graph(G)
+    for n in G:
+        yield (n, _bfs(G, n, cutoff, "Source", return_type="path"))
+
+
+def _bfs(
+    G, source, cutoff, kind, *, return_type, reverse_path=False, target=None, scale=None
+):
+    """BFS for unweighted shortest path algorithms.
+
+    Parameters
+    ----------
+    source: node label
+
+    cutoff: int, optional
+
+    kind: {"Source", "Target"}
+
+    return_type: {"length", "path", "length-path"}
+
+    reverse_path: bool
+
+    target: node label
+
+    scale: int or float, optional
+        The amount to scale the lengths
+    """
+    # DRY: _sssp in weighted.py has similar code
     if source not in G:
-        raise nx.NodeNotFound(f"{kind} {source} is not in G")
-    if G.src_indices.size == 0:
-        return {source: 0}
+        # Different message to pass networkx tests
+        if return_type == "length":
+            raise nx.NodeNotFound(f"{kind} {source} is not in G")
+        raise nx.NodeNotFound(f"{kind} {source} not in G")
+    if target is not None:
+        if source == target or cutoff is not None and cutoff <= 0:
+            if return_type == "path":
+                return [source]
+            if return_type == "length":
+                return 0
+            # return_type == "length-path"
+            return 0, [source]
+        if target not in G or G.src_indices.size == 0:
+            raise nx.NetworkXNoPath(f"Node {target} not reachable from {source}")
+    elif G.src_indices.size == 0 or cutoff is not None and cutoff <= 0:
+        if return_type == "path":
+            return {source: [source]}
+        if return_type == "length":
+            return {source: 0}
+        # return_type == "length-path"
+        return {source: 0}, {source: [source]}
+
     if cutoff is None:
         cutoff = -1
     src_index = source if G.key_to_id is None else G.key_to_id[source]
@@ -46,8 +140,68 @@ def _single_shortest_path_length(G, source, cutoff, kind):
         sources=cp.array([src_index], index_dtype),
         direction_optimizing=False,  # True for undirected only; what's recommended?
         depth_limit=cutoff,
-        compute_predecessors=False,
+        compute_predecessors=return_type != "length",
         do_expensive_check=False,
     )
     mask = distances != np.iinfo(distances.dtype).max
-    return G._nodearrays_to_dict(node_ids[mask], distances[mask])
+    node_ids = node_ids[mask]
+    if return_type != "path":
+        lengths = distances = distances[mask]
+        if scale is not None:
+            lengths = scale * lengths
+        lengths = G._nodearrays_to_dict(node_ids, lengths)
+        if target is not None:
+            if target not in lengths:
+                raise nx.NetworkXNoPath(f"Node {target} not reachable from {source}")
+            lengths = lengths[target]
+    if return_type != "length":
+        if target is not None:
+            d = dict(zip(node_ids.tolist(), predecessors[mask].tolist()))
+            dst_index = target if G.key_to_id is None else G.key_to_id[target]
+            if dst_index not in d:
+                raise nx.NetworkXNoPath(f"Node {target} not reachable from {source}")
+            cur = dst_index
+            paths = [dst_index]
+            while cur != src_index:
+                cur = d[cur]
+                paths.append(cur)
+            if (id_to_key := G.id_to_key) is not None:
+                if reverse_path:
+                    paths = [id_to_key[cur] for cur in paths]
+                else:
+                    paths = [id_to_key[cur] for cur in reversed(paths)]
+            elif not reverse_path:
+                paths.reverse()
+        else:
+            if return_type == "path":
+                distances = distances[mask]
+            groups = _groupby(distances, [predecessors[mask], node_ids])
+
+            # `pred_node_iter` does the equivalent as these nested for loops:
+            # for length in range(1, len(groups)):
+            #     preds, nodes = groups[length]
+            #     for pred, node in zip(preds.tolist(), nodes.tolist()):
+            if G.key_to_id is None:
+                pred_node_iter = concat(
+                    zip(*(x.tolist() for x in groups[length]))
+                    for length in range(1, len(groups))
+                )
+            else:
+                pred_node_iter = concat(
+                    zip(*(G._nodeiter_to_iter(x.tolist()) for x in groups[length]))
+                    for length in range(1, len(groups))
+                )
+            # Consider making utility functions for creating paths
+            paths = {source: [source]}
+            if reverse_path:
+                for pred, node in pred_node_iter:
+                    paths[node] = [node, *paths[pred]]
+            else:
+                for pred, node in pred_node_iter:
+                    paths[node] = [*paths[pred], node]
+    if return_type == "path":
+        return paths
+    if return_type == "length":
+        return lengths
+    # return_type == "length-path"
+    return lengths, paths
diff --git a/python/nx-cugraph/nx_cugraph/algorithms/shortest_paths/weighted.py b/python/nx-cugraph/nx_cugraph/algorithms/shortest_paths/weighted.py
new file mode 100644
index 00000000000..32323dd45f3
--- /dev/null
+++ b/python/nx-cugraph/nx_cugraph/algorithms/shortest_paths/weighted.py
@@ -0,0 +1,286 @@
+# Copyright (c) 2024, NVIDIA CORPORATION.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import networkx as nx
+import numpy as np
+import pylibcugraph as plc
+
+from nx_cugraph.convert import _to_graph
+from nx_cugraph.utils import (
+    _dtype_param,
+    _get_float_dtype,
+    _groupby,
+    networkx_algorithm,
+)
+
+from .unweighted import _bfs
+
+__all__ = [
+    "bellman_ford_path",
+    "bellman_ford_path_length",
+    "single_source_bellman_ford",
+    "single_source_bellman_ford_path",
+    "single_source_bellman_ford_path_length",
+    "all_pairs_bellman_ford_path",
+    "all_pairs_bellman_ford_path_length",
+]
+
+
+def _add_doc(func):
+    func.__doc__ = (
+        "Negative cycles are not yet supported. ``NotImplementedError`` will be raised "
+        "if there are negative edge weights. We plan to support negative edge weights "
+        "soon. Also, callable ``weight`` argument is not supported."
+    )
+    return func
+
+
+@networkx_algorithm(extra_params=_dtype_param, version_added="24.04", _plc="sssp")
+@_add_doc
+def bellman_ford_path(G, source, target, weight="weight", *, dtype=None):
+    G = _to_graph(G, weight, 1, np.float32)
+    dtype = _get_float_dtype(dtype, graph=G, weight=weight)
+    return _sssp(G, source, weight, target, return_type="path", dtype=dtype)
+
+
+@bellman_ford_path._can_run
+def _(G, source, target, weight="weight", *, dtype=None):
+    return (
+        weight is None
+        or not callable(weight)
+        and not nx.is_negatively_weighted(G, weight=weight)
+    )
+
+
+@networkx_algorithm(extra_params=_dtype_param, version_added="24.04", _plc="sssp")
+@_add_doc
+def bellman_ford_path_length(G, source, target, weight="weight", *, dtype=None):
+    G = _to_graph(G, weight, 1, np.float32)
+    dtype = _get_float_dtype(dtype, graph=G, weight=weight)
+    return _sssp(G, source, weight, target, return_type="length", dtype=dtype)
+
+
+@bellman_ford_path_length._can_run
+def _(G, source, target, weight="weight", *, dtype=None):
+    return (
+        weight is None
+        or not callable(weight)
+        and not nx.is_negatively_weighted(G, weight=weight)
+    )
+
+
+@networkx_algorithm(extra_params=_dtype_param, version_added="24.04", _plc="sssp")
+@_add_doc
+def single_source_bellman_ford_path(G, source, weight="weight", *, dtype=None):
+    G = _to_graph(G, weight, 1, np.float32)
+    dtype = _get_float_dtype(dtype, graph=G, weight=weight)
+    return _sssp(G, source, weight, return_type="path", dtype=dtype)
+
+
+@single_source_bellman_ford_path._can_run
+def _(G, source, weight="weight", *, dtype=None):
+    return (
+        weight is None
+        or not callable(weight)
+        and not nx.is_negatively_weighted(G, weight=weight)
+    )
+
+
+@networkx_algorithm(extra_params=_dtype_param, version_added="24.04", _plc="sssp")
+@_add_doc
+def single_source_bellman_ford_path_length(G, source, weight="weight", *, dtype=None):
+    G = _to_graph(G, weight, 1, np.float32)
+    dtype = _get_float_dtype(dtype, graph=G, weight=weight)
+    return _sssp(G, source, weight, return_type="length", dtype=dtype)
+
+
+@single_source_bellman_ford_path_length._can_run
+def _(G, source, weight="weight", *, dtype=None):
+    return (
+        weight is None
+        or not callable(weight)
+        and not nx.is_negatively_weighted(G, weight=weight)
+    )
+
+
+@networkx_algorithm(extra_params=_dtype_param, version_added="24.04", _plc="sssp")
+@_add_doc
+def single_source_bellman_ford(G, source, target=None, weight="weight", *, dtype=None):
+    G = _to_graph(G, weight, 1, np.float32)
+    dtype = _get_float_dtype(dtype, graph=G, weight=weight)
+    return _sssp(G, source, weight, target, return_type="length-path", dtype=dtype)
+
+
+@single_source_bellman_ford._can_run
+def _(G, source, target=None, weight="weight", *, dtype=None):
+    return (
+        weight is None
+        or not callable(weight)
+        and not nx.is_negatively_weighted(G, weight=weight)
+    )
+
+
+@networkx_algorithm(extra_params=_dtype_param, version_added="24.04", _plc="sssp")
+@_add_doc
+def all_pairs_bellman_ford_path_length(G, weight="weight", *, dtype=None):
+    # TODO PERF: batched bfs to compute many at once
+    G = _to_graph(G, weight, 1, np.float32)
+    dtype = _get_float_dtype(dtype, graph=G, weight=weight)
+    for n in G:
+        yield (n, _sssp(G, n, weight, return_type="length", dtype=dtype))
+
+
+@all_pairs_bellman_ford_path_length._can_run
+def _(G, weight="weight", *, dtype=None):
+    return (
+        weight is None
+        or not callable(weight)
+        and not nx.is_negatively_weighted(G, weight=weight)
+    )
+
+
+@networkx_algorithm(extra_params=_dtype_param, version_added="24.04", _plc="sssp")
+@_add_doc
+def all_pairs_bellman_ford_path(G, weight="weight", *, dtype=None):
+    # TODO PERF: batched bfs to compute many at once
+    G = _to_graph(G, weight, 1, np.float32)
+    dtype = _get_float_dtype(dtype, graph=G, weight=weight)
+    for n in G:
+        yield (n, _sssp(G, n, weight, return_type="path", dtype=dtype))
+
+
+@all_pairs_bellman_ford_path._can_run
+def _(G, weight="weight", *, dtype=None):
+    return (
+        weight is None
+        or not callable(weight)
+        and not nx.is_negatively_weighted(G, weight=weight)
+    )
+
+
+def _sssp(G, source, weight, target=None, *, return_type, dtype, reverse_path=False):
+    """SSSP for weighted shortest paths.
+
+    Parameters
+    ----------
+    return_type : {"length", "path", "length-path"}
+
+    """
+    # DRY: _bfs in unweighted.py has similar code
+    if source not in G:
+        raise nx.NodeNotFound(f"Node {source} not found in graph")
+    if target is not None:
+        if source == target:
+            if return_type == "path":
+                return [source]
+            if return_type == "length":
+                return 0
+            # return_type == "length-path"
+            return 0, [source]
+        if target not in G or G.src_indices.size == 0:
+            raise nx.NetworkXNoPath(f"Node {target} not reachable from {source}")
+    elif G.src_indices.size == 0:
+        if return_type == "path":
+            return {source: [source]}
+        if return_type == "length":
+            return {source: 0}
+        # return_type == "length-path"
+        return {source: 0}, {source: [source]}
+
+    if callable(weight):
+        raise NotImplementedError("callable `weight` argument is not supported")
+
+    if weight not in G.edge_values:
+        # No edge values, so use BFS instead
+        return _bfs(G, source, None, "Source", return_type=return_type, target=target)
+
+    # Check for negative values since we don't support negative cycles
+    edge_vals = G.edge_values[weight]
+    if weight in G.edge_masks:
+        edge_vals = edge_vals[G.edge_masks[weight]]
+    if (edge_vals < 0).any():
+        raise NotImplementedError("Negative edge weights not yet supported")
+    edge_val = edge_vals[0]
+    if (edge_vals == edge_val).all() and (
+        edge_vals.size == G.src_indices.size or edge_val == 1
+    ):
+        # Edge values are all the same, so use scaled BFS instead
+        return _bfs(
+            G,
+            source,
+            None,
+            "Source",
+            return_type=return_type,
+            target=target,
+            scale=edge_val,
+            reverse_path=reverse_path,
+        )
+
+    src_index = source if G.key_to_id is None else G.key_to_id[source]
+    node_ids, distances, predecessors = plc.sssp(
+        resource_handle=plc.ResourceHandle(),
+        graph=G._get_plc_graph(weight, 1, dtype),
+        source=src_index,
+        cutoff=np.inf,
+        compute_predecessors=True,  # TODO: False is not yet supported
+        # compute_predecessors=return_type != "length",
+        do_expensive_check=False,
+    )
+    mask = distances != np.finfo(distances.dtype).max
+    node_ids = node_ids[mask]
+    if return_type != "path":
+        lengths = G._nodearrays_to_dict(node_ids, distances[mask])
+        if target is not None:
+            if target not in lengths:
+                raise nx.NetworkXNoPath(f"Node {target} not reachable from {source}")
+            lengths = lengths[target]
+    if return_type != "length":
+        if target is not None:
+            d = dict(zip(node_ids.tolist(), predecessors[mask].tolist()))
+            dst_index = target if G.key_to_id is None else G.key_to_id[target]
+            if dst_index not in d:
+                raise nx.NetworkXNoPath(f"Node {target} not reachable from {source}")
+            cur = dst_index
+            paths = [dst_index]
+            while cur != src_index:
+                cur = d[cur]
+                paths.append(cur)
+            if (id_to_key := G.id_to_key) is not None:
+                if reverse_path:
+                    paths = [id_to_key[cur] for cur in paths]
+                else:
+                    paths = [id_to_key[cur] for cur in reversed(paths)]
+            elif not reverse_path:
+                paths.reverse()
+        else:
+            groups = _groupby(predecessors[mask], node_ids)
+            if (id_to_key := G.id_to_key) is not None:
+                groups = {id_to_key[k]: v for k, v in groups.items() if k >= 0}
+            paths = {source: [source]}
+            preds = [source]
+            while preds:
+                pred = preds.pop()
+                pred_path = paths[pred]
+                nodes = G._nodearray_to_list(groups[pred])
+                if reverse_path:
+                    for node in nodes:
+                        paths[node] = [node, *pred_path]
+                else:
+                    for node in nodes:
+                        paths[node] = [*pred_path, node]
+                preds.extend(nodes & groups.keys())
+    if return_type == "path":
+        return paths
+    if return_type == "length":
+        return lengths
+    # return_type == "length-path"
+    return lengths, paths
diff --git a/python/nx-cugraph/nx_cugraph/interface.py b/python/nx-cugraph/nx_cugraph/interface.py
index d044ba6960d..0d893ac286b 100644
--- a/python/nx-cugraph/nx_cugraph/interface.py
+++ b/python/nx-cugraph/nx_cugraph/interface.py
@@ -67,6 +67,7 @@ def key(testpath):
         no_multigraph = "multigraphs not currently supported"
         louvain_different = "Louvain may be different due to RNG"
         no_string_dtype = "string edge values not currently supported"
+        sssp_path_different = "sssp may choose a different valid path"
 
         xfail = {
             # This is removed while strongly_connected_components() is not
@@ -77,6 +78,19 @@ def key(testpath):
             #     "test_strongly_connected.py:"
             #     "TestStronglyConnected.test_condensation_mapping_and_members"
             # ): "Strongly connected groups in different iteration order",
+            key(
+                "test_cycles.py:TestMinimumCycleBasis.test_unweighted_diamond"
+            ): sssp_path_different,
+            key(
+                "test_cycles.py:TestMinimumCycleBasis.test_weighted_diamond"
+            ): sssp_path_different,
+            key(
+                "test_cycles.py:TestMinimumCycleBasis.test_petersen_graph"
+            ): sssp_path_different,
+            key(
+                "test_cycles.py:TestMinimumCycleBasis."
+                "test_gh6787_and_edge_attribute_names"
+            ): sssp_path_different,
         }
 
         from packaging.version import parse
diff --git a/python/nx-cugraph/nx_cugraph/utils/misc.py b/python/nx-cugraph/nx_cugraph/utils/misc.py
index aa06d7fd29b..eab4b42c2cc 100644
--- a/python/nx-cugraph/nx_cugraph/utils/misc.py
+++ b/python/nx-cugraph/nx_cugraph/utils/misc.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2023, NVIDIA CORPORATION.
+# Copyright (c) 2023-2024, NVIDIA CORPORATION.
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
@@ -22,7 +22,9 @@
 import numpy as np
 
 if TYPE_CHECKING:
-    from ..typing import Dtype
+    import nx_cugraph as nxcg
+
+    from ..typing import Dtype, EdgeKey
 
 try:
     from itertools import pairwise  # Python >=3.10
@@ -190,10 +192,14 @@ def _get_int_dtype(
         raise ValueError("Value is too large to store as integer: {val}") from exc
 
 
-def _get_float_dtype(dtype: Dtype):
+def _get_float_dtype(
+    dtype: Dtype, *, graph: nxcg.Graph | None = None, weight: EdgeKey | None = None
+):
     """Promote dtype to float32 or float64 as appropriate."""
     if dtype is None:
-        return np.dtype(np.float32)
+        if graph is None or weight not in graph.edge_values:
+            return np.dtype(np.float32)
+        dtype = graph.edge_values[weight].dtype
     rv = np.promote_types(dtype, np.float32)
     if np.float32 != rv != np.float64:
         raise TypeError(
diff --git a/python/nx-cugraph/scripts/update_readme.py b/python/nx-cugraph/scripts/update_readme.py
old mode 100644
new mode 100755

From fda91fac9df7429febfa61251db1044aa1149fc1 Mon Sep 17 00:00:00 2001
From: Bradley Dice 
Date: Wed, 13 Mar 2024 10:22:13 -0500
Subject: [PATCH 3/3] Add upper bound to prevent usage of NumPy 2 (#4233)

NumPy 2 is expected to be released in the near future. For the RAPIDS 24.04 release, we will pin to `numpy>=1.23,<2.0a0`. This PR adds an upper bound to affected RAPIDS repositories.

xref: https://github.com/rapidsai/build-planning/issues/29

Authors:
  - Bradley Dice (https://github.com/bdice)

Approvers:
  - Rick Ratzel (https://github.com/rlratzel)
  - Ray Douglass (https://github.com/raydouglass)

URL: https://github.com/rapidsai/cugraph/pull/4233
---
 conda/environments/all_cuda-118_arch-x86_64.yaml | 2 +-
 conda/environments/all_cuda-122_arch-x86_64.yaml | 2 +-
 conda/recipes/cugraph-dgl/meta.yaml              | 2 +-
 conda/recipes/cugraph-pyg/meta.yaml              | 2 +-
 conda/recipes/cugraph-service/meta.yaml          | 2 +-
 dependencies.yaml                                | 2 +-
 python/cugraph-dgl/pyproject.toml                | 2 +-
 python/cugraph-pyg/pyproject.toml                | 2 +-
 python/cugraph-service/server/pyproject.toml     | 4 ++--
 python/cugraph/pyproject.toml                    | 4 ++--
 python/nx-cugraph/pyproject.toml                 | 2 +-
 python/pylibcugraph/pyproject.toml               | 2 +-
 12 files changed, 14 insertions(+), 14 deletions(-)

diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml
index 6aed308c498..f0eff82e1ae 100644
--- a/conda/environments/all_cuda-118_arch-x86_64.yaml
+++ b/conda/environments/all_cuda-118_arch-x86_64.yaml
@@ -42,7 +42,7 @@ dependencies:
 - ninja
 - notebook>=0.5.0
 - numba>=0.57
-- numpy>=1.23
+- numpy>=1.23,<2.0a0
 - numpydoc
 - nvcc_linux-64=11.8
 - openmpi
diff --git a/conda/environments/all_cuda-122_arch-x86_64.yaml b/conda/environments/all_cuda-122_arch-x86_64.yaml
index 4a095058219..93972f40d8b 100644
--- a/conda/environments/all_cuda-122_arch-x86_64.yaml
+++ b/conda/environments/all_cuda-122_arch-x86_64.yaml
@@ -48,7 +48,7 @@ dependencies:
 - ninja
 - notebook>=0.5.0
 - numba>=0.57
-- numpy>=1.23
+- numpy>=1.23,<2.0a0
 - numpydoc
 - openmpi
 - packaging>=21
diff --git a/conda/recipes/cugraph-dgl/meta.yaml b/conda/recipes/cugraph-dgl/meta.yaml
index 09322a9c7d3..5e28e69a0d7 100644
--- a/conda/recipes/cugraph-dgl/meta.yaml
+++ b/conda/recipes/cugraph-dgl/meta.yaml
@@ -25,7 +25,7 @@ requirements:
     - cugraph ={{ version }}
     - dgl >=1.1.0.cu*
     - numba >=0.57
-    - numpy >=1.23
+    - numpy >=1.23,<2.0a0
     - pylibcugraphops ={{ minor_version }}
     - python
     - pytorch
diff --git a/conda/recipes/cugraph-pyg/meta.yaml b/conda/recipes/cugraph-pyg/meta.yaml
index 624f5753fd2..4ada5e31211 100644
--- a/conda/recipes/cugraph-pyg/meta.yaml
+++ b/conda/recipes/cugraph-pyg/meta.yaml
@@ -28,7 +28,7 @@ requirements:
   run:
     - rapids-dask-dependency ={{ minor_version }}
     - numba >=0.57
-    - numpy >=1.23
+    - numpy >=1.23,<2.0a0
     - python
     - pytorch >=2.0
     - cupy >=12.0.0
diff --git a/conda/recipes/cugraph-service/meta.yaml b/conda/recipes/cugraph-service/meta.yaml
index c04c1a7c7fa..8698d4f6985 100644
--- a/conda/recipes/cugraph-service/meta.yaml
+++ b/conda/recipes/cugraph-service/meta.yaml
@@ -60,7 +60,7 @@ outputs:
         - dask-cuda ={{ minor_version }}
         - dask-cudf ={{ minor_version }}
         - numba >=0.57
-        - numpy >=1.23
+        - numpy >=1.23,<2.0a0
         - python
         - rapids-dask-dependency ={{ minor_version }}
         - thriftpy2 >=0.4.15
diff --git a/dependencies.yaml b/dependencies.yaml
index e6cf6c9e93c..d8be5352c7d 100644
--- a/dependencies.yaml
+++ b/dependencies.yaml
@@ -449,7 +449,7 @@ dependencies:
           - &dask rapids-dask-dependency==24.4.*
           - &dask_cuda dask-cuda==24.4.*
           - &numba numba>=0.57
-          - &numpy numpy>=1.23
+          - &numpy numpy>=1.23,<2.0a0
           - &ucx_py ucx-py==0.37.*
       - output_types: conda
         packages:
diff --git a/python/cugraph-dgl/pyproject.toml b/python/cugraph-dgl/pyproject.toml
index c6f76325761..f17292c5e70 100644
--- a/python/cugraph-dgl/pyproject.toml
+++ b/python/cugraph-dgl/pyproject.toml
@@ -25,7 +25,7 @@ classifiers = [
 dependencies = [
     "cugraph==24.4.*",
     "numba>=0.57",
-    "numpy>=1.23",
+    "numpy>=1.23,<2.0a0",
     "pylibcugraphops==24.4.*",
 ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`.
 
diff --git a/python/cugraph-pyg/pyproject.toml b/python/cugraph-pyg/pyproject.toml
index cbee5ed4b58..150ecbf506b 100644
--- a/python/cugraph-pyg/pyproject.toml
+++ b/python/cugraph-pyg/pyproject.toml
@@ -29,7 +29,7 @@ classifiers = [
 dependencies = [
     "cugraph==24.4.*",
     "numba>=0.57",
-    "numpy>=1.23",
+    "numpy>=1.23,<2.0a0",
     "pylibcugraphops==24.4.*",
 ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`.
 
diff --git a/python/cugraph-service/server/pyproject.toml b/python/cugraph-service/server/pyproject.toml
index a32b18a9551..d6cf48432cb 100644
--- a/python/cugraph-service/server/pyproject.toml
+++ b/python/cugraph-service/server/pyproject.toml
@@ -26,7 +26,7 @@ dependencies = [
     "dask-cuda==24.4.*",
     "dask-cudf==24.4.*",
     "numba>=0.57",
-    "numpy>=1.23",
+    "numpy>=1.23,<2.0a0",
     "rapids-dask-dependency==24.4.*",
     "rmm==24.4.*",
     "thriftpy2",
@@ -46,7 +46,7 @@ cugraph-service-server = "cugraph_service_server.__main__:main"
 [project.optional-dependencies]
 test = [
     "networkx>=2.5.1",
-    "numpy>=1.23",
+    "numpy>=1.23,<2.0a0",
     "pandas",
     "pytest",
     "pytest-benchmark",
diff --git a/python/cugraph/pyproject.toml b/python/cugraph/pyproject.toml
index 113c316ccbf..a6d3d841298 100644
--- a/python/cugraph/pyproject.toml
+++ b/python/cugraph/pyproject.toml
@@ -35,7 +35,7 @@ dependencies = [
     "dask-cudf==24.4.*",
     "fsspec[http]>=0.6.0",
     "numba>=0.57",
-    "numpy>=1.23",
+    "numpy>=1.23,<2.0a0",
     "pylibcugraph==24.4.*",
     "raft-dask==24.4.*",
     "rapids-dask-dependency==24.4.*",
@@ -53,7 +53,7 @@ classifiers = [
 [project.optional-dependencies]
 test = [
     "networkx>=2.5.1",
-    "numpy>=1.23",
+    "numpy>=1.23,<2.0a0",
     "pandas",
     "pytest",
     "pytest-benchmark",
diff --git a/python/nx-cugraph/pyproject.toml b/python/nx-cugraph/pyproject.toml
index 07ec0eab264..dbdc8dd19e1 100644
--- a/python/nx-cugraph/pyproject.toml
+++ b/python/nx-cugraph/pyproject.toml
@@ -33,7 +33,7 @@ classifiers = [
 dependencies = [
     "cupy-cuda11x>=12.0.0",
     "networkx>=3.0",
-    "numpy>=1.23",
+    "numpy>=1.23,<2.0a0",
     "pylibcugraph==24.4.*",
 ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`.
 
diff --git a/python/pylibcugraph/pyproject.toml b/python/pylibcugraph/pyproject.toml
index eb7323d19e5..d5f568a7a90 100644
--- a/python/pylibcugraph/pyproject.toml
+++ b/python/pylibcugraph/pyproject.toml
@@ -42,7 +42,7 @@ classifiers = [
 [project.optional-dependencies]
 test = [
     "cudf==24.4.*",
-    "numpy>=1.23",
+    "numpy>=1.23,<2.0a0",
     "pandas",
     "pytest",
     "pytest-benchmark",