Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(frontend): support extended query for cursor #17821

Merged
merged 10 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions e2e_test/subscription/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ def test_cursor_with_table_alter():
row = execute_query("fetch next from cur",conn)
check_rows_data([1,2],row[0],1)
row = execute_query("fetch next from cur",conn)
assert(row == [])
row = execute_query("fetch next from cur",conn)
check_rows_data([4,4,4],row[0],1)
execute_insert("insert into t1 values(5,5,5)",conn)
execute_insert("flush",conn)
Expand All @@ -258,6 +260,8 @@ def test_cursor_with_table_alter():
execute_insert("insert into t1 values(6,6)",conn)
execute_insert("flush",conn)
row = execute_query("fetch next from cur",conn)
assert(row == [])
row = execute_query("fetch next from cur",conn)
check_rows_data([6,6],row[0],1)
drop_table_subscription()

Expand Down Expand Up @@ -324,6 +328,7 @@ def test_rebuild_table():
check_rows_data([1,1],row[0],1)
check_rows_data([1,1],row[1],4)
check_rows_data([1,100],row[2],3)
drop_table_subscription()

if __name__ == "__main__":
test_cursor_snapshot()
Expand Down
42 changes: 42 additions & 0 deletions src/frontend/src/binder/fetch_cursor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_common::catalog::Schema;

use crate::error::Result;
use crate::Binder;

#[derive(Debug, Clone)]
pub struct BoundFetchCursor {
pub cursor_name: String,

pub count: u32,

pub returning_schema: Option<Schema>,
}

impl Binder {
pub fn bind_fetch_cursor(
&mut self,
cursor_name: String,
count: u32,
returning_schema: Option<Schema>,
) -> Result<BoundFetchCursor> {
Ok(BoundFetchCursor {
cursor_name,
count,
returning_schema,
})
}
}
1 change: 1 addition & 0 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ mod create;
mod create_view;
mod delete;
mod expr;
pub mod fetch_cursor;
mod for_system;
mod insert;
mod query;
Expand Down
7 changes: 7 additions & 0 deletions src/frontend/src/binder/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use risingwave_common::catalog::Field;
use risingwave_sqlparser::ast::Statement;

use super::delete::BoundDelete;
use super::fetch_cursor::BoundFetchCursor;
use super::update::BoundUpdate;
use crate::binder::create_view::BoundCreateView;
use crate::binder::{Binder, BoundInsert, BoundQuery};
Expand All @@ -29,6 +30,7 @@ pub enum BoundStatement {
Delete(Box<BoundDelete>),
Update(Box<BoundUpdate>),
Query(Box<BoundQuery>),
FetchCursor(Box<BoundFetchCursor>),
CreateView(Box<BoundCreateView>),
}

Expand All @@ -48,6 +50,10 @@ impl BoundStatement {
.as_ref()
.map_or(vec![], |s| s.fields().into()),
BoundStatement::Query(q) => q.schema().fields().into(),
BoundStatement::FetchCursor(f) => f
.returning_schema
.as_ref()
.map_or(vec![], |s| s.fields().into()),
BoundStatement::CreateView(_) => vec![],
}
}
Expand Down Expand Up @@ -127,6 +133,7 @@ impl RewriteExprsRecursive for BoundStatement {
BoundStatement::Delete(inner) => inner.rewrite_exprs_recursive(rewriter),
BoundStatement::Update(inner) => inner.rewrite_exprs_recursive(rewriter),
BoundStatement::Query(inner) => inner.rewrite_exprs_recursive(rewriter),
BoundStatement::FetchCursor(_) => {}
BoundStatement::CreateView(inner) => inner.rewrite_exprs_recursive(rewriter),
}
}
Expand Down
53 changes: 45 additions & 8 deletions src/frontend/src/handler/declare_cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,25 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use pgwire::pg_field_descriptor::PgFieldDescriptor;
use std::sync::Arc;

use pgwire::pg_response::{PgResponse, StatementType};
use risingwave_common::catalog::Field;
use risingwave_common::session_config::QueryMode;
use risingwave_common::util::epoch::Epoch;
use risingwave_sqlparser::ast::{DeclareCursorStatement, ObjectName, Query, Since, Statement};

use super::query::{gen_batch_plan_by_statement, gen_batch_plan_fragmenter};
use super::query::{
gen_batch_plan_by_statement, gen_batch_plan_fragmenter, BatchPlanFragmenterResult,
};
use super::util::convert_unix_millis_to_logstore_u64;
use super::RwPgResponse;
use crate::error::{ErrorCode, Result};
use crate::handler::query::create_stream;
use crate::handler::query::{distribute_execute, local_execute};
use crate::handler::HandlerArgs;
use crate::{Binder, OptimizerContext, PgResponseStream};
use crate::session::cursor_manager::CursorDataChunkStream;
use crate::session::SessionImpl;
use crate::{Binder, OptimizerContext};

