diff --git a/SConscript b/SConscript index 147eb304..9a317f7e 100644 --- a/SConscript +++ b/SConscript @@ -15,7 +15,7 @@ msgq_objects = env.SharedObject([ 'msgq/msgq.cc', ]) msgq = env.Library('msgq', msgq_objects) -msgq_python = envCython.Program('msgq/ipc_pyx.so', 'msgq/ipc_pyx.pyx', LIBS=envCython["LIBS"]+[msgq, "zmq", common]) +msgq_python = envCython.Program('msgq/ipc_pyx.so', 'msgq/ipc_pyx.pyx', LIBS=envCython["LIBS"]+[msgq, "zmq", common, 'pthread']) # Build Vision IPC vipc_files = ['visionipc.cc', 'visionipc_server.cc', 'visionipc_client.cc', 'visionbuf.cc'] @@ -31,7 +31,7 @@ visionipc = env.Library('visionipc', vipc_objects) vipc_frameworks = [] -vipc_libs = envCython["LIBS"] + [visionipc, msgq, common, "zmq"] +vipc_libs = envCython["LIBS"] + [visionipc, msgq, common, "zmq", 'pthread'] if arch == "Darwin": vipc_frameworks.append('OpenCL') else: @@ -40,7 +40,7 @@ envCython.Program(f'{visionipc_dir.abspath}/visionipc_pyx.so', f'{visionipc_dir. LIBS=vipc_libs, FRAMEWORKS=vipc_frameworks) if GetOption('extras'): - env.Program('msgq/test_runner', ['msgq/test_runner.cc', 'msgq/msgq_tests.cc'], LIBS=[msgq, common]) + env.Program('msgq/test_runner', ['msgq/test_runner.cc', 'msgq/msgq_tests.cc'], LIBS=[msgq, common, 'pthread']) env.Program(f'{visionipc_dir.abspath}/test_runner', [f'{visionipc_dir.abspath}/test_runner.cc', f'{visionipc_dir.abspath}/visionipc_tests.cc'], LIBS=['pthread'] + vipc_libs, FRAMEWORKS=vipc_frameworks) diff --git a/msgq/impl_msgq.cc b/msgq/impl_msgq.cc index b2399135..598f2dc5 100644 --- a/msgq/impl_msgq.cc +++ b/msgq/impl_msgq.cc @@ -70,7 +70,6 @@ int MSGQSubSocket::connect(Context *context, std::string endpoint, std::string a return 0; } - Message * MSGQSubSocket::receive(bool non_blocking){ msgq_do_exit = 0; @@ -81,49 +80,27 @@ Message * MSGQSubSocket::receive(bool non_blocking){ prev_handler_sigterm = std::signal(SIGTERM, sig_handler); } + int rc = 0; msgq_msg_t msg; - - MSGQMessage *r = NULL; - - int rc = msgq_msg_recv(&msg, q); - - // Hack to implement blocking read with a poller. Don't use this - while (!non_blocking && rc == 0 && msgq_do_exit == 0){ - msgq_pollitem_t items[1]; - items[0].q = q; - - int t = (timeout != -1) ? timeout : 100; - - int n = msgq_poll(items, 1, t); + while (!msgq_do_exit) { rc = msgq_msg_recv(&msg, q); + if (rc > 0 || non_blocking) break; - // The poll indicated a message was ready, but the receive failed. Try again - if (n == 1 && rc == 0){ - continue; - } - - if (timeout != -1){ - break; - } + int ms = (timeout != -1) ? timeout : 100; + if (!q->shm->waitFor(ms) && timeout != -1) break; } - if (!non_blocking){ std::signal(SIGINT, prev_handler_sigint); std::signal(SIGTERM, prev_handler_sigterm); } errno = msgq_do_exit ? EINTR : 0; - + MSGQMessage *r = nullptr; if (rc > 0){ - if (msgq_do_exit){ - msgq_msg_close(&msg); // Free unused message on exit - } else { - r = new MSGQMessage; - r->takeOwnership(msg.data, msg.size); - } + r = new MSGQMessage; + r->takeOwnership(msg.data, msg.size); } - return (Message*)r; } diff --git a/msgq/msgq.cc b/msgq/msgq.cc index fed2959b..7a89db1a 100644 --- a/msgq/msgq.cc +++ b/msgq/msgq.cc @@ -1,34 +1,112 @@ +#include "msgq.h" + #include #include #include +#include #include -#include -#include #include #include -#include -#include #include -#include #include +#include -#include -#include #include -#include -#include #include #include #include -#include +constexpr const char *SHM_INIT_SEM = "/op_shm_init_sem"; + +SharedMemory::SharedMemory(const std::string &name, size_t size) :shm_name(name) { + const char* prefix = std::getenv("OPENPILOT_PREFIX"); + std::string full_path = "/dev/shm/"; + if (prefix) { + full_path += std::string(prefix) + "/"; + } + full_path += shm_name; + shm_fd = open(full_path.c_str(), O_RDWR | O_CREAT, 0664); + assert(shm_fd != -1); + + shm_size = sizeof(SharedMemoryHeader) + size; + int ret = ftruncate(shm_fd, shm_size); + assert(ret != -1); + shm_ptr = mmap(0, shm_size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); + assert(shm_ptr != MAP_FAILED); -#include "msgq/msgq.h" + initMutexCond(); +} + +SharedMemory::~SharedMemory() { + munmap(shm_ptr, shm_size); + ::close(shm_fd); +} + +void SharedMemory::initMutexCond() { + sem_t *sem = sem_open(SHM_INIT_SEM, O_CREAT, 0644, 1); + assert(sem != SEM_FAILED); + sem_wait(sem); // Lock semaphore + + // Initialize the header if it hasn't been initialized yet + header = (SharedMemoryHeader *)shm_ptr; + if (!header->initialized) { + pthread_mutexattr_t mutex_attr; + pthread_mutexattr_init(&mutex_attr); + pthread_mutexattr_setpshared(&mutex_attr, PTHREAD_PROCESS_SHARED); +#ifndef __APPLE__ + // Set the mutex to be robust, meaning it can recover from a process crash + pthread_mutexattr_setrobust(&mutex_attr, PTHREAD_MUTEX_ROBUST); +#endif + pthread_mutex_init(&(header->mutex), &mutex_attr); + + pthread_condattr_t cond_attr; + pthread_condattr_init(&cond_attr); + pthread_condattr_setpshared(&cond_attr, PTHREAD_PROCESS_SHARED); + pthread_cond_init(&(header->cond), &cond_attr); + + pthread_mutexattr_destroy(&mutex_attr); + pthread_condattr_destroy(&cond_attr); + header->initialized = true; + } -void sigusr2_handler(int signal) { - assert(signal == SIGUSR2); + sem_post(sem); + sem_close(sem); } +void SharedMemory::notifyAll() { + pthread_cond_broadcast(&(header->cond)); +} + +bool SharedMemory::waitFor(int timeout_ms) { + struct timespec ts; + clock_gettime(CLOCK_REALTIME, &ts); + ts.tv_sec += (timeout_ms / 1000); + ts.tv_nsec += (timeout_ms % 1000) * 1000000; + if (ts.tv_nsec >= 1000000000) { + ts.tv_sec += 1; + ts.tv_nsec -= 1000000000; + } + + int ret = pthread_mutex_lock(&(header->mutex)); +#ifndef __APPLE__ + // Handle case where previous owner of the mutex died + if (ret == EOWNERDEAD) { + pthread_mutex_consistent((&(header->mutex))); + } +#endif + ret = pthread_cond_timedwait(&(header->cond), &(header->mutex), &ts); + pthread_mutex_unlock(&(header->mutex)); + + // Return true if condition was signaled, false if timed out + return ret == 0; +} + +PollerContext::PollerContext() : shm("msgq_poll", 0) { + ctx = (SharedMemoryHeader *)shm.shm_ptr; +} + +PollerContext poller_context; + uint64_t msgq_get_uid(void){ std::random_device rd("/dev/urandom"); std::uniform_int_distribution distribution(0, std::numeric_limits::max()); @@ -53,7 +131,6 @@ int msgq_msg_init_size(msgq_msg_t * msg, size_t size){ int msgq_msg_init_data(msgq_msg_t * msg, char * data, size_t size) { int r = msgq_msg_init_size(msg, size); - if (r == 0) memcpy(msg->data, data, size); @@ -79,41 +156,13 @@ void msgq_wait_for_subscriber(msgq_queue_t *q){ while (*q->num_readers == 0){ // wait for subscriber } - return; } int msgq_new_queue(msgq_queue_t * q, const char * path, size_t size){ assert(size < 0xFFFFFFFF); // Buffer must be smaller than 2^32 bytes - std::signal(SIGUSR2, sigusr2_handler); - - std::string full_path = "/dev/shm/"; - const char* prefix = std::getenv("OPENPILOT_PREFIX"); - if (prefix) { - full_path += std::string(prefix) + "/"; - } - full_path += path; - - auto fd = open(full_path.c_str(), O_RDWR | O_CREAT, 0664); - if (fd < 0) { - std::cout << "Warning, could not open: " << full_path << std::endl; - return -1; - } - - int rc = ftruncate(fd, size + sizeof(msgq_header_t)); - if (rc < 0){ - close(fd); - return -1; - } - char * mem = (char*)mmap(NULL, size + sizeof(msgq_header_t), PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); - close(fd); - - if (mem == NULL){ - return -1; - } - q->mmap_p = mem; - - msgq_header_t *header = (msgq_header_t *)mem; + q->shm = std::make_unique(path, size + sizeof(msgq_header_t)); + msgq_header_t *header = (msgq_header_t*)(q->shm->header + 1); // Setup pointers to header segment q->num_readers = reinterpret_cast*>(&header->num_readers); @@ -126,7 +175,7 @@ int msgq_new_queue(msgq_queue_t * q, const char * path, size_t size){ q->read_uids[i] = reinterpret_cast*>(&header->read_uids[i]); } - q->data = mem + sizeof(msgq_header_t); + q->data = (char*)(header + 1); q->size = size; q->reader_id = -1; @@ -136,12 +185,7 @@ int msgq_new_queue(msgq_queue_t * q, const char * path, size_t size){ return 0; } -void msgq_close_queue(msgq_queue_t *q){ - if (q->mmap_p != NULL){ - munmap(q->mmap_p, q->size + sizeof(msgq_header_t)); - } -} - +void msgq_close_queue(msgq_queue_t *q) {} void msgq_init_publisher(msgq_queue_t * q) { //std::cout << "Starting publisher" << std::endl; @@ -158,15 +202,6 @@ void msgq_init_publisher(msgq_queue_t * q) { q->write_uid_local = uid; } -static void thread_signal(uint32_t tid) { - #ifndef SYS_tkill - // TODO: this won't work for multithreaded programs - kill(tid, SIGUSR2); - #else - syscall(SYS_tkill, tid, SIGUSR2); - #endif -} - void msgq_init_subscriber(msgq_queue_t * q) { assert(q != NULL); assert(q->num_readers != NULL); @@ -185,13 +220,11 @@ void msgq_init_subscriber(msgq_queue_t * q) { for (size_t i = 0; i < NUM_READERS; i++){ *q->read_valids[i] = false; - - uint64_t old_uid = *q->read_uids[i]; *q->read_uids[i] = 0; - // Wake up reader in case they are in a poll - thread_signal(old_uid & 0xFFFFFFFF); + q->shm->notifyAll(); } + poller_context.shm.notifyAll(); continue; } @@ -292,12 +325,9 @@ int msgq_msg_send(msgq_msg_t * msg, msgq_queue_t *q){ uint32_t new_ptr = ALIGN(write_pointer + msg->size + sizeof(int64_t)); PACK64(*q->write_pointer, write_cycles, new_ptr); - // Notify readers - for (uint64_t i = 0; i < num_readers; i++){ - uint64_t reader_uid = *q->read_uids[i]; - thread_signal(reader_uid & 0xFFFFFFFF); - } - + // Notify pollers and readers + poller_context.shm.notifyAll(); + q->shm->notifyAll(); return msg->size; } @@ -418,42 +448,21 @@ int msgq_msg_recv(msgq_msg_t * msg, msgq_queue_t * q){ return msg->size; } - - -int msgq_poll(msgq_pollitem_t * items, size_t nitems, int timeout){ +int msgq_poll(msgq_pollitem_t *items, size_t nitems, int timeout) { int num = 0; - - // Check if messages ready - for (size_t i = 0; i < nitems; i++) { - items[i].revents = msgq_msg_ready(items[i].q); - if (items[i].revents) num++; - } - int ms = (timeout == -1) ? 100 : timeout; - struct timespec ts; - ts.tv_sec = ms / 1000; - ts.tv_nsec = (ms % 1000) * 1000 * 1000; - - - while (num == 0) { - int ret; - - ret = nanosleep(&ts, &ts); + while (true) { // Check if messages ready - for (size_t i = 0; i < nitems; i++) { - if (items[i].revents == 0 && msgq_msg_ready(items[i].q)){ - num += 1; - items[i].revents = 1; - } + for (size_t i = 0; i < nitems; ++i) { + items[i].revents = msgq_msg_ready(items[i].q); + if (items[i].revents) ++num; } + if (num > 0 || timeout == 0) break; - // exit if we had a timeout and the sleep finished - if (timeout != -1 && ret == 0){ - break; - } + // Wait until messages are ready or timeout occurs + if (!poller_context.shm.waitFor(ms) && timeout != -1) break; } - return num; } diff --git a/msgq/msgq.h b/msgq/msgq.h index 94e18494..6cfcfd50 100644 --- a/msgq/msgq.h +++ b/msgq/msgq.h @@ -1,9 +1,9 @@ #pragma once -#include #include #include #include +#include #define DEFAULT_SEGMENT_SIZE (10 * 1024 * 1024) #define NUM_READERS 15 @@ -13,10 +13,42 @@ #define UNPACK64(higher, lower, input) do {uint64_t tmp = input; higher = tmp >> 32; lower = tmp & 0xFFFFFFFF;} while (0) #define PACK64(output, higher, lower) output = ((uint64_t)higher << 32) | ((uint64_t)lower & 0xFFFFFFFF) +struct SharedMemoryHeader { + pthread_mutex_t mutex; + pthread_cond_t cond; + bool initialized; +}; + +class SharedMemory { +public: + SharedMemory(const std::string &name, size_t size); + ~SharedMemory(); + void notifyAll(); + bool waitFor(int timeout_ms); + void *shm_ptr; + size_t shm_size; + SharedMemoryHeader *header; + +private: + void initMutexCond(); + std::string shm_name; + int shm_fd; +}; + +struct PollerContext { + PollerContext(); + SharedMemory shm; + SharedMemoryHeader *ctx; +}; + +extern PollerContext poller_context; + struct msgq_header_t { uint64_t num_readers; uint64_t write_pointer; uint64_t write_uid; + pthread_mutex_t mutex; + pthread_cond_t cond; uint64_t read_pointers[NUM_READERS]; uint64_t read_valids[NUM_READERS]; uint64_t read_uids[NUM_READERS]; @@ -29,7 +61,6 @@ struct msgq_queue_t { std::atomic *read_pointers[NUM_READERS]; std::atomic *read_valids[NUM_READERS]; std::atomic *read_uids[NUM_READERS]; - char * mmap_p; char * data; size_t size; int reader_id; @@ -38,6 +69,7 @@ struct msgq_queue_t { bool read_conflate; std::string endpoint; + std::unique_ptr shm; }; struct msgq_msg_t {