Skip to content

Commit

Permalink
Refactor tasks and make cancellable
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisspyB committed Nov 29, 2024
1 parent c620b78 commit 94bd63e
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 74 deletions.
8 changes: 3 additions & 5 deletions src/gribjump/Engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,17 @@ TaskReport Engine::scheduleExtractionTasks(filemap_t& filemap){

bool inefficientExtraction = LibGribJump::instance().config().getBool("inefficientExtraction", false);

size_t counter = 0;
TaskGroup taskGroup;

for (auto& [fname, extractionItems] : filemap) {
if (extractionItems[0]->isRemote()) {
if (inefficientExtraction) {
taskGroup.enqueueTask(new InefficientFileExtractionTask(taskGroup, counter++, fname, extractionItems));
taskGroup.enqueueTask<InefficientFileExtractionTask>(fname, extractionItems);
} else {
throw eckit::SeriousBug("Got remote URI from FDB, but forwardExtraction enabled in gribjump config.");
}
} else {
taskGroup.enqueueTask(new FileExtractionTask(taskGroup, counter++, fname, extractionItems));
taskGroup.enqueueTask<FileExtractionTask>(fname, extractionItems);
}
}
taskGroup.waitForTasks();
Expand Down Expand Up @@ -208,11 +207,10 @@ TaskOutcome<size_t> Engine::scan(std::vector<eckit::PathName> files) {

TaskOutcome<size_t> Engine::scheduleScanTasks(const scanmap_t& scanmap) {

size_t counter = 0;
std::atomic<size_t> nfields(0);
TaskGroup taskGroup;
for (auto& [uri, offsets] : scanmap) {
taskGroup.enqueueTask(new FileScanTask(taskGroup, counter++, uri.path(), offsets, nfields));
taskGroup.enqueueTask<FileScanTask>(uri.path(), offsets, nfields);
}
taskGroup.waitForTasks();

Expand Down
7 changes: 2 additions & 5 deletions src/gribjump/Forwarder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,9 @@ TaskOutcome<size_t> Forwarder::scan(const std::vector<eckit::URI>& uris) {
}

TaskGroup taskGroup;
size_t counter = 0;
std::atomic<size_t> nFields(0);
size_t i = 0;
for (auto& [endpoint, scanmap] : serverfilemaps) {
taskGroup.enqueueTask(new ForwardScanTask(taskGroup, counter++, endpoint, scanmap, nFields));
taskGroup.enqueueTask<ForwardScanTask>(endpoint, scanmap, nFields);
}
taskGroup.waitForTasks();

Expand All @@ -57,9 +55,8 @@ TaskReport Forwarder::extract(filemap_t& filemap) {
std::unordered_map<eckit::net::Endpoint, filemap_t> serverfilemaps = serverFileMap(filemap);

TaskGroup taskGroup;
size_t counter = 0;
for (auto& [endpoint, subfilemap] : serverfilemaps) {
taskGroup.enqueueTask(new ForwardExtractionTask(taskGroup, counter++, endpoint, subfilemap));
taskGroup.enqueueTask<ForwardExtractionTask>(endpoint, subfilemap);
}
taskGroup.waitForTasks();

Expand Down
98 changes: 61 additions & 37 deletions src/gribjump/Task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,7 @@

namespace gribjump {

static std::string thread_id_str() {
auto id = std::this_thread::get_id();
std::stringstream ss;
ss << id;
return ss.str();
}
constexpr bool cancelOnFirstError = true; ///@todo make this configurable

//----------------------------------------------------------------------------------------------------------------------

Expand All @@ -45,65 +40,106 @@ Task::Task(TaskGroup& taskGroup, size_t taskid) : taskGroup_(taskGroup), taskid_
Task::~Task() {}

void Task::notify() {
status_ = Status::DONE;
taskGroup_.notify(id());
}

void Task::notifyError(const std::string& s) {
status_ = Status::FAILED;
taskGroup_.notifyError(id(), s);
}

void Task::notifyCancelled() {
status_ = Status::CANCELLED;
taskGroup_.notifyCancelled(id());
}

void Task::execute() {
// atomically set status to executing, but only if it is currently pending
Status expected = Status::PENDING;
if (!status_.compare_exchange_strong(expected, Status::EXECUTING)) {
return;
}
executeImpl();
notify();
}

void Task::cancel() {
// atomically set status to cancelled, but only if it is pending
Status expected = Status::PENDING;
status_.compare_exchange_strong(expected, Status::CANCELLED);
}

//----------------------------------------------------------------------------------------------------------------------

void TaskGroup::notify(size_t taskid) {
std::lock_guard<std::mutex> lock(m_);
taskStatus_[taskid] = Task::Status::DONE;
counter_++;
nComplete_++;

// Logging progress
if (waiting_) {
if (counter_ == logcounter_) {
eckit::Log::info() << "Gribjump Progress: " << counter_ << " of " << taskStatus_.size() << " tasks complete" << std::endl;
if (nComplete_ == logcounter_) {
eckit::Log::info() << "Gribjump Progress: " << nComplete_ << " of " << tasks_.size() << " tasks complete" << std::endl;
logcounter_ += logincrement_;
}
}

cv_.notify_one();
}

void TaskGroup::notifyCancelled(size_t taskid) {
std::lock_guard<std::mutex> lock(m_);
nComplete_++;
nCancelledTasks_++;
cv_.notify_one();
}

void TaskGroup::notifyError(size_t taskid, const std::string& s) {
std::lock_guard<std::mutex> lock(m_);
taskStatus_[taskid] = Task::Status::FAILED;
errors_.push_back(s);
counter_++;
nComplete_++;
cv_.notify_one();

if (cancelOnFirstError) {
cancelTasks();
}
}

// Note: This will only affect tasks that have not yet started. Cancelled tasks will call notifyCancelled() when they are executed.
// NB: We do not lock a mutex as this will be called from notifyError()
void TaskGroup::cancelTasks() {
for (auto& task : tasks_) {
task->cancel();
}
}

void TaskGroup::enqueueTask(Task* task) {
taskStatus_.push_back(Task::Status::PENDING);
WorkItem w(task);
WorkQueue& queue = WorkQueue::instance();
queue.push(w);
std::lock_guard<std::mutex> lock(m_);
tasks_.push_back(std::unique_ptr<Task>(task)); // TaskGroup takes ownership of its tasks
WorkQueue::instance().push(task);

LOG_DEBUG_LIB(LibGribJump) << "Queued task " << tasks_.size() << std::endl;
}

void TaskGroup::waitForTasks() {
ASSERT(taskStatus_.size() > 0); // todo Might want to allow for "no tasks" case, though be careful with the lock / counter.
LOG_DEBUG_LIB(LibGribJump) << "Waiting for " << eckit::Plural(taskStatus_.size(), "task") << "..." << std::endl;
std::unique_lock<std::mutex> lock(m_);
ASSERT(tasks_.size() > 0);
LOG_DEBUG_LIB(LibGribJump) << "Waiting for " << eckit::Plural(tasks_.size(), "task") << "..." << std::endl;

waiting_ = true;
logincrement_ = taskStatus_.size() / 10;
logincrement_ = tasks_.size() / 10;
if (logincrement_ == 0) {
logincrement_ = 1;
}

cv_.wait(lock, [&]{return counter_ == taskStatus_.size();});
cv_.wait(lock, [&]{return nComplete_ == tasks_.size();});
waiting_ = false;
done_ = true;
LOG_DEBUG_LIB(LibGribJump) << "All tasks complete" << std::endl;

MetricsManager::instance().set("count_tasks", taskStatus_.size());
MetricsManager::instance().set("count_tasks", tasks_.size());
MetricsManager::instance().set("count_failed_tasks", errors_.size());
MetricsManager::instance().set("count_cancelled_tasks", nCancelledTasks_);

if (errors_.size() > 0) {
MetricsManager::instance().set("first_error", errors_[0]);
Expand Down Expand Up @@ -143,18 +179,14 @@ FileExtractionTask::FileExtractionTask(TaskGroup& taskgroup, const size_t id, co
{
}

void FileExtractionTask::execute() {
const std::string thread_id = thread_id_str();
eckit::Timer full_timer("Thread total time. Thread: " + thread_id, eckit::Log::debug());
void FileExtractionTask::executeImpl() {

// Sort extractionItems_ by offset
std::sort(extractionItems_.begin(), extractionItems_.end(), [](const ExtractionItem* a, const ExtractionItem* b) {
return a->offset() < b->offset();
});

extract();

notify();
}

void FileExtractionTask::extract() {
Expand Down Expand Up @@ -201,12 +233,10 @@ ForwardExtractionTask::ForwardExtractionTask(TaskGroup& taskgroup, const size_t
filemap_(filemap)
{}

void ForwardExtractionTask::execute(){
void ForwardExtractionTask::executeImpl(){

RemoteGribJump remoteGribJump(endpoint_);
remoteGribJump.forwardExtract(filemap_);

notify();
}

ForwardScanTask::ForwardScanTask(TaskGroup& taskgroup, const size_t id, eckit::net::Endpoint endpoint, scanmap_t& scanmap, std::atomic<size_t>& nfields):
Expand All @@ -216,11 +246,10 @@ ForwardScanTask::ForwardScanTask(TaskGroup& taskgroup, const size_t id, eckit::n
nfields_(nfields) {
}

void ForwardScanTask::execute(){
void ForwardScanTask::executeImpl(){

RemoteGribJump remoteGribJump(endpoint_);
nfields_ += remoteGribJump.forwardScan(scanmap_);
notify();
}

//----------------------------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -277,15 +306,10 @@ FileScanTask::FileScanTask(TaskGroup& taskgroup, const size_t id, const eckit::P
nfields_(nfields){
}

void FileScanTask::execute() {
eckit::Timer timer;
eckit::Timer full_timer("Thread total time. Thread: " + thread_id_str());
void FileScanTask::executeImpl() {

std::sort(offsets_.begin(), offsets_.end());

scan();

notify();
}

void FileScanTask::scan() {
Expand Down
Loading

0 comments on commit 94bd63e

Please sign in to comment.