diff --git a/src/frontend/planner_test/src/lib.rs b/src/frontend/planner_test/src/lib.rs index 9c8db9dde51ee..3e6ebc7ef4322 100644 --- a/src/frontend/planner_test/src/lib.rs +++ b/src/frontend/planner_test/src/lib.rs @@ -282,7 +282,12 @@ impl TestCase { .chain(std::iter::once(self.sql())) { result = self - .run_sql(sql, session.clone(), do_check_result, result) + .run_sql( + Arc::from(sql.to_owned()), + session.clone(), + do_check_result, + result, + ) .await?; } @@ -326,7 +331,7 @@ impl TestCase { ); let temp_file = create_proto_file(content.as_str()); self.run_sql( - &(sql + temp_file.path().to_str().unwrap() + "')"), + Arc::from(sql + temp_file.path().to_str().unwrap() + "')"), session.clone(), false, None, @@ -357,7 +362,7 @@ impl TestCase { ); let temp_file = create_proto_file(content.as_str()); self.run_sql( - &(sql + temp_file.path().to_str().unwrap() + "')"), + Arc::from(sql + temp_file.path().to_str().unwrap() + "')"), session.clone(), false, None, @@ -376,15 +381,15 @@ impl TestCase { async fn run_sql( &self, - sql: &str, + sql: Arc, session: Arc, do_check_result: bool, mut result: Option, ) -> Result> { - let statements = Parser::parse_sql(sql).unwrap(); + let statements = Parser::parse_sql(&sql).unwrap(); for stmt in statements { // TODO: `sql` may contain multiple statements here. - let handler_args = HandlerArgs::new(session.clone(), &stmt, sql)?; + let handler_args = HandlerArgs::new(session.clone(), &stmt, sql.clone())?; let _guard = session.txn_begin_implicit(); match stmt.clone() { Statement::Query(_) @@ -399,7 +404,7 @@ impl TestCase { ..Default::default() }; let context = OptimizerContext::new( - HandlerArgs::new(session.clone(), &stmt, sql)?, + HandlerArgs::new(session.clone(), &stmt, sql.clone())?, explain_options, ); let ret = self.apply_query(&stmt, context.into())?; diff --git a/src/frontend/src/handler/alter_table_column.rs b/src/frontend/src/handler/alter_table_column.rs index 1b3babc41ceaf..f75c4043290df 100644 --- a/src/frontend/src/handler/alter_table_column.rs +++ b/src/frontend/src/handler/alter_table_column.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use anyhow::Context; use itertools::Itertools; use pgwire::pg_response::{PgResponse, StatementType}; @@ -172,7 +174,7 @@ pub async fn handle_alter_table_column( } // Create handler args as if we're creating a new table with the altered definition. - let handler_args = HandlerArgs::new(session.clone(), &definition, "")?; + let handler_args = HandlerArgs::new(session.clone(), &definition, Arc::from(""))?; let col_id_gen = ColumnIdGenerator::new_alter(&original_catalog); let Statement::CreateTable { columns, diff --git a/src/frontend/src/handler/extended_handle.rs b/src/frontend/src/handler/extended_handle.rs index 1c0f0d36f0cbf..d6f22984f404e 100644 --- a/src/frontend/src/handler/extended_handle.rs +++ b/src/frontend/src/handler/extended_handle.rs @@ -57,11 +57,7 @@ impl std::fmt::Display for Portal { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match &self { Portal::Empty => write!(f, "Empty"), - Portal::Portal(portal) => write!( - f, - "{}, params = {:?}", - portal.statement, portal.bound_result.parsed_params - ), + Portal::Portal(portal) => portal.fmt(f), Portal::PureStatement(stmt) => write!(f, "{}", stmt), } } @@ -74,14 +70,24 @@ pub struct PortalResult { pub result_formats: Vec, } +impl std::fmt::Display for PortalResult { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!( + f, + "{}, params = {:?}", + self.statement, self.bound_result.parsed_params + ) + } +} + pub fn handle_parse( session: Arc, statement: Statement, specific_param_types: Vec>, ) -> Result { session.clear_cancel_query_flag(); - let str_sql = statement.to_string(); - let handler_args = HandlerArgs::new(session, &statement, &str_sql)?; + let sql: Arc = Arc::from(statement.to_string()); + let handler_args = HandlerArgs::new(session, &statement, sql)?; match &statement { Statement::Query(_) | Statement::Insert { .. } @@ -181,8 +187,8 @@ pub async fn handle_execute(session: Arc, portal: Portal) -> Result Portal::Portal(portal) => { session.clear_cancel_query_flag(); let _guard = session.txn_begin_implicit(); // TODO(bugen): is this behavior correct? - let str_sql = portal.statement.to_string(); - let handler_args = HandlerArgs::new(session, &portal.statement, &str_sql)?; + let sql: Arc = Arc::from(portal.statement.to_string()); + let handler_args = HandlerArgs::new(session, &portal.statement, sql)?; match &portal.statement { Statement::Query(_) | Statement::Insert { .. } @@ -192,8 +198,8 @@ pub async fn handle_execute(session: Arc, portal: Portal) -> Result } } Portal::PureStatement(stmt) => { - let sql = stmt.to_string(); - handle(session, stmt, &sql, vec![]).await + let sql: Arc = Arc::from(stmt.to_string()); + handle(session, stmt, sql, vec![]).await } } } diff --git a/src/frontend/src/handler/mod.rs b/src/frontend/src/handler/mod.rs index 748269a0bd58a..de8f048660f4a 100644 --- a/src/frontend/src/handler/mod.rs +++ b/src/frontend/src/handler/mod.rs @@ -110,16 +110,16 @@ impl From> for PgResponseStream { #[derive(Clone)] pub struct HandlerArgs { pub session: Arc, - pub sql: String, + pub sql: Arc, pub normalized_sql: String, pub with_options: WithOptions, } impl HandlerArgs { - pub fn new(session: Arc, stmt: &Statement, sql: &str) -> Result { + pub fn new(session: Arc, stmt: &Statement, sql: Arc) -> Result { Ok(Self { session, - sql: sql.into(), + sql, with_options: WithOptions::try_from(stmt)?, normalized_sql: Self::normalize_sql(stmt), }) @@ -172,12 +172,11 @@ impl HandlerArgs { pub async fn handle( session: Arc, stmt: Statement, - sql: &str, + sql: Arc, formats: Vec, ) -> Result { session.clear_cancel_query_flag(); let _guard = session.txn_begin_implicit(); - let handler_args = HandlerArgs::new(session, &stmt, sql)?; match stmt { diff --git a/src/frontend/src/handler/show.rs b/src/frontend/src/handler/show.rs index 88a9b1e694e33..2957519e39dff 100644 --- a/src/frontend/src/handler/show.rs +++ b/src/frontend/src/handler/show.rs @@ -16,7 +16,9 @@ use std::sync::Arc; use itertools::Itertools; use pgwire::pg_field_descriptor::PgFieldDescriptor; +use pgwire::pg_protocol::truncated_fmt; use pgwire::pg_response::{PgResponse, StatementType}; +use pgwire::pg_server::Session; use pgwire::types::Row; use risingwave_common::catalog::{ColumnCatalog, DEFAULT_SCHEMA_NAME}; use risingwave_common::error::{ErrorCode, Result}; @@ -267,6 +269,32 @@ pub async fn handle_show_object( .values(rows.into(), row_desc) .into()); } + ShowObject::ProcessList => { + let rows = { + let sessions_map = session.env().sessions_map(); + sessions_map + .read() + .values() + .map(|s| { + Row::new(vec![ + Some(format!("{}-{}", s.id().0, s.id().1).into()), + Some(s.user_name().to_owned().into()), + Some(format!("{}", s.peer_addr()).into()), + Some(s.database().to_owned().into()), + s.elapse_since_running_sql() + .map(|mills| format!("{}ms", mills).into()), + s.running_sql().map(|sql| { + format!("{}", truncated_fmt::TruncatedFmt(&sql, 1024)).into() + }), + ]) + }) + .collect_vec() + }; + + return Ok(PgResponse::builder(StatementType::SHOW_COMMAND) + .values(rows.into(), row_desc) + .into()); + } }; let rows = names diff --git a/src/frontend/src/optimizer/optimizer_context.rs b/src/frontend/src/optimizer/optimizer_context.rs index dcb4b74464b37..e4b8d3c566813 100644 --- a/src/frontend/src/optimizer/optimizer_context.rs +++ b/src/frontend/src/optimizer/optimizer_context.rs @@ -33,7 +33,7 @@ pub struct OptimizerContext { /// Store plan node id next_plan_node_id: RefCell, /// The original SQL string, used for debugging. - sql: String, + sql: Arc, /// Normalized SQL string. See [`HandlerArgs::normalize_sql`]. normalized_sql: String, /// Explain options @@ -97,7 +97,7 @@ impl OptimizerContext { Self { session_ctx: Arc::new(SessionImpl::mock()), next_plan_node_id: RefCell::new(0), - sql: "".to_owned(), + sql: Arc::from(""), normalized_sql: "".to_owned(), explain_options: ExplainOptions::default(), optimizer_trace: RefCell::new(vec![]), diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index 0170057024621..201882cc6416c 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -14,16 +14,22 @@ use std::collections::HashMap; use std::io::{Error, ErrorKind}; +#[cfg(test)] +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::atomic::{AtomicI32, Ordering}; -use std::sync::Arc; -use std::time::Duration; +use std::sync::{Arc, Weak}; +use std::time::{Duration, Instant}; use bytes::Bytes; use parking_lot::{Mutex, RwLock, RwLockReadGuard}; +use pgwire::net::{Address, AddressRef}; use pgwire::pg_field_descriptor::PgFieldDescriptor; use pgwire::pg_message::TransactionStatus; use pgwire::pg_response::PgResponse; -use pgwire::pg_server::{BoxedError, Session, SessionId, SessionManager, UserAuthenticator}; +use pgwire::pg_server::{ + BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager, + UserAuthenticator, +}; use pgwire::types::{Format, FormatIterator}; use rand::RngCore; use risingwave_batch::task::{ShutdownSender, ShutdownToken}; @@ -132,7 +138,7 @@ pub struct FrontendEnv { } /// Session map identified by `(process_id, secret_key)` -type SessionMapRef = Arc>>>; +type SessionMapRef = Arc>>>; impl FrontendEnv { pub fn mock() -> Self { @@ -169,7 +175,7 @@ impl FrontendEnv { hummock_snapshot_manager, server_addr, client_pool, - sessions_map: Arc::new(Mutex::new(HashMap::new())), + sessions_map: Arc::new(RwLock::new(HashMap::new())), frontend_metrics: Arc::new(FrontendMetrics::for_test()), batch_config: BatchConfig::default(), meta_config: MetaConfig::default(), @@ -329,7 +335,7 @@ impl FrontendEnv { server_addr: frontend_address, client_pool, frontend_metrics, - sessions_map: Arc::new(Mutex::new(HashMap::new())), + sessions_map: Arc::new(RwLock::new(HashMap::new())), batch_config, meta_config, source_metrics, @@ -415,6 +421,10 @@ impl FrontendEnv { &self.creating_streaming_job_tracker } + pub fn sessions_map(&self) -> &SessionMapRef { + &self.sessions_map + } + pub fn compute_runtime(&self) -> Arc { self.compute_runtime.clone() } @@ -450,7 +460,7 @@ impl AuthContext { pub struct SessionImpl { env: FrontendEnv, auth_context: Arc, - // Used for user authentication. + /// Used for user authentication. user_authenticator: UserAuthenticator, /// Stores the value of configurations. config_map: Arc>, @@ -460,15 +470,21 @@ pub struct SessionImpl { /// Identified by process_id, secret_key. Corresponds to SessionManager. id: (i32, i32), + /// Client address + peer_addr: AddressRef, + /// Transaction state. - // TODO: get rid of the `Mutex` here as a workaround if the `Send` requirement of - // async functions, there should actually be no contention. + /// TODO: get rid of the `Mutex` here as a workaround if the `Send` requirement of + /// async functions, there should actually be no contention. txn: Arc>, /// Query cancel flag. /// This flag is set only when current query is executed in local mode, and used to cancel /// local query. current_query_cancel_flag: Mutex>, + + /// execution context represents the lifetime of a running SQL in the current session + exec_context: Mutex>>, } #[derive(Error, Debug)] @@ -494,6 +510,7 @@ impl SessionImpl { auth_context: Arc, user_authenticator: UserAuthenticator, id: SessionId, + peer_addr: AddressRef, ) -> Self { Self { env, @@ -501,9 +518,11 @@ impl SessionImpl { user_authenticator, config_map: Default::default(), id, + peer_addr, txn: Default::default(), current_query_cancel_flag: Mutex::new(None), notices: Default::default(), + exec_context: Mutex::new(None), } } @@ -523,6 +542,12 @@ impl SessionImpl { txn: Default::default(), current_query_cancel_flag: Mutex::new(None), notices: Default::default(), + exec_context: Mutex::new(None), + peer_addr: Address::Tcp(SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + 8080, + )) + .into(), } } @@ -571,6 +596,26 @@ impl SessionImpl { self.id } + pub fn running_sql(&self) -> Option> { + self.exec_context + .lock() + .as_ref() + .and_then(|weak| weak.upgrade()) + .map(|context| context.running_sql.clone()) + } + + pub fn peer_addr(&self) -> &Address { + &self.peer_addr + } + + pub fn elapse_since_running_sql(&self) -> Option { + self.exec_context + .lock() + .as_ref() + .and_then(|weak| weak.upgrade()) + .map(|context| context.last_instant.elapsed().as_millis()) + } + pub fn check_relation_name_duplicated( &self, name: ObjectName, @@ -710,11 +755,11 @@ impl SessionImpl { /// Maybe we can remove it in the future. pub async fn run_statement( self: Arc, - sql: &str, + sql: Arc, formats: Vec, ) -> std::result::Result, BoxedError> { // Parse sql. - let mut stmts = Parser::parse_sql(sql) + let mut stmts = Parser::parse_sql(&sql) .inspect_err(|e| tracing::error!("failed to parse sql:\n{}:\n{}", sql, e))?; if stmts.is_empty() { return Ok(PgResponse::empty_result( @@ -730,7 +775,7 @@ impl SessionImpl { } let stmt = stmts.swap_remove(0); let rsp = { - let mut handle_fut = Box::pin(handle(self, stmt, sql, formats)); + let mut handle_fut = Box::pin(handle(self, stmt, sql.clone(), formats)); if cfg!(debug_assertions) { // Report the SQL in the log periodically if the query is slow. const SLOW_QUERY_LOG_PERIOD: Duration = Duration::from_secs(60); @@ -740,7 +785,7 @@ impl SessionImpl { Ok(result) => break result, Err(_) => tracing::warn!( target: SLOW_QUERY_LOG, - sql, + sql = sql.as_ref(), "slow query has been running for another {SLOW_QUERY_LOG_PERIOD:?}" ), } @@ -782,6 +827,7 @@ impl SessionManager for SessionManagerImpl { &self, database: &str, user_name: &str, + peer_addr: AddressRef, ) -> std::result::Result, BoxedError> { let catalog_reader = self.env.catalog_reader(); let reader = catalog_reader.read_guard(); @@ -849,6 +895,7 @@ impl SessionManager for SessionManagerImpl { )), user_authenticator, id, + peer_addr, ) .into(); self.insert_session(session_impl.clone()); @@ -864,7 +911,7 @@ impl SessionManager for SessionManagerImpl { /// Used when cancel request happened. fn cancel_queries_in_session(&self, session_id: SessionId) { - let guard = self.env.sessions_map.lock(); + let guard = self.env.sessions_map.read(); if let Some(session) = guard.get(&session_id) { session.cancel_current_query() } else { @@ -873,7 +920,7 @@ impl SessionManager for SessionManagerImpl { } fn cancel_creating_jobs_in_session(&self, session_id: SessionId) { - let guard = self.env.sessions_map.lock(); + let guard = self.env.sessions_map.read(); if let Some(session) = guard.get(&session_id) { session.cancel_current_creating_job() } else { @@ -899,7 +946,7 @@ impl SessionManagerImpl { fn insert_session(&self, session: Arc) { let active_sessions = { - let mut write_guard = self.env.sessions_map.lock(); + let mut write_guard = self.env.sessions_map.write(); write_guard.insert(session.id(), session); write_guard.len() }; @@ -911,7 +958,7 @@ impl SessionManagerImpl { fn delete_session(&self, session_id: &SessionId) { let active_sessions = { - let mut write_guard = self.env.sessions_map.lock(); + let mut write_guard = self.env.sessions_map.write(); write_guard.remove(session_id); write_guard.len() }; @@ -934,9 +981,11 @@ impl Session for SessionImpl { stmt: Statement, format: Format, ) -> std::result::Result, BoxedError> { - let sql_str = stmt.to_string(); + let string = stmt.to_string(); + let sql_str = string.as_str(); + let sql: Arc = Arc::from(sql_str); let rsp = { - let mut handle_fut = Box::pin(handle(self, stmt, &sql_str, vec![format])); + let mut handle_fut = Box::pin(handle(self, stmt, sql.clone(), vec![format])); if cfg!(debug_assertions) { // Report the SQL in the log periodically if the query is slow. const SLOW_QUERY_LOG_PERIOD: Duration = Duration::from_secs(60); @@ -953,7 +1002,7 @@ impl Session for SessionImpl { handle_fut.await } } - .inspect_err(|e| tracing::error!("failed to handle sql:\n{}:\n{}", sql_str, e))?; + .inspect_err(|e| tracing::error!("failed to handle sql:\n{}:\n{}", sql, e))?; Ok(rsp) } @@ -1072,6 +1121,16 @@ impl Session for SessionImpl { // TODO: failed transaction } } + + /// Init and return an `ExecContextGuard` which could be used as a guard to represent the execution flow. + fn init_exec_context(&self, sql: Arc) -> ExecContextGuard { + let exec_context = Arc::new(ExecContext { + running_sql: sql, + last_instant: Instant::now(), + }); + *self.exec_context.lock() = Some(Arc::downgrade(&exec_context)); + ExecContextGuard::new(exec_context) + } } /// Returns row description of the statement diff --git a/src/frontend/src/test_utils.rs b/src/frontend/src/test_utils.rs index c2c6840c8dd77..4e336b29c002d 100644 --- a/src/frontend/src/test_utils.rs +++ b/src/frontend/src/test_utils.rs @@ -14,11 +14,13 @@ use std::collections::HashMap; use std::io::Write; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; use futures_async_stream::for_await; use parking_lot::RwLock; +use pgwire::net::{Address, AddressRef}; use pgwire::pg_response::StatementType; use pgwire::pg_server::{BoxedError, SessionId, SessionManager, UserAuthenticator}; use pgwire::types::Row; @@ -76,6 +78,7 @@ impl SessionManager for LocalFrontend { &self, _database: &str, _user_name: &str, + _peer_addr: AddressRef, ) -> std::result::Result, BoxedError> { Ok(self.session_ref()) } @@ -104,8 +107,8 @@ impl LocalFrontend { &self, sql: impl Into, ) -> std::result::Result> { - let sql = sql.into(); - self.session_ref().run_statement(sql.as_str(), vec![]).await + let sql: Arc = Arc::from(sql.into()); + self.session_ref().run_statement(sql, vec![]).await } pub async fn run_sql_with_session( @@ -113,8 +116,8 @@ impl LocalFrontend { session_ref: Arc, sql: impl Into, ) -> std::result::Result> { - let sql = sql.into(); - session_ref.run_statement(sql.as_str(), vec![]).await + let sql: Arc = Arc::from(sql.into()); + session_ref.run_statement(sql, vec![]).await } pub async fn run_user_sql( @@ -124,9 +127,9 @@ impl LocalFrontend { user_name: String, user_id: UserId, ) -> std::result::Result> { - let sql = sql.into(); + let sql: Arc = Arc::from(sql.into()); self.session_user_ref(database, user_name, user_id) - .run_statement(sql.as_str(), vec![]) + .run_statement(sql, vec![]) .await } @@ -178,6 +181,11 @@ impl LocalFrontend { UserAuthenticator::None, // Local Frontend use a non-sense id. (0, 0), + Address::Tcp(SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + 6666, + )) + .into(), )) } } diff --git a/src/frontend/src/utils/infer_stmt_row_desc.rs b/src/frontend/src/utils/infer_stmt_row_desc.rs index 8ebb7ac5c7d7a..2dc4115e305c0 100644 --- a/src/frontend/src/utils/infer_stmt_row_desc.rs +++ b/src/frontend/src/utils/infer_stmt_row_desc.rs @@ -161,6 +161,38 @@ pub fn infer_show_object(objects: &ShowObject) -> Vec { DataType::Varchar.type_len(), ), ], + ShowObject::ProcessList => vec![ + PgFieldDescriptor::new( + "Id".to_owned(), + DataType::Varchar.to_oid(), + DataType::Varchar.type_len(), + ), + PgFieldDescriptor::new( + "User".to_owned(), + DataType::Varchar.to_oid(), + DataType::Varchar.type_len(), + ), + PgFieldDescriptor::new( + "Host".to_owned(), + DataType::Varchar.to_oid(), + DataType::Varchar.type_len(), + ), + PgFieldDescriptor::new( + "Database".to_owned(), + DataType::Varchar.to_oid(), + DataType::Varchar.type_len(), + ), + PgFieldDescriptor::new( + "Time".to_owned(), + DataType::Varchar.to_oid(), + DataType::Varchar.type_len(), + ), + PgFieldDescriptor::new( + "Info".to_owned(), + DataType::Varchar.to_oid(), + DataType::Varchar.type_len(), + ), + ], _ => vec![PgFieldDescriptor::new( "Name".to_owned(), DataType::Varchar.to_oid(), diff --git a/src/sqlparser/src/ast/mod.rs b/src/sqlparser/src/ast/mod.rs index d96016334b72c..757fb7f2a237b 100644 --- a/src/sqlparser/src/ast/mod.rs +++ b/src/sqlparser/src/ast/mod.rs @@ -858,6 +858,7 @@ pub enum ShowObject { Indexes { table: ObjectName }, Cluster, Jobs, + ProcessList, } #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -899,6 +900,7 @@ impl fmt::Display for ShowObject { write!(f, "CLUSTER") } ShowObject::Jobs => write!(f, "JOBS"), + ShowObject::ProcessList => write!(f, "PROCESSLIST"), } } } diff --git a/src/sqlparser/src/keywords.rs b/src/sqlparser/src/keywords.rs index 4188f06f76ae3..4efdce19605cd 100644 --- a/src/sqlparser/src/keywords.rs +++ b/src/sqlparser/src/keywords.rs @@ -383,6 +383,7 @@ define_keywords!( PRIMARY, PRIVILEGES, PROCEDURE, + PROCESSLIST, PURGE, RANGE, RANK, diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index 70050d920cdf3..85932fb65e5f3 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -4090,6 +4090,12 @@ impl Parser { filter: self.parse_show_statement_filter()?, }); } + Keyword::PROCESSLIST => { + return Ok(Statement::ShowObjects { + object: ShowObject::ProcessList, + filter: self.parse_show_statement_filter()?, + }); + } _ => {} } } diff --git a/src/tests/sqlsmith/tests/frontend/mod.rs b/src/tests/sqlsmith/tests/frontend/mod.rs index 8f681ab38a956..ec34c83a4405d 100644 --- a/src/tests/sqlsmith/tests/frontend/mod.rs +++ b/src/tests/sqlsmith/tests/frontend/mod.rs @@ -44,7 +44,7 @@ pub struct SqlsmithEnv { /// Returns `Ok(true)` if query result was ignored. /// Skip status is required, so that we know if a SQL statement writing to the database was skipped. /// Then, we can infer the correct state of the database. -async fn handle(session: Arc, stmt: Statement, sql: &str) -> Result { +async fn handle(session: Arc, stmt: Statement, sql: Arc) -> Result { let result = handler::handle(session.clone(), stmt, sql, vec![]) .await .map(|_| ()) @@ -97,18 +97,19 @@ async fn create_tables( let (mut tables, statements) = parse_create_table_statements(sql); for s in statements { - let create_sql = s.to_string(); - handle(session.clone(), s, &create_sql).await?; + let create_sql: Arc = Arc::from(s.to_string()); + handle(session.clone(), s, create_sql).await?; } // Generate some mviews for i in 0..20 { let (sql, table) = mview_sql_gen(rng, tables.clone(), &format!("m{}", i)); + let sql: Arc = Arc::from(sql); reproduce_failing_queries(&setup_sql, &sql); setup_sql.push_str(&format!("{};", &sql)); let stmts = parse_sql(&sql); let stmt = stmts[0].clone(); - let skipped = handle(session.clone(), stmt, &sql).await?; + let skipped = handle(session.clone(), stmt, sql).await?; if !skipped { tables.push(table); } @@ -158,15 +159,16 @@ async fn test_stream_query( } let (sql, table) = mview_sql_gen(&mut rng, tables.clone(), "stream_query"); + let sql: Arc = Arc::from(sql); reproduce_failing_queries(setup_sql, &sql); // The generated SQL must be parsable. let stmt = round_trip_parse_test(&sql)?; - let skipped = handle(session.clone(), stmt, &sql).await?; + let skipped = handle(session.clone(), stmt, sql).await?; if !skipped { - let drop_sql = format!("DROP MATERIALIZED VIEW {}", table.name); + let drop_sql: Arc = Arc::from(format!("DROP MATERIALIZED VIEW {}", table.name)); let drop_stmts = parse_sql(&drop_sql); let drop_stmt = drop_stmts[0].clone(); - handle(session.clone(), drop_stmt, &drop_sql).await?; + handle(session.clone(), drop_stmt, drop_sql).await?; } Ok(()) } @@ -215,13 +217,13 @@ fn test_batch_query( rng = SmallRng::seed_from_u64(seed); } - let sql = sql_gen(&mut rng, tables); + let sql: Arc = Arc::from(sql_gen(&mut rng, tables)); reproduce_failing_queries(setup_sql, &sql); // The generated SQL must be parsable. let stmt = round_trip_parse_test(&sql)?; let context: OptimizerContextRef = - OptimizerContext::from_handler_args(HandlerArgs::new(session.clone(), &stmt, &sql)?).into(); + OptimizerContext::from_handler_args(HandlerArgs::new(session.clone(), &stmt, sql)?).into(); match stmt { Statement::Query(_) => { diff --git a/src/utils/pgwire/src/net.rs b/src/utils/pgwire/src/net.rs index ce341dec3e742..7b2d9f76d8ebf 100644 --- a/src/utils/pgwire/src/net.rs +++ b/src/utils/pgwire/src/net.rs @@ -16,6 +16,7 @@ use std::io; use std::net::SocketAddr as IpSocketAddr; #[cfg(madsim)] use std::os::unix::net::SocketAddr as UnixSocketAddr; +use std::sync::Arc; #[cfg(not(madsim))] use tokio::net::unix::SocketAddr as UnixSocketAddr; @@ -35,11 +36,13 @@ pub(crate) enum Stream { } /// A wrapper of either [`std::net::SocketAddr`] or [`tokio::net::unix::SocketAddr`]. -pub(crate) enum Address { +pub enum Address { Tcp(IpSocketAddr), Unix(UnixSocketAddr), } +pub type AddressRef = Arc
; + impl std::fmt::Display for Address { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/src/utils/pgwire/src/pg_protocol.rs b/src/utils/pgwire/src/pg_protocol.rs index c1ff5db59be64..8bdbd90d17360 100644 --- a/src/utils/pgwire/src/pg_protocol.rs +++ b/src/utils/pgwire/src/pg_protocol.rs @@ -35,6 +35,7 @@ use tokio_openssl::SslStream; use tracing::{error, warn, Instrument}; use crate::error::{PsqlError, PsqlResult}; +use crate::net::AddressRef; use crate::pg_extended::ResultCache; use crate::pg_message::{ BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage, @@ -90,6 +91,9 @@ where // Used in extended query protocol. When encounter error in extended query, we need to ignore // the following message util sync message. ignore_util_sync: bool, + + // Client Address + peer_addr: AddressRef, } const PGWIRE_QUERY_LOG: &str = "pgwire_query_log"; @@ -154,7 +158,12 @@ where S: AsyncWrite + AsyncRead + Unpin, SM: SessionManager, { - pub fn new(stream: S, session_mgr: Arc, tls_config: Option) -> Self { + pub fn new( + stream: S, + session_mgr: Arc, + tls_config: Option, + peer_addr: AddressRef, + ) -> Self { Self { stream: Conn::Unencrypted(PgStream { stream: Some(stream), @@ -174,6 +183,7 @@ where portal_store: Default::default(), statement_portal_dependency: Default::default(), ignore_util_sync: false, + peer_addr, } } @@ -366,7 +376,7 @@ where let session = self .session_mgr - .connect(&db_name, &user_name) + .connect(&db_name, &user_name, self.peer_addr.clone()) .map_err(PsqlError::StartupError)?; let application_name = msg.config.get("application_name"); @@ -429,11 +439,15 @@ where } async fn process_query_msg(&mut self, query_string: io::Result<&str>) -> PsqlResult<()> { - let sql = query_string.map_err(|err| PsqlError::QueryError(Box::new(err)))?; + let sql: Arc = + Arc::from(query_string.map_err(|err| PsqlError::QueryError(Box::new(err)))?); let start = Instant::now(); let session = self.session.clone().unwrap(); let session_id = session.id().0; - let result = self.inner_process_query_msg(sql, session).await; + let _exec_context_guard = session.init_exec_context(sql.clone()); + let result = self + .inner_process_query_msg(sql.clone(), session.clone()) + .await; let mills = start.elapsed().as_millis(); @@ -451,11 +465,11 @@ where async fn inner_process_query_msg( &mut self, - sql: &str, + sql: Arc, session: Arc, ) -> PsqlResult<()> { // Parse sql. - let stmts = Parser::parse_sql(sql) + let stmts = Parser::parse_sql(&sql) .inspect_err(|e| tracing::error!("failed to parse sql:\n{}:\n{}", sql, e)) .map_err(|err| PsqlError::QueryError(err.into()))?; if stmts.is_empty() { @@ -700,9 +714,10 @@ where } else { let start = Instant::now(); let portal = self.get_portal(&portal_name)?; - let sql = format!("{}", portal); + let sql: Arc = Arc::from(format!("{}", portal)); - let result = session.execute(portal).await; + let _exec_context_guard = session.init_exec_context(sql.clone()); + let result = session.clone().execute(portal).await; let mills = start.elapsed().as_millis(); @@ -1043,7 +1058,7 @@ fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult { Ok(acceptor.into_context()) } -mod truncated_fmt { +pub mod truncated_fmt { use std::fmt::*; struct TruncatedFormatter<'a, 'b> { diff --git a/src/utils/pgwire/src/pg_server.rs b/src/utils/pgwire/src/pg_server.rs index 2734ff857735e..f561540797b1a 100644 --- a/src/utils/pgwire/src/pg_server.rs +++ b/src/utils/pgwire/src/pg_server.rs @@ -16,13 +16,14 @@ use std::future::Future; use std::io; use std::result::Result; use std::sync::Arc; +use std::time::Instant; use bytes::Bytes; use risingwave_common::types::DataType; use risingwave_sqlparser::ast::Statement; use tokio::io::{AsyncRead, AsyncWrite}; -use crate::net::Listener; +use crate::net::{AddressRef, Listener}; use crate::pg_field_descriptor::PgFieldDescriptor; use crate::pg_message::TransactionStatus; use crate::pg_protocol::{PgProtocol, TlsConfig}; @@ -37,7 +38,12 @@ pub type SessionId = (i32, i32); pub trait SessionManager: Send + Sync + 'static { type Session: Session; - fn connect(&self, database: &str, user_name: &str) -> Result, BoxedError>; + fn connect( + &self, + database: &str, + user_name: &str, + peer_addr: AddressRef, + ) -> Result, BoxedError>; fn cancel_queries_in_session(&self, session_id: SessionId); @@ -57,7 +63,7 @@ pub trait Session: Send + Sync { /// view, see . fn run_one_query( self: Arc, - sql: Statement, + stmt: Statement, format: Format, ) -> impl Future, BoxedError>> + Send; @@ -101,6 +107,26 @@ pub trait Session: Send + Sync { fn set_config(&self, key: &str, value: Vec) -> Result<(), BoxedError>; fn transaction_status(&self) -> TransactionStatus; + + fn init_exec_context(&self, sql: Arc) -> ExecContextGuard; +} + +/// Each session could run different SQLs multiple times. +/// `ExecContext` represents the lifetime of a running SQL in the current session. +pub struct ExecContext { + pub running_sql: Arc, + /// The instant of the running sql + pub last_instant: Instant, +} + +/// `ExecContextGuard` holds a `Arc` pointer. Once `ExecContextGuard` is dropped, +/// the inner `Arc` should not be referred anymore, so that its `Weak` reference (used in `SessionImpl`) will be the same lifecycle of the running sql execution context. +pub struct ExecContextGuard(Arc); + +impl ExecContextGuard { + pub fn new(exec_context: Arc) -> Self { + Self(exec_context) + } } #[derive(Debug, Clone)] @@ -146,6 +172,7 @@ pub async fn pg_serve( stream, session_mgr.clone(), tls_config.clone(), + Arc::new(peer_addr), )); } @@ -163,11 +190,12 @@ pub async fn handle_connection( stream: S, session_mgr: Arc, tls_config: Option, + peer_addr: AddressRef, ) where S: AsyncWrite + AsyncRead + Unpin, SM: SessionManager, { - let mut pg_proto = PgProtocol::new(stream, session_mgr, tls_config); + let mut pg_proto = PgProtocol::new(stream, session_mgr, tls_config, peer_addr); loop { let msg = match pg_proto.read_message().await { Ok(msg) => msg, @@ -191,6 +219,7 @@ pub async fn handle_connection( mod tests { use std::error::Error; use std::sync::Arc; + use std::time::Instant; use bytes::Bytes; use futures::stream::BoxStream; @@ -203,7 +232,8 @@ mod tests { use crate::pg_message::TransactionStatus; use crate::pg_response::{PgResponse, RowSetResult, StatementType}; use crate::pg_server::{ - pg_serve, BoxedError, Session, SessionId, SessionManager, UserAuthenticator, + pg_serve, BoxedError, ExecContext, ExecContextGuard, Session, SessionId, SessionManager, + UserAuthenticator, }; use crate::types; use crate::types::Row; @@ -218,6 +248,7 @@ mod tests { &self, _database: &str, _user_name: &str, + _peer_addr: crate::net::AddressRef, ) -> Result, Box> { Ok(Arc::new(MockSession {})) } @@ -240,7 +271,7 @@ mod tests { async fn run_one_query( self: Arc, - _sql: Statement, + _stmt: Statement, _format: types::Format, ) -> Result>, BoxedError> { Ok(PgResponse::builder(StatementType::SELECT) @@ -329,6 +360,14 @@ mod tests { fn transaction_status(&self) -> TransactionStatus { TransactionStatus::Idle } + + fn init_exec_context(&self, sql: Arc) -> ExecContextGuard { + let exec_context = Arc::new(ExecContext { + running_sql: sql, + last_instant: Instant::now(), + }); + ExecContextGuard::new(exec_context) + } } async fn do_test_query(bind_addr: impl Into, pg_config: impl Into) {