diff --git a/Cargo.lock b/Cargo.lock index 9bda000949a8..bab77ee4db3b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9938,6 +9938,7 @@ dependencies = [ "tower", "tower-http", "urlencoding", + "uuid", "zstd 0.13.1", ] diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index b217089c97b6..bd68a0209936 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -105,6 +105,7 @@ tonic-reflection = "0.11" tower = { workspace = true, features = ["full"] } tower-http = { version = "0.4", features = ["full"] } urlencoding = "2.1" +uuid.workspace = true zstd.workspace = true [target.'cfg(not(windows))'.dependencies] diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index 84a58faeb9f2..0e90ed0ddc1c 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -59,7 +59,7 @@ pub struct MysqlInstanceShim { salt: [u8; 20], session: SessionRef, user_provider: Option, - prepared_stmts: Arc>>, + prepared_stmts: Arc>>, prepared_stmts_counter: AtomicU32, } @@ -134,18 +134,88 @@ impl MysqlInstanceShim { self.query_handler.do_describe(statement, query_ctx).await } - /// Save query and logical plan, return the unique id - fn save_plan(&self, plan: SqlPlan) -> u32 { - let stmt_id = self.prepared_stmts_counter.fetch_add(1, Ordering::Relaxed); + /// Save query and logical plan with a given statement key + fn save_plan(&self, plan: SqlPlan, stmt_key: String) { let mut prepared_stmts = self.prepared_stmts.write(); - let _ = prepared_stmts.insert(stmt_id, plan); - stmt_id + let _ = prepared_stmts.insert(stmt_key, plan); } - /// Retrieve the query and logical plan by id - fn plan(&self, stmt_id: u32) -> Option { + /// Retrieve the query and logical plan by a given statement key + fn plan(&self, stmt_key: String) -> Option { let guard = self.prepared_stmts.read(); - guard.get(&stmt_id).cloned() + guard.get(&stmt_key).cloned() + } + + /// Save the prepared statement and return the parameters and result columns + async fn do_prepare( + &mut self, + raw_query: &str, + query_ctx: QueryContextRef, + stmt_key: String, + ) -> Result<(Vec, Vec)> { + let (query, param_num) = replace_placeholders(raw_query); + + let statement = validate_query(raw_query).await?; + + // We have to transform the placeholder, because DataFusion only parses placeholders + // in the form of "$i", it can't process "?" right now. + let statement = transform_placeholders(statement); + + let describe_result = self + .do_describe(statement.clone(), query_ctx.clone()) + .await?; + let (plan, schema) = if let Some(DescribeResult { + logical_plan, + schema, + }) = describe_result + { + (Some(logical_plan), Some(schema)) + } else { + (None, None) + }; + + let params = if let Some(plan) = &plan { + prepared_params( + &plan + .get_param_types() + .context(error::GetPreparedStmtParamsSnafu)?, + )? + } else { + dummy_params(param_num)? + }; + + debug_assert_eq!(params.len(), param_num - 1); + + let columns = schema + .as_ref() + .map(|schema| { + schema + .column_schemas() + .iter() + .map(|column_schema| { + create_mysql_column(&column_schema.data_type, &column_schema.name) + }) + .collect::>>() + }) + .transpose()? + .unwrap_or_default(); + + self.save_plan( + SqlPlan { + query: query.to_string(), + plan, + schema, + }, + stmt_key, + ); + + Ok((params, columns)) + } + + /// Remove the prepared statement by a given statement key + fn do_close(&mut self, stmt_key: String) { + let mut guard = self.prepared_stmts.write(); + let _ = guard.remove(&stmt_key); } } @@ -210,59 +280,11 @@ impl AsyncMysqlShim for MysqlInstanceShi w: StatementMetaWriter<'a, W>, ) -> Result<()> { let query_ctx = self.session.new_query_context(); - let (query, param_num) = replace_placeholders(raw_query); - - let statement = validate_query(raw_query).await?; - - // We have to transform the placeholder, because DataFusion only parses placeholders - // in the form of "$i", it can't process "?" right now. - let statement = transform_placeholders(statement); - - let describe_result = self - .do_describe(statement.clone(), query_ctx.clone()) + let stmt_id = self.prepared_stmts_counter.fetch_add(1, Ordering::Relaxed); + let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string(); + let (params, columns) = self + .do_prepare(raw_query, query_ctx.clone(), stmt_key) .await?; - let (plan, schema) = if let Some(DescribeResult { - logical_plan, - schema, - }) = describe_result - { - (Some(logical_plan), Some(schema)) - } else { - (None, None) - }; - - let params = if let Some(plan) = &plan { - prepared_params( - &plan - .get_param_types() - .context(error::GetPreparedStmtParamsSnafu)?, - )? - } else { - dummy_params(param_num)? - }; - - debug_assert_eq!(params.len(), param_num - 1); - - let columns = schema - .as_ref() - .map(|schema| { - schema - .column_schemas() - .iter() - .map(|column_schema| { - create_mysql_column(&column_schema.data_type, &column_schema.name) - }) - .collect::>>() - }) - .transpose()? - .unwrap_or_default(); - - let stmt_id = self.save_plan(SqlPlan { - query: query.to_string(), - plan, - schema, - }); - w.reply(stmt_id, ¶ms, &columns).await?; crate::metrics::METRIC_MYSQL_PREPARED_COUNT .with_label_values(&[query_ctx.get_db_string().as_str()]) @@ -283,11 +305,12 @@ impl AsyncMysqlShim for MysqlInstanceShi .start_timer(); let params: Vec = p.into_iter().collect(); - let sql_plan = match self.plan(stmt_id) { + let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string(); + let sql_plan = match self.plan(stmt_key) { None => { w.error( ErrorKind::ER_UNKNOWN_STMT_HANDLER, - b"prepare statement not exist", + b"prepare statement not found", ) .await?; return Ok(()); @@ -334,7 +357,11 @@ impl AsyncMysqlShim for MysqlInstanceShi ] } None => { - let query = replace_params(params, sql_plan.query); + let param_strs = params + .iter() + .map(|x| convert_param_value_to_string(x)) + .collect(); + let query = replace_params(param_strs, sql_plan.query); debug!("Mysql execute replaced query: {}", query); self.do_query(&query, query_ctx.clone()).await } @@ -349,8 +376,8 @@ impl AsyncMysqlShim for MysqlInstanceShi where W: 'async_trait, { - let mut guard = self.prepared_stmts.write(); - let _ = guard.remove(&stmt_id); + let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string(); + self.do_close(stmt_key); } #[tracing::instrument(skip_all, fields(protocol = "mysql"))] @@ -364,6 +391,130 @@ impl AsyncMysqlShim for MysqlInstanceShi let _timer = crate::metrics::METRIC_MYSQL_QUERY_TIMER .with_label_values(&[crate::metrics::METRIC_MYSQL_TEXTQUERY, db.as_str()]) .start_timer(); + + let query_upcase = query.to_uppercase(); + if query_upcase.starts_with("PREPARE ") { + match ParserContext::parse_mysql_prepare_stmt(query, query_ctx.sql_dialect()) { + Ok((stmt_name, stmt)) => { + let prepare_results = + self.do_prepare(&stmt, query_ctx.clone(), stmt_name).await; + match prepare_results { + Ok(_) => { + let outputs = vec![Ok(Output::new_with_affected_rows(0))]; + writer::write_output(writer, query_ctx, outputs).await?; + return Ok(()); + } + Err(e) => { + writer + .error(ErrorKind::ER_SP_BADSTATEMENT, e.output_msg().as_bytes()) + .await?; + return Ok(()); + } + } + } + Err(e) => { + writer + .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes()) + .await?; + return Ok(()); + } + } + } else if query_upcase.starts_with("EXECUTE ") { + match ParserContext::parse_mysql_execute_stmt(query, query_ctx.sql_dialect()) { + // TODO: similar to on_execute, refactor this + Ok((stmt_name, params)) => { + let sql_plan = match self.plan(stmt_name) { + None => { + writer + .error( + ErrorKind::ER_UNKNOWN_STMT_HANDLER, + b"prepare statement not found", + ) + .await?; + return Ok(()); + } + Some(sql_plan) => sql_plan, + }; + + let outputs = match sql_plan.plan { + Some(plan) => { + let param_types = plan + .get_param_types() + .context(error::GetPreparedStmtParamsSnafu)?; + + if params.len() != param_types.len() { + writer + .error( + ErrorKind::ER_SP_BADSTATEMENT, + b"prepare statement params number mismatch", + ) + .await?; + return Ok(()); + } + + let plan = match replace_params_with_exprs(&plan, param_types, ¶ms) + { + Ok(plan) => plan, + Err(e) => { + if e.status_code().should_log_error() { + error!(e; "params: {}", params + .iter() + .map(|x| format!("({:?})", x)) + .join(", ")); + } + + writer + .error( + ErrorKind::ER_TRUNCATED_WRONG_VALUE, + e.output_msg().as_bytes(), + ) + .await?; + return Ok(()); + } + }; + + debug!("Mysql execute prepared plan: {}", plan.display_indent()); + vec![ + self.do_exec_plan(&sql_plan.query, plan, query_ctx.clone()) + .await, + ] + } + None => { + let param_strs = params.iter().map(|x| x.to_string()).collect(); + let query = replace_params(param_strs, sql_plan.query); + debug!("Mysql execute replaced query: {}", query); + let outputs = self.do_query(&query, query_ctx.clone()).await; + writer::write_output(writer, query_ctx, outputs).await?; + return Ok(()); + } + }; + writer::write_output(writer, query_ctx, outputs).await?; + return Ok(()); + } + Err(e) => { + writer + .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes()) + .await?; + return Ok(()); + } + } + } else if query_upcase.starts_with("DEALLOCATE ") { + match ParserContext::parse_mysql_deallocate_stmt(query, query_ctx.sql_dialect()) { + Ok(stmt_name) => { + self.do_close(stmt_name); + let outputs = vec![Ok(Output::new_with_affected_rows(0))]; + writer::write_output(writer, query_ctx, outputs).await?; + return Ok(()); + } + Err(e) => { + writer + .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes()) + .await?; + return Ok(()); + } + } + } + let outputs = self.do_query(query, query_ctx.clone()).await; writer::write_output(writer, query_ctx, outputs).await?; Ok(()) @@ -420,21 +571,24 @@ impl AsyncMysqlShim for MysqlInstanceShi } } -fn replace_params(params: Vec, query: String) -> String { +fn convert_param_value_to_string(param: &ParamValue) -> String { + match param.value.into_inner() { + ValueInner::Int(u) => u.to_string(), + ValueInner::UInt(u) => u.to_string(), + ValueInner::Double(u) => u.to_string(), + ValueInner::NULL => "NULL".to_string(), + ValueInner::Bytes(b) => format!("'{}'", &String::from_utf8_lossy(b)), + ValueInner::Date(_) => NaiveDate::from(param.value).to_string(), + ValueInner::Datetime(_) => NaiveDateTime::from(param.value).to_string(), + ValueInner::Time(_) => format_duration(Duration::from(param.value)), + } +} + +fn replace_params(params: Vec, query: String) -> String { let mut query = query; let mut index = 1; for param in params { - let s = match param.value.into_inner() { - ValueInner::Int(u) => u.to_string(), - ValueInner::UInt(u) => u.to_string(), - ValueInner::Double(u) => u.to_string(), - ValueInner::NULL => "NULL".to_string(), - ValueInner::Bytes(b) => format!("'{}'", &String::from_utf8_lossy(b)), - ValueInner::Date(_) => NaiveDate::from(param.value).to_string(), - ValueInner::Datetime(_) => NaiveDateTime::from(param.value).to_string(), - ValueInner::Time(_) => format_duration(Duration::from(param.value)), - }; - query = query.replace(&format_placeholder(index), &s); + query = query.replace(&format_placeholder(index), ¶m); index += 1; } query @@ -477,6 +631,33 @@ fn replace_params_with_values( .context(error::ReplacePreparedStmtParamsSnafu) } +fn replace_params_with_exprs( + plan: &LogicalPlan, + param_types: HashMap>, + params: &[sql::ast::Expr], +) -> Result { + debug_assert_eq!(param_types.len(), params.len()); + + debug!( + "replace_params_with_exprs(param_types: {:#?}, params: {:#?})", + param_types, + params.iter().map(|x| format!("({:?})", x)).join(", ") + ); + + let mut values = Vec::with_capacity(params.len()); + + for (i, param) in params.iter().enumerate() { + if let Some(Some(t)) = param_types.get(&format_placeholder(i + 1)) { + let value = helper::convert_expr_to_scalar_value(param, t)?; + + values.push(value); + } + } + + plan.replace_params_with_values(&values) + .context(error::ReplacePreparedStmtParamsSnafu) +} + async fn validate_query(query: &str) -> Result { let statement = ParserContext::create_with_dialect(query, &MySqlDialect {}, ParseOptions::default()); diff --git a/src/servers/src/mysql/helper.rs b/src/servers/src/mysql/helper.rs index f1aede0b5d5f..df174b38400c 100644 --- a/src/servers/src/mysql/helper.rs +++ b/src/servers/src/mysql/helper.rs @@ -23,6 +23,7 @@ use itertools::Itertools; use opensrv_mysql::{to_naive_datetime, ParamValue, ValueInner}; use snafu::ResultExt; use sql::ast::{visit_expressions_mut, Expr, Value as ValueExpr, VisitMut}; +use sql::statements::sql_value_to_value; use sql::statements::statement::Statement; use crate::error::{self, Result}; @@ -201,6 +202,27 @@ pub fn convert_value(param: &ParamValue, t: &ConcreteDataType) -> Result Result { + match param { + Expr::Value(v) => { + let v = sql_value_to_value("", t, v, None); + match v { + Ok(v) => v + .try_to_scalar_value(t) + .context(error::ConvertScalarValueSnafu), + Err(e) => error::InvalidParameterSnafu { + reason: e.to_string(), + } + .fail(), + } + } + _ => error::InvalidParameterSnafu { + reason: format!("cannot convert {:?} to scalar value of type {}", param, t), + } + .fail(), + } +} + #[cfg(test)] mod tests { use sql::dialect::MySqlDialect; @@ -265,4 +287,45 @@ mod tests { }; assert_eq!("SELECT from AS demo WHERE host = $1 AND idc IN (SELECT idc FROM idcs WHERE name = $2) AND cpu > $3", select.inner.to_string()); } + + #[test] + fn test_convert_expr_to_scalar_value() { + let expr = Expr::Value(ValueExpr::Number("123".to_string(), false)); + let t = ConcreteDataType::int32_datatype(); + let v = convert_expr_to_scalar_value(&expr, &t).unwrap(); + assert_eq!(ScalarValue::Int32(Some(123)), v); + + let expr = Expr::Value(ValueExpr::Number("123.456789".to_string(), false)); + let t = ConcreteDataType::float64_datatype(); + let v = convert_expr_to_scalar_value(&expr, &t).unwrap(); + assert_eq!(ScalarValue::Float64(Some(123.456789)), v); + + let expr = Expr::Value(ValueExpr::SingleQuotedString("2001-01-02".to_string())); + let t = ConcreteDataType::date_datatype(); + let v = convert_expr_to_scalar_value(&expr, &t).unwrap(); + let scalar_v = ScalarValue::Utf8(Some("2001-01-02".to_string())) + .cast_to(&arrow_schema::DataType::Date32) + .unwrap(); + assert_eq!(scalar_v, v); + + let expr = Expr::Value(ValueExpr::SingleQuotedString( + "2001-01-02 03:04:05".to_string(), + )); + let t = ConcreteDataType::datetime_datatype(); + let v = convert_expr_to_scalar_value(&expr, &t).unwrap(); + let scalar_v = ScalarValue::Utf8(Some("2001-01-02 03:04:05".to_string())) + .cast_to(&arrow_schema::DataType::Date64) + .unwrap(); + assert_eq!(scalar_v, v); + + let expr = Expr::Value(ValueExpr::SingleQuotedString("hello".to_string())); + let t = ConcreteDataType::string_datatype(); + let v = convert_expr_to_scalar_value(&expr, &t).unwrap(); + assert_eq!(ScalarValue::Utf8(Some("hello".to_string())), v); + + let expr = Expr::Value(ValueExpr::Null); + let t = ConcreteDataType::time_microsecond_datatype(); + let v = convert_expr_to_scalar_value(&expr, &t).unwrap(); + assert_eq!(ScalarValue::Time64Microsecond(None), v); + } } diff --git a/src/sql/src/parser.rs b/src/sql/src/parser.rs index 65a12b9ea335..41f771de87cd 100644 --- a/src/sql/src/parser.rs +++ b/src/sql/src/parser.rs @@ -175,6 +175,27 @@ impl<'a> ParserContext<'a> { } } + /// Parses MySQL style 'PREPARE stmt_name FROM stmt' into a (stmt_name, stmt) tuple. + pub fn parse_mysql_prepare_stmt( + sql: &'a str, + dialect: &dyn Dialect, + ) -> Result<(String, String)> { + ParserContext::new(dialect, sql)?.parse_mysql_prepare() + } + + /// Parses MySQL style 'EXECUTE stmt_name USING param_list' into a stmt_name string and a list of parameters. + pub fn parse_mysql_execute_stmt( + sql: &'a str, + dialect: &dyn Dialect, + ) -> Result<(String, Vec)> { + ParserContext::new(dialect, sql)?.parse_mysql_execute() + } + + /// Parses MySQL style 'DEALLOCATE stmt_name' into a stmt_name string. + pub fn parse_mysql_deallocate_stmt(sql: &'a str, dialect: &dyn Dialect) -> Result { + ParserContext::new(dialect, sql)?.parse_deallocate() + } + /// Raises an "unsupported statement" error. pub fn unsupported(&self, keyword: String) -> Result { error::UnsupportedSnafu { @@ -257,6 +278,7 @@ impl<'a> ParserContext<'a> { mod tests { use datatypes::prelude::ConcreteDataType; + use sqlparser::dialect::MySqlDialect; use super::*; use crate::dialect::GreptimeDbDialect; @@ -351,4 +373,57 @@ mod tests { assert_eq!(object_name.0.len(), 1); assert_eq!(object_name.to_string(), table_name.to_ascii_lowercase()); } + + #[test] + pub fn test_parse_mysql_prepare_stmt() { + let sql = "PREPARE stmt1 FROM 'SELECT * FROM t1 WHERE id = ?';"; + let (stmt_name, stmt) = + ParserContext::parse_mysql_prepare_stmt(sql, &MySqlDialect {}).unwrap(); + assert_eq!(stmt_name, "stmt1"); + assert_eq!(stmt, "SELECT * FROM t1 WHERE id = ?"); + + let sql = "PREPARE stmt2 FROM \"SELECT * FROM t1 WHERE id = ?\""; + let (stmt_name, stmt) = + ParserContext::parse_mysql_prepare_stmt(sql, &MySqlDialect {}).unwrap(); + assert_eq!(stmt_name, "stmt2"); + assert_eq!(stmt, "SELECT * FROM t1 WHERE id = ?"); + } + + #[test] + pub fn test_parse_mysql_execute_stmt() { + let sql = "EXECUTE stmt1 USING 1, 'hello';"; + let (stmt_name, params) = + ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap(); + assert_eq!(stmt_name, "stmt1"); + assert_eq!(params.len(), 2); + assert_eq!(params[0].to_string(), "1"); + assert_eq!(params[1].to_string(), "'hello'"); + + let sql = "EXECUTE stmt2;"; + let (stmt_name, params) = + ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap(); + assert_eq!(stmt_name, "stmt2"); + assert_eq!(params.len(), 0); + + let sql = "EXECUTE stmt3 USING 231, 'hello', \"2003-03-1\", NULL, ;"; + let (stmt_name, params) = + ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap(); + assert_eq!(stmt_name, "stmt3"); + assert_eq!(params.len(), 4); + assert_eq!(params[0].to_string(), "231"); + assert_eq!(params[1].to_string(), "'hello'"); + assert_eq!(params[2].to_string(), "\"2003-03-1\""); + assert_eq!(params[3].to_string(), "NULL"); + } + + #[test] + pub fn test_parse_mysql_deallocate_stmt() { + let sql = "DEALLOCATE stmt1;"; + let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap(); + assert_eq!(stmt_name, "stmt1"); + + let sql = "DEALLOCATE stmt2"; + let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap(); + assert_eq!(stmt_name, "stmt2"); + } } diff --git a/src/sql/src/parsers.rs b/src/sql/src/parsers.rs index 721f41367784..e4643da3e90a 100644 --- a/src/sql/src/parsers.rs +++ b/src/sql/src/parsers.rs @@ -15,12 +15,15 @@ mod alter_parser; pub(crate) mod copy_parser; pub(crate) mod create_parser; +pub(crate) mod deallocate_parser; pub(crate) mod delete_parser; pub(crate) mod describe_parser; pub(crate) mod drop_parser; pub(crate) mod error; +pub(crate) mod execute_parser; pub(crate) mod explain_parser; pub(crate) mod insert_parser; +pub(crate) mod prepare_parser; pub(crate) mod query_parser; pub(crate) mod set_var_parser; pub(crate) mod show_parser; diff --git a/src/sql/src/parsers/deallocate_parser.rs b/src/sql/src/parsers/deallocate_parser.rs new file mode 100644 index 000000000000..e53337ac9b4c --- /dev/null +++ b/src/sql/src/parsers/deallocate_parser.rs @@ -0,0 +1,30 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use snafu::ResultExt; +use sqlparser::keywords::Keyword; + +use crate::error::{Result, SyntaxSnafu}; +use crate::parser::ParserContext; + +impl<'a> ParserContext<'a> { + /// Parses MySQL style 'PREPARE stmt_name' into a stmt_name string. + pub(crate) fn parse_deallocate(&mut self) -> Result { + self.parser + .expect_keyword(Keyword::DEALLOCATE) + .context(SyntaxSnafu)?; + let stmt_name = self.parser.parse_identifier(false).context(SyntaxSnafu)?; + Ok(stmt_name.value) + } +} diff --git a/src/sql/src/parsers/execute_parser.rs b/src/sql/src/parsers/execute_parser.rs new file mode 100644 index 000000000000..67b3e8b6690f --- /dev/null +++ b/src/sql/src/parsers/execute_parser.rs @@ -0,0 +1,41 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use snafu::ResultExt; +use sqlparser::ast::Expr; +use sqlparser::keywords::Keyword; +use sqlparser::parser::Parser; + +use crate::error::{Result, SyntaxSnafu}; +use crate::parser::ParserContext; + +impl<'a> ParserContext<'a> { + /// Parses MySQL style 'EXECUTE stmt_name USING param_list' into a stmt_name string and a list of parameters. + /// Only use for MySQL. for PostgreSQL, use `sqlparser::parser::Parser::parse_execute` instead. + pub(crate) fn parse_mysql_execute(&mut self) -> Result<(String, Vec)> { + self.parser + .expect_keyword(Keyword::EXECUTE) + .context(SyntaxSnafu)?; + let stmt_name = self.parser.parse_identifier(false).context(SyntaxSnafu)?; + if self.parser.parse_keyword(Keyword::USING) { + let param_list = self + .parser + .parse_comma_separated(Parser::parse_expr) + .context(SyntaxSnafu)?; + Ok((stmt_name.value, param_list)) + } else { + Ok((stmt_name.value, vec![])) + } + } +} diff --git a/src/sql/src/parsers/prepare_parser.rs b/src/sql/src/parsers/prepare_parser.rs new file mode 100644 index 000000000000..a0fc07456b0e --- /dev/null +++ b/src/sql/src/parsers/prepare_parser.rs @@ -0,0 +1,46 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use snafu::ResultExt; +use sqlparser::keywords::Keyword; +use sqlparser::tokenizer::Token; + +use crate::error::{Result, SyntaxSnafu}; +use crate::parser::ParserContext; + +impl<'a> ParserContext<'a> { + /// Parses MySQL style 'PREPARE stmt_name FROM stmt' into a (stmt_name, stmt) tuple. + /// Only use for MySQL. for PostgreSQL, use `sqlparser::parser::Parser::parse_prepare` instead. + pub(crate) fn parse_mysql_prepare(&mut self) -> Result<(String, String)> { + self.parser + .expect_keyword(Keyword::PREPARE) + .context(SyntaxSnafu)?; + let stmt_name = self.parser.parse_identifier(false).context(SyntaxSnafu)?; + self.parser + .expect_keyword(Keyword::FROM) + .context(SyntaxSnafu)?; + let next_token = self.parser.peek_token(); + let stmt = match next_token.token { + Token::SingleQuotedString(s) | Token::DoubleQuotedString(s) => { + let _ = self.parser.next_token(); + s + } + _ => self + .parser + .expected("string literal", next_token) + .context(SyntaxSnafu)?, + }; + Ok((stmt_name.value, stmt)) + } +}