Skip to content

Commit

Permalink
Remove another persist and decrease memory footprint of drop_duplicates
Browse files Browse the repository at this point in the history
  • Loading branch information
VibhuJawa committed Nov 21, 2023
1 parent 129a226 commit e275714
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ def get_args():
client, cluster = start_dask_client(
dask_worker_devices=dask_worker_devices,
jit_unspill=False,
rmm_pool_size=28e9,
rmm_pool_size=20e9,
rmm_async=True,
)
enable_spilling()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from cugraph.dask.common.part_utils import (
get_persisted_df_worker_map,
persist_dask_df_equal_parts_per_worker,
_chunk_lst,
)
from cugraph.dask import get_n_workers
import cugraph.dask.comms.comms as Comms
Expand Down Expand Up @@ -188,7 +189,6 @@ def __from_edgelist(
workers = _client.scheduler_info()["workers"]
# Repartition to 2 partitions per GPU for memory efficient process
input_ddf = input_ddf.repartition(npartitions=len(workers) * 2)
input_ddf = input_ddf.map_partitions(lambda df: df.copy())
# The dataframe will be symmetrized iff the graph is undirected
# otherwise, the inital dataframe will be returned
if edge_attr is not None:
Expand Down Expand Up @@ -325,10 +325,9 @@ def __from_edgelist(
is_symmetric=not self.properties.directed,
)
ddf = ddf.repartition(npartitions=len(workers) * 2)
persisted_keys_d = persist_dask_df_equal_parts_per_worker(
ddf, _client, return_type="dict"
)
del ddf
ddf_keys = ddf.to_delayed()
workers = _client.scheduler_info()["workers"].keys()
ddf_keys_ls = _chunk_lst(ddf_keys, len(workers))

delayed_tasks_d = {
w: delayed(simpleDistributedGraphImpl._make_plc_graph)(
Expand All @@ -339,7 +338,7 @@ def __from_edgelist(
dst_col_name,
store_transposed,
)
for w, edata in persisted_keys_d.items()
for w, edata in zip(workers, ddf_keys_ls)
}
self._plc_graph = {
w: _client.compute(
Expand All @@ -348,8 +347,9 @@ def __from_edgelist(
for w, delayed_task in delayed_tasks_d.items()
}
wait(list(self._plc_graph.values()))
del persisted_keys_d
del ddf_keys
del delayed_tasks_d
gc.collect()
_client.run(gc.collect)

@property
Expand Down

0 comments on commit e275714

Please sign in to comment.