From 2cd6a6d7b713ca7ff77bcebc3bd78dd3049f69d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20G=C3=B6rgens?= Date: Thu, 12 Dec 2024 10:04:43 +0800 Subject: [PATCH 1/3] Introduce `Value::as_i32` (#732) To help make https://github.com/scroll-tech/ceno/pull/596 easier to read and reason about. Also introduce a few more conversion helpers. --- ceno_zkvm/src/instructions/riscv/mul.rs | 4 +-- ceno_zkvm/src/uint.rs | 34 ++++++++++++++++++++++--- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/mul.rs b/ceno_zkvm/src/instructions/riscv/mul.rs index 58b410960..16a08fe41 100644 --- a/ceno_zkvm/src/instructions/riscv/mul.rs +++ b/ceno_zkvm/src/instructions/riscv/mul.rs @@ -432,9 +432,7 @@ impl Signed { lkm, *val.as_u16_limbs().last().unwrap() as u64, )?; - let signed_val = val.as_u32() as i32; - - Ok(signed_val) + Ok(i32::from(val)) } pub fn expr(&self) -> Expression { diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index 193d34f13..5a639a055 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -606,6 +606,30 @@ pub struct Value<'a, T: Into + From + Copy + Default> { pub limbs: Cow<'a, [u16]>, } +impl<'a, T: Into + From + Copy + Default> From<&'a Value<'a, T>> for &'a [u16] { + fn from(v: &'a Value<'a, T>) -> Self { + v.as_u16_limbs() + } +} + +impl<'a, T: Into + From + Copy + Default> From<&Value<'a, T>> for u64 { + fn from(v: &Value<'a, T>) -> Self { + v.as_u64() + } +} + +impl<'a, T: Into + From + Copy + Default> From<&Value<'a, T>> for u32 { + fn from(v: &Value<'a, T>) -> Self { + v.as_u32() + } +} + +impl<'a, T: Into + From + Copy + Default> From<&Value<'a, T>> for i32 { + fn from(v: &Value<'a, T>) -> Self { + v.as_i32() + } +} + // TODO generalize to support non 16 bit limbs // TODO optimize api with fixed size array impl<'a, T: Into + From + Copy + Default> Value<'a, T> { @@ -616,10 +640,7 @@ impl<'a, T: Into + From + Copy + Default> Value<'a, T> { const LIMBS: usize = (Self::M + 15) / 16; pub fn new(val: T, lkm: &mut LkMultiplicity) -> Self { - let uint = Value:: { - val, - limbs: Cow::Owned(Self::split_to_u16(val)), - }; + let uint = Self::new_unchecked(val); Self::assert_u16(&uint.limbs, lkm); uint } @@ -684,6 +705,11 @@ impl<'a, T: Into + From + Copy + Default> Value<'a, T> { self.as_u64() as u32 } + /// Convert the limbs to an i32 value + pub fn as_i32(&self) -> i32 { + self.as_u32() as i32 + } + pub fn u16_fields(&self) -> Vec { self.limbs.iter().map(|v| F::from(*v as u64)).collect_vec() } From 1e52d68832228a56eb6b57145da97f6f21a951d4 Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Thu, 12 Dec 2024 10:31:28 +0800 Subject: [PATCH 2/3] Use `sum` instead of writing our own --- ceno_zkvm/src/uint/arithmetic.rs | 4 +--- mpcs/src/sum_check/classic/coeff.rs | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index dfe33b076..729021be9 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -274,9 +274,7 @@ impl UIntLimbs { .into_iter() .unzip(); - let sum_expr = is_equal_per_limb - .iter() - .fold(Expression::ZERO, |acc, flag| acc.clone() + flag.expr()); + let sum_expr = is_equal_per_limb.iter().map(ToExpr::expr).sum(); let sum_flag = WitIn::from_expr(|| "sum_flag", circuit_builder, sum_expr, false)?; let (is_equal, diff_inv) = diff --git a/mpcs/src/sum_check/classic/coeff.rs b/mpcs/src/sum_check/classic/coeff.rs index 12f46880f..36d5390a6 100644 --- a/mpcs/src/sum_check/classic/coeff.rs +++ b/mpcs/src/sum_check/classic/coeff.rs @@ -49,9 +49,7 @@ impl ClassicSumCheckRoundMessage for Coefficients { } fn sum(&self) -> E { - self[1..] - .iter() - .fold(self[0].double(), |acc, coeff| acc + coeff) + self[..].iter().sum() } fn evaluate(&self, _: &Self::Auxiliary, challenge: &E) -> E { From 4c597d98a84e7e9b005b6d22cae04847bc751bcb Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Thu, 12 Dec 2024 11:12:22 +0800 Subject: [PATCH 3/3] Fix --- mpcs/src/sum_check/classic/coeff.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mpcs/src/sum_check/classic/coeff.rs b/mpcs/src/sum_check/classic/coeff.rs index 36d5390a6..10d5c1c20 100644 --- a/mpcs/src/sum_check/classic/coeff.rs +++ b/mpcs/src/sum_check/classic/coeff.rs @@ -49,7 +49,7 @@ impl ClassicSumCheckRoundMessage for Coefficients { } fn sum(&self) -> E { - self[..].iter().sum() + self[0] + self[..].iter().sum::() } fn evaluate(&self, _: &Self::Auxiliary, challenge: &E) -> E {