diff --git a/python/raft-dask/raft_dask/common/comms.py b/python/raft-dask/raft_dask/common/comms.py
index 118293c093..bbae24a4f7 100644
--- a/python/raft-dask/raft_dask/common/comms.py
+++ b/python/raft-dask/raft_dask/common/comms.py
@@ -18,8 +18,7 @@
 import time
 import uuid
 import warnings
-from collections import Counter, OrderedDict, defaultdict
-from typing import Dict
+from collections import OrderedDict
 
 from dask.distributed import default_client
 from dask_cuda.utils import nvml_device_index
@@ -691,9 +690,9 @@ def _func_ucp_ports(client, workers):
 
 def _func_worker_ranks(client, workers):
     """
-    For each worker connected to the client,
-    compute a global rank which is the sum
-    of the NVML device index and the worker rank offset.
+    For each worker connected to the client, compute a global rank which takes
+    into account the NVML device index and the worker IP
+    (group workers on same host and order by NVML device).
 
     Parameters
     ----------
@@ -703,13 +702,13 @@ def _func_worker_ranks(client, workers):
     # TODO: Add Test this function
     # Running into build issues preventing testing
     nvml_device_index_d = client.run(_get_nvml_device_index, workers=workers)
-    worker_ips = [
-        _get_worker_ip(worker_address)
-        for worker_address in nvml_device_index_d
+    # Sort workers first by IP and then by the nvml device index:
+    worker_info_list = [
+        (_get_worker_ip(worker), nvml_device_index, worker)
+        for worker, nvml_device_index in nvml_device_index_d.items()
     ]
-    ranks = _map_nvml_device_id_to_contiguous_range(nvml_device_index_d)
-    worker_ip_offset_dict = _get_rank_offset_across_nodes(worker_ips)
-    return _append_rank_offset(ranks, worker_ip_offset_dict)
+    worker_info_list.sort()
+    return {wi[2]: i for i, wi in enumerate(worker_info_list)}
 
 
 def _get_nvml_device_index():
@@ -730,73 +729,3 @@ def _get_worker_ip(worker_address):
         worker_address (str): Full address string of the worker
     """
     return ":".join(worker_address.split(":")[0:2])
-
-
-def _map_nvml_device_id_to_contiguous_range(nvml_device_index_d: dict) -> dict:
-    """
-    For each worker address in nvml_device_index_d, map the corresponding
-    worker rank in the range(0, num_workers_per_node) where rank is decided
-    by the NVML device index. Worker with the lowest NVML device index gets
-    rank 0, and worker with the highest NVML device index gets rank
-    num_workers_per_node-1.
-
-    Parameters
-    ----------
-    nvml_device_index_d : dict
-        Dictionary of worker addresses mapped to their nvml device index.
-
-    Returns
-    -------
-    dict
-        Updated dictionary with worker addresses mapped to their rank.
-    """
-
-    rank_per_ip: Dict[str, int] = defaultdict(int)
-
-    # Sort by NVML index to ensure that the worker
-    # with the lowest NVML index gets rank 0.
-    for worker, _ in sorted(nvml_device_index_d.items(), key=lambda x: x[1]):
-        ip = _get_worker_ip(worker)
-
-        nvml_device_index_d[worker] = rank_per_ip[ip]
-        rank_per_ip[ip] += 1
-
-    return nvml_device_index_d
-
-
-def _get_rank_offset_across_nodes(worker_ips):
-    """
-    Get a dictionary of worker IP addresses mapped to the cumulative count of
-    their occurrences in the worker_ips list. The cumulative count serves as
-    the rank offset.
-
-    Parameters
-    ----------
-        worker_ips (list): List of worker IP addresses.
-    """
-    worker_count_dict = Counter(worker_ips)
-    worker_offset_dict = {}
-    current_offset = 0
-    for worker_ip, worker_count in worker_count_dict.items():
-        worker_offset_dict[worker_ip] = current_offset
-        current_offset += worker_count
-    return worker_offset_dict
-
-
-def _append_rank_offset(rank_dict, worker_ip_offset_dict):
-    """
-    For each worker address in the rank dictionary, add the
-    corresponding worker offset from the worker_ip_offset_dict
-    to the rank value.
-
-    Parameters
-    ----------
-        rank_dict (dict): Dictionary of worker addresses mapped to their ranks.
-        worker_ip_offset_dict (dict): Dictionary of worker IP addresses
-                                      mapped to their offsets.
-    """
-    for worker_ip, worker_offset in worker_ip_offset_dict.items():
-        for worker_address in rank_dict:
-            if worker_ip in worker_address:
-                rank_dict[worker_address] += worker_offset
-    return rank_dict