Skip to content

Commit

Permalink
remove msb_decompose and related code
Browse files Browse the repository at this point in the history
  • Loading branch information
KimiWu123 committed Nov 1, 2024
1 parent 610b86a commit 2a68019
Showing 1 changed file with 1 addition and 166 deletions.
167 changes: 1 addition & 166 deletions ceno_zkvm/src/uint/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
error::ZKVMError,
expression::{Expression, ToExpr, WitIn},
gadgets::AssertLTConfig,
instructions::riscv::config::{IsEqualConfig, MsbConfig, UIntLtConfig, UIntLtuConfig},
instructions::riscv::config::IsEqualConfig,
};

impl<const M: usize, const C: usize, E: ExtensionField> UIntLimbs<M, C, E> {
Expand Down Expand Up @@ -292,171 +292,6 @@ impl<const M: usize, const C: usize, E: ExtensionField> UIntLimbs<M, C, E> {
}
}

impl<const M: usize, E: ExtensionField> UIntLimbs<M, 8, E> {
/// decompose x = (x_s, x_{<s})
/// where x_s is highest bit, x_{<s} is the rest
pub fn msb_decompose<F: SmallField>(
&self,
circuit_builder: &mut CircuitBuilder<E>,
) -> Result<MsbConfig, ZKVMError>
where
E: ExtensionField<BaseField = F>,
{
let high_limb_no_msb = circuit_builder.create_witin(|| "high_limb_mask");
let high_limb = self.limbs[Self::NUM_LIMBS - 1].expr();

circuit_builder.lookup_and_byte(
high_limb.clone(),
Expression::from(0b0111_1111),
high_limb_no_msb.expr(),
)?;

let inv_128 = F::from(128).invert().unwrap();
let msb = (high_limb - high_limb_no_msb.expr()) * Expression::Constant(inv_128);
let msb = WitIn::from_expr(|| "msb", circuit_builder, msb, false)?;
Ok(MsbConfig {
msb,
high_limb_no_msb,
})
}

/// compare unsigned intergers a < b
pub fn ltu_limb8(
&self,
circuit_builder: &mut CircuitBuilder<E>,
rhs: &UIntLimbs<M, 8, E>,
) -> Result<UIntLtuConfig, ZKVMError> {
let n_bytes = Self::NUM_LIMBS;
let indexes: Vec<WitIn> = (0..n_bytes)
.map(|_| circuit_builder.create_witin(|| "index"))
.collect();

// indicate the first non-zero byte index i_0 of a[i] - b[i]
// from high to low
// indexes
// .iter()
// .try_for_each(|idx| circuit_builder.assert_bit(|| "bit assert", idx.expr()))?;
// let index_sum = indexes
// .iter()
// .fold(Expression::from(0), |acc, idx| acc + idx.expr());
// circuit_builder.assert_bit(|| "bit assert", index_sum)?;

// equal zero if a==b, otherwise equal (a[i_0]-b[i_0])^{-1}
let byte_diff_inv = circuit_builder.create_witin(|| "byte_diff_inverse");

// define accumulated index sum from high to low
let si_expr: Vec<Expression<E>> = indexes
.iter()
.rev()
.scan(Expression::from(0), |state, idx| {
*state = state.clone() + idx.expr();
Some(state.clone())
})
.collect();
let si = si_expr
.into_iter()
.rev()
.enumerate()
.map(|(i, expr)| {
WitIn::from_expr(|| format!("si_expr_{i}"), circuit_builder, expr, false)
})
.collect::<Result<Vec<WitIn>, ZKVMError>>()?;

// check byte diff that before the first non-zero i_0 equals zero
si.iter()
.zip(self.limbs.iter())
.zip(rhs.limbs.iter())
.enumerate()
.try_for_each(|(i, ((flag, a), b))| {
circuit_builder.require_zero(
|| format!("byte diff {i} zero check"),
a.expr() - b.expr() - flag.expr() * a.expr() + flag.expr() * b.expr(),
)
})?;

// define accumulated byte sum
// when a!= b, sa should equal the first non-zero byte a[i_0]
let sa = self
.limbs
.iter()
.zip_eq(indexes.iter())
.fold(Expression::from(0), |acc, (ai, idx)| {
acc.clone() + ai.expr() * idx.expr()
});
let sb = rhs
.limbs
.iter()
.zip_eq(indexes.iter())
.fold(Expression::from(0), |acc, (bi, idx)| {
acc.clone() + bi.expr() * idx.expr()
});

// check the first byte difference has a inverse
// unwrap is safe because vector len > 0
let lhs_ne_byte = WitIn::from_expr(|| "lhs_ne_byte", circuit_builder, sa.clone(), false)?;
let rhs_ne_byte = WitIn::from_expr(|| "rhs_ne_byte", circuit_builder, sb.clone(), false)?;
let index_ne = si.first().unwrap();
circuit_builder.require_zero(
|| "byte inverse check",
lhs_ne_byte.expr() * byte_diff_inv.expr()
- rhs_ne_byte.expr() * byte_diff_inv.expr()
- index_ne.expr(),
)?;

let is_ltu = circuit_builder.create_witin(|| "is_ltu");
// now we know the first non-equal byte pairs is (lhs_ne_byte, rhs_ne_byte)
circuit_builder.lookup_ltu_byte(lhs_ne_byte.expr(), rhs_ne_byte.expr(), is_ltu.expr())?;
Ok(UIntLtuConfig {
byte_diff_inv,
indexes,
acc_indexes: si,
lhs_ne_byte,
rhs_ne_byte,
is_ltu,
})
}

pub fn lt_limb8(
&self,
circuit_builder: &mut CircuitBuilder<E>,
rhs: &UIntLimbs<M, 8, E>,
) -> Result<UIntLtConfig, ZKVMError> {
let is_lt = circuit_builder.create_witin(|| "is_lt");
// circuit_builder.assert_bit(|| "assert_bit", is_lt.expr())?;

let lhs_msb = self.msb_decompose(circuit_builder)?;
let rhs_msb = rhs.msb_decompose(circuit_builder)?;

let mut lhs_limbs = self.limbs.iter().copied().collect_vec();
lhs_limbs[Self::NUM_LIMBS - 1] = lhs_msb.high_limb_no_msb;
let lhs_no_msb = Self::from_witins_unchecked(lhs_limbs, None, None);
let mut rhs_limbs = rhs.limbs.iter().copied().collect_vec();
rhs_limbs[Self::NUM_LIMBS - 1] = rhs_msb.high_limb_no_msb;
let rhs_no_msb = Self::from_witins_unchecked(rhs_limbs, None, None);

// (1) compute ltu(a_{<s},b_{<s})
let is_ltu = lhs_no_msb.ltu_limb8(circuit_builder, &rhs_no_msb)?;
// (2) compute $lt(a,b)=a_s\cdot (1-b_s)+eq(a_s,b_s)\cdot ltu(a_{<s},b_{<s})$
// Refer Jolt 5.3: Set Less Than (https://people.cs.georgetown.edu/jthaler/Jolt-paper.pdf)
let (msb_is_equal, msb_diff_inv) =
circuit_builder.is_equal(lhs_msb.msb.expr(), rhs_msb.msb.expr())?;
circuit_builder.require_zero(
|| "is lt zero check",
lhs_msb.msb.expr() - lhs_msb.msb.expr() * rhs_msb.msb.expr()
+ msb_is_equal.expr() * is_ltu.is_ltu.expr()
- is_lt.expr(),
)?;
Ok(UIntLtConfig {
lhs_msb,
rhs_msb,
msb_is_equal,
msb_diff_inv,
is_ltu,
is_lt,
})
}
}

#[cfg(test)]
mod tests {

Expand Down

0 comments on commit 2a68019

Please sign in to comment.