diff --git a/gloo/common/CMakeLists.txt b/gloo/common/CMakeLists.txt index 307588a89..4b8e4c5d2 100644 --- a/gloo/common/CMakeLists.txt +++ b/gloo/common/CMakeLists.txt @@ -1,6 +1,7 @@ set(GLOO_COMMON_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/logging.cc" "${CMAKE_CURRENT_SOURCE_DIR}/utils.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/error.cc" ) set(GLOO_COMMON_HDRS diff --git a/gloo/common/error.cc b/gloo/common/error.cc new file mode 100644 index 000000000..c14f38b4a --- /dev/null +++ b/gloo/common/error.cc @@ -0,0 +1,46 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include "gloo/common/error.h" + +namespace gloo { + + +std::list _cvs; +std::mutex _cvs_mutex; + +std::atomic_bool _is_aborted_flag(false); + +bool _is_aborted() { + return _is_aborted_flag.load(); +} + +void abort() { + _is_aborted_flag.store(true); + std::lock_guard guard(_cvs_mutex); + for(auto& cv : _cvs) { + if(cv != NULL) { + cv->notify_all(); + } + } + GLOO_THROW("GLOO ABORTED"); +} + +void _register_cv(std::condition_variable *cv) { + std::lock_guard guard(_cvs_mutex); + _cvs.push_back(cv); +} + +void _deregister_cv(std::condition_variable *cv) { + std::lock_guard guard(_cvs_mutex); + _cvs.remove(cv); +} +} // namespace gloo diff --git a/gloo/common/error.h b/gloo/common/error.h index f3fdab659..687e5eb67 100644 --- a/gloo/common/error.h +++ b/gloo/common/error.h @@ -10,6 +10,7 @@ #include #include +#include #include "gloo/common/string.h" @@ -20,6 +21,11 @@ namespace gloo { const std::chrono::milliseconds kNoTimeout = std::chrono::milliseconds::zero(); +bool _is_aborted(); +void abort(); +void _register_cv(std::condition_variable *cv); +void _deregister_cv(std::condition_variable *cv); + // A base class for all gloo runtime errors struct Exception : public std::runtime_error { Exception() = delete; diff --git a/gloo/transport/tcp/unbound_buffer.cc b/gloo/transport/tcp/unbound_buffer.cc index fc8fb559b..c69205bab 100644 --- a/gloo/transport/tcp/unbound_buffer.cc +++ b/gloo/transport/tcp/unbound_buffer.cc @@ -28,9 +28,15 @@ UnboundBuffer::UnboundBuffer( recvRank_(-1), sendCompletions_(0), sendRank_(-1), - shareableNonOwningPtr_(this) {} + shareableNonOwningPtr_(this) { + gloo::_register_cv(&recvCv_); + gloo::_register_cv(&sendCv_); +} -UnboundBuffer::~UnboundBuffer() {} +UnboundBuffer::~UnboundBuffer() { + gloo::_deregister_cv(&recvCv_); + gloo::_deregister_cv(&sendCv_); +} void UnboundBuffer::handleRecvCompletion(int rank) { std::lock_guard lock(m_); @@ -60,6 +66,9 @@ bool UnboundBuffer::waitRecv(int* rank, std::chrono::milliseconds timeout) { if (recvCompletions_ == 0) { auto done = recvCv_.wait_for(lock, timeout, [&] { throwIfException(); + if(gloo::_is_aborted()) { + abortWaitRecv_ = true; + } return abortWaitRecv_ || recvCompletions_ > 0; }); if (!done) { @@ -111,9 +120,12 @@ bool UnboundBuffer::waitSend(int* rank, std::chrono::milliseconds timeout) { if (sendCompletions_ == 0) { auto done = sendCv_.wait_for(lock, timeout, [&] { - throwIfException(); - return abortWaitSend_ || sendCompletions_ > 0; - }); + throwIfException(); + if(gloo::_is_aborted()) { + abortWaitSend_ = true; + } + return abortWaitSend_ || sendCompletions_ > 0; + }); if (!done) { // Below, we let all pairs in the transport context know about this // application side timeout. This in turn will call into all pending diff --git a/gloo/transport/uv/unbound_buffer.cc b/gloo/transport/uv/unbound_buffer.cc index bc9ba1c97..858d89417 100644 --- a/gloo/transport/uv/unbound_buffer.cc +++ b/gloo/transport/uv/unbound_buffer.cc @@ -28,9 +28,15 @@ UnboundBuffer::UnboundBuffer( recvRank_(-1), sendCompletions_(0), sendRank_(-1), - shareableNonOwningPtr_(this) {} + shareableNonOwningPtr_(this) { + gloo::_register_cv(&recvCv_); + gloo::_register_cv(&sendCv_); +} -UnboundBuffer::~UnboundBuffer() {} +UnboundBuffer::~UnboundBuffer() { + gloo::_deregister_cv(&recvCv_); + gloo::_deregister_cv(&sendCv_); +} void UnboundBuffer::handleRecvCompletion(int rank) { std::lock_guard lock(mutex_); @@ -58,8 +64,12 @@ bool UnboundBuffer::waitRecv(int* rank, std::chrono::milliseconds timeout) { } if (recvCompletions_ == 0) { - auto done = recvCv_.wait_for( - lock, timeout, [&] { return abortWaitRecv_ || recvCompletions_ > 0; }); + auto done = recvCv_.wait_for(lock, timeout, [&] { + if(gloo::_is_aborted()) { + abortWaitRecv_ = true; + } + return abortWaitRecv_ || recvCompletions_ > 0; + }); if (!done) { throw ::gloo::IoException(GLOO_ERROR_MSG( "Timed out waiting ", @@ -94,8 +104,12 @@ bool UnboundBuffer::waitSend(int* rank, std::chrono::milliseconds timeout) { } if (sendCompletions_ == 0) { - auto done = sendCv_.wait_for( - lock, timeout, [&] { return abortWaitSend_ || sendCompletions_ > 0; }); + auto done = sendCv_.wait_for(lock, timeout, [&] { + if(gloo::_is_aborted()) { + abortWaitSend_ = true; + } + return abortWaitSend_ || sendCompletions_ > 0; + }); if (!done) { throw ::gloo::IoException(GLOO_ERROR_MSG( "Timed out waiting ",