Skip to content

Commit

Permalink
feat(frontend): support extendend query for cursor (#17821)
Browse files Browse the repository at this point in the history
  • Loading branch information
xxhZs authored Aug 9, 2024
1 parent cc21a6a commit efdbf3c
Show file tree
Hide file tree
Showing 16 changed files with 438 additions and 122 deletions.
5 changes: 5 additions & 0 deletions e2e_test/subscription/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ def test_cursor_with_table_alter():
row = execute_query("fetch next from cur",conn)
check_rows_data([1,2],row[0],1)
row = execute_query("fetch next from cur",conn)
assert(row == [])
row = execute_query("fetch next from cur",conn)
check_rows_data([4,4,4],row[0],1)
execute_insert("insert into t1 values(5,5,5)",conn)
execute_insert("flush",conn)
Expand All @@ -258,6 +260,8 @@ def test_cursor_with_table_alter():
execute_insert("insert into t1 values(6,6)",conn)
execute_insert("flush",conn)
row = execute_query("fetch next from cur",conn)
assert(row == [])
row = execute_query("fetch next from cur",conn)
check_rows_data([6,6],row[0],1)
drop_table_subscription()

Expand Down Expand Up @@ -324,6 +328,7 @@ def test_rebuild_table():
check_rows_data([1,1],row[0],1)
check_rows_data([1,1],row[1],4)
check_rows_data([1,100],row[2],3)
drop_table_subscription()

if __name__ == "__main__":
test_cursor_snapshot()
Expand Down
42 changes: 42 additions & 0 deletions src/frontend/src/binder/fetch_cursor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_common::catalog::Schema;

use crate::error::Result;
use crate::Binder;

#[derive(Debug, Clone)]
pub struct BoundFetchCursor {
pub cursor_name: String,

pub count: u32,

pub returning_schema: Option<Schema>,
}

impl Binder {
pub fn bind_fetch_cursor(
&mut self,
cursor_name: String,
count: u32,
returning_schema: Option<Schema>,
) -> Result<BoundFetchCursor> {
Ok(BoundFetchCursor {
cursor_name,
count,
returning_schema,
})
}
}
1 change: 1 addition & 0 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ mod create;
mod create_view;
mod delete;
mod expr;
pub mod fetch_cursor;
mod for_system;
mod insert;
mod query;
Expand Down
7 changes: 7 additions & 0 deletions src/frontend/src/binder/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use risingwave_common::catalog::Field;
use risingwave_sqlparser::ast::Statement;

use super::delete::BoundDelete;
use super::fetch_cursor::BoundFetchCursor;
use super::update::BoundUpdate;
use crate::binder::create_view::BoundCreateView;
use crate::binder::{Binder, BoundInsert, BoundQuery};
Expand All @@ -29,6 +30,7 @@ pub enum BoundStatement {
Delete(Box<BoundDelete>),
Update(Box<BoundUpdate>),
Query(Box<BoundQuery>),
FetchCursor(Box<BoundFetchCursor>),
CreateView(Box<BoundCreateView>),
}

Expand All @@ -48,6 +50,10 @@ impl BoundStatement {
.as_ref()
.map_or(vec![], |s| s.fields().into()),
BoundStatement::Query(q) => q.schema().fields().into(),
BoundStatement::FetchCursor(f) => f
.returning_schema
.as_ref()
.map_or(vec![], |s| s.fields().into()),
BoundStatement::CreateView(_) => vec![],
}
}
Expand Down Expand Up @@ -127,6 +133,7 @@ impl RewriteExprsRecursive for BoundStatement {
BoundStatement::Delete(inner) => inner.rewrite_exprs_recursive(rewriter),
BoundStatement::Update(inner) => inner.rewrite_exprs_recursive(rewriter),
BoundStatement::Query(inner) => inner.rewrite_exprs_recursive(rewriter),
BoundStatement::FetchCursor(_) => {}
BoundStatement::CreateView(inner) => inner.rewrite_exprs_recursive(rewriter),
}
}
Expand Down
53 changes: 45 additions & 8 deletions src/frontend/src/handler/declare_cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,25 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use pgwire::pg_field_descriptor::PgFieldDescriptor;
use std::sync::Arc;

use pgwire::pg_response::{PgResponse, StatementType};
use risingwave_common::catalog::Field;
use risingwave_common::session_config::QueryMode;
use risingwave_common::util::epoch::Epoch;
use risingwave_sqlparser::ast::{DeclareCursorStatement, ObjectName, Query, Since, Statement};

use super::query::{gen_batch_plan_by_statement, gen_batch_plan_fragmenter};
use super::query::{
gen_batch_plan_by_statement, gen_batch_plan_fragmenter, BatchPlanFragmenterResult,
};
use super::util::convert_unix_millis_to_logstore_u64;
use super::RwPgResponse;
use crate::error::{ErrorCode, Result};
use crate::handler::query::create_stream;
use crate::handler::query::{distribute_execute, local_execute};
use crate::handler::HandlerArgs;
use crate::{Binder, OptimizerContext, PgResponseStream};
use crate::session::cursor_manager::CursorDataChunkStream;
use crate::session::SessionImpl;
use crate::{Binder, OptimizerContext};

