Skip to content

Commit

Permalink
Try using contiguous rank to fix cuda_visible_devices
Browse files Browse the repository at this point in the history
  • Loading branch information
VibhuJawa committed Oct 24, 2023
1 parent 945355d commit e85e010
Showing 1 changed file with 40 additions and 3 deletions.
43 changes: 40 additions & 3 deletions python/raft-dask/raft_dask/common/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import time
import uuid
import warnings
from collections import Counter, OrderedDict
from collections import Counter, OrderedDict, defaultdict
from typing import Dict

from dask.distributed import default_client
from dask_cuda.utils import nvml_device_index
Expand Down Expand Up @@ -698,8 +699,12 @@ def _func_worker_ranks(client):
----------
client (object): Dask client object.
"""
ranks = client.run(_get_nvml_device_index)
worker_ips = [_get_worker_ip(worker_address) for worker_address in ranks]
nvml_device_index_d = client.run(_get_nvml_device_index)
worker_ips = [
_get_worker_ip(worker_address)
for worker_address in nvml_device_index_d
]
ranks = _map_nvml_device_id_to_contiguous_range(nvml_device_index_d)
worker_ip_offset_dict = _get_rank_offset_across_nodes(worker_ips)
return _append_rank_offset(ranks, worker_ip_offset_dict)

Expand All @@ -724,6 +729,38 @@ def _get_worker_ip(worker_address):
return ":".join(worker_address.split(":")[0:2])


def _map_nvml_device_id_to_contiguous_range(nvml_device_index_d: dict) -> dict:
"""
For each worker address in nvml_device_index_d, map the corresponding
worker rank in the range(0, num_workers_per_node) where rank is decided
by the NVML device index. Worker with the lowest NVML device index gets
rank 0, and worker with the highest NVML device index gets rank
num_workers_per_node-1.
Parameters
----------
nvml_device_index_d : dict
Dictionary of worker addresses mapped to their nvml device index.
Returns
-------
dict
Updated dictionary with worker addresses mapped to their rank.
"""

rank_per_ip: Dict[str, int] = defaultdict(int)

# Sort by NVML index to ensure that the worker
# with the lowest NVML index gets rank 0.
for worker, _ in sorted(nvml_device_index_d.items(), key=lambda x: x[1]):
ip = _get_worker_ip(worker)

nvml_device_index_d[worker] = rank_per_ip[ip]
rank_per_ip[ip] += 1

return nvml_device_index_d


def _get_rank_offset_across_nodes(worker_ips):
"""
Get a dictionary of worker IP addresses mapped to the cumulative count of
Expand Down

0 comments on commit e85e010

Please sign in to comment.