Skip to content

Commit

Permalink
sse flag and poll as default (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
chudkowsky authored Sep 6, 2024
1 parent d373c49 commit 48bce0f
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 4 deletions.
42 changes: 40 additions & 2 deletions bin/cairo-prove/src/fetch.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use std::time::Duration;

use prover_sdk::sdk::ProverSDK;
use serde::Deserialize;
use serde_json::Value;
use tokio::time::sleep;
use tracing::info;

use crate::errors::ProveErrors;

Expand All @@ -9,10 +13,11 @@ pub struct JobId {
pub job_id: u64,
}

pub async fn fetch_job(sdk: ProverSDK, job: String) -> Result<String, ProveErrors> {
pub async fn fetch_job_sse(sdk: ProverSDK, job: String) -> Result<String, ProveErrors> {
let job: JobId = serde_json::from_str(&job)?;
println!("Job ID: {}", job.job_id);
info!("Job ID: {}", job.job_id);
sdk.sse(job.job_id).await?;
info!("Job completed");
let response = sdk.get_job(job.job_id).await?;
let response = response.text().await?;
let json_response: Value = serde_json::from_str(&response)?;
Expand All @@ -30,3 +35,36 @@ pub async fn fetch_job(sdk: ProverSDK, job: String) -> Result<String, ProveError
Err(ProveErrors::Custom(json_response.to_string()))
}
}
pub async fn fetch_job_polling(sdk: ProverSDK, job: String) -> Result<String, ProveErrors> {
let job: JobId = serde_json::from_str(&job)?;
info!("Fetching job: {}", job.job_id);
let mut counter = 0;
loop {
let response = sdk.get_job(job.job_id).await?;
let response = response.text().await?;
let json_response: Value = serde_json::from_str(&response)?;
if let Some(status) = json_response.get("status").and_then(Value::as_str) {
match status {
"Completed" => {
return Ok(json_response
.get("result")
.and_then(Value::as_str)
.unwrap_or("No result found")
.to_string());
}
"Pending" | "Running" => {
info!("Job is still in progress. Status: {}", status);
info!(
"Time passed: {} Waiting for 10 seconds before retrying...",
counter * 10
);
counter += 1;
sleep(Duration::from_secs(10)).await;
}
_ => {
return Err(ProveErrors::Custom(json_response.to_string()));
}
}
}
}
}
2 changes: 2 additions & 0 deletions bin/cairo-prove/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ pub struct Args {
pub prover_access_key: String,
#[arg(long, env, default_value = "false")]
pub wait: bool,
#[arg(long, env, default_value = "false")]
pub sse: bool,
}

fn validate_input(input: &str) -> Result<Vec<Felt>, ProveErrors> {
Expand Down
12 changes: 10 additions & 2 deletions bin/cairo-prove/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use cairo_prove::errors::ProveErrors;
use cairo_prove::prove::prove;
use cairo_prove::{fetch::fetch_job, Args};
use cairo_prove::{
fetch::{fetch_job_polling, fetch_job_sse},
Args,
};
use clap::Parser;
use prover_sdk::access_key::ProverAccessKey;
use prover_sdk::sdk::ProverSDK;
Expand All @@ -12,9 +15,14 @@ pub async fn main() -> Result<(), ProveErrors> {
let sdk = ProverSDK::new(args.prover_url.clone(), access_key).await?;
let job = prove(args.clone(), sdk.clone()).await?;
if args.wait {
let job = fetch_job(sdk, job).await?;
let job = if args.sse {
fetch_job_sse(sdk, job).await?
} else {
fetch_job_polling(sdk, job).await?
};
let path: std::path::PathBuf = args.program_output;
std::fs::write(path, job)?;
}

Ok(())
}

0 comments on commit 48bce0f

Please sign in to comment.