Skip to content

Commit

Permalink
update trivial circuit code
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenfeizhang committed Oct 11, 2024
1 parent 5625484 commit 0520972
Showing 1 changed file with 33 additions and 20 deletions.
53 changes: 33 additions & 20 deletions expander_compiler/bin/trivial_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ struct Args {
}

// this cannot be too big as we currently uses static array of size 2^LOG_NUM_VARS
const LOG_NUM_VARS: usize = 15;
const LOG_NUM_VARS: usize = 22;
const NUM_LAYERS: usize = 1;

fn main() {
Expand All @@ -42,7 +42,7 @@ fn main() {
fn build<C: Config>() {
let assignment = TrivialCircuit::<C::CircuitField>::random_witnesses();

let compile_result = compile::<C, _>(&TrivialCircuit::default()).unwrap();
let compile_result = compile::<C, _>(&TrivialCircuit::new()).unwrap();

let CompileResult {
witness_solver,
Expand All @@ -54,11 +54,11 @@ fn build<C: Config>() {

assert_eq!(res, vec![true]);

let file = std::fs::File::create("circuit.txt").unwrap();
let file = std::fs::File::create(format!("trivial_circuit_{}.txt", LOG_NUM_VARS)).unwrap();
let writer = std::io::BufWriter::new(file);
layered_circuit.serialize_into(writer).unwrap();

let file = std::fs::File::create("witness.txt").unwrap();
let file = std::fs::File::create(format!("trivial_witness_{}.txt", LOG_NUM_VARS)).unwrap();
let writer = std::io::BufWriter::new(file);
witness.serialize_into(writer).unwrap();
}
Expand All @@ -72,8 +72,8 @@ fn print_info(args: &Args) {
}

declare_circuit!(TrivialCircuit {
input_layer: [Variable; 1 << LOG_NUM_VARS],
output_layer: [Variable; 1 << LOG_NUM_VARS],
input_layer: [Variable],
output_layer: [Variable],
});

impl<C: Config> Define<C> for TrivialCircuit<Variable> {
Expand All @@ -85,13 +85,11 @@ impl<C: Config> Define<C> for TrivialCircuit<Variable> {
}
}

fn compute_output<C: Config>(
api: &mut API<C>,
input_layer: &[Variable; 1 << LOG_NUM_VARS],
) -> [Variable; 1 << LOG_NUM_VARS] {
let mut cur_layer = *input_layer;
(1..NUM_LAYERS).for_each(|_| {
let mut next_layer = [Variable::default(); 1 << LOG_NUM_VARS];
fn compute_output<C: Config>(api: &mut API<C>, input_layer: &[Variable]) -> Vec<Variable> {
let mut cur_layer = input_layer.to_vec();

(0..NUM_LAYERS).for_each(|_| {
let mut next_layer = vec![Variable::default(); 1 << LOG_NUM_VARS];
for i in 0..(1 << (LOG_NUM_VARS - 1)) {
next_layer[i << 1] = api.add(cur_layer[i << 1], cur_layer[(i << 1) + 1]);
next_layer[(i << 1) + 1] = api.mul(cur_layer[i << 1], cur_layer[(i << 1) + 1]);
Expand All @@ -101,18 +99,33 @@ fn compute_output<C: Config>(
cur_layer
}

impl<T: Default> TrivialCircuit<T> {
fn new() -> Self {
let input_layer = (0..1 << LOG_NUM_VARS)
.map(|_| T::default())
.collect::<Vec<_>>();
let output_layer = (0..1 << LOG_NUM_VARS)
.map(|_| T::default())
.collect::<Vec<_>>();

Self {
input_layer,
output_layer,
}
}
}

impl<T: Field> TrivialCircuit<T> {
fn random_witnesses() -> Self {
let mut rng = test_rng();

let mut input_layer = [T::default(); 1 << LOG_NUM_VARS];
input_layer
.iter_mut()
.for_each(|x| *x = T::random_unsafe(&mut rng));
let input_layer = (0..1 << LOG_NUM_VARS)
.map(|_| T::random_unsafe(&mut rng))
.collect::<Vec<_>>();

let mut cur_layer = input_layer;
(1..NUM_LAYERS).for_each(|_| {
let mut next_layer = [T::default(); 1 << LOG_NUM_VARS];
let mut cur_layer = input_layer.clone();
(0..NUM_LAYERS).for_each(|_| {
let mut next_layer = vec![T::default(); 1 << LOG_NUM_VARS];
for i in 0..1 << (LOG_NUM_VARS - 1) {
next_layer[i << 1] = cur_layer[i << 1] + cur_layer[(i << 1) + 1];
next_layer[(i << 1) + 1] = cur_layer[i << 1] * cur_layer[(i << 1) + 1];
Expand Down

0 comments on commit 0520972

Please sign in to comment.