Skip to content

Commit

Permalink
Fix Jaccard hang (#4080)
Browse files Browse the repository at this point in the history
This PR leverages `client.map` to simultaneously launch processes in order to avoid hangs

closes #3926

Authors:
  - Joseph Nke (https://github.com/jnke2016)
  - Rick Ratzel (https://github.com/rlratzel)

Approvers:
  - Rick Ratzel (https://github.com/rlratzel)

URL: #4080
  • Loading branch information
jnke2016 authored Jan 20, 2024
1 parent ec65907 commit 77d833a
Showing 1 changed file with 26 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION.
# Copyright (c) 2021-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -35,11 +35,9 @@
from cugraph.structure.number_map import NumberMap
from cugraph.structure.symmetrize import symmetrize
from cugraph.dask.common.part_utils import (
get_persisted_df_worker_map,
persist_dask_df_equal_parts_per_worker,
)
from cugraph.dask.common.mg_utils import run_gc_on_dask_cluster
from cugraph.dask import get_n_workers
import cugraph.dask.comms.comms as Comms


Expand Down Expand Up @@ -825,12 +823,13 @@ def get_two_hop_neighbors(self, start_vertices=None):
_client = default_client()

def _call_plc_two_hop_neighbors(sID, mg_graph_x, start_vertices):
return pylibcugraph_get_two_hop_neighbors(
results_ = pylibcugraph_get_two_hop_neighbors(
resource_handle=ResourceHandle(Comms.get_handle(sID).getHandle()),
graph=mg_graph_x,
start_vertices=start_vertices,
do_expensive_check=False,
)
return results_

if isinstance(start_vertices, int):
start_vertices = [start_vertices]
Expand All @@ -845,31 +844,31 @@ def _call_plc_two_hop_neighbors(sID, mg_graph_x, start_vertices):
else:
start_vertices_type = self.input_df.dtypes[0]

if not isinstance(start_vertices, (dask_cudf.Series)):
start_vertices = dask_cudf.from_cudf(
start_vertices = start_vertices.astype(start_vertices_type)

def create_iterable_args(
session_id, input_graph, start_vertices=None, npartitions=None
):
session_id_it = [session_id] * npartitions
graph_it = input_graph.values()
start_vertices = cp.array_split(start_vertices.values, npartitions)
return [
session_id_it,
graph_it,
start_vertices,
npartitions=min(self._npartitions, len(start_vertices)),
)
start_vertices = start_vertices.astype(start_vertices_type)
]

n_workers = get_n_workers()
start_vertices = start_vertices.repartition(npartitions=n_workers)
start_vertices = persist_dask_df_equal_parts_per_worker(
start_vertices, _client
result = _client.map(
_call_plc_two_hop_neighbors,
*create_iterable_args(
Comms.get_session_id(),
self._plc_graph,
start_vertices,
self._npartitions,
),
pure=False,
)
start_vertices = get_persisted_df_worker_map(start_vertices, _client)

result = [
_client.submit(
_call_plc_two_hop_neighbors,
Comms.get_session_id(),
self._plc_graph[w],
start_vertices[w][0],
workers=[w],
allow_other_workers=False,
)
for w in start_vertices.keys()
]
else:
result = [
_client.submit(
Expand All @@ -896,7 +895,8 @@ def convert_to_cudf(cp_arrays):
return df

cudf_result = [
_client.submit(convert_to_cudf, cp_arrays) for cp_arrays in result
_client.submit(convert_to_cudf, cp_arrays, pure=False)
for cp_arrays in result
]

wait(cudf_result)
Expand Down

0 comments on commit 77d833a

Please sign in to comment.