Skip to content

Commit

Permalink
[AdHoc] Test fix: clear metrics and simplify the replicated counter
Browse files Browse the repository at this point in the history
  • Loading branch information
rpsilva-aws committed Dec 10, 2024
1 parent e8720ba commit 55cdf38
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion test/spmd/test_spmd_lowering_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def _get_computation_hlo_txt(self, ctx):
return xb.get_computation_hlo(comp)

def test_basic(self):
met.clear_all()
save_file = os.getenv('XLA_SAVE_TENSORS_FILE')
save_format = os.getenv('XLA_SAVE_TENSORS_FMT')
assert save_file, "This test should be run with XLA_SAVE_TENSORS_FILE"
Expand Down Expand Up @@ -99,7 +100,7 @@ def fn(x, y):
assert expected_output[1] == f"f32[32,2048] {b_sharding_spec}"
assert expected_output[2] == f"f32[2048] {a_sharding_spec}"
assert expected_output[3] == f"f32[32,2048] {b_sharding_spec}"
self.assertTrue(met.counter_value("ExecuteReplicated") == 1)
self.assertTrue(met.counter_value("ExecuteReplicated") > 0)
self.assertTrue(met.counter_value("ExecuteComputation") is None)

def test_device_parameter_id_tensor_mapping(self):
Expand Down

0 comments on commit 55cdf38

Please sign in to comment.