diff --git a/Cargo.lock b/Cargo.lock index 5f073db250..138080640e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10063,6 +10063,7 @@ dependencies = [ "installinator-artifact-client", "installinator-artifactd", "installinator-common", + "itertools 0.11.0", "omicron-certificates", "omicron-common 0.1.0", "omicron-passwords 0.1.0", diff --git a/openapi/wicketd.json b/openapi/wicketd.json index 40d798da00..d67fc79f7a 100644 --- a/openapi/wicketd.json +++ b/openapi/wicketd.json @@ -598,6 +598,33 @@ } } }, + "/update": { + "post": { + "summary": "An endpoint to start updating one or more sleds, switches and PSCs.", + "operationId": "post_start_update", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/StartUpdateParams" + } + } + }, + "required": true + }, + "responses": { + "204": { + "description": "resource updated" + }, + "4XX": { + "$ref": "#/components/responses/Error" + }, + "5XX": { + "$ref": "#/components/responses/Error" + } + } + } + }, "/update/{type}/{slot}": { "get": { "summary": "An endpoint to get the status of any update being performed or recently", @@ -641,51 +668,6 @@ "$ref": "#/components/responses/Error" } } - }, - "post": { - "summary": "An endpoint to start updating a sled.", - "operationId": "post_start_update", - "parameters": [ - { - "in": "path", - "name": "slot", - "required": true, - "schema": { - "type": "integer", - "format": "uint32", - "minimum": 0 - } - }, - { - "in": "path", - "name": "type", - "required": true, - "schema": { - "$ref": "#/components/schemas/SpType" - } - } - ], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/StartUpdateOptions" - } - } - }, - "required": true - }, - "responses": { - "204": { - "description": "resource updated" - }, - "4XX": { - "$ref": "#/components/responses/Error" - }, - "5XX": { - "$ref": "#/components/responses/Error" - } - } } } }, @@ -2761,6 +2743,31 @@ "skip_sp_version_check" ] }, + "StartUpdateParams": { + "type": "object", + "properties": { + "options": { + "description": "Options for the update.", + "allOf": [ + { + "$ref": "#/components/schemas/StartUpdateOptions" + } + ] + }, + "targets": { + "description": "The SP identifiers to start the update with. Must be non-empty.", + "type": "array", + "items": { + "$ref": "#/components/schemas/SpIdentifier" + }, + "uniqueItems": true + } + }, + "required": [ + "options", + "targets" + ] + }, "StepComponentSummaryForGenericSpec": { "type": "object", "properties": { diff --git a/wicket/src/wicketd.rs b/wicket/src/wicketd.rs index 160bcb1c6a..2411542429 100644 --- a/wicket/src/wicketd.rs +++ b/wicket/src/wicketd.rs @@ -12,7 +12,7 @@ use tokio::time::{interval, Duration, MissedTickBehavior}; use wicketd_client::types::{ AbortUpdateOptions, ClearUpdateStateOptions, GetInventoryParams, GetInventoryResponse, GetLocationResponse, IgnitionCommand, SpIdentifier, - SpType, StartUpdateOptions, + SpType, StartUpdateOptions, StartUpdateParams, }; use crate::events::EventReportMap; @@ -164,10 +164,11 @@ impl WicketdManager { tokio::spawn(async move { let update_client = create_wicketd_client(&log, addr, WICKETD_TIMEOUT); - let sp: SpIdentifier = component_id.into(); - let response = match update_client - .post_start_update(sp.type_, sp.slot, &options) - .await + let params = StartUpdateParams { + targets: vec![component_id.into()], + options, + }; + let response = match update_client.post_start_update(¶ms).await { Ok(_) => Ok(()), Err(error) => Err(error.to_string()), diff --git a/wicketd/Cargo.toml b/wicketd/Cargo.toml index a36344b6fb..6df5e0e4e5 100644 --- a/wicketd/Cargo.toml +++ b/wicketd/Cargo.toml @@ -24,6 +24,7 @@ hubtools.workspace = true http.workspace = true hyper.workspace = true illumos-utils.workspace = true +itertools.workspace = true reqwest.workspace = true schemars.workspace = true serde.workspace = true diff --git a/wicketd/src/helpers.rs b/wicketd/src/helpers.rs new file mode 100644 index 0000000000..a8b47d4f12 --- /dev/null +++ b/wicketd/src/helpers.rs @@ -0,0 +1,41 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +//! Helpers and utility functions for wicketd. + +use std::fmt; + +use gateway_client::types::{SpIdentifier, SpType}; +use itertools::Itertools; + +#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash)] +pub(crate) struct SpIdentifierDisplay(pub(crate) SpIdentifier); + +impl From for SpIdentifierDisplay { + fn from(id: SpIdentifier) -> Self { + SpIdentifierDisplay(id) + } +} + +impl<'a> From<&'a SpIdentifier> for SpIdentifierDisplay { + fn from(id: &'a SpIdentifier) -> Self { + SpIdentifierDisplay(*id) + } +} + +impl fmt::Display for SpIdentifierDisplay { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0.type_ { + SpType::Sled => write!(f, "sled {}", self.0.slot), + SpType::Switch => write!(f, "switch {}", self.0.slot), + SpType::Power => write!(f, "PSC {}", self.0.slot), + } + } +} + +pub(crate) fn sps_to_string>( + sps: impl IntoIterator, +) -> String { + sps.into_iter().map_into().join(", ") +} diff --git a/wicketd/src/http_entrypoints.rs b/wicketd/src/http_entrypoints.rs index 98cac8dc5d..72c3341334 100644 --- a/wicketd/src/http_entrypoints.rs +++ b/wicketd/src/http_entrypoints.rs @@ -4,6 +4,8 @@ //! HTTP entrypoint functions for wicketd +use crate::helpers::sps_to_string; +use crate::helpers::SpIdentifierDisplay; use crate::mgs::GetInventoryError; use crate::mgs::GetInventoryResponse; use crate::mgs::MgsHandle; @@ -44,7 +46,6 @@ use std::net::IpAddr; use std::net::Ipv6Addr; use std::time::Duration; use tokio::io::AsyncWriteExt; -use uuid::Uuid; use wicket_common::rack_setup::PutRssUserConfigInsensitive; use wicket_common::update_events::EventReport; @@ -652,6 +653,15 @@ async fn get_artifacts_and_event_reports( Ok(HttpResponseOk(response)) } +#[derive(Clone, Debug, JsonSchema, Deserialize)] +pub(crate) struct StartUpdateParams { + /// The SP identifiers to start the update with. Must be non-empty. + pub(crate) targets: BTreeSet, + + /// Options for the update. + pub(crate) options: StartUpdateOptions, +} + #[derive(Clone, Debug, JsonSchema, Deserialize)] pub(crate) struct StartUpdateOptions { /// If passed in, fails the update with a simulated error. @@ -730,19 +740,24 @@ impl UpdateTestError { log: &slog::Logger, reason: &str, ) -> HttpError { + let message = self.into_error_string(log, reason).await; + HttpError::for_bad_request(None, message) + } + + pub(crate) async fn into_error_string( + self, + log: &slog::Logger, + reason: &str, + ) -> String { match self { - UpdateTestError::Fail => HttpError::for_bad_request( - None, - format!("Simulated failure while {reason}"), - ), + UpdateTestError::Fail => { + format!("Simulated failure while {reason}") + } UpdateTestError::Timeout { secs } => { slog::info!(log, "Simulating timeout while {reason}"); // 15 seconds should be enough to cause a timeout. tokio::time::sleep(Duration::from_secs(secs)).await; - HttpError::for_bad_request( - None, - "XXX request should time out before this is hit".into(), - ) + "XXX request should time out before this is hit".into() } } } @@ -834,21 +849,27 @@ async fn get_location( })) } -/// An endpoint to start updating a sled. +/// An endpoint to start updating one or more sleds, switches and PSCs. #[endpoint { method = POST, - path = "/update/{type}/{slot}", + path = "/update", }] async fn post_start_update( rqctx: RequestContext, - target: Path, - opts: TypedBody, + params: TypedBody, ) -> Result { let log = &rqctx.log; let rqctx = rqctx.context(); - let target = target.into_inner(); + let params = params.into_inner(); + + if params.targets.is_empty() { + return Err(HttpError::for_bad_request( + None, + "No update targets specified".into(), + )); + } - // Can we update the target SP? We refuse to update if: + // Can we update the target SPs? We refuse to update if, for any target SP: // // 1. We haven't pulled its state in our inventory (most likely cause: the // cubby is empty; less likely cause: the SP is misbehaving, which will @@ -870,70 +891,136 @@ async fn post_start_update( } }; - // Next, do we have the state of the target SP? - let sp_state = match inventory { + // Error cases. + let mut inventory_absent = BTreeSet::new(); + let mut self_update = None; + let mut maybe_self_update = BTreeSet::new(); + + // Next, do we have the states of the target SP? + let sp_states = match inventory { GetInventoryResponse::Response { inventory, .. } => inventory .sps .into_iter() - .filter_map(|sp| if sp.id == target { sp.state } else { None }) - .next(), - GetInventoryResponse::Unavailable => None, - }; - let Some(sp_state) = sp_state else { - return Err(HttpError::for_bad_request( - None, - "cannot update target sled (no inventory state present)".into(), - )); + .filter_map(|sp| { + if params.targets.contains(&sp.id) { + if let Some(sp_state) = sp.state { + Some((sp.id, sp_state)) + } else { + None + } + } else { + None + } + }) + .collect(), + GetInventoryResponse::Unavailable => BTreeMap::new(), }; - // If we have the state of the SP, are we allowed to update it? We - // refuse to try to update our own sled. - match &rqctx.baseboard { - Some(baseboard) => { - if baseboard.identifier() == sp_state.serial_number - && baseboard.model() == sp_state.model - && baseboard.revision() == i64::from(sp_state.revision) - { - return Err(HttpError::for_bad_request( - None, - "cannot update sled where wicketd is running".into(), - )); + for target in ¶ms.targets { + let sp_state = match sp_states.get(target) { + Some(sp_state) => sp_state, + None => { + // The state isn't present, so add to inventory_absent. + inventory_absent.insert(*target); + continue; } - } - None => { - // We don't know our own baseboard, which is a very - // questionable state to be in! For now, we will hard-code - // the possibly locations where we could be running: - // scrimlets can only be in cubbies 14 or 16, so we refuse - // to update either of those. - let target_is_scrimlet = - matches!((target.type_, target.slot), (SpType::Sled, 14 | 16)); - if target_is_scrimlet { - return Err(HttpError::for_bad_request( - None, - "wicketd does not know its own baseboard details: \ - refusing to update either scrimlet" - .into(), - )); + }; + + // If we have the state of the SP, are we allowed to update it? We + // refuse to try to update our own sled. + match &rqctx.baseboard { + Some(baseboard) => { + if baseboard.identifier() == sp_state.serial_number + && baseboard.model() == sp_state.model + && baseboard.revision() == i64::from(sp_state.revision) + { + self_update = Some(*target); + continue; + } + } + None => { + // We don't know our own baseboard, which is a very questionable + // state to be in! For now, we will hard-code the possibly + // locations where we could be running: scrimlets can only be in + // cubbies 14 or 16, so we refuse to update either of those. + let target_is_scrimlet = matches!( + (target.type_, target.slot), + (SpType::Sled, 14 | 16) + ); + if target_is_scrimlet { + maybe_self_update.insert(*target); + continue; + } } } } - let opts = opts.into_inner(); - if let Some(test_error) = opts.test_error { - return Err(test_error.into_http_error(log, "starting update").await); + // Do we have any errors? + let mut errors = Vec::new(); + if !inventory_absent.is_empty() { + errors.push(format!( + "cannot update sleds (no inventory state present for {})", + sps_to_string(&inventory_absent) + )); + } + if let Some(self_update) = self_update { + errors.push(format!( + "cannot update sled where wicketd is running ({})", + SpIdentifierDisplay(self_update) + )); + } + if !maybe_self_update.is_empty() { + errors.push(format!( + "wicketd does not know its own baseboard details: \ + refusing to update either scrimlet ({})", + sps_to_string(&inventory_absent) + )); } - // All pre-flight update checks look OK: start the update. - // - // Generate an ID for this update; the update tracker will send it to the - // sled as part of the InstallinatorImageId, and installinator will send it - // back to our artifact server with its progress reports. - let update_id = Uuid::new_v4(); + if let Some(test_error) = ¶ms.options.test_error { + errors.push(test_error.into_error_string(log, "starting update").await); + } - match rqctx.update_tracker.start(target, update_id, opts).await { - Ok(()) => Ok(HttpResponseUpdatedNoContent {}), - Err(err) => Err(err.to_http_error()), + let start_update_errors = if errors.is_empty() { + // No errors: we can try and proceed with this update. + match rqctx.update_tracker.start(params.targets, params.options).await { + Ok(()) => return Ok(HttpResponseUpdatedNoContent {}), + Err(errors) => errors, + } + } else { + // We've already found errors, so all we want to do is to check whether + // the update tracker thinks there are any errors as well. + match rqctx.update_tracker.update_pre_checks(params.targets).await { + Ok(()) => Vec::new(), + Err(errors) => errors, + } + }; + + errors.extend(start_update_errors.iter().map(|error| error.to_string())); + + // If we get here, we have errors to report. + + match errors.len() { + 0 => { + unreachable!( + "we already returned Ok(_) above if there were no errors" + ) + } + 1 => { + return Err(HttpError::for_bad_request( + None, + errors.pop().unwrap(), + )); + } + _ => { + return Err(HttpError::for_bad_request( + None, + format!( + "multiple errors encountered:\n - {}", + itertools::join(errors, "\n - ") + ), + )); + } } } diff --git a/wicketd/src/lib.rs b/wicketd/src/lib.rs index 78209ea04a..e17c15642c 100644 --- a/wicketd/src/lib.rs +++ b/wicketd/src/lib.rs @@ -6,6 +6,7 @@ mod artifacts; mod bootstrap_addrs; mod config; mod context; +mod helpers; mod http_entrypoints; mod installinator_progress; mod inventory; diff --git a/wicketd/src/update_tracker.rs b/wicketd/src/update_tracker.rs index a95a98bd72..1bbda00158 100644 --- a/wicketd/src/update_tracker.rs +++ b/wicketd/src/update_tracker.rs @@ -7,6 +7,7 @@ use crate::artifacts::ArtifactIdData; use crate::artifacts::UpdatePlan; use crate::artifacts::WicketdArtifactStore; +use crate::helpers::sps_to_string; use crate::http_entrypoints::GetArtifactsAndEventReportsResponse; use crate::http_entrypoints::StartUpdateOptions; use crate::http_entrypoints::UpdateSimulatedResult; @@ -19,7 +20,6 @@ use anyhow::ensure; use anyhow::Context; use display_error_chain::DisplayErrorChain; use dropshot::HttpError; -use futures::Future; use gateway_client::types::HostPhase2Progress; use gateway_client::types::HostPhase2RecoveryImageId; use gateway_client::types::HostStartupOptions; @@ -156,146 +156,23 @@ impl UpdateTracker { pub(crate) async fn start( &self, - sp: SpIdentifier, - update_id: Uuid, + sps: BTreeSet, opts: StartUpdateOptions, - ) -> Result<(), StartUpdateError> { - self.start_impl(sp, |plan| async { - // Do we need to upload this plan's trampoline phase 2 to MGS? - let upload_trampoline_phase_2_to_mgs = { - let mut upload_trampoline_phase_2_to_mgs = - self.upload_trampoline_phase_2_to_mgs.lock().await; - - match upload_trampoline_phase_2_to_mgs.as_mut() { - Some(prev) => { - // We've previously started an upload - does it match - // this artifact? If not, cancel the old task (which - // might still be trying to upload) and start a new one - // with our current image. - if prev.status.borrow().hash - != plan.trampoline_phase_2.data.hash() - { - // It does _not_ match - we have a new plan with a - // different trampoline image. If the old task is - // still running, cancel it, and start a new one. - prev.task.abort(); - *prev = self - .spawn_upload_trampoline_phase_2_to_mgs(&plan); - } - } - None => { - *upload_trampoline_phase_2_to_mgs = Some( - self.spawn_upload_trampoline_phase_2_to_mgs(&plan), - ); - } - } - - // Both branches above leave `upload_trampoline_phase_2_to_mgs` - // with data, so we can unwrap here to clone the `watch` - // channel. - upload_trampoline_phase_2_to_mgs - .as_ref() - .unwrap() - .status - .clone() - }; - - let event_buffer = Arc::new(StdMutex::new(EventBuffer::new(16))); - let ipr_start_receiver = - self.ipr_update_tracker.register(update_id); - - let update_cx = UpdateContext { - update_id, - sp, - mgs_client: self.mgs_client.clone(), - upload_trampoline_phase_2_to_mgs, - log: self.log.new(o!( - "sp" => format!("{sp:?}"), - "update_id" => update_id.to_string(), - )), - }; - // TODO do we need `UpdateDriver` as a distinct type? - let update_driver = UpdateDriver {}; - - // Using a oneshot channel to communicate the abort handle isn't - // ideal, but it works and is the easiest way to send it without - // restructuring this code. - let (abort_handle_sender, abort_handle_receiver) = - oneshot::channel(); - let task = tokio::spawn(update_driver.run( - plan, - update_cx, - event_buffer.clone(), - ipr_start_receiver, - opts, - abort_handle_sender, - )); - - let abort_handle = abort_handle_receiver - .await - .expect("abort handle is sent immediately"); - - SpUpdateData { task, abort_handle, event_buffer } - }) - .await + ) -> Result<(), Vec> { + let imp = RealSpawnUpdateDriver { update_tracker: self, opts }; + self.start_impl(sps, Some(imp)).await } /// Starts a fake update that doesn't perform any steps, but simply waits - /// for a oneshot receiver to resolve. + /// for a watch receiver to resolve. #[doc(hidden)] pub async fn start_fake_update( &self, - sp: SpIdentifier, - oneshot_receiver: oneshot::Receiver<()>, - ) -> Result<(), StartUpdateError> { - self.start_impl(sp, |_plan| async move { - let (sender, mut receiver) = mpsc::channel(128); - let event_buffer = Arc::new(StdMutex::new(EventBuffer::new(16))); - let event_buffer_2 = event_buffer.clone(); - let log = self.log.clone(); - - let engine = UpdateEngine::new(&log, sender); - let abort_handle = engine.abort_handle(); - - let task = tokio::spawn(async move { - // The step component and ID have been chosen arbitrarily here -- - // they aren't important. - engine - .new_step( - UpdateComponent::Host, - UpdateStepId::RunningInstallinator, - "Fake step that waits for receiver to resolve", - move |_cx| async move { - _ = oneshot_receiver.await; - StepSuccess::new(()).into() - }, - ) - .register(); - - // Spawn a task to accept all events from the executing engine. - let event_receiving_task = tokio::spawn(async move { - while let Some(event) = receiver.recv().await { - event_buffer_2.lock().unwrap().add_event(event); - } - }); - - match engine.execute().await { - Ok(_cx) => (), - Err(err) => { - error!(log, "update failed"; "err" => %err); - } - } - - // Wait for all events to be received and written to the event - // buffer. - event_receiving_task - .await - .expect("event receiving task panicked"); - }); - - SpUpdateData { task, abort_handle, event_buffer } - }) - .await + sps: BTreeSet, + watch_receiver: watch::Receiver<()>, + ) -> Result<(), Vec> { + let imp = FakeUpdateDriver { watch_receiver, log: self.log.clone() }; + self.start_impl(sps, Some(imp)).await } pub(crate) async fn clear_update_state( @@ -315,40 +192,107 @@ impl UpdateTracker { update_data.abort_update(sp, message).await } - async fn start_impl( + /// Checks whether an update can be started for the given SPs, without + /// actually starting it. + /// + /// This should only be used in situations where starting the update is not + /// desired (for example, if we've already encountered errors earlier in the + /// process and we're just adding to the list of errors). In cases where the + /// start method *is* desired, prefer the [`Self::start`] method, which also + /// performs the same checks. + pub(crate) async fn update_pre_checks( &self, - sp: SpIdentifier, - spawn_update_driver: F, - ) -> Result<(), StartUpdateError> + sps: BTreeSet, + ) -> Result<(), Vec> { + self.start_impl::(sps, None).await + } + + async fn start_impl( + &self, + sps: BTreeSet, + spawn_update_driver: Option, + ) -> Result<(), Vec> where - F: FnOnce(UpdatePlan) -> Fut, - Fut: Future + Send, + Spawn: SpawnUpdateDriver, { let mut update_data = self.sp_update_data.lock().await; - let plan = update_data - .artifact_store - .current_plan() - .ok_or(StartUpdateError::TufRepositoryUnavailable)?; + let mut errors = Vec::new(); - match update_data.sp_update_data.entry(sp) { - // Vacant: this is the first time we've started an update to this - // sp. - Entry::Vacant(slot) => { - slot.insert(spawn_update_driver(plan).await); - Ok(()) - } - // Occupied: we've previously started an update to this sp; only - // allow this one if that update is no longer running. - Entry::Occupied(mut slot) => { - if slot.get().task.is_finished() { - slot.insert(spawn_update_driver(plan).await); - Ok(()) - } else { - Err(StartUpdateError::UpdateInProgress(sp)) + // Check that we're not already updating any of these SPs. + let update_in_progress: Vec<_> = sps + .iter() + .filter(|sp| { + // If we don't have any update data for this SP, it's not in + // progress. + // + // If we do, it's in progress if the task is not finished. + update_data + .sp_update_data + .get(sp) + .map_or(false, |data| !data.task.is_finished()) + }) + .copied() + .collect(); + + if !update_in_progress.is_empty() { + errors.push(StartUpdateError::UpdateInProgress(update_in_progress)); + } + + let plan = update_data.artifact_store.current_plan(); + if plan.is_none() { + // (1), referred to below. + errors.push(StartUpdateError::TufRepositoryUnavailable); + } + + // If there are any errors, return now. + if !errors.is_empty() { + return Err(errors); + } + + let plan = + plan.expect("we'd have returned an error at (1) if plan was None"); + + // Call the setup method now. + if let Some(mut spawn_update_driver) = spawn_update_driver { + let setup_data = spawn_update_driver.setup(&plan).await; + + for sp in sps { + match update_data.sp_update_data.entry(sp) { + // Vacant: this is the first time we've started an update to this + // sp. + Entry::Vacant(slot) => { + slot.insert( + spawn_update_driver + .spawn_update_driver( + sp, + plan.clone(), + &setup_data, + ) + .await, + ); + } + // Occupied: we've previously started an update to this sp. + Entry::Occupied(mut slot) => { + assert!( + slot.get().task.is_finished(), + "we just checked that the task was finished" + ); + slot.insert( + spawn_update_driver + .spawn_update_driver( + sp, + plan.clone(), + &setup_data, + ) + .await, + ); + } } } } + + Ok(()) } fn spawn_upload_trampoline_phase_2_to_mgs( @@ -425,6 +369,226 @@ impl UpdateTracker { } } +/// A trait that represents a backend implementation for spawning the update +/// driver. +#[async_trait::async_trait] +trait SpawnUpdateDriver { + /// The type returned by the [`Self::setup`] method. This is passed in by + /// reference to [`Self::spawn_update_driver`]. + type Setup; + + /// Perform setup required to spawn the update driver. + /// + /// This is called *once*, before any calls to + /// [`Self::spawn_update_driver`]. + async fn setup(&mut self, plan: &UpdatePlan) -> Self::Setup; + + /// Spawn the update driver for the given SP. + /// + /// This is called once per SP. + async fn spawn_update_driver( + &mut self, + sp: SpIdentifier, + plan: UpdatePlan, + setup_data: &Self::Setup, + ) -> SpUpdateData; +} + +/// The production implementation of [`SpawnUpdateDriver`]. +/// +/// This implementation spawns real update drivers. +#[derive(Debug)] +struct RealSpawnUpdateDriver<'tr> { + update_tracker: &'tr UpdateTracker, + opts: StartUpdateOptions, +} + +#[async_trait::async_trait] +impl<'tr> SpawnUpdateDriver for RealSpawnUpdateDriver<'tr> { + type Setup = watch::Receiver; + + async fn setup(&mut self, plan: &UpdatePlan) -> Self::Setup { + // Do we need to upload this plan's trampoline phase 2 to MGS? + + let mut upload_trampoline_phase_2_to_mgs = + self.update_tracker.upload_trampoline_phase_2_to_mgs.lock().await; + + match upload_trampoline_phase_2_to_mgs.as_mut() { + Some(prev) => { + // We've previously started an upload - does it match + // this artifact? If not, cancel the old task (which + // might still be trying to upload) and start a new one + // with our current image. + if prev.status.borrow().hash + != plan.trampoline_phase_2.data.hash() + { + // It does _not_ match - we have a new plan with a + // different trampoline image. If the old task is + // still running, cancel it, and start a new one. + prev.task.abort(); + *prev = self + .update_tracker + .spawn_upload_trampoline_phase_2_to_mgs(&plan); + } + } + None => { + *upload_trampoline_phase_2_to_mgs = Some( + self.update_tracker + .spawn_upload_trampoline_phase_2_to_mgs(&plan), + ); + } + } + + // Both branches above leave `upload_trampoline_phase_2_to_mgs` + // with data, so we can unwrap here to clone the `watch` + // channel. + upload_trampoline_phase_2_to_mgs.as_ref().unwrap().status.clone() + } + + async fn spawn_update_driver( + &mut self, + sp: SpIdentifier, + plan: UpdatePlan, + setup_data: &Self::Setup, + ) -> SpUpdateData { + // Generate an ID for this update; the update tracker will send it to the + // sled as part of the InstallinatorImageId, and installinator will send it + // back to our artifact server with its progress reports. + let update_id = Uuid::new_v4(); + + let event_buffer = Arc::new(StdMutex::new(EventBuffer::new(16))); + let ipr_start_receiver = + self.update_tracker.ipr_update_tracker.register(update_id); + + let update_cx = UpdateContext { + update_id, + sp, + mgs_client: self.update_tracker.mgs_client.clone(), + upload_trampoline_phase_2_to_mgs: setup_data.clone(), + log: self.update_tracker.log.new(o!( + "sp" => format!("{sp:?}"), + "update_id" => update_id.to_string(), + )), + }; + // TODO do we need `UpdateDriver` as a distinct type? + let update_driver = UpdateDriver {}; + + // Using a oneshot channel to communicate the abort handle isn't + // ideal, but it works and is the easiest way to send it without + // restructuring this code. + let (abort_handle_sender, abort_handle_receiver) = oneshot::channel(); + let task = tokio::spawn(update_driver.run( + plan, + update_cx, + event_buffer.clone(), + ipr_start_receiver, + self.opts.clone(), + abort_handle_sender, + )); + + let abort_handle = abort_handle_receiver + .await + .expect("abort handle is sent immediately"); + + SpUpdateData { task, abort_handle, event_buffer } + } +} + +/// A fake implementation of [`SpawnUpdateDriver`]. +/// +/// This implementation is only used by tests. It contains a single step that +/// waits for a [`watch::Receiver`] to resolve. +#[derive(Debug)] +struct FakeUpdateDriver { + watch_receiver: watch::Receiver<()>, + log: Logger, +} + +#[async_trait::async_trait] +impl SpawnUpdateDriver for FakeUpdateDriver { + type Setup = (); + + async fn setup(&mut self, _plan: &UpdatePlan) -> Self::Setup {} + + async fn spawn_update_driver( + &mut self, + _sp: SpIdentifier, + _plan: UpdatePlan, + _setup_data: &Self::Setup, + ) -> SpUpdateData { + let (sender, mut receiver) = mpsc::channel(128); + let event_buffer = Arc::new(StdMutex::new(EventBuffer::new(16))); + let event_buffer_2 = event_buffer.clone(); + let log = self.log.clone(); + + let engine = UpdateEngine::new(&log, sender); + let abort_handle = engine.abort_handle(); + + let mut watch_receiver = self.watch_receiver.clone(); + + let task = tokio::spawn(async move { + // The step component and ID have been chosen arbitrarily here -- + // they aren't important. + engine + .new_step( + UpdateComponent::Host, + UpdateStepId::RunningInstallinator, + "Fake step that waits for receiver to resolve", + move |_cx| async move { + // This will resolve as soon as the watch sender + // (typically a test) sends a value over the watch + // channel. + _ = watch_receiver.changed().await; + StepSuccess::new(()).into() + }, + ) + .register(); + + // Spawn a task to accept all events from the executing engine. + let event_receiving_task = tokio::spawn(async move { + while let Some(event) = receiver.recv().await { + event_buffer_2.lock().unwrap().add_event(event); + } + }); + + match engine.execute().await { + Ok(_cx) => (), + Err(err) => { + error!(log, "update failed"; "err" => %err); + } + } + + // Wait for all events to be received and written to the event + // buffer. + event_receiving_task.await.expect("event receiving task panicked"); + }); + + SpUpdateData { task, abort_handle, event_buffer } + } +} + +/// An implementation of [`SpawnUpdateDriver`] that cannot be constructed. +/// +/// This is an uninhabited type (an empty enum), and is only used to provide a +/// type parameter for the [`UpdateTracker::update_pre_checks`] method. +enum NeverUpdateDriver {} + +#[async_trait::async_trait] +impl SpawnUpdateDriver for NeverUpdateDriver { + type Setup = (); + + async fn setup(&mut self, _plan: &UpdatePlan) -> Self::Setup {} + + async fn spawn_update_driver( + &mut self, + _sp: SpIdentifier, + _plan: UpdatePlan, + _setup_data: &Self::Setup, + ) -> SpUpdateData { + unreachable!("this update driver cannot be constructed") + } +} + #[derive(Debug)] struct UpdateTrackerData { artifact_store: WicketdArtifactStore, @@ -518,21 +682,8 @@ impl UpdateTrackerData { pub enum StartUpdateError { #[error("no TUF repository available")] TufRepositoryUnavailable, - #[error("target is already being updated: {0:?}")] - UpdateInProgress(SpIdentifier), -} - -impl StartUpdateError { - pub(crate) fn to_http_error(&self) -> HttpError { - let message = DisplayErrorChain::new(self).to_string(); - - match self { - StartUpdateError::TufRepositoryUnavailable - | StartUpdateError::UpdateInProgress(_) => { - HttpError::for_bad_request(None, message) - } - } - } + #[error("targets are already being updated: {}", sps_to_string(.0))] + UpdateInProgress(Vec), } #[derive(Debug, Clone, Error, Eq, PartialEq)] diff --git a/wicketd/tests/integration_tests/updates.rs b/wicketd/tests/integration_tests/updates.rs index a4b330930a..a198068ef3 100644 --- a/wicketd/tests/integration_tests/updates.rs +++ b/wicketd/tests/integration_tests/updates.rs @@ -16,13 +16,13 @@ use omicron_common::{ api::internal::nexus::KnownArtifactKind, update::{ArtifactHashId, ArtifactKind}, }; -use tokio::sync::oneshot; +use tokio::sync::watch; use uuid::Uuid; use wicket_common::update_events::{StepEventKind, UpdateComponent}; use wicketd::{RunningUpdateState, StartUpdateError}; use wicketd_client::types::{ GetInventoryParams, GetInventoryResponse, SpIdentifier, SpType, - StartUpdateOptions, + StartUpdateOptions, StartUpdateParams, }; #[tokio::test] @@ -138,13 +138,11 @@ async fn test_updates() { } // Now, try starting the update on SP 0. + let options = StartUpdateOptions::default(); + let params = StartUpdateParams { targets: vec![target_sp], options }; wicketd_testctx .wicketd_client - .post_start_update( - target_sp.type_, - target_sp.slot, - &StartUpdateOptions::default(), - ) + .post_start_update(¶ms) .await .expect("update started successfully"); @@ -352,12 +350,13 @@ async fn test_update_races() { slot: 0, type_: gateway_client::types::SpType::Sled, }; + let sps: BTreeSet<_> = vec![sp].into_iter().collect(); - let (sender, receiver) = oneshot::channel(); + let (sender, receiver) = watch::channel(()); wicketd_testctx .server .update_tracker - .start_fake_update(sp, receiver) + .start_fake_update(sps.clone(), receiver) .await .expect("start_fake_update successful"); @@ -372,14 +371,18 @@ async fn test_update_races() { // Also try starting another fake update, which should fail -- we don't let // updates be started in the middle of other updates. { - let (_, receiver) = oneshot::channel(); + let (_, receiver) = watch::channel(()); let err = wicketd_testctx .server .update_tracker - .start_fake_update(sp, receiver) + .start_fake_update(sps, receiver) .await .expect_err("start_fake_update failed while update is running"); - assert_eq!(err, StartUpdateError::UpdateInProgress(sp)); + assert_eq!(err.len(), 1, "one error returned: {err:?}"); + assert_eq!( + err.first().unwrap(), + &StartUpdateError::UpdateInProgress(vec![sp]) + ); } // Unblock the update, letting it run to completion.