From b506cdd3c84eb373b3b37c004282a530f31a74a6 Mon Sep 17 00:00:00 2001 From: siq1 Date: Tue, 10 Dec 2024 09:03:44 +0800 Subject: [PATCH] hints --- .../src/circuit/ir/hint_normalized/mod.rs | 66 ++-- .../ir/hint_normalized/witness_solver.rs | 10 +- expander_compiler/src/frontend/api.rs | 6 + expander_compiler/src/frontend/builder.rs | 25 +- expander_compiler/src/frontend/mod.rs | 3 +- expander_compiler/src/frontend/witness.rs | 36 +- expander_compiler/src/hints/builtin.rs | 321 +++++++++++++++++ expander_compiler/src/hints/mod.rs | 325 +----------------- expander_compiler/src/hints/registry.rs | 53 +++ expander_compiler/tests/to_binary_hint.rs | 89 +++++ 10 files changed, 587 insertions(+), 347 deletions(-) create mode 100644 expander_compiler/src/hints/builtin.rs create mode 100644 expander_compiler/src/hints/registry.rs create mode 100644 expander_compiler/tests/to_binary_hint.rs diff --git a/expander_compiler/src/circuit/ir/hint_normalized/mod.rs b/expander_compiler/src/circuit/ir/hint_normalized/mod.rs index b5b169d..4bb723b 100644 --- a/expander_compiler/src/circuit/ir/hint_normalized/mod.rs +++ b/expander_compiler/src/circuit/ir/hint_normalized/mod.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use crate::field::FieldArith; +use crate::hints::registry::HintRegistry; use crate::utils::error::Error; use crate::{ circuit::{ @@ -201,6 +202,42 @@ impl common::Instruction for Instruction { } } +impl Instruction { + fn eval_safe( + &self, + values: &[C::CircuitField], + public_inputs: &[C::CircuitField], + hint_registry: &mut HintRegistry, + ) -> EvalResult { + if let Instruction::ConstantLike(coef) = self { + return match coef { + Coef::Constant(c) => EvalResult::Value(c.clone()), + Coef::PublicInput(i) => EvalResult::Value(public_inputs[*i]), + Coef::Random => EvalResult::Error(Error::UserError( + "random coef occured in witness solver".to_string(), + )), + }; + } + if let Instruction::Hint { + hint_id, + inputs, + num_outputs, + } = self + { + return match hints::safe_impl( + hint_registry, + *hint_id, + &inputs.iter().map(|i| values[*i]).collect(), + *num_outputs, + ) { + Ok(outputs) => EvalResult::Values(outputs), + Err(e) => EvalResult::Error(e), + }; + } + self.eval_unsafe(values) + } +} + pub type Circuit = common::Circuit>; pub type RootCircuit = common::RootCircuit>; @@ -443,41 +480,27 @@ impl RootCircuit { self.circuits.insert(0, c0); } - pub fn eval_with_public_inputs( + pub fn eval_safe( &self, inputs: Vec, public_inputs: &[C::CircuitField], + hint_registry: &mut HintRegistry, ) -> Result, Error> { assert_eq!(inputs.len(), self.input_size()); - self.eval_sub_with_public_inputs(&self.circuits[&0], inputs, public_inputs) + self.eval_sub_safe(&self.circuits[&0], inputs, public_inputs, hint_registry) } - fn eval_sub_with_public_inputs( + fn eval_sub_safe( &self, circuit: &Circuit, inputs: Vec, public_inputs: &[C::CircuitField], + hint_registry: &mut HintRegistry, ) -> Result, Error> { let mut values = vec![C::CircuitField::zero(); 1]; values.extend(inputs); for insn in circuit.instructions.iter() { - if let Instruction::ConstantLike(coef) = insn { - match coef { - Coef::Constant(c) => { - values.push(*c); - } - Coef::PublicInput(i) => { - values.push(public_inputs[*i]); - } - Coef::Random => { - return Err(Error::UserError( - "random coef occured in witness solver".to_string(), - )); - } - } - continue; - } - match insn.eval_unsafe(&values) { + match insn.eval_safe(&values, public_inputs, hint_registry) { EvalResult::Value(v) => { values.push(v); } @@ -485,10 +508,11 @@ impl RootCircuit { values.append(&mut vs); } EvalResult::SubCircuitCall(sub_circuit_id, inputs) => { - let res = self.eval_sub_with_public_inputs( + let res = self.eval_sub_safe( &self.circuits[&sub_circuit_id], inputs.iter().map(|&i| values[i]).collect(), public_inputs, + hint_registry, )?; values.extend(res); } diff --git a/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs b/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs index 970307d..3374739 100644 --- a/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs +++ b/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs @@ -11,10 +11,11 @@ impl WitnessSolver { &self, vars: Vec, public_vars: Vec, + hint_registry: &mut HintRegistry, ) -> Result<(Vec, usize), Error> { assert_eq!(vars.len(), self.circuit.input_size()); assert_eq!(public_vars.len(), self.circuit.num_public_inputs); - let mut a = self.circuit.eval_with_public_inputs(vars, &public_vars)?; + let mut a = self.circuit.eval_safe(vars, &public_vars, hint_registry)?; let res_len = a.len(); a.extend(public_vars); Ok((a, res_len)) @@ -24,8 +25,10 @@ impl WitnessSolver { &self, vars: Vec, public_vars: Vec, + hint_registry: &mut HintRegistry, ) -> Result, Error> { - let (values, num_inputs_per_witness) = self.solve_witness_inner(vars, public_vars)?; + let (values, num_inputs_per_witness) = + self.solve_witness_inner(vars, public_vars, hint_registry)?; Ok(Witness { num_witnesses: 1, num_inputs_per_witness, @@ -40,12 +43,13 @@ impl WitnessSolver { &self, num_witnesses: usize, f: F, + hint_registry: &mut HintRegistry, ) -> Result, Error> { let mut values = Vec::new(); let mut num_inputs_per_witness = 0; for i in 0..num_witnesses { let (a, b) = f(i); - let (a, num) = self.solve_witness_inner(a, b)?; + let (a, num) = self.solve_witness_inner(a, b, hint_registry)?; values.extend(a); num_inputs_per_witness = num; } diff --git a/expander_compiler/src/frontend/api.rs b/expander_compiler/src/frontend/api.rs index d0bad08..6f7d050 100644 --- a/expander_compiler/src/frontend/api.rs +++ b/expander_compiler/src/frontend/api.rs @@ -42,6 +42,12 @@ pub trait BasicAPI { y: impl ToVariableOrValue, ); fn get_random_value(&mut self) -> Variable; + fn new_hint( + &mut self, + hint_key: &str, + inputs: &[Variable], + num_outputs: usize, + ) -> Vec; } pub trait UnconstrainedAPI { diff --git a/expander_compiler/src/frontend/builder.rs b/expander_compiler/src/frontend/builder.rs index 220927d..7074af7 100644 --- a/expander_compiler/src/frontend/builder.rs +++ b/expander_compiler/src/frontend/builder.rs @@ -13,7 +13,7 @@ use crate::{ layered::Coef, }, field::{Field, FieldArith}, - hints, + hints::{self, registry::hint_key_to_id}, utils::function_id::get_function_id, }; @@ -292,6 +292,20 @@ impl BasicAPI for Builder { .push(SourceInstruction::ConstantLike(Coef::Random)); self.new_var() } + + fn new_hint( + &mut self, + hint_key: &str, + inputs: &[Variable], + num_outputs: usize, + ) -> Vec { + self.instructions.push(SourceInstruction::Hint { + hint_id: hint_key_to_id(hint_key), + inputs: inputs.iter().map(|v| v.id).collect(), + num_outputs, + }); + (0..num_outputs).map(|_| self.new_var()).collect() + } } // write macro rules for unconstrained binary op definition @@ -445,6 +459,15 @@ impl BasicAPI for RootBuilder { fn get_random_value(&mut self) -> Variable { self.last_builder().get_random_value() } + + fn new_hint( + &mut self, + hint_key: &str, + inputs: &[Variable], + num_outputs: usize, + ) -> Vec { + self.last_builder().new_hint(hint_key, inputs, num_outputs) + } } impl RootBuilder { diff --git a/expander_compiler/src/frontend/mod.rs b/expander_compiler/src/frontend/mod.rs index 1b087b3..4a66c3f 100644 --- a/expander_compiler/src/frontend/mod.rs +++ b/expander_compiler/src/frontend/mod.rs @@ -11,7 +11,8 @@ mod witness; pub use circuit::declare_circuit; pub type API = builder::RootBuilder; pub use crate::circuit::config::*; -pub use crate::field::{Field, BN254, GF2, M31}; +pub use crate::field::{Field, FieldArith, FieldModulus, BN254, GF2, M31}; +pub use crate::hints::registry::HintRegistry; pub use crate::utils::error::Error; pub use api::BasicAPI; pub use builder::Variable; diff --git a/expander_compiler/src/frontend/witness.rs b/expander_compiler/src/frontend/witness.rs index f686fe1..d39f130 100644 --- a/expander_compiler/src/frontend/witness.rs +++ b/expander_compiler/src/frontend/witness.rs @@ -1,5 +1,5 @@ pub use crate::circuit::ir::hint_normalized::witness_solver::WitnessSolver; -use crate::circuit::layered::witness::Witness; +use crate::{circuit::layered::witness::Witness, hints::registry::HintRegistry}; use super::{internal, Config, Error}; @@ -7,22 +7,42 @@ impl WitnessSolver { pub fn solve_witness>( &self, assignment: &Cir, + ) -> Result, Error> { + self.solve_witness_with_hints(assignment, &mut HintRegistry::new()) + } + + pub fn solve_witness_with_hints>( + &self, + assignment: &Cir, + hint_registry: &mut HintRegistry, ) -> Result, Error> { let mut vars = Vec::new(); let mut public_vars = Vec::new(); assignment.dump_into(&mut vars, &mut public_vars); - self.solve_witness_from_raw_inputs(vars, public_vars) + self.solve_witness_from_raw_inputs(vars, public_vars, hint_registry) } pub fn solve_witnesses>( &self, assignments: &[Cir], ) -> Result, Error> { - self.solve_witnesses_from_raw_inputs(assignments.len(), |i| { - let mut vars = Vec::new(); - let mut public_vars = Vec::new(); - assignments[i].dump_into(&mut vars, &mut public_vars); - (vars, public_vars) - }) + self.solve_witnesses_with_hints(assignments, &mut HintRegistry::new()) + } + + pub fn solve_witnesses_with_hints>( + &self, + assignments: &[Cir], + hint_registry: &mut HintRegistry, + ) -> Result, Error> { + self.solve_witnesses_from_raw_inputs( + assignments.len(), + |i| { + let mut vars = Vec::new(); + let mut public_vars = Vec::new(); + assignments[i].dump_into(&mut vars, &mut public_vars); + (vars, public_vars) + }, + hint_registry, + ) } } diff --git a/expander_compiler/src/hints/builtin.rs b/expander_compiler/src/hints/builtin.rs new file mode 100644 index 0000000..0444ecb --- /dev/null +++ b/expander_compiler/src/hints/builtin.rs @@ -0,0 +1,321 @@ +use std::hash::{DefaultHasher, Hash, Hasher}; + +use ethnum::U256; +use rand::RngCore; + +use crate::{field::Field, utils::error::Error}; + +#[repr(u64)] +pub enum BuiltinHintIds { + Identity = 0xccc000000000, + Div, + Eq, + NotEq, + BoolOr, + BoolAnd, + BitOr, + BitAnd, + BitXor, + Select, + Pow, + IntDiv, + Mod, + ShiftL, + ShiftR, + LesserEq, + GreaterEq, + Lesser, + Greater, +} + +#[cfg(not(target_pointer_width = "64"))] +compile_error!("compilation is only allowed for 64-bit targets"); + +impl BuiltinHintIds { + pub fn from_usize(id: usize) -> Option { + if id < (BuiltinHintIds::Identity as u64 as usize) { + return None; + } + if id > (BuiltinHintIds::Identity as u64 as usize + 100) { + return None; + } + match id { + x if x == BuiltinHintIds::Identity as u64 as usize => Some(BuiltinHintIds::Identity), + x if x == BuiltinHintIds::Div as u64 as usize => Some(BuiltinHintIds::Div), + x if x == BuiltinHintIds::Eq as u64 as usize => Some(BuiltinHintIds::Eq), + x if x == BuiltinHintIds::NotEq as u64 as usize => Some(BuiltinHintIds::NotEq), + x if x == BuiltinHintIds::BoolOr as u64 as usize => Some(BuiltinHintIds::BoolOr), + x if x == BuiltinHintIds::BoolAnd as u64 as usize => Some(BuiltinHintIds::BoolAnd), + x if x == BuiltinHintIds::BitOr as u64 as usize => Some(BuiltinHintIds::BitOr), + x if x == BuiltinHintIds::BitAnd as u64 as usize => Some(BuiltinHintIds::BitAnd), + x if x == BuiltinHintIds::BitXor as u64 as usize => Some(BuiltinHintIds::BitXor), + x if x == BuiltinHintIds::Select as u64 as usize => Some(BuiltinHintIds::Select), + x if x == BuiltinHintIds::Pow as u64 as usize => Some(BuiltinHintIds::Pow), + x if x == BuiltinHintIds::IntDiv as u64 as usize => Some(BuiltinHintIds::IntDiv), + x if x == BuiltinHintIds::Mod as u64 as usize => Some(BuiltinHintIds::Mod), + x if x == BuiltinHintIds::ShiftL as u64 as usize => Some(BuiltinHintIds::ShiftL), + x if x == BuiltinHintIds::ShiftR as u64 as usize => Some(BuiltinHintIds::ShiftR), + x if x == BuiltinHintIds::LesserEq as u64 as usize => Some(BuiltinHintIds::LesserEq), + x if x == BuiltinHintIds::GreaterEq as u64 as usize => Some(BuiltinHintIds::GreaterEq), + x if x == BuiltinHintIds::Lesser as u64 as usize => Some(BuiltinHintIds::Lesser), + x if x == BuiltinHintIds::Greater as u64 as usize => Some(BuiltinHintIds::Greater), + _ => None, + } + } +} + +fn stub_impl_general(hint_id: usize, inputs: &Vec, num_outputs: usize) -> Vec { + let mut hasher = DefaultHasher::new(); + hint_id.hash(&mut hasher); + inputs.hash(&mut hasher); + let mut outputs = Vec::with_capacity(num_outputs); + for _ in 0..num_outputs { + let t = hasher.finish(); + outputs.push(F::from(t as u32)); + t.hash(&mut hasher); + } + outputs +} + +fn validate_builtin_hint( + hint_id: BuiltinHintIds, + num_inputs: usize, + num_outputs: usize, +) -> Result<(), Error> { + match hint_id { + BuiltinHintIds::Identity => { + if num_inputs != num_outputs { + return Err(Error::InternalError( + "identity hint requires exactly the same number of inputs and outputs" + .to_string(), + )); + } + if num_inputs == 0 { + return Err(Error::InternalError( + "identity hint requires at least 1 input".to_string(), + )); + } + } + BuiltinHintIds::Div + | BuiltinHintIds::Eq + | BuiltinHintIds::NotEq + | BuiltinHintIds::BoolOr + | BuiltinHintIds::BoolAnd + | BuiltinHintIds::BitOr + | BuiltinHintIds::BitAnd + | BuiltinHintIds::BitXor + | BuiltinHintIds::Pow + | BuiltinHintIds::IntDiv + | BuiltinHintIds::Mod + | BuiltinHintIds::ShiftL + | BuiltinHintIds::ShiftR + | BuiltinHintIds::LesserEq + | BuiltinHintIds::GreaterEq + | BuiltinHintIds::Lesser + | BuiltinHintIds::Greater => { + if num_inputs != 2 { + return Err(Error::InternalError( + "binary op requires exactly 2 inputs".to_string(), + )); + } + if num_outputs != 1 { + return Err(Error::InternalError( + "binary op requires exactly 1 output".to_string(), + )); + } + } + BuiltinHintIds::Select => { + if num_inputs != 3 { + return Err(Error::InternalError( + "select requires exactly 3 inputs".to_string(), + )); + } + if num_outputs != 1 { + return Err(Error::InternalError( + "select requires exactly 1 output".to_string(), + )); + } + } + } + Ok(()) +} + +pub fn validate_hint(hint_id: usize, num_inputs: usize, num_outputs: usize) -> Result<(), Error> { + match BuiltinHintIds::from_usize(hint_id) { + Some(hint_id) => validate_builtin_hint(hint_id, num_inputs, num_outputs), + None => { + if num_outputs == 0 { + return Err(Error::InternalError( + "custom hint requires at least 1 output".to_string(), + )); + } + if num_inputs == 0 { + return Err(Error::InternalError( + "custom hint requires at least 1 input".to_string(), + )); + } + Ok(()) + } + } +} + +pub fn impl_builtin_hint( + hint_id: BuiltinHintIds, + inputs: &[F], + num_outputs: usize, +) -> Vec { + match hint_id { + BuiltinHintIds::Identity => inputs.iter().take(num_outputs).cloned().collect(), + BuiltinHintIds::Div => binop_hint(inputs, |x, y| match y.inv() { + Some(inv) => x * inv, + None => F::zero(), + }), + BuiltinHintIds::Eq => binop_hint(inputs, |x, y| F::from((x == y) as u32)), + BuiltinHintIds::NotEq => binop_hint(inputs, |x, y| F::from((x != y) as u32)), + BuiltinHintIds::BoolOr => binop_hint(inputs, |x, y| { + F::from((!x.is_zero() || !y.is_zero()) as u32) + }), + BuiltinHintIds::BoolAnd => binop_hint(inputs, |x, y| { + F::from((!x.is_zero() && !y.is_zero()) as u32) + }), + BuiltinHintIds::BitOr => binop_hint_on_u256(inputs, |x, y| x | y), + BuiltinHintIds::BitAnd => binop_hint_on_u256(inputs, |x, y| x & y), + BuiltinHintIds::BitXor => binop_hint_on_u256(inputs, |x, y| x ^ y), + BuiltinHintIds::Select => { + let mut outputs = Vec::with_capacity(num_outputs); + outputs.push(if !inputs[0].is_zero() { + inputs[1] + } else { + inputs[2] + }); + outputs + } + BuiltinHintIds::Pow => binop_hint(inputs, |x, y| { + let mut t = x; + let mut res = F::one(); + let mut y: U256 = y.to_u256(); + while y != U256::ZERO { + if y & U256::from(1u32) != U256::ZERO { + res *= t; + } + y >>= 1; + t = t * t; + } + res + }), + BuiltinHintIds::IntDiv => { + binop_hint_on_u256( + inputs, + |x, y| if y == U256::ZERO { U256::ZERO } else { x / y }, + ) + } + BuiltinHintIds::Mod => { + binop_hint_on_u256( + inputs, + |x, y| if y == U256::ZERO { U256::ZERO } else { x % y }, + ) + } + BuiltinHintIds::ShiftL => binop_hint_on_u256(inputs, |x, y| circom_shift_l_impl::(x, y)), + BuiltinHintIds::ShiftR => binop_hint_on_u256(inputs, |x, y| circom_shift_r_impl::(x, y)), + BuiltinHintIds::LesserEq => binop_hint(inputs, |x, y| F::from((x <= y) as u32)), + BuiltinHintIds::GreaterEq => binop_hint(inputs, |x, y| F::from((x >= y) as u32)), + BuiltinHintIds::Lesser => binop_hint(inputs, |x, y| F::from((x < y) as u32)), + BuiltinHintIds::Greater => binop_hint(inputs, |x, y| F::from((x > y) as u32)), + } +} + +fn binop_hint F>(inputs: &[F], f: G) -> Vec { + vec![f(inputs[0], inputs[1])] +} + +fn binop_hint_on_u256 U256>(inputs: &[F], f: G) -> Vec { + let x_u256: U256 = inputs[0].to_u256(); + let y_u256: U256 = inputs[1].to_u256(); + let z_u256 = f(x_u256, y_u256); + vec![F::from_u256(z_u256)] +} + +pub fn stub_impl(hint_id: usize, inputs: &Vec, num_outputs: usize) -> Vec { + match BuiltinHintIds::from_usize(hint_id) { + Some(hint_id) => impl_builtin_hint(hint_id, inputs, num_outputs), + None => stub_impl_general(hint_id, inputs, num_outputs), + } +} + +pub fn random_builtin(mut rand: impl RngCore) -> (usize, usize, usize) { + loop { + let hint_id = (rand.next_u64() as usize % 100) + (BuiltinHintIds::Identity as u64 as usize); + if let Some(hint_id) = BuiltinHintIds::from_usize(hint_id) { + match hint_id { + BuiltinHintIds::Identity => { + let num_inputs = (rand.next_u64() % 10) as usize + 1; + let num_outputs = num_inputs; + return (hint_id as usize, num_inputs, num_outputs); + } + BuiltinHintIds::Div + | BuiltinHintIds::Eq + | BuiltinHintIds::NotEq + | BuiltinHintIds::BoolOr + | BuiltinHintIds::BoolAnd + | BuiltinHintIds::BitOr + | BuiltinHintIds::BitAnd + | BuiltinHintIds::BitXor + | BuiltinHintIds::Pow + | BuiltinHintIds::IntDiv + | BuiltinHintIds::Mod + | BuiltinHintIds::ShiftL + | BuiltinHintIds::ShiftR + | BuiltinHintIds::LesserEq + | BuiltinHintIds::GreaterEq + | BuiltinHintIds::Lesser + | BuiltinHintIds::Greater => { + return (hint_id as usize, 2, 1); + } + BuiltinHintIds::Select => { + return (hint_id as usize, 3, 1); + } + } + } + } +} + +pub fn u256_bit_length(x: U256) -> usize { + 256 - x.leading_zeros() as usize +} + +pub fn circom_shift_l_impl(x: U256, k: U256) -> U256 { + let top = F::modulus() / 2; + if k <= top { + let shift = if (k >> U256::from(64u32)) == U256::ZERO { + k.as_u64() as usize + } else { + u256_bit_length(F::modulus()) + }; + if shift >= 256 { + return U256::ZERO; + } + let value = x << shift; + let mask = U256::from(1u32) << u256_bit_length(F::modulus()); + let mask = mask - 1; + value & mask + } else { + circom_shift_r_impl::(x, F::modulus() - k) + } +} + +pub fn circom_shift_r_impl(x: U256, k: U256) -> U256 { + let top = F::modulus() / 2; + if k <= top { + let shift = if (k >> U256::from(64u32)) == U256::ZERO { + k.as_u64() as usize + } else { + u256_bit_length(F::modulus()) + }; + if shift >= 256 { + return U256::ZERO; + } + x >> shift + } else { + circom_shift_l_impl::(x, F::modulus() - k) + } +} diff --git a/expander_compiler/src/hints/mod.rs b/expander_compiler/src/hints/mod.rs index b9a312c..9d39723 100644 --- a/expander_compiler/src/hints/mod.rs +++ b/expander_compiler/src/hints/mod.rs @@ -1,321 +1,20 @@ -use std::hash::{DefaultHasher, Hash, Hasher}; +pub mod builtin; +pub mod registry; -use ethnum::U256; -use rand::RngCore; +pub use builtin::*; -use crate::{field::Field, utils::error::Error}; - -#[repr(u64)] -pub enum BuiltinHintIds { - Identity = 0xccc000000000, - Div, - Eq, - NotEq, - BoolOr, - BoolAnd, - BitOr, - BitAnd, - BitXor, - Select, - Pow, - IntDiv, - Mod, - ShiftL, - ShiftR, - LesserEq, - GreaterEq, - Lesser, - Greater, -} - -#[cfg(not(target_pointer_width = "64"))] -compile_error!("compilation is only allowed for 64-bit targets"); - -impl BuiltinHintIds { - pub fn from_usize(id: usize) -> Option { - if id < (BuiltinHintIds::Identity as u64 as usize) { - return None; - } - if id > (BuiltinHintIds::Identity as u64 as usize + 100) { - return None; - } - match id { - x if x == BuiltinHintIds::Identity as u64 as usize => Some(BuiltinHintIds::Identity), - x if x == BuiltinHintIds::Div as u64 as usize => Some(BuiltinHintIds::Div), - x if x == BuiltinHintIds::Eq as u64 as usize => Some(BuiltinHintIds::Eq), - x if x == BuiltinHintIds::NotEq as u64 as usize => Some(BuiltinHintIds::NotEq), - x if x == BuiltinHintIds::BoolOr as u64 as usize => Some(BuiltinHintIds::BoolOr), - x if x == BuiltinHintIds::BoolAnd as u64 as usize => Some(BuiltinHintIds::BoolAnd), - x if x == BuiltinHintIds::BitOr as u64 as usize => Some(BuiltinHintIds::BitOr), - x if x == BuiltinHintIds::BitAnd as u64 as usize => Some(BuiltinHintIds::BitAnd), - x if x == BuiltinHintIds::BitXor as u64 as usize => Some(BuiltinHintIds::BitXor), - x if x == BuiltinHintIds::Select as u64 as usize => Some(BuiltinHintIds::Select), - x if x == BuiltinHintIds::Pow as u64 as usize => Some(BuiltinHintIds::Pow), - x if x == BuiltinHintIds::IntDiv as u64 as usize => Some(BuiltinHintIds::IntDiv), - x if x == BuiltinHintIds::Mod as u64 as usize => Some(BuiltinHintIds::Mod), - x if x == BuiltinHintIds::ShiftL as u64 as usize => Some(BuiltinHintIds::ShiftL), - x if x == BuiltinHintIds::ShiftR as u64 as usize => Some(BuiltinHintIds::ShiftR), - x if x == BuiltinHintIds::LesserEq as u64 as usize => Some(BuiltinHintIds::LesserEq), - x if x == BuiltinHintIds::GreaterEq as u64 as usize => Some(BuiltinHintIds::GreaterEq), - x if x == BuiltinHintIds::Lesser as u64 as usize => Some(BuiltinHintIds::Lesser), - x if x == BuiltinHintIds::Greater as u64 as usize => Some(BuiltinHintIds::Greater), - _ => None, - } - } -} - -fn stub_impl_general(hint_id: usize, inputs: &Vec, num_outputs: usize) -> Vec { - let mut hasher = DefaultHasher::new(); - hint_id.hash(&mut hasher); - inputs.hash(&mut hasher); - let mut outputs = Vec::with_capacity(num_outputs); - for _ in 0..num_outputs { - let t = hasher.finish(); - outputs.push(F::from(t as u32)); - t.hash(&mut hasher); - } - outputs -} - -fn validate_builtin_hint( - hint_id: BuiltinHintIds, - num_inputs: usize, - num_outputs: usize, -) -> Result<(), Error> { - match hint_id { - BuiltinHintIds::Identity => { - if num_inputs != num_outputs { - return Err(Error::InternalError( - "identity hint requires exactly the same number of inputs and outputs" - .to_string(), - )); - } - if num_inputs == 0 { - return Err(Error::InternalError( - "identity hint requires at least 1 input".to_string(), - )); - } - } - BuiltinHintIds::Div - | BuiltinHintIds::Eq - | BuiltinHintIds::NotEq - | BuiltinHintIds::BoolOr - | BuiltinHintIds::BoolAnd - | BuiltinHintIds::BitOr - | BuiltinHintIds::BitAnd - | BuiltinHintIds::BitXor - | BuiltinHintIds::Pow - | BuiltinHintIds::IntDiv - | BuiltinHintIds::Mod - | BuiltinHintIds::ShiftL - | BuiltinHintIds::ShiftR - | BuiltinHintIds::LesserEq - | BuiltinHintIds::GreaterEq - | BuiltinHintIds::Lesser - | BuiltinHintIds::Greater => { - if num_inputs != 2 { - return Err(Error::InternalError( - "binary op requires exactly 2 inputs".to_string(), - )); - } - if num_outputs != 1 { - return Err(Error::InternalError( - "binary op requires exactly 1 output".to_string(), - )); - } - } - BuiltinHintIds::Select => { - if num_inputs != 3 { - return Err(Error::InternalError( - "select requires exactly 3 inputs".to_string(), - )); - } - if num_outputs != 1 { - return Err(Error::InternalError( - "select requires exactly 1 output".to_string(), - )); - } - } - } - Ok(()) -} +use registry::HintRegistry; -pub fn validate_hint(hint_id: usize, num_inputs: usize, num_outputs: usize) -> Result<(), Error> { - match BuiltinHintIds::from_usize(hint_id) { - Some(hint_id) => validate_builtin_hint(hint_id, num_inputs, num_outputs), - None => { - if num_outputs == 0 { - return Err(Error::InternalError( - "custom hint requires at least 1 output".to_string(), - )); - } - if num_inputs == 0 { - return Err(Error::InternalError( - "custom hint requires at least 1 input".to_string(), - )); - } - Ok(()) - } - } -} +use crate::{field::Field, utils::error::Error}; -fn impl_builtin_hint( - hint_id: BuiltinHintIds, - inputs: &[F], +pub fn safe_impl( + hint_registry: &mut HintRegistry, + hint_id: usize, + inputs: &Vec, num_outputs: usize, -) -> Vec { - match hint_id { - BuiltinHintIds::Identity => inputs.iter().take(num_outputs).cloned().collect(), - BuiltinHintIds::Div => binop_hint(inputs, |x, y| match y.inv() { - Some(inv) => x * inv, - None => F::zero(), - }), - BuiltinHintIds::Eq => binop_hint(inputs, |x, y| F::from((x == y) as u32)), - BuiltinHintIds::NotEq => binop_hint(inputs, |x, y| F::from((x != y) as u32)), - BuiltinHintIds::BoolOr => binop_hint(inputs, |x, y| { - F::from((!x.is_zero() || !y.is_zero()) as u32) - }), - BuiltinHintIds::BoolAnd => binop_hint(inputs, |x, y| { - F::from((!x.is_zero() && !y.is_zero()) as u32) - }), - BuiltinHintIds::BitOr => binop_hint_on_u256(inputs, |x, y| x | y), - BuiltinHintIds::BitAnd => binop_hint_on_u256(inputs, |x, y| x & y), - BuiltinHintIds::BitXor => binop_hint_on_u256(inputs, |x, y| x ^ y), - BuiltinHintIds::Select => { - let mut outputs = Vec::with_capacity(num_outputs); - outputs.push(if !inputs[0].is_zero() { - inputs[1] - } else { - inputs[2] - }); - outputs - } - BuiltinHintIds::Pow => binop_hint(inputs, |x, y| { - let mut t = x; - let mut res = F::one(); - let mut y: U256 = y.to_u256(); - while y != U256::ZERO { - if y & U256::from(1u32) != U256::ZERO { - res *= t; - } - y >>= 1; - t = t * t; - } - res - }), - BuiltinHintIds::IntDiv => { - binop_hint_on_u256( - inputs, - |x, y| if y == U256::ZERO { U256::ZERO } else { x / y }, - ) - } - BuiltinHintIds::Mod => { - binop_hint_on_u256( - inputs, - |x, y| if y == U256::ZERO { U256::ZERO } else { x % y }, - ) - } - BuiltinHintIds::ShiftL => binop_hint_on_u256(inputs, |x, y| circom_shift_l_impl::(x, y)), - BuiltinHintIds::ShiftR => binop_hint_on_u256(inputs, |x, y| circom_shift_r_impl::(x, y)), - BuiltinHintIds::LesserEq => binop_hint(inputs, |x, y| F::from((x <= y) as u32)), - BuiltinHintIds::GreaterEq => binop_hint(inputs, |x, y| F::from((x >= y) as u32)), - BuiltinHintIds::Lesser => binop_hint(inputs, |x, y| F::from((x < y) as u32)), - BuiltinHintIds::Greater => binop_hint(inputs, |x, y| F::from((x > y) as u32)), - } -} - -fn binop_hint F>(inputs: &[F], f: G) -> Vec { - vec![f(inputs[0], inputs[1])] -} - -fn binop_hint_on_u256 U256>(inputs: &[F], f: G) -> Vec { - let x_u256: U256 = inputs[0].to_u256(); - let y_u256: U256 = inputs[1].to_u256(); - let z_u256 = f(x_u256, y_u256); - vec![F::from_u256(z_u256)] -} - -pub fn stub_impl(hint_id: usize, inputs: &Vec, num_outputs: usize) -> Vec { +) -> Result, Error> { match BuiltinHintIds::from_usize(hint_id) { - Some(hint_id) => impl_builtin_hint(hint_id, inputs, num_outputs), - None => stub_impl_general(hint_id, inputs, num_outputs), - } -} - -pub fn random_builtin(mut rand: impl RngCore) -> (usize, usize, usize) { - loop { - let hint_id = (rand.next_u64() as usize % 100) + (BuiltinHintIds::Identity as u64 as usize); - if let Some(hint_id) = BuiltinHintIds::from_usize(hint_id) { - match hint_id { - BuiltinHintIds::Identity => { - let num_inputs = (rand.next_u64() % 10) as usize + 1; - let num_outputs = num_inputs; - return (hint_id as usize, num_inputs, num_outputs); - } - BuiltinHintIds::Div - | BuiltinHintIds::Eq - | BuiltinHintIds::NotEq - | BuiltinHintIds::BoolOr - | BuiltinHintIds::BoolAnd - | BuiltinHintIds::BitOr - | BuiltinHintIds::BitAnd - | BuiltinHintIds::BitXor - | BuiltinHintIds::Pow - | BuiltinHintIds::IntDiv - | BuiltinHintIds::Mod - | BuiltinHintIds::ShiftL - | BuiltinHintIds::ShiftR - | BuiltinHintIds::LesserEq - | BuiltinHintIds::GreaterEq - | BuiltinHintIds::Lesser - | BuiltinHintIds::Greater => { - return (hint_id as usize, 2, 1); - } - BuiltinHintIds::Select => { - return (hint_id as usize, 3, 1); - } - } - } - } -} - -pub fn u256_bit_length(x: U256) -> usize { - 256 - x.leading_zeros() as usize -} - -pub fn circom_shift_l_impl(x: U256, k: U256) -> U256 { - let top = F::modulus() / 2; - if k <= top { - let shift = if (k >> U256::from(64u32)) == U256::ZERO { - k.as_u64() as usize - } else { - u256_bit_length(F::modulus()) - }; - if shift >= 256 { - return U256::ZERO; - } - let value = x << shift; - let mask = U256::from(1u32) << u256_bit_length(F::modulus()); - let mask = mask - 1; - value & mask - } else { - circom_shift_r_impl::(x, F::modulus() - k) - } -} - -pub fn circom_shift_r_impl(x: U256, k: U256) -> U256 { - let top = F::modulus() / 2; - if k <= top { - let shift = if (k >> U256::from(64u32)) == U256::ZERO { - k.as_u64() as usize - } else { - u256_bit_length(F::modulus()) - }; - if shift >= 256 { - return U256::ZERO; - } - x >> shift - } else { - circom_shift_l_impl::(x, F::modulus() - k) + Some(hint_id) => Ok(impl_builtin_hint(hint_id, inputs, num_outputs)), + None => hint_registry.call(hint_id, inputs, num_outputs), } } diff --git a/expander_compiler/src/hints/registry.rs b/expander_compiler/src/hints/registry.rs new file mode 100644 index 0000000..4a3c72f --- /dev/null +++ b/expander_compiler/src/hints/registry.rs @@ -0,0 +1,53 @@ +use std::collections::HashMap; + +use tiny_keccak::Hasher; + +use crate::{field::Field, utils::error::Error}; + +use super::BuiltinHintIds; + +pub type HintFn = dyn FnMut(&[F], &mut [F]) -> Result<(), Error>; + +pub struct HintRegistry { + hints: HashMap>>, +} + +pub fn hint_key_to_id(key: &str) -> usize { + let mut hasher = tiny_keccak::Keccak::v256(); + hasher.update(key.as_bytes()); + let mut hash = [0u8; 32]; + hasher.finalize(&mut hash); + + let res = usize::from_le_bytes(hash[0..8].try_into().unwrap()); + if BuiltinHintIds::from_usize(res).is_some() { + panic!("Hint id {} collides with a builtin hint id", res); + } + res +} + +impl HintRegistry { + pub fn new() -> Self { + Self { + hints: HashMap::new(), + } + } + pub fn register Result<(), Error> + 'static>( + &mut self, + key: &str, + hint: Hint, + ) { + let id = hint_key_to_id(key); + if self.hints.contains_key(&id) { + panic!("Hint with id {} already exists", id); + } + self.hints.insert(id, Box::new(hint)); + } + pub fn call(&mut self, id: usize, args: &[F], num_outputs: usize) -> Result, Error> { + if let Some(hint) = self.hints.get_mut(&id) { + let mut outputs = vec![F::zero(); num_outputs]; + hint(args, &mut outputs).map(|_| outputs) + } else { + panic!("Hint with id {} not found", id); + } + } +} diff --git a/expander_compiler/tests/to_binary_hint.rs b/expander_compiler/tests/to_binary_hint.rs new file mode 100644 index 0000000..fb7700f --- /dev/null +++ b/expander_compiler/tests/to_binary_hint.rs @@ -0,0 +1,89 @@ +use std::cell::RefCell; +use std::rc::Rc; + +use expander_compiler::frontend::*; + +declare_circuit!(Circuit { + input: PublicVariable, +}); + +fn to_binary(api: &mut API, x: Variable, n_bits: usize) -> Vec { + api.new_hint("myhint.tobinary", &vec![x], n_bits) +} + +fn from_binary(api: &mut API, bits: Vec) -> Variable { + let mut res = api.constant(0); + for i in 0..bits.len() { + let coef = 1 << i; + let cur = api.mul(coef, bits[i]); + res = api.add(res, cur); + } + res +} + +impl Define for Circuit { + fn define(&self, builder: &mut API) { + let bits = to_binary(builder, self.input, 8); + let x = from_binary(builder, bits); + builder.assert_is_equal(x, self.input); + } +} + +fn to_binary_hint(x: &[M31], y: &mut [M31]) -> Result<(), Error> { + let t = x[0].to_u256(); + for (i, k) in y.iter_mut().enumerate() { + *k = M31::from_u256(t >> i as u32 & 1); + } + Ok(()) +} + +#[test] +fn test_300() { + let mut hint_registry = HintRegistry::::new(); + hint_registry.register("myhint.tobinary", to_binary_hint); + + let compile_result = compile(&Circuit::default()).unwrap(); + for i in 0..300 { + let assignment = Circuit:: { + input: M31::from(i as u32), + }; + let witness = compile_result + .witness_solver + .solve_witness_with_hints(&assignment, &mut hint_registry) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![i < 256]); + } +} + +#[test] +fn test_300_closure() { + let mut hint_registry = HintRegistry::::new(); + let call_count = Rc::new(RefCell::new(0)); + let call_count_clone = call_count.clone(); + hint_registry.register( + "myhint.tobinary", + move |x: &[M31], y: &mut [M31]| -> Result<(), Error> { + *call_count_clone.borrow_mut() += 1; + let t = x[0].to_u256(); + for (i, k) in y.iter_mut().enumerate() { + *k = M31::from_u256(t >> i as u32 & 1); + } + Ok(()) + }, + ); + + let compile_result = compile(&Circuit::default()).unwrap(); + for i in 0..300 { + let assignment = Circuit:: { + input: M31::from(i as u32), + }; + let witness = compile_result + .witness_solver + .solve_witness_with_hints(&assignment, &mut hint_registry) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![i < 256]); + } + assert_eq!(*call_count.borrow(), 300); +}