Skip to content

Commit

Permalink
Piniom/1 Small code quality tweaks (#47)
Browse files Browse the repository at this point in the history
* small code quality tweaks

* Piniom/2 Job store  (#48)

* made job store a struct

* Piniom/3 Fix the `JobStore` (#50)

* fixed problems with the `JobStore`

* cargo clippy and fmt
  • Loading branch information
piniom authored Sep 10, 2024
1 parent b2c64c8 commit 6bd6430
Show file tree
Hide file tree
Showing 8 changed files with 222 additions and 181 deletions.
3 changes: 1 addition & 2 deletions prover/src/prove/cairo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::auth::jwt::Claims;
use crate::extractors::workdir::TempDirHandle;
use crate::server::AppState;
use crate::threadpool::CairoVersionedInput;
use crate::utils::job::create_job;
use axum::Json;
use axum::{extract::State, http::StatusCode, response::IntoResponse};
use common::cairo_prover_input::CairoProverInput;
Expand All @@ -16,7 +15,7 @@ pub async fn root(
) -> impl IntoResponse {
let thread_pool = app_state.thread_pool.clone();
let job_store = app_state.job_store.clone();
let job_id = create_job(&job_store).await;
let job_id = job_store.create_job().await;
let thread = thread_pool.lock().await;
thread
.execute(
Expand Down
3 changes: 1 addition & 2 deletions prover/src/prove/cairo0.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::auth::jwt::Claims;
use crate::extractors::workdir::TempDirHandle;
use crate::server::AppState;
use crate::threadpool::CairoVersionedInput;
use crate::utils::job::create_job;
use axum::Json;
use axum::{extract::State, http::StatusCode, response::IntoResponse};
use common::cairo0_prover_input::Cairo0ProverInput;
Expand All @@ -16,7 +15,7 @@ pub async fn root(
) -> impl IntoResponse {
let thread_pool = app_state.thread_pool.clone();
let job_store = app_state.job_store.clone();
let job_id = create_job(&job_store).await;
let job_id = job_store.create_job().await;
let thread = thread_pool.lock().await;
thread
.execute(
Expand Down
2 changes: 1 addition & 1 deletion prover/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub async fn start(args: Args) -> Result<(), ProverError> {
jwt_secret_key: args.jwt_secret_key,
nonces: Arc::new(Mutex::new(HashMap::new())),
authorizer,
job_store: Arc::new(Mutex::new(Vec::new())),
job_store: JobStore::default(),
thread_pool: Arc::new(Mutex::new(ThreadPool::new(args.num_workers))),
admin_key,
sse_tx: Arc::new(Mutex::new(sse_tx)),
Expand Down
12 changes: 5 additions & 7 deletions prover/src/sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@ pub async fn sse_handler(
let mut rx = state.sse_tx.lock().await.subscribe();
let job_id = params.job_id;

let job_status = {
let jobs = state.job_store.lock().await;
jobs.iter()
.find(|job| job.id == job_id)
.map(|job| job.status.clone())
.unwrap_or(JobStatus::Unknown)
};
let job_status = state
.job_store
.get_job(job_id)
.await
.map_or(JobStatus::Unknown, |j| j.status);

let stream = stream! {
if matches!(job_status.clone(), JobStatus::Completed | JobStatus::Failed) {
Expand Down
134 changes: 70 additions & 64 deletions prover/src/threadpool/prove.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use super::CairoVersionedInput;
use crate::errors::ProverError;
use crate::utils::{
config::generate,
job::{update_job_status, JobStore},
};
use crate::utils::{config::Template, job::JobStore};
use common::models::JobStatus;
use serde_json::Value;
use starknet_types_core::felt::Felt;
Expand All @@ -24,7 +21,9 @@ pub async fn prove(
program_input: CairoVersionedInput,
sse_tx: Arc<Mutex<Sender<String>>>,
) -> Result<(), ProverError> {
update_job_status(job_id, &job_store, JobStatus::Running, None).await;
job_store
.update_job_status(job_id, JobStatus::Running, None)
.await;
let path = dir.into_path();
let program_input_path: PathBuf = path.join("program_input.json");
let program_path: PathBuf = path.join("program.json");
Expand All @@ -39,17 +38,17 @@ pub async fn prove(
CairoVersionedInput::Cairo(input) => {
let program = serde_json::to_string(&input.program)?;
let layout = input.layout;
let input = prepare_input(input.program_input)?;
let input = prepare_input(&input.program_input);
fs::write(&program_path, &program)?;
fs::write(&program_input_path, &input)?;
cairo_run(
trace_file,
memory_file,
&trace_file,
&memory_file,
layout,
public_input_file.clone(),
private_input_file.clone(),
program_input_path,
program_path,
&public_input_file,
&private_input_file,
&program_input_path,
&program_path,
)
.await?;
}
Expand All @@ -61,19 +60,19 @@ pub async fn prove(
fs::write(&program_path, serde_json::to_string(&input.program)?)?;
let layout = input.layout;
cairo0_run(
trace_file,
memory_file,
&trace_file,
&memory_file,
layout,
public_input_file.clone(),
private_input_file.clone(),
program_input_path,
program_path,
&public_input_file,
&private_input_file,
&program_input_path,
&program_path,
)
.await?;
}
}

generate(public_input_file.clone(), params_file.clone())?;
Template::generate_from_public_input_file(&public_input_file)?.save_to_file(&params_file)?;

let mut command_proof = Command::new("cpu_air_prover");
command_proof
Expand All @@ -97,14 +96,18 @@ pub async fn prove(
let sender = sse_tx.lock().await;

if status_proof.success() {
update_job_status(job_id, &job_store, JobStatus::Completed, Some(final_result)).await;
job_store
.update_job_status(job_id, JobStatus::Completed, Some(final_result))
.await;
if sender.receiver_count() > 0 {
sender
.send(serde_json::to_string(&(JobStatus::Completed, job_id))?)
.unwrap();
}
} else {
update_job_status(job_id, &job_store, JobStatus::Failed, Some(final_result)).await;
job_store
.update_job_status(job_id, JobStatus::Failed, Some(final_result))
.await;
if sender.receiver_count() > 0 {
sender
.send(serde_json::to_string(&(JobStatus::Failed, job_id))?)
Expand All @@ -115,32 +118,32 @@ pub async fn prove(
}

pub async fn cairo0_run(
trace_file: PathBuf,
memory_file: PathBuf,
trace_file: &PathBuf,
memory_file: &PathBuf,
layout: String,
public_input_file: PathBuf,
private_input_file: PathBuf,
program_input_path: PathBuf,
program_path: PathBuf,
public_input_file: &PathBuf,
private_input_file: &PathBuf,
program_input_path: &PathBuf,
program_path: &PathBuf,
) -> Result<(), ProverError> {
trace!("Running cairo0-run");
let mut command = Command::new("cairo-run");
command
.arg("--trace_file")
.arg(&trace_file)
.arg(trace_file)
.arg("--memory_file")
.arg(&memory_file)
.arg(memory_file)
.arg("--layout")
.arg(layout)
.arg("--proof_mode")
.arg("--air_public_input")
.arg(&public_input_file)
.arg(public_input_file)
.arg("--air_private_input")
.arg(&private_input_file)
.arg(private_input_file)
.arg("--program_input")
.arg(&program_input_path)
.arg(program_input_path)
.arg("--program")
.arg(&program_path)
.arg(program_path)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped());

Expand All @@ -155,56 +158,59 @@ pub async fn cairo0_run(
Ok(())
}
pub async fn cairo_run(
trace_file: PathBuf,
memory_file: PathBuf,
trace_file: &PathBuf,
memory_file: &PathBuf,
layout: String,
public_input_file: PathBuf,
private_input_file: PathBuf,
program_input_path: PathBuf,
program_path: PathBuf,
public_input_file: &PathBuf,
private_input_file: &PathBuf,
program_input_path: &PathBuf,
program_path: &PathBuf,
) -> Result<(), ProverError> {
let mut command = Command::new("cairo1-run");
command
.arg("--trace_file")
.arg(&trace_file)
.arg(trace_file)
.arg("--memory_file")
.arg(&memory_file)
.arg(memory_file)
.arg("--layout")
.arg(layout)
.arg("--proof_mode")
.arg("--air_public_input")
.arg(&public_input_file)
.arg(public_input_file)
.arg("--air_private_input")
.arg(&private_input_file)
.arg(private_input_file)
.arg("--args_file")
.arg(&program_input_path)
.arg(&program_path)
.arg(program_input_path)
.arg(program_path)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped());

let child = command.spawn()?;
let output = child.wait_with_output().await?;

// Capture stderr in case of an error
if !output.status.success() {
if output.status.success() {
Ok(())
} else {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(ProverError::CustomError(stderr.into()));
Err(ProverError::CustomError(stderr.into()))
}
Ok(())
}
pub fn prepare_input(felts: Vec<Felt>) -> Result<String, ProverError> {
if felts.is_empty() {
return Err(ProverError::CustomError(
"Input is empty, input must be a array of felt in format [felt,...,felt]".to_string(),
));
}
let mut input = String::from("[");
for i in 0..felts.len() {
input.push_str(&felts[i].to_string());
if i != felts.len() - 1 {
input.push(' ');
}
}
input.push(']');
Ok(input)

pub fn prepare_input(felts: &[Felt]) -> String {
felts
.iter()
.fold("[".to_string(), |a, i| a + &i.to_string() + " ")
.trim_end()
.to_string()
+ "]"
}

#[test]
fn test_prepare_input() {
assert_eq!("[]", prepare_input(&[]));
assert_eq!("[1]", prepare_input(&[1.into()]));
assert_eq!(
"[1 2 3 4]",
prepare_input(&[1.into(), 2.into(), 3.into(), 4.into()])
);
}
Loading

0 comments on commit 6bd6430

Please sign in to comment.