diff --git a/test/cpp/BUILD b/test/cpp/BUILD index fd53eefc377e..c8aeb729d784 100644 --- a/test/cpp/BUILD +++ b/test/cpp/BUILD @@ -78,9 +78,9 @@ ptxla_cc_test( ":torch_xla_test", "//torch_xla/csrc/runtime:runtime", "//torch_xla/csrc/runtime:debug_macros", - "//torch_xla/csrc/runtime:multi_wait", - "//torch_xla/csrc/runtime:thread_pool", "//torch_xla/csrc:tensor", + "//torch_xla/csrc:thread_pool", + "@com_google_absl//absl/synchronization", "@com_google_googletest//:gtest_main", "@xla//xla:shape_util", "@xla//xla/client:xla_builder", diff --git a/test/cpp/test_replication.cpp b/test/cpp/test_replication.cpp index 6d7a54add0ce..39fbb7201b05 100644 --- a/test/cpp/test_replication.cpp +++ b/test/cpp/test_replication.cpp @@ -3,15 +3,15 @@ #include +#include "absl/synchronization/blocking_counter.h" #include "test/cpp/cpp_test_util.h" #include "test/cpp/torch_xla_test.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/runtime.h" -#include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/thread_pool.h" #include "torch_xla/csrc/torch_util.h" #include "xla/client/xla_builder.h" #include "xla/shape_util.h" @@ -57,7 +57,7 @@ void TestSingleReplication( std::vector> results(device_strings.size()); - torch_xla::runtime::util::MultiWait mwait(device_strings.size()); + absl::BlockingCounter counter(device_strings.size()); torch_xla::runtime::ComputationClient::ExecuteComputationOptions exec_options; for (size_t i = 0; i < device_strings.size(); ++i) { auto executor = [&, i]() { @@ -68,11 +68,11 @@ void TestSingleReplication( torch_xla::runtime::ComputationClient::Data>( tensors_data[i])}, device_strings[i], exec_options); + counter.DecrementCount(); }; - torch_xla::runtime::env::ScheduleIoClosure( - mwait.Completer(std::move(executor))); + torch_xla::thread::Schedule(std::move(executor)); } - mwait.Wait(); + counter.Wait(); for (size_t i = 0; i < results.size(); ++i) { auto literals = diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 128107e9e7ad..b18014ab2dfb 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -269,15 +269,14 @@ ptxla_cc_library( "//torch_xla/csrc/runtime:metrics", "//torch_xla/csrc/runtime:metrics_analysis", "//torch_xla/csrc/runtime:metrics_reader", - "//torch_xla/csrc/runtime:multi_wait", "//torch_xla/csrc/runtime:profiler", "//torch_xla/csrc/runtime:sys_util", - "//torch_xla/csrc/runtime:thread_pool", "//torch_xla/csrc/runtime:util", "//torch_xla/csrc/runtime:xla_coordinator", "//torch_xla/csrc/runtime:xla_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:variant", "@tsl//tsl/profiler/lib:traceme", "@tsl//tsl/profiler/lib:traceme_encode", @@ -320,6 +319,16 @@ cc_library( ], ) +cc_library( + name = "thread_pool", + srcs = ["thread_pool.cc"], + hdrs = ["thread_pool.h"], + deps = [ + "//torch_xla/csrc/runtime:sys_util", + "@tsl//tsl/platform:env" + ], +) + ptxla_cc_library( name = "unwrap_data", srcs = ["unwrap_data.cpp"], diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 8c45d68f8029..4758579bbb6a 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -20,6 +20,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" +#include "absl/synchronization/blocking_counter.h" #include "absl/types/variant.h" #include "pybind11/attr.h" #include "pybind11/cast.h" @@ -43,11 +44,9 @@ #include "torch_xla/csrc/runtime/metrics.h" #include "torch_xla/csrc/runtime/metrics_analysis.h" #include "torch_xla/csrc/runtime/metrics_reader.h" -#include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/profiler.h" #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/sys_util.h" -#include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/runtime/util.h" #include "torch_xla/csrc/runtime/xla_coordinator.h" #include "torch_xla/csrc/runtime/xla_util.h" diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index cbeea6abeb7c..d705ea0bdc5a 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -82,13 +82,12 @@ cc_library( ":computation_client", ":debug_macros", ":env_vars", - ":multi_wait", ":profiler", ":stablehlo_helper", ":tensor_source", ":tf_logging", - ":thread_pool", ":xla_coordinator", + "//torch_xla/csrc:thread_pool", "@xla//xla:literal", "@xla//xla:shape_util", "@xla//xla/client:xla_computation", @@ -102,6 +101,7 @@ cc_library( "@tsl//tsl/profiler/lib:traceme", "@tsl//tsl/platform/cloud:gcs_file_system", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", ], ) @@ -187,15 +187,6 @@ cc_library( ], ) -cc_library( - name = "multi_wait", - srcs = ["multi_wait.cc"], - hdrs = ["multi_wait.h"], - deps = [ - "@xla//xla:types", - ], -) - # Profiler silently fails unless we link these backends cc_library( name = "profiler_backends", @@ -279,16 +270,6 @@ cc_library( ], ) -cc_library( - name = "thread_pool", - srcs = ["thread_pool.cc"], - hdrs = ["thread_pool.h"], - deps = [ - ":metrics", - ":tf_logging", - ], -) - cc_library( name = "tensor_source", hdrs = ["tensor_source.h"], diff --git a/torch_xla/csrc/runtime/multi_wait.cc b/torch_xla/csrc/runtime/multi_wait.cc deleted file mode 100644 index c4d0def062b0..000000000000 --- a/torch_xla/csrc/runtime/multi_wait.cc +++ /dev/null @@ -1,73 +0,0 @@ -#include "torch_xla/csrc/runtime/multi_wait.h" - -#include -#include - -namespace torch_xla { -namespace runtime { -namespace util { - -void MultiWait::Done() { - bool notify = false; - { - std::lock_guard lock(mutex_); - completed_count_ += 1; - notify = completed_count_ == count_; - } - if (notify) { - cv_.notify_all(); - } -} - -void MultiWait::Wait() { - std::unique_lock lock(mutex_); - cv_.wait(lock, [this] { return completed_count_ >= count_; }); - if (exptr_ != nullptr) { - std::rethrow_exception(exptr_); - } -} - -void MultiWait::Wait(double wait_seconds) { - std::unique_lock lock(mutex_); - if (!cv_.wait_for(lock, std::chrono::duration(wait_seconds), - [this] { return completed_count_ >= count_; })) { - throw std::runtime_error("Timeout"); - } - if (exptr_ != nullptr) { - std::rethrow_exception(exptr_); - } -} - -void MultiWait::Reset(size_t count) { - std::lock_guard lock(mutex_); - count_ = count; - completed_count_ = 0; - exptr_ = nullptr; -} - -std::function MultiWait::Completer(std::function func) { - auto completer = [this, func = std::move(func)]() { Complete(func); }; - return completer; -} - -std::function MultiWait::Completer(std::shared_ptr mwait, - std::function func) { - auto completer = [mwait = std::move(mwait), func = std::move(func)]() { - mwait->Complete(func); - }; - return completer; -} - -void MultiWait::Complete(const std::function& func) { - try { - func(); - } catch (...) { - std::lock_guard lock(mutex_); - exptr_ = std::current_exception(); - } - Done(); -} - -} // namespace util -} // namespace runtime -} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/multi_wait.h b/torch_xla/csrc/runtime/multi_wait.h deleted file mode 100644 index 9637850d555a..000000000000 --- a/torch_xla/csrc/runtime/multi_wait.h +++ /dev/null @@ -1,60 +0,0 @@ -#ifndef XLA_CLIENT_MULTI_WAIT_H_ -#define XLA_CLIENT_MULTI_WAIT_H_ - -#include -#include -#include -#include - -#include "xla/types.h" - -namespace torch_xla { -namespace runtime { -namespace util { - -// Support waiting for a number of tasks to complete. -class MultiWait { - public: - explicit MultiWait(size_t count) : count_(count) {} - - // Signal the completion of a single task. - void Done(); - - // Waits until at least count (passed as constructor value) completions - // happened. - void Wait(); - - // Same as above, but waits up to wait_seconds. - void Wait(double wait_seconds); - - // Resets the threshold counter for the MultiWait object. The completed count - // is also reset to zero. - void Reset(size_t count); - - // Creates a completer functor which signals the mult wait object once func - // has completed. Handles exceptions by signaling the multi wait with the - // proper status value. This API returns a function which captures a MultiWait - // reference, so care must be taken such that the reference remains valid for - // the whole lifetime of the returned function. - std::function Completer(std::function func); - - // Similar as the above API, but with explicit capture of the MultiWait shared - // pointer. - static std::function Completer(std::shared_ptr mwait, - std::function func); - - private: - void Complete(const std::function& func); - - std::mutex mutex_; - std::condition_variable cv_; - size_t count_ = 0; - size_t completed_count_ = 0; - std::exception_ptr exptr_; -}; - -} // namespace util -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_MULTI_WAIT_H_ diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index ab998d3ce0da..9ad731eba82a 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -5,18 +5,18 @@ #include #include "absl/strings/ascii.h" +#include "absl/synchronization/blocking_counter.h" #include "absl/types/span.h" #include "pjrt_computation_client.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/env_vars.h" -#include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/profiler.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/tensor_source.h" #include "torch_xla/csrc/runtime/tf_logging.h" -#include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/runtime/xla_coordinator.h" +#include "torch_xla/csrc/thread_pool.h" #include "tsl/profiler/lib/traceme.h" #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" @@ -620,9 +620,9 @@ PjRtComputationClient::ExecuteComputation( } CreateDataHandlesCounter()->AddValue(datas.size()); - auto mwait = std::make_shared(1); - auto lockfn = [&, this, device, returned_future = std::move(*returned_future), - timed]() mutable { + thread::Schedule(std::move([&, this, device, + returned_future = std::move(*returned_future), + timed]() mutable { TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " << device; // Grab the shared lock and block the `WaitDeviceOps` until buffer is @@ -643,9 +643,7 @@ PjRtComputationClient::ExecuteComputation( timed.reset(); TF_VLOG(3) << "ExecuteComputation returned_future->OnReady finished"; }); - }; - - env::ScheduleIoClosure(util::MultiWait::Completer(mwait, std::move(lockfn))); + })); TF_VLOG(1) << "Returning " << datas.size() << " results"; return datas; @@ -669,7 +667,7 @@ PjRtComputationClient::ExecuteReplicated( XLA_CHECK(devices.size() == arguments.size()) << "ExecuteReplicated over " << devices.size() << " devices, but " << arguments.size() << " arguments devices."; - auto mwait_argument = std::make_shared(devices.size()); + absl::BlockingCounter counter(devices.size()); std::vector> argument_handles(devices.size()); { tsl::profiler::TraceMe activity( @@ -690,11 +688,11 @@ PjRtComputationClient::ExecuteReplicated( buffers.push_back(pjrt_data->buffer.get()); } argument_handles[i] = std::move(buffers); + counter.DecrementCount(); }; - env::ScheduleIoClosure(util::MultiWait::Completer( - mwait_argument, std::move(buffer_converter))); + thread::Schedule(std::move(buffer_converter)); } - mwait_argument->Wait(); + counter.Wait(); } xla::ExecuteOptions execute_options; @@ -749,9 +747,9 @@ PjRtComputationClient::ExecuteReplicated( } } - auto mwait = std::make_shared(1); - auto lockfn = [&, this, returned_futures = std::move(*returned_futures), - timed]() mutable { + thread::Schedule(std::move([&, this, + returned_futures = std::move(*returned_futures), + timed]() mutable { // Grab the shared lock and block the `WaitDeviceOps` until buffer is // ready. Since this is the SPMD code path. There is no points to grab // devices lock for every individual device. @@ -772,8 +770,7 @@ PjRtComputationClient::ExecuteReplicated( timed.reset(); TF_VLOG(3) << "ExecuteReplicated returned_future->OnReady finished"; }); - }; - env::ScheduleIoClosure(util::MultiWait::Completer(mwait, std::move(lockfn))); + })); TF_VLOG(1) << "Returning " << data_handles.size() << " sets of results " << "with dimensions [" << absl::StrJoin(dims, ",") << "]."; diff --git a/torch_xla/csrc/runtime/thread_pool.cc b/torch_xla/csrc/runtime/thread_pool.cc deleted file mode 100644 index fa0212e3a26a..000000000000 --- a/torch_xla/csrc/runtime/thread_pool.cc +++ /dev/null @@ -1,183 +0,0 @@ -#include "torch_xla/csrc/runtime/thread_pool.h" - -#include -#include -#include -#include - -#include "torch_xla/csrc/runtime/metrics.h" -#include "torch_xla/csrc/runtime/tf_logging.h" - -namespace torch_xla { -namespace runtime { -namespace env { -namespace { - -class ThreadPool { - public: - explicit ThreadPool(size_t num_threads) { - threads_.reserve(num_threads); - for (size_t i = 0; i < num_threads; ++i) { - threads_.emplace_back([this]() { Worker(); }); - } - } - - ~ThreadPool() { - { - std::lock_guard lock(mutex_); - exiting_ = true; - cv_.notify_all(); - } - for (auto& thread : threads_) { - thread.join(); - } - } - - void Schedule(std::function closure) { - // If we have more work scheduled than waiting worker threads, just schedule - // it on a separate thread. This prevents tricky thread-pool-size-deadlocks - // caused by an undersized thread pool and closures that end up doing sync - // waits on the pool threads. - bool scheduled = false; - { - std::lock_guard lock(mutex_); - if (work_.size() < waiting_) { - work_.emplace_back(std::move(closure)); - scheduled = true; - } - } - if (scheduled) { - cv_.notify_one(); - } else { - ScheduleOnThread(std::move(closure)); - } - } - - private: - void Worker() { - while (true) { - std::function closure = GetWork(); - if (closure == nullptr) { - break; - } - try { - closure(); - } catch (const std::exception& ex) { - XLA_COUNTER("ThreadPoolException", 1); - TF_LOG(ERROR) << "Exception from running thread pool closure: " - << ex.what(); - } - } - } - - void ScheduleOnThread(std::function closure) { - std::thread thread(std::move(closure)); - thread.detach(); - } - - std::function GetWork() { - std::unique_lock lock(mutex_); - ++waiting_; - cv_.wait(lock, [this] { return exiting_ || !work_.empty(); }); - --waiting_; - if (work_.empty()) { - return nullptr; - } - std::function closure(std::move(work_.front())); - work_.pop_front(); - return closure; - } - - std::vector threads_; - std::mutex mutex_; - std::condition_variable cv_; - bool exiting_ = false; - std::deque> work_; - size_t waiting_ = 0; -}; - -ThreadPool* GetThreadPool() { - static size_t num_threads = sys_util::GetEnvInt( - "XLA_THREAD_POOL_SIZE", std::thread::hardware_concurrency()); - static ThreadPool* pool = new ThreadPool(num_threads); - return pool; -} - -ThreadPool* GetIoThreadPool() { - static size_t num_threads = sys_util::GetEnvInt( - "XLA_IO_THREAD_POOL_SIZE", std::thread::hardware_concurrency()); - static ThreadPool* pool = new ThreadPool(num_threads); - return pool; -} - -} // namespace - -class Completion::Data { - public: - void Wait() { - std::unique_lock lock(mutex_); - cv_.wait(lock, [this] { return completed_; }); - if (exptr_ != nullptr) { - std::rethrow_exception(exptr_); - } - } - - static std::function GetCompleter(std::shared_ptr data, - std::function closure) { - auto closure_wrapper = [closure = std::move(closure), data]() { - std::exception_ptr exptr; - try { - closure(); - } catch (...) { - exptr = std::current_exception(); - } - data->Complete(exptr); - }; - return closure_wrapper; - } - - private: - void Complete(std::exception_ptr exptr) { - std::lock_guard lock(mutex_); - exptr_ = std::move(exptr); - completed_ = true; - cv_.notify_all(); - } - - std::mutex mutex_; - std::condition_variable cv_; - bool completed_ = false; - std::exception_ptr exptr_; -}; - -Completion::Completion(std::shared_ptr data) : data_(std::move(data)) {} - -Completion::~Completion() {} - -void Completion::Wait() { data_->Wait(); } - -void ScheduleClosure(std::function closure) { - GetThreadPool()->Schedule(std::move(closure)); -} - -void ScheduleIoClosure(std::function closure) { - GetIoThreadPool()->Schedule(std::move(closure)); -} - -Completion ScheduleClosureWithCompletion(std::function closure) { - auto data = std::make_shared(); - GetThreadPool()->Schedule( - Completion::Data::GetCompleter(data, std::move(closure))); - return Completion(std::move(data)); -} - -Completion ScheduleIoClosureWithCompletion(std::function closure) { - auto data = std::make_shared(); - GetIoThreadPool()->Schedule( - Completion::Data::GetCompleter(data, std::move(closure))); - return Completion(std::move(data)); -} - -} // namespace env -} // namespace runtime -} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/thread_pool.h b/torch_xla/csrc/runtime/thread_pool.h deleted file mode 100644 index 072e28594cce..000000000000 --- a/torch_xla/csrc/runtime/thread_pool.h +++ /dev/null @@ -1,39 +0,0 @@ -#ifndef XLA_CLIENT_THREAD_POOL_H_ -#define XLA_CLIENT_THREAD_POOL_H_ - -#include -#include -#include - -namespace torch_xla { -namespace runtime { -namespace env { - -class Completion { - public: - class Data; - - explicit Completion(std::shared_ptr data); - - ~Completion(); - - void Wait(); - - private: - std::shared_ptr data_; -}; - -// Schedules a closure to be run. The closure should not block waiting for other -// events. -void ScheduleClosure(std::function closure); -Completion ScheduleClosureWithCompletion(std::function closure); - -// Schedules a closure which might wait for IO or other events/conditions. -void ScheduleIoClosure(std::function closure); -Completion ScheduleIoClosureWithCompletion(std::function closure); - -} // namespace env -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_THREAD_POOL_H_ diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index b30cbe7c01ec..96465abf44c1 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -38,7 +38,6 @@ #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "torch_xla/csrc/runtime/sys_util.h" -#include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/torch_util.h" diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 6e46899aea3c..e46bf7e022cb 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -12,18 +12,18 @@ #include #include +#include "absl/synchronization/blocking_counter.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/layout_manager.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/tf_logging.h" -#include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/runtime/util.h" +#include "torch_xla/csrc/thread_pool.h" #include "torch_xla/csrc/torch_util.h" #include "torch_xla/csrc/xla_backend_impl.h" #include "torch_xla/csrc/xla_sharding_util.h" @@ -366,16 +366,16 @@ void CopyTensors(const void* src_buffer, const xla::Shape& src_shape, std::vector iter_dims = GetIterationDimensions(dest_shape); std::vector parts = CreateCopyPartitions(dest_shape.dimensions(), iter_dims.front()); - auto mwait = std::make_shared(parts.size()); + absl::BlockingCounter counter(parts.size()); for (size_t i = 0; i < parts.size(); ++i) { auto copy_fn = [&, i]() { SlicedCopy(dest_shape.dimensions(), src_data, src_strides, dest_data, dest_strides, iter_dims, parts[i]); + counter.DecrementCount(); }; - runtime::env::ScheduleClosure( - runtime::util::MultiWait::Completer(mwait, std::move(copy_fn))); + thread::Schedule(std::move(copy_fn)); } - mwait->Wait(); + counter.Wait(); } } diff --git a/torch_xla/csrc/thread_pool.cc b/torch_xla/csrc/thread_pool.cc new file mode 100644 index 000000000000..e440afce7bda --- /dev/null +++ b/torch_xla/csrc/thread_pool.cc @@ -0,0 +1,21 @@ +#include "torch_xla/csrc/thread_pool.h" + +#include + +#include "torch_xla/csrc/runtime/sys_util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/threadpool.h" + +namespace torch_xla { +namespace thread { + +void Schedule(std::function fn) { + static size_t num_threads = torch_xla::runtime::sys_util::GetEnvInt( + "XLA_THREAD_POOL_SIZE", std::thread::hardware_concurrency()); + static tsl::thread::ThreadPool pool(tsl::Env::Default(), "pytorchxla", + num_threads); + pool.Schedule(std::move(fn)); +} + +} // namespace thread +} // namespace torch_xla diff --git a/torch_xla/csrc/thread_pool.h b/torch_xla/csrc/thread_pool.h new file mode 100644 index 000000000000..22074e6886f3 --- /dev/null +++ b/torch_xla/csrc/thread_pool.h @@ -0,0 +1,16 @@ +#ifndef XLA_CLIENT_THREAD_POOL_H_ +#define XLA_CLIENT_THREAD_POOL_H_ + +#include + +namespace torch_xla { +namespace thread { + +// Schedules a closure to be run. The closure should not block waiting for other +// events. +void Schedule(std::function fn); + +} // namespace thread +} // namespace torch_xla + +#endif // XLA_CLIENT_THREAD_POOL_H_ diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 02ea28874e34..0033176a172e 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -48,10 +48,10 @@ #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/sys_util.h" -#include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/thread_pool.h" #include "torch_xla/csrc/torch_util.h" #include "torch_xla/csrc/xla_backend_impl.h" #include "torch_xla/csrc/xla_sharding_util.h" @@ -757,7 +757,7 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( } }; - runtime::env::ScheduleIoClosure(async->mwait.Completer(std::move(syncfn))); + thread::Schedule(async->mwait.Completer(std::move(syncfn))); return placeholders; } @@ -1029,7 +1029,7 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( } }; - runtime::env::ScheduleIoClosure(async->mwait.Completer(std::move(syncfn))); + thread::Schedule(async->mwait.Completer(std::move(syncfn))); return async; } diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index c7a870be3196..90eec4012d68 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -10,6 +10,7 @@ #include #include +#include "absl/synchronization/blocking_counter.h" #include "torch_xla/csrc/cross_replica_reduces.h" #include "torch_xla/csrc/debug_util.h" #include "torch_xla/csrc/device.h" @@ -18,7 +19,6 @@ #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/runtime/cache.h" #include "torch_xla/csrc/runtime/computation_client.h" -#include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/util.h" #include "torch_xla/csrc/tensor.h" #include "torch_xla/csrc/torch_util.h" diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index ae5863160734..36fc1810d8b3 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -5,6 +5,7 @@ #include #include +#include "absl/synchronization/blocking_counter.h" #include "torch/csrc/lazy/core/ir_util.h" #include "torch_xla/csrc/aten_autograd_ops.h" #include "torch_xla/csrc/aten_xla_bridge.h" @@ -13,12 +14,11 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/runtime/computation_client.h" -#include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/runtime.h" -#include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/tensor.h" #include "torch_xla/csrc/tensor_methods.h" #include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/thread_pool.h" #include "torch_xla/csrc/xla_graph_executor.h" #include "tsl/profiler/lib/traceme.h" #include "xla/execution_options_util.h" @@ -326,7 +326,7 @@ ShardingUtil::InputHandler( // the first local index with the first global device ordinal. auto device_index = build_index_map(devices); - auto mwait = std::make_shared(devices.size()); + absl::BlockingCounter counter(devices.size()); for (int i = 0; i < devices.size(); i++) { auto argument_setter = [&, i]() { @@ -339,11 +339,11 @@ ShardingUtil::InputHandler( int device_i = device_index[global_ordinal]; arguments_by_device[device_i][argument_i] = shard; } + counter.DecrementCount(); }; - runtime::env::ScheduleIoClosure( - runtime::util::MultiWait::Completer(mwait, std::move(argument_setter))); + thread::Schedule(std::move(argument_setter)); } - mwait->Wait(); + counter.Wait(); return arguments_by_device; }