diff --git a/src/batch/src/task/env.rs b/src/batch/src/task/env.rs index 58631cb9563a0..1487e8a60dee0 100644 --- a/src/batch/src/task/env.rs +++ b/src/batch/src/task/env.rs @@ -50,7 +50,7 @@ pub struct BatchEnvironment { /// Executor level metrics. executor_metrics: Arc, - /// Compute client pool for grpc exchange. + /// Compute client pool for batch gRPC exchange. client_pool: ComputeClientPoolRef, /// Manages dml information. @@ -111,7 +111,7 @@ impl BatchEnvironment { MonitoredStorageMetrics::unused(), )), task_metrics: Arc::new(BatchTaskMetrics::for_test()), - client_pool: Arc::new(ComputeClientPool::default()), + client_pool: Arc::new(ComputeClientPool::for_test()), dml_manager: Arc::new(DmlManager::for_test()), source_metrics: Arc::new(SourceMetrics::default()), executor_metrics: Arc::new(BatchExecutorMetrics::for_test()), diff --git a/src/compute/src/server.rs b/src/compute/src/server.rs index 75f545f9c7684..1247336f7fd70 100644 --- a/src/compute/src/server.rs +++ b/src/compute/src/server.rs @@ -308,7 +308,7 @@ pub async fn compute_node_serve( )); // Initialize batch environment. - let client_pool = Arc::new(ComputeClientPool::new(config.server.connection_pool_size)); + let batch_client_pool = Arc::new(ComputeClientPool::new(config.server.connection_pool_size)); let batch_env = BatchEnvironment::new( batch_mgr.clone(), advertise_addr.clone(), @@ -317,7 +317,7 @@ pub async fn compute_node_serve( state_store.clone(), batch_task_metrics.clone(), batch_executor_metrics.clone(), - client_pool, + batch_client_pool, dml_mgr.clone(), source_metrics.clone(), config.server.metrics_level, @@ -342,6 +342,7 @@ pub async fn compute_node_serve( }; // Initialize the streaming environment. + let stream_client_pool = Arc::new(ComputeClientPool::new(config.server.connection_pool_size)); let stream_env = StreamEnvironment::new( advertise_addr.clone(), connector_params, @@ -352,6 +353,7 @@ pub async fn compute_node_serve( system_params_manager.clone(), source_metrics, meta_client.clone(), + stream_client_pool, ); let stream_mgr = LocalStreamManager::new( diff --git a/src/ctl/src/cmd_impl/await_tree.rs b/src/ctl/src/cmd_impl/await_tree.rs index 1c4ff98562791..00d8d8461d340 100644 --- a/src/ctl/src/cmd_impl/await_tree.rs +++ b/src/ctl/src/cmd_impl/await_tree.rs @@ -28,7 +28,7 @@ pub async fn dump(context: &CtlContext) -> anyhow::Result<()> { let compute_nodes = meta_client .list_worker_nodes(Some(WorkerType::ComputeNode)) .await?; - let clients = ComputeClientPool::default(); + let clients = ComputeClientPool::adhoc(); // FIXME: the compute node may not be accessible directly from risectl, we may let the meta // service collect the reports from all compute nodes in the future. diff --git a/src/ctl/src/cmd_impl/profile.rs b/src/ctl/src/cmd_impl/profile.rs index edb48df5aeb39..bc7897a74c11c 100644 --- a/src/ctl/src/cmd_impl/profile.rs +++ b/src/ctl/src/cmd_impl/profile.rs @@ -33,7 +33,7 @@ pub async fn cpu_profile(context: &CtlContext, sleep_s: u64) -> anyhow::Result<( .into_iter() .filter(|w| w.r#type() == WorkerType::ComputeNode); - let clients = ComputeClientPool::default(); + let clients = ComputeClientPool::adhoc(); let profile_root_path = std::env::var("PREFIX_PROFILING").unwrap_or_else(|_| { tracing::info!("PREFIX_PROFILING is not set, using current directory"); @@ -96,7 +96,7 @@ pub async fn heap_profile(context: &CtlContext, dir: Option) -> anyhow:: .into_iter() .filter(|w| w.r#type() == WorkerType::ComputeNode); - let clients = ComputeClientPool::default(); + let clients = ComputeClientPool::adhoc(); let mut profile_futs = vec![]; diff --git a/src/frontend/src/scheduler/distributed/query.rs b/src/frontend/src/scheduler/distributed/query.rs index c6e866630067b..f290837e3312d 100644 --- a/src/frontend/src/scheduler/distributed/query.rs +++ b/src/frontend/src/scheduler/distributed/query.rs @@ -508,7 +508,7 @@ pub(crate) mod tests { async fn test_query_should_not_hang_with_empty_worker() { let worker_node_manager = Arc::new(WorkerNodeManager::mock(vec![])); let worker_node_selector = WorkerNodeSelector::new(worker_node_manager.clone(), false); - let compute_client_pool = Arc::new(ComputeClientPool::default()); + let compute_client_pool = Arc::new(ComputeClientPool::for_test()); let hummock_snapshot_manager = Arc::new(HummockSnapshotManager::new(Arc::new( MockFrontendMetaClient {}, ))); diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index 130212f676f7a..16deee7848183 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -178,7 +178,7 @@ impl FrontendEnv { let meta_client = Arc::new(MockFrontendMetaClient {}); let hummock_snapshot_manager = Arc::new(HummockSnapshotManager::new(meta_client.clone())); let system_params_manager = Arc::new(LocalSystemParamsManager::for_test()); - let compute_client_pool = Arc::new(ComputeClientPool::default()); + let compute_client_pool = Arc::new(ComputeClientPool::for_test()); let query_manager = QueryManager::new( worker_node_manager.clone(), compute_client_pool, @@ -188,7 +188,7 @@ impl FrontendEnv { None, ); let server_addr = HostAddr::try_from("127.0.0.1:4565").unwrap(); - let client_pool = Arc::new(ComputeClientPool::default()); + let client_pool = Arc::new(ComputeClientPool::for_test()); let creating_streaming_tracker = StreamingJobTracker::new(meta_client.clone()); let compute_runtime = Arc::new(BackgroundShutdownRuntime::from( Builder::new_multi_thread() diff --git a/src/meta/node/src/server.rs b/src/meta/node/src/server.rs index 43e564271702a..10609cc7e2fc2 100644 --- a/src/meta/node/src/server.rs +++ b/src/meta/node/src/server.rs @@ -518,7 +518,7 @@ pub async fn start_service_as_election_leader( prometheus_client, prometheus_selector, metadata_manager: metadata_manager.clone(), - compute_clients: ComputeClientPool::default(), + compute_clients: ComputeClientPool::new(1), // typically no need for plural clients diagnose_command, trace_state, }; diff --git a/src/meta/src/manager/diagnose.rs b/src/meta/src/manager/diagnose.rs index 06c76c47c5daa..d70c56ecd3f2d 100644 --- a/src/meta/src/manager/diagnose.rs +++ b/src/meta/src/manager/diagnose.rs @@ -667,7 +667,7 @@ impl DiagnoseCommand { let mut all = StackTraceResponse::default(); - let compute_clients = ComputeClientPool::default(); + let compute_clients = ComputeClientPool::adhoc(); for worker_node in &worker_nodes { if let Ok(client) = compute_clients.get(worker_node).await && let Ok(result) = client.stack_trace().await diff --git a/src/meta/src/manager/env.rs b/src/meta/src/manager/env.rs index 5006f5864e84f..9d171000e7577 100644 --- a/src/meta/src/manager/env.rs +++ b/src/meta/src/manager/env.rs @@ -349,7 +349,7 @@ impl MetaSrvEnv { let notification_manager = Arc::new(NotificationManager::new(meta_store_impl.clone()).await); let idle_manager = Arc::new(IdleManager::new(opts.max_idle_ms)); - let stream_client_pool = Arc::new(StreamClientPool::default()); + let stream_client_pool = Arc::new(StreamClientPool::new(1)); // typically no need for plural clients let event_log_manager = Arc::new(start_event_log_manager( opts.event_log_enabled, opts.event_log_channel_max_size, diff --git a/src/rpc_client/src/compute_client.rs b/src/rpc_client/src/compute_client.rs index f908bb21aa3a2..641f56324d47b 100644 --- a/src/rpc_client/src/compute_client.rs +++ b/src/rpc_client/src/compute_client.rs @@ -278,4 +278,4 @@ impl RpcClient for ComputeClient { } pub type ComputeClientPool = RpcClientPool; -pub type ComputeClientPoolRef = Arc; +pub type ComputeClientPoolRef = Arc; // TODO: no need for `Arc` since clone is cheap and shared diff --git a/src/rpc_client/src/lib.rs b/src/rpc_client/src/lib.rs index fa276bdd0a5ce..bb1d90dcffbf4 100644 --- a/src/rpc_client/src/lib.rs +++ b/src/rpc_client/src/lib.rs @@ -88,19 +88,25 @@ pub struct RpcClientPool { clients: Cache>>, } -impl Default for RpcClientPool -where - S: RpcClient, -{ - fn default() -> Self { - Self::new(1) +impl std::fmt::Debug for RpcClientPool { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RpcClientPool") + .field("connection_pool_size", &self.connection_pool_size) + .field("type", &type_name::()) + .field("len", &self.clients.entry_count()) + .finish() } } +/// Intentionally not implementing `Default` to let callers be explicit about the pool size. +impl !Default for RpcClientPool {} + impl RpcClientPool where S: RpcClient, { + /// Create a new pool with the given `connection_pool_size`, which is the number of + /// connections to each node that will be reused. pub fn new(connection_pool_size: u16) -> Self { Self { connection_pool_size, @@ -108,6 +114,16 @@ where } } + /// Create a pool for testing purposes. Same as [`Self::adhoc`]. + pub fn for_test() -> Self { + Self::adhoc() + } + + /// Create a pool for ad-hoc usage, where the number of connections to each node is 1. + pub fn adhoc() -> Self { + Self::new(1) + } + /// Gets the RPC client for the given node. If the connection is not established, a /// new client will be created and returned. pub async fn get(&self, node: &WorkerNode) -> Result { diff --git a/src/stream/src/executor/exchange/input.rs b/src/stream/src/executor/exchange/input.rs index 11796441326aa..cefe8a8613d46 100644 --- a/src/stream/src/executor/exchange/input.rs +++ b/src/stream/src/executor/exchange/input.rs @@ -264,7 +264,7 @@ pub(crate) fn new_input( } else { RemoteInput::new( context.local_barrier_manager.clone(), - context.compute_client_pool.clone(), + context.compute_client_pool.as_ref().to_owned(), upstream_addr, (upstream_actor_id, actor_id), (upstream_fragment_id, fragment_id), diff --git a/src/stream/src/executor/merge.rs b/src/stream/src/executor/merge.rs index d453d6979ac76..0c9cec77f3a54 100644 --- a/src/stream/src/executor/merge.rs +++ b/src/stream/src/executor/merge.rs @@ -753,7 +753,7 @@ mod tests { assert!(server_run.load(Ordering::SeqCst)); let remote_input = { - let pool = ComputeClientPool::default(); + let pool = ComputeClientPool::for_test(); RemoteInput::new( LocalBarrierManager::for_test(), pool, diff --git a/src/stream/src/task/barrier_manager.rs b/src/stream/src/task/barrier_manager.rs index c5d564d9c6ab8..40104fdd83d7c 100644 --- a/src/stream/src/task/barrier_manager.rs +++ b/src/stream/src/task/barrier_manager.rs @@ -348,8 +348,7 @@ impl LocalBarrierWorker { let (event_tx, event_rx) = unbounded_channel(); let (failure_tx, failure_rx) = unbounded_channel(); let shared_context = Arc::new(SharedContext::new( - actor_manager.env.server_address().clone(), - actor_manager.env.config(), + &actor_manager.env, LocalBarrierManager { barrier_event_sender: event_tx, actor_failure_sender: failure_tx, diff --git a/src/stream/src/task/env.rs b/src/stream/src/task/env.rs index 9a0b26f25f0c5..fc405b12d06fb 100644 --- a/src/stream/src/task/env.rs +++ b/src/stream/src/task/env.rs @@ -21,7 +21,7 @@ use risingwave_common::util::addr::HostAddr; use risingwave_connector::source::monitor::SourceMetrics; use risingwave_connector::ConnectorParams; use risingwave_dml::dml_manager::DmlManagerRef; -use risingwave_rpc_client::MetaClient; +use risingwave_rpc_client::{ComputeClientPoolRef, MetaClient}; use risingwave_storage::StateStoreImpl; pub(crate) type WorkerNodeId = u32; @@ -59,6 +59,9 @@ pub struct StreamEnvironment { /// Meta client. Use `None` for test only meta_client: Option, + + /// Compute client pool for streaming gRPC exchange. + client_pool: ComputeClientPoolRef, } impl StreamEnvironment { @@ -73,6 +76,7 @@ impl StreamEnvironment { system_params_manager: LocalSystemParamsManagerRef, source_metrics: Arc, meta_client: MetaClient, + client_pool: ComputeClientPoolRef, ) -> Self { StreamEnvironment { server_addr, @@ -85,6 +89,7 @@ impl StreamEnvironment { source_metrics, total_mem_val: Arc::new(TrAdder::new()), meta_client: Some(meta_client), + client_pool, } } @@ -94,6 +99,7 @@ impl StreamEnvironment { use risingwave_common::system_param::local_manager::LocalSystemParamsManager; use risingwave_dml::dml_manager::DmlManager; use risingwave_pb::connector_service::SinkPayloadFormat; + use risingwave_rpc_client::ComputeClientPool; use risingwave_storage::monitor::MonitoredStorageMetrics; StreamEnvironment { server_addr: "127.0.0.1:5688".parse().unwrap(), @@ -108,6 +114,7 @@ impl StreamEnvironment { source_metrics: Arc::new(SourceMetrics::default()), total_mem_val: Arc::new(TrAdder::new()), meta_client: None, + client_pool: Arc::new(ComputeClientPool::for_test()), } } @@ -150,4 +157,8 @@ impl StreamEnvironment { pub fn meta_client(&self) -> Option { self.meta_client.clone() } + + pub fn client_pool(&self) -> ComputeClientPoolRef { + self.client_pool.clone() + } } diff --git a/src/stream/src/task/mod.rs b/src/stream/src/task/mod.rs index 7a6fd40f9231a..b0cc0d45b6acc 100644 --- a/src/stream/src/task/mod.rs +++ b/src/stream/src/task/mod.rs @@ -19,7 +19,7 @@ use parking_lot::{MappedMutexGuard, Mutex, MutexGuard, RwLock}; use risingwave_common::config::StreamingConfig; use risingwave_common::util::addr::HostAddr; use risingwave_pb::common::ActorInfo; -use risingwave_rpc_client::ComputeClientPool; +use risingwave_rpc_client::ComputeClientPoolRef; use crate::error::StreamResult; use crate::executor::exchange::permit::{self, Receiver, Sender}; @@ -75,10 +75,10 @@ pub struct SharedContext { /// between two actors/actors. pub(crate) addr: HostAddr, - /// The pool of compute clients. + /// Compute client pool for streaming gRPC exchange. // TODO: currently the client pool won't be cleared. Should remove compute clients when // disconnected. - pub(crate) compute_client_pool: ComputeClientPool, + pub(crate) compute_client_pool: ComputeClientPoolRef, pub(crate) config: StreamingConfig, @@ -94,30 +94,28 @@ impl std::fmt::Debug for SharedContext { } impl SharedContext { - pub fn new( - addr: HostAddr, - config: &StreamingConfig, - local_barrier_manager: LocalBarrierManager, - ) -> Self { + pub fn new(env: &StreamEnvironment, local_barrier_manager: LocalBarrierManager) -> Self { Self { channel_map: Default::default(), actor_infos: Default::default(), - addr, - compute_client_pool: ComputeClientPool::default(), - config: config.clone(), + addr: env.server_address().clone(), + config: env.config().as_ref().to_owned(), + compute_client_pool: env.client_pool(), local_barrier_manager, } } #[cfg(test)] pub fn for_test() -> Self { + use std::sync::Arc; + use risingwave_common::config::StreamingDeveloperConfig; + use risingwave_rpc_client::ComputeClientPool; Self { channel_map: Default::default(), actor_infos: Default::default(), addr: LOCAL_TEST_ADDR.clone(), - compute_client_pool: ComputeClientPool::default(), config: StreamingConfig { developer: StreamingDeveloperConfig { exchange_initial_permits: permit::for_test::INITIAL_PERMITS, @@ -127,6 +125,7 @@ impl SharedContext { }, ..Default::default() }, + compute_client_pool: Arc::new(ComputeClientPool::for_test()), local_barrier_manager: LocalBarrierManager::for_test(), } }