Skip to content

Commit

Permalink
test: fix unit test for check
Browse files Browse the repository at this point in the history
  • Loading branch information
sunng87 committed Oct 7, 2023
1 parent 8eb57f2 commit b7e12dc
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions src/servers/src/mysql/federated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,22 +331,25 @@ fn get_version() -> String {
#[cfg(test)]
mod test {

use session::context::QueryContext;
use session::context::{Channel, QueryContext};
use session::Session;

use super::*;

#[test]
fn test_check() {
let session = Arc::new(Session::new(None, Channel::Mysql));
let query = "select 1";
let result = check(query, QueryContext::arc());
let result = check(query, QueryContext::arc(), session.clone());
assert!(result.is_none());

let query = "select versiona";
let output = check(query, QueryContext::arc());
let output = check(query, QueryContext::arc(), session.clone());
assert!(output.is_none());

fn test(query: &str, expected: &str) {
let output = check(query, QueryContext::arc());
let session = Arc::new(Session::new(None, Channel::Mysql));
let output = check(query, QueryContext::arc(), session.clone());
match output.unwrap() {
Output::RecordBatches(r) => {
assert_eq!(&r.pretty_print().unwrap(), expected)
Expand All @@ -357,7 +360,7 @@ mod test {

let query = "select version()";
let version = env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| "unknown".to_string());
let output = check(query, QueryContext::arc());
let output = check(query, QueryContext::arc(), session.clone());
match output.unwrap() {
Output::RecordBatches(r) => {
assert!(&r
Expand Down Expand Up @@ -435,17 +438,22 @@ mod test {

#[test]
fn test_set_time_zone() {
let query_context = QueryContext::arc();
let output = check("set time_zone = 'UTC'", query_context.clone());
let session = Arc::new(Session::new(None, Channel::Mysql));
let output = check(
"set time_zone = 'UTC'",
QueryContext::arc(),
session.clone(),
);
match output.unwrap() {
Output::AffectedRows(rows) => {
assert_eq!(rows, 0)
}
_ => unreachable!(),
}
let query_context = session.new_query_context();
assert_eq!("UTC", query_context.time_zone().unwrap().to_string());

let output = check("select @@time_zone", query_context);
let output = check("select @@time_zone", query_context.clone(), session.clone());
match output.unwrap() {
Output::RecordBatches(r) => {
let expected = "\
Expand Down

0 comments on commit b7e12dc

Please sign in to comment.