diff --git a/e2e_test/subscription/main.py b/e2e_test/subscription/main.py index caa1d3a141c09..fa89c9697d40c 100644 --- a/e2e_test/subscription/main.py +++ b/e2e_test/subscription/main.py @@ -249,6 +249,8 @@ def test_cursor_with_table_alter(): row = execute_query("fetch next from cur",conn) check_rows_data([1,2],row[0],1) row = execute_query("fetch next from cur",conn) + assert(row == []) + row = execute_query("fetch next from cur",conn) check_rows_data([4,4,4],row[0],1) execute_insert("insert into t1 values(5,5,5)",conn) execute_insert("flush",conn) @@ -258,6 +260,8 @@ def test_cursor_with_table_alter(): execute_insert("insert into t1 values(6,6)",conn) execute_insert("flush",conn) row = execute_query("fetch next from cur",conn) + assert(row == []) + row = execute_query("fetch next from cur",conn) check_rows_data([6,6],row[0],1) drop_table_subscription() @@ -324,6 +328,7 @@ def test_rebuild_table(): check_rows_data([1,1],row[0],1) check_rows_data([1,1],row[1],4) check_rows_data([1,100],row[2],3) + drop_table_subscription() if __name__ == "__main__": test_cursor_snapshot() diff --git a/src/frontend/src/binder/fetch_cursor.rs b/src/frontend/src/binder/fetch_cursor.rs new file mode 100644 index 0000000000000..50b48f631fcac --- /dev/null +++ b/src/frontend/src/binder/fetch_cursor.rs @@ -0,0 +1,42 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use risingwave_common::catalog::Schema; + +use crate::error::Result; +use crate::Binder; + +#[derive(Debug, Clone)] +pub struct BoundFetchCursor { + pub cursor_name: String, + + pub count: u32, + + pub returning_schema: Option, +} + +impl Binder { + pub fn bind_fetch_cursor( + &mut self, + cursor_name: String, + count: u32, + returning_schema: Option, + ) -> Result { + Ok(BoundFetchCursor { + cursor_name, + count, + returning_schema, + }) + } +} diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index af1be41a711ed..82d0a8d7edd25 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -32,6 +32,7 @@ mod create; mod create_view; mod delete; mod expr; +pub mod fetch_cursor; mod for_system; mod insert; mod query; diff --git a/src/frontend/src/binder/statement.rs b/src/frontend/src/binder/statement.rs index 7fca5ff483dfe..b73fab90aed9a 100644 --- a/src/frontend/src/binder/statement.rs +++ b/src/frontend/src/binder/statement.rs @@ -17,6 +17,7 @@ use risingwave_common::catalog::Field; use risingwave_sqlparser::ast::Statement; use super::delete::BoundDelete; +use super::fetch_cursor::BoundFetchCursor; use super::update::BoundUpdate; use crate::binder::create_view::BoundCreateView; use crate::binder::{Binder, BoundInsert, BoundQuery}; @@ -29,6 +30,7 @@ pub enum BoundStatement { Delete(Box), Update(Box), Query(Box), + FetchCursor(Box), CreateView(Box), } @@ -48,6 +50,10 @@ impl BoundStatement { .as_ref() .map_or(vec![], |s| s.fields().into()), BoundStatement::Query(q) => q.schema().fields().into(), + BoundStatement::FetchCursor(f) => f + .returning_schema + .as_ref() + .map_or(vec![], |s| s.fields().into()), BoundStatement::CreateView(_) => vec![], } } @@ -127,6 +133,7 @@ impl RewriteExprsRecursive for BoundStatement { BoundStatement::Delete(inner) => inner.rewrite_exprs_recursive(rewriter), BoundStatement::Update(inner) => inner.rewrite_exprs_recursive(rewriter), BoundStatement::Query(inner) => inner.rewrite_exprs_recursive(rewriter), + BoundStatement::FetchCursor(_) => {} BoundStatement::CreateView(inner) => inner.rewrite_exprs_recursive(rewriter), } } diff --git a/src/frontend/src/handler/declare_cursor.rs b/src/frontend/src/handler/declare_cursor.rs index 25e146fa714ce..a4974530cfe50 100644 --- a/src/frontend/src/handler/declare_cursor.rs +++ b/src/frontend/src/handler/declare_cursor.rs @@ -12,18 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -use pgwire::pg_field_descriptor::PgFieldDescriptor; +use std::sync::Arc; + use pgwire::pg_response::{PgResponse, StatementType}; +use risingwave_common::catalog::Field; +use risingwave_common::session_config::QueryMode; use risingwave_common::util::epoch::Epoch; use risingwave_sqlparser::ast::{DeclareCursorStatement, ObjectName, Query, Since, Statement}; -use super::query::{gen_batch_plan_by_statement, gen_batch_plan_fragmenter}; +use super::query::{ + gen_batch_plan_by_statement, gen_batch_plan_fragmenter, BatchPlanFragmenterResult, +}; use super::util::convert_unix_millis_to_logstore_u64; use super::RwPgResponse; use crate::error::{ErrorCode, Result}; -use crate::handler::query::create_stream; +use crate::handler::query::{distribute_execute, local_execute}; use crate::handler::HandlerArgs; -use crate::{Binder, OptimizerContext, PgResponseStream}; +use crate::session::cursor_manager::CursorDataChunkStream; +use crate::session::SessionImpl; +use crate::{Binder, OptimizerContext}; pub async fn handle_declare_cursor( handle_args: HandlerArgs, @@ -111,12 +118,12 @@ async fn handle_declare_query_cursor( cursor_name: ObjectName, query: Box, ) -> Result { - let (row_stream, pg_descs) = + let (chunk_stream, fields) = create_stream_for_cursor_stmt(handle_args.clone(), Statement::Query(query)).await?; handle_args .session .get_cursor_manager() - .add_query_cursor(cursor_name, row_stream, pg_descs) + .add_query_cursor(cursor_name, chunk_stream, fields) .await?; Ok(PgResponse::empty_result(StatementType::DECLARE_CURSOR)) } @@ -124,12 +131,42 @@ async fn handle_declare_query_cursor( pub async fn create_stream_for_cursor_stmt( handle_args: HandlerArgs, stmt: Statement, -) -> Result<(PgResponseStream, Vec)> { +) -> Result<(CursorDataChunkStream, Vec)> { let session = handle_args.session.clone(); let plan_fragmenter_result = { let context = OptimizerContext::from_handler_args(handle_args); let plan_result = gen_batch_plan_by_statement(&session, context.into(), stmt)?; gen_batch_plan_fragmenter(&session, plan_result)? }; - create_stream(session, plan_fragmenter_result, vec![]).await + create_chunk_stream_for_cursor(session, plan_fragmenter_result).await +} + +pub async fn create_chunk_stream_for_cursor( + session: Arc, + plan_fragmenter_result: BatchPlanFragmenterResult, +) -> Result<(CursorDataChunkStream, Vec)> { + let BatchPlanFragmenterResult { + plan_fragmenter, + query_mode, + schema, + .. + } = plan_fragmenter_result; + + let can_timeout_cancel = true; + + let query = plan_fragmenter.generate_complete_query().await?; + tracing::trace!("Generated query after plan fragmenter: {:?}", &query); + + Ok(( + match query_mode { + QueryMode::Auto => unreachable!(), + QueryMode::Local => CursorDataChunkStream::LocalDataChunk(Some( + local_execute(session.clone(), query, can_timeout_cancel).await?, + )), + QueryMode::Distributed => CursorDataChunkStream::DistributedDataChunk(Some( + distribute_execute(session.clone(), query, can_timeout_cancel).await?, + )), + }, + schema.fields.clone(), + )) } diff --git a/src/frontend/src/handler/extended_handle.rs b/src/frontend/src/handler/extended_handle.rs index ac0e799d9c1ab..f12eaa617352b 100644 --- a/src/frontend/src/handler/extended_handle.rs +++ b/src/frontend/src/handler/extended_handle.rs @@ -23,7 +23,7 @@ use risingwave_common::types::DataType; use risingwave_sqlparser::ast::{CreateSink, Query, Statement}; use super::query::BoundResult; -use super::{handle, query, HandlerArgs, RwPgResponse}; +use super::{fetch_cursor, handle, query, HandlerArgs, RwPgResponse}; use crate::error::Result; use crate::session::SessionImpl; @@ -94,7 +94,7 @@ impl std::fmt::Display for PortalResult { } } -pub fn handle_parse( +pub async fn handle_parse( session: Arc, statement: Statement, specific_param_types: Vec>, @@ -109,6 +109,9 @@ pub fn handle_parse( | Statement::Update { .. } => { query::handle_parse(handler_args, statement, specific_param_types) } + Statement::FetchCursor { .. } => { + fetch_cursor::handle_parse(handler_args, statement, specific_param_types).await + } Statement::CreateView { query, materialized, @@ -198,8 +201,11 @@ pub async fn handle_execute(session: Arc, portal: Portal) -> Result let _guard = session.txn_begin_implicit(); // TODO(bugen): is this behavior correct? let sql: Arc = Arc::from(portal.statement.to_string()); let handler_args = HandlerArgs::new(session, &portal.statement, sql)?; - - query::handle_execute(handler_args, portal).await + if let Statement::FetchCursor { .. } = &portal.statement { + fetch_cursor::handle_fetch_cursor_execute(handler_args, portal).await + } else { + query::handle_execute(handler_args, portal).await + } } Portal::PureStatement(stmt) => { let sql: Arc = Arc::from(stmt.to_string()); diff --git a/src/frontend/src/handler/fetch_cursor.rs b/src/frontend/src/handler/fetch_cursor.rs index 05305a9657b1a..d339e3e7a1acb 100644 --- a/src/frontend/src/handler/fetch_cursor.rs +++ b/src/frontend/src/handler/fetch_cursor.rs @@ -14,19 +14,49 @@ use pgwire::pg_field_descriptor::PgFieldDescriptor; use pgwire::pg_response::{PgResponse, StatementType}; -use pgwire::types::Row; -use risingwave_sqlparser::ast::FetchCursorStatement; +use pgwire::types::{Format, Row}; +use risingwave_common::bail_not_implemented; +use risingwave_common::catalog::Schema; +use risingwave_common::types::DataType; +use risingwave_sqlparser::ast::{FetchCursorStatement, Statement}; +use super::extended_handle::{PortalResult, PrepareStatement, PreparedResult}; +use super::query::BoundResult; use super::RwPgResponse; +use crate::binder::BoundStatement; use crate::error::Result; use crate::handler::HandlerArgs; use crate::{Binder, PgResponseStream}; +pub async fn handle_fetch_cursor_execute( + handler_args: HandlerArgs, + portal_result: PortalResult, +) -> Result { + if let PortalResult { + statement: Statement::FetchCursor { stmt }, + bound_result: + BoundResult { + bound: BoundStatement::FetchCursor(fetch_cursor), + .. + }, + result_formats, + .. + } = portal_result + { + match fetch_cursor.returning_schema { + Some(_) => handle_fetch_cursor(handler_args, stmt, &result_formats).await, + None => Ok(build_fetch_cursor_response(vec![], vec![])), + } + } else { + bail_not_implemented!("unsupported portal {}", portal_result) + } +} pub async fn handle_fetch_cursor( - handle_args: HandlerArgs, + handler_args: HandlerArgs, stmt: FetchCursorStatement, + formats: &Vec, ) -> Result { - let session = handle_args.session.clone(); + let session = handler_args.session.clone(); let db_name = session.database(); let (_, cursor_name) = Binder::resolve_schema_qualified_name(db_name, stmt.cursor_name.clone())?; @@ -34,7 +64,7 @@ pub async fn handle_fetch_cursor( let cursor_manager = session.get_cursor_manager(); let (rows, pg_descs) = cursor_manager - .get_rows_with_cursor(cursor_name, stmt.count, handle_args) + .get_rows_with_cursor(cursor_name, stmt.count, handler_args, formats) .await?; Ok(build_fetch_cursor_response(rows, pg_descs)) } @@ -45,3 +75,40 @@ fn build_fetch_cursor_response(rows: Vec, pg_descs: Vec) .values(PgResponseStream::from(rows), pg_descs) .into() } + +pub async fn handle_parse( + handler_args: HandlerArgs, + statement: Statement, + specific_param_types: Vec>, +) -> Result { + if let Statement::FetchCursor { stmt } = &statement { + let session = handler_args.session.clone(); + let db_name = session.database(); + let (_, cursor_name) = + Binder::resolve_schema_qualified_name(db_name, stmt.cursor_name.clone())?; + let fields = session + .get_cursor_manager() + .get_fields_with_cursor(cursor_name.clone()) + .await?; + + let mut binder = Binder::new_with_param_types(&session, specific_param_types); + let schema = Some(Schema::new(fields)); + + let bound = binder.bind_fetch_cursor(cursor_name, stmt.count, schema)?; + let bound_result = BoundResult { + stmt_type: StatementType::FETCH_CURSOR, + must_dist: false, + bound: BoundStatement::FetchCursor(Box::new(bound)), + param_types: binder.export_param_types()?, + parsed_params: None, + dependent_relations: binder.included_relations(), + }; + let result = PreparedResult { + statement, + bound_result, + }; + Ok(PrepareStatement::Prepared(result)) + } else { + bail_not_implemented!("unsupported statement {:?}", statement) + } +} diff --git a/src/frontend/src/handler/mod.rs b/src/frontend/src/handler/mod.rs index f8beeedb19438..dbc7da91b4800 100644 --- a/src/frontend/src/handler/mod.rs +++ b/src/frontend/src/handler/mod.rs @@ -401,7 +401,7 @@ pub async fn handle( declare_cursor::handle_declare_cursor(handler_args, stmt).await } Statement::FetchCursor { stmt } => { - fetch_cursor::handle_fetch_cursor(handler_args, stmt).await + fetch_cursor::handle_fetch_cursor(handler_args, stmt, &formats).await } Statement::CloseCursor { stmt } => { close_cursor::handle_close_cursor(handler_args, stmt).await diff --git a/src/frontend/src/handler/privilege.rs b/src/frontend/src/handler/privilege.rs index e9f60a1f78b79..ff47dac4af860 100644 --- a/src/frontend/src/handler/privilege.rs +++ b/src/frontend/src/handler/privilege.rs @@ -115,6 +115,7 @@ pub(crate) fn resolve_privileges(stmt: &BoundStatement) -> Vec objects.push(object); } BoundStatement::Query(ref query) => objects.extend(resolve_query_privileges(query)), + BoundStatement::FetchCursor(_) => unimplemented!(), BoundStatement::CreateView(ref create_view) => { objects.extend(resolve_query_privileges(&create_view.query)) } diff --git a/src/frontend/src/handler/query.rs b/src/frontend/src/handler/query.rs index bdb32b590300b..de60743e47173 100644 --- a/src/frontend/src/handler/query.rs +++ b/src/frontend/src/handler/query.rs @@ -516,7 +516,7 @@ async fn execute( .into()) } -async fn distribute_execute( +pub async fn distribute_execute( session: Arc, query: Query, can_timeout_cancel: bool, @@ -538,8 +538,7 @@ async fn distribute_execute( .map_err(|err| err.into()) } -#[expect(clippy::unused_async)] -async fn local_execute( +pub async fn local_execute( session: Arc, query: Query, can_timeout_cancel: bool, diff --git a/src/frontend/src/handler/util.rs b/src/frontend/src/handler/util.rs index 73b52b977c7a4..0531ce5a65284 100644 --- a/src/frontend/src/handler/util.rs +++ b/src/frontend/src/handler/util.rs @@ -53,14 +53,14 @@ pin_project! { #[pin] chunk_stream: VS, column_types: Vec, - formats: Vec, + pub formats: Vec, session_data: StaticSessionData, } } // Static session data frozen at the time of the creation of the stream -struct StaticSessionData { - timezone: String, +pub struct StaticSessionData { + pub timezone: String, } impl DataChunkToRowSetAdapter @@ -110,7 +110,7 @@ where } /// Format scalars according to postgres convention. -fn pg_value_format( +pub fn pg_value_format( data_type: &DataType, d: ScalarRefImpl<'_>, format: Format, diff --git a/src/frontend/src/planner/statement.rs b/src/frontend/src/planner/statement.rs index 22b63de9f40be..91c1b9edfc619 100644 --- a/src/frontend/src/planner/statement.rs +++ b/src/frontend/src/planner/statement.rs @@ -24,6 +24,7 @@ impl Planner { BoundStatement::Delete(d) => self.plan_delete(*d), BoundStatement::Update(u) => self.plan_update(*u), BoundStatement::Query(q) => self.plan_query(*q), + BoundStatement::FetchCursor(_) => unimplemented!(), BoundStatement::CreateView(c) => self.plan_query(*c.query), } } diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index 8266fd48fcbf1..ada6a4d5c611f 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -1307,13 +1307,13 @@ impl Session for SessionImpl { self.id } - fn parse( + async fn parse( self: Arc, statement: Option, params_types: Vec>, ) -> std::result::Result { Ok(if let Some(statement) = statement { - handle_parse(self, statement, params_types)? + handle_parse(self, statement, params_types).await? } else { PrepareStatement::Empty }) @@ -1446,7 +1446,8 @@ fn infer(bound: Option, stmt: Statement) -> Result Ok(bound + | Statement::Update { .. } + | Statement::FetchCursor { .. } => Ok(bound .unwrap() .output_fields() .iter() diff --git a/src/frontend/src/session/cursor_manager.rs b/src/frontend/src/session/cursor_manager.rs index bcd1aa11ec749..390428f09bea3 100644 --- a/src/frontend/src/session/cursor_manager.rs +++ b/src/frontend/src/session/cursor_manager.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use core::mem; use core::time::Duration; use std::collections::{HashMap, VecDeque}; use std::rc::Rc; @@ -23,7 +24,9 @@ use fixedbitset::FixedBitSet; use futures::StreamExt; use pgwire::pg_field_descriptor::PgFieldDescriptor; use pgwire::pg_response::StatementType; -use pgwire::types::Row; +use pgwire::types::{Format, Row}; +use risingwave_common::catalog::Field; +use risingwave_common::error::BoxedError; use risingwave_common::session_config::QueryMode; use risingwave_common::types::DataType; use risingwave_sqlparser::ast::{Ident, ObjectName, Statement}; @@ -32,15 +35,70 @@ use super::SessionImpl; use crate::catalog::subscription_catalog::SubscriptionCatalog; use crate::catalog::TableId; use crate::error::{ErrorCode, Result}; -use crate::handler::declare_cursor::create_stream_for_cursor_stmt; -use crate::handler::query::{create_stream, gen_batch_plan_fragmenter, BatchQueryPlanResult}; -use crate::handler::util::{convert_logstore_u64_to_unix_millis, gen_query_from_table_name}; +use crate::handler::declare_cursor::{ + create_chunk_stream_for_cursor, create_stream_for_cursor_stmt, +}; +use crate::handler::query::{gen_batch_plan_fragmenter, BatchQueryPlanResult}; +use crate::handler::util::{ + convert_logstore_u64_to_unix_millis, gen_query_from_table_name, pg_value_format, to_pg_field, + DataChunkToRowSetAdapter, StaticSessionData, +}; use crate::handler::HandlerArgs; use crate::optimizer::plan_node::{generic, BatchLogSeqScan}; use crate::optimizer::property::{Order, RequiredDist}; use crate::optimizer::PlanRoot; +use crate::scheduler::{DistributedQueryStream, LocalQueryStream}; use crate::{OptimizerContext, OptimizerContextRef, PgResponseStream, PlanRef, TableCatalog}; +pub enum CursorDataChunkStream { + LocalDataChunk(Option), + DistributedDataChunk(Option), + PgResponse(PgResponseStream), +} + +impl CursorDataChunkStream { + pub fn init_row_stream( + &mut self, + fields: &Vec, + formats: &Vec, + session: Arc, + ) { + let columns_type = fields.iter().map(|f| f.data_type().clone()).collect(); + match self { + CursorDataChunkStream::LocalDataChunk(data_chunk) => { + let data_chunk = mem::take(data_chunk).unwrap(); + let row_stream = PgResponseStream::LocalQuery(DataChunkToRowSetAdapter::new( + data_chunk, + columns_type, + formats.clone(), + session, + )); + *self = CursorDataChunkStream::PgResponse(row_stream); + } + CursorDataChunkStream::DistributedDataChunk(data_chunk) => { + let data_chunk = mem::take(data_chunk).unwrap(); + let row_stream = PgResponseStream::DistributedQuery(DataChunkToRowSetAdapter::new( + data_chunk, + columns_type, + formats.clone(), + session, + )); + *self = CursorDataChunkStream::PgResponse(row_stream); + } + _ => {} + } + } + + pub async fn next(&mut self) -> Result, BoxedError>>> { + match self { + CursorDataChunkStream::PgResponse(row_stream) => Ok(row_stream.next().await), + _ => Err(ErrorCode::InternalError( + "Only 'CursorDataChunkStream' can call next and return rows".to_string(), + ) + .into()), + } + } +} pub enum Cursor { Subscription(SubscriptionCursor), Query(QueryCursor), @@ -50,32 +108,40 @@ impl Cursor { &mut self, count: u32, handle_args: HandlerArgs, + formats: &Vec, ) -> Result<(Vec, Vec)> { match self { - Cursor::Subscription(cursor) => cursor.next(count, handle_args).await, - Cursor::Query(cursor) => cursor.next(count).await, + Cursor::Subscription(cursor) => cursor.next(count, handle_args, formats).await, + Cursor::Query(cursor) => cursor.next(count, formats, handle_args).await, + } + } + + pub fn get_fields(&mut self) -> Vec { + match self { + Cursor::Subscription(cursor) => cursor.fields.clone(), + Cursor::Query(cursor) => cursor.fields.clone(), } } } pub struct QueryCursor { - row_stream: PgResponseStream, - pg_descs: Vec, + chunk_stream: CursorDataChunkStream, + fields: Vec, remaining_rows: VecDeque, } impl QueryCursor { - pub fn new(row_stream: PgResponseStream, pg_descs: Vec) -> Result { + pub fn new(chunk_stream: CursorDataChunkStream, fields: Vec) -> Result { Ok(Self { - row_stream, - pg_descs, + chunk_stream, + fields, remaining_rows: VecDeque::::new(), }) } pub async fn next_once(&mut self) -> Result> { while self.remaining_rows.is_empty() { - let rows = self.row_stream.next().await; + let rows = self.chunk_stream.next().await?; let rows = match rows { None => return Ok(None), Some(row) => row?, @@ -86,18 +152,27 @@ impl QueryCursor { Ok(Some(row)) } - pub async fn next(&mut self, count: u32) -> Result<(Vec, Vec)> { + pub async fn next( + &mut self, + count: u32, + formats: &Vec, + handle_args: HandlerArgs, + ) -> Result<(Vec, Vec)> { // `FETCH NEXT` is equivalent to `FETCH 1`. // min with 100 to avoid allocating too many memory at once. + let session = handle_args.session; let mut ans = Vec::with_capacity(std::cmp::min(100, count) as usize); let mut cur = 0; + let desc = self.fields.iter().map(to_pg_field).collect(); + self.chunk_stream + .init_row_stream(&self.fields, formats, session); while cur < count && let Some(row) = self.next_once().await? { cur += 1; ans.push(row); } - Ok((ans, self.pg_descs.clone())) + Ok((ans, desc)) } } @@ -120,11 +195,7 @@ enum State { // The row stream to from the batch query read. // It is returned from the batch execution. - row_stream: PgResponseStream, - - // The pg descs to from the batch query read. - // It is returned from the batch execution. - pg_descs: Vec, + chunk_stream: CursorDataChunkStream, // A cache to store the remaining rows from the row stream. remaining_rows: VecDeque, @@ -140,6 +211,9 @@ pub struct SubscriptionCursor { dependent_table_id: TableId, cursor_need_drop_time: Instant, state: State, + // fields will be set in the table's catalog when the cursor is created, + // and will be reset each time it is created chunk_stream, this is to avoid changes in the catalog due to alter. + fields: Vec, } impl SubscriptionCursor { @@ -150,17 +224,28 @@ impl SubscriptionCursor { dependent_table_id: TableId, handle_args: &HandlerArgs, ) -> Result { - let state = if let Some(start_timestamp) = start_timestamp { - State::InitLogStoreQuery { - seek_timestamp: start_timestamp, - expected_timestamp: None, - } + let (state, fields) = if let Some(start_timestamp) = start_timestamp { + let table_catalog = handle_args.session.get_table_by_id(&dependent_table_id)?; + let fields = table_catalog + .columns + .iter() + .filter(|c| !c.is_hidden) + .map(|c| Field::with_name(c.data_type().clone(), c.name())) + .collect(); + let fields = Self::build_desc(fields, true); + ( + State::InitLogStoreQuery { + seek_timestamp: start_timestamp, + expected_timestamp: None, + }, + fields, + ) } else { // The query stream needs to initiated on cursor creation to make sure // future fetch on the cursor starts from the snapshot when the cursor is declared. // // TODO: is this the right behavior? Should we delay the query stream initiation till the first fetch? - let (row_stream, pg_descs) = + let (chunk_stream, fields) = Self::initiate_query(None, &dependent_table_id, handle_args.clone()).await?; let pinned_epoch = handle_args .session @@ -177,14 +262,16 @@ impl SubscriptionCursor { .0; let start_timestamp = pinned_epoch; - State::Fetch { - from_snapshot: true, - rw_timestamp: start_timestamp, - row_stream, - pg_descs, - remaining_rows: VecDeque::new(), - expected_timestamp: None, - } + ( + State::Fetch { + from_snapshot: true, + rw_timestamp: start_timestamp, + chunk_stream, + remaining_rows: VecDeque::new(), + expected_timestamp: None, + }, + fields, + ) }; let cursor_need_drop_time = @@ -195,14 +282,15 @@ impl SubscriptionCursor { dependent_table_id, cursor_need_drop_time, state, + fields, }) } async fn next_row( &mut self, handle_args: &HandlerArgs, - expected_pg_descs: &Vec, - ) -> Result<(Option, Vec)> { + formats: &Vec, + ) -> Result> { loop { match &mut self.state { State::InitLogStoreQuery { @@ -222,33 +310,39 @@ impl SubscriptionCursor { .await { Ok((Some(rw_timestamp), expected_timestamp)) => { - let (mut row_stream, pg_descs) = Self::initiate_query( + let (mut chunk_stream, fields) = Self::initiate_query( Some(rw_timestamp), &self.dependent_table_id, handle_args.clone(), ) .await?; + Self::init_row_stream( + &mut chunk_stream, + formats, + &from_snapshot, + &self.fields, + handle_args.session.clone(), + ); + self.cursor_need_drop_time = Instant::now() + Duration::from_secs(self.subscription.retention_seconds); let mut remaining_rows = VecDeque::new(); - Self::try_refill_remaining_rows(&mut row_stream, &mut remaining_rows) + Self::try_refill_remaining_rows(&mut chunk_stream, &mut remaining_rows) .await?; // Transition to the Fetch state self.state = State::Fetch { from_snapshot, rw_timestamp, - row_stream, - pg_descs: pg_descs.clone(), + chunk_stream, remaining_rows, expected_timestamp, }; - if (!expected_pg_descs.is_empty()) && expected_pg_descs.ne(&pg_descs) { - // If the user alters the table upstream of the sub, there will be different descs here. - // So we should output data for different descs in two separate batches - return Ok((None, vec![])); + if self.fields.ne(&fields) { + self.fields = fields; + return Ok(None); } } - Ok((None, _)) => return Ok((None, vec![])), + Ok((None, _)) => return Ok(None), Err(e) => { self.state = State::Invalid; return Err(e); @@ -258,30 +352,36 @@ impl SubscriptionCursor { State::Fetch { from_snapshot, rw_timestamp, - row_stream, - pg_descs, + chunk_stream, remaining_rows, expected_timestamp, } => { + let session_data = StaticSessionData { + timezone: handle_args.session.config().timezone(), + }; let from_snapshot = *from_snapshot; let rw_timestamp = *rw_timestamp; // Try refill remaining rows - Self::try_refill_remaining_rows(row_stream, remaining_rows).await?; + Self::try_refill_remaining_rows(chunk_stream, remaining_rows).await?; if let Some(row) = remaining_rows.pop_front() { // 1. Fetch the next row let new_row = row.take(); if from_snapshot { - return Ok(( - Some(Row::new(Self::build_row(new_row, None)?)), - pg_descs.clone(), - )); + return Ok(Some(Row::new(Self::build_row( + new_row, + None, + formats, + &session_data, + )?))); } else { - return Ok(( - Some(Row::new(Self::build_row(new_row, Some(rw_timestamp))?)), - pg_descs.clone(), - )); + return Ok(Some(Row::new(Self::build_row( + new_row, + Some(rw_timestamp), + formats, + &session_data, + )?))); } } else { // 2. Reach EOF for the current query. @@ -314,6 +414,7 @@ impl SubscriptionCursor { &mut self, count: u32, handle_args: HandlerArgs, + formats: &Vec, ) -> Result<(Vec, Vec)> { if Instant::now() > self.cursor_need_drop_time { return Err(ErrorCode::InternalError( @@ -324,12 +425,25 @@ impl SubscriptionCursor { let mut ans = Vec::with_capacity(std::cmp::min(100, count) as usize); let mut cur = 0; - let mut pg_descs_ans = vec![]; + let desc = self.fields.iter().map(to_pg_field).collect(); + if let State::Fetch { + from_snapshot, + chunk_stream, + .. + } = &mut self.state + { + Self::init_row_stream( + chunk_stream, + formats, + from_snapshot, + &self.fields, + handle_args.session.clone(), + ); + } while cur < count { - let (row, descs_ans) = self.next_row(&handle_args, &pg_descs_ans).await?; + let row = self.next_row(&handle_args, formats).await?; match row { Some(row) => { - pg_descs_ans = descs_ans; cur += 1; ans.push(row); } @@ -339,7 +453,7 @@ impl SubscriptionCursor { } } - Ok((ans, pg_descs_ans)) + Ok((ans, desc)) } async fn get_next_rw_timestamp( @@ -379,10 +493,10 @@ impl SubscriptionCursor { rw_timestamp: Option, dependent_table_id: &TableId, handle_args: HandlerArgs, - ) -> Result<(PgResponseStream, Vec)> { + ) -> Result<(CursorDataChunkStream, Vec)> { let session = handle_args.clone().session; let table_catalog = session.get_table_by_id(dependent_table_id)?; - let (row_stream, pg_descs) = if let Some(rw_timestamp) = rw_timestamp { + let (chunk_stream, fields) = if let Some(rw_timestamp) = rw_timestamp { let context = OptimizerContext::from_handler_args(handle_args); let plan_fragmenter_result = gen_batch_plan_fragmenter( &session, @@ -394,7 +508,7 @@ impl SubscriptionCursor { rw_timestamp, )?, )?; - create_stream(session, plan_fragmenter_result, vec![]).await? + create_chunk_stream_for_cursor(session, plan_fragmenter_result).await? } else { let subscription_from_table_name = ObjectName(vec![Ident::from(table_catalog.name.as_ref())]); @@ -404,17 +518,17 @@ impl SubscriptionCursor { create_stream_for_cursor_stmt(handle_args, query_stmt).await? }; Ok(( - row_stream, - Self::build_desc(pg_descs, rw_timestamp.is_none()), + chunk_stream, + Self::build_desc(fields, rw_timestamp.is_none()), )) } async fn try_refill_remaining_rows( - row_stream: &mut PgResponseStream, + chunk_stream: &mut CursorDataChunkStream, remaining_rows: &mut VecDeque, ) -> Result<()> { if remaining_rows.is_empty() - && let Some(row_set) = row_stream.next().await + && let Some(row_set) = chunk_stream.next().await? { remaining_rows.extend(row_set?); } @@ -424,34 +538,39 @@ impl SubscriptionCursor { pub fn build_row( mut row: Vec>, rw_timestamp: Option, + formats: &Vec, + session_data: &StaticSessionData, ) -> Result>> { + let row_len = row.len(); let new_row = if let Some(rw_timestamp) = rw_timestamp { - vec![Some(Bytes::from( - convert_logstore_u64_to_unix_millis(rw_timestamp).to_string(), - ))] + let rw_timestamp_formats = formats.get(row_len).unwrap_or(&Format::Text); + let rw_timestamp = convert_logstore_u64_to_unix_millis(rw_timestamp); + let rw_timestamp = pg_value_format( + &DataType::Int64, + risingwave_common::types::ScalarRefImpl::Int64(rw_timestamp as i64), + *rw_timestamp_formats, + session_data, + )?; + vec![Some(rw_timestamp)] } else { - vec![Some(Bytes::from(1i16.to_string())), None] + let op_formats = formats.get(row_len).unwrap_or(&Format::Text); + let op = pg_value_format( + &DataType::Int16, + risingwave_common::types::ScalarRefImpl::Int16(1_i16), + *op_formats, + session_data, + )?; + vec![Some(op), None] }; row.extend(new_row); Ok(row) } - pub fn build_desc( - mut descs: Vec, - from_snapshot: bool, - ) -> Vec { + pub fn build_desc(mut descs: Vec, from_snapshot: bool) -> Vec { if from_snapshot { - descs.push(PgFieldDescriptor::new( - "op".to_owned(), - DataType::Int16.to_oid(), - DataType::Int16.type_len(), - )); + descs.push(Field::with_name(DataType::Int16, "op")); } - descs.push(PgFieldDescriptor::new( - "rw_timestamp".to_owned(), - DataType::Int64.to_oid(), - DataType::Int64.type_len(), - )); + descs.push(Field::with_name(DataType::Int64, "rw_timestamp")); descs } @@ -508,6 +627,26 @@ impl SubscriptionCursor { dependent_relations: table_catalog.dependent_relations.clone(), }) } + + // In the beginning (declare cur), we will give it an empty formats, + // this formats is not a real, when we fetch, We fill it with the formats returned from the pg client. + pub fn init_row_stream( + chunk_stream: &mut CursorDataChunkStream, + formats: &Vec, + from_snapshot: &bool, + fields: &Vec, + session: Arc, + ) { + let mut formats = formats.clone(); + let mut fields = fields.clone(); + formats.pop(); + fields.pop(); + if *from_snapshot { + formats.pop(); + fields.pop(); + } + chunk_stream.init_row_stream(&fields, &formats, session); + } } #[derive(Default)] @@ -553,10 +692,10 @@ impl CursorManager { pub async fn add_query_cursor( &self, cursor_name: ObjectName, - row_stream: PgResponseStream, - pg_descs: Vec, + chunk_stream: CursorDataChunkStream, + fields: Vec, ) -> Result<()> { - let cursor = QueryCursor::new(row_stream, pg_descs)?; + let cursor = QueryCursor::new(chunk_stream, fields)?; self.cursor_map .lock() .await @@ -595,9 +734,18 @@ impl CursorManager { cursor_name: String, count: u32, handle_args: HandlerArgs, + formats: &Vec, ) -> Result<(Vec, Vec)> { if let Some(cursor) = self.cursor_map.lock().await.get_mut(&cursor_name) { - cursor.next(count, handle_args).await + cursor.next(count, handle_args, formats).await + } else { + Err(ErrorCode::ItemNotFound(format!("Cannot find cursor `{}`", cursor_name)).into()) + } + } + + pub async fn get_fields_with_cursor(&self, cursor_name: String) -> Result> { + if let Some(cursor) = self.cursor_map.lock().await.get_mut(&cursor_name) { + Ok(cursor.get_fields()) } else { Err(ErrorCode::ItemNotFound(format!("Cannot find cursor `{}`", cursor_name)).into()) } diff --git a/src/utils/pgwire/src/pg_protocol.rs b/src/utils/pgwire/src/pg_protocol.rs index d700e39757df1..72b99f6d50d64 100644 --- a/src/utils/pgwire/src/pg_protocol.rs +++ b/src/utils/pgwire/src/pg_protocol.rs @@ -414,7 +414,7 @@ where FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?, FeMessage::Terminate => self.process_terminate(), FeMessage::Parse(m) => { - if let Err(err) = self.process_parse_msg(m) { + if let Err(err) = self.process_parse_msg(m).await { self.ignore_util_sync = true; return Err(err); } @@ -681,16 +681,17 @@ where self.is_terminate = true; } - fn process_parse_msg(&mut self, msg: FeParseMessage) -> PsqlResult<()> { + async fn process_parse_msg(&mut self, msg: FeParseMessage) -> PsqlResult<()> { let sql = cstr_to_str(&msg.sql_bytes).unwrap(); record_sql_in_span(sql, self.redact_sql_option_keywords.clone()); let session = self.session.clone().unwrap(); let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_string(); self.inner_process_parse_msg(session, sql, statement_name, msg.type_ids) + .await } - fn inner_process_parse_msg( + async fn inner_process_parse_msg( &mut self, session: Arc, sql: &str, @@ -737,6 +738,7 @@ where let prepare_statement = session .parse(stmt, param_types) + .await .map_err(PsqlError::ExtendedPrepareError)?; if statement_name.is_empty() { @@ -850,7 +852,6 @@ where .unwrap() .describe_statement(prepare_statement) .map_err(PsqlError::Uncategorized)?; - self.stream .write_no_flush(&BeMessage::ParameterDescription( ¶m_types.iter().map(|t| t.to_oid()).collect_vec(), diff --git a/src/utils/pgwire/src/pg_server.rs b/src/utils/pgwire/src/pg_server.rs index 840f21dda1be2..4b0b8657ac59f 100644 --- a/src/utils/pgwire/src/pg_server.rs +++ b/src/utils/pgwire/src/pg_server.rs @@ -85,7 +85,7 @@ pub trait Session: Send + Sync { self: Arc, sql: Option, params_types: Vec>, - ) -> Result; + ) -> impl Future> + Send; // TODO: maybe this function should be async and return the notice more timely /// try to take the current notices from the session @@ -424,7 +424,7 @@ mod tests { .into()) } - fn parse( + async fn parse( self: Arc, _sql: Option, _params_types: Vec>,