Skip to content

Commit

Permalink
More 'mixed' instances to simplify Expression arithmetic (#479)
Browse files Browse the repository at this point in the history
We implement convenient instances to add / multiply / subtract
references to expressions.

Nothing changes under the hood, but we move some cloning from the caller
to the implementation of arithmetic operations.
  • Loading branch information
matthiasgoergens authored Nov 1, 2024
1 parent f97a8ca commit 684a762
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 62 deletions.
20 changes: 5 additions & 15 deletions ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
NR: Into<String>,
N: FnOnce() -> NR,
{
self.namespace(
|| "require_one",
|cb| cb.cs.require_zero(name_fn, Expression::from(1) - expr),
)
self.namespace(|| "require_one", |cb| cb.cs.require_zero(name_fn, 1 - expr))
}

pub fn condition_require_equal<NR, N>(
Expand Down Expand Up @@ -260,7 +257,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
when_true: &Expression<E>,
when_false: &Expression<E>,
) -> Expression<E> {
cond.clone() * when_true.clone() + (1 - cond.clone()) * when_false.clone()
cond * when_true + (1 - cond) * when_false
}

pub(crate) fn assert_ux<NR, N, const C: usize>(
Expand Down Expand Up @@ -346,10 +343,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
{
self.namespace(
|| "assert_bit",
|cb| {
cb.cs
.require_zero(name_fn, expr.clone() * (Expression::ONE - expr))
},
|cb| cb.cs.require_zero(name_fn, &expr * (1 - &expr)),
)
}

Expand Down Expand Up @@ -417,14 +411,10 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
let is_eq = self.create_witin(|| "is_eq");
let diff_inverse = self.create_witin(|| "diff_inverse");

self.require_zero(|| "is equal", is_eq.expr() * &lhs - is_eq.expr() * &rhs)?;
self.require_zero(
|| "is equal",
is_eq.expr().clone() * lhs.clone() - is_eq.expr() * rhs.clone(),
)?;
self.require_zero(
|| "is equal",
Expression::from(1) - is_eq.expr().clone() - diff_inverse.expr() * lhs
+ diff_inverse.expr() * rhs,
1 - is_eq.expr() - diff_inverse.expr() * lhs + diff_inverse.expr() * rhs,
)?;

Ok((is_eq, diff_inverse))
Expand Down
145 changes: 122 additions & 23 deletions ceno_zkvm/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
fmt::Display,
iter::Sum,
mem::MaybeUninit,
ops::{Add, Deref, Mul, Neg, Sub},
ops::{Add, AddAssign, Deref, Mul, MulAssign, Neg, Sub, SubAssign},
};

use ff::Field;
Expand Down Expand Up @@ -315,6 +315,24 @@ impl<E: ExtensionField> Add for Expression<E> {
}
}

macro_rules! binop_assign_instances {
($op_assign: ident, $fun_assign: ident, $op: ident, $fun: ident) => {
impl<E: ExtensionField, Rhs> $op_assign<Rhs> for Expression<E>
where
Expression<E>: $op<Rhs, Output = Expression<E>>,
{
fn $fun_assign(&mut self, rhs: Rhs) {
// TODO: consider in-place?
*self = self.clone().$fun(rhs);
}
}
};
}

binop_assign_instances!(AddAssign, add_assign, Add, add);
binop_assign_instances!(SubAssign, sub_assign, Sub, sub);
binop_assign_instances!(MulAssign, mul_assign, Mul, mul);

