diff --git a/src/dict.c b/src/dict.c index 280f0b6abc..a902bdbed9 100644 --- a/src/dict.c +++ b/src/dict.c @@ -329,6 +329,12 @@ int _dictResize(dict *d, unsigned long size, int *malloc_failed) { return DICT_OK; } + if (d->type->no_deferred_rehash) { + /* If the dict type does not support deferred rehashing, we need to + * rehash the whole table immediately. */ + while (dictRehash(d, 1000)); + } + return DICT_OK; } diff --git a/src/dict.h b/src/dict.h index a7c5c71826..aa55fb1bde 100644 --- a/src/dict.h +++ b/src/dict.h @@ -87,6 +87,8 @@ typedef struct dictType { /* If embedded_entry flag is set, it indicates that a copy of the key is created and the key is embedded * as part of the dict entry. */ unsigned int embedded_entry : 1; + /* Perform rehashing during resizing instead of deferring it across multiple steps */ + unsigned int no_deferred_rehash : 1; } dictType; #define DICTHT_SIZE(exp) ((exp) == -1 ? 0 : (unsigned long)1 << (exp)) diff --git a/src/io_threads.c b/src/io_threads.c index 6149febabc..89e2686cd9 100644 --- a/src/io_threads.c +++ b/src/io_threads.c @@ -124,6 +124,18 @@ int inMainThread(void) { return thread_id == 0; } +/* Drains the I/O threads queue by waiting for all jobs to be processed. + * This function must be called from the main thread. */ +void drainIOThreadsQueue(void) { + serverAssert(inMainThread()); + for (int i = 1; i < IO_THREADS_MAX_NUM; i++) { /* No need to drain thread 0, which is the main thread. */ + while (!IOJobQueue_isEmpty(&io_jobs[i])) { + /* memory barrier acquire to get the latest job queue state */ + atomic_thread_fence(memory_order_acquire); + } + } +} + /* Wait until the IO-thread is done with the client */ void waitForClientIO(client *c) { /* No need to wait if the client was not offloaded to the IO thread. */ diff --git a/src/io_threads.h b/src/io_threads.h index 30d1cdad79..9fb23c190e 100644 --- a/src/io_threads.h +++ b/src/io_threads.h @@ -9,5 +9,6 @@ int inMainThread(void); int trySendReadToIOThreads(client *c); int trySendWriteToIOThreads(client *c); void adjustIOThreadsByEventLoad(int numevents, int increase_only); +void drainIOThreadsQueue(void); #endif /* IO_THREADS_H */ diff --git a/src/module.c b/src/module.c index 5844fcbdea..90a18e0a38 100644 --- a/src/module.c +++ b/src/module.c @@ -61,6 +61,7 @@ #include "hdr_histogram.h" #include "crc16_slottable.h" #include "valkeymodule.h" +#include "io_threads.h" #include #include #include @@ -684,7 +685,7 @@ void moduleReleaseTempClient(client *c) { c->raw_flag = 0; c->flag.module = 1; c->user = NULL; /* Root user */ - c->cmd = c->lastcmd = c->realcmd = NULL; + c->cmd = c->lastcmd = c->realcmd = c->io_parsed_cmd = NULL; if (c->bstate.async_rm_call_handle) { ValkeyModuleAsyncRMCallPromise *promise = c->bstate.async_rm_call_handle; promise->c = NULL; /* Remove the client from the promise so it will no longer be possible to abort it. */ @@ -1295,7 +1296,8 @@ int VM_CreateCommand(ValkeyModuleCtx *ctx, ValkeyModuleCommand *cp = moduleCreateCommandProxy(ctx->module, declared_name, sdsdup(declared_name), cmdfunc, flags, firstkey, lastkey, keystep); cp->serverCmd->arity = cmdfunc ? -1 : -2; /* Default value, can be changed later via dedicated API */ - + /* Drain IO queue before modifying commands dictionary to prevent concurrent access while modifying it. */ + drainIOThreadsQueue(); serverAssert(dictAdd(server.commands, sdsdup(declared_name), cp->serverCmd) == DICT_OK); serverAssert(dictAdd(server.orig_commands, sdsdup(declared_name), cp->serverCmd) == DICT_OK); cp->serverCmd->id = ACLGetCommandID(declared_name); /* ID used for ACL. */ @@ -6281,7 +6283,7 @@ ValkeyModuleCallReply *VM_Call(ValkeyModuleCtx *ctx, const char *cmdname, const if (error_as_call_replies) reply = callReplyCreateError(err, ctx); goto cleanup; } - if (!commandCheckArity(c, error_as_call_replies ? &err : NULL)) { + if (!commandCheckArity(c->cmd, c->argc, error_as_call_replies ? &err : NULL)) { errno = EINVAL; if (error_as_call_replies) reply = callReplyCreateError(err, ctx); goto cleanup; @@ -10675,6 +10677,8 @@ void moduleCallCommandFilters(client *c) { ValkeyModuleCommandFilterCtx filter = {.argv = c->argv, .argv_len = c->argv_len, .argc = c->argc, .c = c}; + robj *tmp = c->argv[0]; + incrRefCount(tmp); while ((ln = listNext(&li))) { ValkeyModuleCommandFilter *f = ln->value; @@ -10690,6 +10694,12 @@ void moduleCallCommandFilters(client *c) { c->argv = filter.argv; c->argv_len = filter.argv_len; c->argc = filter.argc; + if (tmp != c->argv[0]) { + /* With I/O thread command-lookup offload, we set c->io_parsed_cmd to the command corresponding to c->argv[0]. + * Since the command filter just changed it, we need to reset c->io_parsed_cmd to null. */ + c->io_parsed_cmd = NULL; + } + decrRefCount(tmp); } /* Return the number of arguments a filtered command has. The number of @@ -12037,6 +12047,8 @@ int moduleFreeCommand(struct ValkeyModule *module, struct serverCommand *cmd) { } void moduleUnregisterCommands(struct ValkeyModule *module) { + /* Drain IO queue before modifying commands dictionary to prevent concurrent access while modifying it. */ + drainIOThreadsQueue(); /* Unregister all the commands registered by this module. */ dictIterator *di = dictGetSafeIterator(server.commands); dictEntry *de; diff --git a/src/networking.c b/src/networking.c index b249aa61f3..a8cdd4e4cf 100644 --- a/src/networking.c +++ b/src/networking.c @@ -164,7 +164,7 @@ client *createClient(connection *conn) { c->nread = 0; c->read_flags = 0; c->write_flags = 0; - c->cmd = c->lastcmd = c->realcmd = NULL; + c->cmd = c->lastcmd = c->realcmd = c->io_parsed_cmd = NULL; c->cur_script = NULL; c->multibulklen = 0; c->bulklen = -1; @@ -1428,6 +1428,7 @@ void freeClientArgv(client *c) { for (j = 0; j < c->argc; j++) decrRefCount(c->argv[j]); c->argc = 0; c->cmd = NULL; + c->io_parsed_cmd = NULL; c->argv_len_sum = 0; c->argv_len = 0; zfree(c->argv); @@ -4635,6 +4636,24 @@ void ioThreadReadQueryFromClient(void *data) { parseCommand(c); + /* Parsing was not completed - let the main-thread handle it. */ + if (!(c->read_flags & READ_FLAGS_PARSING_COMPLETED)) { + goto done; + } + + /* Empty command - Multibulk processing could see a <= 0 length. */ + if (c->argc == 0) { + goto done; + } + + /* Lookup command offload */ + c->io_parsed_cmd = lookupCommand(c->argv, c->argc); + if (c->io_parsed_cmd && commandCheckArity(c->io_parsed_cmd, c->argc, NULL) == 0) { + /* The command was found, but the arity is invalid. + * In this case, we reset the parsed_cmd and will let the main thread handle it. */ + c->io_parsed_cmd = NULL; + } + done: trimClientQueryBuffer(c); atomic_thread_fence(memory_order_release); diff --git a/src/server.c b/src/server.c index 465aa29391..61d6022dfa 100644 --- a/src/server.c +++ b/src/server.c @@ -492,12 +492,13 @@ dictType dbExpiresDictType = { /* Command table. sds string -> command struct pointer. */ dictType commandTableDictType = { - dictSdsCaseHash, /* hash function */ - NULL, /* key dup */ - dictSdsKeyCaseCompare, /* key compare */ - dictSdsDestructor, /* key destructor */ - NULL, /* val destructor */ - NULL /* allow to expand */ + dictSdsCaseHash, /* hash function */ + NULL, /* key dup */ + dictSdsKeyCaseCompare, /* key compare */ + dictSdsDestructor, /* key destructor */ + NULL, /* val destructor */ + NULL, /* allow to expand */ + .no_deferred_rehash = 1, /* no deferred rehash as the command table may be accessed from IO threads. */ }; /* Hash type hash table (note that small hashes are represented with listpacks) */ @@ -3719,11 +3720,11 @@ int commandCheckExistence(client *c, sds *err) { /* Check if c->argc is valid for c->cmd, fills `err` with details in case it isn't. * Return 1 if valid. */ -int commandCheckArity(client *c, sds *err) { - if ((c->cmd->arity > 0 && c->cmd->arity != c->argc) || (c->argc < -c->cmd->arity)) { +int commandCheckArity(struct serverCommand *cmd, int argc, sds *err) { + if ((cmd->arity > 0 && cmd->arity != argc) || (argc < -cmd->arity)) { if (err) { *err = sdsnew(NULL); - *err = sdscatprintf(*err, "wrong number of arguments for '%s' command", c->cmd->fullname); + *err = sdscatprintf(*err, "wrong number of arguments for '%s' command", cmd->fullname); } return 0; } @@ -3794,13 +3795,14 @@ int processCommand(client *c) { * In case we are reprocessing a command after it was blocked, * we do not have to repeat the same checks */ if (!client_reprocessing_command) { - c->cmd = c->lastcmd = c->realcmd = lookupCommand(c->argv, c->argc); + struct serverCommand *cmd = c->io_parsed_cmd ? c->io_parsed_cmd : lookupCommand(c->argv, c->argc); + c->cmd = c->lastcmd = c->realcmd = cmd; sds err; if (!commandCheckExistence(c, &err)) { rejectCommandSds(c, err); return C_OK; } - if (!commandCheckArity(c, &err)) { + if (!commandCheckArity(c->cmd, c->argc, &err)) { rejectCommandSds(c, err); return C_OK; } diff --git a/src/server.h b/src/server.h index 36a4b641e7..a5cf2d08aa 100644 --- a/src/server.h +++ b/src/server.h @@ -1216,6 +1216,7 @@ typedef struct client { struct serverCommand *realcmd; /* The original command that was executed by the client, Used to update error stats in case the c->cmd was modified during the command invocation (like on GEOADD for example). */ + struct serverCommand *io_parsed_cmd; /* The command that was parsed by the IO thread. */ user *user; /* User associated with this connection. If the user is set to NULL the connection can do anything (admin). */ @@ -3147,7 +3148,7 @@ struct serverCommand *lookupCommandByCStringLogic(dict *commands, const char *s) struct serverCommand *lookupCommandByCString(const char *s); struct serverCommand *lookupCommandOrOriginal(robj **argv, int argc); int commandCheckExistence(client *c, sds *err); -int commandCheckArity(client *c, sds *err); +int commandCheckArity(struct serverCommand *cmd, int argc, sds *err); void startCommandExecution(void); int incrCommandStatsOnError(struct serverCommand *cmd, int flags); void call(client *c, int flags);