Skip to content

Commit

Permalink
add close cursor
Browse files Browse the repository at this point in the history
  • Loading branch information
xxhZs committed Feb 21, 2024
1 parent 4192f18 commit 49ba7ba
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 16 deletions.
37 changes: 37 additions & 0 deletions src/frontend/src/handler/close_cursor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use pgwire::pg_response::{PgResponse, StatementType};
use risingwave_sqlparser::ast::CloseCursorStatement;

use super::{HandlerArgs, RwPgResponse};
use crate::error::Result;
use crate::Binder;

pub async fn handle_close_cursor(
handle_args: HandlerArgs,
stmt: CloseCursorStatement,
) -> Result<RwPgResponse> {
let session = handle_args.session.clone();
let db_name = session.database();
let (_, cursor_name) =
Binder::resolve_schema_qualified_name(db_name, stmt.cursor_name.clone())?;
session
.get_cursor_manager()
.lock()
.await
.remove_cursor(cursor_name)?;

Ok(PgResponse::empty_result(StatementType::CLOSE_CURSOR))
}
8 changes: 7 additions & 1 deletion src/frontend/src/handler/declare_cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use core::time::Duration;

use pgwire::pg_response::{PgResponse, StatementType};
use pgwire::types::Format;
use risingwave_sqlparser::ast::{DeclareCursorStatement, Ident, ObjectName, Statement};
Expand Down Expand Up @@ -45,7 +47,10 @@ pub async fn handle_declare_cursor(
let is_snapshot = start_rw_timestamp == 0;
let subscription =
session.get_subscription_by_name(schema_name, &cursor_from_subscription_name)?;
// let retention_seconds = subscription.get_retention_seconds()?;
let cursor_retention_secs = std::cmp::min(
session.statement_timeout(),
Duration::from_secs(subscription.get_retention_seconds()?),
);
let (start_rw_timestamp, res) = if is_snapshot {
let subscription_from_table_name = ObjectName(vec![Ident::from(
subscription.subscription_from_name.as_ref(),
Expand All @@ -72,6 +77,7 @@ pub async fn handle_declare_cursor(
is_snapshot,
true,
stmt.cursor_from.clone(),
cursor_retention_secs,
)
.await?;
session
Expand Down
18 changes: 14 additions & 4 deletions src/frontend/src/handler/fetch_cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ pub async fn handle_fetch_cursor(
CursorRowValue::Row((row, pg_descs)) => {
return Ok(build_fetch_cursor_response(vec![row], pg_descs));
}
CursorRowValue::QueryWithNextRwTimestamp(rw_timestamp, subscription_name) => {
CursorRowValue::QueryWithNextRwTimestamp(
rw_timestamp,
subscription_name,
cursor_retention_secs,
) => {
let query_stmt =
gen_query_from_logstore_ge_rw_timestamp(subscription_name.clone(), rw_timestamp)?;
let res = handle_query(handle_args, query_stmt, formats).await?;
Expand All @@ -54,11 +58,16 @@ pub async fn handle_fetch_cursor(
false,
true,
subscription_name.clone(),
cursor_retention_secs,
)
.await?;
cursor_manager.update_cursor(cursor)?;
}
CursorRowValue::QueryWithStartRwTimestamp(rw_timestamp, subscription_name) => {
CursorRowValue::QueryWithStartRwTimestamp(
rw_timestamp,
subscription_name,
cursor_retention_secs,
) => {
let query_stmt = gen_query_from_logstore_ge_rw_timestamp(
subscription_name.clone(),
rw_timestamp + 1,
Expand All @@ -72,6 +81,7 @@ pub async fn handle_fetch_cursor(
false,
false,
subscription_name.clone(),
cursor_retention_secs,
)
.await?;
cursor_manager.update_cursor(cursor)?;
Expand All @@ -82,10 +92,10 @@ pub async fn handle_fetch_cursor(
CursorRowValue::Row((row, pg_descs)) => {
Ok(build_fetch_cursor_response(vec![row], pg_descs))
}
CursorRowValue::QueryWithStartRwTimestamp(_, _) => {
CursorRowValue::QueryWithStartRwTimestamp(_, _, _) => {
Ok(build_fetch_cursor_response(vec![], vec![]))
}
CursorRowValue::QueryWithNextRwTimestamp(_, _) => Err(ErrorCode::InternalError(
CursorRowValue::QueryWithNextRwTimestamp(_, _, _) => Err(ErrorCode::InternalError(
"Fetch cursor, one must get a row or null".to_string(),
)
.into()),
Expand Down
4 changes: 4 additions & 0 deletions src/frontend/src/handler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ mod alter_system;
mod alter_table_column;
pub mod alter_user;
pub mod cancel_job;
pub mod close_cursor;
mod comment;
pub mod create_connection;
mod create_database;
Expand Down Expand Up @@ -318,6 +319,9 @@ pub async fn handle(
Statement::FetchCursor { stmt } => {
fetch_cursor::handle_fetch_cursor(handler_args, stmt, formats).await
}
Statement::CloseCursor { stmt } => {
close_cursor::handle_close_cursor(handler_args, stmt).await
}
Statement::AlterUser(stmt) => alter_user::handle_alter_user(handler_args, stmt).await,
Statement::Grant { .. } => {
handle_privilege::handle_grant_privilege(handler_args, stmt).await
Expand Down
28 changes: 18 additions & 10 deletions src/frontend/src/session/cursor_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
// limitations under the License.

use core::ops::Index;
use core::time::Duration;
use std::collections::{HashMap, VecDeque};
use std::time::Instant;

use bytes::Bytes;
use futures::StreamExt;
Expand All @@ -34,6 +36,8 @@ pub struct Cursor {
is_snapshot: bool,
subscription_name: ObjectName,
pg_desc: Vec<PgFieldDescriptor>,
cursor_need_drop_time: Instant,
cursor_retention_secs: Duration,
}

impl Cursor {
Expand All @@ -44,6 +48,7 @@ impl Cursor {
is_snapshot: bool,
need_check_timestamp: bool,
subscription_name: ObjectName,
cursor_retention_secs: Duration,
) -> Result<Self> {
let (rw_timestamp, data_chunk_cache) = if is_snapshot {
(start_timestamp, vec![])
Expand All @@ -62,7 +67,6 @@ impl Cursor {
let query_timestamp = data_chunk_cache
.get(0)
.map(|row| {
println!("123");
row.index(0)
.as_ref()
.map(|bytes| std::str::from_utf8(bytes).unwrap().parse().unwrap())
Expand All @@ -79,19 +83,27 @@ impl Cursor {
(query_timestamp, data_chunk_cache)
};
let pg_desc = build_desc(rw_pg_response.row_desc(), is_snapshot);
// check timestamp.
let cursor_need_drop_time = Instant::now() + cursor_retention_secs;
Ok(Self {
cursor_name,
rw_pg_response,
data_chunk_cache: VecDeque::from(data_chunk_cache),
rw_timestamp,
is_snapshot,
subscription_name,
cursor_need_drop_time,
cursor_retention_secs,
pg_desc,
})
}

pub async fn next(&mut self) -> Result<CursorRowValue> {
if Instant::now() > self.cursor_need_drop_time {
return Err(ErrorCode::InternalError(
"The cursor has exceeded its maximum lifetime, please recreate it.".to_string(),
)
.into());
}
let stream = self.rw_pg_response.values_stream();
loop {
if self.data_chunk_cache.is_empty() {
Expand All @@ -106,11 +118,10 @@ impl Cursor {
return Ok(CursorRowValue::QueryWithStartRwTimestamp(
self.rw_timestamp,
self.subscription_name.clone(),
self.cursor_retention_secs,
));
}
}
println!("desc:{:?}", self.pg_desc);
println!("data_chunk_cache:{:?}", self.data_chunk_cache);
if let Some(row) = self.data_chunk_cache.pop_front() {
let new_row = row.take();
if self.is_snapshot {
Expand All @@ -127,14 +138,11 @@ impl Cursor {
.map(|bytes| std::str::from_utf8(bytes).unwrap().parse().unwrap())
.unwrap();

println!(
"timestamp_row:{:?},self.rw_timestamp{:?}",
timestamp_row, self.rw_timestamp
);
if timestamp_row != self.rw_timestamp {
return Ok(CursorRowValue::QueryWithNextRwTimestamp(
timestamp_row,
self.subscription_name.clone(),
self.cursor_retention_secs,
));
} else {
return Ok(CursorRowValue::Row((
Expand Down Expand Up @@ -193,8 +201,8 @@ pub fn build_desc(mut descs: Vec<PgFieldDescriptor>, is_snapshot: bool) -> Vec<P

pub enum CursorRowValue {
Row((Row, Vec<PgFieldDescriptor>)),
QueryWithNextRwTimestamp(i64, ObjectName),
QueryWithStartRwTimestamp(i64, ObjectName),
QueryWithNextRwTimestamp(i64, ObjectName, Duration),
QueryWithStartRwTimestamp(i64, ObjectName, Duration),
}
#[derive(Default)]
pub struct CursorManager {
Expand Down
8 changes: 7 additions & 1 deletion src/sqlparser/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,11 @@ pub enum Statement {
stmt: FetchCursorStatement,
},

// CLOSE CURSOR
CloseCursor {
stmt: CloseCursorStatement,
},

/// ALTER DATABASE
AlterDatabase {
name: ObjectName,
Expand Down Expand Up @@ -1740,7 +1745,8 @@ impl fmt::Display for Statement {
Statement::CreateSubscription { stmt } => write!(f, "CREATE SUBSCRIPTION {}", stmt,),
Statement::CreateConnection { stmt } => write!(f, "CREATE CONNECTION {}", stmt,),
Statement::DeclareCursor { stmt } => write!(f, "DECLARE CURSOR {}", stmt,),
Statement::FetchCursor { stmt } => write!(f, "DECLARE {}", stmt),
Statement::FetchCursor { stmt } => write!(f, "FETCH {}", stmt),
Statement::CloseCursor { stmt } => write!(f, "CLOSE {}", stmt),
Statement::AlterDatabase { name, operation } => {
write!(f, "ALTER DATABASE {} {}", name, operation)
}
Expand Down
24 changes: 24 additions & 0 deletions src/sqlparser/src/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,30 @@ impl fmt::Display for FetchCursorStatement {
}
}

// sql_grammar!(CloseCursorStatement {
// cursor_name: Ident,
// });
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct CloseCursorStatement {
pub cursor_name: ObjectName,
}

impl ParseTo for CloseCursorStatement {
fn parse_to(p: &mut Parser) -> Result<Self, ParserError> {
impl_parse_to!(cursor_name: ObjectName, p);

Ok(Self { cursor_name })
}
}
impl fmt::Display for CloseCursorStatement {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut v: Vec<String> = vec![];
impl_fmt_display!(cursor_name, v, self);
v.iter().join(" ").fmt(f)
}
}

// sql_grammar!(CreateConnectionStatement {
// if_not_exists => [Keyword::IF, Keyword::NOT, Keyword::EXISTS],
// connection_name: Ident,
Expand Down
7 changes: 7 additions & 0 deletions src/sqlparser/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ impl Parser {
}
Keyword::DECLARE => Ok(self.parse_declare()?),
Keyword::FETCH => Ok(self.parse_fetch_cursor()?),
Keyword::CLOSE => Ok(self.parse_close_cursor()?),
Keyword::TRUNCATE => Ok(self.parse_truncate()?),
Keyword::CREATE => Ok(self.parse_create()?),
Keyword::DROP => Ok(self.parse_drop()?),
Expand Down Expand Up @@ -2307,6 +2308,12 @@ impl Parser {
})
}

pub fn parse_close_cursor(&mut self) -> Result<Statement, ParserError> {
Ok(Statement::CloseCursor {
stmt: CloseCursorStatement::parse_to(self)?,
})
}

fn parse_table_column_def(&mut self) -> Result<TableColumnDef, ParserError> {
Ok(TableColumnDef {
name: self.parse_identifier_non_reserved()?,
Expand Down
1 change: 1 addition & 0 deletions src/utils/pgwire/src/pg_response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub enum StatementType {
FETCH,
COPY,
EXPLAIN,
CLOSE_CURSOR,
CREATE_TABLE,
CREATE_MATERIALIZED_VIEW,
CREATE_VIEW,
Expand Down

0 comments on commit 49ba7ba

Please sign in to comment.