diff --git a/benchmarks/src/lib.rs b/benchmarks/src/lib.rs index d5521fd37..664e4fb86 100644 --- a/benchmarks/src/lib.rs +++ b/benchmarks/src/lib.rs @@ -18,7 +18,7 @@ use futures_util::{future, TryFutureExt}; use http::header::CONTENT_TYPE; use http::Uri; use pprof::flamegraph::Options; -use restate_core::{TaskCenter, TaskCenterBuilder, TaskKind}; +use restate_core::{task_center, TaskCenter, TaskCenterBuilder, TaskKind}; use restate_node::Node; use restate_rocksdb::RocksDbManager; use restate_types::config::{ @@ -86,7 +86,7 @@ pub fn discover_deployment(current_thread_rt: &Runtime, address: Uri) { .is_success(),); } -pub fn spawn_restate(config: Configuration) -> TaskCenter { +pub fn spawn_restate(config: Configuration) -> task_center::Handle { if rlimit::increase_nofile_limit(u64::MAX).is_err() { warn!("Failed to increase the number of open file descriptors limit."); } @@ -94,19 +94,19 @@ pub fn spawn_restate(config: Configuration) -> TaskCenter { let tc = TaskCenterBuilder::default() .options(config.common.clone()) .build() - .expect("task_center builds"); - let cloned_tc = tc.clone(); + .expect("task_center builds") + .to_handle(); restate_types::config::set_current_config(config.clone()); let updateable_config = Configuration::updateable(); tc.block_on(async { RocksDbManager::init(Constant::new(config.common)); - tc.spawn(TaskKind::SystemBoot, "restate", None, async move { + TaskCenter::spawn(TaskKind::SystemBoot, "restate", async move { let node = Node::create(updateable_config) .await .expect("Restate node must build"); - cloned_tc.run_in_scope("startup", None, node.start()).await + node.start().await }) .unwrap(); }); diff --git a/crates/admin/src/cluster_controller/cluster_state_refresher.rs b/crates/admin/src/cluster_controller/cluster_state_refresher.rs index a8fd1968a..1732d68eb 100644 --- a/crates/admin/src/cluster_controller/cluster_state_refresher.rs +++ b/crates/admin/src/cluster_controller/cluster_state_refresher.rs @@ -30,7 +30,6 @@ use restate_types::time::MillisSinceEpoch; use restate_types::Version; pub struct ClusterStateRefresher { - metadata: Metadata, network_sender: Networking, get_state_router: RpcRouter, in_flight_refresh: Option>>, @@ -39,11 +38,7 @@ pub struct ClusterStateRefresher { } impl ClusterStateRefresher { - pub fn new( - metadata: Metadata, - network_sender: Networking, - router_builder: &mut MessageRouterBuilder, - ) -> Self { + pub fn new(network_sender: Networking, router_builder: &mut MessageRouterBuilder) -> Self { let get_state_router = RpcRouter::new(router_builder); let initial_state = ClusterState { @@ -57,7 +52,6 @@ impl ClusterStateRefresher { watch::channel(Arc::from(initial_state)); Self { - metadata, network_sender, get_state_router, in_flight_refresh: None, @@ -99,7 +93,6 @@ impl ClusterStateRefresher { self.get_state_router.clone(), self.network_sender.clone(), Arc::clone(&self.cluster_state_update_tx), - self.metadata.clone(), )?; Ok(()) @@ -109,10 +102,10 @@ impl ClusterStateRefresher { get_state_router: RpcRouter, network_sender: Networking, cluster_state_tx: Arc>>, - metadata: Metadata, ) -> Result>>, ShutdownError> { let refresh = async move { let last_state = Arc::clone(&cluster_state_tx.borrow()); + let metadata = Metadata::current(); // make sure we have a partition table that equals or newer than last refresh let partition_table_version = metadata .wait_for_version( @@ -228,10 +221,9 @@ impl ClusterStateRefresher { Ok(()) }; - let handle = TaskCenter::current().spawn_unmanaged( + let handle = TaskCenter::spawn_unmanaged( restate_core::TaskKind::Disposable, "cluster-state-refresh", - None, refresh, )?; diff --git a/crates/admin/src/cluster_controller/logs_controller.rs b/crates/admin/src/cluster_controller/logs_controller.rs index 760adea0f..ca66b8680 100644 --- a/crates/admin/src/cluster_controller/logs_controller.rs +++ b/crates/admin/src/cluster_controller/logs_controller.rs @@ -28,7 +28,7 @@ use restate_bifrost::{Bifrost, BifrostAdmin, Error as BifrostError}; use restate_core::metadata_store::{ retry_on_network_error, MetadataStoreClient, Precondition, ReadWriteError, WriteError, }; -use restate_core::{metadata, task_center, Metadata, MetadataWriter, ShutdownError}; +use restate_core::{Metadata, MetadataWriter, ShutdownError, TaskCenterFutureExt}; use restate_types::config::Configuration; use restate_types::errors::GenericError; use restate_types::identifiers::PartitionId; @@ -324,7 +324,7 @@ fn try_provisioning( #[cfg(feature = "replicated-loglet")] ProviderKind::Replicated => build_new_replicated_loglet_configuration( ReplicatedLogletId::new(log_id, SegmentIndex::OLDEST), - metadata().nodes_config_ref().as_ref(), + &Metadata::with_current(|m| m.nodes_config_ref()), observed_cluster_state, None, node_set_selector_hints.preferred_sequencer(&log_id), @@ -494,7 +494,7 @@ impl LogletConfiguration { LogletConfiguration::Replicated(configuration) => { build_new_replicated_loglet_configuration( configuration.loglet_id.next(), - &metadata().nodes_config_ref(), + &Metadata::with_current(|m| m.nodes_config_ref()), observed_cluster_state, Some(configuration), preferred_sequencer, @@ -621,7 +621,7 @@ struct LogsControllerInner { logs_state: HashMap, logs_write_in_progress: Option, - // We are storing the logs explicitly (not relying on metadata()) because we need a fixed + // We are storing the logs explicitly (not relying on Metadata::current()) because we need a fixed // snapshot to keep logs_state in sync. current_logs: Arc, retry_policy: RetryPolicy, @@ -1024,15 +1024,11 @@ impl LogsController { Event::LogsTailUpdates { updates } }; - let tc = task_center(); - self.async_operations.spawn(async move { - tc.run_in_scope( - "log-controller-refresh-tail", - None, - find_tail.instrument(trace_span!("scheduled-find-tail")), - ) - .await - }); + self.async_operations.spawn( + find_tail + .instrument(trace_span!("scheduled-find-tail")) + .in_current_tc(), + ); } pub fn on_observed_cluster_state_update( @@ -1093,12 +1089,10 @@ impl LogsController { logs: Arc, mut debounce: Option>, ) { - let tc = task_center().clone(); let metadata_store_client = self.metadata_store_client.clone(); let metadata_writer = self.metadata_writer.clone(); self.async_operations.spawn(async move { - tc.run_in_scope("logs-controller-write-logs", None, async { if let Some(debounce) = &mut debounce { let delay = debounce.next().unwrap_or(FALLBACK_MAX_RETRY_DELAY); debug!(?delay, %previous_version, "Wait before attempting to write logs"); @@ -1153,9 +1147,7 @@ impl LogsController { let version = logs.version(); Event::WriteLogsSucceeded(version) - }) - .await - }); + }.in_current_tc()); } fn seal_log( @@ -1164,13 +1156,12 @@ impl LogsController { segment_index: SegmentIndex, mut debounce: Option>, ) { - let tc = task_center().clone(); let bifrost = self.bifrost.clone(); let metadata_store_client = self.metadata_store_client.clone(); let metadata_writer = self.metadata_writer.clone(); - self.async_operations.spawn(async move { - tc.run_in_scope("logs-controller-seal-log", None, async { + self.async_operations.spawn( + async move { if let Some(debounce) = &mut debounce { let delay = debounce.next().unwrap_or(FALLBACK_MAX_RETRY_DELAY); debug!(?delay, %log_id, %segment_index, "Wait before attempting to seal log"); @@ -1205,9 +1196,9 @@ impl LogsController { } } } - }) - .await - }); + } + .in_current_tc(), + ); } pub async fn run_async_operations(&mut self) -> Result { diff --git a/crates/admin/src/cluster_controller/scheduler.rs b/crates/admin/src/cluster_controller/scheduler.rs index ca40af8c5..a20293092 100644 --- a/crates/admin/src/cluster_controller/scheduler.rs +++ b/crates/admin/src/cluster_controller/scheduler.rs @@ -19,7 +19,7 @@ use restate_core::metadata_store::{ WriteError, }; use restate_core::network::{NetworkSender, Networking, Outgoing, TransportConnect}; -use restate_core::{metadata, ShutdownError, SyncError, TaskCenter, TaskKind}; +use restate_core::{Metadata, ShutdownError, SyncError, TaskCenter, TaskKind}; use restate_types::cluster_controller::{ ReplicationStrategy, SchedulingPlan, SchedulingPlanBuilder, TargetPartitionState, }; @@ -464,13 +464,15 @@ impl Scheduler { ); } + let (cur_partition_table_version, cur_logs_version) = + Metadata::with_current(|m| (m.partition_table_version(), m.logs_version())); for (node_id, commands) in commands.into_iter() { // only send control processors message if there are commands to send if !commands.is_empty() { let control_processors = ControlProcessors { // todo: Maybe remove unneeded partition table version - min_partition_table_version: metadata().partition_table_version(), - min_logs_table_version: metadata().logs_version(), + min_partition_table_version: cur_partition_table_version, + min_logs_table_version: cur_logs_version, commands, }; @@ -579,9 +581,7 @@ mod tests { HashSet, PartitionProcessorPlacementHints, Scheduler, }; use restate_core::network::{ForwardingHandler, Incoming, MessageCollectorMockConnector}; - use restate_core::{ - metadata, TaskCenterBuilder, TaskCenterFutureExt, TestCoreEnv, TestCoreEnvBuilder, - }; + use restate_core::{Metadata, TestCoreEnv, TestCoreEnvBuilder}; use restate_types::cluster::cluster_state::{ AliveNode, ClusterState, DeadNode, NodeState, PartitionProcessorStatus, RunMode, }; @@ -612,47 +612,43 @@ mod tests { } } - #[test(tokio::test)] + #[test(restate_core::test)] async fn empty_leadership_changes_dont_modify_plan() -> googletest::Result<()> { let test_env = TestCoreEnv::create_with_single_node(0, 0).await; let metadata_store_client = test_env.metadata_store_client.clone(); let networking = test_env.networking.clone(); - async { - let initial_scheduling_plan = metadata_store_client - .get::(SCHEDULING_PLAN_KEY.clone()) - .await - .expect("scheduling plan"); - let mut scheduler = Scheduler::init( - Configuration::pinned().as_ref(), - metadata_store_client.clone(), - networking, + let initial_scheduling_plan = metadata_store_client + .get::(SCHEDULING_PLAN_KEY.clone()) + .await + .expect("scheduling plan"); + let mut scheduler = Scheduler::init( + Configuration::pinned().as_ref(), + metadata_store_client.clone(), + networking, + ) + .await?; + let observed_cluster_state = ObservedClusterState::default(); + + scheduler + .on_observed_cluster_state( + &observed_cluster_state, + &Metadata::with_current(|m| m.nodes_config_ref()), + NoPlacementHints, ) .await?; - let observed_cluster_state = ObservedClusterState::default(); - scheduler - .on_observed_cluster_state( - &observed_cluster_state, - &metadata().nodes_config_ref(), - NoPlacementHints, - ) - .await?; - - let scheduling_plan = metadata_store_client - .get::(SCHEDULING_PLAN_KEY.clone()) - .await - .expect("scheduling plan"); + let scheduling_plan = metadata_store_client + .get::(SCHEDULING_PLAN_KEY.clone()) + .await + .expect("scheduling plan"); - assert_eq!(initial_scheduling_plan, scheduling_plan); + assert_eq!(initial_scheduling_plan, scheduling_plan); - Ok(()) - } - .in_tc(&test_env.tc) - .await + Ok(()) } - #[test(tokio::test(start_paused = true))] + #[test(restate_core::test(start_paused = true))] async fn schedule_partitions_with_replication_factor() -> googletest::Result<()> { schedule_partitions(ReplicationStrategy::Factor( NonZero::new(3).expect("non-zero"), @@ -661,7 +657,7 @@ mod tests { Ok(()) } - #[test(tokio::test(start_paused = true))] + #[test(restate_core::test(start_paused = true))] async fn schedule_partitions_with_all_nodes_replication() -> googletest::Result<()> { schedule_partitions(ReplicationStrategy::OnAllNodes).await?; Ok(()) @@ -690,15 +686,11 @@ mod tests { nodes_config.upsert_node(node_config); } - let tc = TaskCenterBuilder::default_for_tests() - .build() - .expect("task_center builds"); - // network messages going to other nodes are written to `tx` let (tx, control_recv) = mpsc::channel(100); - let connector = MessageCollectorMockConnector::new(tc.clone(), 10, tx.clone()); + let connector = MessageCollectorMockConnector::new(10, tx.clone()); - let mut builder = TestCoreEnvBuilder::with_transport_connector(tc, connector); + let mut builder = TestCoreEnvBuilder::with_transport_connector(connector); builder.router_builder.add_raw_handler( TargetName::ControlProcessors, // network messages going to my node is also written to `tx` @@ -730,82 +722,76 @@ mod tests { let networking = builder.networking.clone(); - let env = builder + let _env = builder .set_nodes_config(nodes_config.clone()) .set_partition_table(partition_table.clone()) .set_scheduling_plan(initial_scheduling_plan) .build() .await; - async move { - let mut scheduler = Scheduler::init( - Configuration::pinned().as_ref(), - metadata_store_client.clone(), - networking, - ) - .await?; - let mut observed_cluster_state = ObservedClusterState::default(); + let mut scheduler = Scheduler::init( + Configuration::pinned().as_ref(), + metadata_store_client.clone(), + networking, + ) + .await?; + let mut observed_cluster_state = ObservedClusterState::default(); - for _ in 0..num_scheduling_rounds { - let cluster_state = random_cluster_state(&node_ids, num_partitions); + for _ in 0..num_scheduling_rounds { + let cluster_state = random_cluster_state(&node_ids, num_partitions); - observed_cluster_state.update(&cluster_state); - scheduler - .on_observed_cluster_state( - &observed_cluster_state, - &metadata().nodes_config_ref(), - NoPlacementHints, - ) - .await?; - // collect all control messages from the network to build up the effective scheduling plan - let control_messages = control_recv - .as_mut() - .take_until(tokio::time::sleep(Duration::from_secs(10))) - .collect::>() - .await; - - let observed_cluster_state = - derive_observed_cluster_state(&cluster_state, control_messages); - let target_scheduling_plan = metadata_store_client - .get::(SCHEDULING_PLAN_KEY.clone()) - .await? - .expect("the scheduler should have created a scheduling plan"); - - // assert that the effective scheduling plan aligns with the target scheduling plan - assert_that!( - observed_cluster_state, - matches_scheduling_plan(&target_scheduling_plan) - ); + observed_cluster_state.update(&cluster_state); + scheduler + .on_observed_cluster_state( + &observed_cluster_state, + &Metadata::with_current(|m| m.nodes_config_ref()), + NoPlacementHints, + ) + .await?; + // collect all control messages from the network to build up the effective scheduling plan + let control_messages = control_recv + .as_mut() + .take_until(tokio::time::sleep(Duration::from_secs(10))) + .collect::>() + .await; + + let observed_cluster_state = + derive_observed_cluster_state(&cluster_state, control_messages); + let target_scheduling_plan = metadata_store_client + .get::(SCHEDULING_PLAN_KEY.clone()) + .await? + .expect("the scheduler should have created a scheduling plan"); - let alive_nodes: HashSet<_> = cluster_state - .alive_nodes() - .map(|node| node.generational_node_id.as_plain()) - .collect(); - - for (_, target_state) in target_scheduling_plan.iter() { - // assert that every partition has a leader which is part of the alive nodes set - assert!(target_state - .leader - .is_some_and(|leader| alive_nodes.contains(&leader))); - - // assert that the replication strategy was respected - match replication_strategy { - ReplicationStrategy::OnAllNodes => { - assert_eq!(target_state.node_set, alive_nodes) - } - ReplicationStrategy::Factor(replication_factor) => assert_eq!( - target_state.node_set.len(), - alive_nodes.len().min( - usize::try_from(replication_factor.get()) - .expect("u32 fits into usize") - ) - ), + // assert that the effective scheduling plan aligns with the target scheduling plan + assert_that!( + observed_cluster_state, + matches_scheduling_plan(&target_scheduling_plan) + ); + + let alive_nodes: HashSet<_> = cluster_state + .alive_nodes() + .map(|node| node.generational_node_id.as_plain()) + .collect(); + + for (_, target_state) in target_scheduling_plan.iter() { + // assert that every partition has a leader which is part of the alive nodes set + assert!(target_state + .leader + .is_some_and(|leader| alive_nodes.contains(&leader))); + + // assert that the replication strategy was respected + match replication_strategy { + ReplicationStrategy::OnAllNodes => { + assert_eq!(target_state.node_set, alive_nodes) } + ReplicationStrategy::Factor(replication_factor) => assert_eq!( + target_state.node_set.len(), + alive_nodes.len().min( + usize::try_from(replication_factor.get()).expect("u32 fits into usize") + ) + ), } } - googletest::Result::Ok(()) } - .in_tc(&env.tc) - .await?; Ok(()) } diff --git a/crates/admin/src/cluster_controller/service.rs b/crates/admin/src/cluster_controller/service.rs index 4a07f5e35..afe930c8e 100644 --- a/crates/admin/src/cluster_controller/service.rs +++ b/crates/admin/src/cluster_controller/service.rs @@ -58,7 +58,6 @@ pub enum Error { } pub struct Service { - metadata: Metadata, networking: Networking, bifrost: Bifrost, cluster_state_refresher: ClusterStateRefresher, @@ -83,7 +82,6 @@ where mut configuration: Live, health_status: HealthStatus, bifrost: Bifrost, - metadata: Metadata, networking: Networking, router_builder: &mut MessageRouterBuilder, server_builder: &mut NetworkServerBuilder, @@ -93,7 +91,7 @@ where let (command_tx, command_rx) = mpsc::channel(2); let cluster_state_refresher = - ClusterStateRefresher::new(metadata.clone(), networking.clone(), router_builder); + ClusterStateRefresher::new(networking.clone(), router_builder); let processor_manager_client = PartitionProcessorManagerClient::new(networking.clone(), router_builder); @@ -119,7 +117,6 @@ where Service { configuration, health_status, - metadata, networking, bifrost, cluster_state_refresher, @@ -226,7 +223,7 @@ impl Service { TaskCenter::spawn_child( TaskKind::SystemService, "cluster-controller-metadata-sync", - sync_cluster_controller_metadata(self.metadata.clone()), + sync_cluster_controller_metadata(), )?; let mut shutdown = std::pin::pin!(cancellation_watcher()); @@ -393,12 +390,13 @@ impl Service { } } -async fn sync_cluster_controller_metadata(metadata: Metadata) -> anyhow::Result<()> { +async fn sync_cluster_controller_metadata() -> anyhow::Result<()> { // todo make this configurable let mut interval = time::interval(Duration::from_secs(10)); interval.set_missed_tick_behavior(MissedTickBehavior::Delay); let mut cancel = std::pin::pin!(cancellation_watcher()); + let metadata = Metadata::current(); loop { tokio::select! { @@ -484,10 +482,8 @@ mod tests { use restate_core::network::{ FailingConnector, Incoming, MessageHandler, MockPeerConnection, NetworkServerBuilder, }; - use restate_core::{ - NoOpMessageHandler, TaskCenter, TaskCenterFutureExt, TaskKind, TestCoreEnv, - TestCoreEnvBuilder, - }; + use restate_core::test_env::NoOpMessageHandler; + use restate_core::{TaskCenter, TaskKind, TestCoreEnv, TestCoreEnvBuilder}; use restate_types::cluster::cluster_state::PartitionProcessorStatus; use restate_types::config::{AdminOptions, Configuration}; use restate_types::health::HealthStatus; @@ -500,53 +496,41 @@ mod tests { use restate_types::nodes_config::{LogServerConfig, NodeConfig, NodesConfiguration, Role}; use restate_types::{GenerationalNodeId, Version}; - #[test(tokio::test)] + #[test(restate_core::test)] async fn manual_log_trim() -> anyhow::Result<()> { const LOG_ID: LogId = LogId::new(0); let mut builder = TestCoreEnvBuilder::with_incoming_only_connector(); - let tc = builder.tc.clone(); - async { - let bifrost_svc = BifrostService::new().with_factory(memory_loglet::Factory::default()); - let bifrost = bifrost_svc.handle(); - - let svc = Service::new( - Live::from_value(Configuration::default()), - HealthStatus::default(), - bifrost.clone(), - builder.metadata.clone(), - builder.networking.clone(), - &mut builder.router_builder, - &mut NetworkServerBuilder::default(), - builder.metadata_writer.clone(), - builder.metadata_store_client.clone(), - ); - let svc_handle = svc.handle(); - - let _ = builder.build().await; - bifrost_svc.start().await?; - - let mut appender = bifrost.create_appender(LOG_ID)?; - - TaskCenter::current().spawn( - TaskKind::SystemService, - "cluster-controller", - None, - svc.run(), - )?; - - for _ in 1..=5 { - appender.append("").await?; - } + let bifrost_svc = BifrostService::new().with_factory(memory_loglet::Factory::default()); + let bifrost = bifrost_svc.handle(); + + let svc = Service::new( + Live::from_value(Configuration::default()), + HealthStatus::default(), + bifrost.clone(), + builder.networking.clone(), + &mut builder.router_builder, + &mut NetworkServerBuilder::default(), + builder.metadata_writer.clone(), + builder.metadata_store_client.clone(), + ); + let svc_handle = svc.handle(); + + let _ = builder.build().await; + bifrost_svc.start().await?; - svc_handle.trim_log(LOG_ID, Lsn::from(3)).await??; + let mut appender = bifrost.create_appender(LOG_ID)?; - let record = bifrost.read(LOG_ID, Lsn::OLDEST).await?.unwrap(); - assert_that!(record.sequence_number(), eq(Lsn::OLDEST)); - assert_that!(record.trim_gap_to_sequence_number(), eq(Some(Lsn::new(3)))); - Ok::<(), anyhow::Error>(()) + TaskCenter::spawn(TaskKind::SystemService, "cluster-controller", svc.run())?; + + for _ in 1..=5 { + appender.append("").await?; } - .in_tc(&tc) - .await?; + + svc_handle.trim_log(LOG_ID, Lsn::from(3)).await??; + + let record = bifrost.read(LOG_ID, Lsn::OLDEST).await?.unwrap(); + assert_that!(record.sequence_number(), eq(Lsn::OLDEST)); + assert_that!(record.trim_gap_to_sequence_number(), eq(Some(Lsn::new(3)))); Ok(()) } @@ -584,7 +568,7 @@ mod tests { } } - #[test(tokio::test(start_paused = true))] + #[test(restate_core::test(start_paused = true))] async fn auto_log_trim() -> anyhow::Result<()> { const LOG_ID: LogId = LogId::new(0); @@ -612,63 +596,58 @@ mod tests { .add_message_handler(NoOpMessageHandler::::default()) }) .await?; - let tc = node_env.tc.clone(); - - async move { - // simulate a connection from node 2 so we can have a connection between the two - // nodes - let node_2 = MockPeerConnection::connect( - GenerationalNodeId::new(2, 2), - node_env.metadata.nodes_config_version(), - node_env - .metadata - .nodes_config_ref() - .cluster_name() - .to_owned(), - node_env.networking.connection_manager(), - 10, - ) - .await?; - // let node2 receive messages and use the same message handler as node1 - let (_node_2, _node2_reactor) = node_2 - .process_with_message_handler(&TaskCenter::current(), get_node_state_handler)?; - - let mut appender = bifrost.create_appender(LOG_ID)?; - for i in 1..=20 { - let lsn = appender.append("").await?; - assert_eq!(Lsn::from(i), lsn); - } - tokio::time::sleep(interval_duration * 10).await; + // simulate a connection from node 2 so we can have a connection between the two + // nodes + let node_2 = MockPeerConnection::connect( + GenerationalNodeId::new(2, 2), + node_env.metadata.nodes_config_version(), + node_env + .metadata + .nodes_config_ref() + .cluster_name() + .to_owned(), + node_env.networking.connection_manager(), + 10, + ) + .await?; + // let node2 receive messages and use the same message handler as node1 + let (_node_2, _node2_reactor) = + node_2.process_with_message_handler(get_node_state_handler)?; + + let mut appender = bifrost.create_appender(LOG_ID)?; + for i in 1..=20 { + let lsn = appender.append("").await?; + assert_eq!(Lsn::from(i), lsn); + } + + tokio::time::sleep(interval_duration * 10).await; - assert_eq!(Lsn::INVALID, bifrost.get_trim_point(LOG_ID).await?); + assert_eq!(Lsn::INVALID, bifrost.get_trim_point(LOG_ID).await?); - // report persisted lsn back to cluster controller - persisted_lsn.store(6, Ordering::Relaxed); + // report persisted lsn back to cluster controller + persisted_lsn.store(6, Ordering::Relaxed); - tokio::time::sleep(interval_duration * 10).await; - // we delete 1-6. - assert_eq!(Lsn::from(6), bifrost.get_trim_point(LOG_ID).await?); + tokio::time::sleep(interval_duration * 10).await; + // we delete 1-6. + assert_eq!(Lsn::from(6), bifrost.get_trim_point(LOG_ID).await?); - // increase by 4 more, this should not overcome the threshold - persisted_lsn.store(10, Ordering::Relaxed); + // increase by 4 more, this should not overcome the threshold + persisted_lsn.store(10, Ordering::Relaxed); - tokio::time::sleep(interval_duration * 10).await; - assert_eq!(Lsn::from(6), bifrost.get_trim_point(LOG_ID).await?); + tokio::time::sleep(interval_duration * 10).await; + assert_eq!(Lsn::from(6), bifrost.get_trim_point(LOG_ID).await?); - // now we have reached the min threshold wrt to the last trim point - persisted_lsn.store(11, Ordering::Relaxed); + // now we have reached the min threshold wrt to the last trim point + persisted_lsn.store(11, Ordering::Relaxed); - tokio::time::sleep(interval_duration * 10).await; - assert_eq!(Lsn::from(11), bifrost.get_trim_point(LOG_ID).await?); + tokio::time::sleep(interval_duration * 10).await; + assert_eq!(Lsn::from(11), bifrost.get_trim_point(LOG_ID).await?); - Ok::<(), anyhow::Error>(()) - } - .in_tc(&tc) - .await + Ok(()) } - #[test(tokio::test(start_paused = true))] + #[test(restate_core::test(start_paused = true))] async fn auto_log_trim_zero_threshold() -> anyhow::Result<()> { const LOG_ID: LogId = LogId::new(0); let mut admin_options = AdminOptions::default(); @@ -695,58 +674,53 @@ mod tests { }) .await?; - let tc = node_env.tc.clone(); - async move { - // simulate a connection from node 2 so we can have a connection between the two - // nodes - let node_2 = MockPeerConnection::connect( - GenerationalNodeId::new(2, 2), - node_env.metadata.nodes_config_version(), - node_env - .metadata - .nodes_config_ref() - .cluster_name() - .to_owned(), - node_env.networking.connection_manager(), - 10, - ) - .await?; - // let node2 receive messages and use the same message handler as node1 - let (_node_2, _node2_reactor) = - node_2.process_with_message_handler(&node_env.tc, get_node_state_handler)?; - - let mut appender = bifrost.create_appender(LOG_ID)?; - for i in 1..=20 { - let lsn = appender.append(format!("record{}", i)).await?; - assert_eq!(Lsn::from(i), lsn); - } - tokio::time::sleep(interval_duration * 10).await; - assert_eq!(Lsn::INVALID, bifrost.get_trim_point(LOG_ID).await?); + // simulate a connection from node 2 so we can have a connection between the two + // nodes + let node_2 = MockPeerConnection::connect( + GenerationalNodeId::new(2, 2), + node_env.metadata.nodes_config_version(), + node_env + .metadata + .nodes_config_ref() + .cluster_name() + .to_owned(), + node_env.networking.connection_manager(), + 10, + ) + .await?; + // let node2 receive messages and use the same message handler as node1 + let (_node_2, _node2_reactor) = + node_2.process_with_message_handler(get_node_state_handler)?; + + let mut appender = bifrost.create_appender(LOG_ID)?; + for i in 1..=20 { + let lsn = appender.append(format!("record{}", i)).await?; + assert_eq!(Lsn::from(i), lsn); + } + tokio::time::sleep(interval_duration * 10).await; + assert_eq!(Lsn::INVALID, bifrost.get_trim_point(LOG_ID).await?); - // report persisted lsn back to cluster controller - persisted_lsn.store(3, Ordering::Relaxed); + // report persisted lsn back to cluster controller + persisted_lsn.store(3, Ordering::Relaxed); - tokio::time::sleep(interval_duration * 10).await; - // everything before the persisted_lsn. - assert_eq!(bifrost.get_trim_point(LOG_ID).await?, Lsn::from(3)); - // we should be able to after the last persisted lsn - let v = bifrost.read(LOG_ID, Lsn::from(4)).await?.unwrap(); - assert_that!(v.sequence_number(), eq(Lsn::new(4))); - assert!(v.is_data_record()); - assert_that!(v.decode_unchecked::(), eq("record4".to_owned())); + tokio::time::sleep(interval_duration * 10).await; + // everything before the persisted_lsn. + assert_eq!(bifrost.get_trim_point(LOG_ID).await?, Lsn::from(3)); + // we should be able to after the last persisted lsn + let v = bifrost.read(LOG_ID, Lsn::from(4)).await?.unwrap(); + assert_that!(v.sequence_number(), eq(Lsn::new(4))); + assert!(v.is_data_record()); + assert_that!(v.decode_unchecked::(), eq("record4".to_owned())); - persisted_lsn.store(20, Ordering::Relaxed); + persisted_lsn.store(20, Ordering::Relaxed); - tokio::time::sleep(interval_duration * 10).await; - assert_eq!(Lsn::from(20), bifrost.get_trim_point(LOG_ID).await?); + tokio::time::sleep(interval_duration * 10).await; + assert_eq!(Lsn::from(20), bifrost.get_trim_point(LOG_ID).await?); - Ok::<(), anyhow::Error>(()) - } - .in_tc(&tc) - .await + Ok(()) } - #[test(tokio::test(start_paused = true))] + #[test(restate_core::test(start_paused = true))] async fn do_not_trim_if_not_all_nodes_report_persisted_lsn() -> anyhow::Result<()> { const LOG_ID: LogId = LogId::new(0); @@ -762,7 +736,7 @@ mod tests { let persisted_lsn = Arc::new(AtomicU64::new(0)); let archived_lsn = Arc::new(AtomicU64::new(0)); - let (node_env, bifrost) = create_test_env(config, |builder| { + let (_node_env, bifrost) = create_test_env(config, |builder| { let black_list = builder .nodes_config .iter() @@ -781,24 +755,18 @@ mod tests { }) .await?; - async move { - let mut appender = bifrost.create_appender(LOG_ID)?; - for i in 1..=5 { - let lsn = appender.append(format!("record{}", i)).await?; - assert_eq!(Lsn::from(i), lsn); - } - - // report persisted lsn back to cluster controller for a subset of the nodes - persisted_lsn.store(5, Ordering::Relaxed); + let mut appender = bifrost.create_appender(LOG_ID)?; + for i in 1..=5 { + let lsn = appender.append(format!("record{}", i)).await?; + assert_eq!(Lsn::from(i), lsn); + } - tokio::time::sleep(interval_duration * 10).await; - // no trimming should have happened because one node did not report the persisted lsn - assert_eq!(Lsn::INVALID, bifrost.get_trim_point(LOG_ID).await?); + // report persisted lsn back to cluster controller for a subset of the nodes + persisted_lsn.store(5, Ordering::Relaxed); - Ok::<(), anyhow::Error>(()) - } - .in_tc(&node_env.tc) - .await?; + tokio::time::sleep(interval_duration * 10).await; + // no trimming should have happened because one node did not report the persisted lsn + assert_eq!(Lsn::INVALID, bifrost.get_trim_point(LOG_ID).await?); Ok(()) } @@ -811,54 +779,43 @@ mod tests { F: FnMut(TestCoreEnvBuilder) -> TestCoreEnvBuilder, { let mut builder = TestCoreEnvBuilder::with_incoming_only_connector(); - let tc = builder.tc.clone(); - async { - let bifrost_svc = BifrostService::new().with_factory(memory_loglet::Factory::default()); - let bifrost = bifrost_svc.handle(); - - let mut server_builder = NetworkServerBuilder::default(); + let bifrost_svc = BifrostService::new().with_factory(memory_loglet::Factory::default()); + let bifrost = bifrost_svc.handle(); + + let mut server_builder = NetworkServerBuilder::default(); + + let svc = Service::new( + Live::from_value(config), + HealthStatus::default(), + bifrost.clone(), + builder.networking.clone(), + &mut builder.router_builder, + &mut server_builder, + builder.metadata_writer.clone(), + builder.metadata_store_client.clone(), + ); - let svc = Service::new( - Live::from_value(config), - HealthStatus::default(), - bifrost.clone(), - builder.metadata.clone(), - builder.networking.clone(), - &mut builder.router_builder, - &mut server_builder, - builder.metadata_writer.clone(), - builder.metadata_store_client.clone(), - ); - - let mut nodes_config = NodesConfiguration::new(Version::MIN, "test-cluster".to_owned()); - nodes_config.upsert_node(NodeConfig::new( - "node-1".to_owned(), - GenerationalNodeId::new(1, 1), - AdvertisedAddress::Uds("foobar".into()), - Role::Worker.into(), - LogServerConfig::default(), - )); - nodes_config.upsert_node(NodeConfig::new( - "node-2".to_owned(), - GenerationalNodeId::new(2, 2), - AdvertisedAddress::Uds("bar".into()), - Role::Worker.into(), - LogServerConfig::default(), - )); - let builder = modify_builder(builder.set_nodes_config(nodes_config)); - - let node_env = builder.build().await; - bifrost_svc.start().await?; - - node_env.tc.spawn( - TaskKind::SystemService, - "cluster-controller", - None, - svc.run(), - )?; - Ok((node_env, bifrost)) - } - .in_tc(&tc) - .await + let mut nodes_config = NodesConfiguration::new(Version::MIN, "test-cluster".to_owned()); + nodes_config.upsert_node(NodeConfig::new( + "node-1".to_owned(), + GenerationalNodeId::new(1, 1), + AdvertisedAddress::Uds("foobar".into()), + Role::Worker.into(), + LogServerConfig::default(), + )); + nodes_config.upsert_node(NodeConfig::new( + "node-2".to_owned(), + GenerationalNodeId::new(2, 2), + AdvertisedAddress::Uds("bar".into()), + Role::Worker.into(), + LogServerConfig::default(), + )); + let builder = modify_builder(builder.set_nodes_config(nodes_config)); + + let node_env = builder.build().await; + bifrost_svc.start().await?; + + TaskCenter::spawn(TaskKind::SystemService, "cluster-controller", svc.run())?; + Ok((node_env, bifrost)) } } diff --git a/crates/admin/src/cluster_controller/service/state.rs b/crates/admin/src/cluster_controller/service/state.rs index d8e0a5948..938a4ac7e 100644 --- a/crates/admin/src/cluster_controller/service/state.rs +++ b/crates/admin/src/cluster_controller/service/state.rs @@ -20,7 +20,7 @@ use tracing::{debug, info, warn}; use restate_bifrost::{Bifrost, BifrostAdmin}; use restate_core::metadata_store::MetadataStoreClient; use restate_core::network::TransportConnect; -use restate_core::{Metadata, MetadataWriter}; +use restate_core::{my_node_id, Metadata, MetadataWriter}; use restate_types::cluster::cluster_state::{AliveNode, NodeState}; use restate_types::config::{AdminOptions, Configuration}; use restate_types::identifiers::PartitionId; @@ -50,24 +50,26 @@ where T: TransportConnect, { pub async fn update(&mut self, service: &Service) -> anyhow::Result<()> { - let nodes_config = service.metadata.nodes_config_ref(); - let maybe_leader = nodes_config - .get_admin_nodes() - .filter(|node| { - service - .observed_cluster_state - .is_node_alive(node.current_generation) - }) - .map(|node| node.current_generation) - .sorted() - .next(); + let maybe_leader = { + let nodes_config = Metadata::with_current(|m| m.nodes_config_ref()); + nodes_config + .get_admin_nodes() + .filter(|node| { + service + .observed_cluster_state + .is_node_alive(node.current_generation) + }) + .map(|node| node.current_generation) + .sorted() + .next() + }; // A Cluster Controller is a leader if the node holds the smallest PlainNodeID // If no other node was found to take leadership, we assume leadership let is_leader = match maybe_leader { None => true, - Some(leader) => leader == service.metadata.my_node_id(), + Some(leader) => leader == my_node_id(), }; match (is_leader, &self) { @@ -160,6 +162,8 @@ where async fn from_service(service: &Service) -> anyhow::Result> { let configuration = service.configuration.pinned(); + let metadata = Metadata::current(); + let scheduler = Scheduler::init( &configuration, service.metadata_store_client.clone(), @@ -169,7 +173,7 @@ where let logs_controller = LogsController::init( &configuration, - service.metadata.clone(), + metadata.clone(), service.bifrost.clone(), service.metadata_store_client.clone(), service.metadata_writer.clone(), @@ -184,16 +188,16 @@ where find_logs_tail_interval.set_missed_tick_behavior(MissedTickBehavior::Delay); let mut leader = Self { - metadata: service.metadata.clone(), + metadata: metadata.clone(), bifrost: service.bifrost.clone(), metadata_store_client: service.metadata_store_client.clone(), metadata_writer: service.metadata_writer.clone(), - logs_watcher: service.metadata.watch(MetadataKind::Logs), - nodes_config: service.metadata.updateable_nodes_config(), - partition_table_watcher: service.metadata.watch(MetadataKind::PartitionTable), + logs_watcher: metadata.watch(MetadataKind::Logs), + nodes_config: metadata.updateable_nodes_config(), + partition_table_watcher: metadata.watch(MetadataKind::PartitionTable), cluster_state_watcher: service.cluster_state_refresher.cluster_state_watcher(), - partition_table: service.metadata.updateable_partition_table(), - logs: service.metadata.updateable_logs_metadata(), + partition_table: metadata.updateable_partition_table(), + logs: metadata.updateable_logs_metadata(), find_logs_tail_interval, log_trim_interval, log_trim_threshold, diff --git a/crates/admin/src/schema_registry/mod.rs b/crates/admin/src/schema_registry/mod.rs index b9869d47a..cd043be8d 100644 --- a/crates/admin/src/schema_registry/mod.rs +++ b/crates/admin/src/schema_registry/mod.rs @@ -21,7 +21,7 @@ use std::time::Duration; use tracing::subscriber::NoSubscriber; use restate_core::metadata_store::MetadataStoreClient; -use restate_core::{metadata, MetadataWriter}; +use restate_core::{Metadata, MetadataWriter}; use restate_service_protocol::discovery::{DiscoverEndpoint, DiscoveredEndpoint, ServiceDiscovery}; use restate_types::identifiers::{DeploymentId, ServiceRevision, SubscriptionId}; use restate_types::metadata_store::keys::SCHEMA_INFORMATION_KEY; @@ -132,7 +132,7 @@ impl SchemaRegistry { let (id, services) = if !apply_mode.should_apply() { let mut updater = SchemaUpdater::new( - metadata().schema().deref().clone(), + Metadata::with_current(|m| m.schema()).deref().clone(), self.experimental_feature_kafka_ingress_next, ); @@ -303,38 +303,33 @@ impl SchemaRegistry { } pub fn list_services(&self) -> Vec { - metadata().schema().list_services() + Metadata::with_current(|m| m.schema()).list_services() } pub fn get_service(&self, service_name: impl AsRef) -> Option { - metadata().schema().resolve_latest_service(&service_name) + Metadata::with_current(|m| m.schema()).resolve_latest_service(&service_name) } pub fn get_service_openapi(&self, service_name: impl AsRef) -> Option { - metadata() - .schema() - .resolve_latest_service_openapi(&service_name) + Metadata::with_current(|m| m.schema()).resolve_latest_service_openapi(&service_name) } pub fn get_deployment( &self, deployment_id: DeploymentId, ) -> Option<(Deployment, Vec)> { - metadata() - .schema() - .get_deployment_and_services(&deployment_id) + Metadata::with_current(|m| m.schema()).get_deployment_and_services(&deployment_id) } pub fn list_deployments(&self) -> Vec<(Deployment, Vec<(String, ServiceRevision)>)> { - metadata().schema().get_deployments() + Metadata::with_current(|m| m.schema()).get_deployments() } pub fn list_service_handlers( &self, service_name: impl AsRef, ) -> Option> { - metadata() - .schema() + Metadata::with_current(|m| m.schema()) .resolve_latest_service(&service_name) .map(|m| m.handlers) } @@ -344,8 +339,7 @@ impl SchemaRegistry { service_name: impl AsRef, handler_name: impl AsRef, ) -> Option { - metadata() - .schema() + Metadata::with_current(|m| m.schema()) .resolve_latest_service(&service_name) .and_then(|m| { m.handlers @@ -355,11 +349,11 @@ impl SchemaRegistry { } pub fn get_subscription(&self, subscription_id: SubscriptionId) -> Option { - metadata().schema().get_subscription(subscription_id) + Metadata::with_current(|m| m.schema()).get_subscription(subscription_id) } pub fn list_subscriptions(&self, filters: &[ListSubscriptionFilter]) -> Vec { - metadata().schema().list_subscriptions(filters) + Metadata::with_current(|m| m.schema()).list_subscriptions(filters) } } diff --git a/crates/bifrost/benches/append_throughput.rs b/crates/bifrost/benches/append_throughput.rs index 679c3f9c4..77bacf562 100644 --- a/crates/bifrost/benches/append_throughput.rs +++ b/crates/bifrost/benches/append_throughput.rs @@ -15,6 +15,7 @@ use std::ops::Range; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use futures::stream::{FuturesOrdered, FuturesUnordered}; use futures::StreamExt; +use restate_core::TaskCenterFutureExt; use tracing::info; use tracing_subscriber::EnvFilter; @@ -122,13 +123,9 @@ fn write_throughput_local_loglet(c: &mut Criterion) { .sample_size(10) .throughput(Throughput::Elements(count_per_run)) .bench_function("sequential_single_log", |bencher| { - bencher.to_async(&test_runner_rt).iter(|| { - tc.run_in_scope( - "bench", - None, - append_seq(bifrost.clone(), LogId::new(1), count_per_run), - ) - }); + bencher + .to_async(&test_runner_rt) + .iter(|| append_seq(bifrost.clone(), LogId::new(1), count_per_run).in_tc(&tc)); }); // Concurrent single log @@ -141,15 +138,12 @@ fn write_throughput_local_loglet(c: &mut Criterion) { count_per_run, |bencher, &count_per_run| { bencher.to_async(&test_runner_rt).iter(|| { - tc.run_in_scope( - "bench", - None, - append_records_concurrent_single_log( - bifrost.clone(), - LogId::new(1), - count_per_run, - ), + append_records_concurrent_single_log( + bifrost.clone(), + LogId::new(1), + count_per_run, ) + .in_tc(&tc) }); }, ); @@ -167,15 +161,12 @@ fn write_throughput_local_loglet(c: &mut Criterion) { count_per_run, |bencher, &count_per_run| { bencher.to_async(&test_runner_rt).iter(|| { - tc.run_in_scope( - "bench", - None, - append_records_multi_log( - bifrost.clone(), - 0..num_logs_per_run, - count_per_run, - ), + append_records_multi_log( + bifrost.clone(), + 0..num_logs_per_run, + count_per_run, ) + .in_tc(&tc) }); }, ); diff --git a/crates/bifrost/benches/util.rs b/crates/bifrost/benches/util.rs index 30a5c8ea1..3d107b05b 100644 --- a/crates/bifrost/benches/util.rs +++ b/crates/bifrost/benches/util.rs @@ -12,8 +12,8 @@ use std::sync::Arc; use tracing::warn; use restate_core::{ - spawn_metadata_manager, MetadataBuilder, MetadataManager, TaskCenter, TaskCenterBuilder, - TaskCenterFutureExt, + spawn_metadata_manager, task_center, MetadataBuilder, MetadataManager, TaskCenter, + TaskCenterBuilder, TaskCenterFutureExt, }; use restate_metadata_store::{MetadataStoreClient, Precondition}; use restate_rocksdb::RocksDbManager; @@ -26,14 +26,15 @@ pub async fn spawn_environment( config: Configuration, num_logs: u16, provider: ProviderKind, -) -> TaskCenter { +) -> task_center::Handle { if rlimit::increase_nofile_limit(u64::MAX).is_err() { warn!("Failed to increase the number of open file descriptors limit."); } let tc = TaskCenterBuilder::default() .options(config.common.clone()) .build() - .expect("task_center builds"); + .expect("task_center builds") + .to_handle(); async { restate_types::config::set_current_config(config.clone()); diff --git a/crates/bifrost/src/background_appender.rs b/crates/bifrost/src/background_appender.rs index 6ac7921eb..beb95d65c 100644 --- a/crates/bifrost/src/background_appender.rs +++ b/crates/bifrost/src/background_appender.rs @@ -17,7 +17,6 @@ use tokio::sync::{mpsc, oneshot, Notify}; use tracing::{trace, warn}; use restate_core::{cancellation_watcher, ShutdownError, TaskCenter, TaskHandle}; -use restate_types::identifiers::PartitionId; use restate_types::storage::StorageEncode; use crate::error::EnqueueError; @@ -58,17 +57,12 @@ where /// Start the background appender as a TaskCenter background task. Note that the task will not /// automatically react to TaskCenter's shutdown signal, it gives control over the shutdown /// behaviour to the owner of [`AppenderHandle`] to drain or drop when appropriate. - pub fn start( - self, - name: &'static str, - partition_id: Option, - ) -> Result, ShutdownError> { + pub fn start(self, name: &'static str) -> Result, ShutdownError> { let (tx, rx) = tokio::sync::mpsc::channel(self.queue_capacity); - let handle = TaskCenter::current().spawn_unmanaged( + let handle = TaskCenter::spawn_unmanaged( restate_core::TaskKind::BifrostAppender, name, - partition_id, self.run(rx), )?; diff --git a/crates/bifrost/src/bifrost.rs b/crates/bifrost/src/bifrost.rs index f04cf4aca..f1a97e14d 100644 --- a/crates/bifrost/src/bifrost.rs +++ b/crates/bifrost/src/bifrost.rs @@ -496,8 +496,8 @@ mod tests { use tracing::info; use tracing_test::traced_test; - use restate_core::TestCoreEnvBuilder2; - use restate_core::{TaskCenter, TaskKind, TestCoreEnv2}; + use restate_core::TestCoreEnvBuilder; + use restate_core::{TaskCenter, TaskKind, TestCoreEnv}; use restate_rocksdb::RocksDbManager; use restate_types::config::CommonOptions; use restate_types::live::Constant; @@ -514,7 +514,7 @@ mod tests { #[traced_test] async fn test_append_smoke() -> googletest::Result<()> { let num_partitions = 5; - let _ = TestCoreEnvBuilder2::with_incoming_only_connector() + let _ = TestCoreEnvBuilder::with_incoming_only_connector() .set_partition_table(PartitionTable::with_equally_sized_partitions( Version::MIN, num_partitions, @@ -586,7 +586,7 @@ mod tests { #[restate_core::test(start_paused = true)] async fn test_lazy_initialization() -> googletest::Result<()> { - let _ = TestCoreEnv2::create_with_single_node(1, 1).await; + let _ = TestCoreEnv::create_with_single_node(1, 1).await; let delay = Duration::from_secs(5); // This memory provider adds a delay to its loglet initialization, we want // to ensure that appends do not fail while waiting for the loglet; @@ -604,7 +604,7 @@ mod tests { #[test(restate_core::test(flavor = "multi_thread", worker_threads = 2))] async fn trim_log_smoke_test() -> googletest::Result<()> { const LOG_ID: LogId = LogId::new(0); - let node_env = TestCoreEnvBuilder2::with_incoming_only_connector() + let node_env = TestCoreEnvBuilder::with_incoming_only_connector() .set_provider_kind(ProviderKind::Local) .build() .await; @@ -676,7 +676,7 @@ mod tests { #[restate_core::test(start_paused = true)] async fn test_read_across_segments() -> googletest::Result<()> { const LOG_ID: LogId = LogId::new(0); - let node_env = TestCoreEnvBuilder2::with_incoming_only_connector() + let node_env = TestCoreEnvBuilder::with_incoming_only_connector() .set_partition_table(PartitionTable::with_equally_sized_partitions( Version::MIN, 1, @@ -862,7 +862,7 @@ mod tests { #[traced_test] async fn test_appends_correctly_handle_reconfiguration() -> googletest::Result<()> { const LOG_ID: LogId = LogId::new(0); - let node_env = TestCoreEnvBuilder2::with_incoming_only_connector() + let node_env = TestCoreEnvBuilder::with_incoming_only_connector() .set_partition_table(PartitionTable::with_equally_sized_partitions( Version::MIN, 1, @@ -881,7 +881,7 @@ mod tests { // create an appender let stop_signal = Arc::new(AtomicBool::default()); let append_counter = Arc::new(AtomicUsize::new(0)); - let _ = TaskCenter::current().spawn(TaskKind::TestRunner, "append-records", None, { + let _ = TaskCenter::spawn(TaskKind::TestRunner, "append-records", { let append_counter = append_counter.clone(); let stop_signal = stop_signal.clone(); let bifrost = bifrost.clone(); diff --git a/crates/bifrost/src/loglet/loglet_tests.rs b/crates/bifrost/src/loglet/loglet_tests.rs index 12a2a9223..586b8bbf5 100644 --- a/crates/bifrost/src/loglet/loglet_tests.rs +++ b/crates/bifrost/src/loglet/loglet_tests.rs @@ -123,7 +123,7 @@ pub async fn gapless_loglet_smoke_test(loglet: Arc) -> googletest::R assert!(loglet.read_opt(Lsn::new(end)).await?.is_none()); let handle1: TaskHandle> = - TaskCenter::current().spawn_unmanaged(TaskKind::TestRunner, "read", None, { + TaskCenter::spawn_unmanaged(TaskKind::TestRunner, "read", { let loglet = loglet.clone(); async move { // read future record 4 @@ -140,7 +140,7 @@ pub async fn gapless_loglet_smoke_test(loglet: Arc) -> googletest::R // Waiting for 10 let handle2: TaskHandle> = - TaskCenter::current().spawn_unmanaged(TaskKind::TestRunner, "read", None, { + TaskCenter::spawn_unmanaged(TaskKind::TestRunner, "read", { let loglet = loglet.clone(); async move { // read future record 10 diff --git a/crates/bifrost/src/providers/local_loglet/mod.rs b/crates/bifrost/src/providers/local_loglet/mod.rs index 129e9406b..e9ab52f93 100644 --- a/crates/bifrost/src/providers/local_loglet/mod.rs +++ b/crates/bifrost/src/providers/local_loglet/mod.rs @@ -278,11 +278,11 @@ impl Loglet for LocalLoglet { mod tests { use futures::TryStreamExt; use googletest::prelude::eq; - use googletest::{assert_that, elements_are}; + use googletest::{assert_that, elements_are, IntoTestResult}; use test_log::test; use crate::loglet::Loglet; - use restate_core::TestCoreEnvBuilder; + use restate_core::{TaskCenter, TestCoreEnvBuilder}; use restate_rocksdb::RocksDbManager; use restate_types::config::Configuration; use restate_types::live::Live; @@ -294,55 +294,47 @@ mod tests { macro_rules! run_test { ($test:ident) => { paste::paste! { - #[tokio::test(start_paused = true)] + #[restate_core::test(start_paused = true)] async fn []() -> googletest::Result<()> { - run_in_test_env(crate::loglet::loglet_tests::$test).await + let loglet = create_loglet().await.into_test_result()?; + crate::loglet::loglet_tests::$test(loglet).await?; + TaskCenter::shutdown_node("test completed", 0).await; + RocksDbManager::get().shutdown().await; + Ok(()) } } }; } - async fn run_in_test_env(mut future: F) -> googletest::Result<()> - where - F: FnMut(Arc) -> O, - O: std::future::Future>, - { - let node_env = TestCoreEnvBuilder::with_incoming_only_connector() + async fn create_loglet() -> anyhow::Result> { + let _node_env = TestCoreEnvBuilder::with_incoming_only_connector() .set_provider_kind(ProviderKind::Local) .build() .await; - node_env - .tc - .run_in_scope("test", None, async { - let config = Live::from_value(Configuration::default()); - RocksDbManager::init(config.clone().map(|c| &c.common)); - let params = LogletParams::from("42".to_string()); + let config = Live::from_value(Configuration::default()); + RocksDbManager::init(config.clone().map(|c| &c.common)); + let params = LogletParams::from("42".to_string()); - let log_store = RocksDbLogStore::create( - &config.pinned().bifrost.local, - config.clone().map(|c| &c.bifrost.local.rocksdb).boxed(), - ) - .await?; + let log_store = RocksDbLogStore::create( + &config.pinned().bifrost.local, + config.clone().map(|c| &c.bifrost.local.rocksdb).boxed(), + ) + .await?; - let log_writer = log_store - .create_writer() - .start(config.clone().map(|c| &c.bifrost.local).boxed())?; + let log_writer = log_store + .create_writer() + .start(config.clone().map(|c| &c.bifrost.local).boxed())?; - let loglet = Arc::new(LocalLoglet::create( - params - .parse() - .expect("loglet params can be converted into u64"), - log_store, - log_writer, - )?); + let loglet = Arc::new(LocalLoglet::create( + params + .parse() + .expect("loglet params can be converted into u64"), + log_store, + log_writer, + )?); - future(loglet).await - }) - .await?; - node_env.tc.shutdown_node("test completed", 0).await; - RocksDbManager::get().shutdown().await; - Ok(()) + Ok(loglet) } run_test!(gapless_loglet_smoke_test); @@ -351,85 +343,77 @@ mod tests { run_test!(append_after_seal); run_test!(seal_empty); - #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + #[restate_core::test(flavor = "multi_thread", worker_threads = 4)] async fn local_loglet_append_after_seal_concurrent() -> googletest::Result<()> { - let node_env = TestCoreEnvBuilder::with_incoming_only_connector() + let _node_env = TestCoreEnvBuilder::with_incoming_only_connector() .set_provider_kind(ProviderKind::Local) .build() .await; - node_env - .tc - .run_in_scope("test", None, async { - let config = Live::from_value(Configuration::default()); - RocksDbManager::init(config.clone().map(|c| &c.common)); - - let log_store = RocksDbLogStore::create( - &config.pinned().bifrost.local, - config.clone().map(|c| &c.bifrost.local.rocksdb).boxed(), - ) - .await?; - - let log_writer = log_store - .create_writer() - .start(config.clone().map(|c| &c.bifrost.local).boxed())?; - - // Run the test 10 times - for i in 1..=10 { - let loglet = Arc::new(LocalLoglet::create( - i, - log_store.clone(), - log_writer.clone(), - )?); - crate::loglet::loglet_tests::append_after_seal_concurrent(loglet).await?; - } + let config = Live::from_value(Configuration::default()); + RocksDbManager::init(config.clone().map(|c| &c.common)); + + let log_store = RocksDbLogStore::create( + &config.pinned().bifrost.local, + config.clone().map(|c| &c.bifrost.local.rocksdb).boxed(), + ) + .await?; + + let log_writer = log_store + .create_writer() + .start(config.clone().map(|c| &c.bifrost.local).boxed())?; + + // Run the test 10 times + for i in 1..=10 { + let loglet = Arc::new(LocalLoglet::create( + i, + log_store.clone(), + log_writer.clone(), + )?); + crate::loglet::loglet_tests::append_after_seal_concurrent(loglet).await?; + } - googletest::Result::Ok(()) - }) - .await?; - node_env.tc.shutdown_node("test completed", 0).await; + TaskCenter::shutdown_node("test completed", 0).await; RocksDbManager::get().shutdown().await; Ok(()) } - #[test(tokio::test)] + #[test(restate_core::test)] async fn read_stream_with_filters() -> googletest::Result<()> { - run_in_test_env(|loglet| async { - let batch: Arc<[Record]> = vec![ - ("record-1", Keys::Single(1)).into(), - ("record-2", Keys::Single(2)).into(), - ("record-3", Keys::Single(1)).into(), - ] - .into(); - let offset = loglet.enqueue_batch(batch).await?.await?; + let loglet = create_loglet().await.into_test_result()?; + let batch: Arc<[Record]> = vec![ + ("record-1", Keys::Single(1)).into(), + ("record-2", Keys::Single(2)).into(), + ("record-3", Keys::Single(1)).into(), + ] + .into(); + let offset = loglet.enqueue_batch(batch).await?.await?; + + let key_filter = KeyFilter::Include(1); + let read_stream = loglet + .create_read_stream(key_filter, LogletOffset::OLDEST, Some(offset)) + .await?; - let key_filter = KeyFilter::Include(1); - let read_stream = loglet - .create_read_stream(key_filter, LogletOffset::OLDEST, Some(offset)) - .await?; + let records: Vec<_> = read_stream + .try_collect::>() + .await? + .into_iter() + .map(|log_entry| { + ( + log_entry.sequence_number(), + log_entry.decode_unchecked::(), + ) + }) + .collect(); - let records: Vec<_> = read_stream - .try_collect::>() - .await? - .into_iter() - .map(|log_entry| { - ( - log_entry.sequence_number(), - log_entry.decode_unchecked::(), - ) - }) - .collect(); - - assert_that!( - records, - elements_are![ - eq((LogletOffset::from(1), "record-1".to_owned())), - eq((LogletOffset::from(3), "record-3".to_owned())) - ] - ); - - Ok(()) - }) - .await + assert_that!( + records, + elements_are![ + eq((LogletOffset::from(1), "record-1".to_owned())), + eq((LogletOffset::from(3), "record-3".to_owned())) + ] + ); + + Ok(()) } } diff --git a/crates/bifrost/src/providers/memory_loglet.rs b/crates/bifrost/src/providers/memory_loglet.rs index 8006c9746..a1bb9ae0f 100644 --- a/crates/bifrost/src/providers/memory_loglet.rs +++ b/crates/bifrost/src/providers/memory_loglet.rs @@ -408,20 +408,14 @@ mod tests { macro_rules! run_test { ($test:ident) => { paste::paste! { - #[tokio::test(start_paused = true)] + #[restate_core::test(start_paused = true)] async fn []() -> googletest::Result<()> { - let node_env = TestCoreEnvBuilder::with_incoming_only_connector() - .set_provider_kind(ProviderKind::InMemory) - .build() - .await; - node_env - .tc - .run_in_scope("test", None, async { - let loglet = MemoryLoglet::new(LogletParams::from("112".to_string())); - crate::loglet::loglet_tests::$test(loglet).await - }) - .await?; - Ok(()) + let _node_env = TestCoreEnvBuilder::with_incoming_only_connector() + .set_provider_kind(ProviderKind::InMemory) + .build() + .await; + let loglet = MemoryLoglet::new(LogletParams::from("112".to_string())); + crate::loglet::loglet_tests::$test(loglet).await } } }; diff --git a/crates/bifrost/src/providers/replicated_loglet/loglet.rs b/crates/bifrost/src/providers/replicated_loglet/loglet.rs index 552470693..01baee29d 100644 --- a/crates/bifrost/src/providers/replicated_loglet/loglet.rs +++ b/crates/bifrost/src/providers/replicated_loglet/loglet.rs @@ -15,7 +15,6 @@ use futures::stream::BoxStream; use tracing::{debug, info, instrument}; use restate_core::network::{Networking, TransportConnect}; -use restate_core::task_center; use restate_types::logs::metadata::SegmentIndex; use restate_types::logs::{ KeyFilter, LogId, LogletOffset, Record, RecordCache, SequenceNumber, TailState, @@ -239,7 +238,6 @@ impl Loglet for ReplicatedLoglet { } SequencerAccess::Remote { .. } => { let task = FindTailTask::new( - task_center(), self.log_id, self.segment_index, self.my_params.clone(), @@ -347,7 +345,7 @@ mod tests { use test_log::test; use restate_core::network::NetworkServerBuilder; - use restate_core::TestCoreEnvBuilder; + use restate_core::{TaskCenter, TestCoreEnvBuilder}; use restate_log_server::LogServerService; use restate_rocksdb::RocksDbManager; use restate_types::config::{set_current_config, Configuration}; @@ -396,42 +394,36 @@ mod tests { let node_env = node_env.build().await; - node_env - .tc - .clone() - .run_in_scope("test", None, async { - RocksDbManager::init(config.clone().map(|c| &c.common)); - - log_server - .start(node_env.metadata_writer.clone(), &mut server_builder) - .await - .into_test_result()?; - - let loglet = Arc::new(ReplicatedLoglet::new( - LogId::new(1), - SegmentIndex::from(1), - loglet_params, - node_env.networking.clone(), - logserver_rpc, - sequencer_rpc, - record_cache.clone(), - )); + RocksDbManager::init(config.clone().map(|c| &c.common)); + + log_server + .start(node_env.metadata_writer.clone(), &mut server_builder) + .await + .into_test_result()?; + + let loglet = Arc::new(ReplicatedLoglet::new( + LogId::new(1), + SegmentIndex::from(1), + loglet_params, + node_env.networking.clone(), + logserver_rpc, + sequencer_rpc, + record_cache.clone(), + )); - let env = TestEnv { - loglet, - record_cache, - }; + let env = TestEnv { + loglet, + record_cache, + }; - future(env).await - }) - .await?; - node_env.tc.shutdown_node("test completed", 0).await; + future(env).await?; + TaskCenter::shutdown_node("test completed", 0).await; RocksDbManager::get().shutdown().await; Ok(()) } // ** Single-node replicated-loglet smoke tests ** - #[test(tokio::test(start_paused = true))] + #[test(restate_core::test(start_paused = true))] async fn test_append_local_sequencer_single_node() -> Result<()> { let loglet_id = ReplicatedLogletId::new_unchecked(122); let params = ReplicatedLogletParams { @@ -474,7 +466,7 @@ mod tests { } // ** Single-node replicated-loglet seal ** - #[test(tokio::test(start_paused = true))] + #[test(restate_core::test(start_paused = true))] async fn test_seal_local_sequencer_single_node() -> Result<()> { let loglet_id = ReplicatedLogletId::new_unchecked(122); let params = ReplicatedLogletParams { @@ -522,7 +514,7 @@ mod tests { // # Loglet Spec Tests On Single Node // ** Single-node replicated-loglet ** - #[test(tokio::test(start_paused = true))] + #[test(restate_core::test(start_paused = true))] async fn single_node_gapless_loglet_smoke_test() -> Result<()> { let record_cache = RecordCache::new(1_000_000); let loglet_id = ReplicatedLogletId::new_unchecked(122); @@ -538,7 +530,7 @@ mod tests { .await } - #[test(tokio::test(start_paused = true))] + #[test(restate_core::test(start_paused = true))] async fn single_node_single_loglet_readstream() -> Result<()> { let loglet_id = ReplicatedLogletId::new_unchecked(122); let params = ReplicatedLogletParams { @@ -554,7 +546,7 @@ mod tests { .await } - #[test(tokio::test(start_paused = true))] + #[test(restate_core::test(start_paused = true))] async fn single_node_single_loglet_readstream_with_trims() -> Result<()> { let loglet_id = ReplicatedLogletId::new_unchecked(122); let params = ReplicatedLogletParams { @@ -577,7 +569,7 @@ mod tests { .await } - #[test(tokio::test(start_paused = true))] + #[test(restate_core::test(start_paused = true))] async fn single_node_append_after_seal() -> Result<()> { let loglet_id = ReplicatedLogletId::new_unchecked(122); let params = ReplicatedLogletParams { @@ -593,7 +585,7 @@ mod tests { .await } - #[test(tokio::test(start_paused = true))] + #[test(restate_core::test(start_paused = true))] async fn single_node_append_after_seal_concurrent() -> Result<()> { let loglet_id = ReplicatedLogletId::new_unchecked(122); let params = ReplicatedLogletParams { @@ -610,7 +602,7 @@ mod tests { .await } - #[test(tokio::test(start_paused = true))] + #[test(restate_core::test(start_paused = true))] async fn single_node_seal_empty() -> Result<()> { let loglet_id = ReplicatedLogletId::new_unchecked(122); let params = ReplicatedLogletParams { diff --git a/crates/bifrost/src/providers/replicated_loglet/provider.rs b/crates/bifrost/src/providers/replicated_loglet/provider.rs index 70e9aa464..d164ccbfa 100644 --- a/crates/bifrost/src/providers/replicated_loglet/provider.rs +++ b/crates/bifrost/src/providers/replicated_loglet/provider.rs @@ -33,7 +33,6 @@ use crate::providers::replicated_loglet::tasks::PeriodicTailChecker; use crate::Error; pub struct Factory { - task_center: TaskCenter, metadata_store_client: MetadataStoreClient, networking: Networking, logserver_rpc_routers: LogServersRpc, @@ -44,7 +43,6 @@ pub struct Factory { impl Factory { pub fn new( - task_center: TaskCenter, metadata_store_client: MetadataStoreClient, networking: Networking, record_cache: RecordCache, @@ -61,7 +59,6 @@ impl Factory { let sequencer_rpc_routers = SequencersRpc::new(router_builder); // todo(asoli): Create a handler to answer to control plane monitoring questions Self { - task_center, metadata_store_client, networking, logserver_rpc_routers, @@ -81,7 +78,6 @@ impl LogletProviderFactory for Factory { async fn create(self: Box) -> Result, OperationError> { metric_definitions::describe_metrics(); let provider = Arc::new(ReplicatedLogletProvider::new( - self.task_center.clone(), self.metadata_store_client, self.networking, self.logserver_rpc_routers, @@ -90,23 +86,17 @@ impl LogletProviderFactory for Factory { )); // run the request pump. The request pump handles/routes incoming messages to our // locally hosted sequencers. - self.task_center.spawn( - TaskKind::NetworkMessageHandler, - "sequencers-ingress", - None, - { - let request_pump = self.request_pump; - let provider = provider.clone(); - async { request_pump.run(provider).await } - }, - )?; + TaskCenter::spawn(TaskKind::NetworkMessageHandler, "sequencers-ingress", { + let request_pump = self.request_pump; + let provider = provider.clone(); + request_pump.run(provider) + })?; Ok(provider) } } pub(super) struct ReplicatedLogletProvider { - task_center: TaskCenter, active_loglets: DashMap<(LogId, SegmentIndex), Arc>>, _metadata_store_client: MetadataStoreClient, networking: Networking, @@ -117,7 +107,6 @@ pub(super) struct ReplicatedLogletProvider { impl ReplicatedLogletProvider { fn new( - task_center: TaskCenter, metadata_store_client: MetadataStoreClient, networking: Networking, logserver_rpc_routers: LogServersRpc, @@ -125,7 +114,6 @@ impl ReplicatedLogletProvider { record_cache: RecordCache, ) -> Self { Self { - task_center, active_loglets: Default::default(), _metadata_store_client: metadata_store_client, networking, @@ -186,10 +174,9 @@ impl ReplicatedLogletProvider { ); let key_value = entry.insert(Arc::new(loglet)); let loglet = Arc::downgrade(key_value.value()); - let _ = self.task_center.spawn( + let _ = TaskCenter::spawn( TaskKind::BifrostBackgroundLowPriority, "periodic-tail-checker", - None, // todo: configuration PeriodicTailChecker::run(loglet_id, loglet, Duration::from_secs(2)), ); diff --git a/crates/bifrost/src/providers/replicated_loglet/read_path/read_stream_task.rs b/crates/bifrost/src/providers/replicated_loglet/read_path/read_stream_task.rs index 3d4540575..d041ef200 100644 --- a/crates/bifrost/src/providers/replicated_loglet/read_path/read_stream_task.rs +++ b/crates/bifrost/src/providers/replicated_loglet/read_path/read_stream_task.rs @@ -15,7 +15,7 @@ use tokio::sync::mpsc; use tracing::{info, trace}; use restate_core::network::{NetworkError, Networking, TransportConnect}; -use restate_core::{task_center, ShutdownError, TaskHandle, TaskKind}; +use restate_core::{ShutdownError, TaskCenter, TaskHandle, TaskKind}; use restate_types::config::Configuration; use restate_types::logs::{KeyFilter, LogletOffset, MatchKeyQuery, RecordCache, SequenceNumber}; use restate_types::net::log_server::{GetRecords, LogServerRequestHeader, MaybeRecord}; @@ -122,10 +122,9 @@ impl ReadStreamTask { stats: Stats::default(), move_beyond_global_tail, }; - let handle = task_center().spawn_unmanaged( + let handle = TaskCenter::spawn_unmanaged( TaskKind::ReplicatedLogletReadStream, "replicatedloglet-read-stream", - None, task.run(networking), )?; diff --git a/crates/bifrost/src/providers/replicated_loglet/remote_sequencer.rs b/crates/bifrost/src/providers/replicated_loglet/remote_sequencer.rs index 5a221b38b..baae009d5 100644 --- a/crates/bifrost/src/providers/replicated_loglet/remote_sequencer.rs +++ b/crates/bifrost/src/providers/replicated_loglet/remote_sequencer.rs @@ -23,7 +23,7 @@ use restate_core::{ rpc_router::{RpcRouter, RpcToken}, NetworkError, NetworkSendError, Networking, Outgoing, TransportConnect, WeakConnection, }, - task_center, ShutdownError, TaskKind, + ShutdownError, TaskCenter, TaskKind, }; use restate_types::{ config::Configuration, @@ -238,10 +238,9 @@ impl RemoteSequencerConnection { ) -> Result { let (tx, rx) = mpsc::unbounded_channel(); - task_center().spawn( + TaskCenter::spawn( TaskKind::NetworkMessageHandler, "remote-sequencer-connection", - None, Self::handle_appended_responses(known_global_tail, connection.clone(), rx), )?; @@ -494,7 +493,7 @@ mod test { use restate_core::{ network::{Incoming, MessageHandler, MockConnector}, - TaskCenterBuilder, TestCoreEnv, TestCoreEnvBuilder, + TestCoreEnv, TestCoreEnvBuilder, }; use restate_types::{ logs::{LogId, LogletOffset, Record, SequenceNumber, TailState}, @@ -564,17 +563,11 @@ mod test { F: FnOnce(TestEnv) -> O, { let (connector, _receiver) = MockConnector::new(100); - let tc = TaskCenterBuilder::default() - .default_runtime_handle(tokio::runtime::Handle::current()) - .ingress_runtime_handle(tokio::runtime::Handle::current()) - .build() - .expect("task_center builds"); let connector = Arc::new(connector); - let mut builder = - TestCoreEnvBuilder::with_transport_connector(tc.clone(), Arc::clone(&connector)) - .add_mock_nodes_config() - .add_message_handler(sequencer); + let mut builder = TestCoreEnvBuilder::with_transport_connector(Arc::clone(&connector)) + .add_mock_nodes_config() + .add_message_handler(sequencer); let sequencer_rpc = SequencersRpc::new(&mut builder.router_builder); @@ -595,20 +588,14 @@ mod test { ); let core_env = builder.build().await; - core_env - .tc - .clone() - .run_in_scope("test", None, async { - let env = TestEnv { - core_env, - remote_sequencer, - }; - test(env).await; - }) - .await; + let env = TestEnv { + core_env, + remote_sequencer, + }; + test(env).await; } - #[tokio::test] + #[restate_core::test] async fn test_remote_stream_ok() { let handler = SequencerMockHandler::default(); @@ -636,7 +623,7 @@ mod test { .await; } - #[tokio::test] + #[restate_core::test] async fn test_remote_stream_sealed() { let handler = SequencerMockHandler::with_reply_status(SequencerStatus::Sealed); diff --git a/crates/bifrost/src/providers/replicated_loglet/sequencer/mod.rs b/crates/bifrost/src/providers/replicated_loglet/sequencer/mod.rs index 04b0f9ed6..8b76cf659 100644 --- a/crates/bifrost/src/providers/replicated_loglet/sequencer/mod.rs +++ b/crates/bifrost/src/providers/replicated_loglet/sequencer/mod.rs @@ -21,7 +21,7 @@ use tracing::{debug, instrument, trace}; use restate_core::{ network::{rpc_router::RpcRouter, Networking, TransportConnect}, - task_center, ShutdownError, TaskKind, + ShutdownError, TaskCenter, TaskKind, }; use restate_types::{ config::Configuration, @@ -269,7 +269,7 @@ impl Sequencer { let fut = self.in_flight.track_future(appender.run()); - task_center().spawn(TaskKind::SequencerAppender, "sequencer-appender", None, fut)?; + TaskCenter::spawn(TaskKind::SequencerAppender, "sequencer-appender", fut)?; Ok(loglet_commit) } diff --git a/crates/bifrost/src/providers/replicated_loglet/tasks/digests.rs b/crates/bifrost/src/providers/replicated_loglet/tasks/digests.rs index b888d78c9..134da9015 100644 --- a/crates/bifrost/src/providers/replicated_loglet/tasks/digests.rs +++ b/crates/bifrost/src/providers/replicated_loglet/tasks/digests.rs @@ -16,7 +16,7 @@ use tracing::{debug, trace, warn}; use restate_core::network::rpc_router::{RpcError, RpcRouter}; use restate_core::network::{Networking, TransportConnect}; -use restate_core::{cancellation_watcher, task_center, ShutdownError}; +use restate_core::{cancellation_watcher, ShutdownError, TaskCenterFutureExt}; use restate_types::logs::{LogletOffset, SequenceNumber}; use restate_types::net::log_server::{ Digest, LogServerRequestHeader, RecordStatus, Status, Store, StoreFlags, @@ -233,13 +233,8 @@ impl Digests { let networking = networking.clone(); let msg = msg.clone(); let store_rpc = store_rpc.clone(); - let tc = task_center(); - async move { - tc.run_in_scope("repair-store", None, async move { - (node, store_rpc.call(&networking, node, msg).await) - }) - .await - } + + async move { (node, store_rpc.call(&networking, node, msg).await) }.in_current_tc() }); } let mut cancel = std::pin::pin!(cancellation_watcher()); diff --git a/crates/bifrost/src/providers/replicated_loglet/tasks/find_tail.rs b/crates/bifrost/src/providers/replicated_loglet/tasks/find_tail.rs index d8f504763..38771a13b 100644 --- a/crates/bifrost/src/providers/replicated_loglet/tasks/find_tail.rs +++ b/crates/bifrost/src/providers/replicated_loglet/tasks/find_tail.rs @@ -15,7 +15,7 @@ use tracing::{debug, error, info, instrument, trace, warn}; use restate_core::network::rpc_router::{RpcError, RpcRouter}; use restate_core::network::{Networking, Outgoing, TransportConnect}; -use restate_core::TaskCenter; +use restate_core::TaskCenterFutureExt; use restate_types::config::Configuration; use restate_types::logs::metadata::SegmentIndex; use restate_types::logs::{LogId, LogletOffset, RecordCache, SequenceNumber}; @@ -49,7 +49,6 @@ pub struct FindTailTask { log_id: LogId, segment_index: SegmentIndex, my_params: ReplicatedLogletParams, - task_center: TaskCenter, networking: Networking, logservers_rpc: LogServersRpc, sequencers_rpc: SequencersRpc, @@ -72,7 +71,6 @@ pub enum FindTailResult { impl FindTailTask { #[allow(clippy::too_many_arguments)] pub fn new( - task_center: TaskCenter, log_id: LogId, segment_index: SegmentIndex, my_params: ReplicatedLogletParams, @@ -83,7 +81,6 @@ impl FindTailTask { record_cache: RecordCache, ) -> Self { Self { - task_center, log_id, segment_index, networking, @@ -189,7 +186,6 @@ impl FindTailTask { let mut inflight_info_requests = JoinSet::new(); for node in effective_nodeset.iter() { inflight_info_requests.spawn({ - let tc = self.task_center.clone(); let networking = self.networking.clone(); let get_loglet_info_rpc = self.logservers_rpc.get_loglet_info.clone(); let known_global_tail = self.known_global_tail.clone(); @@ -201,9 +197,9 @@ impl FindTailTask { get_loglet_info_rpc: &get_loglet_info_rpc, known_global_tail: &known_global_tail, }; - tc.run_in_scope("find-tail-on-node", None, task.run(&networking)) - .await + task.run(&networking).await } + .in_current_tc() }); } @@ -307,7 +303,6 @@ impl FindTailTask { // We can start repair. match RepairTail::new( self.my_params.clone(), - self.task_center.clone(), self.networking.clone(), self.logservers_rpc.clone(), self.record_cache.clone(), @@ -418,16 +413,8 @@ impl FindTailTask { known_global_tail: self.known_global_tail.clone(), }; inflight_tail_update_watches.spawn({ - let tc = self.task_center.clone(); let networking = self.networking.clone(); - async move { - tc.run_in_scope( - "wait-for-tail-on-node", - None, - task.run(max_local_tail, networking), - ) - .await - } + task.run(max_local_tail, networking).in_current_tc() }); } loop { diff --git a/crates/bifrost/src/providers/replicated_loglet/tasks/get_trim_point.rs b/crates/bifrost/src/providers/replicated_loglet/tasks/get_trim_point.rs index 00c778744..4e8d525a7 100644 --- a/crates/bifrost/src/providers/replicated_loglet/tasks/get_trim_point.rs +++ b/crates/bifrost/src/providers/replicated_loglet/tasks/get_trim_point.rs @@ -8,11 +8,11 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. +use restate_core::TaskCenterFutureExt; use tokio::task::JoinSet; use tracing::trace; use restate_core::network::{Incoming, Networking, TransportConnect}; -use restate_core::task_center; use restate_types::config::Configuration; use restate_types::logs::{LogletOffset, SequenceNumber}; use restate_types::net::log_server::{GetLogletInfo, LogServerRequestHeader, LogletInfo, Status}; @@ -97,7 +97,6 @@ impl<'a> GetTrimPointTask<'a> { let networking = networking.clone(); let trim_rpc_router = self.logservers_rpc.get_loglet_info.clone(); let known_global_tail = self.known_global_tail.clone(); - let tc = task_center(); async move { let task = RunOnSingleNode::new( @@ -111,16 +110,10 @@ impl<'a> GetTrimPointTask<'a> { .log_server_retry_policy .clone(), ); - ( - node_id, - tc.run_in_scope( - "find-trimpoint-on-node", - None, - task.run(on_info_response, &networking), - ) - .await, - ) + + (node_id, task.run(on_info_response, &networking).await) } + .in_current_tc() }); } diff --git a/crates/bifrost/src/providers/replicated_loglet/tasks/repair_tail.rs b/crates/bifrost/src/providers/replicated_loglet/tasks/repair_tail.rs index 06e6e76f7..fc596ffe9 100644 --- a/crates/bifrost/src/providers/replicated_loglet/tasks/repair_tail.rs +++ b/crates/bifrost/src/providers/replicated_loglet/tasks/repair_tail.rs @@ -11,7 +11,7 @@ use std::time::Duration; use restate_core::network::{Networking, TransportConnect}; -use restate_core::{ShutdownError, TaskCenter}; +use restate_core::{ShutdownError, TaskCenterFutureExt}; use restate_types::logs::{KeyFilter, LogletOffset, RecordCache, SequenceNumber}; use restate_types::net::log_server::{GetDigest, LogServerRequestHeader}; use restate_types::replicated_loglet::{EffectiveNodeSet, ReplicatedLogletParams}; @@ -78,7 +78,6 @@ use super::digests::Digests; /// known_global_tail. This is a best-effort phase and it should not block the completion of the repair task. pub struct RepairTail { my_params: ReplicatedLogletParams, - task_center: TaskCenter, networking: Networking, logservers_rpc: LogServersRpc, record_cache: RecordCache, @@ -98,7 +97,6 @@ impl RepairTail { #[allow(clippy::too_many_arguments)] pub fn new( my_params: ReplicatedLogletParams, - task_center: TaskCenter, networking: Networking, logservers_rpc: LogServersRpc, record_cache: RecordCache, @@ -109,7 +107,6 @@ impl RepairTail { let digests = Digests::new(&my_params, start_offset, target_tail); RepairTail { my_params, - task_center, networking, logservers_rpc, record_cache, @@ -139,27 +136,24 @@ impl RepairTail { to_offset: self.digests.target_tail().prev(), }; get_digest_requests.spawn({ - let tc = self.task_center.clone(); let networking = self.networking.clone(); let logservers_rpc = self.logservers_rpc.clone(); let peer = *node; async move { - tc.run_in_scope("get-digest-from-node", None, async move { - loop { - // todo: handle retries with exponential backoff... - let Ok(incoming) = logservers_rpc - .get_digest - .call(&networking, peer, msg.clone()) - .await - else { - tokio::time::sleep(Duration::from_secs(1)).await; - continue; - }; - return incoming; - } - }) - .await + loop { + // todo: handle retries with exponential backoff... + let Ok(incoming) = logservers_rpc + .get_digest + .call(&networking, peer, msg.clone()) + .await + else { + tokio::time::sleep(Duration::from_secs(1)).await; + continue; + }; + return incoming; + } } + .in_current_tc() }); } diff --git a/crates/bifrost/src/providers/replicated_loglet/tasks/trim.rs b/crates/bifrost/src/providers/replicated_loglet/tasks/trim.rs index a9417d6f3..979ac3381 100644 --- a/crates/bifrost/src/providers/replicated_loglet/tasks/trim.rs +++ b/crates/bifrost/src/providers/replicated_loglet/tasks/trim.rs @@ -8,11 +8,11 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. +use restate_core::TaskCenterFutureExt; use tokio::task::JoinSet; use tracing::{debug, trace, warn}; use restate_core::network::{Incoming, Networking, TransportConnect}; -use restate_core::task_center; use restate_types::config::Configuration; use restate_types::logs::{LogletOffset, SequenceNumber}; use restate_types::net::log_server::{LogServerRequestHeader, Status, Trim, Trimmed}; @@ -118,7 +118,6 @@ impl<'a> TrimTask<'a> { let networking = networking.clone(); let trim_rpc_router = self.logservers_rpc.trim.clone(); let known_global_tail = self.known_global_tail.clone(); - let tc = task_center(); async move { let task = RunOnSingleNode::new( @@ -132,16 +131,10 @@ impl<'a> TrimTask<'a> { .log_server_retry_policy .clone(), ); - ( - node_id, - tc.run_in_scope( - "trim-on-node", - None, - task.run(on_trim_response, &networking), - ) - .await, - ) + + (node_id, task.run(on_trim_response, &networking).await) } + .in_current_tc() }); } diff --git a/crates/bifrost/src/read_stream.rs b/crates/bifrost/src/read_stream.rs index 2f6a20e8a..554f31b88 100644 --- a/crates/bifrost/src/read_stream.rs +++ b/crates/bifrost/src/read_stream.rs @@ -445,7 +445,7 @@ mod tests { use tracing::info; use tracing_test::traced_test; - use restate_core::{MetadataKind, TargetVersion, TaskCenter, TaskKind, TestCoreEnvBuilder2}; + use restate_core::{MetadataKind, TargetVersion, TaskCenter, TaskKind, TestCoreEnvBuilder}; use restate_rocksdb::RocksDbManager; use restate_types::config::{CommonOptions, Configuration}; use restate_types::live::{Constant, Live}; @@ -461,7 +461,7 @@ mod tests { async fn test_readstream_one_loglet() -> anyhow::Result<()> { const LOG_ID: LogId = LogId::new(0); - let _ = TestCoreEnvBuilder2::with_incoming_only_connector() + let _ = TestCoreEnvBuilder::with_incoming_only_connector() .set_provider_kind(ProviderKind::Local) .build() .await; @@ -490,28 +490,22 @@ mod tests { let read_counter = Arc::new(AtomicUsize::new(0)); // spawn a reader that reads 5 records and exits. let counter_clone = read_counter.clone(); - let id = TaskCenter::current().spawn( - TaskKind::TestRunner, - "read-records", - None, - async move { - for i in 6..=10 { - let record = reader.next().await.expect("to never terminate")?; - let expected_lsn = Lsn::from(i); - counter_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - assert_that!(record.sequence_number(), eq(expected_lsn)); - assert_that!(reader.read_pointer(), ge(record.sequence_number())); - assert_that!( - record.decode_unchecked::(), - eq(format!("record{}", expected_lsn)) - ); - } - Ok(()) - }, - )?; + let id = TaskCenter::spawn(TaskKind::TestRunner, "read-records", async move { + for i in 6..=10 { + let record = reader.next().await.expect("to never terminate")?; + let expected_lsn = Lsn::from(i); + counter_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + assert_that!(record.sequence_number(), eq(expected_lsn)); + assert_that!(reader.read_pointer(), ge(record.sequence_number())); + assert_that!( + record.decode_unchecked::(), + eq(format!("record{}", expected_lsn)) + ); + } + Ok(()) + })?; - let reader_bg_handle = - TaskCenter::with_current(|tc| tc.take_task(id)).expect("read-records task to exist"); + let reader_bg_handle = TaskCenter::take_task(id).expect("read-records task to exist"); tokio::task::yield_now().await; // Not finished, we still didn't append records @@ -547,7 +541,7 @@ mod tests { async fn test_read_stream_with_trim() -> anyhow::Result<()> { const LOG_ID: LogId = LogId::new(0); - let node_env = TestCoreEnvBuilder2::with_incoming_only_connector() + let node_env = TestCoreEnvBuilder::with_incoming_only_connector() .set_provider_kind(ProviderKind::Local) .build() .await; @@ -640,7 +634,7 @@ mod tests { async fn test_readstream_simple_multi_loglet() -> anyhow::Result<()> { const LOG_ID: LogId = LogId::new(0); - let node_env = TestCoreEnvBuilder2::with_incoming_only_connector() + let node_env = TestCoreEnvBuilder::with_incoming_only_connector() .set_provider_kind(ProviderKind::Local) .build() .await; @@ -796,7 +790,7 @@ mod tests { async fn test_readstream_sealed_multi_loglet() -> anyhow::Result<()> { const LOG_ID: LogId = LogId::new(0); - let node_env = TestCoreEnvBuilder2::with_incoming_only_connector() + let node_env = TestCoreEnvBuilder::with_incoming_only_connector() .set_provider_kind(ProviderKind::Local) .build() .await; @@ -914,7 +908,7 @@ mod tests { async fn test_readstream_prefix_trimmed() -> anyhow::Result<()> { const LOG_ID: LogId = LogId::new(0); - let node_env = TestCoreEnvBuilder2::with_incoming_only_connector() + let node_env = TestCoreEnvBuilder::with_incoming_only_connector() .set_provider_kind(ProviderKind::Local) .build() .await; diff --git a/crates/bifrost/src/service.rs b/crates/bifrost/src/service.rs index 0a1ed2eef..3222cc5c2 100644 --- a/crates/bifrost/src/service.rs +++ b/crates/bifrost/src/service.rs @@ -158,10 +158,9 @@ impl BifrostService { .map_err(|_| anyhow::anyhow!("bifrost must be initialized only once"))?; // We spawn the watchdog as a background long-running task - TaskCenter::current().spawn( + TaskCenter::spawn( TaskKind::BifrostBackgroundHighPriority, "bifrost-watchdog", - None, self.watchdog.run(), )?; diff --git a/crates/bifrost/src/watchdog.rs b/crates/bifrost/src/watchdog.rs index c6fd1783c..c6bece86b 100644 --- a/crates/bifrost/src/watchdog.rs +++ b/crates/bifrost/src/watchdog.rs @@ -54,10 +54,9 @@ impl Watchdog { match cmd { WatchdogCommand::ScheduleMetadataSync => { let bifrost = self.inner.clone(); - let _ = TaskCenter::current().spawn( + let _ = TaskCenter::spawn( TaskKind::MetadataBackgroundSync, "bifrost-metadata-sync", - None, async move { bifrost .sync_metadata() @@ -70,10 +69,9 @@ impl Watchdog { WatchdogCommand::WatchProvider(provider) => { self.live_providers.push(provider.clone()); // TODO: Convert to a managed background task - let _ = TaskCenter::current().spawn( + let _ = TaskCenter::spawn( TaskKind::BifrostBackgroundHighPriority, "bifrost-provider-on-start", - None, async move { provider.post_start().await; Ok(()) diff --git a/crates/core/derive/src/tc_test.rs b/crates/core/derive/src/tc_test.rs index 82d6f2a14..e9b75965f 100644 --- a/crates/core/derive/src/tc_test.rs +++ b/crates/core/derive/src/tc_test.rs @@ -463,8 +463,8 @@ fn parse_knobs(mut input: ItemFn, config: FinalConfig) -> TokenStream { .build() .expect("Failed building task-center"); - let ret = rt.block_on(#body_ident.in_tc(&task_center)); - rt.block_on(task_center.shutdown_node("completed", 0)); + let ret = rt.block_on(#body_ident.in_tc(&task_center.handle())); + rt.block_on(task_center.to_handle().shutdown_node("completed", 0)); ret } }; diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 857cbb194..9b3b54f3b 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -14,7 +14,7 @@ pub mod metadata_store; mod metric_definitions; pub mod network; pub mod partitions; -mod task_center; +pub mod task_center; pub mod worker_api; pub use error::*; @@ -44,16 +44,14 @@ pub use metadata::{ spawn_metadata_manager, Metadata, MetadataBuilder, MetadataKind, MetadataManager, MetadataWriter, SyncError, TargetVersion, }; -pub use task_center::*; - -#[cfg(any(test, feature = "test-util"))] -mod test_env; - -#[cfg(any(test, feature = "test-util"))] -mod test_env2; +pub use task_center::{ + cancellation_token, cancellation_watcher, is_cancellation_requested, my_node_id, AsyncRuntime, + MetadataFutureExt, RuntimeError, RuntimeRootTaskHandle, TaskCenter, TaskCenterBuildError, + TaskCenterBuilder, TaskCenterFutureExt, TaskContext, TaskHandle, TaskId, TaskKind, +}; #[cfg(any(test, feature = "test-util"))] -pub use test_env::{create_mock_nodes_config, NoOpMessageHandler, TestCoreEnv, TestCoreEnvBuilder}; +pub mod test_env; #[cfg(any(test, feature = "test-util"))] -pub use test_env2::{TestCoreEnv2, TestCoreEnvBuilder2}; +pub use test_env::{TestCoreEnv, TestCoreEnvBuilder}; diff --git a/crates/core/src/metadata/mod.rs b/crates/core/src/metadata.rs similarity index 99% rename from crates/core/src/metadata/mod.rs rename to crates/core/src/metadata.rs index e2dd6809a..c0c9d3227 100644 --- a/crates/core/src/metadata/mod.rs +++ b/crates/core/src/metadata.rs @@ -387,10 +387,9 @@ impl Default for VersionWatch { } pub fn spawn_metadata_manager(metadata_manager: MetadataManager) -> Result { - TaskCenter::current().spawn( + TaskCenter::spawn( TaskKind::MetadataBackgroundSync, "metadata-manager", - None, metadata_manager.run(), ) } diff --git a/crates/core/src/metadata/manager.rs b/crates/core/src/metadata/manager.rs index 5629a2e61..1da4fc4ab 100644 --- a/crates/core/src/metadata/manager.rs +++ b/crates/core/src/metadata/manager.rs @@ -611,7 +611,7 @@ mod tests { F: Fn(&Metadata) -> Version, S: Fn(&mut T, Version), { - let tc = TaskCenterBuilder::default().build()?; + let tc = TaskCenterBuilder::default().build()?.to_handle(); tc.block_on(async move { let metadata_builder = MetadataBuilder::default(); let metadata_store_client = MetadataStoreClient::new_in_memory(); @@ -682,7 +682,7 @@ mod tests { F: Fn(&Metadata) -> Version, I: Fn(&mut T), { - let tc = TaskCenterBuilder::default().build()?; + let tc = TaskCenterBuilder::default().build()?.to_handle(); tc.block_on(async move { let metadata_builder = MetadataBuilder::default(); let metadata_store_client = MetadataStoreClient::new_in_memory(); diff --git a/crates/core/src/metadata_store/mod.rs b/crates/core/src/metadata_store.rs similarity index 100% rename from crates/core/src/metadata_store/mod.rs rename to crates/core/src/metadata_store.rs diff --git a/crates/core/src/network/connection.rs b/crates/core/src/network/connection.rs index e383321d3..110848d2f 100644 --- a/crates/core/src/network/connection.rs +++ b/crates/core/src/network/connection.rs @@ -490,13 +490,12 @@ pub mod test_util { // Allow for messages received on this connection to be processed by a given message handler. pub fn process_with_message_handler( self, - task_center: &TaskCenter, handler: H, ) -> anyhow::Result<(WeakConnection, TaskHandle>)> { let mut router = MessageRouterBuilder::default(); router.add_message_handler(handler); let router = router.build(); - self.process_with_message_router(task_center, router) + self.process_with_message_router(router) } // Allow for messages received on this connection to be processed by a given message router. @@ -504,7 +503,6 @@ pub mod test_util { // drop the receive stream (simulates connection loss). pub fn process_with_message_router( self, - task_center: &TaskCenter, router: R, ) -> anyhow::Result<(WeakConnection, TaskHandle>)> { let Self { @@ -530,10 +528,9 @@ pub mod test_util { connection, recv_stream, }; - let handle = task_center.spawn_unmanaged( + let handle = TaskCenter::spawn_unmanaged( TaskKind::ConnectionReactor, "test-message-processor", - None, async move { message_processor.run().await }, )?; Ok((weak, handle)) @@ -542,7 +539,6 @@ pub mod test_util { // Allow for messages received on this connection to be forwarded to the supplied sender. pub fn forward_to_sender( self, - task_center: &TaskCenter, sender: mpsc::Sender<(GenerationalNodeId, Incoming)>, ) -> anyhow::Result<(WeakConnection, TaskHandle>)> { let handler = ForwardingHandler { @@ -550,7 +546,7 @@ pub mod test_util { inner_sender: sender, }; - self.process_with_message_router(task_center, handler) + self.process_with_message_router(handler) } } diff --git a/crates/core/src/network/connection_manager.rs b/crates/core/src/network/connection_manager.rs index dd9ce033d..ce7d5ff59 100644 --- a/crates/core/src/network/connection_manager.rs +++ b/crates/core/src/network/connection_manager.rs @@ -41,8 +41,7 @@ use super::{Handler, MessageRouter}; use crate::metadata::Urgency; use crate::network::handshake::{negotiate_protocol_version, wait_for_hello}; use crate::network::{Incoming, PeerMetadataVersion}; -use crate::{cancellation_watcher, current_task_id, TaskId, TaskKind}; -use crate::{Metadata, TaskCenter}; +use crate::{Metadata, TaskCenter, TaskContext, TaskId, TaskKind}; struct ConnectionManagerInner { router: MessageRouter, @@ -495,11 +494,9 @@ async fn run_reactor( where S: Stream> + Unpin + Send, { - Span::current().record( - "task_id", - tracing::field::display(current_task_id().unwrap()), - ); - let mut cancellation = std::pin::pin!(cancellation_watcher()); + let current_task = TaskContext::current(); + Span::current().record("task_id", tracing::field::display(current_task.id())); + let mut cancellation = std::pin::pin!(current_task.cancellation_token().cancelled()); let mut seen_versions = MetadataVersions::default(); // Receive loop @@ -705,9 +702,8 @@ fn on_connection_draining( } fn on_connection_terminated(inner_manager: &Mutex) { - let task_id = current_task_id().expect("TaskId is set"); let mut guard = inner_manager.lock(); - guard.drop_connection(task_id); + guard.drop_connection(TaskContext::with_current(|ctx| ctx.id())); } #[derive(Debug, Clone, PartialEq, derive_more::Index, derive_more::IndexMut)] @@ -782,232 +778,210 @@ mod tests { use restate_types::Version; use crate::network::MockPeerConnection; - use crate::{TestCoreEnv, TestCoreEnvBuilder}; + use crate::{self as restate_core, TestCoreEnv, TestCoreEnvBuilder}; // Test handshake with a client - #[tokio::test] + #[restate_core::test] async fn test_hello_welcome_handshake() -> Result<()> { - let test_setup = TestCoreEnv::create_with_single_node(1, 1).await; - test_setup - .tc - .run_in_scope("test", None, async { - let metadata = crate::metadata(); - let connections = ConnectionManager::new_incoming_only(metadata.clone()); - - let _mock_connection = MockPeerConnection::connect( - GenerationalNodeId::new(1, 1), - metadata.nodes_config_version(), - metadata.nodes_config_ref().cluster_name().to_owned(), - &connections, - 10, - ) - .await - .unwrap(); - - Ok(()) - }) - .await + let _env = TestCoreEnv::create_with_single_node(1, 1).await; + let metadata = Metadata::current(); + let connections = ConnectionManager::new_incoming_only(metadata.clone()); + + let _mock_connection = MockPeerConnection::connect( + GenerationalNodeId::new(1, 1), + metadata.nodes_config_version(), + metadata.nodes_config_ref().cluster_name().to_owned(), + &connections, + 10, + ) + .await + .unwrap(); + + Ok(()) } - #[tokio::test(start_paused = true)] + #[restate_core::test(start_paused = true)] async fn test_hello_welcome_timeout() -> Result<()> { - let test_setup = TestCoreEnv::create_with_single_node(1, 1).await; - let metadata = test_setup.metadata; + let _env = TestCoreEnv::create_with_single_node(1, 1).await; + let metadata = Metadata::current(); let net_opts = NetworkingOptions::default(); - test_setup - .tc - .run_in_scope("test", None, async { - let (_tx, rx) = mpsc::channel(1); - let connections = ConnectionManager::new_incoming_only(metadata); - - let start = tokio::time::Instant::now(); - let incoming = ReceiverStream::new(rx); - let resp = connections.accept_incoming_connection(incoming).await; - assert!(resp.is_err()); - assert!(matches!( - resp, - Err(NetworkError::ProtocolError( - ProtocolError::HandshakeTimeout(_) - )) - )); - assert!(start.elapsed() >= net_opts.handshake_timeout.into()); - Ok(()) - }) - .await + let (_tx, rx) = mpsc::channel(1); + let connections = ConnectionManager::new_incoming_only(metadata); + + let start = tokio::time::Instant::now(); + let incoming = ReceiverStream::new(rx); + let resp = connections.accept_incoming_connection(incoming).await; + assert!(resp.is_err()); + assert!(matches!( + resp, + Err(NetworkError::ProtocolError( + ProtocolError::HandshakeTimeout(_) + )) + )); + assert!(start.elapsed() >= net_opts.handshake_timeout.into()); + Ok(()) } - #[tokio::test] + #[restate_core::test] async fn test_bad_handshake() -> Result<()> { let test_setup = TestCoreEnv::create_with_single_node(1, 1).await; let metadata = test_setup.metadata; - test_setup - .tc - .run_in_scope("test", None, async { - let (tx, rx) = mpsc::channel(1); - let my_node_id = metadata.my_node_id(); - - // unsupported protocol version - let hello = Hello { - min_protocol_version: ProtocolVersion::Unknown.into(), - max_protocol_version: ProtocolVersion::Unknown.into(), - my_node_id: Some(my_node_id.into()), - cluster_name: metadata.nodes_config_ref().cluster_name().to_owned(), - }; - let hello = Message::new( - Header::new( - metadata.nodes_config_version(), - None, - None, - None, - crate::network::generate_msg_id(), - None, - ), - hello, - ); - tx.send(Ok(hello)) - .await - .expect("Channel accept hello message"); - - let connections = ConnectionManager::new_incoming_only(metadata.clone()); - let incoming = ReceiverStream::new(rx); - let resp = connections.accept_incoming_connection(incoming).await; - assert!(resp.is_err()); - assert!(matches!( - resp, - Err(NetworkError::ProtocolError( - ProtocolError::UnsupportedVersion(proto_version) - )) if proto_version == ProtocolVersion::Unknown as i32 - )); - - // cluster name mismatch - let (tx, rx) = mpsc::channel(1); - let my_node_id = metadata.my_node_id(); - let hello = Hello { - min_protocol_version: MIN_SUPPORTED_PROTOCOL_VERSION.into(), - max_protocol_version: CURRENT_PROTOCOL_VERSION.into(), - my_node_id: Some(my_node_id.into()), - cluster_name: "Random-cluster".to_owned(), - }; - let hello = Message::new( - Header::new( - metadata.nodes_config_version(), - None, - None, - None, - crate::network::generate_msg_id(), - None, - ), - hello, - ); - tx.send(Ok(hello)).await?; + let (tx, rx) = mpsc::channel(1); + let my_node_id = metadata.my_node_id(); + + // unsupported protocol version + let hello = Hello { + min_protocol_version: ProtocolVersion::Unknown.into(), + max_protocol_version: ProtocolVersion::Unknown.into(), + my_node_id: Some(my_node_id.into()), + cluster_name: metadata.nodes_config_ref().cluster_name().to_owned(), + }; + let hello = Message::new( + Header::new( + metadata.nodes_config_version(), + None, + None, + None, + crate::network::generate_msg_id(), + None, + ), + hello, + ); + tx.send(Ok(hello)) + .await + .expect("Channel accept hello message"); + + let connections = ConnectionManager::new_incoming_only(metadata.clone()); + let incoming = ReceiverStream::new(rx); + let resp = connections.accept_incoming_connection(incoming).await; + assert!(resp.is_err()); + assert!(matches!( + resp, + Err(NetworkError::ProtocolError( + ProtocolError::UnsupportedVersion(proto_version) + )) if proto_version == ProtocolVersion::Unknown as i32 + )); + + // cluster name mismatch + let (tx, rx) = mpsc::channel(1); + let my_node_id = metadata.my_node_id(); + let hello = Hello { + min_protocol_version: MIN_SUPPORTED_PROTOCOL_VERSION.into(), + max_protocol_version: CURRENT_PROTOCOL_VERSION.into(), + my_node_id: Some(my_node_id.into()), + cluster_name: "Random-cluster".to_owned(), + }; + let hello = Message::new( + Header::new( + metadata.nodes_config_version(), + None, + None, + None, + crate::network::generate_msg_id(), + None, + ), + hello, + ); + tx.send(Ok(hello)).await?; - let connections = ConnectionManager::new_incoming_only(metadata.clone()); - let incoming = ReceiverStream::new(rx); - let err = connections - .accept_incoming_connection(incoming) - .await - .err() - .unwrap(); - assert!(matches!( - err, - NetworkError::ProtocolError(ProtocolError::HandshakeFailed( - "cluster name mismatch" - )) - )); - Ok(()) - }) + let connections = ConnectionManager::new_incoming_only(metadata.clone()); + let incoming = ReceiverStream::new(rx); + let err = connections + .accept_incoming_connection(incoming) .await + .err() + .unwrap(); + assert!(matches!( + err, + NetworkError::ProtocolError(ProtocolError::HandshakeFailed("cluster name mismatch")) + )); + Ok(()) } - #[tokio::test] + #[restate_core::test] async fn test_node_generation() -> Result<()> { - let test_setup = TestCoreEnv::create_with_single_node(1, 2).await; - let metadata = test_setup.metadata; - test_setup - .tc - .run_in_scope("test", None, async { - let (tx, rx) = mpsc::channel(1); - let mut my_node_id = metadata.my_node_id(); - assert_eq!(2, my_node_id.generation()); - my_node_id.bump_generation(); - - // newer generation - let hello = Hello::new( - my_node_id, - metadata.nodes_config_ref().cluster_name().to_owned(), - ); - let hello = Message::new( - Header::new( - metadata.nodes_config_version(), - None, - None, - None, - crate::network::generate_msg_id(), - None, - ), - hello, - ); - tx.send(Ok(hello)) - .await - .expect("Channel accept hello message"); + let _env = TestCoreEnv::create_with_single_node(1, 2).await; + let metadata = Metadata::current(); + let (tx, rx) = mpsc::channel(1); + let mut my_node_id = metadata.my_node_id(); + assert_eq!(2, my_node_id.generation()); + my_node_id.bump_generation(); + + // newer generation + let hello = Hello::new( + my_node_id, + metadata.nodes_config_ref().cluster_name().to_owned(), + ); + let hello = Message::new( + Header::new( + metadata.nodes_config_version(), + None, + None, + None, + crate::network::generate_msg_id(), + None, + ), + hello, + ); + tx.send(Ok(hello)) + .await + .expect("Channel accept hello message"); - let connections = ConnectionManager::new_incoming_only(metadata.clone()); + let connections = ConnectionManager::new_incoming_only(metadata.clone()); - let incoming = ReceiverStream::new(rx); - let err = connections - .accept_incoming_connection(incoming) - .await - .err() - .unwrap(); - - assert!(matches!( - err, - NetworkError::ProtocolError(ProtocolError::HandshakeFailed( - "cannot accept a connection to the same NodeID from a different generation", - )) - )); - - // Unrecognized node Id - let (tx, rx) = mpsc::channel(1); - let my_node_id = GenerationalNodeId::new(55, 2); - - let hello = Hello::new( - my_node_id, - metadata.nodes_config_ref().cluster_name().to_owned(), - ); - let hello = Message::new( - Header::new( - metadata.nodes_config_version(), - None, - None, - None, - crate::network::generate_msg_id(), - None, - ), - hello, - ); - tx.send(Ok(hello)) - .await - .expect("Channel accept hello message"); + let incoming = ReceiverStream::new(rx); + let err = connections + .accept_incoming_connection(incoming) + .await + .err() + .unwrap(); - let connections = ConnectionManager::new_incoming_only(metadata); + assert!(matches!( + err, + NetworkError::ProtocolError(ProtocolError::HandshakeFailed( + "cannot accept a connection to the same NodeID from a different generation", + )) + )); - let incoming = ReceiverStream::new(rx); - let err = connections - .accept_incoming_connection(incoming) - .await - .err() - .unwrap(); - assert!(matches!( - err, - NetworkError::UnknownNode(NodesConfigError::UnknownNodeId(_)) - )); - Ok(()) - }) + // Unrecognized node Id + let (tx, rx) = mpsc::channel(1); + let my_node_id = GenerationalNodeId::new(55, 2); + + let hello = Hello::new( + my_node_id, + metadata.nodes_config_ref().cluster_name().to_owned(), + ); + let hello = Message::new( + Header::new( + metadata.nodes_config_version(), + None, + None, + None, + crate::network::generate_msg_id(), + None, + ), + hello, + ); + tx.send(Ok(hello)) .await + .expect("Channel accept hello message"); + + let connections = ConnectionManager::new_incoming_only(metadata); + + let incoming = ReceiverStream::new(rx); + let err = connections + .accept_incoming_connection(incoming) + .await + .err() + .unwrap(); + assert!(matches!( + err, + NetworkError::UnknownNode(NodesConfigError::UnknownNodeId(_)) + )); + Ok(()) } - #[test(tokio::test(start_paused = true))] + #[test(restate_core::test(start_paused = true))] async fn fetching_metadata_updates_through_message_headers() -> Result<()> { let mut nodes_config = NodesConfiguration::new(Version::MIN, "test-cluster".to_owned()); @@ -1026,49 +1000,44 @@ mod tests { .build() .await; - test_env - .tc - .run_in_scope("test", None, async { - let metadata = crate::metadata(); - - let mut connection = MockPeerConnection::connect( - node_id, - metadata.nodes_config_version(), - metadata.nodes_config_ref().cluster_name().to_string(), - test_env.networking.connection_manager(), - 10, - ) - .await - .into_test_result()?; - - let request = GetNodeState {}; - let partition_table_version = metadata.partition_table_version().next(); - let header = Header::new( - metadata.nodes_config_version(), - None, - None, - Some(partition_table_version), - crate::network::generate_msg_id(), - None, - ); + let metadata = Metadata::current(); - connection - .send_raw(request, header) - .await - .into_test_result()?; - - // we expect the request to go throught he existing open connection to my node - let message = connection.recv_stream.next().await.expect("some message"); - assert_get_metadata_request( - message, - connection.protocol_version, - MetadataKind::PartitionTable, - partition_table_version, - ); + let mut connection = MockPeerConnection::connect( + node_id, + metadata.nodes_config_version(), + metadata.nodes_config_ref().cluster_name().to_string(), + test_env.networking.connection_manager(), + 10, + ) + .await + .into_test_result()?; + + let request = GetNodeState {}; + let partition_table_version = metadata.partition_table_version().next(); + let header = Header::new( + metadata.nodes_config_version(), + None, + None, + Some(partition_table_version), + crate::network::generate_msg_id(), + None, + ); - Ok(()) - }) + connection + .send_raw(request, header) .await + .into_test_result()?; + + // we expect the request to go throught he existing open connection to my node + let message = connection.recv_stream.next().await.expect("some message"); + assert_get_metadata_request( + message, + connection.protocol_version, + MetadataKind::PartitionTable, + partition_table_version, + ); + + Ok(()) } fn assert_get_metadata_request( diff --git a/crates/core/src/network/net_util.rs b/crates/core/src/network/net_util.rs index dcfd06cd2..ea25b319e 100644 --- a/crates/core/src/network/net_util.rs +++ b/crates/core/src/network/net_util.rs @@ -28,7 +28,7 @@ use restate_types::config::{MetadataStoreClientOptions, NetworkingOptions}; use restate_types::errors::GenericError; use restate_types::net::{AdvertisedAddress, BindAddress}; -use crate::{cancellation_watcher, ShutdownError, TaskCenter, TaskKind}; +use crate::{cancellation_watcher, task_center, ShutdownError, TaskCenter, TaskKind}; pub fn create_tonic_channel_from_advertised_address( address: AdvertisedAddress, @@ -231,12 +231,12 @@ where #[derive(Clone)] struct TaskCenterExecutor { - task_center: TaskCenter, + task_center: task_center::Handle, name: &'static str, } impl TaskCenterExecutor { - fn new(task_center: TaskCenter, name: &'static str) -> Self { + fn new(task_center: task_center::Handle, name: &'static str) -> Self { Self { task_center, name } } } @@ -248,7 +248,7 @@ where { fn execute(&self, fut: F) { // ignore shutdown error - self.task_center.run_in_scope_sync(|| { + self.task_center.run_sync(|| { let _ = TaskCenter::spawn_child(TaskKind::RpcConnection, self.name, async move { // ignore the future output let _ = fut.await; diff --git a/crates/core/src/network/transport_connector.rs b/crates/core/src/network/transport_connector.rs index d5e6d7984..5984d9d1b 100644 --- a/crates/core/src/network/transport_connector.rs +++ b/crates/core/src/network/transport_connector.rs @@ -105,7 +105,7 @@ pub mod test_util { use super::{NetworkError, ProtocolError}; use crate::network::{Incoming, MockPeerConnection, PartialPeerConnection, WeakConnection}; - use crate::{TaskCenter, TaskHandle, TaskKind}; + use crate::{my_node_id, TaskCenter, TaskHandle, TaskKind}; #[derive(Clone)] pub struct MockConnector { @@ -148,7 +148,7 @@ pub mod test_util { let peer_connection = PartialPeerConnection { my_node_id: node_id, - peer: crate::metadata().my_node_id(), + peer: my_node_id(), sender, recv_stream: output_stream.boxed(), created: Instant::now(), @@ -177,7 +177,6 @@ pub mod test_util { impl MessageCollectorMockConnector { pub fn new( - task_center: TaskCenter, sendbuf: usize, sender: mpsc::Sender<(GenerationalNodeId, Incoming)>, ) -> Arc { @@ -188,20 +187,17 @@ pub mod test_util { }); // start acceptor - let _ = task_center - .clone() - .spawn(TaskKind::TestRunner, "test-connection-acceptor", None, { - let connector = connector.clone(); - async move { - while let Some(connection) = new_connections.recv().await { - let (connection, task) = - connection.forward_to_sender(&task_center, sender.clone())?; - connector.tasks.lock().push((connection, task)); - } - Ok(()) + TaskCenter::spawn(TaskKind::RpcConnection, "test-connection-acceptor", { + let connector = connector.clone(); + async move { + while let Some(connection) = new_connections.recv().await { + let (connection, task) = connection.forward_to_sender(sender.clone())?; + connector.tasks.lock().push((connection, task)); } - }) - .unwrap(); + Ok(()) + } + }) + .unwrap(); connector } } diff --git a/crates/core/src/network/types.rs b/crates/core/src/network/types.rs index c66ab781c..a0a9f7870 100644 --- a/crates/core/src/network/types.rs +++ b/crates/core/src/network/types.rs @@ -22,7 +22,7 @@ use restate_types::net::RpcRequest; use restate_types::protobuf::node::Header; use restate_types::{GenerationalNodeId, NodeId, Version}; -use crate::with_metadata; +use crate::Metadata; use super::connection::OwnedConnection; use super::metric_definitions::CONNECTION_SEND_DURATION; @@ -425,7 +425,7 @@ impl Outgoing { NetworkError::ConnectionClosed(connection.peer()) ); - with_metadata(|metadata| { + Metadata::with_current(|metadata| { permit.send(self, metadata); }); CONNECTION_SEND_DURATION.record(send_start.elapsed()); @@ -443,7 +443,7 @@ impl Outgoing { Err(e) => return Err(NetworkSendError::new(self, e)), }; - with_metadata(|metadata| { + Metadata::with_current(|metadata| { permit.send(self, metadata); }); CONNECTION_SEND_DURATION.record(send_start.elapsed()); @@ -459,7 +459,7 @@ impl Outgoing { let connection = bail_on_error!(self, self.try_upgrade()); let permit = bail_on_error!(self, connection.try_reserve()); - with_metadata(|metadata| { + Metadata::with_current(|metadata| { permit.send(self, metadata); }); diff --git a/crates/core/src/partitions/mod.rs b/crates/core/src/partitions.rs similarity index 95% rename from crates/core/src/partitions/mod.rs rename to crates/core/src/partitions.rs index 0e0b8b280..bd0885c57 100644 --- a/crates/core/src/partitions/mod.rs +++ b/crates/core/src/partitions.rs @@ -26,9 +26,7 @@ use restate_types::metadata_store::keys::SCHEDULING_PLAN_KEY; use restate_types::{NodeId, Version, Versioned}; use crate::metadata_store::MetadataStoreClient; -use crate::{ - cancellation_watcher, task_center, ShutdownError, TaskCenter, TaskHandle, TaskId, TaskKind, -}; +use crate::{cancellation_watcher, ShutdownError, TaskCenter, TaskHandle, TaskId, TaskKind}; pub type CommandSender = mpsc::Sender; pub type CommandReceiver = mpsc::Receiver; @@ -190,16 +188,10 @@ impl PartitionRoutingRefresher { let partition_to_node_mappings = self.inner.clone(); let metadata_store_client = self.metadata_store_client.clone(); - let task = task_center().spawn_unmanaged( + let task = TaskCenter::spawn_unmanaged( TaskKind::Disposable, "refresh-routing-information", - None, - { - async move { - sync_routing_information(partition_to_node_mappings, metadata_store_client) - .await; - } - }, + sync_routing_information(partition_to_node_mappings, metadata_store_client), ); self.inflight_refresh_task = task.ok(); } else { @@ -209,13 +201,11 @@ impl PartitionRoutingRefresher { } pub fn spawn_partition_routing_refresher( - tc: &TaskCenter, partition_routing_refresher: PartitionRoutingRefresher, ) -> Result { - tc.spawn( + TaskCenter::spawn( TaskKind::MetadataBackgroundSync, "partition-routing-refresher", - None, partition_routing_refresher.run(), ) } diff --git a/crates/core/src/task_center/mod.rs b/crates/core/src/task_center.rs similarity index 70% rename from crates/core/src/task_center/mod.rs rename to crates/core/src/task_center.rs index dd40530d6..ce9312284 100644 --- a/crates/core/src/task_center/mod.rs +++ b/crates/core/src/task_center.rs @@ -10,12 +10,16 @@ mod builder; mod extensions; +mod handle; +mod monitoring; mod runtime; mod task; mod task_kind; pub use builder::*; pub use extensions::*; +pub use handle::*; +pub use monitoring::*; pub use runtime::*; pub use task::*; pub use task_kind::*; @@ -28,14 +32,13 @@ use std::sync::{Arc, OnceLock}; use std::time::{Duration, Instant}; use futures::FutureExt; -use metrics::{counter, gauge}; +use metrics::counter; use parking_lot::Mutex; -use tokio::runtime::RuntimeMetrics; use tokio::sync::oneshot; use tokio::task::LocalSet; use tokio::task_local; -use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned}; -use tracing::{debug, error, info, instrument, trace, warn}; +use tokio_util::sync::CancellationToken; +use tracing::{debug, error, info, trace, warn}; use restate_types::identifiers::PartitionId; use restate_types::GenerationalNodeId; @@ -49,7 +52,7 @@ const EXIT_CODE_FAILURE: i32 = 1; task_local! { // Current task center - pub(self) static CURRENT_TASK_CENTER: TaskCenter; + pub(self) static CURRENT_TASK_CENTER: handle::Handle; // Tasks provide access to their context static TASK_CONTEXT: TaskContext; @@ -72,59 +75,16 @@ pub enum RuntimeError { } /// Task center is used to manage long-running and background tasks and their lifecycle. -#[derive(Clone, derive_more::Debug)] -#[debug("TaskCenter({})", inner.id)] -pub struct TaskCenter { - inner: Arc, -} - -static_assertions::assert_impl_all!(TaskCenter: Send, Sync, Clone); +pub struct TaskCenter {} impl TaskCenter { - fn new( - default_runtime_handle: tokio::runtime::Handle, - ingress_runtime_handle: tokio::runtime::Handle, - default_runtime: Option, - ingress_runtime: Option, - // used in tests to start all runtimes with clock paused. Note that this only impacts - // partition processor runtimes - pause_time: bool, - ) -> Self { - metric_definitions::describe_metrics(); - let root_task_context = TaskContext { - id: TaskId::ROOT, - name: "::", - kind: TaskKind::InPlace, - cancellation_token: CancellationToken::new(), - partition_id: None, - }; - Self { - inner: Arc::new(TaskCenterInner { - id: rand::random(), - start_time: Instant::now(), - default_runtime_handle, - default_runtime, - ingress_runtime_handle, - ingress_runtime, - global_cancel_token: CancellationToken::new(), - shutdown_requested: AtomicBool::new(false), - current_exit_code: AtomicI32::new(0), - managed_tasks: Mutex::new(HashMap::new()), - global_metadata: OnceLock::new(), - managed_runtimes: Mutex::new(HashMap::with_capacity(64)), - root_task_context, - pause_time, - }), - } - } - - pub fn try_current() -> Option { + pub fn try_current() -> Option { Self::try_with_current(Clone::clone) } pub fn try_with_current(f: F) -> Option where - F: FnOnce(&TaskCenter) -> R, + F: FnOnce(&Handle) -> R, { CURRENT_TASK_CENTER.try_with(|tc| f(tc)).ok() } @@ -132,109 +92,262 @@ impl TaskCenter { /// Get the current task center. Use this to spawn tasks on the current task center. /// This must be called from within a task-center task. #[track_caller] - pub fn current() -> TaskCenter { + pub fn current() -> Handle { Self::with_current(Clone::clone) } #[track_caller] pub fn with_current(f: F) -> R where - F: FnOnce(&TaskCenter) -> R, + F: FnOnce(&Handle) -> R, { CURRENT_TASK_CENTER .try_with(|tc| f(tc)) .expect("called outside task-center task") } - pub fn default_runtime_metrics(&self) -> RuntimeMetrics { - self.inner.default_runtime_handle.metrics() + #[track_caller] + /// Attempt to access task-level overridden metadata first, if we don't have an override, + /// fallback to task-center's level metadata. + pub(crate) fn with_metadata(f: F) -> Option + where + F: FnOnce(&Metadata) -> R, + { + Self::with_current(|tc| tc.with_metadata(f)) } - pub fn ingress_runtime_metrics(&self) -> RuntimeMetrics { - self.inner.ingress_runtime_handle.metrics() + /// Attempt to set the global metadata handle. This should be called once + /// at the startup of the node. + pub fn try_set_global_metadata(metadata: Metadata) -> bool { + Self::with_current(|tc| tc.try_set_global_metadata(metadata)) } - pub fn managed_runtime_metrics(&self) -> Vec<(&'static str, RuntimeMetrics)> { - let guard = self.inner.managed_runtimes.lock(); - guard.iter().map(|(k, v)| (*k, v.metrics())).collect() + /// Launch a new task + #[track_caller] + pub fn spawn(kind: TaskKind, name: &'static str, future: F) -> Result + where + F: Future> + Send + 'static, + { + Self::with_current(|tc| tc.spawn(kind, name, future)) } - /// How long has the task-center been running? - pub fn age(&self) -> Duration { - self.inner.start_time.elapsed() + /// Spawn a new task that is a child of the current task. The child task will be cancelled if the parent + /// task is cancelled. At the moment, the parent task will not automatically wait for children tasks to + /// finish before completion, but this might change in the future if the need for that arises. + #[track_caller] + pub fn spawn_child( + kind: TaskKind, + name: &'static str, + future: F, + ) -> Result + where + F: Future> + Send + 'static, + { + Self::with_current(|tc| tc.spawn_child(kind, name, future)) } - /// Submit telemetry for all runtimes to metrics recorder - pub fn submit_metrics(&self) { - Self::submit_runtime_metrics("default", self.default_runtime_metrics()); - Self::submit_runtime_metrics("ingress", self.ingress_runtime_metrics()); + #[track_caller] + pub fn spawn_unmanaged( + kind: TaskKind, + name: &'static str, + future: F, + ) -> Result, ShutdownError> + where + F: Future + Send + 'static, + T: Send + 'static, + { + Self::with_current(|tc| tc.spawn_unmanaged(kind, name, future)) + } - // Partition processor runtimes - let processor_runtimes = self.managed_runtime_metrics(); - for (task_name, metrics) in processor_runtimes { - Self::submit_runtime_metrics(task_name, metrics); - } + /// Must be called within a Localset-scoped task, not from a normal spawned task. + /// If ran from a non-localset task, this will panic. + #[track_caller] + pub fn spawn_local( + kind: TaskKind, + name: &'static str, + future: F, + ) -> Result + where + F: Future> + 'static, + { + Self::with_current(|tc| tc.spawn_local(kind, name, future)) } - /// Use to monitor an on-going shutdown when requested - pub fn watch_shutdown(&self) -> WaitForCancellationFutureOwned { - self.inner.global_cancel_token.clone().cancelled_owned() + /// Starts the `root_future` on a new runtime. The runtime is stopped once the root future + /// completes. + #[track_caller] + pub fn start_runtime( + root_task_kind: TaskKind, + runtime_name: &'static str, + partition_id: Option, + root_future: impl FnOnce() -> F + Send + 'static, + ) -> Result>, RuntimeError> + where + F: Future> + 'static, + { + Self::with_current(|tc| { + tc.start_runtime(root_task_kind, runtime_name, partition_id, root_future) + }) } - /// Use to monitor an on-going shutdown when requested - pub fn shutdown_token(&self) -> CancellationToken { - self.inner.global_cancel_token.clone() + /// Spawn a potentially thread-blocking future on a dedicated thread pool + #[track_caller] + pub fn spawn_blocking_unmanaged( + name: &'static str, + future: F, + ) -> tokio::task::JoinHandle + where + F: Future + Send + 'static, + O: Send + 'static, + { + Self::with_current(|tc| tc.spawn_blocking_unmanaged(name, future)) } - /// The exit code that the process should exit with. - pub fn exit_code(&self) -> i32 { - self.inner.current_exit_code.load(Ordering::Relaxed) + /// Take control over the running task from task-center. This returns None if the task was not + /// found, completed, or has been cancelled. + #[track_caller] + pub fn take_task(task_id: TaskId) -> Option> { + Self::with_current(|tc| tc.take_task(task_id)) + } + + /// Request cancellation of a task. This returns the join handle if the task was found and was + /// not already cancelled or completed. The returned task will not be awaited by task-center on + /// shutdown, and it's the responsibility of the caller to join or abort. + #[track_caller] + pub fn cancel_task(task_id: TaskId) -> Option> { + Self::with_current(|tc| tc.cancel_task(task_id)) + } + + /// Signal and wait for tasks to stop. + /// + /// + /// You can select which tasks to cancel. Any None arguments are ignored. + /// For example, to shut down all MetadataBackgroundSync tasks: + /// + /// cancel_tasks(Some(TaskKind::MetadataBackgroundSync), None) + /// + /// Or to shut down all tasks for a particular partition ID: + /// + /// cancel_tasks(None, Some(partition_id)) + /// + pub async fn cancel_tasks(kind: Option, partition_id: Option) { + Self::current().cancel_tasks(kind, partition_id).await + } + + /// Triggers a shutdown of the system. All running tasks will be asked gracefully + /// to cancel but we will only wait for tasks with a TaskKind that has the property + /// "OnCancel" set to "wait". + pub async fn shutdown_node(reason: &str, exit_code: i32) { + Self::current().shutdown_node(reason, exit_code).await + } + + #[track_caller] + pub fn shutdown_managed_runtimes() { + Self::with_current(|tc| tc.shutdown_managed_runtimes()) } - fn submit_runtime_metrics(runtime: &'static str, stats: RuntimeMetrics) { - gauge!("restate.tokio.num_workers", "runtime" => runtime).set(stats.num_workers() as f64); - gauge!("restate.tokio.blocking_threads", "runtime" => runtime) - .set(stats.num_blocking_threads() as f64); - gauge!("restate.tokio.blocking_queue_depth", "runtime" => runtime) - .set(stats.blocking_queue_depth() as f64); - gauge!("restate.tokio.num_alive_tasks", "runtime" => runtime) - .set(stats.num_alive_tasks() as f64); - gauge!("restate.tokio.io_driver_ready_count", "runtime" => runtime) - .set(stats.io_driver_ready_count() as f64); - gauge!("restate.tokio.remote_schedule_count", "runtime" => runtime) - .set(stats.remote_schedule_count() as f64); - // per worker stats - for idx in 0..stats.num_workers() { - gauge!("restate.tokio.worker_overflow_count", "runtime" => runtime, "worker" => - idx.to_string()) - .set(stats.worker_overflow_count(idx) as f64); - gauge!("restate.tokio.worker_poll_count", "runtime" => runtime, "worker" => idx.to_string()) - .set(stats.worker_poll_count(idx) as f64); - gauge!("restate.tokio.worker_park_count", "runtime" => runtime, "worker" => idx.to_string()) - .set(stats.worker_park_count(idx) as f64); - gauge!("restate.tokio.worker_noop_count", "runtime" => runtime, "worker" => idx.to_string()) - .set(stats.worker_noop_count(idx) as f64); - gauge!("restate.tokio.worker_steal_count", "runtime" => runtime, "worker" => idx.to_string()) - .set(stats.worker_steal_count(idx) as f64); - gauge!("restate.tokio.worker_total_busy_duration_seconds", "runtime" => runtime, "worker" => idx.to_string()) - .set(stats.worker_total_busy_duration(idx).as_secs_f64()); - gauge!("restate.tokio.worker_mean_poll_time", "runtime" => runtime, "worker" => idx.to_string()) - .set(stats.worker_mean_poll_time(idx).as_secs_f64()); + /// Sets the current task_center but doesn't create a task. Use this when you need to run a + /// future within task_center scope. + #[track_caller] + pub fn block_on(&self, future: F) -> O + where + F: Future, + { + Self::with_current(|tc| tc.block_on(future)) + } + + /// Sets the current task_center but doesn't create a task. Use this when you need to run a + /// closure within task_center scope. + #[track_caller] + pub fn run_sync(f: F) -> O + where + F: FnOnce() -> O, + { + Self::with_current(|tc| tc.run_sync(f)) + } +} + +struct TaskCenterInner { + #[allow(dead_code)] + /// used in Debug impl to distinguish between multiple task-centers + id: u16, + /// Should we start new runtimes with paused clock? + #[allow(dead_code)] + pause_time: bool, + default_runtime_handle: tokio::runtime::Handle, + ingress_runtime_handle: tokio::runtime::Handle, + managed_runtimes: Mutex>>, + start_time: Instant, + /// We hold on to the owned Runtime to ensure it's dropped when task center is dropped. If this + /// is None, it means that it's the responsibility of the Handle owner to correctly drop + /// tokio's runtime after dropping the task center. + #[allow(dead_code)] + default_runtime: Option, + #[allow(dead_code)] + ingress_runtime: Option, + global_cancel_token: CancellationToken, + shutdown_requested: AtomicBool, + current_exit_code: AtomicI32, + managed_tasks: Mutex>>, + global_metadata: OnceLock, + root_task_context: TaskContext, +} + +impl TaskCenterInner { + fn new( + default_runtime_handle: tokio::runtime::Handle, + ingress_runtime_handle: tokio::runtime::Handle, + default_runtime: Option, + ingress_runtime: Option, + // used in tests to start all runtimes with clock paused. Note that this only impacts + // partition processor runtimes + pause_time: bool, + ) -> Self { + metric_definitions::describe_metrics(); + let root_task_context = TaskContext { + id: TaskId::ROOT, + name: "::", + kind: TaskKind::InPlace, + cancellation_token: CancellationToken::new(), + partition_id: None, + }; + Self { + id: rand::random(), + start_time: Instant::now(), + default_runtime_handle, + default_runtime, + ingress_runtime_handle, + ingress_runtime, + global_cancel_token: CancellationToken::new(), + shutdown_requested: AtomicBool::new(false), + current_exit_code: AtomicI32::new(0), + managed_tasks: Mutex::new(HashMap::new()), + global_metadata: OnceLock::new(), + managed_runtimes: Mutex::new(HashMap::with_capacity(64)), + root_task_context, + pause_time, } } - pub(crate) fn metadata(&self) -> Option { + /// Attempt to set the global metadata handle. This should be called once + /// at the startup of the node. + pub fn try_set_global_metadata(self: &Arc, metadata: Metadata) -> bool { + self.global_metadata.set(metadata).is_ok() + } + + pub fn global_metadata(self: &Arc) -> Option<&Metadata> { + self.global_metadata.get() + } + + pub fn metadata(self: &Arc) -> Option { match OVERRIDES.try_with(|overrides| overrides.metadata.clone()) { Ok(Some(o)) => Some(o), // No metadata override, use task-center-level metadata - _ => self.inner.global_metadata.get().cloned(), + _ => self.global_metadata.get().cloned(), } } - #[track_caller] - /// Attempt to access task-level overridden metadata first, if we don't have an override, - /// fallback to task-center's level metadata. - pub(crate) fn with_metadata(f: F) -> Option + pub fn with_metadata(self: &Arc, f: F) -> Option where F: FnOnce(&Metadata) -> R, { @@ -242,163 +355,117 @@ impl TaskCenter { .try_with(|overrides| match &overrides.metadata { Some(m) => Some(f(m)), // No metadata override, use task-center-level metadata - None => CURRENT_TASK_CENTER.with(|tc| tc.inner.global_metadata.get().map(f)), + None => self.global_metadata().map(f), }) .ok() .flatten() } - fn with_task_context(&self, f: F) -> R + pub fn run_sync(self: &Arc, f: F) -> O where - F: Fn(&TaskContext) -> R, + F: FnOnce() -> O, { - TASK_CONTEXT - .try_with(|ctx| f(ctx)) - .unwrap_or_else(|_| f(&self.inner.root_task_context)) - } - - /// Triggers a shutdown of the system. All running tasks will be asked gracefully - /// to cancel but we will only wait for tasks with a TaskKind that has the property - /// "OnCancel" set to "wait". - #[instrument(level = "error", skip(self, exit_code))] - pub async fn shutdown_node(&self, reason: &str, exit_code: i32) { - let inner = self.inner.clone(); - if inner - .shutdown_requested - .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) - .unwrap_or_else(|e| e) - { - // already shutting down.... - return; - } - let start = Instant::now(); - inner.current_exit_code.store(exit_code, Ordering::Relaxed); - - if exit_code != 0 { - warn!("** Shutdown requested"); - } else { - info!("** Shutdown requested"); - } - self.cancel_tasks(None, None).await; - self.shutdown_managed_runtimes(); - // notify outer components that we have completed the shutdown. - self.inner.global_cancel_token.cancel(); - info!("** Shutdown completed in {:?}", start.elapsed()); - } - - /// Attempt to set the global metadata handle. This should be called once - /// at the startup of the node. - pub fn try_set_global_metadata(metadata: Metadata) -> bool { - Self::with_current(|tc| tc.try_set_global_metadata_inner(metadata)) + CURRENT_TASK_CENTER.sync_scope(Handle::new(self), || { + OVERRIDES.sync_scope(OVERRIDES.try_with(Clone::clone).unwrap_or_default(), || { + TASK_CONTEXT.sync_scope(self.with_task_context(Clone::clone), f) + }) + }) } - /// Attempt to set the global metadata handle. This should be called once - /// at the startup of the node. - pub(crate) fn try_set_global_metadata_inner(&self, metadata: Metadata) -> bool { - self.inner.global_metadata.set(metadata).is_ok() + /// Sets the current task_center but doesn't create a task. Use this when you need to run a + /// future within task_center scope. + pub fn block_on(self: &Arc, future: F) -> O + where + F: Future, + { + self.default_runtime_handle + .block_on(future.in_tc(&Handle::new(self))) } - #[track_caller] - fn spawn_inner( - &self, + /// Launch a new task + pub fn spawn( + self: &Arc, kind: TaskKind, name: &'static str, - partition_id: Option, - cancel: CancellationToken, future: F, - ) -> TaskId + ) -> Result where F: Future> + Send + 'static, { - let inner = self.inner.clone(); - let id = TaskId::default(); - let context = TaskContext { - id, - name, - kind, - partition_id, - cancellation_token: cancel.clone(), - }; - let task = Arc::new(Task { - context: context.clone(), - handle: Mutex::new(None), - }); - - inner.managed_tasks.lock().insert(id, Arc::clone(&task)); + if self.shutdown_requested.load(Ordering::Relaxed) { + return Err(ShutdownError); + } - let mut handle_mut = task.handle.lock(); + // spawned tasks get their own unlinked cancellation tokens + let cancel = CancellationToken::new(); + let (parent_id, parent_name, parent_partition) = + self.with_task_context(|ctx| (ctx.id, ctx.name, ctx.partition_id)); - let fut = wrapper(self.clone(), context, future); - *handle_mut = Some(self.spawn_on_runtime(kind, name, cancel, fut)); - // drop the lock - drop(handle_mut); - // Task is ready - id - } + let result = self.spawn_inner(kind, name, parent_id, parent_partition, cancel, future); - fn spawn_on_runtime( - &self, - kind: TaskKind, - name: &'static str, - cancellation_token: CancellationToken, - fut: F, - ) -> TaskHandle - where - F: Future + Send + 'static, - T: Send + 'static, - { - let kind_str: &'static str = kind.into(); - let runtime_name: &'static str = kind.runtime().into(); - let tokio_task = tokio::task::Builder::new().name(name); - counter!(TC_SPAWN, "kind" => kind_str, "runtime" => runtime_name).increment(1); - let runtime = match kind.runtime() { - crate::AsyncRuntime::Inherit => &tokio::runtime::Handle::current(), - crate::AsyncRuntime::Default => &self.inner.default_runtime_handle, - crate::AsyncRuntime::Ingress => &self.inner.ingress_runtime_handle, - }; - let inner_handle = tokio_task - .spawn_on(fut, runtime) - .expect("runtime can spawn tasks"); - - TaskHandle { - cancellation_token, - inner_handle, - } + trace!( + kind = ?kind, + name = %name, + parent_task = %parent_id, + "Task \"{}\" {} spawned \"{}\" {}", + parent_name, parent_id, name, result + ); + Ok(result) } - /// Launch a new task - #[track_caller] - pub fn spawn( - &self, + /// Spawn a new task that is a child of the current task. The child task will be cancelled if the parent + /// task is cancelled. At the moment, the parent task will not automatically wait for children tasks to + /// finish before completion, but this might change in the future if the need for that arises. + pub fn spawn_child( + self: &Arc, kind: TaskKind, name: &'static str, - partition_id: Option, future: F, ) -> Result where F: Future> + Send + 'static, { - if self.inner.shutdown_requested.load(Ordering::Relaxed) { + if self.shutdown_requested.load(Ordering::Relaxed) { return Err(ShutdownError); } - Ok(self.spawn_unchecked(kind, name, partition_id, future)) + + let (parent_id, parent_name, parent_kind, parent_partition, cancel) = self + .with_task_context(|ctx| { + ( + ctx.id, + ctx.name, + ctx.kind, + ctx.partition_id, + ctx.cancellation_token.child_token(), + ) + }); + + let result = self.spawn_inner(kind, name, parent_id, parent_partition, cancel, future); + + trace!( + kind = ?parent_kind, + name = ?parent_name, + child_kind = ?kind, + "Task \"{}\" {} spawned \"{}\" {}", + parent_name, parent_id, name, result + ); + Ok(result) } - #[track_caller] pub fn spawn_unmanaged( - &self, + self: &Arc, kind: TaskKind, name: &'static str, - partition_id: Option, future: F, ) -> Result, ShutdownError> where F: Future + Send + 'static, T: Send + 'static, { - if self.inner.shutdown_requested.load(Ordering::Relaxed) { + if self.shutdown_requested.load(Ordering::Relaxed) { return Err(ShutdownError); } + let parent_partition = self.with_task_context(|ctx| (ctx.partition_id)); let cancel = CancellationToken::new(); let id = TaskId::default(); @@ -406,36 +473,19 @@ impl TaskCenter { id, name, kind, - partition_id, + partition_id: parent_partition, cancellation_token: cancel.clone(), }; - let fut = unmanaged_wrapper(self.clone(), context, future); + let fut = unmanaged_wrapper(Arc::clone(self), context, future); Ok(self.spawn_on_runtime(kind, name, cancel, fut)) } - // Allows for spawning a new task without checking if the system is shutting down. This means - // that this task might not be able to finish if the system is shutting down. - #[track_caller] - fn spawn_unchecked( - &self, - kind: TaskKind, - name: &'static str, - partition_id: Option, - future: F, - ) -> TaskId - where - F: Future> + Send + 'static, - { - let cancel = CancellationToken::new(); - self.spawn_inner(kind, name, partition_id, cancel, future) - } - /// Must be called within a Localset-scoped task, not from a normal spawned task. /// If ran from a non-localset task, this will panic. pub fn spawn_local( - &self, + self: &Arc, kind: TaskKind, name: &'static str, future: F, @@ -459,14 +509,12 @@ impl TaskCenter { handle: Mutex::new(None), }); - let inner = self.inner.clone(); - inner - .managed_tasks + self.managed_tasks .lock() .insert(context.id, Arc::clone(&task)); let mut handle_mut = task.handle.lock(); - let fut = wrapper(self.clone(), context, future); + let fut = wrapper(Arc::clone(self), context, future); let tokio_task = tokio::task::Builder::new().name(name); let inner_handle = tokio_task @@ -483,10 +531,26 @@ impl TaskCenter { Ok(id) } + // Spawn a future in its own thread + pub fn spawn_blocking_unmanaged( + self: &Arc, + name: &'static str, + future: F, + ) -> tokio::task::JoinHandle + where + F: Future + Send + 'static, + O: Send + 'static, + { + let rt_handle = self.default_runtime_handle.clone(); + let future = future.in_tc_as_task(&Handle::new(self), TaskKind::InPlace, name); + self.default_runtime_handle + .spawn_blocking(move || rt_handle.block_on(future)) + } + /// Starts the `root_future` on a new runtime. The runtime is stopped once the root future /// completes. pub fn start_runtime( - &self, + self: &Arc, root_task_kind: TaskKind, runtime_name: &'static str, partition_id: Option, @@ -495,14 +559,14 @@ impl TaskCenter { where F: Future> + 'static, { - if self.inner.shutdown_requested.load(Ordering::Relaxed) { + if self.shutdown_requested.load(Ordering::Relaxed) { return Err(ShutdownError.into()); } let cancel = CancellationToken::new(); // hold a lock while creating the runtime to avoid concurrent runtimes with the same name - let mut runtimes_guard = self.inner.managed_runtimes.lock(); + let mut runtimes_guard = self.managed_runtimes.lock(); if runtimes_guard.contains_key(runtime_name) { warn!( "Failed to start new runtime, a runtime with name {} already exists", @@ -516,7 +580,7 @@ impl TaskCenter { let mut builder = tokio::runtime::Builder::new_current_thread(); #[cfg(any(test, feature = "test-util"))] - builder.start_paused(self.inner.pause_time); + builder.start_paused(self.pause_time); let rt = builder .enable_all() @@ -556,283 +620,112 @@ impl TaskCenter { drop(rt_handle); tc.drop_runtime(runtime_name); - // need to use an oneshot here since we cannot await a thread::JoinHandle :-( - let _ = result_tx.send(result); - }) - .unwrap(); - - Ok(RuntimeRootTaskHandle { - inner_handle: result_rx, - cancellation_token: cancel, - }) - } - - fn drop_runtime(&self, name: &'static str) { - let mut runtimes_guard = self.inner.managed_runtimes.lock(); - if runtimes_guard.remove(name).is_some() { - trace!("Runtime {} was dropped", name); - } - } - - /// Spawn a new task that is a child of the current task. The child task will be cancelled if the parent - /// task is cancelled. At the moment, the parent task will not automatically wait for children tasks to - /// finish before completion, but this might change in the future if the need for that arises. - #[track_caller] - pub fn spawn_child( - kind: TaskKind, - name: &'static str, - future: F, - ) -> Result - where - F: Future> + Send + 'static, - { - Self::with_current(|tc| tc.spawn_child_inner(kind, name, future)) - } - - /// Spawn a new task that is a child of the current task. The child task will be cancelled if the parent - /// task is cancelled. At the moment, the parent task will not automatically wait for children tasks to - /// finish before completion, but this might change in the future if the need for that arises. - fn spawn_child_inner( - &self, - kind: TaskKind, - name: &'static str, - future: F, - ) -> Result - where - F: Future> + Send + 'static, - { - if self.inner.shutdown_requested.load(Ordering::Relaxed) { - return Err(ShutdownError); - } - - let (parent_id, parent_name, parent_kind, parent_partition, cancel) = self - .with_task_context(|ctx| { - ( - ctx.id, - ctx.name, - ctx.kind, - ctx.partition_id, - ctx.cancellation_token.child_token(), - ) - }); - - let result = self.spawn_inner(kind, name, parent_partition, cancel, future); - - trace!( - kind = ?parent_kind, - name = ?parent_name, - child_kind = ?kind, - "Task \"{}\" {} spawned \"{}\" {}", - parent_name, parent_id, name, result - ); - Ok(result) - } - - // Spawn a future in its own thread - pub fn spawn_blocking_unmanaged( - &self, - name: &'static str, - future: F, - ) -> tokio::task::JoinHandle - where - F: Future + Send + 'static, - O: Send + 'static, - { - let rt_handle = self.inner.default_runtime_handle.clone(); - let future = future.in_tc_as_task(self, TaskKind::InPlace, name); - self.inner - .default_runtime_handle - .spawn_blocking(move || rt_handle.block_on(future)) - } - - /// Cancelling the child will not cancel the parent. Note that parent task will not - /// wait for children tasks. The parent task is allowed to finish before children. - #[track_caller] - pub fn spawn_child_unchecked( - &self, - kind: TaskKind, - name: &'static str, - partition_id: Option, - future: F, - ) -> TaskId - where - F: Future> + Send + 'static, - { - let cancel = self.with_task_context(|ctx| ctx.cancellation_token.child_token()); - self.spawn_inner(kind, name, partition_id, cancel, future) - } - - /// Signal and wait for tasks to stop. - /// - /// - /// You can select which tasks to cancel. Any None arguments are ignored. - /// For example, to shut down all MetadataBackgroundSync tasks: - /// - /// cancel_tasks(Some(TaskKind::MetadataBackgroundSync), None) - /// - /// Or to shut down all tasks for a particular partition ID: - /// - /// cancel_tasks(None, Some(partition_id)) - /// - pub async fn cancel_tasks(&self, kind: Option, partition_id: Option) { - let inner = self.inner.clone(); - let mut victims = Vec::new(); - - { - let tasks = inner.managed_tasks.lock(); - for task in tasks.values() { - if (kind.is_none() || Some(task.context.kind) == kind) - && (partition_id.is_none() || task.context.partition_id == partition_id) - { - task.context.cancellation_token.cancel(); - victims.push((Arc::clone(task), task.kind(), task.partition_id())); - } - } - } - - for (task, task_kind, partition_id) in victims { - let handle = { - let mut task_mut = task.handle.lock(); - // Task is not running anymore or another cancel is waiting for it. - task_mut.take() - }; - if let Some(mut handle) = handle { - if task_kind.should_abort_on_cancel() { - // We should not wait, instead, just abort the tokio task. - debug!(kind = ?task_kind, name = ?task.name(), partition_id = ?partition_id, "task {} aborted!", task.id()); - handle.abort(); - } else if task_kind.should_wait_on_cancel() { - // Give the task a chance to finish before logging. - if tokio::time::timeout(Duration::from_secs(2), &mut handle) - .await - .is_err() - { - info!(kind = ?task_kind, name = ?task.name(), partition_id = ?partition_id, "waiting for task {} to shutdown", task.id()); - // Ignore join errors on cancel. on_finish already takes care - let _ = handle.await; - info!(kind = ?task_kind, name = ?task.name(), partition_id = ?partition_id, "task {} completed", task.id()); - } - } else { - // Ignore the task. the task will be dropped on tokio runtime drop. - } - } else { - // Possibly one of: - // * The task had not even fully started yet. - // * It was shut down concurrently and already exited (or failed) - } - } + // need to use an oneshot here since we cannot await a thread::JoinHandle :-( + let _ = result_tx.send(result); + }) + .unwrap(); + + Ok(RuntimeRootTaskHandle { + inner_handle: result_rx, + cancellation_token: cancel, + }) } - pub fn shutdown_managed_runtimes(&self) { - let mut runtimes = self.inner.managed_runtimes.lock(); - for (_, runtime) in runtimes.drain() { - if let Some(runtime) = Arc::into_inner(runtime) { - runtime.shutdown_background(); - } + fn drop_runtime(self: &Arc, name: &'static str) { + let mut runtimes_guard = self.managed_runtimes.lock(); + if runtimes_guard.remove(name).is_some() { + trace!("Runtime {} was dropped", name); } } - /// Sets the current task_center but doesn't create a task. Use this when you need to run a - /// future within task_center scope. - pub fn block_on(&self, future: F) -> O + fn with_task_context(&self, f: F) -> R where - F: Future, + F: Fn(&TaskContext) -> R, { - self.inner - .default_runtime_handle - .block_on(future.in_tc(self)) + TASK_CONTEXT + .try_with(|ctx| f(ctx)) + .unwrap_or_else(|_| f(&self.root_task_context)) } - /// Sets the current task_center but doesn't create a task. Use this when you need to run a - /// future within task_center scope. - pub async fn run_in_scope( - &self, + fn spawn_inner( + self: &Arc, + kind: TaskKind, name: &'static str, + _parent_id: TaskId, partition_id: Option, + cancel: CancellationToken, future: F, - ) -> O + ) -> TaskId where - F: Future, + F: Future> + Send + 'static, { - let cancellation_token = CancellationToken::new(); + let inner = Arc::clone(self); let id = TaskId::default(); - let ctx = TaskContext { + let context = TaskContext { id, name, - kind: TaskKind::InPlace, - cancellation_token: cancellation_token.clone(), + kind, partition_id, + cancellation_token: cancel.clone(), }; + let task = Arc::new(Task { + context: context.clone(), + handle: Mutex::new(None), + }); - CURRENT_TASK_CENTER - .scope( - self.clone(), - OVERRIDES.scope( - OVERRIDES.try_with(Clone::clone).unwrap_or_default(), - TASK_CONTEXT.scope(ctx, future), - ), - ) - .await - } - - /// Sets the current task_center but doesn't create a task. Use this when you need to run a - /// closure within task_center scope. - pub fn run_in_scope_sync(&self, f: F) -> O - where - F: FnOnce() -> O, - { - CURRENT_TASK_CENTER.sync_scope(self.clone(), || { - OVERRIDES.sync_scope(OVERRIDES.try_with(Clone::clone).unwrap_or_default(), || { - TASK_CONTEXT.sync_scope(self.with_task_context(Clone::clone), f) - }) - }) - } + inner.managed_tasks.lock().insert(id, Arc::clone(&task)); - /// Take control over the running task from task-center. This returns None if the task was not - /// found, completed, or has been cancelled. - pub fn take_task(&self, task_id: TaskId) -> Option> { - let inner = self.inner.clone(); - let task = { - // find the task - let mut tasks = inner.managed_tasks.lock(); - tasks.remove(&task_id)? - }; + let mut handle_mut = task.handle.lock(); - let mut task_mut = task.handle.lock(); - // Task is not running anymore or a cancellation is already in progress. - task_mut.take() + let fut = wrapper(inner, context, future); + *handle_mut = Some(self.spawn_on_runtime(kind, name, cancel, fut)); + // drop the lock + drop(handle_mut); + // Task is ready + id } - /// Request cancellation of a task. This returns the join handle if the task was found and was - /// not already cancelled or completed. The returned task will not be awaited by task-center on - /// shutdown, and it's the responsibility of the caller to join or abort. - pub fn cancel_task(&self, task_id: TaskId) -> Option> { - let inner = self.inner.clone(); - let task = { - // find the task - let tasks = inner.managed_tasks.lock(); - let task = tasks.get(&task_id)?; - // request cancellation - task.cancel(); - Arc::clone(task) + fn spawn_on_runtime( + self: &Arc, + kind: TaskKind, + name: &'static str, + cancellation_token: CancellationToken, + fut: F, + ) -> TaskHandle + where + F: Future + Send + 'static, + T: Send + 'static, + { + let kind_str: &'static str = kind.into(); + let runtime_name: &'static str = kind.runtime().into(); + let tokio_task = tokio::task::Builder::new().name(name); + counter!(TC_SPAWN, "kind" => kind_str, "runtime" => runtime_name).increment(1); + let runtime = match kind.runtime() { + crate::AsyncRuntime::Inherit => &tokio::runtime::Handle::current(), + crate::AsyncRuntime::Default => &self.default_runtime_handle, + crate::AsyncRuntime::Ingress => &self.ingress_runtime_handle, }; + let inner_handle = tokio_task + .spawn_on(fut, runtime) + .expect("runtime can spawn tasks"); - let mut task_mut = task.handle.lock(); - // Task is not running anymore or a cancellation is already in progress. - task_mut.take() + TaskHandle { + cancellation_token, + inner_handle, + } } async fn on_finish( - &self, + self: &Arc, task_id: TaskId, result: std::result::Result< anyhow::Result<()>, std::boxed::Box, >, ) { - let inner = self.inner.clone(); + //let inner = self.inner.clone(); // Remove our entry from the tasks map. - let Some(task) = inner.managed_tasks.lock().remove(&task_id) else { + let Some(task) = self.managed_tasks.lock().remove(&task_id) else { // This can happen if the task ownership was taken by calling take_task(id); return; }; @@ -893,37 +786,142 @@ impl TaskCenter { .await; } } -} -struct TaskCenterInner { - #[allow(dead_code)] - /// used in Debug impl to distinguish between multiple task-centers - id: u16, - /// Should we start new runtimes with paused clock? - #[allow(dead_code)] - pause_time: bool, - default_runtime_handle: tokio::runtime::Handle, - ingress_runtime_handle: tokio::runtime::Handle, - managed_runtimes: Mutex>>, - start_time: Instant, - /// We hold on to the owned Runtime to ensure it's dropped when task center is dropped. If this - /// is None, it means that it's the responsibility of the Handle owner to correctly drop - /// tokio's runtime after dropping the task center. - #[allow(dead_code)] - default_runtime: Option, - #[allow(dead_code)] - ingress_runtime: Option, - global_cancel_token: CancellationToken, - shutdown_requested: AtomicBool, - current_exit_code: AtomicI32, - managed_tasks: Mutex>>, - global_metadata: OnceLock, - root_task_context: TaskContext, + async fn shutdown_node(self: &Arc, reason: &str, exit_code: i32) { + //let inner = self.inner.clone(); + if self + .shutdown_requested + .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) + .unwrap_or_else(|e| e) + { + // already shutting down.... + return; + } + let start = Instant::now(); + self.current_exit_code.store(exit_code, Ordering::Relaxed); + + if exit_code != 0 { + warn!(%reason, "** Shutdown requested"); + } else { + info!(%reason, "** Shutdown requested"); + } + self.cancel_tasks(None, None).await; + self.shutdown_managed_runtimes(); + // notify outer components that we have completed the shutdown. + self.global_cancel_token.cancel(); + info!("** Shutdown completed in {:?}", start.elapsed()); + } + + /// Take control over the running task from task-center. This returns None if the task was not + /// found, completed, or has been cancelled. + pub fn take_task(self: &Arc, task_id: TaskId) -> Option> { + let task = { + // find the task + let mut tasks = self.managed_tasks.lock(); + tasks.remove(&task_id)? + }; + + let mut task_mut = task.handle.lock(); + // Task is not running anymore or a cancellation is already in progress. + task_mut.take() + } + + /// Request cancellation of a task. This returns the join handle if the task was found and was + /// not already cancelled or completed. The returned task will not be awaited by task-center on + /// shutdown, and it's the responsibility of the caller to join or abort. + pub fn cancel_task(self: &Arc, task_id: TaskId) -> Option> { + let task = { + // find the task + let tasks = self.managed_tasks.lock(); + let task = tasks.get(&task_id)?; + // request cancellation + task.cancel(); + Arc::clone(task) + }; + + let mut task_mut = task.handle.lock(); + // Task is not running anymore or a cancellation is already in progress. + task_mut.take() + } + + /// Signal and wait for tasks to stop. + /// + /// + /// You can select which tasks to cancel. Any None arguments are ignored. + /// For example, to shut down all MetadataBackgroundSync tasks: + /// + /// cancel_tasks(Some(TaskKind::MetadataBackgroundSync), None) + /// + /// Or to shut down all tasks for a particular partition ID: + /// + /// cancel_tasks(None, Some(partition_id)) + /// + async fn cancel_tasks( + self: &Arc, + kind: Option, + partition_id: Option, + ) { + //let inner = self.inner.clone(); + let mut victims = Vec::new(); + + { + let tasks = self.managed_tasks.lock(); + for task in tasks.values() { + if (kind.is_none() || Some(task.context.kind) == kind) + && (partition_id.is_none() || task.context.partition_id == partition_id) + { + task.context.cancellation_token.cancel(); + victims.push((Arc::clone(task), task.kind(), task.partition_id())); + } + } + } + + for (task, task_kind, partition_id) in victims { + let handle = { + let mut task_mut = task.handle.lock(); + // Task is not running anymore or another cancel is waiting for it. + task_mut.take() + }; + if let Some(mut handle) = handle { + if task_kind.should_abort_on_cancel() { + // We should not wait, instead, just abort the tokio task. + debug!(kind = ?task_kind, name = ?task.name(), partition_id = ?partition_id, "task {} aborted!", task.id()); + handle.abort(); + } else if task_kind.should_wait_on_cancel() { + // Give the task a chance to finish before logging. + if tokio::time::timeout(Duration::from_secs(2), &mut handle) + .await + .is_err() + { + info!(kind = ?task_kind, name = ?task.name(), partition_id = ?partition_id, "waiting for task {} to shutdown", task.id()); + // Ignore join errors on cancel. on_finish already takes care + let _ = handle.await; + info!(kind = ?task_kind, name = ?task.name(), partition_id = ?partition_id, "task {} completed", task.id()); + } + } else { + // Ignore the task. the task will be dropped on tokio runtime drop. + } + } else { + // Possibly one of: + // * The task had not even fully started yet. + // * It was shut down concurrently and already exited (or failed) + } + } + } + + fn shutdown_managed_runtimes(self: &Arc) { + let mut runtimes = self.managed_runtimes.lock(); + for (_, runtime) in runtimes.drain() { + if let Some(runtime) = Arc::into_inner(runtime) { + runtime.shutdown_background(); + } + } + } } /// This wrapper function runs in a newly-spawned task. It initializes the /// task-local variables and wraps the inner future. -async fn wrapper(task_center: TaskCenter, context: TaskContext, future: F) +async fn wrapper(inner: Arc, context: TaskContext, future: F) where F: Future> + 'static, { @@ -932,7 +930,7 @@ where let result = CURRENT_TASK_CENTER .scope( - task_center.clone(), + Handle::new(&inner), OVERRIDES.scope( OVERRIDES.try_with(Clone::clone).unwrap_or_default(), TASK_CONTEXT.scope( @@ -945,11 +943,11 @@ where ), ) .await; - task_center.on_finish(id, result).await; + inner.on_finish(id, result).await; } /// Like wrapper but doesn't call on_finish nor it catches panics -async fn unmanaged_wrapper(task_center: TaskCenter, context: TaskContext, future: F) -> T +async fn unmanaged_wrapper(inner: Arc, context: TaskContext, future: F) -> T where F: Future + 'static, { @@ -957,7 +955,7 @@ where CURRENT_TASK_CENTER .scope( - task_center.clone(), + Handle::new(&inner), OVERRIDES.scope( OVERRIDES.try_with(Clone::clone).unwrap_or_default(), TASK_CONTEXT.scope(context, future), @@ -966,21 +964,6 @@ where .await } -/// Access to global metadata handle. This is available in task-center tasks only! -#[track_caller] -pub fn metadata() -> Metadata { - // todo: migrate call-sites - Metadata::current() -} - -#[track_caller] -pub fn with_metadata(f: F) -> R -where - F: FnOnce(&Metadata) -> R, -{ - Metadata::with_current(f) -} - /// Access to this node id. This is available in task-center tasks only! #[track_caller] pub fn my_node_id() -> GenerationalNodeId { @@ -991,28 +974,12 @@ pub fn my_node_id() -> GenerationalNodeId { /// The current task-center task Id. This returns None if we are not in the scope /// of a task-center task. pub fn current_task_id() -> Option { - TASK_CONTEXT - .try_with(|ctx| Some(ctx.id)) - .unwrap_or(TaskCenter::try_with_current(|tc| { - tc.inner.root_task_context.id - })) + TaskContext::try_with_current(|ctx| ctx.id()) } /// The current partition Id associated to the running task-center task. pub fn current_task_partition_id() -> Option { - TASK_CONTEXT - .try_with(|ctx| Some(ctx.partition_id)) - .unwrap_or(TaskCenter::try_with_current(|tc| { - tc.inner.root_task_context.partition_id - })) - .flatten() -} - -/// Get the current task center. Use this to spawn tasks on the current task center. -/// This must be called from within a task-center task. -pub fn task_center() -> TaskCenter { - // migrate call-sites - TaskCenter::current() + TaskContext::try_with_current(|ctx| ctx.partition_id()).flatten() } /// A Future that can be used to check if the current task has been requested to @@ -1028,7 +995,7 @@ pub async fn cancellation_watcher() { /// cancel_task() call, or if it's a child and the parent is being cancelled by a /// cancel_task() call, this cancellation token will be set to cancelled. pub fn cancellation_token() -> CancellationToken { - let res = TASK_CONTEXT.try_with(|ctx| ctx.cancellation_token.clone()); + let res = TaskContext::try_with_current(|ctx| ctx.cancellation_token().clone()); if cfg!(any(test, feature = "test-util")) { // allow in tests to call from non-task-center tasks. @@ -1040,14 +1007,14 @@ pub fn cancellation_token() -> CancellationToken { /// Has the current task been requested to cancel? pub fn is_cancellation_requested() -> bool { - TASK_CONTEXT - .try_with(|ctx| ctx.cancellation_token.is_cancelled()) - .unwrap_or_else(|_| { + TaskContext::try_with_current(|ctx| ctx.cancellation_token().is_cancelled()).unwrap_or_else( + || { if cfg!(any(test, feature = "test-util")) { warn!("is_cancellation_requested() called outside task-center context"); } false - }) + }, + ) } #[cfg(test)] @@ -1069,9 +1036,10 @@ mod tests { .options(common_opts) .default_runtime_handle(tokio::runtime::Handle::current()) .ingress_runtime_handle(tokio::runtime::Handle::current()) - .build()?; + .build()? + .to_handle(); let start = tokio::time::Instant::now(); - tc.spawn(TaskKind::RoleRunner, "worker-role", None, async { + tc.spawn(TaskKind::RoleRunner, "worker-role", async { info!("Hello async"); tokio::time::sleep(Duration::from_secs(10)).await; info!("Bye async"); diff --git a/crates/core/src/task_center/builder.rs b/crates/core/src/task_center/builder.rs index 7325bbfa9..9d1e55cb7 100644 --- a/crates/core/src/task_center/builder.rs +++ b/crates/core/src/task_center/builder.rs @@ -14,7 +14,7 @@ use tracing::error; use restate_types::config::CommonOptions; -use super::TaskCenter; +use super::{OwnedHandle, TaskCenterInner}; static WORKER_ID: AtomicUsize = const { AtomicUsize::new(0) }; @@ -78,7 +78,7 @@ impl TaskCenterBuilder { .pause_time(true) } - pub fn build(mut self) -> Result { + pub fn build(mut self) -> Result { let options = self.options.unwrap_or_default(); if self.default_runtime_handle.is_none() { let mut default_runtime_builder = tokio_builder("worker", &options); @@ -97,13 +97,13 @@ impl TaskCenterBuilder { if cfg!(any(test, feature = "test-util")) { eprintln!("!!!! Runnning with test-util enabled !!!!"); } - Ok(TaskCenter::new( + Ok(OwnedHandle::new(TaskCenterInner::new( self.default_runtime_handle.unwrap(), self.ingress_runtime_handle.unwrap(), self.default_runtime, self.ingress_runtime, self.pause_time, - )) + ))) } } diff --git a/crates/core/src/task_center/extensions.rs b/crates/core/src/task_center/extensions.rs index 99736838e..fe550e7a1 100644 --- a/crates/core/src/task_center/extensions.rs +++ b/crates/core/src/task_center/extensions.rs @@ -19,24 +19,25 @@ use crate::task_center::TaskContext; use crate::Metadata; use super::{ - GlobalOverrides, TaskCenter, TaskId, TaskKind, CURRENT_TASK_CENTER, OVERRIDES, TASK_CONTEXT, + GlobalOverrides, Handle, TaskCenter, TaskId, TaskKind, CURRENT_TASK_CENTER, OVERRIDES, + TASK_CONTEXT, }; type TaskCenterFuture = - TaskLocalFuture>>; + TaskLocalFuture>>; /// Adds the ability to override task-center for a future and all its children pub trait TaskCenterFutureExt: Sized { /// Ensures that a future will run within a task-center context. This will inherit the current /// task context (if there is one). Otherwise, it'll run in the context of the root task (task-id=0). - fn in_tc(self, task_center: &TaskCenter) -> WithTaskCenter; + fn in_tc(self, task_center: &Handle) -> WithTaskCenter; /// Lets task-center treat this future as a pseudo-task. It gets its own TaskId and an /// independent cancellation token. However, task-center will not spawn this as a task nor /// manage its lifecycle. fn in_tc_as_task( self, - task_center: &TaskCenter, + task_center: &Handle, kind: TaskKind, name: &'static str, ) -> WithTaskCenter; @@ -67,7 +68,7 @@ impl TaskCenterFutureExt for F where F: Future, { - fn in_tc(self, task_center: &TaskCenter) -> WithTaskCenter { + fn in_tc(self, task_center: &Handle) -> WithTaskCenter { let ctx = task_center.with_task_context(Clone::clone); let inner = CURRENT_TASK_CENTER.scope( @@ -82,7 +83,7 @@ where fn in_tc_as_task( self, - task_center: &TaskCenter, + task_center: &Handle, kind: TaskKind, name: &'static str, ) -> WithTaskCenter { @@ -106,10 +107,12 @@ where /// Ensures that a future will run within a task-center context. This will inherit the current /// task context (if there is one). Otherwise, it'll run in the context of the root task (task-id=0). + #[track_caller] fn in_current_tc(self) -> WithTaskCenter { TaskCenter::with_current(|tc| self.in_tc(tc)) } + #[track_caller] fn in_current_tc_as_task(self, kind: TaskKind, name: &'static str) -> WithTaskCenter { TaskCenter::with_current(|tc| self.in_tc_as_task(tc, kind, name)) } diff --git a/crates/core/src/task_center/handle.rs b/crates/core/src/task_center/handle.rs new file mode 100644 index 000000000..cfd44fca5 --- /dev/null +++ b/crates/core/src/task_center/handle.rs @@ -0,0 +1,236 @@ +// Copyright (c) 2023 - 2025 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use std::future::Future; +use std::sync::atomic::Ordering; +use std::sync::Arc; + +use restate_types::identifiers::PartitionId; +use tokio_util::sync::CancellationToken; +use tracing::instrument; + +use crate::{Metadata, ShutdownError}; + +use super::{ + RuntimeError, RuntimeRootTaskHandle, TaskCenterInner, TaskContext, TaskHandle, TaskId, TaskKind, +}; + +#[derive(Clone, derive_more::Debug)] +#[debug("TaskCenter({})", inner.id)] +pub struct Handle { + pub(super) inner: Arc, +} + +static_assertions::assert_impl_all!(Handle: Send, Sync, Clone); + +impl Handle { + pub(super) fn new(inner: &Arc) -> Self { + Self { + inner: Arc::clone(inner), + } + } + + pub(crate) fn with_task_context(&self, f: F) -> R + where + F: Fn(&TaskContext) -> R, + { + self.inner.with_task_context(f) + } + + /// Attempt to access task-level overridden metadata first, if we don't have an override, + /// fallback to task-center's level metadata. + pub(crate) fn with_metadata(&self, f: F) -> Option + where + F: FnOnce(&Metadata) -> R, + { + self.inner.with_metadata(f) + } + + /// Attempt to set the global metadata handle. This should be called once + /// at the startup of the node. + pub fn try_set_global_metadata(&self, metadata: Metadata) -> bool { + self.inner.try_set_global_metadata(metadata) + } + + /// Sets the current task_center but doesn't create a task. Use this when you need to run a + /// closure within task_center scope. + pub fn run_sync(&self, f: F) -> O + where + F: FnOnce() -> O, + { + self.inner.run_sync(f) + } + + /// Sets the current task_center but doesn't create a task. Use this when you need to run a + /// future within task_center scope. + pub fn block_on(&self, future: F) -> O + where + F: Future, + { + self.inner.block_on(future) + } + + pub fn start_runtime( + &self, + root_task_kind: TaskKind, + runtime_name: &'static str, + partition_id: Option, + root_future: impl FnOnce() -> F + Send + 'static, + ) -> Result>, RuntimeError> + where + F: Future> + 'static, + { + self.inner + .start_runtime(root_task_kind, runtime_name, partition_id, root_future) + } + + /// Launch a new task + pub fn spawn( + &self, + kind: TaskKind, + name: &'static str, + future: F, + ) -> Result + where + F: Future> + Send + 'static, + { + self.inner.spawn(kind, name, future) + } + + /// Spawn a new task that is a child of the current task. The child task will be cancelled if the parent + /// task is cancelled. At the moment, the parent task will not automatically wait for children tasks to + /// finish before completion, but this might change in the future if the need for that arises. + pub fn spawn_child( + &self, + kind: TaskKind, + name: &'static str, + future: F, + ) -> Result + where + F: Future> + Send + 'static, + { + self.inner.spawn_child(kind, name, future) + } + + pub fn spawn_unmanaged( + &self, + kind: TaskKind, + name: &'static str, + future: F, + ) -> Result, ShutdownError> + where + F: Future + Send + 'static, + T: Send + 'static, + { + self.inner.spawn_unmanaged(kind, name, future) + } + + /// Must be called within a Localset-scoped task, not from a normal spawned task. + /// If ran from a non-localset task, this will panic. + pub fn spawn_local( + &self, + kind: TaskKind, + name: &'static str, + future: F, + ) -> Result + where + F: Future> + 'static, + { + self.inner.spawn_local(kind, name, future) + } + + pub fn metadata(&self) -> Option { + self.inner.metadata() + } + + /// Spawn a potentially thread-blocking future on a dedicated thread pool + pub fn spawn_blocking_unmanaged( + &self, + name: &'static str, + future: F, + ) -> tokio::task::JoinHandle + where + F: Future + Send + 'static, + O: Send + 'static, + { + self.inner.spawn_blocking_unmanaged(name, future) + } + + /// Take control over the running task from task-center. This returns None if the task was not + /// found, completed, or has been cancelled. + pub fn take_task(&self, task_id: TaskId) -> Option> { + self.inner.take_task(task_id) + } + + /// Request cancellation of a task. This returns the join handle if the task was found and was + /// not already cancelled or completed. The returned task will not be awaited by task-center on + /// shutdown, and it's the responsibility of the caller to join or abort. + pub fn cancel_task(&self, task_id: TaskId) -> Option> { + self.inner.cancel_task(task_id) + } + + /// Signal and wait for tasks to stop. + /// + /// + /// You can select which tasks to cancel. Any None arguments are ignored. + /// For example, to shut down all MetadataBackgroundSync tasks: + /// + /// cancel_tasks(Some(TaskKind::MetadataBackgroundSync), None) + /// + /// Or to shut down all tasks for a particular partition ID: + /// + /// cancel_tasks(None, Some(partition_id)) + /// + pub async fn cancel_tasks(&self, kind: Option, partition_id: Option) { + self.inner.cancel_tasks(kind, partition_id).await + } + + pub fn shutdown_managed_runtimes(&self) { + self.inner.shutdown_managed_runtimes() + } + + /// Triggers a shutdown of the system. All running tasks will be asked gracefully + /// to cancel but we will only wait for tasks with a TaskKind that has the property + /// "OnCancel" set to "wait". + #[instrument(level = "error", skip(self, exit_code))] + pub async fn shutdown_node(&self, reason: &str, exit_code: i32) { + self.inner.shutdown_node(reason, exit_code).await; + } + + /// Use to monitor an on-going shutdown when requested + pub fn shutdown_token(&self) -> CancellationToken { + self.inner.global_cancel_token.clone() + } +} + +// Shutsdown +pub struct OwnedHandle { + inner: Arc, +} + +impl OwnedHandle { + pub(super) fn new(inner: TaskCenterInner) -> Self { + Self { + inner: Arc::new(inner), + } + } + + pub fn handle(&self) -> Handle { + Handle::new(&self.inner) + } + + pub fn to_handle(self) -> Handle { + Handle { inner: self.inner } + } + /// The exit code that the process should exit with. + pub fn exit_code(&self) -> i32 { + self.inner.current_exit_code.load(Ordering::Relaxed) + } +} diff --git a/crates/core/src/task_center/monitoring.rs b/crates/core/src/task_center/monitoring.rs new file mode 100644 index 000000000..be36b4551 --- /dev/null +++ b/crates/core/src/task_center/monitoring.rs @@ -0,0 +1,94 @@ +// Copyright (c) 2023 - 2025 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use std::time::Duration; + +use metrics::gauge; +use tokio::runtime::RuntimeMetrics; + +use super::Handle; + +pub trait TaskCenterMonitoring { + fn default_runtime_metrics(&self) -> RuntimeMetrics; + + fn ingress_runtime_metrics(&self) -> RuntimeMetrics; + + fn managed_runtime_metrics(&self) -> Vec<(&'static str, RuntimeMetrics)>; + + /// How long has the task-center been running? + fn age(&self) -> Duration; + + /// Submit telemetry for all runtimes to metrics recorder + fn submit_metrics(&self); +} + +impl TaskCenterMonitoring for Handle { + fn default_runtime_metrics(&self) -> RuntimeMetrics { + self.inner.default_runtime_handle.metrics() + } + + fn ingress_runtime_metrics(&self) -> RuntimeMetrics { + self.inner.ingress_runtime_handle.metrics() + } + + fn managed_runtime_metrics(&self) -> Vec<(&'static str, RuntimeMetrics)> { + let guard = self.inner.managed_runtimes.lock(); + guard.iter().map(|(k, v)| (*k, v.metrics())).collect() + } + + /// How long has the task-center been running? + fn age(&self) -> Duration { + self.inner.start_time.elapsed() + } + + /// Submit telemetry for all runtimes to metrics recorder + fn submit_metrics(&self) { + submit_runtime_metrics("default", self.default_runtime_metrics()); + submit_runtime_metrics("ingress", self.ingress_runtime_metrics()); + + // Partition processor runtimes + let processor_runtimes = self.managed_runtime_metrics(); + for (task_name, metrics) in processor_runtimes { + submit_runtime_metrics(task_name, metrics); + } + } +} + +fn submit_runtime_metrics(runtime: &'static str, stats: RuntimeMetrics) { + gauge!("restate.tokio.num_workers", "runtime" => runtime).set(stats.num_workers() as f64); + gauge!("restate.tokio.blocking_threads", "runtime" => runtime) + .set(stats.num_blocking_threads() as f64); + gauge!("restate.tokio.blocking_queue_depth", "runtime" => runtime) + .set(stats.blocking_queue_depth() as f64); + gauge!("restate.tokio.num_alive_tasks", "runtime" => runtime) + .set(stats.num_alive_tasks() as f64); + gauge!("restate.tokio.io_driver_ready_count", "runtime" => runtime) + .set(stats.io_driver_ready_count() as f64); + gauge!("restate.tokio.remote_schedule_count", "runtime" => runtime) + .set(stats.remote_schedule_count() as f64); + // per worker stats + for idx in 0..stats.num_workers() { + gauge!("restate.tokio.worker_overflow_count", "runtime" => runtime, "worker" => + idx.to_string()) + .set(stats.worker_overflow_count(idx) as f64); + gauge!("restate.tokio.worker_poll_count", "runtime" => runtime, "worker" => idx.to_string()) + .set(stats.worker_poll_count(idx) as f64); + gauge!("restate.tokio.worker_park_count", "runtime" => runtime, "worker" => idx.to_string()) + .set(stats.worker_park_count(idx) as f64); + gauge!("restate.tokio.worker_noop_count", "runtime" => runtime, "worker" => idx.to_string()) + .set(stats.worker_noop_count(idx) as f64); + gauge!("restate.tokio.worker_steal_count", "runtime" => runtime, "worker" => idx.to_string()) + .set(stats.worker_steal_count(idx) as f64); + gauge!("restate.tokio.worker_total_busy_duration_seconds", "runtime" => runtime, "worker" => idx.to_string()) + .set(stats.worker_total_busy_duration(idx).as_secs_f64()); + gauge!("restate.tokio.worker_mean_poll_time", "runtime" => runtime, "worker" => idx.to_string()) + .set(stats.worker_mean_poll_time(idx).as_secs_f64()); + } +} diff --git a/crates/core/src/task_center/task.rs b/crates/core/src/task_center/task.rs index beb1b08f8..12408ed2f 100644 --- a/crates/core/src/task_center/task.rs +++ b/crates/core/src/task_center/task.rs @@ -17,11 +17,11 @@ use tokio_util::sync::CancellationToken; use restate_types::identifiers::PartitionId; -use super::{TaskId, TaskKind}; +use super::{TaskId, TaskKind, TASK_CONTEXT}; use crate::ShutdownError; #[derive(Clone)] -pub(super) struct TaskContext { +pub struct TaskContext { /// It's nice to have a unique ID for each task. pub(super) id: TaskId, pub(super) name: &'static str, @@ -33,6 +33,59 @@ pub(super) struct TaskContext { pub(super) partition_id: Option, } +impl TaskContext { + /// Access to current task-center task context + #[track_caller] + pub fn current() -> Self { + Self::with_current(Clone::clone) + } + #[track_caller] + pub fn with_current(f: F) -> R + where + F: FnOnce(&TaskContext) -> R, + { + TASK_CONTEXT + .try_with(|ctx| f(ctx)) + .expect("called outside task-center task") + } + + pub fn try_with_current(f: F) -> Option + where + F: FnOnce(&Self) -> R, + { + TASK_CONTEXT.try_with(|tc| f(tc)).ok() + } + + /// Access to current task-center task context + pub fn try_current() -> Option { + Self::try_with_current(Clone::clone) + } + + pub fn id(&self) -> TaskId { + self.id + } + + pub fn name(&self) -> &'static str { + self.name + } + + pub fn kind(&self) -> TaskKind { + self.kind + } + + pub fn partition_id(&self) -> Option { + self.partition_id + } + + pub fn cancellation_token(&self) -> &CancellationToken { + &self.cancellation_token + } + + pub fn cancel(&self) { + self.cancellation_token.cancel() + } +} + pub(super) struct Task { pub(super) context: TaskContext, pub(super) handle: Mutex>>, diff --git a/crates/core/src/test_env.rs b/crates/core/src/test_env.rs index 1d7143aab..d95f2ee5e 100644 --- a/crates/core/src/test_env.rs +++ b/crates/core/src/test_env.rs @@ -33,12 +33,11 @@ use crate::network::{ ConnectionManager, FailingConnector, Incoming, MessageHandler, MessageRouterBuilder, NetworkError, Networking, ProtocolError, TransportConnect, }; -use crate::{spawn_metadata_manager, MetadataBuilder, TaskCenterFutureExt, TaskId}; +use crate::TaskCenter; +use crate::{spawn_metadata_manager, MetadataBuilder, TaskId}; use crate::{Metadata, MetadataManager, MetadataWriter}; -use crate::{TaskCenter, TaskCenterBuilder}; pub struct TestCoreEnvBuilder { - pub tc: TaskCenter, pub my_node_id: GenerationalNodeId, pub metadata_manager: MetadataManager, pub metadata_writer: MetadataWriter, @@ -54,11 +53,6 @@ pub struct TestCoreEnvBuilder { impl TestCoreEnvBuilder { pub fn with_incoming_only_connector() -> Self { - let tc = TaskCenterBuilder::default() - .default_runtime_handle(tokio::runtime::Handle::current()) - .ingress_runtime_handle(tokio::runtime::Handle::current()) - .build() - .expect("task_center builds"); let metadata_builder = MetadataBuilder::default(); let net_opts = NetworkingOptions::default(); let connection_manager = @@ -69,11 +63,11 @@ impl TestCoreEnvBuilder { connection_manager, ); - TestCoreEnvBuilder::with_networking(tc, networking, metadata_builder) + TestCoreEnvBuilder::with_networking(networking, metadata_builder) } } impl TestCoreEnvBuilder { - pub fn with_transport_connector(tc: TaskCenter, connector: Arc) -> TestCoreEnvBuilder { + pub fn with_transport_connector(connector: Arc) -> TestCoreEnvBuilder { let metadata_builder = MetadataBuilder::default(); let net_opts = NetworkingOptions::default(); let connection_manager = @@ -84,14 +78,10 @@ impl TestCoreEnvBuilder { connection_manager, ); - TestCoreEnvBuilder::with_networking(tc, networking, metadata_builder) + TestCoreEnvBuilder::with_networking(networking, metadata_builder) } - pub fn with_networking( - tc: TaskCenter, - networking: Networking, - metadata_builder: MetadataBuilder, - ) -> Self { + pub fn with_networking(networking: Networking, metadata_builder: MetadataBuilder) -> Self { let my_node_id = GenerationalNodeId::new(1, 1); let metadata_store_client = MetadataStoreClient::new_in_memory(); let metadata = metadata_builder.to_metadata(); @@ -103,7 +93,7 @@ impl TestCoreEnvBuilder { let partition_table = PartitionTable::with_equally_sized_partitions(Version::MIN, 10); let scheduling_plan = SchedulingPlan::from(&partition_table, ReplicationStrategy::OnAllNodes); - tc.try_set_global_metadata_inner(metadata.clone()); + TaskCenter::try_set_global_metadata(metadata.clone()); // Use memory-loglet as a default if in test-mode #[cfg(any(test, feature = "test-util"))] @@ -112,7 +102,6 @@ impl TestCoreEnvBuilder { let provider_kind = ProviderKind::Local; TestCoreEnvBuilder { - tc, my_node_id, metadata_manager, metadata_writer, @@ -167,86 +156,79 @@ impl TestCoreEnvBuilder { } pub async fn build(mut self) -> TestCoreEnv { - let tc = self.tc; - async { - self.metadata_manager - .register_in_message_router(&mut self.router_builder); - self.networking - .connection_manager() - .set_message_router(self.router_builder.build()); - - let metadata_manager_task = spawn_metadata_manager(self.metadata_manager) - .expect("metadata manager should start"); - - self.metadata_store_client - .put( - NODES_CONFIG_KEY.clone(), - &self.nodes_config, - Precondition::None, - ) - .await - .expect("to store nodes config in metadata store"); - self.metadata_writer - .submit(Arc::new(self.nodes_config.clone())); - - let logs = bootstrap_logs_metadata( - self.provider_kind, - None, - self.partition_table.num_partitions(), - ); - self.metadata_store_client - .put(BIFROST_CONFIG_KEY.clone(), &logs, Precondition::None) - .await - .expect("to store bifrost config in metadata store"); - self.metadata_writer.submit(Arc::new(logs)); - - self.metadata_store_client - .put( - PARTITION_TABLE_KEY.clone(), - &self.partition_table, - Precondition::None, - ) - .await - .expect("to store partition table in metadata store"); - self.metadata_writer.submit(Arc::new(self.partition_table)); - - self.metadata_store_client - .put( - SCHEDULING_PLAN_KEY.clone(), - &self.scheduling_plan, - Precondition::None, - ) - .await - .expect("to store scheduling plan in metadata store"); - - let _ = self - .metadata - .wait_for_version( - MetadataKind::NodesConfiguration, - self.nodes_config.version(), - ) - .await - .unwrap(); - - self.metadata_writer.set_my_node_id(self.my_node_id); - - TestCoreEnv { - tc: TaskCenter::current(), - metadata: self.metadata, - metadata_manager_task, - metadata_writer: self.metadata_writer, - networking: self.networking, - metadata_store_client: self.metadata_store_client, - } + self.metadata_manager + .register_in_message_router(&mut self.router_builder); + self.networking + .connection_manager() + .set_message_router(self.router_builder.build()); + + let metadata_manager_task = + spawn_metadata_manager(self.metadata_manager).expect("metadata manager should start"); + + self.metadata_store_client + .put( + NODES_CONFIG_KEY.clone(), + &self.nodes_config, + Precondition::None, + ) + .await + .expect("to store nodes config in metadata store"); + self.metadata_writer + .submit(Arc::new(self.nodes_config.clone())); + + let logs = bootstrap_logs_metadata( + self.provider_kind, + None, + self.partition_table.num_partitions(), + ); + self.metadata_store_client + .put(BIFROST_CONFIG_KEY.clone(), &logs, Precondition::None) + .await + .expect("to store bifrost config in metadata store"); + self.metadata_writer.submit(Arc::new(logs)); + + self.metadata_store_client + .put( + PARTITION_TABLE_KEY.clone(), + &self.partition_table, + Precondition::None, + ) + .await + .expect("to store partition table in metadata store"); + self.metadata_writer.submit(Arc::new(self.partition_table)); + + self.metadata_store_client + .put( + SCHEDULING_PLAN_KEY.clone(), + &self.scheduling_plan, + Precondition::None, + ) + .await + .expect("to store scheduling plan in metadata store"); + + let _ = self + .metadata + .wait_for_version( + MetadataKind::NodesConfiguration, + self.nodes_config.version(), + ) + .await + .unwrap(); + + self.metadata_writer.set_my_node_id(self.my_node_id); + + TestCoreEnv { + metadata: self.metadata, + metadata_manager_task, + metadata_writer: self.metadata_writer, + networking: self.networking, + metadata_store_client: self.metadata_store_client, } - .in_tc(&tc) - .await } } // This might need to be moved to a better place in the future. pub struct TestCoreEnv { - pub tc: TaskCenter, pub metadata: Metadata, pub metadata_writer: MetadataWriter, pub networking: Networking, diff --git a/crates/core/src/test_env2.rs b/crates/core/src/test_env2.rs deleted file mode 100644 index 17f966446..000000000 --- a/crates/core/src/test_env2.rs +++ /dev/null @@ -1,303 +0,0 @@ -// Copyright (c) 2023 - 2025 Restate Software, Inc., Restate GmbH. -// All rights reserved. -// -// Use of this software is governed by the Business Source License -// included in the LICENSE file. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0. - -use std::marker::PhantomData; -use std::str::FromStr; -use std::sync::Arc; - -use futures::Stream; - -use restate_types::cluster_controller::{ReplicationStrategy, SchedulingPlan}; -use restate_types::config::NetworkingOptions; -use restate_types::logs::metadata::{bootstrap_logs_metadata, ProviderKind}; -use restate_types::metadata_store::keys::{ - BIFROST_CONFIG_KEY, NODES_CONFIG_KEY, PARTITION_TABLE_KEY, SCHEDULING_PLAN_KEY, -}; -use restate_types::net::codec::{Targeted, WireDecode}; -use restate_types::net::metadata::MetadataKind; -use restate_types::net::AdvertisedAddress; -use restate_types::nodes_config::{LogServerConfig, NodeConfig, NodesConfiguration, Role}; -use restate_types::partition_table::PartitionTable; -use restate_types::protobuf::node::Message; -use restate_types::{GenerationalNodeId, Version}; - -use crate::metadata_store::{MetadataStoreClient, Precondition}; -use crate::network::{ - ConnectionManager, FailingConnector, Incoming, MessageHandler, MessageRouterBuilder, - NetworkError, Networking, ProtocolError, TransportConnect, -}; -use crate::TaskCenter; -use crate::{spawn_metadata_manager, MetadataBuilder, TaskId}; -use crate::{Metadata, MetadataManager, MetadataWriter}; - -pub struct TestCoreEnvBuilder2 { - pub my_node_id: GenerationalNodeId, - pub metadata_manager: MetadataManager, - pub metadata_writer: MetadataWriter, - pub metadata: Metadata, - pub networking: Networking, - pub nodes_config: NodesConfiguration, - pub provider_kind: ProviderKind, - pub router_builder: MessageRouterBuilder, - pub partition_table: PartitionTable, - pub scheduling_plan: SchedulingPlan, - pub metadata_store_client: MetadataStoreClient, -} - -impl TestCoreEnvBuilder2 { - pub fn with_incoming_only_connector() -> Self { - let metadata_builder = MetadataBuilder::default(); - let net_opts = NetworkingOptions::default(); - let connection_manager = - ConnectionManager::new_incoming_only(metadata_builder.to_metadata()); - let networking = Networking::with_connection_manager( - metadata_builder.to_metadata(), - net_opts, - connection_manager, - ); - - TestCoreEnvBuilder2::with_networking(networking, metadata_builder) - } -} -impl TestCoreEnvBuilder2 { - pub fn with_transport_connector(connector: Arc) -> TestCoreEnvBuilder2 { - let metadata_builder = MetadataBuilder::default(); - let net_opts = NetworkingOptions::default(); - let connection_manager = - ConnectionManager::new(metadata_builder.to_metadata(), connector, net_opts.clone()); - let networking = Networking::with_connection_manager( - metadata_builder.to_metadata(), - net_opts, - connection_manager, - ); - - TestCoreEnvBuilder2::with_networking(networking, metadata_builder) - } - - pub fn with_networking(networking: Networking, metadata_builder: MetadataBuilder) -> Self { - let my_node_id = GenerationalNodeId::new(1, 1); - let metadata_store_client = MetadataStoreClient::new_in_memory(); - let metadata = metadata_builder.to_metadata(); - let metadata_manager = - MetadataManager::new(metadata_builder, metadata_store_client.clone()); - let metadata_writer = metadata_manager.writer(); - let router_builder = MessageRouterBuilder::default(); - let nodes_config = NodesConfiguration::new(Version::MIN, "test-cluster".to_owned()); - let partition_table = PartitionTable::with_equally_sized_partitions(Version::MIN, 10); - let scheduling_plan = - SchedulingPlan::from(&partition_table, ReplicationStrategy::OnAllNodes); - TaskCenter::try_set_global_metadata(metadata.clone()); - - // Use memory-loglet as a default if in test-mode - #[cfg(any(test, feature = "test-util"))] - let provider_kind = ProviderKind::InMemory; - #[cfg(not(any(test, feature = "test-util")))] - let provider_kind = ProviderKind::Local; - - TestCoreEnvBuilder2 { - my_node_id, - metadata_manager, - metadata_writer, - metadata, - networking, - nodes_config, - router_builder, - partition_table, - scheduling_plan, - metadata_store_client, - provider_kind, - } - } - - pub fn set_nodes_config(mut self, nodes_config: NodesConfiguration) -> Self { - self.nodes_config = nodes_config; - self - } - - pub fn set_partition_table(mut self, partition_table: PartitionTable) -> Self { - self.partition_table = partition_table; - self - } - - pub fn set_scheduling_plan(mut self, scheduling_plan: SchedulingPlan) -> Self { - self.scheduling_plan = scheduling_plan; - self - } - - pub fn set_my_node_id(mut self, my_node_id: GenerationalNodeId) -> Self { - self.my_node_id = my_node_id; - self - } - - pub fn set_provider_kind(mut self, provider_kind: ProviderKind) -> Self { - self.provider_kind = provider_kind; - self - } - - pub fn add_mock_nodes_config(mut self) -> Self { - self.nodes_config = - create_mock_nodes_config(self.my_node_id.raw_id(), self.my_node_id.raw_generation()); - self - } - - pub fn add_message_handler(mut self, handler: H) -> Self - where - H: MessageHandler + Send + Sync + 'static, - { - self.router_builder.add_message_handler(handler); - self - } - - pub async fn build(mut self) -> TestCoreEnv2 { - self.metadata_manager - .register_in_message_router(&mut self.router_builder); - self.networking - .connection_manager() - .set_message_router(self.router_builder.build()); - - let metadata_manager_task = - spawn_metadata_manager(self.metadata_manager).expect("metadata manager should start"); - - self.metadata_store_client - .put( - NODES_CONFIG_KEY.clone(), - &self.nodes_config, - Precondition::None, - ) - .await - .expect("to store nodes config in metadata store"); - self.metadata_writer - .submit(Arc::new(self.nodes_config.clone())); - - let logs = bootstrap_logs_metadata( - self.provider_kind, - None, - self.partition_table.num_partitions(), - ); - self.metadata_store_client - .put(BIFROST_CONFIG_KEY.clone(), &logs, Precondition::None) - .await - .expect("to store bifrost config in metadata store"); - self.metadata_writer.submit(Arc::new(logs)); - - self.metadata_store_client - .put( - PARTITION_TABLE_KEY.clone(), - &self.partition_table, - Precondition::None, - ) - .await - .expect("to store partition table in metadata store"); - self.metadata_writer.submit(Arc::new(self.partition_table)); - - self.metadata_store_client - .put( - SCHEDULING_PLAN_KEY.clone(), - &self.scheduling_plan, - Precondition::None, - ) - .await - .expect("to store scheduling plan in metadata store"); - - let _ = self - .metadata - .wait_for_version( - MetadataKind::NodesConfiguration, - self.nodes_config.version(), - ) - .await - .unwrap(); - - self.metadata_writer.set_my_node_id(self.my_node_id); - - TestCoreEnv2 { - metadata: self.metadata, - metadata_manager_task, - metadata_writer: self.metadata_writer, - networking: self.networking, - metadata_store_client: self.metadata_store_client, - } - } -} - -// This might need to be moved to a better place in the future. -pub struct TestCoreEnv2 { - pub metadata: Metadata, - pub metadata_writer: MetadataWriter, - pub networking: Networking, - pub metadata_manager_task: TaskId, - pub metadata_store_client: MetadataStoreClient, -} - -impl TestCoreEnv2 { - pub async fn create_with_single_node(node_id: u32, generation: u32) -> Self { - TestCoreEnvBuilder2::with_incoming_only_connector() - .set_my_node_id(GenerationalNodeId::new(node_id, generation)) - .add_mock_nodes_config() - .build() - .await - } -} - -impl TestCoreEnv2 { - pub async fn accept_incoming_connection( - &self, - incoming: S, - ) -> Result + Unpin + Send + 'static, NetworkError> - where - S: Stream> + Unpin + Send + 'static, - { - self.networking - .connection_manager() - .accept_incoming_connection(incoming) - .await - } -} - -pub fn create_mock_nodes_config(node_id: u32, generation: u32) -> NodesConfiguration { - let mut nodes_config = NodesConfiguration::new(Version::MIN, "test-cluster".to_owned()); - let address = AdvertisedAddress::from_str("http://127.0.0.1:5122/").unwrap(); - let node_id = GenerationalNodeId::new(node_id, generation); - let roles = Role::Admin | Role::Worker; - let my_node = NodeConfig::new( - format!("MyNode-{}", node_id), - node_id, - address, - roles, - LogServerConfig::default(), - ); - nodes_config.upsert_node(my_node); - nodes_config -} - -/// No-op message handler which simply drops the received messages. Useful if you don't want to -/// react to network messages. -pub struct NoOpMessageHandler { - phantom_data: PhantomData, -} - -impl Default for NoOpMessageHandler { - fn default() -> Self { - NoOpMessageHandler { - phantom_data: PhantomData, - } - } -} - -impl MessageHandler for NoOpMessageHandler -where - M: WireDecode + Targeted + Send + Sync, -{ - type MessageType = M; - - async fn on_message(&self, _msg: Incoming) { - // no-op - } -} diff --git a/crates/ingress-http/src/handler/tests.rs b/crates/ingress-http/src/handler/tests.rs index 0f28cb746..f64e3acb2 100644 --- a/crates/ingress-http/src/handler/tests.rs +++ b/crates/ingress-http/src/handler/tests.rs @@ -48,7 +48,7 @@ use super::Handler; use crate::handler::responses::X_RESTATE_ID; use crate::MockRequestDispatcher; -#[tokio::test] +#[restate_core::test] #[traced_test] async fn call_service() { let greeting_req = GreetingRequest { @@ -105,7 +105,7 @@ async fn call_service() { assert_eq!(response_value.greeting, "Igal"); } -#[tokio::test] +#[restate_core::test] #[traced_test] async fn call_service_with_get() { let req = hyper::Request::builder() @@ -165,7 +165,7 @@ async fn call_service_with_get() { assert_eq!(response_value.greeting, "Igal"); } -#[tokio::test] +#[restate_core::test] #[traced_test] async fn call_virtual_object() { let greeting_req = GreetingRequest { @@ -220,7 +220,7 @@ async fn call_virtual_object() { assert_eq!(response_value.greeting, "Igal"); } -#[tokio::test] +#[restate_core::test] #[traced_test] async fn send_service() { let greeting_req = GreetingRequest { @@ -265,7 +265,7 @@ async fn send_service() { let _: SendResponse = serde_json::from_slice(&response_bytes).unwrap(); } -#[tokio::test] +#[restate_core::test] #[traced_test] async fn send_with_delay_service() { let greeting_req = GreetingRequest { @@ -312,7 +312,7 @@ async fn send_with_delay_service() { let _: SendResponse = serde_json::from_slice(&response_bytes).unwrap(); } -#[tokio::test] +#[restate_core::test] #[traced_test] async fn send_virtual_object() { let greeting_req = GreetingRequest { @@ -358,7 +358,7 @@ async fn send_virtual_object() { let _: SendResponse = serde_json::from_slice(&response_bytes).unwrap(); } -#[tokio::test] +#[restate_core::test] #[traced_test] async fn idempotency_key_parsing() { let greeting_req = GreetingRequest { @@ -423,7 +423,7 @@ async fn idempotency_key_parsing() { assert_eq!(response_value.greeting, "Igal"); } -#[tokio::test] +#[restate_core::test] #[traced_test] async fn idempotency_key_and_send() { let greeting_req = GreetingRequest { @@ -478,7 +478,7 @@ async fn idempotency_key_and_send() { let _: SendResponse = serde_json::from_slice(&response_bytes).unwrap(); } -#[tokio::test] +#[restate_core::test] #[traced_test] async fn idempotency_key_and_send_with_different_invocation_id() { let greeting_req = GreetingRequest { @@ -539,7 +539,7 @@ async fn idempotency_key_and_send_with_different_invocation_id() { assert_eq!(send_response.invocation_id, expected_invocation_id); } -#[tokio::test] +#[restate_core::test] #[traced_test] async fn attach_with_invocation_id() { let invocation_id = InvocationId::mock_random(); @@ -594,7 +594,7 @@ async fn attach_with_invocation_id() { assert_eq!(response_value.greeting, "Igal"); } -#[tokio::test] +#[restate_core::test] #[traced_test] async fn attach_with_idempotency_id_to_unkeyed_service() { let mock_schemas = MockSchemas::default().with_service_and_target( @@ -650,7 +650,7 @@ async fn attach_with_idempotency_id_to_unkeyed_service() { assert_eq!(response_value.greeting, "Igal"); } -#[tokio::test] +#[restate_core::test] #[traced_test] async fn attach_with_idempotency_id_to_keyed_service() { let mock_schemas = MockSchemas::default().with_service_and_target( @@ -713,7 +713,7 @@ async fn attach_with_idempotency_id_to_keyed_service() { assert_eq!(response_value.greeting, "Igal"); } -#[tokio::test] +#[restate_core::test] #[traced_test] async fn get_output_with_invocation_id() { let invocation_id = InvocationId::mock_random(); @@ -768,7 +768,7 @@ async fn get_output_with_invocation_id() { assert_eq!(response_value.greeting, "Igal"); } -#[tokio::test] +#[restate_core::test] #[traced_test] async fn get_output_with_workflow_key() { let service_id = ServiceId::new("MyWorkflow", "my-key"); @@ -830,7 +830,7 @@ async fn get_output_with_workflow_key() { assert_eq!(response_value.greeting, "Igal"); } -#[tokio::test] +#[restate_core::test] #[traced_test] async fn bad_path_service() { let response = handle( @@ -861,7 +861,7 @@ async fn bad_path_service() { assert_eq!(response.status(), StatusCode::BAD_REQUEST); } -#[tokio::test] +#[restate_core::test] #[traced_test] async fn bad_path_virtual_object() { let response = handle( @@ -892,7 +892,7 @@ async fn bad_path_virtual_object() { assert_eq!(response.status(), StatusCode::BAD_REQUEST); } -#[tokio::test] +#[restate_core::test] #[traced_test] async fn unknown_service() { let response = handle( @@ -907,7 +907,7 @@ async fn unknown_service() { assert_eq!(response.status(), StatusCode::NOT_FOUND); } -#[tokio::test] +#[restate_core::test] #[traced_test] async fn unknown_handler() { let response = handle( @@ -920,7 +920,7 @@ async fn unknown_handler() { assert_eq!(response.status(), StatusCode::NOT_FOUND); } -#[tokio::test] +#[restate_core::test] #[traced_test] async fn private_service() { let response = handle_with_schemas_and_dispatcher( @@ -941,7 +941,7 @@ async fn private_service() { assert_eq!(response.status(), StatusCode::BAD_REQUEST); } -#[tokio::test] +#[restate_core::test] #[traced_test] async fn invalid_input() { let response = handle_with_schemas_and_dispatcher( @@ -970,7 +970,7 @@ async fn invalid_input() { assert_eq!(response.status(), StatusCode::BAD_REQUEST); } -#[tokio::test] +#[restate_core::test] #[traced_test] async fn set_custom_content_type_on_response() { let mock_schemas = MockSchemas::default().with_service_and_target( @@ -1018,7 +1018,7 @@ async fn set_custom_content_type_on_response() { ); } -#[tokio::test] +#[restate_core::test] #[traced_test] async fn set_custom_content_type_on_empty_response() { let mock_schemas = MockSchemas::default().with_service_and_target( @@ -1069,7 +1069,7 @@ async fn set_custom_content_type_on_empty_response() { ); } -#[tokio::test] +#[restate_core::test] #[traced_test] async fn health() { let req = hyper::Request::builder() @@ -1135,17 +1135,13 @@ where ::Error: std::error::Error + Send + Sync + 'static, ::Data: Send + Sync + 'static, { - let node_env = TestCoreEnv::create_with_single_node(1, 1).await; + let _env = TestCoreEnv::create_with_single_node(1, 1).await; req.extensions_mut() .insert(ConnectInfo::new("0.0.0.0:0".parse().unwrap())); req.extensions_mut().insert(opentelemetry::Context::new()); - let handler_fut = node_env.tc.run_in_scope( - "ingress", - None, - Handler::new(Live::from_value(schemas), Arc::new(dispatcher)).oneshot(req), - ); + let handler_fut = Handler::new(Live::from_value(schemas), Arc::new(dispatcher)).oneshot(req); handler_fut.await.unwrap() } diff --git a/crates/ingress-http/src/server.rs b/crates/ingress-http/src/server.rs index 733a759cb..ccc0d8a79 100644 --- a/crates/ingress-http/src/server.rs +++ b/crates/ingress-http/src/server.rs @@ -17,7 +17,7 @@ use http_body_util::Full; use hyper::body::Incoming; use hyper_util::rt::TokioIo; use hyper_util::server::conn::auto; -use restate_core::{cancellation_watcher, task_center, TaskKind}; +use restate_core::{cancellation_watcher, TaskCenter, TaskKind}; use restate_types::config::IngressOptions; use restate_types::health::HealthStatus; use restate_types::live::Live; @@ -202,7 +202,7 @@ where )); // Spawn a tokio task to serve the connection - task_center().spawn(TaskKind::Ingress, "ingress", None, async move { + TaskCenter::spawn(TaskKind::Ingress, "ingress", async move { let shutdown = cancellation_watcher(); let auto_connection = auto::Builder::new(TaskCenterExecutor); let serve_connection_fut = auto_connection.serve_connection(io, handler); @@ -231,7 +231,7 @@ where Fut::Output: Send + 'static, { fn execute(&self, fut: Fut) { - let _ = task_center().spawn(TaskKind::Ingress, "ingress", None, async { + let _ = TaskCenter::spawn(TaskKind::Ingress, "ingress", async { fut.await; Ok(()) }); @@ -247,7 +247,8 @@ mod tests { use http_body_util::Full; use hyper_util::client::legacy::Client; use hyper_util::rt::TokioExecutor; - use restate_core::{TaskCenter, TaskKind, TestCoreEnv}; + use restate_core::TestCoreEnv; + use restate_core::{TaskCenter, TaskKind}; use restate_test_util::assert_eq; use restate_types::health::Health; use restate_types::identifiers::WithInvocationId; @@ -270,7 +271,7 @@ mod tests { pub greeting: String, } - #[tokio::test] + #[restate_core::test] #[traced_test] async fn test_http_post() { let mut mock_dispatcher = MockRequestDispatcher::default(); @@ -303,7 +304,7 @@ mod tests { }))) }); - let (address, handle) = bootstrap_test(mock_dispatcher).await; + let address = bootstrap_test(mock_dispatcher).await; // Send the request let client = Client::builder(TokioExecutor::new()) @@ -331,14 +332,10 @@ mod tests { let response_bytes = response_body.collect().await.unwrap().to_bytes(); let response_value: GreetingResponse = serde_json::from_slice(&response_bytes).unwrap(); restate_test_util::assert_eq!(response_value.greeting, "Igal"); - - handle.close().await; } - async fn bootstrap_test( - mock_request_dispatcher: MockRequestDispatcher, - ) -> (SocketAddr, TestHandle) { - let node_env = TestCoreEnv::create_with_single_node(1, 1).await; + async fn bootstrap_test(mock_request_dispatcher: MockRequestDispatcher) -> SocketAddr { + let _env = TestCoreEnv::create_with_single_node(1, 1).await; let health = Health::default(); // Create the ingress and start it @@ -349,22 +346,9 @@ mod tests { Arc::new(mock_request_dispatcher), health.ingress_status(), ); - node_env - .tc - .spawn(TaskKind::SystemService, "ingress", None, ingress.run()) - .unwrap(); + TaskCenter::spawn(TaskKind::SystemService, "ingress", ingress.run()).unwrap(); // Wait server to start - let address = start_signal.await.unwrap(); - - (address, TestHandle(node_env.tc)) - } - - struct TestHandle(TaskCenter); - - impl TestHandle { - async fn close(self) { - self.0.cancel_tasks(None, None).await; - } + start_signal.await.unwrap() } } diff --git a/crates/ingress-kafka/src/consumer_task.rs b/crates/ingress-kafka/src/consumer_task.rs index 71619b3fb..fdc624da3 100644 --- a/crates/ingress-kafka/src/consumer_task.rs +++ b/crates/ingress-kafka/src/consumer_task.rs @@ -213,21 +213,14 @@ impl MessageSender { #[derive(Clone)] pub struct ConsumerTask { - task_center: TaskCenter, client_config: ClientConfig, topics: Vec, sender: MessageSender, } impl ConsumerTask { - pub fn new( - task_center: TaskCenter, - client_config: ClientConfig, - topics: Vec, - sender: MessageSender, - ) -> Self { + pub fn new(client_config: ClientConfig, topics: Vec, sender: MessageSender) -> Self { Self { - task_center, client_config, topics, sender, @@ -313,7 +306,7 @@ impl ConsumerTask { } }; for task_id in topic_partition_tasks.into_values() { - self.task_center.cancel_task(task_id); + TaskCenter::cancel_task(task_id); } result } diff --git a/crates/ingress-kafka/src/dispatcher.rs b/crates/ingress-kafka/src/dispatcher.rs index 63920a05d..3e6cc95b6 100644 --- a/crates/ingress-kafka/src/dispatcher.rs +++ b/crates/ingress-kafka/src/dispatcher.rs @@ -11,7 +11,7 @@ use crate::consumer_task::KafkaDeduplicationId; use bytes::Bytes; use restate_bifrost::Bifrost; -use restate_core::metadata; +use restate_core::{my_node_id, Metadata}; use restate_storage_api::deduplication_table::DedupInformation; use restate_types::identifiers::{ partitioner, InvocationId, PartitionKey, PartitionProcessorRpcRequestId, WithPartitionKey, @@ -190,7 +190,7 @@ impl DispatchKafkaEvent for KafkaIngressDispatcher { let envelope = wrap_service_invocation_in_envelope( partition_key, inner, - metadata().my_node_id(), + my_node_id(), deduplication_id.to_string(), deduplication_index, ); @@ -215,7 +215,7 @@ fn wrap_service_invocation_in_envelope( let header = Header { source: Source::Ingress { node_id: from_node_id, - nodes_config_version: metadata().nodes_config_version(), + nodes_config_version: Metadata::with_current(|m| m.nodes_config_version()), }, dest: Destination::Processor { partition_key, diff --git a/crates/ingress-kafka/src/subscription_controller.rs b/crates/ingress-kafka/src/subscription_controller.rs index 42619a7c4..6d51ae1c4 100644 --- a/crates/ingress-kafka/src/subscription_controller.rs +++ b/crates/ingress-kafka/src/subscription_controller.rs @@ -17,7 +17,7 @@ use crate::subscription_controller::task_orchestrator::TaskOrchestrator; use anyhow::Context; use rdkafka::error::KafkaError; use restate_bifrost::Bifrost; -use restate_core::{cancellation_watcher, task_center}; +use restate_core::cancellation_watcher; use restate_types::config::IngressOptions; use restate_types::identifiers::SubscriptionId; use restate_types::live::LiveLoad; @@ -142,7 +142,6 @@ impl Service { // Create the consumer task let consumer_task = consumer_task::ConsumerTask::new( - task_center(), client_config, vec![topic.to_string()], MessageSender::new( diff --git a/crates/invoker-impl/src/lib.rs b/crates/invoker-impl/src/lib.rs index e37e3f01a..960535e9f 100644 --- a/crates/invoker-impl/src/lib.rs +++ b/crates/invoker-impl/src/lib.rs @@ -1033,8 +1033,7 @@ mod tests { use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; - use restate_core::TaskKind; - use restate_core::TestCoreEnv; + use restate_core::{TaskCenter, TaskKind}; use restate_invoker_api::entry_enricher; use restate_invoker_api::test_util::EmptyStorageReader; use restate_invoker_api::InvokerHandle; @@ -1193,10 +1192,8 @@ mod tests { } } - #[test(tokio::test)] + #[test(restate_core::test)] async fn input_order_is_maintained() { - let node_env = TestCoreEnv::create_with_single_node(1, 1).await; - let tc = node_env.tc; let invoker_options = InvokerOptionsBuilder::default() // fixed amount of retries so that an invocation eventually completes with a failure .retry_policy(RetryPolicy::fixed_delay(Duration::ZERO, Some(1))) @@ -1221,14 +1218,12 @@ mod tests { let mut handle = service.handle(); - let invoker_task_id = tc - .spawn( - TaskKind::SystemService, - "invoker", - None, - service.run(Constant::new(invoker_options)), - ) - .unwrap(); + let invoker_task_id = TaskCenter::spawn( + TaskKind::SystemService, + "invoker", + service.run(Constant::new(invoker_options)), + ) + .unwrap(); let partition_leader_epoch = (PartitionId::from(0), LeaderEpoch::INITIAL); let invocation_target = InvocationTarget::mock_service(); @@ -1260,10 +1255,13 @@ mod tests { // the invocation and we won't see a result for the invocation (failure because the deployment cannot be resolved). check!(let Some(_) = output_rx.recv().await); - tc.cancel_task(invoker_task_id).unwrap().await.unwrap(); + TaskCenter::cancel_task(invoker_task_id) + .unwrap() + .await + .unwrap(); } - #[test(tokio::test)] + #[test(restate_core::test)] async fn quota_allows_one_concurrent_invocation() { let invoker_options = InvokerOptionsBuilder::default() // fixed amount of retries so that an invocation eventually completes with a failure @@ -1359,7 +1357,7 @@ mod tests { assert!(!service_inner.quota.is_slot_available()); } - #[test(tokio::test)] + #[test(restate_core::test)] async fn reclaim_quota_after_abort() { let invoker_options = InvokerOptionsBuilder::default() // fixed amount of retries so that an invocation eventually completes with a failure diff --git a/crates/log-server/src/lib.rs b/crates/log-server/src/lib.rs index bfcd43410..e4056000c 100644 --- a/crates/log-server/src/lib.rs +++ b/crates/log-server/src/lib.rs @@ -21,14 +21,3 @@ mod service; pub use error::LogServerBuildError; pub use service::LogServerService; - -#[cfg(test)] -pub(crate) fn setup_panic_handler() { - // Make sure that panics exits the process. - let orig_hook = std::panic::take_hook(); - std::panic::set_hook(Box::new(move |panic_info| { - // invoke the default handler and exit the process - orig_hook(panic_info); - std::process::exit(1); - })); -} diff --git a/crates/log-server/src/loglet_worker.rs b/crates/log-server/src/loglet_worker.rs index c83c86b85..6b8600868 100644 --- a/crates/log-server/src/loglet_worker.rs +++ b/crates/log-server/src/loglet_worker.rs @@ -119,10 +119,9 @@ impl LogletWorker { let (wait_for_tail_tx, wait_for_tail_rx) = mpsc::unbounded_channel(); let (get_digest_tx, get_digest_rx) = mpsc::unbounded_channel(); // todo - let tc_handle = TaskCenter::current().spawn_unmanaged( + let tc_handle = TaskCenter::spawn_unmanaged( TaskKind::LogletWriter, "loglet-worker", - None, writer.run( store_rx, release_rx, @@ -421,111 +420,97 @@ impl LogletWorker { fn process_wait_for_tail(&mut self, msg: Incoming) { let loglet_state = self.loglet_state.clone(); // fails on shutdown, in this case, we ignore the request - let _ = TaskCenter::current().spawn( - TaskKind::Disposable, - "logserver-tail-monitor", - None, - async move { - let (reciprocal, msg) = msg.split(); - let local_tail_watch = loglet_state.get_local_tail_watch(); - // If shutdown happened, this task will be disposed of and we won't send - // the response. - match msg.query { - TailUpdateQuery::LocalTail(target_offset) => { - local_tail_watch.wait_for_offset_or_seal(target_offset).await?; - } - TailUpdateQuery::GlobalTail(target_global_tail) => { - let global_tail_tracker = loglet_state.get_global_tail_tracker(); - tokio::select! { - res = global_tail_tracker.wait_for_offset(target_global_tail) => { res.map(|_|()) }, - // Are we locally sealed? - res = local_tail_watch.wait_for_seal() => { res }, - }?; - } - TailUpdateQuery::LocalOrGlobal(target_offset) => { - let global_tail_tracker = loglet_state.get_global_tail_tracker(); - tokio::select! { - res = global_tail_tracker.wait_for_offset(target_offset) => { res.map(|_|()) }, - res = local_tail_watch.wait_for_offset_or_seal(target_offset) => { res.map(|_|()) }, - }?; - } - }; + let _ = TaskCenter::spawn(TaskKind::Disposable, "logserver-tail-monitor", async move { + let (reciprocal, msg) = msg.split(); + let local_tail_watch = loglet_state.get_local_tail_watch(); + // If shutdown happened, this task will be disposed of and we won't send + // the response. + match msg.query { + TailUpdateQuery::LocalTail(target_offset) => { + local_tail_watch + .wait_for_offset_or_seal(target_offset) + .await?; + } + TailUpdateQuery::GlobalTail(target_global_tail) => { + let global_tail_tracker = loglet_state.get_global_tail_tracker(); + tokio::select! { + res = global_tail_tracker.wait_for_offset(target_global_tail) => { res.map(|_|()) }, + // Are we locally sealed? + res = local_tail_watch.wait_for_seal() => { res }, + }?; + } + TailUpdateQuery::LocalOrGlobal(target_offset) => { + let global_tail_tracker = loglet_state.get_global_tail_tracker(); + tokio::select! { + res = global_tail_tracker.wait_for_offset(target_offset) => { res.map(|_|()) }, + res = local_tail_watch.wait_for_offset_or_seal(target_offset) => { res.map(|_|()) }, + }?; + } + }; - let update = - TailUpdated::new(loglet_state.local_tail(), loglet_state.known_global_tail()); - let _ = reciprocal.prepare(update).send().await; - Ok(()) - }, - ); + let update = + TailUpdated::new(loglet_state.local_tail(), loglet_state.known_global_tail()); + let _ = reciprocal.prepare(update).send().await; + Ok(()) + }); } fn process_get_records(&mut self, msg: Incoming) { let log_store = self.log_store.clone(); let loglet_state = self.loglet_state.clone(); // fails on shutdown, in this case, we ignore the request - let _ = TaskCenter::current().spawn( - TaskKind::Disposable, - "logserver-get-records", - None, - async move { - let (reciprocal, msg) = msg.split(); - let from_offset = msg.from_offset; - // validate that from_offset <= to_offset - if msg.from_offset > msg.to_offset { - let response = reciprocal - .prepare(Records::empty(from_offset).with_status(Status::Malformed)); - // ship the response to the original connection - let _ = response.send().await; - return Ok(()); - } - let records = match log_store.read_records(msg, &loglet_state).await { - Ok(records) => records, - Err(_) => Records::new( - loglet_state.local_tail(), - loglet_state.known_global_tail(), - from_offset, - ) - .with_status(Status::Disabled), - }; + let _ = TaskCenter::spawn(TaskKind::Disposable, "logserver-get-records", async move { + let (reciprocal, msg) = msg.split(); + let from_offset = msg.from_offset; + // validate that from_offset <= to_offset + if msg.from_offset > msg.to_offset { + let response = + reciprocal.prepare(Records::empty(from_offset).with_status(Status::Malformed)); // ship the response to the original connection - let _ = reciprocal.prepare(records).send().await; - Ok(()) - }, - ); + let _ = response.send().await; + return Ok(()); + } + let records = match log_store.read_records(msg, &loglet_state).await { + Ok(records) => records, + Err(_) => Records::new( + loglet_state.local_tail(), + loglet_state.known_global_tail(), + from_offset, + ) + .with_status(Status::Disabled), + }; + // ship the response to the original connection + let _ = reciprocal.prepare(records).send().await; + Ok(()) + }); } fn process_get_digest(&mut self, msg: Incoming) { let log_store = self.log_store.clone(); let loglet_state = self.loglet_state.clone(); // fails on shutdown, in this case, we ignore the request - let _ = TaskCenter::current().spawn( - TaskKind::Disposable, - "logserver-get-digest", - None, - async move { - let (reciprocal, msg) = msg.split(); - // validation. Note that to_offset is inclusive. - if msg.from_offset > msg.to_offset { - let response = - reciprocal.prepare(Digest::empty().with_status(Status::Malformed)); - // ship the response to the original connection - let _ = response.send().await; - return Ok(()); - } - let digest = match log_store.get_records_digest(msg, &loglet_state).await { - Ok(digest) => digest, - Err(_) => Digest::new( - loglet_state.local_tail(), - loglet_state.known_global_tail(), - Default::default(), - ) - .with_status(Status::Disabled), - }; + let _ = TaskCenter::spawn(TaskKind::Disposable, "logserver-get-digest", async move { + let (reciprocal, msg) = msg.split(); + // validation. Note that to_offset is inclusive. + if msg.from_offset > msg.to_offset { + let response = reciprocal.prepare(Digest::empty().with_status(Status::Malformed)); // ship the response to the original connection - let _ = reciprocal.prepare(digest).send().await; - Ok(()) - }, - ); + let _ = response.send().await; + return Ok(()); + } + let digest = match log_store.get_records_digest(msg, &loglet_state).await { + Ok(digest) => digest, + Err(_) => Digest::new( + loglet_state.local_tail(), + loglet_state.known_global_tail(), + Default::default(), + ) + .with_status(Status::Disabled), + }; + // ship the response to the original connection + let _ = reciprocal.prepare(digest).send().await; + Ok(()) + }); } fn process_trim(&mut self, msg: Incoming) { @@ -536,47 +521,54 @@ impl LogletWorker { // fails on shutdown, in this case, we ignore the request let mut loglet_state = self.loglet_state.clone(); let log_store = self.log_store.clone(); - let _ = - TaskCenter::current() - .spawn(TaskKind::Disposable, "logserver-trim", None, async move { - let loglet_id = msg.body().header.loglet_id; - let new_trim_point = msg.body().trim_point; - // cannot trim beyond the global known tail (if known) or the local_tail whichever is higher. - let local_tail = loglet_state.local_tail(); - let known_global_tail = loglet_state.known_global_tail(); - let high_watermark = known_global_tail.max(local_tail.offset()); - if new_trim_point < LogletOffset::OLDEST || new_trim_point >= high_watermark { - let _ = msg.to_rpc_response(Trimmed::new(loglet_state.local_tail(), known_global_tail).with_status(Status::Malformed)).send().await; - return Ok(()); - } + let _ = TaskCenter::spawn(TaskKind::Disposable, "logserver-trim", async move { + let loglet_id = msg.body().header.loglet_id; + let new_trim_point = msg.body().trim_point; + // cannot trim beyond the global known tail (if known) or the local_tail whichever is higher. + let local_tail = loglet_state.local_tail(); + let known_global_tail = loglet_state.known_global_tail(); + let high_watermark = known_global_tail.max(local_tail.offset()); + if new_trim_point < LogletOffset::OLDEST || new_trim_point >= high_watermark { + let _ = msg + .to_rpc_response( + Trimmed::new(loglet_state.local_tail(), known_global_tail) + .with_status(Status::Malformed), + ) + .send() + .await; + return Ok(()); + } - let (reciprocal, mut msg) = msg.split(); - // The trim point cannot be at or exceed the local_tail, we clip to the - // local_tail-1 if that's the case. - msg.trim_point = msg.trim_point.min(local_tail.offset().prev()); - - - let body = if loglet_state.update_trim_point(msg.trim_point) { - match log_store.enqueue_trim(msg).await?.await { - Ok(_) => Trimmed::new(loglet_state.local_tail(), loglet_state.known_global_tail()).with_status(Status::Ok), - Err(_) => { - warn!( - %loglet_id, - "Log-store is disabled, and its trim-point will falsely be reported as {} since we couldn't commit that to the log-store. Trim-point will be correct after restart.", - new_trim_point - ); - Trimmed::new(loglet_state.local_tail(), loglet_state.known_global_tail()).with_status(Status::Disabled) - } + let (reciprocal, mut msg) = msg.split(); + // The trim point cannot be at or exceed the local_tail, we clip to the + // local_tail-1 if that's the case. + msg.trim_point = msg.trim_point.min(local_tail.offset().prev()); + + let body = if loglet_state.update_trim_point(msg.trim_point) { + match log_store.enqueue_trim(msg).await?.await { + Ok(_) => { + Trimmed::new(loglet_state.local_tail(), loglet_state.known_global_tail()) + .with_status(Status::Ok) + } + Err(_) => { + warn!( + %loglet_id, + "Log-store is disabled, and its trim-point will falsely be reported as {} since we couldn't commit that to the log-store. Trim-point will be correct after restart.", + new_trim_point + ); + Trimmed::new(loglet_state.local_tail(), loglet_state.known_global_tail()) + .with_status(Status::Disabled) } - } else { - // it's already trimmed - Trimmed::new(loglet_state.local_tail(), loglet_state.known_global_tail()) - }; + } + } else { + // it's already trimmed + Trimmed::new(loglet_state.local_tail(), loglet_state.known_global_tail()) + }; - // ship the response to the original connection - let _ = reciprocal.prepare(body).send().await; - Ok(()) - }); + // ship the response to the original connection + let _ = reciprocal.prepare(body).send().await; + Ok(()) + }); } async fn process_seal( @@ -614,7 +606,7 @@ mod tests { use test_log::test; use restate_core::network::OwnedConnection; - use restate_core::{MetadataBuilder, TaskCenter, TaskCenterBuilder, TaskCenterFutureExt}; + use restate_core::{MetadataBuilder, TaskCenter}; use restate_rocksdb::RocksDbManager; use restate_types::config::Configuration; use restate_types::live::Live; @@ -625,893 +617,864 @@ mod tests { use crate::metadata::LogletStateMap; use crate::rocksdb_logstore::{RocksDbLogStore, RocksDbLogStoreBuilder}; - use crate::setup_panic_handler; use super::LogletWorker; - async fn setup() -> Result<(TaskCenter, RocksDbLogStore)> { - setup_panic_handler(); - let tc = TaskCenterBuilder::default_for_tests().build()?; + async fn setup() -> Result { let config = Live::from_value(Configuration::default()); let common_rocks_opts = config.clone().map(|c| &c.common); - let log_store = async { - RocksDbManager::init(common_rocks_opts); - let metadata_builder = MetadataBuilder::default(); - assert!(TaskCenter::try_set_global_metadata( - metadata_builder.to_metadata() - )); - // create logstore. - let builder = RocksDbLogStoreBuilder::create( - config.clone().map(|c| &c.log_server).boxed(), - config.map(|c| &c.log_server.rocksdb).boxed(), - RecordCache::new(1_000_000), - ) - .await?; - let log_store = builder.start(Default::default()).await?; - Result::Ok(log_store) - } - .in_tc(&tc) + RocksDbManager::init(common_rocks_opts); + let metadata_builder = MetadataBuilder::default(); + assert!(TaskCenter::try_set_global_metadata( + metadata_builder.to_metadata() + )); + // create logstore. + let builder = RocksDbLogStoreBuilder::create( + config.clone().map(|c| &c.log_server).boxed(), + config.map(|c| &c.log_server.rocksdb).boxed(), + RecordCache::new(1_000_000), + ) .await?; - Ok((tc, log_store)) + Ok(builder.start(Default::default()).await?) } - #[test(tokio::test(start_paused = true))] + #[test(restate_core::test(start_paused = true))] async fn test_simple_store_flow() -> Result<()> { - let (tc, log_store) = setup().await?; - async { - const SEQUENCER: GenerationalNodeId = GenerationalNodeId::new(1, 1); - const LOGLET: ReplicatedLogletId = ReplicatedLogletId::new_unchecked(1); - let loglet_state_map = LogletStateMap::default(); - let (net_tx, mut net_rx) = mpsc::channel(10); - let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); - - let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; - let worker = LogletWorker::start(LOGLET, log_store, loglet_state)?; - - let payloads: Arc<[Record]> = vec![ - Record::from("a sample record"), - Record::from("another record"), - ] - .into(); - - // offsets 1, 2 - let msg1 = Store { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::OLDEST, - flags: StoreFlags::empty(), - payloads: payloads.clone(), - }; - - // offsets 3, 4 - let msg2 = Store { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::new(3), - flags: StoreFlags::empty(), - payloads: payloads.clone(), - }; - - let msg1 = Incoming::for_testing(connection.downgrade(), msg1, None); - let msg2 = Incoming::for_testing(connection.downgrade(), msg2, None); - let msg1_id = msg1.msg_id(); - let msg2_id = msg2.msg_id(); - - // pipelined writes - worker.enqueue_store(msg1).unwrap(); - worker.enqueue_store(msg2).unwrap(); - // wait for response (in test-env, it's safe to assume that responses will arrive in order) - let response = net_rx.recv().await.unwrap(); - let header = response.header.unwrap(); - assert_that!(header.in_response_to(), eq(msg1_id)); - let stored: Stored = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(stored.status, eq(Status::Ok)); - assert_that!(stored.local_tail, eq(LogletOffset::new(3))); - - // response 2 - let response = net_rx.recv().await.unwrap(); - let header = response.header.unwrap(); - assert_that!(header.in_response_to(), eq(msg2_id)); - let stored: Stored = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(stored.status, eq(Status::Ok)); - assert_that!(stored.local_tail, eq(LogletOffset::new(5))); + let log_store = setup().await?; + const SEQUENCER: GenerationalNodeId = GenerationalNodeId::new(1, 1); + const LOGLET: ReplicatedLogletId = ReplicatedLogletId::new_unchecked(1); + let loglet_state_map = LogletStateMap::default(); + let (net_tx, mut net_rx) = mpsc::channel(10); + let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); + + let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; + let worker = LogletWorker::start(LOGLET, log_store, loglet_state)?; + + let payloads: Arc<[Record]> = vec![ + Record::from("a sample record"), + Record::from("another record"), + ] + .into(); + + // offsets 1, 2 + let msg1 = Store { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::OLDEST, + flags: StoreFlags::empty(), + payloads: payloads.clone(), + }; - tc.shutdown_node("test completed", 0).await; - RocksDbManager::get().shutdown().await; + // offsets 3, 4 + let msg2 = Store { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::new(3), + flags: StoreFlags::empty(), + payloads: payloads.clone(), + }; - Ok(()) - } - .in_tc(&tc) - .await + let msg1 = Incoming::for_testing(connection.downgrade(), msg1, None); + let msg2 = Incoming::for_testing(connection.downgrade(), msg2, None); + let msg1_id = msg1.msg_id(); + let msg2_id = msg2.msg_id(); + + // pipelined writes + worker.enqueue_store(msg1).unwrap(); + worker.enqueue_store(msg2).unwrap(); + // wait for response (in test-env, it's safe to assume that responses will arrive in order) + let response = net_rx.recv().await.unwrap(); + let header = response.header.unwrap(); + assert_that!(header.in_response_to(), eq(msg1_id)); + let stored: Stored = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(stored.status, eq(Status::Ok)); + assert_that!(stored.local_tail, eq(LogletOffset::new(3))); + + // response 2 + let response = net_rx.recv().await.unwrap(); + let header = response.header.unwrap(); + assert_that!(header.in_response_to(), eq(msg2_id)); + let stored: Stored = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(stored.status, eq(Status::Ok)); + assert_that!(stored.local_tail, eq(LogletOffset::new(5))); + + TaskCenter::shutdown_node("test completed", 0).await; + RocksDbManager::get().shutdown().await; + + Ok(()) } - #[test(tokio::test(start_paused = true))] + #[test(restate_core::test(start_paused = true))] async fn test_store_and_seal() -> Result<()> { - let (tc, log_store) = setup().await?; - async { - const SEQUENCER: GenerationalNodeId = GenerationalNodeId::new(1, 1); - const LOGLET: ReplicatedLogletId = ReplicatedLogletId::new_unchecked(1); - let loglet_state_map = LogletStateMap::default(); - let (net_tx, mut net_rx) = mpsc::channel(10); - let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); - - let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; - let worker = LogletWorker::start(LOGLET, log_store, loglet_state)?; - - let payloads: Arc<[Record]> = vec![ - Record::from("a sample record"), - Record::from("another record"), - ] - .into(); - - // offsets 1, 2 - let msg1 = Store { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::OLDEST, - flags: StoreFlags::empty(), - payloads: payloads.clone(), - }; - - let seal1 = Seal { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - sequencer: SEQUENCER, - }; - - let seal2 = Seal { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - sequencer: SEQUENCER, - }; + let log_store = setup().await?; + const SEQUENCER: GenerationalNodeId = GenerationalNodeId::new(1, 1); + const LOGLET: ReplicatedLogletId = ReplicatedLogletId::new_unchecked(1); + let loglet_state_map = LogletStateMap::default(); + let (net_tx, mut net_rx) = mpsc::channel(10); + let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); + + let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; + let worker = LogletWorker::start(LOGLET, log_store, loglet_state)?; + + let payloads: Arc<[Record]> = vec![ + Record::from("a sample record"), + Record::from("another record"), + ] + .into(); + + // offsets 1, 2 + let msg1 = Store { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::OLDEST, + flags: StoreFlags::empty(), + payloads: payloads.clone(), + }; - // offsets 3, 4 - let msg2 = Store { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::new(3), - flags: StoreFlags::empty(), - payloads: payloads.clone(), - }; + let seal1 = Seal { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + sequencer: SEQUENCER, + }; - let msg1 = Incoming::for_testing(connection.downgrade(), msg1, None); - let seal1 = Incoming::for_testing(connection.downgrade(), seal1, None); - let seal2 = Incoming::for_testing(connection.downgrade(), seal2, None); - let msg2 = Incoming::for_testing(connection.downgrade(), msg2, None); - let msg1_id = msg1.msg_id(); - let seal1_id = seal1.msg_id(); - let seal2_id = seal2.msg_id(); - let msg2_id = msg2.msg_id(); - - worker.enqueue_store(msg1).unwrap(); - // first store is successful - let response = net_rx.recv().await.unwrap(); - let header = response.header.unwrap(); - assert_that!(header.in_response_to(), eq(msg1_id)); - let stored: Stored = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(stored.status, eq(Status::Ok)); - assert_that!(stored.local_tail, eq(LogletOffset::new(3))); - worker.enqueue_seal(seal1).unwrap(); - // should latch onto existing seal - worker.enqueue_seal(seal2).unwrap(); - // seal takes precedence, but it gets processed in the background. This store is likely to - // observe Status::Sealing - worker.enqueue_store(msg2).unwrap(); - // sealing - let response = net_rx.recv().await.unwrap(); - let header = response.header.unwrap(); - assert_that!(header.in_response_to(), eq(msg2_id)); - let stored: Stored = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(stored.status, eq(Status::Sealing)); - assert_that!(stored.local_tail, eq(LogletOffset::new(3))); - // seal responses can come at any order, but we'll consume waiters queue before we process - // store messages. - // sealed - let response = net_rx.recv().await.unwrap(); - let header = response.header.unwrap(); - assert_that!(header.in_response_to(), any!(eq(seal1_id), eq(seal2_id))); - let sealed: Sealed = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(sealed.status, eq(Status::Ok)); - assert_that!(sealed.local_tail, eq(LogletOffset::new(3))); - - // sealed2 - let response = net_rx.recv().await.unwrap(); - let header = response.header.unwrap(); - assert_that!(header.in_response_to(), any!(eq(seal1_id), eq(seal2_id))); - let sealed: Sealed = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(sealed.status, eq(Status::Ok)); - assert_that!(sealed.local_tail, eq(LogletOffset::new(3))); - - // try another store - let msg3 = Store { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(3)), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::new(3), - flags: StoreFlags::empty(), - payloads: payloads.clone(), - }; - let msg3 = Incoming::for_testing(connection.downgrade(), msg3, None); - let msg3_id = msg3.msg_id(); - worker.enqueue_store(msg3).unwrap(); - let response = net_rx.recv().await.unwrap(); - let header = response.header.unwrap(); - assert_that!(header.in_response_to(), eq(msg3_id)); - let stored: Stored = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(stored.status, eq(Status::Sealed)); - assert_that!(stored.local_tail, eq(LogletOffset::new(3))); + let seal2 = Seal { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + sequencer: SEQUENCER, + }; - // GetLogletInfo - // offsets 3, 4 - let msg = GetLogletInfo { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - }; - let msg = Incoming::for_testing(connection.downgrade(), msg, None); - let msg_id = msg.msg_id(); - worker.enqueue_get_loglet_info(msg).unwrap(); - - let response = net_rx.recv().await.unwrap(); - let header = response.header.unwrap(); - assert_that!(header.in_response_to(), eq(msg_id)); - let info: LogletInfo = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(info.status, eq(Status::Ok)); - assert_that!(info.local_tail, eq(LogletOffset::new(3))); - assert_that!(info.trim_point, eq(LogletOffset::INVALID)); - assert_that!(info.sealed, eq(true)); + // offsets 3, 4 + let msg2 = Store { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::new(3), + flags: StoreFlags::empty(), + payloads: payloads.clone(), + }; - tc.shutdown_node("test completed", 0).await; - RocksDbManager::get().shutdown().await; - Ok(()) - } - .in_tc(&tc) - .await + let msg1 = Incoming::for_testing(connection.downgrade(), msg1, None); + let seal1 = Incoming::for_testing(connection.downgrade(), seal1, None); + let seal2 = Incoming::for_testing(connection.downgrade(), seal2, None); + let msg2 = Incoming::for_testing(connection.downgrade(), msg2, None); + let msg1_id = msg1.msg_id(); + let seal1_id = seal1.msg_id(); + let seal2_id = seal2.msg_id(); + let msg2_id = msg2.msg_id(); + + worker.enqueue_store(msg1).unwrap(); + // first store is successful + let response = net_rx.recv().await.unwrap(); + let header = response.header.unwrap(); + assert_that!(header.in_response_to(), eq(msg1_id)); + let stored: Stored = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(stored.status, eq(Status::Ok)); + assert_that!(stored.local_tail, eq(LogletOffset::new(3))); + worker.enqueue_seal(seal1).unwrap(); + // should latch onto existing seal + worker.enqueue_seal(seal2).unwrap(); + // seal takes precedence, but it gets processed in the background. This store is likely to + // observe Status::Sealing + worker.enqueue_store(msg2).unwrap(); + // sealing + let response = net_rx.recv().await.unwrap(); + let header = response.header.unwrap(); + assert_that!(header.in_response_to(), eq(msg2_id)); + let stored: Stored = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(stored.status, eq(Status::Sealing)); + assert_that!(stored.local_tail, eq(LogletOffset::new(3))); + // seal responses can come at any order, but we'll consume waiters queue before we process + // store messages. + // sealed + let response = net_rx.recv().await.unwrap(); + let header = response.header.unwrap(); + assert_that!(header.in_response_to(), any!(eq(seal1_id), eq(seal2_id))); + let sealed: Sealed = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(sealed.status, eq(Status::Ok)); + assert_that!(sealed.local_tail, eq(LogletOffset::new(3))); + + // sealed2 + let response = net_rx.recv().await.unwrap(); + let header = response.header.unwrap(); + assert_that!(header.in_response_to(), any!(eq(seal1_id), eq(seal2_id))); + let sealed: Sealed = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(sealed.status, eq(Status::Ok)); + assert_that!(sealed.local_tail, eq(LogletOffset::new(3))); + + // try another store + let msg3 = Store { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(3)), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::new(3), + flags: StoreFlags::empty(), + payloads: payloads.clone(), + }; + let msg3 = Incoming::for_testing(connection.downgrade(), msg3, None); + let msg3_id = msg3.msg_id(); + worker.enqueue_store(msg3).unwrap(); + let response = net_rx.recv().await.unwrap(); + let header = response.header.unwrap(); + assert_that!(header.in_response_to(), eq(msg3_id)); + let stored: Stored = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(stored.status, eq(Status::Sealed)); + assert_that!(stored.local_tail, eq(LogletOffset::new(3))); + + // GetLogletInfo + // offsets 3, 4 + let msg = GetLogletInfo { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + }; + let msg = Incoming::for_testing(connection.downgrade(), msg, None); + let msg_id = msg.msg_id(); + worker.enqueue_get_loglet_info(msg).unwrap(); + + let response = net_rx.recv().await.unwrap(); + let header = response.header.unwrap(); + assert_that!(header.in_response_to(), eq(msg_id)); + let info: LogletInfo = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(info.status, eq(Status::Ok)); + assert_that!(info.local_tail, eq(LogletOffset::new(3))); + assert_that!(info.trim_point, eq(LogletOffset::INVALID)); + assert_that!(info.sealed, eq(true)); + + TaskCenter::shutdown_node("test completed", 0).await; + RocksDbManager::get().shutdown().await; + Ok(()) } - #[test(tokio::test(start_paused = true))] + #[test(restate_core::test(start_paused = true))] async fn test_repair_store() -> Result<()> { - let (tc, log_store) = setup().await?; - async { - const SEQUENCER: GenerationalNodeId = GenerationalNodeId::new(1, 1); - const PEER: GenerationalNodeId = GenerationalNodeId::new(2, 2); - const LOGLET: ReplicatedLogletId = ReplicatedLogletId::new_unchecked(1); - let loglet_state_map = LogletStateMap::default(); - let (net_tx, mut net_rx) = mpsc::channel(10); - let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); - - let (peer_net_tx, mut peer_net_rx) = mpsc::channel(10); - let repair_connection = - OwnedConnection::new_fake(PEER, CURRENT_PROTOCOL_VERSION, peer_net_tx); - - let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; - let worker = LogletWorker::start(LOGLET, log_store, loglet_state)?; - - let payloads: Arc<[Record]> = vec![ - Record::from("a sample record"), - Record::from("another record"), - ] - .into(); - - // offsets 1, 2 - let msg1 = Store { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::OLDEST, - flags: StoreFlags::empty(), - payloads: payloads.clone(), - }; + let log_store = setup().await?; + const SEQUENCER: GenerationalNodeId = GenerationalNodeId::new(1, 1); + const PEER: GenerationalNodeId = GenerationalNodeId::new(2, 2); + const LOGLET: ReplicatedLogletId = ReplicatedLogletId::new_unchecked(1); + let loglet_state_map = LogletStateMap::default(); + let (net_tx, mut net_rx) = mpsc::channel(10); + let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); + + let (peer_net_tx, mut peer_net_rx) = mpsc::channel(10); + let repair_connection = + OwnedConnection::new_fake(PEER, CURRENT_PROTOCOL_VERSION, peer_net_tx); + + let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; + let worker = LogletWorker::start(LOGLET, log_store, loglet_state)?; + + let payloads: Arc<[Record]> = vec![ + Record::from("a sample record"), + Record::from("another record"), + ] + .into(); + + // offsets 1, 2 + let msg1 = Store { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::OLDEST, + flags: StoreFlags::empty(), + payloads: payloads.clone(), + }; - // offsets 10, 11 - let msg2 = Store { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::new(10), - flags: StoreFlags::empty(), - payloads: payloads.clone(), - }; + // offsets 10, 11 + let msg2 = Store { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::new(10), + flags: StoreFlags::empty(), + payloads: payloads.clone(), + }; - let seal1 = Seal { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - sequencer: SEQUENCER, - }; + let seal1 = Seal { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + sequencer: SEQUENCER, + }; - // 5, 6 - let repair_message_before_local_tail = Store { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::new(5), - flags: StoreFlags::IgnoreSeal, - payloads: payloads.clone(), - }; + // 5, 6 + let repair_message_before_local_tail = Store { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::new(5), + flags: StoreFlags::IgnoreSeal, + payloads: payloads.clone(), + }; - // 16, 17 - let repair_message_after_local_tail = Store { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(16)), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::new(16), - flags: StoreFlags::IgnoreSeal, - payloads: payloads.clone(), - }; + // 16, 17 + let repair_message_after_local_tail = Store { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(16)), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::new(16), + flags: StoreFlags::IgnoreSeal, + payloads: payloads.clone(), + }; + + let msg1 = Incoming::for_testing(connection.downgrade(), msg1, None); + let msg2 = Incoming::for_testing(connection.downgrade(), msg2, None); + let repair1 = Incoming::for_testing( + repair_connection.downgrade(), + repair_message_before_local_tail, + None, + ); + let repair2 = Incoming::for_testing( + repair_connection.downgrade(), + repair_message_after_local_tail, + None, + ); + let seal1 = Incoming::for_testing(connection.downgrade(), seal1, None); + + worker.enqueue_store(msg1).unwrap(); + worker.enqueue_store(msg2).unwrap(); + // first store is successful + let response = net_rx.recv().await.unwrap(); + let stored: Stored = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(stored.status, eq(Status::Ok)); + assert_that!(stored.sealed, eq(false)); + assert_that!(stored.local_tail, eq(LogletOffset::new(3))); + + // 10, 11 + let response = net_rx.recv().await.unwrap(); + let stored: Stored = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(stored.status, eq(Status::Ok)); + assert_that!(stored.sealed, eq(false)); + assert_that!(stored.local_tail, eq(LogletOffset::new(12))); + + worker.enqueue_seal(seal1).unwrap(); + // seal responses can come at any order, but we'll consume waiters queue before we process + // store messages. + // sealed + let response = net_rx.recv().await.unwrap(); + let sealed: Sealed = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(sealed.status, eq(Status::Ok)); + assert_that!(sealed.local_tail, eq(LogletOffset::new(12))); + + // repair store (before local tail, local tail won't move) + worker.enqueue_store(repair1).unwrap(); + let response = peer_net_rx.recv().await.unwrap(); + let stored: Stored = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(stored.status, eq(Status::Ok)); + assert_that!(stored.local_tail, eq(LogletOffset::new(12))); + + worker.enqueue_store(repair2).unwrap(); + let response = peer_net_rx.recv().await.unwrap(); + let stored: Stored = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(stored.status, eq(Status::Ok)); + assert_that!(stored.local_tail, eq(LogletOffset::new(18))); + + // GetLogletInfo + // offsets 3, 4 + let msg = GetLogletInfo { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + }; + let msg = Incoming::for_testing(connection.downgrade(), msg, None); + let msg_id = msg.msg_id(); + worker.enqueue_get_loglet_info(msg).unwrap(); + + let response = net_rx.recv().await.unwrap(); + let header = response.header.unwrap(); + assert_that!(header.in_response_to(), eq(msg_id)); + let info: LogletInfo = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(info.status, eq(Status::Ok)); + assert_that!(info.local_tail, eq(LogletOffset::new(18))); + assert_that!(info.trim_point, eq(LogletOffset::INVALID)); + assert_that!(info.sealed, eq(true)); + TaskCenter::shutdown_node("test completed", 0).await; + RocksDbManager::get().shutdown().await; + Ok(()) + } - let msg1 = Incoming::for_testing(connection.downgrade(), msg1, None); - let msg2 = Incoming::for_testing(connection.downgrade(), msg2, None); - let repair1 = Incoming::for_testing( - repair_connection.downgrade(), - repair_message_before_local_tail, + #[test(restate_core::test(start_paused = true))] + async fn test_simple_get_records_flow() -> Result<()> { + let log_store = setup().await?; + const SEQUENCER: GenerationalNodeId = GenerationalNodeId::new(1, 1); + const LOGLET: ReplicatedLogletId = ReplicatedLogletId::new_unchecked(1); + let loglet_state_map = LogletStateMap::default(); + let (net_tx, mut net_rx) = mpsc::channel(10); + let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); + + let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; + let worker = LogletWorker::start(LOGLET, log_store, loglet_state)?; + + // Populate the log-store with some records (..,2,..,5,..,10, 11) + // Note: dots mean we don't have records at those globally committed offsets. + worker + .enqueue_store(Incoming::for_testing( + connection.downgrade(), + Store { + // faking that offset=1 is released + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(2)), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::new(2), + flags: StoreFlags::empty(), + payloads: vec![Record::from("record2")].into(), + }, None, - ); - let repair2 = Incoming::for_testing( - repair_connection.downgrade(), - repair_message_after_local_tail, + )) + .unwrap(); + + worker + .enqueue_store(Incoming::for_testing( + connection.downgrade(), + Store { + // faking that offset=4 is released + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(5)), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::new(5), + flags: StoreFlags::empty(), + payloads: vec![Record::from(("record5", Keys::Single(11)))].into(), + }, None, - ); - let seal1 = Incoming::for_testing(connection.downgrade(), seal1, None); - - worker.enqueue_store(msg1).unwrap(); - worker.enqueue_store(msg2).unwrap(); - // first store is successful - let response = net_rx.recv().await.unwrap(); - let stored: Stored = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(stored.status, eq(Status::Ok)); - assert_that!(stored.sealed, eq(false)); - assert_that!(stored.local_tail, eq(LogletOffset::new(3))); + )) + .unwrap(); + + worker + .enqueue_store(Incoming::for_testing( + connection.downgrade(), + Store { + // faking that offset=9 is released + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::new(10), + flags: StoreFlags::empty(), + payloads: vec![Record::from("record10"), Record::from("record11")].into(), + }, + None, + )) + .unwrap(); - // 10, 11 - let response = net_rx.recv().await.unwrap(); - let stored: Stored = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(stored.status, eq(Status::Ok)); - assert_that!(stored.sealed, eq(false)); - assert_that!(stored.local_tail, eq(LogletOffset::new(12))); - - worker.enqueue_seal(seal1).unwrap(); - // seal responses can come at any order, but we'll consume waiters queue before we process - // store messages. - // sealed - let response = net_rx.recv().await.unwrap(); - let sealed: Sealed = response - .body + // Wait for stores to complete. + for _ in 0..3 { + let stored: Stored = net_rx + .recv() + .await .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(sealed.status, eq(Status::Ok)); - assert_that!(sealed.local_tail, eq(LogletOffset::new(12))); - - // repair store (before local tail, local tail won't move) - worker.enqueue_store(repair1).unwrap(); - let response = peer_net_rx.recv().await.unwrap(); - let stored: Stored = response .body .unwrap() .try_decode(connection.protocol_version())?; assert_that!(stored.status, eq(Status::Ok)); - assert_that!(stored.local_tail, eq(LogletOffset::new(12))); + } - worker.enqueue_store(repair2).unwrap(); - let response = peer_net_rx.recv().await.unwrap(); - let stored: Stored = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(stored.status, eq(Status::Ok)); - assert_that!(stored.local_tail, eq(LogletOffset::new(18))); + // We expect to see [2, 5]. No trim gaps, no filtered gaps. + worker + .enqueue_get_records(Incoming::for_testing( + connection.downgrade(), + GetRecords { + // faking that offset=9 is released + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), + filter: KeyFilter::Any, + // no memory limits + total_limit_in_bytes: None, + from_offset: LogletOffset::new(1), + to_offset: LogletOffset::new(7), + }, + None, + )) + .unwrap(); - // GetLogletInfo - // offsets 3, 4 - let msg = GetLogletInfo { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - }; - let msg = Incoming::for_testing(connection.downgrade(), msg, None); - let msg_id = msg.msg_id(); - worker.enqueue_get_loglet_info(msg).unwrap(); - - let response = net_rx.recv().await.unwrap(); - let header = response.header.unwrap(); - assert_that!(header.in_response_to(), eq(msg_id)); - let info: LogletInfo = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(info.status, eq(Status::Ok)); - assert_that!(info.local_tail, eq(LogletOffset::new(18))); - assert_that!(info.trim_point, eq(LogletOffset::INVALID)); - assert_that!(info.sealed, eq(true)); - tc.shutdown_node("test completed", 0).await; - RocksDbManager::get().shutdown().await; - Ok(()) + let mut records: Records = net_rx + .recv() + .await + .unwrap() + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(records.status, eq(Status::Ok)); + assert_that!(records.local_tail, eq(LogletOffset::new(12))); + assert_that!(records.sealed, eq(false)); + assert_that!(records.next_offset, eq(LogletOffset::new(8))); + assert_that!(records.records.len(), eq(2)); + // pop in reverse order + for i in [5, 2] { + let (offset, record) = records.records.pop().unwrap(); + assert_that!(offset, eq(LogletOffset::from(i))); + assert_that!(record.is_data(), eq(true)); + let data = record.try_unwrap_data().unwrap(); + let original: String = data.decode().unwrap(); + assert_that!(original, eq(format!("record{}", i))); } - .in_tc(&tc) - .await - } - #[test(tokio::test(start_paused = true))] - async fn test_simple_get_records_flow() -> Result<()> { - let (tc, log_store) = setup().await?; - async { - const SEQUENCER: GenerationalNodeId = GenerationalNodeId::new(1, 1); - const LOGLET: ReplicatedLogletId = ReplicatedLogletId::new_unchecked(1); - let loglet_state_map = LogletStateMap::default(); - let (net_tx, mut net_rx) = mpsc::channel(10); - let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); - - let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; - let worker = LogletWorker::start(LOGLET, log_store, loglet_state)?; - - // Populate the log-store with some records (..,2,..,5,..,10, 11) - // Note: dots mean we don't have records at those globally committed offsets. - worker - .enqueue_store(Incoming::for_testing( - connection.downgrade(), - Store { - // faking that offset=1 is released - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(2)), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::new(2), - flags: StoreFlags::empty(), - payloads: vec![Record::from("record2")].into(), - }, - None, - )) - .unwrap(); - - worker - .enqueue_store(Incoming::for_testing( - connection.downgrade(), - Store { - // faking that offset=4 is released - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(5)), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::new(5), - flags: StoreFlags::empty(), - payloads: vec![Record::from(("record5", Keys::Single(11)))].into(), - }, - None, - )) - .unwrap(); - - worker - .enqueue_store(Incoming::for_testing( - connection.downgrade(), - Store { - // faking that offset=9 is released - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::new(10), - flags: StoreFlags::empty(), - payloads: vec![Record::from("record10"), Record::from("record11")].into(), - }, - None, - )) - .unwrap(); - - // Wait for stores to complete. - for _ in 0..3 { - let stored: Stored = net_rx - .recv() - .await - .unwrap() - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(stored.status, eq(Status::Ok)); - } + // We expect to see [2, FILTERED(5), 10, 11]. No trim gaps. + worker + .enqueue_get_records(Incoming::for_testing( + connection.downgrade(), + GetRecords { + // INVALID can be used when we don't have a reasonable value to pass in. + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + // no memory limits + total_limit_in_bytes: None, + filter: KeyFilter::Within(0..=5), + from_offset: LogletOffset::new(1), + // to a point beyond local tail + to_offset: LogletOffset::new(100), + }, + None, + )) + .unwrap(); - // We expect to see [2, 5]. No trim gaps, no filtered gaps. - worker - .enqueue_get_records(Incoming::for_testing( - connection.downgrade(), - GetRecords { - // faking that offset=9 is released - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), - filter: KeyFilter::Any, - // no memory limits - total_limit_in_bytes: None, - from_offset: LogletOffset::new(1), - to_offset: LogletOffset::new(7), - }, - None, - )) - .unwrap(); - - let mut records: Records = net_rx - .recv() - .await - .unwrap() - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(records.status, eq(Status::Ok)); - assert_that!(records.local_tail, eq(LogletOffset::new(12))); - assert_that!(records.sealed, eq(false)); - assert_that!(records.next_offset, eq(LogletOffset::new(8))); - assert_that!(records.records.len(), eq(2)); - // pop in reverse order - for i in [5, 2] { - let (offset, record) = records.records.pop().unwrap(); - assert_that!(offset, eq(LogletOffset::from(i))); + let mut records: Records = net_rx + .recv() + .await + .unwrap() + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(records.status, eq(Status::Ok)); + assert_that!(records.local_tail, eq(LogletOffset::new(12))); + assert_that!(records.next_offset, eq(LogletOffset::new(12))); + assert_that!(records.sealed, eq(false)); + assert_that!(records.records.len(), eq(4)); + // pop() returns records in reverse order + for i in [11, 10, 5, 2] { + let (offset, record) = records.records.pop().unwrap(); + assert_that!(offset, eq(LogletOffset::from(i))); + if i == 5 { + // this one is filtered + assert_that!(record.is_filtered_gap(), eq(true)); + let gap = record.try_unwrap_filtered_gap().unwrap(); + assert_that!(gap.to, eq(LogletOffset::new(5))); + } else { assert_that!(record.is_data(), eq(true)); let data = record.try_unwrap_data().unwrap(); let original: String = data.decode().unwrap(); assert_that!(original, eq(format!("record{}", i))); } + } - // We expect to see [2, FILTERED(5), 10, 11]. No trim gaps. - worker - .enqueue_get_records(Incoming::for_testing( - connection.downgrade(), - GetRecords { - // INVALID can be used when we don't have a reasonable value to pass in. - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - // no memory limits - total_limit_in_bytes: None, - filter: KeyFilter::Within(0..=5), - from_offset: LogletOffset::new(1), - // to a point beyond local tail - to_offset: LogletOffset::new(100), - }, - None, - )) - .unwrap(); - - let mut records: Records = net_rx - .recv() - .await - .unwrap() - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(records.status, eq(Status::Ok)); - assert_that!(records.local_tail, eq(LogletOffset::new(12))); - assert_that!(records.next_offset, eq(LogletOffset::new(12))); - assert_that!(records.sealed, eq(false)); - assert_that!(records.records.len(), eq(4)); - // pop() returns records in reverse order - for i in [11, 10, 5, 2] { - let (offset, record) = records.records.pop().unwrap(); - assert_that!(offset, eq(LogletOffset::from(i))); - if i == 5 { - // this one is filtered - assert_that!(record.is_filtered_gap(), eq(true)); - let gap = record.try_unwrap_filtered_gap().unwrap(); - assert_that!(gap.to, eq(LogletOffset::new(5))); - } else { - assert_that!(record.is_data(), eq(true)); - let data = record.try_unwrap_data().unwrap(); - let original: String = data.decode().unwrap(); - assert_that!(original, eq(format!("record{}", i))); - } - } + // Apply memory limits (2 bytes) should always see the first real record. + // We expect to see [FILTERED(5), 10]. (11 is not returend due to budget) + worker + .enqueue_get_records(Incoming::for_testing( + connection.downgrade(), + GetRecords { + // INVALID can be used when we don't have a reasonable value to pass in. + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + // no memory limits + total_limit_in_bytes: Some(2), + filter: KeyFilter::Within(0..=5), + from_offset: LogletOffset::new(4), + // to a point beyond local tail + to_offset: LogletOffset::new(100), + }, + None, + )) + .unwrap(); - // Apply memory limits (2 bytes) should always see the first real record. - // We expect to see [FILTERED(5), 10]. (11 is not returend due to budget) - worker - .enqueue_get_records(Incoming::for_testing( - connection.downgrade(), - GetRecords { - // INVALID can be used when we don't have a reasonable value to pass in. - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - // no memory limits - total_limit_in_bytes: Some(2), - filter: KeyFilter::Within(0..=5), - from_offset: LogletOffset::new(4), - // to a point beyond local tail - to_offset: LogletOffset::new(100), - }, - None, - )) - .unwrap(); - - let mut records: Records = net_rx - .recv() - .await - .unwrap() - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(records.status, eq(Status::Ok)); - assert_that!(records.local_tail, eq(LogletOffset::new(12))); - assert_that!(records.next_offset, eq(LogletOffset::new(11))); - assert_that!(records.sealed, eq(false)); - assert_that!(records.records.len(), eq(2)); - // pop() returns records in reverse order - for i in [10, 5] { - let (offset, record) = records.records.pop().unwrap(); - assert_that!(offset, eq(LogletOffset::from(i))); - if i == 5 { - // this one is filtered - assert_that!(record.is_filtered_gap(), eq(true)); - let gap = record.try_unwrap_filtered_gap().unwrap(); - assert_that!(gap.to, eq(LogletOffset::new(5))); - } else { - assert_that!(record.is_data(), eq(true)); - let data = record.try_unwrap_data().unwrap(); - let original: String = data.decode().unwrap(); - assert_that!(original, eq(format!("record{}", i))); - } + let mut records: Records = net_rx + .recv() + .await + .unwrap() + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(records.status, eq(Status::Ok)); + assert_that!(records.local_tail, eq(LogletOffset::new(12))); + assert_that!(records.next_offset, eq(LogletOffset::new(11))); + assert_that!(records.sealed, eq(false)); + assert_that!(records.records.len(), eq(2)); + // pop() returns records in reverse order + for i in [10, 5] { + let (offset, record) = records.records.pop().unwrap(); + assert_that!(offset, eq(LogletOffset::from(i))); + if i == 5 { + // this one is filtered + assert_that!(record.is_filtered_gap(), eq(true)); + let gap = record.try_unwrap_filtered_gap().unwrap(); + assert_that!(gap.to, eq(LogletOffset::new(5))); + } else { + assert_that!(record.is_data(), eq(true)); + let data = record.try_unwrap_data().unwrap(); + let original: String = data.decode().unwrap(); + assert_that!(original, eq(format!("record{}", i))); } + } - tc.shutdown_node("test completed", 0).await; - RocksDbManager::get().shutdown().await; + TaskCenter::shutdown_node("test completed", 0).await; + RocksDbManager::get().shutdown().await; - Ok(()) - } - .in_tc(&tc) - .await + Ok(()) } - #[test(tokio::test(start_paused = true))] + #[test(restate_core::test(start_paused = true))] async fn test_trim_basics() -> Result<()> { - let (tc, log_store) = setup().await?; - async { - const SEQUENCER: GenerationalNodeId = GenerationalNodeId::new(1, 1); - const LOGLET: ReplicatedLogletId = ReplicatedLogletId::new_unchecked(1); - let loglet_state_map = LogletStateMap::default(); - let (net_tx, mut net_rx) = mpsc::channel(10); - let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); - - let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; - let worker = LogletWorker::start(LOGLET, log_store.clone(), loglet_state.clone())?; - - assert_that!(loglet_state.trim_point(), eq(LogletOffset::INVALID)); - assert_that!(loglet_state.local_tail().offset(), eq(LogletOffset::OLDEST)); - // The loglet has no knowledge of global commits, it shouldn't accept trims. - worker - .enqueue_trim(Incoming::for_testing( - connection.downgrade(), - Trim { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::OLDEST), - trim_point: LogletOffset::OLDEST, - }, - None, - )) - .unwrap(); - - let trimmed: Trimmed = net_rx - .recv() - .await - .unwrap() - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(trimmed.status, eq(Status::Malformed)); - assert_that!(trimmed.local_tail, eq(LogletOffset::OLDEST)); - assert_that!(trimmed.sealed, eq(false)); - - // The loglet has knowledge of global tail of 10, it should accept trims up to 9 but it - // won't move trim point beyond its local tail. - worker - .enqueue_trim(Incoming::for_testing( - connection.downgrade(), - Trim { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), - trim_point: LogletOffset::new(9), - }, - None, - )) - .unwrap(); - - let trimmed: Trimmed = net_rx - .recv() - .await - .unwrap() - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(trimmed.status, eq(Status::Ok)); - assert_that!(trimmed.local_tail, eq(LogletOffset::OLDEST)); - assert_that!(trimmed.sealed, eq(false)); - - // let's store some records at offsets (5, 6) - worker - .enqueue_store(Incoming::for_testing( - connection.downgrade(), - Store { - // faking that offset=9 is released - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::new(5), - flags: StoreFlags::empty(), - payloads: vec![Record::from("record5"), Record::from("record6")].into(), - }, - None, - )) - .unwrap(); - let stored: Stored = net_rx - .recv() - .await - .unwrap() - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(stored.status, eq(Status::Ok)); - assert_that!(stored.local_tail, eq(LogletOffset::new(7))); - - // trim to 5 - worker - .enqueue_trim(Incoming::for_testing( - connection.downgrade(), - Trim { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), - trim_point: LogletOffset::new(5), - }, - None, - )) - .unwrap(); - - let trimmed: Trimmed = net_rx - .recv() - .await - .unwrap() - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(trimmed.status, eq(Status::Ok)); - assert_that!(trimmed.local_tail, eq(LogletOffset::new(7))); - assert_that!(trimmed.sealed, eq(false)); - - // Attempt to read. We expect to see a trim gap (1->5, 6 (data-record)) - worker - .enqueue_get_records(Incoming::for_testing( - connection.downgrade(), - GetRecords { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - total_limit_in_bytes: None, - filter: KeyFilter::Any, - from_offset: LogletOffset::OLDEST, - // to a point beyond local tail - to_offset: LogletOffset::new(100), - }, - None, - )) - .unwrap(); - - let mut records: Records = net_rx - .recv() - .await - .unwrap() - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(records.status, eq(Status::Ok)); - assert_that!(records.local_tail, eq(LogletOffset::new(7))); - assert_that!(records.next_offset, eq(LogletOffset::new(7))); - assert_that!(records.sealed, eq(false)); - assert_that!(records.records.len(), eq(2)); - // pop() returns records in reverse order - for i in [6, 1] { - let (offset, record) = records.records.pop().unwrap(); - assert_that!(offset, eq(LogletOffset::from(i))); - if i == 1 { - // this one is a trim gap - assert_that!(record.is_trim_gap(), eq(true)); - let gap = record.try_unwrap_trim_gap().unwrap(); - assert_that!(gap.to, eq(LogletOffset::new(5))); - } else { - assert_that!(record.is_data(), eq(true)); - let data = record.try_unwrap_data().unwrap(); - let original: String = data.decode().unwrap(); - assert_that!(original, eq(format!("record{}", i))); - } - } + let log_store = setup().await?; + const SEQUENCER: GenerationalNodeId = GenerationalNodeId::new(1, 1); + const LOGLET: ReplicatedLogletId = ReplicatedLogletId::new_unchecked(1); + let loglet_state_map = LogletStateMap::default(); + let (net_tx, mut net_rx) = mpsc::channel(10); + let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); + + let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; + let worker = LogletWorker::start(LOGLET, log_store.clone(), loglet_state.clone())?; + + assert_that!(loglet_state.trim_point(), eq(LogletOffset::INVALID)); + assert_that!(loglet_state.local_tail().offset(), eq(LogletOffset::OLDEST)); + // The loglet has no knowledge of global commits, it shouldn't accept trims. + worker + .enqueue_trim(Incoming::for_testing( + connection.downgrade(), + Trim { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::OLDEST), + trim_point: LogletOffset::OLDEST, + }, + None, + )) + .unwrap(); - // trim everything - worker - .enqueue_trim(Incoming::for_testing( - connection.downgrade(), - Trim { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), - trim_point: LogletOffset::new(9), - }, - None, - )) - .unwrap(); - - let trimmed: Trimmed = net_rx - .recv() - .await - .unwrap() - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(trimmed.status, eq(Status::Ok)); - assert_that!(trimmed.local_tail, eq(LogletOffset::new(7))); - assert_that!(trimmed.sealed, eq(false)); - - // Attempt to read again. We expect to see a trim gap (1->6) - worker - .enqueue_get_records(Incoming::for_testing( - connection.downgrade(), - GetRecords { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - total_limit_in_bytes: None, - filter: KeyFilter::Any, - from_offset: LogletOffset::OLDEST, - // to a point beyond local tail - to_offset: LogletOffset::new(100), - }, - None, - )) - .unwrap(); - - let mut records: Records = net_rx - .recv() - .await - .unwrap() - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(records.status, eq(Status::Ok)); - assert_that!(records.local_tail, eq(LogletOffset::new(7))); - assert_that!(records.next_offset, eq(LogletOffset::new(7))); - assert_that!(records.sealed, eq(false)); - assert_that!(records.records.len(), eq(1)); + let trimmed: Trimmed = net_rx + .recv() + .await + .unwrap() + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(trimmed.status, eq(Status::Malformed)); + assert_that!(trimmed.local_tail, eq(LogletOffset::OLDEST)); + assert_that!(trimmed.sealed, eq(false)); + + // The loglet has knowledge of global tail of 10, it should accept trims up to 9 but it + // won't move trim point beyond its local tail. + worker + .enqueue_trim(Incoming::for_testing( + connection.downgrade(), + Trim { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), + trim_point: LogletOffset::new(9), + }, + None, + )) + .unwrap(); + + let trimmed: Trimmed = net_rx + .recv() + .await + .unwrap() + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(trimmed.status, eq(Status::Ok)); + assert_that!(trimmed.local_tail, eq(LogletOffset::OLDEST)); + assert_that!(trimmed.sealed, eq(false)); + + // let's store some records at offsets (5, 6) + worker + .enqueue_store(Incoming::for_testing( + connection.downgrade(), + Store { + // faking that offset=9 is released + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::new(5), + flags: StoreFlags::empty(), + payloads: vec![Record::from("record5"), Record::from("record6")].into(), + }, + None, + )) + .unwrap(); + let stored: Stored = net_rx + .recv() + .await + .unwrap() + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(stored.status, eq(Status::Ok)); + assert_that!(stored.local_tail, eq(LogletOffset::new(7))); + + // trim to 5 + worker + .enqueue_trim(Incoming::for_testing( + connection.downgrade(), + Trim { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), + trim_point: LogletOffset::new(5), + }, + None, + )) + .unwrap(); + + let trimmed: Trimmed = net_rx + .recv() + .await + .unwrap() + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(trimmed.status, eq(Status::Ok)); + assert_that!(trimmed.local_tail, eq(LogletOffset::new(7))); + assert_that!(trimmed.sealed, eq(false)); + + // Attempt to read. We expect to see a trim gap (1->5, 6 (data-record)) + worker + .enqueue_get_records(Incoming::for_testing( + connection.downgrade(), + GetRecords { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + total_limit_in_bytes: None, + filter: KeyFilter::Any, + from_offset: LogletOffset::OLDEST, + // to a point beyond local tail + to_offset: LogletOffset::new(100), + }, + None, + )) + .unwrap(); + + let mut records: Records = net_rx + .recv() + .await + .unwrap() + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(records.status, eq(Status::Ok)); + assert_that!(records.local_tail, eq(LogletOffset::new(7))); + assert_that!(records.next_offset, eq(LogletOffset::new(7))); + assert_that!(records.sealed, eq(false)); + assert_that!(records.records.len(), eq(2)); + // pop() returns records in reverse order + for i in [6, 1] { let (offset, record) = records.records.pop().unwrap(); - assert_that!(offset, eq(LogletOffset::from(1))); - assert_that!(record.is_trim_gap(), eq(true)); - let gap = record.try_unwrap_trim_gap().unwrap(); - assert_that!(gap.to, eq(LogletOffset::new(6))); - - // Make sure that we can load the local-tail correctly when loading the loglet_state - let loglet_state_map = LogletStateMap::default(); - let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; - assert_that!(loglet_state.trim_point(), eq(LogletOffset::new(6))); - assert_that!(loglet_state.local_tail().offset(), eq(LogletOffset::new(7))); - - tc.shutdown_node("test completed", 0).await; - RocksDbManager::get().shutdown().await; - Ok(()) + assert_that!(offset, eq(LogletOffset::from(i))); + if i == 1 { + // this one is a trim gap + assert_that!(record.is_trim_gap(), eq(true)); + let gap = record.try_unwrap_trim_gap().unwrap(); + assert_that!(gap.to, eq(LogletOffset::new(5))); + } else { + assert_that!(record.is_data(), eq(true)); + let data = record.try_unwrap_data().unwrap(); + let original: String = data.decode().unwrap(); + assert_that!(original, eq(format!("record{}", i))); + } } - .in_tc(&tc) - .await + + // trim everything + worker + .enqueue_trim(Incoming::for_testing( + connection.downgrade(), + Trim { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), + trim_point: LogletOffset::new(9), + }, + None, + )) + .unwrap(); + + let trimmed: Trimmed = net_rx + .recv() + .await + .unwrap() + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(trimmed.status, eq(Status::Ok)); + assert_that!(trimmed.local_tail, eq(LogletOffset::new(7))); + assert_that!(trimmed.sealed, eq(false)); + + // Attempt to read again. We expect to see a trim gap (1->6) + worker + .enqueue_get_records(Incoming::for_testing( + connection.downgrade(), + GetRecords { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + total_limit_in_bytes: None, + filter: KeyFilter::Any, + from_offset: LogletOffset::OLDEST, + // to a point beyond local tail + to_offset: LogletOffset::new(100), + }, + None, + )) + .unwrap(); + + let mut records: Records = net_rx + .recv() + .await + .unwrap() + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(records.status, eq(Status::Ok)); + assert_that!(records.local_tail, eq(LogletOffset::new(7))); + assert_that!(records.next_offset, eq(LogletOffset::new(7))); + assert_that!(records.sealed, eq(false)); + assert_that!(records.records.len(), eq(1)); + let (offset, record) = records.records.pop().unwrap(); + assert_that!(offset, eq(LogletOffset::from(1))); + assert_that!(record.is_trim_gap(), eq(true)); + let gap = record.try_unwrap_trim_gap().unwrap(); + assert_that!(gap.to, eq(LogletOffset::new(6))); + + // Make sure that we can load the local-tail correctly when loading the loglet_state + let loglet_state_map = LogletStateMap::default(); + let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; + assert_that!(loglet_state.trim_point(), eq(LogletOffset::new(6))); + assert_that!(loglet_state.local_tail().offset(), eq(LogletOffset::new(7))); + + TaskCenter::shutdown_node("test completed", 0).await; + RocksDbManager::get().shutdown().await; + Ok(()) } } diff --git a/crates/log-server/src/rocksdb_logstore/store.rs b/crates/log-server/src/rocksdb_logstore/store.rs index 6c0a43364..1c6b08290 100644 --- a/crates/log-server/src/rocksdb_logstore/store.rs +++ b/crates/log-server/src/rocksdb_logstore/store.rs @@ -497,7 +497,7 @@ mod tests { use googletest::prelude::*; use test_log::test; - use restate_core::{TaskCenter, TaskCenterBuilder, TaskCenterFutureExt}; + use restate_core::TaskCenter; use restate_rocksdb::RocksDbManager; use restate_types::config::Configuration; use restate_types::live::Live; @@ -512,33 +512,24 @@ mod tests { use crate::logstore::LogStore; use crate::metadata::LogStoreMarker; use crate::rocksdb_logstore::RocksDbLogStoreBuilder; - use crate::setup_panic_handler; - - async fn setup() -> Result<(TaskCenter, RocksDbLogStore)> { - setup_panic_handler(); - let tc = TaskCenterBuilder::default_for_tests().build()?; - let log_store = async { - let config = Live::from_value(Configuration::default()); - let common_rocks_opts = config.clone().map(|c| &c.common); - RocksDbManager::init(common_rocks_opts); - // create logstore. - let builder = RocksDbLogStoreBuilder::create( - config.clone().map(|c| &c.log_server).boxed(), - config.map(|c| &c.log_server.rocksdb).boxed(), - RecordCache::new(1_000_000), - ) - .await?; - let log_store = builder.start(Default::default()).await?; - Result::Ok(log_store) - } - .in_tc(&tc) + + async fn setup() -> Result { + let config = Live::from_value(Configuration::default()); + let common_rocks_opts = config.clone().map(|c| &c.common); + RocksDbManager::init(common_rocks_opts); + // create logstore. + let builder = RocksDbLogStoreBuilder::create( + config.clone().map(|c| &c.log_server).boxed(), + config.map(|c| &c.log_server.rocksdb).boxed(), + RecordCache::new(1_000_000), + ) .await?; - Ok((tc, log_store)) + Ok(builder.start(Default::default()).await?) } - #[test(tokio::test(start_paused = true))] + #[test(restate_core::test(start_paused = true))] async fn test_log_store_marker() -> Result<()> { - let (tc, log_store) = setup().await?; + let log_store = setup().await?; let marker = log_store.load_marker().await?; assert!(marker.is_none()); @@ -554,14 +545,14 @@ mod tests { let marker_again = log_store.load_marker().await?; assert_that!(marker_again, some(eq(marker))); - tc.shutdown_node("test completed", 0).await; + TaskCenter::shutdown_node("test completed", 0).await; RocksDbManager::get().shutdown().await; Ok(()) } - #[test(tokio::test(start_paused = true))] + #[test(restate_core::test(start_paused = true))] async fn test_load_loglet_state() -> Result<()> { - let (tc, log_store) = setup().await?; + let log_store = setup().await?; // fresh/unknown loglet let loglet_id_1 = ReplicatedLogletId::new_unchecked(88); let loglet_id_2 = ReplicatedLogletId::new_unchecked(89); @@ -646,14 +637,14 @@ mod tests { some(eq(&GenerationalNodeId::new(2, 212))) ); - tc.shutdown_node("test completed", 0).await; + TaskCenter::shutdown_node("test completed", 0).await; RocksDbManager::get().shutdown().await; Ok(()) } - #[test(tokio::test(start_paused = true))] + #[test(restate_core::test(start_paused = true))] async fn test_digest() -> Result<()> { - let (tc, log_store) = setup().await?; + let log_store = setup().await?; let loglet_id_1 = ReplicatedLogletId::new_unchecked(88); let loglet_id_2 = ReplicatedLogletId::new_unchecked(89); let sequencer_1 = GenerationalNodeId::new(5, 213); @@ -812,7 +803,7 @@ mod tests { ] ); - tc.shutdown_node("test completed", 0).await; + TaskCenter::shutdown_node("test completed", 0).await; RocksDbManager::get().shutdown().await; Ok(()) } diff --git a/crates/metadata-store/src/local/tests.rs b/crates/metadata-store/src/local/tests.rs index 8111f3aaa..96fb9e82d 100644 --- a/crates/metadata-store/src/local/tests.rs +++ b/crates/metadata-store/src/local/tests.rs @@ -62,167 +62,152 @@ impl Versioned for Value { flexbuffers_storage_encode_decode!(Value); /// Tests basic operations of the metadata store. -#[test(tokio::test(flavor = "multi_thread", worker_threads = 2))] +#[test(restate_core::test(flavor = "multi_thread", worker_threads = 2))] async fn basic_metadata_store_operations() -> anyhow::Result<()> { - let (client, env) = create_test_environment(&MetadataStoreOptions::default()).await?; - - env.tc - .run_in_scope("test", None, async move { - let key: ByteString = "key".into(); - let value = Value { - version: Version::MIN, - value: "test_value".to_owned(), - }; - - let next_value = Value { - version: Version::from(2), - value: "next_value".to_owned(), - }; - - let other_value = Value { - version: Version::MIN, - value: "other_value".to_owned(), - }; - - // first get should be empty - assert!(client.get::(key.clone()).await?.is_none()); - - // put initial value - client.put(key.clone(), &value, Precondition::None).await?; - - assert_eq!( - client.get_version(key.clone()).await?, - Some(value.version()) - ); - assert_eq!(client.get(key.clone()).await?, Some(value)); - - // fail to overwrite existing value - assert!(matches!( - client - .put(key.clone(), &other_value, Precondition::DoesNotExist) - .await, - Err(WriteError::FailedPrecondition(_)) - )); - - // fail to overwrite existing value with wrong version - assert!(matches!( - client - .put( - key.clone(), - &other_value, - Precondition::MatchesVersion(Version::INVALID) - ) - .await, - Err(WriteError::FailedPrecondition(_)) - )); - - // overwrite with matching version precondition - client - .put( - key.clone(), - &next_value, - Precondition::MatchesVersion(Version::MIN), - ) - .await?; - assert_eq!(client.get(key.clone()).await?, Some(next_value)); - - // try to delete value with wrong version should fail - assert!(matches!( - client - .delete(key.clone(), Precondition::MatchesVersion(Version::MIN)) - .await, - Err(WriteError::FailedPrecondition(_)) - )); - - // delete should succeed with the right precondition - client - .delete(key.clone(), Precondition::MatchesVersion(Version::from(2))) - .await?; - assert!(client.get::(key.clone()).await?.is_none()); - - // unconditional delete - client - .put(key.clone(), &other_value, Precondition::None) - .await?; - client.delete(key.clone(), Precondition::None).await?; - assert!(client.get::(key.clone()).await?.is_none()); - - Ok::<(), anyhow::Error>(()) - }) + let (client, _env) = create_test_environment(&MetadataStoreOptions::default()).await?; + + let key: ByteString = "key".into(); + let value = Value { + version: Version::MIN, + value: "test_value".to_owned(), + }; + + let next_value = Value { + version: Version::from(2), + value: "next_value".to_owned(), + }; + + let other_value = Value { + version: Version::MIN, + value: "other_value".to_owned(), + }; + + // first get should be empty + assert!(client.get::(key.clone()).await?.is_none()); + + // put initial value + client.put(key.clone(), &value, Precondition::None).await?; + + assert_eq!( + client.get_version(key.clone()).await?, + Some(value.version()) + ); + assert_eq!(client.get(key.clone()).await?, Some(value)); + + // fail to overwrite existing value + assert!(matches!( + client + .put(key.clone(), &other_value, Precondition::DoesNotExist) + .await, + Err(WriteError::FailedPrecondition(_)) + )); + + // fail to overwrite existing value with wrong version + assert!(matches!( + client + .put( + key.clone(), + &other_value, + Precondition::MatchesVersion(Version::INVALID) + ) + .await, + Err(WriteError::FailedPrecondition(_)) + )); + + // overwrite with matching version precondition + client + .put( + key.clone(), + &next_value, + Precondition::MatchesVersion(Version::MIN), + ) + .await?; + assert_eq!(client.get(key.clone()).await?, Some(next_value)); + + // try to delete value with wrong version should fail + assert!(matches!( + client + .delete(key.clone(), Precondition::MatchesVersion(Version::MIN)) + .await, + Err(WriteError::FailedPrecondition(_)) + )); + + // delete should succeed with the right precondition + client + .delete(key.clone(), Precondition::MatchesVersion(Version::from(2))) .await?; + assert!(client.get::(key.clone()).await?.is_none()); - env.tc.shutdown_node("shutdown", 0).await; + // unconditional delete + client + .put(key.clone(), &other_value, Precondition::None) + .await?; + client.delete(key.clone(), Precondition::None).await?; + assert!(client.get::(key.clone()).await?.is_none()); Ok(()) } /// Tests multiple concurrent operations issued by the same client -#[test(tokio::test(flavor = "multi_thread", worker_threads = 2))] +#[test(restate_core::test(flavor = "multi_thread", worker_threads = 2))] async fn concurrent_operations() -> anyhow::Result<()> { - let (client, env) = create_test_environment(&MetadataStoreOptions::default()).await?; - - env.tc - .run_in_scope("test", None, async move { - let mut concurrent_operations = FuturesUnordered::default(); - - for key in 1u32..=10 { - for _instance in 0..key { - let client = client.clone(); - let key = ByteString::from(key.to_string()); - concurrent_operations.push(async move { - loop { - let value = client.get::(key.clone()).await?; - - let result = if let Some(value) = value { - let previous_version = value.version(); - client - .put( - key.clone(), - &value.next_version(), - Precondition::MatchesVersion(previous_version), - ) - .await - } else { - client - .put(key.clone(), &Value::default(), Precondition::DoesNotExist) - .await - }; - - match result { - Ok(()) => return Ok::<(), anyhow::Error>(()), - Err(WriteError::FailedPrecondition(_)) => continue, - Err(err) => return Err(err.into()), - } - } - }); + let (client, _env) = create_test_environment(&MetadataStoreOptions::default()).await?; + + let mut concurrent_operations = FuturesUnordered::default(); + + for key in 1u32..=10 { + for _instance in 0..key { + let client = client.clone(); + let key = ByteString::from(key.to_string()); + concurrent_operations.push(async move { + loop { + let value = client.get::(key.clone()).await?; + + let result = if let Some(value) = value { + let previous_version = value.version(); + client + .put( + key.clone(), + &value.next_version(), + Precondition::MatchesVersion(previous_version), + ) + .await + } else { + client + .put(key.clone(), &Value::default(), Precondition::DoesNotExist) + .await + }; + + match result { + Ok(()) => return Ok::<(), anyhow::Error>(()), + Err(WriteError::FailedPrecondition(_)) => continue, + Err(err) => return Err(err.into()), + } } - } - - while let Some(result) = concurrent_operations.next().await { - result?; - } + }); + } + } - // sanity check - for key in 1u32..=10 { - let metadata_key = ByteString::from(key.to_string()); - let value = client - .get::(metadata_key) - .await? - .map(|v| v.version()); + while let Some(result) = concurrent_operations.next().await { + result?; + } - assert_eq!(value, Some(Version::from(key))); - } + // sanity check + for key in 1u32..=10 { + let metadata_key = ByteString::from(key.to_string()); + let value = client + .get::(metadata_key) + .await? + .map(|v| v.version()); - Ok::<(), anyhow::Error>(()) - }) - .await?; + assert_eq!(value, Some(Version::from(key))); + } - env.tc.shutdown_node("shutdown", 0).await; Ok(()) } /// Tests that the metadata store stores values durably so that they can be read after a restart. -#[test(tokio::test(flavor = "multi_thread", worker_threads = 2))] +#[test(restate_core::test(flavor = "multi_thread", worker_threads = 2))] async fn durable_storage() -> anyhow::Result<()> { // get current base dir and use this for subsequent tests. let base_path = reset_base_temp_dir_and_retain(); @@ -231,34 +216,26 @@ async fn durable_storage() -> anyhow::Result<()> { assert!(base_path.starts_with(tmp)); assert_eq!(base_path.join("local-metadata-store"), opts.data_dir()); - let (client, env) = create_test_environment(&opts).await?; + let (client, _env) = create_test_environment(&opts).await?; // write data - env.tc - .run_in_scope("write-data", None, async move { - for key in 1u32..=10 { - let value = key.to_string(); - let metadata_key = ByteString::from(value.clone()); - client - .put( - metadata_key, - &Value { - version: Version::from(key), - value, - }, - Precondition::DoesNotExist, - ) - .await?; - } - - Ok::<(), anyhow::Error>(()) - }) - .await?; + for key in 1u32..=10 { + let value = key.to_string(); + let metadata_key = ByteString::from(value.clone()); + client + .put( + metadata_key, + &Value { + version: Version::from(key), + value, + }, + Precondition::DoesNotExist, + ) + .await?; + } // restart the metadata store - env.tc - .cancel_tasks(Some(TaskKind::MetadataStore), None) - .await; + TaskCenter::cancel_tasks(Some(TaskKind::MetadataStore), None).await; // reset RocksDbManager to allow restarting the metadata store RocksDbManager::get().reset().await?; @@ -277,31 +254,24 @@ async fn durable_storage() -> anyhow::Result<()> { metadata_store_client_opts, metadata_store_opts.clone().boxed(), metadata_store_opts.map(|c| &c.rocksdb).boxed(), - &env.tc, ) .await?; // validate data - env.tc - .run_in_scope("validate-data", None, async move { - for key in 1u32..=10 { - let value = key.to_string(); - let metadata_key = ByteString::from(value.clone()); - - assert_eq!( - client.get(metadata_key).await?, - Some(Value { - version: Version::from(key), - value - }) - ); - } - - Ok::<(), anyhow::Error>(()) - }) - .await?; + for key in 1u32..=10 { + let value = key.to_string(); + let metadata_key = ByteString::from(value.clone()); + + assert_eq!( + client.get(metadata_key).await?, + Some(Value { + version: Version::from(key), + value + }) + ); + } - env.tc.shutdown_node("shutdown", 0).await; + TaskCenter::shutdown_node("shutdown", 0).await; std::fs::remove_dir_all(base_path)?; Ok(()) } @@ -329,15 +299,12 @@ async fn create_test_environment( .build() .await; - let task_center = &env.tc; - - task_center.run_in_scope_sync(|| RocksDbManager::init(config.clone().map(|c| &c.common))); + RocksDbManager::init(config.clone().map(|c| &c.common)); let client = start_metadata_store( config.pinned().common.metadata_store_client.clone(), config.clone().map(|c| &c.metadata_store).boxed(), config.clone().map(|c| &c.metadata_store.rocksdb).boxed(), - task_center, ) .await?; @@ -348,7 +315,6 @@ async fn start_metadata_store( metadata_store_client_options: MetadataStoreClientOptions, opts: BoxedLiveLoad, updateables_rocksdb_options: BoxedLiveLoad, - task_center: &TaskCenter, ) -> anyhow::Result { let health_status = HealthStatus::default(); let service = LocalMetadataStoreService::from_options( @@ -357,10 +323,9 @@ async fn start_metadata_store( updateables_rocksdb_options, ); - task_center.spawn( + TaskCenter::spawn( TaskKind::MetadataStore, "local-metadata-store", - None, async move { service.run().await?; Ok(()) diff --git a/crates/node/Cargo.toml b/crates/node/Cargo.toml index 739b8e962..82dd87ea0 100644 --- a/crates/node/Cargo.toml +++ b/crates/node/Cargo.toml @@ -9,7 +9,7 @@ publish = false [features] default = [] -memory-loglet = ["restate-bifrost/memory-loglet"] +memory-loglet = ["restate-bifrost/memory-loglet", "restate-admin/memory-loglet"] replicated-loglet = ["restate-bifrost/replicated-loglet"] options_schema = [ "dep:schemars", diff --git a/crates/node/src/lib.rs b/crates/node/src/lib.rs index 29ab5e91a..4ecc8caed 100644 --- a/crates/node/src/lib.rs +++ b/crates/node/src/lib.rs @@ -23,11 +23,11 @@ use restate_core::network::{ GrpcConnector, MessageRouterBuilder, NetworkServerBuilder, Networking, }; use restate_core::partitions::{spawn_partition_routing_refresher, PartitionRoutingRefresher}; +use restate_core::TaskKind; use restate_core::{ spawn_metadata_manager, MetadataBuilder, MetadataKind, MetadataManager, TargetVersion, TaskCenter, }; -use restate_core::{task_center, TaskKind}; #[cfg(feature = "replicated-loglet")] use restate_log_server::LogServerService; use restate_metadata_store::local::LocalMetadataStoreService; @@ -122,7 +122,6 @@ pub struct Node { impl Node { pub async fn create(updateable_config: Live) -> Result { - let tc = task_center(); let health = Health::default(); health.node_status().update(NodeStatus::StartingUp); let mut server_builder = NetworkServerBuilder::default(); @@ -172,7 +171,6 @@ impl Node { // replicated-loglet #[cfg(feature = "replicated-loglet")] let replicated_loglet_factory = restate_bifrost::providers::replicated_loglet::Factory::new( - tc.clone(), metadata_store_client.clone(), networking.clone(), record_cache.clone(), @@ -256,12 +254,11 @@ impl Node { Some( AdminRole::create( health.admin_status(), - tc.clone(), bifrost.clone(), updateable_config.clone(), - metadata, partition_routing_refresher.partition_routing(), networking.clone(), + metadata, metadata_manager.writer(), &mut server_builder, &mut router_builder, @@ -310,15 +307,12 @@ impl Node { } pub async fn start(mut self) -> Result<(), anyhow::Error> { - let tc = task_center(); - let config = self.updateable_config.pinned(); if let Some(metadata_store) = self.metadata_store_role { - tc.spawn( + TaskCenter::spawn( TaskKind::MetadataStore, "local-metadata-store", - None, async move { metadata_store.run().await?; Ok(()) @@ -335,7 +329,7 @@ impl Node { spawn_metadata_manager(self.metadata_manager)?; // Start partition routing information refresher - spawn_partition_routing_refresher(&tc, self.partition_routing_refresher)?; + spawn_partition_routing_refresher(self.partition_routing_refresher)?; let nodes_config = Self::upsert_node_config(&self.metadata_store_client, &config.common).await?; @@ -403,21 +397,21 @@ impl Node { let admin_node_id = admin_node.current_generation; let networking = self.networking.clone(); - tc.spawn_unmanaged(TaskKind::Disposable, "announce-node-at-admin-node", None, async move { - if let Err(err) = networking - .node_connection(admin_node_id) - .await - { - info!("Failed connecting to admin node '{admin_node_id}' and announcing myself. This can indicate network problems: {err}"); - } - })?; + TaskCenter::spawn_unmanaged( + TaskKind::Disposable, + "announce-node-at-admin-node", + async move { + if let Err(err) = networking.node_connection(admin_node_id).await { + info!("Failed connecting to admin node '{admin_node_id}' and announcing myself. This can indicate network problems: {err}"); + } + }, + )?; } } // Ensures bifrost has initial metadata synced up before starting the worker. // Need to run start in new tc scope to have access to metadata() - tc.run_in_scope("bifrost-init", None, self.bifrost.start()) - .await?; + self.bifrost.start().await?; #[cfg(feature = "replicated-loglet")] if let Some(log_server) = self.log_server { @@ -427,23 +421,18 @@ impl Node { } if let Some(admin_role) = self.admin_role { - tc.spawn(TaskKind::SystemBoot, "admin-init", None, admin_role.start())?; + TaskCenter::spawn(TaskKind::SystemBoot, "admin-init", admin_role.start())?; } if let Some(worker_role) = self.worker_role { - tc.spawn( - TaskKind::SystemBoot, - "worker-init", - None, - worker_role.start(), - )?; + TaskCenter::spawn(TaskKind::SystemBoot, "worker-init", worker_role.start())?; } if let Some(ingress_role) = self.ingress_role { TaskCenter::spawn_child(TaskKind::Ingress, "ingress-http", ingress_role.run())?; } - tc.spawn(TaskKind::RpcServer, "node-rpc-server", None, { + TaskCenter::spawn(TaskKind::RpcServer, "node-rpc-server", { let health = self.health.clone(); let common_options = config.common.clone(); let connection_manager = self.networking.connection_manager().clone(); @@ -463,7 +452,7 @@ impl Node { let my_roles = my_node_config.roles; // Report that the node is running when all roles are ready - let _ = tc.spawn(TaskKind::Disposable, "status-report", None, async move { + let _ = TaskCenter::spawn(TaskKind::Disposable, "status-report", async move { self.health .node_status() .wait_for_value(NodeStatus::Alive) diff --git a/crates/node/src/network_server/grpc_svc_handler.rs b/crates/node/src/network_server/grpc_svc_handler.rs index 5ff24a22d..9e3fb26cf 100644 --- a/crates/node/src/network_server/grpc_svc_handler.rs +++ b/crates/node/src/network_server/grpc_svc_handler.rs @@ -11,7 +11,6 @@ use bytes::BytesMut; use enumset::EnumSet; use futures::stream::BoxStream; -use restate_types::storage::StorageCodec; use tokio_stream::StreamExt; use tonic::{Request, Response, Status, Streaming}; @@ -21,15 +20,15 @@ use restate_core::network::protobuf::node_svc::{ }; use restate_core::network::ConnectionManager; use restate_core::network::{ProtocolError, TransportConnect}; -use restate_core::{ - metadata, MetadataKind, TargetVersion, TaskCenter, TaskCenterFutureExt, TaskKind, -}; +use restate_core::task_center::TaskCenterMonitoring; +use restate_core::{task_center, Metadata, MetadataKind, TargetVersion}; use restate_types::health::Health; use restate_types::nodes_config::Role; use restate_types::protobuf::node::Message; +use restate_types::storage::StorageCodec; pub struct NodeSvcHandler { - task_center: TaskCenter, + task_center: task_center::Handle, cluster_name: String, roles: EnumSet, health: Health, @@ -38,7 +37,7 @@ pub struct NodeSvcHandler { impl NodeSvcHandler { pub fn new( - task_center: TaskCenter, + task_center: task_center::Handle, cluster_name: String, roles: EnumSet, health: Health, @@ -63,24 +62,22 @@ impl NodeSvc for NodeSvcHandler { let metadata_server_status = self.health.current_metadata_server_status(); let log_server_status = self.health.current_log_server_status(); let age_s = self.task_center.age().as_secs(); - self.task_center.run_in_scope_sync(|| { - let metadata = metadata(); - Ok(Response::new(IdentResponse { - status: node_status.into(), - node_id: Some(metadata.my_node_id().into()), - roles: self.roles.iter().map(|r| r.to_string()).collect(), - cluster_name: self.cluster_name.clone(), - age_s, - admin_status: admin_status.into(), - worker_status: worker_status.into(), - metadata_server_status: metadata_server_status.into(), - log_server_status: log_server_status.into(), - nodes_config_version: metadata.nodes_config_version().into(), - logs_version: metadata.logs_version().into(), - schema_version: metadata.schema_version().into(), - partition_table_version: metadata.partition_table_version().into(), - })) - }) + let metadata = Metadata::current(); + Ok(Response::new(IdentResponse { + status: node_status.into(), + node_id: Some(metadata.my_node_id().into()), + roles: self.roles.iter().map(|r| r.to_string()).collect(), + cluster_name: self.cluster_name.clone(), + age_s, + admin_status: admin_status.into(), + worker_status: worker_status.into(), + metadata_server_status: metadata_server_status.into(), + log_server_status: log_server_status.into(), + nodes_config_version: metadata.nodes_config_version().into(), + logs_version: metadata.logs_version().into(), + schema_version: metadata.schema_version().into(), + partition_table_version: metadata.partition_table_version().into(), + })) } type CreateConnectionStream = BoxStream<'static, Result>; @@ -102,7 +99,6 @@ impl NodeSvc for NodeSvcHandler { let output_stream = self .connections .accept_incoming_connection(transformed) - .in_current_tc_as_task(TaskKind::InPlace, "accept-connection") .await?; // For uniformity with outbound connections, we map all responses to Ok, we never rely on @@ -116,7 +112,7 @@ impl NodeSvc for NodeSvcHandler { request: Request, ) -> Result, Status> { let request = request.into_inner(); - let metadata = metadata(); + let metadata = Metadata::current(); let kind = request.kind.into(); if request.sync { metadata diff --git a/crates/node/src/network_server/metrics.rs b/crates/node/src/network_server/metrics.rs index 889daa1eb..ae91f5faf 100644 --- a/crates/node/src/network_server/metrics.rs +++ b/crates/node/src/network_server/metrics.rs @@ -18,6 +18,7 @@ use metrics_util::layers::Layer; use metrics_util::MetricKindMask; use rocksdb::statistics::{Histogram, Ticker}; +use restate_core::task_center::TaskCenterMonitoring; use restate_rocksdb::{CfName, RocksDbManager}; use restate_types::config::CommonOptions; diff --git a/crates/node/src/network_server/service.rs b/crates/node/src/network_server/service.rs index 13125351c..9fe9cb863 100644 --- a/crates/node/src/network_server/service.rs +++ b/crates/node/src/network_server/service.rs @@ -13,7 +13,7 @@ use tonic::codec::CompressionEncoding; use restate_core::network::protobuf::node_svc::node_svc_server::NodeSvcServer; use restate_core::network::{ConnectionManager, NetworkServerBuilder, TransportConnect}; -use restate_core::task_center; +use restate_core::TaskCenter; use restate_types::config::CommonOptions; use restate_types::health::Health; @@ -31,10 +31,9 @@ impl NetworkServer { mut server_builder: NetworkServerBuilder, options: CommonOptions, ) -> Result<(), anyhow::Error> { - let tc = task_center(); // Configure Metric Exporter let mut state_builder = NodeCtrlHandlerStateBuilder::default(); - state_builder.task_center(tc.clone()); + state_builder.task_center(TaskCenter::current()); if !options.disable_prometheus { state_builder.prometheus_handle(Some(install_global_prometheus_recorder(&options))); @@ -51,7 +50,7 @@ impl NetworkServer { server_builder.register_grpc_service( NodeSvcServer::new(NodeSvcHandler::new( - tc, + TaskCenter::current(), options.cluster_name().to_owned(), options.roles, health, diff --git a/crates/node/src/network_server/state.rs b/crates/node/src/network_server/state.rs index b82b13352..b9d9f0687 100644 --- a/crates/node/src/network_server/state.rs +++ b/crates/node/src/network_server/state.rs @@ -9,11 +9,11 @@ // by the Apache License, Version 2.0. use metrics_exporter_prometheus::PrometheusHandle; -use restate_core::TaskCenter; +use restate_core::task_center; #[derive(Clone, derive_builder::Builder)] pub struct NodeCtrlHandlerState { #[builder(default)] pub prometheus_handle: Option, - pub task_center: TaskCenter, + pub task_center: task_center::Handle, } diff --git a/crates/node/src/roles/admin.rs b/crates/node/src/roles/admin.rs index d2fb4a7eb..b06ed3eec 100644 --- a/crates/node/src/roles/admin.rs +++ b/crates/node/src/roles/admin.rs @@ -62,12 +62,11 @@ impl AdminRole { #[allow(clippy::too_many_arguments)] pub async fn create( health_status: HealthStatus, - task_center: TaskCenter, bifrost: Bifrost, updateable_config: Live, - metadata: Metadata, partition_routing: PartitionRouting, networking: Networking, + metadata: Metadata, metadata_writer: MetadataWriter, server_builder: &mut NetworkServerBuilder, router_builder: &mut MessageRouterBuilder, @@ -87,18 +86,14 @@ impl AdminRole { query_context } else { let remote_scanner_manager = RemoteScannerManager::new( - create_remote_scanner_service( - networking.clone(), - task_center.clone(), - router_builder, - ), + create_remote_scanner_service(networking.clone(), router_builder), create_partition_locator(partition_routing, metadata.clone()), ); // need to create a remote query context since we are not co-located with a worker role QueryContext::create( &config.admin.query_engine, - SelectPartitionsFromMetadata::new(metadata.clone()), + SelectPartitionsFromMetadata, None, Option::::None, metadata.updateable_schema(), @@ -122,7 +117,6 @@ impl AdminRole { updateable_config.clone(), health_status, bifrost, - metadata, networking, router_builder, server_builder, diff --git a/crates/partition-store/Cargo.toml b/crates/partition-store/Cargo.toml index 0df225cee..12148d98c 100644 --- a/crates/partition-store/Cargo.toml +++ b/crates/partition-store/Cargo.toml @@ -39,7 +39,7 @@ static_assertions = { workspace = true } strum = { workspace = true } sync_wrapper = { workspace = true } thiserror = { workspace = true } -tokio = { workspace = true } +tokio = { workspace = true, features = ["fs"] } tokio-stream = { workspace = true } tracing = { workspace = true } diff --git a/crates/partition-store/benches/basic_benchmark.rs b/crates/partition-store/benches/basic_benchmark.rs index 135de7b39..50623738e 100644 --- a/crates/partition-store/benches/basic_benchmark.rs +++ b/crates/partition-store/benches/basic_benchmark.rs @@ -44,10 +44,11 @@ fn basic_writing_reading_benchmark(c: &mut Criterion) { let tc = TaskCenterBuilder::default() .default_runtime_handle(rt.handle().clone()) .build() - .expect("task_center builds"); + .expect("task_center builds") + .to_handle(); let worker_options = WorkerOptions::default(); - tc.run_in_scope_sync(|| RocksDbManager::init(Constant::new(CommonOptions::default()))); + tc.run_sync(|| RocksDbManager::init(Constant::new(CommonOptions::default()))); let rocksdb = tc.block_on(async { // // setup diff --git a/crates/partition-store/src/tests/idempotency_table_test/mod.rs b/crates/partition-store/src/tests/idempotency_table_test/mod.rs index 258322c9c..da7ca3d34 100644 --- a/crates/partition-store/src/tests/idempotency_table_test/mod.rs +++ b/crates/partition-store/src/tests/idempotency_table_test/mod.rs @@ -31,7 +31,7 @@ const IDEMPOTENCY_ID_2: IdempotencyId = const IDEMPOTENCY_ID_3: IdempotencyId = IdempotencyId::unkeyed(10, "my-component", "my-handler-2", "my-key"); -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[restate_core::test(flavor = "multi_thread", worker_threads = 2)] async fn test_idempotency_key() { let mut rocksdb = storage_test_environment().await; diff --git a/crates/partition-store/src/tests/invocation_status_table_test/mod.rs b/crates/partition-store/src/tests/invocation_status_table_test/mod.rs index 055861254..aadd219e4 100644 --- a/crates/partition-store/src/tests/invocation_status_table_test/mod.rs +++ b/crates/partition-store/src/tests/invocation_status_table_test/mod.rs @@ -174,7 +174,7 @@ async fn verify_all_svc_with_status_invoked(txn: &mut ); } -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[restate_core::test(flavor = "multi_thread", worker_threads = 2)] async fn test_invocation_status() { let mut rocksdb = storage_test_environment().await; let mut txn = rocksdb.transaction(); @@ -184,7 +184,7 @@ async fn test_invocation_status() { verify_all_svc_with_status_invoked(&mut txn).await; } -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[restate_core::test(flavor = "multi_thread", worker_threads = 2)] async fn test_migration() { let mut rocksdb = storage_test_environment().await; diff --git a/crates/partition-store/src/tests/journal_table_test/mod.rs b/crates/partition-store/src/tests/journal_table_test/mod.rs index c63ad3f59..bf32914d6 100644 --- a/crates/partition-store/src/tests/journal_table_test/mod.rs +++ b/crates/partition-store/src/tests/journal_table_test/mod.rs @@ -129,7 +129,7 @@ async fn verify_journal_deleted(txn: &mut T) { } } -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[restate_core::test(flavor = "multi_thread", worker_threads = 2)] async fn journal_tests() { let mut rocksdb = storage_test_environment().await; diff --git a/crates/partition-store/src/tests/mod.rs b/crates/partition-store/src/tests/mod.rs index b986a3672..b4c095231 100644 --- a/crates/partition-store/src/tests/mod.rs +++ b/crates/partition-store/src/tests/mod.rs @@ -17,7 +17,6 @@ use futures::Stream; use tokio_stream::StreamExt; use crate::{OpenMode, PartitionStore, PartitionStoreManager}; -use restate_core::TaskCenterBuilder; use restate_rocksdb::RocksDbManager; use restate_storage_api::StorageError; use restate_types::config::{CommonOptions, WorkerOptions}; @@ -47,12 +46,7 @@ async fn storage_test_environment_with_manager() -> (PartitionStoreManager, Part // // create a rocksdb storage from options // - let tc = TaskCenterBuilder::default() - .default_runtime_handle(tokio::runtime::Handle::current()) - .ingress_runtime_handle(tokio::runtime::Handle::current()) - .build() - .expect("task_center builds"); - tc.run_in_scope_sync(|| RocksDbManager::init(Constant::new(CommonOptions::default()))); + RocksDbManager::init(Constant::new(CommonOptions::default())); let worker_options = Live::from_value(WorkerOptions::default()); let manager = PartitionStoreManager::create( worker_options.clone().map(|c| &c.storage), @@ -75,7 +69,7 @@ async fn storage_test_environment_with_manager() -> (PartitionStoreManager, Part (manager, store) } -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[restate_core::test(flavor = "multi_thread", worker_threads = 2)] async fn test_read_write() { let (manager, store) = storage_test_environment_with_manager().await; diff --git a/crates/partition-store/src/tests/promise_table_test/mod.rs b/crates/partition-store/src/tests/promise_table_test/mod.rs index 49129c4cf..b88544932 100644 --- a/crates/partition-store/src/tests/promise_table_test/mod.rs +++ b/crates/partition-store/src/tests/promise_table_test/mod.rs @@ -34,7 +34,7 @@ const PROMISE_COMPLETED: Promise = Promise { state: PromiseState::Completed(EntryResult::Success(Bytes::from_static(b"{}"))), }; -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[restate_core::test(flavor = "multi_thread", worker_threads = 2)] async fn test_promise_table() { let mut rocksdb = storage_test_environment().await; diff --git a/crates/partition-store/src/tests/state_table_test/mod.rs b/crates/partition-store/src/tests/state_table_test/mod.rs index 2384074bd..1603c6831 100644 --- a/crates/partition-store/src/tests/state_table_test/mod.rs +++ b/crates/partition-store/src/tests/state_table_test/mod.rs @@ -111,7 +111,7 @@ pub(crate) async fn run_tests(mut rocksdb: PartitionStore) { verify_prefix_scan_after_delete(&mut txn).await; } -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[restate_core::test(flavor = "multi_thread", worker_threads = 2)] async fn test_delete_all() { let mut rocksdb = storage_test_environment().await; diff --git a/crates/rocksdb/src/db_manager.rs b/crates/rocksdb/src/db_manager.rs index 0fa04692e..34c9467d5 100644 --- a/crates/rocksdb/src/db_manager.rs +++ b/crates/rocksdb/src/db_manager.rs @@ -18,7 +18,7 @@ use rocksdb::{BlockBasedOptions, Cache, LogLevel, WriteBufferManager}; use tokio::sync::mpsc; use tracing::{debug, info, warn}; -use restate_core::{cancellation_watcher, task_center, ShutdownError, TaskKind}; +use restate_core::{cancellation_watcher, ShutdownError, TaskCenter, TaskKind}; use restate_serde_util::ByteCount; use restate_types::config::{ CommonOptions, Configuration, RocksDbLogLevel, RocksDbOptions, StatisticsLevel, @@ -121,14 +121,12 @@ impl RocksDbManager { DB_MANAGER.set(manager).expect("DBManager initialized once"); // Start db monitoring. - task_center() - .spawn( - TaskKind::SystemService, - "db-manager", - None, - DbWatchdog::run(Self::get(), watchdog_rx, base_opts), - ) - .expect("run db watchdog"); + TaskCenter::spawn( + TaskKind::SystemService, + "db-manager", + DbWatchdog::run(Self::get(), watchdog_rx, base_opts), + ) + .expect("run db watchdog"); Self::get() } diff --git a/crates/storage-query-datafusion/src/context.rs b/crates/storage-query-datafusion/src/context.rs index 0a98c1b9c..8718cf830 100644 --- a/crates/storage-query-datafusion/src/context.rs +++ b/crates/storage-query-datafusion/src/context.rs @@ -323,25 +323,13 @@ impl AsRef for QueryContext { /// Newtype to add debug implementation which is required for [`SelectPartitions`]. #[derive(Clone, derive_more::Debug)] -pub struct SelectPartitionsFromMetadata { - #[debug(skip)] - metadata: Metadata, -} - -impl SelectPartitionsFromMetadata { - pub fn new(metadata: Metadata) -> Self { - Self { metadata } - } -} +pub struct SelectPartitionsFromMetadata; #[async_trait] impl SelectPartitions for SelectPartitionsFromMetadata { async fn get_live_partitions(&self) -> Result, GenericError> { - Ok(self - .metadata - .partition_table_ref() - .partition_ids() - .cloned() - .collect()) + Ok(Metadata::with_current(|m| { + m.partition_table_ref().partition_ids().cloned().collect() + })) } } diff --git a/crates/storage-query-datafusion/src/idempotency/tests.rs b/crates/storage-query-datafusion/src/idempotency/tests.rs index c8733dd98..b725bf41f 100644 --- a/crates/storage-query-datafusion/src/idempotency/tests.rs +++ b/crates/storage-query-datafusion/src/idempotency/tests.rs @@ -16,20 +16,13 @@ use datafusion::arrow::record_batch::RecordBatch; use futures::StreamExt; use googletest::all; use googletest::prelude::{assert_that, eq}; -use restate_core::TaskCenterBuilder; use restate_storage_api::idempotency_table::{IdempotencyMetadata, IdempotencyTable}; use restate_storage_api::Transaction; use restate_types::identifiers::{IdempotencyId, InvocationId}; -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[restate_core::test(flavor = "multi_thread", worker_threads = 2)] async fn get_idempotency_key() { - let tc = TaskCenterBuilder::default() - .default_runtime_handle(tokio::runtime::Handle::current()) - .build() - .expect("task_center builds"); - let mut engine = tc - .run_in_scope("mock-query-engine", None, MockQueryEngine::create()) - .await; + let mut engine = MockQueryEngine::create().await; let mut tx = engine.partition_store().transaction(); let invocation_id_1 = InvocationId::mock_random(); diff --git a/crates/storage-query-datafusion/src/inbox/tests.rs b/crates/storage-query-datafusion/src/inbox/tests.rs index 005bac25c..5b8c15f1a 100644 --- a/crates/storage-query-datafusion/src/inbox/tests.rs +++ b/crates/storage-query-datafusion/src/inbox/tests.rs @@ -15,21 +15,14 @@ use datafusion::arrow::record_batch::RecordBatch; use futures::StreamExt; use googletest::all; use googletest::prelude::{assert_that, eq}; -use restate_core::TaskCenterBuilder; use restate_storage_api::inbox_table::{InboxEntry, InboxTable}; use restate_storage_api::Transaction; use restate_types::identifiers::InvocationId; use restate_types::invocation::InvocationTarget; -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[restate_core::test(flavor = "multi_thread", worker_threads = 2)] async fn get_inbox() { - let tc = TaskCenterBuilder::default() - .default_runtime_handle(tokio::runtime::Handle::current()) - .build() - .expect("task_center builds"); - let mut engine = tc - .run_in_scope("mock-query-engine", None, MockQueryEngine::create()) - .await; + let mut engine = MockQueryEngine::create().await; let mut tx = engine.partition_store().transaction(); let invocation_target = InvocationTarget::mock_virtual_object(); diff --git a/crates/storage-query-datafusion/src/journal/tests.rs b/crates/storage-query-datafusion/src/journal/tests.rs index 53b00e954..614529d90 100644 --- a/crates/storage-query-datafusion/src/journal/tests.rs +++ b/crates/storage-query-datafusion/src/journal/tests.rs @@ -17,7 +17,6 @@ use futures::StreamExt; use googletest::all; use googletest::prelude::{assert_that, eq}; use prost::Message; -use restate_core::TaskCenterBuilder; use restate_service_protocol::codec::ProtobufRawEntryCodec; use restate_storage_api::journal_table::{JournalEntry, JournalTable}; use restate_storage_api::Transaction; @@ -29,15 +28,9 @@ use restate_types::journal::enriched::{ use restate_types::journal::{Entry, EntryType, InputEntry}; use restate_types::service_protocol; -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[restate_core::test(flavor = "multi_thread", worker_threads = 2)] async fn get_entries() { - let tc = TaskCenterBuilder::default() - .default_runtime_handle(tokio::runtime::Handle::current()) - .build() - .expect("task_center builds"); - let mut engine = tc - .run_in_scope("mock-query-engine", None, MockQueryEngine::create()) - .await; + let mut engine = MockQueryEngine::create().await; let mut tx = engine.partition_store().transaction(); let journal_invocation_id = InvocationId::mock_random(); @@ -131,15 +124,9 @@ async fn get_entries() { ); } -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[restate_core::test(flavor = "multi_thread", worker_threads = 2)] async fn select_count_star() { - let tc = TaskCenterBuilder::default() - .default_runtime_handle(tokio::runtime::Handle::current()) - .build() - .expect("task_center builds"); - let mut engine = tc - .run_in_scope("mock-query-engine", None, MockQueryEngine::create()) - .await; + let mut engine = MockQueryEngine::create().await; let mut tx = engine.partition_store().transaction(); let journal_invocation_id = InvocationId::mock_random(); diff --git a/crates/storage-query-datafusion/src/mocks.rs b/crates/storage-query-datafusion/src/mocks.rs index 46428b066..a4fe33a65 100644 --- a/crates/storage-query-datafusion/src/mocks.rs +++ b/crates/storage-query-datafusion/src/mocks.rs @@ -19,7 +19,6 @@ use datafusion::arrow::record_batch::RecordBatch; use datafusion::common::DataFusionError; use datafusion::execution::SendableRecordBatchStream; use googletest::matcher::{Matcher, MatcherResult}; -use restate_core::task_center; use restate_invoker_api::status_handle::test_util::MockStatusHandle; use restate_invoker_api::StatusHandle; use restate_partition_store::{OpenMode, PartitionStore, PartitionStoreManager}; @@ -161,8 +160,7 @@ impl MockQueryEngine { + 'static, ) -> Self { // Prepare Rocksdb - task_center() - .run_in_scope_sync(|| RocksDbManager::init(Constant::new(CommonOptions::default()))); + RocksDbManager::init(Constant::new(CommonOptions::default())); let worker_options = Live::from_value(WorkerOptions::default()); let manager = PartitionStoreManager::create( worker_options.clone().map(|c| &c.storage), diff --git a/crates/storage-query-datafusion/src/remote_query_scanner_client.rs b/crates/storage-query-datafusion/src/remote_query_scanner_client.rs index 4794b43b3..916da0c8f 100644 --- a/crates/storage-query-datafusion/src/remote_query_scanner_client.rs +++ b/crates/storage-query-datafusion/src/remote_query_scanner_client.rs @@ -12,22 +12,24 @@ use std::fmt::{Debug, Formatter}; use std::ops::RangeInclusive; use std::sync::Arc; -use crate::{decode_record_batch, encode_schema}; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; use datafusion::error::DataFusionError; use datafusion::execution::SendableRecordBatchStream; use datafusion::physical_plan::stream::RecordBatchReceiverStream; +use tracing::warn; + use restate_core::network::rpc_router::RpcRouter; use restate_core::network::{Incoming, MessageRouterBuilder, Networking, TransportConnect}; -use restate_core::TaskCenter; +use restate_core::{task_center, TaskCenter, TaskCenterFutureExt, TaskKind}; use restate_types::identifiers::{PartitionId, PartitionKey}; use restate_types::net::remote_query_scanner::{ RemoteQueryScannerClose, RemoteQueryScannerClosed, RemoteQueryScannerNext, RemoteQueryScannerNextResult, RemoteQueryScannerOpen, RemoteQueryScannerOpened, }; use restate_types::NodeId; -use tracing::warn; + +use crate::{decode_record_batch, encode_schema}; // ----- rpc service definition ----- @@ -55,12 +57,11 @@ pub trait RemoteScannerService: Send + Sync + Debug + 'static { // ----- service proxy ----- pub fn create_remote_scanner_service( network: Networking, - task_center: TaskCenter, router_builder: &mut MessageRouterBuilder, ) -> Arc { Arc::new(RemoteScannerServiceProxy::new( network, - task_center, + TaskCenter::current(), router_builder, )) } @@ -165,7 +166,7 @@ pub fn remote_scan_as_datafusion_stream( #[derive(Clone)] struct RemoteScannerServiceProxy { networking: Networking, - task_center: TaskCenter, + task_center: task_center::Handle, open_rpc: RpcRouter, next_rpc: RpcRouter, close_rpc: RpcRouter, @@ -180,7 +181,7 @@ impl Debug for RemoteScannerServiceProxy { impl RemoteScannerServiceProxy { fn new( networking: Networking, - task_center: TaskCenter, + task_center: task_center::Handle, router_builder: &mut MessageRouterBuilder, ) -> Self { Self { @@ -200,15 +201,16 @@ impl RemoteScannerService for RemoteScannerServiceProxy peer: NodeId, req: RemoteQueryScannerOpen, ) -> Result { - self.task_center - .run_in_scope("RemoteScannerServiceProxy::open", None, async { - self.open_rpc - .call(&self.networking, peer, req) - .await - .map_err(|e| DataFusionError::External(e.into())) - .map(Incoming::into_body) - }) + self.open_rpc + .call(&self.networking, peer, req) + .in_tc_as_task( + &self.task_center, + TaskKind::InPlace, + "RemoteScannerServiceProxy::open", + ) .await + .map_err(|e| DataFusionError::External(e.into())) + .map(Incoming::into_body) } async fn next_batch( @@ -216,15 +218,16 @@ impl RemoteScannerService for RemoteScannerServiceProxy peer: NodeId, req: RemoteQueryScannerNext, ) -> Result { - self.task_center - .run_in_scope("RemoteScannerServiceProxy::next_batch", None, async { - self.next_rpc - .call(&self.networking, peer, req) - .await - .map_err(|e| DataFusionError::External(e.into())) - .map(Incoming::into_body) - }) + self.next_rpc + .call(&self.networking, peer, req) + .in_tc_as_task( + &self.task_center, + TaskKind::InPlace, + "RemoteScannerServiceProxy::next_batch", + ) .await + .map_err(|e| DataFusionError::External(e.into())) + .map(Incoming::into_body) } async fn close( @@ -232,14 +235,15 @@ impl RemoteScannerService for RemoteScannerServiceProxy peer: NodeId, req: RemoteQueryScannerClose, ) -> Result { - self.task_center - .run_in_scope("RemoteScannerServiceProxy::close", None, async { - self.close_rpc - .call(&self.networking, peer, req) - .await - .map_err(|e| DataFusionError::External(e.into())) - .map(Incoming::into_body) - }) + self.close_rpc + .call(&self.networking, peer, req) + .in_tc_as_task( + &self.task_center, + TaskKind::InPlace, + "RemoteScannerServiceProxy::close", + ) .await + .map_err(|e| DataFusionError::External(e.into())) + .map(Incoming::into_body) } } diff --git a/crates/storage-query-datafusion/src/remote_query_scanner_server.rs b/crates/storage-query-datafusion/src/remote_query_scanner_server.rs index a82ecefa5..e25fd7374 100644 --- a/crates/storage-query-datafusion/src/remote_query_scanner_server.rs +++ b/crates/storage-query-datafusion/src/remote_query_scanner_server.rs @@ -105,8 +105,6 @@ impl RemoteQueryScannerServer { let mut scanners: HashMap = Default::default(); let mut interval = time::interval(expire_old_scanners_after); - let node_id = my_node_id(); - loop { tokio::select! { biased; @@ -119,7 +117,7 @@ impl RemoteQueryScannerServer { }, Some(scan_req) = open_stream.next() => { next_scanner_id += 1; - let scanner_id = ScannerId(node_id, next_scanner_id); + let scanner_id = ScannerId(my_node_id(), next_scanner_id); Self::on_open(scanner_id, scan_req, &mut scanners, query_context.clone()).await; }, Some(next_req) = next_stream.next() => { diff --git a/crates/storage-query-datafusion/src/tests.rs b/crates/storage-query-datafusion/src/tests.rs index 9476af445..a3815ddac 100644 --- a/crates/storage-query-datafusion/src/tests.rs +++ b/crates/storage-query-datafusion/src/tests.rs @@ -8,14 +8,14 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. -use crate::mocks::*; -use crate::row; +use std::time::{Duration, SystemTime}; + use datafusion::arrow::array::{LargeStringArray, UInt64Array}; use datafusion::arrow::record_batch::RecordBatch; use futures::StreamExt; use googletest::all; use googletest::prelude::{assert_that, eq}; -use restate_core::TaskCenterBuilder; + use restate_invoker_api::status_handle::test_util::MockStatusHandle; use restate_invoker_api::status_handle::InvocationStatusReportInner; use restate_invoker_api::{InvocationErrorReport, InvocationStatusReport}; @@ -29,46 +29,39 @@ use restate_types::identifiers::PartitionId; use restate_types::identifiers::{DeploymentId, InvocationId}; use restate_types::invocation::InvocationTarget; use restate_types::journal::EntryType; -use std::time::{Duration, SystemTime}; -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +use crate::mocks::*; +use crate::row; + +#[restate_core::test(flavor = "multi_thread", worker_threads = 2)] async fn query_sys_invocation() { let invocation_id = InvocationId::mock_random(); let invocation_target = InvocationTarget::service("MySvc", "MyMethod"); let invocation_error = InvocationError::internal("my error"); - let tc = TaskCenterBuilder::default() - .default_runtime_handle(tokio::runtime::Handle::current()) - .build() - .expect("task_center builds"); - let mut engine = tc - .run_in_scope( - "mock-query-engine", - None, - MockQueryEngine::create_with( - MockStatusHandle::default().with(InvocationStatusReport::new( - invocation_id, - (PartitionId::MIN, LeaderEpoch::INITIAL), - InvocationStatusReportInner { - in_flight: false, - start_count: 1, - last_start_at: SystemTime::now() - Duration::from_secs(10), - last_retry_attempt_failure: Some(InvocationErrorReport { - err: invocation_error.clone(), - doc_error_code: None, - related_entry_index: Some(1), - related_entry_name: Some("my-side-effect".to_string()), - related_entry_type: Some(EntryType::Run), - }), - next_retry_at: Some(SystemTime::now() + Duration::from_secs(10)), - last_attempt_deployment_id: Some(DeploymentId::new()), - last_attempt_server: Some("restate-sdk-java/0.8.0".to_owned()), - }, - )), - MockSchemas::default(), - ), - ) - .await; + let mut engine = MockQueryEngine::create_with( + MockStatusHandle::default().with(InvocationStatusReport::new( + invocation_id, + (PartitionId::MIN, LeaderEpoch::INITIAL), + InvocationStatusReportInner { + in_flight: false, + start_count: 1, + last_start_at: SystemTime::now() - Duration::from_secs(10), + last_retry_attempt_failure: Some(InvocationErrorReport { + err: invocation_error.clone(), + doc_error_code: None, + related_entry_index: Some(1), + related_entry_name: Some("my-side-effect".to_string()), + related_entry_type: Some(EntryType::Run), + }), + next_retry_at: Some(SystemTime::now() + Duration::from_secs(10)), + last_attempt_deployment_id: Some(DeploymentId::new()), + last_attempt_server: Some("restate-sdk-java/0.8.0".to_owned()), + }, + )), + MockSchemas::default(), + ) + .await; let mut tx = engine.partition_store().transaction(); tx.put_invocation_status( diff --git a/crates/types/src/logs/metadata.rs b/crates/types/src/logs/metadata.rs index b94600b8d..cb526e20c 100644 --- a/crates/types/src/logs/metadata.rs +++ b/crates/types/src/logs/metadata.rs @@ -218,7 +218,7 @@ pub struct LogletConfig { } impl LogletConfig { - #[cfg(any(test, feature = "test-util"))] + #[cfg(any(test, feature = "memory-loglet"))] pub fn for_testing() -> Self { Self { kind: ProviderKind::InMemory, diff --git a/crates/worker/src/lib.rs b/crates/worker/src/lib.rs index 7d47536d7..06498669a 100644 --- a/crates/worker/src/lib.rs +++ b/crates/worker/src/lib.rs @@ -29,7 +29,7 @@ use restate_core::network::Networking; use restate_core::network::TransportConnect; use restate_core::partitions::PartitionRouting; use restate_core::worker_api::ProcessorsManagerHandle; -use restate_core::{task_center, Metadata, TaskKind}; +use restate_core::{Metadata, TaskKind}; use restate_ingress_kafka::Service as IngressKafkaService; use restate_invoker_impl::InvokerHandle as InvokerChannelServiceHandle; use restate_metadata_store::MetadataStoreClient; @@ -149,13 +149,13 @@ impl Worker { router_builder.add_message_handler(partition_processor_manager.message_handler()); let remote_scanner_manager = RemoteScannerManager::new( - create_remote_scanner_service(networking, task_center(), router_builder), + create_remote_scanner_service(networking, router_builder), create_partition_locator(partition_routing, metadata.clone()), ); let schema = metadata.updateable_schema(); let storage_query_context = QueryContext::create( &config.admin.query_engine, - SelectPartitionsFromMetadata::new(metadata), + SelectPartitionsFromMetadata, Some(partition_store_manager.clone()), Some(partition_processor_manager.invokers_status_reader()), schema, diff --git a/crates/worker/src/partition/cleaner.rs b/crates/worker/src/partition/cleaner.rs index c8ca2c4c8..2bea1b639 100644 --- a/crates/worker/src/partition/cleaner.rs +++ b/crates/worker/src/partition/cleaner.rs @@ -168,7 +168,7 @@ mod tests { use futures::{stream, Stream}; use googletest::prelude::*; - use restate_core::{Metadata, TaskCenter, TaskKind, TestCoreEnvBuilder2}; + use restate_core::{Metadata, TaskCenter, TaskKind, TestCoreEnvBuilder}; use restate_storage_api::invocation_status_table::{ CompletedInvocation, InFlightInvocationMetadata, InvocationStatus, }; @@ -213,7 +213,7 @@ mod tests { // Start paused makes sure the timer is immediately fired #[test(restate_core::test(start_paused = true))] pub async fn cleanup_works() { - let _env = TestCoreEnvBuilder2::with_incoming_only_connector() + let _env = TestCoreEnvBuilder::with_incoming_only_connector() .set_partition_table(PartitionTable::with_equally_sized_partitions( Version::MIN, 1, @@ -257,22 +257,20 @@ mod tests { ), ]); - TaskCenter::current() - .spawn( - TaskKind::Cleaner, - "cleaner", - Some(PartitionId::MIN), - Cleaner::new( - PartitionId::MIN, - LeaderEpoch::INITIAL, - mock_storage, - bifrost.clone(), - RangeInclusive::new(PartitionKey::MIN, PartitionKey::MAX), - Duration::from_secs(1), - ) - .run(), + TaskCenter::spawn( + TaskKind::Cleaner, + "cleaner", + Cleaner::new( + PartitionId::MIN, + LeaderEpoch::INITIAL, + mock_storage, + bifrost.clone(), + RangeInclusive::new(PartitionKey::MIN, PartitionKey::MAX), + Duration::from_secs(1), ) - .unwrap(); + .run(), + ) + .unwrap(); // By yielding once we let the cleaner task run, and perform the cleanup tokio::task::yield_now().await; diff --git a/crates/worker/src/partition/leadership/mod.rs b/crates/worker/src/partition/leadership/mod.rs index f25176e63..5cdaa57d4 100644 --- a/crates/worker/src/partition/leadership/mod.rs +++ b/crates/worker/src/partition/leadership/mod.rs @@ -23,7 +23,7 @@ use tracing::{debug, instrument, warn}; use restate_bifrost::Bifrost; use restate_core::network::Reciprocal; -use restate_core::{metadata, my_node_id, ShutdownError, TaskCenter, TaskKind}; +use restate_core::{my_node_id, ShutdownError, TaskCenter, TaskKind}; use restate_errors::NotRunningError; use restate_invoker_api::InvokeInputJournal; use restate_partition_store::PartitionStore; @@ -227,7 +227,6 @@ where self.partition_processor_metadata.partition_id, EpochSequenceNumber::new(leader_epoch), &self.bifrost, - metadata(), )?; self_proposer @@ -591,7 +590,7 @@ mod tests { use crate::partition::leadership::{LeadershipState, PartitionProcessorMetadata, State}; use assert2::let_assert; use restate_bifrost::Bifrost; - use restate_core::{TaskCenter, TestCoreEnv2}; + use restate_core::{TaskCenter, TestCoreEnv}; use restate_invoker_api::test_util::MockInvokerHandle; use restate_partition_store::{OpenMode, PartitionStoreManager}; use restate_rocksdb::RocksDbManager; @@ -615,7 +614,7 @@ mod tests { #[test(restate_core::test)] async fn become_leader_then_step_down() -> googletest::Result<()> { - let _env = TestCoreEnv2::create_with_single_node(0, 0).await; + let _env = TestCoreEnv::create_with_single_node(0, 0).await; let storage_options = StorageOptions::default(); let rocksdb_options = RocksDbOptions::default(); diff --git a/crates/worker/src/partition/leadership/self_proposer.rs b/crates/worker/src/partition/leadership/self_proposer.rs index 91e24512a..a3611463c 100644 --- a/crates/worker/src/partition/leadership/self_proposer.rs +++ b/crates/worker/src/partition/leadership/self_proposer.rs @@ -11,7 +11,7 @@ use crate::partition::leadership::Error; use futures::never::Never; use restate_bifrost::{Bifrost, CommitToken}; -use restate_core::Metadata; +use restate_core::my_node_id; use restate_storage_api::deduplication_table::{DedupInformation, EpochSequenceNumber}; use restate_types::identifiers::{PartitionId, PartitionKey}; use restate_types::logs::LogId; @@ -33,7 +33,6 @@ pub struct SelfProposer { partition_id: PartitionId, epoch_sequence_number: EpochSequenceNumber, bifrost_appender: restate_bifrost::AppenderHandle, - metadata: Metadata, } impl SelfProposer { @@ -41,7 +40,6 @@ impl SelfProposer { partition_id: PartitionId, epoch_sequence_number: EpochSequenceNumber, bifrost: &Bifrost, - metadata: Metadata, ) -> Result { let bifrost_appender = bifrost .create_background_appender( @@ -49,13 +47,12 @@ impl SelfProposer { BIFROST_QUEUE_SIZE, MAX_BIFROST_APPEND_BATCH, )? - .start("self-appender", Some(partition_id))?; + .start("self-appender")?; Ok(Self { partition_id, epoch_sequence_number, bifrost_appender, - metadata, }) } @@ -97,7 +94,7 @@ impl SelfProposer { let esn = self.epoch_sequence_number; self.epoch_sequence_number = self.epoch_sequence_number.next(); - let my_node_id = self.metadata.my_node_id(); + let my_node_id = my_node_id(); Header { dest: Destination::Processor { partition_key, diff --git a/crates/worker/src/partition/shuffle.rs b/crates/worker/src/partition/shuffle.rs index 6a568addc..87bbceeb8 100644 --- a/crates/worker/src/partition/shuffle.rs +++ b/crates/worker/src/partition/shuffle.rs @@ -454,7 +454,7 @@ mod tests { use restate_bifrost::{Bifrost, LogEntry}; use restate_core::network::FailingConnector; - use restate_core::{TaskCenter, TaskKind, TestCoreEnv2, TestCoreEnvBuilder2}; + use restate_core::{TaskCenter, TaskKind, TestCoreEnv, TestCoreEnvBuilder}; use restate_storage_api::outbox_table::OutboxMessage; use restate_storage_api::StorageError; use restate_types::identifiers::{InvocationId, LeaderEpoch, PartitionId}; @@ -622,7 +622,7 @@ mod tests { struct ShuffleEnv { #[allow(dead_code)] - env: TestCoreEnv2, + env: TestCoreEnv, bifrost: Bifrost, shuffle: Shuffle, } @@ -631,7 +631,7 @@ mod tests { outbox_reader: OR, ) -> ShuffleEnv { // set numbers of partitions to 1 to easily find all sent messages by the shuffle - let env = TestCoreEnvBuilder2::with_incoming_only_connector() + let env = TestCoreEnvBuilder::with_incoming_only_connector() .set_partition_table(PartitionTable::with_equally_sized_partitions( Version::MIN, 1, diff --git a/crates/worker/src/partition/state_machine/tests/delayed_send.rs b/crates/worker/src/partition/state_machine/tests/delayed_send.rs index 0b6d57777..19513abfe 100644 --- a/crates/worker/src/partition/state_machine/tests/delayed_send.rs +++ b/crates/worker/src/partition/state_machine/tests/delayed_send.rs @@ -16,7 +16,7 @@ use restate_types::time::MillisSinceEpoch; use std::time::{Duration, SystemTime}; use test_log::test; -#[test(tokio::test)] +#[test(restate_core::test)] async fn send_with_delay() { let mut test_env = TestEnv::create().await; @@ -74,7 +74,7 @@ async fn send_with_delay() { test_env.shutdown().await; } -#[test(tokio::test)] +#[test(restate_core::test)] async fn send_with_delay_to_locked_virtual_object() { let mut test_env = TestEnv::create().await; @@ -152,7 +152,7 @@ async fn send_with_delay_to_locked_virtual_object() { test_env.shutdown().await; } -#[test(tokio::test)] +#[test(restate_core::test)] async fn send_with_delay_and_idempotency_key() { let mut test_env = TestEnv::create().await; diff --git a/crates/worker/src/partition/state_machine/tests/idempotency.rs b/crates/worker/src/partition/state_machine/tests/idempotency.rs index 8e6929d14..66fb2c8a6 100644 --- a/crates/worker/src/partition/state_machine/tests/idempotency.rs +++ b/crates/worker/src/partition/state_machine/tests/idempotency.rs @@ -27,7 +27,7 @@ use std::time::Duration; #[rstest] #[case(true)] #[case(false)] -#[tokio::test] +#[restate_core::test] async fn start_and_complete_idempotent_invocation(#[case] disable_idempotency_table: bool) { let mut test_env = TestEnv::create_with_options(disable_idempotency_table).await; @@ -133,7 +133,7 @@ async fn start_and_complete_idempotent_invocation(#[case] disable_idempotency_ta #[rstest] #[case(true)] #[case(false)] -#[tokio::test] +#[restate_core::test] async fn start_and_complete_idempotent_invocation_neo_table( #[case] disable_idempotency_table: bool, ) { @@ -245,7 +245,7 @@ async fn start_and_complete_idempotent_invocation_neo_table( #[rstest] #[case(true)] #[case(false)] -#[tokio::test] +#[restate_core::test] async fn complete_already_completed_invocation(#[case] disable_idempotency_table: bool) { let mut test_env = TestEnv::create_with_options(disable_idempotency_table).await; @@ -304,7 +304,7 @@ async fn complete_already_completed_invocation(#[case] disable_idempotency_table #[rstest] #[case(true)] #[case(false)] -#[tokio::test] +#[restate_core::test] async fn attach_with_service_invocation_command_while_executing( #[case] disable_idempotency_table: bool, ) { @@ -403,7 +403,7 @@ async fn attach_with_service_invocation_command_while_executing( #[case(true, false)] #[case(false, true)] #[case(false, false)] -#[tokio::test] +#[restate_core::test] async fn attach_with_send_service_invocation( #[case] disable_idempotency_table: bool, #[case] use_same_request_id: bool, @@ -526,7 +526,7 @@ async fn attach_with_send_service_invocation( #[rstest] #[case(true)] #[case(false)] -#[tokio::test] +#[restate_core::test] async fn attach_inboxed_with_send_service_invocation(#[case] disable_idempotency_table: bool) { let mut test_env = TestEnv::create_with_options(disable_idempotency_table).await; @@ -622,7 +622,7 @@ async fn attach_inboxed_with_send_service_invocation(#[case] disable_idempotency #[rstest] #[case(true)] #[case(false)] -#[tokio::test] +#[restate_core::test] async fn attach_command(#[case] disable_idempotency_table: bool) { let mut test_env = TestEnv::create_with_options(disable_idempotency_table).await; @@ -715,7 +715,7 @@ async fn attach_command(#[case] disable_idempotency_table: bool) { test_env.shutdown().await; } -#[tokio::test] +#[restate_core::test] async fn attach_command_without_blocking_inflight() { let mut test_env = TestEnv::create().await; @@ -775,7 +775,7 @@ async fn attach_command_without_blocking_inflight() { #[rstest] #[case(true)] #[case(false)] -#[tokio::test] +#[restate_core::test] async fn purge_completed_idempotent_invocation(#[case] disable_idempotency_table: bool) { let mut test_env = TestEnv::create_with_options(disable_idempotency_table).await; diff --git a/crates/worker/src/partition/state_machine/tests/kill_cancel.rs b/crates/worker/src/partition/state_machine/tests/kill_cancel.rs index 18f0a1fad..069d1ee39 100644 --- a/crates/worker/src/partition/state_machine/tests/kill_cancel.rs +++ b/crates/worker/src/partition/state_machine/tests/kill_cancel.rs @@ -22,7 +22,7 @@ use restate_types::journal::enriched::EnrichedEntryHeader; use restate_types::service_protocol; use test_log::test; -#[test(tokio::test)] +#[test(restate_core::test)] async fn kill_inboxed_invocation() -> anyhow::Result<()> { let mut test_env = TestEnv::create().await; @@ -108,7 +108,7 @@ async fn kill_inboxed_invocation() -> anyhow::Result<()> { Ok(()) } -#[test(tokio::test)] +#[test(restate_core::test)] async fn kill_call_tree() -> anyhow::Result<()> { let mut test_env = TestEnv::create().await; @@ -218,7 +218,7 @@ async fn kill_call_tree() -> anyhow::Result<()> { Ok(()) } -#[test(tokio::test)] +#[test(restate_core::test)] async fn cancel_invoked_invocation() -> Result<(), Error> { let mut test_env = TestEnv::create().await; @@ -331,7 +331,7 @@ async fn cancel_invoked_invocation() -> Result<(), Error> { Ok(()) } -#[test(tokio::test)] +#[test(restate_core::test)] async fn cancel_suspended_invocation() -> Result<(), Error> { let mut test_env = TestEnv::create().await; @@ -448,7 +448,7 @@ async fn cancel_suspended_invocation() -> Result<(), Error> { Ok(()) } -#[test(tokio::test)] +#[test(restate_core::test)] async fn cancel_invocation_entry_referring_to_previous_entry() { let mut test_env = TestEnv::create().await; diff --git a/crates/worker/src/partition/state_machine/tests/mod.rs b/crates/worker/src/partition/state_machine/tests/mod.rs index ce730ebc0..4548ed299 100644 --- a/crates/worker/src/partition/state_machine/tests/mod.rs +++ b/crates/worker/src/partition/state_machine/tests/mod.rs @@ -29,7 +29,7 @@ use bytestring::ByteString; use futures::{StreamExt, TryStreamExt}; use googletest::matcher::Matcher; use googletest::{all, assert_that, pat, property}; -use restate_core::{task_center, TaskCenter, TaskCenterBuilder}; +use restate_core::TaskCenter; use restate_invoker_api::{EffectKind, InvokeInputJournal}; use restate_partition_store::{OpenMode, PartitionStore, PartitionStoreManager}; use restate_rocksdb::RocksDbManager; @@ -69,8 +69,6 @@ use test_log::test; use tracing_subscriber::fmt::format::FmtSpan; pub struct TestEnv { - #[allow(dead_code)] - task_center: TaskCenter, state_machine: StateMachine, // TODO for the time being we use rocksdb storage because we have no mocks for storage interfaces. // Perhaps we could make these tests faster by having those. @@ -79,7 +77,7 @@ pub struct TestEnv { impl TestEnv { pub async fn shutdown(self) { - self.task_center.shutdown_node("test complete", 0).await; + TaskCenter::shutdown_node("test complete", 0).await; RocksDbManager::get().shutdown().await; } @@ -123,43 +121,33 @@ impl TestEnv { None => FmtSpan::NONE, }).with_test_writer().try_init(); - let tc = TaskCenterBuilder::default() - .default_runtime_handle(tokio::runtime::Handle::current()) - .ingress_runtime_handle(tokio::runtime::Handle::current()) - .build() - .expect("task_center builds"); - - tc.run_in_scope("init", None, async { - RocksDbManager::init(Constant::new(CommonOptions::default())); - let worker_options = Live::from_value(WorkerOptions::default()); - info!( - "Using RocksDB temp directory {}", - worker_options.pinned().storage.data_dir().display() - ); - let manager = PartitionStoreManager::create( - worker_options.clone().map(|c| &c.storage), - worker_options.clone().map(|c| &c.storage.rocksdb).boxed(), - &[], + RocksDbManager::init(Constant::new(CommonOptions::default())); + let worker_options = Live::from_value(WorkerOptions::default()); + info!( + "Using RocksDB temp directory {}", + worker_options.pinned().storage.data_dir().display() + ); + let manager = PartitionStoreManager::create( + worker_options.clone().map(|c| &c.storage), + worker_options.clone().map(|c| &c.storage.rocksdb).boxed(), + &[], + ) + .await + .unwrap(); + let rocksdb_storage = manager + .open_partition_store( + PartitionId::MIN, + RangeInclusive::new(PartitionKey::MIN, PartitionKey::MAX), + OpenMode::CreateIfMissing, + &worker_options.pinned().storage.rocksdb, ) .await .unwrap(); - let rocksdb_storage = manager - .open_partition_store( - PartitionId::MIN, - RangeInclusive::new(PartitionKey::MIN, PartitionKey::MAX), - OpenMode::CreateIfMissing, - &worker_options.pinned().storage.rocksdb, - ) - .await - .unwrap(); - - Self { - task_center: task_center(), - state_machine, - storage: rocksdb_storage, - } - }) - .await + + Self { + state_machine, + storage: rocksdb_storage, + } } pub async fn apply(&mut self, command: Command) -> Vec { @@ -193,7 +181,7 @@ impl TestEnv { type TestResult = Result<(), anyhow::Error>; -#[test(tokio::test)] +#[test(restate_core::test)] async fn start_invocation() -> TestResult { let mut test_env = TestEnv::create().await; let id = fixtures::mock_start_invocation(&mut test_env).await; @@ -204,7 +192,7 @@ async fn start_invocation() -> TestResult { Ok(()) } -#[test(tokio::test)] +#[test(restate_core::test)] async fn shared_invocation_skips_inbox() -> TestResult { let mut test_env = TestEnv::create().await; @@ -249,7 +237,7 @@ async fn shared_invocation_skips_inbox() -> TestResult { Ok(()) } -#[test(tokio::test)] +#[test(restate_core::test)] async fn awakeable_completion_received_before_entry() -> TestResult { let mut test_env = TestEnv::create().await; let invocation_id = fixtures::mock_start_invocation(&mut test_env).await; @@ -353,7 +341,7 @@ async fn awakeable_completion_received_before_entry() -> TestResult { Ok(()) } -#[test(tokio::test)] +#[test(restate_core::test)] async fn complete_awakeable_with_success() { let mut test_env = TestEnv::create().await; let invocation_id = fixtures::mock_start_invocation(&mut test_env).await; @@ -396,7 +384,7 @@ async fn complete_awakeable_with_success() { test_env.shutdown().await; } -#[test(tokio::test)] +#[test(restate_core::test)] async fn complete_awakeable_with_failure() { let mut test_env = TestEnv::create().await; let invocation_id = fixtures::mock_start_invocation(&mut test_env).await; @@ -442,7 +430,7 @@ async fn complete_awakeable_with_failure() { test_env.shutdown().await; } -#[test(tokio::test)] +#[test(restate_core::test)] async fn invoke_with_headers() -> TestResult { let mut test_env = TestEnv::create().await; let service_id = ServiceId::mock_random(); @@ -485,7 +473,7 @@ async fn invoke_with_headers() -> TestResult { Ok(()) } -#[test(tokio::test)] +#[test(restate_core::test)] async fn mutate_state() -> anyhow::Result<()> { let mut test_env = TestEnv::create().await; let invocation_target = InvocationTarget::mock_virtual_object(); @@ -554,7 +542,7 @@ async fn mutate_state() -> anyhow::Result<()> { Ok(()) } -#[test(tokio::test)] +#[test(restate_core::test)] async fn clear_all_user_states() -> anyhow::Result<()> { let mut test_env = TestEnv::create().await; let service_id = ServiceId::new("MySvc", "my-key"); @@ -591,7 +579,7 @@ async fn clear_all_user_states() -> anyhow::Result<()> { Ok(()) } -#[test(tokio::test)] +#[test(restate_core::test)] async fn get_state_keys() -> TestResult { let mut test_env = TestEnv::create().await; let service_id = ServiceId::mock_random(); @@ -632,7 +620,7 @@ async fn get_state_keys() -> TestResult { Ok(()) } -#[test(tokio::test)] +#[test(restate_core::test)] async fn get_invocation_id_entry() { let mut test_env = TestEnv::create().await; let invocation_id = fixtures::mock_start_invocation(&mut test_env).await; @@ -721,7 +709,7 @@ async fn get_invocation_id_entry() { test_env.shutdown().await; } -#[tokio::test] +#[restate_core::test] async fn attach_invocation_entry() { let mut test_env = TestEnv::create().await; let invocation_id = fixtures::mock_start_invocation(&mut test_env).await; @@ -765,7 +753,7 @@ async fn attach_invocation_entry() { test_env.shutdown().await; } -#[tokio::test] +#[restate_core::test] async fn get_invocation_output_entry() { let mut test_env = TestEnv::create().await; let invocation_id = fixtures::mock_start_invocation(&mut test_env).await; @@ -826,7 +814,7 @@ async fn get_invocation_output_entry() { test_env.shutdown().await; } -#[test(tokio::test)] +#[test(restate_core::test)] async fn send_ingress_response_to_multiple_targets() -> TestResult { let mut test_env = TestEnv::create().await; let invocation_target = InvocationTarget::mock_virtual_object(); @@ -933,7 +921,7 @@ async fn send_ingress_response_to_multiple_targets() -> TestResult { Ok(()) } -#[test(tokio::test)] +#[test(restate_core::test)] async fn truncate_outbox_from_empty() -> Result<(), Error> { // An outbox message with index 0 has been successfully processed, and must now be truncated let outbox_index = 0; @@ -955,7 +943,7 @@ async fn truncate_outbox_from_empty() -> Result<(), Error> { Ok(()) } -#[test(tokio::test)] +#[test(restate_core::test)] async fn truncate_outbox_with_gap() -> Result<(), Error> { // The outbox contains items [3..=5], and the range must be truncated after message 5 is processed let outbox_head_index = 3; @@ -988,7 +976,7 @@ async fn truncate_outbox_with_gap() -> Result<(), Error> { Ok(()) } -#[test(tokio::test)] +#[test(restate_core::test)] async fn consecutive_exclusive_handler_invocations_will_use_inbox() -> TestResult { let mut test_env = TestEnv::create().await; @@ -1100,7 +1088,7 @@ async fn consecutive_exclusive_handler_invocations_will_use_inbox() -> TestResul Ok(()) } -#[test(tokio::test)] +#[test(restate_core::test)] async fn deduplicate_requests_with_same_pp_rpc_request_id() -> TestResult { let mut test_env = TestEnv::create().await; let invocation_id = InvocationId::mock_random(); diff --git a/crates/worker/src/partition/state_machine/tests/workflow.rs b/crates/worker/src/partition/state_machine/tests/workflow.rs index a60829b26..5c5448f2e 100644 --- a/crates/worker/src/partition/state_machine/tests/workflow.rs +++ b/crates/worker/src/partition/state_machine/tests/workflow.rs @@ -22,7 +22,7 @@ use std::time::Duration; #[rstest] #[case(true)] #[case(false)] -#[tokio::test] +#[restate_core::test] async fn start_workflow_method(#[case] disable_idempotency_table: bool) { let mut test_env = TestEnv::create_with_options(disable_idempotency_table).await; @@ -186,7 +186,7 @@ async fn start_workflow_method(#[case] disable_idempotency_table: bool) { #[rstest] #[case(true)] #[case(false)] -#[tokio::test] +#[restate_core::test] async fn attach_by_workflow_key(#[case] disable_idempotency_table: bool) { let mut test_env = TestEnv::create_with_options(disable_idempotency_table).await; @@ -324,7 +324,7 @@ async fn attach_by_workflow_key(#[case] disable_idempotency_table: bool) { #[rstest] #[case(true)] #[case(false)] -#[tokio::test] +#[restate_core::test] async fn purge_completed_workflow(#[case] disable_idempotency_table: bool) { let mut test_env = TestEnv::create_with_options(disable_idempotency_table).await; diff --git a/crates/worker/src/partition_processor_manager/mod.rs b/crates/worker/src/partition_processor_manager/mod.rs index 642c8145c..6d9a96939 100644 --- a/crates/worker/src/partition_processor_manager/mod.rs +++ b/crates/worker/src/partition_processor_manager/mod.rs @@ -35,7 +35,8 @@ use restate_core::worker_api::{ SnapshotResult, }; use restate_core::{ - cancellation_watcher, Metadata, ShutdownError, TaskCenterFutureExt, TaskHandle, TaskKind, + cancellation_watcher, my_node_id, Metadata, ShutdownError, TaskCenterFutureExt, TaskHandle, + TaskKind, }; use restate_core::{RuntimeRootTaskHandle, TaskCenter}; use restate_invoker_api::StatusHandle; @@ -293,10 +294,9 @@ impl PartitionProcessorManager { match self.processor_states.get(&partition_id) { None => { // ignore shutdown errors - let _ = TaskCenter::current().spawn( + let _ = TaskCenter::spawn( TaskKind::Disposable, "partition-processor-rpc-response", - None, async move { partition_processor_rpc .to_rpc_response(Err(PartitionProcessorRpcError::NotLeader( @@ -481,7 +481,7 @@ impl PartitionProcessorManager { leader_epoch_token, partition_id, metadata_store_client, - Metadata::with_current(|m| m.my_node_id()), + my_node_id(), ) .in_current_tc(), ); @@ -628,7 +628,7 @@ impl PartitionProcessorManager { // We spawn the partition processors start tasks on the blocking thread pool due to a macOS issue // where doing otherwise appears to starve the Tokio event loop, causing very slow startup. - let handle = TaskCenter::current().spawn_blocking_unmanaged( + let handle = TaskCenter::spawn_blocking_unmanaged( "starting-partition-processor", starting_task.run(), ); @@ -779,10 +779,9 @@ impl PartitionProcessorManager { node_name: config.common.node_name().into(), }; - let spawn_task_result = TaskCenter::current().spawn_unmanaged( + let spawn_task_result = TaskCenter::spawn_unmanaged( TaskKind::PartitionSnapshotProducer, "create-snapshot", - Some(partition_id), create_snapshot_task.run(), ); @@ -897,7 +896,7 @@ mod tests { use restate_bifrost::providers::memory_loglet; use restate_bifrost::BifrostService; use restate_core::network::MockPeerConnection; - use restate_core::{TaskCenter, TaskKind, TestCoreEnvBuilder2}; + use restate_core::{TaskCenter, TaskKind, TestCoreEnvBuilder}; use restate_partition_store::PartitionStoreManager; use restate_rocksdb::RocksDbManager; use restate_types::config::{CommonOptions, Configuration, RocksDbOptions, StorageOptions}; @@ -931,7 +930,7 @@ mod tests { nodes_config.upsert_node(node_config); let mut env_builder = - TestCoreEnvBuilder2::with_incoming_only_connector().set_nodes_config(nodes_config); + TestCoreEnvBuilder::with_incoming_only_connector().set_nodes_config(nodes_config); let health_status = HealthStatus::default(); RocksDbManager::init(Constant::new(CommonOptions::default())); @@ -959,10 +958,9 @@ mod tests { let processors_manager_handle = partition_processor_manager.handle(); bifrost_svc.start().await.into_test_result()?; - TaskCenter::current().spawn( + TaskCenter::spawn( TaskKind::SystemService, "partition-processor-manager", - None, partition_processor_manager.run(), )?; diff --git a/crates/worker/src/partition_processor_manager/persisted_lsn_watchdog.rs b/crates/worker/src/partition_processor_manager/persisted_lsn_watchdog.rs index 99b39c523..5a894966e 100644 --- a/crates/worker/src/partition_processor_manager/persisted_lsn_watchdog.rs +++ b/crates/worker/src/partition_processor_manager/persisted_lsn_watchdog.rs @@ -156,7 +156,7 @@ impl PersistedLogLsnWatchdog { #[cfg(test)] mod tests { use crate::partition_processor_manager::persisted_lsn_watchdog::PersistedLogLsnWatchdog; - use restate_core::{TaskKind, TestCoreEnv}; + use restate_core::{TaskCenter, TaskKind, TestCoreEnv}; use restate_partition_store::{OpenMode, PartitionStoreManager}; use restate_rocksdb::RocksDbManager; use restate_storage_api::fsm_table::FsmTable; @@ -172,15 +172,13 @@ mod tests { use tokio::sync::watch; use tokio::time::Instant; - #[test(tokio::test(start_paused = true))] + #[test(restate_core::test(start_paused = true))] async fn persisted_log_lsn_watchdog_detects_applied_lsns() -> anyhow::Result<()> { - let node_env = TestCoreEnv::create_with_single_node(1, 1).await; + let _node_env = TestCoreEnv::create_with_single_node(1, 1).await; let storage_options = StorageOptions::default(); let rocksdb_options = RocksDbOptions::default(); - node_env - .tc - .run_in_scope_sync(|| RocksDbManager::init(Constant::new(CommonOptions::default()))); + RocksDbManager::init(Constant::new(CommonOptions::default())); let all_partition_keys = RangeInclusive::new(0, PartitionKey::MAX); let partition_store_manager = PartitionStoreManager::create( @@ -210,12 +208,7 @@ mod tests { let now = Instant::now(); - node_env.tc.spawn( - TaskKind::Watchdog, - "persiste-log-lsn-test", - None, - watchdog.run(), - )?; + TaskCenter::spawn(TaskKind::Watchdog, "persiste-log-lsn-test", watchdog.run())?; assert!( tokio::time::timeout(Duration::from_secs(1), watch_rx.changed()) @@ -262,6 +255,8 @@ mod tests { Some(&next_persisted_lsn) ); + RocksDbManager::get().shutdown().await; + Ok(()) } } diff --git a/crates/worker/src/partition_processor_manager/processor_state.rs b/crates/worker/src/partition_processor_manager/processor_state.rs index 24eb61499..e7924a48b 100644 --- a/crates/worker/src/partition_processor_manager/processor_state.rs +++ b/crates/worker/src/partition_processor_manager/processor_state.rs @@ -283,10 +283,9 @@ impl ProcessorState { ) { match self { ProcessorState::Starting { .. } => { - let _ = TaskCenter::current().spawn( + let _ = TaskCenter::spawn( TaskKind::Disposable, "partition-processor-rpc", - None, async move { partition_processor_rpc .into_outgoing(Err(PartitionProcessorRpcError::Starting)) @@ -304,10 +303,9 @@ impl ProcessorState { { match err { TrySendError::Full(req) => { - let _ = TaskCenter::current().spawn( + let _ = TaskCenter::spawn( TaskKind::Disposable, "partition-processor-rpc", - None, async move { req.into_outgoing(Err(PartitionProcessorRpcError::Busy)) .send() @@ -317,10 +315,9 @@ impl ProcessorState { ); } TrySendError::Closed(req) => { - let _ = TaskCenter::current().spawn( + let _ = TaskCenter::spawn( TaskKind::Disposable, "partition-processor-rpc", - None, async move { req.into_outgoing(Err(PartitionProcessorRpcError::NotLeader( partition_id, @@ -335,10 +332,9 @@ impl ProcessorState { } } ProcessorState::Stopping { .. } => { - let _ = TaskCenter::current().spawn( + let _ = TaskCenter::spawn( TaskKind::Disposable, "partition-processor-rpc", - None, async move { partition_processor_rpc .into_outgoing(Err(PartitionProcessorRpcError::Stopping)) diff --git a/server/Cargo.toml b/server/Cargo.toml index 031273ef9..0c8a1c858 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -69,6 +69,7 @@ restate-bifrost = { workspace = true, features = ["test-util"] } restate-core = { workspace = true, features = ["test-util"] } restate-local-cluster-runner = { workspace = true } restate-test-util = { workspace = true } +restate-types = { workspace = true, features = ["test-util"] } anyhow = { workspace = true } async-trait = { workspace = true } diff --git a/server/src/main.rs b/server/src/main.rs index 98004b374..cf4d993eb 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -17,6 +17,7 @@ use std::time::Duration; use clap::Parser; use codederror::CodedError; +use restate_core::TaskCenter; use tokio::io; use tracing::error; use tracing::{info, trace, warn}; @@ -161,8 +162,7 @@ fn main() { .options(Configuration::pinned().common.clone()) .build() .expect("task_center builds"); - tc.block_on({ - let tc = tc.clone(); + tc.handle().block_on({ async move { // Apply tracing config globally // We need to apply this first to log correctly @@ -213,9 +213,9 @@ fn main() { // We ignore errors since we will wait for shutdown below anyway. // This starts node roles and the rest of the system async under tasks managed by // the TaskCenter. - let _ = tc.spawn(TaskKind::SystemBoot, "init", None, node.unwrap().start()); + let _ = TaskCenter::spawn(TaskKind::SystemBoot, "init", node.unwrap().start()); - let task_center_watch = tc.shutdown_token(); + let task_center_watch = TaskCenter::current().shutdown_token(); tokio::pin!(task_center_watch); let config_update_watcher = Configuration::watcher(); @@ -231,7 +231,7 @@ fn main() { let shutdown_with_timeout = tokio::time::timeout( Configuration::pinned().common.shutdown_grace_period(), async { - tc.shutdown_node(&signal_reason, 0).await; + TaskCenter::shutdown_node(&signal_reason, 0).await; rocksdb_manager.shutdown().await; } ); diff --git a/server/tests/cluster.rs b/server/tests/cluster.rs index 0fd6cd1a3..5cea3ef2a 100644 --- a/server/tests/cluster.rs +++ b/server/tests/cluster.rs @@ -24,7 +24,7 @@ use test_log::test; mod common; -#[test(tokio::test)] +#[test(restate_core::test)] async fn node_id_mismatch() -> googletest::Result<()> { let base_config = Configuration::default(); @@ -72,7 +72,7 @@ async fn node_id_mismatch() -> googletest::Result<()> { Ok(()) } -#[test(tokio::test)] +#[test(restate_core::test)] async fn cluster_name_mismatch() -> googletest::Result<()> { let base_config = Configuration::default(); @@ -115,7 +115,7 @@ async fn cluster_name_mismatch() -> googletest::Result<()> { Ok(()) } -#[test(tokio::test)] +#[test(restate_core::test)] async fn replicated_loglet() -> googletest::Result<()> { let mut base_config = Configuration::default(); base_config.bifrost.default_provider = ProviderKind::Replicated; diff --git a/server/tests/common/replicated_loglet.rs b/server/tests/common/replicated_loglet.rs index d326998ef..94efe9f6f 100644 --- a/server/tests/common/replicated_loglet.rs +++ b/server/tests/common/replicated_loglet.rs @@ -17,8 +17,8 @@ use googletest::IntoTestResult; use restate_bifrost::{loglet::Loglet, Bifrost, BifrostAdmin}; use restate_core::metadata_store::Precondition; -use restate_core::TaskCenterFutureExt; -use restate_core::{metadata_store::MetadataStoreClient, MetadataWriter, TaskCenterBuilder}; +use restate_core::TaskCenter; +use restate_core::{metadata_store::MetadataStoreClient, MetadataWriter}; use restate_local_cluster_runner::{ cluster::{Cluster, MaybeTempDir, StartedCluster}, node::{BinarySource, Node}, @@ -121,76 +121,66 @@ where log_server_count, ); - let tc = TaskCenterBuilder::default() - .default_runtime_handle(tokio::runtime::Handle::current()) - .ingress_runtime_handle(tokio::runtime::Handle::current()) - .build() - .expect("task_center builds"); - // ensure base dir lives longer than the node, otherwise it sees shutdown errors // this will still respect LOCAL_CLUSTER_RUNNER_RETAIN_TEMPDIR=true let base_dir: MaybeTempDir = tempfile::tempdir()?.into(); - async { - RocksDbManager::init(Configuration::mapped_updateable(|c| &c.common)); - - let cluster = Cluster::builder() - .base_dir(base_dir.as_path().to_owned()) - .nodes(nodes) - .build() - .start() - .await?; - - cluster.wait_healthy(Duration::from_secs(30)).await?; - - let loglet_params = ReplicatedLogletParams { - loglet_id: ReplicatedLogletId::new(LogId::from(1u32), SegmentIndex::OLDEST), - sequencer, - replication, - // node 1 is the metadata, 2..=count+1 are logservers - nodeset: (2..=log_server_count + 1).collect(), - }; - let loglet_params = loglet_params.serialize()?; - - let chain = Chain::new(ProviderKind::Replicated, LogletParams::from(loglet_params)); - let mut logs_builder = LogsBuilder::default(); - logs_builder.add_log(LogId::MIN, chain)?; - - let metadata_store_client = cluster.nodes[0] - .metadata_client() - .await - .map_err(|err| TestAssertionFailure::create(err.to_string()))?; - metadata_store_client - .put( - BIFROST_CONFIG_KEY.clone(), - &logs_builder.build(), - Precondition::None, - ) - .await?; - - // join a new node to the cluster solely to act as a bifrost client - // it will have node id log_server_count+2 - let (bifrost, loglet, metadata_writer, metadata_store_client) = replicated_loglet_client( - base_config, - &cluster, - PlainNodeId::new(log_server_count + 2), - ) + RocksDbManager::init(Configuration::mapped_updateable(|c| &c.common)); + + let cluster = Cluster::builder() + .base_dir(base_dir.as_path().to_owned()) + .nodes(nodes) + .build() + .start() .await?; - // global metadata should now be set, running in scope sets it in the task center context - future(TestEnv { - bifrost, - loglet, - cluster, - metadata_writer, - metadata_store_client, - }) + cluster.wait_healthy(Duration::from_secs(30)).await?; + + let loglet_params = ReplicatedLogletParams { + loglet_id: ReplicatedLogletId::new(LogId::from(1u32), SegmentIndex::OLDEST), + sequencer, + replication, + // node 1 is the metadata, 2..=count+1 are logservers + nodeset: (2..=log_server_count + 1).collect(), + }; + let loglet_params = loglet_params.serialize()?; + + let chain = Chain::new(ProviderKind::Replicated, LogletParams::from(loglet_params)); + let mut logs_builder = LogsBuilder::default(); + logs_builder.add_log(LogId::MIN, chain)?; + + let metadata_store_client = cluster.nodes[0] + .metadata_client() .await - } - .in_tc(&tc) + .map_err(|err| TestAssertionFailure::create(err.to_string()))?; + metadata_store_client + .put( + BIFROST_CONFIG_KEY.clone(), + &logs_builder.build(), + Precondition::None, + ) + .await?; + + // join a new node to the cluster solely to act as a bifrost client + // it will have node id log_server_count+2 + let (bifrost, loglet, metadata_writer, metadata_store_client) = replicated_loglet_client( + base_config, + &cluster, + PlainNodeId::new(log_server_count + 2), + ) + .await?; + + // global metadata should now be set, running in scope sets it in the task center context + future(TestEnv { + bifrost, + loglet, + cluster, + metadata_writer, + metadata_store_client, + }) .await?; - tc.shutdown_node("test completed", 0).await; + TaskCenter::shutdown_node("test completed", 0).await; RocksDbManager::get().shutdown().await; Ok(()) } diff --git a/server/tests/replicated_loglet.rs b/server/tests/replicated_loglet.rs index 9bbdbb751..c32c70008 100644 --- a/server/tests/replicated_loglet.rs +++ b/server/tests/replicated_loglet.rs @@ -22,7 +22,7 @@ mod tests { use futures_util::StreamExt; use googletest::prelude::*; use restate_bifrost::loglet::AppendError; - use restate_core::{cancellation_token, metadata, task_center}; + use restate_core::{cancellation_token, Metadata, TaskCenterFutureExt}; use test_log::test; use restate_types::{ @@ -49,7 +49,7 @@ mod tests { ) } - #[test(tokio::test)] + #[test(restate_core::test)] async fn test_append_local_sequencer_three_logserver() -> Result<()> { run_in_test_env( Configuration::default(), @@ -76,7 +76,7 @@ mod tests { .await } - #[test(tokio::test)] + #[test(restate_core::test)] async fn test_seal_local_sequencer_three_logserver() -> Result<()> { run_in_test_env( Configuration::default(), @@ -114,7 +114,7 @@ mod tests { .await } - #[test(tokio::test)] + #[test(restate_core::test)] async fn three_logserver_gapless_smoke_test() -> googletest::Result<()> { run_in_test_env( Configuration::default(), @@ -128,7 +128,7 @@ mod tests { .await } - #[test(tokio::test)] + #[test(restate_core::test)] async fn three_logserver_readstream() -> googletest::Result<()> { run_in_test_env( Configuration::default(), @@ -142,7 +142,7 @@ mod tests { .await } - #[test(tokio::test)] + #[test(restate_core::test)] async fn three_logserver_readstream_with_trims() -> googletest::Result<()> { // For this test to work, we need to disable the record cache to ensure we // observer the moving trimpoint. @@ -166,7 +166,7 @@ mod tests { .await } - #[test(tokio::test)] + #[test(restate_core::test)] async fn three_logserver_append_after_seal() -> googletest::Result<()> { run_in_test_env( Configuration::default(), @@ -178,7 +178,7 @@ mod tests { .await } - #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + #[restate_core::test(flavor = "multi_thread", worker_threads = 4)] async fn three_logserver_append_after_seal_concurrent() -> googletest::Result<()> { run_in_test_env( Configuration::default(), @@ -192,7 +192,7 @@ mod tests { .await } - #[test(tokio::test)] + #[test(restate_core::test)] async fn three_logserver_seal_empty() -> googletest::Result<()> { run_in_test_env( Configuration::default(), @@ -204,7 +204,7 @@ mod tests { .await } - #[test(tokio::test(flavor = "multi_thread", worker_threads = 4))] + #[test(restate_core::test(flavor = "multi_thread", worker_threads = 4))] async fn bifrost_append_and_seal_concurrent() -> googletest::Result<()> { const TEST_DURATION: Duration = Duration::from_secs(10); const SEAL_PERIOD: Duration = Duration::from_secs(1); @@ -218,8 +218,7 @@ mod tests { |test_env| async move { let log_id = LogId::new(0); - let tc = task_center(); - let metadata = metadata(); + let metadata = Metadata::current(); let mut appenders: JoinSet> = JoinSet::new(); @@ -229,30 +228,25 @@ mod tests { appenders.spawn({ let bifrost = test_env.bifrost.clone(); let cancel_appenders = cancel_appenders.clone(); - let tc = tc.clone(); async move { - tc.run_in_scope("append", None, async move { - let mut i = 1; - let mut committed = Vec::new(); - while !cancel_appenders.is_cancelled() { - let offset = bifrost - .append( - log_id, - format!("appender-{}-record{}", appender_id, i), - ) - .await?; - i += 1; - committed.push(offset); - } - Ok(committed) - }) - .await - } + let mut i = 1; + let mut committed = Vec::new(); + while !cancel_appenders.is_cancelled() { + let offset = bifrost + .append( + log_id, + format!("appender-{}-record{}", appender_id, i), + ) + .await?; + i += 1; + committed.push(offset); + } + Ok(committed) + }.in_current_tc() }); } let mut sealer_handle: JoinHandle> = tokio::task::spawn({ - let tc = tc.clone(); let (bifrost, metadata_writer, metadata_store_client) = ( test_env.bifrost.clone(), test_env.metadata_writer.clone(), @@ -270,37 +264,33 @@ mod tests { &metadata_store_client, ); - tc.run_in_scope("sealer", None, async move { - let mut last_loglet_id = None; - - while !cancellation_token.is_cancelled() { - tokio::time::sleep(SEAL_PERIOD).await; + let mut last_loglet_id = None; - let mut params = ReplicatedLogletParams::deserialize_from( - chain.live_load().tail().config.params.as_ref(), - )?; - if last_loglet_id == Some(params.loglet_id) { - fail!("Could not seal as metadata has not caught up from the last seal (version={})", metadata.logs_version())?; - } - last_loglet_id = Some(params.loglet_id); - eprintln!("Sealing loglet {} and creating new loglet {}", params.loglet_id, params.loglet_id.next()); - params.loglet_id = params.loglet_id.next(); + while !cancellation_token.is_cancelled() { + tokio::time::sleep(SEAL_PERIOD).await; - bifrost_admin - .seal_and_extend_chain( - log_id, - None, - Version::MIN, - ProviderKind::Replicated, - LogletParams::from(params.serialize()?), - ) - .await?; + let mut params = ReplicatedLogletParams::deserialize_from( + chain.live_load().tail().config.params.as_ref(), + )?; + if last_loglet_id == Some(params.loglet_id) { + fail!("Could not seal as metadata has not caught up from the last seal (version={})", metadata.logs_version())?; } - - Ok(()) - }) - .await - } + last_loglet_id = Some(params.loglet_id); + eprintln!("Sealing loglet {} and creating new loglet {}", params.loglet_id, params.loglet_id.next()); + params.loglet_id = params.loglet_id.next(); + + bifrost_admin + .seal_and_extend_chain( + log_id, + None, + Version::MIN, + ProviderKind::Replicated, + LogletParams::from(params.serialize()?), + ) + .await?; + } + Ok(()) + }.in_current_tc() }); tokio::select! { diff --git a/tools/bifrost-benchpress/src/append_latency.rs b/tools/bifrost-benchpress/src/append_latency.rs index 864a2ff85..37d5a82df 100644 --- a/tools/bifrost-benchpress/src/append_latency.rs +++ b/tools/bifrost-benchpress/src/append_latency.rs @@ -15,7 +15,6 @@ use hdrhistogram::Histogram; use tracing::info; use restate_bifrost::Bifrost; -use restate_core::TaskCenter; use restate_types::logs::{LogId, WithKeys}; use crate::util::{print_latencies, DummyPayload}; @@ -37,7 +36,6 @@ pub struct AppendLatencyOpts { pub async fn run( _common_args: &Arguments, args: &AppendLatencyOpts, - _tc: TaskCenter, bifrost: Bifrost, ) -> anyhow::Result<()> { let blob = BytesMut::zeroed(args.payload_size).freeze(); diff --git a/tools/bifrost-benchpress/src/main.rs b/tools/bifrost-benchpress/src/main.rs index 902e4b68e..28512d130 100644 --- a/tools/bifrost-benchpress/src/main.rs +++ b/tools/bifrost-benchpress/src/main.rs @@ -14,13 +14,15 @@ use std::time::Duration; use clap::Parser; use codederror::CodedError; use metrics_exporter_prometheus::PrometheusBuilder; +use restate_core::task_center::TaskCenterMonitoring; use tracing::trace; use bifrost_benchpress::util::{print_prometheus_stats, print_rocksdb_stats}; use bifrost_benchpress::{append_latency, write_to_read, Arguments, Command}; use restate_bifrost::{Bifrost, BifrostService}; use restate_core::{ - spawn_metadata_manager, MetadataBuilder, MetadataManager, TaskCenter, TaskCenterBuilder, + spawn_metadata_manager, task_center, MetadataBuilder, MetadataManager, TaskCenter, + TaskCenterBuilder, }; use restate_errors::fmt::RestateCode; use restate_metadata_store::{MetadataStoreClient, Precondition}; @@ -98,10 +100,10 @@ fn main() -> anyhow::Result<()> { match args.command { Command::WriteToRead(ref opts) => { - write_to_read::run(&args, opts, task_center.clone(), bifrost).await?; + write_to_read::run(&args, opts, bifrost).await?; } Command::AppendLatency(ref opts) => { - append_latency::run(&args, opts, task_center.clone(), bifrost).await?; + append_latency::run(&args, opts, bifrost).await?; } } // record tokio's runtime metrics @@ -142,11 +144,12 @@ fn main() -> anyhow::Result<()> { Ok(()) } -fn spawn_environment(config: Live, num_logs: u16) -> (TaskCenter, Bifrost) { +fn spawn_environment(config: Live, num_logs: u16) -> (task_center::Handle, Bifrost) { let tc = TaskCenterBuilder::default() .options(config.pinned().common.clone()) .build() - .expect("task_center builds"); + .expect("task_center builds") + .to_handle(); let bifrost = tc.block_on(async move { let metadata_builder = MetadataBuilder::default(); diff --git a/tools/bifrost-benchpress/src/write_to_read.rs b/tools/bifrost-benchpress/src/write_to_read.rs index ee2e13505..f4f58f63a 100644 --- a/tools/bifrost-benchpress/src/write_to_read.rs +++ b/tools/bifrost-benchpress/src/write_to_read.rs @@ -45,18 +45,13 @@ pub struct WriteToReadOpts { payload_size: usize, } -pub async fn run( - _common_args: &Arguments, - args: &WriteToReadOpts, - tc: TaskCenter, - bifrost: Bifrost, -) -> Result<()> { +pub async fn run(_common_args: &Arguments, args: &WriteToReadOpts, bifrost: Bifrost) -> Result<()> { let clock = quanta::Clock::new(); // Create two tasks, one that writes to the log continously and another one that reads from the // log and measures the latency. Collect latencies in a histogram and print the histogram // before the test ends. let reader_task: TaskHandle> = - tc.spawn_unmanaged(TaskKind::PartitionProcessor, "test-log-reader", None, { + TaskCenter::spawn_unmanaged(TaskKind::PartitionProcessor, "test-log-reader", { let clock = clock.clone(); let args = args.clone(); let bifrost = bifrost.clone(); @@ -91,7 +86,7 @@ pub async fn run( })?; let writer_task: TaskHandle> = - tc.spawn_unmanaged(TaskKind::PartitionProcessor, "test-log-appender", None, { + TaskCenter::spawn_unmanaged(TaskKind::PartitionProcessor, "test-log-appender", { let clock = clock.clone(); let bifrost = bifrost.clone(); let args = args.clone(); @@ -104,7 +99,7 @@ pub async fn run( args.write_buffer_size, args.max_batch_size, )? - .start("writer", None)?; + .start("writer")?; let sender = appender_handle.sender(); let start = Instant::now(); for counter in 1..=args.num_records { @@ -145,6 +140,6 @@ pub async fn run( println!(); print_latencies("append latency", append_latency); print_latencies("write-to-read latency", read_latency); - tc.shutdown_node("completed", 0).await; + TaskCenter::shutdown_node("completed", 0).await; Ok(()) } diff --git a/tools/restatectl/src/commands/log/dump_log.rs b/tools/restatectl/src/commands/log/dump_log.rs index 30ee43d5d..caedaaf71 100644 --- a/tools/restatectl/src/commands/log/dump_log.rs +++ b/tools/restatectl/src/commands/log/dump_log.rs @@ -80,7 +80,6 @@ async fn dump_log(opts: &DumpLogOpts) -> anyhow::Result<()> { Live::from_value(config.metadata_store.clone()) .map(|c| &c.rocksdb) .boxed(), - &TaskCenter::current(), ) .await?; debug!("Metadata store client created"); @@ -90,10 +89,9 @@ async fn dump_log(opts: &DumpLogOpts) -> anyhow::Result<()> { let mut router_builder = MessageRouterBuilder::default(); metadata_manager.register_in_message_router(&mut router_builder); - TaskCenter::current().spawn( + TaskCenter::spawn( TaskKind::SystemService, "metadata-manager", - None, metadata_manager.run(), )?; diff --git a/tools/restatectl/src/commands/metadata/get.rs b/tools/restatectl/src/commands/metadata/get.rs index 6a1533dc9..bb483bca8 100644 --- a/tools/restatectl/src/commands/metadata/get.rs +++ b/tools/restatectl/src/commands/metadata/get.rs @@ -13,7 +13,6 @@ use clap::Parser; use cling::{Collect, Run}; use tracing::debug; -use restate_core::TaskCenter; use restate_rocksdb::RocksDbManager; use restate_types::config::Configuration; use restate_types::live::Live; @@ -68,7 +67,6 @@ async fn get_value_direct(opts: &GetValueOpts) -> anyhow::Result, updateables_rocksdb_options: BoxedLiveLoad, - task_center: &TaskCenter, ) -> anyhow::Result { let health_status = HealthStatus::default(); let service = LocalMetadataStoreService::from_options( @@ -31,10 +30,9 @@ pub async fn start_metadata_store( updateables_rocksdb_options, ); - task_center.spawn( + TaskCenter::spawn( TaskKind::MetadataStore, "local-metadata-store", - None, async move { service.run().await?; Ok(()) diff --git a/tools/restatectl/src/environment/task_center.rs b/tools/restatectl/src/environment/task_center.rs index 5e286aa4b..73d576107 100644 --- a/tools/restatectl/src/environment/task_center.rs +++ b/tools/restatectl/src/environment/task_center.rs @@ -55,9 +55,10 @@ where .ingress_runtime_handle(tokio::runtime::Handle::current()) .options(config.common.clone()) .build() - .expect("task_center builds"); + .expect("task_center builds") + .to_handle(); - let result = task_center.run_in_scope_sync(|| fn_body(config)).await; + let result = task_center.run_sync(|| fn_body(config)).await; task_center.shutdown_node("finished", 0).await; result diff --git a/tools/xtask/src/main.rs b/tools/xtask/src/main.rs index 8a068bfd1..82155c9b4 100644 --- a/tools/xtask/src/main.rs +++ b/tools/xtask/src/main.rs @@ -18,8 +18,8 @@ use schemars::gen::SchemaSettings; use restate_admin::service::AdminService; use restate_bifrost::Bifrost; -use restate_core::{TaskCenter, TaskCenterBuilder, TestCoreEnv2}; -use restate_core::{TaskCenterFutureExt, TaskKind}; +use restate_core::TaskKind; +use restate_core::{TaskCenter, TaskCenterBuilder, TestCoreEnv}; use restate_service_client::{AssumeRoleCacheMode, ServiceClient}; use restate_service_protocol::discovery::ServiceDiscovery; use restate_storage_query_datafusion::table_docs; @@ -101,58 +101,45 @@ async fn generate_rest_api_doc() -> anyhow::Result<()> { config.admin.bind_address.port() ); - let tc = TaskCenterBuilder::default_for_tests() - .build() - .expect("building task-center should not fail"); - async { - // We start the Meta service, then download the openapi schema generated - let node_env = TestCoreEnv2::create_with_single_node(1, 1).await; - let bifrost = Bifrost::init_in_memory().await; - - let admin_service = AdminService::new( - node_env.metadata_writer.clone(), - node_env.metadata_store_client.clone(), - bifrost, - Mock, - ServiceDiscovery::new( - RetryPolicy::default(), - ServiceClient::from_options( - &config.common.service_client, - AssumeRoleCacheMode::None, - ) + // We start the Meta service, then download the openapi schema generated + let node_env = TestCoreEnv::create_with_single_node(1, 1).await; + let bifrost = Bifrost::init_in_memory().await; + + let admin_service = AdminService::new( + node_env.metadata_writer.clone(), + node_env.metadata_store_client.clone(), + bifrost, + Mock, + ServiceDiscovery::new( + RetryPolicy::default(), + ServiceClient::from_options(&config.common.service_client, AssumeRoleCacheMode::None) .unwrap(), - ), - false, - None, - ); - - TaskCenter::current().spawn( - TaskKind::TestRunner, - "doc-gen", - None, - admin_service.run(Constant::new(config.admin)), - )?; + ), + false, + None, + ); - let res = RetryPolicy::fixed_delay(Duration::from_millis(100), Some(20)) - .retry(|| async { - reqwest::Client::builder() - .build()? - .get(openapi_address.clone()) - .header(ACCEPT, "application/json") - .send() - .await? - .text() - .await - }) - .await - .unwrap(); - - println!("{}", res); - anyhow::Ok(()) - } - .in_tc(&tc) - .await?; - tc.shutdown_node("completed", 0).await; + TaskCenter::spawn( + TaskKind::TestRunner, + "doc-gen", + admin_service.run(Constant::new(config.admin)), + )?; + + let res = RetryPolicy::fixed_delay(Duration::from_millis(100), Some(20)) + .retry(|| async { + reqwest::Client::builder() + .build()? + .get(openapi_address.clone()) + .header(ACCEPT, "application/json") + .send() + .await? + .text() + .await + }) + .await + .unwrap(); + + println!("{}", res); Ok(()) } @@ -219,13 +206,17 @@ Tasks: #[tokio::main] async fn main() -> anyhow::Result<()> { + let tc = TaskCenterBuilder::default_for_tests() + .build() + .expect("building task-center should not fail") + .to_handle(); let task = env::args().nth(1); match task { None => print_help(), Some(t) => match t.as_str() { "generate-config-schema" => generate_config_schema()?, "generate-default-config" => generate_default_config(), - "generate-rest-api-doc" => generate_rest_api_doc().await?, + "generate-rest-api-doc" => tc.block_on(generate_rest_api_doc())?, "generate-table-docs" => generate_table_docs()?, invalid => { print_help(); @@ -233,5 +224,6 @@ async fn main() -> anyhow::Result<()> { } }, }; + tc.shutdown_node("completed", 0).await; Ok(()) }