Skip to content

Commit

Permalink
fix concurrency issues
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Nov 29, 2023
1 parent f411a37 commit cb64bee
Showing 1 changed file with 42 additions and 21 deletions.
63 changes: 42 additions & 21 deletions torch_xla/csrc/runtime/ifrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <vector>

#include "absl/strings/ascii.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/types/span.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
Expand Down Expand Up @@ -578,11 +579,9 @@ IfrtComputationClient::ExecuteReplicated(
// TODO: devices isn't doing anything helpful here
absl::Span<const std::string> devices,
const ExecuteReplicatedOptions& options) {
// XLA_ERROR() << __FUNCTION__ << " not implemented";
// Shared ownership of the timed section ensures that it will only get logged
// once both `ExecuteReplicated` and the async work in `Execute` are
// complete; a copy is held from the lambda that releases it when done.
// TODO: fix timing
auto timed =
std::make_shared<metrics::TimedSection>(ExecuteReplicatedMetric());
tsl::profiler::TraceMe activity("IfrtComputationClient::ExecuteReplicated",
Expand All @@ -591,40 +590,62 @@ IfrtComputationClient::ExecuteReplicated(
dynamic_cast<const IfrtComputation&>(computation);

std::vector<tsl::RCReference<xla::ifrt::Array>> argument_handles(arguments.size());
pool_.ParallelFor(arguments.size(), 10000,
[&](int64_t start, int64_t end) {
for (int32_t i = start; i < end; ++i) {
auto ifrt_data = std::dynamic_pointer_cast<IfrtData>(arguments[i]);
argument_handles[i] = ifrt_data->buffer;
}
});
{
absl::BlockingCounter counter(arguments.size());
pool_.ParallelFor(arguments.size(), 10000,
[&](int64_t start, int64_t end) {
for (int32_t i = start; i < end; ++i) {
auto ifrt_data = std::dynamic_pointer_cast<IfrtData>(arguments[i]);
argument_handles[i] = ifrt_data->buffer;
counter.DecrementCount();
}
});
counter.Wait();
}

xla::ExecuteOptions execute_options;
execute_options.untuple_result = options.explode_tuple;
execute_options.strict_shape_checking = true;
// TODO(yeounoh) currently only support single-slice execution
execute_options.multi_slice_config = nullptr;

TF_VLOG(5) << "ExecuteReplicated acquiring IFRT device lock for "
<< spmd_device_str;
auto op_tracker = operation_manager_.StartOperation(spmd_device_str);
TF_VLOG(5) << "ExecuteReplicated acquiring IFRT device lock for "
<< spmd_device_str << " Done";

xla::ifrt::LoadedExecutable::ExecuteResult result =
ifrt_computation.executable
->Execute(absl::MakeSpan(argument_handles), execute_options, std::nullopt)
.value();

xla::ifrt::Future<xla::Status> returned_future = result.status;
auto results = result.outputs;
result.status.OnReady(
std::move([timed, op_tracker = std::move(op_tracker)](
xla::Status status) mutable {
timed.reset();
TF_VLOG(3) << "ExecuteReplicated returned_future->OnReady finished with status " << status;
}));

auto outputs = result.outputs;

XLA_CHECK(ifrt_computation.output_shardings_.has_value());
auto& output_shardings = *(ifrt_computation.output_shardings_);
XLA_CHECK_EQ(output_shardings.size(), results.size());

std::vector<ComputationClient::DataPtr> data_handles(results.size());
pool_.ParallelFor(results.size(), 10000,
[&](int64_t start, int64_t end) {
for (int32_t i = start; i < end; ++i) {
data_handles[i] =
std::make_shared<IfrtData>(spmd_device_str, results[i], output_shardings[i]);
}
});
XLA_CHECK_EQ(output_shardings.size(), outputs.size());

std::vector<ComputationClient::DataPtr> data_handles(outputs.size());
{
absl::BlockingCounter counter(outputs.size());
pool_.ParallelFor(outputs.size(), 10000,
[&](int64_t start, int64_t end) {
for (int32_t i = start; i < end; ++i) {
data_handles[i] =
std::make_shared<IfrtData>(spmd_device_str, outputs[i], output_shardings[i]);
counter.DecrementCount();
}
});
counter.Wait();
}

TF_VLOG(1) << "Returning " << data_handles.size() << " sharded outputs.";
return data_handles;
Expand Down

0 comments on commit cb64bee

Please sign in to comment.