From c48ea1f4a5203e35f8b3c9a09692c1e615c1572f Mon Sep 17 00:00:00 2001 From: rpsilva-aws Date: Mon, 9 Dec 2024 20:58:36 +0000 Subject: [PATCH] Fix the setup for SPMD graph dump test --- test/spmd/test_spmd_graph_dump.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/test/spmd/test_spmd_graph_dump.py b/test/spmd/test_spmd_graph_dump.py index e3cadd7b3ce..9ef67ef0a85 100644 --- a/test/spmd/test_spmd_graph_dump.py +++ b/test/spmd/test_spmd_graph_dump.py @@ -23,8 +23,7 @@ def setUpClass(cls): def test_dump_with_output_sharding(self): save_file = os.getenv('XLA_SAVE_TENSORS_FILE') save_format = os.getenv('XLA_SAVE_TENSORS_FMT') - if not save_file: - assert False, "This test should be run with XLA_SAVE_TENSORS_FILE" + assert save_file, "This test should be run with XLA_SAVE_TENSORS_FILE" should_dump_output_sharding = (save_format == 'hlo') save_file += '.0' device = xm.xla_device() @@ -35,12 +34,10 @@ def test_dump_with_output_sharding(self): xla_sharded_x = xs.mark_sharding(xla_x, self._get_mesh((1, self.n_devices)), partition_spec) xla_res = xla_x + xla_y + xm.mark_step() with open(save_file, 'rb') as f: - current_line = sum(1 for line in f) - with open(save_file, 'rb') as f: - xm.mark_step() lines = f.readlines() - self.assertGreater(len(lines), current_line) + self.assertGreater(len(lines), 0) if should_dump_output_sharding: self.assertIn('OUTPUT_SHARDING_END', str(lines[-2])) else: