Skip to content

Commit

Permalink
Distribute Literal->Tensor copies across thread pool (#5825)
Browse files Browse the repository at this point in the history
* Distribute Literal->Tensor copies across thread pool

* Update for #5799
  • Loading branch information
jonb377 authored and bhavya01 committed Apr 22, 2024
1 parent f9dfabb commit 17eddce
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 11 deletions.
3 changes: 2 additions & 1 deletion test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ namespace {
bool XlaDataValuesEqual(torch::lazy::BackendDataPtr a,
torch::lazy::BackendDataPtr b,
at::ScalarType element_type) {
std::vector<at::Tensor> tensors = XlaDataToTensors({a, b}, element_type);
std::vector<at::Tensor> tensors =
XlaDataToTensors({a, b}, {element_type, element_type});
return TensorCompare(tensors[0], tensors[1]);
}
} // namespace
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1784,8 +1784,8 @@ void InitXlaModuleBindings(py::module m) {
shard_handles) {
shards.push_back(
XlaDataToTensors({shard_handle},
MaybeUpcastToHostTorchType(
shard_handle->shape().element_type()))
{MaybeUpcastToHostTorchType(
shard_handle->shape().element_type())})
.front());
str_devices.push_back(shard_handle->device());
}
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,8 @@ at::Tensor XLATensor::ToTensor(bool detached) {
XLAGraphExecutor::Get()->DeviceBarrier(GetDevice());
// The GetXlaData() call will trigger an ApplyPendingGraph() if an IR
// XlaNode is available on the tensor.
std::vector<at::Tensor> tensors = XlaDataToTensors({GetXlaData()}, dtype());
std::vector<at::Tensor> tensors =
XlaDataToTensors({GetXlaData()}, {dtype()});
tensor = std::move(tensors.front());
if (!detached) {
SetTensorData(tensor);
Expand Down
15 changes: 10 additions & 5 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -796,13 +796,18 @@ std::vector<xla::Literal> ReleaseGilAndTransferData(

std::vector<at::Tensor> XlaDataToTensors(
absl::Span<const torch::lazy::BackendDataPtr> xla_data,
at::ScalarType dest_element_type) {
absl::Span<const at::ScalarType> dest_element_type) {
std::vector<xla::Literal> literals = ReleaseGilAndTransferData(xla_data);
std::vector<at::Tensor> tensors;
tensors.reserve(literals.size());
for (auto& literal : literals) {
tensors.push_back(MakeTensorFromXlaLiteral(literal, dest_element_type));
std::vector<at::Tensor> tensors(literals.size());
absl::BlockingCounter counter(literals.size());
for (size_t i = 0; i < tensors.size(); ++i) {
auto copy_fn = [&, i]() {
tensors[i] = MakeTensorFromXlaLiteral(literals[i], dest_element_type[i]);
counter.DecrementCount();
};
thread::Schedule(std::move(copy_fn));
}
counter.Wait();
return tensors;
}

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/tensor_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ std::vector<xla::Literal> ReleaseGilAndTransferData(
// TODO LTC @wonjoo - Migrate to upstream after Device -> BackendDevice
std::vector<at::Tensor> XlaDataToTensors(
absl::Span<const torch::lazy::BackendDataPtr> xla_data,
at::ScalarType dest_element_type);
absl::Span<const at::ScalarType> dest_element_type);

bool TensorCompare(const at::Tensor& t1, const at::Tensor& t2);

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/xla_backend_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
const torch::lazy::BackendDataPtr data,
c10::optional<at::ScalarType> logical_scalar_type) const override {
// TODO(JackCaoG): handle the logical_scalar_type == nullptr case
return XlaDataToTensors({data}, *logical_scalar_type)[0];
return XlaDataToTensors({data}, {*logical_scalar_type})[0];
}

std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext(
Expand Down

0 comments on commit 17eddce

Please sign in to comment.