diff --git a/test/scan/test_scan_spmd.py b/test/scan/test_scan_spmd.py index cde7fb7bb65..19ec991cbb3 100644 --- a/test/scan/test_scan_spmd.py +++ b/test/scan/test_scan_spmd.py @@ -39,10 +39,19 @@ def fn(carry, x): # Check the input and output sharding. Note that we do this after # `torch_xla.sync()` to ensure the output tensors are materialized and # have taken on sharding annotations propagated by the compiler. - for tensor in [init, xs, final_carry, ys]: - self.assertIn('ShardingSpec: {devices=[', + N = xr.global_runtime_device_count() + for tensor in [init, final_carry]: + self.assertIn(f'devices=[{N}]0,', + torch_xla._XLAC._get_xla_sharding_spec(tensor)) + self.assertIn('OpSharding: {' + f'devices=[{N}]0,', torch_xla._XLAC._get_xla_tensor_debug_info(tensor)) - self.assertIn('OpSharding: {devices=[', + # For xs and ys, they are replicated at the first dim and sharded at the second dim. + for tensor in [xs, ys]: + self.assertIn(f'devices=[1,{N}]0,', + torch_xla._XLAC._get_xla_sharding_spec(tensor)) + self.assertIn('OpSharding: {' + f'devices=[1,{N}]0,', torch_xla._XLAC._get_xla_tensor_debug_info(tensor))