diff --git a/src/common/proc_macro/src/session_config.rs b/src/common/proc_macro/src/session_config.rs index 0ca0b1c81f2f..aba8fe7043e4 100644 --- a/src/common/proc_macro/src/session_config.rs +++ b/src/common/proc_macro/src/session_config.rs @@ -21,6 +21,7 @@ use syn::DeriveInput; #[derive(FromAttributes)] struct Parameter { pub rename: Option, + pub alias: Option, pub default: syn::Expr, pub flags: Option, pub check_hook: Option, @@ -39,6 +40,7 @@ pub(crate) fn derive_config(input: DeriveInput) -> TokenStream { let mut get_match_branches = vec![]; let mut reset_match_branches = vec![]; let mut show_all_list = vec![]; + let mut alias_to_entry_name_branches = vec![]; for field in fields { let field_ident = field.ident.expect_or_abort("Field need to be named"); @@ -62,6 +64,7 @@ pub(crate) fn derive_config(input: DeriveInput) -> TokenStream { Parameter::from_attributes(&field.attrs).expect_or_abort("Failed to parse attribute"); let Parameter { rename, + alias, default, flags, check_hook: check_hook_name, @@ -78,6 +81,12 @@ pub(crate) fn derive_config(input: DeriveInput) -> TokenStream { quote! {stringify!(#ident)} }; + if let Some(alias) = alias { + alias_to_entry_name_branches.push(quote! { + #alias => #entry_name, + }) + } + let flags = flags.map(|f| f.value()).unwrap_or_default(); let flags: Vec<_> = flags.split('|').map(|str| str.trim()).collect(); @@ -236,6 +245,13 @@ pub(crate) fn derive_config(input: DeriveInput) -> TokenStream { Default::default() } + fn alias_to_entry_name(key_name: &str) -> &str { + match key_name { + #(#alias_to_entry_name_branches)* + _ => key_name, + } + } + #(#struct_impl_get)* #(#struct_impl_set)* @@ -244,6 +260,7 @@ pub(crate) fn derive_config(input: DeriveInput) -> TokenStream { /// Set a parameter given it's name and value string. pub fn set(&mut self, key_name: &str, value: String, reporter: &mut impl ConfigReporter) -> SessionConfigResult<()> { + let key_name = Self::alias_to_entry_name(key_name); match key_name.to_ascii_lowercase().as_ref() { #(#set_match_branches)* _ => Err(SessionConfigError::UnrecognizedEntry(key_name.to_string())), @@ -252,6 +269,7 @@ pub(crate) fn derive_config(input: DeriveInput) -> TokenStream { /// Get a parameter by it's name. pub fn get(&self, key_name: &str) -> SessionConfigResult { + let key_name = Self::alias_to_entry_name(key_name); match key_name.to_ascii_lowercase().as_ref() { #(#get_match_branches)* _ => Err(SessionConfigError::UnrecognizedEntry(key_name.to_string())), @@ -260,6 +278,7 @@ pub(crate) fn derive_config(input: DeriveInput) -> TokenStream { /// Reset a parameter by it's name. pub fn reset(&mut self, key_name: &str, reporter: &mut impl ConfigReporter) -> SessionConfigResult<()> { + let key_name = Self::alias_to_entry_name(key_name); match key_name.to_ascii_lowercase().as_ref() { #(#reset_match_branches)* _ => Err(SessionConfigError::UnrecognizedEntry(key_name.to_string())), diff --git a/src/common/src/session_config/mod.rs b/src/common/src/session_config/mod.rs index 82c5bd0a6f12..d5c93944f81a 100644 --- a/src/common/src/session_config/mod.rs +++ b/src/common/src/session_config/mod.rs @@ -304,3 +304,25 @@ pub trait ConfigReporter { impl ConfigReporter for () { fn report_status(&mut self, _key: &str, _new_val: String) {} } + +#[cfg(test)] +mod test { + use super::*; + + #[derive(SessionConfig)] + struct TestConfig { + #[parameter(default = 1, alias = "test_param_alias" | "alias_param_test")] + test_param: i32, + } + + #[test] + fn test_session_config_alias() { + let mut config = TestConfig::default(); + config.set("test_param", "2".to_string(), &mut ()).unwrap(); + assert_eq!(config.get("test_param_alias").unwrap(), "2"); + config + .set("alias_param_test", "3".to_string(), &mut ()) + .unwrap(); + assert_eq!(config.get("test_param_alias").unwrap(), "3"); + } +}