diff --git a/package/src/bin/omicron-package.rs b/package/src/bin/omicron-package.rs index a0146eee50..ea490e54cf 100644 --- a/package/src/bin/omicron-package.rs +++ b/package/src/bin/omicron-package.rs @@ -41,6 +41,11 @@ enum SubCommand { Deploy(DeployCommand), } +fn parse_duration_ms(arg: &str) -> Result { + let ms = arg.parse()?; + Ok(std::time::Duration::from_millis(ms)) +} + #[derive(Debug, Parser)] #[clap(name = "packaging tool")] struct Args { @@ -77,6 +82,23 @@ struct Args { )] force: bool, + #[clap( + long, + help = "Number of retries to use when re-attempting failed package downloads", + action, + default_value_t = 10 + )] + retry_count: usize, + + #[clap( + long, + help = "Duration, in ms, to wait before re-attempting failed package downloads", + action, + value_parser = parse_duration_ms, + default_value = "1000", + )] + retry_duration: std::time::Duration, + #[clap(subcommand)] subcommand: SubCommand, } @@ -303,8 +325,63 @@ async fn get_sha256_digest(path: &PathBuf) -> Result { Ok(context.finish()) } +async fn download_prebuilt( + progress: &PackageProgress, + package_name: &str, + repo: &str, + commit: &str, + expected_digest: &Vec, + path: &Path, +) -> Result<()> { + progress.set_message("downloading prebuilt".into()); + let url = format!( + "https://buildomat.eng.oxide.computer/public/file/oxidecomputer/{}/image/{}/{}", + repo, + commit, + path.file_name().unwrap().to_string_lossy(), + ); + let response = reqwest::Client::new() + .get(&url) + .send() + .await + .with_context(|| format!("failed to get {url}"))?; + progress.set_length( + response + .content_length() + .ok_or_else(|| anyhow!("Missing Content Length"))?, + ); + let mut file = tokio::fs::File::create(&path) + .await + .with_context(|| format!("failed to create {path:?}"))?; + let mut stream = response.bytes_stream(); + let mut context = DigestContext::new(&SHA256); + while let Some(chunk) = stream.next().await { + let chunk = chunk + .with_context(|| format!("failed reading response from {url}"))?; + // Update the running SHA digest + context.update(&chunk); + // Update the downloaded file + file.write_all(&chunk) + .await + .with_context(|| format!("failed writing {path:?}"))?; + // Record progress in the UI + progress.increment(chunk.len().try_into().unwrap()); + } + + let digest = context.finish(); + if digest.as_ref() != expected_digest { + bail!( + "Digest mismatch downloading {package_name}: Saw {}, expected {}", + hex::encode(digest.as_ref()), + hex::encode(expected_digest) + ); + } + Ok(()) +} + // Ensures a package exists, either by creating it or downloading it. async fn get_package( + config: &Config, target: &Target, ui: &Arc, package_name: &String, @@ -328,45 +405,30 @@ async fn get_package( }; if should_download { - progress.set_message("downloading prebuilt".into()); - let url = format!( - "https://buildomat.eng.oxide.computer/public/file/oxidecomputer/{}/image/{}/{}", - repo, - commit, - path.as_path().file_name().unwrap().to_string_lossy(), - ); - let response = reqwest::Client::new() - .get(&url) - .send() - .await - .with_context(|| format!("failed to get {url}"))?; - progress.set_length( - response - .content_length() - .ok_or_else(|| anyhow!("Missing Content Length"))?, - ); - let mut file = tokio::fs::File::create(&path) + let mut attempts_left = config.retry_count + 1; + loop { + match download_prebuilt( + &progress, + package_name, + repo, + commit, + &expected_digest, + path.as_path(), + ) .await - .with_context(|| format!("failed to create {path:?}"))?; - let mut stream = response.bytes_stream(); - let mut context = DigestContext::new(&SHA256); - while let Some(chunk) = stream.next().await { - let chunk = chunk.with_context(|| { - format!("failed reading response from {url}") - })?; - // Update the running SHA digest - context.update(&chunk); - // Update the downloaded file - file.write_all(&chunk) - .await - .with_context(|| format!("failed writing {path:?}"))?; - // Record progress in the UI - progress.increment(chunk.len().try_into().unwrap()); - } - - let digest = context.finish(); - if digest.as_ref() != expected_digest { - bail!("Digest mismatch downloading {package_name}: Saw {}, expected {}", hex::encode(digest.as_ref()), hex::encode(expected_digest)); + { + Ok(()) => break, + Err(err) => { + attempts_left -= 1; + let msg = format!("Failed to download prebuilt ({attempts_left} attempts remaining)"); + progress.set_error_message(msg.into()); + if attempts_left == 0 { + bail!("Failed to download package: {err}"); + } + tokio::time::sleep(config.retry_duration).await; + progress.reset(); + } + } } } } @@ -463,6 +525,7 @@ async fn do_package(config: &Config, output_directory: &Path) -> Result<()> { None, |((package_name, package), ui)| async move { get_package( + &config, &config.target, &ui, package_name, @@ -761,6 +824,13 @@ fn completed_progress_style() -> ProgressStyle { .progress_chars("#>.") } +fn error_progress_style() -> ProgressStyle { + ProgressStyle::default_bar() + .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {msg:.red}") + .expect("Invalid template") + .progress_chars("#>.") +} + // Struct managing display of progress to UI. struct ProgressUI { multi: MultiProgress, @@ -782,10 +852,21 @@ impl PackageProgress { fn set_length(&self, total: u64) { self.pb.set_length(total); } + + fn set_error_message(&self, message: std::borrow::Cow<'static, str>) { + self.pb.set_style(error_progress_style()); + self.pb.set_message(format!("{}: {}", self.service_name, message)); + self.pb.tick(); + } + + fn reset(&self) { + self.pb.reset(); + } } impl Progress for PackageProgress { fn set_message(&self, message: std::borrow::Cow<'static, str>) { + self.pb.set_style(in_progress_style()); self.pb.set_message(format!("{}: {}", self.service_name, message)); self.pb.tick(); } @@ -820,6 +901,10 @@ struct Config { target: Target, // True if we should skip confirmations for destructive operations. force: bool, + // Number of times to retry failed downloads. + retry_count: usize, + // Duration to wait before retrying failed downloads. + retry_duration: std::time::Duration, } impl Config { @@ -886,6 +971,8 @@ async fn main() -> Result<()> { package_config, target, force: args.force, + retry_count: args.retry_count, + retry_duration: args.retry_duration, }) };