Skip to content

Commit

Permalink
fix: mysql timezone settings (#2534)
Browse files Browse the repository at this point in the history
* fix: restore time zone settings for mysql

* test: add integration test for time zone

* test: fix unit test for check
  • Loading branch information
sunng87 authored Oct 7, 2023
1 parent b44e39f commit 0ad3fb6
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 27 deletions.
37 changes: 25 additions & 12 deletions src/servers/src/mysql/federated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use once_cell::sync::Lazy;
use regex::bytes::RegexSet;
use regex::Regex;
use session::context::QueryContextRef;
use session::SessionRef;

static SELECT_VAR_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new("(?i)^(SELECT @@(.*))").unwrap());
static MYSQL_CONN_JAVA_PATTERN: Lazy<Regex> =
Expand Down Expand Up @@ -263,12 +264,12 @@ fn check_show_variables(query: &str) -> Option<Output> {
}

// TODO(sunng87): extract this to use sqlparser for more variables
fn check_set_variables(query: &str, query_ctx: QueryContextRef) -> Option<Output> {
fn check_set_variables(query: &str, session: SessionRef) -> Option<Output> {
if let Some(captures) = SET_TIME_ZONE_PATTERN.captures(query) {
// get the capture
let tz = captures.get(1).unwrap();
if let Ok(timezone) = TimeZone::from_tz_string(tz.as_str()) {
query_ctx.set_time_zone(timezone);
session.set_time_zone(timezone);
return Some(Output::AffectedRows(0));
}
}
Expand Down Expand Up @@ -300,7 +301,11 @@ fn check_others(query: &str, query_ctx: QueryContextRef) -> Option<Output> {

// Check whether the query is a federated or driver setup command,
// and return some faked results if there are any.
pub(crate) fn check(query: &str, query_ctx: QueryContextRef) -> Option<Output> {
pub(crate) fn check(
query: &str,
query_ctx: QueryContextRef,
session: SessionRef,
) -> Option<Output> {
// INSERT don't need MySQL federated check. We assume the query doesn't contain
// federated or driver setup command if it starts with a 'INSERT' statement.
if query.len() > 6 && query[..6].eq_ignore_ascii_case("INSERT") {
Expand All @@ -311,7 +316,7 @@ pub(crate) fn check(query: &str, query_ctx: QueryContextRef) -> Option<Output> {
check_select_variable(query, query_ctx.clone())
// Then to check "show variables like ...".
.or_else(|| check_show_variables(query))
.or_else(|| check_set_variables(query, query_ctx.clone()))
.or_else(|| check_set_variables(query, session.clone()))
// Last check
.or_else(|| check_others(query, query_ctx))
}
Expand All @@ -326,22 +331,25 @@ fn get_version() -> String {
#[cfg(test)]
mod test {

use session::context::QueryContext;
use session::context::{Channel, QueryContext};
use session::Session;

use super::*;

#[test]
fn test_check() {
let session = Arc::new(Session::new(None, Channel::Mysql));
let query = "select 1";
let result = check(query, QueryContext::arc());
let result = check(query, QueryContext::arc(), session.clone());
assert!(result.is_none());

let query = "select versiona";
let output = check(query, QueryContext::arc());
let output = check(query, QueryContext::arc(), session.clone());
assert!(output.is_none());

fn test(query: &str, expected: &str) {
let output = check(query, QueryContext::arc());
let session = Arc::new(Session::new(None, Channel::Mysql));
let output = check(query, QueryContext::arc(), session.clone());
match output.unwrap() {
Output::RecordBatches(r) => {
assert_eq!(&r.pretty_print().unwrap(), expected)
Expand All @@ -352,7 +360,7 @@ mod test {

let query = "select version()";
let version = env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| "unknown".to_string());
let output = check(query, QueryContext::arc());
let output = check(query, QueryContext::arc(), session.clone());
match output.unwrap() {
Output::RecordBatches(r) => {
assert!(&r
Expand Down Expand Up @@ -430,17 +438,22 @@ mod test {

#[test]
fn test_set_time_zone() {
let query_context = QueryContext::arc();
let output = check("set time_zone = 'UTC'", query_context.clone());
let session = Arc::new(Session::new(None, Channel::Mysql));
let output = check(
"set time_zone = 'UTC'",
QueryContext::arc(),
session.clone(),
);
match output.unwrap() {
Output::AffectedRows(rows) => {
assert_eq!(rows, 0)
}
_ => unreachable!(),
}
let query_context = session.new_query_context();
assert_eq!("UTC", query_context.time_zone().unwrap().to_string());

let output = check("select @@time_zone", query_context);
let output = check("select @@time_zone", query_context.clone(), session.clone());
match output.unwrap() {
Output::RecordBatches(r) => {
let expected = "\
Expand Down
8 changes: 6 additions & 2 deletions src/servers/src/mysql/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ impl MysqlInstanceShim {
}

async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec<Result<Output>> {
if let Some(output) = crate::mysql::federated::check(query, query_ctx.clone()) {
if let Some(output) =
crate::mysql::federated::check(query, query_ctx.clone(), self.session.clone())
{
vec![Ok(output)]
} else {
let trace_id = query_ctx.trace_id();
Expand All @@ -110,7 +112,9 @@ impl MysqlInstanceShim {
plan: LogicalPlan,
query_ctx: QueryContextRef,
) -> Result<Output> {
if let Some(output) = crate::mysql::federated::check(query, query_ctx.clone()) {
if let Some(output) =
crate::mysql::federated::check(query, query_ctx.clone(), self.session.clone())
{
Ok(output)
} else {
self.query_handler.do_exec_plan(plan, query_ctx).await
Expand Down
13 changes: 3 additions & 10 deletions src/session/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub struct QueryContext {
current_catalog: String,
current_schema: String,
current_user: ArcSwap<Option<UserInfoRef>>,
time_zone: ArcSwap<Option<TimeZone>>,
time_zone: Option<TimeZone>,
sql_dialect: Box<dyn Dialect + Send + Sync>,
trace_id: u64,
}
Expand Down Expand Up @@ -103,12 +103,7 @@ impl QueryContext {

#[inline]
pub fn time_zone(&self) -> Option<TimeZone> {
self.time_zone.load().as_ref().clone()
}

#[inline]
pub fn set_time_zone(&self, tz: Option<TimeZone>) {
let _ = self.time_zone.swap(Arc::new(tz));
self.time_zone.clone()
}

#[inline]
Expand Down Expand Up @@ -139,9 +134,7 @@ impl QueryContextBuilder {
current_user: self
.current_user
.unwrap_or_else(|| ArcSwap::new(Arc::new(None))),
time_zone: self
.time_zone
.unwrap_or_else(|| ArcSwap::new(Arc::new(None))),
time_zone: self.time_zone.unwrap_or(None),
sql_dialect: self
.sql_dialect
.unwrap_or_else(|| Box::new(GreptimeDbDialect {})),
Expand Down
14 changes: 14 additions & 0 deletions src/session/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use arc_swap::ArcSwap;
use auth::UserInfoRef;
use common_catalog::build_db_string;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_time::TimeZone;
use context::QueryContextBuilder;

use crate::context::{Channel, ConnInfo, QueryContextRef};
Expand All @@ -32,6 +33,7 @@ pub struct Session {
schema: ArcSwap<String>,
user_info: ArcSwap<UserInfoRef>,
conn_info: ConnInfo,
time_zone: ArcSwap<Option<TimeZone>>,
}

pub type SessionRef = Arc<Session>;
Expand All @@ -43,6 +45,7 @@ impl Session {
schema: ArcSwap::new(Arc::new(DEFAULT_SCHEMA_NAME.into())),
user_info: ArcSwap::new(Arc::new(auth::userinfo_by_name(None))),
conn_info: ConnInfo::new(addr, channel),
time_zone: ArcSwap::new(Arc::new(None)),
}
}

Expand All @@ -55,6 +58,7 @@ impl Session {
.current_catalog(self.catalog.load().to_string())
.current_schema(self.schema.load().to_string())
.sql_dialect(self.conn_info.channel.dialect())
.time_zone((**self.time_zone.load()).clone())
.build()
}

Expand All @@ -68,6 +72,16 @@ impl Session {
&mut self.conn_info
}

#[inline]
pub fn time_zone(&self) -> Option<TimeZone> {
self.time_zone.load().as_ref().clone()
}

#[inline]
pub fn set_time_zone(&self, tz: Option<TimeZone>) {
let _ = self.time_zone.swap(Arc::new(tz));
}

#[inline]
pub fn user_info(&self) -> UserInfoRef {
self.user_info.load().clone().as_ref().clone()
Expand Down
50 changes: 47 additions & 3 deletions tests-integration/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
// limitations under the License.

use auth::user_provider_from_option;
use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc};
use sqlx::mysql::{MySqlDatabaseError, MySqlPoolOptions};
use chrono::{DateTime, NaiveDate, NaiveDateTime, SecondsFormat, Utc};
use sqlx::mysql::{MySqlConnection, MySqlDatabaseError, MySqlPoolOptions};
use sqlx::postgres::{PgDatabaseError, PgPoolOptions};
use sqlx::Row;
use sqlx::{Connection, Executor, Row};
use tests_integration::test_util::{
setup_mysql_server, setup_mysql_server_with_user_provider, setup_pg_server,
setup_pg_server_with_user_provider, StorageType,
Expand Down Expand Up @@ -55,6 +55,7 @@ macro_rules! sql_tests {

test_mysql_auth,
test_mysql_crud,
test_mysql_timezone,
test_postgres_auth,
test_postgres_crud,
test_postgres_parameter_inference,
Expand Down Expand Up @@ -207,6 +208,49 @@ pub async fn test_mysql_crud(store_type: StorageType) {
guard.remove_all().await;
}

pub async fn test_mysql_timezone(store_type: StorageType) {
common_telemetry::init_default_ut_logging();

let (addr, mut guard, fe_mysql_server) = setup_mysql_server(store_type, "mysql_timezone").await;
let mut conn = MySqlConnection::connect(&format!("mysql://{addr}/public"))
.await
.unwrap();

let _ = conn.execute("SET time_zone = 'UTC'").await.unwrap();
let time_zone = conn.fetch_all("SELECT @@time_zone").await.unwrap();
assert_eq!(time_zone[0].get::<String, usize>(0), "UTC");

// test data
let _ = conn
.execute("create table demo(i bigint, ts timestamp time index)")
.await
.unwrap();
let _ = conn
.execute("insert into demo values(1, 1667446797450)")
.await
.unwrap();
let rows = conn.fetch_all("select ts from demo").await.unwrap();
assert_eq!(
rows[0]
.get::<chrono::DateTime<Utc>, usize>(0)
.to_rfc3339_opts(SecondsFormat::Millis, true),
"2022-11-03T03:39:57.450Z"
);

let _ = conn.execute("SET time_zone = '+08:00'").await.unwrap();
let rows2 = conn.fetch_all("select ts from demo").await.unwrap();
// we use Utc here for format only
assert_eq!(
rows2[0]
.get::<chrono::DateTime<Utc>, usize>(0)
.to_rfc3339_opts(SecondsFormat::Millis, true),
"2022-11-03T11:39:57.450Z"
);

let _ = fe_mysql_server.shutdown().await;
guard.remove_all().await;
}

pub async fn test_postgres_auth(store_type: StorageType) {
let user_provider = user_provider_from_option(
&"static_user_provider:cmd:greptime_user=greptime_pwd".to_string(),
Expand Down

0 comments on commit 0ad3fb6

Please sign in to comment.