Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor logup to std #41

Merged
merged 1 commit into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[workspace]
resolver = "2"
members = ["expander_compiler", "expander_compiler/ec_go_lib"]
members = [ "circuit-std-rs","expander_compiler", "expander_compiler/ec_go_lib"]

[profile.test]
opt-level = 3
Expand Down
17 changes: 17 additions & 0 deletions circuit-std-rs/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[package]
name = "circuit-std-rs"
version = "0.1.0"
edition = "2021"


[dependencies]
expander_compiler = { path = "../expander_compiler"}

ark-std.workspace = true
rand.workspace = true
expander_config.workspace = true
expander_circuit.workspace = true
gkr.workspace = true
arith.workspace = true
gf2.workspace = true
mersenne31.workspace = true
5 changes: 5 additions & 0 deletions circuit-std-rs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pub mod traits;
pub use traits::StdCircuit;

pub mod logup;
pub use logup::{LogUpCircuit, LogUpParams};
167 changes: 85 additions & 82 deletions expander_compiler/tests/logup.rs → circuit-std-rs/src/logup.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
use arith::Field;
use expander_compiler::frontend::*;
use extra::Serde;
use rand::{thread_rng, Rng};
use rand::Rng;

const KEY_LEN: usize = 3;
const N_TABLE_ROWS: usize = 17;
const N_COLUMNS: usize = 5;
const N_QUERIES: usize = 33;
use crate::StdCircuit;

