From 24e657387e7ed2bcecc58365bd395ec8d420d628 Mon Sep 17 00:00:00 2001 From: Seungmin Lee Date: Mon, 25 Nov 2024 02:05:42 -0800 Subject: [PATCH] Add new replicate module API to bypass command validation Signed-off-by: Seungmin Lee --- src/module.c | 66 +++++++++++++++++++----------- src/valkeymodule.h | 8 ++++ tests/modules/propagate.c | 25 +++++++++-- tests/unit/moduleapi/propagate.tcl | 26 ++++++++++++ 4 files changed, 98 insertions(+), 27 deletions(-) diff --git a/src/module.c b/src/module.c index 1e98b36f30..119118a9a9 100644 --- a/src/module.c +++ b/src/module.c @@ -502,6 +502,7 @@ static void zsetKeyReset(ValkeyModuleKey *key); static void moduleInitKeyTypeSpecific(ValkeyModuleKey *key); void VM_FreeDict(ValkeyModuleCtx *ctx, ValkeyModuleDict *d); void VM_FreeServerInfo(ValkeyModuleCtx *ctx, ValkeyModuleServerInfoData *data); +int moduleReplicate(ValkeyModuleCtx *ctx, ValkeyModuleFlag flag, const char *cmdname, const char *fmt, va_list ap); /* Helpers for VM_SetCommandInfo. */ static int moduleValidateCommandInfo(const ValkeyModuleCommandInfo *info); @@ -3531,34 +3532,22 @@ int VM_ReplyWithLongDouble(ValkeyModuleCtx *ctx, long double ld) { * The command returns VALKEYMODULE_ERR if the format specifiers are invalid * or the command name does not belong to a known command. */ int VM_Replicate(ValkeyModuleCtx *ctx, const char *cmdname, const char *fmt, ...) { - struct serverCommand *cmd; - robj **argv = NULL; - int argc = 0, flags = 0, j; va_list ap; - - cmd = lookupCommandByCString((char *)cmdname); - if (!cmd) return VALKEYMODULE_ERR; - - /* Create the client and dispatch the command. */ va_start(ap, fmt); - argv = moduleCreateArgvFromUserFormat(cmdname, fmt, &argc, &flags, ap); + int result = moduleReplicate(ctx, VALKEYMODULE_FLAG_DEFAULT, cmdname, fmt, ap); va_end(ap); - if (argv == NULL) return VALKEYMODULE_ERR; - - /* Select the propagation target. Usually is AOF + replicas, however - * the caller can exclude one or the other using the "A" or "R" - * modifiers. */ - int target = 0; - if (!(flags & VALKEYMODULE_ARGV_NO_AOF)) target |= PROPAGATE_AOF; - if (!(flags & VALKEYMODULE_ARGV_NO_REPLICAS)) target |= PROPAGATE_REPL; - - alsoPropagate(ctx->client->db->id, argv, argc, target); + return result; +} - /* Release the argv. */ - for (j = 0; j < argc; j++) decrRefCount(argv[j]); - zfree(argv); - server.dirty++; - return VALKEYMODULE_OK; +/* Same as ValkeyModule_Replicate, but can take ValkeyModuleFlag + * Can be either VALKEYMODULE_FLAG_DEFAULT, which means default behavior + * (same as calling ValkeyModule_Replicate) */ +int VM_ReplicateWithFlag(ValkeyModuleCtx *ctx, ValkeyModuleFlag flag, const char *cmdname, const char *fmt, ...) { + va_list ap; + va_start(ap, fmt); + int result = moduleReplicate(ctx, flag, cmdname, fmt, ap); + va_end(ap); + return result; } /* This function will replicate the command exactly as it was invoked @@ -13523,6 +13512,34 @@ void moduleDefragGlobals(void) { dictReleaseIterator(di); } +/* Helper function for VM_Replicate and VM_ReplicateWithFlag to replicate the specified command + * and arguments to replicas and AOF, as effect of execution of the calling command implementation. + * Skip command validation if the ValkeyModuleFlag is set to VALKEYMODULE_FLAG_SKIP_VALIDATION. */ +int moduleReplicate(ValkeyModuleCtx *ctx, ValkeyModuleFlag flag, const char *cmdname, const char *fmt, va_list ap) { + struct serverCommand *cmd; + robj **argv = NULL; + int argc = 0, flags = 0, j; + if (flag != VALKEYMODULE_FLAG_SKIP_VALIDATION) { + cmd = lookupCommandByCString((char *)cmdname); + if (!cmd) return VALKEYMODULE_ERR; + } + /* Create the client and dispatch the command. */ + argv = moduleCreateArgvFromUserFormat(cmdname, fmt, &argc, &flags, ap); + if (argv == NULL) return VALKEYMODULE_ERR; + /* Select the propagation target. Usually is AOF + replicas, however + * the caller can exclude one or the other using the "A" or "R" + * modifiers. */ + int target = 0; + if (!(flags & VALKEYMODULE_ARGV_NO_AOF)) target |= PROPAGATE_AOF; + if (!(flags & VALKEYMODULE_ARGV_NO_REPLICAS)) target |= PROPAGATE_REPL; + alsoPropagate(ctx->client->db->id, argv, argc, target); + /* Release the argv. */ + for (j = 0; j < argc; j++) decrRefCount(argv[j]); + zfree(argv); + server.dirty++; + return VALKEYMODULE_OK; +} + /* Returns the name of the key currently being processed. * There is no guarantee that the key name is always available, so this may return NULL. */ @@ -13635,6 +13652,7 @@ void moduleRegisterCoreAPI(void) { REGISTER_API(StringPtrLen); REGISTER_API(AutoMemory); REGISTER_API(Replicate); + REGISTER_API(ReplicateWithFlag); REGISTER_API(ReplicateVerbatim); REGISTER_API(DeleteKey); REGISTER_API(UnlinkKey); diff --git a/src/valkeymodule.h b/src/valkeymodule.h index c2cdb2f0e7..be3e85642a 100644 --- a/src/valkeymodule.h +++ b/src/valkeymodule.h @@ -782,6 +782,11 @@ typedef enum { VALKEYMODULE_ACL_LOG_CHANNEL /* Channel authorization failure */ } ValkeyModuleACLLogEntryReason; +typedef enum { + VALKEYMODULE_FLAG_DEFAULT = 0, /* Default behavior */ + VALKEYMODULE_FLAG_SKIP_VALIDATION, /* Skip validation */ +} ValkeyModuleFlag; + /* Incomplete structures needed by both the core and modules. */ typedef struct ValkeyModuleIO ValkeyModuleIO; typedef struct ValkeyModuleDigest ValkeyModuleDigest; @@ -1092,6 +1097,8 @@ VALKEYMODULE_API int (*ValkeyModule_StringToStreamID)(const ValkeyModuleString * VALKEYMODULE_API void (*ValkeyModule_AutoMemory)(ValkeyModuleCtx *ctx) VALKEYMODULE_ATTR; VALKEYMODULE_API int (*ValkeyModule_Replicate)(ValkeyModuleCtx *ctx, const char *cmdname, const char *fmt, ...) VALKEYMODULE_ATTR; +VALKEYMODULE_API int (*ValkeyModule_ReplicateWithFlag)(ValkeyModuleCtx *ctx, ValkeyModuleFlag flag, const char *cmdname, const char *fmt, ...) + VALKEYMODULE_ATTR; VALKEYMODULE_API int (*ValkeyModule_ReplicateVerbatim)(ValkeyModuleCtx *ctx) VALKEYMODULE_ATTR; VALKEYMODULE_API const char *(*ValkeyModule_CallReplyStringPtr)(ValkeyModuleCallReply *reply, size_t *len)VALKEYMODULE_ATTR; @@ -1750,6 +1757,7 @@ static int ValkeyModule_Init(ValkeyModuleCtx *ctx, const char *name, int ver, in VALKEYMODULE_GET_API(StringPtrLen); VALKEYMODULE_GET_API(AutoMemory); VALKEYMODULE_GET_API(Replicate); + VALKEYMODULE_GET_API(ReplicateWithFlag); VALKEYMODULE_GET_API(ReplicateVerbatim); VALKEYMODULE_GET_API(DeleteKey); VALKEYMODULE_GET_API(UnlinkKey); diff --git a/tests/modules/propagate.c b/tests/modules/propagate.c index b3cd279e5a..72684317aa 100644 --- a/tests/modules/propagate.c +++ b/tests/modules/propagate.c @@ -250,7 +250,8 @@ int propagateTestSimpleCommand(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, /* Replicate two commands to test MULTI/EXEC wrapping. */ ValkeyModule_Replicate(ctx,"INCR","c","counter-1"); - ValkeyModule_Replicate(ctx,"INCR","c","counter-2"); + ValkeyModule_ReplicateWithFlag(ctx, VALKEYMODULE_FLAG_SKIP_VALIDATION, "INCR", + "c", "counter-2"); ValkeyModule_ReplyWithSimpleString(ctx,"OK"); return VALKEYMODULE_OK; } @@ -266,8 +267,8 @@ int propagateTestMixedCommand(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, i ValkeyModule_FreeCallReply(reply); ValkeyModule_Replicate(ctx,"INCR","c","counter-1"); - ValkeyModule_Replicate(ctx,"INCR","c","counter-2"); - + ValkeyModule_ReplicateWithFlag(ctx, VALKEYMODULE_FLAG_SKIP_VALIDATION, "INCR", + "c", "counter-2"); reply = ValkeyModule_Call(ctx, "INCR", "c!", "after-call"); ValkeyModule_FreeCallReply(reply); @@ -275,6 +276,19 @@ int propagateTestMixedCommand(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, i return VALKEYMODULE_OK; } +int propagateTestInvalidCommand(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, + int argc) { + VALKEYMODULE_NOT_USED(argv); + VALKEYMODULE_NOT_USED(argc); + /* Replicate two commands to test MULTI/EXEC wrapping. */ + ValkeyModule_ReplicateWithFlag(ctx, VALKEYMODULE_FLAG_SKIP_VALIDATION, "INVALID", + "c", "counter-1"); + ValkeyModule_ReplicateWithFlag(ctx, VALKEYMODULE_FLAG_SKIP_VALIDATION, "INVALID", + "c", "counter-2"); + ValkeyModule_ReplyWithSimpleString(ctx, "OK"); + return VALKEYMODULE_OK; +} + int propagateTestNestedCommand(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int argc) { VALKEYMODULE_NOT_USED(argv); @@ -380,6 +394,11 @@ int ValkeyModule_OnLoad(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int arg "write",1,1,1) == VALKEYMODULE_ERR) return VALKEYMODULE_ERR; + if (ValkeyModule_CreateCommand(ctx, "propagate-test.invalid", + propagateTestInvalidCommand, "", 1, 1, + 1) == VALKEYMODULE_ERR) + return VALKEYMODULE_ERR; + if (ValkeyModule_CreateCommand(ctx,"propagate-test.nested", propagateTestNestedCommand, "write",1,1,1) == VALKEYMODULE_ERR) diff --git a/tests/unit/moduleapi/propagate.tcl b/tests/unit/moduleapi/propagate.tcl index 89f79998f0..6920f698a6 100644 --- a/tests/unit/moduleapi/propagate.tcl +++ b/tests/unit/moduleapi/propagate.tcl @@ -676,6 +676,32 @@ tags "modules" { } } +tags "modules" { + start_server [list overrides [list loadmodule "$testmodule"]] { + set replica [srv 0 client] + set replica_host [srv 0 host] + set replica_port [srv 0 port] + start_server [list overrides [list loadmodule "$testmodule"]] { + set master [srv 0 client] + set master_host [srv 0 host] + set master_port [srv 0 port] + # Start the replication process... + $replica replicaof $master_host $master_port + wait_for_sync $replica + after 1000 + test {module crash when propagating invalid command} { + $master propagate-test.invalid + catch {wait_for_sync $replica} + + wait_for_log_messages -1 {"*=== * BUG REPORT START: Cut & paste starting from here ===*"} 0 10 1000 + wait_for_log_messages -1 {"* This replica panicked sending an error to its primary after processing the command '' *"} 0 10 1000 + + assert_equal 1 [count_log_message -1 "=== .* BUG REPORT START: Cut & paste starting from here ==="] + assert_equal 1 [count_log_message -1 "This replica panicked sending an error to its primary after processing the command ''"] + } + } + } +} tags "modules aof" { foreach aofload_type {debug_cmd startup} {