pub async fn handle_declare_cursor(
handle_args: HandlerArgs,
Expand Down Expand Up @@ -111,25 +118,55 @@ async fn handle_declare_query_cursor(
cursor_name: ObjectName,
query: Box<Query>,
) -> Result<RwPgResponse> {
let (row_stream, pg_descs) =
let (chunk_stream, fields) =
create_stream_for_cursor_stmt(handle_args.clone(), Statement::Query(query)).await?;
handle_args
.session
.get_cursor_manager()
.add_query_cursor(cursor_name, row_stream, pg_descs)
.add_query_cursor(cursor_name, chunk_stream, fields)
.await?;
Ok(PgResponse::empty_result(StatementType::DECLARE_CURSOR))
}

pub async fn create_stream_for_cursor_stmt(
handle_args: HandlerArgs,
stmt: Statement,
) -> Result<(PgResponseStream, Vec<PgFieldDescriptor>)> {
) -> Result<(CursorDataChunkStream, Vec<Field>)> {
let session = handle_args.session.clone();
let plan_fragmenter_result = {
let context = OptimizerContext::from_handler_args(handle_args);
let plan_result = gen_batch_plan_by_statement(&session, context.into(), stmt)?;
gen_batch_plan_fragmenter(&session, plan_result)?
};
create_stream(session, plan_fragmenter_result, vec![]).await
create_chunk_stream_for_cursor(session, plan_fragmenter_result).await
}

pub async fn create_chunk_stream_for_cursor(
session: Arc<SessionImpl>,
plan_fragmenter_result: BatchPlanFragmenterResult,
) -> Result<(CursorDataChunkStream, Vec<Field>)> {
let BatchPlanFragmenterResult {
plan_fragmenter,
query_mode,
schema,
..
} = plan_fragmenter_result;

let can_timeout_cancel = true;

let query = plan_fragmenter.generate_complete_query().await?;
tracing::trace!("Generated query after plan fragmenter: {:?}", &query);

Ok((
match query_mode {
QueryMode::Auto => unreachable!(),
QueryMode::Local => CursorDataChunkStream::LocalDataChunk(Some(
local_execute(session.clone(), query, can_timeout_cancel).await?,
)),
QueryMode::Distributed => CursorDataChunkStream::DistributedDataChunk(Some(
distribute_execute(session.clone(), query, can_timeout_cancel).await?,
)),
},
schema.fields.clone(),
))
}
14 changes: 10 additions & 4 deletions src/frontend/src/handler/extended_handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::{CreateSink, Query, Statement};

use super::query::BoundResult;
use super::{handle, query, HandlerArgs, RwPgResponse};
use super::{fetch_cursor, handle, query, HandlerArgs, RwPgResponse};
use crate::error::Result;
use crate::session::SessionImpl;

Expand Down Expand Up @@ -94,7 +94,7 @@ impl std::fmt::Display for PortalResult {
}
}

