diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 07733e3b5bc..ce2cae18dd6 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -900,139 +900,6 @@ def test_op_sharding_cache(self): xs.mark_sharding(v, mesh, (0, None)) self.assertEqual(met.counter_value("CreateOpSharding"), 2) - def test_from_cpu_shards_replicated(self): - from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards - - # Create an OpSharding with all devices on a single axis - mesh = self._get_mesh((self.n_devices,)) - partition_spec = (None,) - op_sharding = mesh.get_op_sharding(partition_spec) - shards = [torch.arange(4)] * self.n_devices - - # No shape should result in the shape of a single shard. - global_tensor = from_cpu_shards(shards, op_sharding) - self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0])) - - # Specify a valid shape for the global tensor - global_tensor = from_cpu_shards(shards, op_sharding, shards[0].shape) - self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0])) - - # All invalid shapes should raise - with self.assertRaises(RuntimeError): - from_cpu_shards(shards, op_sharding, torch.Size((5,))) - with self.assertRaises(RuntimeError): - from_cpu_shards(shards, op_sharding, torch.Size((3,))) - with self.assertRaises(RuntimeError): - from_cpu_shards(shards, op_sharding, torch.Size((2, 2))) - - def test_from_cpu_shards_tiled(self): - from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards - - # Create an OpSharding with all devices on a single axis - mesh = self._get_mesh((self.n_devices,)) - partition_spec = (0,) - op_sharding = mesh.get_op_sharding(partition_spec) - shards = [torch.LongTensor([i]) for i in range(self.n_devices)] - - global_tensor = from_cpu_shards(shards, op_sharding) - self.assertTrue( - torch.allclose(global_tensor.cpu(), torch.arange(self.n_devices))) - - # Test incorrect number of shards - with self.assertRaises(RuntimeError): - from_cpu_shards(shards[:-1], op_sharding) - - # Test an invalid global shape - too many values. - with self.assertRaises(RuntimeError): - from_cpu_shards(shards, op_sharding, torch.Size((self.n_devices * 2,))) - - # Test an invalid global shape - incorrect rank - with self.assertRaises(RuntimeError): - from_cpu_shards(shards, op_sharding, torch.Size((1, self.n_devices))) - - # Test a valid global shape - restrict the number of meaningful values - # to 1, treating the rest as padding. - global_tensor = from_cpu_shards(shards, op_sharding, torch.Size((1,))) - self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(1))) - - def test_from_cpu_shards_2d(self): - from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards - - # Create an appropriate 2D mesh for the number of devices - if self.n_devices >= 4: - mesh_shape = (self.n_devices // 2, 2) - else: - mesh_shape = (1, self.n_devices) - mesh_2d = self._get_mesh(mesh_shape) - - # Replicated sharding - shards = [torch.LongTensor([self.n_devices])] * self.n_devices - partition_spec = (None, None) - op_sharding = mesh_2d.get_op_sharding(partition_spec) - global_tensor = from_cpu_shards(shards, op_sharding) - self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0])) - - if self.n_devices > 1: - # Tiled sharding - shards = [torch.LongTensor([[i]]) for i in range(self.n_devices)] - partition_spec = (0, 1) - op_sharding = mesh_2d.get_op_sharding(partition_spec) - global_tensor = from_cpu_shards(shards, op_sharding) - expected = torch.arange(self.n_devices).reshape(2, self.n_devices // 2) - self.assertTrue(torch.allclose(global_tensor.cpu(), expected)) - - # Partially replicated sharding - shards = [torch.LongTensor([[i]]) for i in range(2)] * ( - self.n_devices // 2) - partition_spec = (None, 1) - op_sharding = mesh_2d.get_op_sharding(partition_spec) - global_tensor = from_cpu_shards(shards, op_sharding) - # Partial replication along the 0th axis represents a global tensor - # of torch.Tensor([[0, 1]]). - expected = torch.arange(2).reshape(1, 2) - self.assertTrue(torch.allclose(global_tensor.cpu(), expected)) - - def test_from_cpu_shards_global_shape(self): - from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards - - mesh = self._get_mesh((self.n_devices,)) - numel = self.n_devices**2 - # The global tensor is torch.arange(numel). - shards = [ - torch.arange(self.n_devices) + (i * self.n_devices) - for i in range(self.n_devices) - ] - partition_spec = (0,) - op_sharding = mesh.get_op_sharding(partition_spec) - - # No global shape specified will include all data from the shards - global_tensor = from_cpu_shards(shards, op_sharding) - self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(numel))) - - # Too large of a global shape will error out - with self.assertRaises(RuntimeError): - from_cpu_shards(shards, op_sharding, torch.Size((numel + 1,))) - - if self.n_devices > 1: - # When the global tensor has fewer elements than the sum of its shards, - # there are two cases: - - # Case 1: If the global shape is within n_devices of numel, the excess - # data is treated as padding and ignored. - for delta in range(self.n_devices): - size = torch.Size((numel - delta,)) - global_tensor = from_cpu_shards(shards, op_sharding, size) - expected = torch.arange(size[0]) - self.assertTrue(torch.allclose(global_tensor.cpu(), expected)) - - # Case 2: Otherwise, it is not possible to have that much padding in a - # sharded tensor, and the shards are incompatible with the shape. - with self.assertRaises(RuntimeError): - shape = torch.Size((numel - self.n_devices,)) - from_cpu_shards(shards, op_sharding, shape) - with self.assertRaises(RuntimeError): - from_cpu_shards(shards, op_sharding, torch.Size((1,))) - if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index fe18f9508b6..b3957d7a68f 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -34,7 +34,6 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/ir_dump_util.h" -#include "torch_xla/csrc/layout_manager.h" #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/runtime/computation_client.h" @@ -1664,72 +1663,6 @@ void InitXlaModuleBindings(py::module m) { } return std::nullopt; }); - // Reassemble the CPU shards into a global tensor. A new sharded tensor is - // created from the local shards with the provided sharding annotation - // attached. The order of the shards should coincide with the order of - // devices returned by `torch_xla.runtime.local_runtime_devices()`. - m.def( - "_global_tensor_from_cpu_shards", - [](const std::vector& shards, const xla::OpSharding& sharding, - std::optional>& global_shape) -> at::Tensor { - XLA_CHECK(UseVirtualDevice()) - << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; - auto local_devices = runtime::GetComputationClient()->GetLocalDevices(); - XLA_CHECK(local_devices.size() == shards.size()) - << "Must specify a shard for each local device"; - XLA_CHECK(!global_shape.has_value() || - global_shape.value().size() == shards[0].sizes().size()) - << "Global shape rank must agree with shard rank: expected rank " - << shards[0].sizes().size() << ", got " - << global_shape.value().size(); - - if (!global_shape.has_value()) { - // Set a default value for the global shape based on the sharding - // type. - if (sharding.type() == xla::OpSharding::OTHER) { - // Infer the global shape to be the shard shape scaled by the tiling - // dimensionality. - auto tile_shape = sharding.tile_assignment_dimensions(); - global_shape = std::vector(); - for (int dim = 0; dim < shards[0].sizes().size(); ++dim) { - auto global_dim = tile_shape[dim] * shards[0].sizes()[dim]; - global_shape->push_back(global_dim); - } - } else if (sharding.type() == xla::OpSharding::REPLICATED) { - global_shape = shards[0].sizes().vec(); - } else { - XLA_ERROR() << "Unsupported OpSharding type: " << sharding.type(); - } - } - - auto device = GetVirtualDevice(); - auto primitive_type = - MakeXlaPrimitiveType(shards[0].type().scalarType(), &device); - xla::Shape tensor_shape = MakeArrayShapeFromDimensions( - global_shape.value(), /*dynamic_dimensions=*/{}, primitive_type, - static_cast(device.type())); - auto sharding_spec = - std::make_shared(sharding, tensor_shape); - - // Verify that the shard shape is correct for the global shape and - // sharding spec. - auto expected_shard_shape = ShardingUtil::GetShardShape(sharding_spec); - for (auto shard : shards) { - XLA_CHECK(shard.sizes() == expected_shard_shape) - << "Input shard shape must include padding: " << shard.sizes() - << " vs " << expected_shard_shape; - } - - auto data_handle = WrapXlaData(ShardingUtil::CreateShardedData( - shards, local_devices, sharding_spec)); - XLATensorPtr xla_tensor = XLATensor::Create(std::move(data_handle)); - xla_tensor->SetShardingSpec(*sharding_spec); - auto tensor = bridge::AtenFromXlaTensor(std::move(xla_tensor)); - return torch::autograd::make_variable(tensor, - shards[0].requires_grad()); - }, - py::arg("shards"), py::arg("sharding"), - py::arg("global_shape") = py::none()); // Returns the local shards of the tensor, with values taken from the // underlying ComputationClient::GetDataShards. As such, the shards will // contain any padding that was applied to ensure they all have the same diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index cde74256eee..f7da463fb64 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -706,8 +706,7 @@ void ShardingUtil::PrepareOutputShardingPropagation( } runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( - const std::vector& local_shards, - const std::vector& devices, + std::vector& local_shards, std::vector& devices, const XLATensor::ShardingSpecPtr& sharding_spec) { XLA_CHECK(local_shards.size() == devices.size()) << "A device must be speficied for each shard"; diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 32060c7fc09..4a595f4e99b 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -147,8 +147,7 @@ class ShardingUtil { // Transfers the individual shards to the devices and returns a DataPtr for // the PjRtShardedData wrapping the shards. static runtime::ComputationClient::DataPtr CreateShardedData( - const std::vector& shards, - const std::vector& devices, + std::vector& shards, std::vector& devices, const XLATensor::ShardingSpecPtr& sharding_spec); }; diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index 21d0e2e570a..95f4a88128b 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -87,14 +87,6 @@ def get_op_sharding(self, Return the OpSharding for the given partition spec. This is an expensive operation as the mesh grows, so the value is cached for reuse. """ - partition_spec = _translate_named_partition_spec(self, partition_spec) - flat_specs = np.hstack([d for d in partition_spec]) - specs = [d for d in flat_specs if d is not None] - assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \ - f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape." - assert len(specs) == len(np.unique(specs)), \ - f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." - tile_assignment = _get_tile_assignment(self, partition_spec) if len(tile_assignment.shape) > len(partition_spec): # Use partial replication for sharding a tensor over a higher-rank mesh @@ -490,12 +482,19 @@ def mark_sharding( assert num_devices > 0, "This requires XLA supported device(s)." assert mesh.size() == num_devices, \ f"{mesh.mesh_shape} is not mappable over {num_devices} devices." + partition_spec = _translate_named_partition_spec(mesh, partition_spec) # We only allow fully specified `partition_spec` to be applicable, as opposed # to filling in the unspecified replicated dims. Fully specified `partiion_spec` # should be of the same rank as `t`. This is to support partial replication # where the group assignment may vary with different input ranks. assert len(t.shape) == len(partition_spec), \ f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." + flat_specs = np.hstack([d for d in partition_spec]) + specs = [d for d in flat_specs if d is not None] + assert all(d >= 0 and d < len(mesh.mesh_shape) for d in specs), \ + f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape." + assert len(specs) == len(np.unique(specs)), \ + f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." op_sharding = mesh.get_op_sharding(partition_spec)