diff --git a/Cargo.lock b/Cargo.lock index 199455ab8e5a8..485d2435e1470 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11412,6 +11412,7 @@ name = "risingwave_sqlparser" version = "1.9.0-alpha" dependencies = [ "itertools 0.12.1", + "madsim-tokio", "matches", "serde", "thiserror", diff --git a/src/common/src/config.rs b/src/common/src/config.rs index d2037cd70ae48..99aae245125eb 100644 --- a/src/common/src/config.rs +++ b/src/common/src/config.rs @@ -527,6 +527,11 @@ pub struct BatchConfig { /// This is the secs used to mask a worker unavailable temporarily. #[serde(default = "default::batch::mask_worker_temporary_secs")] pub mask_worker_temporary_secs: usize, + + /// Keywords on which SQL option redaction is based in the query log. + /// A SQL option with a name containing any of these keywords will be redacted. + #[serde(default = "default::batch::redact_sql_option_keywords")] + pub redact_sql_option_keywords: Vec, } /// The section `[streaming]` in `risingwave.toml`. @@ -1749,6 +1754,20 @@ pub mod default { pub fn mask_worker_temporary_secs() -> usize { 30 } + + pub fn redact_sql_option_keywords() -> Vec { + [ + "credential", + "key", + "password", + "private", + "secret", + "token", + ] + .into_iter() + .map(str::to_string) + .collect() + } } pub mod compaction_config { diff --git a/src/config/docs.md b/src/config/docs.md index 22548a9ff97bd..aea210f5235af 100644 --- a/src/config/docs.md +++ b/src/config/docs.md @@ -11,6 +11,7 @@ This page is automatically generated by `./risedev generate-example-config` | frontend_compute_runtime_worker_threads | frontend compute runtime worker threads | 4 | | mask_worker_temporary_secs | This is the secs used to mask a worker unavailable temporarily. | 30 | | max_batch_queries_per_frontend_node | This is the max number of batch queries per frontend node. | | +| redact_sql_option_keywords | Keywords on which SQL option redaction is based in the query log. A SQL option with a name containing any of these keywords will be redacted. | ["credential", "key", "password", "private", "secret", "token"] | | statement_timeout_in_sec | Timeout for a batch query in seconds. | 3600 | | worker_threads_num | The thread number of the batch task runtime in the compute node. The default value is decided by `tokio`. | | diff --git a/src/config/example.toml b/src/config/example.toml index 36874e3cfea85..fc70258788bbc 100644 --- a/src/config/example.toml +++ b/src/config/example.toml @@ -85,6 +85,7 @@ enable_barrier_read = false statement_timeout_in_sec = 3600 frontend_compute_runtime_worker_threads = 4 mask_worker_temporary_secs = 30 +redact_sql_option_keywords = ["credential", "key", "password", "private", "secret", "token"] [batch.developer] batch_connector_message_buffer_size = 16 diff --git a/src/frontend/src/lib.rs b/src/frontend/src/lib.rs index 2c0eac6d0ad88..5f0c03061c0ab 100644 --- a/src/frontend/src/lib.rs +++ b/src/frontend/src/lib.rs @@ -42,6 +42,9 @@ risingwave_expr_impl::enable!(); #[macro_use] mod catalog; + +use std::collections::HashSet; + pub use catalog::TableCatalog; mod binder; pub use binder::{bind_data_type, Binder}; @@ -168,8 +171,22 @@ pub fn start(opts: FrontendOpts) -> Pin + Send>> { Box::pin(async move { let listen_addr = opts.listen_addr.clone(); let session_mgr = Arc::new(SessionManagerImpl::new(opts).await.unwrap()); - pg_serve(&listen_addr, session_mgr, TlsConfig::new_default()) - .await - .unwrap(); + let redact_sql_option_keywords = Arc::new( + session_mgr + .env() + .batch_config() + .redact_sql_option_keywords + .iter() + .map(|s| s.to_lowercase()) + .collect::>(), + ); + pg_serve( + &listen_addr, + session_mgr, + TlsConfig::new_default(), + Some(redact_sql_option_keywords), + ) + .await + .unwrap() }) } diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index 071608ac2f5e2..2fa728194a17d 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -1156,6 +1156,10 @@ impl SessionManagerImpl { }) } + pub fn env(&self) -> &FrontendEnv { + &self.env + } + fn insert_session(&self, session: Arc) { let active_sessions = { let mut write_guard = self.env.sessions_map.write(); diff --git a/src/sqlparser/Cargo.toml b/src/sqlparser/Cargo.toml index 8c20eb9bf2f29..b47803339e646 100644 --- a/src/sqlparser/Cargo.toml +++ b/src/sqlparser/Cargo.toml @@ -25,6 +25,7 @@ normal = ["workspace-hack"] itertools = { workspace = true } serde = { version = "1.0", features = ["derive"], optional = true } thiserror = "1.0.61" +tokio = { version = "0.2", package = "madsim-tokio" } tracing = "0.1" tracing-subscriber = "0.3" winnow = { version = "0.6.8", git = "https://github.com/TennyZhuang/winnow.git", rev = "a6b1f04" } diff --git a/src/sqlparser/src/ast/mod.rs b/src/sqlparser/src/ast/mod.rs index 11a7f350f3578..2c3aa67cfaf36 100644 --- a/src/sqlparser/src/ast/mod.rs +++ b/src/sqlparser/src/ast/mod.rs @@ -27,6 +27,8 @@ use alloc::{ }; use core::fmt; use core::fmt::Display; +use std::collections::HashSet; +use std::sync::Arc; use itertools::Itertools; #[cfg(feature = "serde")] @@ -59,6 +61,12 @@ pub use crate::ast::ddl::{ use crate::keywords::Keyword; use crate::parser::{IncludeOption, IncludeOptionItem, Parser, ParserError}; +pub type RedactSqlOptionKeywordsRef = Arc>; + +tokio::task_local! { + pub static REDACT_SQL_OPTION_KEYWORDS: RedactSqlOptionKeywordsRef; +} + pub struct DisplaySeparated<'a, T> where T: fmt::Display, @@ -2584,7 +2592,17 @@ pub struct SqlOption { impl fmt::Display for SqlOption { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{} = {}", self.name, self.value) + let should_redact = REDACT_SQL_OPTION_KEYWORDS + .try_with(|keywords| { + let sql_option_name = self.name.real_value().to_lowercase(); + keywords.iter().any(|k| sql_option_name.contains(k)) + }) + .unwrap_or(false); + if should_redact { + write!(f, "{} = [REDACTED]", self.name) + } else { + write!(f, "{} = {}", self.name, self.value) + } } } @@ -3142,6 +3160,12 @@ impl fmt::Display for DiscardType { } } +impl Statement { + pub fn to_redacted_string(&self, keywords: RedactSqlOptionKeywordsRef) -> String { + REDACT_SQL_OPTION_KEYWORDS.sync_scope(keywords, || self.to_string()) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/utils/pgwire/src/pg_protocol.rs b/src/utils/pgwire/src/pg_protocol.rs index 5e9e7a056f261..d700e39757df1 100644 --- a/src/utils/pgwire/src/pg_protocol.rs +++ b/src/utils/pgwire/src/pg_protocol.rs @@ -29,7 +29,7 @@ use risingwave_common::types::DataType; use risingwave_common::util::panic::FutureCatchUnwindExt; use risingwave_common::util::query_log::*; use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS}; -use risingwave_sqlparser::ast::Statement; +use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement}; use risingwave_sqlparser::parser::Parser; use thiserror_ext::AsReport; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; @@ -101,6 +101,8 @@ where // Client Address peer_addr: AddressRef, + + redact_sql_option_keywords: Option, } /// Configures TLS encryption for connections. @@ -152,16 +154,31 @@ pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> { } /// Record `sql` in the current tracing span. -fn record_sql_in_span(sql: &str) { +fn record_sql_in_span(sql: &str, redact_sql_option_keywords: Option) { + let redacted_sql = if let Some(keywords) = redact_sql_option_keywords { + redact_sql(sql, keywords) + } else { + sql.to_owned() + }; tracing::Span::current().record( "sql", tracing::field::display(truncated_fmt::TruncatedFmt( - &sql, + &redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN, )), ); } +fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String { + match Parser::parse_sql(sql) { + Ok(sqls) => sqls + .into_iter() + .map(|sql| sql.to_redacted_string(keywords.clone())) + .join(";"), + Err(_) => sql.to_owned(), + } +} + impl PgProtocol where S: AsyncWrite + AsyncRead + Unpin, @@ -172,6 +189,7 @@ where session_mgr: Arc, tls_config: Option, peer_addr: AddressRef, + redact_sql_option_keywords: Option, ) -> Self { Self { stream: Conn::Unencrypted(PgStream { @@ -193,6 +211,7 @@ where statement_portal_dependency: Default::default(), ignore_util_sync: false, peer_addr, + redact_sql_option_keywords, } } @@ -555,7 +574,7 @@ where async fn process_query_msg(&mut self, query_string: io::Result<&str>) -> PsqlResult<()> { let sql: Arc = Arc::from(query_string.map_err(|err| PsqlError::SimpleQueryError(Box::new(err)))?); - record_sql_in_span(&sql); + record_sql_in_span(&sql, self.redact_sql_option_keywords.clone()); let session = self.session.clone().unwrap(); session.check_idle_in_transaction_timeout()?; @@ -664,7 +683,7 @@ where fn process_parse_msg(&mut self, msg: FeParseMessage) -> PsqlResult<()> { let sql = cstr_to_str(&msg.sql_bytes).unwrap(); - record_sql_in_span(sql); + record_sql_in_span(sql, self.redact_sql_option_keywords.clone()); let session = self.session.clone().unwrap(); let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_string(); @@ -798,7 +817,7 @@ where } else { let portal = self.get_portal(&portal_name)?; let sql: Arc = Arc::from(format!("{}", portal)); - record_sql_in_span(&sql); + record_sql_in_span(&sql, self.redact_sql_option_keywords.clone()); session.check_idle_in_transaction_timeout()?; let _exec_context_guard = session.init_exec_context(sql.clone()); @@ -1205,3 +1224,25 @@ pub mod truncated_fmt { } } } + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use super::*; + + #[test] + fn test_redact_parsable_sql() { + let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()])); + let sql = r" + create source temp (k bigint, v varchar) with ( + connector = 'datagen', + v1 = 123, + v2 = 'with', + v3 = false, + v4 = '', + ) FORMAT plain ENCODE json (a='1',b='2') + "; + assert_eq!(redact_sql(sql, keywords), "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"); + } +} diff --git a/src/utils/pgwire/src/pg_server.rs b/src/utils/pgwire/src/pg_server.rs index 6d53e7049c34d..d2882fdceafbb 100644 --- a/src/utils/pgwire/src/pg_server.rs +++ b/src/utils/pgwire/src/pg_server.rs @@ -23,7 +23,7 @@ use bytes::Bytes; use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation}; use parking_lot::Mutex; use risingwave_common::types::DataType; -use risingwave_sqlparser::ast::Statement; +use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement}; use serde::Deserialize; use thiserror_ext::AsReport; use tokio::io::{AsyncRead, AsyncWrite}; @@ -251,6 +251,7 @@ pub async fn pg_serve( addr: &str, session_mgr: Arc, tls_config: Option, + redact_sql_option_keywords: Option, ) -> io::Result<()> { let listener = Listener::bind(addr).await?; tracing::info!(addr, "server started"); @@ -281,6 +282,7 @@ pub async fn pg_serve( session_mgr.clone(), tls_config.clone(), Arc::new(peer_addr), + redact_sql_option_keywords.clone(), )); } @@ -299,11 +301,18 @@ pub async fn handle_connection( session_mgr: Arc, tls_config: Option, peer_addr: AddressRef, + redact_sql_option_keywords: Option, ) where S: AsyncWrite + AsyncRead + Unpin, SM: SessionManager, { - let mut pg_proto = PgProtocol::new(stream, session_mgr, tls_config, peer_addr); + let mut pg_proto = PgProtocol::new( + stream, + session_mgr, + tls_config, + peer_addr, + redact_sql_option_keywords, + ); loop { let msg = match pg_proto.read_message().await { Ok(msg) => msg, @@ -486,7 +495,7 @@ mod tests { let pg_config = pg_config.into(); let session_mgr = Arc::new(MockSessionManager {}); - tokio::spawn(async move { pg_serve(&bind_addr, session_mgr, None).await }); + tokio::spawn(async move { pg_serve(&bind_addr, session_mgr, None, None).await }); // wait for server to start tokio::time::sleep(std::time::Duration::from_millis(100)).await;