Skip to content

Commit

Permalink
add another test for aliasing across mark_step
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed May 3, 2024
1 parent 22d1fb4 commit 8c1e37a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
11 changes: 11 additions & 0 deletions test/test_input_output_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ def test_aliasing_with_cloned(self):
torch.allclose(t1 - 1, t1_cloned)
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0)

def test_aliasing_across_mark_step(self):
xla_device = xm.xla_device()
met.clear_all()
t1 = torch.randn(4, 5).to(xla_device)
t1 += 1
xm.mark_step()
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0)
t1 *= 100
xm.mark_step()
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0)

def test_aliasing_with_multiple_inplace_update(self):
BATCH_SIZE = 1
SEQ_LEN = 128
Expand Down
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ python3 test/pjrt/test_dtypes.py
python3 test/pjrt/test_dynamic_plugin_tpu.py
python3 test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
python3 test/test_pallas.py
python3 test/test_input_output_aliases.py
python3 test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py
python3 test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py
python3 test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py
Expand Down

0 comments on commit 8c1e37a

Please sign in to comment.