Skip to content

Commit

Permalink
implement gloo abort
Browse files Browse the repository at this point in the history
  • Loading branch information
Aidyn-A committed Nov 5, 2024
1 parent 43b7acb commit 0ad4df8
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 8 deletions.
1 change: 1 addition & 0 deletions gloo/common/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
set(GLOO_COMMON_SRCS
"${CMAKE_CURRENT_SOURCE_DIR}/logging.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/utils.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/error.cc"
)

set(GLOO_COMMON_HDRS
Expand Down
46 changes: 46 additions & 0 deletions gloo/common/error.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/**
* Copyright (c) 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <atomic>
#include <list>

#include "gloo/common/error.h"

namespace gloo {


std::list<std::condition_variable *> _cvs;
std::mutex _cvs_mutex;

std::atomic_bool _is_aborted_flag(false);

bool _is_aborted() {
return _is_aborted_flag.load();
}

void abort() {
_is_aborted_flag.store(true);
std::lock_guard<std::mutex> guard(_cvs_mutex);
for(auto& cv : _cvs) {
if(cv != NULL) {
cv->notify_all();
}
}
GLOO_THROW("GLOO ABORTED");
}

void _register_cv(std::condition_variable *cv) {
std::lock_guard<std::mutex> guard(_cvs_mutex);
_cvs.push_back(cv);
}

void _deregister_cv(std::condition_variable *cv) {
std::lock_guard<std::mutex> guard(_cvs_mutex);
_cvs.remove(cv);
}
} // namespace gloo
6 changes: 6 additions & 0 deletions gloo/common/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <chrono>
#include <exception>
#include <condition_variable>

#include "gloo/common/string.h"

Expand All @@ -20,6 +21,11 @@ namespace gloo {

const std::chrono::milliseconds kNoTimeout = std::chrono::milliseconds::zero();

bool _is_aborted();
void abort();
void _register_cv(std::condition_variable *cv);
void _deregister_cv(std::condition_variable *cv);

// A base class for all gloo runtime errors
struct Exception : public std::runtime_error {
Exception() = delete;
Expand Down
16 changes: 14 additions & 2 deletions gloo/transport/tcp/unbound_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,15 @@ UnboundBuffer::UnboundBuffer(
recvRank_(-1),
sendCompletions_(0),
sendRank_(-1),
shareableNonOwningPtr_(this) {}
shareableNonOwningPtr_(this) {
gloo::_register_cv(&recvCv_);
gloo::_register_cv(&sendCv_);
}

UnboundBuffer::~UnboundBuffer() {}
UnboundBuffer::~UnboundBuffer() {
gloo::_deregister_cv(&recvCv_);
gloo::_deregister_cv(&sendCv_);
}

void UnboundBuffer::handleRecvCompletion(int rank) {
std::lock_guard<std::mutex> lock(m_);
Expand Down Expand Up @@ -60,6 +66,9 @@ bool UnboundBuffer::waitRecv(int* rank, std::chrono::milliseconds timeout) {
if (recvCompletions_ == 0) {
auto done = recvCv_.wait_for(lock, timeout, [&] {
throwIfException();
if(gloo::_is_aborted()) {
abortWaitRecv_ = true;
}
return abortWaitRecv_ || recvCompletions_ > 0;
});
if (!done) {
Expand Down Expand Up @@ -111,6 +120,9 @@ bool UnboundBuffer::waitSend(int* rank, std::chrono::milliseconds timeout) {
if (sendCompletions_ == 0) {
auto done = sendCv_.wait_for(lock, timeout, [&] {
throwIfException();
if(gloo::_is_aborted()) {
abortWaitSend_ = true;
}
return abortWaitSend_ || sendCompletions_ > 0;
});
if (!done) {
Expand Down
26 changes: 20 additions & 6 deletions gloo/transport/uv/unbound_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,15 @@ UnboundBuffer::UnboundBuffer(
recvRank_(-1),
sendCompletions_(0),
sendRank_(-1),
shareableNonOwningPtr_(this) {}
shareableNonOwningPtr_(this) {
gloo::_register_cv(&recvCv_);
gloo::_register_cv(&sendCv_);
}

UnboundBuffer::~UnboundBuffer() {}
UnboundBuffer::~UnboundBuffer() {
gloo::_deregister_cv(&recvCv_);
gloo::_deregister_cv(&sendCv_);
}

void UnboundBuffer::handleRecvCompletion(int rank) {
std::lock_guard<std::mutex> lock(mutex_);
Expand Down Expand Up @@ -58,8 +64,12 @@ bool UnboundBuffer::waitRecv(int* rank, std::chrono::milliseconds timeout) {
}

if (recvCompletions_ == 0) {
auto done = recvCv_.wait_for(
lock, timeout, [&] { return abortWaitRecv_ || recvCompletions_ > 0; });
auto done = recvCv_.wait_for(lock, timeout, [&] {
if(gloo::_is_aborted()) {
abortWaitRecv_ = true;
}
return abortWaitRecv_ || recvCompletions_ > 0;
});
if (!done) {
throw ::gloo::IoException(GLOO_ERROR_MSG(
"Timed out waiting ",
Expand Down Expand Up @@ -94,8 +104,12 @@ bool UnboundBuffer::waitSend(int* rank, std::chrono::milliseconds timeout) {
}

if (sendCompletions_ == 0) {
auto done = sendCv_.wait_for(
lock, timeout, [&] { return abortWaitSend_ || sendCompletions_ > 0; });
auto done = sendCv_.wait_for(lock, timeout, [&] {
if(gloo::_is_aborted()) {
abortWaitSend_ = true;
}
return abortWaitSend_ || sendCompletions_ > 0;
});
if (!done) {
throw ::gloo::IoException(GLOO_ERROR_MSG(
"Timed out waiting ",
Expand Down

0 comments on commit 0ad4df8

Please sign in to comment.