From d1d64c61d759f6ff618a3040c91cb960f93fbbd2 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Wed, 23 Oct 2024 16:47:48 -0700 Subject: [PATCH 1/2] test: add a streaming response example --- Cargo.toml | 2 +- examples/streaming.rs | 158 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 159 insertions(+), 1 deletion(-) create mode 100644 examples/streaming.rs diff --git a/Cargo.toml b/Cargo.toml index cd03bd9..86e4256 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,7 +70,7 @@ _sqlite = [] _bundled = ["duckdb/bundled", "rusqlite/bundled"] [dev-dependencies] -tokio = { version = "1.19", features = ["rt-multi-thread", "net", "macros"]} +tokio = { version = "1.19", features = ["rt-multi-thread", "net", "macros", "time"]} rusqlite = { version = "0.32.1", features = ["column_decltype"] } ## for duckdb example duckdb = { version = "1.0.0" } diff --git a/examples/streaming.rs b/examples/streaming.rs new file mode 100644 index 0000000..3937003 --- /dev/null +++ b/examples/streaming.rs @@ -0,0 +1,158 @@ +use std::fmt::Debug; +use std::pin::Pin; +use std::sync::Arc; +use std::task::Poll; +use std::time::Duration; + +use async_trait::async_trait; +use futures::{Sink, SinkExt, Stream}; +use pgwire::messages::data::DataRow; +use tokio::net::TcpListener; + +use pgwire::api::auth::noop::NoopStartupHandler; +use pgwire::api::copy::NoopCopyHandler; +use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; +use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response}; +use pgwire::api::{ClientInfo, PgWireHandlerFactory, Type}; +use pgwire::error::ErrorInfo; +use pgwire::error::{PgWireError, PgWireResult}; +use pgwire::messages::response::NoticeResponse; +use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage}; +use pgwire::tokio::process_socket; +use tokio::time::{interval, Interval}; + +pub struct DummyProcessor; + +#[async_trait] +impl NoopStartupHandler for DummyProcessor { + async fn post_startup( + &self, + client: &mut C, + _message: PgWireFrontendMessage, + ) -> PgWireResult<()> + where + C: ClientInfo + Sink + Unpin + Send, + C::Error: Debug, + PgWireError: From<>::Error>, + { + println!("Connected: {}", client.socket_addr()); + client + .send(PgWireBackendMessage::NoticeResponse(NoticeResponse::from( + ErrorInfo::new( + "NOTICE".to_owned(), + "01000".to_owned(), + "This is an example demos streaming response from backend. Try `SELECT 1;` to see it in action." + .to_string(), + ), + ))) + .await?; + Ok(()) + } +} + +struct ResultStream { + schema: Arc>, + counter: usize, + interval: Interval, +} + +impl Stream for ResultStream { + type Item = PgWireResult; + + fn size_hint(&self) -> (usize, Option) { + (self.counter, Some(self.counter)) + } + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + if self.counter >= 10 { + Poll::Ready(None) + } else { + match Pin::new(&mut self.interval).poll_tick(cx) { + Poll::Ready(_) => { + self.counter += 1; + let row = { + let mut encoder = DataRowEncoder::new(self.schema.clone()); + encoder.encode_field(&Some(1))?; + + encoder.finish() + }; + Poll::Ready(Some(row)) + } + Poll::Pending => Poll::Pending, + } + } + } +} + +#[async_trait] +impl SimpleQueryHandler for DummyProcessor { + async fn do_query<'a, C>( + &self, + _client: &mut C, + _query: &'a str, + ) -> PgWireResult>> + where + C: ClientInfo + Sink + Unpin + Send + Sync, + C::Error: Debug, + PgWireError: From<>::Error>, + { + let f1 = FieldInfo::new("SELECT 1".into(), None, None, Type::INT4, FieldFormat::Text); + let schema = Arc::new(vec![f1]); + + // generate 10 results + let data_row_stream = ResultStream { + schema: schema.clone(), + counter: 0, + interval: interval(Duration::from_secs(1)), + }; + + let resp = Response::Query(QueryResponse::new(schema, data_row_stream)); + Ok(vec![resp]) + } +} + +struct DummyProcessorFactory { + handler: Arc, +} + +impl PgWireHandlerFactory for DummyProcessorFactory { + type StartupHandler = DummyProcessor; + type SimpleQueryHandler = DummyProcessor; + type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; + type CopyHandler = NoopCopyHandler; + + fn simple_query_handler(&self) -> Arc { + self.handler.clone() + } + + fn extended_query_handler(&self) -> Arc { + Arc::new(PlaceholderExtendedQueryHandler) + } + + fn startup_handler(&self) -> Arc { + self.handler.clone() + } + + fn copy_handler(&self) -> Arc { + Arc::new(NoopCopyHandler) + } +} + +#[tokio::main] +pub async fn main() { + let factory = Arc::new(DummyProcessorFactory { + handler: Arc::new(DummyProcessor), + }); + + let server_addr = "127.0.0.1:5432"; + let listener = TcpListener::bind(server_addr).await.unwrap(); + println!("Listening to {}", server_addr); + loop { + let incoming_socket = listener.accept().await.unwrap(); + let factory_ref = factory.clone(); + tokio::spawn(async move { process_socket(incoming_socket.0, None, factory_ref).await }); + } +} From 2a88ad0dd32b699ef71225bd19a8df8b0a6d1e7f Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Wed, 23 Oct 2024 16:51:32 -0700 Subject: [PATCH 2/2] chore: flush data in stream --- src/api/query.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/api/query.rs b/src/api/query.rs index ef52083..a2ec636 100644 --- a/src/api/query.rs +++ b/src/api/query.rs @@ -420,7 +420,7 @@ where while let Some(row) = data_rows.next().await { let row = row?; rows += 1; - client.feed(PgWireBackendMessage::DataRow(row)).await?; + client.send(PgWireBackendMessage::DataRow(row)).await?; } let tag = Tag::new(&command_tag).with_rows(rows);