From 59de52948b4a31bbdd5533d27d9c6e39bbae44b2 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 9 Nov 2023 18:46:11 +0000 Subject: [PATCH] remove multiwait --- test/cpp/BUILD | 2 +- torch_xla/csrc/BUILD | 2 +- torch_xla/csrc/init_python_bindings.cpp | 2 +- torch_xla/csrc/runtime/BUILD | 11 +-- torch_xla/csrc/runtime/multi_wait.cc | 73 ------------------- torch_xla/csrc/runtime/multi_wait.h | 60 --------------- .../csrc/runtime/pjrt_computation_client.cc | 28 +++---- torch_xla/csrc/runtime/thread_pool.cc | 3 +- torch_xla/csrc/tensor_util.cpp | 10 +-- torch_xla/csrc/xla_graph_executor.h | 2 +- torch_xla/csrc/xla_sharding_util.cpp | 10 +-- 11 files changed, 29 insertions(+), 174 deletions(-) delete mode 100644 torch_xla/csrc/runtime/multi_wait.cc delete mode 100644 torch_xla/csrc/runtime/multi_wait.h diff --git a/test/cpp/BUILD b/test/cpp/BUILD index fd53eefc377..2e796f516ee 100644 --- a/test/cpp/BUILD +++ b/test/cpp/BUILD @@ -78,9 +78,9 @@ ptxla_cc_test( ":torch_xla_test", "//torch_xla/csrc/runtime:runtime", "//torch_xla/csrc/runtime:debug_macros", - "//torch_xla/csrc/runtime:multi_wait", "//torch_xla/csrc/runtime:thread_pool", "//torch_xla/csrc:tensor", + "@com_google_absl//absl/synchronization", "@com_google_googletest//:gtest_main", "@xla//xla:shape_util", "@xla//xla/client:xla_builder", diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 128107e9e7a..635da87fc9b 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -269,7 +269,6 @@ ptxla_cc_library( "//torch_xla/csrc/runtime:metrics", "//torch_xla/csrc/runtime:metrics_analysis", "//torch_xla/csrc/runtime:metrics_reader", - "//torch_xla/csrc/runtime:multi_wait", "//torch_xla/csrc/runtime:profiler", "//torch_xla/csrc/runtime:sys_util", "//torch_xla/csrc/runtime:thread_pool", @@ -278,6 +277,7 @@ ptxla_cc_library( "//torch_xla/csrc/runtime:xla_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:variant", "@tsl//tsl/profiler/lib:traceme", "@tsl//tsl/profiler/lib:traceme_encode", diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 1391b73a16c..0ba4fd1b297 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -20,6 +20,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" +#include "absl/synchronization/blocking_counter.h" #include "absl/types/variant.h" #include "pybind11/attr.h" #include "pybind11/cast.h" @@ -43,7 +44,6 @@ #include "torch_xla/csrc/runtime/metrics.h" #include "torch_xla/csrc/runtime/metrics_analysis.h" #include "torch_xla/csrc/runtime/metrics_reader.h" -#include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/profiler.h" #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/sys_util.h" diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index b7aebccd221..395dd0433c5 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -82,7 +82,6 @@ cc_library( ":computation_client", ":debug_macros", ":env_vars", - ":multi_wait", ":profiler", ":stablehlo_helper", ":tensor_source", @@ -102,6 +101,7 @@ cc_library( "@tsl//tsl/profiler/lib:traceme", "@tsl//tsl/platform/cloud:gcs_file_system", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", ], ) @@ -187,15 +187,6 @@ cc_library( ], ) -cc_library( - name = "multi_wait", - srcs = ["multi_wait.cc"], - hdrs = ["multi_wait.h"], - deps = [ - "@xla//xla:types", - ], -) - # Profiler silently fails unless we link these backends cc_library( name = "profiler_backends", diff --git a/torch_xla/csrc/runtime/multi_wait.cc b/torch_xla/csrc/runtime/multi_wait.cc deleted file mode 100644 index c4d0def062b..00000000000 --- a/torch_xla/csrc/runtime/multi_wait.cc +++ /dev/null @@ -1,73 +0,0 @@ -#include "torch_xla/csrc/runtime/multi_wait.h" - -#include -#include - -namespace torch_xla { -namespace runtime { -namespace util { - -void MultiWait::Done() { - bool notify = false; - { - std::lock_guard lock(mutex_); - completed_count_ += 1; - notify = completed_count_ == count_; - } - if (notify) { - cv_.notify_all(); - } -} - -void MultiWait::Wait() { - std::unique_lock lock(mutex_); - cv_.wait(lock, [this] { return completed_count_ >= count_; }); - if (exptr_ != nullptr) { - std::rethrow_exception(exptr_); - } -} - -void MultiWait::Wait(double wait_seconds) { - std::unique_lock lock(mutex_); - if (!cv_.wait_for(lock, std::chrono::duration(wait_seconds), - [this] { return completed_count_ >= count_; })) { - throw std::runtime_error("Timeout"); - } - if (exptr_ != nullptr) { - std::rethrow_exception(exptr_); - } -} - -void MultiWait::Reset(size_t count) { - std::lock_guard lock(mutex_); - count_ = count; - completed_count_ = 0; - exptr_ = nullptr; -} - -std::function MultiWait::Completer(std::function func) { - auto completer = [this, func = std::move(func)]() { Complete(func); }; - return completer; -} - -std::function MultiWait::Completer(std::shared_ptr mwait, - std::function func) { - auto completer = [mwait = std::move(mwait), func = std::move(func)]() { - mwait->Complete(func); - }; - return completer; -} - -void MultiWait::Complete(const std::function& func) { - try { - func(); - } catch (...) { - std::lock_guard lock(mutex_); - exptr_ = std::current_exception(); - } - Done(); -} - -} // namespace util -} // namespace runtime -} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/multi_wait.h b/torch_xla/csrc/runtime/multi_wait.h deleted file mode 100644 index 9637850d555..00000000000 --- a/torch_xla/csrc/runtime/multi_wait.h +++ /dev/null @@ -1,60 +0,0 @@ -#ifndef XLA_CLIENT_MULTI_WAIT_H_ -#define XLA_CLIENT_MULTI_WAIT_H_ - -#include -#include -#include -#include - -#include "xla/types.h" - -namespace torch_xla { -namespace runtime { -namespace util { - -// Support waiting for a number of tasks to complete. -class MultiWait { - public: - explicit MultiWait(size_t count) : count_(count) {} - - // Signal the completion of a single task. - void Done(); - - // Waits until at least count (passed as constructor value) completions - // happened. - void Wait(); - - // Same as above, but waits up to wait_seconds. - void Wait(double wait_seconds); - - // Resets the threshold counter for the MultiWait object. The completed count - // is also reset to zero. - void Reset(size_t count); - - // Creates a completer functor which signals the mult wait object once func - // has completed. Handles exceptions by signaling the multi wait with the - // proper status value. This API returns a function which captures a MultiWait - // reference, so care must be taken such that the reference remains valid for - // the whole lifetime of the returned function. - std::function Completer(std::function func); - - // Similar as the above API, but with explicit capture of the MultiWait shared - // pointer. - static std::function Completer(std::shared_ptr mwait, - std::function func); - - private: - void Complete(const std::function& func); - - std::mutex mutex_; - std::condition_variable cv_; - size_t count_ = 0; - size_t completed_count_ = 0; - std::exception_ptr exptr_; -}; - -} // namespace util -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_MULTI_WAIT_H_ diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 29b56af5c50..6712394ccfb 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -5,12 +5,12 @@ #include #include "absl/strings/ascii.h" +#include "absl/synchronization/blocking_counter.h" #include "absl/types/span.h" #include "pjrt_computation_client.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/env_vars.h" -#include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/profiler.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/tensor_source.h" @@ -619,9 +619,9 @@ PjRtComputationClient::ExecuteComputation( } CreateDataHandlesCounter()->AddValue(datas.size()); - auto mwait = std::make_shared(1); - auto lockfn = [&, this, device, returned_future = std::move(*returned_future), - timed]() mutable { + Schedule(std::move([&, this, device, + returned_future = std::move(*returned_future), + timed]() mutable { TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " << device; // Grab the shared lock and block the `WaitDeviceOps` until buffer is @@ -642,9 +642,7 @@ PjRtComputationClient::ExecuteComputation( timed.reset(); TF_VLOG(3) << "ExecuteComputation returned_future->OnReady finished"; }); - }; - - Schedule(util::MultiWait::Completer(mwait, std::move(lockfn))); + })); TF_VLOG(1) << "Returning " << datas.size() << " results"; return datas; @@ -668,7 +666,7 @@ PjRtComputationClient::ExecuteReplicated( XLA_CHECK(devices.size() == arguments.size()) << "ExecuteReplicated over " << devices.size() << " devices, but " << arguments.size() << " arguments devices."; - auto mwait_argument = std::make_shared(devices.size()); + absl::BlockingCounter mwait(devices.size()); std::vector> argument_handles(devices.size()); { tsl::profiler::TraceMe activity( @@ -689,11 +687,11 @@ PjRtComputationClient::ExecuteReplicated( buffers.push_back(pjrt_data->buffer.get()); } argument_handles[i] = std::move(buffers); + mwait.DecrementCount(); }; - Schedule(util::MultiWait::Completer( - mwait_argument, std::move(buffer_converter))); + Schedule(std::move(buffer_converter)); } - mwait_argument->Wait(); + mwait.Wait(); } xla::ExecuteOptions execute_options; @@ -748,9 +746,8 @@ PjRtComputationClient::ExecuteReplicated( } } - auto mwait = std::make_shared(1); - auto lockfn = [&, this, returned_futures = std::move(*returned_futures), - timed]() mutable { + Schedule(std::move([&, this, returned_futures = std::move(*returned_futures), + timed]() mutable { // Grab the shared lock and block the `WaitDeviceOps` until buffer is // ready. Since this is the SPMD code path. There is no points to grab // devices lock for every individual device. @@ -771,8 +768,7 @@ PjRtComputationClient::ExecuteReplicated( timed.reset(); TF_VLOG(3) << "ExecuteReplicated returned_future->OnReady finished"; }); - }; - Schedule(util::MultiWait::Completer(mwait, std::move(lockfn))); + })); TF_VLOG(1) << "Returning " << data_handles.size() << " sets of results " << "with dimensions [" << absl::StrJoin(dims, ",") << "]."; diff --git a/torch_xla/csrc/runtime/thread_pool.cc b/torch_xla/csrc/runtime/thread_pool.cc index 6cfc37a6614..a51916dd2a9 100644 --- a/torch_xla/csrc/runtime/thread_pool.cc +++ b/torch_xla/csrc/runtime/thread_pool.cc @@ -16,7 +16,8 @@ namespace runtime { void Schedule(std::function fn) { static size_t num_threads = sys_util::GetEnvInt( "XLA_THREAD_POOL_SIZE", std::thread::hardware_concurrency()); - static tsl::thread::ThreadPool pool(tsl::Env::Default(), "pytorchxla", num_threads); + static tsl::thread::ThreadPool pool(tsl::Env::Default(), "pytorchxla", + num_threads); pool.Schedule(std::move(fn)); } diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index f435fed48df..e598e4c0b3e 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -12,13 +12,13 @@ #include #include +#include "absl/synchronization/blocking_counter.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/layout_manager.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/tf_logging.h" @@ -366,16 +366,16 @@ void CopyTensors(const void* src_buffer, const xla::Shape& src_shape, std::vector iter_dims = GetIterationDimensions(dest_shape); std::vector parts = CreateCopyPartitions(dest_shape.dimensions(), iter_dims.front()); - auto mwait = std::make_shared(parts.size()); + absl::BlockingCounter mwait(parts.size()); for (size_t i = 0; i < parts.size(); ++i) { auto copy_fn = [&, i]() { SlicedCopy(dest_shape.dimensions(), src_data, src_strides, dest_data, dest_strides, iter_dims, parts[i]); + mwait.DecrementCount(); }; - runtime::Schedule( - runtime::util::MultiWait::Completer(mwait, std::move(copy_fn))); + runtime::Schedule(std::move(copy_fn)); } - mwait->Wait(); + mwait.Wait(); } } diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index c7a870be319..90eec4012d6 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -10,6 +10,7 @@ #include #include +#include "absl/synchronization/blocking_counter.h" #include "torch_xla/csrc/cross_replica_reduces.h" #include "torch_xla/csrc/debug_util.h" #include "torch_xla/csrc/device.h" @@ -18,7 +19,6 @@ #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/runtime/cache.h" #include "torch_xla/csrc/runtime/computation_client.h" -#include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/util.h" #include "torch_xla/csrc/tensor.h" #include "torch_xla/csrc/torch_util.h" diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index e40fa0ca879..6d1ec6853a8 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -5,6 +5,7 @@ #include #include +#include "absl/synchronization/blocking_counter.h" #include "torch/csrc/lazy/core/ir_util.h" #include "torch_xla/csrc/aten_autograd_ops.h" #include "torch_xla/csrc/aten_xla_bridge.h" @@ -13,7 +14,6 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/runtime/computation_client.h" -#include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/tensor.h" @@ -326,7 +326,7 @@ ShardingUtil::InputHandler( // the first local index with the first global device ordinal. auto device_index = build_index_map(devices); - auto mwait = std::make_shared(devices.size()); + absl::BlockingCounter mwait(devices.size()); for (int i = 0; i < devices.size(); i++) { auto argument_setter = [&, i]() { @@ -339,11 +339,11 @@ ShardingUtil::InputHandler( int device_i = device_index[global_ordinal]; arguments_by_device[device_i][argument_i] = shard; } + mwait.DecrementCount(); }; - runtime::Schedule( - runtime::util::MultiWait::Completer(mwait, std::move(argument_setter))); + runtime::Schedule(std::move(argument_setter)); } - mwait->Wait(); + mwait.Wait(); return arguments_by_device; }