Skip to content

Commit

Permalink
Parallelize d2h transfers (#5824)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 authored Nov 29, 2023
1 parent eb728a8 commit b9475d9
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,8 @@ std::vector<xla::Literal> PjRtComputationClient::TransferFromServer(
metrics::TimedSection timed(TransferFromServerMetric());
tsl::profiler::TraceMe activity("PjRtComputationClient::TransferFromServer",
tsl::profiler::TraceMeLevel::kInfo);
std::vector<xla::PjRtFuture<tsl::Status>> futures;
futures.reserve(handles.size());
std::vector<xla::Literal> literals;
literals.reserve(handles.size());
int64_t total_size = 0;
Expand All @@ -473,12 +475,16 @@ std::vector<xla::Literal> PjRtComputationClient::TransferFromServer(
auto new_handle = ReplicateShardedData(handle);
const PjRtData& pjrt_data = dynamic_cast<const PjRtData&>(*new_handle);

auto& literal =
xla::Literal& literal =
literals.emplace_back(host_output_shape(pjrt_data.buffer.get()));
XLA_CHECK_OK(pjrt_data.buffer->ToLiteralSync(&literal));
futures.push_back(pjrt_data.buffer->ToLiteral(&literal));

total_size += literal.size_bytes();
}
for (auto& future : futures) {
tsl::Status status = future.Await();
XLA_CHECK_OK(status);
}
InboundDataMetric()->AddSample(total_size);

return literals;
Expand Down

0 comments on commit b9475d9

Please sign in to comment.