From 62f9e72f7470af023ff0ac57a764d943813e309a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Wed, 17 Jul 2024 06:40:04 +0800 Subject: [PATCH] Reorganize arrow-flight test code (#6065) * Reorganize test code * asf header * reuse TestFixture * .await * Create flight_sql_client.rs * remove code * remove unused import * Fix clippy lints --- arrow-flight/Cargo.toml | 5 + arrow-flight/tests/client.rs | 108 +----------- arrow-flight/tests/common/fixture.rs | 117 +++++++++++++ arrow-flight/tests/common/mod.rs | 20 +++ arrow-flight/tests/common/server.rs | 23 +++ arrow-flight/tests/flight_sql_client.rs | 119 +++++++++++++ arrow-flight/tests/flight_sql_client_cli.rs | 176 +------------------- 7 files changed, 300 insertions(+), 268 deletions(-) create mode 100644 arrow-flight/tests/common/fixture.rs create mode 100644 arrow-flight/tests/common/mod.rs create mode 100644 arrow-flight/tests/flight_sql_client.rs diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml index 111bf94d804c..539b1ea35d6c 100644 --- a/arrow-flight/Cargo.toml +++ b/arrow-flight/Cargo.toml @@ -86,6 +86,11 @@ required-features = ["flight-sql-experimental", "tls"] name = "flight_sql_client" required-features = ["cli", "flight-sql-experimental", "tls"] +[[test]] +name = "flight_sql_client" +path = "tests/flight_sql_client.rs" +required-features = ["flight-sql-experimental", "tls"] + [[test]] name = "flight_sql_client_cli" path = "tests/flight_sql_client_cli.rs" diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs index 478938d939a9..25dad0e77a3e 100644 --- a/arrow-flight/tests/client.rs +++ b/arrow-flight/tests/client.rs @@ -17,10 +17,9 @@ //! Integration test for "mid level" Client -mod common { - pub mod server; - pub mod trailers_layer; -} +mod common; + +use crate::common::fixture::TestFixture; use arrow_array::{RecordBatch, UInt64Array}; use arrow_flight::{ decode::FlightRecordBatchStream, encode::FlightDataEncoderBuilder, error::FlightError, Action, @@ -30,18 +29,12 @@ use arrow_flight::{ }; use arrow_schema::{DataType, Field, Schema}; use bytes::Bytes; -use common::{server::TestFlightServer, trailers_layer::TrailersLayer}; +use common::server::TestFlightServer; use futures::{Future, StreamExt, TryStreamExt}; use prost::Message; -use tokio::{net::TcpListener, task::JoinHandle}; -use tonic::{ - transport::{Channel, Uri}, - Status, -}; - -use std::{net::SocketAddr, sync::Arc, time::Duration}; +use tonic::Status; -const DEFAULT_TIMEOUT_SECONDS: u64 = 30; +use std::sync::Arc; #[tokio::test] async fn test_handshake() { @@ -1123,7 +1116,7 @@ where Fut: Future, { let test_server = TestFlightServer::new(); - let fixture = TestFixture::new(&test_server).await; + let fixture = TestFixture::new(test_server.service()).await; let client = FlightClient::new(fixture.channel().await); // run the test function @@ -1156,90 +1149,3 @@ fn expect_status(error: FlightError, expected: Status) { "Got {status:?} want {expected:?}" ); } - -/// Creates and manages a running TestServer with a background task -struct TestFixture { - /// channel to send shutdown command - shutdown: Option>, - - /// Address the server is listening on - addr: SocketAddr, - - // handle for the server task - handle: Option>>, -} - -impl TestFixture { - /// create a new test fixture from the server - pub async fn new(test_server: &TestFlightServer) -> Self { - // let OS choose a a free port - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - - println!("Listening on {addr}"); - - // prepare the shutdown channel - let (tx, rx) = tokio::sync::oneshot::channel(); - - let server_timeout = Duration::from_secs(DEFAULT_TIMEOUT_SECONDS); - - let shutdown_future = async move { - rx.await.ok(); - }; - - let serve_future = tonic::transport::Server::builder() - .timeout(server_timeout) - .layer(TrailersLayer) - .add_service(test_server.service()) - .serve_with_incoming_shutdown( - tokio_stream::wrappers::TcpListenerStream::new(listener), - shutdown_future, - ); - - // Run the server in its own background task - let handle = tokio::task::spawn(serve_future); - - Self { - shutdown: Some(tx), - addr, - handle: Some(handle), - } - } - - /// Return a [`Channel`] connected to the TestServer - pub async fn channel(&self) -> Channel { - let url = format!("http://{}", self.addr); - let uri: Uri = url.parse().expect("Valid URI"); - Channel::builder(uri) - .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECONDS)) - .connect() - .await - .expect("error connecting to server") - } - - /// Stops the test server and waits for the server to shutdown - pub async fn shutdown_and_wait(mut self) { - if let Some(shutdown) = self.shutdown.take() { - shutdown.send(()).expect("server quit early"); - } - if let Some(handle) = self.handle.take() { - println!("Waiting on server to finish"); - handle - .await - .expect("task join error (panic?)") - .expect("Server Error found at shutdown"); - } - } -} - -impl Drop for TestFixture { - fn drop(&mut self) { - if let Some(shutdown) = self.shutdown.take() { - shutdown.send(()).ok(); - } - if self.handle.is_some() { - // tests should properly clean up TestFixture - println!("TestFixture::Drop called prior to `shutdown_and_wait`"); - } - } -} diff --git a/arrow-flight/tests/common/fixture.rs b/arrow-flight/tests/common/fixture.rs new file mode 100644 index 000000000000..141879e2a358 --- /dev/null +++ b/arrow-flight/tests/common/fixture.rs @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::common::trailers_layer::TrailersLayer; +use arrow_flight::flight_service_server::{FlightService, FlightServiceServer}; +use http::Uri; +use std::net::SocketAddr; +use std::time::Duration; +use tokio::net::TcpListener; +use tokio::task::JoinHandle; +use tonic::transport::Channel; + +/// All tests must complete within this many seconds or else the test server is shutdown +const DEFAULT_TIMEOUT_SECONDS: u64 = 30; + +/// Creates and manages a running TestServer with a background task +pub struct TestFixture { + /// channel to send shutdown command + shutdown: Option>, + + /// Address the server is listening on + pub addr: SocketAddr, + + /// handle for the server task + handle: Option>>, +} + +impl TestFixture { + /// create a new test fixture from the server + pub async fn new(test_server: FlightServiceServer) -> Self { + // let OS choose a free port + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + println!("Listening on {addr}"); + + // prepare the shutdown channel + let (tx, rx) = tokio::sync::oneshot::channel(); + + let server_timeout = Duration::from_secs(DEFAULT_TIMEOUT_SECONDS); + + let shutdown_future = async move { + rx.await.ok(); + }; + + let serve_future = tonic::transport::Server::builder() + .timeout(server_timeout) + .layer(TrailersLayer) + .add_service(test_server) + .serve_with_incoming_shutdown( + tokio_stream::wrappers::TcpListenerStream::new(listener), + shutdown_future, + ); + + // Run the server in its own background task + let handle = tokio::task::spawn(serve_future); + + Self { + shutdown: Some(tx), + addr, + handle: Some(handle), + } + } + + /// Return a [`Channel`] connected to the TestServer + #[allow(dead_code)] + pub async fn channel(&self) -> Channel { + let url = format!("http://{}", self.addr); + let uri: Uri = url.parse().expect("Valid URI"); + Channel::builder(uri) + .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECONDS)) + .connect() + .await + .expect("error connecting to server") + } + + /// Stops the test server and waits for the server to shutdown + #[allow(dead_code)] + pub async fn shutdown_and_wait(mut self) { + if let Some(shutdown) = self.shutdown.take() { + shutdown.send(()).expect("server quit early"); + } + if let Some(handle) = self.handle.take() { + println!("Waiting on server to finish"); + handle + .await + .expect("task join error (panic?)") + .expect("Server Error found at shutdown"); + } + } +} + +impl Drop for TestFixture { + fn drop(&mut self) { + if let Some(shutdown) = self.shutdown.take() { + shutdown.send(()).ok(); + } + if self.handle.is_some() { + // tests should properly clean up TestFixture + println!("TestFixture::Drop called prior to `shutdown_and_wait`"); + } + } +} diff --git a/arrow-flight/tests/common/mod.rs b/arrow-flight/tests/common/mod.rs new file mode 100644 index 000000000000..85716e56058c --- /dev/null +++ b/arrow-flight/tests/common/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod fixture; +pub mod server; +pub mod trailers_layer; diff --git a/arrow-flight/tests/common/server.rs b/arrow-flight/tests/common/server.rs index a75590a13334..a004ccb0737e 100644 --- a/arrow-flight/tests/common/server.rs +++ b/arrow-flight/tests/common/server.rs @@ -38,6 +38,7 @@ pub struct TestFlightServer { impl TestFlightServer { /// Create a `TestFlightServer` + #[allow(dead_code)] pub fn new() -> Self { Self { state: Arc::new(Mutex::new(State::new())), @@ -46,18 +47,21 @@ impl TestFlightServer { /// Return an [`FlightServiceServer`] that can be used with a /// [`Server`](tonic::transport::Server) + #[allow(dead_code)] pub fn service(&self) -> FlightServiceServer { // wrap up tonic goop FlightServiceServer::new(self.clone()) } /// Specify the response returned from the next call to handshake + #[allow(dead_code)] pub fn set_handshake_response(&self, response: Result) { let mut state = self.state.lock().expect("mutex not poisoned"); state.handshake_response.replace(response); } /// Take and return last handshake request sent to the server, + #[allow(dead_code)] pub fn take_handshake_request(&self) -> Option { self.state .lock() @@ -67,12 +71,14 @@ impl TestFlightServer { } /// Specify the response returned from the next call to get_flight_info + #[allow(dead_code)] pub fn set_get_flight_info_response(&self, response: Result) { let mut state = self.state.lock().expect("mutex not poisoned"); state.get_flight_info_response.replace(response); } /// Take and return last get_flight_info request sent to the server, + #[allow(dead_code)] pub fn take_get_flight_info_request(&self) -> Option { self.state .lock() @@ -82,12 +88,14 @@ impl TestFlightServer { } /// Specify the response returned from the next call to poll_flight_info + #[allow(dead_code)] pub fn set_poll_flight_info_response(&self, response: Result) { let mut state = self.state.lock().expect("mutex not poisoned"); state.poll_flight_info_response.replace(response); } /// Take and return last poll_flight_info request sent to the server, + #[allow(dead_code)] pub fn take_poll_flight_info_request(&self) -> Option { self.state .lock() @@ -97,12 +105,14 @@ impl TestFlightServer { } /// Specify the response returned from the next call to `do_get` + #[allow(dead_code)] pub fn set_do_get_response(&self, response: Vec>) { let mut state = self.state.lock().expect("mutex not poisoned"); state.do_get_response.replace(response); } /// Take and return last do_get request send to the server, + #[allow(dead_code)] pub fn take_do_get_request(&self) -> Option { self.state .lock() @@ -112,12 +122,14 @@ impl TestFlightServer { } /// Specify the response returned from the next call to `do_put` + #[allow(dead_code)] pub fn set_do_put_response(&self, response: Vec>) { let mut state = self.state.lock().expect("mutex not poisoned"); state.do_put_response.replace(response); } /// Take and return last do_put request sent to the server, + #[allow(dead_code)] pub fn take_do_put_request(&self) -> Option> { self.state .lock() @@ -127,12 +139,14 @@ impl TestFlightServer { } /// Specify the response returned from the next call to `do_exchange` + #[allow(dead_code)] pub fn set_do_exchange_response(&self, response: Vec>) { let mut state = self.state.lock().expect("mutex not poisoned"); state.do_exchange_response.replace(response); } /// Take and return last do_exchange request send to the server, + #[allow(dead_code)] pub fn take_do_exchange_request(&self) -> Option> { self.state .lock() @@ -142,12 +156,14 @@ impl TestFlightServer { } /// Specify the response returned from the next call to `list_flights` + #[allow(dead_code)] pub fn set_list_flights_response(&self, response: Vec>) { let mut state = self.state.lock().expect("mutex not poisoned"); state.list_flights_response.replace(response); } /// Take and return last list_flights request send to the server, + #[allow(dead_code)] pub fn take_list_flights_request(&self) -> Option { self.state .lock() @@ -157,12 +173,14 @@ impl TestFlightServer { } /// Specify the response returned from the next call to `get_schema` + #[allow(dead_code)] pub fn set_get_schema_response(&self, response: Result) { let mut state = self.state.lock().expect("mutex not poisoned"); state.get_schema_response.replace(response); } /// Take and return last get_schema request send to the server, + #[allow(dead_code)] pub fn take_get_schema_request(&self) -> Option { self.state .lock() @@ -172,12 +190,14 @@ impl TestFlightServer { } /// Specify the response returned from the next call to `list_actions` + #[allow(dead_code)] pub fn set_list_actions_response(&self, response: Vec>) { let mut state = self.state.lock().expect("mutex not poisoned"); state.list_actions_response.replace(response); } /// Take and return last list_actions request send to the server, + #[allow(dead_code)] pub fn take_list_actions_request(&self) -> Option { self.state .lock() @@ -187,12 +207,14 @@ impl TestFlightServer { } /// Specify the response returned from the next call to `do_action` + #[allow(dead_code)] pub fn set_do_action_response(&self, response: Vec>) { let mut state = self.state.lock().expect("mutex not poisoned"); state.do_action_response.replace(response); } /// Take and return last do_action request send to the server, + #[allow(dead_code)] pub fn take_do_action_request(&self) -> Option { self.state .lock() @@ -202,6 +224,7 @@ impl TestFlightServer { } /// Returns the last metadata from a request received by the server + #[allow(dead_code)] pub fn take_last_request_metadata(&self) -> Option { self.state .lock() diff --git a/arrow-flight/tests/flight_sql_client.rs b/arrow-flight/tests/flight_sql_client.rs new file mode 100644 index 000000000000..94b768a13621 --- /dev/null +++ b/arrow-flight/tests/flight_sql_client.rs @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod common; + +use crate::common::fixture::TestFixture; +use arrow_flight::flight_service_server::FlightServiceServer; +use arrow_flight::sql::client::FlightSqlServiceClient; +use arrow_flight::sql::server::FlightSqlService; +use arrow_flight::sql::{ + ActionBeginTransactionRequest, ActionBeginTransactionResult, ActionEndTransactionRequest, + EndTransaction, SqlInfo, +}; +use arrow_flight::Action; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::Mutex; +use tonic::{Request, Status}; +use uuid::Uuid; + +#[tokio::test] +pub async fn test_begin_end_transaction() { + let test_server = FlightSqlServiceImpl { + transactions: Arc::new(Mutex::new(HashMap::new())), + }; + let fixture = TestFixture::new(test_server.service()).await; + let channel = fixture.channel().await; + let mut flight_sql_client = FlightSqlServiceClient::new(channel); + + // begin commit + let transaction_id = flight_sql_client.begin_transaction().await.unwrap(); + flight_sql_client + .end_transaction(transaction_id, EndTransaction::Commit) + .await + .unwrap(); + + // begin rollback + let transaction_id = flight_sql_client.begin_transaction().await.unwrap(); + flight_sql_client + .end_transaction(transaction_id, EndTransaction::Rollback) + .await + .unwrap(); + + // unknown transaction id + let transaction_id = "UnknownTransactionId".to_string().into(); + assert!(flight_sql_client + .end_transaction(transaction_id, EndTransaction::Commit) + .await + .is_err()); +} + +#[derive(Clone)] +pub struct FlightSqlServiceImpl { + transactions: Arc>>, +} + +impl FlightSqlServiceImpl { + /// Return an [`FlightServiceServer`] that can be used with a + /// [`Server`](tonic::transport::Server) + pub fn service(&self) -> FlightServiceServer { + // wrap up tonic goop + FlightServiceServer::new(self.clone()) + } +} + +#[tonic::async_trait] +impl FlightSqlService for FlightSqlServiceImpl { + type FlightService = FlightSqlServiceImpl; + + async fn do_action_begin_transaction( + &self, + _query: ActionBeginTransactionRequest, + _request: Request, + ) -> Result { + let transaction_id = Uuid::new_v4().to_string(); + self.transactions + .lock() + .await + .insert(transaction_id.clone(), ()); + Ok(ActionBeginTransactionResult { + transaction_id: transaction_id.as_bytes().to_vec().into(), + }) + } + + async fn do_action_end_transaction( + &self, + query: ActionEndTransactionRequest, + _request: Request, + ) -> Result<(), Status> { + let transaction_id = String::from_utf8(query.transaction_id.to_vec()) + .map_err(|_| Status::invalid_argument("Invalid transaction id"))?; + if self + .transactions + .lock() + .await + .remove(&transaction_id) + .is_none() + { + return Err(Status::invalid_argument("Transaction id not found")); + } + Ok(()) + } + + async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {} +} diff --git a/arrow-flight/tests/flight_sql_client_cli.rs b/arrow-flight/tests/flight_sql_client_cli.rs index 631f5cd31465..168015d07e2d 100644 --- a/arrow-flight/tests/flight_sql_client_cli.rs +++ b/arrow-flight/tests/flight_sql_client_cli.rs @@ -15,21 +15,20 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashMap; -use std::{net::SocketAddr, pin::Pin, sync::Arc, time::Duration}; +mod common; +use std::{pin::Pin, sync::Arc}; + +use crate::common::fixture::TestFixture; use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray}; -use arrow_flight::sql::client::FlightSqlServiceClient; -use arrow_flight::sql::EndTransaction; use arrow_flight::{ decode::FlightRecordBatchStream, flight_service_server::{FlightService, FlightServiceServer}, sql::{ server::{FlightSqlService, PeekableFlightDataStream}, - ActionBeginTransactionRequest, ActionBeginTransactionResult, - ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult, - ActionEndTransactionRequest, Any, CommandPreparedStatementQuery, CommandStatementQuery, - DoPutPreparedStatementResult, ProstMessageExt, SqlInfo, + ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult, Any, + CommandPreparedStatementQuery, CommandStatementQuery, DoPutPreparedStatementResult, + ProstMessageExt, SqlInfo, }, utils::batches_to_flight_data, Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, @@ -41,18 +40,14 @@ use assert_cmd::Command; use bytes::Bytes; use futures::{Stream, TryStreamExt}; use prost::Message; -use tokio::sync::Mutex; -use tokio::{net::TcpListener, task::JoinHandle}; -use tonic::transport::Endpoint; use tonic::{Request, Response, Status, Streaming}; -use uuid::Uuid; const QUERY: &str = "SELECT * FROM table;"; #[tokio::test] async fn test_simple() { let test_server = FlightSqlServiceImpl::default(); - let fixture = TestFixture::new(&test_server).await; + let fixture = TestFixture::new(test_server.service()).await; let addr = fixture.addr; let stdout = tokio::task::spawn_blocking(move || { @@ -95,7 +90,7 @@ const PREPARED_STATEMENT_HANDLE: &str = "prepared_statement_handle"; const UPDATED_PREPARED_STATEMENT_HANDLE: &str = "updated_prepared_statement_handle"; async fn test_do_put_prepared_statement(test_server: FlightSqlServiceImpl) { - let fixture = TestFixture::new(&test_server).await; + let fixture = TestFixture::new(test_server.service()).await; let addr = fixture.addr; let stdout = tokio::task::spawn_blocking(move || { @@ -139,7 +134,6 @@ async fn test_do_put_prepared_statement(test_server: FlightSqlServiceImpl) { pub async fn test_do_put_prepared_statement_stateless() { test_do_put_prepared_statement(FlightSqlServiceImpl { stateless_prepared_statements: true, - transactions: Arc::new(Mutex::new(HashMap::new())), }) .await } @@ -148,65 +142,22 @@ pub async fn test_do_put_prepared_statement_stateless() { pub async fn test_do_put_prepared_statement_stateful() { test_do_put_prepared_statement(FlightSqlServiceImpl { stateless_prepared_statements: false, - transactions: Arc::new(Mutex::new(HashMap::new())), }) .await } -#[tokio::test] -pub async fn test_begin_end_transaction() { - let test_server = FlightSqlServiceImpl { - stateless_prepared_statements: true, - transactions: Arc::new(Mutex::new(HashMap::new())), - }; - let fixture = TestFixture::new(&test_server).await; - let addr = fixture.addr; - let channel = Endpoint::from_shared(format!("http://{}:{}", addr.ip(), addr.port())) - .unwrap() - .connect() - .await - .expect("error connecting"); - let mut flight_sql_client = FlightSqlServiceClient::new(channel); - - // begin commit - let transaction_id = flight_sql_client.begin_transaction().await.unwrap(); - flight_sql_client - .end_transaction(transaction_id, EndTransaction::Commit) - .await - .unwrap(); - - // begin rollback - let transaction_id = flight_sql_client.begin_transaction().await.unwrap(); - flight_sql_client - .end_transaction(transaction_id, EndTransaction::Rollback) - .await - .unwrap(); - - // unknown transaction id - let transaction_id = "UnknownTransactionId".to_string().into(); - assert!(flight_sql_client - .end_transaction(transaction_id, EndTransaction::Commit) - .await - .is_err()); -} - -/// All tests must complete within this many seconds or else the test server is shutdown -const DEFAULT_TIMEOUT_SECONDS: u64 = 30; - #[derive(Clone)] pub struct FlightSqlServiceImpl { /// Whether to emulate stateless (true) or stateful (false) behavior for /// prepared statements. stateful servers will not return an updated /// handle after executing `DoPut(CommandPreparedStatementQuery)` stateless_prepared_statements: bool, - transactions: Arc>>, } impl Default for FlightSqlServiceImpl { fn default() -> Self { Self { stateless_prepared_statements: true, - transactions: Arc::new(Mutex::new(HashMap::new())), } } } @@ -401,118 +352,9 @@ impl FlightSqlService for FlightSqlServiceImpl { .map_err(|e| Status::internal(format!("Unable to serialize schema: {e}"))) } - async fn do_action_begin_transaction( - &self, - _query: ActionBeginTransactionRequest, - _request: Request, - ) -> Result { - let transaction_id = Uuid::new_v4().to_string(); - self.transactions - .lock() - .await - .insert(transaction_id.clone(), ()); - Ok(ActionBeginTransactionResult { - transaction_id: transaction_id.as_bytes().to_vec().into(), - }) - } - - async fn do_action_end_transaction( - &self, - query: ActionEndTransactionRequest, - _request: Request, - ) -> Result<(), Status> { - let transaction_id = String::from_utf8(query.transaction_id.to_vec()) - .map_err(|_| Status::invalid_argument("Invalid transaction id"))?; - if self - .transactions - .lock() - .await - .remove(&transaction_id) - .is_none() - { - return Err(Status::invalid_argument("Transaction id not found")); - } - Ok(()) - } - async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {} } -/// Creates and manages a running TestServer with a background task -struct TestFixture { - /// channel to send shutdown command - shutdown: Option>, - - /// Address the server is listening on - addr: SocketAddr, - - // handle for the server task - handle: Option>>, -} - -impl TestFixture { - /// create a new test fixture from the server - pub async fn new(test_server: &FlightSqlServiceImpl) -> Self { - // let OS choose a a free port - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - - println!("Listening on {addr}"); - - // prepare the shutdown channel - let (tx, rx) = tokio::sync::oneshot::channel(); - - let server_timeout = Duration::from_secs(DEFAULT_TIMEOUT_SECONDS); - - let shutdown_future = async move { - rx.await.ok(); - }; - - let serve_future = tonic::transport::Server::builder() - .timeout(server_timeout) - .add_service(test_server.service()) - .serve_with_incoming_shutdown( - tokio_stream::wrappers::TcpListenerStream::new(listener), - shutdown_future, - ); - - // Run the server in its own background task - let handle = tokio::task::spawn(serve_future); - - Self { - shutdown: Some(tx), - addr, - handle: Some(handle), - } - } - - /// Stops the test server and waits for the server to shutdown - pub async fn shutdown_and_wait(mut self) { - if let Some(shutdown) = self.shutdown.take() { - shutdown.send(()).expect("server quit early"); - } - if let Some(handle) = self.handle.take() { - println!("Waiting on server to finish"); - handle - .await - .expect("task join error (panic?)") - .expect("Server Error found at shutdown"); - } - } -} - -impl Drop for TestFixture { - fn drop(&mut self) { - if let Some(shutdown) = self.shutdown.take() { - shutdown.send(()).ok(); - } - if self.handle.is_some() { - // tests should properly clean up TestFixture - println!("TestFixture::Drop called prior to `shutdown_and_wait`"); - } - } -} - #[derive(Clone, PartialEq, ::prost::Message)] pub struct FetchResults { #[prost(string, tag = "1")]