Skip to content

Commit

Permalink
remove multiwait
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Nov 14, 2023
1 parent bf4ff6f commit 59de529
Show file tree
Hide file tree
Showing 11 changed files with 29 additions and 174 deletions.
2 changes: 1 addition & 1 deletion test/cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ ptxla_cc_test(
":torch_xla_test",
"//torch_xla/csrc/runtime:runtime",
"//torch_xla/csrc/runtime:debug_macros",
"//torch_xla/csrc/runtime:multi_wait",
"//torch_xla/csrc/runtime:thread_pool",
"//torch_xla/csrc:tensor",
"@com_google_absl//absl/synchronization",
"@com_google_googletest//:gtest_main",
"@xla//xla:shape_util",
"@xla//xla/client:xla_builder",
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,6 @@ ptxla_cc_library(
"//torch_xla/csrc/runtime:metrics",
"//torch_xla/csrc/runtime:metrics_analysis",
"//torch_xla/csrc/runtime:metrics_reader",
"//torch_xla/csrc/runtime:multi_wait",
"//torch_xla/csrc/runtime:profiler",
"//torch_xla/csrc/runtime:sys_util",
"//torch_xla/csrc/runtime:thread_pool",
Expand All @@ -278,6 +277,7 @@ ptxla_cc_library(
"//torch_xla/csrc/runtime:xla_util",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:variant",
"@tsl//tsl/profiler/lib:traceme",
"@tsl//tsl/profiler/lib:traceme_encode",
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_cat.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/types/variant.h"
#include "pybind11/attr.h"
#include "pybind11/cast.h"
Expand All @@ -43,7 +44,6 @@
#include "torch_xla/csrc/runtime/metrics.h"
#include "torch_xla/csrc/runtime/metrics_analysis.h"
#include "torch_xla/csrc/runtime/metrics_reader.h"
#include "torch_xla/csrc/runtime/multi_wait.h"
#include "torch_xla/csrc/runtime/profiler.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/runtime/sys_util.h"
Expand Down
11 changes: 1 addition & 10 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ cc_library(
":computation_client",
":debug_macros",
":env_vars",
":multi_wait",
":profiler",
":stablehlo_helper",
":tensor_source",
Expand All @@ -102,6 +101,7 @@ cc_library(
"@tsl//tsl/profiler/lib:traceme",
"@tsl//tsl/platform/cloud:gcs_file_system",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
],
)
Expand Down Expand Up @@ -187,15 +187,6 @@ cc_library(
],
)

cc_library(
name = "multi_wait",
srcs = ["multi_wait.cc"],
hdrs = ["multi_wait.h"],
deps = [
"@xla//xla:types",
],
)

# Profiler silently fails unless we link these backends
cc_library(
name = "profiler_backends",
Expand Down
73 changes: 0 additions & 73 deletions torch_xla/csrc/runtime/multi_wait.cc

This file was deleted.

60 changes: 0 additions & 60 deletions torch_xla/csrc/runtime/multi_wait.h

This file was deleted.

28 changes: 12 additions & 16 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
#include <vector>

#include "absl/strings/ascii.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/types/span.h"
#include "pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/multi_wait.h"
#include "torch_xla/csrc/runtime/profiler.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
Expand Down Expand Up @@ -619,9 +619,9 @@ PjRtComputationClient::ExecuteComputation(
}
CreateDataHandlesCounter()->AddValue(datas.size());

auto mwait = std::make_shared<util::MultiWait>(1);
auto lockfn = [&, this, device, returned_future = std::move(*returned_future),
timed]() mutable {
Schedule(std::move([&, this, device,
returned_future = std::move(*returned_future),
timed]() mutable {
TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for "
<< device;
// Grab the shared lock and block the `WaitDeviceOps` until buffer is
Expand All @@ -642,9 +642,7 @@ PjRtComputationClient::ExecuteComputation(
timed.reset();
TF_VLOG(3) << "ExecuteComputation returned_future->OnReady finished";
});
};

Schedule(util::MultiWait::Completer(mwait, std::move(lockfn)));
}));

TF_VLOG(1) << "Returning " << datas.size() << " results";
return datas;
Expand All @@ -668,7 +666,7 @@ PjRtComputationClient::ExecuteReplicated(
XLA_CHECK(devices.size() == arguments.size())
<< "ExecuteReplicated over " << devices.size() << " devices, but "
<< arguments.size() << " arguments devices.";
auto mwait_argument = std::make_shared<util::MultiWait>(devices.size());
absl::BlockingCounter mwait(devices.size());
std::vector<std::vector<xla::PjRtBuffer*>> argument_handles(devices.size());
{
tsl::profiler::TraceMe activity(
Expand All @@ -689,11 +687,11 @@ PjRtComputationClient::ExecuteReplicated(
buffers.push_back(pjrt_data->buffer.get());
}
argument_handles[i] = std::move(buffers);
mwait.DecrementCount();
};
Schedule(util::MultiWait::Completer(
mwait_argument, std::move(buffer_converter)));
Schedule(std::move(buffer_converter));
}
mwait_argument->Wait();
mwait.Wait();
}

xla::ExecuteOptions execute_options;
Expand Down Expand Up @@ -748,9 +746,8 @@ PjRtComputationClient::ExecuteReplicated(
}
}

auto mwait = std::make_shared<util::MultiWait>(1);
auto lockfn = [&, this, returned_futures = std::move(*returned_futures),
timed]() mutable {
Schedule(std::move([&, this, returned_futures = std::move(*returned_futures),
timed]() mutable {
// Grab the shared lock and block the `WaitDeviceOps` until buffer is
// ready. Since this is the SPMD code path. There is no points to grab
// devices lock for every individual device.
Expand All @@ -771,8 +768,7 @@ PjRtComputationClient::ExecuteReplicated(
timed.reset();
TF_VLOG(3) << "ExecuteReplicated returned_future->OnReady finished";
});
};
Schedule(util::MultiWait::Completer(mwait, std::move(lockfn)));
}));

