From 36484f48fe5dd60939fb7e0610d9c5091c0780b2 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Fri, 15 Mar 2024 14:25:10 -0700 Subject: [PATCH] MAINT: Simplify NCCL worker rank identification (#2228) This PR is based on @seberg work in https://github.com/rapidsai/raft/pull/1928 . From the PR: This is a follow up on https://github.com/rapidsai/raft/pull/1926, since the rank sorting seemed a bit hard to understand. It does modify the logic in the sense that the host is now sorted by IP as a way to group based on it. But I don't really think that host sorting was ever a goal? If the goal is really about being deterministic, then this should be more (or at least clearer) deterministic about order of worker IPs. OTOH, if the NVML device order doesn't matter, we could just sort the workers directly. The original https://github.com/rapidsai/raft/pull/1587 mentions: NCCL>1.11 expects a process with rank r to be mapped to r % num_gpus_per_node which is something that neither approach seems to quite assure, if such a requirement exists, I would want to do one of: Ensure we can guarantee this, but this requires initializing workers that are not involved in the operation. At least raise an error, because if NCCL will end up raising the error it will be very confusing. Authors: - Vibhu Jawa (https://github.com/VibhuJawa) - Sebastian Berg (https://github.com/seberg) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2228 --- python/raft-dask/raft_dask/common/comms.py | 95 +++---------------- python/raft-dask/raft_dask/test/test_comms.py | 15 ++- 2 files changed, 27 insertions(+), 83 deletions(-) diff --git a/python/raft-dask/raft_dask/common/comms.py b/python/raft-dask/raft_dask/common/comms.py index 118293c093..b2f7d1fb74 100644 --- a/python/raft-dask/raft_dask/common/comms.py +++ b/python/raft-dask/raft_dask/common/comms.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,8 +18,7 @@ import time import uuid import warnings -from collections import Counter, OrderedDict, defaultdict -from typing import Dict +from collections import OrderedDict from dask.distributed import default_client from dask_cuda.utils import nvml_device_index @@ -691,9 +690,11 @@ def _func_ucp_ports(client, workers): def _func_worker_ranks(client, workers): """ - For each worker connected to the client, - compute a global rank which is the sum - of the NVML device index and the worker rank offset. + For each worker connected to the client, compute a global rank which takes + into account the NVML device index and the worker IP + (group workers on same host and order by NVML device). + Note that the reason for sorting was nvbug 4149999 and is presumably + fixed afterNCCL 2.19.3. Parameters ---------- @@ -703,13 +704,13 @@ def _func_worker_ranks(client, workers): # TODO: Add Test this function # Running into build issues preventing testing nvml_device_index_d = client.run(_get_nvml_device_index, workers=workers) - worker_ips = [ - _get_worker_ip(worker_address) - for worker_address in nvml_device_index_d + # Sort workers first by IP and then by the nvml device index: + worker_info_list = [ + (_get_worker_ip(worker), nvml_device_index, worker) + for worker, nvml_device_index in nvml_device_index_d.items() ] - ranks = _map_nvml_device_id_to_contiguous_range(nvml_device_index_d) - worker_ip_offset_dict = _get_rank_offset_across_nodes(worker_ips) - return _append_rank_offset(ranks, worker_ip_offset_dict) + worker_info_list.sort() + return {wi[2]: i for i, wi in enumerate(worker_info_list)} def _get_nvml_device_index(): @@ -730,73 +731,3 @@ def _get_worker_ip(worker_address): worker_address (str): Full address string of the worker """ return ":".join(worker_address.split(":")[0:2]) - - -def _map_nvml_device_id_to_contiguous_range(nvml_device_index_d: dict) -> dict: - """ - For each worker address in nvml_device_index_d, map the corresponding - worker rank in the range(0, num_workers_per_node) where rank is decided - by the NVML device index. Worker with the lowest NVML device index gets - rank 0, and worker with the highest NVML device index gets rank - num_workers_per_node-1. - - Parameters - ---------- - nvml_device_index_d : dict - Dictionary of worker addresses mapped to their nvml device index. - - Returns - ------- - dict - Updated dictionary with worker addresses mapped to their rank. - """ - - rank_per_ip: Dict[str, int] = defaultdict(int) - - # Sort by NVML index to ensure that the worker - # with the lowest NVML index gets rank 0. - for worker, _ in sorted(nvml_device_index_d.items(), key=lambda x: x[1]): - ip = _get_worker_ip(worker) - - nvml_device_index_d[worker] = rank_per_ip[ip] - rank_per_ip[ip] += 1 - - return nvml_device_index_d - - -def _get_rank_offset_across_nodes(worker_ips): - """ - Get a dictionary of worker IP addresses mapped to the cumulative count of - their occurrences in the worker_ips list. The cumulative count serves as - the rank offset. - - Parameters - ---------- - worker_ips (list): List of worker IP addresses. - """ - worker_count_dict = Counter(worker_ips) - worker_offset_dict = {} - current_offset = 0 - for worker_ip, worker_count in worker_count_dict.items(): - worker_offset_dict[worker_ip] = current_offset - current_offset += worker_count - return worker_offset_dict - - -def _append_rank_offset(rank_dict, worker_ip_offset_dict): - """ - For each worker address in the rank dictionary, add the - corresponding worker offset from the worker_ip_offset_dict - to the rank value. - - Parameters - ---------- - rank_dict (dict): Dictionary of worker addresses mapped to their ranks. - worker_ip_offset_dict (dict): Dictionary of worker IP addresses - mapped to their offsets. - """ - for worker_ip, worker_offset in worker_ip_offset_dict.items(): - for worker_address in rank_dict: - if worker_ip in worker_address: - rank_dict[worker_address] += worker_offset - return rank_dict diff --git a/python/raft-dask/raft_dask/test/test_comms.py b/python/raft-dask/raft_dask/test/test_comms.py index 68c9fee556..b62d7185b2 100644 --- a/python/raft-dask/raft_dask/test/test_comms.py +++ b/python/raft-dask/raft_dask/test/test_comms.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -354,3 +354,16 @@ def test_device_multicast_sendrecv(n_trials, client): wait(dfs, timeout=5) assert list(map(lambda x: x.result(), dfs)) + + +@pytest.mark.nccl +@pytest.mark.parametrize( + "subset", [slice(-1, None), slice(1), slice(None, None, -2)] +) +def test_comm_init_worker_subset(client, subset): + # Basic test that initializing a subset of workers is fine + cb = Comms(comms_p2p=True, verbose=True) + + workers = list(client.scheduler_info()["workers"].keys()) + workers = workers[subset] + cb.init(workers=workers)