diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index 3cecbe9cbda..1bd39e91783 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -70,9 +70,9 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { auto sharding_spec = std::make_shared(sharding, tensor_shape); auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); - auto shard_indices = ShardingUtil::GetShardIndicesForDevices( + auto replica_and_indices = ShardingUtil::GetShardReplicaAndIndicesForDevices( shard_shape, tensor.sizes().vec(), sharding, devices); - EXPECT_EQ(shard_indices.size(), devices.size()); + EXPECT_EQ(replica_and_indices.size(), devices.size()); /* Tiled indices should be: dim=0 dim=1 device=0 [0:4, 0:4] @@ -81,11 +81,15 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { device=3 [4:8, 4:7] */ std::vector> slice_starts = {{0, 0}, {0, 4}, {4, 0}, {4, 4}}; std::vector> slice_ends = {{4, 4}, {4, 7}, {8, 4}, {8, 7}}; - for (int device = 0; device < shard_indices.size(); ++device) { - EXPECT_EQ(shard_indices[device].size(), tensor.sizes().size()); - for (int dim = 0; dim < shard_indices[device].size(); ++dim) { - EXPECT_TRUE(shard_indices[device][dim].is_slice()); - auto slice = shard_indices[device][dim].slice(); + for (int device = 0; device < replica_and_indices.size(); ++device) { + auto& shard_replica_id = replica_and_indices[device].first; + EXPECT_EQ(shard_replica_id, + 0); // Shard replica_id is always 0 for tiled sharding. + auto& shard_indices = replica_and_indices[device].second; + EXPECT_EQ(shard_indices.size(), tensor.sizes().size()); + for (int dim = 0; dim < shard_indices.size(); ++dim) { + EXPECT_TRUE(shard_indices[dim].is_slice()); + auto slice = shard_indices[dim].slice(); EXPECT_EQ(slice.start(), slice_starts[device][dim]); EXPECT_EQ(slice.stop(), slice_ends[device][dim]); EXPECT_EQ(slice.step(), 1); @@ -94,12 +98,15 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { sharding = xla::HloSharding::Replicate().ToProto(); sharding_spec->sharding = sharding; shard_shape = ShardingUtil::GetShardShape(sharding_spec); - shard_indices = ShardingUtil::GetShardIndicesForDevices( + replica_and_indices = ShardingUtil::GetShardReplicaAndIndicesForDevices( shard_shape, tensor.sizes().vec(), sharding, devices); - EXPECT_EQ(shard_indices.size(), devices.size()); + EXPECT_EQ(replica_and_indices.size(), devices.size()); for (int i = 0; i < devices.size(); ++i) { - EXPECT_EQ(shard_indices[i].size(), 1); - EXPECT_TRUE(shard_indices[i][0].is_ellipsis()); + auto& replica_id = replica_and_indices[i].first; + EXPECT_EQ(replica_id, i); // Shard replica_id should equal global ordinal. + auto& shard_indices = replica_and_indices[i].second; + EXPECT_EQ(shard_indices.size(), 1); + EXPECT_TRUE(shard_indices[0].is_ellipsis()); } } diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 07fa79bc658..ce2cae18dd6 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -68,7 +68,7 @@ def test_xla_shards(self): for i, shard in enumerate(shards): self.assertEqual(shard.data.device, torch.device('cpu')) self.assertEqual(shard.data.shape, (shard_len,)) - start, end = (i, i + 1) * shard_len + start, end = i * shard_len, (i + 1) * shard_len expected = torch.arange(start, end, dtype=torch.float32) self.assertTrue(torch.allclose(shard.data, expected)) if isinstance(shard.indices, list): @@ -77,6 +77,8 @@ def test_xla_shards(self): else: self.assertIsInstance(shard.indices, type(Ellipsis)) self.assertTrue(torch.allclose(shard.data, t[shard.indices])) + # Tiled sharding makes all shards have replica_id 0. + self.assertEqual(shard.replica_id, 0) def test_padded_xla_shards(self): num_element = self.n_devices + 1 # Ensure padding with two or more devices @@ -105,6 +107,8 @@ def test_padded_xla_shards(self): else: self.assertIsInstance(shard.indices, type(Ellipsis)) self.assertTrue(torch.allclose(shard.unpadded_data, t[shard.indices])) + # Tiled sharding makes all shards have replica_id 0. + self.assertEqual(shard.replica_id, 0) def test_replicated_xla_shards(self): num_element = self.n_devices @@ -120,6 +124,36 @@ def test_replicated_xla_shards(self): self.assertIsInstance(shard.indices, type(Ellipsis)) self.assertTrue(torch.allclose(shard.data, t[shard.indices])) self.assertTrue(torch.allclose(shard.data, shard.unpadded_data)) + # Replicated sharding sets the shard replica_id to the device ordinal + self.assertEqual(shard.replica_id, i) + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Multiple devices required for partial replication") + def test_partially_replicated_xla_shards(self): + num_element = 256 + mesh = self._get_mesh((self.n_devices // 2, 2)) + t = torch.arange(num_element, dtype=torch.float32).reshape((16, 16)) + # Partial replication along the 0th tensor axis, shard 2-way on the 1st + xt = xs.mark_sharding(t.to(xm.xla_device()), mesh, (None, 1)) + shard_len = t.shape[1] // 2 + + shards = xt.local_shards + self.assertEqual(len(shards), self.n_devices) + for i, shard in enumerate(shards): + self.assertEqual(shard.data.device, torch.device('cpu')) + self.assertEqual(shard.data.shape, (t.shape[0], shard_len)) + self.assertEqual(len(shard.indices), len(t.shape)) + start, end = (i % 2) * shard_len, ((i % 2) + 1) * shard_len + # All shards should contain the full range for dim 0 + self.assertEqual(shard.indices[0], slice(0, t.shape[0], 1)) + # The index range should be sharded for dim 1 + self.assertEqual(shard.indices[1], slice(start, end, 1)) + self.assertTrue(torch.allclose(shard.data, t[shard.indices])) + self.assertTrue(torch.allclose(shard.data, shard.unpadded_data)) + # The replica_id should be coincide with the replication group for the + # device. Given the mesh shape, the shard replica_id will be the device's + # row in the mesh, which is device_id // 2 + self.assertEqual(shard.replica_id, i // 2) def test_load_local_shards(self): num_element = self.n_devices diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 5970c0388cc..b3957d7a68f 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1667,7 +1667,7 @@ void InitXlaModuleBindings(py::module m) { // underlying ComputationClient::GetDataShards. As such, the shards will // contain any padding that was applied to ensure they all have the same // shape. Note that this padding is _not_ included in the global indices - // returned by `_get_local_shard_indices`. + // returned by `_get_local_shard_replica_and_indices`. m.def("_get_local_shards", [](const at::Tensor& input) -> std::tuple, std::vector> { @@ -1699,13 +1699,16 @@ void InitXlaModuleBindings(py::module m) { } return std::make_tuple(shards, str_devices); }); - // Returns the indices of the shards into the global tensor as either - // a Python list of slices for each dimension or a Python Ellipsis object - // indicating that the tensor is replicated. These indices will not reflect - // any padding that has been applied to the shards. The order of the returned - // indices matches the order of the shards returned from `_get_local_shards`. - m.def("_get_local_shard_indices", - [](const at::Tensor& input) -> std::vector { + // For each local shard, returns the tuple: + // (replica_id: int, indices: Union[List[Slice], Ellipsis]), + // where `replica_id` is the replica the shard belongs to and `indices` index + // into the global tensor. The value of `indices` is either a Python list of + // slices for each dimension or an Ellipsis object indicating that the tensor + // is replicated. These indices will not reflect any padding that has been + // applied to the shards. The order of the returned indices matches the order + // of the shards returned from `_get_local_shards`. + m.def("_get_local_shard_replica_and_indices", + [](const at::Tensor& input) -> std::vector> { XLATensorPtr xtensor = bridge::GetXlaTensor(input); XLA_CHECK(xtensor->sharding_spec() != nullptr) << "Tensor is not sharded"; @@ -1720,29 +1723,33 @@ void InitXlaModuleBindings(py::module m) { auto sharding_spec = xtensor->sharding_spec(); auto sharding = xtensor->sharding_spec()->sharding; auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); - auto indices = ShardingUtil::GetShardIndicesForDevices( - shard_shape, input.sizes().vec(), sharding, shard_devices); + auto replica_and_indices = + ShardingUtil::GetShardReplicaAndIndicesForDevices( + shard_shape, input.sizes().vec(), sharding, shard_devices); // Convert each vector to List[py::slice] or py::ellipsis - std::vector result; + std::vector> result; result.reserve(shard_devices.size()); - for (auto& device_indices : indices) { - XLA_CHECK(device_indices.size() > 0) + for (auto& device_replica_and_indices : replica_and_indices) { + auto& replica_id = device_replica_and_indices.first; + auto& indices = device_replica_and_indices.second; + XLA_CHECK(indices.size() > 0) << "Unexpected empty shard indices for tensor " << input; - if (device_indices[0].is_ellipsis()) { - result.push_back(py::ellipsis()); + if (indices[0].is_ellipsis()) { + result.push_back(std::make_pair(replica_id, py::ellipsis())); } else { - std::vector device_slices; - for (auto& tensor_index : device_indices) { + std::vector index_slices; + for (auto& tensor_index : indices) { XLA_CHECK(tensor_index.is_slice()) << "Unexpected TensorIndex type: " << tensor_index; auto slice = tensor_index.slice(); ssize_t start = slice.start().expect_int(); ssize_t stop = slice.stop().expect_int(); ssize_t step = slice.step().expect_int(); - device_slices.push_back(py::slice(start, stop, step)); + index_slices.push_back(py::slice(start, stop, step)); } - result.push_back(py::cast(device_slices)); + result.push_back( + std::make_pair(replica_id, py::cast(index_slices))); } } return result; diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 72013a595d8..f7da463fb64 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -435,24 +435,29 @@ ShardingUtil::GetShardIndicesForMinibatchTensor( return shard_indices; } -std::vector> -ShardingUtil::GetShardIndicesForDevices( +std::vector>> +ShardingUtil::GetShardReplicaAndIndicesForDevices( const std::vector& shard_shape, const std::vector& tensor_shape, const xla::OpSharding sharding, const std::vector& devices) { + using namespace at::indexing; + // `shard_indices[dev][dim]` represents the index slice for dimension `dim` // that belongs on device `devices[dev]` if the tensor is sharded. If // `sharding` is REPLICATED, `shard_indices[dev]` will only have a single // Ellipsis element to indicate that the tensor is replicated across all // dimensions. - std::vector> shard_indices( + std::vector>> shard_indices( devices.size()); auto tile_shape = sharding.tile_assignment_dimensions(); if (sharding.type() == xla::OpSharding::REPLICATED) { // Use Ellipsis to indicate all dimensions are replicated - auto ellipsis = at::indexing::TensorIndex(at::indexing::Ellipsis); - auto indices = std::vector({ellipsis}); - std::fill_n(shard_indices.begin(), shard_indices.size(), indices); + auto ellipsis = TensorIndex(Ellipsis); + auto indices = std::vector({ellipsis}); + for (int i = 0; i < devices.size(); ++i) { + int global_ordinal = ParseDeviceString(devices[i]).ordinal(); + shard_indices[i] = std::make_pair(global_ordinal, indices); + } } else if (sharding.type() == xla::OpSharding::OTHER) { auto device_index = build_index_map(devices); std::vector tile_assignment_devices( @@ -472,6 +477,10 @@ ShardingUtil::GetShardIndicesForDevices( continue; } + // The replica id for this shard. This value is only updated from 0 if + // the sharding is partially replicated. + int replica_id = 0; + // Given the shard's row-major index `i`, we need to calculate shard's // coordinates (n_0, ..., n_d) in the tiling to generate the index // slices. Using `N_j = tile_shape[j]` and `0 <= n_j < N_j`, the @@ -482,26 +491,27 @@ ShardingUtil::GetShardIndicesForDevices( // n_0)))`. Then `offset_d = i`, `n_j = offset_j % N_j`, and // `offset_{j-1} = offset_j / N_j`. int offset = i; - std::vector indices; + std::vector indices; for (int j = tile_shape.size() - 1; j >= 0; j--) { + int64_t n_j = offset % tile_shape[j]; if (sharding.replicate_on_last_tile_dim() && j == tile_shape.size() - 1) { // the last tile assignment dimension is replicated, which implies // that the consecutive `tile_shape[j]` devices hold the replicated. + replica_id = n_j; offset /= tile_shape[j]; continue; } - int64_t n_j = offset % tile_shape[j]; // Clamp the slice bounds to the tensor shape to accurately reflect // the shard size without padding. int start = std::min(n_j * shard_shape[j], tensor_shape[j]); int end = std::min((n_j + 1) * shard_shape[j], tensor_shape[j]); - auto slice = at::indexing::Slice(start, end); - indices.push_back(at::indexing::TensorIndex(slice)); + auto slice = Slice(start, end); + indices.push_back(TensorIndex(slice)); offset /= tile_shape[j]; } std::reverse(indices.begin(), indices.end()); - shard_indices[device_index[core]] = indices; + shard_indices[device_index[core]] = std::make_pair(replica_id, indices); } } else { TF_LOG(ERROR) << "Unsupported OpSharding type " << sharding.type(); @@ -534,8 +544,12 @@ std::vector ShardingUtil::ShardTensor( if (minibatch) { shard_indices = GetShardIndicesForMinibatchTensor(shard_shape, devices); } else { - shard_indices = GetShardIndicesForDevices( + auto replica_and_indices = GetShardReplicaAndIndicesForDevices( shard_shape, tensor.sizes().vec(), sharding, devices); + // Extract only the indices, the replica_id is unnecessary for sharding. + std::transform(replica_and_indices.begin(), replica_and_indices.end(), + std::back_inserter(shard_indices), + [](auto& pair) { return pair.second; }); } for (size_t i = 0; i < shard_indices.size(); i++) { diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 8b7cc7d02f8..4a595f4e99b 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -93,11 +93,14 @@ class ShardingUtil { // Uses the provided `sharding` spec and expected shard shape to determine the // index slices for the shards which belong on `devices`. Only supports // `REPLICATED` and `OTHER` sharding types. - static std::vector> - GetShardIndicesForDevices(const std::vector& shard_shape, - const std::vector& tensor_shape, - const xla::OpSharding sharding, - const std::vector& devices); + // For each input device, returns a pair of the shard's replica_id and a + // vector of TensorIndex denoting the offset of the device's shard into the + // global tensor. + static std::vector>> + GetShardReplicaAndIndicesForDevices(const std::vector& shard_shape, + const std::vector& tensor_shape, + const xla::OpSharding sharding, + const std::vector& devices); // Returns the indices for the shards. Supports `OTHER` sharding types and // called when input is sharded along the batch axis. diff --git a/torch_xla/experimental/distributed_checkpoint.py b/torch_xla/experimental/distributed_checkpoint.py index 09be65d4b0a..5b1ee97b7d6 100644 --- a/torch_xla/experimental/distributed_checkpoint.py +++ b/torch_xla/experimental/distributed_checkpoint.py @@ -333,9 +333,10 @@ def _create_write_items_for_xla_sharded_tensor( items = [] # Since local shards are currently moved to CPU on creation, we need to get # the shard indices indirectly to avoid unnecessarily consuming host memory. - shard_indices = torch_xla._XLAC._get_local_shard_indices(t.global_tensor) + replica_and_indices = torch_xla._XLAC._get_local_shard_replica_and_indices( + t.global_tensor) prop = TensorProperties.create_from_tensor(t) - for shard_ind, indices in enumerate(shard_indices): + for shard_ind, (_, indices) in enumerate(replica_and_indices): write_item = _create_write_item_from_indices(fqn, shard_ind, indices, t.size(), prop) items.append(write_item) @@ -389,7 +390,10 @@ def _create_xla_read_items(sharded_state_dict: STATE_DICT_TYPE, md = metadata.state_dict_metadata[fqn] # Since local shards are currently moved to CPU on creation, we need to get # the shard indices indirectly to avoid unnecessarily consuming host memory. - shard_indices = torch_xla._XLAC._get_local_shard_indices(t.global_tensor) - chunks = [_create_chunk_from_shard_index(index) for index in shard_indices] + replica_and_indices = torch_xla._XLAC._get_local_shard_replica_and_indices( + t.global_tensor) + chunks = [ + _create_chunk_from_shard_index(ind) for _, ind in replica_and_indices + ] items.extend(create_read_items_for_chunk_list(fqn, md, chunks)) return items diff --git a/torch_xla/experimental/xla_sharded_tensor.py b/torch_xla/experimental/xla_sharded_tensor.py index ce423b3918f..1c3eaf34916 100644 --- a/torch_xla/experimental/xla_sharded_tensor.py +++ b/torch_xla/experimental/xla_sharded_tensor.py @@ -23,8 +23,14 @@ class XLAShard: # The device this shard's data originated from. shard_device: str - # TODO(jonbolin): Expose replica rank with partial replication - # rank: int + # The replica this shard belongs to, as determined by the sharding. The + # replica is determined differently for each sharding type: + # - TILED: Since the tensor isn't replicated, replica_id is always 0. + # - PARTIAL: replica_id is taken from the OpSharding and is a value in + # the range [0, num_replica). + # - REPLICATED: Since the tensor is fully replicated, replica_id is the + # device's global ordinal. + replica_id: int @property def unpadded_data(self) -> torch.Tensor: @@ -110,8 +116,13 @@ def __new__(cls, elem: torch.Tensor, *args, **kwargs): @property def local_shards(self) -> List[XLAShard]: shards, devices = torch_xla._XLAC._get_local_shards(self.global_tensor) - indices = torch_xla._XLAC._get_local_shard_indices(self.global_tensor) - return [XLAShard(s, i, d) for s, i, d in zip(shards, indices, devices)] + replica_and_indices = torch_xla._XLAC._get_local_shard_replica_and_indices( + self.global_tensor) + zipped = zip(shards, replica_and_indices, devices) + return [ + XLAShard(data, indices, dev, replica) + for data, (replica, indices), dev in zipped + ] # Load the given list of local shards into the underlying tensor's data # on the local devices.