TF_VLOG(1) << "Returning " << data_handles.size() << " sets of results "
<< "with dimensions [" << absl::StrJoin(dims, ",") << "].";
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/runtime/thread_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ namespace runtime {
void Schedule(std::function<void()> fn) {
static size_t num_threads = sys_util::GetEnvInt(
"XLA_THREAD_POOL_SIZE", std::thread::hardware_concurrency());
static tsl::thread::ThreadPool pool(tsl::Env::Default(), "pytorchxla", num_threads);
static tsl::thread::ThreadPool pool(tsl::Env::Default(), "pytorchxla",
num_threads);
pool.Schedule(std::move(fn));
}

Expand Down
10 changes: 5 additions & 5 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
#include <numeric>
#include <thread>

#include "absl/synchronization/blocking_counter.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/dtype.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/layout_manager.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/multi_wait.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
Expand Down Expand Up @@ -366,16 +366,16 @@ void CopyTensors(const void* src_buffer, const xla::Shape& src_shape,
std::vector<int64_t> iter_dims = GetIterationDimensions(dest_shape);
std::vector<CopyPartition> parts =
CreateCopyPartitions(dest_shape.dimensions(), iter_dims.front());
auto mwait = std::make_shared<runtime::util::MultiWait>(parts.size());
absl::BlockingCounter mwait(parts.size());
for (size_t i = 0; i < parts.size(); ++i) {
auto copy_fn = [&, i]() {
SlicedCopy<SType, DType>(dest_shape.dimensions(), src_data, src_strides,
dest_data, dest_strides, iter_dims, parts[i]);
mwait.DecrementCount();
};
runtime::Schedule(
runtime::util::MultiWait::Completer(mwait, std::move(copy_fn)));
runtime::Schedule(std::move(copy_fn));
}
mwait->Wait();
mwait.Wait();
}
}

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/xla_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <string>
#include <unordered_map>

#include "absl/synchronization/blocking_counter.h"
#include "torch_xla/csrc/cross_replica_reduces.h"
#include "torch_xla/csrc/debug_util.h"
#include "torch_xla/csrc/device.h"
Expand All @@ -18,7 +19,6 @@
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/runtime/cache.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/multi_wait.h"
#include "torch_xla/csrc/runtime/util.h"
#include "torch_xla/csrc/tensor.h"
#include "torch_xla/csrc/torch_util.h"
Expand Down
10 changes: 5 additions & 5 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <cmath>
#include <unordered_map>

#include "absl/synchronization/blocking_counter.h"
#include "torch/csrc/lazy/core/ir_util.h"
#include "torch_xla/csrc/aten_autograd_ops.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
Expand All @@ -13,7 +14,6 @@
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/multi_wait.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/tensor.h"
Expand Down Expand Up @@ -326,7 +326,7 @@ ShardingUtil::InputHandler(
// the first local index with the first global device ordinal.
auto device_index = build_index_map(devices);

auto mwait = std::make_shared<runtime::util::MultiWait>(devices.size());
absl::BlockingCounter mwait(devices.size());

for (int i = 0; i < devices.size(); i++) {
auto argument_setter = [&, i]() {
Expand All @@ -339,11 +339,11 @@ ShardingUtil::InputHandler(
int device_i = device_index[global_ordinal];
arguments_by_device[device_i][argument_i] = shard;
}
mwait.DecrementCount();
};
runtime::Schedule(
runtime::util::MultiWait::Completer(mwait, std::move(argument_setter)));
runtime::Schedule(std::move(argument_setter));
}
mwait->Wait();
mwait.Wait();
return arguments_by_device;
}

Expand Down

0 comments on commit 59de529

Please sign in to comment.