diff --git a/examples/acl.rs b/examples/acl.rs index 409acc7e..9dcbc4b2 100644 --- a/examples/acl.rs +++ b/examples/acl.rs @@ -1,8 +1,15 @@ +use std::sync::Mutex; + +use lazy_static::{__Deref, lazy_static}; use redis_module::{ redis_module, AclPermissions, Context, NextArg, RedisError, RedisResult, RedisString, - RedisValue, + RedisUser, RedisValue, Status, }; +lazy_static! { + static ref USER: Mutex = Mutex::new(RedisUser::new("acl")); +} + fn verify_key_access_for_user(ctx: &Context, args: Vec) -> RedisResult { let mut args = args.into_iter().skip(1); let user = args.next_arg()?; @@ -18,6 +25,20 @@ fn get_current_user(ctx: &Context, _args: Vec) -> RedisResult { Ok(RedisValue::BulkRedisString(ctx.get_current_user())) } +fn authenticate_with_user(ctx: &Context, _args: Vec) -> RedisResult { + let user = USER.lock()?; + ctx.authenticate_client_with_user(user.deref())?; + Ok(RedisValue::SimpleStringStatic("OK")) +} + +fn init(_ctx: &Context, _args: &[RedisString]) -> Status { + // Set the user ACL + let _ = USER.lock().unwrap().set_acl("on allcommands allkeys"); + + // Module initialized + Status::Ok +} + ////////////////////////////////////////////////////// redis_module! { @@ -25,7 +46,9 @@ redis_module! { version: 1, allocator: (redis_module::alloc::RedisAlloc, redis_module::alloc::RedisAlloc), data_types: [], + init: init, commands: [ + ["authenticate_with_user", authenticate_with_user, "", 0, 0, 0], ["verify_key_access_for_user", verify_key_access_for_user, "", 0, 0, 0], ["get_current_user", get_current_user, "", 0, 0, 0], ], diff --git a/src/context/mod.rs b/src/context/mod.rs index 32d6f5ab..ec88fdb5 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -6,12 +6,12 @@ use std::os::raw::{c_char, c_int, c_long, c_longlong}; use std::ptr::{self, NonNull}; use std::sync::atomic::{AtomicPtr, Ordering}; -use crate::add_info_section; use crate::key::{RedisKey, RedisKeyWritable}; use crate::logging::RedisLogLevel; use crate::raw::{ModuleOptions, Version}; use crate::redisvalue::RedisValueKey; use crate::{add_info_field_long_long, add_info_field_str, raw, utils, Status}; +use crate::{add_info_section, RedisUser}; use crate::{RedisError, RedisResult, RedisString, RedisValue}; use std::ops::Deref; @@ -731,6 +731,16 @@ impl Context { RedisString::from_redis_module_string(ptr::null_mut(), user) } + /// Return the current user as a [RedisUser] object + pub fn get_module_user(&self, user_name: &RedisString) -> Option { + let user = unsafe { raw::RedisModule_GetModuleUserFromUserName.unwrap()(user_name.inner) }; + if user.is_null() { + return None; + } + + Some(RedisUser::from_redis_module_user(user)) + } + /// Attach the given user to the current context so each operation performed from /// now on using this context will be validated againts this new user. /// Return [ContextUserScope] which make sure to unset the user when freed and @@ -747,6 +757,25 @@ impl Context { Ok(ContextUserScope::new(self, user)) } + /// Authenticate the current context's user with the provided [RedisUser]. + pub fn authenticate_client_with_user(&self, user: &RedisUser) -> Result<(), RedisError> { + let result = unsafe { + raw::RedisModule_AuthenticateClientWithUser.unwrap()( + self.ctx, + user.user, + None, + std::ptr::null_mut(), + std::ptr::null_mut(), + ) + }; + + if result != raw::REDISMODULE_OK as i32 { + return Err(RedisError::Str("Error authenticating user client")); + } + + Ok(()) + } + fn deautenticate_user(&self) { unsafe { raw::RedisModule_SetContextUser.unwrap()(self.ctx, ptr::null_mut()) }; } @@ -760,21 +789,10 @@ impl Context { key_name: &RedisString, permissions: &AclPermissions, ) -> Result<(), RedisError> { - let user = unsafe { raw::RedisModule_GetModuleUserFromUserName.unwrap()(user_name.inner) }; - if user.is_null() { - return Err(RedisError::Str("User does not exists or disabled")); - } - let acl_permission_result: raw::Status = unsafe { - raw::RedisModule_ACLCheckKeyPermissions.unwrap()( - user, - key_name.inner, - permissions.bits(), - ) + match self.get_module_user(user_name) { + Some(user) => user.acl_check_key_permission(key_name, permissions), + None => Err(RedisError::Str("User does not exists or disabled")), } - .into(); - unsafe { raw::RedisModule_FreeModuleUser.unwrap()(user) }; - let acl_permission_result: Result<(), &str> = acl_permission_result.into(); - acl_permission_result.map_err(|_e| RedisError::Str("User does not have permissions on key")) } api!( diff --git a/src/lib.rs b/src/lib.rs index c72658c7..bbcd1d9f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,7 @@ pub mod raw; pub mod rediserror; mod redismodule; pub mod redisraw; +pub mod redisuser; pub mod redisvalue; pub mod stream; diff --git a/src/redismodule.rs b/src/redismodule.rs index 186e5b63..15a3c100 100644 --- a/src/redismodule.rs +++ b/src/redismodule.rs @@ -15,6 +15,7 @@ use serde::de::{Error, SeqAccess}; pub use crate::raw; pub use crate::rediserror::RedisError; +pub use crate::redisuser::RedisUser; pub use crate::redisvalue::RedisValue; use crate::Context; diff --git a/src/redisuser.rs b/src/redisuser.rs new file mode 100644 index 00000000..9564e04a --- /dev/null +++ b/src/redisuser.rs @@ -0,0 +1,76 @@ +use std::{ffi::CString, os::raw::c_char}; + +use crate::{raw, AclPermissions, RedisError, RedisString}; + +pub struct RedisUser { + pub(super) user: *mut raw::RedisModuleUser, +} + +impl RedisUser { + pub fn new(username: &str) -> RedisUser { + let username = CString::new(username).unwrap(); + let module_user = unsafe { raw::RedisModule_CreateModuleUser.unwrap()(username.as_ptr()) }; + + RedisUser { user: module_user } + } + + pub(super) fn from_redis_module_user(user: *mut raw::RedisModuleUser) -> RedisUser { + RedisUser { user } + } + + pub fn set_acl(&self, acl: &str) -> Result<(), RedisError> { + let acl = CString::new(acl).unwrap(); + let mut error: *mut raw::RedisModuleString = std::ptr::null_mut(); + let error_ptr: *mut *mut raw::RedisModuleString = &mut error; + + let result = unsafe { + raw::RedisModule_SetModuleUserACLString.unwrap()( + std::ptr::null_mut(), + self.user, + acl.as_ptr().cast::(), + error_ptr, + ) + }; + + // If the result is an error, parse the error string + if result != raw::REDISMODULE_OK as i32 { + let error = RedisString::from_redis_module_string(std::ptr::null_mut(), error); + return Err(RedisError::String(error.to_string_lossy())); + } + + Ok(()) + } + + pub fn acl(&self) -> RedisString { + let acl = unsafe { raw::RedisModule_GetModuleUserACLString.unwrap()(self.user) }; + RedisString::from_redis_module_string(std::ptr::null_mut(), acl) + } + + /// Verify the the given user has the give ACL permission on the given key. + /// Return Ok(()) if the user has the permissions or error (with relevant error message) + /// if the validation failed. + pub fn acl_check_key_permission( + &self, + key_name: &RedisString, + permissions: &AclPermissions, + ) -> Result<(), RedisError> { + let acl_permission_result: raw::Status = unsafe { + raw::RedisModule_ACLCheckKeyPermissions.unwrap()( + self.user, + key_name.inner, + permissions.bits(), + ) + } + .into(); + let acl_permission_result: Result<(), &str> = acl_permission_result.into(); + acl_permission_result.map_err(|_e| RedisError::Str("User does not have permissions on key")) + } +} + +impl Drop for RedisUser { + fn drop(&mut self) { + unsafe { raw::RedisModule_FreeModuleUser.unwrap()(self.user) }; + } +} + +unsafe impl Send for RedisUser {} diff --git a/tests/integration.rs b/tests/integration.rs index 1b00bd2c..f314a92a 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -282,6 +282,23 @@ fn test_get_current_user() -> Result<()> { Ok(()) } +#[test] +fn test_authenticate_client_with_user() -> Result<()> { + let port: u16 = 6490; + let _guards = vec![start_redis_server_with_module("acl", port) + .with_context(|| "failed to start redis server")?]; + let mut con = + get_redis_connection(port).with_context(|| "failed to connect to redis server")?; + + let res: String = redis::cmd("authenticate_with_user").query(&mut con)?; + assert_eq!(&res, "OK"); + + let res: String = redis::cmd("get_current_user").query(&mut con)?; + assert_eq!(&res, "acl"); + + Ok(()) +} + #[test] fn test_verify_acl_on_user() -> Result<()> { let port: u16 = 6491;