Skip to content

Commit

Permalink
Sync graph creation
Browse files Browse the repository at this point in the history
  • Loading branch information
VibhuJawa committed Dec 5, 2023
1 parent 062f685 commit 28db408
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
7 changes: 6 additions & 1 deletion python/cugraph/cugraph/dask/common/mg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.

import os

import gc
import numba.cuda


Expand Down Expand Up @@ -68,3 +68,8 @@ def get_visible_devices():
else:
visible_devices = _visible_devices.strip().split(",")
return visible_devices


def run_gc_on_dask_cluster(client):
gc.collect()
client.run(gc.collect)
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@
get_persisted_df_worker_map,
persist_dask_df_equal_parts_per_worker,
)
from cugraph.dask.common.mg_utils import run_gc_on_dask_cluster
from cugraph.dask import get_n_workers
import cugraph.dask.comms.comms as Comms
from dask.distributed import get_worker


class simpleDistributedGraphImpl:
Expand Down Expand Up @@ -98,7 +98,6 @@ def _make_plc_graph(
edge_id_type,
edge_type_id,
):
print(f"_make_plc_graph called from worker = {get_worker().name}", flush=True)
weights = None
edge_ids = None
edge_types = None
Expand Down Expand Up @@ -172,7 +171,6 @@ def __from_edgelist(
store_transposed=False,
legacy_renum_only=False,
):

if not isinstance(input_ddf, dask_cudf.DataFrame):
raise TypeError("input should be a dask_cudf dataFrame")

Expand Down Expand Up @@ -276,7 +274,6 @@ def __from_edgelist(
)
value_col = None
else:

source_col, dest_col, value_col = symmetrize(
input_ddf,
source,
Expand Down Expand Up @@ -380,9 +377,9 @@ def __from_edgelist(
for w, delayed_task in delayed_tasks_d.items()
}
del delayed_tasks_d
gc.collect()
_client.run(gc.collect)
run_gc_on_dask_cluster(_client)
wait(list(self._plc_graph.values()))
run_gc_on_dask_cluster(_client)

@property
def renumbered(self):
Expand Down Expand Up @@ -948,7 +945,6 @@ def convert_to_cudf(cp_arrays: cp.ndarray) -> cudf.Series:
def _call_plc_select_random_vertices(
mg_graph_x, sID: bytes, random_state: int, num_vertices: int
) -> cudf.Series:

cp_arrays = pylibcugraph_select_random_vertices(
graph=mg_graph_x,
resource_handle=ResourceHandle(Comms.get_handle(sID).getHandle()),
Expand All @@ -964,7 +960,6 @@ def _mg_call_plc_select_random_vertices(
random_state: int,
num_vertices: int,
) -> dask_cudf.Series:

result = [
client.submit(
_call_plc_select_random_vertices,
Expand Down

0 comments on commit 28db408

Please sign in to comment.