Skip to content

Commit

Permalink
Simplify expressions.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon-Ti committed Nov 24, 2024
1 parent e350aa7 commit 56185a3
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 14 deletions.
101 changes: 94 additions & 7 deletions crates/prover/src/constraint_framework/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ pub enum Expr {
}

impl Expr {
#[allow(dead_code)]
pub fn format_expr(&self) -> String {
match self {
Expr::Col(ColumnExpr {
Expand Down Expand Up @@ -67,6 +66,10 @@ impl Expr {
Expr::Inv(a) => format!("1 / ({})", a.format_expr()),
}
}

pub fn simplify_and_format(&self) -> String {
simplify(self.clone()).format_expr()
}
}

impl From<BaseField> for Expr {
Expand Down Expand Up @@ -183,6 +186,91 @@ impl AddAssign<BaseField> for Expr {
}
}

pub fn simplify(expr: Expr) -> Expr {
match expr {
Expr::Add(a, b) => {
let a = simplify(*a);
let b = simplify(*b);
if let (Expr::Const(a), Expr::Const(b)) = (a.clone(), b.clone()) {
Expr::Const(a + b)
} else if a == Expr::zero() {
b
} else if b == Expr::zero() {
a
} else if let Expr::Neg(a) = a {
if let Expr::Neg(b) = b {
-(*a + *b)
} else {
b - *a
}
} else if let Expr::Neg(b) = b {
a - *b
} else {
a + b
}
}
Expr::Sub(a, b) => {
let a = simplify(*a);
let b = simplify(*b);
if a == Expr::zero() {
-b
} else if b == Expr::zero() {
a
} else if a == b {
Expr::zero()
} else {
a - b
}
}
Expr::Mul(a, b) => {
let a = simplify(*a);
let b = simplify(*b);
let minus_one = Expr::Const(-BaseField::one());
if let (Expr::Const(a), Expr::Const(b)) = (a.clone(), b.clone()) {
Expr::Const(a * b)
} else if a == Expr::zero() || b == Expr::zero() {
Expr::zero()
} else if a == Expr::one() {
b
} else if b == Expr::one() {
a
} else if a == minus_one {
-b
} else if b == minus_one {
-a
} else {
a * b
}
}
Expr::Col(colexpr) => Expr::Col(colexpr),
Expr::SecureCol([a, b, c, d]) => Expr::SecureCol([
Box::new(simplify(*a)),
Box::new(simplify(*b)),
Box::new(simplify(*c)),
Box::new(simplify(*d)),
]),
Expr::Const(c) => Expr::Const(c),
Expr::Param(x) => Expr::Param(x),
Expr::Neg(a) => {
let a = simplify(*a);
match a {
Expr::Neg(b) => *b,
Expr::Const(c) => Expr::Const(-c),
Expr::Sub(a, b) => Expr::Sub(b, a),
_ => -a,
}
}
Expr::Inv(a) => {
let a = simplify(*a);
match a {
Expr::Inv(b) => *b,
Expr::Const(c) => Expr::Const(c.inverse()),
_ => Expr::Inv(Box::new(a)),
}
}
}
}

/// Returns the expression
/// `value[0] * <relation>_alpha0 + value[1] * <relation>_alpha1 + ... - <relation>_z.`
fn combine_formal<R: Relation<Expr, Expr>>(relation: &R, values: &[Expr]) -> Expr {
Expand Down Expand Up @@ -266,15 +354,15 @@ impl ExprEvaluator {
let lets_string = self
.intermediates
.iter()
.map(|(name, expr)| format!("let {} = {};", name, expr.format_expr()))
.map(|(name, expr)| format!("let {} = {};", name, expr.simplify_and_format()))
.collect::<Vec<String>>()
.join("\n");

let constraints_str = self
.constraints
.iter()
.enumerate()
.map(|(i, c)| format!("let constraint_{i} = ") + &c.format_expr() + ";")
.map(|(i, c)| format!("let constraint_{i} = ") + &c.simplify_and_format() + ";")
.collect::<Vec<String>>()
.join("\n\n");

Expand Down Expand Up @@ -362,8 +450,7 @@ mod tests {
fn test_format_expr() {
let test_struct = TestStruct {};
let eval = test_struct.evaluate(ExprEvaluator::new(16, false));
let expected = "let temp_0 = 0 \
+ (TestRelation_alpha0) * (col_1_0[0]) \
let expected = "let temp_0 = (TestRelation_alpha0) * (col_1_0[0]) \
+ (TestRelation_alpha1) * (col_1_1[0]) \
+ (TestRelation_alpha2) * (col_1_2[0]) \
- (TestRelation_z);
Expand All @@ -375,8 +462,8 @@ mod tests {
\
let constraint_1 = (SecureCol(col_2_4[0], col_2_6[0], col_2_8[0], col_2_10[0]) \
- (SecureCol(col_2_5[-1], col_2_7[-1], col_2_9[-1], col_2_11[-1]) \
- ((col_0_3[0]) * (total_sum))) \
- (0)) \
- ((col_0_3[0]) * (total_sum)))\
) \
* (temp_0) \
- (1);"
.to_string();
Expand Down
12 changes: 5 additions & 7 deletions crates/prover/src/examples/state_machine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,13 +293,11 @@ mod tests {
);

let eval = component.evaluate(ExprEvaluator::new(log_n_rows, true));
let expected = "let temp_0 = 0 \
+ (StateMachineElements_alpha0) * (col_1_0[0]) \
let expected = "let temp_0 = (StateMachineElements_alpha0) * (col_1_0[0]) \
+ (StateMachineElements_alpha1) * (col_1_1[0]) \
- (StateMachineElements_z);
\
let temp_1 = 0 \
+ (StateMachineElements_alpha0) * (col_1_0[0] + 1) \
let temp_1 = (StateMachineElements_alpha0) * (col_1_0[0] + 1) \
+ (StateMachineElements_alpha1) * (col_1_1[0]) \
- (StateMachineElements_z);
Expand All @@ -315,10 +313,10 @@ mod tests {
\
let constraint_1 = (SecureCol(col_2_3[0], col_2_6[0], col_2_9[0], col_2_12[0]) \
- (SecureCol(col_2_4[-1], col_2_7[-1], col_2_10[-1], col_2_13[-1]) \
- ((col_0_2[0]) * (total_sum))) \
- (0)) \
- ((col_0_2[0]) * (total_sum)))\
) \
* ((temp_0) * (temp_1)) \
- ((temp_1) * (1) + (temp_0) * (-(1)));"
- (temp_1 - (temp_0));"
.to_string();

assert_eq!(eval.format_constraints(), expected);
Expand Down

0 comments on commit 56185a3

Please sign in to comment.