diff --git a/update-engine/examples/update-engine-basic/display.rs b/update-engine/examples/update-engine-basic/display.rs index 122777211b..891bdce6d3 100644 --- a/update-engine/examples/update-engine-basic/display.rs +++ b/update-engine/examples/update-engine-basic/display.rs @@ -88,6 +88,7 @@ async fn display_group( slog::info!(log, "setting up display"); let mut display = GroupDisplay::new( + log, [ (GroupDisplayKey::Example, "example"), (GroupDisplayKey::Other, "other"), diff --git a/update-engine/src/buffer.rs b/update-engine/src/buffer.rs index 6e0e66d6d0..36a0626963 100644 --- a/update-engine/src/buffer.rs +++ b/update-engine/src/buffer.rs @@ -1627,6 +1627,16 @@ pub enum TerminalKind { Aborted, } +impl fmt::Display for TerminalKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Completed => write!(f, "completed"), + Self::Failed => write!(f, "failed"), + Self::Aborted => write!(f, "aborted"), + } + } +} + impl ExecutionStatus { /// Returns the terminal status and the total amount of time elapsed, or /// None if the execution has not reached a terminal state. @@ -1671,17 +1681,13 @@ mod tests { use std::collections::HashSet; use anyhow::{bail, ensure, Context}; - use futures::StreamExt; use indexmap::IndexSet; use omicron_test_utils::dev::test_setup_log; use serde::{de::IntoDeserializer, Deserialize}; - use tokio::sync::mpsc; - use tokio_stream::wrappers::ReceiverStream; use crate::{ - events::{ProgressCounter, ProgressUnits, StepProgress}, - test_utils::TestSpec, - StepContext, StepSuccess, UpdateEngine, + events::ProgressCounter, + test_utils::{generate_test_events, GenerateTestEventsKind, TestSpec}, }; use super::*; @@ -1689,108 +1695,11 @@ mod tests { #[tokio::test] async fn test_buffer() { let logctx = test_setup_log("test_buffer"); - // The channel is big enough to contain all possible events. - let (sender, receiver) = mpsc::channel(512); - let engine: UpdateEngine = - UpdateEngine::new(&logctx.log, sender); - - engine - .new_step("foo".to_owned(), 1, "Step 1", move |_cx| async move { - StepSuccess::new(()).into() - }) - .register(); - - engine - .new_step("bar".to_owned(), 2, "Step 2", move |cx| async move { - for _ in 0..20 { - cx.send_progress(StepProgress::with_current_and_total( - 5, - 20, - ProgressUnits::BYTES, - Default::default(), - )) - .await; - - cx.send_progress(StepProgress::reset( - Default::default(), - "reset step 2", - )) - .await; - - cx.send_progress(StepProgress::retry("retry step 2")).await; - } - StepSuccess::new(()).into() - }) - .register(); - - engine - .new_step( - "nested".to_owned(), - 3, - "Step 3 (this is nested)", - move |parent_cx| async move { - parent_cx - .with_nested_engine(|engine| { - define_nested_engine(&parent_cx, engine); - Ok(()) - }) - .await - .expect_err("this is expected to fail"); - - StepSuccess::new(()).into() - }, - ) - .register(); - - let log = logctx.log.clone(); - engine - .new_step( - "remote-nested".to_owned(), - 20, - "Step 4 (remote nested)", - move |cx| async move { - let (sender, mut receiver) = mpsc::channel(16); - let mut engine = UpdateEngine::new(&log, sender); - define_remote_nested_engine(&mut engine, 20); - - let mut buffer = EventBuffer::default(); - - let mut execute_fut = std::pin::pin!(engine.execute()); - let mut execute_done = false; - loop { - tokio::select! { - res = &mut execute_fut, if !execute_done => { - res.expect("remote nested engine completed successfully"); - execute_done = true; - } - Some(event) = receiver.recv() => { - // Generate complete reports to ensure deduping - // happens within StepContexts. - buffer.add_event(event); - cx.send_nested_report(buffer.generate_report()).await?; - } - else => { - break; - } - } - } - - StepSuccess::new(()).into() - }, - ) - .register(); - - // The step index here (100) is large enough to be higher than all nested - // steps. - engine - .new_step("baz".to_owned(), 100, "Step 5", move |_cx| async move { - StepSuccess::new(()).into() - }) - .register(); - - engine.execute().await.expect("execution successful"); - let generated_events: Vec<_> = - ReceiverStream::new(receiver).collect().await; + let generated_events = generate_test_events( + &logctx.log, + GenerateTestEventsKind::Completed, + ) + .await; let test_cx = BufferTestContext::new(generated_events); @@ -2417,71 +2326,6 @@ mod tests { } } - fn define_nested_engine<'a>( - parent_cx: &'a StepContext, - engine: &mut UpdateEngine<'a, TestSpec>, - ) { - engine - .new_step( - "nested-foo".to_owned(), - 4, - "Nested step 1", - move |cx| async move { - parent_cx - .send_progress(StepProgress::with_current_and_total( - 1, - 3, - "steps", - Default::default(), - )) - .await; - cx.send_progress( - StepProgress::progress(Default::default()), - ) - .await; - StepSuccess::new(()).into() - }, - ) - .register(); - - engine - .new_step::<_, _, ()>( - "nested-bar".to_owned(), - 5, - "Nested step 2 (fails)", - move |cx| async move { - // This is used by NestedProgressCheck below. - parent_cx - .send_progress(StepProgress::with_current_and_total( - 2, - 3, - "steps", - Default::default(), - )) - .await; - - cx.send_progress(StepProgress::with_current( - 50, - "units", - Default::default(), - )) - .await; - - parent_cx - .send_progress(StepProgress::with_current_and_total( - 3, - 3, - "steps", - Default::default(), - )) - .await; - - bail!("failing step") - }, - ) - .register(); - } - #[derive(Clone, Copy, Debug, PartialEq, Eq)] enum NestedProgressCheck { Initial, @@ -2530,42 +2374,4 @@ mod tests { ); } } - - fn define_remote_nested_engine( - engine: &mut UpdateEngine<'_, TestSpec>, - start_id: usize, - ) { - engine - .new_step( - "nested-foo".to_owned(), - start_id + 1, - "Nested step 1", - move |cx| async move { - cx.send_progress( - StepProgress::progress(Default::default()), - ) - .await; - StepSuccess::new(()).into() - }, - ) - .register(); - - engine - .new_step::<_, _, ()>( - "nested-bar".to_owned(), - start_id + 2, - "Nested step 2", - move |cx| async move { - cx.send_progress(StepProgress::with_current( - 20, - "units", - Default::default(), - )) - .await; - - StepSuccess::new(()).into() - }, - ) - .register(); - } } diff --git a/update-engine/src/display/group_display.rs b/update-engine/src/display/group_display.rs index 0d50489a9f..cfd37aac16 100644 --- a/update-engine/src/display/group_display.rs +++ b/update-engine/src/display/group_display.rs @@ -30,6 +30,7 @@ use super::{ pub struct GroupDisplay { // We don't need to add any buffering here because we already write data to // the writer in a line-buffered fashion (see Self::write_events). + log: slog::Logger, writer: W, max_width: usize, // This is set to the highest value of root_total_elapsed seen from any event reports. @@ -45,6 +46,7 @@ impl GroupDisplay { /// /// The function passed in is expected to create a writer. pub fn new( + log: &slog::Logger, keys_and_prefixes: impl IntoIterator, writer: W, ) -> Self @@ -70,6 +72,7 @@ impl GroupDisplay { let not_started = single_states.len(); Self { + log: log.new(slog::o!("component" => "GroupDisplay")), writer, max_width, // This creates the stopwatch in the stopped state with duration 0 -- i.e. a minimal @@ -84,6 +87,7 @@ impl GroupDisplay { /// Creates a new `GroupDisplay` with the provided report keys, using the /// `Display` impl to obtain the respective prefixes. pub fn new_with_display( + log: &slog::Logger, keys: impl IntoIterator, writer: W, ) -> Self @@ -91,6 +95,7 @@ impl GroupDisplay { K: fmt::Display, { Self::new( + log, keys.into_iter().map(|k| { let prefix = k.to_string(); (k, prefix) @@ -144,7 +149,30 @@ impl GroupDisplay { TokioSw::with_elapsed_started(root_total_elapsed); } } + self.stats.apply_result(result); + + if result.before != result.after { + slog::info!( + self.log, + "add_event_report caused state transition"; + "prefix" => &state.prefix, + "before" => %result.before, + "after" => %result.after, + "current_stats" => ?self.stats, + "root_total_elapsed" => ?result.root_total_elapsed, + ); + } else { + slog::trace!( + self.log, + "add_event_report called, state did not change"; + "prefix" => &state.prefix, + "state" => %result.before, + "current_stats" => ?self.stats, + "root_total_elapsed" => ?result.root_total_elapsed, + ); + } + Ok(()) } else { Err(UnknownReportKey {}) @@ -179,7 +207,7 @@ impl GroupDisplay { } } -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] pub struct GroupDisplayStats { /// The total number of reports. pub total: usize, @@ -236,18 +264,9 @@ impl GroupDisplayStats { } fn apply_result(&mut self, result: AddEventReportResult) { - // Process result.after first to avoid integer underflow. - match result.after { - SingleStateTag::NotStarted => self.not_started += 1, - SingleStateTag::Running => self.running += 1, - SingleStateTag::Terminal(TerminalKind::Completed) => { - self.completed += 1 - } - SingleStateTag::Terminal(TerminalKind::Failed) => self.failed += 1, - SingleStateTag::Terminal(TerminalKind::Aborted) => { - self.aborted += 1 - } - SingleStateTag::Overwritten => self.overwritten += 1, + if result.before == result.after { + // Nothing to do. + return; } match result.before { @@ -262,6 +281,19 @@ impl GroupDisplayStats { } SingleStateTag::Overwritten => self.overwritten -= 1, } + + match result.after { + SingleStateTag::NotStarted => self.not_started += 1, + SingleStateTag::Running => self.running += 1, + SingleStateTag::Terminal(TerminalKind::Completed) => { + self.completed += 1 + } + SingleStateTag::Terminal(TerminalKind::Failed) => self.failed += 1, + SingleStateTag::Terminal(TerminalKind::Aborted) => { + self.aborted += 1 + } + SingleStateTag::Overwritten => self.overwritten += 1, + } } fn format_line( @@ -336,92 +368,139 @@ impl SingleState { &mut self, event_report: EventReport, ) -> AddEventReportResult { - let before = match &self.kind { + match &mut self.kind { SingleStateKind::NotStarted { .. } => { - self.kind = SingleStateKind::Running { - event_buffer: EventBuffer::new(8), + // We're starting a new update. + let before = SingleStateTag::NotStarted; + let mut event_buffer = EventBuffer::default(); + let (after, root_total_elapsed) = + match Self::apply_report(&mut event_buffer, event_report) { + ApplyReportResult::NotStarted => { + // This means that the event report was empty. Don't + // update `self.kind`. + (SingleStateTag::NotStarted, None) + } + ApplyReportResult::Running(root_total_elapsed) => { + self.kind = + SingleStateKind::Running { event_buffer }; + (SingleStateTag::Running, Some(root_total_elapsed)) + } + ApplyReportResult::Terminal(info) => { + let terminal_kind = info.kind; + let root_total_elapsed = info.root_total_elapsed; + + self.kind = SingleStateKind::Terminal { + info, + pending_event_buffer: Some(event_buffer), + }; + ( + SingleStateTag::Terminal(terminal_kind), + root_total_elapsed, + ) + } + ApplyReportResult::Overwritten => { + self.kind = SingleStateKind::Overwritten { + displayed: false, + }; + (SingleStateTag::Overwritten, None) + } + }; + + AddEventReportResult { before, after, root_total_elapsed } + } + SingleStateKind::Running { event_buffer } => { + // We're in the middle of an update. + let before = SingleStateTag::Running; + let (after, root_total_elapsed) = match Self::apply_report( + event_buffer, + event_report, + ) { + ApplyReportResult::NotStarted => { + // This is an illegal state transition: once a + // non-empty event report has been received, the + // event buffer never goes back to the NotStarted + // state. + unreachable!("illegal state transition from Running to NotStarted") + } + ApplyReportResult::Running(root_total_elapsed) => { + (SingleStateTag::Running, Some(root_total_elapsed)) + } + ApplyReportResult::Terminal(info) => { + let terminal_kind = info.kind; + let root_total_elapsed = info.root_total_elapsed; + + // Grab the event buffer so we can store it in the + // Terminal state below. + let event_buffer = std::mem::replace( + event_buffer, + EventBuffer::new(0), + ); + + self.kind = SingleStateKind::Terminal { + info, + pending_event_buffer: Some(event_buffer), + }; + ( + SingleStateTag::Terminal(terminal_kind), + root_total_elapsed, + ) + } + ApplyReportResult::Overwritten => { + self.kind = + SingleStateKind::Overwritten { displayed: false }; + (SingleStateTag::Overwritten, None) + } }; - SingleStateTag::NotStarted + AddEventReportResult { before, after, root_total_elapsed } } - SingleStateKind::Running { .. } => SingleStateTag::Running, - SingleStateKind::Terminal { info, .. } => { // Once we've reached a terminal state, we don't record any more // events. - return AddEventReportResult::unchanged( + AddEventReportResult::unchanged( SingleStateTag::Terminal(info.kind), info.root_total_elapsed, - ); + ) } SingleStateKind::Overwritten { .. } => { // This update has already completed -- assume that the event // buffer is for a new update, which we don't show. - return AddEventReportResult::unchanged( + AddEventReportResult::unchanged( SingleStateTag::Overwritten, None, - ); + ) } - }; - - let SingleStateKind::Running { event_buffer } = &mut self.kind else { - unreachable!("other branches were handled above"); - }; + } + } + /// The internal logic used by [`Self::add_event_report`]. + fn apply_report( + event_buffer: &mut EventBuffer, + event_report: EventReport, + ) -> ApplyReportResult { if let Some(root_execution_id) = event_buffer.root_execution_id() { if event_report.root_execution_id != Some(root_execution_id) { // The report is for a different execution ID -- assume that // this event is completed and mark our current execution as // completed. - self.kind = SingleStateKind::Overwritten { displayed: false }; - return AddEventReportResult { - before, - after: SingleStateTag::Overwritten, - root_total_elapsed: None, - }; + return ApplyReportResult::Overwritten; } } event_buffer.add_event_report(event_report); - let (after, max_total_elapsed) = - match event_buffer.root_execution_summary() { - Some(summary) => { - match summary.execution_status { - ExecutionStatus::NotStarted => { - (SingleStateTag::NotStarted, None) - } - ExecutionStatus::Running { - root_total_elapsed: max_total_elapsed, - .. - } => (SingleStateTag::Running, Some(max_total_elapsed)), - ExecutionStatus::Terminal(info) => { - // Grab the event buffer to store it in the terminal state. - let event_buffer = std::mem::replace( - event_buffer, - EventBuffer::new(0), - ); - let terminal_kind = info.kind; - let root_total_elapsed = info.root_total_elapsed; - self.kind = SingleStateKind::Terminal { - info, - pending_event_buffer: Some(event_buffer), - }; - ( - SingleStateTag::Terminal(terminal_kind), - root_total_elapsed, - ) - } - } + match event_buffer.root_execution_summary() { + Some(summary) => match summary.execution_status { + ExecutionStatus::NotStarted => ApplyReportResult::NotStarted, + ExecutionStatus::Running { root_total_elapsed, .. } => { + ApplyReportResult::Running(root_total_elapsed) } - None => { - // We don't have a summary yet. - (SingleStateTag::NotStarted, None) + ExecutionStatus::Terminal(info) => { + ApplyReportResult::Terminal(info) } - }; - - AddEventReportResult { - before, - after, - root_total_elapsed: max_total_elapsed, + }, + None => { + // We don't have a summary yet. + ApplyReportResult::NotStarted + } } } @@ -488,6 +567,7 @@ enum SingleStateKind { }, } +#[derive(Clone, Copy, Debug, Eq, PartialEq)] struct AddEventReportResult { before: SingleStateTag, after: SingleStateTag, @@ -503,10 +583,238 @@ impl AddEventReportResult { } } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] enum SingleStateTag { NotStarted, Running, Terminal(TerminalKind), Overwritten, } + +impl fmt::Display for SingleStateTag { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::NotStarted => write!(f, "not started"), + Self::Running => write!(f, "running"), + Self::Terminal(kind) => write!(f, "{kind}"), + Self::Overwritten => write!(f, "overwritten"), + } + } +} + +#[derive(Clone, Debug)] +enum ApplyReportResult { + NotStarted, + Running(Duration), + Terminal(ExecutionTerminalInfo), + Overwritten, +} + +#[cfg(test)] +mod tests { + use omicron_test_utils::dev::test_setup_log; + + use super::*; + + use crate::test_utils::{generate_test_events, GenerateTestEventsKind}; + + #[tokio::test] + async fn test_stats() { + let logctx = test_setup_log("test_stats"); + // Generate three sets of events, one for each kind. + let generated_completed = generate_test_events( + &logctx.log, + GenerateTestEventsKind::Completed, + ) + .await; + let generated_failed = + generate_test_events(&logctx.log, GenerateTestEventsKind::Failed) + .await; + let generated_aborted = + generate_test_events(&logctx.log, GenerateTestEventsKind::Aborted) + .await; + + // Set up a `GroupDisplay` with three keys. + let mut group_display = GroupDisplay::new_with_display( + &logctx.log, + vec![ + GroupDisplayKey::Completed, + GroupDisplayKey::Failed, + GroupDisplayKey::Aborted, + GroupDisplayKey::Overwritten, + ], + std::io::stdout(), + ); + + let mut expected_stats = GroupDisplayStats { + total: 4, + not_started: 4, + running: 0, + completed: 0, + failed: 0, + aborted: 0, + overwritten: 0, + }; + assert_eq!(group_display.stats(), &expected_stats); + assert!(!expected_stats.is_terminal()); + assert!(!expected_stats.has_failures()); + + // Pass in an empty EventReport -- ensure that this doesn't move it to + // a Running state. + + group_display + .add_event_report( + &GroupDisplayKey::Completed, + EventReport::default(), + ) + .unwrap(); + assert_eq!(group_display.stats(), &expected_stats); + + // Pass in events one by one -- ensure that we're always in the running + // state until we've completed. + { + expected_stats.not_started -= 1; + expected_stats.running += 1; + + let n = generated_completed.len(); + + let mut buffer = EventBuffer::default(); + let mut last_seen = None; + + for (i, event) in + generated_completed.clone().into_iter().enumerate() + { + buffer.add_event(event); + let report = buffer.generate_report_since(&mut last_seen); + group_display + .add_event_report(&GroupDisplayKey::Completed, report) + .unwrap(); + if i == n - 1 { + // The last event should have moved us to the completed + // state. + expected_stats.running -= 1; + expected_stats.completed += 1; + } else { + // We should still be in the running state. + } + assert_eq!(group_display.stats(), &expected_stats); + assert!(!expected_stats.is_terminal()); + assert!(!expected_stats.has_failures()); + } + } + + // Pass in failed events, this time using buffer.generate_report() + // rather than buffer.generate_report_since(). + { + expected_stats.not_started -= 1; + expected_stats.running += 1; + + let n = generated_failed.len(); + + let mut buffer = EventBuffer::default(); + for (i, event) in generated_failed.clone().into_iter().enumerate() { + buffer.add_event(event); + let report = buffer.generate_report(); + group_display + .add_event_report(&GroupDisplayKey::Failed, report) + .unwrap(); + if i == n - 1 { + // The last event should have moved us to the failed state. + expected_stats.running -= 1; + expected_stats.failed += 1; + assert!(expected_stats.has_failures()); + } else { + // We should still be in the running state. + assert!(!expected_stats.has_failures()); + } + assert_eq!(group_display.stats(), &expected_stats); + } + } + + // Pass in aborted events all at once. + { + expected_stats.not_started -= 1; + expected_stats.running += 1; + + let mut buffer = EventBuffer::default(); + for event in generated_aborted { + buffer.add_event(event); + } + let report = buffer.generate_report(); + group_display + .add_event_report(&GroupDisplayKey::Aborted, report) + .unwrap(); + // The aborted events should have moved us to the aborted state. + expected_stats.running -= 1; + expected_stats.aborted += 1; + assert_eq!(group_display.stats(), &expected_stats); + + // Try passing in one of the events that, if we were running, would + // cause us to move to an overwritten state. Ensure that that does + // not happen (i.e. expected_stats stays the same) + let mut buffer = EventBuffer::default(); + buffer.add_event(generated_failed.first().unwrap().clone()); + let report = buffer.generate_report(); + group_display + .add_event_report(&GroupDisplayKey::Aborted, report) + .unwrap(); + assert_eq!(group_display.stats(), &expected_stats); + } + + // For the overwritten state, pass in half of the completed events, and + // then pass in all of the failed events. + + { + expected_stats.not_started -= 1; + expected_stats.running += 1; + + let mut buffer = EventBuffer::default(); + let n = generated_completed.len() / 2; + for event in generated_completed.into_iter().take(n) { + buffer.add_event(event); + } + let report = buffer.generate_report(); + group_display + .add_event_report(&GroupDisplayKey::Overwritten, report) + .unwrap(); + assert_eq!(group_display.stats(), &expected_stats); + + // Now pass in a single failed event, which has a different + // execution ID. + let mut buffer = EventBuffer::default(); + buffer.add_event(generated_failed.first().unwrap().clone()); + let report = buffer.generate_report(); + group_display + .add_event_report(&GroupDisplayKey::Overwritten, report) + .unwrap(); + // The overwritten event should have moved us to the overwritten + // state. + expected_stats.running -= 1; + expected_stats.overwritten += 1; + } + + assert!(expected_stats.has_failures()); + assert!(expected_stats.is_terminal()); + + logctx.cleanup_successful(); + } + + #[derive(Debug, Eq, PartialEq, Ord, PartialOrd)] + enum GroupDisplayKey { + Completed, + Failed, + Aborted, + Overwritten, + } + + impl fmt::Display for GroupDisplayKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Completed => write!(f, "completed"), + Self::Failed => write!(f, "failed"), + Self::Aborted => write!(f, "aborted"), + Self::Overwritten => write!(f, "overwritten"), + } + } + } +} diff --git a/update-engine/src/test_utils.rs b/update-engine/src/test_utils.rs index 0bacfbeb8d..b943d1ddfe 100644 --- a/update-engine/src/test_utils.rs +++ b/update-engine/src/test_utils.rs @@ -4,9 +4,16 @@ // Copyright 2023 Oxide Computer Company +use anyhow::bail; +use futures::StreamExt; use schemars::JsonSchema; +use tokio::sync::{mpsc, oneshot}; +use tokio_stream::wrappers::ReceiverStream; -use crate::{ExecutionId, StepSpec}; +use crate::{ + events::{Event, ProgressUnits, StepProgress}, + EventBuffer, ExecutionId, StepContext, StepSpec, StepSuccess, UpdateEngine, +}; #[derive(JsonSchema)] pub(crate) enum TestSpec {} @@ -27,3 +34,278 @@ pub(crate) static TEST_EXECUTION_UUID: &str = pub fn test_execution_id() -> ExecutionId { ExecutionId(TEST_EXECUTION_UUID.parse().expect("valid UUID")) } + +#[derive(Copy, Clone, Debug)] +pub(crate) enum GenerateTestEventsKind { + Completed, + Failed, + Aborted, +} + +pub(crate) async fn generate_test_events( + log: &slog::Logger, + kind: GenerateTestEventsKind, +) -> Vec> { + // The channel is big enough to contain all possible events. + let (sender, receiver) = mpsc::channel(512); + let engine = UpdateEngine::new(log, sender); + + match kind { + GenerateTestEventsKind::Completed => { + define_test_steps(log, &engine, LastStepOutcome::Completed); + engine.execute().await.expect("execution successful"); + } + GenerateTestEventsKind::Failed => { + define_test_steps(log, &engine, LastStepOutcome::Failed); + engine.execute().await.expect_err("execution failed"); + } + GenerateTestEventsKind::Aborted => { + // In this case, the last step signals that it has been reached via + // sending a message over this channel, and then waits forever. We + // abort execution by calling into the AbortHandle. + let (sender, receiver) = oneshot::channel(); + define_test_steps(log, &engine, LastStepOutcome::Aborted(sender)); + let abort_handle = engine.abort_handle(); + let mut execute_fut = std::pin::pin!(engine.execute()); + let mut receiver = std::pin::pin!(receiver); + let mut receiver_done = false; + loop { + tokio::select! { + res = &mut execute_fut => { + res.expect_err("execution should have been aborted, but completed successfully"); + break; + } + _ = &mut receiver, if !receiver_done => { + receiver_done = true; + abort_handle + .abort("test engine deliberately aborted") + .expect("engine should still be alive"); + } + } + } + } + } + + ReceiverStream::new(receiver).collect().await +} + +#[derive(Debug)] +enum LastStepOutcome { + Completed, + Failed, + Aborted(oneshot::Sender<()>), +} + +#[derive(Debug)] +enum Never {} + +fn define_test_steps( + log: &slog::Logger, + engine: &UpdateEngine, + last_step_outcome: LastStepOutcome, +) { + engine + .new_step("foo".to_owned(), 1, "Step 1", move |_cx| async move { + StepSuccess::new(()).into() + }) + .register(); + + engine + .new_step("bar".to_owned(), 2, "Step 2", move |cx| async move { + for _ in 0..20 { + cx.send_progress(StepProgress::with_current_and_total( + 5, + 20, + ProgressUnits::BYTES, + Default::default(), + )) + .await; + + cx.send_progress(StepProgress::reset( + Default::default(), + "reset step 2", + )) + .await; + + cx.send_progress(StepProgress::retry("retry step 2")).await; + } + StepSuccess::new(()).into() + }) + .register(); + + engine + .new_step( + "nested".to_owned(), + 3, + "Step 3 (this is nested)", + move |parent_cx| async move { + parent_cx + .with_nested_engine(|engine| { + define_nested_engine(&parent_cx, engine); + Ok(()) + }) + .await + .expect_err("this is expected to fail"); + + StepSuccess::new(()).into() + }, + ) + .register(); + + let log = log.clone(); + engine + .new_step( + "remote-nested".to_owned(), + 20, + "Step 4 (remote nested)", + move |cx| async move { + let (sender, mut receiver) = mpsc::channel(16); + let mut engine = UpdateEngine::new(&log, sender); + define_remote_nested_engine(&mut engine, 20); + + let mut buffer = EventBuffer::default(); + + let mut execute_fut = std::pin::pin!(engine.execute()); + let mut execute_done = false; + loop { + tokio::select! { + res = &mut execute_fut, if !execute_done => { + res.expect("remote nested engine completed successfully"); + execute_done = true; + } + Some(event) = receiver.recv() => { + // Generate complete reports to ensure deduping + // happens within StepContexts. + buffer.add_event(event); + cx.send_nested_report(buffer.generate_report()).await?; + } + else => { + break; + } + } + } + + StepSuccess::new(()).into() + }, + ) + .register(); + + // The step index here (100) is large enough to be higher than all nested + // steps. + engine + .new_step("baz".to_owned(), 100, "Step 5", move |_cx| async move { + match last_step_outcome { + LastStepOutcome::Completed => StepSuccess::new(()).into(), + LastStepOutcome::Failed => { + bail!("last step failed") + } + LastStepOutcome::Aborted(sender) => { + sender.send(()).expect("receiver should be alive"); + // The driver of the engine is responsible for aborting it + // at this point. + std::future::pending::().await; + unreachable!("pending future can never resolve"); + } + } + }) + .register(); +} + +fn define_nested_engine<'a>( + parent_cx: &'a StepContext, + engine: &mut UpdateEngine<'a, TestSpec>, +) { + engine + .new_step( + "nested-foo".to_owned(), + 4, + "Nested step 1", + move |cx| async move { + parent_cx + .send_progress(StepProgress::with_current_and_total( + 1, + 3, + "steps", + Default::default(), + )) + .await; + cx.send_progress(StepProgress::progress(Default::default())) + .await; + StepSuccess::new(()).into() + }, + ) + .register(); + + engine + .new_step::<_, _, ()>( + "nested-bar".to_owned(), + 5, + "Nested step 2 (fails)", + move |cx| async move { + // This is used by NestedProgressCheck below. + parent_cx + .send_progress(StepProgress::with_current_and_total( + 2, + 3, + "steps", + Default::default(), + )) + .await; + + cx.send_progress(StepProgress::with_current( + 50, + "units", + Default::default(), + )) + .await; + + parent_cx + .send_progress(StepProgress::with_current_and_total( + 3, + 3, + "steps", + Default::default(), + )) + .await; + + bail!("failing step") + }, + ) + .register(); +} + +fn define_remote_nested_engine( + engine: &mut UpdateEngine<'_, TestSpec>, + start_id: usize, +) { + engine + .new_step( + "nested-foo".to_owned(), + start_id + 1, + "Nested step 1", + move |cx| async move { + cx.send_progress(StepProgress::progress(Default::default())) + .await; + StepSuccess::new(()).into() + }, + ) + .register(); + + engine + .new_step::<_, _, ()>( + "nested-bar".to_owned(), + start_id + 2, + "Nested step 2", + move |cx| async move { + cx.send_progress(StepProgress::with_current( + 20, + "units", + Default::default(), + )) + .await; + + StepSuccess::new(()).into() + }, + ) + .register(); +} diff --git a/wicket/src/cli/rack_update.rs b/wicket/src/cli/rack_update.rs index fa41fa7b8c..cac0f09ee5 100644 --- a/wicket/src/cli/rack_update.rs +++ b/wicket/src/cli/rack_update.rs @@ -174,6 +174,7 @@ async fn do_attach_to_updates( output: CommandOutput<'_>, ) -> Result<()> { let mut display = GroupDisplay::new_with_display( + &log, update_ids.iter().copied(), output.stderr, );