Skip to content

Commit

Permalink
Adds support for scripting engines as Valkey modules
Browse files Browse the repository at this point in the history
This commit extends the module API to support the addition of different
scripting engines to run user defined functions.

The scripting engine can be implemented as a Valkey module, and can be
dynamically loaded with the `loadmodule` config directive, or with
the `MODULE LOAD` command.

This commit also adds an example of a dummy scripting engine module,
to show how to use the new module API.

The current module API support, only allows to load scripting engines to
run functions using `FCALL` command.

In a follow up PR, we will move the Lua scripting engine implmentation
into its own module.

Signed-off-by: Ricardo Dias <[email protected]>
  • Loading branch information
rjd15372 committed Nov 11, 2024
1 parent 45d596e commit c7afe80
Show file tree
Hide file tree
Showing 11 changed files with 590 additions and 100 deletions.
2 changes: 1 addition & 1 deletion src/aof.c
Original file line number Diff line number Diff line change
Expand Up @@ -2175,7 +2175,7 @@ static int rewriteFunctions(rio *aof) {
dictIterator *iter = dictGetIterator(functions);
dictEntry *entry = NULL;
while ((entry = dictNext(iter))) {
functionLibInfo *li = dictGetVal(entry);
ValkeyModuleScriptingEngineFunctionLibrary *li = dictGetVal(entry);
if (rioWrite(aof, "*3\r\n", 4) == 0) goto werr;
char function_load[] = "$8\r\nFUNCTION\r\n$4\r\nLOAD\r\n";
if (rioWrite(aof, function_load, sizeof(function_load) - 1) == 0) goto werr;
Expand Down
30 changes: 14 additions & 16 deletions src/function_lua.c
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ typedef struct luaFunctionCtx {
} luaFunctionCtx;

typedef struct loadCtx {
functionLibInfo *li;
ValkeyModuleScriptingEngineFunctionLibrary *li;
monotime start_time;
size_t timeout;
} loadCtx;
Expand Down Expand Up @@ -100,7 +100,7 @@ static void luaEngineLoadHook(lua_State *lua, lua_Debug *ar) {
*
* Return NULL on compilation error and set the error to the err variable
*/
static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, size_t timeout, sds *err) {
static int luaEngineCreate(void *engine_ctx, ValkeyModuleScriptingEngineFunctionLibrary *li, const char *blob, size_t timeout, char **err) {
int ret = C_ERR;
luaEngineCtx *lua_engine_ctx = engine_ctx;
lua_State *lua = lua_engine_ctx->lua;
Expand All @@ -114,7 +114,7 @@ static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, size
lua_pop(lua, 1); /* pop the metatable */

/* compile the code */
if (luaL_loadbuffer(lua, blob, sdslen(blob), "@user_function")) {
if (luaL_loadbuffer(lua, blob, strlen(blob), "@user_function")) {
*err = sdscatprintf(sdsempty(), "Error compiling function: %s", lua_tostring(lua, -1));
lua_pop(lua, 1); /* pops the error */
goto done;
Expand Down Expand Up @@ -158,7 +158,7 @@ static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, size
/*
* Invole the give function with the given keys and args
*/
static void luaEngineCall(scriptRunCtx *run_ctx,
static void luaEngineCall(ValkeyModuleScriptingEngineFunctionCallCtx *func_ctx,
void *engine_ctx,
void *compiled_function,
robj **keys,
Expand All @@ -177,6 +177,7 @@ static void luaEngineCall(scriptRunCtx *run_ctx,

serverAssert(lua_isfunction(lua, -1));

scriptRunCtx *run_ctx = moduleGetScriptRunCtxFromFunctionCtx(func_ctx);
luaCallFunction(run_ctx, lua, keys, nkeys, args, nargs, 0);
lua_pop(lua, 1); /* Pop error handler */
}
Expand Down Expand Up @@ -494,16 +495,13 @@ int luaEngineInitEngine(void) {
lua_enablereadonlytable(lua_engine_ctx->lua, -1, 1); /* protect the new global table */
lua_replace(lua_engine_ctx->lua, LUA_GLOBALSINDEX); /* set new global table as the new globals */


engine *lua_engine = zmalloc(sizeof(*lua_engine));
*lua_engine = (engine){
.engine_ctx = lua_engine_ctx,
.create = luaEngineCreate,
.call = luaEngineCall,
.get_used_memory = luaEngineGetUsedMemoy,
.get_function_memory_overhead = luaEngineFunctionMemoryOverhead,
.get_engine_memory_overhead = luaEngineMemoryOverhead,
.free_function = luaEngineFreeFunction,
};
return functionsRegisterEngine(LUA_ENGINE_NAME, lua_engine);
return functionsRegisterEngine(LUA_ENGINE_NAME,
NULL,
lua_engine_ctx,
luaEngineCreate,
luaEngineCall,
luaEngineGetUsedMemoy,
luaEngineFunctionMemoryOverhead,
luaEngineMemoryOverhead,
luaEngineFreeFunction);
}
155 changes: 96 additions & 59 deletions src/functions.c
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ static size_t functionMallocSize(functionInfo *fi) {
fi->li->ei->engine->get_function_memory_overhead(fi->function);
}

static size_t libraryMallocSize(functionLibInfo *li) {
static size_t libraryMallocSize(ValkeyModuleScriptingEngineFunctionLibrary *li) {
return zmalloc_size(li) + sdsAllocSize(li->name) + sdsAllocSize(li->code);
}

Expand All @@ -148,7 +148,7 @@ static void engineFunctionDispose(dict *d, void *obj) {
zfree(fi);
}

static void engineLibraryFree(functionLibInfo *li) {
static void engineLibraryFree(ValkeyModuleScriptingEngineFunctionLibrary *li) {
if (!li) {
return;
}
Expand Down Expand Up @@ -227,6 +227,15 @@ functionsLibCtx *functionsLibCtxCreate(void) {
return ret;
}

void functionsAddEngineStats(engineInfo *ei) {
serverAssert(curr_functions_lib_ctx != NULL);
dictEntry *entry = dictFind(curr_functions_lib_ctx->engines_stats, ei->name);
if (entry == NULL) {
functionsLibEngineStats *stats = zcalloc(sizeof(*stats));
dictAdd(curr_functions_lib_ctx->engines_stats, ei->name, stats);
}
}

/*
* Creating a function inside the given library.
* On success, return C_OK.
Expand All @@ -236,7 +245,7 @@ functionsLibCtx *functionsLibCtxCreate(void) {
* the function will verify that the given name is following the naming format
* and return an error if its not.
*/
int functionLibCreateFunction(sds name, void *function, functionLibInfo *li, sds desc, uint64_t f_flags, sds *err) {
int functionLibCreateFunction(sds name, void *function, ValkeyModuleScriptingEngineFunctionLibrary *li, sds desc, uint64_t f_flags, sds *err) {
if (functionsVerifyName(name) != C_OK) {
*err = sdsnew("Library names can only contain letters, numbers, or underscores(_) and must be at least one "
"character long");
Expand All @@ -263,9 +272,9 @@ int functionLibCreateFunction(sds name, void *function, functionLibInfo *li, sds
return C_OK;
}

static functionLibInfo *engineLibraryCreate(sds name, engineInfo *ei, sds code) {
functionLibInfo *li = zmalloc(sizeof(*li));
*li = (functionLibInfo){
static ValkeyModuleScriptingEngineFunctionLibrary *engineLibraryCreate(sds name, engineInfo *ei, sds code) {
ValkeyModuleScriptingEngineFunctionLibrary *li = zmalloc(sizeof(*li));
*li = (ValkeyModuleScriptingEngineFunctionLibrary){
.name = sdsdup(name),
.functions = dictCreate(&libraryFunctionDictType),
.ei = ei,
Expand All @@ -274,7 +283,7 @@ static functionLibInfo *engineLibraryCreate(sds name, engineInfo *ei, sds code)
return li;
}

static void libraryUnlink(functionsLibCtx *lib_ctx, functionLibInfo *li) {
static void libraryUnlink(functionsLibCtx *lib_ctx, ValkeyModuleScriptingEngineFunctionLibrary *li) {
dictIterator *iter = dictGetIterator(li->functions);
dictEntry *entry = NULL;
while ((entry = dictNext(iter))) {
Expand All @@ -296,7 +305,7 @@ static void libraryUnlink(functionsLibCtx *lib_ctx, functionLibInfo *li) {
stats->n_functions -= dictSize(li->functions);
}

static void libraryLink(functionsLibCtx *lib_ctx, functionLibInfo *li) {
static void libraryLink(functionsLibCtx *lib_ctx, ValkeyModuleScriptingEngineFunctionLibrary *li) {
dictIterator *iter = dictGetIterator(li->functions);
dictEntry *entry = NULL;
while ((entry = dictNext(iter))) {
Expand Down Expand Up @@ -332,8 +341,8 @@ libraryJoin(functionsLibCtx *functions_lib_ctx_dst, functionsLibCtx *functions_l
dictEntry *entry = NULL;
iter = dictGetIterator(functions_lib_ctx_src->libraries);
while ((entry = dictNext(iter))) {
functionLibInfo *li = dictGetVal(entry);
functionLibInfo *old_li = dictFetchValue(functions_lib_ctx_dst->libraries, li->name);
ValkeyModuleScriptingEngineFunctionLibrary *li = dictGetVal(entry);
ValkeyModuleScriptingEngineFunctionLibrary *old_li = dictFetchValue(functions_lib_ctx_dst->libraries, li->name);
if (old_li) {
if (!replace) {
/* library already exists, failed the restore. */
Expand Down Expand Up @@ -367,7 +376,7 @@ libraryJoin(functionsLibCtx *functions_lib_ctx_dst, functionsLibCtx *functions_l
/* No collision, it is safe to link all the new libraries. */
iter = dictGetIterator(functions_lib_ctx_src->libraries);
while ((entry = dictNext(iter))) {
functionLibInfo *li = dictGetVal(entry);
ValkeyModuleScriptingEngineFunctionLibrary *li = dictGetVal(entry);
libraryLink(functions_lib_ctx_dst, li);
dictSetVal(functions_lib_ctx_src->libraries, entry, NULL);
}
Expand All @@ -387,7 +396,7 @@ libraryJoin(functionsLibCtx *functions_lib_ctx_dst, functionsLibCtx *functions_l
/* Link back all libraries on tmp_l_ctx */
while (listLength(old_libraries_list) > 0) {
listNode *head = listFirst(old_libraries_list);
functionLibInfo *li = listNodeValue(head);
ValkeyModuleScriptingEngineFunctionLibrary *li = listNodeValue(head);
listNodeValue(head) = NULL;
libraryLink(functions_lib_ctx_dst, li);
listDelNode(old_libraries_list, head);
Expand All @@ -400,34 +409,99 @@ libraryJoin(functionsLibCtx *functions_lib_ctx_dst, functionsLibCtx *functions_l
/* Register an engine, should be called once by the engine on startup and give the following:
*
* - engine_name - name of the engine to register
* - engine_module - the valkey module that implements this engine
* - engine_ctx - the engine ctx that should be used by the server to interact with the engine */
int functionsRegisterEngine(const char *engine_name, engine *engine) {
int functionsRegisterEngine(const char *engine_name,
ValkeyModule *engine_module,
void *engine_ctx,
ValkeyModuleCreateScriptingEngineFunc create_func,
ValkeyModuleCallScriptingEngineFunctionFunc call_func,
ValkeyModuleScriptingEngineGetUsedMemoryFunc get_used_memory_func,
ValkeyModuleScriptingEngineGetFunctionMemoryOverheadFunc get_function_memory_overhead_func,
ValkeyModuleScriptingEngineGetEngineMemoryOverheadFunc get_engine_memory_overhead_func,
ValkeyModuleScriptingEngineFreeFunctionFunc free_function_func) {
sds engine_name_sds = sdsnew(engine_name);
if (dictFetchValue(engines, engine_name_sds)) {
serverLog(LL_WARNING, "Same engine was registered twice");
sdsfree(engine_name_sds);
return C_ERR;
}

engine *engine = zmalloc(sizeof(struct engine));
*engine = (struct engine) {
.engine_ctx = engine_ctx,
.create = create_func,
.call = call_func,
.get_used_memory = get_used_memory_func,
.get_function_memory_overhead = get_function_memory_overhead_func,
.get_engine_memory_overhead = get_engine_memory_overhead_func,
.free_function = free_function_func,
};

client *c = createClient(NULL);
c->flag.deny_blocking = 1;
c->flag.script = 1;
c->flag.fake = 1;
engineInfo *ei = zmalloc(sizeof(*ei));
*ei = (engineInfo){
.name = engine_name_sds,
.engineModule = engine_module,
.engine = engine,
.c = c,
};

dictAdd(engines, engine_name_sds, ei);

functionsAddEngineStats(ei);

engine_cache_memory += zmalloc_size(ei) + sdsAllocSize(ei->name) + zmalloc_size(engine) +
engine->get_engine_memory_overhead(engine->engine_ctx);

return C_OK;
}

/* Removes an egine fomr the server.
*
*/
int functionsUnregisterEngine(const char *engine_name) {
sds engine_name_sds = sdsnew(engine_name);
dictEntry *entry = dictFind(engines, engine_name_sds);
if (entry == NULL) {
serverLog(LL_WARNING, "There's no engine registered with name %s", engine_name);
sdsfree(engine_name_sds);
return C_ERR;
}

engineInfo* ei = dictGetVal(entry);

ValkeyModuleScriptingEngineFunctionLibrary *li = NULL;

dictIterator *iter = dictGetIterator(curr_functions_lib_ctx->libraries);
while ((entry = dictNext(iter))) {
li = dictGetVal(entry);
if (li->ei == ei) {
break;
// int ret = dictDelete(curr_functions_lib_ctx->libraries, li->name);
// serverAssert(ret == DICT_OK);
}
}
dictReleaseIterator(iter);

libraryUnlink(curr_functions_lib_ctx, li);
engineLibraryFree(li);

int ret = dictDelete(engines, engine_name_sds);
serverAssert(ret == DICT_OK);

zfree(ei->engine);
sdsfree(ei->name);
freeClient(ei->c);
zfree(ei);

sdsfree(engine_name_sds);
return C_OK;
}

/*
* FUNCTION STATS
*/
Expand Down Expand Up @@ -535,7 +609,7 @@ void functionListCommand(client *c) {
dictIterator *iter = dictGetIterator(curr_functions_lib_ctx->libraries);
dictEntry *entry = NULL;
while ((entry = dictNext(iter))) {
functionLibInfo *li = dictGetVal(entry);
ValkeyModuleScriptingEngineFunctionLibrary *li = dictGetVal(entry);
if (library_name) {
if (!stringmatchlen(library_name, sdslen(library_name), li->name, sdslen(li->name), 1)) {
continue;
Expand Down Expand Up @@ -584,7 +658,7 @@ void functionListCommand(client *c) {
*/
void functionDeleteCommand(client *c) {
robj *function_name = c->argv[2];
functionLibInfo *li = dictFetchValue(curr_functions_lib_ctx->libraries, function_name->ptr);
ValkeyModuleScriptingEngineFunctionLibrary *li = dictFetchValue(curr_functions_lib_ctx->libraries, function_name->ptr);
if (!li) {
addReplyError(c, "Library not found");
return;
Expand Down Expand Up @@ -614,55 +688,18 @@ uint64_t fcallGetCommandFlags(client *c, uint64_t cmd_flags) {
return scriptFlagsToCmdFlags(cmd_flags, script_flags);
}

static void fcallCommandGeneric(client *c, int ro) {
/* Functions need to be fed to monitors before the commands they execute. */
replicationFeedMonitors(c, server.monitors, c->db->id, c->argv, c->argc);

robj *function_name = c->argv[1];
dictEntry *de = c->cur_script;
if (!de) de = dictFind(curr_functions_lib_ctx->functions, function_name->ptr);
if (!de) {
addReplyError(c, "Function not found");
return;
}
functionInfo *fi = dictGetVal(de);
engine *engine = fi->li->ei->engine;

long long numkeys;
/* Get the number of arguments that are keys */
if (getLongLongFromObject(c->argv[2], &numkeys) != C_OK) {
addReplyError(c, "Bad number of keys provided");
return;
}
if (numkeys > (c->argc - 3)) {
addReplyError(c, "Number of keys can't be greater than number of args");
return;
} else if (numkeys < 0) {
addReplyError(c, "Number of keys can't be negative");
return;
}

scriptRunCtx run_ctx;

if (scriptPrepareForRun(&run_ctx, fi->li->ei->c, c, fi->name, fi->f_flags, ro) != C_OK) return;

engine->call(&run_ctx, engine->engine_ctx, fi->function, c->argv + 3, numkeys, c->argv + 3 + numkeys,
c->argc - 3 - numkeys);
scriptResetRun(&run_ctx);
}

/*
* FCALL <FUNCTION NAME> nkeys <key1 .. keyn> <arg1 .. argn>
*/
void fcallCommand(client *c) {
fcallCommandGeneric(c, 0);
fcallCommandGeneric(curr_functions_lib_ctx->functions, c, 0);
}

/*
* FCALL_RO <FUNCTION NAME> nkeys <key1 .. keyn> <arg1 .. argn>
*/
void fcallroCommand(client *c) {
fcallCommandGeneric(c, 1);
fcallCommandGeneric(curr_functions_lib_ctx->functions, c, 1);
}

/*
Expand Down Expand Up @@ -952,9 +989,10 @@ void functionFreeLibMetaData(functionsLibMetaData *md) {
sds functionsCreateWithLibraryCtx(sds code, int replace, sds *err, functionsLibCtx *lib_ctx, size_t timeout) {
dictIterator *iter = NULL;
dictEntry *entry = NULL;
functionLibInfo *new_li = NULL;
functionLibInfo *old_li = NULL;
ValkeyModuleScriptingEngineFunctionLibrary *old_li = NULL;
functionsLibMetaData md = {0};
ValkeyModuleScriptingEngineFunctionLibrary *new_li = NULL;

if (functionExtractLibMetaData(code, &md, err) != C_OK) {
return NULL;
}
Expand Down Expand Up @@ -1114,12 +1152,11 @@ size_t functionsLibCtxFunctionsLen(functionsLibCtx *functions_ctx) {
int functionsInit(void) {
engines = dictCreate(&engineDictType);

curr_functions_lib_ctx = functionsLibCtxCreate();

if (luaEngineInitEngine() != C_OK) {
return C_ERR;
}

/* Must be initialized after engines initialization */
curr_functions_lib_ctx = functionsLibCtxCreate();

return C_OK;
}
Loading

0 comments on commit c7afe80

Please sign in to comment.