diff --git a/python/cugraph/cugraph/dask/common/mg_utils.py b/python/cugraph/cugraph/dask/common/mg_utils.py index 6acda48c9da..b04f293dc0e 100644 --- a/python/cugraph/cugraph/dask/common/mg_utils.py +++ b/python/cugraph/cugraph/dask/common/mg_utils.py @@ -12,7 +12,7 @@ # limitations under the License. import os - +import gc import numba.cuda @@ -68,3 +68,8 @@ def get_visible_devices(): else: visible_devices = _visible_devices.strip().split(",") return visible_devices + + +def run_gc_on_dask_cluster(client): + gc.collect() + client.run(gc.collect) diff --git a/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py b/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py index 89bcd543e37..319435575cc 100644 --- a/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py +++ b/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py @@ -38,9 +38,9 @@ get_persisted_df_worker_map, persist_dask_df_equal_parts_per_worker, ) +from cugraph.dask.common.mg_utils import run_gc_on_dask_cluster from cugraph.dask import get_n_workers import cugraph.dask.comms.comms as Comms -from dask.distributed import get_worker class simpleDistributedGraphImpl: @@ -98,7 +98,6 @@ def _make_plc_graph( edge_id_type, edge_type_id, ): - print(f"_make_plc_graph called from worker = {get_worker().name}", flush=True) weights = None edge_ids = None edge_types = None @@ -172,7 +171,6 @@ def __from_edgelist( store_transposed=False, legacy_renum_only=False, ): - if not isinstance(input_ddf, dask_cudf.DataFrame): raise TypeError("input should be a dask_cudf dataFrame") @@ -276,7 +274,6 @@ def __from_edgelist( ) value_col = None else: - source_col, dest_col, value_col = symmetrize( input_ddf, source, @@ -380,9 +377,9 @@ def __from_edgelist( for w, delayed_task in delayed_tasks_d.items() } del delayed_tasks_d - gc.collect() - _client.run(gc.collect) + run_gc_on_dask_cluster(_client) wait(list(self._plc_graph.values())) + run_gc_on_dask_cluster(_client) @property def renumbered(self): @@ -948,7 +945,6 @@ def convert_to_cudf(cp_arrays: cp.ndarray) -> cudf.Series: def _call_plc_select_random_vertices( mg_graph_x, sID: bytes, random_state: int, num_vertices: int ) -> cudf.Series: - cp_arrays = pylibcugraph_select_random_vertices( graph=mg_graph_x, resource_handle=ResourceHandle(Comms.get_handle(sID).getHandle()), @@ -964,7 +960,6 @@ def _mg_call_plc_select_random_vertices( random_state: int, num_vertices: int, ) -> dask_cudf.Series: - result = [ client.submit( _call_plc_select_random_vertices,