From 07d54f5a2c4c88b8eafdb3e9cdce50a7541f08a6 Mon Sep 17 00:00:00 2001 From: Ralph Liu Date: Wed, 20 Mar 2024 12:22:25 -0700 Subject: [PATCH] Add changes --- python/cugraph/cugraph/dask/common/part_utils.py | 6 +++++- python/cugraph/cugraph/dask/comms/comms.py | 15 ++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/python/cugraph/cugraph/dask/common/part_utils.py b/python/cugraph/cugraph/dask/common/part_utils.py index 25311902b29..0f5d46e83bc 100644 --- a/python/cugraph/cugraph/dask/common/part_utils.py +++ b/python/cugraph/cugraph/dask/common/part_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -116,6 +116,10 @@ def persist_dask_df_equal_parts_per_worker( ddf_keys = dask_df.to_delayed() workers = client.scheduler_info()["workers"].keys() + worker_to_rank = Comms.rank_to_worker(client) + # assure rank-worker mappings are in ascending order + workers = dict(sorted(worker_to_rank.items())).values() + ddf_keys_ls = _chunk_lst(ddf_keys, len(workers)) persisted_keys_d = {} for w, ddf_k in zip(workers, ddf_keys_ls): diff --git a/python/cugraph/cugraph/dask/comms/comms.py b/python/cugraph/cugraph/dask/comms/comms.py index d623f20a038..3897ab4c959 100644 --- a/python/cugraph/cugraph/dask/comms/comms.py +++ b/python/cugraph/cugraph/dask/comms/comms.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2023, NVIDIA CORPORATION. +# Copyright (c) 2018-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -265,3 +265,16 @@ def get_n_workers(sID=None, dask_worker=None): dask_worker = get_worker() sessionstate = get_raft_comm_state(sID, dask_worker) return sessionstate["nworkers"] + + +def rank_to_worker(client): + """ + Return a mapping of dask workers to ranks. + """ + workers = client.scheduler_info()["workers"].keys() + worker_info = __instance.worker_info(workers) + rank_to_worker = {} + for w in worker_info: + rank_to_worker[worker_info[w]["rank"]] = w + + return rank_to_worker