Skip to content

Commit

Permalink
Add ability for modules to know which client is being cmd filtered (#…
Browse files Browse the repository at this point in the history
…12219)

Adds API
- RedisModule_CommandFilterGetClientId()

Includes addition to commandfilter test module to validate that it works
by performing the same command from 2 different clients
  • Loading branch information
sjpotter authored Jun 20, 2023
1 parent cd4f3e2 commit 07316f1
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ typedef struct RedisModuleCommandFilterCtx {
RedisModuleString **argv;
int argv_len;
int argc;
client *c;
} RedisModuleCommandFilterCtx;

typedef void (*RedisModuleCommandFilterFunc) (RedisModuleCommandFilterCtx *filter);
Expand Down Expand Up @@ -10645,7 +10646,8 @@ void moduleCallCommandFilters(client *c) {
RedisModuleCommandFilterCtx filter = {
.argv = c->argv,
.argv_len = c->argv_len,
.argc = c->argc
.argc = c->argc,
.c = c
};

while((ln = listNext(&li))) {
Expand Down Expand Up @@ -10738,6 +10740,11 @@ int RM_CommandFilterArgDelete(RedisModuleCommandFilterCtx *fctx, int pos)
return REDISMODULE_OK;
}

/* Get Client ID for client that issued the command we are filtering */
unsigned long long RM_CommandFilterGetClientId(RedisModuleCommandFilterCtx *fctx) {
return fctx->c->id;
}

/* For a given pointer allocated via RedisModule_Alloc() or
* RedisModule_Realloc(), return the amount of memory allocated for it.
* Note that this may be different (larger) than the memory we allocated
Expand Down Expand Up @@ -13722,6 +13729,7 @@ void moduleRegisterCoreAPI(void) {
REGISTER_API(CommandFilterArgInsert);
REGISTER_API(CommandFilterArgReplace);
REGISTER_API(CommandFilterArgDelete);
REGISTER_API(CommandFilterGetClientId);
REGISTER_API(Fork);
REGISTER_API(SendChildHeartbeat);
REGISTER_API(ExitFromChild);
Expand Down
2 changes: 2 additions & 0 deletions src/redismodule.h
Original file line number Diff line number Diff line change
Expand Up @@ -1259,6 +1259,7 @@ REDISMODULE_API RedisModuleString * (*RedisModule_CommandFilterArgGet)(RedisModu
REDISMODULE_API int (*RedisModule_CommandFilterArgInsert)(RedisModuleCommandFilterCtx *fctx, int pos, RedisModuleString *arg) REDISMODULE_ATTR;
REDISMODULE_API int (*RedisModule_CommandFilterArgReplace)(RedisModuleCommandFilterCtx *fctx, int pos, RedisModuleString *arg) REDISMODULE_ATTR;
REDISMODULE_API int (*RedisModule_CommandFilterArgDelete)(RedisModuleCommandFilterCtx *fctx, int pos) REDISMODULE_ATTR;
REDISMODULE_API unsigned long long (*RedisModule_CommandFilterGetClientId)(RedisModuleCommandFilterCtx *fctx) REDISMODULE_ATTR;
REDISMODULE_API int (*RedisModule_Fork)(RedisModuleForkDoneHandler cb, void *user_data) REDISMODULE_ATTR;
REDISMODULE_API void (*RedisModule_SendChildHeartbeat)(double progress) REDISMODULE_ATTR;
REDISMODULE_API int (*RedisModule_ExitFromChild)(int retcode) REDISMODULE_ATTR;
Expand Down Expand Up @@ -1619,6 +1620,7 @@ static int RedisModule_Init(RedisModuleCtx *ctx, const char *name, int ver, int
REDISMODULE_GET_API(CommandFilterArgInsert);
REDISMODULE_GET_API(CommandFilterArgReplace);
REDISMODULE_GET_API(CommandFilterArgDelete);
REDISMODULE_GET_API(CommandFilterGetClientId);
REDISMODULE_GET_API(Fork);
REDISMODULE_GET_API(SendChildHeartbeat);
REDISMODULE_GET_API(ExitFromChild);
Expand Down
30 changes: 30 additions & 0 deletions tests/modules/commandfilter.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ static const char log_command_name[] = "commandfilter.log";
static const char ping_command_name[] = "commandfilter.ping";
static const char retained_command_name[] = "commandfilter.retained";
static const char unregister_command_name[] = "commandfilter.unregister";
static const char unfiltered_clientid_name[] = "unfilter_clientid";
static int in_log_command = 0;

unsigned long long unfiltered_clientid = 0;

static RedisModuleCommandFilter *filter, *filter1;
static RedisModuleString *retained;

Expand Down Expand Up @@ -89,6 +92,26 @@ int CommandFilter_LogCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int
return REDISMODULE_OK;
}

int CommandFilter_UnfilteredClientdId(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
{
if (argc < 2)
return RedisModule_WrongArity(ctx);

long long id;
if (RedisModule_StringToLongLong(argv[1], &id) != REDISMODULE_OK) {
RedisModule_ReplyWithError(ctx, "invalid client id");
return REDISMODULE_OK;
}
if (id < 0) {
RedisModule_ReplyWithError(ctx, "invalid client id");
return REDISMODULE_OK;
}

unfiltered_clientid = id;
RedisModule_ReplyWithSimpleString(ctx, "OK");
return REDISMODULE_OK;
}

/* Filter to protect against Bug #11894 reappearing
*
* ensures that the filter is only run the first time through, and not on reprocessing
Expand Down Expand Up @@ -117,6 +140,9 @@ void CommandFilter_BlmoveSwap(RedisModuleCommandFilterCtx *filter)

void CommandFilter_CommandFilter(RedisModuleCommandFilterCtx *filter)
{
unsigned long long id = RedisModule_CommandFilterGetClientId(filter);
if (id == unfiltered_clientid) return;

if (in_log_command) return; /* don't process our own RM_Call() from CommandFilter_LogCommand() */

/* Fun manipulations:
Expand Down Expand Up @@ -192,6 +218,10 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
CommandFilter_UnregisterCommand,"write deny-oom",1,1,1) == REDISMODULE_ERR)
return REDISMODULE_ERR;

if (RedisModule_CreateCommand(ctx, unfiltered_clientid_name,
CommandFilter_UnfilteredClientdId, "admin", 1,1,1) == REDISMODULE_ERR)
return REDISMODULE_ERR;

if ((filter = RedisModule_RegisterCommandFilter(ctx, CommandFilter_CommandFilter,
noself ? REDISMODULE_CMDFILTER_NOSELF : 0))
== NULL) return REDISMODULE_ERR;
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/moduleapi/commandfilter.tcl
Original file line number Diff line number Diff line change
Expand Up @@ -139,5 +139,27 @@ test {Blocking Commands don't run through command filter when reprocessed} {
assert_equal [$rd read] 1
# validate that we moved the correct elements to the correct side of the list
assert_equal [r lpop list2{t}] 1

$rd close
}
}

test {Filtering based on client id} {
start_server {tags {"modules"}} {
r module load $testmodule log-key 0

set rr [redis_client]
set cid [$rr client id]
r unfilter_clientid $cid

r rpush mylist elem1 @replaceme elem2
assert_equal [r lrange mylist 0 -1] {elem1 --replaced-- elem2}

r del mylist

assert_equal [$rr rpush mylist elem1 @replaceme elem2] 3
assert_equal [r lrange mylist 0 -1] {elem1 @replaceme elem2}

$rr close
}
}

0 comments on commit 07316f1

Please sign in to comment.