Skip to content

Commit

Permalink
retry downloads
Browse files Browse the repository at this point in the history
  • Loading branch information
smklein committed Sep 29, 2023
1 parent b03dd6b commit 9bee530
Showing 1 changed file with 125 additions and 38 deletions.
163 changes: 125 additions & 38 deletions package/src/bin/omicron-package.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ enum SubCommand {
Deploy(DeployCommand),
}

fn parse_duration_ms(arg: &str) -> Result<std::time::Duration> {
let ms = arg.parse()?;
Ok(std::time::Duration::from_millis(ms))
}

#[derive(Debug, Parser)]
#[clap(name = "packaging tool")]
struct Args {
Expand Down Expand Up @@ -77,6 +82,23 @@ struct Args {
)]
force: bool,

#[clap(
long,
help = "Number of retries to use when re-attempting failed package downloads",
action,
default_value_t = 10
)]
retry_count: usize,

#[clap(
long,
help = "Duration, in ms, to wait before re-attempting failed package downloads",
action,
value_parser = parse_duration_ms,
default_value = "1000",
)]
retry_duration: std::time::Duration,

#[clap(subcommand)]
subcommand: SubCommand,
}
Expand Down Expand Up @@ -303,8 +325,63 @@ async fn get_sha256_digest(path: &PathBuf) -> Result<Digest> {
Ok(context.finish())
}

async fn download_prebuilt(
progress: &PackageProgress,
package_name: &str,
repo: &str,
commit: &str,
expected_digest: &Vec<u8>,
path: &Path,
) -> Result<()> {
progress.set_message("downloading prebuilt".into());
let url = format!(
"https://buildomat.eng.oxide.computer/public/file/oxidecomputer/{}/image/{}/{}",
repo,
commit,
path.file_name().unwrap().to_string_lossy(),
);
let response = reqwest::Client::new()
.get(&url)
.send()
.await
.with_context(|| format!("failed to get {url}"))?;
progress.set_length(
response
.content_length()
.ok_or_else(|| anyhow!("Missing Content Length"))?,
);
let mut file = tokio::fs::File::create(&path)
.await
.with_context(|| format!("failed to create {path:?}"))?;
let mut stream = response.bytes_stream();
let mut context = DigestContext::new(&SHA256);
while let Some(chunk) = stream.next().await {
let chunk = chunk
.with_context(|| format!("failed reading response from {url}"))?;
// Update the running SHA digest
context.update(&chunk);
// Update the downloaded file
file.write_all(&chunk)
.await
.with_context(|| format!("failed writing {path:?}"))?;
// Record progress in the UI
progress.increment(chunk.len().try_into().unwrap());
}

let digest = context.finish();
if digest.as_ref() != expected_digest {
bail!(
"Digest mismatch downloading {package_name}: Saw {}, expected {}",
hex::encode(digest.as_ref()),
hex::encode(expected_digest)
);
}
Ok(())
}

// Ensures a package exists, either by creating it or downloading it.
async fn get_package(
config: &Config,
target: &Target,
ui: &Arc<ProgressUI>,
package_name: &String,
Expand All @@ -328,45 +405,30 @@ async fn get_package(
};

if should_download {
progress.set_message("downloading prebuilt".into());
let url = format!(
"https://buildomat.eng.oxide.computer/public/file/oxidecomputer/{}/image/{}/{}",
repo,
commit,
path.as_path().file_name().unwrap().to_string_lossy(),
);
let response = reqwest::Client::new()
.get(&url)
.send()
.await
.with_context(|| format!("failed to get {url}"))?;
progress.set_length(
response
.content_length()
.ok_or_else(|| anyhow!("Missing Content Length"))?,
);
let mut file = tokio::fs::File::create(&path)
let mut attempts_left = config.retry_count + 1;
loop {
match download_prebuilt(
&progress,
package_name,
repo,
commit,
&expected_digest,
path.as_path(),
)
.await
.with_context(|| format!("failed to create {path:?}"))?;
let mut stream = response.bytes_stream();
let mut context = DigestContext::new(&SHA256);
while let Some(chunk) = stream.next().await {
let chunk = chunk.with_context(|| {
format!("failed reading response from {url}")
})?;
// Update the running SHA digest
context.update(&chunk);
// Update the downloaded file
file.write_all(&chunk)
.await
.with_context(|| format!("failed writing {path:?}"))?;
// Record progress in the UI
progress.increment(chunk.len().try_into().unwrap());
}

let digest = context.finish();
if digest.as_ref() != expected_digest {
bail!("Digest mismatch downloading {package_name}: Saw {}, expected {}", hex::encode(digest.as_ref()), hex::encode(expected_digest));
{
Ok(()) => break,
Err(err) => {
attempts_left -= 1;
let msg = format!("Failed to download prebuilt ({attempts_left} attempts remaining)");
progress.set_error_message(msg.into());
if attempts_left == 0 {
bail!("Failed to download package: {err}");
}
tokio::time::sleep(config.retry_duration).await;
progress.reset();
}
}
}
}
}
Expand Down Expand Up @@ -463,6 +525,7 @@ async fn do_package(config: &Config, output_directory: &Path) -> Result<()> {
None,
|((package_name, package), ui)| async move {
get_package(
&config,
&config.target,
&ui,
package_name,
Expand Down Expand Up @@ -761,6 +824,13 @@ fn completed_progress_style() -> ProgressStyle {
.progress_chars("#>.")
}

fn error_progress_style() -> ProgressStyle {
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {msg:.red}")
.expect("Invalid template")
.progress_chars("#>.")
}

// Struct managing display of progress to UI.
struct ProgressUI {
multi: MultiProgress,
Expand All @@ -782,10 +852,21 @@ impl PackageProgress {
fn set_length(&self, total: u64) {
self.pb.set_length(total);
}

fn set_error_message(&self, message: std::borrow::Cow<'static, str>) {
self.pb.set_style(error_progress_style());
self.pb.set_message(format!("{}: {}", self.service_name, message));
self.pb.tick();
}

fn reset(&self) {
self.pb.reset();
}
}

impl Progress for PackageProgress {
fn set_message(&self, message: std::borrow::Cow<'static, str>) {
self.pb.set_style(in_progress_style());
self.pb.set_message(format!("{}: {}", self.service_name, message));
self.pb.tick();
}
Expand Down Expand Up @@ -820,6 +901,10 @@ struct Config {
target: Target,
// True if we should skip confirmations for destructive operations.
force: bool,
// Number of times to retry failed downloads.
retry_count: usize,
// Duration to wait before retrying failed downloads.
retry_duration: std::time::Duration,
}

impl Config {
Expand Down Expand Up @@ -886,6 +971,8 @@ async fn main() -> Result<()> {
package_config,
target,
force: args.force,
retry_count: args.retry_count,
retry_duration: args.retry_duration,
})
};

Expand Down

0 comments on commit 9bee530

Please sign in to comment.