diff --git a/python/raft-dask/raft_dask/common/comms.py b/python/raft-dask/raft_dask/common/comms.py index 7a0b786ec4..118293c093 100644 --- a/python/raft-dask/raft_dask/common/comms.py +++ b/python/raft-dask/raft_dask/common/comms.py @@ -18,7 +18,8 @@ import time import uuid import warnings -from collections import Counter, OrderedDict +from collections import Counter, OrderedDict, defaultdict +from typing import Dict from dask.distributed import default_client from dask_cuda.utils import nvml_device_index @@ -157,7 +158,7 @@ def worker_info(self, workers): Builds a dictionary of { (worker_address, worker_port) : (worker_rank, worker_port ) } """ - ranks = _func_worker_ranks(self.client) + ranks = _func_worker_ranks(self.client, workers) ports = ( _func_ucp_ports(self.client, workers) if self.comms_p2p else None ) @@ -688,7 +689,7 @@ def _func_ucp_ports(client, workers): return client.run(_func_ucp_listener_port, workers=workers) -def _func_worker_ranks(client): +def _func_worker_ranks(client, workers): """ For each worker connected to the client, compute a global rank which is the sum @@ -697,9 +698,16 @@ def _func_worker_ranks(client): Parameters ---------- client (object): Dask client object. - """ - ranks = client.run(_get_nvml_device_index) - worker_ips = [_get_worker_ip(worker_address) for worker_address in ranks] + workers (list): List of worker addresses. + """ + # TODO: Add Test this function + # Running into build issues preventing testing + nvml_device_index_d = client.run(_get_nvml_device_index, workers=workers) + worker_ips = [ + _get_worker_ip(worker_address) + for worker_address in nvml_device_index_d + ] + ranks = _map_nvml_device_id_to_contiguous_range(nvml_device_index_d) worker_ip_offset_dict = _get_rank_offset_across_nodes(worker_ips) return _append_rank_offset(ranks, worker_ip_offset_dict) @@ -724,6 +732,38 @@ def _get_worker_ip(worker_address): return ":".join(worker_address.split(":")[0:2]) +def _map_nvml_device_id_to_contiguous_range(nvml_device_index_d: dict) -> dict: + """ + For each worker address in nvml_device_index_d, map the corresponding + worker rank in the range(0, num_workers_per_node) where rank is decided + by the NVML device index. Worker with the lowest NVML device index gets + rank 0, and worker with the highest NVML device index gets rank + num_workers_per_node-1. + + Parameters + ---------- + nvml_device_index_d : dict + Dictionary of worker addresses mapped to their nvml device index. + + Returns + ------- + dict + Updated dictionary with worker addresses mapped to their rank. + """ + + rank_per_ip: Dict[str, int] = defaultdict(int) + + # Sort by NVML index to ensure that the worker + # with the lowest NVML index gets rank 0. + for worker, _ in sorted(nvml_device_index_d.items(), key=lambda x: x[1]): + ip = _get_worker_ip(worker) + + nvml_device_index_d[worker] = rank_per_ip[ip] + rank_per_ip[ip] += 1 + + return nvml_device_index_d + + def _get_rank_offset_across_nodes(worker_ips): """ Get a dictionary of worker IP addresses mapped to the cumulative count of