diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 3b1cd1773af19..308b874280f55 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -25,7 +25,7 @@ def graph_allreduce(world_size, rank, distributed_init_port): init_test_distributed_environment(1, world_size, rank, distributed_init_port) - custom_all_reduce.init_custom_all_reduce() + custom_all_reduce.init_custom_ar() for sz in test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: with custom_all_reduce.capture(): @@ -61,7 +61,7 @@ def eager_allreduce(world_size, rank, distributed_init_port): distributed_init_port) sz = 1024 - custom_all_reduce.init_custom_all_reduce() + custom_all_reduce.init_custom_ar() fa = custom_all_reduce.get_handle() inp = torch.ones(sz, dtype=torch.float32, device=device) out = fa.all_reduce_unreg(inp) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index cc5f8166877ce..5d26254fb832a 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -52,6 +52,10 @@ def init_custom_ar() -> None: "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'" " is set.") return + + # we only use a subset of GPUs here + # so we only need to check the nvlink connectivity of these GPUs + num_dev = world_size # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES