diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index 519eef247a7d..fb13ba943236 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -89,9 +89,19 @@ def test_simple_model(self): self.assertNotIn('xla::add', met.counter_names()) self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo_2.cpu())) # verify that dynamo can handle different inputs - res_xla_dynamo_3 = fn_simple_dynamo(xla_x + xla_y, xla_y * 3) + xla_z = torch.randn(5, 10, device=device) + xla_xy = xla_x + xla_y + xla_y3 = xla_y * 3 + res_xla_dynamo_3 = fn_simple_dynamo(xla_xy, xla_y3) res_cpu_3 = self.fn_simple(x + y, y * 3) self.assertTrue(torch.allclose(res_cpu_3, res_xla_dynamo_3.cpu())) + # executing the compiled function should only materalize input XLATensor + self.assertIn('XLAData: None', + torch_xla._XLAC._get_xla_tensor_debug_info(xla_z)) + self.assertNotIn('XLAData: None', + torch_xla._XLAC._get_xla_tensor_debug_info(xla_xy)) + self.assertNotIn('XLAData: None', + torch_xla._XLAC._get_xla_tensor_debug_info(xla_y3)) def test_fn_without_input(self): diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index a89075a2e082..c87ede2ff5be 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -389,10 +389,14 @@ def optimized_mod(*args): # mark_step needs to be blocking since we want to access args's XLADatas # and they can't be placeholder. - if any( - torch_xla._XLAC._check_tensor_need_materialization( - [a for a in args if isinstance(a, torch.Tensor)])): - xm.mark_step(wait=True) + input_tensors_to_sync = [ + args[i] for i, x in enumerate( + torch_xla._XLAC._check_tensor_need_materialization( + [a for a in args if isinstance(a, torch.Tensor)])) if x + ] + if len(input_tensors_to_sync) > 0: + torch_xla._XLAC._xla_sync_multi( + input_tensors_to_sync, devices=[], wait=True, sync_xla_data=True) # If input sharding has changed from the previous program, dynamo current can # not detect this. It will mistakenly believe the program is the same. We need