From 1f8162a210df5e9d5b00bc70e23c7b99e53c0870 Mon Sep 17 00:00:00 2001 From: KimiWu Date: Fri, 1 Nov 2024 11:21:11 +0800 Subject: [PATCH] apply SignedExtendConfig to the places needed msb/is_neg config --- ceno_zkvm/src/gadgets/is_lt.rs | 6 +-- ceno_zkvm/src/instructions/riscv/mulh.rs | 42 +++++++------------ ceno_zkvm/src/instructions/riscv/shift_imm.rs | 15 +++---- ceno_zkvm/src/instructions/riscv/slti.rs | 12 +++--- ceno_zkvm/src/uint.rs | 19 ++------- 5 files changed, 33 insertions(+), 61 deletions(-) diff --git a/ceno_zkvm/src/gadgets/is_lt.rs b/ceno_zkvm/src/gadgets/is_lt.rs index 1dd11659a..b9bc6815a 100644 --- a/ceno_zkvm/src/gadgets/is_lt.rs +++ b/ceno_zkvm/src/gadgets/is_lt.rs @@ -335,10 +335,8 @@ impl InnerSignedLtConfig { is_lt_expr: Expression, ) -> Result { // Extract the sign bit. - let is_lhs_neg = - SignedExtendConfig::construct_limb(cb, lhs.limbs.iter().last().unwrap().expr())?; - let is_rhs_neg = - SignedExtendConfig::construct_limb(cb, rhs.limbs.iter().last().unwrap().expr())?; + let is_lhs_neg = lhs.is_negative(cb)?; + let is_rhs_neg = rhs.is_negative(cb)?; // Convert to field arithmetic. let lhs_value = lhs.to_field_expr(is_lhs_neg.expr()); diff --git a/ceno_zkvm/src/instructions/riscv/mulh.rs b/ceno_zkvm/src/instructions/riscv/mulh.rs index 12b98eec6..c2c15569b 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh.rs @@ -6,13 +6,13 @@ use ff_ext::ExtensionField; use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, - expression::{Expression, ToExpr}, - gadgets::IsLtConfig, + expression::Expression, + gadgets::SignedExtendConfig, instructions::{ Instruction, riscv::{ RIVInstruction, - constants::{BIT_WIDTH, LIMB_BITS, UInt, UIntMul}, + constants::{BIT_WIDTH, UInt, UIntMul}, r_insn::RInstructionConfig, }, }, @@ -256,7 +256,7 @@ impl Instruction for MulhInstruction { /// corresponding signed value, interpreting the bits as a 2s-complement /// encoding. Gadget allocates 2 `WitIn` values in total. struct Signed { - pub is_negative: IsLtConfig, + pub is_negative: SignedExtendConfig, val: Expression, } @@ -266,23 +266,13 @@ impl Signed { name_fn: N, unsigned_val: &UInt, ) -> Result { - cb.namespace( - || "signed", - |cb| { - let name = name_fn(); - // is_lt is set if top limb of val is negative - let is_negative = IsLtConfig::construct_circuit( - cb, - || name.clone(), - ((1u64 << (LIMB_BITS - 1)) - 1).into(), - unsigned_val.expr().last().unwrap().clone(), - 1, - )?; - let val = unsigned_val.value() - (1u64 << BIT_WIDTH) * is_negative.expr(); - - Ok(Self { is_negative, val }) - }, - ) + cb.namespace(name_fn, |cb| { + // is_lt is set if top limb of val is negative + let is_negative = unsigned_val.is_negative(cb)?; + let val = unsigned_val.value() - (1u64 << BIT_WIDTH) * is_negative.expr(); + + Ok(Self { is_negative, val }) + }) } pub fn assign_instance( @@ -291,11 +281,11 @@ impl Signed { lkm: &mut LkMultiplicity, val: &Value, ) -> Result { - let high_limb = *val.limbs.last().unwrap() as u64; - let sign_cutoff = (1u64 << (LIMB_BITS - 1)) - 1; - self.is_negative - .assign_instance(instance, lkm, sign_cutoff, high_limb)?; - + self.is_negative.assign_instance::( + instance, + lkm, + *val.as_u16_limbs().last().unwrap() as u64, + )?; let signed_val = val.as_u32() as i32; Ok(signed_val) diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index 34e2b0d78..8b8f46b28 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -4,7 +4,7 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, expression::{Expression, ToExpr, WitIn}, - gadgets::{AssertLTConfig, IsLtConfig}, + gadgets::{AssertLTConfig, SignedExtendConfig}, instructions::{ Instruction, riscv::{constants::UInt, i_insn::IInstructionConfig}, @@ -26,7 +26,7 @@ pub struct ShiftImmConfig { assert_lt_config: AssertLTConfig, // SRAI - is_lt_config: Option, + is_lt_config: Option, } pub struct ShiftImmInstruction(PhantomData<(E, I)>); @@ -84,10 +84,9 @@ impl Instruction for ShiftImmInstructio InsnKind::SRAI | InsnKind::SRLI => { let (inflow, is_lt_config) = match I::INST_KIND { InsnKind::SRAI => { - let is_rs1_neg = rs1_read.is_negative(circuit_builder, || "lhs_msb")?; - let msb_expr: Expression = is_rs1_neg.is_lt.expr(); + let is_rs1_neg = rs1_read.is_negative(circuit_builder)?; let ones = imm.expr() - 1; - (msb_expr * ones, Some(is_rs1_neg)) + (is_rs1_neg.expr() * ones, Some(is_rs1_neg)) } InsnKind::SRLI => (Expression::ZERO, None), _ => unreachable!(), @@ -140,12 +139,10 @@ impl Instruction for ShiftImmInstructio InsnKind::SLLI => (rs1_read.as_u64() * imm as u64) >> UInt::::TOTAL_BITS, InsnKind::SRAI | InsnKind::SRLI => { if I::INST_KIND == InsnKind::SRAI { - let max_signed_limb_expr = (1 << (UInt::::LIMB_BITS - 1)) - 1; - config.is_lt_config.as_ref().unwrap().assign_instance( + config.is_lt_config.as_ref().unwrap().assign_instance::( instance, lk_multiplicity, - max_signed_limb_expr, - rs1_read.as_u64() >> UInt::::LIMB_BITS, + *rs1_read.as_u16_limbs().last().unwrap() as u64, )?; } diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index d35d63d75..85efba194 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -12,7 +12,7 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, expression::{ToExpr, WitIn}, - gadgets::IsLtConfig, + gadgets::{IsLtConfig, SignedExtendConfig}, instructions::Instruction, set_val, tables::InsnRecord, @@ -32,7 +32,7 @@ pub struct SetLessThanImmConfig { lt: IsLtConfig, // SLTI - is_rs1_neg: Option, + is_rs1_neg: Option, } pub struct SetLessThanImmInstruction(PhantomData<(E, I)>); @@ -62,7 +62,7 @@ impl Instruction for SetLessThanImmInst let (value_expr, is_rs1_neg) = match I::INST_KIND { InsnKind::SLTIU => (rs1_read.value(), None), InsnKind::SLTI => { - let is_rs1_neg = rs1_read.is_negative(cb, || "lhs_msb")?; + let is_rs1_neg = rs1_read.is_negative(cb)?; (rs1_read.to_field_expr(is_rs1_neg.expr()), Some(is_rs1_neg)) } _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), @@ -100,7 +100,6 @@ impl Instruction for SetLessThanImmInst config.i_insn.assign_instance(instance, lkm, step)?; let rs1 = step.rs1().unwrap().value; - let max_signed_limb = (1u64 << (UInt::::LIMB_BITS - 1)) - 1; let rs1_value = Value::new_unchecked(rs1 as Word); config .rs1_read @@ -117,11 +116,10 @@ impl Instruction for SetLessThanImmInst .assign_instance(instance, lkm, rs1 as u64, imm as u64)?; } InsnKind::SLTI => { - config.is_rs1_neg.as_ref().unwrap().assign_instance( + config.is_rs1_neg.as_ref().unwrap().assign_instance::( instance, lkm, - max_signed_limb, - *rs1_value.limbs.last().unwrap() as u64, + *rs1_value.as_u16_limbs().last().unwrap() as u64, )?; config .lt diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index 23b55606c..989701305 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -8,8 +8,8 @@ use crate::{ circuit_builder::CircuitBuilder, error::{UtilError, ZKVMError}, expression::{Expression, ToExpr, WitIn}, - gadgets::{AssertLTConfig, IsLtConfig}, - instructions::riscv::constants::{LIMB_BITS, UInt}, + gadgets::{AssertLTConfig, SignedExtendConfig}, + instructions::riscv::constants::UInt, utils::add_one_to_big_num, witness::LkMultiplicity, }; @@ -21,7 +21,6 @@ use goldilocks::SmallField; use itertools::Itertools; use std::{ borrow::Cow, - fmt::Display, mem::{self, MaybeUninit}, ops::Index, }; @@ -541,18 +540,8 @@ impl UInt { /// /// Also called Most Significant Bit extraction, when /// interpreted as an unsigned int. - pub fn is_negative + Display + Clone, N: FnOnce() -> NR>( - &self, - cb: &mut CircuitBuilder, - name_fn: N, - ) -> Result { - IsLtConfig::construct_circuit( - cb, - name_fn, - ((1u64 << (UInt::::LIMB_BITS - 1)) - 1).into(), - self.expr().last().unwrap().clone(), - LIMB_BITS.div_ceil(LIMB_BITS), - ) + pub fn is_negative(&self, cb: &mut CircuitBuilder) -> Result { + SignedExtendConfig::construct_limb(cb, self.limbs.iter().last().unwrap().expr()) } }