From 94c5dd7a1cd2173ab94759934efa4fcecc6501de Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 9 Jan 2024 13:49:21 +0100 Subject: [PATCH] Fix segmentation fault on canceled requests that are not awaited (#153) Calling `InflightRequests::cancelAll()` does not guarantee the requests to be canceled immediately, just like any other request we must check their status before releasing the `ucxx::Request` object. To prevent the user from releasing a request that has not yet completed and may still call the completion callback we must make sure we still keep a reference to it until its status is set. With this change we now ensure both `ucxx::Endpoint` and `ucxx::Worker` will only release references after those requests that have issued cancelation complete. It is also important that all requests have a valid status before destroying the object, thus we should cancel a request if all references to the Cython `UCXRequest` object have been dropped but the request has not completed yet. Additionally use `std::unique_ptr` in `ucxx::Worker` and reenable `test_ucxx_unreachable`. Authors: - Peter Andreas Entschev (https://github.com/pentschev) Approvers: - Lawrence Mitchell (https://github.com/wence-) URL: https://github.com/rapidsai/ucxx/pull/153 --- cpp/include/ucxx/inflight_requests.h | 64 +++++++++++---- cpp/include/ucxx/worker.h | 10 +-- cpp/src/endpoint.cpp | 16 ++-- cpp/src/inflight_requests.cpp | 77 +++++++++++++++---- cpp/src/worker.cpp | 27 ++++--- .../distributed_ucxx/tests/test_ucxx.py | 4 - python/ucxx/_lib/libucxx.pyx | 1 + python/ucxx/_lib/ucxx_api.pxd | 1 + 8 files changed, 143 insertions(+), 57 deletions(-) diff --git a/cpp/include/ucxx/inflight_requests.h b/cpp/include/ucxx/inflight_requests.h index a068b3d0..a1712761 100644 --- a/cpp/include/ucxx/inflight_requests.h +++ b/cpp/include/ucxx/inflight_requests.h @@ -7,6 +7,7 @@ #include #include #include +#include namespace ucxx { @@ -14,16 +15,37 @@ class Request; typedef std::map> InflightRequestsMap; typedef std::unique_ptr InflightRequestsMapPtr; +typedef struct TrackedRequests { + InflightRequestsMapPtr _inflight; + InflightRequestsMapPtr _canceling; + + TrackedRequests() + : _inflight(std::make_unique()), + _canceling(std::make_unique()) + { + } +} TrackedRequests; +typedef std::unique_ptr TrackedRequestsPtr; class InflightRequests { private: - InflightRequestsMapPtr _inflightRequests{ - std::make_unique()}; ///< Container storing pointers to all inflight - ///< requests known to the owner of this object + TrackedRequestsPtr _trackedRequests{ + std::make_unique()}; ///< Container storing pointers to all inflight + ///< and in cancelation process requests known to + ///< the owner of this object std::mutex _mutex{}; ///< Mutex to control access to inflight requests container std::mutex _cancelMutex{}; ///< Mutex to allow cancelation and prevent removing requests simultaneously + /** + * @brief Drop references to requests that completed cancelation. + * + * Drops references to requests that completed cancelation and stop tracking them. + * + * @returns The number of requests that have completed cancelation since last call. + */ + size_t dropCanceled(); + public: /** * @brief Default constructor. @@ -57,15 +79,15 @@ class InflightRequests { void insert(std::shared_ptr request); /** - * @brief Merge a container of inflight requests with the internal container. + * @brief Merge containers of inflight requests with the internal containers. * - * Merge a container of inflight requests obtained from `InflightRequests::release()` of - * another object with the internal container. + * Merge containers of inflight requests obtained from `InflightRequests::release()` of + * another object with the internal containers. * - * @param[in] inflightRequestsMap container of inflight requests to merge with the - * internal container. + * @param[in] trackedRequestsPtr containers of tracked inflight requests to merge with the + * internal tracked inflight requests. */ - void merge(InflightRequestsMapPtr inflightRequestsMap); + void merge(TrackedRequestsPtr trackedRequests); /** * @brief Remove an inflight request from the internal container. @@ -92,15 +114,27 @@ class InflightRequests { size_t cancelAll(); /** - * @brief Releases the internal container. + * @brief Releases the internally-tracked containers. + * + * Releases the internally-tracked containers that can be merged into another + * `InflightRequests` object with `InflightRequests::merge()`. Effectively leaves the + * internal state as a clean, new object. + * + * @returns The internally-tracked containers. + */ + TrackedRequestsPtr release(); + + /** + * @brief Get count of requests in process of cancelation. * - * Releases the internal container that can be merged into another `InflightRequests` - * object with `InflightRequests::release()`. Effectively leaves the internal state as a - * clean, new object. + * After `cancelAll()` is called the requests are scheduled for cancelation but may not + * complete immediately, therefore they are tracked until cancelation is completed. This + * method returns the count of requests in process of cancelation and drops references + * to those that have completed. * - * @returns The internal container. + * @returns The count of requests that are in process of cancelation. */ - InflightRequestsMapPtr release(); + size_t getCancelingSize(); }; } // namespace ucxx diff --git a/cpp/include/ucxx/worker.h b/cpp/include/ucxx/worker.h index ada15f2c..a12bb5d3 100644 --- a/cpp/include/ucxx/worker.h +++ b/cpp/include/ucxx/worker.h @@ -40,12 +40,12 @@ class Worker : public Component { int _epollFileDescriptor{-1}; ///< The epoll file descriptor int _workerFileDescriptor{-1}; ///< The worker file descriptor std::mutex _inflightRequestsMutex{}; ///< Mutex to access the inflight requests pool - std::shared_ptr _inflightRequests{ - std::make_shared()}; ///< The inflight requests + std::unique_ptr _inflightRequests{ + std::make_unique()}; ///< The inflight requests std::mutex _inflightRequestsToCancelMutex{}; ///< Mutex to access the inflight requests to cancel pool - std::shared_ptr _inflightRequestsToCancel{ - std::make_shared()}; ///< The inflight requests scheduled to be canceled + std::unique_ptr _inflightRequestsToCancel{ + std::make_unique()}; ///< The inflight requests scheduled to be canceled std::shared_ptr _progressThread{nullptr}; ///< The progress thread object std::thread::id _progressThreadId{}; ///< The progress thread ID std::function _progressThreadStartCallback{ @@ -601,7 +601,7 @@ class Worker : public Component { * * @param[in] inflight requests object that implements the `cancelAll()` method. */ - void scheduleRequestCancel(std::shared_ptr inflightRequests); + void scheduleRequestCancel(TrackedRequestsPtr trackedRequests); /** * @brief Remove reference to request from internal container. diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 5e40f028..c5bccf40 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -258,7 +258,7 @@ std::shared_ptr Endpoint::registerInflightRequest(std::shared_ptrstatus != UCS_OK) - _callbackData->worker->scheduleRequestCancel(_inflightRequests); + _callbackData->worker->scheduleRequestCancel(_inflightRequests->release()); return request; } @@ -275,22 +275,24 @@ size_t Endpoint::cancelInflightRequests(uint64_t period, uint64_t maxAttempts) if (std::this_thread::get_id() == worker->getProgressThreadId()) { canceled = _inflightRequests->cancelAll(); - worker->progress(); + for (uint64_t i = 0; i < maxAttempts && _inflightRequests->getCancelingSize() > 0; ++i) + worker->progress(); } else if (worker->isProgressThreadRunning()) { bool cancelSuccess = false; for (uint64_t i = 0; i < maxAttempts && !cancelSuccess; ++i) { utils::CallbackNotifier callbackNotifierPre{}; worker->registerGenericPre([this, &callbackNotifierPre, &canceled]() { - canceled = _inflightRequests->cancelAll(); + canceled += _inflightRequests->cancelAll(); callbackNotifierPre.set(); }); if (!callbackNotifierPre.wait(period)) continue; utils::CallbackNotifier callbackNotifierPost{}; - worker->registerGenericPost([&callbackNotifierPost]() { callbackNotifierPost.set(); }); + worker->registerGenericPost([this, &callbackNotifierPost, &cancelSuccess]() { + cancelSuccess = _inflightRequests->getCancelingSize() == 0; + callbackNotifierPost.set(); + }); if (!callbackNotifierPost.wait(period)) continue; - - cancelSuccess = true; } if (!cancelSuccess) ucxx_error("All attempts to cancel inflight requests failed on endpoint: %p, UCP handle: %p", @@ -402,7 +404,7 @@ void Endpoint::errorCallback(void* arg, ucp_ep_h ep, ucs_status_t status) { ErrorCallbackData* data = reinterpret_cast(arg); data->status = status; - data->worker->scheduleRequestCancel(data->inflightRequests); + data->worker->scheduleRequestCancel(data->inflightRequests->release()); if (data->closeCallback) { ucxx_debug("Calling user callback for endpoint %p", ep); data->closeCallback(data->closeCallbackArg); diff --git a/cpp/src/inflight_requests.cpp b/cpp/src/inflight_requests.cpp index 47d51581..50ada0c8 100644 --- a/cpp/src/inflight_requests.cpp +++ b/cpp/src/inflight_requests.cpp @@ -12,20 +12,25 @@ namespace ucxx { InflightRequests::~InflightRequests() { cancelAll(); } -size_t InflightRequests::size() { return _inflightRequests->size(); } +size_t InflightRequests::size() { return _trackedRequests->_inflight->size(); } void InflightRequests::insert(std::shared_ptr request) { std::lock_guard lock(_mutex); - _inflightRequests->insert({request.get(), request}); + _trackedRequests->_inflight->insert({request.get(), request}); } -void InflightRequests::merge(InflightRequestsMapPtr inflightRequestsMap) +void InflightRequests::merge(TrackedRequestsPtr trackedRequests) { - std::lock_guard lock(_mutex); - - _inflightRequests->merge(*inflightRequestsMap); + { + std::lock_guard lock(_mutex); + _trackedRequests->_inflight->merge(*(trackedRequests->_inflight)); + } + { + std::lock_guard lock(_cancelMutex); + _trackedRequests->_canceling->merge(*(trackedRequests->_canceling)); + } } void InflightRequests::remove(const Request* const request) @@ -46,19 +51,19 @@ void InflightRequests::remove(const Request* const request) if (result == 0) { return; } else if (result == -1) { - auto search = _inflightRequests->find(request); + auto search = _trackedRequests->_inflight->find(request); decltype(search->second) tmpRequest; - if (search != _inflightRequests->end()) { + if (search != _trackedRequests->_inflight->end()) { /** * If this is the last request to hold `std::shared_ptr` erasing it * may cause the `ucxx::Endpoint`s destructor and subsequently the `close()` method * to be called which will in turn call `cancelAll()` and attempt to take the * mutexes. For this reason we should make a temporary copy of the request being - * erased from `_inflightRequests` to allow unlocking the mutexes and only then + * erased from `_trackedRequests->_inflight` to allow unlocking the mutexes and only then * destroy the object upon this method's return. */ tmpRequest = search->second; - _inflightRequests->erase(search); + _trackedRequests->_inflight->erase(search); } _cancelMutex.unlock(); _mutex.unlock(); @@ -67,19 +72,52 @@ void InflightRequests::remove(const Request* const request) } while (true); } +size_t InflightRequests::dropCanceled() +{ + size_t removed = 0; + + { + std::scoped_lock lock{_cancelMutex}; + for (auto it = _trackedRequests->_canceling->begin(); + it != _trackedRequests->_canceling->end();) { + auto request = it->second; + if (request != nullptr && request->getStatus() != UCS_INPROGRESS) { + it = _trackedRequests->_canceling->erase(it); + ++removed; + } else { + ++it; + } + } + } + + return removed; +} + +size_t InflightRequests::getCancelingSize() +{ + dropCanceled(); + size_t cancelingSize = 0; + { + std::scoped_lock lock{_cancelMutex}; + cancelingSize = _trackedRequests->_canceling->size(); + } + + return cancelingSize; +} + size_t InflightRequests::cancelAll() { - decltype(_inflightRequests) toCancel; + decltype(_trackedRequests->_inflight) toCancel; size_t total; { std::scoped_lock lock{_cancelMutex, _mutex}; - total = _inflightRequests->size(); + total = _trackedRequests->_inflight->size(); // Fast path when no requests have been registered or the map has been // previously released. if (total == 0) return 0; - toCancel = std::exchange(_inflightRequests, std::make_unique()); + toCancel = std::exchange(_trackedRequests->_inflight, std::make_unique()); } ucxx_debug("Canceling %lu requests", total); @@ -88,16 +126,21 @@ size_t InflightRequests::cancelAll() auto request = r.second; if (request != nullptr) { request->cancel(); } } - toCancel->clear(); + + { + std::scoped_lock lock{_cancelMutex, _mutex}; + _trackedRequests->_canceling->merge(*toCancel); + } + dropCanceled(); return total; } -InflightRequestsMapPtr InflightRequests::release() +TrackedRequestsPtr InflightRequests::release() { - std::lock_guard lock(_mutex); + std::scoped_lock lock{_cancelMutex, _mutex}; - return std::exchange(_inflightRequests, std::make_unique()); + return std::exchange(_trackedRequests, std::make_unique()); } } // namespace ucxx diff --git a/cpp/src/worker.cpp b/cpp/src/worker.cpp index 8ccf8a89..71ad2afe 100644 --- a/cpp/src/worker.cpp +++ b/cpp/src/worker.cpp @@ -403,7 +403,7 @@ size_t Worker::cancelInflightRequests(uint64_t period, uint64_t maxAttempts) { size_t canceled = 0; - auto inflightRequestsToCancel = std::make_shared(); + auto inflightRequestsToCancel = std::make_unique(); { std::lock_guard lock(_inflightRequestsMutex); std::swap(_inflightRequestsToCancel, inflightRequestsToCancel); @@ -411,22 +411,25 @@ size_t Worker::cancelInflightRequests(uint64_t period, uint64_t maxAttempts) if (std::this_thread::get_id() == getProgressThreadId()) { canceled = inflightRequestsToCancel->cancelAll(); - progressPending(); + for (uint64_t i = 0; i < maxAttempts && inflightRequestsToCancel->getCancelingSize() > 0; ++i) + progressPending(); } else if (isProgressThreadRunning()) { bool cancelSuccess = false; for (uint64_t i = 0; i < maxAttempts && !cancelSuccess; ++i) { utils::CallbackNotifier callbackNotifierPre{}; registerGenericPre([&callbackNotifierPre, &canceled, &inflightRequestsToCancel]() { - canceled = inflightRequestsToCancel->cancelAll(); + canceled += inflightRequestsToCancel->cancelAll(); callbackNotifierPre.set(); }); if (!callbackNotifierPre.wait(period)) continue; utils::CallbackNotifier callbackNotifierPost{}; - registerGenericPost([&callbackNotifierPost]() { callbackNotifierPost.set(); }); + registerGenericPost( + [this, &callbackNotifierPost, &inflightRequestsToCancel, &cancelSuccess]() { + cancelSuccess = inflightRequestsToCancel->getCancelingSize() == 0; + callbackNotifierPost.set(); + }); if (!callbackNotifierPost.wait(period)) continue; - - cancelSuccess = true; } if (!cancelSuccess) @@ -437,15 +440,21 @@ size_t Worker::cancelInflightRequests(uint64_t period, uint64_t maxAttempts) canceled = inflightRequestsToCancel->cancelAll(); } + if (inflightRequestsToCancel->getCancelingSize() > 0) { + std::lock_guard lock(_inflightRequestsMutex); + _inflightRequestsToCancel->merge(inflightRequestsToCancel->release()); + } + return canceled; } -void Worker::scheduleRequestCancel(std::shared_ptr inflightRequests) +void Worker::scheduleRequestCancel(TrackedRequestsPtr trackedRequests) { { std::lock_guard lock(_inflightRequestsMutex); - ucxx_debug("Scheduling cancelation of %lu requests", inflightRequests->size()); - _inflightRequestsToCancel->merge(inflightRequests->release()); + ucxx_debug("Scheduling cancelation of %lu requests", + trackedRequests->_inflight->size() + trackedRequests->_canceling->size()); + _inflightRequestsToCancel->merge(std::move(trackedRequests)); } } diff --git a/python/distributed-ucxx/distributed_ucxx/tests/test_ucxx.py b/python/distributed-ucxx/distributed_ucxx/tests/test_ucxx.py index 91af3dad..16f7116a 100644 --- a/python/distributed-ucxx/distributed_ucxx/tests/test_ucxx.py +++ b/python/distributed-ucxx/distributed_ucxx/tests/test_ucxx.py @@ -393,10 +393,6 @@ async def test_ucxx_protocol(ucxx_loop, cleanup, port): @gen_test() -@pytest.mark.skipif( - int(os.environ.get("UCXPY_ENABLE_PYTHON_FUTURE", "1")) != 0, - reason="Segfaults when Python futures are enabled", -) async def test_ucxx_unreachable( ucxx_loop, ): diff --git a/python/ucxx/_lib/libucxx.pyx b/python/ucxx/_lib/libucxx.pyx index 49ff590d..57534320 100644 --- a/python/ucxx/_lib/libucxx.pyx +++ b/python/ucxx/_lib/libucxx.pyx @@ -691,6 +691,7 @@ cdef class UCXRequest(): def __dealloc__(self): with nogil: + self._request.get().cancel() self._request.reset() def is_completed(self): diff --git a/python/ucxx/_lib/ucxx_api.pxd b/python/ucxx/_lib/ucxx_api.pxd index b0ff3476..39fa86b3 100644 --- a/python/ucxx/_lib/ucxx_api.pxd +++ b/python/ucxx/_lib/ucxx_api.pxd @@ -327,6 +327,7 @@ cdef extern from "" namespace "ucxx" nogil: void checkError() except +raise_py_error void* getFuture() except +raise_py_error shared_ptr[Buffer] getRecvBuffer() except +raise_py_error + void cancel() cdef extern from "" namespace "ucxx" nogil: