Skip to content

Commit

Permalink
Add xla.step context manager
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed May 15, 2024
1 parent c6074ab commit 1a496c0
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 27 deletions.
51 changes: 24 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,29 +33,27 @@ To update your existing training loop, make the following changes:

```diff
-import torch.multiprocessing as mp
+import torch_xla as xla
+import torch_xla.core.xla_model as xm
+import torch_xla.distributed.parallel_loader as pl
+import torch_xla.distributed.xla_multiprocessing as xmp

def _mp_fn(index):
...

+ # Move the model paramters to your XLA device
+ model.to(xm.xla_device())
+
+ # MpDeviceLoader preloads data to the XLA device
+ xla_train_loader = pl.MpDeviceLoader(train_loader, xm.xla_device())

- for inputs, labels in train_loader:
+ for inputs, labels in xla_train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
- optimizer.step()
+
+ # `xm.optimizer_step` combines gradients across replicas
+ xm.optimizer_step()
+ model.to(xla.device())

for inputs, labels in train_loader:
+ with xla.step():
+ # Transfer data to the XLA device. This happens asynchronously.
+ inputs, labels = inputs.to(xla.device()), labels.to(xla.device())
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
- optimizer.step()
+ # `xm.optimizer_step` combines gradients across replicas
+ xm.optimizer_step(optimizer)

if __name__ == '__main__':
- mp.spawn(_mp_fn, args=(), nprocs=world_size)
Expand All @@ -69,8 +67,7 @@ If you're using `DistributedDataParallel`, make the following changes:
```diff
import torch.distributed as dist
-import torch.multiprocessing as mp
+import torch_xla.core.xla_model as xm
+import torch_xla.distributed.parallel_loader as pl
+import torch_xla as xla
+import torch_xla.distributed.xla_multiprocessing as xmp
+import torch_xla.distributed.xla_backend

Expand All @@ -89,15 +86,15 @@ If you're using `DistributedDataParallel`, make the following changes:

- model = model.to(rank)
- ddp_model = DDP(model, device_ids=[rank])
+ xla_train_loader = pl.MpDeviceLoader(train_loader, xm.xla_device())

- for inputs, labels in train_loader:
+ for inputs, labels in xla_train_loader:
optimizer.zero_grad()
outputs = ddp_model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()

for inputs, labels in train_loader:
+ with xla.step():
+ inputs, labels = inputs.to(xla.device()), labels.to(xla.device())
optimizer.zero_grad()
outputs = ddp_model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()

if __name__ == '__main__':
- mp.spawn(_mp_fn, args=(), nprocs=world_size)
Expand Down
51 changes: 51 additions & 0 deletions test/test_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,57 @@ def test_sync(self):

self.assertEqual(met.counter_value('MarkStep'), 1)

def test_step(self):
with torch.step():
torch.ones((3, 3), device=xla.device())

self.assertEqual(met.counter_value('MarkStep'), 1)

def test_step_exception(self):
try:
with torch.step():
torch.ones((3, 3), device=xla.device())
raise RuntimeError("Expected error")
except:
pass

self.assertEqual(met.counter_value('MarkStep'), 1)

# Should roughly match example given in README
def test_trivial_model(self):
class TrivialModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)

def forward(self, x):
return self.linear(x)

model = TrivialModel().to(xla.device())

batch_size = 16
num_samples = 100

input_data = torch.randn(num_samples, 10)
target_data = torch.randn(num_samples, 10)

# Create a DataLoader
dataset = TensorDataset(input_data, target_data)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

for inputs, targets in data_loader:
with xla.step():
inputs, labels = inputs.to(xla.device()), labels.to(xla.device())
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
# optimizer.step()
xm.optimizer_step(optimizer)


if __name__ == "__main__":
absltest.main()
13 changes: 13 additions & 0 deletions torch_xla/torch_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,16 @@ def device_count() -> int:
def sync():
"""Launches all pending graph operations."""
xm.mark_step()


@contextlib.contextmanager
def step():
"""Wraps code that should be dispatched to the runtime.
Experimental: `xla.step` is still a work in progress. Some code that currently
works with `xla.step` but does not follow best practices will become errors in
future releases. See https://github.com/pytorch/xla/issues/6751 for context.
"""
yield
xm.mark_step()

0 comments on commit 1a496c0

Please sign in to comment.