From 2ee45adc4ebe810e741e69fb40cd754f8a9f4fa9 Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Thu, 5 Dec 2024 20:55:01 +0000 Subject: [PATCH] Stricter testing of sharding annotations after scan For `init` and `final_carry`, the sharding spec should look like `devices=[4]0,1,2,3`. For `xs` and `ys`, the sharding spec should look like `devices=[1,4]0,1,2,3`. --- test/scan/test_scan_spmd.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/test/scan/test_scan_spmd.py b/test/scan/test_scan_spmd.py index cde7fb7bb65..19ec991cbb3 100644 --- a/test/scan/test_scan_spmd.py +++ b/test/scan/test_scan_spmd.py @@ -39,10 +39,19 @@ def fn(carry, x): # Check the input and output sharding. Note that we do this after # `torch_xla.sync()` to ensure the output tensors are materialized and # have taken on sharding annotations propagated by the compiler. - for tensor in [init, xs, final_carry, ys]: - self.assertIn('ShardingSpec: {devices=[', + N = xr.global_runtime_device_count() + for tensor in [init, final_carry]: + self.assertIn(f'devices=[{N}]0,', + torch_xla._XLAC._get_xla_sharding_spec(tensor)) + self.assertIn('OpSharding: {' + f'devices=[{N}]0,', torch_xla._XLAC._get_xla_tensor_debug_info(tensor)) - self.assertIn('OpSharding: {devices=[', + # For xs and ys, they are replicated at the first dim and sharded at the second dim. + for tensor in [xs, ys]: + self.assertIn(f'devices=[1,{N}]0,', + torch_xla._XLAC._get_xla_sharding_spec(tensor)) + self.assertIn('OpSharding: {' + f'devices=[1,{N}]0,', torch_xla._XLAC._get_xla_tensor_debug_info(tensor))