From b9475d9ec910dd38dc4298951b0ddcc887ab3bdd Mon Sep 17 00:00:00 2001 From: jonb377 Date: Thu, 30 Nov 2023 08:40:05 +1100 Subject: [PATCH] Parallelize d2h transfers (#5824) --- torch_xla/csrc/runtime/pjrt_computation_client.cc | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index f9e46dce55d..1aa017bc33d 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -464,6 +464,8 @@ std::vector PjRtComputationClient::TransferFromServer( metrics::TimedSection timed(TransferFromServerMetric()); tsl::profiler::TraceMe activity("PjRtComputationClient::TransferFromServer", tsl::profiler::TraceMeLevel::kInfo); + std::vector> futures; + futures.reserve(handles.size()); std::vector literals; literals.reserve(handles.size()); int64_t total_size = 0; @@ -473,12 +475,16 @@ std::vector PjRtComputationClient::TransferFromServer( auto new_handle = ReplicateShardedData(handle); const PjRtData& pjrt_data = dynamic_cast(*new_handle); - auto& literal = + xla::Literal& literal = literals.emplace_back(host_output_shape(pjrt_data.buffer.get())); - XLA_CHECK_OK(pjrt_data.buffer->ToLiteralSync(&literal)); + futures.push_back(pjrt_data.buffer->ToLiteral(&literal)); total_size += literal.size_bytes(); } + for (auto& future : futures) { + tsl::Status status = future.Await(); + XLA_CHECK_OK(status); + } InboundDataMetric()->AddSample(total_size); return literals;