Skip to content

Commit

Permalink
Rust const variables (#60)
Browse files Browse the repository at this point in the history
* rust const variables

* clippy
  • Loading branch information
siq1 authored Jan 8, 2025
1 parent 1caf607 commit 385ca25
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 8 deletions.
6 changes: 6 additions & 0 deletions expander_compiler/src/frontend/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ pub trait BasicAPI<C: Config> {
num_outputs: usize,
) -> Vec<Variable>;
fn constant(&mut self, x: impl ToVariableOrValue<C::CircuitField>) -> Variable;
// try to get the value of a compile-time constant variable
// this function has different behavior in normal and debug mode, in debug mode it always returns Some(value)
fn constant_value(
&mut self,
x: impl ToVariableOrValue<C::CircuitField>,
) -> Option<C::CircuitField>;
}

pub trait UnconstrainedAPI<C: Config> {
Expand Down
146 changes: 138 additions & 8 deletions expander_compiler/src/frontend/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ use super::api::{BasicAPI, DebugAPI, RootAPI, UnconstrainedAPI};
pub struct Builder<C: Config> {
instructions: Vec<SourceInstruction<C>>,
constraints: Vec<SourceConstraint>,
var_max: usize,
var_const_id: Vec<usize>,
const_values: Vec<C::CircuitField>,
num_inputs: usize,
}

Expand All @@ -31,6 +32,12 @@ pub struct Variable {
id: usize,
}

impl Variable {
pub fn id(&self) -> usize {
self.id
}
}

pub fn new_variable(id: usize) -> Variable {
Variable { id }
}
Expand All @@ -44,7 +51,7 @@ pub enum VariableOrValue<F: Field> {
Value(F),
}

pub trait ToVariableOrValue<F: Field> {
pub trait ToVariableOrValue<F: Field>: Clone {
fn convert_to_variable_or_value(self) -> VariableOrValue<F>;
}

Expand All @@ -53,7 +60,7 @@ impl NotVariable for u32 {}
impl NotVariable for U256 {}
impl<F: Field> NotVariable for F {}

impl<F: Field, T: Into<F> + NotVariable> ToVariableOrValue<F> for T {
impl<F: Field, T: Into<F> + NotVariable + Clone> ToVariableOrValue<F> for T {
fn convert_to_variable_or_value(self) -> VariableOrValue<F> {
VariableOrValue::Value(self.into())
}
Expand All @@ -77,8 +84,9 @@ impl<C: Config> Builder<C> {
Builder {
instructions: Vec::new(),
constraints: Vec::new(),
var_max: num_inputs,
num_inputs,
var_const_id: vec![0; num_inputs + 1],
const_values: vec![C::CircuitField::zero()],
},
(1..=num_inputs).map(|id| Variable { id }).collect(),
)
Expand All @@ -99,15 +107,20 @@ impl<C: Config> Builder<C> {
VariableOrValue::Value(v) => {
self.instructions
.push(SourceInstruction::ConstantLike(Coef::Constant(v)));
self.var_max += 1;
Variable { id: self.var_max }
self.var_const_id.push(self.const_values.len());
self.const_values.push(v);
Variable {
id: self.var_const_id.len() - 1,
}
}
}
}

fn new_var(&mut self) -> Variable {
self.var_max += 1;
Variable { id: self.var_max }
self.var_const_id.push(0);
Variable {
id: self.var_const_id.len() - 1,
}
}
}

Expand All @@ -117,6 +130,13 @@ impl<C: Config> BasicAPI<C> for Builder<C> {
x: impl ToVariableOrValue<C::CircuitField>,
y: impl ToVariableOrValue<C::CircuitField>,
) -> Variable {
let xc = self.constant_value(x.clone());
let yc = self.constant_value(y.clone());
if let Some(xv) = xc {
if let Some(yv) = yc {
return self.constant(xv + yv);
}
}
let x = self.convert_to_variable(x);
let y = self.convert_to_variable(y);
self.instructions.push(SourceInstruction::LinComb(LinComb {
Expand All @@ -140,6 +160,13 @@ impl<C: Config> BasicAPI<C> for Builder<C> {
x: impl ToVariableOrValue<C::CircuitField>,
y: impl ToVariableOrValue<C::CircuitField>,
) -> Variable {
let xc = self.constant_value(x.clone());
let yc = self.constant_value(y.clone());
if let Some(xv) = xc {
if let Some(yv) = yc {
return self.constant(xv - yv);
}
}
let x = self.convert_to_variable(x);
let y = self.convert_to_variable(y);
self.instructions.push(SourceInstruction::LinComb(LinComb {
Expand All @@ -159,6 +186,10 @@ impl<C: Config> BasicAPI<C> for Builder<C> {
}

fn neg(&mut self, x: impl ToVariableOrValue<C::CircuitField>) -> Variable {
let xc = self.constant_value(x.clone());
if let Some(xv) = xc {
return self.constant(-xv);
}
let x = self.convert_to_variable(x);
self.instructions.push(SourceInstruction::LinComb(LinComb {
terms: vec![LinCombTerm {
Expand All @@ -175,6 +206,13 @@ impl<C: Config> BasicAPI<C> for Builder<C> {
x: impl ToVariableOrValue<C::CircuitField>,
y: impl ToVariableOrValue<C::CircuitField>,
) -> Variable {
let xc = self.constant_value(x.clone());
let yc = self.constant_value(y.clone());
if let Some(xv) = xc {
if let Some(yv) = yc {
return self.constant(xv * yv);
}
}
let x = self.convert_to_variable(x);
let y = self.convert_to_variable(y);
self.instructions
Expand All @@ -188,6 +226,21 @@ impl<C: Config> BasicAPI<C> for Builder<C> {
y: impl ToVariableOrValue<C::CircuitField>,
checked: bool,
) -> Variable {
let xc = self.constant_value(x.clone());
let yc = self.constant_value(y.clone());
if let Some(xv) = xc {
if let Some(yv) = yc {
let res = if yv.is_zero() {
if checked || !xv.is_zero() {
panic!("division by zero");
}
C::CircuitField::zero()
} else {
xv * yv.inv().unwrap()
};
return self.constant(res);
}
}
let x = self.convert_to_variable(x);
let y = self.convert_to_variable(y);
self.instructions.push(SourceInstruction::Div {
Expand All @@ -203,6 +256,15 @@ impl<C: Config> BasicAPI<C> for Builder<C> {
x: impl ToVariableOrValue<C::CircuitField>,
y: impl ToVariableOrValue<C::CircuitField>,
) -> Variable {
let xc = self.constant_value(x.clone());
let yc = self.constant_value(y.clone());
if let Some(xv) = xc {
if let Some(yv) = yc {
self.assert_is_bool(xv);
self.assert_is_bool(yv);
return self.constant(C::CircuitField::from((xv != yv) as u32));
}
}
let x = self.convert_to_variable(x);
let y = self.convert_to_variable(y);
self.instructions.push(SourceInstruction::BoolBinOp {
Expand All @@ -218,6 +280,17 @@ impl<C: Config> BasicAPI<C> for Builder<C> {
x: impl ToVariableOrValue<C::CircuitField>,
y: impl ToVariableOrValue<C::CircuitField>,
) -> Variable {
let xc = self.constant_value(x.clone());
let yc = self.constant_value(y.clone());
if let Some(xv) = xc {
if let Some(yv) = yc {
self.assert_is_bool(xv);
self.assert_is_bool(yv);
return self.constant(C::CircuitField::from(
(!xv.is_zero() || !yv.is_zero()) as u32,
));
}
}
let x = self.convert_to_variable(x);
let y = self.convert_to_variable(y);
self.instructions.push(SourceInstruction::BoolBinOp {
Expand All @@ -233,6 +306,17 @@ impl<C: Config> BasicAPI<C> for Builder<C> {
x: impl ToVariableOrValue<C::CircuitField>,
y: impl ToVariableOrValue<C::CircuitField>,
) -> Variable {
let xc = self.constant_value(x.clone());
let yc = self.constant_value(y.clone());
if let Some(xv) = xc {
if let Some(yv) = yc {
self.assert_is_bool(xv);
self.assert_is_bool(yv);
return self.constant(C::CircuitField::from(
(!xv.is_zero() && !yv.is_zero()) as u32,
));
}
}
let x = self.convert_to_variable(x);
let y = self.convert_to_variable(y);
self.instructions.push(SourceInstruction::BoolBinOp {
Expand All @@ -244,12 +328,22 @@ impl<C: Config> BasicAPI<C> for Builder<C> {
}

fn is_zero(&mut self, x: impl ToVariableOrValue<C::CircuitField>) -> Variable {
let xc = self.constant_value(x.clone());
if let Some(xv) = xc {
return self.constant(C::CircuitField::from(xv.is_zero() as u32));
}
let x = self.convert_to_variable(x);
self.instructions.push(SourceInstruction::IsZero(x.id));
self.new_var()
}

fn assert_is_zero(&mut self, x: impl ToVariableOrValue<C::CircuitField>) {
let xc = self.constant_value(x.clone());
if let Some(xv) = xc {
if !xv.is_zero() {
panic!("assert_is_zero failed");
}
}
let x = self.convert_to_variable(x);
self.constraints.push(SourceConstraint {
typ: source::ConstraintType::Zero,
Expand All @@ -258,6 +352,12 @@ impl<C: Config> BasicAPI<C> for Builder<C> {
}

fn assert_is_non_zero(&mut self, x: impl ToVariableOrValue<C::CircuitField>) {
let xc = self.constant_value(x.clone());
if let Some(xv) = xc {
if xv.is_zero() {
panic!("assert_is_zero failed");
}
}
let x = self.convert_to_variable(x);
self.constraints.push(SourceConstraint {
typ: source::ConstraintType::NonZero,
Expand All @@ -266,6 +366,12 @@ impl<C: Config> BasicAPI<C> for Builder<C> {
}

fn assert_is_bool(&mut self, x: impl ToVariableOrValue<C::CircuitField>) {
let xc = self.constant_value(x.clone());
if let Some(xv) = xc {
if !xv.is_zero() && xv != C::CircuitField::one() {
panic!("assert_is_bool failed");
}
}
let x = self.convert_to_variable(x);
self.constraints.push(SourceConstraint {
typ: source::ConstraintType::Bool,
Expand Down Expand Up @@ -296,6 +402,23 @@ impl<C: Config> BasicAPI<C> for Builder<C> {
fn constant(&mut self, value: impl ToVariableOrValue<C::CircuitField>) -> Variable {
self.convert_to_variable(value)
}

fn constant_value(
&mut self,
x: impl ToVariableOrValue<<C as Config>::CircuitField>,
) -> Option<<C as Config>::CircuitField> {
match x.convert_to_variable_or_value() {
VariableOrValue::Variable(v) => {
let t = self.var_const_id[v.id];
if t != 0 {
Some(self.const_values[t])
} else {
None
}
}
VariableOrValue::Value(v) => Some(v),
}
}
}

// write macro rules for unconstrained binary op definition
Expand Down Expand Up @@ -442,6 +565,13 @@ impl<C: Config> BasicAPI<C> for RootBuilder<C> {
fn constant(&mut self, x: impl ToVariableOrValue<<C as Config>::CircuitField>) -> Variable {
self.last_builder().constant(x)
}

fn constant_value(
&mut self,
x: impl ToVariableOrValue<<C as Config>::CircuitField>,
) -> Option<<C as Config>::CircuitField> {
self.last_builder().constant_value(x)
}
}

impl<C: Config> RootAPI<C> for RootBuilder<C> {
Expand Down
6 changes: 6 additions & 0 deletions expander_compiler/src/frontend/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ impl<C: Config, H: HintCaller<C::CircuitField>> BasicAPI<C> for DebugBuilder<C,
let x = self.convert_to_value(x);
self.return_as_variable(x)
}
fn constant_value(
&mut self,
x: impl ToVariableOrValue<<C as Config>::CircuitField>,
) -> Option<<C as Config>::CircuitField> {
Some(self.convert_to_value(x))
}
}

impl<C: Config, H: HintCaller<C::CircuitField>> UnconstrainedAPI<C> for DebugBuilder<C, H> {
Expand Down

0 comments on commit 385ca25

Please sign in to comment.