Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize local shard retrieval #5826

Merged
merged 2 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 61 additions & 40 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1760,48 +1760,67 @@ void InitXlaModuleBindings(py::module m) {
// contain any padding that was applied to ensure they all have the same
// shape. Note that this padding is _not_ included in the global indices
// returned by `_get_local_shard_replica_and_indices`.
// For each input tensor, returns a list of shards and their corresponding
// device string.
m.def("_get_local_shards",
[](const at::Tensor& input)
-> std::tuple<std::vector<at::Tensor>, std::vector<std::string>> {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
XLA_CHECK(xtensor->GetXlaData() != nullptr)
<< "Shard data is not available";
XLA_CHECK(xtensor->sharding_spec() != nullptr)
<< "Tensor is not sharded";
XLA_CHECK(UseVirtualDevice())
<< "Virtual device must be enabled to use _get_local_shards";
auto handle =
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
xtensor->GetXlaData());
std::vector<runtime::ComputationClient::DataPtr> shard_handles =
runtime::GetComputationClient()->GetDataShards(handle);
std::vector<at::Tensor> shards;
std::vector<std::string> str_devices;
shards.reserve(shard_handles.size());
str_devices.reserve(shard_handles.size());
// Tansfer shards from the device and create cpu tensors.
for (const runtime::ComputationClient::DataPtr shard_handle :
shard_handles) {
shards.push_back(
XlaDataToTensors({shard_handle},
{MaybeUpcastToHostTorchType(
shard_handle->shape().element_type())})
.front());
str_devices.push_back(shard_handle->device());
[](const std::vector<at::Tensor>& input)
-> std::vector<std::vector<std::pair<at::Tensor, std::string>>> {
std::vector<runtime::ComputationClient::DataPtr> handles;
std::vector<at::ScalarType> element_types;
// Find all shard handles for transfer
for (auto& tensor : input) {
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
XLA_CHECK(xtensor->GetXlaData() != nullptr)
<< "Shard data is not available";
XLA_CHECK(xtensor->sharding_spec() != nullptr)
<< "Tensor is not sharded";
auto handle =
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
xtensor->GetXlaData());
std::vector<runtime::ComputationClient::DataPtr> shard_handles =
runtime::GetComputationClient()->GetDataShards(handle);
handles.insert(handles.end(), shard_handles.begin(),
shard_handles.end());
element_types.insert(element_types.end(), shard_handles.size(),
MaybeUpcastToHostTorchType(
shard_handles[0]->shape().element_type()));
}

std::vector<at::Tensor> cpu_shards =
XlaDataToTensors(WrapXlaData(handles), element_types);
// Populate the resulting vector of shards and device strings
std::vector<std::vector<std::pair<at::Tensor, std::string>>> result;
int shards_per_tensor =
runtime::GetComputationClient()->GetLocalDevices().size();
result.reserve(cpu_shards.size() / shards_per_tensor);
for (int i = 0; i < cpu_shards.size(); i += shards_per_tensor) {
std::vector<std::pair<at::Tensor, std::string>> shard_devices;
for (int shard = 0; shard < shards_per_tensor; ++shard) {
at::Tensor cpu_shard = cpu_shards[i + shard];
std::string source_device = handles[i + shard]->device();
std::pair<at::Tensor, std::string> shard_dev(cpu_shard,
source_device);
shard_devices.push_back(shard_dev);
}
result.push_back(shard_devices);
}
return std::make_tuple(shards, str_devices);
return result;
});
// For each local shard, returns the tuple:
// For each input tensors' local shards, returns the tuple:
// (replica_id: int, indices: Union[List[Slice], Ellipsis]),
// where `replica_id` is the replica the shard belongs to and `indices` index
// into the global tensor. The value of `indices` is either a Python list of
// slices for each dimension or an Ellipsis object indicating that the tensor
// is replicated. These indices will not reflect any padding that has been
// applied to the shards. The order of the returned indices matches the order
// of the shards returned from `_get_local_shards`.
m.def("_get_local_shard_replica_and_indices",
[](const at::Tensor& input) -> std::vector<std::pair<int, py::object>> {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
m.def(
"_get_local_shard_replica_and_indices",
[](const std::vector<at::Tensor>& input_tensors)
-> std::vector<std::vector<std::pair<int, py::object>>> {
std::vector<std::vector<std::pair<int, py::object>>> result;
for (auto& tensor : input_tensors) {
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
XLA_CHECK(xtensor->sharding_spec() != nullptr)
<< "Tensor is not sharded";
auto handle =
Expand All @@ -1817,18 +1836,18 @@ void InitXlaModuleBindings(py::module m) {
auto shard_shape = ShardingUtil::GetShardShape(sharding_spec);
auto replica_and_indices =
ShardingUtil::GetShardReplicaAndIndicesForDevices(
shard_shape, input.sizes().vec(), sharding, shard_devices);
shard_shape, tensor.sizes().vec(), sharding, shard_devices);

// Convert each vector<TensorIndex> to List[py::slice] or py::ellipsis
std::vector<std::pair<int, py::object>> result;
result.reserve(shard_devices.size());
std::vector<std::pair<int, py::object>> tensor_ind;
tensor_ind.reserve(shard_devices.size());
for (auto& device_replica_and_indices : replica_and_indices) {
auto& replica_id = device_replica_and_indices.first;
auto& indices = device_replica_and_indices.second;
XLA_CHECK(indices.size() > 0)
<< "Unexpected empty shard indices for tensor " << input;
<< "Unexpected empty shard indices for tensor " << tensor;
if (indices[0].is_ellipsis()) {
result.push_back(std::make_pair(replica_id, py::ellipsis()));
tensor_ind.push_back(std::make_pair(replica_id, py::ellipsis()));
} else {
std::vector<py::object> index_slices;
for (auto& tensor_index : indices) {
Expand All @@ -1840,12 +1859,14 @@ void InitXlaModuleBindings(py::module m) {
ssize_t step = slice.step().expect_int();
index_slices.push_back(py::slice(start, stop, step));
}
result.push_back(
tensor_ind.push_back(
std::make_pair(replica_id, py::cast(index_slices)));
}
}
return result;
});
result.push_back(tensor_ind);
}
return result;
});
// Load a list of local shards into an explicitly-sharded tensor. A shard must
// be provided for each device.
m.def("_load_local_shards", [](const at::Tensor& tensor,
Expand Down
9 changes: 4 additions & 5 deletions torch_xla/distributed/spmd/xla_sharded_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,12 @@ def __new__(cls, elem: torch.Tensor, *args, **kwargs):
# which results from the sharding.
@property
def local_shards(self) -> List[XLAShard]:
shards, devices = torch_xla._XLAC._get_local_shards(self.global_tensor)
replica_and_indices = torch_xla._XLAC._get_local_shard_replica_and_indices(
self.global_tensor)
zipped = zip(shards, replica_and_indices, devices)
shard_dev = torch_xla._XLAC._get_local_shards([self.global_tensor])[0]
replica_ind = torch_xla._XLAC._get_local_shard_replica_and_indices(
[self.global_tensor])[0]
return [
XLAShard(data, indices, dev, replica)
for data, (replica, indices), dev in zipped
for (data, dev), (replica, indices) in zip(shard_dev, replica_ind)
]

# Load the given list of local shards into the underlying tensor's data
Expand Down
65 changes: 55 additions & 10 deletions torch_xla/experimental/distributed_checkpoint/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
# their APIs.

import dataclasses
from itertools import starmap

import torch
import torch_xla
import torch_xla.distributed.spmd as xs

from torch.distributed.checkpoint.planner import SavePlan
Expand All @@ -24,7 +26,7 @@
from torch.distributed.checkpoint.metadata import (MetadataIndex,
STATE_DICT_TYPE)
from torch_xla.distributed.spmd import XLAShardedTensor, ShardingType
from torch.utils._pytree import tree_map
from torch.utils._pytree import tree_flatten, tree_unflatten

PATH_ITEM = Union[str, int]
OBJ_PATH = Tuple[PATH_ITEM, ...]
Expand Down Expand Up @@ -217,16 +219,59 @@ class _CpuShards:
global_shape: torch.Size


def _sharded_cpu_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
def _cpu_shards_from_tensors(tensors: List[torch.Tensor]):
"""
Converts a state_dict on XLA device to a sharded state_dict on CPU.
Transfer all shards for the input tensors to CPU, and create a _CpuShards
object for each.
"""

def move_state_dict_to_cpu(v):
v = xs.wrap_if_sharded(v)
if not _is_sharded_tensor(v):
v = _unwrap_xla_sharded_tensor(v)
return v.cpu() if isinstance(v, torch.Tensor) else v
return _CpuShards(shards=v.local_shards, global_shape=v.global_tensor.shape)
def create_cpu_shards(global_tensor: torch.Tensor,
shards_dev: List[Tuple[torch.Tensor, str]],
replica_ind: List[Tuple[int, Union[List[slice],
type(Ellipsis)]]]):
shards = [
xs.XLAShard(data, indices, dev, replica)
for (data, dev), (replica, indices) in zip(shards_dev, replica_ind)
]
global_shape = global_tensor.shape
return _CpuShards(shards=shards, global_shape=global_shape)

shards_devs = torch_xla._XLAC._get_local_shards(tensors)
rep_inds = torch_xla._XLAC._get_local_shard_replica_and_indices(tensors)
return list(starmap(create_cpu_shards, zip(tensors, shards_devs, rep_inds)))


return tree_map(move_state_dict_to_cpu, state_dict)
def _sharded_cpu_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
"""
Converts a state_dict on XLA device to a sharded state_dict on CPU.
"""
flat, tree_spec = tree_flatten(state_dict)
flat = [xs.wrap_if_sharded(x) for x in flat]
sharded = [
_unwrap_xla_sharded_tensor(x) for x in flat if _is_sharded_tensor(x)
]

# Move all sharded tensors to CPU
cpu_shards = _cpu_shards_from_tensors(sharded)
cpu_shards_iter = iter(cpu_shards)

# Move all unsharded tensors to CPU
unsharded_tensors = [
_unwrap_xla_sharded_tensor(x)
for x in flat
if isinstance(x, torch.Tensor) and not _is_sharded_tensor(x)
]
cpu_tensors = torch_xla._XLAC._xla_get_cpu_tensors(unsharded_tensors)
cpu_tensors_iter = iter(cpu_tensors)

# Combine the results. The order between the iterators and the flattened
# state_dict is consistent, so simply interweave the iterators.
def to_cpu(x: Any):
if _is_sharded_tensor(x):
return next(cpu_shards_iter)
elif isinstance(x, torch.Tensor):
return next(cpu_tensors_iter)
return x

flat = [to_cpu(x) for x in flat]
return tree_unflatten(flat, tree_spec)
4 changes: 2 additions & 2 deletions torch_xla/experimental/distributed_checkpoint/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def _create_write_items_for_xla_sharded_tensor(
# Since local shards are currently moved to CPU on creation, we need to get
# the shard indices indirectly to avoid unnecessarily consuming host memory.
replica_and_indices = torch_xla._XLAC._get_local_shard_replica_and_indices(
t.global_tensor)
[t.global_tensor])[0]
prop = TensorProperties.create_from_tensor(t)
for shard_ind, (_, indices) in enumerate(replica_and_indices):
write_item = _create_write_item_from_indices(fqn, shard_ind, indices,
Expand Down Expand Up @@ -386,7 +386,7 @@ def _create_xla_read_items(sharded_state_dict: STATE_DICT_TYPE,
# Since local shards are currently moved to CPU on creation, we need to get
# the shard indices indirectly to avoid unnecessarily consuming host memory.
replica_and_indices = torch_xla._XLAC._get_local_shard_replica_and_indices(
t.global_tensor)
[t.global_tensor])[0]
chunks = [
_create_chunk_from_shard_index(ind) for _, ind in replica_and_indices
]
Expand Down