Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a DDP graph capture issue #8489

Merged
merged 3 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ If you're using `DistributedDataParallel`, make the following changes:
+ dist.init_process_group("xla", init_method='xla://')
+
+ model.to(xm.xla_device())
+ # `gradient_as_bucket_view=True` required for XLA
+ ddp_model = DDP(model, gradient_as_bucket_view=True)
+ ddp_model = DDP(model)

- model = model.to(rank)
- ddp_model = DDP(model, device_ids=[rank])
Expand Down
523 changes: 522 additions & 1 deletion contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/source/learn/pjrt.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _mp_fn(index):

+ # Optional for TPU v4 and GPU
+ xm.broadcast_master_param(model)
model = DDP(model, gradient_as_bucket_view=True)
model = DDP(model)

loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=.001)
Expand Down
9 changes: 4 additions & 5 deletions docs/source/perf/ddp.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ device](../API_GUIDE.md#running-on-a-single-xla-device).
world_size = xr.world_size()
```

4. Pass `gradient_as_bucket_view=True` to the DDP wrapper.
4. Wrap the model with DDP.

``` python
ddp_model = DDP(model, gradient_as_bucket_view=True)
ddp_model = DDP(model)
```

5. Finally launch your model with xla specific launcher.
Expand Down Expand Up @@ -107,8 +107,7 @@ def demo_basic(rank):
# create model and move it to XLA device
device = xm.xla_device()
model = ToyModel().to(device)
# currently, graident_as_bucket_view is needed to make DDP work for xla
ddp_model = DDP(model, gradient_as_bucket_view=True)
ddp_model = DDP(model)

loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
Expand Down Expand Up @@ -246,6 +245,6 @@ the native xla data parallel approach, here is the
[tutorial](../API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing).

Here are some of the known issues that are under investigation: \*
`gradient_as_bucket_view=True` needs to be enforced. \* There are some
`gradient_as_bucket_view=False` needs to be enforced. \* There are some
issues while being used with `torch.utils.data.DataLoader`.
`test_train_mp_mnist.py` with real data crashes before exiting.
2 changes: 1 addition & 1 deletion examples/data_parallel/train_resnet_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self):
dist.init_process_group('xla', init_method='xla://')
super().__init__()
self.model = DDP(
self.model, gradient_as_bucket_view=True, broadcast_buffers=False)
self.model, broadcast_buffers=False)
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)


Expand Down
11 changes: 4 additions & 7 deletions test/distributed_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,18 +111,15 @@ def ddp_correctness(init_method: str = 'env://',
steps = 5 # To save test time.
cpu_model = LargeNet()

# TODO(@alanwaketan): Investigate whether we can omit the gradient_as_bucket_view option.
# TODO: There're issues in the captured graph when gradient_as_bucket_view is True
# bucket_cap_mb is set to 1 mb such that we can still have multiple all_reduces while avoiding
# using models that are too larger (25 mb).
# To be noted, DDP currently uses one bucket for the first iteration. See pytorch#73732.
ddp_model = DDP(
copy.deepcopy(cpu_model).to(device),
gradient_as_bucket_view=True,
bucket_cap_mb=1)
ddp_model = DDP(copy.deepcopy(cpu_model).to(device), bucket_cap_mb=1)
# ddp_model.register_comm_hook(state=None, hook=comp_hook)

cpu_optimizer = optim.SGD(cpu_model.parameters(), lr=1e-4)
ddp_optimizer = optim.SGD(ddp_model.parameters(), lr=1e-4)
cpu_optimizer = optim.SGD(cpu_model.parameters(), lr=1e-1)
ddp_optimizer = optim.SGD(ddp_model.parameters(), lr=1e-1)
loss_fn = nn.MSELoss()

local_batch_size = 2
Expand Down
76 changes: 76 additions & 0 deletions test/test_inplace_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import io
import sys
import unittest

import torch
import torch_xla
import torch_xla.core.xla_model as xm
from test_utils import temporary_env


class InplaceUpdateTest(unittest.TestCase):

def test_aten_op_after_full_update(self):
device = xm.xla_device()
t = torch.ones(2, 1, device=device)
w = torch.ones(1, 2, device=device)
t.zero_()
y = torch.matmul(t, w)
expected = torch.zeros(2, 2, device=device)
xm.mark_step()
self.assertTrue(torch.all(torch.eq(y, expected)))

def test_aten_op_after_partial_update(self):
device = xm.xla_device()
t = torch.ones(2, 1, device=device)
w = torch.ones(1, 2, device=device)
t[0][0] = 0
y = torch.matmul(t, w)
expected = torch.tensor([[0, 0], [1, 1]], device=device)
xm.mark_step()
self.assertTrue(torch.all(torch.eq(y, expected)))

def test_non_aten_op_after_full_update(self):
device = xm.xla_device()
t = torch.ones(2, 1, device=device)
w = torch.ones(1, 2, device=device)
t.zero_()
y = torch_xla._XLAC._xla_dot_general(t, w, (([1], [0]), ()))
expected = torch.zeros(2, 2, device=device)
xm.mark_step()
self.assertTrue(torch.all(torch.eq(y, expected)))

def test_non_aten_op_after_partial_update(self):
device = xm.xla_device()
t = torch.ones(2, 1, device=device)
w = torch.ones(1, 2, device=device)
t[0][0] = 0
y = torch_xla._XLAC._xla_dot_general(t, w, (([1], [0]), ()))
expected = torch.tensor([[0, 0], [1, 1]], device=device)
xm.mark_step()
self.assertTrue(torch.all(torch.eq(y, expected)))

def test_xm_save(self):
with temporary_env(
XLA_DISABLE_FUNCTIONALIZATION="0", XLA_ENABLE_PARAM_ALIASING="0"):
xla_device = xm.xla_device()
t1 = torch.tensor([1], device=xla_device)
t2 = t1.detach()
xm.mark_step()

t2.add_(t2)
xm.mark_step()

# mark_step() causes t1 and t2 to be out of sync on the XLA side.

fobj = io.BytesIO()
xm.save({'t1': t1}, fobj)
fobj.seek(0)
saved = torch.load(fobj)

self.assertEqual(t1.item(), saved['t1'].item())


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
2 changes: 1 addition & 1 deletion test/test_train_mp_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def train_imagenet():
xm.broadcast_master_param(model)

if FLAGS.ddp:
model = DDP(model, gradient_as_bucket_view=True, broadcast_buffers=False)
model = DDP(model, broadcast_buffers=False)

writer = None
if xm.is_master_ordinal():
Expand Down
2 changes: 1 addition & 1 deletion test/test_train_mp_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def train_mnist(flags, **kwargs):
xm.broadcast_master_param(model)

if flags.ddp:
model = DDP(model, gradient_as_bucket_view=True)
model = DDP(model)
writer = None
if xm.is_master_ordinal():
writer = test_utils.get_summary_writer(flags.logdir)
Expand Down
30 changes: 30 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections
from contextlib import contextmanager
import itertools
import math
import os
Expand Down Expand Up @@ -390,3 +391,32 @@ def runAtenTest(self, tensors, fn, device=None, rel_err=1e-2, abs_err=1e-5):
results = xu.as_list(fn(*tensors))
xla_results = xu.as_list(fn(*xla_tensors))
self.compareResults(results, xla_results, rel_err=rel_err, abs_err=abs_err)


@contextmanager
def temporary_env(**kwargs):
"""
Temporarily set environment variables within the context.

Args:
**kwargs: Key-value pairs representing environment variables to set.
For example: temporary_env(PATH='/new/path', DEBUG='1')
"""
original_env = {}

# Store original values and set new ones
for key, value in kwargs.items():
original_env[key] = os.environ.get(key, None)
os.environ[key] = value

try:
yield
finally:
# Restore original environment variables
for key, old_value in original_env.items():
if old_value is None:
# The variable was not originally set
del os.environ[key]
else:
# Restore the original value
os.environ[key] = old_value
5 changes: 5 additions & 0 deletions torch_xla/csrc/aten_xla_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ XLATensorImpl* GetXlaTensorImpl(const at::Tensor& tensor) {
} // namespace

XLATensorPtr TryGetXlaTensor(const at::Tensor& tensor) {
if (tensor.defined() &&
at::functionalization::impl::isFunctionalTensor(tensor)) {
// To make sure we have the most updated version of tensor.
at::functionalization::impl::sync(tensor);
}
XLATensorImpl* impl = GetXlaTensorImpl(tensor);
if (impl == nullptr) {
return XLATensorPtr();
Expand Down
Loading