Skip to content

Commit

Permalink
Fix the setup for SPMD graph dump test
Browse files Browse the repository at this point in the history
  • Loading branch information
rpsilva-aws committed Dec 9, 2024
1 parent 9a98a36 commit c48ea1f
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions test/spmd/test_spmd_graph_dump.py
Original file line number Diff line number Diff line change
@@ -23,8 +23,7 @@ def setUpClass(cls):
def test_dump_with_output_sharding(self):
save_file = os.getenv('XLA_SAVE_TENSORS_FILE')
save_format = os.getenv('XLA_SAVE_TENSORS_FMT')
if not save_file:
assert False, "This test should be run with XLA_SAVE_TENSORS_FILE"
assert save_file, "This test should be run with XLA_SAVE_TENSORS_FILE"
should_dump_output_sharding = (save_format == 'hlo')
save_file += '.0'
device = xm.xla_device()
@@ -35,12 +34,10 @@ def test_dump_with_output_sharding(self):
xla_sharded_x = xs.mark_sharding(xla_x, self._get_mesh((1, self.n_devices)),
partition_spec)
xla_res = xla_x + xla_y
xm.mark_step()
with open(save_file, 'rb') as f:
current_line = sum(1 for line in f)
with open(save_file, 'rb') as f:
xm.mark_step()
lines = f.readlines()
self.assertGreater(len(lines), current_line)
self.assertGreater(len(lines), 0)
if should_dump_output_sharding:
self.assertIn('OUTPUT_SHARDING_END', str(lines[-2]))
else:

0 comments on commit c48ea1f

Please sign in to comment.