Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Implement gloo abort for graceful shutdown #388

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gloo/common/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
set(GLOO_COMMON_SRCS
"${CMAKE_CURRENT_SOURCE_DIR}/logging.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/utils.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/error.cc"
)

set(GLOO_COMMON_HDRS
Expand Down
46 changes: 46 additions & 0 deletions gloo/common/error.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/**
* Copyright (c) 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <atomic>
#include <list>

#include "gloo/common/error.h"

namespace gloo {


std::list<std::condition_variable *> _cvs;
std::mutex _cvs_mutex;

std::atomic_bool _is_aborted_flag(false);

bool _is_aborted() {
return _is_aborted_flag.load();
}

void abort() {
_is_aborted_flag.store(true);
std::lock_guard<std::mutex> guard(_cvs_mutex);
for(auto& cv : _cvs) {
if(cv != NULL) {
cv->notify_all();
}
}
GLOO_THROW("GLOO ABORTED");
}

void _register_cv(std::condition_variable *cv) {
std::lock_guard<std::mutex> guard(_cvs_mutex);
_cvs.push_back(cv);
}

void _deregister_cv(std::condition_variable *cv) {
std::lock_guard<std::mutex> guard(_cvs_mutex);
_cvs.remove(cv);
}
} // namespace gloo
6 changes: 6 additions & 0 deletions gloo/common/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <chrono>
#include <exception>
#include <condition_variable>

#include "gloo/common/string.h"

Expand All @@ -20,6 +21,11 @@ namespace gloo {

const std::chrono::milliseconds kNoTimeout = std::chrono::milliseconds::zero();

bool _is_aborted();
void abort();
void _register_cv(std::condition_variable *cv);
void _deregister_cv(std::condition_variable *cv);

// A base class for all gloo runtime errors
struct Exception : public std::runtime_error {
Exception() = delete;
Expand Down
1 change: 1 addition & 0 deletions gloo/test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
find_package(OpenSSL 1.1 REQUIRED EXACT)

set(GLOO_TEST_SRCS
"${CMAKE_CURRENT_SOURCE_DIR}/abort_test.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/allgather_test.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/allgatherv_test.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/allreduce_test.cc"
Expand Down
80 changes: 80 additions & 0 deletions gloo/test/abort_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/**
* Copyright (c) 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <functional>
#include <thread>
#include <vector>

#include "gloo/barrier_all_to_all.h"
#include "gloo/barrier_all_to_one.h"
#include "gloo/broadcast.h"
#include "gloo/test/base_test.h"

namespace gloo {
namespace test {
namespace {

// Synchronized version of std::chrono::clock::now().
// All processes participating in the specified context will
// see the same value.
template <typename clock>
std::chrono::time_point<clock> syncNow(std::shared_ptr<Context> context) {
const typename clock::time_point now = clock::now();
typename clock::duration::rep count = now.time_since_epoch().count();
BroadcastOptions opts(context);
opts.setRoot(0);
opts.setOutput(&count, 1);
broadcast(opts);
return typename clock::time_point(typename clock::duration(count));
}

using NewParam = std::tuple<Transport, int>;

class AbortBarrierTest : public BaseTest,
public ::testing::WithParamInterface<NewParam> {};

TEST_P(AbortBarrierTest, Default) {
const auto transport = std::get<0>(GetParam());
const auto contextSize = std::get<1>(GetParam());

spawn(transport, contextSize, [&](std::shared_ptr<Context> context) {
BarrierOptions opts(context);

// Run barrier to synchronize processes after starting.
barrier(opts);

auto timeout = std::chrono::milliseconds(context->getTimeout());
const auto start = syncNow<std::chrono::high_resolution_clock>(context);
// Run barrier on all ranks but 0 so it hangs
if (context->rank != 0) {
barrier(opts);
}

// Abort should unhang the barrier
try {
abort();
} catch (const Exception &e) {
EXPECT_TRUE(strstr(e.what(), "GLOO ABORTED") != NULL);
}

// Expect all processes to have taken less than the timeout, as abort was
// called
auto stop = std::chrono::high_resolution_clock::now();
auto delta = std::chrono::duration_cast<decltype(timeout)>(stop - start);
ASSERT_LE(delta.count(), timeout.count() / 4);
});
}

INSTANTIATE_TEST_CASE_P(
AbortBarrier, AbortBarrierTest,
::testing::Combine(::testing::ValuesIn(kTransportsForFunctionAlgorithms),
::testing::Values(1, 2, 4, 7)));

} // namespace
} // namespace test
} // namespace gloo
16 changes: 14 additions & 2 deletions gloo/transport/tcp/unbound_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,15 @@ UnboundBuffer::UnboundBuffer(
recvRank_(-1),
sendCompletions_(0),
sendRank_(-1),
shareableNonOwningPtr_(this) {}
shareableNonOwningPtr_(this) {
gloo::_register_cv(&recvCv_);
gloo::_register_cv(&sendCv_);
}

UnboundBuffer::~UnboundBuffer() {}
UnboundBuffer::~UnboundBuffer() {
gloo::_deregister_cv(&recvCv_);
gloo::_deregister_cv(&sendCv_);
}

void UnboundBuffer::handleRecvCompletion(int rank) {
std::lock_guard<std::mutex> lock(m_);
Expand Down Expand Up @@ -58,6 +64,9 @@ bool UnboundBuffer::waitRecv(int* rank, std::chrono::milliseconds timeout) {
if (recvCompletions_ == 0) {
auto done = recvCv_.wait_for(lock, timeout, [&] {
throwIfException();
if(gloo::_is_aborted()) {
abortWaitRecv_ = true;
}
return abortWaitRecv_ || recvCompletions_ > 0;
});
if (!done) {
Expand Down Expand Up @@ -109,6 +118,9 @@ bool UnboundBuffer::waitSend(int* rank, std::chrono::milliseconds timeout) {
if (sendCompletions_ == 0) {
auto done = sendCv_.wait_for(lock, timeout, [&] {
throwIfException();
if(gloo::_is_aborted()) {
abortWaitSend_ = true;
}
return abortWaitSend_ || sendCompletions_ > 0;
});
if (!done) {
Expand Down
26 changes: 20 additions & 6 deletions gloo/transport/uv/unbound_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,15 @@ UnboundBuffer::UnboundBuffer(
recvRank_(-1),
sendCompletions_(0),
sendRank_(-1),
shareableNonOwningPtr_(this) {}
shareableNonOwningPtr_(this) {
gloo::_register_cv(&recvCv_);
gloo::_register_cv(&sendCv_);
}

UnboundBuffer::~UnboundBuffer() {}
UnboundBuffer::~UnboundBuffer() {
gloo::_deregister_cv(&recvCv_);
gloo::_deregister_cv(&sendCv_);
}

void UnboundBuffer::handleRecvCompletion(int rank) {
std::lock_guard<std::mutex> lock(mutex_);
Expand Down Expand Up @@ -56,8 +62,12 @@ bool UnboundBuffer::waitRecv(int* rank, std::chrono::milliseconds timeout) {
}

if (recvCompletions_ == 0) {
auto done = recvCv_.wait_for(
lock, timeout, [&] { return abortWaitRecv_ || recvCompletions_ > 0; });
auto done = recvCv_.wait_for(lock, timeout, [&] {
if(gloo::_is_aborted()) {
abortWaitRecv_ = true;
}
return abortWaitRecv_ || recvCompletions_ > 0;
});
if (!done) {
throw ::gloo::IoException(GLOO_ERROR_MSG(
"Timed out waiting ",
Expand Down Expand Up @@ -92,8 +102,12 @@ bool UnboundBuffer::waitSend(int* rank, std::chrono::milliseconds timeout) {
}

if (sendCompletions_ == 0) {
auto done = sendCv_.wait_for(
lock, timeout, [&] { return abortWaitSend_ || sendCompletions_ > 0; });
auto done = sendCv_.wait_for(lock, timeout, [&] {
if(gloo::_is_aborted()) {
abortWaitSend_ = true;
}
return abortWaitSend_ || sendCompletions_ > 0;
});
if (!done) {
throw ::gloo::IoException(GLOO_ERROR_MSG(
"Timed out waiting ",
Expand Down
Loading