From 02d8f9fe81cfe0475fd9a9007f8abe37fbbea741 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 16 May 2024 18:51:38 +0000 Subject: [PATCH] Fix errors --- test/test_devices.py | 16 +++++++++------- torch_xla/torch_xla.py | 6 ++++-- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/test/test_devices.py b/test/test_devices.py index 5ee76df3e7e..e978db3f47e 100644 --- a/test/test_devices.py +++ b/test/test_devices.py @@ -2,7 +2,10 @@ from absl.testing import absltest, parameterized import torch +from torch import nn +from torch.utils.data import TensorDataset, DataLoader import torch_xla as xla +import torch_xla.core.xla_model as xm import torch_xla.runtime as xr import torch_xla.debug.metrics as met @@ -14,8 +17,8 @@ def setUpClass(cls): xr.set_device_type('CPU') os.environ['CPU_NUM_DEVICES'] = '4' - def tearDown(self): - met.clear_metrics() + def setUp(self): + met.clear_all() @parameterized.parameters((None, torch.device('xla:0')), (0, torch.device('xla:0')), @@ -41,17 +44,17 @@ def test_sync(self): self.assertEqual(met.counter_value('MarkStep'), 1) def test_step(self): - with torch.step(): + with xla.step(): torch.ones((3, 3), device=xla.device()) self.assertEqual(met.counter_value('MarkStep'), 1) def test_step_exception(self): try: - with torch.step(): + with xla.step(): torch.ones((3, 3), device=xla.device()) raise RuntimeError("Expected error") - except: + except RuntimeError: pass self.assertEqual(met.counter_value('MarkStep'), 1) @@ -83,14 +86,13 @@ def forward(self, x): loss_fn = nn.MSELoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - for inputs, targets in data_loader: + for inputs, labels in loader: with xla.step(): inputs, labels = inputs.to(xla.device()), labels.to(xla.device()) optimizer.zero_grad() outputs = model(inputs) loss = loss_fn(outputs, labels) loss.backward() - # optimizer.step() xm.optimizer_step(optimizer) diff --git a/torch_xla/torch_xla.py b/torch_xla/torch_xla.py index 3755300afe6..fab890d0a4d 100644 --- a/torch_xla/torch_xla.py +++ b/torch_xla/torch_xla.py @@ -62,5 +62,7 @@ def step(): works with `xla.step` but does not follow best practices will become errors in future releases. See https://github.com/pytorch/xla/issues/6751 for context. """ - yield - xm.mark_step() + try: + yield + finally: + xm.mark_step()