Skip to content

Commit

Permalink
Implement unit tests for MULH opcode
Browse files Browse the repository at this point in the history
  • Loading branch information
Bryan Gillespie committed Oct 29, 2024
1 parent 9a70475 commit ec321bb
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 58 deletions.
7 changes: 7 additions & 0 deletions ceno_emul/src/rv32im_encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ const MASK_8_BITS: u32 = 0xFF;
const MASK_10_BITS: u32 = 0x3FF;
const MASK_12_BITS: u32 = 0xFFF;

/// Generate bit encoding of a RISC-V instruction.
///
/// Values `rs1`, `rs2` and `rd1` are 5-bit register indices, and `imm` is of
/// bit length depending on the requirements of the instruction format type.
///
/// Fields not required by the instruction's format type are ignored, so one can
/// safely pass an arbitrary value for these, say 0.
pub const fn encode_rv32(kind: InsnKind, rs1: u32, rs2: u32, rd: u32, imm: u32) -> u32 {
match kind.codes().format {
InsnFormat::R => encode_r(kind, rs1, rs2, rd),
Expand Down
199 changes: 141 additions & 58 deletions ceno_zkvm/src/instructions/riscv/mulh.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{marker::PhantomData, ops::Neg};
use std::{fmt::Display, marker::PhantomData, ops::Neg};

use ceno_emul::{InsnKind, StepRecord};
use ff_ext::ExtensionField;
Expand Down Expand Up @@ -151,9 +151,9 @@ impl<E: ExtensionField> Instruction<E> for MulhInstruction<E> {
let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?;
let rd_written = UInt::new(|| "rd_written", circuit_builder)?;

let rs1_signed = Signed::construct_circuit(circuit_builder, &rs1_read)?;
let rs2_signed = Signed::construct_circuit(circuit_builder, &rs2_read)?;
let rd_signed = Signed::construct_circuit(circuit_builder, &rd_written)?;
let rs1_signed = Signed::construct_circuit(circuit_builder, || "rs1", &rs1_read)?;
let rs2_signed = Signed::construct_circuit(circuit_builder, || "rs2", &rs2_read)?;
let rd_signed = Signed::construct_circuit(circuit_builder, || "rd", &rd_written)?;

let unsigned_prod_low = UInt::new(|| "prod_low", circuit_builder)?;

Expand All @@ -164,13 +164,31 @@ impl<E: ExtensionField> Instruction<E> for MulhInstruction<E> {
+ Expression::<E>::from(1u64 << 32) * rd_signed.abs_value.expr(),
)?;

circuit_builder.require_equal(
// Check that signs are compatible:
// negative * negative = non-negative * non-negative = non-negative
// negative * positive = positive * negative = negative
// negative * zero = zero * negative = non-negative
//
// For the nonzero cases, b1*(1-b2) + (1-b1)*b2 - b3 = 0 validates.
// If either input is zero, the result is nonnegative.
// Taking product of LHS above with abs value of rs1 and rs2 inputs
// gives value which can be zero only when one of the above outcomes
// holds.
//
// Note in particular since the above LHS has values in {-1, 0, 1},
// this product with two 31-bit unsigned values is zero in Goldilocks
// field only when one of the unsigned values is zero, or the LHS is
// zero -- no overflow can take place.

let rs1_sign_bit: Expression<E> = rs1_signed.is_negative.expr();
let rs2_sign_bit: Expression<E> = rs2_signed.is_negative.expr();
let rd_sign_bit: Expression<E> = rd_signed.is_negative.expr();
let sign_check = rs1_sign_bit.clone() * (Expression::ONE - rs2_sign_bit.clone())
+ (Expression::ONE - rs1_sign_bit) * rs2_sign_bit - rd_sign_bit;

circuit_builder.require_zero(
|| "check_signs",
rs1_signed.is_negative.expr::<E>()
* (Expression::<E>::ONE - rs2_signed.is_negative.expr())
+ (Expression::<E>::ONE - rs1_signed.is_negative.expr::<E>())
* rs2_signed.is_negative.expr(),
rd_signed.is_negative.expr(),
sign_check * rs1_signed.abs_value.expr() * rs2_signed.abs_value.expr()
)?;

let r_insn = RInstructionConfig::<E>::construct_circuit(
Expand Down Expand Up @@ -220,24 +238,27 @@ impl<E: ExtensionField> Instruction<E> for MulhInstruction<E> {
.assign_limbs(instance, rd_written.as_u16_limbs());

// Assign sign values
let (_, rs1_abs) = config.rs1_signed.assign_instance(
instance,
lk_multiplicity,
&rs1_read)?;
let (_, rs1_abs) =
config
.rs1_signed
.assign_instance(instance, lk_multiplicity, &rs1_read)?;

let (_, rs2_abs) = config.rs2_signed.assign_instance(
instance,
lk_multiplicity,
&rs2_read)?;
let (_, rs2_abs) =
config
.rs2_signed
.assign_instance(instance, lk_multiplicity, &rs2_read)?;

config.rd_signed.assign_instance(
instance,
lk_multiplicity,
&rd_written)?;
config
.rd_signed
.assign_instance(instance, lk_multiplicity, &rd_written)?;

// Extract low limbs value of unsigned product
let unsigned_prod_low = Value::new((rs1_abs * rs2_abs) % (1u64 << BIT_WIDTH), lk_multiplicity);
config.unsigned_prod_low
let unsigned_prod_low = Value::new(
((rs1_abs * rs2_abs) % (1u64 << BIT_WIDTH)) as u32,
lk_multiplicity,
);
config
.unsigned_prod_low
.assign_limbs(instance, unsigned_prod_low.as_u16_limbs());

Ok(())
Expand All @@ -250,29 +271,41 @@ struct Signed {
}

impl Signed {
pub fn construct_circuit<E: ExtensionField>(
pub fn construct_circuit<
E: ExtensionField,
NR: Into<String> + Display + Clone,
N: FnOnce() -> NR,
>(
cb: &mut CircuitBuilder<E>,
name_fn: N,
val: &UInt<E>,
) -> Result<Self, ZKVMError> {
// is_lt is set if top limb of val is negative
let is_negative = IsLtConfig::construct_circuit(
cb,
cb.namespace(
|| "signed",
(1u64 << (LIMB_BITS - 1)).into(),
val.expr().last().unwrap().clone(),
1,
)?;
let abs_value = cb.create_witin(|| "abs_value witin")?;
cb.require_equal(
|| "abs_value",
abs_value.expr(),
(1 - 2 * is_negative.expr()) * (val.value() - (1 << 32) * is_negative.expr()),
)?;

Ok(Self {
is_negative,
abs_value,
})
|cb| {
let name = name_fn();
// is_lt is set if top limb of val is negative
let is_negative = IsLtConfig::construct_circuit(
cb,
|| name.clone(),
(1u64 << (LIMB_BITS - 1)).into(),
val.expr().last().unwrap().clone(),
1,
)?;
let abs_value = cb.create_witin(|| format!("{name} abs_value witin"))?;
cb.require_equal(
|| "abs_value",
abs_value.expr(),
(1u64 - 2 * is_negative.expr())
* (val.value() - (1u64 << 32) * is_negative.expr()),
)?;

Ok(Self {
is_negative,
abs_value,
})
},
)
}

pub fn assign_instance<F: SmallField>(
Expand All @@ -283,12 +316,8 @@ impl Signed {
) -> Result<(bool, u64), ZKVMError> {
let high_limb = *val.limbs.last().unwrap() as u64;
let sign_cutoff = 1u64 << (LIMB_BITS - 1);
self.is_negative.assign_instance(
instance,
lkm,
sign_cutoff,
high_limb,
)?;
self.is_negative
.assign_instance(instance, lkm, sign_cutoff, high_limb)?;
let is_negative = sign_cutoff < high_limb;
let abs_value = {
let unsigned = val.as_u64();
Expand All @@ -298,11 +327,7 @@ impl Signed {
unsigned
}
};
set_val!(
instance,
self.abs_value,
abs_value
);
set_val!(instance, self.abs_value, abs_value);
Ok((is_negative, abs_value))
}
}
Expand All @@ -322,12 +347,12 @@ mod test {

#[test]
fn test_opcode_mulhu() {
verify(2, 11);
verify(u32::MAX, u32::MAX);
verify(u16::MAX as u32, u16::MAX as u32);
verify_mulhu(2, 11);
verify_mulhu(u32::MAX, u32::MAX);
verify_mulhu(u16::MAX as u32, u16::MAX as u32);
}

fn verify(rs1: u32, rs2: u32) {
fn verify_mulhu(rs1: u32, rs2: u32) {
let mut cs = ConstraintSystem::<GoldilocksExt2>::new(|| "riscv");
let mut cb = CircuitBuilder::new(&mut cs);
let config = cb
Expand Down Expand Up @@ -367,4 +392,62 @@ mod test {

MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm));
}

#[test]
fn test_opcode_mulh() {
let test_cases = vec![
(2, 11),
(0, -1),
(0, 1),
(1, 0),
(-1, -1),
(i32::MAX, i32::MIN), // TODO handle problem with abs value of min
(i32::MAX, i32::MAX),
(i32::MIN, i32::MIN),
];
test_cases
.into_iter()
.for_each(|(rs1, rs2)| verify_mulh(rs1, rs2));
}

fn verify_mulh(rs1: i32, rs2: i32) {
let mut cs = ConstraintSystem::<GoldilocksExt2>::new(|| "riscv");
let mut cb = CircuitBuilder::new(&mut cs);
let config = cb
.namespace(|| "mulh", |cb| Ok(MulhInstruction::construct_circuit(cb)))
.unwrap()
.unwrap();

let signed_prod_high = (rs1 as i64).wrapping_mul(rs2 as i64) >> 32;

println!("{rs1} {rs2} {signed_prod_high}");

// // values assignment
let insn_code = encode_rv32(InsnKind::MULH, 2, 3, 4, 0);
let (raw_witin, lkm) =
MulhInstruction::assign_instances(&config, cb.cs.num_witin as usize, vec![
StepRecord::new_r_instruction(
3,
MOCK_PC_START,
insn_code,
rs1 as u32,
rs2 as u32,
Change::new(0, signed_prod_high as u32),
0,
),
])
.unwrap();

// verify value write to register, which is only hi
// let expected_rd_written = UInt::from_const_unchecked(value_mul.as_hi_limb_slice().to_vec());
let rd_written_expr = cb.get_debug_expr(DebugIndex::RdWrite as usize)[0].clone();
cb.require_equal(
|| "assert_rd_written",
rd_written_expr,
Expression::from(signed_prod_high as u32),
)
.unwrap();

MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm));
}
}

0 comments on commit ec321bb

Please sign in to comment.