declare_circuit!(Circuit {
table_keys: [[Variable; KEY_LEN]; N_TABLE_ROWS],
table_values: [[Variable; N_COLUMNS]; N_TABLE_ROWS],
#[derive(Clone, Copy, Debug)]
pub struct LogUpParams {
pub key_len: usize,
pub value_len: usize,
pub n_table_rows: usize,
pub n_queries: usize,
}

declare_circuit!(_LogUpCircuit {
table_keys: [[Variable]],
table_values: [[Variable]],

query_keys: [[Variable; KEY_LEN]; N_QUERIES],
query_results: [[Variable; N_COLUMNS]; N_QUERIES],
query_keys: [[Variable]],
query_results: [[Variable]],

// counting the number of occurences for each row of the table
query_count: [Variable; N_TABLE_ROWS],
query_count: [Variable],
});

#[derive(Clone, Copy)]
pub type LogUpCircuit = _LogUpCircuit<Variable>;

#[derive(Clone, Copy, Debug)]
struct Rational {
numerator: Variable,
denominator: Variable,
Expand Down Expand Up @@ -77,7 +83,7 @@ fn sum_rational_vec<C: Config>(builder: &mut API<C>, vs: &[Rational]) -> Rationa
vvs[0]
}

// TODO: Add poly randomness
// TODO-Feature: poly randomness
fn get_column_randomness<C: Config>(builder: &mut API<C>, n_columns: usize) -> Vec<Variable> {
let mut randomness = vec![];
randomness.push(builder.constant(1));
Expand All @@ -87,9 +93,16 @@ fn get_column_randomness<C: Config>(builder: &mut API<C>, n_columns: usize) -> V
randomness
}

fn concat_d1(v1: &[Vec<Variable>], v2: &[Vec<Variable>]) -> Vec<Vec<Variable>> {
v1.iter()
.zip(v2.iter())
.map(|(row_key, row_value)| [row_key.to_vec(), row_value.to_vec()].concat())
.collect()
}

fn combine_columns<C: Config>(
builder: &mut API<C>,
vec_2d: &Vec<Vec<Variable>>,
vec_2d: &[Vec<Variable>],
randomness: &[Variable],
) -> Vec<Variable> {
if vec_2d.is_empty() {
Expand Down Expand Up @@ -128,31 +141,24 @@ fn logup_poly_val<C: Config>(
sum_rational_vec(builder, &poly_terms)
}

impl<C: Config> Define<C> for Circuit<Variable> {
impl<C: Config> Define<C> for LogUpCircuit {
fn define(&self, builder: &mut API<C>) {
let key_len = self.table_keys[0].len();
let value_len = self.table_values[0].len();

let alpha = builder.get_random_value();
let randomness = get_column_randomness(builder, KEY_LEN + N_COLUMNS);
let randomness = get_column_randomness(builder, key_len + value_len);

let table_combined = combine_columns(
builder,
&self
.table_keys
.iter()
.zip(self.table_values)
.map(|(row_key, row_value)| [row_key.to_vec(), row_value.to_vec()].concat())
.collect(),
&concat_d1(&self.table_keys, &self.table_values),
&randomness,
);
let v_table = logup_poly_val(builder, &table_combined, &self.query_count, &alpha);

let query_combined = combine_columns(
builder,
&self
.query_keys
.iter()
.zip(self.query_results)
.map(|(row_key, row_value)| [row_key.to_vec(), row_value.to_vec()].concat())
.collect(),
&concat_d1(&self.query_keys, &self.query_results),
&randomness,
);
let one = builder.constant(1);
Expand All @@ -167,63 +173,60 @@ impl<C: Config> Define<C> for Circuit<Variable> {
}
}

#[inline]
fn gen_assignment<C: Config>() -> Circuit<C::CircuitField> {
let mut circuit = Circuit::<C::CircuitField>::default();
let mut rng = thread_rng();
for i in 0..N_TABLE_ROWS {
for j in 0..KEY_LEN {
circuit.table_keys[i][j] = C::CircuitField::random_unsafe(&mut rng);
}
impl<C: Config> StdCircuit<C> for LogUpCircuit {
type Params = LogUpParams;
type Assignment = _LogUpCircuit<C::CircuitField>;

for j in 0..N_COLUMNS {
circuit.table_values[i][j] = C::CircuitField::random_unsafe(&mut rng);
}
}
fn new_circuit(params: &Self::Params) -> Self {
let mut circuit = Self::default();

circuit.query_count = [C::CircuitField::ZERO; N_TABLE_ROWS];
for i in 0..N_QUERIES {
let query_id: usize = rng.gen::<usize>() % N_TABLE_ROWS;
circuit.query_count[query_id] += C::CircuitField::ONE;
circuit.query_keys[i] = circuit.table_keys[query_id];
circuit.query_results[i] = circuit.table_values[query_id];
circuit.table_keys.resize(
params.n_table_rows,
vec![Variable::default(); params.key_len],
);
circuit.table_values.resize(
params.n_table_rows,
vec![Variable::default(); params.value_len],
);
circuit
.query_keys
.resize(params.n_queries, vec![Variable::default(); params.key_len]);
circuit.query_results.resize(
params.n_queries,
vec![Variable::default(); params.value_len],
);
circuit
.query_count
.resize(params.n_table_rows, Variable::default());

circuit
}

circuit
}
fn new_assignment(params: &Self::Params, mut rng: impl rand::RngCore) -> Self::Assignment {
let mut assignment = _LogUpCircuit::<C::CircuitField>::default();
assignment.table_keys.resize(params.n_table_rows, vec![]);
assignment.table_values.resize(params.n_table_rows, vec![]);
assignment.query_keys.resize(params.n_queries, vec![]);
assignment.query_results.resize(params.n_queries, vec![]);

for i in 0..params.n_table_rows {
for _ in 0..params.key_len {
assignment.table_keys[i].push(C::CircuitField::random_unsafe(&mut rng));
}

for _ in 0..params.value_len {
assignment.table_values[i].push(C::CircuitField::random_unsafe(&mut rng));
}
}

fn logup_test_helper<C: Config>() {
let compile_result: CompileResult<C> = compile(&Circuit::default()).unwrap();
let assignment = gen_assignment::<C>();
let witness = compile_result
.witness_solver
.solve_witness(&assignment)
.unwrap();
let output = compile_result.layered_circuit.run(&witness);
assert_eq!(output, vec![true]);

let file = std::fs::File::create("circuit.txt").unwrap();
let writer = std::io::BufWriter::new(file);
compile_result
.layered_circuit
.serialize_into(writer)
.unwrap();

let file = std::fs::File::create("witness.txt").unwrap();
let writer = std::io::BufWriter::new(file);
witness.serialize_into(writer).unwrap();

let file = std::fs::File::create("witness_solver.txt").unwrap();
let writer = std::io::BufWriter::new(file);
compile_result
.witness_solver
.serialize_into(writer)
.unwrap();
}
assignment.query_count = vec![C::CircuitField::ZERO; params.n_table_rows];
for i in 0..params.n_queries {
let query_id: usize = rng.gen::<usize>() % params.n_table_rows;
assignment.query_count[query_id] += C::CircuitField::ONE;
assignment.query_keys[i] = assignment.table_keys[query_id].clone();
assignment.query_results[i] = assignment.table_values[query_id].clone();
}

#[test]
fn logup_test() {
logup_test_helper::<GF2Config>();
logup_test_helper::<M31Config>();
logup_test_helper::<BN254Config>();
assignment
}
}
14 changes: 14 additions & 0 deletions circuit-std-rs/src/traits.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use std::fmt::Debug;

use expander_compiler::frontend::{internal::DumpLoadTwoVariables, Config, Define, Variable};
use rand::RngCore;

// All std circuits must implement the following trait
pub trait StdCircuit<C: Config>: Clone + Define<C> + DumpLoadTwoVariables<Variable> {
type Params: Clone + Debug;
type Assignment: Clone + DumpLoadTwoVariables<C::CircuitField>;

fn new_circuit(params: &Self::Params) -> Self;

fn new_assignment(params: &Self::Params, rng: impl RngCore) -> Self::Assignment;
}
38 changes: 38 additions & 0 deletions circuit-std-rs/tests/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use circuit_std_rs::StdCircuit;
use expander_compiler::frontend::*;
use extra::Serde;
use rand::thread_rng;

pub fn circuit_test_helper<Cfg, Cir>(params: &Cir::Params)
where
Cfg: Config,
Cir: StdCircuit<Cfg>,
{
let mut rng = thread_rng();
let compile_result: CompileResult<Cfg> = compile(&Cir::new_circuit(&params)).unwrap();
let assignment = Cir::new_assignment(&params, &mut rng);
let witness = compile_result
.witness_solver
.solve_witness(&assignment)
.unwrap();
let output = compile_result.layered_circuit.run(&witness);
assert_eq!(output, vec![true]);

let file = std::fs::File::create("circuit.txt").unwrap();
let writer = std::io::BufWriter::new(file);
compile_result
.layered_circuit
.serialize_into(writer)
.unwrap();

let file = std::fs::File::create("witness.txt").unwrap();
let writer = std::io::BufWriter::new(file);
witness.serialize_into(writer).unwrap();

let file = std::fs::File::create("witness_solver.txt").unwrap();
let writer = std::io::BufWriter::new(file);
compile_result
.witness_solver
.serialize_into(writer)
.unwrap();
}
18 changes: 18 additions & 0 deletions circuit-std-rs/tests/logup.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
mod common;

use circuit_std_rs::{LogUpCircuit, LogUpParams};
use expander_compiler::frontend::*;

#[test]
fn logup_test() {
let logup_params = LogUpParams {
key_len: 7,
value_len: 7,
n_table_rows: 123,
n_queries: 456,
};

common::circuit_test_helper::<BN254Config, LogUpCircuit>(&logup_params);
common::circuit_test_helper::<M31Config, LogUpCircuit>(&logup_params);
common::circuit_test_helper::<GF2Config, LogUpCircuit>(&logup_params);
}
Loading