Skip to content

Commit

Permalink
hints
Browse files Browse the repository at this point in the history
  • Loading branch information
siq1 committed Dec 10, 2024
1 parent 7f7c612 commit b506cdd
Show file tree
Hide file tree
Showing 10 changed files with 587 additions and 347 deletions.
66 changes: 45 additions & 21 deletions expander_compiler/src/circuit/ir/hint_normalized/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::HashMap;

use crate::field::FieldArith;
use crate::hints::registry::HintRegistry;
use crate::utils::error::Error;
use crate::{
circuit::{
Expand Down Expand Up @@ -201,6 +202,42 @@ impl<C: Config> common::Instruction<C> for Instruction<C> {
}
}

impl<C: Config> Instruction<C> {
fn eval_safe(
&self,
values: &[C::CircuitField],
public_inputs: &[C::CircuitField],
hint_registry: &mut HintRegistry<C::CircuitField>,
) -> EvalResult<C> {
if let Instruction::ConstantLike(coef) = self {
return match coef {
Coef::Constant(c) => EvalResult::Value(c.clone()),
Coef::PublicInput(i) => EvalResult::Value(public_inputs[*i]),
Coef::Random => EvalResult::Error(Error::UserError(
"random coef occured in witness solver".to_string(),
)),
};
}
if let Instruction::Hint {
hint_id,
inputs,
num_outputs,
} = self
{
return match hints::safe_impl(
hint_registry,
*hint_id,
&inputs.iter().map(|i| values[*i]).collect(),
*num_outputs,
) {
Ok(outputs) => EvalResult::Values(outputs),
Err(e) => EvalResult::Error(e),
};
}
self.eval_unsafe(values)
}
}

pub type Circuit<C> = common::Circuit<Irc<C>>;
pub type RootCircuit<C> = common::RootCircuit<Irc<C>>;

Expand Down Expand Up @@ -443,52 +480,39 @@ impl<C: Config> RootCircuit<C> {
self.circuits.insert(0, c0);
}

pub fn eval_with_public_inputs(
pub fn eval_safe(
&self,
inputs: Vec<C::CircuitField>,
public_inputs: &[C::CircuitField],
hint_registry: &mut HintRegistry<C::CircuitField>,
) -> Result<Vec<C::CircuitField>, Error> {
assert_eq!(inputs.len(), self.input_size());
self.eval_sub_with_public_inputs(&self.circuits[&0], inputs, public_inputs)
self.eval_sub_safe(&self.circuits[&0], inputs, public_inputs, hint_registry)
}

fn eval_sub_with_public_inputs(
fn eval_sub_safe(
&self,
circuit: &Circuit<C>,
inputs: Vec<C::CircuitField>,
public_inputs: &[C::CircuitField],
hint_registry: &mut HintRegistry<C::CircuitField>,
) -> Result<Vec<C::CircuitField>, Error> {
let mut values = vec![C::CircuitField::zero(); 1];
values.extend(inputs);
for insn in circuit.instructions.iter() {
if let Instruction::ConstantLike(coef) = insn {
match coef {
Coef::Constant(c) => {
values.push(*c);
}
Coef::PublicInput(i) => {
values.push(public_inputs[*i]);
}
Coef::Random => {
return Err(Error::UserError(
"random coef occured in witness solver".to_string(),
));
}
}
continue;
}
match insn.eval_unsafe(&values) {
match insn.eval_safe(&values, public_inputs, hint_registry) {
EvalResult::Value(v) => {
values.push(v);
}
EvalResult::Values(mut vs) => {
values.append(&mut vs);
}
EvalResult::SubCircuitCall(sub_circuit_id, inputs) => {
let res = self.eval_sub_with_public_inputs(
let res = self.eval_sub_safe(
&self.circuits[&sub_circuit_id],
inputs.iter().map(|&i| values[i]).collect(),
public_inputs,
hint_registry,
)?;
values.extend(res);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ impl<C: Config> WitnessSolver<C> {
&self,
vars: Vec<C::CircuitField>,
public_vars: Vec<C::CircuitField>,
hint_registry: &mut HintRegistry<C::CircuitField>,
) -> Result<(Vec<C::CircuitField>, usize), Error> {
assert_eq!(vars.len(), self.circuit.input_size());
assert_eq!(public_vars.len(), self.circuit.num_public_inputs);
let mut a = self.circuit.eval_with_public_inputs(vars, &public_vars)?;
let mut a = self.circuit.eval_safe(vars, &public_vars, hint_registry)?;
let res_len = a.len();
a.extend(public_vars);
Ok((a, res_len))
Expand All @@ -24,8 +25,10 @@ impl<C: Config> WitnessSolver<C> {
&self,
vars: Vec<C::CircuitField>,
public_vars: Vec<C::CircuitField>,
hint_registry: &mut HintRegistry<C::CircuitField>,
) -> Result<Witness<C>, Error> {
let (values, num_inputs_per_witness) = self.solve_witness_inner(vars, public_vars)?;
let (values, num_inputs_per_witness) =
self.solve_witness_inner(vars, public_vars, hint_registry)?;
Ok(Witness {
num_witnesses: 1,
num_inputs_per_witness,
Expand All @@ -40,12 +43,13 @@ impl<C: Config> WitnessSolver<C> {
&self,
num_witnesses: usize,
f: F,
hint_registry: &mut HintRegistry<C::CircuitField>,
) -> Result<Witness<C>, Error> {
let mut values = Vec::new();
let mut num_inputs_per_witness = 0;
for i in 0..num_witnesses {
let (a, b) = f(i);
let (a, num) = self.solve_witness_inner(a, b)?;
let (a, num) = self.solve_witness_inner(a, b, hint_registry)?;
values.extend(a);
num_inputs_per_witness = num;
}
Expand Down
6 changes: 6 additions & 0 deletions expander_compiler/src/frontend/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ pub trait BasicAPI<C: Config> {
y: impl ToVariableOrValue<C::CircuitField>,
);
fn get_random_value(&mut self) -> Variable;
fn new_hint(
&mut self,
hint_key: &str,
inputs: &[Variable],
num_outputs: usize,
) -> Vec<Variable>;
}

pub trait UnconstrainedAPI<C: Config> {
Expand Down
25 changes: 24 additions & 1 deletion expander_compiler/src/frontend/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
layered::Coef,
},
field::{Field, FieldArith},
hints,
hints::{self, registry::hint_key_to_id},
utils::function_id::get_function_id,
};

Expand Down Expand Up @@ -292,6 +292,20 @@ impl<C: Config> BasicAPI<C> for Builder<C> {
.push(SourceInstruction::ConstantLike(Coef::Random));
self.new_var()
}

fn new_hint(
&mut self,
hint_key: &str,
inputs: &[Variable],
num_outputs: usize,
) -> Vec<Variable> {
self.instructions.push(SourceInstruction::Hint {
hint_id: hint_key_to_id(hint_key),
inputs: inputs.iter().map(|v| v.id).collect(),
num_outputs,
});
(0..num_outputs).map(|_| self.new_var()).collect()
}
}

// write macro rules for unconstrained binary op definition
Expand Down Expand Up @@ -445,6 +459,15 @@ impl<C: Config> BasicAPI<C> for RootBuilder<C> {
fn get_random_value(&mut self) -> Variable {
self.last_builder().get_random_value()
}

fn new_hint(
&mut self,
hint_key: &str,
inputs: &[Variable],
num_outputs: usize,
) -> Vec<Variable> {
self.last_builder().new_hint(hint_key, inputs, num_outputs)
}
}

impl<C: Config> RootBuilder<C> {
Expand Down
3 changes: 2 additions & 1 deletion expander_compiler/src/frontend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ mod witness;
pub use circuit::declare_circuit;
pub type API<C> = builder::RootBuilder<C>;
pub use crate::circuit::config::*;
pub use crate::field::{Field, BN254, GF2, M31};
pub use crate::field::{Field, FieldArith, FieldModulus, BN254, GF2, M31};
pub use crate::hints::registry::HintRegistry;
pub use crate::utils::error::Error;
pub use api::BasicAPI;
pub use builder::Variable;
Expand Down
36 changes: 28 additions & 8 deletions expander_compiler/src/frontend/witness.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,48 @@
pub use crate::circuit::ir::hint_normalized::witness_solver::WitnessSolver;
use crate::circuit::layered::witness::Witness;
use crate::{circuit::layered::witness::Witness, hints::registry::HintRegistry};

use super::{internal, Config, Error};

impl<C: Config> WitnessSolver<C> {
pub fn solve_witness<Cir: internal::DumpLoadTwoVariables<C::CircuitField>>(
&self,
assignment: &Cir,
) -> Result<Witness<C>, Error> {
self.solve_witness_with_hints(assignment, &mut HintRegistry::new())
}

pub fn solve_witness_with_hints<Cir: internal::DumpLoadTwoVariables<C::CircuitField>>(
&self,
assignment: &Cir,
hint_registry: &mut HintRegistry<C::CircuitField>,
) -> Result<Witness<C>, Error> {
let mut vars = Vec::new();
let mut public_vars = Vec::new();
assignment.dump_into(&mut vars, &mut public_vars);
self.solve_witness_from_raw_inputs(vars, public_vars)
self.solve_witness_from_raw_inputs(vars, public_vars, hint_registry)
}

pub fn solve_witnesses<Cir: internal::DumpLoadTwoVariables<C::CircuitField>>(
&self,
assignments: &[Cir],
) -> Result<Witness<C>, Error> {
self.solve_witnesses_from_raw_inputs(assignments.len(), |i| {
let mut vars = Vec::new();
let mut public_vars = Vec::new();
assignments[i].dump_into(&mut vars, &mut public_vars);
(vars, public_vars)
})
self.solve_witnesses_with_hints(assignments, &mut HintRegistry::new())
}

pub fn solve_witnesses_with_hints<Cir: internal::DumpLoadTwoVariables<C::CircuitField>>(
&self,
assignments: &[Cir],
hint_registry: &mut HintRegistry<C::CircuitField>,
) -> Result<Witness<C>, Error> {
self.solve_witnesses_from_raw_inputs(
assignments.len(),
|i| {
let mut vars = Vec::new();
let mut public_vars = Vec::new();
assignments[i].dump_into(&mut vars, &mut public_vars);
(vars, public_vars)
},
hint_registry,
)
}
}
Loading

0 comments on commit b506cdd

Please sign in to comment.