diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index db303302e09f..e826e8b50fa6 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -1047,6 +1047,28 @@ def test_from_cpu_shards_global_shape(self): with self.assertRaises(RuntimeError): from_cpu_shards(shards, op_sharding, torch.Size((1,))) + def test_mark_shard_scalar(self): + x = torch.tensor(1.0).to(xm.xla_device()) + self.assertEqual(len(x.shape), 0) + + xt = xs.mark_sharding(x, self._get_mesh((1, self.n_devices)), ()) + self.assertEqual(xt, x) + self.assertEqual(xt.sharding_type, xs.ShardingType.REPLICATED) + self.assertEqual(xt.sharding_spec, "{replicated}") + + shards = xt.local_shards + self.assertEqual(len(shards), self.n_devices) + # all shards are REPLICATED. + for i, shard in enumerate(shards): + self.assertEqual(shard.data.device, torch.device('cpu')) + self.assertTrue(torch.allclose(shard.data, torch.tensor(1.0))) + self.assertIsInstance(shard.indices, type(Ellipsis)) + self.assertEqual(shard.replica_id, i) + + # It looks like mesh_shape attribute is never implemented. + with self.assertRaises(AttributeError): + xt.mesh_shape + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 2fd4a2eb753a..90e0eaaa3e90 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -111,6 +111,12 @@ def get_op_sharding(self, Return the OpSharding for the given partition spec. This is an expensive operation as the mesh grows, so the value is cached for reuse. """ + # For scalar tensors, it can only be replicated. + # We have made sure len(t.shape) == len(partition_spec) + # in mark_sharding API. + if len(partition_spec) == 0: + return torch_xla._XLAC.OpSharding([], [], [], ShardingType.REPLICATED) + tile_assignment, group_assignment, replication_groups, sharding_type = self._get_op_sharding_args( partition_spec) return torch_xla._XLAC.OpSharding(tile_assignment, group_assignment,