diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 2e1c5d007920..cc7f5c835a7f 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1760,38 +1760,53 @@ void InitXlaModuleBindings(py::module m) { // contain any padding that was applied to ensure they all have the same // shape. Note that this padding is _not_ included in the global indices // returned by `_get_local_shard_replica_and_indices`. + // For each input tensor, returns a list of shards and their corresponding + // device string. m.def("_get_local_shards", - [](const at::Tensor& input) - -> std::tuple, std::vector> { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - XLA_CHECK(xtensor->GetXlaData() != nullptr) - << "Shard data is not available"; - XLA_CHECK(xtensor->sharding_spec() != nullptr) - << "Tensor is not sharded"; - XLA_CHECK(UseVirtualDevice()) - << "Virtual device must be enabled to use _get_local_shards"; - auto handle = - std::dynamic_pointer_cast( - xtensor->GetXlaData()); - std::vector shard_handles = - runtime::GetComputationClient()->GetDataShards(handle); - std::vector shards; - std::vector str_devices; - shards.reserve(shard_handles.size()); - str_devices.reserve(shard_handles.size()); - // Tansfer shards from the device and create cpu tensors. - for (const runtime::ComputationClient::DataPtr shard_handle : - shard_handles) { - shards.push_back( - XlaDataToTensors({shard_handle}, - {MaybeUpcastToHostTorchType( - shard_handle->shape().element_type())}) - .front()); - str_devices.push_back(shard_handle->device()); + [](const std::vector& input) + -> std::vector>> { + std::vector handles; + std::vector element_types; + // Find all shard handles for transfer + for (auto& tensor : input) { + XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); + XLA_CHECK(xtensor->GetXlaData() != nullptr) + << "Shard data is not available"; + XLA_CHECK(xtensor->sharding_spec() != nullptr) + << "Tensor is not sharded"; + auto handle = + std::dynamic_pointer_cast( + xtensor->GetXlaData()); + std::vector shard_handles = + runtime::GetComputationClient()->GetDataShards(handle); + handles.insert(handles.end(), shard_handles.begin(), + shard_handles.end()); + element_types.insert(element_types.end(), shard_handles.size(), + MaybeUpcastToHostTorchType( + shard_handles[0]->shape().element_type())); + } + + std::vector cpu_shards = + XlaDataToTensors(WrapXlaData(handles), element_types); + // Populate the resulting vector of shards and device strings + std::vector>> result; + int shards_per_tensor = + runtime::GetComputationClient()->GetLocalDevices().size(); + result.reserve(cpu_shards.size() / shards_per_tensor); + for (int i = 0; i < cpu_shards.size(); i += shards_per_tensor) { + std::vector> shard_devices; + for (int shard = 0; shard < shards_per_tensor; ++shard) { + at::Tensor cpu_shard = cpu_shards[i + shard]; + std::string source_device = handles[i + shard]->device(); + std::pair shard_dev(cpu_shard, + source_device); + shard_devices.push_back(shard_dev); + } + result.push_back(shard_devices); } - return std::make_tuple(shards, str_devices); + return result; }); - // For each local shard, returns the tuple: + // For each input tensors' local shards, returns the tuple: // (replica_id: int, indices: Union[List[Slice], Ellipsis]), // where `replica_id` is the replica the shard belongs to and `indices` index // into the global tensor. The value of `indices` is either a Python list of @@ -1799,9 +1814,13 @@ void InitXlaModuleBindings(py::module m) { // is replicated. These indices will not reflect any padding that has been // applied to the shards. The order of the returned indices matches the order // of the shards returned from `_get_local_shards`. - m.def("_get_local_shard_replica_and_indices", - [](const at::Tensor& input) -> std::vector> { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); + m.def( + "_get_local_shard_replica_and_indices", + [](const std::vector& input_tensors) + -> std::vector>> { + std::vector>> result; + for (auto& tensor : input_tensors) { + XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); XLA_CHECK(xtensor->sharding_spec() != nullptr) << "Tensor is not sharded"; auto handle = @@ -1817,18 +1836,18 @@ void InitXlaModuleBindings(py::module m) { auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); auto replica_and_indices = ShardingUtil::GetShardReplicaAndIndicesForDevices( - shard_shape, input.sizes().vec(), sharding, shard_devices); + shard_shape, tensor.sizes().vec(), sharding, shard_devices); // Convert each vector to List[py::slice] or py::ellipsis - std::vector> result; - result.reserve(shard_devices.size()); + std::vector> tensor_ind; + tensor_ind.reserve(shard_devices.size()); for (auto& device_replica_and_indices : replica_and_indices) { auto& replica_id = device_replica_and_indices.first; auto& indices = device_replica_and_indices.second; XLA_CHECK(indices.size() > 0) - << "Unexpected empty shard indices for tensor " << input; + << "Unexpected empty shard indices for tensor " << tensor; if (indices[0].is_ellipsis()) { - result.push_back(std::make_pair(replica_id, py::ellipsis())); + tensor_ind.push_back(std::make_pair(replica_id, py::ellipsis())); } else { std::vector index_slices; for (auto& tensor_index : indices) { @@ -1840,12 +1859,14 @@ void InitXlaModuleBindings(py::module m) { ssize_t step = slice.step().expect_int(); index_slices.push_back(py::slice(start, stop, step)); } - result.push_back( + tensor_ind.push_back( std::make_pair(replica_id, py::cast(index_slices))); } } - return result; - }); + result.push_back(tensor_ind); + } + return result; + }); // Load a list of local shards into an explicitly-sharded tensor. A shard must // be provided for each device. m.def("_load_local_shards", [](const at::Tensor& tensor, diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index 2945502dcc26..44377027f5bc 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -115,13 +115,12 @@ def __new__(cls, elem: torch.Tensor, *args, **kwargs): # which results from the sharding. @property def local_shards(self) -> List[XLAShard]: - shards, devices = torch_xla._XLAC._get_local_shards(self.global_tensor) - replica_and_indices = torch_xla._XLAC._get_local_shard_replica_and_indices( - self.global_tensor) - zipped = zip(shards, replica_and_indices, devices) + shard_dev = torch_xla._XLAC._get_local_shards([self.global_tensor])[0] + replica_ind = torch_xla._XLAC._get_local_shard_replica_and_indices( + [self.global_tensor])[0] return [ XLAShard(data, indices, dev, replica) - for data, (replica, indices), dev in zipped + for (data, dev), (replica, indices) in zip(shard_dev, replica_ind) ] # Load the given list of local shards into the underlying tensor's data diff --git a/torch_xla/experimental/distributed_checkpoint/_helpers.py b/torch_xla/experimental/distributed_checkpoint/_helpers.py index 62c3c6f2ee0b..16a4b2181ee0 100644 --- a/torch_xla/experimental/distributed_checkpoint/_helpers.py +++ b/torch_xla/experimental/distributed_checkpoint/_helpers.py @@ -3,8 +3,10 @@ # their APIs. import dataclasses +from itertools import starmap import torch +import torch_xla import torch_xla.distributed.spmd as xs from torch.distributed.checkpoint.planner import SavePlan @@ -24,7 +26,7 @@ from torch.distributed.checkpoint.metadata import (MetadataIndex, STATE_DICT_TYPE) from torch_xla.distributed.spmd import XLAShardedTensor, ShardingType -from torch.utils._pytree import tree_map +from torch.utils._pytree import tree_flatten, tree_unflatten PATH_ITEM = Union[str, int] OBJ_PATH = Tuple[PATH_ITEM, ...] @@ -217,16 +219,59 @@ class _CpuShards: global_shape: torch.Size -def _sharded_cpu_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: +def _cpu_shards_from_tensors(tensors: List[torch.Tensor]): """ - Converts a state_dict on XLA device to a sharded state_dict on CPU. + Transfer all shards for the input tensors to CPU, and create a _CpuShards + object for each. """ - def move_state_dict_to_cpu(v): - v = xs.wrap_if_sharded(v) - if not _is_sharded_tensor(v): - v = _unwrap_xla_sharded_tensor(v) - return v.cpu() if isinstance(v, torch.Tensor) else v - return _CpuShards(shards=v.local_shards, global_shape=v.global_tensor.shape) + def create_cpu_shards(global_tensor: torch.Tensor, + shards_dev: List[Tuple[torch.Tensor, str]], + replica_ind: List[Tuple[int, Union[List[slice], + type(Ellipsis)]]]): + shards = [ + xs.XLAShard(data, indices, dev, replica) + for (data, dev), (replica, indices) in zip(shards_dev, replica_ind) + ] + global_shape = global_tensor.shape + return _CpuShards(shards=shards, global_shape=global_shape) + + shards_devs = torch_xla._XLAC._get_local_shards(tensors) + rep_inds = torch_xla._XLAC._get_local_shard_replica_and_indices(tensors) + return list(starmap(create_cpu_shards, zip(tensors, shards_devs, rep_inds))) + - return tree_map(move_state_dict_to_cpu, state_dict) +def _sharded_cpu_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: + """ + Converts a state_dict on XLA device to a sharded state_dict on CPU. + """ + flat, tree_spec = tree_flatten(state_dict) + flat = [xs.wrap_if_sharded(x) for x in flat] + sharded = [ + _unwrap_xla_sharded_tensor(x) for x in flat if _is_sharded_tensor(x) + ] + + # Move all sharded tensors to CPU + cpu_shards = _cpu_shards_from_tensors(sharded) + cpu_shards_iter = iter(cpu_shards) + + # Move all unsharded tensors to CPU + unsharded_tensors = [ + _unwrap_xla_sharded_tensor(x) + for x in flat + if isinstance(x, torch.Tensor) and not _is_sharded_tensor(x) + ] + cpu_tensors = torch_xla._XLAC._xla_get_cpu_tensors(unsharded_tensors) + cpu_tensors_iter = iter(cpu_tensors) + + # Combine the results. The order between the iterators and the flattened + # state_dict is consistent, so simply interweave the iterators. + def to_cpu(x: Any): + if _is_sharded_tensor(x): + return next(cpu_shards_iter) + elif isinstance(x, torch.Tensor): + return next(cpu_tensors_iter) + return x + + flat = [to_cpu(x) for x in flat] + return tree_unflatten(flat, tree_spec) diff --git a/torch_xla/experimental/distributed_checkpoint/planners.py b/torch_xla/experimental/distributed_checkpoint/planners.py index 6810ddb56a31..c417872c2f25 100644 --- a/torch_xla/experimental/distributed_checkpoint/planners.py +++ b/torch_xla/experimental/distributed_checkpoint/planners.py @@ -329,7 +329,7 @@ def _create_write_items_for_xla_sharded_tensor( # Since local shards are currently moved to CPU on creation, we need to get # the shard indices indirectly to avoid unnecessarily consuming host memory. replica_and_indices = torch_xla._XLAC._get_local_shard_replica_and_indices( - t.global_tensor) + [t.global_tensor])[0] prop = TensorProperties.create_from_tensor(t) for shard_ind, (_, indices) in enumerate(replica_and_indices): write_item = _create_write_item_from_indices(fqn, shard_ind, indices, @@ -386,7 +386,7 @@ def _create_xla_read_items(sharded_state_dict: STATE_DICT_TYPE, # Since local shards are currently moved to CPU on creation, we need to get # the shard indices indirectly to avoid unnecessarily consuming host memory. replica_and_indices = torch_xla._XLAC._get_local_shard_replica_and_indices( - t.global_tensor) + [t.global_tensor])[0] chunks = [ _create_chunk_from_shard_index(ind) for _, ind in replica_and_indices ]