diff --git a/nexus/src/app/test_interfaces.rs b/nexus/src/app/test_interfaces.rs index 6161a9a1c1..581b9a89bb 100644 --- a/nexus/src/app/test_interfaces.rs +++ b/nexus/src/app/test_interfaces.rs @@ -10,10 +10,9 @@ use sled_agent_client::Client as SledAgentClient; use std::sync::Arc; use uuid::Uuid; +pub use super::update::HostPhase1Updater; pub use super::update::MgsClients; -pub use super::update::RotUpdateError; pub use super::update::RotUpdater; -pub use super::update::SpUpdateError; pub use super::update::SpUpdater; pub use super::update::UpdateProgress; pub use gateway_client::types::SpType; diff --git a/nexus/src/app/update/common_sp_update.rs b/nexus/src/app/update/common_sp_update.rs new file mode 100644 index 0000000000..69a5b132a2 --- /dev/null +++ b/nexus/src/app/update/common_sp_update.rs @@ -0,0 +1,239 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +//! Module containing implementation details shared amongst all MGS-to-SP-driven +//! updates. + +use super::MgsClients; +use super::UpdateProgress; +use gateway_client::types::SpType; +use gateway_client::types::SpUpdateStatus; +use slog::Logger; +use std::time::Duration; +use tokio::sync::watch; +use uuid::Uuid; + +type GatewayClientError = gateway_client::Error; + +/// Error type returned when an update to a component managed by the SP fails. +/// +/// Note that the SP manages itself, as well, so "SP component" here includes +/// the SP. +#[derive(Debug, thiserror::Error)] +pub enum SpComponentUpdateError { + #[error("error communicating with MGS")] + MgsCommunication(#[from] GatewayClientError), + #[error("different update is now preparing ({0})")] + DifferentUpdatePreparing(Uuid), + #[error("different update is now in progress ({0})")] + DifferentUpdateInProgress(Uuid), + #[error("different update is now complete ({0})")] + DifferentUpdateComplete(Uuid), + #[error("different update is now aborted ({0})")] + DifferentUpdateAborted(Uuid), + #[error("different update failed ({0})")] + DifferentUpdateFailed(Uuid), + #[error("update status lost (did the SP reset?)")] + UpdateStatusLost, + #[error("update was aborted")] + UpdateAborted, + #[error("update failed (error code {0})")] + UpdateFailedWithCode(u32), + #[error("update failed (error message {0})")] + UpdateFailedWithMessage(String), +} + +pub(super) trait SpComponentUpdater { + /// The target component. + /// + /// Should be produced via `SpComponent::const_as_str()`. + fn component(&self) -> &'static str; + + /// The type of the target SP. + fn target_sp_type(&self) -> SpType; + + /// The slot number of the target SP. + fn target_sp_slot(&self) -> u32; + + /// The target firmware slot for the component. + fn firmware_slot(&self) -> u16; + + /// The ID of this update. + fn update_id(&self) -> Uuid; + + /// The update payload data to send to MGS. + // TODO-performance This has to be convertible into a `reqwest::Body`, so we + // return an owned Vec. That requires all our implementors to clone the data + // at least once; maybe we should use `Bytes` instead (which is cheap to + // clone and also convertible into a reqwest::Body)? + fn update_data(&self) -> Vec; + + /// The sending half of the watch channel to report update progress. + fn progress(&self) -> &watch::Sender>; + + /// Logger to use while performing this update. + fn logger(&self) -> &Logger; +} + +pub(super) async fn deliver_update( + updater: &(dyn SpComponentUpdater + Send + Sync), + mgs_clients: &mut MgsClients, +) -> Result<(), SpComponentUpdateError> { + // How frequently do we poll MGS for the update progress? + const STATUS_POLL_INTERVAL: Duration = Duration::from_secs(3); + + // Start the update. + mgs_clients + .try_all_serially(updater.logger(), |client| async move { + client + .sp_component_update( + updater.target_sp_type(), + updater.target_sp_slot(), + updater.component(), + updater.firmware_slot(), + &updater.update_id(), + reqwest::Body::from(updater.update_data()), + ) + .await?; + updater.progress().send_replace(Some(UpdateProgress::Started)); + info!( + updater.logger(), "update started"; + "mgs_addr" => client.baseurl(), + ); + Ok(()) + }) + .await?; + + // Wait for the update to complete. + loop { + let status = mgs_clients + .try_all_serially(updater.logger(), |client| async move { + let update_status = client + .sp_component_update_status( + updater.target_sp_type(), + updater.target_sp_slot(), + updater.component(), + ) + .await?; + + debug!( + updater.logger(), "got update status"; + "mgs_addr" => client.baseurl(), + "status" => ?update_status, + ); + + Ok(update_status) + }) + .await?; + + if status_is_complete( + status.into_inner(), + updater.update_id(), + updater.progress(), + updater.logger(), + )? { + updater.progress().send_replace(Some(UpdateProgress::InProgress { + progress: Some(1.0), + })); + return Ok(()); + } + + tokio::time::sleep(STATUS_POLL_INTERVAL).await; + } +} + +fn status_is_complete( + status: SpUpdateStatus, + update_id: Uuid, + progress_tx: &watch::Sender>, + log: &Logger, +) -> Result { + match status { + // For `Preparing` and `InProgress`, we could check the progress + // information returned by these steps and try to check that + // we're still _making_ progress, but every Nexus instance needs + // to do that anyway in case we (or the MGS instance delivering + // the update) crash, so we'll omit that check here. Instead, we + // just sleep and we'll poll again shortly. + SpUpdateStatus::Preparing { id, progress } => { + if id == update_id { + let progress = progress.and_then(|progress| { + if progress.current > progress.total { + warn!( + log, "nonsense preparing progress"; + "current" => progress.current, + "total" => progress.total, + ); + None + } else if progress.total == 0 { + None + } else { + Some( + f64::from(progress.current) + / f64::from(progress.total), + ) + } + }); + progress_tx + .send_replace(Some(UpdateProgress::Preparing { progress })); + Ok(false) + } else { + Err(SpComponentUpdateError::DifferentUpdatePreparing(id)) + } + } + SpUpdateStatus::InProgress { id, bytes_received, total_bytes } => { + if id == update_id { + let progress = if bytes_received > total_bytes { + warn!( + log, "nonsense update progress"; + "bytes_received" => bytes_received, + "total_bytes" => total_bytes, + ); + None + } else if total_bytes == 0 { + None + } else { + Some(f64::from(bytes_received) / f64::from(total_bytes)) + }; + progress_tx.send_replace(Some(UpdateProgress::InProgress { + progress, + })); + Ok(false) + } else { + Err(SpComponentUpdateError::DifferentUpdateInProgress(id)) + } + } + SpUpdateStatus::Complete { id } => { + if id == update_id { + Ok(true) + } else { + Err(SpComponentUpdateError::DifferentUpdateComplete(id)) + } + } + SpUpdateStatus::None => Err(SpComponentUpdateError::UpdateStatusLost), + SpUpdateStatus::Aborted { id } => { + if id == update_id { + Err(SpComponentUpdateError::UpdateAborted) + } else { + Err(SpComponentUpdateError::DifferentUpdateAborted(id)) + } + } + SpUpdateStatus::Failed { code, id } => { + if id == update_id { + Err(SpComponentUpdateError::UpdateFailedWithCode(code)) + } else { + Err(SpComponentUpdateError::DifferentUpdateFailed(id)) + } + } + SpUpdateStatus::RotError { id, message } => { + if id == update_id { + Err(SpComponentUpdateError::UpdateFailedWithMessage(format!( + "rot error: {message}" + ))) + } else { + Err(SpComponentUpdateError::DifferentUpdateFailed(id)) + } + } + } +} diff --git a/nexus/src/app/update/host_phase1_updater.rs b/nexus/src/app/update/host_phase1_updater.rs new file mode 100644 index 0000000000..fb013d0ffe --- /dev/null +++ b/nexus/src/app/update/host_phase1_updater.rs @@ -0,0 +1,177 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +//! Module containing types for updating host OS phase1 images via MGS. + +use super::common_sp_update::deliver_update; +use super::common_sp_update::SpComponentUpdater; +use super::MgsClients; +use super::SpComponentUpdateError; +use super::UpdateProgress; +use gateway_client::types::SpComponentFirmwareSlot; +use gateway_client::types::SpType; +use gateway_client::SpComponent; +use slog::Logger; +use tokio::sync::watch; +use uuid::Uuid; + +type GatewayClientError = gateway_client::Error; + +pub struct HostPhase1Updater { + log: Logger, + progress: watch::Sender>, + sp_type: SpType, + sp_slot: u32, + target_host_slot: u16, + update_id: Uuid, + // TODO-clarity maybe a newtype for this? TBD how we get this from + // wherever it's stored, which might give us a stronger type already. + phase1_data: Vec, +} + +impl HostPhase1Updater { + pub fn new( + sp_type: SpType, + sp_slot: u32, + target_host_slot: u16, + update_id: Uuid, + phase1_data: Vec, + log: &Logger, + ) -> Self { + let log = log.new(slog::o!( + "component" => "HostPhase1Updater", + "sp_type" => format!("{sp_type:?}"), + "sp_slot" => sp_slot, + "target_host_slot" => target_host_slot, + "update_id" => format!("{update_id}"), + )); + let progress = watch::Sender::new(None); + Self { + log, + progress, + sp_type, + sp_slot, + target_host_slot, + update_id, + phase1_data, + } + } + + pub fn progress_watcher(&self) -> watch::Receiver> { + self.progress.subscribe() + } + + /// Drive this host phase 1 update to completion (or failure). + /// + /// Only one MGS instance is required to drive an update; however, if + /// multiple MGS instances are available and passed to this method and an + /// error occurs communicating with one instance, `HostPhase1Updater` will + /// try the remaining instances before failing. + pub async fn update( + mut self, + mgs_clients: &mut MgsClients, + ) -> Result<(), SpComponentUpdateError> { + // The async block below wants a `&self` reference, but we take `self` + // for API clarity (to start a new update, the caller should construct a + // new instance of the updater). Create a `&self` ref that we use + // through the remainder of this method. + let me = &self; + + // Prior to delivering the update, ensure the correct target slot is + // activated. + // + // TODO-correctness Should we be doing this, or should a higher level + // executor set this up before calling us? + mgs_clients + .try_all_serially(&self.log, |client| async move { + me.mark_target_slot_active(&client).await + }) + .await?; + + // Deliver and drive the update to completion + deliver_update(&mut self, mgs_clients).await?; + + // Unlike SP and RoT updates, we have nothing to do after delivery of + // the update completes; signal to any watchers that we're done. + self.progress.send_replace(Some(UpdateProgress::Complete)); + + // wait for any progress watchers to be dropped before we return; + // otherwise, they'll get `RecvError`s when trying to check the current + // status + self.progress.closed().await; + + Ok(()) + } + + async fn mark_target_slot_active( + &self, + client: &gateway_client::Client, + ) -> Result<(), GatewayClientError> { + // TODO-correctness Should we always persist this choice? + let persist = true; + + let slot = self.firmware_slot(); + + // TODO-correctness Until + // https://github.com/oxidecomputer/hubris/issues/1172 is fixed, the + // host must be in A2 for this operation to succeed. After it is fixed, + // there will still be a window while a host is booting where this + // operation can fail. How do we handle this? + client + .sp_component_active_slot_set( + self.sp_type, + self.sp_slot, + self.component(), + persist, + &SpComponentFirmwareSlot { slot }, + ) + .await?; + + // TODO-correctness Should we send some kind of update to + // `self.progress`? We haven't actually started delivering an update + // yet, but it seems weird to give no indication that we have + // successfully (potentially) modified the state of the target sled. + + info!( + self.log, "host phase1 target slot marked active"; + "mgs_addr" => client.baseurl(), + ); + + Ok(()) + } +} + +impl SpComponentUpdater for HostPhase1Updater { + fn component(&self) -> &'static str { + SpComponent::HOST_CPU_BOOT_FLASH.const_as_str() + } + + fn target_sp_type(&self) -> SpType { + self.sp_type + } + + fn target_sp_slot(&self) -> u32 { + self.sp_slot + } + + fn firmware_slot(&self) -> u16 { + self.target_host_slot + } + + fn update_id(&self) -> Uuid { + self.update_id + } + + fn update_data(&self) -> Vec { + self.phase1_data.clone() + } + + fn progress(&self) -> &watch::Sender> { + &self.progress + } + + fn logger(&self) -> &Logger { + &self.log + } +} diff --git a/nexus/src/app/update/mgs_clients.rs b/nexus/src/app/update/mgs_clients.rs index 5915505829..4b200a1819 100644 --- a/nexus/src/app/update/mgs_clients.rs +++ b/nexus/src/app/update/mgs_clients.rs @@ -5,53 +5,14 @@ //! Module providing support for handling failover between multiple MGS clients use futures::Future; -use gateway_client::types::SpType; -use gateway_client::types::SpUpdateStatus; use gateway_client::Client; use slog::Logger; use std::collections::VecDeque; use std::sync::Arc; -use uuid::Uuid; pub(super) type GatewayClientError = gateway_client::Error; -pub(super) enum PollUpdateStatus { - Preparing { progress: Option }, - InProgress { progress: Option }, - Complete, -} - -#[derive(Debug, thiserror::Error)] -pub enum UpdateStatusError { - #[error("different update is now preparing ({0})")] - DifferentUpdatePreparing(Uuid), - #[error("different update is now in progress ({0})")] - DifferentUpdateInProgress(Uuid), - #[error("different update is now complete ({0})")] - DifferentUpdateComplete(Uuid), - #[error("different update is now aborted ({0})")] - DifferentUpdateAborted(Uuid), - #[error("different update failed ({0})")] - DifferentUpdateFailed(Uuid), - #[error("update status lost (did the SP reset?)")] - UpdateStatusLost, - #[error("update was aborted")] - UpdateAborted, - #[error("update failed (error code {0})")] - UpdateFailedWithCode(u32), - #[error("update failed (error message {0})")] - UpdateFailedWithMessage(String), -} - -#[derive(Debug, thiserror::Error)] -pub(super) enum PollUpdateStatusError { - #[error(transparent)] - StatusError(#[from] UpdateStatusError), - #[error(transparent)] - ClientError(#[from] GatewayClientError), -} - #[derive(Debug, Clone)] pub struct MgsClients { clients: VecDeque>, @@ -130,111 +91,4 @@ impl MgsClients { // errors. Return the error from the last MGS we tried. Err(GatewayClientError::CommunicationError(last_err.unwrap())) } - - /// Poll for the status of an expected-to-be-in-progress update. - pub(super) async fn poll_update_status( - &mut self, - sp_type: SpType, - sp_slot: u32, - component: &'static str, - update_id: Uuid, - log: &Logger, - ) -> Result { - let update_status = self - .try_all_serially(log, |client| async move { - let update_status = client - .sp_component_update_status(sp_type, sp_slot, component) - .await?; - - debug!( - log, "got update status"; - "mgs_addr" => client.baseurl(), - "status" => ?update_status, - ); - - Ok(update_status) - }) - .await? - .into_inner(); - - match update_status { - SpUpdateStatus::Preparing { id, progress } => { - if id == update_id { - let progress = progress.and_then(|progress| { - if progress.current > progress.total { - warn!( - log, "nonsense preparing progress"; - "current" => progress.current, - "total" => progress.total, - ); - None - } else if progress.total == 0 { - None - } else { - Some( - f64::from(progress.current) - / f64::from(progress.total), - ) - } - }); - Ok(PollUpdateStatus::Preparing { progress }) - } else { - Err(UpdateStatusError::DifferentUpdatePreparing(id).into()) - } - } - SpUpdateStatus::InProgress { id, bytes_received, total_bytes } => { - if id == update_id { - let progress = if bytes_received > total_bytes { - warn!( - log, "nonsense update progress"; - "bytes_received" => bytes_received, - "total_bytes" => total_bytes, - ); - None - } else if total_bytes == 0 { - None - } else { - Some(f64::from(bytes_received) / f64::from(total_bytes)) - }; - Ok(PollUpdateStatus::InProgress { progress }) - } else { - Err(UpdateStatusError::DifferentUpdateInProgress(id).into()) - } - } - SpUpdateStatus::Complete { id } => { - if id == update_id { - Ok(PollUpdateStatus::Complete) - } else { - Err(UpdateStatusError::DifferentUpdateComplete(id).into()) - } - } - SpUpdateStatus::None => { - Err(UpdateStatusError::UpdateStatusLost.into()) - } - SpUpdateStatus::Aborted { id } => { - if id == update_id { - Err(UpdateStatusError::UpdateAborted.into()) - } else { - Err(UpdateStatusError::DifferentUpdateAborted(id).into()) - } - } - SpUpdateStatus::Failed { code, id } => { - if id == update_id { - Err(UpdateStatusError::UpdateFailedWithCode(code).into()) - } else { - Err(UpdateStatusError::DifferentUpdateFailed(id).into()) - } - } - SpUpdateStatus::RotError { id, message } => { - if id == update_id { - Err(UpdateStatusError::UpdateFailedWithMessage(format!( - "rot error: {message}" - )) - .into()) - } else { - Err(UpdateStatusError::DifferentUpdateFailed(id).into()) - } - } - } - } } diff --git a/nexus/src/app/update/mod.rs b/nexus/src/app/update/mod.rs index 7d5c642822..5075e421ae 100644 --- a/nexus/src/app/update/mod.rs +++ b/nexus/src/app/update/mod.rs @@ -26,13 +26,17 @@ use std::path::Path; use tokio::io::AsyncWriteExt; use uuid::Uuid; +mod common_sp_update; +mod host_phase1_updater; mod mgs_clients; mod rot_updater; mod sp_updater; -pub use mgs_clients::{MgsClients, UpdateStatusError}; -pub use rot_updater::{RotUpdateError, RotUpdater}; -pub use sp_updater::{SpUpdateError, SpUpdater}; +pub use common_sp_update::SpComponentUpdateError; +pub use host_phase1_updater::HostPhase1Updater; +pub use mgs_clients::MgsClients; +pub use rot_updater::RotUpdater; +pub use sp_updater::SpUpdater; #[derive(Debug, PartialEq, Clone)] pub enum UpdateProgress { diff --git a/nexus/src/app/update/rot_updater.rs b/nexus/src/app/update/rot_updater.rs index d7d21e3b3a..12126a7de9 100644 --- a/nexus/src/app/update/rot_updater.rs +++ b/nexus/src/app/update/rot_updater.rs @@ -4,40 +4,21 @@ //! Module containing types for updating RoTs via MGS. -use super::mgs_clients::PollUpdateStatusError; +use super::common_sp_update::deliver_update; +use super::common_sp_update::SpComponentUpdater; use super::MgsClients; +use super::SpComponentUpdateError; use super::UpdateProgress; -use super::UpdateStatusError; -use crate::app::update::mgs_clients::PollUpdateStatus; use gateway_client::types::RotSlot; use gateway_client::types::SpComponentFirmwareSlot; use gateway_client::types::SpType; use gateway_client::SpComponent; use slog::Logger; -use std::time::Duration; use tokio::sync::watch; use uuid::Uuid; type GatewayClientError = gateway_client::Error; -#[derive(Debug, thiserror::Error)] -pub enum RotUpdateError { - #[error("error communicating with MGS")] - MgsCommunication(#[from] GatewayClientError), - - #[error("failed checking update status: {0}")] - PollUpdateStatus(#[from] UpdateStatusError), -} - -impl From for RotUpdateError { - fn from(err: PollUpdateStatusError) -> Self { - match err { - PollUpdateStatusError::StatusError(err) => err.into(), - PollUpdateStatusError::ClientError(err) => err.into(), - } - } -} - pub struct RotUpdater { log: Logger, progress: watch::Sender>, @@ -89,9 +70,14 @@ impl RotUpdater { /// error occurs communicating with one instance, `RotUpdater` will try the /// remaining instances before failing. pub async fn update( - self, - mut mgs_clients: MgsClients, - ) -> Result<(), RotUpdateError> { + mut self, + mgs_clients: &mut MgsClients, + ) -> Result<(), SpComponentUpdateError> { + // Deliver and drive the update to "completion" (which isn't really + // complete for the RoT, since we still have to do the steps below after + // the delivery of the update completes). + deliver_update(&mut self, mgs_clients).await?; + // The async blocks below want `&self` references, but we take `self` // for API clarity (to start a new update, the caller should construct a // new updater). Create a `&self` ref that we use through the remainder @@ -100,23 +86,13 @@ impl RotUpdater { mgs_clients .try_all_serially(&self.log, |client| async move { - me.start_update_one_mgs(&client).await - }) - .await?; - - // `wait_for_update_completion` uses `try_all_mgs_clients` internally, - // so we don't wrap it here. - me.wait_for_update_completion(&mut mgs_clients).await?; - - mgs_clients - .try_all_serially(&self.log, |client| async move { - me.mark_target_slot_active_one_mgs(&client).await + me.mark_target_slot_active(&client).await }) .await?; mgs_clients .try_all_serially(&self.log, |client| async move { - me.finalize_update_via_reset_one_mgs(&client).await + me.finalize_update_via_reset(&client).await }) .await?; @@ -128,82 +104,7 @@ impl RotUpdater { Ok(()) } - async fn start_update_one_mgs( - &self, - client: &gateway_client::Client, - ) -> Result<(), GatewayClientError> { - let firmware_slot = self.target_rot_slot.as_u16(); - - // Start the update. - client - .sp_component_update( - self.sp_type, - self.sp_slot, - SpComponent::ROT.const_as_str(), - firmware_slot, - &self.update_id, - reqwest::Body::from(self.rot_hubris_archive.clone()), - ) - .await?; - - self.progress.send_replace(Some(UpdateProgress::Started)); - - info!( - self.log, "RoT update started"; - "mgs_addr" => client.baseurl(), - ); - - Ok(()) - } - - async fn wait_for_update_completion( - &self, - mgs_clients: &mut MgsClients, - ) -> Result<(), RotUpdateError> { - // How frequently do we poll MGS for the update progress? - const STATUS_POLL_INTERVAL: Duration = Duration::from_secs(3); - - loop { - let status = mgs_clients - .poll_update_status( - self.sp_type, - self.sp_slot, - SpComponent::ROT.const_as_str(), - self.update_id, - &self.log, - ) - .await?; - - // For `Preparing` and `InProgress`, we could check the progress - // information returned by these steps and try to check that - // we're still _making_ progress, but every Nexus instance needs - // to do that anyway in case we (or the MGS instance delivering - // the update) crash, so we'll omit that check here. Instead, we - // just sleep and we'll poll again shortly. - match status { - PollUpdateStatus::Preparing { progress } => { - self.progress.send_replace(Some( - UpdateProgress::Preparing { progress }, - )); - } - PollUpdateStatus::InProgress { progress } => { - self.progress.send_replace(Some( - UpdateProgress::InProgress { progress }, - )); - } - PollUpdateStatus::Complete => { - self.progress.send_replace(Some( - UpdateProgress::InProgress { progress: Some(1.0) }, - )); - return Ok(()); - } - } - - tokio::time::sleep(STATUS_POLL_INTERVAL).await; - } - } - - async fn mark_target_slot_active_one_mgs( + async fn mark_target_slot_active( &self, client: &gateway_client::Client, ) -> Result<(), GatewayClientError> { @@ -211,13 +112,13 @@ impl RotUpdater { // tell it to persist our choice. let persist = true; - let slot = self.target_rot_slot.as_u16(); + let slot = self.firmware_slot(); client .sp_component_active_slot_set( self.sp_type, self.sp_slot, - SpComponent::ROT.const_as_str(), + self.component(), persist, &SpComponentFirmwareSlot { slot }, ) @@ -236,16 +137,12 @@ impl RotUpdater { Ok(()) } - async fn finalize_update_via_reset_one_mgs( + async fn finalize_update_via_reset( &self, client: &gateway_client::Client, ) -> Result<(), GatewayClientError> { client - .sp_component_reset( - self.sp_type, - self.sp_slot, - SpComponent::ROT.const_as_str(), - ) + .sp_component_reset(self.sp_type, self.sp_slot, self.component()) .await?; self.progress.send_replace(Some(UpdateProgress::Complete)); @@ -258,15 +155,39 @@ impl RotUpdater { } } -trait RotSlotAsU16 { - fn as_u16(&self) -> u16; -} +impl SpComponentUpdater for RotUpdater { + fn component(&self) -> &'static str { + SpComponent::ROT.const_as_str() + } + + fn target_sp_type(&self) -> SpType { + self.sp_type + } -impl RotSlotAsU16 for RotSlot { - fn as_u16(&self) -> u16 { - match self { + fn target_sp_slot(&self) -> u32 { + self.sp_slot + } + + fn firmware_slot(&self) -> u16 { + match self.target_rot_slot { RotSlot::A => 0, RotSlot::B => 1, } } + + fn update_id(&self) -> Uuid { + self.update_id + } + + fn update_data(&self) -> Vec { + self.rot_hubris_archive.clone() + } + + fn progress(&self) -> &watch::Sender> { + &self.progress + } + + fn logger(&self) -> &Logger { + &self.log + } } diff --git a/nexus/src/app/update/sp_updater.rs b/nexus/src/app/update/sp_updater.rs index 419a733441..2a6ddc6de6 100644 --- a/nexus/src/app/update/sp_updater.rs +++ b/nexus/src/app/update/sp_updater.rs @@ -4,39 +4,19 @@ //! Module containing types for updating SPs via MGS. -use crate::app::update::mgs_clients::PollUpdateStatus; - -use super::mgs_clients::PollUpdateStatusError; +use super::common_sp_update::deliver_update; +use super::common_sp_update::SpComponentUpdater; use super::MgsClients; +use super::SpComponentUpdateError; use super::UpdateProgress; -use super::UpdateStatusError; use gateway_client::types::SpType; use gateway_client::SpComponent; use slog::Logger; -use std::time::Duration; use tokio::sync::watch; use uuid::Uuid; type GatewayClientError = gateway_client::Error; -#[derive(Debug, thiserror::Error)] -pub enum SpUpdateError { - #[error("error communicating with MGS")] - MgsCommunication(#[from] GatewayClientError), - - #[error("failed checking update status: {0}")] - PollUpdateStatus(#[from] UpdateStatusError), -} - -impl From for SpUpdateError { - fn from(err: PollUpdateStatusError) -> Self { - match err { - PollUpdateStatusError::StatusError(err) => err.into(), - PollUpdateStatusError::ClientError(err) => err.into(), - } - } -} - pub struct SpUpdater { log: Logger, progress: watch::Sender>, @@ -77,10 +57,15 @@ impl SpUpdater { /// error occurs communicating with one instance, `SpUpdater` will try the /// remaining instances before failing. pub async fn update( - self, - mut mgs_clients: MgsClients, - ) -> Result<(), SpUpdateError> { - // The async blocks below want `&self` references, but we take `self` + mut self, + mgs_clients: &mut MgsClients, + ) -> Result<(), SpComponentUpdateError> { + // Deliver and drive the update to "completion" (which isn't really + // complete for the SP, since we still have to reset it after the + // delivery of the update completes). + deliver_update(&mut self, mgs_clients).await?; + + // The async block below wants a `&self` reference, but we take `self` // for API clarity (to start a new SP update, the caller should // construct a new `SpUpdater`). Create a `&self` ref that we use // through the remainder of this method. @@ -88,17 +73,7 @@ impl SpUpdater { mgs_clients .try_all_serially(&self.log, |client| async move { - me.start_update_one_mgs(&client).await - }) - .await?; - - // `wait_for_update_completion` uses `try_all_mgs_clients` internally, - // so we don't wrap it here. - me.wait_for_update_completion(&mut mgs_clients).await?; - - mgs_clients - .try_all_serially(&self.log, |client| async move { - me.finalize_update_via_reset_one_mgs(&client).await + me.finalize_update_via_reset(&client).await }) .await?; @@ -110,102 +85,57 @@ impl SpUpdater { Ok(()) } - async fn start_update_one_mgs( + async fn finalize_update_via_reset( &self, client: &gateway_client::Client, ) -> Result<(), GatewayClientError> { - // The SP has two firmware slots, but they're aren't individually - // labled. We always request an update to slot 0, which means "the - // inactive slot". - let firmware_slot = 0; - - // Start the update. client - .sp_component_update( - self.sp_type, - self.sp_slot, - SpComponent::SP_ITSELF.const_as_str(), - firmware_slot, - &self.update_id, - reqwest::Body::from(self.sp_hubris_archive.clone()), - ) + .sp_component_reset(self.sp_type, self.sp_slot, self.component()) .await?; - self.progress.send_replace(Some(UpdateProgress::Started)); - + self.progress.send_replace(Some(UpdateProgress::Complete)); info!( - self.log, "SP update started"; + self.log, "SP update complete"; "mgs_addr" => client.baseurl(), ); Ok(()) } +} - async fn wait_for_update_completion( - &self, - mgs_clients: &mut MgsClients, - ) -> Result<(), SpUpdateError> { - // How frequently do we poll MGS for the update progress? - const STATUS_POLL_INTERVAL: Duration = Duration::from_secs(3); - - loop { - let status = mgs_clients - .poll_update_status( - self.sp_type, - self.sp_slot, - SpComponent::SP_ITSELF.const_as_str(), - self.update_id, - &self.log, - ) - .await?; - - // For `Preparing` and `InProgress`, we could check the progress - // information returned by these steps and try to check that - // we're still _making_ progress, but every Nexus instance needs - // to do that anyway in case we (or the MGS instance delivering - // the update) crash, so we'll omit that check here. Instead, we - // just sleep and we'll poll again shortly. - match status { - PollUpdateStatus::Preparing { progress } => { - self.progress.send_replace(Some( - UpdateProgress::Preparing { progress }, - )); - } - PollUpdateStatus::InProgress { progress } => { - self.progress.send_replace(Some( - UpdateProgress::InProgress { progress }, - )); - } - PollUpdateStatus::Complete => { - self.progress.send_replace(Some( - UpdateProgress::InProgress { progress: Some(1.0) }, - )); - return Ok(()); - } - } - - tokio::time::sleep(STATUS_POLL_INTERVAL).await; - } +impl SpComponentUpdater for SpUpdater { + fn component(&self) -> &'static str { + SpComponent::SP_ITSELF.const_as_str() } - async fn finalize_update_via_reset_one_mgs( - &self, - client: &gateway_client::Client, - ) -> Result<(), GatewayClientError> { - client - .sp_component_reset( - self.sp_type, - self.sp_slot, - SpComponent::SP_ITSELF.const_as_str(), - ) - .await?; + fn target_sp_type(&self) -> SpType { + self.sp_type + } - self.progress.send_replace(Some(UpdateProgress::Complete)); - info!( - self.log, "SP update complete"; - "mgs_addr" => client.baseurl(), - ); + fn target_sp_slot(&self) -> u32 { + self.sp_slot + } - Ok(()) + fn firmware_slot(&self) -> u16 { + // The SP has two firmware slots, but they're aren't individually + // labled. We always request an update to slot 0, which means "the + // inactive slot". + 0 + } + + fn update_id(&self) -> Uuid { + self.update_id + } + + fn update_data(&self) -> Vec { + self.sp_hubris_archive.clone() + } + + fn progress(&self) -> &watch::Sender> { + &self.progress + } + + fn logger(&self) -> &Logger { + &self.log } } diff --git a/nexus/tests/integration_tests/host_phase1_updater.rs b/nexus/tests/integration_tests/host_phase1_updater.rs new file mode 100644 index 0000000000..01d546636e --- /dev/null +++ b/nexus/tests/integration_tests/host_phase1_updater.rs @@ -0,0 +1,584 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +//! Tests `HostPhase1Updater`'s delivery of updates to host phase 1 flash via +//! MGS to SP. + +use gateway_client::types::SpType; +use gateway_messages::{SpPort, UpdateInProgressStatus, UpdateStatus}; +use gateway_test_utils::setup as mgs_setup; +use omicron_nexus::app::test_interfaces::{ + HostPhase1Updater, MgsClients, UpdateProgress, +}; +use rand::RngCore; +use sp_sim::SimulatedSp; +use std::mem; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::AsyncWriteExt; +use tokio::net::TcpListener; +use tokio::net::TcpStream; +use tokio::sync::mpsc; +use uuid::Uuid; + +fn make_fake_host_phase1_image() -> Vec { + let mut image = vec![0; 128]; + rand::thread_rng().fill_bytes(&mut image); + image +} + +#[tokio::test] +async fn test_host_phase1_updater_updates_sled() { + // Start MGS + Sim SP. + let mgstestctx = mgs_setup::test_setup( + "test_host_phase1_updater_updates_sled", + SpPort::One, + ) + .await; + + // Configure an MGS client. + let mut mgs_clients = + MgsClients::from_clients([gateway_client::Client::new( + &mgstestctx.client.url("/").to_string(), + mgstestctx.logctx.log.new(slog::o!("component" => "MgsClient")), + )]); + + for target_host_slot in [0, 1] { + // Configure and instantiate an `HostPhase1Updater`. + let sp_type = SpType::Sled; + let sp_slot = 0; + let update_id = Uuid::new_v4(); + let phase1_data = make_fake_host_phase1_image(); + + let host_phase1_updater = HostPhase1Updater::new( + sp_type, + sp_slot, + target_host_slot, + update_id, + phase1_data.clone(), + &mgstestctx.logctx.log, + ); + + // Run the update. + host_phase1_updater + .update(&mut mgs_clients) + .await + .expect("update failed"); + + // Ensure the SP received the complete update. + let last_update_image = mgstestctx.simrack.gimlets[sp_slot as usize] + .last_host_phase1_update_data(target_host_slot) + .await + .expect("simulated host phase1 did not receive an update"); + + assert_eq!( + phase1_data.as_slice(), + &*last_update_image, + "simulated host phase1 update contents (len {}) \ + do not match test generated fake image (len {})", + last_update_image.len(), + phase1_data.len(), + ); + } + + mgstestctx.teardown().await; +} + +#[tokio::test] +async fn test_host_phase1_updater_remembers_successful_mgs_instance() { + // Start MGS + Sim SP. + let mgstestctx = mgs_setup::test_setup( + "test_host_phase1_updater_remembers_successful_mgs_instance", + SpPort::One, + ) + .await; + + // Also start a local TCP server that we will claim is an MGS instance, but + // it will close connections immediately after accepting them. This will + // allow us to count how many connections it receives, while simultaneously + // causing errors in the HostPhase1Updater when it attempts to use this + // "MGS". + let (failing_mgs_task, failing_mgs_addr, failing_mgs_conn_counter) = { + let socket = TcpListener::bind("[::1]:0").await.unwrap(); + let addr = socket.local_addr().unwrap(); + let conn_count = Arc::new(AtomicUsize::new(0)); + + let task = { + let conn_count = Arc::clone(&conn_count); + tokio::spawn(async move { + loop { + let (mut stream, _peer) = socket.accept().await.unwrap(); + conn_count.fetch_add(1, Ordering::SeqCst); + stream.shutdown().await.unwrap(); + } + }) + }; + + (task, addr, conn_count) + }; + + // Order the MGS clients such that the bogus MGS that immediately closes + // connections comes first. `HostPhase1Updater` should remember that the + // second MGS instance succeeds, and only send subsequent requests to it: we + // should only see a single attempted connection to the bogus MGS, even + // though delivering an update requires a bare minimum of three requests + // (start the update, query the status, reset the SP) and often more (if + // repeated queries are required to wait for completion). + let mut mgs_clients = MgsClients::from_clients([ + gateway_client::Client::new( + &format!("http://{failing_mgs_addr}"), + mgstestctx.logctx.log.new(slog::o!("component" => "MgsClient1")), + ), + gateway_client::Client::new( + &mgstestctx.client.url("/").to_string(), + mgstestctx.logctx.log.new(slog::o!("component" => "MgsClient")), + ), + ]); + + let sp_type = SpType::Sled; + let sp_slot = 0; + let target_host_slot = 0; + let update_id = Uuid::new_v4(); + let phase1_data = make_fake_host_phase1_image(); + + let host_phase1_updater = HostPhase1Updater::new( + sp_type, + sp_slot, + target_host_slot, + update_id, + phase1_data.clone(), + &mgstestctx.logctx.log, + ); + + host_phase1_updater.update(&mut mgs_clients).await.expect("update failed"); + + let last_update_image = mgstestctx.simrack.gimlets[sp_slot as usize] + .last_host_phase1_update_data(target_host_slot) + .await + .expect("simulated host phase1 did not receive an update"); + + assert_eq!( + phase1_data.as_slice(), + &*last_update_image, + "simulated host phase1 update contents (len {}) \ + do not match test generated fake image (len {})", + last_update_image.len(), + phase1_data.len(), + ); + + // Check that our bogus MGS only received a single connection attempt. + // (After HostPhase1Updater failed to talk to this instance, it should have + // fallen back to the valid one for all further requests.) + assert_eq!( + failing_mgs_conn_counter.load(Ordering::SeqCst), + 1, + "bogus MGS instance didn't receive the expected number of connections" + ); + failing_mgs_task.abort(); + + mgstestctx.teardown().await; +} + +#[tokio::test] +async fn test_host_phase1_updater_switches_mgs_instances_on_failure() { + enum MgsProxy { + One(TcpStream), + Two(TcpStream), + } + + // Start MGS + Sim SP. + let mgstestctx = mgs_setup::test_setup( + "test_host_phase1_updater_switches_mgs_instances_on_failure", + SpPort::One, + ) + .await; + let mgs_bind_addr = mgstestctx.client.bind_address; + + let spawn_mgs_proxy_task = |mut stream: TcpStream| { + tokio::spawn(async move { + let mut mgs_stream = TcpStream::connect(mgs_bind_addr) + .await + .expect("failed to connect to MGS"); + tokio::io::copy_bidirectional(&mut stream, &mut mgs_stream) + .await + .expect("failed to proxy connection to MGS"); + }) + }; + + // Start two MGS proxy tasks; when each receives an incoming TCP connection, + // it forwards that `TcpStream` along the `mgs_proxy_connections` channel + // along with a tag of which proxy it is. We'll use this below to flip flop + // between MGS "instances" (really these two proxies). + let (mgs_proxy_connections_tx, mut mgs_proxy_connections_rx) = + mpsc::unbounded_channel(); + let (mgs_proxy_one_task, mgs_proxy_one_addr) = { + let socket = TcpListener::bind("[::1]:0").await.unwrap(); + let addr = socket.local_addr().unwrap(); + let mgs_proxy_connections_tx = mgs_proxy_connections_tx.clone(); + let task = tokio::spawn(async move { + loop { + let (stream, _peer) = socket.accept().await.unwrap(); + mgs_proxy_connections_tx.send(MgsProxy::One(stream)).unwrap(); + } + }); + (task, addr) + }; + let (mgs_proxy_two_task, mgs_proxy_two_addr) = { + let socket = TcpListener::bind("[::1]:0").await.unwrap(); + let addr = socket.local_addr().unwrap(); + let task = tokio::spawn(async move { + loop { + let (stream, _peer) = socket.accept().await.unwrap(); + mgs_proxy_connections_tx.send(MgsProxy::Two(stream)).unwrap(); + } + }); + (task, addr) + }; + + // Disable connection pooling so each request gets a new TCP connection. + let client = + reqwest::Client::builder().pool_max_idle_per_host(0).build().unwrap(); + + // Configure two MGS clients pointed at our two proxy tasks. + let mut mgs_clients = MgsClients::from_clients([ + gateway_client::Client::new_with_client( + &format!("http://{mgs_proxy_one_addr}"), + client.clone(), + mgstestctx.logctx.log.new(slog::o!("component" => "MgsClient1")), + ), + gateway_client::Client::new_with_client( + &format!("http://{mgs_proxy_two_addr}"), + client, + mgstestctx.logctx.log.new(slog::o!("component" => "MgsClient2")), + ), + ]); + + let sp_type = SpType::Sled; + let sp_slot = 0; + let target_host_slot = 0; + let update_id = Uuid::new_v4(); + let phase1_data = make_fake_host_phase1_image(); + + let host_phase1_updater = HostPhase1Updater::new( + sp_type, + sp_slot, + target_host_slot, + update_id, + phase1_data.clone(), + &mgstestctx.logctx.log, + ); + + // Spawn the actual update task. + let mut update_task = tokio::spawn(async move { + host_phase1_updater.update(&mut mgs_clients).await + }); + + // Loop over incoming requests. We expect this sequence: + // + // 1. Connection arrives on the first proxy + // 2. We spawn a task to service that request, and set `should_swap` + // 3. Connection arrives on the first proxy + // 4. We drop that connection, flip `expected_proxy`, and clear + // `should_swap` + // 5. Connection arrives on the second proxy + // 6. We spawn a task to service that request, and set `should_swap` + // 7. Connection arrives on the second proxy + // 8. We drop that connection, flip `expected_proxy`, and clear + // `should_swap` + // + // ... repeat until the update is complete. + let mut expected_proxy = 0; + let mut proxy_one_count = 0; + let mut proxy_two_count = 0; + let mut total_requests_handled = 0; + let mut should_swap = false; + loop { + tokio::select! { + Some(proxy_stream) = mgs_proxy_connections_rx.recv() => { + let stream = match proxy_stream { + MgsProxy::One(stream) => { + assert_eq!(expected_proxy, 0); + proxy_one_count += 1; + stream + } + MgsProxy::Two(stream) => { + assert_eq!(expected_proxy, 1); + proxy_two_count += 1; + stream + } + }; + + // Should we trigger `HostPhase1Updater` to swap to the other + // MGS (proxy)? If so, do that by dropping this connection + // (which will cause a client failure) and note that we expect + // the next incoming request to come on the other proxy. + if should_swap { + mem::drop(stream); + expected_proxy ^= 1; + should_swap = false; + } else { + // Otherwise, handle this connection. + total_requests_handled += 1; + spawn_mgs_proxy_task(stream); + should_swap = true; + } + } + + result = &mut update_task => { + match result { + Ok(Ok(())) => { + mgs_proxy_one_task.abort(); + mgs_proxy_two_task.abort(); + break; + } + Ok(Err(err)) => panic!("update failed: {err}"), + Err(err) => panic!("update task panicked: {err}"), + } + } + } + } + + // A host flash update requires a minimum of 3 requests to MGS: set the + // active flash slot, post the update, and check the status. There may be + // more requests if the update is not yet complete when the status is + // checked, but we can just check that each of our proxies received at least + // 2 incoming requests; based on our outline above, if we got the minimum of + // 3 requests, it would look like this: + // + // 1. POST update -> first proxy (success) + // 2. GET status -> first proxy (fail) + // 3. GET status retry -> second proxy (success) + // 4. POST reset -> second proxy (fail) + // 5. POST reset -> first proxy (success) + // + // This pattern would repeat if multiple status requests were required, so + // we always expect the first proxy to see exactly one more connection + // attempt than the second (because it went first before they started + // swapping), and the two together should see a total of one less than + // double the number of successful requests required. + assert!(total_requests_handled >= 3); + assert_eq!(proxy_one_count, proxy_two_count + 1); + assert_eq!( + (proxy_one_count + proxy_two_count + 1) / 2, + total_requests_handled + ); + + let last_update_image = mgstestctx.simrack.gimlets[sp_slot as usize] + .last_host_phase1_update_data(target_host_slot) + .await + .expect("simulated host phase1 did not receive an update"); + + assert_eq!( + phase1_data.as_slice(), + &*last_update_image, + "simulated host phase1 update contents (len {}) \ + do not match test generated fake image (len {})", + last_update_image.len(), + phase1_data.len(), + ); + + mgstestctx.teardown().await; +} + +#[tokio::test] +async fn test_host_phase1_updater_delivers_progress() { + // Start MGS + Sim SP. + let mgstestctx = mgs_setup::test_setup( + "test_host_phase1_updater_delivers_progress", + SpPort::One, + ) + .await; + + // Configure an MGS client. + let mut mgs_clients = + MgsClients::from_clients([gateway_client::Client::new( + &mgstestctx.client.url("/").to_string(), + mgstestctx.logctx.log.new(slog::o!("component" => "MgsClient")), + )]); + + let sp_type = SpType::Sled; + let sp_slot = 0; + let target_host_slot = 0; + let update_id = Uuid::new_v4(); + let phase1_data = make_fake_host_phase1_image(); + + let host_phase1_updater = HostPhase1Updater::new( + sp_type, + sp_slot, + target_host_slot, + update_id, + phase1_data.clone(), + &mgstestctx.logctx.log, + ); + + let phase1_data_len = phase1_data.len() as u32; + + // Subscribe to update progress, and check that there is no status yet; we + // haven't started the update. + let mut progress = host_phase1_updater.progress_watcher(); + assert_eq!(*progress.borrow_and_update(), None); + + // Install a semaphore on the requests our target SP will receive so we can + // inspect progress messages without racing. + let target_sp = &mgstestctx.simrack.gimlets[sp_slot as usize]; + let sp_accept_sema = target_sp.install_udp_accept_semaphore().await; + let mut sp_responses = target_sp.responses_sent_count().unwrap(); + + // Spawn the update on a background task so we can watch `progress` as it is + // applied. + let do_update_task = tokio::spawn(async move { + host_phase1_updater.update(&mut mgs_clients).await + }); + + // Allow the SP to respond to 2 messages: the message to activate the target + // flash slot and the "prepare update" messages that triggers the start of an + // update, then ensure we see the "started" progress. + sp_accept_sema.send(2).unwrap(); + progress.changed().await.unwrap(); + assert_eq!(*progress.borrow_and_update(), Some(UpdateProgress::Started)); + + // Ensure our simulated SP is in the state we expect: it's prepared for an + // update but has not yet received any data. + assert_eq!( + target_sp.current_update_status().await, + UpdateStatus::InProgress(UpdateInProgressStatus { + id: update_id.into(), + bytes_received: 0, + total_size: phase1_data_len, + }) + ); + + // Record the number of responses the SP has sent; we'll use + // `sp_responses.changed()` in the loop below, and want to mark whatever + // value this watch channel currently has as seen. + sp_responses.borrow_and_update(); + + // At this point, there are two clients racing each other to talk to our + // simulated SP: + // + // 1. MGS is trying to deliver the update + // 2. `host_phase1_updater` is trying to poll (via MGS) for update status + // + // and we want to ensure that we see any relevant progress reports from + // `host_phase1_updater`. We'll let one MGS -> SP message through at a time + // (waiting until our SP has responded by waiting for a change to + // `sp_responses`) then check its update state: if it changed, the packet we + // let through was data from MGS; otherwise, it was a status request from + // `host_phase1_updater`. + // + // This loop will continue until either: + // + // 1. We see an `UpdateStatus::InProgress` message indicating 100% delivery, + // at which point we break out of the loop + // 2. We time out waiting for the previous step (by timing out for either + // the SP to process a request or `host_phase1_updater` to realize + // there's been progress), at which point we panic and fail this test. + let mut prev_bytes_received = 0; + let mut expect_progress_change = false; + loop { + // Allow the SP to accept and respond to a single UDP packet. + sp_accept_sema.send(1).unwrap(); + + // Wait until the SP has sent a response, with a safety rail that we + // haven't screwed up our untangle-the-race logic: if we don't see the + // SP process any new messages after several seconds, our test is + // broken, so fail. + tokio::time::timeout(Duration::from_secs(10), sp_responses.changed()) + .await + .expect("timeout waiting for SP response count to change") + .expect("sp response count sender dropped"); + + // Inspec the SP's in-memory update state; we expect only `InProgress` + // or `Complete`, and in either case we note whether we expect to see + // status changes from `host_phase1_updater`. + match target_sp.current_update_status().await { + UpdateStatus::InProgress(sp_progress) => { + if sp_progress.bytes_received > prev_bytes_received { + prev_bytes_received = sp_progress.bytes_received; + expect_progress_change = true; + continue; + } + } + UpdateStatus::Complete(_) => { + if prev_bytes_received < phase1_data_len { + break; + } + } + status @ (UpdateStatus::None + | UpdateStatus::Preparing(_) + | UpdateStatus::SpUpdateAuxFlashChckScan { .. } + | UpdateStatus::Aborted(_) + | UpdateStatus::Failed { .. } + | UpdateStatus::RotError { .. }) => { + panic!("unexpected status {status:?}"); + } + } + + // If we get here, the most recent packet did _not_ change the SP's + // internal update state, so it was a status request from + // `host_phase1_updater`. If we expect the updater to see new progress, + // wait for that change here. + if expect_progress_change { + // Safety rail that we haven't screwed up our untangle-the-race + // logic: if we don't see a new progress after several seconds, our + // test is broken, so fail. + tokio::time::timeout(Duration::from_secs(10), progress.changed()) + .await + .expect("progress timeout") + .expect("progress watch sender dropped"); + let status = progress.borrow_and_update().clone().unwrap(); + expect_progress_change = false; + + assert!( + matches!(status, UpdateProgress::InProgress { .. }), + "unexpected progress status {status:?}" + ); + } + } + + // We know the SP has received a complete update, but `HostPhase1Updater` + // may still need to request status to realize that; release the socket + // semaphore so the SP can respond. + sp_accept_sema.send(usize::MAX).unwrap(); + + // Unlike the SP and RoT cases, there are no MGS/SP steps in between the + // update completing and `HostPhase1Updater` sending + // `UpdateProgress::Complete`. Therefore, it's a race whether we'll see + // some number of `InProgress` status before `Complete`, but we should + // quickly move to `Complete`. + loop { + tokio::time::timeout(Duration::from_secs(10), progress.changed()) + .await + .expect("progress timeout") + .expect("progress watch sender dropped"); + let status = progress.borrow_and_update().clone().unwrap(); + match status { + UpdateProgress::Complete => break, + UpdateProgress::InProgress { .. } => continue, + _ => panic!("unexpected progress status {status:?}"), + } + } + + // drop our progress receiver so `do_update_task` can complete + mem::drop(progress); + + do_update_task.await.expect("update task panicked").expect("update failed"); + + let last_update_image = target_sp + .last_host_phase1_update_data(target_host_slot) + .await + .expect("simulated host phase1 did not receive an update"); + + assert_eq!( + phase1_data.as_slice(), + &*last_update_image, + "simulated host phase1 update contents (len {}) \ + do not match test generated fake image (len {})", + last_update_image.len(), + phase1_data.len(), + ); + + mgstestctx.teardown().await; +} diff --git a/nexus/tests/integration_tests/mod.rs b/nexus/tests/integration_tests/mod.rs index 87c5c74f0f..4d7b41cfa8 100644 --- a/nexus/tests/integration_tests/mod.rs +++ b/nexus/tests/integration_tests/mod.rs @@ -12,6 +12,7 @@ mod commands; mod console_api; mod device_auth; mod disks; +mod host_phase1_updater; mod images; mod initialization; mod instances; diff --git a/nexus/tests/integration_tests/rot_updater.rs b/nexus/tests/integration_tests/rot_updater.rs index 750f9571d0..2e6d65f8b1 100644 --- a/nexus/tests/integration_tests/rot_updater.rs +++ b/nexus/tests/integration_tests/rot_updater.rs @@ -45,10 +45,11 @@ async fn test_rot_updater_updates_sled() { .await; // Configure an MGS client. - let mgs_clients = MgsClients::from_clients([gateway_client::Client::new( - &mgstestctx.client.url("/").to_string(), - mgstestctx.logctx.log.new(slog::o!("component" => "MgsClient")), - )]); + let mut mgs_clients = + MgsClients::from_clients([gateway_client::Client::new( + &mgstestctx.client.url("/").to_string(), + mgstestctx.logctx.log.new(slog::o!("component" => "MgsClient")), + )]); // Configure and instantiate an `RotUpdater`. let sp_type = SpType::Sled; @@ -67,7 +68,7 @@ async fn test_rot_updater_updates_sled() { ); // Run the update. - rot_updater.update(mgs_clients).await.expect("update failed"); + rot_updater.update(&mut mgs_clients).await.expect("update failed"); // Ensure the RoT received the complete update. let last_update_image = mgstestctx.simrack.gimlets[sp_slot as usize] @@ -97,10 +98,11 @@ async fn test_rot_updater_updates_switch() { .await; // Configure an MGS client. - let mgs_clients = MgsClients::from_clients([gateway_client::Client::new( - &mgstestctx.client.url("/").to_string(), - mgstestctx.logctx.log.new(slog::o!("component" => "MgsClient")), - )]); + let mut mgs_clients = + MgsClients::from_clients([gateway_client::Client::new( + &mgstestctx.client.url("/").to_string(), + mgstestctx.logctx.log.new(slog::o!("component" => "MgsClient")), + )]); let sp_type = SpType::Switch; let sp_slot = 0; @@ -117,7 +119,7 @@ async fn test_rot_updater_updates_switch() { &mgstestctx.logctx.log, ); - rot_updater.update(mgs_clients).await.expect("update failed"); + rot_updater.update(&mut mgs_clients).await.expect("update failed"); let last_update_image = mgstestctx.simrack.sidecars[sp_slot as usize] .last_rot_update_data() @@ -177,7 +179,7 @@ async fn test_rot_updater_remembers_successful_mgs_instance() { // delivering an update requires a bare minimum of three requests (start the // update, query the status, reset the RoT) and often more (if repeated // queries are required to wait for completion). - let mgs_clients = MgsClients::from_clients([ + let mut mgs_clients = MgsClients::from_clients([ gateway_client::Client::new( &format!("http://{failing_mgs_addr}"), mgstestctx.logctx.log.new(slog::o!("component" => "MgsClient1")), @@ -203,7 +205,7 @@ async fn test_rot_updater_remembers_successful_mgs_instance() { &mgstestctx.logctx.log, ); - rot_updater.update(mgs_clients).await.expect("update failed"); + rot_updater.update(&mut mgs_clients).await.expect("update failed"); let last_update_image = mgstestctx.simrack.gimlets[sp_slot as usize] .last_rot_update_data() @@ -295,7 +297,7 @@ async fn test_rot_updater_switches_mgs_instances_on_failure() { reqwest::Client::builder().pool_max_idle_per_host(0).build().unwrap(); // Configure two MGS clients pointed at our two proxy tasks. - let mgs_clients = MgsClients::from_clients([ + let mut mgs_clients = MgsClients::from_clients([ gateway_client::Client::new_with_client( &format!("http://{mgs_proxy_one_addr}"), client.clone(), @@ -324,7 +326,8 @@ async fn test_rot_updater_switches_mgs_instances_on_failure() { ); // Spawn the actual update task. - let mut update_task = tokio::spawn(rot_updater.update(mgs_clients)); + let mut update_task = + tokio::spawn(async move { rot_updater.update(&mut mgs_clients).await }); // Loop over incoming requests. We expect this sequence: // @@ -447,10 +450,11 @@ async fn test_rot_updater_delivers_progress() { .await; // Configure an MGS client. - let mgs_clients = MgsClients::from_clients([gateway_client::Client::new( - &mgstestctx.client.url("/").to_string(), - mgstestctx.logctx.log.new(slog::o!("component" => "MgsClient")), - )]); + let mut mgs_clients = + MgsClients::from_clients([gateway_client::Client::new( + &mgstestctx.client.url("/").to_string(), + mgstestctx.logctx.log.new(slog::o!("component" => "MgsClient")), + )]); let sp_type = SpType::Sled; let sp_slot = 0; @@ -483,10 +487,11 @@ async fn test_rot_updater_delivers_progress() { // Spawn the update on a background task so we can watch `progress` as it is // applied. - let do_update_task = tokio::spawn(rot_updater.update(mgs_clients)); + let do_update_task = + tokio::spawn(async move { rot_updater.update(&mut mgs_clients).await }); // Allow the SP to respond to 1 message: the "prepare update" messages that - // trigger the start of an update, then ensure we see the "started" + // triggers the start of an update, then ensure we see the "started" // progress. sp_accept_sema.send(1).unwrap(); progress.changed().await.unwrap(); @@ -556,7 +561,6 @@ async fn test_rot_updater_delivers_progress() { UpdateStatus::Complete(_) => { if prev_bytes_received < rot_image_len { prev_bytes_received = rot_image_len; - continue; } } status @ (UpdateStatus::None diff --git a/nexus/tests/integration_tests/sp_updater.rs b/nexus/tests/integration_tests/sp_updater.rs index 89735ac3d9..1b6764e609 100644 --- a/nexus/tests/integration_tests/sp_updater.rs +++ b/nexus/tests/integration_tests/sp_updater.rs @@ -46,10 +46,11 @@ async fn test_sp_updater_updates_sled() { .await; // Configure an MGS client. - let mgs_clients = MgsClients::from_clients([gateway_client::Client::new( - &mgstestctx.client.url("/").to_string(), - mgstestctx.logctx.log.new(slog::o!("component" => "MgsClient")), - )]); + let mut mgs_clients = + MgsClients::from_clients([gateway_client::Client::new( + &mgstestctx.client.url("/").to_string(), + mgstestctx.logctx.log.new(slog::o!("component" => "MgsClient")), + )]); // Configure and instantiate an `SpUpdater`. let sp_type = SpType::Sled; @@ -66,7 +67,7 @@ async fn test_sp_updater_updates_sled() { ); // Run the update. - sp_updater.update(mgs_clients).await.expect("update failed"); + sp_updater.update(&mut mgs_clients).await.expect("update failed"); // Ensure the SP received the complete update. let last_update_image = mgstestctx.simrack.gimlets[sp_slot as usize] @@ -96,10 +97,11 @@ async fn test_sp_updater_updates_switch() { .await; // Configure an MGS client. - let mgs_clients = MgsClients::from_clients([gateway_client::Client::new( - &mgstestctx.client.url("/").to_string(), - mgstestctx.logctx.log.new(slog::o!("component" => "MgsClient")), - )]); + let mut mgs_clients = + MgsClients::from_clients([gateway_client::Client::new( + &mgstestctx.client.url("/").to_string(), + mgstestctx.logctx.log.new(slog::o!("component" => "MgsClient")), + )]); let sp_type = SpType::Switch; let sp_slot = 0; @@ -114,7 +116,7 @@ async fn test_sp_updater_updates_switch() { &mgstestctx.logctx.log, ); - sp_updater.update(mgs_clients).await.expect("update failed"); + sp_updater.update(&mut mgs_clients).await.expect("update failed"); let last_update_image = mgstestctx.simrack.sidecars[sp_slot as usize] .last_sp_update_data() @@ -174,7 +176,7 @@ async fn test_sp_updater_remembers_successful_mgs_instance() { // delivering an update requires a bare minimum of three requests (start the // update, query the status, reset the SP) and often more (if repeated // queries are required to wait for completion). - let mgs_clients = MgsClients::from_clients([ + let mut mgs_clients = MgsClients::from_clients([ gateway_client::Client::new( &format!("http://{failing_mgs_addr}"), mgstestctx.logctx.log.new(slog::o!("component" => "MgsClient1")), @@ -198,7 +200,7 @@ async fn test_sp_updater_remembers_successful_mgs_instance() { &mgstestctx.logctx.log, ); - sp_updater.update(mgs_clients).await.expect("update failed"); + sp_updater.update(&mut mgs_clients).await.expect("update failed"); let last_update_image = mgstestctx.simrack.gimlets[sp_slot as usize] .last_sp_update_data() @@ -290,7 +292,7 @@ async fn test_sp_updater_switches_mgs_instances_on_failure() { reqwest::Client::builder().pool_max_idle_per_host(0).build().unwrap(); // Configure two MGS clients pointed at our two proxy tasks. - let mgs_clients = MgsClients::from_clients([ + let mut mgs_clients = MgsClients::from_clients([ gateway_client::Client::new_with_client( &format!("http://{mgs_proxy_one_addr}"), client.clone(), @@ -317,7 +319,8 @@ async fn test_sp_updater_switches_mgs_instances_on_failure() { ); // Spawn the actual update task. - let mut update_task = tokio::spawn(sp_updater.update(mgs_clients)); + let mut update_task = + tokio::spawn(async move { sp_updater.update(&mut mgs_clients).await }); // Loop over incoming requests. We expect this sequence: // @@ -436,10 +439,11 @@ async fn test_sp_updater_delivers_progress() { .await; // Configure an MGS client. - let mgs_clients = MgsClients::from_clients([gateway_client::Client::new( - &mgstestctx.client.url("/").to_string(), - mgstestctx.logctx.log.new(slog::o!("component" => "MgsClient")), - )]); + let mut mgs_clients = + MgsClients::from_clients([gateway_client::Client::new( + &mgstestctx.client.url("/").to_string(), + mgstestctx.logctx.log.new(slog::o!("component" => "MgsClient")), + )]); let sp_type = SpType::Sled; let sp_slot = 0; @@ -470,10 +474,11 @@ async fn test_sp_updater_delivers_progress() { // Spawn the update on a background task so we can watch `progress` as it is // applied. - let do_update_task = tokio::spawn(sp_updater.update(mgs_clients)); + let do_update_task = + tokio::spawn(async move { sp_updater.update(&mut mgs_clients).await }); // Allow the SP to respond to 2 messages: the caboose check and the "prepare - // update" messages that trigger the start of an update, then ensure we see + // update" messages that triggers the start of an update, then ensure we see // the "started" progress. sp_accept_sema.send(2).unwrap(); progress.changed().await.unwrap(); @@ -543,7 +548,6 @@ async fn test_sp_updater_delivers_progress() { UpdateStatus::Complete(_) => { if prev_bytes_received < sp_image_len { prev_bytes_received = sp_image_len; - continue; } } status @ (UpdateStatus::None diff --git a/sp-sim/src/gimlet.rs b/sp-sim/src/gimlet.rs index 635e8fde6b..5cfad94c86 100644 --- a/sp-sim/src/gimlet.rs +++ b/sp-sim/src/gimlet.rs @@ -123,6 +123,15 @@ impl SimulatedSp for Gimlet { handler.update_state.last_rot_update_data() } + async fn last_host_phase1_update_data( + &self, + slot: u16, + ) -> Option> { + let handler = self.handler.as_ref()?; + let handler = handler.lock().await; + handler.update_state.last_host_phase1_update_data(slot) + } + async fn current_update_status(&self) -> gateway_messages::UpdateStatus { let Some(handler) = self.handler.as_ref() else { return gateway_messages::UpdateStatus::None; @@ -1188,7 +1197,7 @@ impl SpHandler for Handler { port: SpPort, component: SpComponent, ) -> Result { - warn!( + debug!( &self.log, "asked for component active slot"; "sender" => %sender, "port" => ?port, @@ -1211,7 +1220,7 @@ impl SpHandler for Handler { slot: u16, persist: bool, ) -> Result<(), SpError> { - warn!( + debug!( &self.log, "asked to set component active slot"; "sender" => %sender, "port" => ?port, @@ -1222,9 +1231,12 @@ impl SpHandler for Handler { if component == SpComponent::ROT { self.rot_active_slot = rot_slot_id_from_u16(slot)?; Ok(()) + } else if component == SpComponent::HOST_CPU_BOOT_FLASH { + self.update_state.set_active_host_slot(slot); + Ok(()) } else { // The real SP returns `RequestUnsupportedForComponent` for anything - // other than the RoT, including SP_ITSELF. + // other than the RoT and host boot flash, including SP_ITSELF. Err(SpError::RequestUnsupportedForComponent) } } diff --git a/sp-sim/src/lib.rs b/sp-sim/src/lib.rs index 0958e8a177..87643af9a8 100644 --- a/sp-sim/src/lib.rs +++ b/sp-sim/src/lib.rs @@ -68,6 +68,12 @@ pub trait SimulatedSp { /// Only returns data after a simulated reset of the RoT. async fn last_rot_update_data(&self) -> Option>; + /// Get the last completed update delivered to the host phase1 flash slot. + async fn last_host_phase1_update_data( + &self, + slot: u16, + ) -> Option>; + /// Get the current update status, just as would be returned by an MGS /// request to get the update status. async fn current_update_status(&self) -> gateway_messages::UpdateStatus; diff --git a/sp-sim/src/sidecar.rs b/sp-sim/src/sidecar.rs index 19e84ffc64..1bd6fe4964 100644 --- a/sp-sim/src/sidecar.rs +++ b/sp-sim/src/sidecar.rs @@ -134,6 +134,14 @@ impl SimulatedSp for Sidecar { handler.update_state.last_rot_update_data() } + async fn last_host_phase1_update_data( + &self, + _slot: u16, + ) -> Option> { + // sidecars do not have attached hosts + None + } + async fn current_update_status(&self) -> gateway_messages::UpdateStatus { let Some(handler) = self.handler.as_ref() else { return gateway_messages::UpdateStatus::None; diff --git a/sp-sim/src/update.rs b/sp-sim/src/update.rs index 9879a3ecde..0efa730a26 100644 --- a/sp-sim/src/update.rs +++ b/sp-sim/src/update.rs @@ -2,6 +2,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at https://mozilla.org/MPL/2.0/. +use std::collections::BTreeMap; use std::io::Cursor; use std::mem; @@ -15,6 +16,8 @@ pub(crate) struct SimSpUpdate { state: UpdateState, last_sp_update_data: Option>, last_rot_update_data: Option>, + last_host_phase1_update_data: BTreeMap>, + active_host_slot: Option, } impl Default for SimSpUpdate { @@ -23,6 +26,13 @@ impl Default for SimSpUpdate { state: UpdateState::NotPrepared, last_sp_update_data: None, last_rot_update_data: None, + last_host_phase1_update_data: BTreeMap::new(), + + // In the real SP, there is always _some_ active host slot. We could + // emulate that by always defaulting to slot 0, but instead we'll + // ensure any tests that expect to read or write a particular slot + // set that slot as active first. + active_host_slot: None, } } } @@ -43,9 +53,20 @@ impl SimSpUpdate { UpdateState::NotPrepared | UpdateState::Aborted(_) | UpdateState::Completed { .. } => { + let slot = if component == SpComponent::HOST_CPU_BOOT_FLASH { + match self.active_host_slot { + Some(slot) => slot, + None => return Err(SpError::InvalidSlotForComponent), + } + } else { + // We don't manage SP or RoT slots, so just use 0 + 0 + }; + self.state = UpdateState::Prepared { component, id, + slot, data: Cursor::new(vec![0u8; total_size].into_boxed_slice()), }; Ok(()) @@ -63,7 +84,7 @@ impl SimSpUpdate { chunk_data: &[u8], ) -> Result<(), SpError> { match &mut self.state { - UpdateState::Prepared { component, id, data } => { + UpdateState::Prepared { component, id, slot, data } => { // Ensure that the update ID and target component are correct. if chunk.id != *id || chunk.component != *component { return Err(SpError::InvalidUpdateId { sp_update_id: *id }); @@ -84,10 +105,17 @@ impl SimSpUpdate { if data.position() == data.get_ref().len() as u64 { let mut stolen = Cursor::new(Box::default()); mem::swap(data, &mut stolen); + let data = stolen.into_inner(); + + if *component == SpComponent::HOST_CPU_BOOT_FLASH { + self.last_host_phase1_update_data + .insert(*slot, data.clone()); + } + self.state = UpdateState::Completed { component: *component, id: *id, - data: stolen.into_inner(), + data, }; } @@ -150,6 +178,17 @@ impl SimSpUpdate { pub(crate) fn last_rot_update_data(&self) -> Option> { self.last_rot_update_data.clone() } + + pub(crate) fn last_host_phase1_update_data( + &self, + slot: u16, + ) -> Option> { + self.last_host_phase1_update_data.get(&slot).cloned() + } + + pub(crate) fn set_active_host_slot(&mut self, slot: u16) { + self.active_host_slot = Some(slot); + } } enum UpdateState { @@ -157,6 +196,7 @@ enum UpdateState { Prepared { component: SpComponent, id: UpdateId, + slot: u16, // data would ordinarily be a Cursor>, but that can grow and // reallocate. We want to ensure that we don't receive any more data // than originally promised, so use a Cursor> to ensure that