Skip to content

Commit

Permalink
[Core][Test] fix function name typo in custom allreduce (vllm-project…
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored May 10, 2024
1 parent fcc2994 commit 4e12131
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/distributed/test_custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def graph_allreduce(world_size, rank, distributed_init_port):
init_test_distributed_environment(1, world_size, rank,
distributed_init_port)

custom_all_reduce.init_custom_all_reduce()
custom_all_reduce.init_custom_ar()
for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
with custom_all_reduce.capture():
Expand Down Expand Up @@ -61,7 +61,7 @@ def eager_allreduce(world_size, rank, distributed_init_port):
distributed_init_port)

sz = 1024
custom_all_reduce.init_custom_all_reduce()
custom_all_reduce.init_custom_ar()
fa = custom_all_reduce.get_handle()
inp = torch.ones(sz, dtype=torch.float32, device=device)
out = fa.all_reduce_unreg(inp)
Expand Down
4 changes: 4 additions & 0 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def init_custom_ar() -> None:
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set.")
return

# we only use a subset of GPUs here
# so we only need to check the nvlink connectivity of these GPUs
num_dev = world_size
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
Expand Down

0 comments on commit 4e12131

Please sign in to comment.