diff --git a/Cargo.lock b/Cargo.lock index d007b6c6f4f3e..c028ec24b9e6e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8470,6 +8470,7 @@ dependencies = [ "risingwave_sqlparser", "serde", "serde_json", + "socket2 0.5.6", "tempfile", "thiserror", "thiserror-ext", diff --git a/src/cmd_all/src/standalone.rs b/src/cmd_all/src/standalone.rs index 2d09461959d0b..fd3e950f34d69 100644 --- a/src/cmd_all/src/standalone.rs +++ b/src/cmd_all/src/standalone.rs @@ -498,6 +498,7 @@ mod test { frontend_opts: Some( FrontendOpts { listen_addr: "0.0.0.0:4566", + tcp_keepalive_idle_secs: 300, advertise_addr: None, meta_addr: List( [ diff --git a/src/frontend/src/lib.rs b/src/frontend/src/lib.rs index 003700feb6763..d8b484e3d6fa2 100644 --- a/src/frontend/src/lib.rs +++ b/src/frontend/src/lib.rs @@ -44,6 +44,7 @@ risingwave_expr_impl::enable!(); mod catalog; use std::collections::HashSet; +use std::time::Duration; pub use catalog::TableCatalog; mod binder; @@ -55,6 +56,7 @@ mod observer; pub mod optimizer; pub use optimizer::{Explain, OptimizerContext, OptimizerContextRef, PlanRef}; mod planner; +use pgwire::net::TcpKeepalive; pub use planner::Planner; mod scheduler; pub mod session; @@ -97,6 +99,11 @@ pub struct FrontendOpts { #[clap(long, env = "RW_LISTEN_ADDR", default_value = "0.0.0.0:4566")] pub listen_addr: String, + /// The amount of time with no network activity after which the server will send a + /// TCP keepalive message to the client. + #[clap(long, env = "RW_TCP_KEEPALIVE_IDLE_SECS", default_value = "300")] + pub tcp_keepalive_idle_secs: usize, + /// The address for contacting this instance of the service. /// This would be synonymous with the service's "public address" /// or "identifying address". @@ -187,6 +194,9 @@ pub fn start( // slow compile in release mode. Box::pin(async move { let listen_addr = opts.listen_addr.clone(); + let tcp_keepalive = + TcpKeepalive::new().with_time(Duration::from_secs(opts.tcp_keepalive_idle_secs as _)); + let session_mgr = Arc::new(SessionManagerImpl::new(opts).await.unwrap()); SESSION_MANAGER.get_or_init(|| session_mgr.clone()); let redact_sql_option_keywords = Arc::new( @@ -201,6 +211,7 @@ pub fn start( pg_serve( &listen_addr, + tcp_keepalive, session_mgr.clone(), TlsConfig::new_default(), Some(redact_sql_option_keywords), diff --git a/src/utils/pgwire/Cargo.toml b/src/utils/pgwire/Cargo.toml index 0d29afea01855..d5bd081b92d6a 100644 --- a/src/utils/pgwire/Cargo.toml +++ b/src/utils/pgwire/Cargo.toml @@ -30,6 +30,7 @@ risingwave_common = { workspace = true } risingwave_sqlparser = { workspace = true } serde = { version = "1", features = ["derive"] } serde_json = "1" +socket2 = "0.5" thiserror = "1" thiserror-ext = { workspace = true } tokio = { version = "0.2", package = "madsim-tokio", features = [ diff --git a/src/utils/pgwire/src/net.rs b/src/utils/pgwire/src/net.rs index 7fc54676af1ca..8d89272c4b859 100644 --- a/src/utils/pgwire/src/net.rs +++ b/src/utils/pgwire/src/net.rs @@ -74,11 +74,18 @@ impl Listener { /// Accepts a new incoming connection from this listener. /// /// Returns a tuple of the stream and the string representation of the peer address. - pub async fn accept(&self) -> io::Result<(Stream, Address)> { + pub async fn accept(&self, tcp_keepalive: &TcpKeepalive) -> io::Result<(Stream, Address)> { match self { Self::Tcp(listener) => { let (stream, addr) = listener.accept().await?; stream.set_nodelay(true)?; + // Set TCP keepalive to 5 minutes, which is less than the connection idle timeout of 350 seconds in AWS ELB. + // https://docs.aws.amazon.com/elasticloadbalancing/latest/network/network-load-balancers.html#connection-idle-timeout + #[cfg(not(madsim))] + { + let r = socket2::SockRef::from(&stream); + r.set_tcp_keepalive(tcp_keepalive)?; + } Ok((Stream::Tcp(stream), Address::Tcp(addr))) } Self::Unix(listener) => { @@ -88,3 +95,5 @@ impl Listener { } } } + +pub use socket2::TcpKeepalive; diff --git a/src/utils/pgwire/src/pg_server.rs b/src/utils/pgwire/src/pg_server.rs index 99cbd58bc20aa..e637b4bb3a2e3 100644 --- a/src/utils/pgwire/src/pg_server.rs +++ b/src/utils/pgwire/src/pg_server.rs @@ -30,7 +30,7 @@ use thiserror_ext::AsReport; use tokio::io::{AsyncRead, AsyncWrite}; use crate::error::{PsqlError, PsqlResult}; -use crate::net::{AddressRef, Listener}; +use crate::net::{AddressRef, Listener, TcpKeepalive}; use crate::pg_field_descriptor::PgFieldDescriptor; use crate::pg_message::TransactionStatus; use crate::pg_protocol::{PgProtocol, TlsConfig}; @@ -265,6 +265,7 @@ impl UserAuthenticator { /// Returns when the `shutdown` token is triggered. pub async fn pg_serve( addr: &str, + tcp_keepalive: TcpKeepalive, session_mgr: Arc, tls_config: Option, redact_sql_option_keywords: Option, @@ -291,7 +292,7 @@ pub async fn pg_serve( let session_mgr_clone = session_mgr.clone(); let f = async move { loop { - let conn_ret = listener.accept().await; + let conn_ret = listener.accept(&tcp_keepalive).await; match conn_ret { Ok((stream, peer_addr)) => { tracing::info!(%peer_addr, "accept connection"); @@ -534,6 +535,7 @@ mod tests { tokio::spawn(async move { pg_serve( &bind_addr, + socket2::TcpKeepalive::new(), Arc::new(session_mgr), None, None,