From 94bd63ef3d691e08449273378f8ee2e6bb4a64ba Mon Sep 17 00:00:00 2001 From: Chris Bradley Date: Fri, 29 Nov 2024 18:36:28 +0000 Subject: [PATCH] Refactor tasks and make cancellable --- src/gribjump/Engine.cc | 8 +-- src/gribjump/Forwarder.cc | 7 +-- src/gribjump/Task.cc | 98 ++++++++++++++++++++------------ src/gribjump/Task.h | 75 ++++++++++++++++-------- src/gribjump/remote/WorkItem.h | 3 +- src/gribjump/remote/WorkQueue.cc | 3 +- src/gribjump/remote/WorkQueue.h | 3 +- 7 files changed, 123 insertions(+), 74 deletions(-) diff --git a/src/gribjump/Engine.cc b/src/gribjump/Engine.cc index 067399f..7d07b6f 100644 --- a/src/gribjump/Engine.cc +++ b/src/gribjump/Engine.cc @@ -114,18 +114,17 @@ TaskReport Engine::scheduleExtractionTasks(filemap_t& filemap){ bool inefficientExtraction = LibGribJump::instance().config().getBool("inefficientExtraction", false); - size_t counter = 0; TaskGroup taskGroup; for (auto& [fname, extractionItems] : filemap) { if (extractionItems[0]->isRemote()) { if (inefficientExtraction) { - taskGroup.enqueueTask(new InefficientFileExtractionTask(taskGroup, counter++, fname, extractionItems)); + taskGroup.enqueueTask(fname, extractionItems); } else { throw eckit::SeriousBug("Got remote URI from FDB, but forwardExtraction enabled in gribjump config."); } } else { - taskGroup.enqueueTask(new FileExtractionTask(taskGroup, counter++, fname, extractionItems)); + taskGroup.enqueueTask(fname, extractionItems); } } taskGroup.waitForTasks(); @@ -208,11 +207,10 @@ TaskOutcome Engine::scan(std::vector files) { TaskOutcome Engine::scheduleScanTasks(const scanmap_t& scanmap) { - size_t counter = 0; std::atomic nfields(0); TaskGroup taskGroup; for (auto& [uri, offsets] : scanmap) { - taskGroup.enqueueTask(new FileScanTask(taskGroup, counter++, uri.path(), offsets, nfields)); + taskGroup.enqueueTask(uri.path(), offsets, nfields); } taskGroup.waitForTasks(); diff --git a/src/gribjump/Forwarder.cc b/src/gribjump/Forwarder.cc index e2088e9..a6dd261 100644 --- a/src/gribjump/Forwarder.cc +++ b/src/gribjump/Forwarder.cc @@ -42,11 +42,9 @@ TaskOutcome Forwarder::scan(const std::vector& uris) { } TaskGroup taskGroup; - size_t counter = 0; std::atomic nFields(0); - size_t i = 0; for (auto& [endpoint, scanmap] : serverfilemaps) { - taskGroup.enqueueTask(new ForwardScanTask(taskGroup, counter++, endpoint, scanmap, nFields)); + taskGroup.enqueueTask(endpoint, scanmap, nFields); } taskGroup.waitForTasks(); @@ -57,9 +55,8 @@ TaskReport Forwarder::extract(filemap_t& filemap) { std::unordered_map serverfilemaps = serverFileMap(filemap); TaskGroup taskGroup; - size_t counter = 0; for (auto& [endpoint, subfilemap] : serverfilemaps) { - taskGroup.enqueueTask(new ForwardExtractionTask(taskGroup, counter++, endpoint, subfilemap)); + taskGroup.enqueueTask(endpoint, subfilemap); } taskGroup.waitForTasks(); diff --git a/src/gribjump/Task.cc b/src/gribjump/Task.cc index da450b6..54bc4fb 100644 --- a/src/gribjump/Task.cc +++ b/src/gribjump/Task.cc @@ -30,12 +30,7 @@ namespace gribjump { -static std::string thread_id_str() { - auto id = std::this_thread::get_id(); - std::stringstream ss; - ss << id; - return ss.str(); -} +constexpr bool cancelOnFirstError = true; ///@todo make this configurable //---------------------------------------------------------------------------------------------------------------------- @@ -45,24 +40,46 @@ Task::Task(TaskGroup& taskGroup, size_t taskid) : taskGroup_(taskGroup), taskid_ Task::~Task() {} void Task::notify() { + status_ = Status::DONE; taskGroup_.notify(id()); } void Task::notifyError(const std::string& s) { + status_ = Status::FAILED; taskGroup_.notifyError(id(), s); } +void Task::notifyCancelled() { + status_ = Status::CANCELLED; + taskGroup_.notifyCancelled(id()); +} + +void Task::execute() { + // atomically set status to executing, but only if it is currently pending + Status expected = Status::PENDING; + if (!status_.compare_exchange_strong(expected, Status::EXECUTING)) { + return; + } + executeImpl(); + notify(); +} + +void Task::cancel() { + // atomically set status to cancelled, but only if it is pending + Status expected = Status::PENDING; + status_.compare_exchange_strong(expected, Status::CANCELLED); +} + //---------------------------------------------------------------------------------------------------------------------- void TaskGroup::notify(size_t taskid) { std::lock_guard lock(m_); - taskStatus_[taskid] = Task::Status::DONE; - counter_++; + nComplete_++; // Logging progress if (waiting_) { - if (counter_ == logcounter_) { - eckit::Log::info() << "Gribjump Progress: " << counter_ << " of " << taskStatus_.size() << " tasks complete" << std::endl; + if (nComplete_ == logcounter_) { + eckit::Log::info() << "Gribjump Progress: " << nComplete_ << " of " << tasks_.size() << " tasks complete" << std::endl; logcounter_ += logincrement_; } } @@ -70,40 +87,59 @@ void TaskGroup::notify(size_t taskid) { cv_.notify_one(); } +void TaskGroup::notifyCancelled(size_t taskid) { + std::lock_guard lock(m_); + nComplete_++; + nCancelledTasks_++; + cv_.notify_one(); +} void TaskGroup::notifyError(size_t taskid, const std::string& s) { std::lock_guard lock(m_); - taskStatus_[taskid] = Task::Status::FAILED; errors_.push_back(s); - counter_++; + nComplete_++; cv_.notify_one(); + + if (cancelOnFirstError) { + cancelTasks(); + } +} + +// Note: This will only affect tasks that have not yet started. Cancelled tasks will call notifyCancelled() when they are executed. +// NB: We do not lock a mutex as this will be called from notifyError() +void TaskGroup::cancelTasks() { + for (auto& task : tasks_) { + task->cancel(); + } } void TaskGroup::enqueueTask(Task* task) { - taskStatus_.push_back(Task::Status::PENDING); - WorkItem w(task); - WorkQueue& queue = WorkQueue::instance(); - queue.push(w); + std::lock_guard lock(m_); + tasks_.push_back(std::unique_ptr(task)); // TaskGroup takes ownership of its tasks + WorkQueue::instance().push(task); + LOG_DEBUG_LIB(LibGribJump) << "Queued task " << tasks_.size() << std::endl; } void TaskGroup::waitForTasks() { - ASSERT(taskStatus_.size() > 0); // todo Might want to allow for "no tasks" case, though be careful with the lock / counter. - LOG_DEBUG_LIB(LibGribJump) << "Waiting for " << eckit::Plural(taskStatus_.size(), "task") << "..." << std::endl; std::unique_lock lock(m_); + ASSERT(tasks_.size() > 0); + LOG_DEBUG_LIB(LibGribJump) << "Waiting for " << eckit::Plural(tasks_.size(), "task") << "..." << std::endl; waiting_ = true; - logincrement_ = taskStatus_.size() / 10; + logincrement_ = tasks_.size() / 10; if (logincrement_ == 0) { logincrement_ = 1; } - cv_.wait(lock, [&]{return counter_ == taskStatus_.size();}); + cv_.wait(lock, [&]{return nComplete_ == tasks_.size();}); waiting_ = false; + done_ = true; LOG_DEBUG_LIB(LibGribJump) << "All tasks complete" << std::endl; - MetricsManager::instance().set("count_tasks", taskStatus_.size()); + MetricsManager::instance().set("count_tasks", tasks_.size()); MetricsManager::instance().set("count_failed_tasks", errors_.size()); + MetricsManager::instance().set("count_cancelled_tasks", nCancelledTasks_); if (errors_.size() > 0) { MetricsManager::instance().set("first_error", errors_[0]); @@ -143,9 +179,7 @@ FileExtractionTask::FileExtractionTask(TaskGroup& taskgroup, const size_t id, co { } -void FileExtractionTask::execute() { - const std::string thread_id = thread_id_str(); - eckit::Timer full_timer("Thread total time. Thread: " + thread_id, eckit::Log::debug()); +void FileExtractionTask::executeImpl() { // Sort extractionItems_ by offset std::sort(extractionItems_.begin(), extractionItems_.end(), [](const ExtractionItem* a, const ExtractionItem* b) { @@ -153,8 +187,6 @@ void FileExtractionTask::execute() { }); extract(); - - notify(); } void FileExtractionTask::extract() { @@ -201,12 +233,10 @@ ForwardExtractionTask::ForwardExtractionTask(TaskGroup& taskgroup, const size_t filemap_(filemap) {} -void ForwardExtractionTask::execute(){ +void ForwardExtractionTask::executeImpl(){ RemoteGribJump remoteGribJump(endpoint_); remoteGribJump.forwardExtract(filemap_); - - notify(); } ForwardScanTask::ForwardScanTask(TaskGroup& taskgroup, const size_t id, eckit::net::Endpoint endpoint, scanmap_t& scanmap, std::atomic& nfields): @@ -216,11 +246,10 @@ ForwardScanTask::ForwardScanTask(TaskGroup& taskgroup, const size_t id, eckit::n nfields_(nfields) { } -void ForwardScanTask::execute(){ +void ForwardScanTask::executeImpl(){ RemoteGribJump remoteGribJump(endpoint_); nfields_ += remoteGribJump.forwardScan(scanmap_); - notify(); } //---------------------------------------------------------------------------------------------------------------------- @@ -277,15 +306,10 @@ FileScanTask::FileScanTask(TaskGroup& taskgroup, const size_t id, const eckit::P nfields_(nfields){ } -void FileScanTask::execute() { - eckit::Timer timer; - eckit::Timer full_timer("Thread total time. Thread: " + thread_id_str()); +void FileScanTask::executeImpl() { std::sort(offsets_.begin(), offsets_.end()); - scan(); - - notify(); } void FileScanTask::scan() { diff --git a/src/gribjump/Task.h b/src/gribjump/Task.h index e44e84c..18bd30d 100644 --- a/src/gribjump/Task.h +++ b/src/gribjump/Task.h @@ -30,8 +30,10 @@ class Task { enum Status { DONE = 0, - PENDING = 1, - FAILED = 2 + PENDING, + FAILED, + EXECUTING, + CANCELLED, }; Task(TaskGroup& taskGroup, size_t id); @@ -41,18 +43,28 @@ class Task { size_t id() const { return taskid_; } /// executes the task to completion - virtual void execute() = 0; + virtual void execute() final; /// notifies the completion of the task - virtual void notify(); + void notify(); + + /// notifies that the task was cancelled before execution e.g. because of an error in a related task + void notifyCancelled(); /// notifies the error in execution of the task - virtual void notifyError(const std::string& s); + void notifyError(const std::string& s); + + /// cancels the task. If execute() is called after this, it will return immediately. + void cancel(); + +protected: + virtual void executeImpl() = 0; protected: TaskGroup& taskGroup_; //< Groups like-tasks to be executed in parallel size_t taskid_; //< Task id within parent request + std::atomic status_ = Status::PENDING; }; //---------------------------------------------------------------------------------------------------------------------- @@ -81,45 +93,60 @@ class TaskGroup { TaskGroup() = default; /// Notify that a task has been completed - /// potentially completing all the work for this request void notify(size_t taskid); /// Notify that a task has finished with error - /// potentially completing all the work for this request void notifyError(size_t taskid, const std::string& s); - /// Enqueue tasks to be executed to complete this request - void enqueueTask(Task* task); - + /// Notify that a task was cancelled + void notifyCancelled(size_t taskid); + + /// Enqueue tasks on the global task queue + template + void enqueueTask(Args&&... args) { + enqueueTask(new TaskType(*this, tasks_.size(), std::forward(args)...)); + } + /// Wait for all queued tasks to be executed void waitForTasks(); - TaskReport report() {return TaskReport(std::move(errors_)); } - - std::mutex debugMutex_; + /// Report on errors and other status information about executed tasks. + /// Calling code may use this to report to a client or raise an exception. + TaskReport report() { + std::lock_guard lock(m_); + ASSERT(done_); + return TaskReport(std::move(errors_)); + } size_t nTasks() const { std::lock_guard lock(m_); - return taskStatus_.size(); + return tasks_.size(); } + size_t nErrors() const { std::lock_guard lock(m_); return errors_.size(); } - + +private: + + void enqueueTask(Task* task); + + void cancelTasks(); + private: - int counter_ = 0; //< incremented by notify() or notifyError() + int nComplete_ = 0; //< incremented when a task completes + int nCancelledTasks_ = 0; //< incremented by notifyCancelled() int logcounter_ = 1; //< used to log progress int logincrement_ = 1; //< used to log progress - bool waiting_ = false; - + bool waiting_ = false; //< true if waiting for tasks to complete + bool done_ = false; //< true if all tasks have completed mutable std::mutex m_; std::condition_variable cv_; - std::vector> tasks_; //< stores tasks status, must be initialised by derived class - std::vector taskStatus_; + std::vector> tasks_; std::vector errors_; //< stores error messages, empty if no errors }; @@ -133,7 +160,7 @@ class FileExtractionTask : public Task { FileExtractionTask(TaskGroup& taskgroup, const size_t id, const eckit::PathName& fname, ExtractionItems& extractionItems); - void execute() override; + void executeImpl() override; virtual void extract(); @@ -164,7 +191,7 @@ class ForwardExtractionTask : public Task { ForwardExtractionTask(TaskGroup& taskgroup, const size_t id, eckit::net::Endpoint endpoint, filemap_t& filemap); - void execute() override; + void executeImpl() override; private: eckit::net::Endpoint endpoint_; @@ -177,7 +204,7 @@ class ForwardScanTask : public Task { ForwardScanTask(TaskGroup& taskgroup, const size_t id, eckit::net::Endpoint endpoint, scanmap_t& scanmap, std::atomic& nfields_); - void execute() override; + void executeImpl() override; private: eckit::net::Endpoint endpoint_; @@ -194,7 +221,7 @@ class FileScanTask : public Task { FileScanTask(TaskGroup& taskgroup, const size_t id, const eckit::PathName& fname, const std::vector& offsets, std::atomic& nfields); - void execute() override; + void executeImpl() override; void scan(); diff --git a/src/gribjump/remote/WorkItem.h b/src/gribjump/remote/WorkItem.h index c4582c4..c8bd638 100644 --- a/src/gribjump/remote/WorkItem.h +++ b/src/gribjump/remote/WorkItem.h @@ -34,7 +34,8 @@ class WorkItem { void error(const std::string& s); private: - Task* task_; + + Task* task_; //< non-owning pointer }; //---------------------------------------------------------------------------------------------------------------------- diff --git a/src/gribjump/remote/WorkQueue.cc b/src/gribjump/remote/WorkQueue.cc index 5a5dd36..f18e24a 100644 --- a/src/gribjump/remote/WorkQueue.cc +++ b/src/gribjump/remote/WorkQueue.cc @@ -65,7 +65,8 @@ WorkQueue::WorkQueue() : queue_(eckit::Resource("$GRIBJUMP_QUEUESIZE;gri } } -void WorkQueue::push(WorkItem& item) { +void WorkQueue::push(Task* task) { + WorkItem item(task); queue_.push(item); } diff --git a/src/gribjump/remote/WorkQueue.h b/src/gribjump/remote/WorkQueue.h index 323056d..0580b9f 100644 --- a/src/gribjump/remote/WorkQueue.h +++ b/src/gribjump/remote/WorkQueue.h @@ -18,6 +18,7 @@ #include "gribjump/ExtractionData.h" #include "gribjump/remote/WorkItem.h" +#include "gribjump/Task.h" namespace gribjump { @@ -30,7 +31,7 @@ class WorkQueue : private eckit::NonCopyable { ~WorkQueue(); - void push(WorkItem& item); + void push(Task* task); protected: WorkQueue();