From 55cdf381fc50374ee4abf80df0a83ea2ce5f4ba8 Mon Sep 17 00:00:00 2001 From: rpsilva-aws Date: Tue, 10 Dec 2024 21:39:11 +0000 Subject: [PATCH] [AdHoc] Test fix: clear metrics and simplify the replicated counter --- test/spmd/test_spmd_lowering_context.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/spmd/test_spmd_lowering_context.py b/test/spmd/test_spmd_lowering_context.py index df872073073..2589e060ab9 100644 --- a/test/spmd/test_spmd_lowering_context.py +++ b/test/spmd/test_spmd_lowering_context.py @@ -27,6 +27,7 @@ def _get_computation_hlo_txt(self, ctx): return xb.get_computation_hlo(comp) def test_basic(self): + met.clear_all() save_file = os.getenv('XLA_SAVE_TENSORS_FILE') save_format = os.getenv('XLA_SAVE_TENSORS_FMT') assert save_file, "This test should be run with XLA_SAVE_TENSORS_FILE" @@ -99,7 +100,7 @@ def fn(x, y): assert expected_output[1] == f"f32[32,2048] {b_sharding_spec}" assert expected_output[2] == f"f32[2048] {a_sharding_spec}" assert expected_output[3] == f"f32[32,2048] {b_sharding_spec}" - self.assertTrue(met.counter_value("ExecuteReplicated") == 1) + self.assertTrue(met.counter_value("ExecuteReplicated") > 0) self.assertTrue(met.counter_value("ExecuteComputation") is None) def test_device_parameter_id_tensor_mapping(self):