diff --git a/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py b/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py index 319435575cc..c2ae4356c2b 100644 --- a/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py +++ b/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py @@ -825,12 +825,13 @@ def get_two_hop_neighbors(self, start_vertices=None): _client = default_client() def _call_plc_two_hop_neighbors(sID, mg_graph_x, start_vertices): - return pylibcugraph_get_two_hop_neighbors( + results_ = pylibcugraph_get_two_hop_neighbors( resource_handle=ResourceHandle(Comms.get_handle(sID).getHandle()), graph=mg_graph_x, start_vertices=start_vertices, - do_expensive_check=False, + do_expensive_check=True, ) + return results_ if isinstance(start_vertices, int): start_vertices = [start_vertices] @@ -845,31 +846,26 @@ def _call_plc_two_hop_neighbors(sID, mg_graph_x, start_vertices): else: start_vertices_type = self.input_df.dtypes[0] - if not isinstance(start_vertices, (dask_cudf.Series)): - start_vertices = dask_cudf.from_cudf( + start_vertices = start_vertices.astype(start_vertices_type) + + def create_iterable_args(session_id, input_graph, start_vertices=None, npartitions=None): + session_id_it = [session_id] * npartitions + graph_it = input_graph.values() + start_vertices = cp.array_split(start_vertices.values, npartitions) + return [ + session_id_it, + graph_it, start_vertices, - npartitions=min(self._npartitions, len(start_vertices)), - ) - start_vertices = start_vertices.astype(start_vertices_type) + ] - n_workers = get_n_workers() - start_vertices = start_vertices.repartition(npartitions=n_workers) - start_vertices = persist_dask_df_equal_parts_per_worker( - start_vertices, _client + result = _client.map( + _call_plc_two_hop_neighbors, + *create_iterable_args( + Comms.get_session_id(), self._plc_graph, start_vertices, self._npartitions + ), + pure=False, ) - start_vertices = get_persisted_df_worker_map(start_vertices, _client) - result = [ - _client.submit( - _call_plc_two_hop_neighbors, - Comms.get_session_id(), - self._plc_graph[w], - start_vertices[w][0], - workers=[w], - allow_other_workers=False, - ) - for w in start_vertices.keys() - ] else: result = [ _client.submit( @@ -896,7 +892,7 @@ def convert_to_cudf(cp_arrays): return df cudf_result = [ - _client.submit(convert_to_cudf, cp_arrays) for cp_arrays in result + _client.submit(convert_to_cudf, cp_arrays, pure=False) for cp_arrays in result ] wait(cudf_result)