Skip to content

Commit

Permalink
Fix a DDP graph capture issue (#8489)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaochengji authored Dec 17, 2024
1 parent 0121444 commit b1869a8
Show file tree
Hide file tree
Showing 11 changed files with 646 additions and 19 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ If you're using `DistributedDataParallel`, make the following changes:
+ dist.init_process_group("xla", init_method='xla://')
+
+ model.to(xm.xla_device())
+ # `gradient_as_bucket_view=True` required for XLA
+ ddp_model = DDP(model, gradient_as_bucket_view=True)
+ ddp_model = DDP(model)

- model = model.to(rank)
- ddp_model = DDP(model, device_ids=[rank])
Expand Down
523 changes: 522 additions & 1 deletion contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/source/learn/pjrt.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _mp_fn(index):

+ # Optional for TPU v4 and GPU
+ xm.broadcast_master_param(model)
model = DDP(model, gradient_as_bucket_view=True)
model = DDP(model)

loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=.001)
Expand Down
9 changes: 4 additions & 5 deletions docs/source/perf/ddp.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ device](../API_GUIDE.md#running-on-a-single-xla-device).
world_size = xr.world_size()
```

4. Pass `gradient_as_bucket_view=True` to the DDP wrapper.
4. Wrap the model with DDP.

``` python
ddp_model = DDP(model, gradient_as_bucket_view=True)
ddp_model = DDP(model)
```

5. Finally launch your model with xla specific launcher.
Expand Down Expand Up @@ -107,8 +107,7 @@ def demo_basic(rank):
# create model and move it to XLA device
device = xm.xla_device()
model = ToyModel().to(device)
# currently, graident_as_bucket_view is needed to make DDP work for xla
ddp_model = DDP(model, gradient_as_bucket_view=True)
ddp_model = DDP(model)

loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
Expand Down Expand Up @@ -246,6 +245,6 @@ the native xla data parallel approach, here is the
[tutorial](../API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing).

Here are some of the known issues that are under investigation: \*
`gradient_as_bucket_view=True` needs to be enforced. \* There are some
`gradient_as_bucket_view=False` needs to be enforced. \* There are some
issues while being used with `torch.utils.data.DataLoader`.
`test_train_mp_mnist.py` with real data crashes before exiting.
2 changes: 1 addition & 1 deletion examples/data_parallel/train_resnet_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self):
dist.init_process_group('xla', init_method='xla://')
super().__init__()
self.model = DDP(
self.model, gradient_as_bucket_view=True, broadcast_buffers=False)
self.model, broadcast_buffers=False)
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)


Expand Down
11 changes: 4 additions & 7 deletions test/distributed_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,18 +111,15 @@ def ddp_correctness(init_method: str = 'env://',
steps = 5 # To save test time.
cpu_model = LargeNet()

# TODO(@alanwaketan): Investigate whether we can omit the gradient_as_bucket_view option.
# TODO: There're issues in the captured graph when gradient_as_bucket_view is True
# bucket_cap_mb is set to 1 mb such that we can still have multiple all_reduces while avoiding
# using models that are too larger (25 mb).
# To be noted, DDP currently uses one bucket for the first iteration. See pytorch#73732.
ddp_model = DDP(
copy.deepcopy(cpu_model).to(device),
gradient_as_bucket_view=True,
bucket_cap_mb=1)
ddp_model = DDP(copy.deepcopy(cpu_model).to(device), bucket_cap_mb=1)
# ddp_model.register_comm_hook(state=None, hook=comp_hook)

cpu_optimizer = optim.SGD(cpu_model.parameters(), lr=1e-4)
ddp_optimizer = optim.SGD(ddp_model.parameters(), lr=1e-4)
cpu_optimizer = optim.SGD(cpu_model.parameters(), lr=1e-1)
ddp_optimizer = optim.SGD(ddp_model.parameters(), lr=1e-1)
loss_fn = nn.MSELoss()

local_batch_size = 2
Expand Down
76 changes: 76 additions & 0 deletions test/test_inplace_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import io
import sys
import unittest

import torch
import torch_xla
import torch_xla.core.xla_model as xm
from test_utils import temporary_env


class InplaceUpdateTest(unittest.TestCase):

def test_aten_op_after_full_update(self):
device = xm.xla_device()
t = torch.ones(2, 1, device=device)
w = torch.ones(1, 2, device=device)
t.zero_()
y = torch.matmul(t, w)
expected = torch.zeros(2, 2, device=device)
xm.mark_step()
self.assertTrue(torch.all(torch.eq(y, expected)))

def test_aten_op_after_partial_update(self):
device = xm.xla_device()
t = torch.ones(2, 1, device=device)
w = torch.ones(1, 2, device=device)
t[0][0] = 0
y = torch.matmul(t, w)
expected = torch.tensor([[0, 0], [1, 1]], device=device)
xm.mark_step()
self.assertTrue(torch.all(torch.eq(y, expected)))

def test_non_aten_op_after_full_update(self):
device = xm.xla_device()
t = torch.ones(2, 1, device=device)
w = torch.ones(1, 2, device=device)
t.zero_()
y = torch_xla._XLAC._xla_dot_general(t, w, (([1], [0]), ()))
expected = torch.zeros(2, 2, device=device)
xm.mark_step()
self.assertTrue(torch.all(torch.eq(y, expected)))

def test_non_aten_op_after_partial_update(self):
device = xm.xla_device()
t = torch.ones(2, 1, device=device)
w = torch.ones(1, 2, device=device)
t[0][0] = 0
y = torch_xla._XLAC._xla_dot_general(t, w, (([1], [0]), ()))
expected = torch.tensor([[0, 0], [1, 1]], device=device)
xm.mark_step()
self.assertTrue(torch.all(torch.eq(y, expected)))

def test_xm_save(self):
with temporary_env(
XLA_DISABLE_FUNCTIONALIZATION="0", XLA_ENABLE_PARAM_ALIASING="0"):
xla_device = xm.xla_device()
t1 = torch.tensor([1], device=xla_device)
t2 = t1.detach()
xm.mark_step()

t2.add_(t2)
xm.mark_step()

# mark_step() causes t1 and t2 to be out of sync on the XLA side.

fobj = io.BytesIO()
xm.save({'t1': t1}, fobj)
fobj.seek(0)
saved = torch.load(fobj)

self.assertEqual(t1.item(), saved['t1'].item())


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
2 changes: 1 addition & 1 deletion test/test_train_mp_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def train_imagenet():
xm.broadcast_master_param(model)

if FLAGS.ddp:
model = DDP(model, gradient_as_bucket_view=True, broadcast_buffers=False)
model = DDP(model, broadcast_buffers=False)

writer = None
if xm.is_master_ordinal():
Expand Down
2 changes: 1 addition & 1 deletion test/test_train_mp_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def train_mnist(flags, **kwargs):
xm.broadcast_master_param(model)

if flags.ddp:
model = DDP(model, gradient_as_bucket_view=True)
model = DDP(model)
writer = None
if xm.is_master_ordinal():
writer = test_utils.get_summary_writer(flags.logdir)
Expand Down
30 changes: 30 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections
from contextlib import contextmanager
import itertools
import math
import os
Expand Down Expand Up @@ -390,3 +391,32 @@ def runAtenTest(self, tensors, fn, device=None, rel_err=1e-2, abs_err=1e-5):
results = xu.as_list(fn(*tensors))
xla_results = xu.as_list(fn(*xla_tensors))
self.compareResults(results, xla_results, rel_err=rel_err, abs_err=abs_err)


@contextmanager
def temporary_env(**kwargs):
"""
Temporarily set environment variables within the context.
Args:
**kwargs: Key-value pairs representing environment variables to set.
For example: temporary_env(PATH='/new/path', DEBUG='1')
"""
original_env = {}

# Store original values and set new ones
for key, value in kwargs.items():
original_env[key] = os.environ.get(key, None)
os.environ[key] = value

try:
yield
finally:
# Restore original environment variables
for key, old_value in original_env.items():
if old_value is None:
# The variable was not originally set
del os.environ[key]
else:
# Restore the original value
os.environ[key] = old_value
5 changes: 5 additions & 0 deletions torch_xla/csrc/aten_xla_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ XLATensorImpl* GetXlaTensorImpl(const at::Tensor& tensor) {
} // namespace

XLATensorPtr TryGetXlaTensor(const at::Tensor& tensor) {
if (tensor.defined() &&
at::functionalization::impl::isFunctionalTensor(tensor)) {
// To make sure we have the most updated version of tensor.
at::functionalization::impl::sync(tensor);
}
XLATensorImpl* impl = GetXlaTensorImpl(tensor);
if (impl == nullptr) {
return XLATensorPtr();
Expand Down

0 comments on commit b1869a8

Please sign in to comment.