Skip to content

Commit

Permalink
Fix a DDP graph capture issue
Browse files Browse the repository at this point in the history
  • Loading branch information
yaochengji committed Dec 12, 2024
1 parent 8e37d50 commit 31ca070
Show file tree
Hide file tree
Showing 10 changed files with 594 additions and 19 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ If you're using `DistributedDataParallel`, make the following changes:
+ dist.init_process_group("xla", init_method='xla://')
+
+ model.to(xm.xla_device())
+ # `gradient_as_bucket_view=True` required for XLA
+ ddp_model = DDP(model, gradient_as_bucket_view=True)
+ ddp_model = DDP(model)

- model = model.to(rank)
- ddp_model = DDP(model, device_ids=[rank])
Expand Down
523 changes: 522 additions & 1 deletion contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/source/learn/pjrt.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _mp_fn(index):

+ # Optional for TPU v4 and GPU
+ xm.broadcast_master_param(model)
model = DDP(model, gradient_as_bucket_view=True)
model = DDP(model)

loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=.001)
Expand Down
9 changes: 4 additions & 5 deletions docs/source/perf/ddp.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ device](../API_GUIDE.md#running-on-a-single-xla-device).
world_size = xr.world_size()
```

4. Pass `gradient_as_bucket_view=True` to the DDP wrapper.
4. Wrap the model with DDP.

``` python
ddp_model = DDP(model, gradient_as_bucket_view=True)
ddp_model = DDP(model)
```

5. Finally launch your model with xla specific launcher.
Expand Down Expand Up @@ -107,8 +107,7 @@ def demo_basic(rank):
# create model and move it to XLA device
device = xm.xla_device()
model = ToyModel().to(device)
# currently, graident_as_bucket_view is needed to make DDP work for xla
ddp_model = DDP(model, gradient_as_bucket_view=True)
ddp_model = DDP(model)

loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
Expand Down Expand Up @@ -246,6 +245,6 @@ the native xla data parallel approach, here is the
[tutorial](../API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing).

Here are some of the known issues that are under investigation: \*
`gradient_as_bucket_view=True` needs to be enforced. \* There are some
`gradient_as_bucket_view=False` needs to be enforced. \* There are some
issues while being used with `torch.utils.data.DataLoader`.
`test_train_mp_mnist.py` with real data crashes before exiting.
2 changes: 1 addition & 1 deletion examples/data_parallel/train_resnet_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self):
dist.init_process_group('xla', init_method='xla://')
super().__init__()
self.model = DDP(
self.model, gradient_as_bucket_view=True, broadcast_buffers=False)
self.model, broadcast_buffers=False)
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)


Expand Down
11 changes: 4 additions & 7 deletions test/distributed_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,18 +111,15 @@ def ddp_correctness(init_method: str = 'env://',
steps = 5 # To save test time.
cpu_model = LargeNet()

# TODO(@alanwaketan): Investigate whether we can omit the gradient_as_bucket_view option.
# TODO: There're issues in the captured graph when gradient_as_bucket_view is True
# bucket_cap_mb is set to 1 mb such that we can still have multiple all_reduces while avoiding
# using models that are too larger (25 mb).
# To be noted, DDP currently uses one bucket for the first iteration. See pytorch#73732.
ddp_model = DDP(
copy.deepcopy(cpu_model).to(device),
gradient_as_bucket_view=True,
bucket_cap_mb=1)
ddp_model = DDP(copy.deepcopy(cpu_model).to(device), bucket_cap_mb=1)
# ddp_model.register_comm_hook(state=None, hook=comp_hook)

cpu_optimizer = optim.SGD(cpu_model.parameters(), lr=1e-4)
ddp_optimizer = optim.SGD(ddp_model.parameters(), lr=1e-4)
cpu_optimizer = optim.SGD(cpu_model.parameters(), lr=1e-1)
ddp_optimizer = optim.SGD(ddp_model.parameters(), lr=1e-1)
loss_fn = nn.MSELoss()

local_batch_size = 2
Expand Down
55 changes: 55 additions & 0 deletions test/test_inplace_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import sys
import unittest

import torch
import torch_xla
import torch_xla.core.xla_model as xm


class InplaceUpdateTest(unittest.TestCase):

def test_aten_op_after_full_update(self):
device = xm.xla_device()
t = torch.ones(2, 1, device=device)
w = torch.ones(1, 2, device=device)
t.zero_()
y = torch.matmul(t, w)
expected = torch.zeros(2, 2, device=device)
xm.mark_step()
self.assertTrue(torch.all(torch.eq(y, expected)))

def test_aten_op_after_partial_update(self):
device = xm.xla_device()
t = torch.ones(2, 1, device=device)
w = torch.ones(1, 2, device=device)
t[0][0] = 0
y = torch.matmul(t, w)
expected = torch.tensor([[0, 0], [1, 1]], device=device)
xm.mark_step()
self.assertTrue(torch.all(torch.eq(y, expected)))

def test_non_aten_op_after_full_update(self):
device = xm.xla_device()
t = torch.ones(2, 1, device=device)
w = torch.ones(1, 2, device=device)
t.zero_()
y = torch_xla._XLAC._xla_dot_general(t, w, (([1], [0]), ()))
expected = torch.zeros(2, 2, device=device)
xm.mark_step()
self.assertTrue(torch.all(torch.eq(y, expected)))

def test_non_aten_op_after_partial_update(self):
device = xm.xla_device()
t = torch.ones(2, 1, device=device)
w = torch.ones(1, 2, device=device)
t[0][0] = 0
y = torch_xla._XLAC._xla_dot_general(t, w, (([1], [0]), ()))
expected = torch.tensor([[0, 0], [1, 1]], device=device)
xm.mark_step()
self.assertTrue(torch.all(torch.eq(y, expected)))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
2 changes: 1 addition & 1 deletion test/test_train_mp_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def train_imagenet():
xm.broadcast_master_param(model)

if FLAGS.ddp:
model = DDP(model, gradient_as_bucket_view=True, broadcast_buffers=False)
model = DDP(model, broadcast_buffers=False)

writer = None
if xm.is_master_ordinal():
Expand Down
2 changes: 1 addition & 1 deletion test/test_train_mp_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def train_mnist(flags, **kwargs):
xm.broadcast_master_param(model)

if flags.ddp:
model = DDP(model, gradient_as_bucket_view=True)
model = DDP(model)
writer = None
if xm.is_master_ordinal():
writer = test_utils.get_summary_writer(flags.logdir)
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/aten_xla_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ XLATensorImpl* GetXlaTensorImpl(const at::Tensor& tensor) {
} // namespace

XLATensorPtr TryGetXlaTensor(const at::Tensor& tensor) {
if (tensor.defined() && at::functionalization::impl::isFunctionalTensor(t)) {
// To make sure we have the most updated version of tensor.
at::functionalization::impl::sync(tensor);
}
XLATensorImpl* impl = GetXlaTensorImpl(tensor);
if (impl == nullptr) {
return XLATensorPtr();
Expand Down

0 comments on commit 31ca070

Please sign in to comment.