diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index d705ea0bdc5..be201333c32 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -82,6 +82,7 @@ cc_library( ":computation_client", ":debug_macros", ":env_vars", + ":operation_manager", ":profiler", ":stablehlo_helper", ":tensor_source", @@ -187,6 +188,18 @@ cc_library( ], ) +cc_library( + name = "operation_manager", + srcs = ["operation_manager.cc"], + hdrs = ["operation_manager.h"], + visibility = ["//visibility:private"], + deps = [ + ":debug_macros", + ":tf_logging", + "@com_google_absl//absl/types:span", + ], +) + # Profiler silently fails unless we link these backends cc_library( name = "profiler_backends", diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 9af461bfd56..28e09be6c68 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -338,7 +338,7 @@ class ComputationClient { // Block until pass in devices' async operation are finished. If empty, all // the local devices will be waited for. - virtual void WaitDeviceOps(const std::vector& devices) = 0; + virtual void WaitDeviceOps(absl::Span devices) = 0; // Check whether the XlaCoordinator has been initialized. virtual bool CoordinatorInitialized() const = 0; diff --git a/torch_xla/csrc/runtime/operation_manager.cc b/torch_xla/csrc/runtime/operation_manager.cc new file mode 100644 index 00000000000..817bfacbba3 --- /dev/null +++ b/torch_xla/csrc/runtime/operation_manager.cc @@ -0,0 +1,83 @@ +#include "torch_xla/csrc/runtime/operation_manager.h" + +#include + +#include "absl/types/span.h" +#include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/tf_logging.h" + +namespace torch_xla { +namespace runtime { + +OperationManager::OperationManager(absl::Span devices) { + for (auto& device : devices) { + op_counters_.try_emplace(device, device); + } +} + +OperationManager::OperationTracker::OperationTracker(Counter* counter) + : counter_(counter) { + XLA_CHECK(counter_); + counter_->Increment(); +} + +OperationManager::OperationTracker::~OperationTracker() { + counter_->Decrement(); +} + +std::unique_ptr +OperationManager::StartOperation(std::string device) { + return std::make_unique(&op_counters_.at(device)); +} + +void OperationManager::WaitForDevices(absl::Span devices) { + std::vector> locks; + locks.reserve(devices.size()); + + for (const std::string& device_str : devices) { + TF_VLOG(5) << "Blocking new operations on " << device_str; + auto lock = op_counters_.at(device_str).BlockNewOperations(); + locks.emplace_back(std::move(lock)); + + TF_VLOG(3) << "Waiting for device execution for " << device_str + << " to finish"; + op_counters_.at(device_str).Wait(); + TF_VLOG(3) << "Finished operations on device " << device_str; + } +} + +void OperationManager::Counter::Increment() { + // Block new operations after BlockNewOperations() is called. count_ is + // already atomic, so atomic so we don't need an exclusive lock to prevent + // data races. + std::shared_lock lock(pending_operations_mu_); + auto current = count_.fetch_add(1, std::memory_order_acq_rel) + 1; + TF_VLOG(5) << "Incremented operations for " << device_ << " to " << current; +} + +void OperationManager::Counter::Decrement() { + auto current = count_.fetch_sub(1, std::memory_order_acq_rel) - 1; + TF_VLOG(5) << "Decremented operations for " << device_ << " to " << current; + + if (current == 0) { + std::unique_lock cv_lock(cv_mu_); + TF_VLOG(3) << "All operations complete for " << device_; + cv_.notify_all(); + } +} + +std::unique_lock +OperationManager::Counter::BlockNewOperations() { + return std::unique_lock(pending_operations_mu_); +} + +void OperationManager::Counter::Wait() { + TF_VLOG(5) << "Waiting for " << count_ << " operations on " << device_; + std::unique_lock cv_lock(cv_mu_); + cv_.wait(cv_lock, + [this] { return count_.load(std::memory_order_acquire) == 0; }); + TF_VLOG(5) << "Done waiting for " << device_; +} + +} // namespace runtime +} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/operation_manager.h b/torch_xla/csrc/runtime/operation_manager.h new file mode 100644 index 00000000000..e4e8136f27e --- /dev/null +++ b/torch_xla/csrc/runtime/operation_manager.h @@ -0,0 +1,86 @@ +#ifndef XLA_CLIENT_OPERATION_MANAGER_H_ +#define XLA_CLIENT_OPERATION_MANAGER_H_ + +#include +#include +#include +#include +#include + +#include "absl/types/span.h" + +namespace torch_xla { +namespace runtime { + +// Track inflight operations for each device. +class OperationManager { + public: + OperationManager() = default; + OperationManager(absl::Span); + + OperationManager(const OperationManager&) = delete; + OperationManager& operator=(const OperationManager&) = delete; + + OperationManager(OperationManager&&) = default; + OperationManager& operator=(OperationManager&&) = default; + + class Counter { + public: + Counter(const std::string& device) : device_(device){}; + + Counter(const Counter&) = delete; + Counter& operator=(const Counter&) = delete; + + // Register a new operation. Blocks if `BlockNewOperations` has been called. + void Increment(); + + // Mark an inflight task completed. + void Decrement(); + + // Wait until all operations are complete. Does not block new operations + // (see BlockNewOperations). + void Wait(); + + // Returns a lock that prevents new operations on the device. + std::unique_lock BlockNewOperations(); + + private: + std::string device_; + + std::shared_mutex pending_operations_mu_; + std::atomic count_{0}; + + std::mutex cv_mu_; + std::condition_variable cv_; + }; + + class OperationTracker { + public: + // Register an operation in the `counter_`. + OperationTracker(Counter* counter); + + // Mark an operation complete in `counter_`. + ~OperationTracker(); + + OperationTracker(const OperationTracker&) = delete; + OperationTracker& operator=(const OperationTracker&) = delete; + + private: + std::string device_; + Counter* counter_; + }; + + // Register a new operation for `device`. + std::unique_ptr StartOperation(std::string device); + + // Wait for all device execution to complete on devices. + void WaitForDevices(absl::Span devices); + + private: + std::unordered_map op_counters_; +}; + +} // namespace runtime +} // namespace torch_xla + +#endif diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 9ad731eba82..57fbb3cc861 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -11,6 +11,7 @@ #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/env_vars.h" +#include "torch_xla/csrc/runtime/operation_manager.h" #include "torch_xla/csrc/runtime/profiler.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/tensor_source.h" @@ -206,10 +207,11 @@ PjRtComputationClient::PjRtComputationClient() { global_ordinals_[device->id()] = global_ordinals_.size(); std::string device_str = PjRtDeviceToString(device); string_to_device_.emplace(device_str, device); - device_locks_.emplace(device_str, std::make_unique()); } - // manually create the device_locks for SPMD device - device_locks_.emplace(spmd_device_str, std::make_unique()); + + auto tracked_devices = GetLocalDevices(); + tracked_devices.emplace_back(spmd_device_str); + operation_manager_ = std::move(OperationManager(std::move(tracked_devices))); } PjRtComputationClient::~PjRtComputationClient() { @@ -601,6 +603,11 @@ PjRtComputationClient::ExecuteComputation( // Required as of cl/518733871 execute_options.use_major_to_minor_data_layout_for_callbacks = true; + TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " << device; + auto op_tracker = operation_manager_.StartOperation(device); + TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " << device + << " Done"; + std::optional> returned_future; std::vector> results = pjrt_computation.executable @@ -608,6 +615,12 @@ PjRtComputationClient::ExecuteComputation( returned_future) .value(); + returned_future->OnReady(std::move( + [timed, op_tracker = std::move(op_tracker)](xla::Status unused) mutable { + timed.reset(); + TF_VLOG(3) << "ExecuteComputation returned_future->OnReady finished"; + })); + std::vector datas; datas.reserve(results.size()); for (auto& result : results) { @@ -620,31 +633,6 @@ PjRtComputationClient::ExecuteComputation( } CreateDataHandlesCounter()->AddValue(datas.size()); - thread::Schedule(std::move([&, this, device, - returned_future = std::move(*returned_future), - timed]() mutable { - TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " - << device; - // Grab the shared lock and block the `WaitDeviceOps` until buffer is - // ready. - // TODO(JackCaoG): This lock should acquired outside of the lockfn and - // passed in. It is possible that lockfn started after ExecuteComputation - // released the xla_graph_executor lock, which will create a short windows - // where device is unlcoked while execution is still running. - auto lock = lock_device_shared(device); - TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " << device - << " Done"; - // Signal that `ExecuteSharded` has completed for the ExecuteTime - // metric. Copies the `timed` shared pointer into the lambda. - XLA_CHECK(returned_future.IsValid()) - << "returned_future in ExecuteComputation is empty"; - returned_future.OnReady( - [timed, lock = std::move(lock)](xla::Status unused) mutable { - timed.reset(); - TF_VLOG(3) << "ExecuteComputation returned_future->OnReady finished"; - }); - })); - TF_VLOG(1) << "Returning " << datas.size() << " results"; return datas; } @@ -704,6 +692,15 @@ PjRtComputationClient::ExecuteReplicated( // Required as of cl/518733871 execute_options.use_major_to_minor_data_layout_for_callbacks = true; + // Grab the shared lock and block the `WaitDeviceOps` until buffer is + // ready. Since this is the SPMD code path. There is no points to grab + // devices lock for every individual device. + TF_VLOG(5) << "ExecuteReplicated acquiring PJRT device lock for " + << spmd_device_str; + auto op_tracker = operation_manager_.StartOperation(spmd_device_str); + TF_VLOG(5) << "ExecuteReplicated acquiring PJRT device lock for " + << spmd_device_str << " Done"; + std::optional>> returned_futures( devices.size()); std::vector>> results; @@ -715,6 +712,13 @@ PjRtComputationClient::ExecuteReplicated( ->Execute(std::move(argument_handles), execute_options, returned_futures) .value(); + + (*returned_futures)[0].OnReady( + std::move([timed, op_tracker = std::move(op_tracker)]( + xla::Status unused) mutable { + timed.reset(); + TF_VLOG(3) << "ExecuteReplicated returned_future->OnReady finished"; + })); } std::vector> data_handles; @@ -747,31 +751,6 @@ PjRtComputationClient::ExecuteReplicated( } } - thread::Schedule(std::move([&, this, - returned_futures = std::move(*returned_futures), - timed]() mutable { - // Grab the shared lock and block the `WaitDeviceOps` until buffer is - // ready. Since this is the SPMD code path. There is no points to grab - // devices lock for every individual device. - TF_VLOG(5) << "ExecuteReplicated acquiring PJRT device lock for " - << spmd_device_str; - auto lock = lock_device_shared(spmd_device_str); - TF_VLOG(5) << "ExecuteReplicated acquiring PJRT device lock for " - << spmd_device_str << " Done"; - // Signal that `ExecuteReplicated` has completed for one of the devices - // the ExecuteReplicatedTime metric. Here, we assume that all devices - // will finish execution roughly at the same time, hence only use one of - // the returned_futures. Copies the `timed` shared pointer into the - // lambda. - XLA_CHECK(returned_futures[0].IsValid()) - << "returned_future in ExecuteReplicated is empty"; - returned_futures[0].OnReady( - [timed, lock = std::move(lock)](xla::Status unused) mutable { - timed.reset(); - TF_VLOG(3) << "ExecuteReplicated returned_future->OnReady finished"; - }); - })); - TF_VLOG(1) << "Returning " << data_handles.size() << " sets of results " << "with dimensions [" << absl::StrJoin(dims, ",") << "]."; return data_handles; @@ -826,37 +805,11 @@ xla::PjRtDevice* PjRtComputationClient::StringToPjRtDevice( return pjrt_device; } -std::shared_lock PjRtComputationClient::lock_device_shared( - const std::string& device) { - std::shared_lock lock(*device_locks_[device]); - return lock; -} - -std::unique_lock PjRtComputationClient::lock_device( - const std::string& device) { - std::unique_lock lock(*device_locks_[device]); - return lock; -} - void PjRtComputationClient::WaitDeviceOps( - const std::vector& devices) { - std::unordered_set wait_devices; - if (!devices.empty()) { - for (auto& device_str : devices) { - wait_devices.insert(device_str); - } - } else { - for (auto& device_str : GetLocalDevices()) { - wait_devices.insert(device_str); - } - } - for (const std::string& device_str : wait_devices) { - TF_VLOG(3) << "Waiting for device execution for " << device_str - << " to finish"; - lock_device(device_str); - TF_VLOG(3) << "Waiting for device execution for " << device_str - << " to finish.. Done"; - } + absl::Span devices) { + TF_VLOG(3) << "Waiting for " << absl::StrJoin(devices, ", "); + operation_manager_.WaitForDevices(devices.empty() ? GetLocalDevices() + : devices); } std::map PjRtComputationClient::GetMetrics() const { diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index b66e4ff5097..e5eaf4039ef 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -10,6 +10,7 @@ #include "absl/types/span.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/operation_manager.h" #include "torch_xla/csrc/runtime/util.h" #include "xla/client/xla_computation.h" #include "xla/literal.h" @@ -86,7 +87,7 @@ class PjRtComputationClient : public ComputationClient { std::shared_ptr> GetReplicationDevices() override; - void WaitDeviceOps(const std::vector& devices) override; + void WaitDeviceOps(absl::Span devices) override; std::map GetMetrics() const override; @@ -112,13 +113,9 @@ class PjRtComputationClient : public ComputationClient { std::unordered_map global_ordinals_; std::unordered_map string_to_device_; std::shared_ptr> replication_devices_; - std::unordered_map> - device_locks_; + OperationManager operation_manager_; xla::PjRtDevice* StringToPjRtDevice(const std::string& device); - std::shared_lock lock_device_shared( - const std::string& device); - std::unique_lock lock_device(const std::string& device); std::string PjRtDeviceToString(xla::PjRtDevice* const device) const; std::vector PjRtDevicesToString(