Skip to content

Commit

Permalink
feat: returning warning instead of error on unsupported SET stateme…
Browse files Browse the repository at this point in the history
…nt (#4761)

* feat: add capability to send warning to pgclient

* fix: refactor query context to carry query scope data

* feat: return a warning for unsupported postgres statement
  • Loading branch information
sunng87 authored Sep 24, 2024
1 parent d1b2527 commit e3c0b54
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 46 deletions.
16 changes: 12 additions & 4 deletions src/operator/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ use datafusion_expr::LogicalPlan;
use partition::manager::{PartitionRuleManager, PartitionRuleManagerRef};
use query::parser::QueryStatement;
use query::QueryEngineRef;
use session::context::QueryContextRef;
use session::context::{Channel, QueryContextRef};
use session::table_name::table_idents_to_full_name;
use snafu::{ensure, OptionExt, ResultExt};
use sql::statements::copy::{CopyDatabase, CopyDatabaseArgument, CopyTable, CopyTableArgument};
Expand Down Expand Up @@ -338,10 +338,18 @@ impl StatementExecutor {

"CLIENT_ENCODING" => validate_client_encoding(set_var)?,
_ => {
return NotSupportedSnafu {
feat: format!("Unsupported set variable {}", var_name),
// for postgres, we give unknown SET statements a warning with
// success, this is prevent the SET call becoming a blocker
// of connection establishment
//
if query_ctx.channel() == Channel::Postgres {
query_ctx.set_warning(format!("Unsupported set variable {}", var_name));
} else {
return NotSupportedSnafu {
feat: format!("Unsupported set variable {}", var_name),
}
.fail();
}
.fail()
}
}
Ok(Output::new_with_affected_rows(0))
Expand Down
42 changes: 37 additions & 5 deletions src/servers/src/postgres/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::fmt::Debug;
use std::sync::Arc;

use async_trait::async_trait;
Expand All @@ -23,7 +24,7 @@ use common_telemetry::{debug, error, tracing};
use datafusion_common::ParamValues;
use datatypes::prelude::ConcreteDataType;
use datatypes::schema::SchemaRef;
use futures::{future, stream, Stream, StreamExt};
use futures::{future, stream, Sink, SinkExt, Stream, StreamExt};
use pgwire::api::portal::{Format, Portal};
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
use pgwire::api::results::{
Expand All @@ -32,6 +33,7 @@ use pgwire::api::results::{
use pgwire::api::stmt::{QueryParser, StoredStatement};
use pgwire::api::{ClientInfo, Type};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::messages::PgWireBackendMessage;
use query::query_engine::DescribeResult;
use session::context::QueryContextRef;
use session::Session;
Expand All @@ -49,11 +51,13 @@ impl SimpleQueryHandler for PostgresServerHandlerInner {
#[tracing::instrument(skip_all, fields(protocol = "postgres"))]
async fn do_query<'a, C>(
&self,
_client: &mut C,
client: &mut C,
query: &'a str,
) -> PgWireResult<Vec<Response<'a>>>
where
C: ClientInfo + Unpin + Send + Sync,
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
let query_ctx = self.session.new_query_context();
let db = query_ctx.get_db_string();
Expand All @@ -67,6 +71,7 @@ impl SimpleQueryHandler for PostgresServerHandlerInner {
}

if let Some(resps) = fixtures::process(query, query_ctx.clone()) {
send_warning_opt(client, query_ctx).await?;
Ok(resps)
} else {
let outputs = self.query_handler.do_query(query, query_ctx.clone()).await;
Expand All @@ -79,11 +84,34 @@ impl SimpleQueryHandler for PostgresServerHandlerInner {
results.push(resp);
}

send_warning_opt(client, query_ctx).await?;
Ok(results)
}
}
}

async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
where
C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
if let Some(warning) = query_context.warning() {
client
.feed(PgWireBackendMessage::NoticeResponse(
ErrorInfo::new(
PgErrorSeverity::Warning.to_string(),
PgErrorCode::Ec01000.code(),
warning.to_string(),
)
.into(),
))
.await?;
}

Ok(())
}

pub(crate) fn output_to_query_response<'a>(
query_ctx: QueryContextRef,
output: Result<Output>,
Expand Down Expand Up @@ -247,12 +275,14 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {

async fn do_query<'a, C>(
&self,
_client: &mut C,
client: &mut C,
portal: &'a Portal<Self::Statement>,
_max_rows: usize,
) -> PgWireResult<Response<'a>>
where
C: ClientInfo + Unpin + Send + Sync,
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
let query_ctx = self.session.new_query_context();
let db = query_ctx.get_db_string();
Expand All @@ -268,6 +298,7 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
}

if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
send_warning_opt(client, query_ctx).await?;
// if the statement matches our predefined rules, return it early
return Ok(resps.remove(0));
}
Expand Down Expand Up @@ -297,6 +328,7 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
.remove(0)
};

