Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deadlock attempting to do concurrent send, receive #72

Open
pspillai opened this issue Sep 24, 2024 · 2 comments
Open

Deadlock attempting to do concurrent send, receive #72

pspillai opened this issue Sep 24, 2024 · 2 comments

Comments

@pspillai
Copy link

I am trying to implement a concurrent asynchronous send and receive between multiple processes. This results in deadlock. Minimum code to reproduce this is as follows:

import torch.nn.parallel
import torch.distributed as dist
import intel_extension_for_pytorch as ipex
import oneccl_bindings_for_pytorch
import os

os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
os.environ['RANK'] = str(os.environ.get('PMI_RANK', 0))
os.environ['WORLD_SIZE'] = str(os.environ.get('PMI_SIZE', 1))

print (os.environ['RANK'], os.environ['WORLD_SIZE'])
backend = 'ccl'
dist.init_process_group(backend)
my_rank = dist.get_rank()
my_size = dist.get_world_size()
print("my rank = %d  my size = %d" % (my_rank, my_size))

dev = f"xpu:{my_rank}"
torch.xpu.set_device(my_rank)
A = torch.ones(1,2, dtype=torch.float32).to(dev)
_ = A[0,0].item()
B = torch.zeros(1,2, dtype=torch.float32).to(dev)
_ = B[0,0].item()

dist.barrier()

dist.all_reduce(A)

print ("START")
o1 = dist.isend(A,1-my_rank)
o2 = dist.irecv(B,1-my_rank)
o1.wait()
o2.wait()

print ("DONE")

Run with

mpirun -n 2 python -u test.py

This sounds like the isend and irecv on each process is serialized. This particular example can complete if one process does send first and the other recv first, but I think they are still being serialized, so the two transfers are not concurrent.

I tried to use batch_isend_irecv to define a list of transfers, but this resulted in the same deadlock.
Without concurrent transfers, it is almost impossible to implement efficient distributed compute and shift algorithms or Cannon's algorithms, etc.

@gaopengff
Copy link
Contributor

Now torch-ccl only support one rank do sending and another do receiving at the same time, if you change the code to

if my_rank == 0:
    o1 = dist.isend(A,1-my_rank)
    o1.wait()
else:
    o2 = dist.irecv(B,1-my_rank)
    o2.wait()

It will works. Did you run your test with cuda's nccl. If it works with cuda, I think this is design issue of torch-ccl.

@pspillai
Copy link
Author

pspillai commented Oct 9, 2024

Yes, if the send and receive ordering is matched, it will work, but this causes the transmissions to be serialized, wasting half of the available bandwidth. (There should be no reason why the two transfers cannot be done concurrently).

I have not tested on nccl, however looking at the sample code for torch.distributed.batch_isend_irecv: https://pytorch.org/docs/stable/distributed.html#torch.distributed.batch_isend_irecv
and the source code at:
https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#batch_isend_irecv
it looks like the batch_isend_irecv is just calling the isend/irecv operations in the order provided, which in this example is the same for each rank. So I expect this to work fine on nccl.

Not surprisingly, batch_isend_irecv locks up with this example using ccl.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants