diff --git a/third_party/xla/xla/backends/gpu/collectives/BUILD b/third_party/xla/xla/backends/gpu/collectives/BUILD index 400606643acc36..7121377146e064 100644 --- a/third_party/xla/xla/backends/gpu/collectives/BUILD +++ b/third_party/xla/xla/backends/gpu/collectives/BUILD @@ -120,6 +120,7 @@ cc_library( "@local_tsl//tsl/platform:hash", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:traceme", ], ) diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_clique_locking.cc b/third_party/xla/xla/backends/gpu/collectives/gpu_clique_locking.cc index a699472bfb4faa..be53a701c1192e 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_clique_locking.cc +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_clique_locking.cc @@ -57,6 +57,7 @@ limitations under the License. #include "tsl/platform/hash.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/traceme.h" namespace xla::gpu { @@ -208,6 +209,8 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device, // gives access to clique communicators. auto initialize = [&](absl::Span args) -> absl::StatusOr { + tsl::profiler::TraceMe trace("InitializeGpuClique"); + TF_ASSIGN_OR_RETURN(auto clique_id, clique_id_callback(clique_key)); // Check that all ranks successfully synchronized device activity before @@ -348,6 +351,8 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device, // gives access to clique communicators. auto split = [&](absl::Span rank_pairs) -> absl::StatusOr { + tsl::profiler::TraceMe trace("SplitGpuClique"); + // Collect mapping from ranks in parent clique to ranks in a new clique. absl::btree_map rank_mapping; for (auto* rank_pair : rank_pairs) { @@ -446,6 +451,13 @@ absl::StatusOr> AcquireGpuClique( << "; acquired_cliques=" << acquired_cliques.size() << "; max_nchannels=" << max_nchannels; + tsl::profiler::TraceMe trace([&] { + return tsl::profiler::TraceMeEncode( + "AcquireGpuClique", {{"rank", rank.value()}, + {"num_local_participants", num_local_participants}, + {"clique_key", clique_key.ToString()}}); + }); + // Get the clique lock via the rendezvous to guarantee that all clique // members participate in XLA run. auto rendezvous_key = std::make_tuple(run_id, clique_key); @@ -458,12 +470,18 @@ absl::StatusOr> AcquireGpuClique( RendezvousSingle>( rendezvous_name, rendezvous_key, num_local_participants, [&] { + tsl::profiler::TraceMe trace("LockGpuClique"); ProcessGpuCliques& cliques = GetProcessGpuCliques(); - absl::MutexLock lock(&cliques.mu); - // Returns empty lock if we do not have a clique for `clique_key`. - auto it = cliques.map.find(clique_key); - return it == cliques.map.end() ? LockableGpuClique::Lock() - : it->second.Acquire(); + + // Returns nullptr if we do not have a clique for `clique_key`. + auto lockable_clique = [&]() -> LockableGpuClique* { + absl::MutexLock lock(&cliques.mu); + auto it = cliques.map.find(clique_key); + return it == cliques.map.end() ? nullptr : &it->second; + }(); + + return lockable_clique ? lockable_clique->Acquire() + : LockableGpuClique::Lock(); }, WarnStuckTimeout(), TerminateTimeout())); diff --git a/third_party/xla/xla/service/rendezvous.cc b/third_party/xla/xla/service/rendezvous.cc index 3241a4c0ac25ac..b4be7d39e1c815 100644 --- a/third_party/xla/xla/service/rendezvous.cc +++ b/third_party/xla/xla/service/rendezvous.cc @@ -16,32 +16,76 @@ limitations under the License. #include "xla/service/rendezvous.h" #include +#include #include -#include #include #include #include "absl/strings/str_format.h" -#include "absl/synchronization/notification.h" +#include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "tsl/platform/logging.h" +#include "tsl/profiler/lib/traceme.h" namespace xla { namespace internal { -void AwaitAndLogIfStuck(std::atomic& ack, absl::Notification& ready, - std::string_view name, size_t num_threads, +// Waits for the rendezvous to be ready with a timeout. Returns true if the +// rendezvous is ready, false if the timeout is exceeded. +static bool WaitForReadyWithTimeout(RendezvousStateSynchronization& state, + absl::Duration timeout) { + absl::MutexLock lock(&state.mutex); + + // Keep checking if the rendezvous is ready inside a loop and update TraceMe + // annotation to track the rendezvous progress. + while (state.ready.load() == false) { + size_t num_pending = state.num_threads - state.ack.load(); + + tsl::profiler::TraceMe trace([&] { + if (num_pending == 0) { + return absl::StrFormat("Wait for rendezvous callback"); + } else { + return absl::StrFormat("Wait %d of %d", num_pending, state.num_threads); + } + }); + + bool timed_out = state.cv.WaitWithTimeout(&state.mutex, timeout); + bool ready = state.ready.load(); + + // We are done and ready. + if (ready) return true; + + // We are done with waiting because the timeout is exceeded. + if (timed_out && !ready) { + return false; + } + + // Otherwise we keep waiting. + } + + return state.ready.load(); +} + +void AwaitAndLogIfStuck(RendezvousStateSynchronization& state, int32_t id, + std::string_view name, absl::Duration warn_stuck_timeout, absl::Duration terminate_timeout) { - if (ready.WaitForNotificationWithTimeout(warn_stuck_timeout)) { + // Wait for `warn_stuck_timeout` for the rendezvous to be ready. + if (WaitForReadyWithTimeout(state, warn_stuck_timeout)) { return; } + // If we are stuck, log a warning and add a trace annotation. + tsl::profiler::TraceMe trace([&] { + return absl::StrFormat("Stuck Waiting for %d of %d", + state.num_threads - state.ack, state.num_threads); + }); + // Check if all rendezvous participants arrived to the rendezvous point and // incremented `ack` counter. We still can be stuck because the leader is // waiting for completion of rendezvous callback, but it must not be confused // with participants not arriving to the rendezvous point. - bool is_all_participants_arrived = ack.load() == num_threads; + bool is_all_participants_arrived = state.ack.load() == state.num_threads; if (is_all_participants_arrived) { LOG(ERROR) << absl::StreamFormat( @@ -49,36 +93,41 @@ void AwaitAndLogIfStuck(std::atomic& ack, absl::Notification& ready, "stuck. All %d threads joined the rendezvous, however the leader has " "not marked the rendezvous as completed. Leader can be deadlocked " "inside the rendezvous callback.", - name, absl::ToInt64Seconds(warn_stuck_timeout), num_threads); + name, absl::ToInt64Seconds(warn_stuck_timeout), state.num_threads); } else { LOG(ERROR) << absl::StreamFormat( "This thread has been waiting for `%s` for %d seconds and may be " "stuck. Expected %d threads to join the rendezvous, but not all of " "them arrived on time.", - name, absl::ToInt64Seconds(warn_stuck_timeout), num_threads); + name, absl::ToInt64Seconds(warn_stuck_timeout), state.num_threads); } - if (ready.WaitForNotificationWithTimeout(terminate_timeout)) { + // Wait for `terminate_timeout` for the rendezvous to be ready before killing + // the process. + if (WaitForReadyWithTimeout(state, terminate_timeout)) { LOG(ERROR) << "Thread is unstuck! Warning above was a false-positive. " "Perhaps the timeout is too short."; return; } + // Check again if all participants arrived to the rendezvous point. + is_all_participants_arrived = state.ack.load() == state.num_threads; + if (is_all_participants_arrived) { LOG(FATAL) << absl::StreamFormat( "Termination timeout for `%s` of %d seconds exceeded. Exiting to " "ensure a consistent program state. All %d threads joined the " "rendezvous, however the leader has not marked the rendezvous as " "completed. Leader can be deadlocked inside the rendezvous callback.", - name, absl::ToInt64Seconds(terminate_timeout), num_threads); + name, absl::ToInt64Seconds(terminate_timeout), state.num_threads); } else { LOG(FATAL) << absl::StreamFormat( "Termination timeout for `%s` of %d seconds exceeded. Exiting to " "ensure a consistent program state. Expected %d threads to join the " "rendezvous, but not all of them arrived on time.", - name, absl::ToInt64Seconds(terminate_timeout), num_threads); + name, absl::ToInt64Seconds(terminate_timeout), state.num_threads); } } diff --git a/third_party/xla/xla/service/rendezvous.h b/third_party/xla/xla/service/rendezvous.h index 0a06c11540280c..a1b6585d07c655 100644 --- a/third_party/xla/xla/service/rendezvous.h +++ b/third_party/xla/xla/service/rendezvous.h @@ -173,20 +173,35 @@ void RendezvousSingle( namespace internal { +// A base class for rendezvous state that holds synchronization primitives. +struct RendezvousStateSynchronization { + explicit RendezvousStateSynchronization(size_t num_threads) + : num_threads(num_threads), ack(0), rel(0), ready(false) {} + + int32_t num_threads; + + std::atomic ack; + std::atomic rel; + + absl::Mutex mutex; + absl::CondVar cv; + + // Signals availability of `result`. + std::atomic ready ABSL_GUARDED_BY(mutex); +}; + // A state for a single round of rendezvous. We expect exactly `num_treads` to // arrive to a rendezvous and update corresponding slots in `values`. We // pre-allocate storage for values so at run time each participant doesn't have // to grab a lock and can simple write to the destination storage. template -struct RendezvousState { - explicit RendezvousState(size_t num_threads) - : ack(0), rel(0), values(num_threads, nullptr), result(nullptr) {} +struct RendezvousState : public RendezvousStateSynchronization { + explicit RendezvousState(size_t n_threads) + : RendezvousStateSynchronization(n_threads), + values(n_threads, nullptr), + result(nullptr) {} - std::atomic ack; - std::atomic rel; std::vector values; - - absl::Notification ready; // signals availability of `result` RendezvousResultType result; }; @@ -239,9 +254,17 @@ class RendezvousMap { return state; }(); - // Notify awaiting participants without holding a lock. + // We notify awaiting participants without holding a rendezvous map lock, as + // the rendezvous callback might be an expensive operation and might block + // the progress of concurrent rendezvous for other keys. + + // Publish rendezvous result to all participants. state->result = std::move(result); - state->ready.Notify(); + + // Notify awaiting participants that result is ready. + absl::MutexLock lock(&state->mutex); + state->ready.store(true); + state->cv.SignalAll(); } private: @@ -249,8 +272,8 @@ class RendezvousMap { absl::flat_hash_map> state_ ABSL_GUARDED_BY(mutex_); }; -void AwaitAndLogIfStuck(std::atomic& ack, absl::Notification& ready, - std::string_view name, size_t num_threads, +void AwaitAndLogIfStuck(RendezvousStateSynchronization& state, int32_t id, + std::string_view name, absl::Duration warn_stuck_timeout, absl::Duration terminate_timeout); } // namespace internal @@ -292,6 +315,9 @@ RendezvousResultType RendezvousSingle(std::string_view name, const K& key, {{"num_threads", num_threads}, {"name", name}, {"id", id}}); }); + // Signal all waiting threads that new participant has arrived. + state->cv.SignalAll(); + // std::vector::operator[] creates data races, so we rely on data pointer // here and when we create an absl::Span below. *(state->values.data() + id) = &value; @@ -304,14 +330,15 @@ RendezvousResultType RendezvousSingle(std::string_view name, const K& key, if (id < num_threads - 1) { // Threads arriving before the last one wait for a result to be computed by // the last joining thread. - internal::AwaitAndLogIfStuck(state->ack, state->ready, name, num_threads, - warn_stuck_timeout, terminate_timeout); + internal::AwaitAndLogIfStuck(*state, id, name, warn_stuck_timeout, + terminate_timeout); } else { // Last thread to arrive executes the function and completes rendezvous by // making result available to all participants. All other participants will - // be notified via `state->ready` notification when result is ready, and we - // rely on the notification to create a memory barrier that makes access to + // be notified via `state->ready` flag when result is ready, and we rely on + // the store to a flag to create a memory barrier that makes access to // `state->result` safe without any extra synchronization. + tsl::profiler::TraceMe trace("ExecuteRendezvousCallback"); absl::Span values(state->values.data(), num_threads); rendezvous.Complete(key, RendezvousResult::Wrap(fn(values))); }