Skip to content

Commit

Permalink
Merge pull request #33 from PolyhedraZK/zz/trivial_layers
Browse files Browse the repository at this point in the history
add a trivial circuit biulder
  • Loading branch information
siq1 authored Oct 12, 2024
2 parents 488d5c2 + a1a55b7 commit 13fde59
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ opt-level = 3


[workspace.dependencies]
ark-std = "0.4.0"
rand = "0.8.5"
chrono = "0.4"
clap = { version = "4.1", features = ["derive"] }
ethnum = "1.5.0"
tiny-keccak = { version = "2.0", features = ["keccak"] }
halo2curves = { git = "https://github.com/PolyhedraZK/halo2curves", default-features = false, features = [
Expand All @@ -24,3 +26,4 @@ gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" }
gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" }
mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" }
expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "transcript" }

6 changes: 6 additions & 0 deletions expander_compiler/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ edition = "2021"


[dependencies]
ark-std.workspace = true
rand.workspace = true
chrono.workspace = true
clap.workspace = true
ethnum.workspace = true
halo2curves.workspace = true
tiny-keccak.workspace = true
Expand All @@ -16,3 +18,7 @@ gkr.workspace = true
arith.workspace = true
gf2.workspace = true
mersenne31.workspace = true

[[bin]]
name = "trivial_circuit"
path = "bin/trivial_circuit.rs"
140 changes: 140 additions & 0 deletions expander_compiler/bin/trivial_circuit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
//! This module generate a trivial GKR layered circuit for test purpose.
//! Arguments:
//! - field: field identifier
//! - n_var: number of variables
//! - n_layer: number of layers
use ark_std::test_rng;
use clap::Parser;
use expander_compiler::field::Field;
use expander_compiler::frontend::{compile, BN254Config, CompileResult, Define, M31Config};
use expander_compiler::utils::serde::Serde;
use expander_compiler::{
declare_circuit,
frontend::{BasicAPI, Config, Variable, API},
};

/// Arguments for the command line
/// - field: field identifier
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Field Identifier: bn254, m31ext3, gf2ext128
#[arg(short, long,default_value_t = String::from("bn254"))]
field: String,
}

// this cannot be too big as we currently uses static array of size 2^LOG_NUM_VARS
const LOG_NUM_VARS: usize = 22;
const NUM_LAYERS: usize = 1;

fn main() {
let args = Args::parse();
print_info(&args);

match args.field.as_str() {
"bn254" => build::<BN254Config>(),
"m31ext3" => build::<M31Config>(),
_ => panic!("Unsupported field"),
}
}

fn build<C: Config>() {
let assignment = TrivialCircuit::<C::CircuitField>::random_witnesses();

let compile_result = compile::<C, _>(&TrivialCircuit::new()).unwrap();

let CompileResult {
witness_solver,
layered_circuit,
} = compile_result;

let witness = witness_solver.solve_witness(&assignment).unwrap();
let res = layered_circuit.run(&witness);

assert_eq!(res, vec![true]);

let file = std::fs::File::create(format!("trivial_circuit_{}.txt", LOG_NUM_VARS)).unwrap();
let writer = std::io::BufWriter::new(file);
layered_circuit.serialize_into(writer).unwrap();

let file = std::fs::File::create(format!("trivial_witness_{}.txt", LOG_NUM_VARS)).unwrap();
let writer = std::io::BufWriter::new(file);
witness.serialize_into(writer).unwrap();
}

fn print_info(args: &Args) {
println!("===============================");
println!("Gen circuit for {} field", args.field);
println!("Log Num of variables: {}", LOG_NUM_VARS);
println!("Number of layers: {}", NUM_LAYERS);
println!("===============================")
}

declare_circuit!(TrivialCircuit {
input_layer: [Variable],
output_layer: [Variable],
});

impl<C: Config> Define<C> for TrivialCircuit<Variable> {
fn define(&self, builder: &mut API<C>) {
let out = compute_output::<C>(builder, &self.input_layer);
out.iter().zip(self.output_layer.iter()).for_each(|(x, y)| {
builder.assert_is_equal(x, y);
});
}
}

fn compute_output<C: Config>(api: &mut API<C>, input_layer: &[Variable]) -> Vec<Variable> {
let mut cur_layer = input_layer.to_vec();

(0..NUM_LAYERS).for_each(|_| {
let mut next_layer = vec![Variable::default(); 1 << LOG_NUM_VARS];
for i in 0..(1 << (LOG_NUM_VARS - 1)) {
next_layer[i << 1] = api.add(cur_layer[i << 1], cur_layer[(i << 1) + 1]);
next_layer[(i << 1) + 1] = api.mul(cur_layer[i << 1], cur_layer[(i << 1) + 1]);
}
cur_layer = next_layer;
});
cur_layer
}

impl<T: Default> TrivialCircuit<T> {
fn new() -> Self {
let input_layer = (0..1 << LOG_NUM_VARS)
.map(|_| T::default())
.collect::<Vec<_>>();
let output_layer = (0..1 << LOG_NUM_VARS)
.map(|_| T::default())
.collect::<Vec<_>>();

Self {
input_layer,
output_layer,
}
}
}

impl<T: Field> TrivialCircuit<T> {
fn random_witnesses() -> Self {
let mut rng = test_rng();

let input_layer = (0..1 << LOG_NUM_VARS)
.map(|_| T::random_unsafe(&mut rng))
.collect::<Vec<_>>();

let mut cur_layer = input_layer.clone();
(0..NUM_LAYERS).for_each(|_| {
let mut next_layer = vec![T::default(); 1 << LOG_NUM_VARS];
for i in 0..1 << (LOG_NUM_VARS - 1) {
next_layer[i << 1] = cur_layer[i << 1] + cur_layer[(i << 1) + 1];
next_layer[(i << 1) + 1] = cur_layer[i << 1] * cur_layer[(i << 1) + 1];
}
cur_layer = next_layer;
});
Self {
input_layer,
output_layer: cur_layer,
}
}
}

0 comments on commit 13fde59

Please sign in to comment.