From 476edafe7320e080c03a719b906c182d52cc6fa9 Mon Sep 17 00:00:00 2001 From: Mason Chang Date: Thu, 29 Feb 2024 14:41:09 +0000 Subject: [PATCH] Fixup test dynamo failures --- test/dynamo/test_dynamo.py | 36 ++++++++++++++++----------------- torch_xla/core/dynamo_bridge.py | 2 +- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index a22b1b8f096..a1078af2f54 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -119,6 +119,23 @@ def test_simple_model_automoves_tensors(self): self.assertTrue(res_cpu_3.device == res_xla_dynamo_different.device) self.assertTrue(torch.allclose(res_cpu_3, res_xla_dynamo_different)) + def test_resnet_all_cpu_tensor_moved_to_xla(self): + met.clear_all() + input = torch.randn(4, 3, 224, 224) + resnet18 = torchvision.models.resnet18() + resnet18.eval() + dynamo_resnet18_cpu = torch.compile(resnet18, backend='openxla') + + # input and model weight on cpu + with warnings.catch_warnings(record=True) as w: + res = dynamo_resnet18_cpu(input) + # there should be 18 paramters + 1 input all moved to XLA Device. + self.assertTrue(len(w) == 0) + + # Ops should work automatically and XLA should "just work" + self.assertTrue(len(met.counter_names()) > 1) + self.assertIn('MarkStep', met.counter_names()) + def test_fn_without_input(self): def fn_without_input(device): @@ -542,7 +559,7 @@ def test_resnet18(self): met.metric_data('RunCachedGraphOutputData')[0], sample_count * 3) -class DynamErrorMessageTest(unittest.TestCase): +class DynamoErrorMessageTest(unittest.TestCase): def test_mixed_cpu_tensor(self): device = xm.xla_device() @@ -566,23 +583,6 @@ def test_mixed_cpu_tensor(self): self.assertTrue( 'found two different devices' in context.exception.__str__()) - def test_all_cpu_tensor(self): - met.clear_all() - input = torch.randn(4, 3, 224, 224) - resnet18 = torchvision.models.resnet18() - resnet18.eval() - dynamo_resnet18_cpu = torch.compile(resnet18, backend='openxla') - # input and model weight on cpu - with warnings.catch_warnings(record=True) as w: - res = dynamo_resnet18_cpu(input) - # there should be 18 paramters + 1 input - self.assertGreater(len(w), 15) - self.assertIn('Found tensor with shape torch.Size', str(w[0].message)) - # no XLA operation should happens except a empty mark_step. Partitioner should offload all CPU - # ops to CPU. - self.assertEqual(len(met.counter_names()), 1) - self.assertIn('MarkStep', met.counter_names()) - if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index 4c4f73998c8..73345578be0 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -613,7 +613,7 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, input_args): for xla_arg in xla_args: assert xla_arg.device.type == 'xla', "Found tensor with shape " + str( - xla_arg.size()) + " on " + str(xla_arg.device) + xla_arg.size()) + " on non-XLA device: " + str(xla_arg.device) cloned_args = [ torch.clone(xla_arg) if isinstance(xla_arg, torch.Tensor) else xla_arg