Skip to content

Commit

Permalink
move fmt expr utils to expression module and impl Display
Browse files Browse the repository at this point in the history
mock prover fmt - rebase it to fmt
  • Loading branch information
zemse committed Oct 2, 2024
1 parent f821585 commit 27f4e41
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 173 deletions.
167 changes: 166 additions & 1 deletion ceno_zkvm/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod monomial;

use std::{
cmp::max,
fmt::Display,
mem::MaybeUninit,
ops::{Add, Deref, Mul, Neg, Sub},
};
Expand Down Expand Up @@ -499,13 +500,143 @@ impl<F: SmallField, E: ExtensionField<BaseField = F>> From<usize> for Expression
}
}

impl<E: ExtensionField> Display for Expression<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut wtns = vec![];
write!(f, "{}", fmt::expr(self, &mut wtns, false))
}
}

pub mod fmt {
use super::*;
use itertools::Itertools;
use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension;
use std::fmt::Write;

pub fn expr<E: ExtensionField>(
expression: &Expression<E>,
wtns: &mut Vec<WitnessId>,
add_prn_sum: bool,
) -> String {
match expression {
Expression::WitIn(wit_in) => {
wtns.push(*wit_in);
format!("WitIn({})", wit_in)
}
Expression::Challenge(id, pow, scaler, offset) => {
if *pow == 1 && *scaler == 1.into() && *offset == 0.into() {
format!("Challenge({})", id)
} else {
let mut s = String::new();
if *scaler != 1.into() {
write!(s, "{}*", field(scaler)).unwrap();
}
write!(s, "Challenge({})", id,).unwrap();
if *pow > 1 {
write!(s, "^{}", pow).unwrap();
}
if *offset != 0.into() {
write!(s, "+{}", field(offset)).unwrap();
}
s
}
}
Expression::Constant(constant) => base_field::<E>(constant, true).to_string(),
Expression::Fixed(fixed) => format!("{:?}", fixed),
Expression::Sum(left, right) => {
let s = format!("{} + {}", expr(left, wtns, false), expr(right, wtns, false));
if add_prn_sum { format!("({})", s) } else { s }
}
Expression::Product(left, right) => {
format!("{} * {}", expr(left, wtns, true), expr(right, wtns, true))
}
Expression::ScaledSum(x, a, b) => {
let s = format!(
"{} * {} + {}",
expr(a, wtns, true),
expr(x, wtns, true),
expr(b, wtns, false)
);
if add_prn_sum { format!("({})", s) } else { s }
}
}
}

pub fn field<E: ExtensionField>(field: &E) -> String {
let name = format!("{:?}", field);
let name = name.split('(').next().unwrap_or("ExtensionField");

let data = field
.as_bases()
.iter()
.map(|b| base_field::<E>(b, false))
.collect::<Vec<String>>();
let only_one_limb = field.as_bases()[1..].iter().all(|&x| x == 0.into());

if only_one_limb {
data[0].to_string()
} else {
format!("{name}[{}]", data.join(","))
}
}

pub fn base_field<E: ExtensionField>(base_field: &E::BaseField, add_prn: bool) -> String {
let value = base_field.to_canonical_u64();

if value > E::BaseField::MODULUS_U64 - u16::MAX as u64 {
// beautiful format for negative number > -65536
prn(format!("-{}", E::BaseField::MODULUS_U64 - value), add_prn)
} else if value < u16::MAX as u64 {
format!("{value}")
} else {
// hex
if value > E::BaseField::MODULUS_U64 - (u32::MAX as u64 + u16::MAX as u64) {
prn(
format!("-{:#x}", E::BaseField::MODULUS_U64 - value),
add_prn,
)
} else {
format!("{value:#x}")
}
}
}

pub fn prn(s: String, add_prn: bool) -> String {
if add_prn { format!("({})", s) } else { s }
}

#[allow(dead_code)]
pub fn wtns<E: ExtensionField>(
wtns: &[WitnessId],
wits_in: &[ArcMultilinearExtension<E>],
inst_id: usize,
wits_in_name: &[String],
) -> String {
wtns.iter()
.sorted()
.map(|wt_id| {
let wit = &wits_in[*wt_id as usize];
let name = &wits_in_name[*wt_id as usize];
let value_fmt = if let Some(e) = wit.get_ext_field_vec_optn() {
field(&e[inst_id])
} else if let Some(bf) = wit.get_base_field_vec_optn() {
base_field::<E>(&bf[inst_id], true)
} else {
"Unknown".to_string()
};
format!(" WitIn({wt_id})={value_fmt} {name:?}")
})
.join("\n")
}
}

#[cfg(test)]
mod tests {
use goldilocks::GoldilocksExt2;

use crate::circuit_builder::{CircuitBuilder, ConstraintSystem};

use super::{Expression, ToExpr};
use super::{fmt, Expression, ToExpr};
use ff::Field;

#[test]
Expand Down Expand Up @@ -629,4 +760,38 @@ mod tests {
* (Into::<Expression<E>>::into(2usize) + y.expr());
assert!(!expr.is_monomial_form());
}

#[test]
fn test_fmt_expr_challenge_1() {
let a = Expression::<GoldilocksExt2>::Challenge(0, 2, 3.into(), 4.into());
let b = Expression::<GoldilocksExt2>::Challenge(0, 5, 6.into(), 7.into());

let mut wtns_acc = vec![];
let s = fmt::expr(&(a * b), &mut wtns_acc, false);

assert_eq!(
s,
"18*Challenge(0)^7+28 + 21*Challenge(0)^2 + 24*Challenge(0)^5"
);
}

#[test]
fn test_fmt_expr_challenge_2() {
let a = Expression::<GoldilocksExt2>::Challenge(0, 1, 1.into(), 0.into());
let b = Expression::<GoldilocksExt2>::Challenge(0, 1, 1.into(), 0.into());

let mut wtns_acc = vec![];
let s = fmt::expr(&(a * b), &mut wtns_acc, false);

assert_eq!(s, "Challenge(0)^2");
}

#[test]
fn test_fmt_expr_wtns_acc_1() {
let expr = Expression::<GoldilocksExt2>::WitIn(0);
let mut wtns_acc = vec![];
let s = fmt::expr(&expr, &mut wtns_acc, false);
assert_eq!(s, "WitIn(0)");
assert_eq!(wtns_acc, vec![0]);
}
}
Loading

0 comments on commit 27f4e41

Please sign in to comment.