pub async fn handle_declare_cursor(
handle_args: HandlerArgs,
Expand Down Expand Up @@ -111,25 +118,55 @@ async fn handle_declare_query_cursor(
cursor_name: ObjectName,
query: Box<Query>,
) -> Result<RwPgResponse> {
let (row_stream, pg_descs) =
let (chunk_stream, fields) =
create_stream_for_cursor_stmt(handle_args.clone(), Statement::Query(query)).await?;
handle_args
.session
.get_cursor_manager()
.add_query_cursor(cursor_name, row_stream, pg_descs)
.add_query_cursor(cursor_name, chunk_stream, fields)
.await?;
Ok(PgResponse::empty_result(StatementType::DECLARE_CURSOR))
}

pub async fn create_stream_for_cursor_stmt(
handle_args: HandlerArgs,
stmt: Statement,
) -> Result<(PgResponseStream, Vec<PgFieldDescriptor>)> {
) -> Result<(CursorDataChunkStream, Vec<Field>)> {
let session = handle_args.session.clone();
let plan_fragmenter_result = {
let context = OptimizerContext::from_handler_args(handle_args);
let plan_result = gen_batch_plan_by_statement(&session, context.into(), stmt)?;
gen_batch_plan_fragmenter(&session, plan_result)?
};
create_stream(session, plan_fragmenter_result, vec![]).await
create_chunk_stream_for_cursor(session, plan_fragmenter_result).await
}

pub async fn create_chunk_stream_for_cursor(
session: Arc<SessionImpl>,
plan_fragmenter_result: BatchPlanFragmenterResult,
) -> Result<(CursorDataChunkStream, Vec<Field>)> {
let BatchPlanFragmenterResult {
plan_fragmenter,
query_mode,
schema,
..
} = plan_fragmenter_result;

let can_timeout_cancel = true;

let query = plan_fragmenter.generate_complete_query().await?;
tracing::trace!("Generated query after plan fragmenter: {:?}", &query);

Ok((
match query_mode {
QueryMode::Auto => unreachable!(),
QueryMode::Local => CursorDataChunkStream::LocalDataChunk(Some(
local_execute(session.clone(), query, can_timeout_cancel).await?,
)),
QueryMode::Distributed => CursorDataChunkStream::DistributedDataChunk(Some(
distribute_execute(session.clone(), query, can_timeout_cancel).await?,
)),
},
schema.fields.clone(),
))
}
14 changes: 10 additions & 4 deletions src/frontend/src/handler/extended_handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::{CreateSink, Query, Statement};

use super::query::BoundResult;
use super::{handle, query, HandlerArgs, RwPgResponse};
use super::{fetch_cursor, handle, query, HandlerArgs, RwPgResponse};
use crate::error::Result;
use crate::session::SessionImpl;

Expand Down Expand Up @@ -94,7 +94,7 @@ impl std::fmt::Display for PortalResult {
}
}

