From b7e12dccc6a1e8c7a532f16841e905197d3f1b75 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Sat, 7 Oct 2023 16:05:25 +0800 Subject: [PATCH] test: fix unit test for check --- src/servers/src/mysql/federated.rs | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/servers/src/mysql/federated.rs b/src/servers/src/mysql/federated.rs index f0470fcee018..e3d4a8ce803e 100644 --- a/src/servers/src/mysql/federated.rs +++ b/src/servers/src/mysql/federated.rs @@ -331,22 +331,25 @@ fn get_version() -> String { #[cfg(test)] mod test { - use session::context::QueryContext; + use session::context::{Channel, QueryContext}; + use session::Session; use super::*; #[test] fn test_check() { + let session = Arc::new(Session::new(None, Channel::Mysql)); let query = "select 1"; - let result = check(query, QueryContext::arc()); + let result = check(query, QueryContext::arc(), session.clone()); assert!(result.is_none()); let query = "select versiona"; - let output = check(query, QueryContext::arc()); + let output = check(query, QueryContext::arc(), session.clone()); assert!(output.is_none()); fn test(query: &str, expected: &str) { - let output = check(query, QueryContext::arc()); + let session = Arc::new(Session::new(None, Channel::Mysql)); + let output = check(query, QueryContext::arc(), session.clone()); match output.unwrap() { Output::RecordBatches(r) => { assert_eq!(&r.pretty_print().unwrap(), expected) @@ -357,7 +360,7 @@ mod test { let query = "select version()"; let version = env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| "unknown".to_string()); - let output = check(query, QueryContext::arc()); + let output = check(query, QueryContext::arc(), session.clone()); match output.unwrap() { Output::RecordBatches(r) => { assert!(&r @@ -435,17 +438,22 @@ mod test { #[test] fn test_set_time_zone() { - let query_context = QueryContext::arc(); - let output = check("set time_zone = 'UTC'", query_context.clone()); + let session = Arc::new(Session::new(None, Channel::Mysql)); + let output = check( + "set time_zone = 'UTC'", + QueryContext::arc(), + session.clone(), + ); match output.unwrap() { Output::AffectedRows(rows) => { assert_eq!(rows, 0) } _ => unreachable!(), } + let query_context = session.new_query_context(); assert_eq!("UTC", query_context.time_zone().unwrap().to_string()); - let output = check("select @@time_zone", query_context); + let output = check("select @@time_zone", query_context.clone(), session.clone()); match output.unwrap() { Output::RecordBatches(r) => { let expected = "\