From 62752e6d7630aec33768a63afdbaedd972cbf0b0 Mon Sep 17 00:00:00 2001 From: Mateusz Chudkowski <120587768+chudkowsky@users.noreply.github.com> Date: Mon, 9 Sep 2024 22:32:15 +0200 Subject: [PATCH] Fix/cario run errors (#45) * error handling in config.rs * fmt --- prover/src/threadpool/prove.rs | 30 +++++++++++++++++------ prover/src/utils/config.rs | 44 ++++++++++++++++------------------ 2 files changed, 43 insertions(+), 31 deletions(-) diff --git a/prover/src/threadpool/prove.rs b/prover/src/threadpool/prove.rs index 21516da..7858e88 100644 --- a/prover/src/threadpool/prove.rs +++ b/prover/src/threadpool/prove.rs @@ -73,7 +73,7 @@ pub async fn prove( } } - generate(public_input_file.clone(), params_file.clone()); + generate(public_input_file.clone(), params_file.clone())?; let mut command_proof = Command::new("cpu_air_prover"); command_proof @@ -140,10 +140,18 @@ pub async fn cairo0_run( .arg("--program_input") .arg(&program_input_path) .arg("--program") - .arg(&program_path); + .arg(&program_path) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()); - let mut child = command.spawn()?; - let _status = child.wait().await?; + let child = command.spawn()?; + let output = child.wait_with_output().await?; + + // Capture stderr in case of an error + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(ProverError::CustomError(stderr.into())); + } Ok(()) } pub async fn cairo_run( @@ -170,10 +178,18 @@ pub async fn cairo_run( .arg(&private_input_file) .arg("--args_file") .arg(&program_input_path) - .arg(&program_path); + .arg(&program_path) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()); - let mut child = command.spawn()?; - let _status = child.wait().await?; + let child = command.spawn()?; + let output = child.wait_with_output().await?; + + // Capture stderr in case of an error + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(ProverError::CustomError(stderr.into())); + } Ok(()) } pub fn prepare_input(felts: Vec) -> Result { diff --git a/prover/src/utils/config.rs b/prover/src/utils/config.rs index 28359b3..85f5970 100644 --- a/prover/src/utils/config.rs +++ b/prover/src/utils/config.rs @@ -3,7 +3,8 @@ use serde_json::Value; use std::fs::File; use std::io::{Read, Write}; use std::path::PathBuf; -use std::process; + +use crate::errors::ProverError; #[derive(Serialize, Deserialize, Debug)] struct StarkFri { @@ -47,39 +48,33 @@ fn update_template_and_save_to_file( template: &mut Template, fri_step_list: Vec, file_path: PathBuf, -) -> Result<(), String> { +) -> Result<(), ProverError> { template.stark.fri.fri_step_list = fri_step_list; - let mut file: File = File::create(file_path).map_err(|e| e.to_string())?; - let json_string = serde_json::to_string_pretty(template).expect("Failed to serialize JSON"); - file.write_all(json_string.as_bytes()) - .map_err(|e| e.to_string()) + let mut file: File = File::create(file_path)?; + let json_string = serde_json::to_string_pretty(template)?; + file.write_all(json_string.as_bytes())?; + Ok(()) } -fn read_json_from_file(file_path: PathBuf) -> Result { +fn read_json_from_file(file_path: PathBuf) -> Result { let mut buffer = String::new(); - let mut file = File::open(file_path).map_err(|e| e.to_string())?; - file.read_to_string(&mut buffer) - .map_err(|e| e.to_string())?; - serde_json::from_str(&buffer).map_err(|e| e.to_string()) + let mut file = File::open(file_path)?; + file.read_to_string(&mut buffer)?; + let result = serde_json::from_str(&buffer)?; + Ok(result) } -pub fn generate(input_file: PathBuf, output_file: PathBuf) { - let program_public_input: Value = match read_json_from_file(input_file) { - Ok(data) => data, - Err(err) => { - eprintln!("Error: Invalid JSON input. {}", err); - process::exit(1); - } - }; +pub fn generate(input_file: PathBuf, output_file: PathBuf) -> Result<(), ProverError> { + let program_public_input: Value = read_json_from_file(input_file)?; - let n_steps = match program_public_input["n_steps"].as_u64() { + let n_steps: u32 = match program_public_input["n_steps"].as_u64() { Some(val) => val as u32, None => { - eprintln!("Error: 'n_steps' is missing or not an integer."); - process::exit(1); + return Err(ProverError::CustomError( + "Failed to get n_steps from cairo run execution".to_string(), + )) } }; - let mut template = Template { field: "PrimeField0".to_string(), channel_hash: "poseidon3".to_string(), @@ -104,5 +99,6 @@ pub fn generate(input_file: PathBuf, output_file: PathBuf) { let last_layer_degree_bound = template.stark.fri.last_layer_degree_bound; let fri_step_list = calculate_fri_step_list(n_steps, last_layer_degree_bound); - let _ = update_template_and_save_to_file(&mut template, fri_step_list, output_file); + update_template_and_save_to_file(&mut template, fri_step_list, output_file)?; + Ok(()) }