Skip to content

Commit

Permalink
Vectorize local shard retrieval (#5826)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 authored Dec 1, 2023
1 parent f6a775c commit c919973
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 57 deletions.
101 changes: 61 additions & 40 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1760,48 +1760,67 @@ void InitXlaModuleBindings(py::module m) {
// contain any padding that was applied to ensure they all have the same
// shape. Note that this padding is _not_ included in the global indices
// returned by `_get_local_shard_replica_and_indices`.
// For each input tensor, returns a list of shards and their corresponding
// device string.
m.def("_get_local_shards",
[](const at::Tensor& input)
-> std::tuple<std::vector<at::Tensor>, std::vector<std::string>> {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
XLA_CHECK(xtensor->GetXlaData() != nullptr)
<< "Shard data is not available";
XLA_CHECK(xtensor->sharding_spec() != nullptr)
<< "Tensor is not sharded";
XLA_CHECK(UseVirtualDevice())
<< "Virtual device must be enabled to use _get_local_shards";
auto handle =
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
xtensor->GetXlaData());
std::vector<runtime::ComputationClient::DataPtr> shard_handles =
runtime::GetComputationClient()->GetDataShards(handle);
std::vector<at::Tensor> shards;
std::vector<std::string> str_devices;
shards.reserve(shard_handles.size());
str_devices.reserve(shard_handles.size());
// Tansfer shards from the device and create cpu tensors.
for (const runtime::ComputationClient::DataPtr shard_handle :
shard_handles) {
shards.push_back(
XlaDataToTensors({shard_handle},
{MaybeUpcastToHostTorchType(
shard_handle->shape().element_type())})
.front());
str_devices.push_back(shard_handle->device());
[](const std::vector<at::Tensor>& input)
-> std::vector<std::vector<std::pair<at::Tensor, std::string>>> {
std::vector<runtime::ComputationClient::DataPtr> handles;
std::vector<at::ScalarType> element_types;
// Find all shard handles for transfer
for (auto& tensor : input) {
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
XLA_CHECK(xtensor->GetXlaData() != nullptr)
<< "Shard data is not available";
XLA_CHECK(xtensor->sharding_spec() != nullptr)
<< "Tensor is not sharded";
auto handle =
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
xtensor->GetXlaData());
std::vector<runtime::ComputationClient::DataPtr> shard_handles =
runtime::GetComputationClient()->GetDataShards(handle);
handles.insert(handles.end(), shard_handles.begin(),
shard_handles.end());
element_types.insert(element_types.end(), shard_handles.size(),
MaybeUpcastToHostTorchType(
shard_handles[0]->shape().element_type()));
}

std::vector<at::Tensor> cpu_shards =
XlaDataToTensors(WrapXlaData(handles), element_types);
// Populate the resulting vector of shards and device strings
std::vector<std::vector<std::pair<at::Tensor, std::string>>> result;
int shards_per_tensor =
runtime::GetComputationClient()->GetLocalDevices().size();
result.reserve(cpu_shards.size() / shards_per_tensor);
for (int i = 0; i < cpu_shards.size(); i += shards_per_tensor) {
std::vector<std::pair<at::Tensor, std::string>> shard_devices;
for (int shard = 0; shard < shards_per_tensor; ++shard) {
at::Tensor cpu_shard = cpu_shards[i + shard];
std::string source_device = handles[i + shard]->device();
std::pair<at::Tensor, std::string> shard_dev(cpu_shard,
source_device);
shard_devices.push_back(shard_dev);
}
result.push_back(shard_devices);
}
return std::make_tuple(shards, str_devices);
return result;
});
// For each local shard, returns the tuple:
// For each input tensors' local shards, returns the tuple:
// (replica_id: int, indices: Union[List[Slice], Ellipsis]),
// where `replica_id` is the replica the shard belongs to and `indices` index
// into the global tensor. The value of `indices` is either a Python list of
// slices for each dimension or an Ellipsis object indicating that the tensor
// is replicated. These indices will not reflect any padding that has been
// applied to the shards. The order of the returned indices matches the order
// of the shards returned from `_get_local_shards`.
m.def("_get_local_shard_replica_and_indices",
[](const at::Tensor& input) -> std::vector<std::pair<int, py::object>> {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
m.def(
"_get_local_shard_replica_and_indices",
[](const std::vector<at::Tensor>& input_tensors)
-> std::vector<std::vector<std::pair<int, py::object>>> {
std::vector<std::vector<std::pair<int, py::object>>> result;
for (auto& tensor : input_tensors) {
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
XLA_CHECK(xtensor->sharding_spec() != nullptr)
<< "Tensor is not sharded";
auto handle =
Expand All @@ -1817,18 +1836,18 @@ void InitXlaModuleBindings(py::module m) {
auto shard_shape = ShardingUtil::GetShardShape(sharding_spec);
auto replica_and_indices =
ShardingUtil::GetShardReplicaAndIndicesForDevices(
shard_shape, input.sizes().vec(), sharding, shard_devices);
shard_shape, tensor.sizes().vec(), sharding, shard_devices);

// Convert each vector<TensorIndex> to List[py::slice] or py::ellipsis
std::vector<std::pair<int, py::object>> result;
result.reserve(shard_devices.size());
std::vector<std::pair<int, py::object>> tensor_ind;
tensor_ind.reserve(shard_devices.size());
for (auto& device_replica_and_indices : replica_and_indices) {
auto& replica_id = device_replica_and_indices.first;
auto& indices = device_replica_and_indices.second;
XLA_CHECK(indices.size() > 0)
<< "Unexpected empty shard indices for tensor " << input;
<< "Unexpected empty shard indices for tensor " << tensor;
if (indices[0].is_ellipsis()) {
result.push_back(std::make_pair(replica_id, py::ellipsis()));
tensor_ind.push_back(std::make_pair(replica_id, py::ellipsis()));
} else {
std::vector<py::object> index_slices;
for (auto& tensor_index : indices) {
Expand All @@ -1840,12 +1859,14 @@ void InitXlaModuleBindings(py::module m) {
ssize_t step = slice.step().expect_int();
index_slices.push_back(py::slice(start, stop, step));
}
result.push_back(
tensor_ind.push_back(
std::make_pair(replica_id, py::cast(index_slices)));
}
}
return result;
});
result.push_back(tensor_ind);
}
return result;
});
// Load a list of local shards into an explicitly-sharded tensor. A shard must
// be provided for each device.
m.def("_load_local_shards", [](const at::Tensor& tensor,
Expand Down
9 changes: 4 additions & 5 deletions torch_xla/distributed/spmd/xla_sharded_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,12 @@ def __new__(cls, elem: torch.Tensor, *args, **kwargs):
# which results from the sharding.
@property
def local_shards(self) -> List[XLAShard]:
shards, devices = torch_xla._XLAC._get_local_shards(self.global_tensor)
replica_and_indices = torch_xla._XLAC._get_local_shard_replica_and_indices(
self.global_tensor)
zipped = zip(shards, replica_and_indices, devices)
shard_dev = torch_xla._XLAC._get_local_shards([self.global_tensor])[0]
replica_ind = torch_xla._XLAC._get_local_shard_replica_and_indices(
[self.global_tensor])[0]
return [
XLAShard(data, indices, dev, replica)
for data, (replica, indices), dev in zipped
for (data, dev), (replica, indices) in zip(shard_dev, replica_ind)
]

# Load the given list of local shards into the underlying tensor's data
Expand Down
65 changes: 55 additions & 10 deletions torch_xla/experimental/distributed_checkpoint/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
# their APIs.

import dataclasses
from itertools import starmap

import torch
import torch_xla
import torch_xla.distributed.spmd as xs

from torch.distributed.checkpoint.planner import SavePlan
Expand All @@ -24,7 +26,7 @@
from torch.distributed.checkpoint.metadata import (MetadataIndex,
STATE_DICT_TYPE)
from torch_xla.distributed.spmd import XLAShardedTensor, ShardingType
from torch.utils._pytree import tree_map
from torch.utils._pytree import tree_flatten, tree_unflatten

PATH_ITEM = Union[str, int]
OBJ_PATH = Tuple[PATH_ITEM, ...]
Expand Down Expand Up @@ -217,16 +219,59 @@ class _CpuShards:
global_shape: torch.Size


def _sharded_cpu_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
def _cpu_shards_from_tensors(tensors: List[torch.Tensor]):
"""
Converts a state_dict on XLA device to a sharded state_dict on CPU.
Transfer all shards for the input tensors to CPU, and create a _CpuShards
object for each.
"""

def move_state_dict_to_cpu(v):
v = xs.wrap_if_sharded(v)
if not _is_sharded_tensor(v):
v = _unwrap_xla_sharded_tensor(v)
return v.cpu() if isinstance(v, torch.Tensor) else v
return _CpuShards(shards=v.local_shards, global_shape=v.global_tensor.shape)
def create_cpu_shards(global_tensor: torch.Tensor,
shards_dev: List[Tuple[torch.Tensor, str]],
replica_ind: List[Tuple[int, Union[List[slice],
type(Ellipsis)]]]):
shards = [
xs.XLAShard(data, indices, dev, replica)
for (data, dev), (replica, indices) in zip(shards_dev, replica_ind)
]
global_shape = global_tensor.shape
return _CpuShards(shards=shards, global_shape=global_shape)

shards_devs = torch_xla._XLAC._get_local_shards(tensors)
rep_inds = torch_xla._XLAC._get_local_shard_replica_and_indices(tensors)
return list(starmap(create_cpu_shards, zip(tensors, shards_devs, rep_inds)))


return tree_map(move_state_dict_to_cpu, state_dict)
def _sharded_cpu_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
"""
Converts a state_dict on XLA device to a sharded state_dict on CPU.
"""
flat, tree_spec = tree_flatten(state_dict)
flat = [xs.wrap_if_sharded(x) for x in flat]
sharded = [
_unwrap_xla_sharded_tensor(x) for x in flat if _is_sharded_tensor(x)
]

# Move all sharded tensors to CPU
cpu_shards = _cpu_shards_from_tensors(sharded)
cpu_shards_iter = iter(cpu_shards)

# Move all unsharded tensors to CPU
unsharded_tensors = [
_unwrap_xla_sharded_tensor(x)
for x in flat
if isinstance(x, torch.Tensor) and not _is_sharded_tensor(x)
]
cpu_tensors = torch_xla._XLAC._xla_get_cpu_tensors(unsharded_tensors)
cpu_tensors_iter = iter(cpu_tensors)

# Combine the results. The order between the iterators and the flattened
# state_dict is consistent, so simply interweave the iterators.
def to_cpu(x: Any):
if _is_sharded_tensor(x):
return next(cpu_shards_iter)
elif isinstance(x, torch.Tensor):
return next(cpu_tensors_iter)
return x

flat = [to_cpu(x) for x in flat]
return tree_unflatten(flat, tree_spec)
4 changes: 2 additions & 2 deletions torch_xla/experimental/distributed_checkpoint/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def _create_write_items_for_xla_sharded_tensor(
# Since local shards are currently moved to CPU on creation, we need to get
# the shard indices indirectly to avoid unnecessarily consuming host memory.
replica_and_indices = torch_xla._XLAC._get_local_shard_replica_and_indices(
t.global_tensor)
[t.global_tensor])[0]
prop = TensorProperties.create_from_tensor(t)
for shard_ind, (_, indices) in enumerate(replica_and_indices):
write_item = _create_write_item_from_indices(fqn, shard_ind, indices,
Expand Down Expand Up @@ -386,7 +386,7 @@ def _create_xla_read_items(sharded_state_dict: STATE_DICT_TYPE,
# Since local shards are currently moved to CPU on creation, we need to get
# the shard indices indirectly to avoid unnecessarily consuming host memory.
replica_and_indices = torch_xla._XLAC._get_local_shard_replica_and_indices(
t.global_tensor)
[t.global_tensor])[0]
chunks = [
_create_chunk_from_shard_index(ind) for _, ind in replica_and_indices
]
Expand Down

0 comments on commit c919973

Please sign in to comment.