Skip to content

Commit

Permalink
update hlo op name in expected hlo str
Browse files Browse the repository at this point in the history
  • Loading branch information
Siyuan Liu committed Feb 7, 2024
1 parent 24ea8c1 commit 3f86ba4
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,13 +819,13 @@ def test_mark_sharding_ir(self):
(0, 1))
hlo = torch_xla._XLAC._get_xla_tensors_hlo([actual.global_tensor])
self.assertIn(
'%custom-call.10 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.9), custom_call_target="Sharding", sharding=',
'%custom-call.7 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.6), custom_call_target="Sharding", sharding=',
hlo)

actual += 0
hlo = torch_xla._XLAC._get_xla_tensors_hlo([actual.global_tensor])
self.assertIn(
'%add.15 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.13, f32[1,128]{1,0} %broadcast.14)',
'%add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.10, f32[1,128]{1,0} %broadcast.11)',
hlo)

self.assertTrue(torch.allclose(expected, actual.cpu()))
Expand Down

0 comments on commit 3f86ba4

Please sign in to comment.