From 385ca25405c692420c4d8aabf4e78745ccca839b Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Thu, 9 Jan 2025 03:35:32 +0700 Subject: [PATCH] Rust const variables (#60) * rust const variables * clippy --- expander_compiler/src/frontend/api.rs | 6 + expander_compiler/src/frontend/builder.rs | 146 ++++++++++++++++++++-- expander_compiler/src/frontend/debug.rs | 6 + 3 files changed, 150 insertions(+), 8 deletions(-) diff --git a/expander_compiler/src/frontend/api.rs b/expander_compiler/src/frontend/api.rs index c66ffb2..75c7549 100644 --- a/expander_compiler/src/frontend/api.rs +++ b/expander_compiler/src/frontend/api.rs @@ -57,6 +57,12 @@ pub trait BasicAPI { num_outputs: usize, ) -> Vec; fn constant(&mut self, x: impl ToVariableOrValue) -> Variable; + // try to get the value of a compile-time constant variable + // this function has different behavior in normal and debug mode, in debug mode it always returns Some(value) + fn constant_value( + &mut self, + x: impl ToVariableOrValue, + ) -> Option; } pub trait UnconstrainedAPI { diff --git a/expander_compiler/src/frontend/builder.rs b/expander_compiler/src/frontend/builder.rs index bf1a435..b6918e8 100644 --- a/expander_compiler/src/frontend/builder.rs +++ b/expander_compiler/src/frontend/builder.rs @@ -22,7 +22,8 @@ use super::api::{BasicAPI, DebugAPI, RootAPI, UnconstrainedAPI}; pub struct Builder { instructions: Vec>, constraints: Vec, - var_max: usize, + var_const_id: Vec, + const_values: Vec, num_inputs: usize, } @@ -31,6 +32,12 @@ pub struct Variable { id: usize, } +impl Variable { + pub fn id(&self) -> usize { + self.id + } +} + pub fn new_variable(id: usize) -> Variable { Variable { id } } @@ -44,7 +51,7 @@ pub enum VariableOrValue { Value(F), } -pub trait ToVariableOrValue { +pub trait ToVariableOrValue: Clone { fn convert_to_variable_or_value(self) -> VariableOrValue; } @@ -53,7 +60,7 @@ impl NotVariable for u32 {} impl NotVariable for U256 {} impl NotVariable for F {} -impl + NotVariable> ToVariableOrValue for T { +impl + NotVariable + Clone> ToVariableOrValue for T { fn convert_to_variable_or_value(self) -> VariableOrValue { VariableOrValue::Value(self.into()) } @@ -77,8 +84,9 @@ impl Builder { Builder { instructions: Vec::new(), constraints: Vec::new(), - var_max: num_inputs, num_inputs, + var_const_id: vec![0; num_inputs + 1], + const_values: vec![C::CircuitField::zero()], }, (1..=num_inputs).map(|id| Variable { id }).collect(), ) @@ -99,15 +107,20 @@ impl Builder { VariableOrValue::Value(v) => { self.instructions .push(SourceInstruction::ConstantLike(Coef::Constant(v))); - self.var_max += 1; - Variable { id: self.var_max } + self.var_const_id.push(self.const_values.len()); + self.const_values.push(v); + Variable { + id: self.var_const_id.len() - 1, + } } } } fn new_var(&mut self) -> Variable { - self.var_max += 1; - Variable { id: self.var_max } + self.var_const_id.push(0); + Variable { + id: self.var_const_id.len() - 1, + } } } @@ -117,6 +130,13 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + return self.constant(xv + yv); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::LinComb(LinComb { @@ -140,6 +160,13 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + return self.constant(xv - yv); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::LinComb(LinComb { @@ -159,6 +186,10 @@ impl BasicAPI for Builder { } fn neg(&mut self, x: impl ToVariableOrValue) -> Variable { + let xc = self.constant_value(x.clone()); + if let Some(xv) = xc { + return self.constant(-xv); + } let x = self.convert_to_variable(x); self.instructions.push(SourceInstruction::LinComb(LinComb { terms: vec![LinCombTerm { @@ -175,6 +206,13 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + return self.constant(xv * yv); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions @@ -188,6 +226,21 @@ impl BasicAPI for Builder { y: impl ToVariableOrValue, checked: bool, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + let res = if yv.is_zero() { + if checked || !xv.is_zero() { + panic!("division by zero"); + } + C::CircuitField::zero() + } else { + xv * yv.inv().unwrap() + }; + return self.constant(res); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::Div { @@ -203,6 +256,15 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + self.assert_is_bool(xv); + self.assert_is_bool(yv); + return self.constant(C::CircuitField::from((xv != yv) as u32)); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::BoolBinOp { @@ -218,6 +280,17 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + self.assert_is_bool(xv); + self.assert_is_bool(yv); + return self.constant(C::CircuitField::from( + (!xv.is_zero() || !yv.is_zero()) as u32, + )); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::BoolBinOp { @@ -233,6 +306,17 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + self.assert_is_bool(xv); + self.assert_is_bool(yv); + return self.constant(C::CircuitField::from( + (!xv.is_zero() && !yv.is_zero()) as u32, + )); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::BoolBinOp { @@ -244,12 +328,22 @@ impl BasicAPI for Builder { } fn is_zero(&mut self, x: impl ToVariableOrValue) -> Variable { + let xc = self.constant_value(x.clone()); + if let Some(xv) = xc { + return self.constant(C::CircuitField::from(xv.is_zero() as u32)); + } let x = self.convert_to_variable(x); self.instructions.push(SourceInstruction::IsZero(x.id)); self.new_var() } fn assert_is_zero(&mut self, x: impl ToVariableOrValue) { + let xc = self.constant_value(x.clone()); + if let Some(xv) = xc { + if !xv.is_zero() { + panic!("assert_is_zero failed"); + } + } let x = self.convert_to_variable(x); self.constraints.push(SourceConstraint { typ: source::ConstraintType::Zero, @@ -258,6 +352,12 @@ impl BasicAPI for Builder { } fn assert_is_non_zero(&mut self, x: impl ToVariableOrValue) { + let xc = self.constant_value(x.clone()); + if let Some(xv) = xc { + if xv.is_zero() { + panic!("assert_is_zero failed"); + } + } let x = self.convert_to_variable(x); self.constraints.push(SourceConstraint { typ: source::ConstraintType::NonZero, @@ -266,6 +366,12 @@ impl BasicAPI for Builder { } fn assert_is_bool(&mut self, x: impl ToVariableOrValue) { + let xc = self.constant_value(x.clone()); + if let Some(xv) = xc { + if !xv.is_zero() && xv != C::CircuitField::one() { + panic!("assert_is_bool failed"); + } + } let x = self.convert_to_variable(x); self.constraints.push(SourceConstraint { typ: source::ConstraintType::Bool, @@ -296,6 +402,23 @@ impl BasicAPI for Builder { fn constant(&mut self, value: impl ToVariableOrValue) -> Variable { self.convert_to_variable(value) } + + fn constant_value( + &mut self, + x: impl ToVariableOrValue<::CircuitField>, + ) -> Option<::CircuitField> { + match x.convert_to_variable_or_value() { + VariableOrValue::Variable(v) => { + let t = self.var_const_id[v.id]; + if t != 0 { + Some(self.const_values[t]) + } else { + None + } + } + VariableOrValue::Value(v) => Some(v), + } + } } // write macro rules for unconstrained binary op definition @@ -442,6 +565,13 @@ impl BasicAPI for RootBuilder { fn constant(&mut self, x: impl ToVariableOrValue<::CircuitField>) -> Variable { self.last_builder().constant(x) } + + fn constant_value( + &mut self, + x: impl ToVariableOrValue<::CircuitField>, + ) -> Option<::CircuitField> { + self.last_builder().constant_value(x) + } } impl RootAPI for RootBuilder { diff --git a/expander_compiler/src/frontend/debug.rs b/expander_compiler/src/frontend/debug.rs index 2020a52..ccffe8b 100644 --- a/expander_compiler/src/frontend/debug.rs +++ b/expander_compiler/src/frontend/debug.rs @@ -145,6 +145,12 @@ impl> BasicAPI for DebugBuilder::CircuitField>, + ) -> Option<::CircuitField> { + Some(self.convert_to_value(x)) + } } impl> UnconstrainedAPI for DebugBuilder {