diff --git a/Cargo.lock b/Cargo.lock index 381109a..c362ae0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -623,8 +623,10 @@ name = "expander_compiler" version = "0.1.0" dependencies = [ "arith", + "ark-std", "chrono", "circuit", + "clap", "config", "ethnum", "gf2", diff --git a/Cargo.toml b/Cargo.toml index a98c5c2..21b1ead 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,8 +10,10 @@ opt-level = 3 [workspace.dependencies] +ark-std = "0.4.0" rand = "0.8.5" chrono = "0.4" +clap = { version = "4.1", features = ["derive"] } ethnum = "1.5.0" tiny-keccak = { version = "2.0", features = ["keccak"] } halo2curves = { git = "https://github.com/PolyhedraZK/halo2curves", default-features = false, features = [ @@ -24,3 +26,4 @@ gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "transcript" } + diff --git a/expander_compiler/Cargo.toml b/expander_compiler/Cargo.toml index bfb4b2b..a955bc2 100644 --- a/expander_compiler/Cargo.toml +++ b/expander_compiler/Cargo.toml @@ -5,8 +5,10 @@ edition = "2021" [dependencies] +ark-std.workspace = true rand.workspace = true chrono.workspace = true +clap.workspace = true ethnum.workspace = true halo2curves.workspace = true tiny-keccak.workspace = true @@ -16,3 +18,7 @@ gkr.workspace = true arith.workspace = true gf2.workspace = true mersenne31.workspace = true + +[[bin]] +name = "trivial_circuit" +path = "bin/trivial_circuit.rs" diff --git a/expander_compiler/bin/trivial_circuit.rs b/expander_compiler/bin/trivial_circuit.rs new file mode 100644 index 0000000..07b0bc3 --- /dev/null +++ b/expander_compiler/bin/trivial_circuit.rs @@ -0,0 +1,140 @@ +//! This module generate a trivial GKR layered circuit for test purpose. +//! Arguments: +//! - field: field identifier +//! - n_var: number of variables +//! - n_layer: number of layers + +use ark_std::test_rng; +use clap::Parser; +use expander_compiler::field::Field; +use expander_compiler::frontend::{compile, BN254Config, CompileResult, Define, M31Config}; +use expander_compiler::utils::serde::Serde; +use expander_compiler::{ + declare_circuit, + frontend::{BasicAPI, Config, Variable, API}, +}; + +/// Arguments for the command line +/// - field: field identifier +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Field Identifier: bn254, m31ext3, gf2ext128 + #[arg(short, long,default_value_t = String::from("bn254"))] + field: String, +} + +// this cannot be too big as we currently uses static array of size 2^LOG_NUM_VARS +const LOG_NUM_VARS: usize = 22; +const NUM_LAYERS: usize = 1; + +fn main() { + let args = Args::parse(); + print_info(&args); + + match args.field.as_str() { + "bn254" => build::(), + "m31ext3" => build::(), + _ => panic!("Unsupported field"), + } +} + +fn build() { + let assignment = TrivialCircuit::::random_witnesses(); + + let compile_result = compile::(&TrivialCircuit::new()).unwrap(); + + let CompileResult { + witness_solver, + layered_circuit, + } = compile_result; + + let witness = witness_solver.solve_witness(&assignment).unwrap(); + let res = layered_circuit.run(&witness); + + assert_eq!(res, vec![true]); + + let file = std::fs::File::create(format!("trivial_circuit_{}.txt", LOG_NUM_VARS)).unwrap(); + let writer = std::io::BufWriter::new(file); + layered_circuit.serialize_into(writer).unwrap(); + + let file = std::fs::File::create(format!("trivial_witness_{}.txt", LOG_NUM_VARS)).unwrap(); + let writer = std::io::BufWriter::new(file); + witness.serialize_into(writer).unwrap(); +} + +fn print_info(args: &Args) { + println!("==============================="); + println!("Gen circuit for {} field", args.field); + println!("Log Num of variables: {}", LOG_NUM_VARS); + println!("Number of layers: {}", NUM_LAYERS); + println!("===============================") +} + +declare_circuit!(TrivialCircuit { + input_layer: [Variable], + output_layer: [Variable], +}); + +impl Define for TrivialCircuit { + fn define(&self, builder: &mut API) { + let out = compute_output::(builder, &self.input_layer); + out.iter().zip(self.output_layer.iter()).for_each(|(x, y)| { + builder.assert_is_equal(x, y); + }); + } +} + +fn compute_output(api: &mut API, input_layer: &[Variable]) -> Vec { + let mut cur_layer = input_layer.to_vec(); + + (0..NUM_LAYERS).for_each(|_| { + let mut next_layer = vec![Variable::default(); 1 << LOG_NUM_VARS]; + for i in 0..(1 << (LOG_NUM_VARS - 1)) { + next_layer[i << 1] = api.add(cur_layer[i << 1], cur_layer[(i << 1) + 1]); + next_layer[(i << 1) + 1] = api.mul(cur_layer[i << 1], cur_layer[(i << 1) + 1]); + } + cur_layer = next_layer; + }); + cur_layer +} + +impl TrivialCircuit { + fn new() -> Self { + let input_layer = (0..1 << LOG_NUM_VARS) + .map(|_| T::default()) + .collect::>(); + let output_layer = (0..1 << LOG_NUM_VARS) + .map(|_| T::default()) + .collect::>(); + + Self { + input_layer, + output_layer, + } + } +} + +impl TrivialCircuit { + fn random_witnesses() -> Self { + let mut rng = test_rng(); + + let input_layer = (0..1 << LOG_NUM_VARS) + .map(|_| T::random_unsafe(&mut rng)) + .collect::>(); + + let mut cur_layer = input_layer.clone(); + (0..NUM_LAYERS).for_each(|_| { + let mut next_layer = vec![T::default(); 1 << LOG_NUM_VARS]; + for i in 0..1 << (LOG_NUM_VARS - 1) { + next_layer[i << 1] = cur_layer[i << 1] + cur_layer[(i << 1) + 1]; + next_layer[(i << 1) + 1] = cur_layer[i << 1] * cur_layer[(i << 1) + 1]; + } + cur_layer = next_layer; + }); + Self { + input_layer, + output_layer: cur_layer, + } + } +}