diff --git a/prover/src/prove/cairo.rs b/prover/src/prove/cairo.rs index 7933bd0..073c8cf 100644 --- a/prover/src/prove/cairo.rs +++ b/prover/src/prove/cairo.rs @@ -2,7 +2,6 @@ use crate::auth::jwt::Claims; use crate::extractors::workdir::TempDirHandle; use crate::server::AppState; use crate::threadpool::CairoVersionedInput; -use crate::utils::job::create_job; use axum::Json; use axum::{extract::State, http::StatusCode, response::IntoResponse}; use common::cairo_prover_input::CairoProverInput; @@ -16,7 +15,7 @@ pub async fn root( ) -> impl IntoResponse { let thread_pool = app_state.thread_pool.clone(); let job_store = app_state.job_store.clone(); - let job_id = create_job(&job_store).await; + let job_id = job_store.create_job().await; let thread = thread_pool.lock().await; thread .execute( diff --git a/prover/src/prove/cairo0.rs b/prover/src/prove/cairo0.rs index 4860382..b836e31 100644 --- a/prover/src/prove/cairo0.rs +++ b/prover/src/prove/cairo0.rs @@ -2,7 +2,6 @@ use crate::auth::jwt::Claims; use crate::extractors::workdir::TempDirHandle; use crate::server::AppState; use crate::threadpool::CairoVersionedInput; -use crate::utils::job::create_job; use axum::Json; use axum::{extract::State, http::StatusCode, response::IntoResponse}; use common::cairo0_prover_input::Cairo0ProverInput; @@ -16,7 +15,7 @@ pub async fn root( ) -> impl IntoResponse { let thread_pool = app_state.thread_pool.clone(); let job_store = app_state.job_store.clone(); - let job_id = create_job(&job_store).await; + let job_id = job_store.create_job().await; let thread = thread_pool.lock().await; thread .execute( diff --git a/prover/src/server.rs b/prover/src/server.rs index 0ce1395..9510802 100644 --- a/prover/src/server.rs +++ b/prover/src/server.rs @@ -68,7 +68,7 @@ pub async fn start(args: Args) -> Result<(), ProverError> { jwt_secret_key: args.jwt_secret_key, nonces: Arc::new(Mutex::new(HashMap::new())), authorizer, - job_store: Arc::new(Mutex::new(Vec::new())), + job_store: JobStore::default(), thread_pool: Arc::new(Mutex::new(ThreadPool::new(args.num_workers))), admin_key, sse_tx: Arc::new(Mutex::new(sse_tx)), diff --git a/prover/src/sse.rs b/prover/src/sse.rs index 4fb1868..46b4c6d 100644 --- a/prover/src/sse.rs +++ b/prover/src/sse.rs @@ -23,13 +23,11 @@ pub async fn sse_handler( let mut rx = state.sse_tx.lock().await.subscribe(); let job_id = params.job_id; - let job_status = { - let jobs = state.job_store.lock().await; - jobs.iter() - .find(|job| job.id == job_id) - .map(|job| job.status.clone()) - .unwrap_or(JobStatus::Unknown) - }; + let job_status = state + .job_store + .get_job(job_id) + .await + .map_or(JobStatus::Unknown, |j| j.status); let stream = stream! { if matches!(job_status.clone(), JobStatus::Completed | JobStatus::Failed) { diff --git a/prover/src/threadpool/prove.rs b/prover/src/threadpool/prove.rs index 7858e88..f79cb70 100644 --- a/prover/src/threadpool/prove.rs +++ b/prover/src/threadpool/prove.rs @@ -1,9 +1,6 @@ use super::CairoVersionedInput; use crate::errors::ProverError; -use crate::utils::{ - config::generate, - job::{update_job_status, JobStore}, -}; +use crate::utils::{config::Template, job::JobStore}; use common::models::JobStatus; use serde_json::Value; use starknet_types_core::felt::Felt; @@ -24,7 +21,9 @@ pub async fn prove( program_input: CairoVersionedInput, sse_tx: Arc>>, ) -> Result<(), ProverError> { - update_job_status(job_id, &job_store, JobStatus::Running, None).await; + job_store + .update_job_status(job_id, JobStatus::Running, None) + .await; let path = dir.into_path(); let program_input_path: PathBuf = path.join("program_input.json"); let program_path: PathBuf = path.join("program.json"); @@ -39,17 +38,17 @@ pub async fn prove( CairoVersionedInput::Cairo(input) => { let program = serde_json::to_string(&input.program)?; let layout = input.layout; - let input = prepare_input(input.program_input)?; + let input = prepare_input(&input.program_input); fs::write(&program_path, &program)?; fs::write(&program_input_path, &input)?; cairo_run( - trace_file, - memory_file, + &trace_file, + &memory_file, layout, - public_input_file.clone(), - private_input_file.clone(), - program_input_path, - program_path, + &public_input_file, + &private_input_file, + &program_input_path, + &program_path, ) .await?; } @@ -61,19 +60,19 @@ pub async fn prove( fs::write(&program_path, serde_json::to_string(&input.program)?)?; let layout = input.layout; cairo0_run( - trace_file, - memory_file, + &trace_file, + &memory_file, layout, - public_input_file.clone(), - private_input_file.clone(), - program_input_path, - program_path, + &public_input_file, + &private_input_file, + &program_input_path, + &program_path, ) .await?; } } - generate(public_input_file.clone(), params_file.clone())?; + Template::generate_from_public_input_file(&public_input_file)?.save_to_file(¶ms_file)?; let mut command_proof = Command::new("cpu_air_prover"); command_proof @@ -97,14 +96,18 @@ pub async fn prove( let sender = sse_tx.lock().await; if status_proof.success() { - update_job_status(job_id, &job_store, JobStatus::Completed, Some(final_result)).await; + job_store + .update_job_status(job_id, JobStatus::Completed, Some(final_result)) + .await; if sender.receiver_count() > 0 { sender .send(serde_json::to_string(&(JobStatus::Completed, job_id))?) .unwrap(); } } else { - update_job_status(job_id, &job_store, JobStatus::Failed, Some(final_result)).await; + job_store + .update_job_status(job_id, JobStatus::Failed, Some(final_result)) + .await; if sender.receiver_count() > 0 { sender .send(serde_json::to_string(&(JobStatus::Failed, job_id))?) @@ -115,32 +118,32 @@ pub async fn prove( } pub async fn cairo0_run( - trace_file: PathBuf, - memory_file: PathBuf, + trace_file: &PathBuf, + memory_file: &PathBuf, layout: String, - public_input_file: PathBuf, - private_input_file: PathBuf, - program_input_path: PathBuf, - program_path: PathBuf, + public_input_file: &PathBuf, + private_input_file: &PathBuf, + program_input_path: &PathBuf, + program_path: &PathBuf, ) -> Result<(), ProverError> { trace!("Running cairo0-run"); let mut command = Command::new("cairo-run"); command .arg("--trace_file") - .arg(&trace_file) + .arg(trace_file) .arg("--memory_file") - .arg(&memory_file) + .arg(memory_file) .arg("--layout") .arg(layout) .arg("--proof_mode") .arg("--air_public_input") - .arg(&public_input_file) + .arg(public_input_file) .arg("--air_private_input") - .arg(&private_input_file) + .arg(private_input_file) .arg("--program_input") - .arg(&program_input_path) + .arg(program_input_path) .arg("--program") - .arg(&program_path) + .arg(program_path) .stdout(std::process::Stdio::piped()) .stderr(std::process::Stdio::piped()); @@ -155,56 +158,59 @@ pub async fn cairo0_run( Ok(()) } pub async fn cairo_run( - trace_file: PathBuf, - memory_file: PathBuf, + trace_file: &PathBuf, + memory_file: &PathBuf, layout: String, - public_input_file: PathBuf, - private_input_file: PathBuf, - program_input_path: PathBuf, - program_path: PathBuf, + public_input_file: &PathBuf, + private_input_file: &PathBuf, + program_input_path: &PathBuf, + program_path: &PathBuf, ) -> Result<(), ProverError> { let mut command = Command::new("cairo1-run"); command .arg("--trace_file") - .arg(&trace_file) + .arg(trace_file) .arg("--memory_file") - .arg(&memory_file) + .arg(memory_file) .arg("--layout") .arg(layout) .arg("--proof_mode") .arg("--air_public_input") - .arg(&public_input_file) + .arg(public_input_file) .arg("--air_private_input") - .arg(&private_input_file) + .arg(private_input_file) .arg("--args_file") - .arg(&program_input_path) - .arg(&program_path) + .arg(program_input_path) + .arg(program_path) .stdout(std::process::Stdio::piped()) .stderr(std::process::Stdio::piped()); let child = command.spawn()?; let output = child.wait_with_output().await?; - // Capture stderr in case of an error - if !output.status.success() { + if output.status.success() { + Ok(()) + } else { let stderr = String::from_utf8_lossy(&output.stderr); - return Err(ProverError::CustomError(stderr.into())); + Err(ProverError::CustomError(stderr.into())) } - Ok(()) } -pub fn prepare_input(felts: Vec) -> Result { - if felts.is_empty() { - return Err(ProverError::CustomError( - "Input is empty, input must be a array of felt in format [felt,...,felt]".to_string(), - )); - } - let mut input = String::from("["); - for i in 0..felts.len() { - input.push_str(&felts[i].to_string()); - if i != felts.len() - 1 { - input.push(' '); - } - } - input.push(']'); - Ok(input) + +pub fn prepare_input(felts: &[Felt]) -> String { + felts + .iter() + .fold("[".to_string(), |a, i| a + &i.to_string() + " ") + .trim_end() + .to_string() + + "]" +} + +#[test] +fn test_prepare_input() { + assert_eq!("[]", prepare_input(&[])); + assert_eq!("[1]", prepare_input(&[1.into()])); + assert_eq!( + "[1 2 3 4]", + prepare_input(&[1.into(), 2.into(), 3.into(), 4.into()]) + ); } diff --git a/prover/src/utils/config.rs b/prover/src/utils/config.rs index 85f5970..e4bb348 100644 --- a/prover/src/utils/config.rs +++ b/prover/src/utils/config.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use std::fs::File; -use std::io::{Read, Write}; +use std::io::{BufReader, Write}; use std::path::PathBuf; use crate::errors::ProverError; @@ -21,7 +21,7 @@ struct Stark { } #[derive(Serialize, Deserialize, Debug)] -struct Template { +pub struct Template { field: String, channel_hash: String, commitment_hash: String, @@ -34,71 +34,68 @@ struct Template { verifier_friendly_commitment_hash: String, } -fn calculate_fri_step_list(n_steps: u32, degree_bound: u32) -> Vec { - let fri_degree = ((n_steps as f64 / degree_bound as f64).log(2.0).round() as u32) + 4; - let mut steps = vec![0]; - steps.extend(vec![4; (fri_degree / 4) as usize]); - if fri_degree % 4 != 0 { - steps.push(fri_degree % 4); +impl Template { + pub fn generate_from_public_input_file(file: &PathBuf) -> Result { + Self::generate_from_public_input(ProgramPublicInputAsNSteps::read_from_file(file)?) + } + pub fn save_to_file(&self, file: &PathBuf) -> Result<(), ProverError> { + let json_string = serde_json::to_string_pretty(self)?; + File::create(file)? + .write_all(json_string.as_bytes()) + .map_err(ProverError::from) + } + fn generate_from_public_input( + public_input: ProgramPublicInputAsNSteps, + ) -> Result { + let mut template = Self::default(); + let fri_step_list = + public_input.calculate_fri_step_list(template.stark.fri.last_layer_degree_bound); + template.stark.fri.fri_step_list = fri_step_list; + Ok(template) } - steps } -fn update_template_and_save_to_file( - template: &mut Template, - fri_step_list: Vec, - file_path: PathBuf, -) -> Result<(), ProverError> { - template.stark.fri.fri_step_list = fri_step_list; - let mut file: File = File::create(file_path)?; - let json_string = serde_json::to_string_pretty(template)?; - file.write_all(json_string.as_bytes())?; - Ok(()) +impl core::default::Default for Template { + fn default() -> Self { + Template { + field: "PrimeField0".to_string(), + channel_hash: "poseidon3".to_string(), + commitment_hash: "blake256_masked160_lsb".to_string(), + n_verifier_friendly_commitment_layers: 9999, + pow_hash: "keccak256".to_string(), + statement: serde_json::json!({ "page_hash": "pedersen" }), + stark: Stark { + fri: StarkFri { + fri_step_list: vec![0, 4, 4, 4], + last_layer_degree_bound: 128, + n_queries: 16, + proof_of_work_bits: 30, + }, + log_n_cosets: 3, + }, + use_extension_field: false, + verifier_friendly_channel_updates: true, + verifier_friendly_commitment_hash: "poseidon3".to_string(), + } + } } -fn read_json_from_file(file_path: PathBuf) -> Result { - let mut buffer = String::new(); - let mut file = File::open(file_path)?; - file.read_to_string(&mut buffer)?; - let result = serde_json::from_str(&buffer)?; - Ok(result) +#[derive(Debug, Deserialize)] +struct ProgramPublicInputAsNSteps { + n_steps: u32, } -pub fn generate(input_file: PathBuf, output_file: PathBuf) -> Result<(), ProverError> { - let program_public_input: Value = read_json_from_file(input_file)?; - - let n_steps: u32 = match program_public_input["n_steps"].as_u64() { - Some(val) => val as u32, - None => { - return Err(ProverError::CustomError( - "Failed to get n_steps from cairo run execution".to_string(), - )) +impl ProgramPublicInputAsNSteps { + pub fn read_from_file(input_file: &PathBuf) -> Result { + serde_json::from_reader(BufReader::new(File::open(input_file)?)).map_err(ProverError::from) + } + fn calculate_fri_step_list(&self, degree_bound: u32) -> Vec { + let fri_degree = ((self.n_steps as f64 / degree_bound as f64).log(2.0).round() as u32) + 4; + let mut steps = vec![0]; + steps.extend(vec![4; (fri_degree / 4) as usize]); + if fri_degree % 4 != 0 { + steps.push(fri_degree % 4); } - }; - let mut template = Template { - field: "PrimeField0".to_string(), - channel_hash: "poseidon3".to_string(), - commitment_hash: "blake256_masked160_lsb".to_string(), - n_verifier_friendly_commitment_layers: 9999, - pow_hash: "keccak256".to_string(), - statement: serde_json::json!({ "page_hash": "pedersen" }), - stark: Stark { - fri: StarkFri { - fri_step_list: vec![0, 4, 4, 4], - last_layer_degree_bound: 128, - n_queries: 16, - proof_of_work_bits: 30, - }, - log_n_cosets: 3, - }, - use_extension_field: false, - verifier_friendly_channel_updates: true, - verifier_friendly_commitment_hash: "poseidon3".to_string(), - }; - - let last_layer_degree_bound = template.stark.fri.last_layer_degree_bound; - - let fri_step_list = calculate_fri_step_list(n_steps, last_layer_degree_bound); - update_template_and_save_to_file(&mut template, fri_step_list, output_file)?; - Ok(()) + steps + } } diff --git a/prover/src/utils/job.rs b/prover/src/utils/job.rs index 7b6345d..112e033 100644 --- a/prover/src/utils/job.rs +++ b/prover/src/utils/job.rs @@ -6,16 +6,21 @@ use axum::{ }; use common::models::JobStatus; use serde::Serialize; -use std::sync::Arc; +use std::{ + collections::BTreeMap, + sync::Arc, + time::{Duration, Instant}, +}; use tokio::sync::Mutex; use crate::{auth::jwt::Claims, server::AppState}; -#[derive(Serialize, Clone)] +#[derive(Clone)] pub struct Job { pub id: u64, pub status: JobStatus, - pub result: Option, // You can change this to any type based on your use case + pub result: Option, + pub created: Instant, } #[derive(Serialize)] @@ -25,42 +30,77 @@ pub enum JobResponse { Completed { result: String, status: JobStatus }, Failed { error: String }, } -pub type JobStore = Arc>>; -pub async fn create_job(job_store: &JobStore) -> u64 { - let mut jobs = job_store.lock().await; - let job_id = jobs.len() as u64; - let new_job = Job { - id: job_id, - status: JobStatus::Pending, - result: None, - }; - jobs.push(new_job); - drop(jobs); - job_id +#[derive(Default, Clone)] +pub struct JobStore { + inner: Arc>, } -pub async fn update_job_status( - job_id: u64, - job_store: &JobStore, - status: JobStatus, - result: Option, -) { - let mut jobs = job_store.lock().await; - if let Some(job) = jobs.iter_mut().find(|job| job.id == job_id) { - job.status = status; - job.result = result; +impl JobStore { + pub async fn create_job(&self) -> u64 { + self.inner.lock().await.create_job() + } + pub async fn update_job_status(&self, job_id: u64, status: JobStatus, result: Option) { + self.inner + .lock() + .await + .update_job_status(job_id, status, result) + } + pub async fn get_job(&self, id: u64) -> Option { + self.inner.lock().await.get_job(id) } - drop(jobs); } + +#[derive(Default)] +struct JobStoreInner { + jobs: BTreeMap, + counter: u64, +} + +impl JobStoreInner { + pub fn create_job(&mut self) -> u64 { + let job_id = self.counter; + self.counter += 1; + let new_job = Job { + id: job_id, + status: JobStatus::Pending, + result: None, + created: Instant::now(), + }; + self.jobs.insert(job_id, new_job); + self.clear_old_jobs(); + job_id + } + pub fn update_job_status(&mut self, job_id: u64, status: JobStatus, result: Option) { + if let Some(job) = self.jobs.get_mut(&job_id) { + job.status = status; + job.result = result; + } + self.clear_old_jobs() + } + pub fn get_job(&mut self, id: u64) -> Option { + let job = self.jobs.get(&id).cloned(); + self.clear_old_jobs(); + job + } + // Clear old jobs so that the memory doesn't go balistic if the server runs for a long time + fn clear_old_jobs(&mut self) { + let expiry_duration = Duration::from_secs(5 * 60 * 60); // 5 hours + while let Some((id, job)) = self.jobs.pop_first() { + if job.created.elapsed() < expiry_duration { + self.jobs.insert(id, job); + break; + } + } + } +} + pub async fn get_job( Path(id): Path, State(app_state): State, _claims: Claims, ) -> impl IntoResponse { - let job_store = &app_state.job_store; - let jobs = job_store.lock().await; - if let Some(job) = jobs.iter().find(|job| job.id == id) { + if let Some(job) = app_state.job_store.get_job(id).await { let (status, response) = match job.status { JobStatus::Pending | JobStatus::Running => ( StatusCode::OK, diff --git a/prover/src/verifier.rs b/prover/src/verifier.rs index d1315b8..e93765e 100644 --- a/prover/src/verifier.rs +++ b/prover/src/verifier.rs @@ -1,9 +1,6 @@ use crate::{ - auth::jwt::Claims, - errors::ProverError, - extractors::workdir::TempDirHandle, - server::AppState, - utils::job::{create_job, update_job_status, JobStore}, + auth::jwt::Claims, errors::ProverError, extractors::workdir::TempDirHandle, server::AppState, + utils::job::JobStore, }; use axum::{extract::State, http::StatusCode, response::IntoResponse, Json}; use common::models::JobStatus; @@ -19,14 +16,17 @@ pub async fn root( _claims: Claims, Json(proof): Json, ) -> impl IntoResponse { - let job_id = create_job(&app_state.job_store).await; let job_store = app_state.job_store.clone(); + let job_id = job_store.create_job().await; + tokio::spawn({ async move { if let Err(e) = verify_proof(job_id, job_store.clone(), dir, proof, app_state.sse_tx).await { - update_job_status(job_id, &job_store, JobStatus::Failed, Some(e.to_string())).await; + job_store + .update_job_status(job_id, JobStatus::Failed, Some(e.to_string())) + .await; } } }); @@ -44,7 +44,9 @@ pub async fn verify_proof( proof: String, sender: Arc>>, ) -> Result<(), ProverError> { - update_job_status(job_id, &job_store, JobStatus::Running, None).await; + job_store + .update_job_status(job_id, JobStatus::Running, None) + .await; // Define the path for the proof file let path = dir.into_path(); @@ -63,13 +65,13 @@ pub async fn verify_proof( std::fs::remove_file(&file)?; // Check if the command was successful - update_job_status( - job_id, - &job_store, - JobStatus::Completed, - Some(status.success().to_string()), - ) - .await; + job_store + .update_job_status( + job_id, + JobStatus::Completed, + Some(status.success().to_string()), + ) + .await; let sender = sender.lock().await; if sender.receiver_count() > 0 { sender