Skip to content

Commit

Permalink
implement enum AclCategory
Browse files Browse the repository at this point in the history
  • Loading branch information
ephraimfeldblum committed Nov 4, 2024
1 parent a6e4203 commit 0f33019
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 25 deletions.
10 changes: 5 additions & 5 deletions examples/acl.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use redis_module::{
redis_module, AclPermissions, Context, NextArg, RedisError, RedisResult, RedisString,
RedisValue,
redis_module, AclCategory, AclPermissions, Context, NextArg, RedisError, RedisResult,
RedisString, RedisValue,
};

fn verify_key_access_for_user(ctx: &Context, args: Vec<RedisString>) -> RedisResult {
Expand All @@ -25,9 +25,9 @@ redis_module! {
version: 1,
allocator: (redis_module::alloc::RedisAlloc, redis_module::alloc::RedisAlloc),
data_types: [],
acl_category: "acl",
acl_categories: [AclCategory::from("acl"), ],
commands: [
["verify_key_access_for_user", verify_key_access_for_user, "", 0, 0, 0, "read", "acl"],
["get_current_user", get_current_user, "", 0, 0, 0, "read", "acl"],
["verify_key_access_for_user", verify_key_access_for_user, "", 0, 0, 0, AclCategory::Read, AclCategory::from("acl")],
["get_current_user", get_current_user, "", 0, 0, 0, vec![AclCategory::Read, AclCategory::Fast], AclCategory::from("acl")],
],
}
107 changes: 107 additions & 0 deletions src/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,113 @@ bitflags! {
}
}

#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum AclCategory {
#[default]
None,
Keyspace,
Read,
Write,
Set,
SortedSet,
List,
Hash,
String,
Bitmap,
HyperLogLog,
Geo,
Stream,
PubSub,
Admin,
Fast,
Slow,
Blocking,
Dangerous,
Connection,
Transaction,
Scripting,
Single(String),
Multi(Vec<AclCategory>),
}

impl From<Vec<AclCategory>> for AclCategory {
fn from(value: Vec<AclCategory>) -> Self {
AclCategory::Multi(value)
}
}

impl From<&str> for AclCategory {
fn from(value: &str) -> Self {
match value {
"" => AclCategory::None,
"keyspace" => AclCategory::Keyspace,
"read" => AclCategory::Read,
"write" => AclCategory::Write,
"set" => AclCategory::Set,
"sortedset" => AclCategory::SortedSet,
"list" => AclCategory::List,
"hash" => AclCategory::Hash,
"string" => AclCategory::String,
"bitmap" => AclCategory::Bitmap,
"hyperloglog" => AclCategory::HyperLogLog,
"geo" => AclCategory::Geo,
"stream" => AclCategory::Stream,
"pubsub" => AclCategory::PubSub,
"admin" => AclCategory::Admin,
"fast" => AclCategory::Fast,
"slow" => AclCategory::Slow,
"blocking" => AclCategory::Blocking,
"dangerous" => AclCategory::Dangerous,
"connection" => AclCategory::Connection,
"transaction" => AclCategory::Transaction,
"scripting" => AclCategory::Scripting,
_ if !value.contains(" ") => AclCategory::Single(value.to_string()),
_ => AclCategory::Multi(value.split_whitespace().map(AclCategory::from).collect()),
}
}
}

impl From<AclCategory> for String {
fn from(value: AclCategory) -> Self {
match value {
AclCategory::None => "".to_string(),
AclCategory::Keyspace => "keyspace".to_string(),
AclCategory::Read => "read".to_string(),
AclCategory::Write => "write".to_string(),
AclCategory::Set => "set".to_string(),
AclCategory::SortedSet => "sortedset".to_string(),
AclCategory::List => "list".to_string(),
AclCategory::Hash => "hash".to_string(),
AclCategory::String => "string".to_string(),
AclCategory::Bitmap => "bitmap".to_string(),
AclCategory::HyperLogLog => "hyperloglog".to_string(),
AclCategory::Geo => "geo".to_string(),
AclCategory::Stream => "stream".to_string(),
AclCategory::PubSub => "pubsub".to_string(),
AclCategory::Admin => "admin".to_string(),
AclCategory::Fast => "fast".to_string(),
AclCategory::Slow => "slow".to_string(),
AclCategory::Blocking => "blocking".to_string(),
AclCategory::Dangerous => "dangerous".to_string(),
AclCategory::Connection => "connection".to_string(),
AclCategory::Transaction => "transaction".to_string(),
AclCategory::Scripting => "scripting".to_string(),
AclCategory::Single(s) => s,
AclCategory::Multi(v) => v
.into_iter()
.map(String::from)
.collect::<Vec<_>>()
.join(" "),
}
}
}

impl std::fmt::Display for AclCategory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", String::from(self.clone()))
}
}

