Skip to content

Commit

Permalink
[SPMD] Add a test case for mark_shard scalar tensors
Browse files Browse the repository at this point in the history
Summary:
Support mark_shard scalar tensors.

Test Plan:
python test/spmd/test_xla_sharding.py -v -k test_mark_shard_scalar
  • Loading branch information
alanwaketan committed Dec 14, 2023
1 parent 419dd87 commit dc58b2e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
22 changes: 22 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,28 @@ def test_from_cpu_shards_global_shape(self):
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((1,)))

def test_mark_shard_scalar(self):
x = torch.tensor(1.0).to(xm.xla_device())
self.assertEqual(len(x.shape), 0)

xt = xs.mark_sharding(x, self._get_mesh((1, self.n_devices)), ())
self.assertEqual(xt, x)
self.assertEqual(xt.sharding_type, xs.ShardingType.REPLICATED)
self.assertEqual(xt.sharding_spec, "{replicated}")

shards = xt.local_shards
self.assertEqual(len(shards), self.n_devices)
# all shards are REPLICATED.
for i, shard in enumerate(shards):
self.assertEqual(shard.data.device, torch.device('cpu'))
self.assertTrue(torch.allclose(shard.data, torch.tensor(1.0)))
self.assertIsInstance(shard.indices, type(Ellipsis))
self.assertEqual(shard.replica_id, i)

# It looks like mesh_shape attribute is never implemented.
with self.assertRaises(AttributeError):
xt.mesh_shape


if __name__ == '__main__':
test = unittest.main()
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ def get_op_sharding(self,
Return the OpSharding for the given partition spec. This is an expensive
operation as the mesh grows, so the value is cached for reuse.
"""
# For scalar tensors, it can only be replicated.
# We have made sure len(t.shape) == len(partition_spec)
# in mark_sharding API.
if len(partition_spec) == 0:
return torch_xla._XLAC.OpSharding([], [], [], ShardingType.REPLICATED)

tile_assignment, group_assignment, replication_groups, sharding_type = self._get_op_sharding_args(
partition_spec)
return torch_xla._XLAC.OpSharding(tile_assignment, group_assignment,
Expand Down

0 comments on commit dc58b2e

Please sign in to comment.