Skip to content

Commit

Permalink
[xla] Add detailed tracing to RendezvousSingle
Browse files Browse the repository at this point in the history
Improve performance observability of rendezvous synchronization.

PiperOrigin-RevId: 703308803
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Dec 6, 2024
1 parent 0117e22 commit 7a89ffa
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 31 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/backends/gpu/collectives/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ cc_library(
"@local_tsl//tsl/platform:hash",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/profiler/lib:traceme",
],
)

Expand Down
28 changes: 23 additions & 5 deletions third_party/xla/xla/backends/gpu/collectives/gpu_clique_locking.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ limitations under the License.
#include "tsl/platform/hash.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::gpu {

Expand Down Expand Up @@ -208,6 +209,8 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device,
// gives access to clique communicators.
auto initialize = [&](absl::Span<const RendezvousArg* const> args)
-> absl::StatusOr<LockableGpuClique::Lock> {
tsl::profiler::TraceMe trace("InitializeGpuClique");

TF_ASSIGN_OR_RETURN(auto clique_id, clique_id_callback(clique_key));

// Check that all ranks successfully synchronized device activity before
Expand Down Expand Up @@ -348,6 +351,8 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device,
// gives access to clique communicators.
auto split = [&](absl::Span<const RankPair* const> rank_pairs)
-> absl::StatusOr<LockableGpuClique::Lock> {
tsl::profiler::TraceMe trace("SplitGpuClique");

// Collect mapping from ranks in parent clique to ranks in a new clique.
absl::btree_map<RankId, RankId> rank_mapping;
for (auto* rank_pair : rank_pairs) {
Expand Down Expand Up @@ -446,6 +451,13 @@ absl::StatusOr<std::shared_ptr<LockableGpuClique::Lock>> AcquireGpuClique(
<< "; acquired_cliques=" << acquired_cliques.size()
<< "; max_nchannels=" << max_nchannels;

tsl::profiler::TraceMe trace([&] {
return tsl::profiler::TraceMeEncode(
"AcquireGpuClique", {{"rank", rank.value()},
{"num_local_participants", num_local_participants},
{"clique_key", clique_key.ToString()}});
});

// Get the clique lock via the rendezvous to guarantee that all clique
// members participate in XLA run.
auto rendezvous_key = std::make_tuple(run_id, clique_key);
Expand All @@ -458,12 +470,18 @@ absl::StatusOr<std::shared_ptr<LockableGpuClique::Lock>> AcquireGpuClique(
RendezvousSingle<absl::StatusOr<LockableGpuClique::Lock>>(
rendezvous_name, rendezvous_key, num_local_participants,
[&] {
tsl::profiler::TraceMe trace("LockGpuClique");
ProcessGpuCliques& cliques = GetProcessGpuCliques();
absl::MutexLock lock(&cliques.mu);
// Returns empty lock if we do not have a clique for `clique_key`.
auto it = cliques.map.find(clique_key);
return it == cliques.map.end() ? LockableGpuClique::Lock()
: it->second.Acquire();

// Returns nullptr if we do not have a clique for `clique_key`.
auto lockable_clique = [&]() -> LockableGpuClique* {
absl::MutexLock lock(&cliques.mu);
auto it = cliques.map.find(clique_key);
return it == cliques.map.end() ? nullptr : &it->second;
}();

return lockable_clique ? lockable_clique->Acquire()
: LockableGpuClique::Lock();
},
WarnStuckTimeout(), TerminateTimeout()));

Expand Down
71 changes: 60 additions & 11 deletions third_party/xla/xla/service/rendezvous.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,69 +16,118 @@ limitations under the License.
#include "xla/service/rendezvous.h"

#include <atomic>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <limits>
#include <string_view>

#include "absl/strings/str_format.h"
#include "absl/synchronization/notification.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "tsl/platform/logging.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla {
namespace internal {

void AwaitAndLogIfStuck(std::atomic<int32_t>& ack, absl::Notification& ready,
std::string_view name, size_t num_threads,
// Waits for the rendezvous to be ready with a timeout. Returns true if the
// rendezvous is ready, false if the timeout is exceeded.
static bool WaitForReadyWithTimeout(RendezvousStateSynchronization& state,
absl::Duration timeout) {
absl::MutexLock lock(&state.mutex);

// Keep checking if the rendezvous is ready inside a loop and update TraceMe
// annotation to track the rendezvous progress.
while (state.ready.load() == false) {
size_t num_pending = state.num_threads - state.ack.load();

tsl::profiler::TraceMe trace([&] {
if (num_pending == 0) {
return absl::StrFormat("Wait for rendezvous callback");
} else {
return absl::StrFormat("Wait %d of %d", num_pending, state.num_threads);
}
});

bool timed_out = state.cv.WaitWithTimeout(&state.mutex, timeout);
bool ready = state.ready.load();

// We are done and ready.
if (ready) return true;

// We are done with waiting because the timeout is exceeded.
if (timed_out && !ready) {
return false;
}

// Otherwise we keep waiting.
}

return state.ready.load();
}

void AwaitAndLogIfStuck(RendezvousStateSynchronization& state, int32_t id,
std::string_view name,
absl::Duration warn_stuck_timeout,
absl::Duration terminate_timeout) {
if (ready.WaitForNotificationWithTimeout(warn_stuck_timeout)) {
// Wait for `warn_stuck_timeout` for the rendezvous to be ready.
if (WaitForReadyWithTimeout(state, warn_stuck_timeout)) {
return;
}

// If we are stuck, log a warning and add a trace annotation.
tsl::profiler::TraceMe trace([&] {
return absl::StrFormat("Stuck Waiting for %d of %d",
state.num_threads - state.ack, state.num_threads);
});

// Check if all rendezvous participants arrived to the rendezvous point and
// incremented `ack` counter. We still can be stuck because the leader is
// waiting for completion of rendezvous callback, but it must not be confused
// with participants not arriving to the rendezvous point.
bool is_all_participants_arrived = ack.load() == num_threads;
bool is_all_participants_arrived = state.ack.load() == state.num_threads;

if (is_all_participants_arrived) {
LOG(ERROR) << absl::StreamFormat(
"This thread has been waiting for `%s` for %d seconds and may be "
"stuck. All %d threads joined the rendezvous, however the leader has "
"not marked the rendezvous as completed. Leader can be deadlocked "
"inside the rendezvous callback.",
name, absl::ToInt64Seconds(warn_stuck_timeout), num_threads);
name, absl::ToInt64Seconds(warn_stuck_timeout), state.num_threads);

} else {
LOG(ERROR) << absl::StreamFormat(
"This thread has been waiting for `%s` for %d seconds and may be "
"stuck. Expected %d threads to join the rendezvous, but not all of "
"them arrived on time.",
name, absl::ToInt64Seconds(warn_stuck_timeout), num_threads);
name, absl::ToInt64Seconds(warn_stuck_timeout), state.num_threads);
}

if (ready.WaitForNotificationWithTimeout(terminate_timeout)) {
// Wait for `terminate_timeout` for the rendezvous to be ready before killing
// the process.
if (WaitForReadyWithTimeout(state, terminate_timeout)) {
LOG(ERROR) << "Thread is unstuck! Warning above was a false-positive. "
"Perhaps the timeout is too short.";
return;
}

// Check again if all participants arrived to the rendezvous point.
is_all_participants_arrived = state.ack.load() == state.num_threads;

if (is_all_participants_arrived) {
LOG(FATAL) << absl::StreamFormat(
"Termination timeout for `%s` of %d seconds exceeded. Exiting to "
"ensure a consistent program state. All %d threads joined the "
"rendezvous, however the leader has not marked the rendezvous as "
"completed. Leader can be deadlocked inside the rendezvous callback.",
name, absl::ToInt64Seconds(terminate_timeout), num_threads);
name, absl::ToInt64Seconds(terminate_timeout), state.num_threads);

} else {
LOG(FATAL) << absl::StreamFormat(
"Termination timeout for `%s` of %d seconds exceeded. Exiting to "
"ensure a consistent program state. Expected %d threads to join the "
"rendezvous, but not all of them arrived on time.",
name, absl::ToInt64Seconds(terminate_timeout), num_threads);
name, absl::ToInt64Seconds(terminate_timeout), state.num_threads);
}
}

Expand Down
57 changes: 42 additions & 15 deletions third_party/xla/xla/service/rendezvous.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,20 +173,35 @@ void RendezvousSingle(

namespace internal {

// A base class for rendezvous state that holds synchronization primitives.
struct RendezvousStateSynchronization {
explicit RendezvousStateSynchronization(size_t num_threads)
: num_threads(num_threads), ack(0), rel(0), ready(false) {}

int32_t num_threads;

std::atomic<int32_t> ack;
std::atomic<int32_t> rel;

absl::Mutex mutex;
absl::CondVar cv;

// Signals availability of `result`.
std::atomic<bool> ready ABSL_GUARDED_BY(mutex);
};

// A state for a single round of rendezvous. We expect exactly `num_treads` to
// arrive to a rendezvous and update corresponding slots in `values`. We
// pre-allocate storage for values so at run time each participant doesn't have
// to grab a lock and can simple write to the destination storage.
template <typename R, typename V>
struct RendezvousState {
explicit RendezvousState(size_t num_threads)
: ack(0), rel(0), values(num_threads, nullptr), result(nullptr) {}
struct RendezvousState : public RendezvousStateSynchronization {
explicit RendezvousState(size_t n_threads)
: RendezvousStateSynchronization(n_threads),
values(n_threads, nullptr),
result(nullptr) {}

std::atomic<int32_t> ack;
std::atomic<int32_t> rel;
std::vector<const V*> values;

absl::Notification ready; // signals availability of `result`
RendezvousResultType<R> result;
};

Expand Down Expand Up @@ -239,18 +254,26 @@ class RendezvousMap {
return state;
}();

// Notify awaiting participants without holding a lock.
// We notify awaiting participants without holding a rendezvous map lock, as
// the rendezvous callback might be an expensive operation and might block
// the progress of concurrent rendezvous for other keys.

// Publish rendezvous result to all participants.
state->result = std::move(result);
state->ready.Notify();

// Notify awaiting participants that result is ready.
absl::MutexLock lock(&state->mutex);
state->ready.store(true);
state->cv.SignalAll();
}

private:
absl::Mutex mutex_;
absl::flat_hash_map<K, std::shared_ptr<State>> state_ ABSL_GUARDED_BY(mutex_);
};

void AwaitAndLogIfStuck(std::atomic<int32_t>& ack, absl::Notification& ready,
std::string_view name, size_t num_threads,
void AwaitAndLogIfStuck(RendezvousStateSynchronization& state, int32_t id,
std::string_view name,
absl::Duration warn_stuck_timeout,
absl::Duration terminate_timeout);
} // namespace internal
Expand Down Expand Up @@ -292,6 +315,9 @@ RendezvousResultType<R> RendezvousSingle(std::string_view name, const K& key,
{{"num_threads", num_threads}, {"name", name}, {"id", id}});
});

// Signal all waiting threads that new participant has arrived.
state->cv.SignalAll();

// std::vector::operator[] creates data races, so we rely on data pointer
// here and when we create an absl::Span below.
*(state->values.data() + id) = &value;
Expand All @@ -304,14 +330,15 @@ RendezvousResultType<R> RendezvousSingle(std::string_view name, const K& key,
if (id < num_threads - 1) {
// Threads arriving before the last one wait for a result to be computed by
// the last joining thread.
internal::AwaitAndLogIfStuck(state->ack, state->ready, name, num_threads,
warn_stuck_timeout, terminate_timeout);
internal::AwaitAndLogIfStuck(*state, id, name, warn_stuck_timeout,
terminate_timeout);
} else {
// Last thread to arrive executes the function and completes rendezvous by
// making result available to all participants. All other participants will
// be notified via `state->ready` notification when result is ready, and we
// rely on the notification to create a memory barrier that makes access to
// be notified via `state->ready` flag when result is ready, and we rely on
// the store to a flag to create a memory barrier that makes access to
// `state->result` safe without any extra synchronization.
tsl::profiler::TraceMe trace("ExecuteRendezvousCallback");
absl::Span<const V*> values(state->values.data(), num_threads);
rendezvous.Complete(key, RendezvousResult<R>::Wrap(fn(values)));
}
Expand Down

0 comments on commit 7a89ffa

Please sign in to comment.