diff --git a/README.md b/README.md index 70bdcfd57d9..b11c93ce45f 100644 --- a/README.md +++ b/README.md @@ -33,29 +33,27 @@ To update your existing training loop, make the following changes: ```diff -import torch.multiprocessing as mp ++import torch_xla as xla +import torch_xla.core.xla_model as xm -+import torch_xla.distributed.parallel_loader as pl +import torch_xla.distributed.xla_multiprocessing as xmp def _mp_fn(index): ... + # Move the model paramters to your XLA device -+ model.to(xm.xla_device()) -+ -+ # MpDeviceLoader preloads data to the XLA device -+ xla_train_loader = pl.MpDeviceLoader(train_loader, xm.xla_device()) - -- for inputs, labels in train_loader: -+ for inputs, labels in xla_train_loader: - optimizer.zero_grad() - outputs = model(inputs) - loss = loss_fn(outputs, labels) - loss.backward() -- optimizer.step() -+ -+ # `xm.optimizer_step` combines gradients across replicas -+ xm.optimizer_step() ++ model.to(xla.device()) + + for inputs, labels in train_loader: ++ with xla.step(): ++ # Transfer data to the XLA device. This happens asynchronously. ++ inputs, labels = inputs.to(xla.device()), labels.to(xla.device()) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_fn(outputs, labels) + loss.backward() +- optimizer.step() ++ # `xm.optimizer_step` combines gradients across replicas ++ xm.optimizer_step(optimizer) if __name__ == '__main__': - mp.spawn(_mp_fn, args=(), nprocs=world_size) @@ -69,8 +67,7 @@ If you're using `DistributedDataParallel`, make the following changes: ```diff import torch.distributed as dist -import torch.multiprocessing as mp -+import torch_xla.core.xla_model as xm -+import torch_xla.distributed.parallel_loader as pl ++import torch_xla as xla +import torch_xla.distributed.xla_multiprocessing as xmp +import torch_xla.distributed.xla_backend @@ -89,15 +86,15 @@ If you're using `DistributedDataParallel`, make the following changes: - model = model.to(rank) - ddp_model = DDP(model, device_ids=[rank]) -+ xla_train_loader = pl.MpDeviceLoader(train_loader, xm.xla_device()) - -- for inputs, labels in train_loader: -+ for inputs, labels in xla_train_loader: - optimizer.zero_grad() - outputs = ddp_model(inputs) - loss = loss_fn(outputs, labels) - loss.backward() - optimizer.step() + + for inputs, labels in train_loader: ++ with xla.step(): ++ inputs, labels = inputs.to(xla.device()), labels.to(xla.device()) + optimizer.zero_grad() + outputs = ddp_model(inputs) + loss = loss_fn(outputs, labels) + loss.backward() + optimizer.step() if __name__ == '__main__': - mp.spawn(_mp_fn, args=(), nprocs=world_size) diff --git a/test/test_devices.py b/test/test_devices.py index e1fc804736d..f840f0abbba 100644 --- a/test/test_devices.py +++ b/test/test_devices.py @@ -40,6 +40,57 @@ def test_sync(self): self.assertEqual(met.counter_value('MarkStep'), 1) + def test_step(self): + with torch.step(): + torch.ones((3, 3), device=xla.device()) + + self.assertEqual(met.counter_value('MarkStep'), 1) + + def test_step_exception(self): + try: + with torch.step(): + torch.ones((3, 3), device=xla.device()) + raise RuntimeError("Expected error") + except: + pass + + self.assertEqual(met.counter_value('MarkStep'), 1) + + # Should roughly match example given in README + def test_trivial_model(self): + class TrivialModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 10) + + def forward(self, x): + return self.linear(x) + + model = TrivialModel().to(xla.device()) + + batch_size = 16 + num_samples = 100 + + input_data = torch.randn(num_samples, 10) + target_data = torch.randn(num_samples, 10) + + # Create a DataLoader + dataset = TensorDataset(input_data, target_data) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + + loss_fn = nn.MSELoss() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + for inputs, targets in data_loader: + with xla.step(): + inputs, labels = inputs.to(xla.device()), labels.to(xla.device()) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_fn(outputs, labels) + loss.backward() + # optimizer.step() + xm.optimizer_step(optimizer) + if __name__ == "__main__": absltest.main() diff --git a/torch_xla/torch_xla.py b/torch_xla/torch_xla.py index 141d7e3e5a7..5f79fc27f6c 100644 --- a/torch_xla/torch_xla.py +++ b/torch_xla/torch_xla.py @@ -50,3 +50,16 @@ def device_count() -> int: def sync(): """Launches all pending graph operations.""" xm.mark_step() + + +@contextlib.contextmanager +def step(): + """Wraps code that should be dispatched to the runtime. + + Experimental: `xla.step` is still a work in progress. Some code that currently + works with `xla.step` but does not follow best practices will become errors in + future releases. See https://github.com/pytorch/xla/issues/6751 for context. + """ + yield + xm.mark_step() +