diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index 4f3e82612..f26cbb609 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -2,6 +2,7 @@ mod monomial; use std::{ cmp::max, + fmt::Display, mem::MaybeUninit, ops::{Add, Deref, Mul, Neg, Sub}, }; @@ -499,13 +500,143 @@ impl> From for Expression } } +impl Display for Expression { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut wtns = vec![]; + write!(f, "{}", fmt::expr(self, &mut wtns, false)) + } +} + +pub mod fmt { + use super::*; + use itertools::Itertools; + use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension; + use std::fmt::Write; + + pub fn expr( + expression: &Expression, + wtns: &mut Vec, + add_prn_sum: bool, + ) -> String { + match expression { + Expression::WitIn(wit_in) => { + wtns.push(*wit_in); + format!("WitIn({})", wit_in) + } + Expression::Challenge(id, pow, scaler, offset) => { + if *pow == 1 && *scaler == 1.into() && *offset == 0.into() { + format!("Challenge({})", id) + } else { + let mut s = String::new(); + if *scaler != 1.into() { + write!(s, "{}*", field(scaler)).unwrap(); + } + write!(s, "Challenge({})", id,).unwrap(); + if *pow > 1 { + write!(s, "^{}", pow).unwrap(); + } + if *offset != 0.into() { + write!(s, "+{}", field(offset)).unwrap(); + } + s + } + } + Expression::Constant(constant) => base_field::(constant, true).to_string(), + Expression::Fixed(fixed) => format!("{:?}", fixed), + Expression::Sum(left, right) => { + let s = format!("{} + {}", expr(left, wtns, false), expr(right, wtns, false)); + if add_prn_sum { format!("({})", s) } else { s } + } + Expression::Product(left, right) => { + format!("{} * {}", expr(left, wtns, true), expr(right, wtns, true)) + } + Expression::ScaledSum(x, a, b) => { + let s = format!( + "{} * {} + {}", + expr(a, wtns, true), + expr(x, wtns, true), + expr(b, wtns, false) + ); + if add_prn_sum { format!("({})", s) } else { s } + } + } + } + + pub fn field(field: &E) -> String { + let name = format!("{:?}", field); + let name = name.split('(').next().unwrap_or("ExtensionField"); + + let data = field + .as_bases() + .iter() + .map(|b| base_field::(b, false)) + .collect::>(); + let only_one_limb = field.as_bases()[1..].iter().all(|&x| x == 0.into()); + + if only_one_limb { + data[0].to_string() + } else { + format!("{name}[{}]", data.join(",")) + } + } + + pub fn base_field(base_field: &E::BaseField, add_prn: bool) -> String { + let value = base_field.to_canonical_u64(); + + if value > E::BaseField::MODULUS_U64 - u16::MAX as u64 { + // beautiful format for negative number > -65536 + prn(format!("-{}", E::BaseField::MODULUS_U64 - value), add_prn) + } else if value < u16::MAX as u64 { + format!("{value}") + } else { + // hex + if value > E::BaseField::MODULUS_U64 - (u32::MAX as u64 + u16::MAX as u64) { + prn( + format!("-{:#x}", E::BaseField::MODULUS_U64 - value), + add_prn, + ) + } else { + format!("{value:#x}") + } + } + } + + pub fn prn(s: String, add_prn: bool) -> String { + if add_prn { format!("({})", s) } else { s } + } + + #[allow(dead_code)] + pub fn wtns( + wtns: &[WitnessId], + wits_in: &[ArcMultilinearExtension], + inst_id: usize, + wits_in_name: &[String], + ) -> String { + wtns.iter() + .sorted() + .map(|wt_id| { + let wit = &wits_in[*wt_id as usize]; + let name = &wits_in_name[*wt_id as usize]; + let value_fmt = if let Some(e) = wit.get_ext_field_vec_optn() { + field(&e[inst_id]) + } else if let Some(bf) = wit.get_base_field_vec_optn() { + base_field::(&bf[inst_id], true) + } else { + "Unknown".to_string() + }; + format!(" WitIn({wt_id})={value_fmt} {name:?}") + }) + .join("\n") + } +} + #[cfg(test)] mod tests { use goldilocks::GoldilocksExt2; use crate::circuit_builder::{CircuitBuilder, ConstraintSystem}; - use super::{Expression, ToExpr}; + use super::{fmt, Expression, ToExpr}; use ff::Field; #[test] @@ -629,4 +760,38 @@ mod tests { * (Into::>::into(2usize) + y.expr()); assert!(!expr.is_monomial_form()); } + + #[test] + fn test_fmt_expr_challenge_1() { + let a = Expression::::Challenge(0, 2, 3.into(), 4.into()); + let b = Expression::::Challenge(0, 5, 6.into(), 7.into()); + + let mut wtns_acc = vec![]; + let s = fmt::expr(&(a * b), &mut wtns_acc, false); + + assert_eq!( + s, + "18*Challenge(0)^7+28 + 21*Challenge(0)^2 + 24*Challenge(0)^5" + ); + } + + #[test] + fn test_fmt_expr_challenge_2() { + let a = Expression::::Challenge(0, 1, 1.into(), 0.into()); + let b = Expression::::Challenge(0, 1, 1.into(), 0.into()); + + let mut wtns_acc = vec![]; + let s = fmt::expr(&(a * b), &mut wtns_acc, false); + + assert_eq!(s, "Challenge(0)^2"); + } + + #[test] + fn test_fmt_expr_wtns_acc_1() { + let expr = Expression::::WitIn(0); + let mut wtns_acc = vec![]; + let s = fmt::expr(&expr, &mut wtns_acc, false); + assert_eq!(s, "WitIn(0)"); + assert_eq!(wtns_acc, vec![0]); + } } diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index e5f4be920..ca5209be5 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -1,9 +1,8 @@ use super::utils::{eval_by_expr, wit_infer_by_expr}; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - expression::Expression, + expression::{fmt, Expression}, scheme::utils::eval_by_expr_with_fixed, - structs::WitnessId, tables::{ AndTable, LtuTable, OpsTable, OrTable, ProgramTableCircuit, RangeTable, TableCircuit, U16Table, U5Table, U8Table, XorTable, @@ -14,12 +13,10 @@ use base64::{engine::general_purpose::STANDARD_NO_PAD, Engine}; use ceno_emul::{ByteAddr, CENO_PLATFORM}; use ff_ext::ExtensionField; use generic_static::StaticTypeMap; -use goldilocks::SmallField; use itertools::Itertools; use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension; use std::{ collections::HashSet, - fmt::Write, fs::{self, File}, hash::Hash, io::{BufReader, ErrorKind}, @@ -139,9 +136,9 @@ impl MockProverError { name, inst_id, } => { - let expression_fmt = fmt_expr(expression, &mut wtns, false); - let wtns_fmt = fmt_wtns::(&wtns, wits_in, *inst_id, wits_in_name); - let eval_fmt = fmt_field::(evaluated); + let expression_fmt = fmt::expr(expression, &mut wtns, false); + let wtns_fmt = fmt::wtns(&wtns, wits_in, *inst_id, wits_in_name); + let eval_fmt = fmt::field(evaluated); println!( "\nAssertZeroError {name:?}: Evaluated expression is not zero\n\ Expression: {expression_fmt}\n\ @@ -157,11 +154,11 @@ impl MockProverError { name, inst_id, } => { - let left_expression_fmt = fmt_expr(left_expression, &mut wtns, false); - let right_expression_fmt = fmt_expr(right_expression, &mut wtns, false); - let wtns_fmt = fmt_wtns::(&wtns, wits_in, *inst_id, wits_in_name); - let left_eval_fmt = fmt_field::(left); - let right_eval_fmt = fmt_field::(right); + let left_expression_fmt = fmt::expr(left_expression, &mut wtns, false); + let right_expression_fmt = fmt::expr(right_expression, &mut wtns, false); + let wtns_fmt = fmt::wtns(&wtns, wits_in, *inst_id, wits_in_name); + let left_eval_fmt = fmt::field(left); + let right_eval_fmt = fmt::field(right); println!( "\nAssertEqualError {name:?}\n\ Left: {left_eval_fmt} != Right: {right_eval_fmt}\n\ @@ -176,9 +173,9 @@ impl MockProverError { name, inst_id, } => { - let expression_fmt = fmt_expr(expression, &mut wtns, false); - let wtns_fmt = fmt_wtns::(&wtns, wits_in, *inst_id, wits_in_name); - let eval_fmt = fmt_field::(evaluated); + let expression_fmt = fmt::expr(expression, &mut wtns, false); + let wtns_fmt = fmt::wtns(&wtns, wits_in, *inst_id, wits_in_name); + let eval_fmt = fmt::field(evaluated); println!( "\nLookupError {name:#?}: Evaluated expression does not exist in T vector\n\ Expression: {expression_fmt}\n\ @@ -190,129 +187,6 @@ impl MockProverError { } } -fn fmt_expr( - expression: &Expression, - wtns: &mut Vec, - add_prn_sum: bool, -) -> String { - match expression { - Expression::WitIn(wit_in) => { - wtns.push(*wit_in); - format!("WitIn({})", wit_in) - } - Expression::Challenge(id, pow, scaler, offset) => { - if *pow == 1 && *scaler == 1.into() && *offset == 0.into() { - format!("Challenge({})", id) - } else { - let mut s = String::new(); - if *scaler != 1.into() { - write!(s, "{}*", fmt_field(scaler)).unwrap(); - } - write!(s, "Challenge({})", id,).unwrap(); - if *pow > 1 { - write!(s, "^{}", pow).unwrap(); - } - if *offset != 0.into() { - write!(s, "+{}", fmt_field(offset)).unwrap(); - } - s - } - } - Expression::Constant(constant) => fmt_base_field::(constant, true).to_string(), - Expression::Fixed(fixed) => format!("{:?}", fixed), - Expression::Sum(left, right) => { - let s = format!( - "{} + {}", - fmt_expr(left, wtns, false), - fmt_expr(right, wtns, false) - ); - if add_prn_sum { format!("({})", s) } else { s } - } - Expression::Product(left, right) => { - format!( - "{} * {}", - fmt_expr(left, wtns, true), - fmt_expr(right, wtns, true) - ) - } - Expression::ScaledSum(x, a, b) => { - let s = format!( - "{} * {} + {}", - fmt_expr(a, wtns, true), - fmt_expr(x, wtns, true), - fmt_expr(b, wtns, false) - ); - if add_prn_sum { format!("({})", s) } else { s } - } - } -} - -fn fmt_field(field: &E) -> String { - let name = format!("{:?}", field); - let name = name.split('(').next().unwrap_or("ExtensionField"); - - let data = field - .as_bases() - .iter() - .map(|b| fmt_base_field::(b, false)) - .collect::>(); - let only_one_limb = field.as_bases()[1..].iter().all(|&x| x == 0.into()); - - if only_one_limb { - data[0].to_string() - } else { - format!("{name}[{}]", data.join(",")) - } -} - -fn fmt_base_field(base_field: &E::BaseField, add_prn: bool) -> String { - let value = base_field.to_canonical_u64(); - - if value > E::BaseField::MODULUS_U64 - u16::MAX as u64 { - // beautiful format for negative number > -65536 - fmt_prn(format!("-{}", E::BaseField::MODULUS_U64 - value), add_prn) - } else if value < u16::MAX as u64 { - format!("{value}") - } else { - // hex - if value > E::BaseField::MODULUS_U64 - (u32::MAX as u64 + u16::MAX as u64) { - fmt_prn( - format!("-{:#x}", E::BaseField::MODULUS_U64 - value), - add_prn, - ) - } else { - format!("{value:#x}") - } - } -} - -fn fmt_prn(s: String, add_prn: bool) -> String { - if add_prn { format!("({})", s) } else { s } -} - -fn fmt_wtns( - wtns: &[WitnessId], - wits_in: &[ArcMultilinearExtension], - inst_id: usize, - wits_in_name: &[String], -) -> String { - wtns.iter() - .sorted() - .map(|wt_id| { - let wit = &wits_in[*wt_id as usize]; - let name = &wits_in_name[*wt_id as usize]; - let value_fmt = if let Some(e) = wit.get_ext_field_vec_optn() { - fmt_field(&e[inst_id]) - } else if let Some(bf) = wit.get_base_field_vec_optn() { - fmt_base_field::(&bf[inst_id], true) - } else { - "Unknown".to_string() - }; - format!(" WitIn({wt_id})={value_fmt} {name:?}") - }) - .join("\n") -} - pub(crate) struct MockProver { _phantom: PhantomData, } @@ -587,40 +461,6 @@ mod tests { use goldilocks::{Goldilocks, GoldilocksExt2}; use multilinear_extensions::mle::{IntoMLE, IntoMLEs}; - #[test] - fn test_fmt_expr_challenge_1() { - let a = Expression::::Challenge(0, 2, 3.into(), 4.into()); - let b = Expression::::Challenge(0, 5, 6.into(), 7.into()); - - let mut wtns_acc = vec![]; - let s = fmt_expr(&(a * b), &mut wtns_acc, false); - - assert_eq!( - s, - "18*Challenge(0)^7+28 + 21*Challenge(0)^2 + 24*Challenge(0)^5" - ); - } - - #[test] - fn test_fmt_expr_challenge_2() { - let a = Expression::::Challenge(0, 1, 1.into(), 0.into()); - let b = Expression::::Challenge(0, 1, 1.into(), 0.into()); - - let mut wtns_acc = vec![]; - let s = fmt_expr(&(a * b), &mut wtns_acc, false); - - assert_eq!(s, "Challenge(0)^2"); - } - - #[test] - fn test_fmt_expr_wtns_acc_1() { - let expr = Expression::::WitIn(0); - let mut wtns_acc = vec![]; - let s = fmt_expr(&expr, &mut wtns_acc, false); - assert_eq!(s, "WitIn(0)"); - assert_eq!(wtns_acc, vec![0]); - } - #[derive(Debug)] #[allow(dead_code)] struct AssertZeroCircuit {