From 8c47daf051553bc407caf76322c08dca2ff070da Mon Sep 17 00:00:00 2001 From: jonb377 Date: Thu, 30 Nov 2023 17:55:43 -0800 Subject: [PATCH] Distribute Literal->Tensor copies across thread pool (#5825) * Distribute Literal->Tensor copies across thread pool * Update for #5799 --- test/cpp/test_xla_sharding.cpp | 3 ++- torch_xla/csrc/init_python_bindings.cpp | 4 ++-- torch_xla/csrc/tensor.cpp | 3 ++- torch_xla/csrc/tensor_util.cpp | 15 ++++++++++----- torch_xla/csrc/tensor_util.h | 2 +- torch_xla/csrc/xla_backend_impl.cpp | 2 +- 6 files changed, 18 insertions(+), 11 deletions(-) diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index 08beb49b7b21..77f2b0c57688 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -27,7 +27,8 @@ namespace { bool XlaDataValuesEqual(torch::lazy::BackendDataPtr a, torch::lazy::BackendDataPtr b, at::ScalarType element_type) { - std::vector tensors = XlaDataToTensors({a, b}, element_type); + std::vector tensors = + XlaDataToTensors({a, b}, {element_type, element_type}); return TensorCompare(tensors[0], tensors[1]); } } // namespace diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 4db190c76b8b..2e1c5d007920 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1784,8 +1784,8 @@ void InitXlaModuleBindings(py::module m) { shard_handles) { shards.push_back( XlaDataToTensors({shard_handle}, - MaybeUpcastToHostTorchType( - shard_handle->shape().element_type())) + {MaybeUpcastToHostTorchType( + shard_handle->shape().element_type())}) .front()); str_devices.push_back(shard_handle->device()); } diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 5d4d6825e03c..f48d3a83db30 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -467,7 +467,8 @@ at::Tensor XLATensor::ToTensor(bool detached) { XLAGraphExecutor::Get()->DeviceBarrier(GetDevice()); // The GetXlaData() call will trigger an ApplyPendingGraph() if an IR // XlaNode is available on the tensor. - std::vector tensors = XlaDataToTensors({GetXlaData()}, dtype()); + std::vector tensors = + XlaDataToTensors({GetXlaData()}, {dtype()}); tensor = std::move(tensors.front()); if (!detached) { SetTensorData(tensor); diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index e46bf7e022cb..c2183f7e7854 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -796,13 +796,18 @@ std::vector ReleaseGilAndTransferData( std::vector XlaDataToTensors( absl::Span xla_data, - at::ScalarType dest_element_type) { + absl::Span dest_element_type) { std::vector literals = ReleaseGilAndTransferData(xla_data); - std::vector tensors; - tensors.reserve(literals.size()); - for (auto& literal : literals) { - tensors.push_back(MakeTensorFromXlaLiteral(literal, dest_element_type)); + std::vector tensors(literals.size()); + absl::BlockingCounter counter(literals.size()); + for (size_t i = 0; i < tensors.size(); ++i) { + auto copy_fn = [&, i]() { + tensors[i] = MakeTensorFromXlaLiteral(literals[i], dest_element_type[i]); + counter.DecrementCount(); + }; + thread::Schedule(std::move(copy_fn)); } + counter.Wait(); return tensors; } diff --git a/torch_xla/csrc/tensor_util.h b/torch_xla/csrc/tensor_util.h index 81b4cd9a5652..f9ca29f7ab1d 100644 --- a/torch_xla/csrc/tensor_util.h +++ b/torch_xla/csrc/tensor_util.h @@ -34,7 +34,7 @@ std::vector ReleaseGilAndTransferData( // TODO LTC @wonjoo - Migrate to upstream after Device -> BackendDevice std::vector XlaDataToTensors( absl::Span xla_data, - at::ScalarType dest_element_type); + absl::Span dest_element_type); bool TensorCompare(const at::Tensor& t1, const at::Tensor& t2); diff --git a/torch_xla/csrc/xla_backend_impl.cpp b/torch_xla/csrc/xla_backend_impl.cpp index 4adb9f50eaee..c2cb2f432894 100644 --- a/torch_xla/csrc/xla_backend_impl.cpp +++ b/torch_xla/csrc/xla_backend_impl.cpp @@ -93,7 +93,7 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface { const torch::lazy::BackendDataPtr data, c10::optional logical_scalar_type) const override { // TODO(JackCaoG): handle the logical_scalar_type == nullptr case - return XlaDataToTensors({data}, *logical_scalar_type)[0]; + return XlaDataToTensors({data}, {*logical_scalar_type})[0]; } std::unique_ptr CreateLoweringContext(