diff --git a/nexus/db-queries/src/db/saga_recovery.rs b/nexus/db-queries/src/db/saga_recovery.rs index eb2003508c..e85011f60f 100644 --- a/nexus/db-queries/src/db/saga_recovery.rs +++ b/nexus/db-queries/src/db/saga_recovery.rs @@ -6,6 +6,7 @@ use crate::context::OpContext; use crate::db; +use crate::db::datastore::SQL_BATCH_SIZE; use crate::db::error::public_error_from_diesel; use crate::db::error::ErrorHandler; use crate::db::pagination::{paginated, paginated_multicolumn, Paginator}; @@ -21,7 +22,6 @@ use omicron_common::backoff::retry_notify; use omicron_common::backoff::retry_policy_internal_service; use omicron_common::backoff::BackoffError; use std::future::Future; -use std::num::NonZeroU32; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -171,11 +171,6 @@ where })) } -const BATCH_SIZE: NonZeroU32 = unsafe { - // Safety: 100 is more than zero - NonZeroU32::new_unchecked(100) -}; - /// Queries the database to return a list of uncompleted sagas assigned to SEC /// `sec_id` // For now, we do the simplest thing: we fetch all the sagas that the @@ -199,8 +194,8 @@ async fn list_unfinished_sagas( // risks blocking the DB for an unreasonable amount of time. Instead, // we paginate to avoid cutting off availability to the DB. let mut sagas = vec![]; - let mut paginator = Paginator::new(BATCH_SIZE); - let conn = datastore.pool_connection_unauthorized().await?; + let mut paginator = Paginator::new(SQL_BATCH_SIZE); + let conn = datastore.pool_connection_authorized(opctx).await?; while let Some(p) = paginator.next() { use db::schema::saga::dsl; @@ -257,7 +252,7 @@ where "saga_name" => saga_name.clone(), ); - let log_events = load_saga_log(datastore, &saga).await?; + let log_events = load_saga_log(&opctx, datastore, &saga).await?; trace!( opctx.log, "recovering saga: loaded log"; @@ -295,6 +290,7 @@ where /// Queries the database to load the full log for the specified saga async fn load_saga_log( + opctx: &OpContext, datastore: &db::DataStore, saga: &db::saga_types::Saga, ) -> Result, Error> { @@ -304,8 +300,8 @@ async fn load_saga_log( // risks blocking the DB for an unreasonable amount of time. Instead, // we paginate to avoid cutting off availability. let mut events = vec![]; - let mut paginator = Paginator::new(BATCH_SIZE); - let conn = datastore.pool_connection_unauthorized().await?; + let mut paginator = Paginator::new(SQL_BATCH_SIZE); + let conn = datastore.pool_connection_authorized(opctx).await?; while let Some(p) = paginator.next() { use db::schema::saga_node_event::dsl; let batch = paginated_multicolumn( @@ -316,15 +312,7 @@ async fn load_saga_log( .filter(dsl::saga_id.eq(saga.id)) .select(db::saga_types::SagaNodeEvent::as_select()) .load_async(&*conn) - .map_err(|e| { - public_error_from_diesel( - e, - ErrorHandler::NotFoundByLookup( - ResourceType::SagaDbg, - LookupType::ById(saga.id.0 .0), - ), - ) - }) + .map_err(|e| public_error_from_diesel(e, ErrorHandler::Server)) .await?; paginator = p.found_batch(&batch, &|row| (row.node_id, row.event_type.clone())); @@ -640,7 +628,7 @@ mod test { db::model::saga_types::Saga::new(sec_id, params) }; - let mut inserted_sagas = (0..BATCH_SIZE.get() * 2) + let mut inserted_sagas = (0..SQL_BATCH_SIZE.get() * 2) .map(|_| new_running_db_saga()) .collect::>(); @@ -696,6 +684,10 @@ mod test { let log = logctx.log.new(o!()); let (mut db, db_datastore) = new_db(&log).await; let sec_id = db::SecId(uuid::Uuid::new_v4()); + let opctx = OpContext::for_tests( + log, + Arc::clone(&db_datastore) as Arc, + ); let saga_id = steno::SagaId(Uuid::new_v4()); // Create a couple batches of saga events @@ -709,7 +701,7 @@ mod test { db::model::saga_types::SagaNodeEvent::new(event, sec_id) }; - let mut inserted_nodes = (0..BATCH_SIZE.get() * 2) + let mut inserted_nodes = (0..SQL_BATCH_SIZE.get() * 2) .flat_map(|i| { // This isn't an exhaustive list of event types, but gives us a few // options to pick from. Since this is a pagination key, it's @@ -750,7 +742,7 @@ mod test { state: steno::SagaCachedState::Running, }; let saga = db::model::saga_types::Saga::new(sec_id, params); - let observed_nodes = load_saga_log(&db_datastore, &saga) + let observed_nodes = load_saga_log(&opctx, &db_datastore, &saga) .await .expect("Failed to list unfinished nodes"); inserted_nodes.sort_by_key(|a| (a.node_id, a.event_type.clone())); @@ -777,4 +769,37 @@ mod test { db.cleanup().await.unwrap(); logctx.cleanup_successful(); } + + #[tokio::test] + async fn test_list_no_unfinished_nodes() { + // Test setup + let logctx = dev::test_setup_log("test_list_no_unfinished_nodes"); + let log = logctx.log.new(o!()); + let (mut db, db_datastore) = new_db(&log).await; + let sec_id = db::SecId(uuid::Uuid::new_v4()); + let opctx = OpContext::for_tests( + log, + Arc::clone(&db_datastore) as Arc, + ); + let saga_id = steno::SagaId(Uuid::new_v4()); + + let params = steno::SagaCreateParams { + id: saga_id, + name: steno::SagaName::new("test saga"), + dag: serde_json::value::Value::Null, + state: steno::SagaCachedState::Running, + }; + let saga = db::model::saga_types::Saga::new(sec_id, params); + + // Test that this returns "no nodes" rather than throwing some "not + // found" error. + let observed_nodes = load_saga_log(&opctx, &db_datastore, &saga) + .await + .expect("Failed to list unfinished nodes"); + assert_eq!(observed_nodes.len(), 0); + + // Test cleanup + db.cleanup().await.unwrap(); + logctx.cleanup_successful(); + } }