/// The values allowed in the "info" sections and dictionaries.
#[derive(Debug, Clone)]
pub enum InfoContextBuilderFieldBottomLevelValue {
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub use crate::context::commands;
pub use crate::context::defrag;
pub use crate::context::keys_cursor::KeysCursor;
pub use crate::context::server_events;
pub use crate::context::AclCategory;
pub use crate::context::AclPermissions;
#[cfg(any(
feature = "min-redis-compatibility-version-7-4",
Expand Down
51 changes: 31 additions & 20 deletions src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ macro_rules! redis_command {
$mandatory_acl_categories:expr
$(, $optional_acl_categories:expr)?
) => {{
use redis_module::AclCategory;

let name = CString::new($command_name).unwrap();
let flags = CString::new($command_flags).unwrap();

Expand Down Expand Up @@ -56,13 +58,15 @@ macro_rules! redis_command {
return $crate::raw::Status::Err as c_int;
}

let mandatory_acl_categories = AclCategory::from($mandatory_acl_categories);
if let Some(RM_SetCommandACLCategories) = $crate::raw::RedisModule_SetCommandACLCategories {
let mut acl_categories = CString::new("").unwrap();
let mut acl_categories = CString::default();
$(
if $mandatory_acl_categories != "" && $optional_acl_categories != "" {
acl_categories = CString::new(format!("{} {}", $mandatory_acl_categories, $optional_acl_categories)).unwrap();
} else if $optional_acl_categories != "" {
acl_categories = CString::new($optional_acl_categories).unwrap();
let optional_acl_categories = AclCategory::from($optional_acl_categories);
if mandatory_acl_categories != AclCategory::None && optional_acl_categories != AclCategory::None {
acl_categories = CString::new(format!("{} {}", mandatory_acl_categories, optional_acl_categories)).unwrap();
} else if optional_acl_categories != AclCategory::None {
acl_categories = CString::new(format!("{}", $optional_acl_categories)).unwrap();
}
// Warn if optional ACL categories are not set, but don't fail.
if RM_SetCommandACLCategories(command, acl_categories.as_ptr()) == $crate::raw::Status::Err as c_int {
Expand All @@ -75,8 +79,8 @@ macro_rules! redis_command {
);
} else
)?
if $mandatory_acl_categories != "" {
acl_categories = CString::new($mandatory_acl_categories).unwrap();
if mandatory_acl_categories != AclCategory::None {
acl_categories = CString::new(format!("{}", mandatory_acl_categories)).unwrap();

// Fail if mandatory ACL categories are not set.
if RM_SetCommandACLCategories(command, acl_categories.as_ptr())
Expand All @@ -86,13 +90,13 @@ macro_rules! redis_command {
$ctx,
&format!(
"Error: failed to set command `{}` mandatory ACL categories `{}`",
$command_name, $mandatory_acl_categories
$command_name, mandatory_acl_categories
),
);
return $crate::raw::Status::Err as c_int;
}
}
} else if $mandatory_acl_categories != "" {
} else if mandatory_acl_categories != AclCategory::None {
$crate::raw::redis_log(
$ctx,
"Error: Redis version does not support ACL categories",
Expand Down Expand Up @@ -167,9 +171,11 @@ macro_rules! redis_module {
data_types: [
$($data_type:ident),* $(,)*
],
// eg: `acl_category: "name_of_module_acl_category",`
// eg: `acl_category: [ "name_of_module_acl_category", ],`
// This will add the specified (optional) ACL categories.
$(acl_category: $module_acl_categories:expr,)* $(,)*
$(acl_categories: [
$($module_acl_category:expr,)*
],)?
$(init: $init_func:ident,)* $(,)*
$(deinit: $deinit_func:ident,)* $(,)*
$(info: $info_func:ident,)?
Expand Down Expand Up @@ -307,16 +313,21 @@ macro_rules! redis_module {
)*

$(
if let Some(RM_AddACLCategory) = raw::RedisModule_AddACLCategory {
let categories = CString::new($module_acl_categories).unwrap();
if RM_AddACLCategory(ctx, categories.as_ptr()) == raw::Status::Err as c_int {
raw::redis_log(ctx, &format!("Error: failed to add ACL categories `{}`", $module_acl_categories));
return raw::Status::Err as c_int;
$(
if let Some(RM_AddACLCategory) = raw::RedisModule_AddACLCategory {
let module_acl_category = AclCategory::from($module_acl_category);
if module_acl_category != AclCategory::None {
let category = CString::new(format!("{}", $module_acl_category)).unwrap();
if RM_AddACLCategory(ctx, category.as_ptr()) == raw::Status::Err as c_int {
raw::redis_log(ctx, &format!("Error: failed to add ACL category `{}`", $module_acl_category));
return raw::Status::Err as c_int;
}
}
} else {
raw::redis_log(ctx, "Warning: Redis version does not support adding new ACL categories");
}
} else {
raw::redis_log(ctx, "Warning: Redis version does not support adding new ACL categories");
}
)*
)*
)?

$(
$crate::redis_command!(ctx, $name, $command, $flags, $firstkey, $lastkey, $keystep, $mandatory_command_acl_categories $(, $optional_command_acl_categories)?);
Expand Down

0 comments on commit 0f33019

Please sign in to comment.