Skip to content

Commit

Permalink
review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
smklein committed Jun 27, 2024
1 parent f55b9d7 commit d7ed0b0
Showing 1 changed file with 48 additions and 23 deletions.
71 changes: 48 additions & 23 deletions nexus/db-queries/src/db/saga_recovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
use crate::context::OpContext;
use crate::db;
use crate::db::datastore::SQL_BATCH_SIZE;
use crate::db::error::public_error_from_diesel;
use crate::db::error::ErrorHandler;
use crate::db::pagination::{paginated, paginated_multicolumn, Paginator};
Expand All @@ -21,7 +22,6 @@ use omicron_common::backoff::retry_notify;
use omicron_common::backoff::retry_policy_internal_service;
use omicron_common::backoff::BackoffError;
use std::future::Future;
use std::num::NonZeroU32;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
Expand Down Expand Up @@ -171,11 +171,6 @@ where
}))
}

const BATCH_SIZE: NonZeroU32 = unsafe {
// Safety: 100 is more than zero
NonZeroU32::new_unchecked(100)
};

/// Queries the database to return a list of uncompleted sagas assigned to SEC
/// `sec_id`
// For now, we do the simplest thing: we fetch all the sagas that the
Expand All @@ -199,8 +194,8 @@ async fn list_unfinished_sagas(
// risks blocking the DB for an unreasonable amount of time. Instead,
// we paginate to avoid cutting off availability to the DB.
let mut sagas = vec![];
let mut paginator = Paginator::new(BATCH_SIZE);
let conn = datastore.pool_connection_unauthorized().await?;
let mut paginator = Paginator::new(SQL_BATCH_SIZE);
let conn = datastore.pool_connection_authorized(opctx).await?;
while let Some(p) = paginator.next() {
use db::schema::saga::dsl;

Expand Down Expand Up @@ -257,7 +252,7 @@ where
"saga_name" => saga_name.clone(),
);

let log_events = load_saga_log(datastore, &saga).await?;
let log_events = load_saga_log(&opctx, datastore, &saga).await?;
trace!(
opctx.log,
"recovering saga: loaded log";
Expand Down Expand Up @@ -295,6 +290,7 @@ where

/// Queries the database to load the full log for the specified saga
async fn load_saga_log(
opctx: &OpContext,
datastore: &db::DataStore,
saga: &db::saga_types::Saga,
) -> Result<Vec<steno::SagaNodeEvent>, Error> {
Expand All @@ -304,8 +300,8 @@ async fn load_saga_log(
// risks blocking the DB for an unreasonable amount of time. Instead,
// we paginate to avoid cutting off availability.
let mut events = vec![];
let mut paginator = Paginator::new(BATCH_SIZE);
let conn = datastore.pool_connection_unauthorized().await?;
let mut paginator = Paginator::new(SQL_BATCH_SIZE);
let conn = datastore.pool_connection_authorized(opctx).await?;
while let Some(p) = paginator.next() {
use db::schema::saga_node_event::dsl;
let batch = paginated_multicolumn(
Expand All @@ -316,15 +312,7 @@ async fn load_saga_log(
.filter(dsl::saga_id.eq(saga.id))
.select(db::saga_types::SagaNodeEvent::as_select())
.load_async(&*conn)
.map_err(|e| {
public_error_from_diesel(
e,
ErrorHandler::NotFoundByLookup(
ResourceType::SagaDbg,
LookupType::ById(saga.id.0 .0),
),
)
})
.map_err(|e| public_error_from_diesel(e, ErrorHandler::Server))
.await?;
paginator =
p.found_batch(&batch, &|row| (row.node_id, row.event_type.clone()));
Expand Down Expand Up @@ -640,7 +628,7 @@ mod test {

db::model::saga_types::Saga::new(sec_id, params)
};
let mut inserted_sagas = (0..BATCH_SIZE.get() * 2)
let mut inserted_sagas = (0..SQL_BATCH_SIZE.get() * 2)
.map(|_| new_running_db_saga())
.collect::<Vec<_>>();

Expand Down Expand Up @@ -696,6 +684,10 @@ mod test {
let log = logctx.log.new(o!());
let (mut db, db_datastore) = new_db(&log).await;
let sec_id = db::SecId(uuid::Uuid::new_v4());
let opctx = OpContext::for_tests(
log,
Arc::clone(&db_datastore) as Arc<dyn nexus_auth::storage::Storage>,
);
let saga_id = steno::SagaId(Uuid::new_v4());

// Create a couple batches of saga events
Expand All @@ -709,7 +701,7 @@ mod test {

db::model::saga_types::SagaNodeEvent::new(event, sec_id)
};
let mut inserted_nodes = (0..BATCH_SIZE.get() * 2)
let mut inserted_nodes = (0..SQL_BATCH_SIZE.get() * 2)
.flat_map(|i| {
// This isn't an exhaustive list of event types, but gives us a few
// options to pick from. Since this is a pagination key, it's
Expand Down Expand Up @@ -750,7 +742,7 @@ mod test {
state: steno::SagaCachedState::Running,
};
let saga = db::model::saga_types::Saga::new(sec_id, params);
let observed_nodes = load_saga_log(&db_datastore, &saga)
let observed_nodes = load_saga_log(&opctx, &db_datastore, &saga)
.await
.expect("Failed to list unfinished nodes");
inserted_nodes.sort_by_key(|a| (a.node_id, a.event_type.clone()));
Expand All @@ -777,4 +769,37 @@ mod test {
db.cleanup().await.unwrap();
logctx.cleanup_successful();
}

#[tokio::test]
async fn test_list_no_unfinished_nodes() {
// Test setup
let logctx = dev::test_setup_log("test_list_no_unfinished_nodes");
let log = logctx.log.new(o!());
let (mut db, db_datastore) = new_db(&log).await;
let sec_id = db::SecId(uuid::Uuid::new_v4());
let opctx = OpContext::for_tests(
log,
Arc::clone(&db_datastore) as Arc<dyn nexus_auth::storage::Storage>,
);
let saga_id = steno::SagaId(Uuid::new_v4());

let params = steno::SagaCreateParams {
id: saga_id,
name: steno::SagaName::new("test saga"),
dag: serde_json::value::Value::Null,
state: steno::SagaCachedState::Running,
};
let saga = db::model::saga_types::Saga::new(sec_id, params);

// Test that this returns "no nodes" rather than throwing some "not
// found" error.
let observed_nodes = load_saga_log(&opctx, &db_datastore, &saga)
.await
.expect("Failed to list unfinished nodes");
assert_eq!(observed_nodes.len(), 0);

// Test cleanup
db.cleanup().await.unwrap();
logctx.cleanup_successful();
}
}

0 comments on commit d7ed0b0

Please sign in to comment.