Skip to content

Commit

Permalink
Changed HighsTaskExecutor to explicitly handle reference counting to …
Browse files Browse the repository at this point in the history
…ensure shared memory is kept alive until all threads have finished.

Changed identification of main thread to use std::thread::id rather than thread_local memory pointer.

Added vector of workerThreads which can be detached or joined on shutdown.

Ensured that shutdown can only block if called on main thread, otherwise it might be possible to deadlock.

Manually using cache_aligned memory allocation, it was used previously with shared_ptr and I wanted to keep it just in case.

Note: I had some weird issues when compiling with /MD flag with mvsc.  It would run but often crash.  /MT flag works consistently for me.
  • Loading branch information
mathgeekcoder committed Sep 13, 2024
1 parent 338c47d commit a2d5db9
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 44 deletions.
17 changes: 14 additions & 3 deletions src/parallel/HighsTaskExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,18 @@ thread_local HighsTaskExecutor::ExecutorHandle
HighsTaskExecutor::globalExecutorHandle{};
#endif

HighsTaskExecutor::ExecutorHandle::~ExecutorHandle() {
if (ptr && this == ptr->mainWorkerHandle.load(std::memory_order_relaxed))
HighsTaskExecutor::shutdown();
void HighsTaskExecutor::ExecutorHandle::dispose() {
if (ptr != nullptr) {
// check to see if we are the main worker and if so, shut down the executor
if (std::this_thread::get_id() == ptr->mainWorkerId.load()) {
ptr->stopWorkerThreads(false);
}

// check to see if we are the last handle and if so, delete the executor
if (--ptr->referenceCount == 0) {
cache_aligned::Deleter<HighsTaskExecutor>()(ptr);
}

ptr = nullptr;
}
}
99 changes: 58 additions & 41 deletions src/parallel/HighsTaskExecutor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ class HighsTaskExecutor {
public:
using cache_aligned = highs::cache_aligned;
struct ExecutorHandle {
cache_aligned::shared_ptr<HighsTaskExecutor> ptr{nullptr};
HighsTaskExecutor* ptr{nullptr};

~ExecutorHandle();
void dispose();
~ExecutorHandle() { dispose(); }
};

private:
Expand All @@ -52,9 +53,11 @@ class HighsTaskExecutor {
}
#endif

std::vector<cache_aligned::unique_ptr<HighsSplitDeque>> workerDeques;
std::atomic<int> referenceCount;
std::atomic<std::thread::id> mainWorkerId;
cache_aligned::shared_ptr<HighsSplitDeque::WorkerBunk> workerBunk;
std::atomic<ExecutorHandle*> mainWorkerHandle;
std::vector<cache_aligned::unique_ptr<HighsSplitDeque>> workerDeques;
std::vector<std::thread> workerThreads;

HighsTask* random_steal_loop(HighsSplitDeque* localDeque) {
const int numWorkers = workerDeques.size();
Expand Down Expand Up @@ -85,42 +88,68 @@ class HighsTaskExecutor {
return nullptr;
}

static void run_worker(
int workerId,
ExecutorHandle* executor,
highs::cache_aligned::shared_ptr<HighsTaskExecutor> ref) {
static void run_worker(int workerId, HighsTaskExecutor* ptr) {

threadLocalExecutorHandle().ptr = ptr;

// check if main thread has shutdown before thread has started
if (ptr->mainWorkerId.load() != std::thread::id()) {
HighsSplitDeque* localDeque = ptr->workerDeques[workerId].get();
threadLocalWorkerDeque() = localDeque;

// now acquire a reference count of the global executor
threadLocalExecutorHandle() = *executor;
HighsSplitDeque* localDeque = ref->workerDeques[workerId].get();
threadLocalWorkerDeque() = localDeque;
HighsTask* currentTask = ref->workerBunk->waitForNewTask(localDeque);
while (currentTask != nullptr) {
localDeque->runStolenTask(currentTask);
HighsTask* currentTask = ptr->workerBunk->waitForNewTask(localDeque);
while (currentTask != nullptr) {
localDeque->runStolenTask(currentTask);

currentTask = ref->random_steal_loop(localDeque);
if (currentTask != nullptr) continue;
currentTask = ptr->random_steal_loop(localDeque);
if (currentTask != nullptr) continue;

currentTask = ref->workerBunk->waitForNewTask(localDeque);
currentTask = ptr->workerBunk->waitForNewTask(localDeque);
}
}

threadLocalExecutorHandle().dispose();
}

public:
HighsTaskExecutor(int numThreads) {
assert(numThreads > 0);
mainWorkerHandle.store(nullptr, std::memory_order_relaxed);
mainWorkerId.store(std::this_thread::get_id());
workerDeques.resize(numThreads);
workerBunk = cache_aligned::make_shared<HighsSplitDeque::WorkerBunk>();
for (int i = 0; i < numThreads; ++i)
workerDeques[i] = cache_aligned::make_unique<HighsSplitDeque>(
workerBunk, workerDeques.data(), i, numThreads);

threadLocalWorkerDeque() = workerDeques[0].get();
}
workerThreads.reserve(numThreads - 1);
referenceCount.store(numThreads);

void init(ExecutorHandle* executor) {
for (int i = 1, numThreads = workerDeques.size(); i < numThreads; ++i) {
std::thread(&HighsTaskExecutor::run_worker, i, executor, executor->ptr).detach();
workerThreads.emplace_back(
std::move(std::thread(&HighsTaskExecutor::run_worker, i, this)));
}
}

void stopWorkerThreads(bool blocking = false) {
auto id = mainWorkerId.exchange(std::thread::id());
if (id == std::thread::id()) return; // already been called

// now inject the null task as termination signal to every worker
for (auto& workerDeque : workerDeques) {
workerDeque->injectTaskAndNotify(nullptr);
}

// only block if called on main thread, otherwise deadlock may occur
if (blocking && std::this_thread::get_id() == id) {
for (auto& workerThread : workerThreads) {
workerThread.join();
}
}
else {
for (auto& workerThread : workerThreads) {
workerThread.detach();
}
}
}

Expand All @@ -134,31 +163,19 @@ class HighsTaskExecutor {

static void initialize(int numThreads) {
auto& executorHandle = threadLocalExecutorHandle();
if (!executorHandle.ptr) {
executorHandle.ptr =
cache_aligned::make_shared<HighsTaskExecutor>(numThreads);
executorHandle.ptr->mainWorkerHandle.store(&executorHandle,
std::memory_order_release);
executorHandle.ptr->init(&executorHandle);
if (executorHandle.ptr == nullptr) {
executorHandle.ptr = new (cache_aligned::alloc(sizeof(HighsTaskExecutor))) HighsTaskExecutor(numThreads);
}
}

// can be called on main or worker threads
// blocking ignored unless called on main thread
static void shutdown(bool blocking = false) {
auto& executorHandle = threadLocalExecutorHandle();
if (executorHandle.ptr) {
// set the active flag to false first with release ordering
executorHandle.ptr->mainWorkerHandle.store(nullptr,
std::memory_order_release);
// now inject the null task as termination signal to every worker
for (auto& workerDeque : executorHandle.ptr->workerDeques)
workerDeque->injectTaskAndNotify(nullptr);
// finally release the global executor reference
if (blocking) {
while (executorHandle.ptr.use_count() != 1)
HighsSpinMutex::yieldProcessor();
}

executorHandle.ptr.reset();
if (executorHandle.ptr != nullptr) {
executorHandle.ptr->stopWorkerThreads(blocking);
executorHandle.dispose();
}
}

Expand Down

0 comments on commit a2d5db9

Please sign in to comment.