diff --git a/python/raft-dask/raft_dask/common/comms.py b/python/raft-dask/raft_dask/common/comms.py index 7a0b786ec4..1f2b9f40e8 100644 --- a/python/raft-dask/raft_dask/common/comms.py +++ b/python/raft-dask/raft_dask/common/comms.py @@ -18,7 +18,8 @@ import time import uuid import warnings -from collections import Counter, OrderedDict +from collections import Counter, OrderedDict, defaultdict +from typing import Dict from dask.distributed import default_client from dask_cuda.utils import nvml_device_index @@ -698,8 +699,12 @@ def _func_worker_ranks(client): ---------- client (object): Dask client object. """ - ranks = client.run(_get_nvml_device_index) - worker_ips = [_get_worker_ip(worker_address) for worker_address in ranks] + nvml_device_index_d = client.run(_get_nvml_device_index) + worker_ips = [ + _get_worker_ip(worker_address) + for worker_address in nvml_device_index_d + ] + ranks = _map_nvml_device_id_to_contiguous_range(nvml_device_index_d) worker_ip_offset_dict = _get_rank_offset_across_nodes(worker_ips) return _append_rank_offset(ranks, worker_ip_offset_dict) @@ -724,6 +729,38 @@ def _get_worker_ip(worker_address): return ":".join(worker_address.split(":")[0:2]) +def _map_nvml_device_id_to_contiguous_range(nvml_device_index_d: dict) -> dict: + """ + For each worker address in nvml_device_index_d, map the corresponding + worker rank in the range(0, num_workers_per_node) where rank is decided + by the NVML device index. Worker with the lowest NVML device index gets + rank 0, and worker with the highest NVML device index gets rank + num_workers_per_node-1. + + Parameters + ---------- + nvml_device_index_d : dict + Dictionary of worker addresses mapped to their nvml device index. + + Returns + ------- + dict + Updated dictionary with worker addresses mapped to their rank. + """ + + rank_per_ip: Dict[str, int] = defaultdict(int) + + # Sort by NVML index to ensure that the worker + # with the lowest NVML index gets rank 0. + for worker, _ in sorted(nvml_device_index_d.items(), key=lambda x: x[1]): + ip = _get_worker_ip(worker) + + nvml_device_index_d[worker] = rank_per_ip[ip] + rank_per_ip[ip] += 1 + + return nvml_device_index_d + + def _get_rank_offset_across_nodes(worker_ips): """ Get a dictionary of worker IP addresses mapped to the cumulative count of