pub fn handle_parse(
pub async fn handle_parse(
session: Arc<SessionImpl>,
statement: Statement,
specific_param_types: Vec<Option<DataType>>,
Expand All @@ -109,6 +109,9 @@ pub fn handle_parse(
| Statement::Update { .. } => {
query::handle_parse(handler_args, statement, specific_param_types)
}
Statement::FetchCursor { .. } => {
fetch_cursor::handle_parse(handler_args, statement, specific_param_types).await
}
Statement::CreateView {
query,
materialized,
Expand Down Expand Up @@ -198,8 +201,11 @@ pub async fn handle_execute(session: Arc<SessionImpl>, portal: Portal) -> Result
let _guard = session.txn_begin_implicit(); // TODO(bugen): is this behavior correct?
let sql: Arc<str> = Arc::from(portal.statement.to_string());
let handler_args = HandlerArgs::new(session, &portal.statement, sql)?;

query::handle_execute(handler_args, portal).await
if let Statement::FetchCursor { .. } = &portal.statement {
fetch_cursor::handle_fetch_cursor_execute(handler_args, portal).await
} else {
query::handle_execute(handler_args, portal).await
}
}
Portal::PureStatement(stmt) => {
let sql: Arc<str> = Arc::from(stmt.to_string());
Expand Down
77 changes: 72 additions & 5 deletions src/frontend/src/handler/fetch_cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,57 @@

use pgwire::pg_field_descriptor::PgFieldDescriptor;
use pgwire::pg_response::{PgResponse, StatementType};
use pgwire::types::Row;
use risingwave_sqlparser::ast::FetchCursorStatement;
use pgwire::types::{Format, Row};
use risingwave_common::bail_not_implemented;
use risingwave_common::catalog::Schema;
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::{FetchCursorStatement, Statement};

use super::extended_handle::{PortalResult, PrepareStatement, PreparedResult};
use super::query::BoundResult;
use super::RwPgResponse;
use crate::binder::BoundStatement;
use crate::error::Result;
use crate::handler::HandlerArgs;
use crate::{Binder, PgResponseStream};

pub async fn handle_fetch_cursor_execute(
handler_args: HandlerArgs,
portal_result: PortalResult,
) -> Result<RwPgResponse> {
if let PortalResult {
statement: Statement::FetchCursor { stmt },
bound_result:
BoundResult {
bound: BoundStatement::FetchCursor(fetch_cursor),
..
},
result_formats,
..
} = portal_result
{
match fetch_cursor.returning_schema {
Some(_) => handle_fetch_cursor(handler_args, stmt, &result_formats).await,
None => Ok(build_fetch_cursor_response(vec![], vec![])),
}
} else {
bail_not_implemented!("unsupported portal {}", portal_result)
}
}
pub async fn handle_fetch_cursor(
handle_args: HandlerArgs,
handler_args: HandlerArgs,
stmt: FetchCursorStatement,
formats: &Vec<Format>,
) -> Result<RwPgResponse> {
let session = handle_args.session.clone();
let session = handler_args.session.clone();
let db_name = session.database();
let (_, cursor_name) =
Binder::resolve_schema_qualified_name(db_name, stmt.cursor_name.clone())?;

let cursor_manager = session.get_cursor_manager();

let (rows, pg_descs) = cursor_manager
.get_rows_with_cursor(cursor_name, stmt.count, handle_args)
.get_rows_with_cursor(cursor_name, stmt.count, handler_args, formats)
.await?;
Ok(build_fetch_cursor_response(rows, pg_descs))
}
Expand All @@ -45,3 +75,40 @@ fn build_fetch_cursor_response(rows: Vec<Row>, pg_descs: Vec<PgFieldDescriptor>)
.values(PgResponseStream::from(rows), pg_descs)
.into()
}

pub async fn handle_parse(
handler_args: HandlerArgs,
statement: Statement,
specific_param_types: Vec<Option<DataType>>,
) -> Result<PrepareStatement> {
if let Statement::FetchCursor { stmt } = &statement {
let session = handler_args.session.clone();
let db_name = session.database();
let (_, cursor_name) =
Binder::resolve_schema_qualified_name(db_name, stmt.cursor_name.clone())?;
let fields = session
.get_cursor_manager()
.get_fields_with_cursor(cursor_name.clone())
.await?;

let mut binder = Binder::new_with_param_types(&session, specific_param_types);
let schema = Some(Schema::new(fields));

let bound = binder.bind_fetch_cursor(cursor_name, stmt.count, schema)?;
let bound_result = BoundResult {
stmt_type: StatementType::FETCH_CURSOR,
must_dist: false,
bound: BoundStatement::FetchCursor(Box::new(bound)),
param_types: binder.export_param_types()?,
parsed_params: None,
dependent_relations: binder.included_relations(),
};
let result = PreparedResult {
statement,
bound_result,
};
Ok(PrepareStatement::Prepared(result))
} else {
bail_not_implemented!("unsupported statement {:?}", statement)
}
}
2 changes: 1 addition & 1 deletion src/frontend/src/handler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ pub async fn handle(
declare_cursor::handle_declare_cursor(handler_args, stmt).await
}
Statement::FetchCursor { stmt } => {
fetch_cursor::handle_fetch_cursor(handler_args, stmt).await
fetch_cursor::handle_fetch_cursor(handler_args, stmt, &formats).await
}
Statement::CloseCursor { stmt } => {
close_cursor::handle_close_cursor(handler_args, stmt).await
Expand Down
1 change: 1 addition & 0 deletions src/frontend/src/handler/privilege.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ pub(crate) fn resolve_privileges(stmt: &BoundStatement) -> Vec<ObjectCheckItem>
objects.push(object);
}
BoundStatement::Query(ref query) => objects.extend(resolve_query_privileges(query)),
BoundStatement::FetchCursor(_) => unimplemented!(),
BoundStatement::CreateView(ref create_view) => {
objects.extend(resolve_query_privileges(&create_view.query))
}
Expand Down
5 changes: 2 additions & 3 deletions src/frontend/src/handler/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ async fn execute(
.into())
}

async fn distribute_execute(
pub async fn distribute_execute(
session: Arc<SessionImpl>,
query: Query,
can_timeout_cancel: bool,
Expand All @@ -538,8 +538,7 @@ async fn distribute_execute(
.map_err(|err| err.into())
}

#[expect(clippy::unused_async)]
async fn local_execute(
pub async fn local_execute(
session: Arc<SessionImpl>,
query: Query,
can_timeout_cancel: bool,
Expand Down
Loading

0 comments on commit efdbf3c

Please sign in to comment.