Skip to content

Commit

Permalink
refactor logup to std (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiyong1997 authored and zhenfeizhang committed Nov 11, 2024
1 parent 893b17b commit fcf74e4
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 83 deletions.
15 changes: 15 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[workspace]
resolver = "2"
members = ["expander_compiler", "expander_compiler/ec_go_lib"]
members = [ "circuit-std-rs","expander_compiler", "expander_compiler/ec_go_lib"]

[profile.test]
opt-level = 3
Expand Down
17 changes: 17 additions & 0 deletions circuit-std-rs/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[package]
name = "circuit-std-rs"
version = "0.1.0"
edition = "2021"


[dependencies]
expander_compiler = { path = "../expander_compiler"}

ark-std.workspace = true
rand.workspace = true
expander_config.workspace = true
expander_circuit.workspace = true
gkr.workspace = true
arith.workspace = true
gf2.workspace = true
mersenne31.workspace = true
5 changes: 5 additions & 0 deletions circuit-std-rs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pub mod traits;
pub use traits::StdCircuit;

pub mod logup;
pub use logup::{LogUpCircuit, LogUpParams};
167 changes: 85 additions & 82 deletions expander_compiler/tests/logup.rs → circuit-std-rs/src/logup.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
use arith::Field;
use expander_compiler::frontend::*;
use extra::Serde;
use rand::{thread_rng, Rng};
use rand::Rng;

const KEY_LEN: usize = 3;
const N_TABLE_ROWS: usize = 17;
const N_COLUMNS: usize = 5;
const N_QUERIES: usize = 33;
use crate::StdCircuit;

declare_circuit!(Circuit {
table_keys: [[Variable; KEY_LEN]; N_TABLE_ROWS],
table_values: [[Variable; N_COLUMNS]; N_TABLE_ROWS],
#[derive(Clone, Copy, Debug)]
pub struct LogUpParams {
pub key_len: usize,
pub value_len: usize,
pub n_table_rows: usize,
pub n_queries: usize,
}

declare_circuit!(_LogUpCircuit {
table_keys: [[Variable]],
table_values: [[Variable]],

query_keys: [[Variable; KEY_LEN]; N_QUERIES],
query_results: [[Variable; N_COLUMNS]; N_QUERIES],
query_keys: [[Variable]],
query_results: [[Variable]],

// counting the number of occurences for each row of the table
query_count: [Variable; N_TABLE_ROWS],
query_count: [Variable],
});

#[derive(Clone, Copy)]
pub type LogUpCircuit = _LogUpCircuit<Variable>;

#[derive(Clone, Copy, Debug)]
struct Rational {
numerator: Variable,
denominator: Variable,
Expand Down Expand Up @@ -77,7 +83,7 @@ fn sum_rational_vec<C: Config>(builder: &mut API<C>, vs: &[Rational]) -> Rationa
vvs[0]
}

// TODO: Add poly randomness
// TODO-Feature: poly randomness
fn get_column_randomness<C: Config>(builder: &mut API<C>, n_columns: usize) -> Vec<Variable> {
let mut randomness = vec![];
randomness.push(builder.constant(1));
Expand All @@ -87,9 +93,16 @@ fn get_column_randomness<C: Config>(builder: &mut API<C>, n_columns: usize) -> V
randomness
}

fn concat_d1(v1: &[Vec<Variable>], v2: &[Vec<Variable>]) -> Vec<Vec<Variable>> {
v1.iter()
.zip(v2.iter())
.map(|(row_key, row_value)| [row_key.to_vec(), row_value.to_vec()].concat())
.collect()
}

fn combine_columns<C: Config>(
builder: &mut API<C>,
vec_2d: &Vec<Vec<Variable>>,
vec_2d: &[Vec<Variable>],
randomness: &[Variable],
) -> Vec<Variable> {
if vec_2d.is_empty() {
Expand Down Expand Up @@ -128,31 +141,24 @@ fn logup_poly_val<C: Config>(
sum_rational_vec(builder, &poly_terms)
}

impl<C: Config> Define<C> for Circuit<Variable> {
impl<C: Config> Define<C> for LogUpCircuit {
fn define(&self, builder: &mut API<C>) {
let key_len = self.table_keys[0].len();
let value_len = self.table_values[0].len();

let alpha = builder.get_random_value();
let randomness = get_column_randomness(builder, KEY_LEN + N_COLUMNS);
let randomness = get_column_randomness(builder, key_len + value_len);

let table_combined = combine_columns(
builder,
&self
.table_keys
.iter()
.zip(self.table_values)
.map(|(row_key, row_value)| [row_key.to_vec(), row_value.to_vec()].concat())
.collect(),
&concat_d1(&self.table_keys, &self.table_values),
&randomness,
);
let v_table = logup_poly_val(builder, &table_combined, &self.query_count, &alpha);

let query_combined = combine_columns(
builder,
&self
.query_keys
.iter()
.zip(self.query_results)
.map(|(row_key, row_value)| [row_key.to_vec(), row_value.to_vec()].concat())
.collect(),
&concat_d1(&self.query_keys, &self.query_results),
&randomness,
);
let one = builder.constant(1);
Expand All @@ -167,63 +173,60 @@ impl<C: Config> Define<C> for Circuit<Variable> {
}
}

#[inline]
fn gen_assignment<C: Config>() -> Circuit<C::CircuitField> {
let mut circuit = Circuit::<C::CircuitField>::default();
let mut rng = thread_rng();
for i in 0..N_TABLE_ROWS {
for j in 0..KEY_LEN {
circuit.table_keys[i][j] = C::CircuitField::random_unsafe(&mut rng);
}
impl<C: Config> StdCircuit<C> for LogUpCircuit {
type Params = LogUpParams;
type Assignment = _LogUpCircuit<C::CircuitField>;

for j in 0..N_COLUMNS {
circuit.table_values[i][j] = C::CircuitField::random_unsafe(&mut rng);
}
}
fn new_circuit(params: &Self::Params) -> Self {
let mut circuit = Self::default();

circuit.query_count = [C::CircuitField::ZERO; N_TABLE_ROWS];
for i in 0..N_QUERIES {
let query_id: usize = rng.gen::<usize>() % N_TABLE_ROWS;
circuit.query_count[query_id] += C::CircuitField::ONE;
circuit.query_keys[i] = circuit.table_keys[query_id];
circuit.query_results[i] = circuit.table_values[query_id];
circuit.table_keys.resize(
params.n_table_rows,
vec![Variable::default(); params.key_len],
);
circuit.table_values.resize(
params.n_table_rows,
vec![Variable::default(); params.value_len],
);
circuit
.query_keys
.resize(params.n_queries, vec![Variable::default(); params.key_len]);
circuit.query_results.resize(
params.n_queries,
vec![Variable::default(); params.value_len],
);
circuit
.query_count
.resize(params.n_table_rows, Variable::default());

circuit
}

circuit
}
fn new_assignment(params: &Self::Params, mut rng: impl rand::RngCore) -> Self::Assignment {
let mut assignment = _LogUpCircuit::<C::CircuitField>::default();
assignment.table_keys.resize(params.n_table_rows, vec![]);
assignment.table_values.resize(params.n_table_rows, vec![]);
assignment.query_keys.resize(params.n_queries, vec![]);
assignment.query_results.resize(params.n_queries, vec![]);

for i in 0..params.n_table_rows {
for _ in 0..params.key_len {
assignment.table_keys[i].push(C::CircuitField::random_unsafe(&mut rng));
}

for _ in 0..params.value_len {
assignment.table_values[i].push(C::CircuitField::random_unsafe(&mut rng));
}
}

fn logup_test_helper<C: Config>() {
let compile_result: CompileResult<C> = compile(&Circuit::default()).unwrap();
let assignment = gen_assignment::<C>();
let witness = compile_result
.witness_solver
.solve_witness(&assignment)
.unwrap();
let output = compile_result.layered_circuit.run(&witness);
assert_eq!(output, vec![true]);

let file = std::fs::File::create("circuit.txt").unwrap();
let writer = std::io::BufWriter::new(file);
compile_result
.layered_circuit
.serialize_into(writer)
.unwrap();

let file = std::fs::File::create("witness.txt").unwrap();
let writer = std::io::BufWriter::new(file);
witness.serialize_into(writer).unwrap();

let file = std::fs::File::create("witness_solver.txt").unwrap();
let writer = std::io::BufWriter::new(file);
compile_result
.witness_solver
.serialize_into(writer)
.unwrap();
}
assignment.query_count = vec![C::CircuitField::ZERO; params.n_table_rows];
for i in 0..params.n_queries {
let query_id: usize = rng.gen::<usize>() % params.n_table_rows;
assignment.query_count[query_id] += C::CircuitField::ONE;
assignment.query_keys[i] = assignment.table_keys[query_id].clone();
assignment.query_results[i] = assignment.table_values[query_id].clone();
}

#[test]
fn logup_test() {
logup_test_helper::<GF2Config>();
logup_test_helper::<M31Config>();
logup_test_helper::<BN254Config>();
assignment
}
}
14 changes: 14 additions & 0 deletions circuit-std-rs/src/traits.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use std::fmt::Debug;

use expander_compiler::frontend::{internal::DumpLoadTwoVariables, Config, Define, Variable};
use rand::RngCore;

// All std circuits must implement the following trait
pub trait StdCircuit<C: Config>: Clone + Define<C> + DumpLoadTwoVariables<Variable> {
type Params: Clone + Debug;
type Assignment: Clone + DumpLoadTwoVariables<C::CircuitField>;

fn new_circuit(params: &Self::Params) -> Self;

fn new_assignment(params: &Self::Params, rng: impl RngCore) -> Self::Assignment;
}
38 changes: 38 additions & 0 deletions circuit-std-rs/tests/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use circuit_std_rs::StdCircuit;
use expander_compiler::frontend::*;
use extra::Serde;
use rand::thread_rng;

pub fn circuit_test_helper<Cfg, Cir>(params: &Cir::Params)
where
Cfg: Config,
Cir: StdCircuit<Cfg>,
{
let mut rng = thread_rng();
let compile_result: CompileResult<Cfg> = compile(&Cir::new_circuit(&params)).unwrap();
let assignment = Cir::new_assignment(&params, &mut rng);
let witness = compile_result
.witness_solver
.solve_witness(&assignment)
.unwrap();
let output = compile_result.layered_circuit.run(&witness);
assert_eq!(output, vec![true]);

let file = std::fs::File::create("circuit.txt").unwrap();
let writer = std::io::BufWriter::new(file);
compile_result
.layered_circuit
.serialize_into(writer)
.unwrap();

let file = std::fs::File::create("witness.txt").unwrap();
let writer = std::io::BufWriter::new(file);
witness.serialize_into(writer).unwrap();

let file = std::fs::File::create("witness_solver.txt").unwrap();
let writer = std::io::BufWriter::new(file);
compile_result
.witness_solver
.serialize_into(writer)
.unwrap();
}
18 changes: 18 additions & 0 deletions circuit-std-rs/tests/logup.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
mod common;

use circuit_std_rs::{LogUpCircuit, LogUpParams};
use expander_compiler::frontend::*;

#[test]
fn logup_test() {
let logup_params = LogUpParams {
key_len: 7,
value_len: 7,
n_table_rows: 123,
n_queries: 456,
};

common::circuit_test_helper::<BN254Config, LogUpCircuit>(&logup_params);
common::circuit_test_helper::<M31Config, LogUpCircuit>(&logup_params);
common::circuit_test_helper::<GF2Config, LogUpCircuit>(&logup_params);
}

0 comments on commit fcf74e4

Please sign in to comment.