diff --git a/src/servers/src/mysql/federated.rs b/src/servers/src/mysql/federated.rs index 3df71b1c0377..e3d4a8ce803e 100644 --- a/src/servers/src/mysql/federated.rs +++ b/src/servers/src/mysql/federated.rs @@ -30,6 +30,7 @@ use once_cell::sync::Lazy; use regex::bytes::RegexSet; use regex::Regex; use session::context::QueryContextRef; +use session::SessionRef; static SELECT_VAR_PATTERN: Lazy = Lazy::new(|| Regex::new("(?i)^(SELECT @@(.*))").unwrap()); static MYSQL_CONN_JAVA_PATTERN: Lazy = @@ -263,12 +264,12 @@ fn check_show_variables(query: &str) -> Option { } // TODO(sunng87): extract this to use sqlparser for more variables -fn check_set_variables(query: &str, query_ctx: QueryContextRef) -> Option { +fn check_set_variables(query: &str, session: SessionRef) -> Option { if let Some(captures) = SET_TIME_ZONE_PATTERN.captures(query) { // get the capture let tz = captures.get(1).unwrap(); if let Ok(timezone) = TimeZone::from_tz_string(tz.as_str()) { - query_ctx.set_time_zone(timezone); + session.set_time_zone(timezone); return Some(Output::AffectedRows(0)); } } @@ -300,7 +301,11 @@ fn check_others(query: &str, query_ctx: QueryContextRef) -> Option { // Check whether the query is a federated or driver setup command, // and return some faked results if there are any. -pub(crate) fn check(query: &str, query_ctx: QueryContextRef) -> Option { +pub(crate) fn check( + query: &str, + query_ctx: QueryContextRef, + session: SessionRef, +) -> Option { // INSERT don't need MySQL federated check. We assume the query doesn't contain // federated or driver setup command if it starts with a 'INSERT' statement. if query.len() > 6 && query[..6].eq_ignore_ascii_case("INSERT") { @@ -311,7 +316,7 @@ pub(crate) fn check(query: &str, query_ctx: QueryContextRef) -> Option { check_select_variable(query, query_ctx.clone()) // Then to check "show variables like ...". .or_else(|| check_show_variables(query)) - .or_else(|| check_set_variables(query, query_ctx.clone())) + .or_else(|| check_set_variables(query, session.clone())) // Last check .or_else(|| check_others(query, query_ctx)) } @@ -326,22 +331,25 @@ fn get_version() -> String { #[cfg(test)] mod test { - use session::context::QueryContext; + use session::context::{Channel, QueryContext}; + use session::Session; use super::*; #[test] fn test_check() { + let session = Arc::new(Session::new(None, Channel::Mysql)); let query = "select 1"; - let result = check(query, QueryContext::arc()); + let result = check(query, QueryContext::arc(), session.clone()); assert!(result.is_none()); let query = "select versiona"; - let output = check(query, QueryContext::arc()); + let output = check(query, QueryContext::arc(), session.clone()); assert!(output.is_none()); fn test(query: &str, expected: &str) { - let output = check(query, QueryContext::arc()); + let session = Arc::new(Session::new(None, Channel::Mysql)); + let output = check(query, QueryContext::arc(), session.clone()); match output.unwrap() { Output::RecordBatches(r) => { assert_eq!(&r.pretty_print().unwrap(), expected) @@ -352,7 +360,7 @@ mod test { let query = "select version()"; let version = env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| "unknown".to_string()); - let output = check(query, QueryContext::arc()); + let output = check(query, QueryContext::arc(), session.clone()); match output.unwrap() { Output::RecordBatches(r) => { assert!(&r @@ -430,17 +438,22 @@ mod test { #[test] fn test_set_time_zone() { - let query_context = QueryContext::arc(); - let output = check("set time_zone = 'UTC'", query_context.clone()); + let session = Arc::new(Session::new(None, Channel::Mysql)); + let output = check( + "set time_zone = 'UTC'", + QueryContext::arc(), + session.clone(), + ); match output.unwrap() { Output::AffectedRows(rows) => { assert_eq!(rows, 0) } _ => unreachable!(), } + let query_context = session.new_query_context(); assert_eq!("UTC", query_context.time_zone().unwrap().to_string()); - let output = check("select @@time_zone", query_context); + let output = check("select @@time_zone", query_context.clone(), session.clone()); match output.unwrap() { Output::RecordBatches(r) => { let expected = "\ diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index 0deebc02bb30..4674f65b1d9a 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -91,7 +91,9 @@ impl MysqlInstanceShim { } async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec> { - if let Some(output) = crate::mysql::federated::check(query, query_ctx.clone()) { + if let Some(output) = + crate::mysql::federated::check(query, query_ctx.clone(), self.session.clone()) + { vec![Ok(output)] } else { let trace_id = query_ctx.trace_id(); @@ -110,7 +112,9 @@ impl MysqlInstanceShim { plan: LogicalPlan, query_ctx: QueryContextRef, ) -> Result { - if let Some(output) = crate::mysql::federated::check(query, query_ctx.clone()) { + if let Some(output) = + crate::mysql::federated::check(query, query_ctx.clone(), self.session.clone()) + { Ok(output) } else { self.query_handler.do_exec_plan(plan, query_ctx).await diff --git a/src/session/src/context.rs b/src/session/src/context.rs index 219c58e33a04..bfdac8f7725f 100644 --- a/src/session/src/context.rs +++ b/src/session/src/context.rs @@ -34,7 +34,7 @@ pub struct QueryContext { current_catalog: String, current_schema: String, current_user: ArcSwap>, - time_zone: ArcSwap>, + time_zone: Option, sql_dialect: Box, trace_id: u64, } @@ -103,12 +103,7 @@ impl QueryContext { #[inline] pub fn time_zone(&self) -> Option { - self.time_zone.load().as_ref().clone() - } - - #[inline] - pub fn set_time_zone(&self, tz: Option) { - let _ = self.time_zone.swap(Arc::new(tz)); + self.time_zone.clone() } #[inline] @@ -139,9 +134,7 @@ impl QueryContextBuilder { current_user: self .current_user .unwrap_or_else(|| ArcSwap::new(Arc::new(None))), - time_zone: self - .time_zone - .unwrap_or_else(|| ArcSwap::new(Arc::new(None))), + time_zone: self.time_zone.unwrap_or(None), sql_dialect: self .sql_dialect .unwrap_or_else(|| Box::new(GreptimeDbDialect {})), diff --git a/src/session/src/lib.rs b/src/session/src/lib.rs index e67bd5b6f048..2ab4e8c56ee4 100644 --- a/src/session/src/lib.rs +++ b/src/session/src/lib.rs @@ -21,6 +21,7 @@ use arc_swap::ArcSwap; use auth::UserInfoRef; use common_catalog::build_db_string; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; +use common_time::TimeZone; use context::QueryContextBuilder; use crate::context::{Channel, ConnInfo, QueryContextRef}; @@ -32,6 +33,7 @@ pub struct Session { schema: ArcSwap, user_info: ArcSwap, conn_info: ConnInfo, + time_zone: ArcSwap>, } pub type SessionRef = Arc; @@ -43,6 +45,7 @@ impl Session { schema: ArcSwap::new(Arc::new(DEFAULT_SCHEMA_NAME.into())), user_info: ArcSwap::new(Arc::new(auth::userinfo_by_name(None))), conn_info: ConnInfo::new(addr, channel), + time_zone: ArcSwap::new(Arc::new(None)), } } @@ -55,6 +58,7 @@ impl Session { .current_catalog(self.catalog.load().to_string()) .current_schema(self.schema.load().to_string()) .sql_dialect(self.conn_info.channel.dialect()) + .time_zone((**self.time_zone.load()).clone()) .build() } @@ -68,6 +72,16 @@ impl Session { &mut self.conn_info } + #[inline] + pub fn time_zone(&self) -> Option { + self.time_zone.load().as_ref().clone() + } + + #[inline] + pub fn set_time_zone(&self, tz: Option) { + let _ = self.time_zone.swap(Arc::new(tz)); + } + #[inline] pub fn user_info(&self) -> UserInfoRef { self.user_info.load().clone().as_ref().clone() diff --git a/tests-integration/tests/sql.rs b/tests-integration/tests/sql.rs index f8e81230eb3c..1f1a17164d33 100644 --- a/tests-integration/tests/sql.rs +++ b/tests-integration/tests/sql.rs @@ -13,10 +13,10 @@ // limitations under the License. use auth::user_provider_from_option; -use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc}; -use sqlx::mysql::{MySqlDatabaseError, MySqlPoolOptions}; +use chrono::{DateTime, NaiveDate, NaiveDateTime, SecondsFormat, Utc}; +use sqlx::mysql::{MySqlConnection, MySqlDatabaseError, MySqlPoolOptions}; use sqlx::postgres::{PgDatabaseError, PgPoolOptions}; -use sqlx::Row; +use sqlx::{Connection, Executor, Row}; use tests_integration::test_util::{ setup_mysql_server, setup_mysql_server_with_user_provider, setup_pg_server, setup_pg_server_with_user_provider, StorageType, @@ -55,6 +55,7 @@ macro_rules! sql_tests { test_mysql_auth, test_mysql_crud, + test_mysql_timezone, test_postgres_auth, test_postgres_crud, test_postgres_parameter_inference, @@ -207,6 +208,49 @@ pub async fn test_mysql_crud(store_type: StorageType) { guard.remove_all().await; } +pub async fn test_mysql_timezone(store_type: StorageType) { + common_telemetry::init_default_ut_logging(); + + let (addr, mut guard, fe_mysql_server) = setup_mysql_server(store_type, "mysql_timezone").await; + let mut conn = MySqlConnection::connect(&format!("mysql://{addr}/public")) + .await + .unwrap(); + + let _ = conn.execute("SET time_zone = 'UTC'").await.unwrap(); + let time_zone = conn.fetch_all("SELECT @@time_zone").await.unwrap(); + assert_eq!(time_zone[0].get::(0), "UTC"); + + // test data + let _ = conn + .execute("create table demo(i bigint, ts timestamp time index)") + .await + .unwrap(); + let _ = conn + .execute("insert into demo values(1, 1667446797450)") + .await + .unwrap(); + let rows = conn.fetch_all("select ts from demo").await.unwrap(); + assert_eq!( + rows[0] + .get::, usize>(0) + .to_rfc3339_opts(SecondsFormat::Millis, true), + "2022-11-03T03:39:57.450Z" + ); + + let _ = conn.execute("SET time_zone = '+08:00'").await.unwrap(); + let rows2 = conn.fetch_all("select ts from demo").await.unwrap(); + // we use Utc here for format only + assert_eq!( + rows2[0] + .get::, usize>(0) + .to_rfc3339_opts(SecondsFormat::Millis, true), + "2022-11-03T11:39:57.450Z" + ); + + let _ = fe_mysql_server.shutdown().await; + guard.remove_all().await; +} + pub async fn test_postgres_auth(store_type: StorageType) { let user_provider = user_provider_from_option( &"static_user_provider:cmd:greptime_user=greptime_pwd".to_string(),