Skip to content

Commit

Permalink
Dynamo bridge should only sync input XLATensor during execution (pyto…
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored and amithrm committed Mar 1, 2024
1 parent 387d380 commit da4993c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
12 changes: 11 additions & 1 deletion test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,19 @@ def test_simple_model(self):
self.assertNotIn('xla::add', met.counter_names())
self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo_2.cpu()))
# verify that dynamo can handle different inputs
res_xla_dynamo_3 = fn_simple_dynamo(xla_x + xla_y, xla_y * 3)
xla_z = torch.randn(5, 10, device=device)
xla_xy = xla_x + xla_y
xla_y3 = xla_y * 3
res_xla_dynamo_3 = fn_simple_dynamo(xla_xy, xla_y3)
res_cpu_3 = self.fn_simple(x + y, y * 3)
self.assertTrue(torch.allclose(res_cpu_3, res_xla_dynamo_3.cpu()))
# executing the compiled function should only materalize input XLATensor
self.assertIn('XLAData: None',
torch_xla._XLAC._get_xla_tensor_debug_info(xla_z))
self.assertNotIn('XLAData: None',
torch_xla._XLAC._get_xla_tensor_debug_info(xla_xy))
self.assertNotIn('XLAData: None',
torch_xla._XLAC._get_xla_tensor_debug_info(xla_y3))

def test_fn_without_input(self):

Expand Down
12 changes: 8 additions & 4 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,10 +389,14 @@ def optimized_mod(*args):

# mark_step needs to be blocking since we want to access args's XLADatas
# and they can't be placeholder.
if any(
torch_xla._XLAC._check_tensor_need_materialization(
[a for a in args if isinstance(a, torch.Tensor)])):
xm.mark_step(wait=True)
input_tensors_to_sync = [
args[i] for i, x in enumerate(
torch_xla._XLAC._check_tensor_need_materialization(
[a for a in args if isinstance(a, torch.Tensor)])) if x
]
if len(input_tensors_to_sync) > 0:
torch_xla._XLAC._xla_sync_multi(
input_tensors_to_sync, devices=[], wait=True, sync_xla_data=True)

# If input sharding has changed from the previous program, dynamo current can
# not detect this. It will mistakenly believe the program is the same. We need
Expand Down

0 comments on commit da4993c

Please sign in to comment.