From eb619bfb2ba39713c038fb0b0c5aa14527ae1f2e Mon Sep 17 00:00:00 2001 From: Bryan Gillespie Date: Thu, 31 Oct 2024 11:30:44 -0600 Subject: [PATCH] Refactor Signed gadget to store expression value instead of new WitIn --- ceno_zkvm/src/chip_handler/general.rs | 16 ++++++ ceno_zkvm/src/instructions/riscv/mulh.rs | 62 ++++++++++++------------ 2 files changed, 47 insertions(+), 31 deletions(-) diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 2d6f8ba39..94383c493 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -165,6 +165,22 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { Ok(limb) } + /// Create a new WitIn constrained to be equal to input expression. + pub fn flatten_expr( + &mut self, + name_fn: N, + expr: Expression, + ) -> Result + where + NR: Into, + N: FnOnce() -> NR + Clone, + { + let wit = self.cs.create_witin(name_fn.clone()); + self.require_equal(name_fn, wit.expr(), expr)?; + + Ok(wit) + } + pub fn require_zero( &mut self, name_fn: N, diff --git a/ceno_zkvm/src/instructions/riscv/mulh.rs b/ceno_zkvm/src/instructions/riscv/mulh.rs index 3769084bb..3fa1248d1 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh.rs @@ -130,8 +130,10 @@ pub struct MulhConfig { rs1_read: UInt, rs2_read: UInt, rd_written: UInt, - rs1_signed: Signed, - rs2_signed: Signed, + rs1_signed: Signed, + rs1_signed_wit: WitIn, + rs2_signed: Signed, + rs2_signed_wit: WitIn, rd_sign_bit: IsLtConfig, unsigned_prod_low: UInt, r_insn: RInstructionConfig, @@ -153,7 +155,10 @@ impl Instruction for MulhInstruction { // 1. Compute the signed values associated with `rs1` and `rs2` let rs1_signed = Signed::construct_circuit(circuit_builder, || "rs1", &rs1_read)?; + let rs1_signed_wit = circuit_builder.flatten_expr(|| "rs1_signed", rs1_signed.expr())?; + let rs2_signed = Signed::construct_circuit(circuit_builder, || "rs2", &rs2_read)?; + let rs2_signed_wit = circuit_builder.flatten_expr(|| "rs2_signed", rs2_signed.expr())?; // 2. Compute the high order bit of `rd`, which is the sign bit of the 2s // complement value for which rd represents the high limb @@ -173,12 +178,11 @@ impl Instruction for MulhInstruction { let unsigned_prod_low = UInt::new(|| "unsigned_prod_low", circuit_builder)?; circuit_builder.require_equal( || "validate_prod_high_limb", - rs1_signed.val.expr() * rs2_signed.val.expr(), + rs1_signed_wit.expr() * rs2_signed_wit.expr(), Expression::::from(1u64 << 32) * rd_written.value() + unsigned_prod_low.value() - Expression::::from(1u128 << 64) * rd_sign_bit.expr(), )?; - // The soundness here is a bit subtle. The signed values of 32-bit // inputs `rs1` and `rs2` have values between `-2^31` and `2^31 - 1`, so // their product is constrained to lie between `-2^62 + 2^31` and @@ -207,7 +211,9 @@ impl Instruction for MulhInstruction { rs2_read, rd_written, rs1_signed, + rs1_signed_wit, rs2_signed, + rs2_signed_wit, rd_sign_bit, unsigned_prod_low, r_insn, @@ -237,15 +243,17 @@ impl Instruction for MulhInstruction { .assign_limbs(instance, rd_written.as_u16_limbs()); // Signed register values - let rs1_signed = - config - .rs1_signed - .assign_instance::(instance, lk_multiplicity, &rs1_read)?; - - let rs2_signed = - config - .rs2_signed - .assign_instance::(instance, lk_multiplicity, &rs2_read)?; + let rs1_signed = config + .rs1_signed + .assign_instance(instance, lk_multiplicity, &rs1_read)?; + let field_elt: E::BaseField = i64_to_base(rs1_signed as i64); + set_val!(instance, config.rs1_signed_wit, field_elt); + + let rs2_signed = config + .rs2_signed + .assign_instance(instance, lk_multiplicity, &rs2_read)?; + let field_elt: E::BaseField = i64_to_base(rs2_signed as i64); + set_val!(instance, config.rs2_signed_wit, field_elt); // Sign bit of rd register let rd_high_limb = *rd_written.limbs.last().unwrap() as u64; @@ -274,17 +282,13 @@ impl Instruction for MulhInstruction { /// Transform a value represented as a `UInt` into a `WitIn` containing its /// corresponding signed value, interpreting the bits as a 2s-complement /// encoding. Gadget allocates 3 `WitIn` values in total. -struct Signed { +struct Signed { pub is_negative: IsLtConfig, - pub val: WitIn, + val: Expression, } -impl Signed { - pub fn construct_circuit< - E: ExtensionField, - NR: Into + Display + Clone, - N: FnOnce() -> NR, - >( +impl Signed { + pub fn construct_circuit + Display + Clone, N: FnOnce() -> NR>( cb: &mut CircuitBuilder, name_fn: N, unsigned_val: &UInt, @@ -301,19 +305,14 @@ impl Signed { unsigned_val.expr().last().unwrap().clone(), 1, )?; - let val = cb.create_witin(|| format!("{name} signed_val witin")); - cb.require_equal( - || "signed_val", - val.expr(), - unsigned_val.value() - (1u64 << BIT_WIDTH) * is_negative.expr(), - )?; + let val = unsigned_val.value() - (1u64 << BIT_WIDTH) * is_negative.expr(); Ok(Self { is_negative, val }) }, ) } - pub fn assign_instance( + pub fn assign_instance( &self, instance: &mut [MaybeUninit], lkm: &mut LkMultiplicity, @@ -326,11 +325,12 @@ impl Signed { let signed_val = val.as_u32() as i32; - let field_elt: E::BaseField = i64_to_base(signed_val as i64); - set_val!(instance, self.val, field_elt); - Ok(signed_val) } + + pub fn expr(&self) -> Expression { + self.val.clone() + } } #[cfg(test)]