send_warning_opt(client, query_ctx.clone()).await?;
output_to_query_response(query_ctx, output, &portal.result_column_format)
}

Expand Down
2 changes: 1 addition & 1 deletion src/servers/src/postgres/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use session::session_config::PGByteaOutputValue;

use self::bytea::{EscapeOutputBytea, HexOutputBytea};
use self::datetime::{StylingDate, StylingDateTime};
pub use self::error::PgErrorCode;
pub use self::error::{PgErrorCode, PgErrorSeverity};
use self::interval::PgInterval;
use crate::error::{self as server_error, Error, Result};
use crate::SqlPlan;
Expand Down
40 changes: 20 additions & 20 deletions src/servers/src/postgres/types/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use strum::{AsRefStr, Display, EnumIter, EnumMessage};

#[derive(Display, Debug, PartialEq)]
#[allow(dead_code)]
enum ErrorSeverity {
pub enum PgErrorSeverity {
#[strum(serialize = "INFO")]
Info,
#[strum(serialize = "DEBUG")]
Expand Down Expand Up @@ -335,23 +335,23 @@ pub enum PgErrorCode {
}

impl PgErrorCode {
fn severity(&self) -> ErrorSeverity {
fn severity(&self) -> PgErrorSeverity {
match self {
PgErrorCode::Ec00000 => ErrorSeverity::Info,
PgErrorCode::Ec01000 => ErrorSeverity::Warning,
PgErrorCode::Ec00000 => PgErrorSeverity::Info,
PgErrorCode::Ec01000 => PgErrorSeverity::Warning,

PgErrorCode::EcXX000 | PgErrorCode::Ec42P14 | PgErrorCode::Ec22023 => {
ErrorSeverity::Error
PgErrorSeverity::Error
}
PgErrorCode::Ec28000 | PgErrorCode::Ec28P01 | PgErrorCode::Ec3D000 => {
ErrorSeverity::Fatal
PgErrorSeverity::Fatal
}

_ => ErrorSeverity::Error,
_ => PgErrorSeverity::Error,
}
}

fn code(&self) -> String {
pub(crate) fn code(&self) -> String {
self.as_ref()[2..].to_string()
}

Expand Down Expand Up @@ -428,33 +428,33 @@ mod tests {
use common_error::status_code::StatusCode;
use strum::{EnumMessage, IntoEnumIterator};

use super::{ErrorInfo, ErrorSeverity, PgErrorCode};
use super::{ErrorInfo, PgErrorCode, PgErrorSeverity};

#[test]
fn test_error_severity() {
// test for ErrorSeverity enum
assert_eq!("INFO", ErrorSeverity::Info.to_string());
assert_eq!("DEBUG", ErrorSeverity::Debug.to_string());
assert_eq!("NOTICE", ErrorSeverity::Notice.to_string());
assert_eq!("WARNING", ErrorSeverity::Warning.to_string());
assert_eq!("INFO", PgErrorSeverity::Info.to_string());
assert_eq!("DEBUG", PgErrorSeverity::Debug.to_string());
assert_eq!("NOTICE", PgErrorSeverity::Notice.to_string());
assert_eq!("WARNING", PgErrorSeverity::Warning.to_string());

assert_eq!("ERROR", ErrorSeverity::Error.to_string());
assert_eq!("FATAL", ErrorSeverity::Fatal.to_string());
assert_eq!("PANIC", ErrorSeverity::Panic.to_string());
assert_eq!("ERROR", PgErrorSeverity::Error.to_string());
assert_eq!("FATAL", PgErrorSeverity::Fatal.to_string());
assert_eq!("PANIC", PgErrorSeverity::Panic.to_string());

// test for severity method
for code in PgErrorCode::iter() {
let name = code.as_ref();
assert_eq!("Ec", &name[0..2]);

if name.starts_with("Ec00") {
assert_eq!(ErrorSeverity::Info, code.severity());
assert_eq!(PgErrorSeverity::Info, code.severity());
} else if name.starts_with("Ec01") {
assert_eq!(ErrorSeverity::Warning, code.severity());
assert_eq!(PgErrorSeverity::Warning, code.severity());
} else if name.starts_with("Ec28") || name.starts_with("Ec3D") {
assert_eq!(ErrorSeverity::Fatal, code.severity());
assert_eq!(PgErrorSeverity::Fatal, code.severity());
} else {
assert_eq!(ErrorSeverity::Error, code.severity());
assert_eq!(PgErrorSeverity::Error, code.severity());
}
}
}
Expand Down
56 changes: 41 additions & 15 deletions src/session/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ pub struct QueryContext {
current_catalog: String,
// we use Arc<RwLock>> for modifiable fields
#[builder(default)]
mutable_inner: Arc<RwLock<MutableInner>>,
mutable_session_data: Arc<RwLock<MutableInner>>,
#[builder(default)]
mutable_query_context_data: Arc<RwLock<QueryContextMutableFields>>,
sql_dialect: Arc<dyn Dialect + Send + Sync>,
#[builder(default)]
extensions: HashMap<String, String>,
Expand All @@ -52,6 +54,12 @@ pub struct QueryContext {
channel: Channel,
}

/// This fields hold data that is only valid to current query context
#[derive(Debug, Builder, Clone, Default)]
pub struct QueryContextMutableFields {
warning: Option<String>,
}

impl Display for QueryContext {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
Expand All @@ -65,21 +73,26 @@ impl Display for QueryContext {

impl QueryContextBuilder {
pub fn current_schema(mut self, schema: String) -> Self {
if self.mutable_inner.is_none() {
self.mutable_inner = Some(Arc::new(RwLock::new(MutableInner::default())));
if self.mutable_session_data.is_none() {
self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
}

// safe for unwrap because previous none check
self.mutable_inner.as_mut().unwrap().write().unwrap().schema = schema;
self.mutable_session_data
.as_mut()
.unwrap()
.write()
.unwrap()
.schema = schema;
self
}

pub fn timezone(mut self, timezone: Timezone) -> Self {
if self.mutable_inner.is_none() {
self.mutable_inner = Some(Arc::new(RwLock::new(MutableInner::default())));
if self.mutable_session_data.is_none() {
self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
}

self.mutable_inner
self.mutable_session_data
.as_mut()
.unwrap()
.write()
Expand Down Expand Up @@ -120,7 +133,7 @@ impl From<QueryContext> for api::v1::QueryContext {
fn from(
QueryContext {
current_catalog,
mutable_inner,
mutable_session_data: mutable_inner,
extensions,
channel,
..
Expand Down Expand Up @@ -182,11 +195,11 @@ impl QueryContext {
}

pub fn current_schema(&self) -> String {
self.mutable_inner.read().unwrap().schema.clone()
self.mutable_session_data.read().unwrap().schema.clone()
}

pub fn set_current_schema(&self, new_schema: &str) {
self.mutable_inner.write().unwrap().schema = new_schema.to_string();
self.mutable_session_data.write().unwrap().schema = new_schema.to_string();
}

pub fn current_catalog(&self) -> &str {
Expand All @@ -208,19 +221,19 @@ impl QueryContext {
}

pub fn timezone(&self) -> Timezone {
self.mutable_inner.read().unwrap().timezone.clone()
self.mutable_session_data.read().unwrap().timezone.clone()
}

pub fn set_timezone(&self, timezone: Timezone) {
self.mutable_inner.write().unwrap().timezone = timezone;
self.mutable_session_data.write().unwrap().timezone = timezone;
}

pub fn current_user(&self) -> UserInfoRef {
self.mutable_inner.read().unwrap().user_info.clone()
self.mutable_session_data.read().unwrap().user_info.clone()
}

pub fn set_current_user(&self, user: UserInfoRef) {
self.mutable_inner.write().unwrap().user_info = user;
self.mutable_session_data.write().unwrap().user_info = user;
}

pub fn set_extension<S1: Into<String>, S2: Into<String>>(&mut self, key: S1, value: S2) {
Expand Down Expand Up @@ -257,6 +270,18 @@ impl QueryContext {
pub fn set_channel(&mut self, channel: Channel) {
self.channel = channel;
}

pub fn warning(&self) -> Option<String> {
self.mutable_query_context_data
.read()
.unwrap()
.warning
.clone()
}

pub fn set_warning(&self, msg: String) {
self.mutable_query_context_data.write().unwrap().warning = Some(msg);
}
}

impl QueryContextBuilder {
Expand All @@ -266,7 +291,8 @@ impl QueryContextBuilder {
current_catalog: self
.current_catalog
.unwrap_or_else(|| DEFAULT_CATALOG_NAME.to_string()),
mutable_inner: self.mutable_inner.unwrap_or_default(),
mutable_session_data: self.mutable_session_data.unwrap_or_default(),
mutable_query_context_data: self.mutable_query_context_data.unwrap_or_default(),
sql_dialect: self
.sql_dialect
.unwrap_or_else(|| Arc::new(GreptimeDbDialect {})),
Expand Down
2 changes: 1 addition & 1 deletion src/session/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ impl Session {
// catalog is not allowed for update in query context so we use
// string here
.current_catalog(self.catalog.read().unwrap().clone())
.mutable_inner(self.mutable_inner.clone())
.mutable_session_data(self.mutable_inner.clone())
.sql_dialect(self.conn_info.channel.dialect())
.configuration_parameter(self.configuration_variables.clone())
.channel(self.conn_info.channel)
Expand Down

0 comments on commit e3c0b54

Please sign in to comment.