Skip to content

Commit

Permalink
Remove dask_cudf dataframe for the _make_plc_graph while creating…
Browse files Browse the repository at this point in the history
… `cugraph.Graph` (#3895)

This PR attempts to fix #3790 

Please note that I  have not being able to cause failure locally so it is really hard for me to know if it actually fixes anything or not  . 

MRE being used to test locally: https://gist.github.com/VibhuJawa/4b1ec24022b6e2dd7879cd2e8d3fab67


CC: @jnke2016 , @rlratzel ,

CC:  @rjzamora , Please let me know what i can do better here.

Authors:
  - Vibhu Jawa (https://github.com/VibhuJawa)
  - Brad Rees (https://github.com/BradReesWork)

Approvers:
  - Rick Ratzel (https://github.com/rlratzel)
  - Joseph Nke (https://github.com/jnke2016)

URL: #3895
  • Loading branch information
VibhuJawa authored Oct 5, 2023
1 parent 061ada4 commit d03cb0f
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 15 deletions.
62 changes: 54 additions & 8 deletions python/cugraph/cugraph/dask/common/part_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,19 +99,65 @@ def _chunk_lst(ls, num_parts):
return [ls[i::num_parts] for i in range(num_parts)]


def persist_dask_df_equal_parts_per_worker(dask_df, client):
def persist_dask_df_equal_parts_per_worker(
dask_df, client, return_type="dask_cudf.DataFrame"
):
"""
Persist dask_df with equal parts per worker
Args:
dask_df: dask_cudf.DataFrame
client: dask.distributed.Client
return_type: str, "dask_cudf.DataFrame" or "dict"
Returns:
persisted_keys: dict of {worker: [persisted_keys]}
"""
if return_type not in ["dask_cudf.DataFrame", "dict"]:
raise ValueError("return_type must be either 'dask_cudf.DataFrame' or 'dict'")

ddf_keys = dask_df.to_delayed()
workers = client.scheduler_info()["workers"].keys()
ddf_keys_ls = _chunk_lst(ddf_keys, len(workers))
persisted_keys = []
persisted_keys_d = {}
for w, ddf_k in zip(workers, ddf_keys_ls):
persisted_keys.extend(
client.persist(ddf_k, workers=w, allow_other_workers=False)
persisted_keys_d[w] = client.compute(
ddf_k, workers=w, allow_other_workers=False, pure=False
)
dask_df = dask_cudf.from_delayed(persisted_keys, meta=dask_df._meta).persist()
wait(dask_df)
client.rebalance(dask_df)
return dask_df

persisted_keys_ls = [
item for sublist in persisted_keys_d.values() for item in sublist
]
wait(persisted_keys_ls)
if return_type == "dask_cudf.DataFrame":
dask_df = dask_cudf.from_delayed(
persisted_keys_ls, meta=dask_df._meta
).persist()
wait(dask_df)
return dask_df

return persisted_keys_d


def get_length_of_parts(persisted_keys_d, client):
"""
Get the length of each partition
Args:
persisted_keys_d: dict of {worker: [persisted_keys]}
client: dask.distributed.Client
Returns:
length_of_parts: dict of {worker: [length_of_parts]}
"""
length_of_parts = {}
for w, p_keys in persisted_keys_d.items():
length_of_parts[w] = [
client.submit(
len, p_key, pure=False, workers=[w], allow_other_workers=False
)
for p_key in p_keys
]

for w, len_futures in length_of_parts.items():
length_of_parts[w] = client.gather(len_futures)
return length_of_parts


async def _extract_partitions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from cugraph.structure.symmetrize import symmetrize
from cugraph.dask.common.part_utils import (
get_persisted_df_worker_map,
get_length_of_parts,
persist_dask_df_equal_parts_per_worker,
)
from cugraph.dask import get_n_workers
Expand Down Expand Up @@ -318,9 +319,14 @@ def __from_edgelist(
is_symmetric=not self.properties.directed,
)
ddf = ddf.repartition(npartitions=len(workers) * 2)
ddf = persist_dask_df_equal_parts_per_worker(ddf, _client)
num_edges = len(ddf)
ddf = get_persisted_df_worker_map(ddf, _client)
persisted_keys_d = persist_dask_df_equal_parts_per_worker(
ddf, _client, return_type="dict"
)
del ddf
length_of_parts = get_length_of_parts(persisted_keys_d, _client)
num_edges = sum(
[item for sublist in length_of_parts.values() for item in sublist]
)
delayed_tasks_d = {
w: delayed(simpleDistributedGraphImpl._make_plc_graph)(
Comms.get_session_id(),
Expand All @@ -331,14 +337,16 @@ def __from_edgelist(
store_transposed,
num_edges,
)
for w, edata in ddf.items()
for w, edata in persisted_keys_d.items()
}
# FIXME: For now, don't delete the copied dataframe to avoid crash
self._plc_graph = {
w: _client.compute(delayed_task, workers=w, allow_other_workers=False)
w: _client.compute(
delayed_task, workers=w, allow_other_workers=False, pure=False
)
for w, delayed_task in delayed_tasks_d.items()
}
wait(list(self._plc_graph.values()))
del persisted_keys_d
del delayed_tasks_d
_client.run(gc.collect)

Expand Down Expand Up @@ -1192,5 +1200,7 @@ def _get_column_from_ls_dfs(lst_df, col_name):
if len_df == 0:
return lst_df[0][col_name]
output_col = cudf.concat([df[col_name] for df in lst_df], ignore_index=True)
# FIXME: For now, don't delete the copied dataframe to avoid cras
for df in lst_df:
df.drop(columns=[col_name], inplace=True)
gc.collect()
return output_col

0 comments on commit d03cb0f

Please sign in to comment.