diff --git a/test/test_devices.py b/test/test_devices.py index e978db3f47e..4588940fb80 100644 --- a/test/test_devices.py +++ b/test/test_devices.py @@ -50,12 +50,10 @@ def test_step(self): self.assertEqual(met.counter_value('MarkStep'), 1) def test_step_exception(self): - try: + with self.assertRaisesRegex(RuntimeError, 'Expected error'): with xla.step(): torch.ones((3, 3), device=xla.device()) - raise RuntimeError("Expected error") - except RuntimeError: - pass + raise RuntimeError('Expected error') self.assertEqual(met.counter_value('MarkStep'), 1)