From fcf74e4a91a6ef3ac01a39380f5f231b4c89b7d8 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Thu, 7 Nov 2024 20:43:23 -0600 Subject: [PATCH] refactor logup to std (#41) --- Cargo.lock | 15 ++ Cargo.toml | 2 +- circuit-std-rs/Cargo.toml | 17 ++ circuit-std-rs/src/lib.rs | 5 + .../tests => circuit-std-rs/src}/logup.rs | 167 +++++++++--------- circuit-std-rs/src/traits.rs | 14 ++ circuit-std-rs/tests/common.rs | 38 ++++ circuit-std-rs/tests/logup.rs | 18 ++ 8 files changed, 193 insertions(+), 83 deletions(-) create mode 100644 circuit-std-rs/Cargo.toml create mode 100644 circuit-std-rs/src/lib.rs rename {expander_compiler/tests => circuit-std-rs/src}/logup.rs (53%) create mode 100644 circuit-std-rs/src/traits.rs create mode 100644 circuit-std-rs/tests/common.rs create mode 100644 circuit-std-rs/tests/logup.rs diff --git a/Cargo.lock b/Cargo.lock index 32fef1b..a3e937a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -343,6 +343,21 @@ dependencies = [ "transcript", ] +[[package]] +name = "circuit-std-rs" +version = "0.1.0" +dependencies = [ + "arith", + "ark-std", + "circuit", + "config", + "expander_compiler", + "gf2", + "gkr", + "mersenne31", + "rand", +] + [[package]] name = "clang-sys" version = "1.8.1" diff --git a/Cargo.toml b/Cargo.toml index 21b1ead..98bcbd0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = ["expander_compiler", "expander_compiler/ec_go_lib"] +members = [ "circuit-std-rs","expander_compiler", "expander_compiler/ec_go_lib"] [profile.test] opt-level = 3 diff --git a/circuit-std-rs/Cargo.toml b/circuit-std-rs/Cargo.toml new file mode 100644 index 0000000..9fbf43a --- /dev/null +++ b/circuit-std-rs/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "circuit-std-rs" +version = "0.1.0" +edition = "2021" + + +[dependencies] +expander_compiler = { path = "../expander_compiler"} + +ark-std.workspace = true +rand.workspace = true +expander_config.workspace = true +expander_circuit.workspace = true +gkr.workspace = true +arith.workspace = true +gf2.workspace = true +mersenne31.workspace = true diff --git a/circuit-std-rs/src/lib.rs b/circuit-std-rs/src/lib.rs new file mode 100644 index 0000000..248446f --- /dev/null +++ b/circuit-std-rs/src/lib.rs @@ -0,0 +1,5 @@ +pub mod traits; +pub use traits::StdCircuit; + +pub mod logup; +pub use logup::{LogUpCircuit, LogUpParams}; diff --git a/expander_compiler/tests/logup.rs b/circuit-std-rs/src/logup.rs similarity index 53% rename from expander_compiler/tests/logup.rs rename to circuit-std-rs/src/logup.rs index fc32d44..25911b7 100644 --- a/expander_compiler/tests/logup.rs +++ b/circuit-std-rs/src/logup.rs @@ -1,25 +1,31 @@ use arith::Field; use expander_compiler::frontend::*; -use extra::Serde; -use rand::{thread_rng, Rng}; +use rand::Rng; -const KEY_LEN: usize = 3; -const N_TABLE_ROWS: usize = 17; -const N_COLUMNS: usize = 5; -const N_QUERIES: usize = 33; +use crate::StdCircuit; -declare_circuit!(Circuit { - table_keys: [[Variable; KEY_LEN]; N_TABLE_ROWS], - table_values: [[Variable; N_COLUMNS]; N_TABLE_ROWS], +#[derive(Clone, Copy, Debug)] +pub struct LogUpParams { + pub key_len: usize, + pub value_len: usize, + pub n_table_rows: usize, + pub n_queries: usize, +} + +declare_circuit!(_LogUpCircuit { + table_keys: [[Variable]], + table_values: [[Variable]], - query_keys: [[Variable; KEY_LEN]; N_QUERIES], - query_results: [[Variable; N_COLUMNS]; N_QUERIES], + query_keys: [[Variable]], + query_results: [[Variable]], // counting the number of occurences for each row of the table - query_count: [Variable; N_TABLE_ROWS], + query_count: [Variable], }); -#[derive(Clone, Copy)] +pub type LogUpCircuit = _LogUpCircuit; + +#[derive(Clone, Copy, Debug)] struct Rational { numerator: Variable, denominator: Variable, @@ -77,7 +83,7 @@ fn sum_rational_vec(builder: &mut API, vs: &[Rational]) -> Rationa vvs[0] } -// TODO: Add poly randomness +// TODO-Feature: poly randomness fn get_column_randomness(builder: &mut API, n_columns: usize) -> Vec { let mut randomness = vec![]; randomness.push(builder.constant(1)); @@ -87,9 +93,16 @@ fn get_column_randomness(builder: &mut API, n_columns: usize) -> V randomness } +fn concat_d1(v1: &[Vec], v2: &[Vec]) -> Vec> { + v1.iter() + .zip(v2.iter()) + .map(|(row_key, row_value)| [row_key.to_vec(), row_value.to_vec()].concat()) + .collect() +} + fn combine_columns( builder: &mut API, - vec_2d: &Vec>, + vec_2d: &[Vec], randomness: &[Variable], ) -> Vec { if vec_2d.is_empty() { @@ -128,31 +141,24 @@ fn logup_poly_val( sum_rational_vec(builder, &poly_terms) } -impl Define for Circuit { +impl Define for LogUpCircuit { fn define(&self, builder: &mut API) { + let key_len = self.table_keys[0].len(); + let value_len = self.table_values[0].len(); + let alpha = builder.get_random_value(); - let randomness = get_column_randomness(builder, KEY_LEN + N_COLUMNS); + let randomness = get_column_randomness(builder, key_len + value_len); let table_combined = combine_columns( builder, - &self - .table_keys - .iter() - .zip(self.table_values) - .map(|(row_key, row_value)| [row_key.to_vec(), row_value.to_vec()].concat()) - .collect(), + &concat_d1(&self.table_keys, &self.table_values), &randomness, ); let v_table = logup_poly_val(builder, &table_combined, &self.query_count, &alpha); let query_combined = combine_columns( builder, - &self - .query_keys - .iter() - .zip(self.query_results) - .map(|(row_key, row_value)| [row_key.to_vec(), row_value.to_vec()].concat()) - .collect(), + &concat_d1(&self.query_keys, &self.query_results), &randomness, ); let one = builder.constant(1); @@ -167,63 +173,60 @@ impl Define for Circuit { } } -#[inline] -fn gen_assignment() -> Circuit { - let mut circuit = Circuit::::default(); - let mut rng = thread_rng(); - for i in 0..N_TABLE_ROWS { - for j in 0..KEY_LEN { - circuit.table_keys[i][j] = C::CircuitField::random_unsafe(&mut rng); - } +impl StdCircuit for LogUpCircuit { + type Params = LogUpParams; + type Assignment = _LogUpCircuit; - for j in 0..N_COLUMNS { - circuit.table_values[i][j] = C::CircuitField::random_unsafe(&mut rng); - } - } + fn new_circuit(params: &Self::Params) -> Self { + let mut circuit = Self::default(); - circuit.query_count = [C::CircuitField::ZERO; N_TABLE_ROWS]; - for i in 0..N_QUERIES { - let query_id: usize = rng.gen::() % N_TABLE_ROWS; - circuit.query_count[query_id] += C::CircuitField::ONE; - circuit.query_keys[i] = circuit.table_keys[query_id]; - circuit.query_results[i] = circuit.table_values[query_id]; + circuit.table_keys.resize( + params.n_table_rows, + vec![Variable::default(); params.key_len], + ); + circuit.table_values.resize( + params.n_table_rows, + vec![Variable::default(); params.value_len], + ); + circuit + .query_keys + .resize(params.n_queries, vec![Variable::default(); params.key_len]); + circuit.query_results.resize( + params.n_queries, + vec![Variable::default(); params.value_len], + ); + circuit + .query_count + .resize(params.n_table_rows, Variable::default()); + + circuit } - circuit -} + fn new_assignment(params: &Self::Params, mut rng: impl rand::RngCore) -> Self::Assignment { + let mut assignment = _LogUpCircuit::::default(); + assignment.table_keys.resize(params.n_table_rows, vec![]); + assignment.table_values.resize(params.n_table_rows, vec![]); + assignment.query_keys.resize(params.n_queries, vec![]); + assignment.query_results.resize(params.n_queries, vec![]); + + for i in 0..params.n_table_rows { + for _ in 0..params.key_len { + assignment.table_keys[i].push(C::CircuitField::random_unsafe(&mut rng)); + } + + for _ in 0..params.value_len { + assignment.table_values[i].push(C::CircuitField::random_unsafe(&mut rng)); + } + } -fn logup_test_helper() { - let compile_result: CompileResult = compile(&Circuit::default()).unwrap(); - let assignment = gen_assignment::(); - let witness = compile_result - .witness_solver - .solve_witness(&assignment) - .unwrap(); - let output = compile_result.layered_circuit.run(&witness); - assert_eq!(output, vec![true]); - - let file = std::fs::File::create("circuit.txt").unwrap(); - let writer = std::io::BufWriter::new(file); - compile_result - .layered_circuit - .serialize_into(writer) - .unwrap(); - - let file = std::fs::File::create("witness.txt").unwrap(); - let writer = std::io::BufWriter::new(file); - witness.serialize_into(writer).unwrap(); - - let file = std::fs::File::create("witness_solver.txt").unwrap(); - let writer = std::io::BufWriter::new(file); - compile_result - .witness_solver - .serialize_into(writer) - .unwrap(); -} + assignment.query_count = vec![C::CircuitField::ZERO; params.n_table_rows]; + for i in 0..params.n_queries { + let query_id: usize = rng.gen::() % params.n_table_rows; + assignment.query_count[query_id] += C::CircuitField::ONE; + assignment.query_keys[i] = assignment.table_keys[query_id].clone(); + assignment.query_results[i] = assignment.table_values[query_id].clone(); + } -#[test] -fn logup_test() { - logup_test_helper::(); - logup_test_helper::(); - logup_test_helper::(); + assignment + } } diff --git a/circuit-std-rs/src/traits.rs b/circuit-std-rs/src/traits.rs new file mode 100644 index 0000000..f42ca17 --- /dev/null +++ b/circuit-std-rs/src/traits.rs @@ -0,0 +1,14 @@ +use std::fmt::Debug; + +use expander_compiler::frontend::{internal::DumpLoadTwoVariables, Config, Define, Variable}; +use rand::RngCore; + +// All std circuits must implement the following trait +pub trait StdCircuit: Clone + Define + DumpLoadTwoVariables { + type Params: Clone + Debug; + type Assignment: Clone + DumpLoadTwoVariables; + + fn new_circuit(params: &Self::Params) -> Self; + + fn new_assignment(params: &Self::Params, rng: impl RngCore) -> Self::Assignment; +} diff --git a/circuit-std-rs/tests/common.rs b/circuit-std-rs/tests/common.rs new file mode 100644 index 0000000..1adb95a --- /dev/null +++ b/circuit-std-rs/tests/common.rs @@ -0,0 +1,38 @@ +use circuit_std_rs::StdCircuit; +use expander_compiler::frontend::*; +use extra::Serde; +use rand::thread_rng; + +pub fn circuit_test_helper(params: &Cir::Params) +where + Cfg: Config, + Cir: StdCircuit, +{ + let mut rng = thread_rng(); + let compile_result: CompileResult = compile(&Cir::new_circuit(¶ms)).unwrap(); + let assignment = Cir::new_assignment(¶ms, &mut rng); + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + + let file = std::fs::File::create("circuit.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + compile_result + .layered_circuit + .serialize_into(writer) + .unwrap(); + + let file = std::fs::File::create("witness.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + witness.serialize_into(writer).unwrap(); + + let file = std::fs::File::create("witness_solver.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + compile_result + .witness_solver + .serialize_into(writer) + .unwrap(); +} diff --git a/circuit-std-rs/tests/logup.rs b/circuit-std-rs/tests/logup.rs new file mode 100644 index 0000000..1f2a44c --- /dev/null +++ b/circuit-std-rs/tests/logup.rs @@ -0,0 +1,18 @@ +mod common; + +use circuit_std_rs::{LogUpCircuit, LogUpParams}; +use expander_compiler::frontend::*; + +#[test] +fn logup_test() { + let logup_params = LogUpParams { + key_len: 7, + value_len: 7, + n_table_rows: 123, + n_queries: 456, + }; + + common::circuit_test_helper::(&logup_params); + common::circuit_test_helper::(&logup_params); + common::circuit_test_helper::(&logup_params); +}