diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 66e4a41f6..b03418f9f 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -225,10 +225,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { NR: Into, N: FnOnce() -> NR, { - self.namespace( - || "require_one", - |cb| cb.cs.require_zero(name_fn, Expression::from(1) - expr), - ) + self.namespace(|| "require_one", |cb| cb.cs.require_zero(name_fn, 1 - expr)) } pub fn condition_require_equal( @@ -260,7 +257,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { when_true: &Expression, when_false: &Expression, ) -> Expression { - cond.clone() * when_true.clone() + (1 - cond.clone()) * when_false.clone() + cond * when_true + (1 - cond) * when_false } pub(crate) fn assert_ux( @@ -346,10 +343,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { { self.namespace( || "assert_bit", - |cb| { - cb.cs - .require_zero(name_fn, expr.clone() * (Expression::ONE - expr)) - }, + |cb| cb.cs.require_zero(name_fn, &expr * (1 - &expr)), ) } @@ -417,14 +411,10 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { let is_eq = self.create_witin(|| "is_eq"); let diff_inverse = self.create_witin(|| "diff_inverse"); + self.require_zero(|| "is equal", is_eq.expr() * &lhs - is_eq.expr() * &rhs)?; self.require_zero( || "is equal", - is_eq.expr().clone() * lhs.clone() - is_eq.expr() * rhs.clone(), - )?; - self.require_zero( - || "is equal", - Expression::from(1) - is_eq.expr().clone() - diff_inverse.expr() * lhs - + diff_inverse.expr() * rhs, + 1 - is_eq.expr() - diff_inverse.expr() * lhs + diff_inverse.expr() * rhs, )?; Ok((is_eq, diff_inverse)) diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index efbf08502..8e489f781 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -5,7 +5,7 @@ use std::{ fmt::Display, iter::Sum, mem::MaybeUninit, - ops::{Add, Deref, Mul, Neg, Sub}, + ops::{Add, AddAssign, Deref, Mul, MulAssign, Neg, Sub, SubAssign}, }; use ff::Field; @@ -315,6 +315,24 @@ impl Add for Expression { } } +macro_rules! binop_assign_instances { + ($op_assign: ident, $fun_assign: ident, $op: ident, $fun: ident) => { + impl $op_assign for Expression + where + Expression: $op>, + { + fn $fun_assign(&mut self, rhs: Rhs) { + // TODO: consider in-place? + *self = self.clone().$fun(rhs); + } + } + }; +} + +binop_assign_instances!(AddAssign, add_assign, Add, add); +binop_assign_instances!(SubAssign, sub_assign, Sub, sub); +binop_assign_instances!(MulAssign, mul_assign, Mul, mul); + impl Sum for Expression { fn sum>>(iter: I) -> Expression { iter.fold(Expression::ZERO, |acc, x| acc + x) @@ -442,7 +460,64 @@ impl Sub for Expression { } } -macro_rules! binop_instances { +/// Instances for binary operations that mix Expression and &Expression +macro_rules! ref_binop_instances { + ($op: ident, $fun: ident) => { + impl $op<&Expression> for Expression { + type Output = Expression; + + fn $fun(self, rhs: &Expression) -> Expression { + self.$fun(rhs.clone()) + } + } + + impl $op> for &Expression { + type Output = Expression; + + fn $fun(self, rhs: Expression) -> Expression { + self.clone().$fun(rhs) + } + } + + impl $op<&Expression> for &Expression { + type Output = Expression; + + fn $fun(self, rhs: &Expression) -> Expression { + self.clone().$fun(rhs.clone()) + } + } + + // for mutable references + impl $op<&mut Expression> for Expression { + type Output = Expression; + + fn $fun(self, rhs: &mut Expression) -> Expression { + self.$fun(rhs.clone()) + } + } + + impl $op> for &mut Expression { + type Output = Expression; + + fn $fun(self, rhs: Expression) -> Expression { + self.clone().$fun(rhs) + } + } + + impl $op<&mut Expression> for &mut Expression { + type Output = Expression; + + fn $fun(self, rhs: &mut Expression) -> Expression { + self.clone().$fun(rhs.clone()) + } + } + }; +} +ref_binop_instances!(Add, add); +ref_binop_instances!(Sub, sub); +ref_binop_instances!(Mul, mul); + +macro_rules! mixed_binop_instances { ($op: ident, $fun: ident, ($($t:ty),*)) => { $(impl $op> for $t { type Output = Expression; @@ -458,21 +533,38 @@ macro_rules! binop_instances { fn $fun(self, rhs: $t) -> Expression { self.$fun(Expression::::from(rhs)) } - })* + } + + impl $op<&Expression> for $t { + type Output = Expression; + + fn $fun(self, rhs: &Expression) -> Expression { + Expression::::from(self).$fun(rhs) + } + } + + impl $op<$t> for &Expression { + type Output = Expression; + + fn $fun(self, rhs: $t) -> Expression { + self.$fun(Expression::::from(rhs)) + } + } + )* }; } -binop_instances!( +mixed_binop_instances!( Add, add, (u8, u16, u32, u64, usize, i8, i16, i32, i64, i128, isize) ); -binop_instances!( +mixed_binop_instances!( Sub, sub, (u8, u16, u32, u64, usize, i8, i16, i32, i64, i128, isize) ); -binop_instances!( +mixed_binop_instances!( Mul, mul, (u8, u16, u32, u64, usize, i8, i16, i32, i64, i128, isize) @@ -686,6 +778,20 @@ impl> ToExpr for F { } } +macro_rules! impl_from_via_ToExpr { + ($($t:ty),*) => { + $( + impl From<$t> for Expression { + fn from(value: $t) -> Self { + value.expr() + } + } + )* + }; +} +impl_from_via_ToExpr!(WitIn, Fixed, Instance); +impl_from_via_ToExpr!(&WitIn, &Fixed, &Instance); + // Implement From trait for unsigned types of at most 64 bits macro_rules! impl_from_unsigned { ($($t:ty),*) => { @@ -880,8 +986,7 @@ mod tests { // scaledsum * challenge // 3 * x + 2 - let expr: Expression = - Into::>::into(3usize) * x.expr() + Into::>::into(2usize); + let expr: Expression = 3 * x.expr() + 2; // c^3 + 1 let c = Expression::Challenge(0, 3, 1.into(), 1.into()); // res @@ -897,7 +1002,7 @@ mod tests { // constant * witin // 3 * x - let expr: Expression = Into::>::into(3usize) * x.expr(); + let expr: Expression = 3 * x.expr(); assert_eq!( expr, Expression::ScaledSum( @@ -947,35 +1052,30 @@ mod tests { let z = cb.create_witin(|| "z"); // scaledsum * challenge // 3 * x + 2 - let expr: Expression = - Into::>::into(3usize) * x.expr() + Into::>::into(2usize); + let expr: Expression = 3 * x.expr() + 2; assert!(expr.is_monomial_form()); // 2 product term - let expr: Expression = Into::>::into(3usize) * x.expr() * y.expr() - + Into::>::into(2usize) * x.expr(); + let expr: Expression = 3 * x.expr() * y.expr() + 2 * x.expr(); assert!(expr.is_monomial_form()); // complex linear operation // (2c + 3) * x * y - 6z let expr: Expression = - Expression::Challenge(0, 1, 2.into(), 3.into()) * x.expr() * y.expr() - - Into::>::into(6usize) * z.expr(); + Expression::Challenge(0, 1, 2_u64.into(), 3_u64.into()) * x.expr() * y.expr() + - 6 * z.expr(); assert!(expr.is_monomial_form()); // complex linear operation // (2c + 3) * x * y - 6z let expr: Expression = - Expression::Challenge(0, 1, 2.into(), 3.into()) * x.expr() * y.expr() - - Into::>::into(6usize) * z.expr(); + Expression::Challenge(0, 1, 2_u64.into(), 3_u64.into()) * x.expr() * y.expr() + - 6 * z.expr(); assert!(expr.is_monomial_form()); // complex linear operation // (2 * x + 3) * 3 + 6 * 8 - let expr: Expression = (Into::>::into(2usize) * x.expr() - + Into::>::into(3usize)) - * Into::>::into(3usize) - + Into::>::into(6usize) * Into::>::into(8usize); + let expr: Expression = (2 * x.expr() + 3) * 3 + 6 * 8; assert!(expr.is_monomial_form()); } @@ -988,8 +1088,7 @@ mod tests { let y = cb.create_witin(|| "y"); // scaledsum * challenge // (x + 1) * (y + 1) - let expr: Expression = (Into::>::into(1usize) + x.expr()) - * (Into::>::into(2usize) + y.expr()); + let expr: Expression = (1 + x.expr()) * (2 + y.expr()); assert!(!expr.is_monomial_form()); } diff --git a/ceno_zkvm/src/expression/monomial.rs b/ceno_zkvm/src/expression/monomial.rs index 4c73c557b..da16ef753 100644 --- a/ceno_zkvm/src/expression/monomial.rs +++ b/ceno_zkvm/src/expression/monomial.rs @@ -39,7 +39,7 @@ impl Expression { for a in a { for b in &b { res.push(Term { - coeff: a.coeff.clone() * b.coeff.clone(), + coeff: &a.coeff * &b.coeff, vars: a.vars.iter().chain(b.vars.iter()).cloned().collect(), }); } @@ -54,7 +54,7 @@ impl Expression { for x in x { for a in &a { res.push(Term { - coeff: x.coeff.clone() * a.coeff.clone(), + coeff: &x.coeff * &a.coeff, vars: x.vars.iter().chain(a.vars.iter()).cloned().collect(), }); } diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 64ac91b12..023e3b60e 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -381,13 +381,13 @@ impl MemAddr { /// Represent the address aligned to 2 bytes. pub fn expr_align2(&self) -> AddressExpr { - self.addr.address_expr() - self.low_bit_exprs()[0].clone() + self.addr.address_expr() - &self.low_bit_exprs()[0] } /// Represent the address aligned to 4 bytes. pub fn expr_align4(&self) -> AddressExpr { let low_bits = self.low_bit_exprs(); - self.addr.address_expr() - low_bits[1].clone() * 2 - low_bits[0].clone() + self.addr.address_expr() - &low_bits[1] * 2 - &low_bits[0] } /// Expressions of the low bits of the address, LSB-first: [bit_0, bit_1]. @@ -425,7 +425,7 @@ impl MemAddr { .invert() .unwrap() .expr(); - let mid_u14 = (limbs[0].clone() - low_sum) * shift_right; + let mid_u14 = (&limbs[0] - low_sum) * shift_right; cb.assert_ux::<_, _, 14>(|| "mid_u14", mid_u14)?; // Range check the high limb. diff --git a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs index 8c6a36998..64f2a4c40 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs @@ -80,7 +80,7 @@ impl MemWordChange { let u8_base_inv = E::BaseField::from(1 << 8).invert().unwrap(); cb.assert_ux::<_, _, 8>( || "rs2_limb[0].le_bytes[1]", - u8_base_inv.expr() * (rs2_limbs[0].clone() - rs2_limb_bytes[0].expr()), + u8_base_inv.expr() * (&rs2_limbs[0] - rs2_limb_bytes[0].expr()), )?; // alloc a new witIn to cache degree 2 expression @@ -125,8 +125,8 @@ impl MemWordChange { // degree 2 expression low_bits[1].clone(), expected_change.expr(), - (1 << 16) * (rs2_limbs[0].clone() - prev_limbs[1].clone()), - rs2_limbs[0].clone() - prev_limbs[0].clone(), + (1 << 16) * (&rs2_limbs[0] - &prev_limbs[1]), + &rs2_limbs[0] - &prev_limbs[0], )?; Ok(MemWordChange { diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 81ebfcf78..cc3766014 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -446,7 +446,7 @@ mod tests { let expected_final_product: E = last_layer .iter() .map(|f| match f.evaluations() { - FieldType::Ext(e) => e.iter().cloned().reduce(|a, b| a * b).unwrap(), + FieldType::Ext(e) => e.iter().copied().reduce(|a, b| a * b).unwrap(), _ => unreachable!(""), }) .product(); diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index 23b55606c..7173af3bd 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -169,10 +169,7 @@ impl UIntLimbs { .assert_ux::<_, _, C>(|| "range check", w.expr()) .unwrap(); circuit_builder - .require_zero( - || "create_witin_from_expr", - w.expr() - expr_limbs[i].clone(), - ) + .require_zero(|| "create_witin_from_expr", w.expr() - &expr_limbs[i]) .unwrap(); w }) @@ -299,7 +296,7 @@ impl UIntLimbs { chunk .iter() .zip(shift_pows.iter()) - .map(|(limb, shift)| shift.clone() * limb.expr()) + .map(|(limb, shift)| shift * limb.expr()) .reduce(|a, b| a + b) .unwrap() }) @@ -317,7 +314,7 @@ impl UIntLimbs { let shift_pows = { let mut shift_pows = Vec::with_capacity(k); shift_pows.push(Expression::Constant(E::BaseField::ONE)); - (0..k - 1).for_each(|_| shift_pows.push(shift_pows.last().unwrap().clone() * (1 << 8))); + (0..k - 1).for_each(|_| shift_pows.push(shift_pows.last().unwrap() * (1 << 8))); shift_pows }; let split_limbs = x @@ -334,7 +331,7 @@ impl UIntLimbs { let combined_limb = limbs .iter() .zip(shift_pows.iter()) - .map(|(limb, shift)| shift.clone() * limb.clone()) + .map(|(limb, shift)| shift * limb) .reduce(|a, b| a + b) .unwrap(); @@ -508,11 +505,10 @@ impl UIntLimbs { /// Get an Expression from the limbs, unsafe if Uint value exceeds field limit pub fn value(&self) -> Expression { - let base = Expression::from(1 << C); self.expr() .into_iter() .rev() - .reduce(|sum, limb| sum * base.clone() + limb) + .reduce(|sum, limb| sum * (1 << C) + limb) .unwrap() } @@ -626,7 +622,7 @@ impl UIntLimbs<32, 8, E> { let u16_limbs = u8_limbs .chunks(2) .map(|chunk| { - let (a, b) = (chunk[0].clone(), chunk[1].clone()); + let (a, b) = (&chunk[0], &chunk[1]); a + b * 256 }) .collect_vec(); diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 01413f6a3..2bba30371 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -179,9 +179,9 @@ impl UIntLimbs { let idx = i + j; if idx < c_limbs.len() { if result_c.get(idx).is_none() { - result_c.push(a.clone() * b.clone()); + result_c.push(a * b); } else { - result_c[idx] = result_c[idx].clone() + a.clone() * b.clone(); + result_c[idx] += a * b; } } }); diff --git a/ceno_zkvm/src/virtual_polys.rs b/ceno_zkvm/src/virtual_polys.rs index 9aaafdf4f..75c930191 100644 --- a/ceno_zkvm/src/virtual_polys.rs +++ b/ceno_zkvm/src/virtual_polys.rs @@ -202,8 +202,7 @@ mod tests { let mut virtual_polys = VirtualPolynomials::new(1, 0); // 3xy + 2y - let expr: Expression = - Expression::from(3) * x.expr() * y.expr() + Expression::from(2) * y.expr(); + let expr: Expression = 3 * x.expr() * y.expr() + 2 * y.expr(); let distrinct_zerocheck_terms_set = virtual_polys.add_mle_list_by_expr( None, @@ -216,7 +215,7 @@ mod tests { assert!(virtual_polys.degree() == 2); // 3x^3 - let expr: Expression = Expression::from(3) * x.expr() * x.expr() * x.expr(); + let expr: Expression = 3 * x.expr() * x.expr() * x.expr(); let distrinct_zerocheck_terms_set = virtual_polys.add_mle_list_by_expr( None, wits_in.iter().collect_vec(),