diff --git a/update-engine/src/buffer.rs b/update-engine/src/buffer.rs index a813f3f80a..a9e04c1d12 100644 --- a/update-engine/src/buffer.rs +++ b/update-engine/src/buffer.rs @@ -1599,7 +1599,7 @@ mod tests { use tokio_stream::wrappers::ReceiverStream; use crate::{ - events::{ProgressCounter, ProgressUnits, StepProgress}, + events::{ProgressUnits, StepProgress}, test_utils::TestSpec, StepContext, StepSuccess, UpdateEngine, }; @@ -1834,36 +1834,6 @@ mod tests { } }; - // Ensure that nested step 2 produces progress events in the - // expected order and in succession. - let mut progress_check = NestedProgressCheck::new(); - for event in &generated_events { - if let Event::Progress(event) = event { - let progress_counter = event.kind.progress_counter(); - if progress_counter - == Some(&ProgressCounter::new(2, 3, "steps")) - { - progress_check.two_out_of_three_seen(); - } else if progress_check - == NestedProgressCheck::TwoOutOfThreeSteps - { - assert_eq!( - progress_counter, - Some(&ProgressCounter::current(50, "units")) - ); - progress_check.fifty_units_seen(); - } else if progress_check == NestedProgressCheck::FiftyUnits - { - assert_eq!( - progress_counter, - Some(&ProgressCounter::new(3, 3, "steps")) - ); - progress_check.three_out_of_three_seen(); - } - } - } - progress_check.assert_done(); - // Ensure that events are never seen twice. let mut event_indexes_seen = HashSet::new(); let mut leaf_event_indexes_seen = HashSet::new(); @@ -2370,7 +2340,6 @@ mod tests { 5, "Nested step 2 (fails)", move |cx| async move { - // This is used by NestedProgressCheck below. parent_cx .send_progress(StepProgress::with_current_and_total( 2, @@ -2381,76 +2350,18 @@ mod tests { .await; cx.send_progress(StepProgress::with_current( - 50, + 20, "units", Default::default(), )) .await; - parent_cx - .send_progress(StepProgress::with_current_and_total( - 3, - 3, - "steps", - Default::default(), - )) - .await; - bail!("failing step") }, ) .register(); } - #[derive(Clone, Copy, Debug, PartialEq, Eq)] - enum NestedProgressCheck { - Initial, - TwoOutOfThreeSteps, - FiftyUnits, - ThreeOutOfThreeSteps, - } - - impl NestedProgressCheck { - fn new() -> Self { - Self::Initial - } - - fn two_out_of_three_seen(&mut self) { - assert_eq!( - *self, - Self::Initial, - "two_out_of_three_seen: expected Initial", - ); - *self = Self::TwoOutOfThreeSteps; - } - - fn fifty_units_seen(&mut self) { - assert_eq!( - *self, - Self::TwoOutOfThreeSteps, - "twenty_units_seen: expected TwoOutOfThreeSteps", - ); - *self = Self::FiftyUnits; - } - - fn three_out_of_three_seen(&mut self) { - assert_eq!( - *self, - Self::FiftyUnits, - "three_out_of_three_seen: expected TwentyUnits", - ); - *self = Self::ThreeOutOfThreeSteps; - } - - fn assert_done(&self) { - assert_eq!( - *self, - Self::ThreeOutOfThreeSteps, - "assert_done: expected ThreeOutOfThreeSteps", - ); - } - } - fn define_remote_nested_engine( engine: &mut UpdateEngine<'_, TestSpec>, start_id: usize, diff --git a/update-engine/src/context.rs b/update-engine/src/context.rs index 6b22578c2e..c2c1e32119 100644 --- a/update-engine/src/context.rs +++ b/update-engine/src/context.rs @@ -11,7 +11,6 @@ use std::{collections::HashMap, fmt}; use derive_where::derive_where; use futures::FutureExt; use tokio::sync::{mpsc, oneshot}; -use tokio::time::Instant; use crate::errors::NestedEngineError; use crate::{ @@ -57,13 +56,10 @@ impl StepContext { /// Sends a progress update to the update engine. #[inline] pub async fn send_progress(&self, progress: StepProgress) { - let now = Instant::now(); - let (done, done_rx) = oneshot::channel(); self.payload_sender - .send(StepContextPayload::Progress { now, progress, done }) + .send(StepContextPayload::Progress(progress)) .await - .expect("our code always keeps payload_receiver open"); - _ = done_rx.await; + .expect("our code always keeps the receiver open") } /// Sends a report from a nested engine, typically one running on a remote @@ -75,8 +71,6 @@ impl StepContext { &self, report: EventReport, ) -> Result<(), NestedEngineError> { - let now = Instant::now(); - let mut res = Ok(()); let delta_report = if let Some(id) = report.root_execution_id { let mut nested_buffers = self.nested_buffers.lock().unwrap(); @@ -140,32 +134,17 @@ impl StepContext { } self.payload_sender - .send(StepContextPayload::Nested { - now, - event: Event::Step(event), - }) + .send(StepContextPayload::Nested(Event::Step(event))) .await - .expect("our code always keeps payload_receiver open"); + .expect("our code always keeps the receiver open"); } for event in delta_report.progress_events { self.payload_sender - .send(StepContextPayload::Nested { - now, - event: Event::Progress(event), - }) + .send(StepContextPayload::Nested(Event::Progress(event))) .await - .expect("our code always keeps payload_receiver open"); + .expect("our code always keeps the receiver open"); } - - // Ensure that all reports have been received by the engine before - // returning. - let (done, done_rx) = oneshot::channel(); - self.payload_sender - .send(StepContextPayload::Sync { done }) - .await - .expect("our code always keeps payload_receiver open"); - _ = done_rx.await; } res @@ -184,75 +163,58 @@ impl StepContext { F: FnOnce(&mut UpdateEngine<'a, S2>) -> Result<(), S2::Error> + Send, S2: StepSpec + 'a, { - // Previously, this code was of the form: - // - // let (sender, mut receiver) = mpsc::channel(128); - // let mut engine = UpdateEngine::new(&self.log, sender); - // - // And there was a loop below that selected over `engine` and - // `receiver`. - // - // That approach was abandoned because it had ordering issues, because - // it wasn't guaranteed that events were received in the order they were - // processed. For example, consider what happens if: - // - // 1. User code sent an event E1 through a child (nested) StepContext. - // 2. Then in quick succession, the same code sent an event E2 through - // self. - // - // What users would expect to happen is that E1 is received before E2. - // However, what actually happened was that: - // - // 1. `engine` was driven until the next suspend point. This caused E2 - // to be sent. - // 2. Then, `receiver` was polled. This caused E1 to be received. - // - // So the order of events was reversed. - // - // To fix this, we now use a single channel, and send events through it - // both from the nested engine and from self. - // - // An alternative would be to use a oneshot channel as a synchronization - // tool. However, just sharing a channel is easier. - let mut engine = UpdateEngine::::new_nested( - &self.log, - self.payload_sender.clone(), - ); - + let (sender, mut receiver) = mpsc::channel(128); + let mut engine = UpdateEngine::new(&self.log, sender); // Create the engine's steps. (engine_fn)(&mut engine) .map_err(|error| NestedEngineError::Creation { error })?; // Now run the engine. let engine = engine.execute(); - match engine.await { - Ok(cx) => Ok(cx), - Err(ExecutionError::EventSendError(_)) => { - unreachable!("our code always keeps payload_receiver open") + tokio::pin!(engine); + + let mut result = None; + let mut events_done = false; + + loop { + tokio::select! { + ret = &mut engine, if result.is_none() => { + match ret { + Ok(cx) => { + result = Some(Ok(cx)); + } + Err(ExecutionError::EventSendError(_)) => { + unreachable!("we always keep the receiver open") + } + Err(ExecutionError::StepFailed { component, id, description, error }) => { + result = Some(Err(NestedEngineError::StepFailed { component, id, description, error })); + } + Err(ExecutionError::Aborted { component, id, description, message }) => { + result = Some(Err(NestedEngineError::Aborted { component, id, description, message })); + } + } + } + event = receiver.recv(), if !events_done => { + match event { + Some(event) => { + self.payload_sender.send( + StepContextPayload::Nested(event.into_generic()) + ) + .await + .expect("we always keep the receiver open"); + } + None => { + events_done = true; + } + } + } + else => { + break; + } } - Err(ExecutionError::StepFailed { - component, - id, - description, - error, - }) => Err(NestedEngineError::StepFailed { - component, - id, - description, - error, - }), - Err(ExecutionError::Aborted { - component, - id, - description, - message, - }) => Err(NestedEngineError::Aborted { - component, - id, - description, - message, - }), } + + result.expect("the loop only exits if result is set") } /// Retrieves a token used to fetch the value out of a [`StepHandle`]. @@ -285,32 +247,10 @@ impl NestedEventBuffer { } } -/// An uninhabited type for oneshot channels, since we only care about them -/// being dropped. -#[derive(Debug)] -pub(crate) enum Never {} - #[derive_where(Debug)] pub(crate) enum StepContextPayload { - Progress { - now: Instant, - progress: StepProgress, - done: oneshot::Sender, - }, - /// A single nested event with synchronization. - NestedSingle { - now: Instant, - event: Event, - done: oneshot::Sender, - }, - /// One out of a series of nested events sent in succession. - Nested { - now: Instant, - event: Event, - }, - Sync { - done: oneshot::Sender, - }, + Progress(StepProgress), + Nested(Event), } /// Context for a step's metadata-generation function. diff --git a/update-engine/src/engine.rs b/update-engine/src/engine.rs index 56d7739a40..24f055858c 100644 --- a/update-engine/src/engine.rs +++ b/update-engine/src/engine.rs @@ -5,12 +5,7 @@ // Copyright 2023 Oxide Computer Company use std::{ - borrow::Cow, - fmt, - ops::ControlFlow, - pin::Pin, - sync::{Arc, Mutex}, - task::Poll, + borrow::Cow, fmt, ops::ControlFlow, pin::Pin, sync::Mutex, task::Poll, }; use cancel_safe_futures::coop_cancel; @@ -33,7 +28,7 @@ use crate::{ StepEvent, StepEventKind, StepInfo, StepInfoWithMetadata, StepOutcome, StepProgress, }, - AsError, CompletionContext, MetadataContext, NestedSpec, StepContext, + AsError, CompletionContext, MetadataContext, StepContext, StepContextPayload, StepHandle, StepSpec, }; @@ -69,7 +64,7 @@ pub struct UpdateEngine<'a, S: StepSpec> { // be a graph in the future. log: slog::Logger, execution_id: ExecutionId, - sender: EngineSender, + sender: mpsc::Sender>, // This is set to None in Self::execute. canceler: Option>, @@ -87,21 +82,6 @@ pub struct UpdateEngine<'a, S: StepSpec> { impl<'a, S: StepSpec + 'a> UpdateEngine<'a, S> { /// Creates a new `UpdateEngine`. pub fn new(log: &slog::Logger, sender: mpsc::Sender>) -> Self { - let sender = Arc::new(DefaultSender { sender }); - Self::new_impl(log, EngineSender { sender }) - } - - // See the comment on `StepContext::with_nested_engine` for why this is - // necessary.`` - pub(crate) fn new_nested( - log: &slog::Logger, - sender: mpsc::Sender>, - ) -> Self { - let sender = Arc::new(NestedSender { sender }); - Self::new_impl(log, EngineSender { sender }) - } - - fn new_impl(log: &slog::Logger, sender: EngineSender) -> Self { let execution_id = ExecutionId(Uuid::new_v4()); let (canceler, cancel_receiver) = coop_cancel::new_pair(); Self { @@ -323,88 +303,6 @@ impl<'a, S: StepSpec + 'a> UpdateEngine<'a, S> { } } -/// Abstraction used to send events to whatever receiver is interested in them. -/// -/// # Why is this type so weird? -/// -/// `EngineSender` is a wrapper around a cloneable trait object. Why do we need -/// that? -/// -/// `SenderImpl` has two implementations: -/// -/// 1. `DefaultSender`, which is a wrapper around an `mpsc::Sender>`. -/// This is used when the receiver is user code. -/// 2. `NestedSender`, which is a more complex wrapper around an -/// `mpsc::Sender>`. -/// -/// You might imagine that we could just have `EngineSender` be an enum with -/// these two variants. But we actually want `NestedSender` to implement -/// `SenderImpl` for *any* StepSpec, not just `S`, to allow nested engines to -/// be a different StepSpec than the outer engine. -/// -/// So we need to use a trait object to achieve type erasure. -#[derive_where(Clone, Debug)] -struct EngineSender { - sender: Arc>, -} - -impl EngineSender { - async fn send(&self, event: Event) -> Result<(), ExecutionError> { - self.sender.send(event).await - } -} - -trait SenderImpl: Send + Sync + fmt::Debug { - fn send( - &self, - event: Event, - ) -> BoxFuture<'_, Result<(), ExecutionError>>; -} - -#[derive_where(Debug)] -struct DefaultSender { - sender: mpsc::Sender>, -} - -impl SenderImpl for DefaultSender { - fn send( - &self, - event: Event, - ) -> BoxFuture<'_, Result<(), ExecutionError>> { - self.sender.send(event).map_err(|error| error.into()).boxed() - } -} - -#[derive_where(Debug)] -struct NestedSender { - sender: mpsc::Sender>, -} - -// Note that NestedSender implements SenderImpl for any S2: StepSpec. -// That is to allow nested engines to implement arbitrary StepSpecs. -impl SenderImpl for NestedSender { - fn send( - &self, - event: Event, - ) -> BoxFuture<'_, Result<(), ExecutionError>> { - let now = Instant::now(); - async move { - let (done, done_rx) = oneshot::channel(); - self.sender - .send(StepContextPayload::NestedSingle { - now, - event: event.into_generic(), - done, - }) - .await - .expect("our code always keeps payload_receiver open"); - _ = done_rx.await; - Ok(()) - } - .boxed() - } -} - /// A join handle for an UpdateEngine. /// /// This handle should be awaited to drive and obtain the result of an execution. @@ -921,16 +819,6 @@ impl<'a, S: StepSpec> StepExec<'a, S> { Ok(ControlFlow::Continue(())) } - // Note: payload_receiver is always kept open while step_fut - // is being driven. It is only dropped before completion if - // the step is aborted, in which case step_fut is also - // cancelled without being driven further. A bunch of - // expects with "our code always keeps payload_receiver - // open" rely on this. - // - // If we ever move the payload receiver to another task so - // it runs in parallel, this situation would have to be - // handled with care. payload = payload_receiver.recv(), if !payload_done => { match payload { Some(payload) => { @@ -980,14 +868,14 @@ struct ExecutionContext { execution_id: ExecutionId, next_event_index: DebugIgnore, total_start: Instant, - sender: EngineSender, + sender: mpsc::Sender>, } impl ExecutionContext { fn new( execution_id: ExecutionId, next_event_index: F, - sender: EngineSender, + sender: mpsc::Sender>, ) -> Self { let total_start = Instant::now(); Self { @@ -1018,7 +906,7 @@ struct StepExecutionContext { next_event_index: DebugIgnore, total_start: Instant, step_info: StepInfoWithMetadata, - sender: EngineSender, + sender: mpsc::Sender>, } type StepMetadataFn<'a, S> = Box< @@ -1053,7 +941,7 @@ struct StepProgressReporter { step_start: Instant, attempt: usize, attempt_start: Instant, - sender: EngineSender, + sender: mpsc::Sender>, } impl usize> StepProgressReporter { @@ -1075,32 +963,51 @@ impl usize> StepProgressReporter { async fn handle_payload( &mut self, payload: StepContextPayload, - ) -> Result<(), ExecutionError> { + ) -> Result<(), mpsc::error::SendError>> { match payload { - StepContextPayload::Progress { now, progress, done } => { - self.handle_progress(now, progress).await?; - std::mem::drop(done); + StepContextPayload::Progress(progress) => { + self.handle_progress(progress).await } - StepContextPayload::NestedSingle { now, event, done } => { - self.handle_nested(now, event).await?; - std::mem::drop(done); - } - StepContextPayload::Nested { now, event } => { - self.handle_nested(now, event).await?; + StepContextPayload::Nested(Event::Step(event)) => { + self.sender + .send(Event::Step(StepEvent { + spec: S::schema_name(), + execution_id: self.execution_id, + event_index: (self.next_event_index)(), + total_elapsed: self.total_start.elapsed(), + kind: StepEventKind::Nested { + step: self.step_info.clone(), + attempt: self.attempt, + event: Box::new(event), + step_elapsed: self.step_start.elapsed(), + attempt_elapsed: self.attempt_start.elapsed(), + }, + })) + .await } - StepContextPayload::Sync { done } => { - std::mem::drop(done); + StepContextPayload::Nested(Event::Progress(event)) => { + self.sender + .send(Event::Progress(ProgressEvent { + spec: S::schema_name(), + execution_id: self.execution_id, + total_elapsed: self.total_start.elapsed(), + kind: ProgressEventKind::Nested { + step: self.step_info.clone(), + attempt: self.attempt, + event: Box::new(event), + step_elapsed: self.step_start.elapsed(), + attempt_elapsed: self.attempt_start.elapsed(), + }, + })) + .await } } - - Ok(()) } async fn handle_progress( &mut self, - now: Instant, progress: StepProgress, - ) -> Result<(), ExecutionError> { + ) -> Result<(), mpsc::error::SendError>> { match progress { StepProgress::Progress { progress, metadata } => { // Send the progress to the sender. @@ -1108,14 +1015,14 @@ impl usize> StepProgressReporter { .send(Event::Progress(ProgressEvent { spec: S::schema_name(), execution_id: self.execution_id, - total_elapsed: now - self.total_start, + total_elapsed: self.total_start.elapsed(), kind: ProgressEventKind::Progress { step: self.step_info.clone(), attempt: self.attempt, progress, metadata, - step_elapsed: now - self.step_start, - attempt_elapsed: now - self.attempt_start, + step_elapsed: self.step_start.elapsed(), + attempt_elapsed: self.attempt_start.elapsed(), }, })) .await @@ -1127,13 +1034,13 @@ impl usize> StepProgressReporter { spec: S::schema_name(), execution_id: self.execution_id, event_index: (self.next_event_index)(), - total_elapsed: now - self.total_start, + total_elapsed: self.total_start.elapsed(), kind: StepEventKind::ProgressReset { step: self.step_info.clone(), attempt: self.attempt, metadata, - step_elapsed: now - self.step_start, - attempt_elapsed: now - self.attempt_start, + step_elapsed: self.step_start.elapsed(), + attempt_elapsed: self.attempt_start.elapsed(), message, }, })) @@ -1142,7 +1049,7 @@ impl usize> StepProgressReporter { StepProgress::Retry { message } => { // Retry this step. self.attempt += 1; - let attempt_elapsed = now - self.attempt_start; + let attempt_elapsed = self.attempt_start.elapsed(); self.attempt_start = Instant::now(); // Send the retry message. @@ -1151,11 +1058,11 @@ impl usize> StepProgressReporter { spec: S::schema_name(), execution_id: self.execution_id, event_index: (self.next_event_index)(), - total_elapsed: now - self.total_start, + total_elapsed: self.total_start.elapsed(), kind: StepEventKind::AttemptRetry { step: self.step_info.clone(), next_attempt: self.attempt, - step_elapsed: now - self.step_start, + step_elapsed: self.step_start.elapsed(), attempt_elapsed, message, }, @@ -1165,48 +1072,6 @@ impl usize> StepProgressReporter { } } - async fn handle_nested( - &mut self, - now: Instant, - event: Event, - ) -> Result<(), ExecutionError> { - match event { - Event::Step(event) => { - self.sender - .send(Event::Step(StepEvent { - spec: S::schema_name(), - execution_id: self.execution_id, - event_index: (self.next_event_index)(), - total_elapsed: now - self.total_start, - kind: StepEventKind::Nested { - step: self.step_info.clone(), - attempt: self.attempt, - event: Box::new(event), - step_elapsed: now - self.step_start, - attempt_elapsed: now - self.attempt_start, - }, - })) - .await - } - Event::Progress(event) => { - self.sender - .send(Event::Progress(ProgressEvent { - spec: S::schema_name(), - execution_id: self.execution_id, - total_elapsed: now - self.total_start, - kind: ProgressEventKind::Nested { - step: self.step_info.clone(), - attempt: self.attempt, - event: Box::new(event), - step_elapsed: now - self.step_start, - attempt_elapsed: now - self.attempt_start, - }, - })) - .await - } - } - } - async fn handle_abort(mut self, message: String) -> ExecutionError { // Send the abort message over the channel. // @@ -1237,7 +1102,7 @@ impl usize> StepProgressReporter { description: self.step_info.info.description.clone(), message: message, }, - Err(error) => error, + Err(error) => error.into(), } } @@ -1322,7 +1187,7 @@ impl usize> StepProgressReporter { async fn send_error( mut self, error: &S::Error, - ) -> Result<(), ExecutionError> { + ) -> Result<(), mpsc::error::SendError>> { // Stringify `error` into a message + list causes; this is written the // way it is to avoid `error` potentially living across the `.await` // below (which can cause lifetime issues in callers). diff --git a/update-engine/src/errors.rs b/update-engine/src/errors.rs index abb0d4cd22..f40ce096d3 100644 --- a/update-engine/src/errors.rs +++ b/update-engine/src/errors.rs @@ -48,13 +48,13 @@ impl fmt::Display for ExecutionError { ) } Self::EventSendError(_) => { - write!(f, "while sending event, event receiver dropped") + write!(f, "event receiver dropped") } } } } -impl error::Error for ExecutionError { +impl error::Error for ExecutionError { fn source(&self) -> Option<&(dyn error::Error + 'static)> { match self { ExecutionError::StepFailed { error, .. } => Some(error.as_error()), @@ -112,7 +112,7 @@ impl fmt::Display for NestedEngineError { } } -impl error::Error for NestedEngineError { +impl error::Error for NestedEngineError { fn source(&self) -> Option<&(dyn error::Error + 'static)> { match self { Self::Creation { error } => Some(error.as_error()), diff --git a/update-engine/src/spec.rs b/update-engine/src/spec.rs index 9b0121c213..a569bcf14a 100644 --- a/update-engine/src/spec.rs +++ b/update-engine/src/spec.rs @@ -15,7 +15,7 @@ use serde::{de::DeserializeOwned, Serialize}; /// /// NOTE: `StepSpec` is only required to implement `JsonSchema` to obtain the /// name of the schema. This is an upstream limitation in `JsonSchema`. -pub trait StepSpec: JsonSchema + Send + 'static { +pub trait StepSpec: JsonSchema + Send { /// A component associated with each step. type Component: Clone + fmt::Debug @@ -183,7 +183,7 @@ impl AsError for NestedError { /// Trait that abstracts over concrete errors and `anyhow::Error`. /// /// This needs to be manually implemented for any custom error types. -pub trait AsError: fmt::Debug + Send + Sync + 'static { +pub trait AsError: fmt::Debug + Send + Sync { fn as_error(&self) -> &(dyn std::error::Error + 'static); }