Skip to content

Commit

Permalink
Stricter testing of sharding annotations after scan
Browse files Browse the repository at this point in the history
For `init` and `final_carry`, the sharding spec should look like
`devices=[4]0,1,2,3`.

For `xs` and `ys`, the sharding spec should look like
`devices=[1,4]0,1,2,3`.
  • Loading branch information
tengyifei committed Dec 5, 2024
1 parent 4c99d21 commit 2ee45ad
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions test/scan/test_scan_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,19 @@ def fn(carry, x):
# Check the input and output sharding. Note that we do this after
# `torch_xla.sync()` to ensure the output tensors are materialized and
# have taken on sharding annotations propagated by the compiler.
for tensor in [init, xs, final_carry, ys]:
self.assertIn('ShardingSpec: {devices=[',
N = xr.global_runtime_device_count()
for tensor in [init, final_carry]:
self.assertIn(f'devices=[{N}]0,',
torch_xla._XLAC._get_xla_sharding_spec(tensor))
self.assertIn('OpSharding: {'
f'devices=[{N}]0,',
torch_xla._XLAC._get_xla_tensor_debug_info(tensor))
self.assertIn('OpSharding: {devices=[',
# For xs and ys, they are replicated at the first dim and sharded at the second dim.
for tensor in [xs, ys]:
self.assertIn(f'devices=[1,{N}]0,',
torch_xla._XLAC._get_xla_sharding_spec(tensor))
self.assertIn('OpSharding: {'
f'devices=[1,{N}]0,',
torch_xla._XLAC._get_xla_tensor_debug_info(tensor))


Expand Down

0 comments on commit 2ee45ad

Please sign in to comment.