diff --git a/src/frontend/src/handler/alter_system.rs b/src/frontend/src/handler/alter_system.rs index 91f8cff23c0d4..b040474d98c49 100644 --- a/src/frontend/src/handler/alter_system.rs +++ b/src/frontend/src/handler/alter_system.rs @@ -14,8 +14,9 @@ use pgwire::pg_response::StatementType; use risingwave_common::error::Result; -use risingwave_sqlparser::ast::{Ident, SetVariableValue, Value}; +use risingwave_sqlparser::ast::{Ident, SetVariableValue}; +use super::variable::set_var_to_param_str; use super::{HandlerArgs, RwPgResponse}; // Warn user if barrier_interval_ms is set above 5mins. @@ -28,12 +29,7 @@ pub async fn handle_alter_system( param: Ident, value: SetVariableValue, ) -> Result { - let value = match value { - SetVariableValue::Literal(Value::DoubleQuotedString(s)) - | SetVariableValue::Literal(Value::SingleQuotedString(s)) => Some(s), - SetVariableValue::Default => None, - _ => Some(value.to_string()), - }; + let value = set_var_to_param_str(&value); let params = handler_args .session .env() diff --git a/src/frontend/src/handler/variable.rs b/src/frontend/src/handler/variable.rs index 112107e725318..d7c8695040a2d 100644 --- a/src/frontend/src/handler/variable.rs +++ b/src/frontend/src/handler/variable.rs @@ -26,15 +26,16 @@ use super::RwPgResponse; use crate::handler::HandlerArgs; use crate::utils::infer_stmt_row_desc::infer_show_variable; -fn set_var_to_guc_str(value: &SetVariableValue) -> String { +/// convert `SetVariableValue` to string while remove the quotes on literals. +pub(crate) fn set_var_to_param_str(value: &SetVariableValue) -> Option { match value { - SetVariableValue::Literal(Value::DoubleQuotedString(s)) - | SetVariableValue::Literal(Value::SingleQuotedString(s)) => s.clone(), - SetVariableValue::List(list) => list - .iter() - .map(set_var_to_guc_str) - .join(SESSION_CONFIG_LIST_SEP), - _ => value.to_string(), + SetVariableValue::Single(var) => Some(var.to_string_unquoted()), + SetVariableValue::List(list) => Some( + list.iter() + .map(|var| var.to_string_unquoted()) + .join(SESSION_CONFIG_LIST_SEP), + ), + SetVariableValue::Default => None, } } @@ -44,7 +45,9 @@ pub fn handle_set( value: SetVariableValue, ) -> Result { // Strip double and single quotes - let string_val = set_var_to_guc_str(&value); + let string_val = set_var_to_param_str(&value).ok_or(ErrorCode::InternalError( + "SET TO DEFAULT is not supported yet".to_string(), + ))?; let mut status = ParameterStatus::default(); diff --git a/src/sqlparser/src/ast/mod.rs b/src/sqlparser/src/ast/mod.rs index a4867ba9d5ae8..a57a6a9175ebd 100644 --- a/src/sqlparser/src/ast/mod.rs +++ b/src/sqlparser/src/ast/mod.rs @@ -2652,24 +2652,51 @@ impl fmt::Display for CreateFunctionUsing { #[derive(Debug, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum SetVariableValue { - Ident(Ident), - Literal(Value), - List(Vec), + Single(SetVariableValueSingle), + List(Vec), Default, } +impl From for SetVariableValue { + fn from(value: SetVariableValueSingle) -> Self { + SetVariableValue::Single(value) + } +} + impl fmt::Display for SetVariableValue { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use SetVariableValue::*; + match self { + Single(val) => write!(f, "{}", val), + List(list) => write!(f, "{}", display_comma_separated(list),), + Default => write!(f, "DEFAULT"), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum SetVariableValueSingle { + Ident(Ident), + Literal(Value), +} + +impl SetVariableValueSingle { + pub fn to_string_unquoted(&self) -> String { + match self { + Self::Literal(Value::SingleQuotedString(s)) + | Self::Literal(Value::DoubleQuotedString(s)) => s.clone(), + _ => self.to_string(), + } + } +} + +impl fmt::Display for SetVariableValueSingle { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use SetVariableValueSingle::*; match self { Ident(ident) => write!(f, "{}", ident), Literal(literal) => write!(f, "{}", literal), - List(list) => write!( - f, - "{}", - list.iter().map(|value| value.to_string()).join(", ") - ), - Default => write!(f, "DEFAULT"), } } } diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index 439aed4a18e7f..4db8b3ebdfff6 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -3228,16 +3228,22 @@ impl Parser { loop { let token = self.peek_token(); let value = match (self.parse_value(), token.token) { - (Ok(value), _) => SetVariableValue::Literal(value), + (Ok(value), _) => SetVariableValueSingle::Literal(value), (Err(_), Token::Word(w)) => { if w.keyword == Keyword::DEFAULT { - SetVariableValue::Default + if !values.is_empty() { + self.expected( + "parameter list value", + Token::Word(w).with_location(token.location), + )? + } + return Ok(SetVariableValue::Default); } else { - SetVariableValue::Ident(w.to_ident()?) + SetVariableValueSingle::Ident(w.to_ident()?) } } (Err(_), unexpected) => { - self.expected("variable value", unexpected.with_location(token.location))? + self.expected("parameter value", unexpected.with_location(token.location))? } }; values.push(value); @@ -3246,7 +3252,7 @@ impl Parser { } } if values.len() == 1 { - Ok(values[0].clone()) + Ok(SetVariableValue::Single(values[0].clone())) } else { Ok(SetVariableValue::List(values)) } diff --git a/src/sqlparser/tests/sqlparser_postgres.rs b/src/sqlparser/tests/sqlparser_postgres.rs index 94eb2d53fbfa5..2d97834ad23b5 100644 --- a/src/sqlparser/tests/sqlparser_postgres.rs +++ b/src/sqlparser/tests/sqlparser_postgres.rs @@ -392,7 +392,7 @@ fn parse_set() { Statement::SetVariable { local: false, variable: "a".into(), - value: SetVariableValue::Ident("b".into()), + value: SetVariableValueSingle::Ident("b".into()).into(), } ); @@ -402,7 +402,7 @@ fn parse_set() { Statement::SetVariable { local: false, variable: "a".into(), - value: SetVariableValue::Literal(Value::SingleQuotedString("b".into())), + value: SetVariableValueSingle::Literal(Value::SingleQuotedString("b".into())).into(), } ); @@ -412,7 +412,7 @@ fn parse_set() { Statement::SetVariable { local: false, variable: "a".into(), - value: SetVariableValue::Literal(number("0")), + value: SetVariableValueSingle::Literal(number("0")).into(), } ); @@ -432,7 +432,7 @@ fn parse_set() { Statement::SetVariable { local: true, variable: "a".into(), - value: SetVariableValue::Ident("b".into()), + value: SetVariableValueSingle::Ident("b".into()).into(), } ); @@ -441,7 +441,7 @@ fn parse_set() { for (sql, err_msg) in [ ("SET", "Expected identifier, found: EOF"), ("SET a b", "Expected equals sign or TO, found: b"), - ("SET a =", "Expected variable value, found: EOF"), + ("SET a =", "Expected parameter value, found: EOF"), ] { let res = parse_sql_statements(sql); assert!(format!("{}", res.unwrap_err()).contains(err_msg)); diff --git a/src/sqlparser/tests/testdata/set.yaml b/src/sqlparser/tests/testdata/set.yaml index 309ffc5213aee..947bbea7056c9 100644 --- a/src/sqlparser/tests/testdata/set.yaml +++ b/src/sqlparser/tests/testdata/set.yaml @@ -13,3 +13,13 @@ formatted_sql: SET TIME ZONE UTC - input: set time = '1'; formatted_sql: SET time = '1' +- input: set search_path to 'default', 'my_path'; + formatted_sql: SET search_path = 'default', 'my_path' +- input: set search_path to default, 'my_path'; + error_msg: |- + sql parser error: Expected end of statement, found: , at line:1, column:28 + Near "set search_path to default" +- input: set search_path to 'my_path', default; + error_msg: |- + sql parser error: Expected parameter list value, found: default at line:1, column:36 + Near "set search_path to 'my_path', default"