Skip to content

Commit

Permalink
Expose replica_id in XLAShard (#5654)
Browse files Browse the repository at this point in the history
* Expose shard rank in XLAShard

* Improve tests and fix dist chkpt

* Change 'rank' to 'replica_id'

* Improve documentation
  • Loading branch information
jonb377 authored Oct 3, 2023
1 parent 745985a commit fba326e
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 56 deletions.
29 changes: 18 additions & 11 deletions test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
auto sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
auto shard_shape = ShardingUtil::GetShardShape(sharding_spec);
auto shard_indices = ShardingUtil::GetShardIndicesForDevices(
auto replica_and_indices = ShardingUtil::GetShardReplicaAndIndicesForDevices(
shard_shape, tensor.sizes().vec(), sharding, devices);
EXPECT_EQ(shard_indices.size(), devices.size());
EXPECT_EQ(replica_and_indices.size(), devices.size());
/* Tiled indices should be:
dim=0 dim=1
device=0 [0:4, 0:4]
Expand All @@ -81,11 +81,15 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
device=3 [4:8, 4:7] */
std::vector<std::vector<int>> slice_starts = {{0, 0}, {0, 4}, {4, 0}, {4, 4}};
std::vector<std::vector<int>> slice_ends = {{4, 4}, {4, 7}, {8, 4}, {8, 7}};
for (int device = 0; device < shard_indices.size(); ++device) {
EXPECT_EQ(shard_indices[device].size(), tensor.sizes().size());
for (int dim = 0; dim < shard_indices[device].size(); ++dim) {
EXPECT_TRUE(shard_indices[device][dim].is_slice());
auto slice = shard_indices[device][dim].slice();
for (int device = 0; device < replica_and_indices.size(); ++device) {
auto& shard_replica_id = replica_and_indices[device].first;
EXPECT_EQ(shard_replica_id,
0); // Shard replica_id is always 0 for tiled sharding.
auto& shard_indices = replica_and_indices[device].second;
EXPECT_EQ(shard_indices.size(), tensor.sizes().size());
for (int dim = 0; dim < shard_indices.size(); ++dim) {
EXPECT_TRUE(shard_indices[dim].is_slice());
auto slice = shard_indices[dim].slice();
EXPECT_EQ(slice.start(), slice_starts[device][dim]);
EXPECT_EQ(slice.stop(), slice_ends[device][dim]);
EXPECT_EQ(slice.step(), 1);
Expand All @@ -94,12 +98,15 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
sharding = xla::HloSharding::Replicate().ToProto();
sharding_spec->sharding = sharding;
shard_shape = ShardingUtil::GetShardShape(sharding_spec);
shard_indices = ShardingUtil::GetShardIndicesForDevices(
replica_and_indices = ShardingUtil::GetShardReplicaAndIndicesForDevices(
shard_shape, tensor.sizes().vec(), sharding, devices);
EXPECT_EQ(shard_indices.size(), devices.size());
EXPECT_EQ(replica_and_indices.size(), devices.size());
for (int i = 0; i < devices.size(); ++i) {
EXPECT_EQ(shard_indices[i].size(), 1);
EXPECT_TRUE(shard_indices[i][0].is_ellipsis());
auto& replica_id = replica_and_indices[i].first;
EXPECT_EQ(replica_id, i); // Shard replica_id should equal global ordinal.
auto& shard_indices = replica_and_indices[i].second;
EXPECT_EQ(shard_indices.size(), 1);
EXPECT_TRUE(shard_indices[0].is_ellipsis());
}
}

Expand Down
36 changes: 35 additions & 1 deletion test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_xla_shards(self):
for i, shard in enumerate(shards):
self.assertEqual(shard.data.device, torch.device('cpu'))
self.assertEqual(shard.data.shape, (shard_len,))
start, end = (i, i + 1) * shard_len
start, end = i * shard_len, (i + 1) * shard_len
expected = torch.arange(start, end, dtype=torch.float32)
self.assertTrue(torch.allclose(shard.data, expected))
if isinstance(shard.indices, list):
Expand All @@ -77,6 +77,8 @@ def test_xla_shards(self):
else:
self.assertIsInstance(shard.indices, type(Ellipsis))
self.assertTrue(torch.allclose(shard.data, t[shard.indices]))
# Tiled sharding makes all shards have replica_id 0.
self.assertEqual(shard.replica_id, 0)

def test_padded_xla_shards(self):
num_element = self.n_devices + 1 # Ensure padding with two or more devices
Expand Down Expand Up @@ -105,6 +107,8 @@ def test_padded_xla_shards(self):
else:
self.assertIsInstance(shard.indices, type(Ellipsis))
self.assertTrue(torch.allclose(shard.unpadded_data, t[shard.indices]))
# Tiled sharding makes all shards have replica_id 0.
self.assertEqual(shard.replica_id, 0)

def test_replicated_xla_shards(self):
num_element = self.n_devices
Expand All @@ -120,6 +124,36 @@ def test_replicated_xla_shards(self):
self.assertIsInstance(shard.indices, type(Ellipsis))
self.assertTrue(torch.allclose(shard.data, t[shard.indices]))
self.assertTrue(torch.allclose(shard.data, shard.unpadded_data))
# Replicated sharding sets the shard replica_id to the device ordinal
self.assertEqual(shard.replica_id, i)

@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
"Multiple devices required for partial replication")
def test_partially_replicated_xla_shards(self):
num_element = 256
mesh = self._get_mesh((self.n_devices // 2, 2))
t = torch.arange(num_element, dtype=torch.float32).reshape((16, 16))
# Partial replication along the 0th tensor axis, shard 2-way on the 1st
xt = xs.mark_sharding(t.to(xm.xla_device()), mesh, (None, 1))
shard_len = t.shape[1] // 2

shards = xt.local_shards
self.assertEqual(len(shards), self.n_devices)
for i, shard in enumerate(shards):
self.assertEqual(shard.data.device, torch.device('cpu'))
self.assertEqual(shard.data.shape, (t.shape[0], shard_len))
self.assertEqual(len(shard.indices), len(t.shape))
start, end = (i % 2) * shard_len, ((i % 2) + 1) * shard_len
# All shards should contain the full range for dim 0
self.assertEqual(shard.indices[0], slice(0, t.shape[0], 1))
# The index range should be sharded for dim 1
self.assertEqual(shard.indices[1], slice(start, end, 1))
self.assertTrue(torch.allclose(shard.data, t[shard.indices]))
self.assertTrue(torch.allclose(shard.data, shard.unpadded_data))
# The replica_id should be coincide with the replication group for the
# device. Given the mesh shape, the shard replica_id will be the device's
# row in the mesh, which is device_id // 2
self.assertEqual(shard.replica_id, i // 2)

def test_load_local_shards(self):
num_element = self.n_devices
Expand Down
45 changes: 26 additions & 19 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1667,7 +1667,7 @@ void InitXlaModuleBindings(py::module m) {
// underlying ComputationClient::GetDataShards. As such, the shards will
// contain any padding that was applied to ensure they all have the same
// shape. Note that this padding is _not_ included in the global indices
// returned by `_get_local_shard_indices`.
// returned by `_get_local_shard_replica_and_indices`.
m.def("_get_local_shards",
[](const at::Tensor& input)
-> std::tuple<std::vector<at::Tensor>, std::vector<std::string>> {
Expand Down Expand Up @@ -1699,13 +1699,16 @@ void InitXlaModuleBindings(py::module m) {
}
return std::make_tuple(shards, str_devices);
});
// Returns the indices of the shards into the global tensor as either
// a Python list of slices for each dimension or a Python Ellipsis object
// indicating that the tensor is replicated. These indices will not reflect
// any padding that has been applied to the shards. The order of the returned
// indices matches the order of the shards returned from `_get_local_shards`.
m.def("_get_local_shard_indices",
[](const at::Tensor& input) -> std::vector<py::object> {
// For each local shard, returns the tuple:
// (replica_id: int, indices: Union[List[Slice], Ellipsis]),
// where `replica_id` is the replica the shard belongs to and `indices` index
// into the global tensor. The value of `indices` is either a Python list of
// slices for each dimension or an Ellipsis object indicating that the tensor
// is replicated. These indices will not reflect any padding that has been
// applied to the shards. The order of the returned indices matches the order
// of the shards returned from `_get_local_shards`.
m.def("_get_local_shard_replica_and_indices",
[](const at::Tensor& input) -> std::vector<std::pair<int, py::object>> {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
XLA_CHECK(xtensor->sharding_spec() != nullptr)
<< "Tensor is not sharded";
Expand All @@ -1720,29 +1723,33 @@ void InitXlaModuleBindings(py::module m) {
auto sharding_spec = xtensor->sharding_spec();
auto sharding = xtensor->sharding_spec()->sharding;
auto shard_shape = ShardingUtil::GetShardShape(sharding_spec);
auto indices = ShardingUtil::GetShardIndicesForDevices(
shard_shape, input.sizes().vec(), sharding, shard_devices);
auto replica_and_indices =
ShardingUtil::GetShardReplicaAndIndicesForDevices(
shard_shape, input.sizes().vec(), sharding, shard_devices);

// Convert each vector<TensorIndex> to List[py::slice] or py::ellipsis
std::vector<py::object> result;
std::vector<std::pair<int, py::object>> result;
result.reserve(shard_devices.size());
for (auto& device_indices : indices) {
XLA_CHECK(device_indices.size() > 0)
for (auto& device_replica_and_indices : replica_and_indices) {
auto& replica_id = device_replica_and_indices.first;
auto& indices = device_replica_and_indices.second;
XLA_CHECK(indices.size() > 0)
<< "Unexpected empty shard indices for tensor " << input;
if (device_indices[0].is_ellipsis()) {
result.push_back(py::ellipsis());
if (indices[0].is_ellipsis()) {
result.push_back(std::make_pair(replica_id, py::ellipsis()));
} else {
std::vector<py::object> device_slices;
for (auto& tensor_index : device_indices) {
std::vector<py::object> index_slices;
for (auto& tensor_index : indices) {
XLA_CHECK(tensor_index.is_slice())
<< "Unexpected TensorIndex type: " << tensor_index;
auto slice = tensor_index.slice();
ssize_t start = slice.start().expect_int();
ssize_t stop = slice.stop().expect_int();
ssize_t step = slice.step().expect_int();
device_slices.push_back(py::slice(start, stop, step));
index_slices.push_back(py::slice(start, stop, step));
}
result.push_back(py::cast(device_slices));
result.push_back(
std::make_pair(replica_id, py::cast(index_slices)));
}
}
return result;
Expand Down
38 changes: 26 additions & 12 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,24 +435,29 @@ ShardingUtil::GetShardIndicesForMinibatchTensor(
return shard_indices;
}

std::vector<std::vector<at::indexing::TensorIndex>>
ShardingUtil::GetShardIndicesForDevices(
std::vector<std::pair<int, std::vector<at::indexing::TensorIndex>>>
ShardingUtil::GetShardReplicaAndIndicesForDevices(
const std::vector<int64_t>& shard_shape,
const std::vector<int64_t>& tensor_shape, const xla::OpSharding sharding,
const std::vector<std::string>& devices) {
using namespace at::indexing;

// `shard_indices[dev][dim]` represents the index slice for dimension `dim`
// that belongs on device `devices[dev]` if the tensor is sharded. If
// `sharding` is REPLICATED, `shard_indices[dev]` will only have a single
// Ellipsis element to indicate that the tensor is replicated across all
// dimensions.
std::vector<std::vector<at::indexing::TensorIndex>> shard_indices(
std::vector<std::pair<int, std::vector<TensorIndex>>> shard_indices(
devices.size());
auto tile_shape = sharding.tile_assignment_dimensions();
if (sharding.type() == xla::OpSharding::REPLICATED) {
// Use Ellipsis to indicate all dimensions are replicated
auto ellipsis = at::indexing::TensorIndex(at::indexing::Ellipsis);
auto indices = std::vector<at::indexing::TensorIndex>({ellipsis});
std::fill_n(shard_indices.begin(), shard_indices.size(), indices);
auto ellipsis = TensorIndex(Ellipsis);
auto indices = std::vector<TensorIndex>({ellipsis});
for (int i = 0; i < devices.size(); ++i) {
int global_ordinal = ParseDeviceString(devices[i]).ordinal();
shard_indices[i] = std::make_pair(global_ordinal, indices);
}
} else if (sharding.type() == xla::OpSharding::OTHER) {
auto device_index = build_index_map(devices);
std::vector<int64_t> tile_assignment_devices(
Expand All @@ -472,6 +477,10 @@ ShardingUtil::GetShardIndicesForDevices(
continue;
}

// The replica id for this shard. This value is only updated from 0 if
// the sharding is partially replicated.
int replica_id = 0;

// Given the shard's row-major index `i`, we need to calculate shard's
// coordinates (n_0, ..., n_d) in the tiling to generate the index
// slices. Using `N_j = tile_shape[j]` and `0 <= n_j < N_j`, the
Expand All @@ -482,26 +491,27 @@ ShardingUtil::GetShardIndicesForDevices(
// n_0)))`. Then `offset_d = i`, `n_j = offset_j % N_j`, and
// `offset_{j-1} = offset_j / N_j`.
int offset = i;
std::vector<at::indexing::TensorIndex> indices;
std::vector<TensorIndex> indices;
for (int j = tile_shape.size() - 1; j >= 0; j--) {
int64_t n_j = offset % tile_shape[j];
if (sharding.replicate_on_last_tile_dim() &&
j == tile_shape.size() - 1) {
// the last tile assignment dimension is replicated, which implies
// that the consecutive `tile_shape[j]` devices hold the replicated.
replica_id = n_j;
offset /= tile_shape[j];
continue;
}
int64_t n_j = offset % tile_shape[j];
// Clamp the slice bounds to the tensor shape to accurately reflect
// the shard size without padding.
int start = std::min(n_j * shard_shape[j], tensor_shape[j]);
int end = std::min((n_j + 1) * shard_shape[j], tensor_shape[j]);
auto slice = at::indexing::Slice(start, end);
indices.push_back(at::indexing::TensorIndex(slice));
auto slice = Slice(start, end);
indices.push_back(TensorIndex(slice));
offset /= tile_shape[j];
}
std::reverse(indices.begin(), indices.end());
shard_indices[device_index[core]] = indices;
shard_indices[device_index[core]] = std::make_pair(replica_id, indices);
}
} else {
TF_LOG(ERROR) << "Unsupported OpSharding type " << sharding.type();
Expand Down Expand Up @@ -534,8 +544,12 @@ std::vector<at::Tensor> ShardingUtil::ShardTensor(
if (minibatch) {
shard_indices = GetShardIndicesForMinibatchTensor(shard_shape, devices);
} else {
shard_indices = GetShardIndicesForDevices(
auto replica_and_indices = GetShardReplicaAndIndicesForDevices(
shard_shape, tensor.sizes().vec(), sharding, devices);
// Extract only the indices, the replica_id is unnecessary for sharding.
std::transform(replica_and_indices.begin(), replica_and_indices.end(),
std::back_inserter(shard_indices),
[](auto& pair) { return pair.second; });
}

for (size_t i = 0; i < shard_indices.size(); i++) {
Expand Down
13 changes: 8 additions & 5 deletions torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,14 @@ class ShardingUtil {
// Uses the provided `sharding` spec and expected shard shape to determine the
// index slices for the shards which belong on `devices`. Only supports
// `REPLICATED` and `OTHER` sharding types.
static std::vector<std::vector<at::indexing::TensorIndex>>
GetShardIndicesForDevices(const std::vector<int64_t>& shard_shape,
const std::vector<int64_t>& tensor_shape,
const xla::OpSharding sharding,
const std::vector<std::string>& devices);
// For each input device, returns a pair of the shard's replica_id and a
// vector of TensorIndex denoting the offset of the device's shard into the
// global tensor.
static std::vector<std::pair<int, std::vector<at::indexing::TensorIndex>>>
GetShardReplicaAndIndicesForDevices(const std::vector<int64_t>& shard_shape,
const std::vector<int64_t>& tensor_shape,
const xla::OpSharding sharding,
const std::vector<std::string>& devices);

// Returns the indices for the shards. Supports `OTHER` sharding types and
// called when input is sharded along the batch axis.
Expand Down
12 changes: 8 additions & 4 deletions torch_xla/experimental/distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,10 @@ def _create_write_items_for_xla_sharded_tensor(
items = []
# Since local shards are currently moved to CPU on creation, we need to get
# the shard indices indirectly to avoid unnecessarily consuming host memory.
shard_indices = torch_xla._XLAC._get_local_shard_indices(t.global_tensor)
replica_and_indices = torch_xla._XLAC._get_local_shard_replica_and_indices(
t.global_tensor)
prop = TensorProperties.create_from_tensor(t)
for shard_ind, indices in enumerate(shard_indices):
for shard_ind, (_, indices) in enumerate(replica_and_indices):
write_item = _create_write_item_from_indices(fqn, shard_ind, indices,
t.size(), prop)
items.append(write_item)
Expand Down Expand Up @@ -389,7 +390,10 @@ def _create_xla_read_items(sharded_state_dict: STATE_DICT_TYPE,
md = metadata.state_dict_metadata[fqn]
# Since local shards are currently moved to CPU on creation, we need to get
# the shard indices indirectly to avoid unnecessarily consuming host memory.
shard_indices = torch_xla._XLAC._get_local_shard_indices(t.global_tensor)
chunks = [_create_chunk_from_shard_index(index) for index in shard_indices]
replica_and_indices = torch_xla._XLAC._get_local_shard_replica_and_indices(
t.global_tensor)
chunks = [
_create_chunk_from_shard_index(ind) for _, ind in replica_and_indices
]
items.extend(create_read_items_for_chunk_list(fqn, md, chunks))
return items
19 changes: 15 additions & 4 deletions torch_xla/experimental/xla_sharded_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,14 @@ class XLAShard:
# The device this shard's data originated from.
shard_device: str

# TODO(jonbolin): Expose replica rank with partial replication
# rank: int
# The replica this shard belongs to, as determined by the sharding. The
# replica is determined differently for each sharding type:
# - TILED: Since the tensor isn't replicated, replica_id is always 0.
# - PARTIAL: replica_id is taken from the OpSharding and is a value in
# the range [0, num_replica).
# - REPLICATED: Since the tensor is fully replicated, replica_id is the
# device's global ordinal.
replica_id: int

@property
def unpadded_data(self) -> torch.Tensor:
Expand Down Expand Up @@ -110,8 +116,13 @@ def __new__(cls, elem: torch.Tensor, *args, **kwargs):
@property
def local_shards(self) -> List[XLAShard]:
shards, devices = torch_xla._XLAC._get_local_shards(self.global_tensor)
indices = torch_xla._XLAC._get_local_shard_indices(self.global_tensor)
return [XLAShard(s, i, d) for s, i, d in zip(shards, indices, devices)]
replica_and_indices = torch_xla._XLAC._get_local_shard_replica_and_indices(
self.global_tensor)
zipped = zip(shards, replica_and_indices, devices)
return [
XLAShard(data, indices, dev, replica)
for data, (replica, indices), dev in zipped
]

# Load the given list of local shards into the underlying tensor's data
# on the local devices.
Expand Down

0 comments on commit fba326e

Please sign in to comment.