Skip to content

Commit

Permalink
Fix/cario run errors (#45)
Browse files Browse the repository at this point in the history
* error handling in config.rs

* fmt
  • Loading branch information
chudkowsky authored Sep 9, 2024
1 parent fd6115d commit 62752e6
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 31 deletions.
30 changes: 23 additions & 7 deletions prover/src/threadpool/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pub async fn prove(
}
}

generate(public_input_file.clone(), params_file.clone());
generate(public_input_file.clone(), params_file.clone())?;

let mut command_proof = Command::new("cpu_air_prover");
command_proof
Expand Down Expand Up @@ -140,10 +140,18 @@ pub async fn cairo0_run(
.arg("--program_input")
.arg(&program_input_path)
.arg("--program")
.arg(&program_path);
.arg(&program_path)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped());

let mut child = command.spawn()?;
let _status = child.wait().await?;
let child = command.spawn()?;
let output = child.wait_with_output().await?;

// Capture stderr in case of an error
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(ProverError::CustomError(stderr.into()));
}
Ok(())
}
pub async fn cairo_run(
Expand All @@ -170,10 +178,18 @@ pub async fn cairo_run(
.arg(&private_input_file)
.arg("--args_file")
.arg(&program_input_path)
.arg(&program_path);
.arg(&program_path)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped());

let mut child = command.spawn()?;
let _status = child.wait().await?;
let child = command.spawn()?;
let output = child.wait_with_output().await?;

// Capture stderr in case of an error
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(ProverError::CustomError(stderr.into()));
}
Ok(())
}
pub fn prepare_input(felts: Vec<Felt>) -> Result<String, ProverError> {
Expand Down
44 changes: 20 additions & 24 deletions prover/src/utils/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use serde_json::Value;
use std::fs::File;
use std::io::{Read, Write};
use std::path::PathBuf;
use std::process;

use crate::errors::ProverError;

#[derive(Serialize, Deserialize, Debug)]
struct StarkFri {
Expand Down Expand Up @@ -47,39 +48,33 @@ fn update_template_and_save_to_file(
template: &mut Template,
fri_step_list: Vec<u32>,
file_path: PathBuf,
) -> Result<(), String> {
) -> Result<(), ProverError> {
template.stark.fri.fri_step_list = fri_step_list;
let mut file: File = File::create(file_path).map_err(|e| e.to_string())?;
let json_string = serde_json::to_string_pretty(template).expect("Failed to serialize JSON");
file.write_all(json_string.as_bytes())
.map_err(|e| e.to_string())
let mut file: File = File::create(file_path)?;
let json_string = serde_json::to_string_pretty(template)?;
file.write_all(json_string.as_bytes())?;
Ok(())
}

fn read_json_from_file(file_path: PathBuf) -> Result<Value, String> {
fn read_json_from_file(file_path: PathBuf) -> Result<Value, ProverError> {
let mut buffer = String::new();
let mut file = File::open(file_path).map_err(|e| e.to_string())?;
file.read_to_string(&mut buffer)
.map_err(|e| e.to_string())?;
serde_json::from_str(&buffer).map_err(|e| e.to_string())
let mut file = File::open(file_path)?;
file.read_to_string(&mut buffer)?;
let result = serde_json::from_str(&buffer)?;
Ok(result)
}

pub fn generate(input_file: PathBuf, output_file: PathBuf) {
let program_public_input: Value = match read_json_from_file(input_file) {
Ok(data) => data,
Err(err) => {
eprintln!("Error: Invalid JSON input. {}", err);
process::exit(1);
}
};
pub fn generate(input_file: PathBuf, output_file: PathBuf) -> Result<(), ProverError> {
let program_public_input: Value = read_json_from_file(input_file)?;

let n_steps = match program_public_input["n_steps"].as_u64() {
let n_steps: u32 = match program_public_input["n_steps"].as_u64() {
Some(val) => val as u32,
None => {
eprintln!("Error: 'n_steps' is missing or not an integer.");
process::exit(1);
return Err(ProverError::CustomError(
"Failed to get n_steps from cairo run execution".to_string(),
))
}
};

let mut template = Template {
field: "PrimeField0".to_string(),
channel_hash: "poseidon3".to_string(),
Expand All @@ -104,5 +99,6 @@ pub fn generate(input_file: PathBuf, output_file: PathBuf) {
let last_layer_degree_bound = template.stark.fri.last_layer_degree_bound;

let fri_step_list = calculate_fri_step_list(n_steps, last_layer_degree_bound);
let _ = update_template_and_save_to_file(&mut template, fri_step_list, output_file);
update_template_and_save_to_file(&mut template, fri_step_list, output_file)?;
Ok(())
}

0 comments on commit 62752e6

Please sign in to comment.