Skip to content

Commit

Permalink
Merge branch 'matthias/phantom-type-for-config' into feat/#482-unify-…
Browse files Browse the repository at this point in the history
…signed-bit-extraction
  • Loading branch information
matthiasgoergens committed Nov 1, 2024
2 parents 1a8f40a + af09fab commit af0260b
Show file tree
Hide file tree
Showing 11 changed files with 184 additions and 73 deletions.
20 changes: 5 additions & 15 deletions ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
NR: Into<String>,
N: FnOnce() -> NR,
{
self.namespace(
|| "require_one",
|cb| cb.cs.require_zero(name_fn, Expression::from(1) - expr),
)
self.namespace(|| "require_one", |cb| cb.cs.require_zero(name_fn, 1 - expr))
}

pub fn condition_require_equal<NR, N>(
Expand Down Expand Up @@ -260,7 +257,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
when_true: &Expression<E>,
when_false: &Expression<E>,
) -> Expression<E> {
cond.clone() * when_true.clone() + (1 - cond.clone()) * when_false.clone()
cond * when_true + (1 - cond) * when_false
}

pub(crate) fn assert_ux<NR, N, const C: usize>(
Expand Down Expand Up @@ -346,10 +343,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
{
self.namespace(
|| "assert_bit",
|cb| {
cb.cs
.require_zero(name_fn, expr.clone() * (Expression::ONE - expr))
},
|cb| cb.cs.require_zero(name_fn, &expr * (1 - &expr)),
)
}

Expand Down Expand Up @@ -417,14 +411,10 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
let is_eq = self.create_witin(|| "is_eq");
let diff_inverse = self.create_witin(|| "diff_inverse");

self.require_zero(|| "is equal", is_eq.expr() * &lhs - is_eq.expr() * &rhs)?;
self.require_zero(
|| "is equal",
is_eq.expr().clone() * lhs.clone() - is_eq.expr() * rhs.clone(),
)?;
self.require_zero(
|| "is equal",
Expression::from(1) - is_eq.expr().clone() - diff_inverse.expr() * lhs
+ diff_inverse.expr() * rhs,
1 - is_eq.expr() - diff_inverse.expr() * lhs + diff_inverse.expr() * rhs,
)?;

Ok((is_eq, diff_inverse))
Expand Down
172 changes: 149 additions & 23 deletions ceno_zkvm/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
fmt::Display,
iter::Sum,
mem::MaybeUninit,
ops::{Add, Deref, Mul, Neg, Sub},
ops::{Add, AddAssign, Deref, Mul, MulAssign, Neg, Shl, ShlAssign, Sub, SubAssign},
};

use ff::Field;
Expand Down Expand Up @@ -315,6 +315,51 @@ impl<E: ExtensionField> Add for Expression<E> {
}
}

macro_rules! binop_assign_instances {
($op_assign: ident, $fun_assign: ident, $op: ident, $fun: ident) => {
impl<E: ExtensionField, Rhs> $op_assign<Rhs> for Expression<E>
where
Expression<E>: $op<Rhs, Output = Expression<E>>,
{
fn $fun_assign(&mut self, rhs: Rhs) {
// TODO: consider in-place?
*self = self.clone().$fun(rhs);
}
}
};
}

binop_assign_instances!(AddAssign, add_assign, Add, add);
binop_assign_instances!(SubAssign, sub_assign, Sub, sub);
binop_assign_instances!(MulAssign, mul_assign, Mul, mul);

impl<E: ExtensionField> Shl<usize> for Expression<E> {
type Output = Expression<E>;
fn shl(self, rhs: usize) -> Expression<E> {
self * (1 << rhs)
}
}

impl<E: ExtensionField> Shl<usize> for &Expression<E> {
type Output = Expression<E>;
fn shl(self, rhs: usize) -> Expression<E> {
self.clone() << rhs
}
}

impl<E: ExtensionField> Shl<usize> for &mut Expression<E> {
type Output = Expression<E>;
fn shl(self, rhs: usize) -> Expression<E> {
self.clone() << rhs
}
}

impl<E: ExtensionField> ShlAssign<usize> for Expression<E> {
fn shl_assign(&mut self, rhs: usize) {
*self = self.clone() << rhs;
}
}

