diff --git a/src/common/proc_macro/src/session_config.rs b/src/common/proc_macro/src/session_config.rs index 6b622241b1296..0ca0b1c81f2f3 100644 --- a/src/common/proc_macro/src/session_config.rs +++ b/src/common/proc_macro/src/session_config.rs @@ -37,6 +37,7 @@ pub(crate) fn derive_config(input: DeriveInput) -> TokenStream { let mut struct_impl_reset = vec![]; let mut set_match_branches = vec![]; let mut get_match_branches = vec![]; + let mut reset_match_branches = vec![]; let mut show_all_list = vec![]; for field in fields { @@ -161,12 +162,13 @@ pub(crate) fn derive_config(input: DeriveInput) -> TokenStream { }); let reset_func_name = format_ident!("reset_{}", field_ident); - struct_impl_reset.push(quote_spanned! { - field_ident.span()=> + struct_impl_reset.push(quote! { #[allow(clippy::useless_conversion)] - pub fn #reset_func_name(&mut self) { - self.#field_ident = #default.into(); + pub fn #reset_func_name(&mut self, reporter: &mut impl ConfigReporter) { + let val = #default; + #report_hook + self.#field_ident = val.into(); } }); @@ -202,6 +204,10 @@ pub(crate) fn derive_config(input: DeriveInput) -> TokenStream { #entry_name => self.#set_func_name(&value, reporter), }); + reset_match_branches.push(quote! { + #entry_name => Ok(self.#reset_func_name(reporter)), + }); + if !flags.contains(&"NO_SHOW_ALL") { show_all_list.push(quote! { VariableInfo { @@ -230,7 +236,6 @@ pub(crate) fn derive_config(input: DeriveInput) -> TokenStream { Default::default() } - #(#struct_impl_get)* #(#struct_impl_set)* @@ -253,6 +258,14 @@ pub(crate) fn derive_config(input: DeriveInput) -> TokenStream { } } + /// Reset a parameter by it's name. + pub fn reset(&mut self, key_name: &str, reporter: &mut impl ConfigReporter) -> SessionConfigResult<()> { + match key_name.to_ascii_lowercase().as_ref() { + #(#reset_match_branches)* + _ => Err(SessionConfigError::UnrecognizedEntry(key_name.to_string())), + } + } + /// Show all parameters. pub fn show_all(&self) -> Vec { vec![ diff --git a/src/frontend/src/handler/variable.rs b/src/frontend/src/handler/variable.rs index d7c8695040a2d..e58981685d09d 100644 --- a/src/frontend/src/handler/variable.rs +++ b/src/frontend/src/handler/variable.rs @@ -45,9 +45,7 @@ pub fn handle_set( value: SetVariableValue, ) -> Result { // Strip double and single quotes - let string_val = set_var_to_param_str(&value).ok_or(ErrorCode::InternalError( - "SET TO DEFAULT is not supported yet".to_string(), - ))?; + let string_val = set_var_to_param_str(&value); let mut status = ParameterStatus::default(); diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index 55d9756187bbd..abe042d2fc30d 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -617,13 +617,20 @@ impl SessionImpl { pub fn set_config_report( &self, key: &str, - value: String, + value: Option, mut reporter: impl ConfigReporter, ) -> Result<()> { - self.config_map - .write() - .set(key, value, &mut reporter) - .map_err(Into::into) + if let Some(value) = value { + self.config_map + .write() + .set(key, value, &mut reporter) + .map_err(Into::into) + } else { + self.config_map + .write() + .reset(key, &mut reporter) + .map_err(Into::into) + } } pub fn session_id(&self) -> SessionId {