diff --git a/Cargo.toml b/Cargo.toml index a1fcddab..2e292e9f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,7 +54,7 @@ required-features = ["experimental-api"] [[example]] name = "test_helper" crate-type = ["cdylib"] -required-features = ["test"] +required-features = ["test","experimental-api"] [[example]] name = "info" diff --git a/examples/hello.rs b/examples/hello.rs index dd6b4388..55db1667 100644 --- a/examples/hello.rs +++ b/examples/hello.rs @@ -1,9 +1,6 @@ #[macro_use] extern crate redis_module; -use redis_module::InfoContext; -use redis_module::Status; - use redis_module::{Context, RedisError, RedisResult, RedisString}; fn hello_mul(_: &Context, args: Vec) -> RedisResult { if args.len() < 2 { @@ -24,32 +21,13 @@ fn hello_mul(_: &Context, args: Vec) -> RedisResult { Ok(response.into()) } -fn hello_err(ctx: &Context, args: Vec) -> RedisResult { - if args.is_empty() { - return Err(RedisError::WrongArity); - } - - let msg = args.get(1).unwrap(); - - ctx.reply_error_string(msg.try_as_str().unwrap()); - Ok(().into()) -} - -fn add_info(ctx: &InfoContext, _for_crash_report: bool) { - if ctx.add_info_section(Some("hello")) == Status::Ok { - ctx.add_info_field_str("field", "hello_value"); - } -} - ////////////////////////////////////////////////////// redis_module! { name: "hello", version: 1, data_types: [], - info: add_info, commands: [ ["hello.mul", hello_mul, "", 0, 0, 0], - ["hello.err", hello_err, "", 0, 0, 0], ], } diff --git a/examples/test_helper.rs b/examples/test_helper.rs index c2578ee6..e10271db 100644 --- a/examples/test_helper.rs +++ b/examples/test_helper.rs @@ -1,7 +1,9 @@ #[macro_use] extern crate redis_module; -use redis_module::{Context, RedisResult, RedisString}; +use redis_module::InfoContext; +use redis_module::Status; +use redis_module::{Context, RedisError, RedisResult, RedisString}; fn test_helper_version(ctx: &Context, _args: Vec) -> RedisResult { let ver = ctx.get_redis_version()?; @@ -18,25 +20,39 @@ fn test_helper_version_rm_call(ctx: &Context, _args: Vec) -> RedisR Ok(response.into()) } -////////////////////////////////////////////////////// +fn test_helper_command_name(ctx: &Context, _args: Vec) -> RedisResult { + Ok(ctx.current_command_name()?.into()) +} -#[cfg(not(feature = "test"))] -redis_module! { - name: "test_helper", - version: 1, - data_types: [], - commands: [ - ["test_helper.version", test_helper_version, "", 0, 0, 0], - ], +fn test_helper_err(ctx: &Context, args: Vec) -> RedisResult { + if args.len() < 1 { + return Err(RedisError::WrongArity); + } + + let msg = args.get(1).unwrap(); + + ctx.reply_error_string(msg.try_as_str().unwrap()); + Ok(().into()) } +fn add_info(ctx: &InfoContext, _for_crash_report: bool) { + if ctx.add_info_section(Some("test_helper")) == Status::Ok { + ctx.add_info_field_str("field", "test_helper_value"); + } +} + +////////////////////////////////////////////////////// + #[cfg(feature = "test")] redis_module! { name: "test_helper", version: 1, data_types: [], + info: add_info, commands: [ ["test_helper.version", test_helper_version, "", 0, 0, 0], ["test_helper._version_rm_call", test_helper_version_rm_call, "", 0, 0, 0], + ["test_helper.name", test_helper_command_name, "", 0, 0, 0], + ["test_helper.err", test_helper_err, "", 0, 0, 0], ], } diff --git a/src/context/mod.rs b/src/context/mod.rs index f387bd34..1e64e2ef 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -8,6 +8,9 @@ use crate::{add_info_field_long_long, add_info_field_str, raw, utils, Status}; use crate::{add_info_section, LogLevel}; use crate::{RedisError, RedisResult, RedisString, RedisValue}; +#[cfg(feature = "experimental-api")] +use std::ffi::CStr; + #[cfg(feature = "experimental-api")] mod timer; @@ -294,7 +297,19 @@ impl Context { unsafe { raw::notify_keyspace_event(self.ctx, event_type, event, keyname) } } - /// Returns the redis version either by calling ``RedisModule_GetServerVersion`` API, + #[cfg(feature = "experimental-api")] + pub fn current_command_name(&self) -> Result { + unsafe { + match raw::RedisModule_GetCurrentCommandName { + Some(cmd) => Ok(CStr::from_ptr(cmd(self.ctx)).to_str().unwrap().to_string()), + None => Err(RedisError::Str( + "API RedisModule_GetCurrentCommandName is not available", + )), + } + } + } + + /// Returns the redis version either by calling RedisModule_GetServerVersion API, /// Or if it is not available, by calling "info server" API and parsing the reply pub fn get_redis_version(&self) -> Result { self.get_redis_version_internal(false) @@ -306,6 +321,22 @@ impl Context { self.get_redis_version_internal(true) } + pub fn version_from_info(info: RedisValue) -> Result { + if let RedisValue::SimpleString(info_str) = info { + if let Some(ver) = utils::get_regexp_captures( + info_str.as_str(), + r"(?m)\bredis_version:([0-9]+)\.([0-9]+)\.([0-9]+)\b", + ) { + return Ok(Version { + major: ver[0].parse::().unwrap(), + minor: ver[1].parse::().unwrap(), + patch: ver[2].parse::().unwrap(), + }); + } + } + Err(RedisError::Str("Error getting redis_version")) + } + #[allow(clippy::not_unsafe_ptr_arg_deref)] fn get_redis_version_internal(&self, force_use_rm_call: bool) -> Result { match unsafe { raw::RedisModule_GetServerVersion } { @@ -315,21 +346,11 @@ impl Context { } _ => { // Call "info server" - if let Ok(RedisValue::SimpleString(s)) = self.call("info", &["server"]) { - if let Some(ver) = utils::get_regexp_captures( - s.as_str(), - r"(?m)\bredis_version:([0-9]+)\.([0-9]+)\.([0-9]+)\b", - ) { - return Ok(Version { - major: ver[0].parse::().unwrap(), - minor: ver[1].parse::().unwrap(), - patch: ver[2].parse::().unwrap(), - }); - } + if let Ok(info) = self.call("info", &["server"]) { + Context::version_from_info(info) + } else { + Err(RedisError::Str("Error calling \"info server\"")) } - Err(RedisError::Str( - "Error getting redis_version from \"info server\" call", - )) } } } diff --git a/src/include/redismodule.h b/src/include/redismodule.h index 672da07e..8a3b8acd 100644 --- a/src/include/redismodule.h +++ b/src/include/redismodule.h @@ -838,6 +838,7 @@ REDISMODULE_API int (*RedisModule_AuthenticateClientWithUser)(RedisModuleCtx *ct REDISMODULE_API int (*RedisModule_DeauthenticateAndCloseClient)(RedisModuleCtx *ctx, uint64_t client_id) REDISMODULE_ATTR; REDISMODULE_API RedisModuleString * (*RedisModule_GetClientCertificate)(RedisModuleCtx *ctx, uint64_t id) REDISMODULE_ATTR; REDISMODULE_API int *(*RedisModule_GetCommandKeys)(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, int *num_keys) REDISMODULE_ATTR; +REDISMODULE_API const char *(*RedisModule_GetCurrentCommandName)(RedisModuleCtx *ctx) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_RegisterDefragFunc)(RedisModuleCtx *ctx, RedisModuleDefragFunc func) REDISMODULE_ATTR; REDISMODULE_API void *(*RedisModule_DefragAlloc)(RedisModuleDefragCtx *ctx, void *ptr) REDISMODULE_ATTR; REDISMODULE_API RedisModuleString *(*RedisModule_DefragRedisModuleString)(RedisModuleDefragCtx *ctx, RedisModuleString *str) REDISMODULE_ATTR; @@ -1110,6 +1111,7 @@ static int RedisModule_Init(RedisModuleCtx *ctx, const char *name, int ver, int REDISMODULE_GET_API(AuthenticateClientWithUser); REDISMODULE_GET_API(GetClientCertificate); REDISMODULE_GET_API(GetCommandKeys); + REDISMODULE_GET_API(GetCurrentCommandName); REDISMODULE_GET_API(RegisterDefragFunc); REDISMODULE_GET_API(DefragAlloc); REDISMODULE_GET_API(DefragRedisModuleString); diff --git a/tests/integration.rs b/tests/integration.rs index c0916beb..e6ad5404 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -1,9 +1,8 @@ +use crate::utils::{get_redis_connection, start_redis_server_with_module}; use anyhow::Context; use anyhow::Result; use redis::RedisError; -use utils::{get_redis_connection, start_redis_server_with_module}; - mod utils; #[test] @@ -74,38 +73,76 @@ fn test_test_helper_version() -> Result<()> { Ok(()) } +#[cfg(feature = "experimental-api")] #[test] -fn test_hello_info() -> Result<()> { +fn test_command_name() -> Result<()> { + use redis_module::RedisValue; + let port: u16 = 6482; - let _guards = vec![start_redis_server_with_module("hello", port) + let _guards = vec![start_redis_server_with_module("test_helper", port) + .with_context(|| "failed to start redis server")?]; + let mut con = + get_redis_connection(port).with_context(|| "failed to connect to redis server")?; + + // Call the tested command + let res: Result = redis::cmd("test_helper.name").query(&mut con); + + // The expected result is according to redis version + let info: String = redis::cmd("info") + .arg(&["server"]) + .query(&mut con) + .with_context(|| "failed to run test_helper.name")?; + + if let Ok(ver) = redis_module::Context::version_from_info(RedisValue::SimpleString(info)) { + if ver.major > 6 + || (ver.major == 6 && ver.minor > 2) + || (ver.major == 6 && ver.minor == 2 && ver.patch >= 5) + { + assert_eq!(res.unwrap(), String::from("test_helper.name")); + } else { + assert!(res + .err() + .unwrap() + .to_string() + .contains("RedisModule_GetCurrentCommandName is not available")); + } + } + + Ok(()) +} + +#[test] +fn test_test_helper_info() -> Result<()> { + let port: u16 = 6483; + let _guards = vec![start_redis_server_with_module("test_helper", port) .with_context(|| "failed to start redis server")?]; let mut con = get_redis_connection(port).with_context(|| "failed to connect to redis server")?; let res: String = redis::cmd("INFO") - .arg("HELLO") + .arg("TEST_HELPER") .query(&mut con) - .with_context(|| "failed to run INFO HELLO")?; - assert!(res.contains("hello_field:hello_value")); + .with_context(|| "failed to run INFO TEST_HELPER")?; + assert!(res.contains("test_helper_field:test_helper_value")); Ok(()) } #[allow(unused_must_use)] #[test] -fn test_hello_err() -> Result<()> { - let port: u16 = 6483; +fn test_test_helper_err() -> Result<()> { + let port: u16 = 6484; let _guards = vec![start_redis_server_with_module("hello", port) .with_context(|| "failed to start redis server")?]; let mut con = get_redis_connection(port).with_context(|| "failed to connect to redis server")?; // Make sure embedded nulls do not cause a crash - redis::cmd("hello.err") + redis::cmd("test_helper.err") .arg(&["\x00\x00"]) .query::<()>(&mut con); - redis::cmd("hello.err") + redis::cmd("test_helper.err") .arg(&["no crash\x00"]) .query::<()>(&mut con);