Skip to content

Commit

Permalink
Add a barrier before cugraph Graph creation (#4046)
Browse files Browse the repository at this point in the history
This PR introduces a short term fix for #4037 . 


CC: @jnke2016 , @rlratzel

Authors:
  - Vibhu Jawa (https://github.com/VibhuJawa)

Approvers:
  - Rick Ratzel (https://github.com/rlratzel)
  - Joseph Nke (https://github.com/jnke2016)

URL: #4046
  • Loading branch information
VibhuJawa authored Dec 6, 2023
1 parent f80480d commit a5718c6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
7 changes: 6 additions & 1 deletion python/cugraph/cugraph/dask/common/mg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.

import os

import gc
import numba.cuda


Expand Down Expand Up @@ -68,3 +68,8 @@ def get_visible_devices():
else:
visible_devices = _visible_devices.strip().split(",")
return visible_devices


def run_gc_on_dask_cluster(client):
gc.collect()
client.run(gc.collect)
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
from cugraph.dask.common.part_utils import (
get_persisted_df_worker_map,
persist_dask_df_equal_parts_per_worker,
_chunk_lst,
)
from cugraph.dask.common.mg_utils import run_gc_on_dask_cluster
from cugraph.dask import get_n_workers
import cugraph.dask.comms.comms as Comms

Expand Down Expand Up @@ -171,7 +171,6 @@ def __from_edgelist(
store_transposed=False,
legacy_renum_only=False,
):

if not isinstance(input_ddf, dask_cudf.DataFrame):
raise TypeError("input should be a dask_cudf dataFrame")

Expand Down Expand Up @@ -275,7 +274,6 @@ def __from_edgelist(
)
value_col = None
else:

source_col, dest_col, value_col = symmetrize(
input_ddf,
source,
Expand Down Expand Up @@ -350,9 +348,11 @@ def __from_edgelist(
is_symmetric=not self.properties.directed,
)
ddf = ddf.repartition(npartitions=len(workers) * 2)
ddf_keys = ddf.to_delayed()
workers = _client.scheduler_info()["workers"].keys()
ddf_keys_ls = _chunk_lst(ddf_keys, len(workers))
persisted_keys_d = persist_dask_df_equal_parts_per_worker(
ddf, _client, return_type="dict"
)
del ddf

delayed_tasks_d = {
w: delayed(simpleDistributedGraphImpl._make_plc_graph)(
Expand All @@ -367,19 +367,19 @@ def __from_edgelist(
self.edge_id_type,
self.edge_type_id_type,
)
for w, edata in zip(workers, ddf_keys_ls)
for w, edata in persisted_keys_d.items()
}
del persisted_keys_d
self._plc_graph = {
w: _client.compute(
delayed_task, workers=w, allow_other_workers=False, pure=False
)
for w, delayed_task in delayed_tasks_d.items()
}
wait(list(self._plc_graph.values()))
del ddf_keys
del delayed_tasks_d
gc.collect()
_client.run(gc.collect)
run_gc_on_dask_cluster(_client)
wait(list(self._plc_graph.values()))
run_gc_on_dask_cluster(_client)

@property
def renumbered(self):
Expand Down Expand Up @@ -945,7 +945,6 @@ def convert_to_cudf(cp_arrays: cp.ndarray) -> cudf.Series:
def _call_plc_select_random_vertices(
mg_graph_x, sID: bytes, random_state: int, num_vertices: int
) -> cudf.Series:

cp_arrays = pylibcugraph_select_random_vertices(
graph=mg_graph_x,
resource_handle=ResourceHandle(Comms.get_handle(sID).getHandle()),
Expand All @@ -961,7 +960,6 @@ def _mg_call_plc_select_random_vertices(
random_state: int,
num_vertices: int,
) -> dask_cudf.Series:

result = [
client.submit(
_call_plc_select_random_vertices,
Expand Down

0 comments on commit a5718c6

Please sign in to comment.