Skip to content

Commit

Permalink
Refactor Signed gadget to store expression value instead of new WitIn
Browse files Browse the repository at this point in the history
  • Loading branch information
Bryan Gillespie committed Oct 31, 2024
1 parent 066fd95 commit eb619bf
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 31 deletions.
16 changes: 16 additions & 0 deletions ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,22 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
Ok(limb)
}

/// Create a new WitIn constrained to be equal to input expression.
pub fn flatten_expr<NR, N>(
&mut self,
name_fn: N,
expr: Expression<E>,
) -> Result<WitIn, ZKVMError>
where
NR: Into<String>,
N: FnOnce() -> NR + Clone,
{
let wit = self.cs.create_witin(name_fn.clone());
self.require_equal(name_fn, wit.expr(), expr)?;

Ok(wit)
}

pub fn require_zero<NR, N>(
&mut self,
name_fn: N,
Expand Down
62 changes: 31 additions & 31 deletions ceno_zkvm/src/instructions/riscv/mulh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,10 @@ pub struct MulhConfig<E: ExtensionField> {
rs1_read: UInt<E>,
rs2_read: UInt<E>,
rd_written: UInt<E>,
rs1_signed: Signed,
rs2_signed: Signed,
rs1_signed: Signed<E>,
rs1_signed_wit: WitIn,
rs2_signed: Signed<E>,
rs2_signed_wit: WitIn,
rd_sign_bit: IsLtConfig,
unsigned_prod_low: UInt<E>,
r_insn: RInstructionConfig<E>,
Expand All @@ -153,7 +155,10 @@ impl<E: ExtensionField> Instruction<E> for MulhInstruction<E> {
// 1. Compute the signed values associated with `rs1` and `rs2`

let rs1_signed = Signed::construct_circuit(circuit_builder, || "rs1", &rs1_read)?;
let rs1_signed_wit = circuit_builder.flatten_expr(|| "rs1_signed", rs1_signed.expr())?;

let rs2_signed = Signed::construct_circuit(circuit_builder, || "rs2", &rs2_read)?;
let rs2_signed_wit = circuit_builder.flatten_expr(|| "rs2_signed", rs2_signed.expr())?;

// 2. Compute the high order bit of `rd`, which is the sign bit of the 2s
// complement value for which rd represents the high limb
Expand All @@ -173,12 +178,11 @@ impl<E: ExtensionField> Instruction<E> for MulhInstruction<E> {
let unsigned_prod_low = UInt::new(|| "unsigned_prod_low", circuit_builder)?;
circuit_builder.require_equal(
|| "validate_prod_high_limb",
rs1_signed.val.expr() * rs2_signed.val.expr(),
rs1_signed_wit.expr() * rs2_signed_wit.expr(),
Expression::<E>::from(1u64 << 32) * rd_written.value() + unsigned_prod_low.value()
- Expression::<E>::from(1u128 << 64) * rd_sign_bit.expr(),
)?;


// The soundness here is a bit subtle. The signed values of 32-bit
// inputs `rs1` and `rs2` have values between `-2^31` and `2^31 - 1`, so
// their product is constrained to lie between `-2^62 + 2^31` and
Expand Down Expand Up @@ -207,7 +211,9 @@ impl<E: ExtensionField> Instruction<E> for MulhInstruction<E> {
rs2_read,
rd_written,
rs1_signed,
rs1_signed_wit,
rs2_signed,
rs2_signed_wit,
rd_sign_bit,
unsigned_prod_low,
r_insn,
Expand Down Expand Up @@ -237,15 +243,17 @@ impl<E: ExtensionField> Instruction<E> for MulhInstruction<E> {
.assign_limbs(instance, rd_written.as_u16_limbs());

// Signed register values
let rs1_signed =
config
.rs1_signed
.assign_instance::<E>(instance, lk_multiplicity, &rs1_read)?;

let rs2_signed =
config
.rs2_signed
.assign_instance::<E>(instance, lk_multiplicity, &rs2_read)?;
let rs1_signed = config
.rs1_signed
.assign_instance(instance, lk_multiplicity, &rs1_read)?;
let field_elt: E::BaseField = i64_to_base(rs1_signed as i64);
set_val!(instance, config.rs1_signed_wit, field_elt);

let rs2_signed = config
.rs2_signed
.assign_instance(instance, lk_multiplicity, &rs2_read)?;
let field_elt: E::BaseField = i64_to_base(rs2_signed as i64);
set_val!(instance, config.rs2_signed_wit, field_elt);

// Sign bit of rd register
let rd_high_limb = *rd_written.limbs.last().unwrap() as u64;
Expand Down Expand Up @@ -274,17 +282,13 @@ impl<E: ExtensionField> Instruction<E> for MulhInstruction<E> {
/// Transform a value represented as a `UInt` into a `WitIn` containing its
/// corresponding signed value, interpreting the bits as a 2s-complement
/// encoding. Gadget allocates 3 `WitIn` values in total.
struct Signed {
struct Signed<E: ExtensionField> {
pub is_negative: IsLtConfig,
pub val: WitIn,
val: Expression<E>,
}

impl Signed {
pub fn construct_circuit<
E: ExtensionField,
NR: Into<String> + Display + Clone,
N: FnOnce() -> NR,
>(
impl<E: ExtensionField> Signed<E> {
pub fn construct_circuit<NR: Into<String> + Display + Clone, N: FnOnce() -> NR>(
cb: &mut CircuitBuilder<E>,
name_fn: N,
unsigned_val: &UInt<E>,
Expand All @@ -301,19 +305,14 @@ impl Signed {
unsigned_val.expr().last().unwrap().clone(),
1,
)?;
let val = cb.create_witin(|| format!("{name} signed_val witin"));
cb.require_equal(
|| "signed_val",
val.expr(),
unsigned_val.value() - (1u64 << BIT_WIDTH) * is_negative.expr(),
)?;
let val = unsigned_val.value() - (1u64 << BIT_WIDTH) * is_negative.expr();

Ok(Self { is_negative, val })
},
)
}

pub fn assign_instance<E: ExtensionField>(
pub fn assign_instance(
&self,
instance: &mut [MaybeUninit<E::BaseField>],
lkm: &mut LkMultiplicity,
Expand All @@ -326,11 +325,12 @@ impl Signed {

let signed_val = val.as_u32() as i32;

let field_elt: E::BaseField = i64_to_base(signed_val as i64);
set_val!(instance, self.val, field_elt);

Ok(signed_val)
}

pub fn expr(&self) -> Expression<E> {
self.val.clone()
}
}

#[cfg(test)]
Expand Down

0 comments on commit eb619bf

Please sign in to comment.