Skip to content

Commit

Permalink
Stricter testing of sharding annotations after scan
Browse files Browse the repository at this point in the history
For `init` and `final_carry`, the sharding spec should look like
`devices=[4]0,1,2,3`.

For `xs` and `ys`, the sharding spec should look like
`devices=[1,4]0,1,2,3`.
  • Loading branch information
tengyifei committed Dec 5, 2024
1 parent 4c99d21 commit 78df1d2
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions test/scan/test_scan_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,21 @@ def fn(carry, x):
# Check the input and output sharding. Note that we do this after
# `torch_xla.sync()` to ensure the output tensors are materialized and
# have taken on sharding annotations propagated by the compiler.
for tensor in [init, xs, final_carry, ys]:
self.assertIn('ShardingSpec: {devices=[',
N = xr.global_runtime_device_count()
for tensor in [init, final_carry]:
self.assertIn('ShardingSpec: {'
f'devices=[{N}]0,',
torch_xla._XLAC._get_xla_tensor_debug_info(tensor))
self.assertIn('OpSharding: {devices=[',
self.assertIn('OpSharding: {'
f'devices=[{N}]0,',
torch_xla._XLAC._get_xla_tensor_debug_info(tensor))
# For xs and ys, they are replicated at the first dim and sharded at the second dim.
for tensor in [xs, ys]:
self.assertIn('ShardingSpec: {'
f'devices=[1,{N}]0,',
torch_xla._XLAC._get_xla_tensor_debug_info(tensor))
self.assertIn('OpSharding: {'
f'devices=[1,{N}]0,',
torch_xla._XLAC._get_xla_tensor_debug_info(tensor))


Expand Down

0 comments on commit 78df1d2

Please sign in to comment.