impl<E: ExtensionField> Sum for Expression<E> {
fn sum<I: Iterator<Item = Expression<E>>>(iter: I) -> Expression<E> {
iter.fold(Expression::ZERO, |acc, x| acc + x)
Expand Down Expand Up @@ -442,7 +487,64 @@ impl<E: ExtensionField> Sub for Expression<E> {
}
}

macro_rules! binop_instances {
/// Instances for binary operations that mix Expression and &Expression
macro_rules! ref_binop_instances {
($op: ident, $fun: ident) => {
impl<E: ExtensionField> $op<&Expression<E>> for Expression<E> {
type Output = Expression<E>;

fn $fun(self, rhs: &Expression<E>) -> Expression<E> {
self.$fun(rhs.clone())
}
}

impl<E: ExtensionField> $op<Expression<E>> for &Expression<E> {
type Output = Expression<E>;

fn $fun(self, rhs: Expression<E>) -> Expression<E> {
self.clone().$fun(rhs)
}
}

impl<E: ExtensionField> $op<&Expression<E>> for &Expression<E> {
type Output = Expression<E>;

fn $fun(self, rhs: &Expression<E>) -> Expression<E> {
self.clone().$fun(rhs.clone())
}
}

// for mutable references
impl<E: ExtensionField> $op<&mut Expression<E>> for Expression<E> {
type Output = Expression<E>;

fn $fun(self, rhs: &mut Expression<E>) -> Expression<E> {
self.$fun(rhs.clone())
}
}

impl<E: ExtensionField> $op<Expression<E>> for &mut Expression<E> {
type Output = Expression<E>;

fn $fun(self, rhs: Expression<E>) -> Expression<E> {
self.clone().$fun(rhs)
}
}

impl<E: ExtensionField> $op<&mut Expression<E>> for &mut Expression<E> {
type Output = Expression<E>;

fn $fun(self, rhs: &mut Expression<E>) -> Expression<E> {
self.clone().$fun(rhs.clone())
}
}
};
}
ref_binop_instances!(Add, add);
ref_binop_instances!(Sub, sub);
ref_binop_instances!(Mul, mul);

macro_rules! mixed_binop_instances {
($op: ident, $fun: ident, ($($t:ty),*)) => {
$(impl<E: ExtensionField> $op<Expression<E>> for $t {
type Output = Expression<E>;
Expand All @@ -458,21 +560,38 @@ macro_rules! binop_instances {
fn $fun(self, rhs: $t) -> Expression<E> {
self.$fun(Expression::<E>::from(rhs))
}
})*
}

impl<E: ExtensionField> $op<&Expression<E>> for $t {
type Output = Expression<E>;

fn $fun(self, rhs: &Expression<E>) -> Expression<E> {
Expression::<E>::from(self).$fun(rhs)
}
}

impl<E: ExtensionField> $op<$t> for &Expression<E> {
type Output = Expression<E>;

fn $fun(self, rhs: $t) -> Expression<E> {
self.$fun(Expression::<E>::from(rhs))
}
}
)*
};
}

binop_instances!(
mixed_binop_instances!(
Add,
add,
(u8, u16, u32, u64, usize, i8, i16, i32, i64, i128, isize)
);
binop_instances!(
mixed_binop_instances!(
Sub,
sub,
(u8, u16, u32, u64, usize, i8, i16, i32, i64, i128, isize)
);
binop_instances!(
mixed_binop_instances!(
Mul,
mul,
(u8, u16, u32, u64, usize, i8, i16, i32, i64, i128, isize)
Expand Down Expand Up @@ -686,6 +805,20 @@ impl<F: SmallField, E: ExtensionField<BaseField = F>> ToExpr<E> for F {
}
}

macro_rules! impl_from_via_ToExpr {
($($t:ty),*) => {
$(
impl<E: ExtensionField> From<$t> for Expression<E> {
fn from(value: $t) -> Self {
value.expr()
}
}
)*
};
}
impl_from_via_ToExpr!(WitIn, Fixed, Instance);
impl_from_via_ToExpr!(&WitIn, &Fixed, &Instance);

// Implement From trait for unsigned types of at most 64 bits
macro_rules! impl_from_unsigned {
($($t:ty),*) => {
Expand Down Expand Up @@ -880,8 +1013,7 @@ mod tests {

// scaledsum * challenge
// 3 * x + 2
let expr: Expression<E> =
Into::<Expression<E>>::into(3usize) * x.expr() + Into::<Expression<E>>::into(2usize);
let expr: Expression<E> = 3 * x.expr() + 2;
// c^3 + 1
let c = Expression::Challenge(0, 3, 1.into(), 1.into());
// res
Expand All @@ -897,7 +1029,7 @@ mod tests {

// constant * witin
// 3 * x
let expr: Expression<E> = Into::<Expression<E>>::into(3usize) * x.expr();
let expr: Expression<E> = 3 * x.expr();
assert_eq!(
expr,
Expression::ScaledSum(
Expand Down Expand Up @@ -947,35 +1079,30 @@ mod tests {
let z = cb.create_witin(|| "z");
// scaledsum * challenge
// 3 * x + 2
let expr: Expression<E> =
Into::<Expression<E>>::into(3usize) * x.expr() + Into::<Expression<E>>::into(2usize);
let expr: Expression<E> = 3 * x.expr() + 2;
assert!(expr.is_monomial_form());

// 2 product term
let expr: Expression<E> = Into::<Expression<E>>::into(3usize) * x.expr() * y.expr()
+ Into::<Expression<E>>::into(2usize) * x.expr();
let expr: Expression<E> = 3 * x.expr() * y.expr() + 2 * x.expr();
assert!(expr.is_monomial_form());

// complex linear operation
// (2c + 3) * x * y - 6z
let expr: Expression<E> =
Expression::Challenge(0, 1, 2.into(), 3.into()) * x.expr() * y.expr()
- Into::<Expression<E>>::into(6usize) * z.expr();
Expression::Challenge(0, 1, 2_u64.into(), 3_u64.into()) * x.expr() * y.expr()
- 6 * z.expr();
assert!(expr.is_monomial_form());

// complex linear operation
// (2c + 3) * x * y - 6z
let expr: Expression<E> =
Expression::Challenge(0, 1, 2.into(), 3.into()) * x.expr() * y.expr()
- Into::<Expression<E>>::into(6usize) * z.expr();
Expression::Challenge(0, 1, 2_u64.into(), 3_u64.into()) * x.expr() * y.expr()
- 6 * z.expr();
assert!(expr.is_monomial_form());

// complex linear operation
// (2 * x + 3) * 3 + 6 * 8
let expr: Expression<E> = (Into::<Expression<E>>::into(2usize) * x.expr()
+ Into::<Expression<E>>::into(3usize))
* Into::<Expression<E>>::into(3usize)
+ Into::<Expression<E>>::into(6usize) * Into::<Expression<E>>::into(8usize);
let expr: Expression<E> = (2 * x.expr() + 3) * 3 + 6 * 8;
assert!(expr.is_monomial_form());
}

Expand All @@ -988,8 +1115,7 @@ mod tests {
let y = cb.create_witin(|| "y");
// scaledsum * challenge
// (x + 1) * (y + 1)
let expr: Expression<E> = (Into::<Expression<E>>::into(1usize) + x.expr())
* (Into::<Expression<E>>::into(2usize) + y.expr());
let expr: Expression<E> = (1 + x.expr()) * (2 + y.expr());
assert!(!expr.is_monomial_form());
}

Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/expression/monomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl<E: ExtensionField> Expression<E> {
for a in a {
for b in &b {
res.push(Term {
coeff: a.coeff.clone() * b.coeff.clone(),
coeff: &a.coeff * &b.coeff,
vars: a.vars.iter().chain(b.vars.iter()).cloned().collect(),
});
}
Expand All @@ -54,7 +54,7 @@ impl<E: ExtensionField> Expression<E> {
for x in x {
for a in &a {
res.push(Term {
coeff: x.coeff.clone() * a.coeff.clone(),
coeff: &x.coeff * &a.coeff,
vars: x.vars.iter().chain(a.vars.iter()).cloned().collect(),
});
}
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/gadgets/signed_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl<E: ExtensionField> SignedExtendConfig<E> {
assert_ux(
cb,
|| "0 <= 2*val - msb*2^N_BITS < 2^N_BITS",
2 * val - msb.expr() * (1 << n_bits),
2 * val - (msb.expr() << n_bits),
)?;

Ok(SignedExtendConfig {
Expand Down
12 changes: 6 additions & 6 deletions ceno_zkvm/src/instructions/riscv/insn_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,13 @@ impl<E: ExtensionField> MemAddr<E> {

/// Represent the address aligned to 2 bytes.
pub fn expr_align2(&self) -> AddressExpr<E> {
self.addr.address_expr() - self.low_bit_exprs()[0].clone()
self.addr.address_expr() - &self.low_bit_exprs()[0]
}

/// Represent the address aligned to 4 bytes.
pub fn expr_align4(&self) -> AddressExpr<E> {
let low_bits = self.low_bit_exprs();
self.addr.address_expr() - low_bits[1].clone() * 2 - low_bits[0].clone()
self.addr.address_expr() - &low_bits[1] * 2 - &low_bits[0]
}

/// Expressions of the low bits of the address, LSB-first: [bit_0, bit_1].
Expand Down Expand Up @@ -417,15 +417,15 @@ impl<E: ExtensionField> MemAddr<E> {
// Express the value of the low bits.
let low_sum: Expression<E> = (n_zeros..Self::N_LOW_BITS)
.zip_eq(low_bits.iter())
.map(|(pos, bit)| bit.expr() * (1 << pos))
.map(|(pos, bit)| bit.expr() << pos)
.sum();

// Range check the middle bits, that is the low limb excluding the low bits.
let shift_right = E::BaseField::from(1 << Self::N_LOW_BITS)
.invert()
.unwrap()
.expr();
let mid_u14 = (limbs[0].clone() - low_sum) * shift_right;
let mid_u14 = (&limbs[0] - low_sum) * shift_right;
cb.assert_ux::<_, _, 14>(|| "mid_u14", mid_u14)?;

// Range check the high limb.
Expand Down Expand Up @@ -537,8 +537,8 @@ mod test {

if is_ok {
cb.require_equal(|| "", mem_addr.expr_unaligned(), addr.into())?;
cb.require_equal(|| "", mem_addr.expr_align2(), (addr >> 1 << 1).into())?;
cb.require_equal(|| "", mem_addr.expr_align4(), (addr >> 2 << 2).into())?;
cb.require_equal(|| "", mem_addr.expr_align2(), (addr & !1).into())?;
cb.require_equal(|| "", mem_addr.expr_align4(), (addr & !3).into())?;
}
MockProver::assert_with_expected_errors(
&cb,
Expand Down
Loading

0 comments on commit af0260b

Please sign in to comment.