diff --git a/python/raft-dask/raft_dask/common/comms.py b/python/raft-dask/raft_dask/common/comms.py index 118293c093..bbae24a4f7 100644 --- a/python/raft-dask/raft_dask/common/comms.py +++ b/python/raft-dask/raft_dask/common/comms.py @@ -18,8 +18,7 @@ import time import uuid import warnings -from collections import Counter, OrderedDict, defaultdict -from typing import Dict +from collections import OrderedDict from dask.distributed import default_client from dask_cuda.utils import nvml_device_index @@ -691,9 +690,9 @@ def _func_ucp_ports(client, workers): def _func_worker_ranks(client, workers): """ - For each worker connected to the client, - compute a global rank which is the sum - of the NVML device index and the worker rank offset. + For each worker connected to the client, compute a global rank which takes + into account the NVML device index and the worker IP + (group workers on same host and order by NVML device). Parameters ---------- @@ -703,13 +702,13 @@ def _func_worker_ranks(client, workers): # TODO: Add Test this function # Running into build issues preventing testing nvml_device_index_d = client.run(_get_nvml_device_index, workers=workers) - worker_ips = [ - _get_worker_ip(worker_address) - for worker_address in nvml_device_index_d + # Sort workers first by IP and then by the nvml device index: + worker_info_list = [ + (_get_worker_ip(worker), nvml_device_index, worker) + for worker, nvml_device_index in nvml_device_index_d.items() ] - ranks = _map_nvml_device_id_to_contiguous_range(nvml_device_index_d) - worker_ip_offset_dict = _get_rank_offset_across_nodes(worker_ips) - return _append_rank_offset(ranks, worker_ip_offset_dict) + worker_info_list.sort() + return {wi[2]: i for i, wi in enumerate(worker_info_list)} def _get_nvml_device_index(): @@ -730,73 +729,3 @@ def _get_worker_ip(worker_address): worker_address (str): Full address string of the worker """ return ":".join(worker_address.split(":")[0:2]) - - -def _map_nvml_device_id_to_contiguous_range(nvml_device_index_d: dict) -> dict: - """ - For each worker address in nvml_device_index_d, map the corresponding - worker rank in the range(0, num_workers_per_node) where rank is decided - by the NVML device index. Worker with the lowest NVML device index gets - rank 0, and worker with the highest NVML device index gets rank - num_workers_per_node-1. - - Parameters - ---------- - nvml_device_index_d : dict - Dictionary of worker addresses mapped to their nvml device index. - - Returns - ------- - dict - Updated dictionary with worker addresses mapped to their rank. - """ - - rank_per_ip: Dict[str, int] = defaultdict(int) - - # Sort by NVML index to ensure that the worker - # with the lowest NVML index gets rank 0. - for worker, _ in sorted(nvml_device_index_d.items(), key=lambda x: x[1]): - ip = _get_worker_ip(worker) - - nvml_device_index_d[worker] = rank_per_ip[ip] - rank_per_ip[ip] += 1 - - return nvml_device_index_d - - -def _get_rank_offset_across_nodes(worker_ips): - """ - Get a dictionary of worker IP addresses mapped to the cumulative count of - their occurrences in the worker_ips list. The cumulative count serves as - the rank offset. - - Parameters - ---------- - worker_ips (list): List of worker IP addresses. - """ - worker_count_dict = Counter(worker_ips) - worker_offset_dict = {} - current_offset = 0 - for worker_ip, worker_count in worker_count_dict.items(): - worker_offset_dict[worker_ip] = current_offset - current_offset += worker_count - return worker_offset_dict - - -def _append_rank_offset(rank_dict, worker_ip_offset_dict): - """ - For each worker address in the rank dictionary, add the - corresponding worker offset from the worker_ip_offset_dict - to the rank value. - - Parameters - ---------- - rank_dict (dict): Dictionary of worker addresses mapped to their ranks. - worker_ip_offset_dict (dict): Dictionary of worker IP addresses - mapped to their offsets. - """ - for worker_ip, worker_offset in worker_ip_offset_dict.items(): - for worker_address in rank_dict: - if worker_ip in worker_address: - rank_dict[worker_address] += worker_offset - return rank_dict