From e3c0b5482f379f71433eb2fbd0d8ea63c38018f1 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Tue, 24 Sep 2024 16:45:55 +0800 Subject: [PATCH] feat: returning warning instead of error on unsupported `SET` statement (#4761) * feat: add capability to send warning to pgclient * fix: refactor query context to carry query scope data * feat: return a warning for unsupported postgres statement --- src/operator/src/statement.rs | 16 +++++-- src/servers/src/postgres/handler.rs | 42 ++++++++++++++++--- src/servers/src/postgres/types.rs | 2 +- src/servers/src/postgres/types/error.rs | 40 +++++++++--------- src/session/src/context.rs | 56 ++++++++++++++++++------- src/session/src/lib.rs | 2 +- 6 files changed, 112 insertions(+), 46 deletions(-) diff --git a/src/operator/src/statement.rs b/src/operator/src/statement.rs index 4dc43e0d92e9..7c76d0dcfffc 100644 --- a/src/operator/src/statement.rs +++ b/src/operator/src/statement.rs @@ -46,7 +46,7 @@ use datafusion_expr::LogicalPlan; use partition::manager::{PartitionRuleManager, PartitionRuleManagerRef}; use query::parser::QueryStatement; use query::QueryEngineRef; -use session::context::QueryContextRef; +use session::context::{Channel, QueryContextRef}; use session::table_name::table_idents_to_full_name; use snafu::{ensure, OptionExt, ResultExt}; use sql::statements::copy::{CopyDatabase, CopyDatabaseArgument, CopyTable, CopyTableArgument}; @@ -338,10 +338,18 @@ impl StatementExecutor { "CLIENT_ENCODING" => validate_client_encoding(set_var)?, _ => { - return NotSupportedSnafu { - feat: format!("Unsupported set variable {}", var_name), + // for postgres, we give unknown SET statements a warning with + // success, this is prevent the SET call becoming a blocker + // of connection establishment + // + if query_ctx.channel() == Channel::Postgres { + query_ctx.set_warning(format!("Unsupported set variable {}", var_name)); + } else { + return NotSupportedSnafu { + feat: format!("Unsupported set variable {}", var_name), + } + .fail(); } - .fail() } } Ok(Output::new_with_affected_rows(0)) diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 158e2cab4da9..522c558cdc71 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::fmt::Debug; use std::sync::Arc; use async_trait::async_trait; @@ -23,7 +24,7 @@ use common_telemetry::{debug, error, tracing}; use datafusion_common::ParamValues; use datatypes::prelude::ConcreteDataType; use datatypes::schema::SchemaRef; -use futures::{future, stream, Stream, StreamExt}; +use futures::{future, stream, Sink, SinkExt, Stream, StreamExt}; use pgwire::api::portal::{Format, Portal}; use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{ @@ -32,6 +33,7 @@ use pgwire::api::results::{ use pgwire::api::stmt::{QueryParser, StoredStatement}; use pgwire::api::{ClientInfo, Type}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; +use pgwire::messages::PgWireBackendMessage; use query::query_engine::DescribeResult; use session::context::QueryContextRef; use session::Session; @@ -49,11 +51,13 @@ impl SimpleQueryHandler for PostgresServerHandlerInner { #[tracing::instrument(skip_all, fields(protocol = "postgres"))] async fn do_query<'a, C>( &self, - _client: &mut C, + client: &mut C, query: &'a str, ) -> PgWireResult>> where - C: ClientInfo + Unpin + Send + Sync, + C: ClientInfo + Sink + Unpin + Send + Sync, + C::Error: Debug, + PgWireError: From<>::Error>, { let query_ctx = self.session.new_query_context(); let db = query_ctx.get_db_string(); @@ -67,6 +71,7 @@ impl SimpleQueryHandler for PostgresServerHandlerInner { } if let Some(resps) = fixtures::process(query, query_ctx.clone()) { + send_warning_opt(client, query_ctx).await?; Ok(resps) } else { let outputs = self.query_handler.do_query(query, query_ctx.clone()).await; @@ -79,11 +84,34 @@ impl SimpleQueryHandler for PostgresServerHandlerInner { results.push(resp); } + send_warning_opt(client, query_ctx).await?; Ok(results) } } } +async fn send_warning_opt(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()> +where + C: Sink + Unpin + Send + Sync, + C::Error: Debug, + PgWireError: From<>::Error>, +{ + if let Some(warning) = query_context.warning() { + client + .feed(PgWireBackendMessage::NoticeResponse( + ErrorInfo::new( + PgErrorSeverity::Warning.to_string(), + PgErrorCode::Ec01000.code(), + warning.to_string(), + ) + .into(), + )) + .await?; + } + + Ok(()) +} + pub(crate) fn output_to_query_response<'a>( query_ctx: QueryContextRef, output: Result, @@ -247,12 +275,14 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner { async fn do_query<'a, C>( &self, - _client: &mut C, + client: &mut C, portal: &'a Portal, _max_rows: usize, ) -> PgWireResult> where - C: ClientInfo + Unpin + Send + Sync, + C: ClientInfo + Sink + Unpin + Send + Sync, + C::Error: Debug, + PgWireError: From<>::Error>, { let query_ctx = self.session.new_query_context(); let db = query_ctx.get_db_string(); @@ -268,6 +298,7 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner { } if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) { + send_warning_opt(client, query_ctx).await?; // if the statement matches our predefined rules, return it early return Ok(resps.remove(0)); } @@ -297,6 +328,7 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner { .remove(0) }; + send_warning_opt(client, query_ctx.clone()).await?; output_to_query_response(query_ctx, output, &portal.result_column_format) } diff --git a/src/servers/src/postgres/types.rs b/src/servers/src/postgres/types.rs index 2e4a805ef0bc..a5d1d392ac3b 100644 --- a/src/servers/src/postgres/types.rs +++ b/src/servers/src/postgres/types.rs @@ -37,7 +37,7 @@ use session::session_config::PGByteaOutputValue; use self::bytea::{EscapeOutputBytea, HexOutputBytea}; use self::datetime::{StylingDate, StylingDateTime}; -pub use self::error::PgErrorCode; +pub use self::error::{PgErrorCode, PgErrorSeverity}; use self::interval::PgInterval; use crate::error::{self as server_error, Error, Result}; use crate::SqlPlan; diff --git a/src/servers/src/postgres/types/error.rs b/src/servers/src/postgres/types/error.rs index 928c5454ce27..9e6f570f2610 100644 --- a/src/servers/src/postgres/types/error.rs +++ b/src/servers/src/postgres/types/error.rs @@ -19,7 +19,7 @@ use strum::{AsRefStr, Display, EnumIter, EnumMessage}; #[derive(Display, Debug, PartialEq)] #[allow(dead_code)] -enum ErrorSeverity { +pub enum PgErrorSeverity { #[strum(serialize = "INFO")] Info, #[strum(serialize = "DEBUG")] @@ -335,23 +335,23 @@ pub enum PgErrorCode { } impl PgErrorCode { - fn severity(&self) -> ErrorSeverity { + fn severity(&self) -> PgErrorSeverity { match self { - PgErrorCode::Ec00000 => ErrorSeverity::Info, - PgErrorCode::Ec01000 => ErrorSeverity::Warning, + PgErrorCode::Ec00000 => PgErrorSeverity::Info, + PgErrorCode::Ec01000 => PgErrorSeverity::Warning, PgErrorCode::EcXX000 | PgErrorCode::Ec42P14 | PgErrorCode::Ec22023 => { - ErrorSeverity::Error + PgErrorSeverity::Error } PgErrorCode::Ec28000 | PgErrorCode::Ec28P01 | PgErrorCode::Ec3D000 => { - ErrorSeverity::Fatal + PgErrorSeverity::Fatal } - _ => ErrorSeverity::Error, + _ => PgErrorSeverity::Error, } } - fn code(&self) -> String { + pub(crate) fn code(&self) -> String { self.as_ref()[2..].to_string() } @@ -428,19 +428,19 @@ mod tests { use common_error::status_code::StatusCode; use strum::{EnumMessage, IntoEnumIterator}; - use super::{ErrorInfo, ErrorSeverity, PgErrorCode}; + use super::{ErrorInfo, PgErrorCode, PgErrorSeverity}; #[test] fn test_error_severity() { // test for ErrorSeverity enum - assert_eq!("INFO", ErrorSeverity::Info.to_string()); - assert_eq!("DEBUG", ErrorSeverity::Debug.to_string()); - assert_eq!("NOTICE", ErrorSeverity::Notice.to_string()); - assert_eq!("WARNING", ErrorSeverity::Warning.to_string()); + assert_eq!("INFO", PgErrorSeverity::Info.to_string()); + assert_eq!("DEBUG", PgErrorSeverity::Debug.to_string()); + assert_eq!("NOTICE", PgErrorSeverity::Notice.to_string()); + assert_eq!("WARNING", PgErrorSeverity::Warning.to_string()); - assert_eq!("ERROR", ErrorSeverity::Error.to_string()); - assert_eq!("FATAL", ErrorSeverity::Fatal.to_string()); - assert_eq!("PANIC", ErrorSeverity::Panic.to_string()); + assert_eq!("ERROR", PgErrorSeverity::Error.to_string()); + assert_eq!("FATAL", PgErrorSeverity::Fatal.to_string()); + assert_eq!("PANIC", PgErrorSeverity::Panic.to_string()); // test for severity method for code in PgErrorCode::iter() { @@ -448,13 +448,13 @@ mod tests { assert_eq!("Ec", &name[0..2]); if name.starts_with("Ec00") { - assert_eq!(ErrorSeverity::Info, code.severity()); + assert_eq!(PgErrorSeverity::Info, code.severity()); } else if name.starts_with("Ec01") { - assert_eq!(ErrorSeverity::Warning, code.severity()); + assert_eq!(PgErrorSeverity::Warning, code.severity()); } else if name.starts_with("Ec28") || name.starts_with("Ec3D") { - assert_eq!(ErrorSeverity::Fatal, code.severity()); + assert_eq!(PgErrorSeverity::Fatal, code.severity()); } else { - assert_eq!(ErrorSeverity::Error, code.severity()); + assert_eq!(PgErrorSeverity::Error, code.severity()); } } } diff --git a/src/session/src/context.rs b/src/session/src/context.rs index 70168d9498eb..f85a8ceea313 100644 --- a/src/session/src/context.rs +++ b/src/session/src/context.rs @@ -40,7 +40,9 @@ pub struct QueryContext { current_catalog: String, // we use Arc> for modifiable fields #[builder(default)] - mutable_inner: Arc>, + mutable_session_data: Arc>, + #[builder(default)] + mutable_query_context_data: Arc>, sql_dialect: Arc, #[builder(default)] extensions: HashMap, @@ -52,6 +54,12 @@ pub struct QueryContext { channel: Channel, } +/// This fields hold data that is only valid to current query context +#[derive(Debug, Builder, Clone, Default)] +pub struct QueryContextMutableFields { + warning: Option, +} + impl Display for QueryContext { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( @@ -65,21 +73,26 @@ impl Display for QueryContext { impl QueryContextBuilder { pub fn current_schema(mut self, schema: String) -> Self { - if self.mutable_inner.is_none() { - self.mutable_inner = Some(Arc::new(RwLock::new(MutableInner::default()))); + if self.mutable_session_data.is_none() { + self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default()))); } // safe for unwrap because previous none check - self.mutable_inner.as_mut().unwrap().write().unwrap().schema = schema; + self.mutable_session_data + .as_mut() + .unwrap() + .write() + .unwrap() + .schema = schema; self } pub fn timezone(mut self, timezone: Timezone) -> Self { - if self.mutable_inner.is_none() { - self.mutable_inner = Some(Arc::new(RwLock::new(MutableInner::default()))); + if self.mutable_session_data.is_none() { + self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default()))); } - self.mutable_inner + self.mutable_session_data .as_mut() .unwrap() .write() @@ -120,7 +133,7 @@ impl From for api::v1::QueryContext { fn from( QueryContext { current_catalog, - mutable_inner, + mutable_session_data: mutable_inner, extensions, channel, .. @@ -182,11 +195,11 @@ impl QueryContext { } pub fn current_schema(&self) -> String { - self.mutable_inner.read().unwrap().schema.clone() + self.mutable_session_data.read().unwrap().schema.clone() } pub fn set_current_schema(&self, new_schema: &str) { - self.mutable_inner.write().unwrap().schema = new_schema.to_string(); + self.mutable_session_data.write().unwrap().schema = new_schema.to_string(); } pub fn current_catalog(&self) -> &str { @@ -208,19 +221,19 @@ impl QueryContext { } pub fn timezone(&self) -> Timezone { - self.mutable_inner.read().unwrap().timezone.clone() + self.mutable_session_data.read().unwrap().timezone.clone() } pub fn set_timezone(&self, timezone: Timezone) { - self.mutable_inner.write().unwrap().timezone = timezone; + self.mutable_session_data.write().unwrap().timezone = timezone; } pub fn current_user(&self) -> UserInfoRef { - self.mutable_inner.read().unwrap().user_info.clone() + self.mutable_session_data.read().unwrap().user_info.clone() } pub fn set_current_user(&self, user: UserInfoRef) { - self.mutable_inner.write().unwrap().user_info = user; + self.mutable_session_data.write().unwrap().user_info = user; } pub fn set_extension, S2: Into>(&mut self, key: S1, value: S2) { @@ -257,6 +270,18 @@ impl QueryContext { pub fn set_channel(&mut self, channel: Channel) { self.channel = channel; } + + pub fn warning(&self) -> Option { + self.mutable_query_context_data + .read() + .unwrap() + .warning + .clone() + } + + pub fn set_warning(&self, msg: String) { + self.mutable_query_context_data.write().unwrap().warning = Some(msg); + } } impl QueryContextBuilder { @@ -266,7 +291,8 @@ impl QueryContextBuilder { current_catalog: self .current_catalog .unwrap_or_else(|| DEFAULT_CATALOG_NAME.to_string()), - mutable_inner: self.mutable_inner.unwrap_or_default(), + mutable_session_data: self.mutable_session_data.unwrap_or_default(), + mutable_query_context_data: self.mutable_query_context_data.unwrap_or_default(), sql_dialect: self .sql_dialect .unwrap_or_else(|| Arc::new(GreptimeDbDialect {})), diff --git a/src/session/src/lib.rs b/src/session/src/lib.rs index ecfc02f23001..33bd140c7057 100644 --- a/src/session/src/lib.rs +++ b/src/session/src/lib.rs @@ -76,7 +76,7 @@ impl Session { // catalog is not allowed for update in query context so we use // string here .current_catalog(self.catalog.read().unwrap().clone()) - .mutable_inner(self.mutable_inner.clone()) + .mutable_session_data(self.mutable_inner.clone()) .sql_dialect(self.conn_info.channel.dialect()) .configuration_parameter(self.configuration_variables.clone()) .channel(self.conn_info.channel)