pub fn handle_parse(
pub async fn handle_parse(
session: Arc<SessionImpl>,
statement: Statement,
specific_param_types: Vec<Option<DataType>>,
Expand All @@ -109,6 +109,9 @@ pub fn handle_parse(
| Statement::Update { .. } => {
query::handle_parse(handler_args, statement, specific_param_types)
}
Statement::FetchCursor { .. } => {
fetch_cursor::handle_parse(handler_args, statement, specific_param_types).await
}
Statement::CreateView {
query,
materialized,
Expand Down Expand Up @@ -198,8 +201,11 @@ pub async fn handle_execute(session: Arc<SessionImpl>, portal: Portal) -> Result
let _guard = session.txn_begin_implicit(); // TODO(bugen): is this behavior correct?
let sql: Arc<str> = Arc::from(portal.statement.to_string());
let handler_args = HandlerArgs::new(session, &portal.statement, sql)?;

query::handle_execute(handler_args, portal).await
if let Statement::FetchCursor { .. } = &portal.statement {
fetch_cursor::handle_fetch_cursor_execute(handler_args, portal).await
} else {
query::handle_execute(handler_args, portal).await
}
}
Portal::PureStatement(stmt) => {
let sql: Arc<str> = Arc::from(stmt.to_string());
Expand Down
77 changes: 72 additions & 5 deletions src/frontend/src/handler/fetch_cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,57 @@

use pgwire::pg_field_descriptor::PgFieldDescriptor;
use pgwire::pg_response::{PgResponse, StatementType};
use pgwire::types::Row;
use risingwave_sqlparser::ast::FetchCursorStatement;
use pgwire::types::{Format, Row};
use risingwave_common::bail_not_implemented;
use risingwave_common::catalog::Schema;
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::{FetchCursorStatement, Statement};

use super::extended_handle::{PortalResult, PrepareStatement, PreparedResult};
use super::query::BoundResult;
use super::RwPgResponse;
use crate::binder::BoundStatement;
use crate::error::Result;
use crate::handler::HandlerArgs;
use crate::{Binder, PgResponseStream};

pub async fn handle_fetch_cursor_execute(
handler_args: HandlerArgs,
portal_result: PortalResult,
) -> Result<RwPgResponse> {
if let PortalResult {
statement: Statement::FetchCursor { stmt },
bound_result:
BoundResult {
bound: BoundStatement::FetchCursor(fetch_cursor),
..
},
result_formats,
..
} = portal_result
{
match fetch_cursor.returning_schema {
Some(_) => handle_fetch_cursor(handler_args, stmt, &result_formats).await,
None => Ok(build_fetch_cursor_response(vec![], vec![])),
}
} else {
bail_not_implemented!("unsupported portal {}", portal_result)
}
}
pub async fn handle_fetch_cursor(
handle_args: HandlerArgs,
handler_args: HandlerArgs,
stmt: FetchCursorStatement,
formats: &Vec<Format>,
) -> Result<RwPgResponse> {
let session = handle_args.session.clone();
let session = handler_args.session.clone();
let db_name = session.database();
let (_, cursor_name) =
Binder::resolve_schema_qualified_name(db_name, stmt.cursor_name.clone())?;

let cursor_manager = session.get_cursor_manager();

let (rows, pg_descs) = cursor_manager
.get_rows_with_cursor(cursor_name, stmt.count, handle_args)
.get_rows_with_cursor(cursor_name, stmt.count, handler_args, formats)
.await?;
Ok(build_fetch_cursor_response(rows, pg_descs))
}
Expand All @@ -45,3 +75,40 @@ fn build_fetch_cursor_response(rows: Vec<Row>, pg_descs: Vec<PgFieldDescriptor>)
.values(PgResponseStream::from(rows), pg_descs)
.into()
}

pub async fn handle_parse(
handler_args: HandlerArgs,
statement: Statement,
specific_param_types: Vec<Option<DataType>>,
) -> Result<PrepareStatement> {
if let Statement::FetchCursor { stmt } = &statement {
let session = handler_args.session.clone();
let db_name = session.database();
let (_, cursor_name) =
Binder::resolve_schema_qualified_name(db_name, stmt.cursor_name.clone())?;
let fields = session
.get_cursor_manager()
.get_fields_with_cursor(cursor_name.clone())
.await?;

let mut binder = Binder::new_with_param_types(&session, specific_param_types);
let schema = Some(Schema::new(fields));

let bound = binder.bind_fetch_cursor(cursor_name, stmt.count, schema)?;
let bound_result = BoundResult {
stmt_type: StatementType::FETCH_CURSOR,
must_dist: false,
bound: BoundStatement::FetchCursor(Box::new(bound)),
param_types: binder.export_param_types()?,
parsed_params: None,
dependent_relations: binder.included_relations(),
};
let result = PreparedResult {
statement,
bound_result,
};
Ok(PrepareStatement::Prepared(result))
} else {
bail_not_implemented!("unsupported statement {:?}", statement)
}
}
2 changes: 1 addition & 1 deletion src/frontend/src/handler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ pub async fn handle(
declare_cursor::handle_declare_cursor(handler_args, stmt).await
}
Statement::FetchCursor { stmt } => {
fetch_cursor::handle_fetch_cursor(handler_args, stmt).await
fetch_cursor::handle_fetch_cursor(handler_args, stmt, &formats).await
}
Statement::CloseCursor { stmt } => {
close_cursor::handle_close_cursor(handler_args, stmt).await
Expand Down
1 change: 1 addition & 0 deletions src/frontend/src/handler/privilege.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ pub(crate) fn resolve_privileges(stmt: &BoundStatement) -> Vec<ObjectCheckItem>
objects.push(object);
}
BoundStatement::Query(ref query) => objects.extend(resolve_query_privileges(query)),
BoundStatement::FetchCursor(_) => unimplemented!(),
BoundStatement::CreateView(ref create_view) => {
objects.extend(resolve_query_privileges(&create_view.query))
}
Expand Down
5 changes: 2 additions & 3 deletions src/frontend/src/handler/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ async fn execute(
.into())
}

async fn distribute_execute(
pub async fn distribute_execute(
session: Arc<SessionImpl>,
query: Query,
can_timeout_cancel: bool,
Expand All @@ -538,8 +538,7 @@ async fn distribute_execute(
.map_err(|err| err.into())
}

#[expect(clippy::unused_async)]
async fn local_execute(
pub async fn local_execute(
session: Arc<SessionImpl>,
query: Query,
can_timeout_cancel: bool,
Expand Down
Loading
Loading