impl<E: ExtensionField> Sum for Expression<E> {
fn sum<I: Iterator<Item = Expression<E>>>(iter: I) -> Expression<E> {
iter.fold(Expression::ZERO, |acc, x| acc + x)
Expand Down Expand Up @@ -442,7 +460,64 @@ impl<E: ExtensionField> Sub for Expression<E> {
}
}

macro_rules! binop_instances {
/// Instances for binary operations that mix Expression and &Expression
macro_rules! ref_binop_instances {
($op: ident, $fun: ident) => {
impl<E: ExtensionField> $op<&Expression<E>> for Expression<E> {
type Output = Expression<E>;

fn $fun(self, rhs: &Expression<E>) -> Expression<E> {
self.$fun(rhs.clone())
}
}

impl<E: ExtensionField> $op<Expression<E>> for &Expression<E> {
type Output = Expression<E>;

fn $fun(self, rhs: Expression<E>) -> Expression<E> {
self.clone().$fun(rhs)
}
}

impl<E: ExtensionField> $op<&Expression<E>> for &Expression<E> {
type Output = Expression<E>;

fn $fun(self, rhs: &Expression<E>) -> Expression<E> {
self.clone().$fun(rhs.clone())
}
}

// for mutable references
impl<E: ExtensionField> $op<&mut Expression<E>> for Expression<E> {
type Output = Expression<E>;

fn $fun(self, rhs: &mut Expression<E>) -> Expression<E> {
self.$fun(rhs.clone())
}
}

impl<E: ExtensionField> $op<Expression<E>> for &mut Expression<E> {
type Output = Expression<E>;

fn $fun(self, rhs: Expression<E>) -> Expression<E> {
self.clone().$fun(rhs)
}
}

impl<E: ExtensionField> $op<&mut Expression<E>> for &mut Expression<E> {
type Output = Expression<E>;

fn $fun(self, rhs: &mut Expression<E>) -> Expression<E> {
self.clone().$fun(rhs.clone())
}
}
};
}
ref_binop_instances!(Add, add);
ref_binop_instances!(Sub, sub);
ref_binop_instances!(Mul, mul);

macro_rules! mixed_binop_instances {
($op: ident, $fun: ident, ($($t:ty),*)) => {
$(impl<E: ExtensionField> $op<Expression<E>> for $t {
type Output = Expression<E>;
Expand All @@ -458,21 +533,38 @@ macro_rules! binop_instances {
fn $fun(self, rhs: $t) -> Expression<E> {
self.$fun(Expression::<E>::from(rhs))
}
})*
}

impl<E: ExtensionField> $op<&Expression<E>> for $t {
type Output = Expression<E>;

fn $fun(self, rhs: &Expression<E>) -> Expression<E> {
Expression::<E>::from(self).$fun(rhs)
}
}

impl<E: ExtensionField> $op<$t> for &Expression<E> {
type Output = Expression<E>;

fn $fun(self, rhs: $t) -> Expression<E> {
self.$fun(Expression::<E>::from(rhs))
}
}
)*
};
}

binop_instances!(
mixed_binop_instances!(
Add,
add,
(u8, u16, u32, u64, usize, i8, i16, i32, i64, i128, isize)
);
binop_instances!(
mixed_binop_instances!(
Sub,
sub,
(u8, u16, u32, u64, usize, i8, i16, i32, i64, i128, isize)
);
binop_instances!(
mixed_binop_instances!(
Mul,
mul,
(u8, u16, u32, u64, usize, i8, i16, i32, i64, i128, isize)
Expand Down Expand Up @@ -686,6 +778,20 @@ impl<F: SmallField, E: ExtensionField<BaseField = F>> ToExpr<E> for F {
}
}

macro_rules! impl_from_via_ToExpr {
($($t:ty),*) => {
$(
impl<E: ExtensionField> From<$t> for Expression<E> {
fn from(value: $t) -> Self {
value.expr()
}
}
)*
};
}
impl_from_via_ToExpr!(WitIn, Fixed, Instance);
impl_from_via_ToExpr!(&WitIn, &Fixed, &Instance);

// Implement From trait for unsigned types of at most 64 bits
macro_rules! impl_from_unsigned {
($($t:ty),*) => {
Expand Down Expand Up @@ -880,8 +986,7 @@ mod tests {

// scaledsum * challenge
// 3 * x + 2
let expr: Expression<E> =
Into::<Expression<E>>::into(3usize) * x.expr() + Into::<Expression<E>>::into(2usize);
let expr: Expression<E> = 3 * x.expr() + 2;
// c^3 + 1
let c = Expression::Challenge(0, 3, 1.into(), 1.into());
// res
Expand All @@ -897,7 +1002,7 @@ mod tests {

// constant * witin
// 3 * x
let expr: Expression<E> = Into::<Expression<E>>::into(3usize) * x.expr();
let expr: Expression<E> = 3 * x.expr();
assert_eq!(
expr,
Expression::ScaledSum(
Expand Down Expand Up @@ -947,35 +1052,30 @@ mod tests {
let z = cb.create_witin(|| "z");
// scaledsum * challenge
// 3 * x + 2
let expr: Expression<E> =
Into::<Expression<E>>::into(3usize) * x.expr() + Into::<Expression<E>>::into(2usize);
let expr: Expression<E> = 3 * x.expr() + 2;
assert!(expr.is_monomial_form());

// 2 product term
let expr: Expression<E> = Into::<Expression<E>>::into(3usize) * x.expr() * y.expr()
+ Into::<Expression<E>>::into(2usize) * x.expr();
let expr: Expression<E> = 3 * x.expr() * y.expr() + 2 * x.expr();
assert!(expr.is_monomial_form());

// complex linear operation
// (2c + 3) * x * y - 6z
let expr: Expression<E> =
Expression::Challenge(0, 1, 2.into(), 3.into()) * x.expr() * y.expr()
- Into::<Expression<E>>::into(6usize) * z.expr();
Expression::Challenge(0, 1, 2_u64.into(), 3_u64.into()) * x.expr() * y.expr()
- 6 * z.expr();
assert!(expr.is_monomial_form());

// complex linear operation
// (2c + 3) * x * y - 6z
let expr: Expression<E> =
Expression::Challenge(0, 1, 2.into(), 3.into()) * x.expr() * y.expr()
- Into::<Expression<E>>::into(6usize) * z.expr();
Expression::Challenge(0, 1, 2_u64.into(), 3_u64.into()) * x.expr() * y.expr()
- 6 * z.expr();
assert!(expr.is_monomial_form());

// complex linear operation
// (2 * x + 3) * 3 + 6 * 8
let expr: Expression<E> = (Into::<Expression<E>>::into(2usize) * x.expr()
+ Into::<Expression<E>>::into(3usize))
* Into::<Expression<E>>::into(3usize)
+ Into::<Expression<E>>::into(6usize) * Into::<Expression<E>>::into(8usize);
let expr: Expression<E> = (2 * x.expr() + 3) * 3 + 6 * 8;
assert!(expr.is_monomial_form());
}

Expand All @@ -988,8 +1088,7 @@ mod tests {
let y = cb.create_witin(|| "y");
// scaledsum * challenge
// (x + 1) * (y + 1)
let expr: Expression<E> = (Into::<Expression<E>>::into(1usize) + x.expr())
* (Into::<Expression<E>>::into(2usize) + y.expr());
let expr: Expression<E> = (1 + x.expr()) * (2 + y.expr());
assert!(!expr.is_monomial_form());
}

Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/expression/monomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl<E: ExtensionField> Expression<E> {
for a in a {
for b in &b {
res.push(Term {
coeff: a.coeff.clone() * b.coeff.clone(),
coeff: &a.coeff * &b.coeff,
vars: a.vars.iter().chain(b.vars.iter()).cloned().collect(),
});
}
Expand All @@ -54,7 +54,7 @@ impl<E: ExtensionField> Expression<E> {
for x in x {
for a in &a {
res.push(Term {
coeff: x.coeff.clone() * a.coeff.clone(),
coeff: &x.coeff * &a.coeff,
vars: x.vars.iter().chain(a.vars.iter()).cloned().collect(),
});
}
Expand Down
6 changes: 3 additions & 3 deletions ceno_zkvm/src/instructions/riscv/insn_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,13 @@ impl<E: ExtensionField> MemAddr<E> {

/// Represent the address aligned to 2 bytes.
pub fn expr_align2(&self) -> AddressExpr<E> {
self.addr.address_expr() - self.low_bit_exprs()[0].clone()
self.addr.address_expr() - &self.low_bit_exprs()[0]
}

/// Represent the address aligned to 4 bytes.
pub fn expr_align4(&self) -> AddressExpr<E> {
let low_bits = self.low_bit_exprs();
self.addr.address_expr() - low_bits[1].clone() * 2 - low_bits[0].clone()
self.addr.address_expr() - &low_bits[1] * 2 - &low_bits[0]
}

/// Expressions of the low bits of the address, LSB-first: [bit_0, bit_1].
Expand Down Expand Up @@ -425,7 +425,7 @@ impl<E: ExtensionField> MemAddr<E> {
.invert()
.unwrap()
.expr();
let mid_u14 = (limbs[0].clone() - low_sum) * shift_right;
let mid_u14 = (&limbs[0] - low_sum) * shift_right;
cb.assert_ux::<_, _, 14>(|| "mid_u14", mid_u14)?;

// Range check the high limb.
Expand Down
6 changes: 3 additions & 3 deletions ceno_zkvm/src/instructions/riscv/memory/gadget.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ impl<const N_ZEROS: usize> MemWordChange<N_ZEROS> {
let u8_base_inv = E::BaseField::from(1 << 8).invert().unwrap();
cb.assert_ux::<_, _, 8>(
|| "rs2_limb[0].le_bytes[1]",
u8_base_inv.expr() * (rs2_limbs[0].clone() - rs2_limb_bytes[0].expr()),
u8_base_inv.expr() * (&rs2_limbs[0] - rs2_limb_bytes[0].expr()),
)?;

// alloc a new witIn to cache degree 2 expression
Expand Down Expand Up @@ -125,8 +125,8 @@ impl<const N_ZEROS: usize> MemWordChange<N_ZEROS> {
// degree 2 expression
low_bits[1].clone(),
expected_change.expr(),
(1 << 16) * (rs2_limbs[0].clone() - prev_limbs[1].clone()),
rs2_limbs[0].clone() - prev_limbs[0].clone(),
(1 << 16) * (&rs2_limbs[0] - &prev_limbs[1]),
&rs2_limbs[0] - &prev_limbs[0],
)?;

Ok(MemWordChange {
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/scheme/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ mod tests {
let expected_final_product: E = last_layer
.iter()
.map(|f| match f.evaluations() {
FieldType::Ext(e) => e.iter().cloned().reduce(|a, b| a * b).unwrap(),
FieldType::Ext(e) => e.iter().copied().reduce(|a, b| a * b).unwrap(),
_ => unreachable!(""),
})
.product();
Expand Down
Loading

0 comments on commit 684a762

Please sign in to comment.