Skip to content

Commit

Permalink
Fix errors
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed May 16, 2024
1 parent a906b0c commit 02d8f9f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
16 changes: 9 additions & 7 deletions test/test_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

from absl.testing import absltest, parameterized
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
import torch_xla as xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.debug.metrics as met

Expand All @@ -14,8 +17,8 @@ def setUpClass(cls):
xr.set_device_type('CPU')
os.environ['CPU_NUM_DEVICES'] = '4'

def tearDown(self):
met.clear_metrics()
def setUp(self):
met.clear_all()

@parameterized.parameters((None, torch.device('xla:0')),
(0, torch.device('xla:0')),
Expand All @@ -41,17 +44,17 @@ def test_sync(self):
self.assertEqual(met.counter_value('MarkStep'), 1)

def test_step(self):
with torch.step():
with xla.step():
torch.ones((3, 3), device=xla.device())

self.assertEqual(met.counter_value('MarkStep'), 1)

def test_step_exception(self):
try:
with torch.step():
with xla.step():
torch.ones((3, 3), device=xla.device())
raise RuntimeError("Expected error")
except:
except RuntimeError:
pass

self.assertEqual(met.counter_value('MarkStep'), 1)
Expand Down Expand Up @@ -83,14 +86,13 @@ def forward(self, x):
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

for inputs, targets in data_loader:
for inputs, labels in loader:
with xla.step():
inputs, labels = inputs.to(xla.device()), labels.to(xla.device())
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
# optimizer.step()
xm.optimizer_step(optimizer)


Expand Down
6 changes: 4 additions & 2 deletions torch_xla/torch_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,7 @@ def step():
works with `xla.step` but does not follow best practices will become errors in
future releases. See https://github.com/pytorch/xla/issues/6751 for context.
"""
yield
xm.mark_step()
try:
yield
finally:
xm.mark_step()

0 comments on commit 02d8f9f

Please sign in to comment.