Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for scripting engines as Valkey modules #1277

Open
wants to merge 18 commits into
base: unstable
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 89 additions & 70 deletions src/function_lua.c
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,13 @@ typedef struct luaFunctionCtx {
} luaFunctionCtx;

typedef struct loadCtx {
functionLibInfo *li;
list *functions;
monotime start_time;
size_t timeout;
} loadCtx;

typedef struct registerFunctionArgs {
sds name;
sds desc;
luaFunctionCtx *lua_f_ctx;
uint64_t f_flags;
} registerFunctionArgs;
static void luaEngineFreeFunction(ValkeyModuleScriptingEngineCtx *engine_ctx,
void *compiled_function);

/* Hook for FUNCTION LOAD execution.
* Used to cancel the execution in case of a timeout (500ms).
Expand All @@ -93,15 +89,30 @@ static void luaEngineLoadHook(lua_State *lua, lua_Debug *ar) {
}
}

static void freeCompiledFunc(luaEngineCtx *lua_engine_ctx, void *compiled_func) {
ValkeyModuleScriptingEngineCompiledFunction *func = compiled_func;
zfree(func->name);
if (func->desc) {
zfree(func->desc);
}
luaEngineFreeFunction(lua_engine_ctx, func->function);
zfree(func);
}

/*
* Compile a given blob and save it on the registry.
* Return a function ctx with Lua ref that allows to later retrieve the
* function from the registry.
*
* Return NULL on compilation error and set the error to the err variable
*/
static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, size_t timeout, sds *err) {
int ret = C_ERR;
static ValkeyModuleScriptingEngineCompiledFunction **luaEngineCreate(
ValkeyModuleScriptingEngineCtx *engine_ctx,
const char *blob,
size_t timeout,
size_t *out_num_compiled_functions,
char **err) {
ValkeyModuleScriptingEngineCompiledFunction **compiled_functions = NULL;
luaEngineCtx *lua_engine_ctx = engine_ctx;
lua_State *lua = lua_engine_ctx->lua;

Expand All @@ -114,15 +125,15 @@ static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, size
lua_pop(lua, 1); /* pop the metatable */

/* compile the code */
if (luaL_loadbuffer(lua, blob, sdslen(blob), "@user_function")) {
*err = sdscatprintf(sdsempty(), "Error compiling function: %s", lua_tostring(lua, -1));
if (luaL_loadbuffer(lua, blob, strlen(blob), "@user_function")) {
*err = valkey_asprintf("Error compiling function: %s", lua_tostring(lua, -1));
rjd15372 marked this conversation as resolved.
Show resolved Hide resolved
lua_pop(lua, 1); /* pops the error */
goto done;
}
serverAssert(lua_isfunction(lua, -1));

loadCtx load_ctx = {
.li = li,
.functions = listCreate(),
.start_time = getMonotonicUs(),
.timeout = timeout,
};
Expand All @@ -133,13 +144,31 @@ static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, size
if (lua_pcall(lua, 0, 0, 0)) {
errorInfo err_info = {0};
luaExtractErrorInformation(lua, &err_info);
*err = sdscatprintf(sdsempty(), "Error registering functions: %s", err_info.msg);
*err = valkey_asprintf("Error registering functions: %s", err_info.msg);
lua_pop(lua, 1); /* pops the error */
luaErrorInformationDiscard(&err_info);
listIter *iter = listGetIterator(load_ctx.functions, AL_START_HEAD);
listNode *node = NULL;
while ((node = listNext(iter)) != NULL) {
freeCompiledFunc(lua_engine_ctx, listNodeValue(node));
}
listReleaseIterator(iter);
listRelease(load_ctx.functions);
goto done;
}

ret = C_OK;
compiled_functions =
zcalloc(sizeof(ValkeyModuleScriptingEngineCompiledFunction *) * listLength(load_ctx.functions));
listIter *iter = listGetIterator(load_ctx.functions, AL_START_HEAD);
listNode *node = NULL;
*out_num_compiled_functions = 0;
while ((node = listNext(iter)) != NULL) {
ValkeyModuleScriptingEngineCompiledFunction *func = listNodeValue(node);
compiled_functions[*out_num_compiled_functions] = func;
(*out_num_compiled_functions)++;
}
listReleaseIterator(iter);
listRelease(load_ctx.functions);

done:
/* restore original globals */
Expand All @@ -152,19 +181,22 @@ static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, size

lua_sethook(lua, NULL, 0, 0); /* Disable hook */
luaSaveOnRegistry(lua, REGISTRY_LOAD_CTX_NAME, NULL);
return ret;
return compiled_functions;
}

/*
* Invole the give function with the given keys and args
*/
static void luaEngineCall(scriptRunCtx *run_ctx,
void *engine_ctx,
static void luaEngineCall(ValkeyModuleCtx *module_ctx,
ValkeyModuleScriptingEngineCtx *engine_ctx,
ValkeyModuleScriptingEngineFunctionCtx *func_ctx,
void *compiled_function,
robj **keys,
size_t nkeys,
robj **args,
size_t nargs) {
serverAssert(module_ctx == NULL);

luaEngineCtx *lua_engine_ctx = engine_ctx;
lua_State *lua = lua_engine_ctx->lua;
luaFunctionCtx *f_ctx = compiled_function;
Expand All @@ -177,11 +209,12 @@ static void luaEngineCall(scriptRunCtx *run_ctx,

serverAssert(lua_isfunction(lua, -1));

scriptRunCtx *run_ctx = (scriptRunCtx *)func_ctx;
luaCallFunction(run_ctx, lua, keys, nkeys, args, nargs, 0);
lua_pop(lua, 1); /* Pop error handler */
}

static size_t luaEngineGetUsedMemoy(void *engine_ctx) {
static size_t luaEngineGetUsedMemoy(ValkeyModuleScriptingEngineCtx *engine_ctx) {
luaEngineCtx *lua_engine_ctx = engine_ctx;
return luaMemory(lua_engine_ctx->lua);
}
Expand All @@ -190,39 +223,33 @@ static size_t luaEngineFunctionMemoryOverhead(void *compiled_function) {
return zmalloc_size(compiled_function);
}

static size_t luaEngineMemoryOverhead(void *engine_ctx) {
static size_t luaEngineMemoryOverhead(ValkeyModuleScriptingEngineCtx *engine_ctx) {
luaEngineCtx *lua_engine_ctx = engine_ctx;
return zmalloc_size(lua_engine_ctx);
}

static void luaEngineFreeFunction(void *engine_ctx, void *compiled_function) {
static void luaEngineFreeFunction(ValkeyModuleScriptingEngineCtx *engine_ctx,
void *compiled_function) {
luaEngineCtx *lua_engine_ctx = engine_ctx;
lua_State *lua = lua_engine_ctx->lua;
luaFunctionCtx *f_ctx = compiled_function;
lua_unref(lua, f_ctx->lua_function_ref);
zfree(f_ctx);
}

static void luaRegisterFunctionArgsInitialize(registerFunctionArgs *register_f_args,
sds name,
sds desc,
static void luaRegisterFunctionArgsInitialize(ValkeyModuleScriptingEngineCompiledFunction *func,
char *name,
char *desc,
luaFunctionCtx *lua_f_ctx,
uint64_t flags) {
*register_f_args = (registerFunctionArgs){
*func = (ValkeyModuleScriptingEngineCompiledFunction){
.name = name,
.desc = desc,
.lua_f_ctx = lua_f_ctx,
.function = lua_f_ctx,
.f_flags = flags,
};
}

static void luaRegisterFunctionArgsDispose(lua_State *lua, registerFunctionArgs *register_f_args) {
sdsfree(register_f_args->name);
if (register_f_args->desc) sdsfree(register_f_args->desc);
lua_unref(lua, register_f_args->lua_f_ctx->lua_function_ref);
zfree(register_f_args->lua_f_ctx);
}

/* Read function flags located on the top of the Lua stack.
* On success, return C_OK and set the flags to 'flags' out parameter
* Return C_ERR if encounter an unknown flag. */
Expand Down Expand Up @@ -267,10 +294,10 @@ static int luaRegisterFunctionReadFlags(lua_State *lua, uint64_t *flags) {
return ret;
}

static int luaRegisterFunctionReadNamedArgs(lua_State *lua, registerFunctionArgs *register_f_args) {
static int luaRegisterFunctionReadNamedArgs(lua_State *lua, ValkeyModuleScriptingEngineCompiledFunction *func) {
char *err = NULL;
sds name = NULL;
sds desc = NULL;
char *name = NULL;
char *desc = NULL;
luaFunctionCtx *lua_f_ctx = NULL;
uint64_t flags = 0;
if (!lua_istable(lua, 1)) {
Expand All @@ -289,12 +316,12 @@ static int luaRegisterFunctionReadNamedArgs(lua_State *lua, registerFunctionArgs
}
const char *key = lua_tostring(lua, -2);
if (!strcasecmp(key, "function_name")) {
if (!(name = luaGetStringSds(lua, -1))) {
if (!(name = luaGetStringCStr(lua, -1))) {
err = "function_name argument given to server.register_function must be a string";
goto error;
}
} else if (!strcasecmp(key, "description")) {
if (!(desc = luaGetStringSds(lua, -1))) {
if (!(desc = luaGetStringCStr(lua, -1))) {
err = "description argument given to server.register_function must be a string";
goto error;
}
Expand Down Expand Up @@ -335,13 +362,13 @@ static int luaRegisterFunctionReadNamedArgs(lua_State *lua, registerFunctionArgs
goto error;
}

luaRegisterFunctionArgsInitialize(register_f_args, name, desc, lua_f_ctx, flags);
luaRegisterFunctionArgsInitialize(func, name, desc, lua_f_ctx, flags);

return C_OK;

error:
if (name) sdsfree(name);
if (desc) sdsfree(desc);
if (name) zfree(name);
if (desc) zfree(desc);
if (lua_f_ctx) {
lua_unref(lua, lua_f_ctx->lua_function_ref);
zfree(lua_f_ctx);
Expand All @@ -350,11 +377,11 @@ static int luaRegisterFunctionReadNamedArgs(lua_State *lua, registerFunctionArgs
return C_ERR;
}

static int luaRegisterFunctionReadPositionalArgs(lua_State *lua, registerFunctionArgs *register_f_args) {
static int luaRegisterFunctionReadPositionalArgs(lua_State *lua, ValkeyModuleScriptingEngineCompiledFunction *func) {
char *err = NULL;
sds name = NULL;
char *name = NULL;
luaFunctionCtx *lua_f_ctx = NULL;
if (!(name = luaGetStringSds(lua, 1))) {
if (!(name = luaGetStringCStr(lua, 1))) {
err = "first argument to server.register_function must be a string";
goto error;
}
Expand All @@ -369,51 +396,46 @@ static int luaRegisterFunctionReadPositionalArgs(lua_State *lua, registerFunctio
lua_f_ctx = zmalloc(sizeof(*lua_f_ctx));
lua_f_ctx->lua_function_ref = lua_function_ref;

luaRegisterFunctionArgsInitialize(register_f_args, name, NULL, lua_f_ctx, 0);
luaRegisterFunctionArgsInitialize(func, name, NULL, lua_f_ctx, 0);

return C_OK;

error:
if (name) sdsfree(name);
if (name) zfree(name);
luaPushError(lua, err);
return C_ERR;
}

static int luaRegisterFunctionReadArgs(lua_State *lua, registerFunctionArgs *register_f_args) {
static int luaRegisterFunctionReadArgs(lua_State *lua, ValkeyModuleScriptingEngineCompiledFunction *func) {
int argc = lua_gettop(lua);
if (argc < 1 || argc > 2) {
luaPushError(lua, "wrong number of arguments to server.register_function");
return C_ERR;
}

if (argc == 1) {
return luaRegisterFunctionReadNamedArgs(lua, register_f_args);
return luaRegisterFunctionReadNamedArgs(lua, func);
} else {
return luaRegisterFunctionReadPositionalArgs(lua, register_f_args);
return luaRegisterFunctionReadPositionalArgs(lua, func);
}
}

static int luaRegisterFunction(lua_State *lua) {
registerFunctionArgs register_f_args = {0};
ValkeyModuleScriptingEngineCompiledFunction *func = zcalloc(sizeof(*func));

loadCtx *load_ctx = luaGetFromRegistry(lua, REGISTRY_LOAD_CTX_NAME);
if (!load_ctx) {
zfree(func);
luaPushError(lua, "server.register_function can only be called on FUNCTION LOAD command");
return luaError(lua);
}

if (luaRegisterFunctionReadArgs(lua, &register_f_args) != C_OK) {
if (luaRegisterFunctionReadArgs(lua, func) != C_OK) {
zfree(func);
return luaError(lua);
}

sds err = NULL;
if (functionLibCreateFunction(register_f_args.name, register_f_args.lua_f_ctx, load_ctx->li, register_f_args.desc,
register_f_args.f_flags, &err) != C_OK) {
luaRegisterFunctionArgsDispose(lua, &register_f_args);
luaPushError(lua, err);
sdsfree(err);
return luaError(lua);
}
listAddNodeTail(load_ctx->functions, func);

return 0;
}
Expand Down Expand Up @@ -494,16 +516,13 @@ int luaEngineInitEngine(void) {
lua_enablereadonlytable(lua_engine_ctx->lua, -1, 1); /* protect the new global table */
lua_replace(lua_engine_ctx->lua, LUA_GLOBALSINDEX); /* set new global table as the new globals */


engine *lua_engine = zmalloc(sizeof(*lua_engine));
*lua_engine = (engine){
.engine_ctx = lua_engine_ctx,
.create = luaEngineCreate,
.call = luaEngineCall,
.get_used_memory = luaEngineGetUsedMemoy,
.get_function_memory_overhead = luaEngineFunctionMemoryOverhead,
.get_engine_memory_overhead = luaEngineMemoryOverhead,
.free_function = luaEngineFreeFunction,
};
return functionsRegisterEngine(LUA_ENGINE_NAME, lua_engine);
return functionsRegisterEngine(LUA_ENGINE_NAME,
NULL,
lua_engine_ctx,
luaEngineCreate,
luaEngineCall,
luaEngineGetUsedMemoy,
luaEngineFunctionMemoryOverhead,
luaEngineMemoryOverhead,
luaEngineFreeFunction);
}
Loading
Loading