diff --git a/ceno_emul/src/rv32im_encode.rs b/ceno_emul/src/rv32im_encode.rs index 4a7fe3e7f..18d1d15a2 100644 --- a/ceno_emul/src/rv32im_encode.rs +++ b/ceno_emul/src/rv32im_encode.rs @@ -8,6 +8,13 @@ const MASK_8_BITS: u32 = 0xFF; const MASK_10_BITS: u32 = 0x3FF; const MASK_12_BITS: u32 = 0xFFF; +/// Generate bit encoding of a RISC-V instruction. +/// +/// Values `rs1`, `rs2` and `rd1` are 5-bit register indices, and `imm` is of +/// bit length depending on the requirements of the instruction format type. +/// +/// Fields not required by the instruction's format type are ignored, so one can +/// safely pass an arbitrary value for these, say 0. pub const fn encode_rv32(kind: InsnKind, rs1: u32, rs2: u32, rd: u32, imm: u32) -> u32 { match kind.codes().format { InsnFormat::R => encode_r(kind, rs1, rs2, rd), diff --git a/ceno_zkvm/src/instructions/riscv/mulh.rs b/ceno_zkvm/src/instructions/riscv/mulh.rs index 4704b57b0..e4bd4f1fd 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh.rs @@ -1,4 +1,4 @@ -use std::{marker::PhantomData, ops::Neg}; +use std::{fmt::Display, marker::PhantomData, ops::Neg}; use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; @@ -151,9 +151,9 @@ impl Instruction for MulhInstruction { let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?; let rd_written = UInt::new(|| "rd_written", circuit_builder)?; - let rs1_signed = Signed::construct_circuit(circuit_builder, &rs1_read)?; - let rs2_signed = Signed::construct_circuit(circuit_builder, &rs2_read)?; - let rd_signed = Signed::construct_circuit(circuit_builder, &rd_written)?; + let rs1_signed = Signed::construct_circuit(circuit_builder, || "rs1", &rs1_read)?; + let rs2_signed = Signed::construct_circuit(circuit_builder, || "rs2", &rs2_read)?; + let rd_signed = Signed::construct_circuit(circuit_builder, || "rd", &rd_written)?; let unsigned_prod_low = UInt::new(|| "prod_low", circuit_builder)?; @@ -164,13 +164,31 @@ impl Instruction for MulhInstruction { + Expression::::from(1u64 << 32) * rd_signed.abs_value.expr(), )?; - circuit_builder.require_equal( + // Check that signs are compatible: + // negative * negative = non-negative * non-negative = non-negative + // negative * positive = positive * negative = negative + // negative * zero = zero * negative = non-negative + // + // For the nonzero cases, b1*(1-b2) + (1-b1)*b2 - b3 = 0 validates. + // If either input is zero, the result is nonnegative. + // Taking product of LHS above with abs value of rs1 and rs2 inputs + // gives value which can be zero only when one of the above outcomes + // holds. + // + // Note in particular since the above LHS has values in {-1, 0, 1}, + // this product with two 31-bit unsigned values is zero in Goldilocks + // field only when one of the unsigned values is zero, or the LHS is + // zero -- no overflow can take place. + + let rs1_sign_bit: Expression = rs1_signed.is_negative.expr(); + let rs2_sign_bit: Expression = rs2_signed.is_negative.expr(); + let rd_sign_bit: Expression = rd_signed.is_negative.expr(); + let sign_check = rs1_sign_bit.clone() * (Expression::ONE - rs2_sign_bit.clone()) + + (Expression::ONE - rs1_sign_bit) * rs2_sign_bit - rd_sign_bit; + + circuit_builder.require_zero( || "check_signs", - rs1_signed.is_negative.expr::() - * (Expression::::ONE - rs2_signed.is_negative.expr()) - + (Expression::::ONE - rs1_signed.is_negative.expr::()) - * rs2_signed.is_negative.expr(), - rd_signed.is_negative.expr(), + sign_check * rs1_signed.abs_value.expr() * rs2_signed.abs_value.expr() )?; let r_insn = RInstructionConfig::::construct_circuit( @@ -220,24 +238,27 @@ impl Instruction for MulhInstruction { .assign_limbs(instance, rd_written.as_u16_limbs()); // Assign sign values - let (_, rs1_abs) = config.rs1_signed.assign_instance( - instance, - lk_multiplicity, - &rs1_read)?; + let (_, rs1_abs) = + config + .rs1_signed + .assign_instance(instance, lk_multiplicity, &rs1_read)?; - let (_, rs2_abs) = config.rs2_signed.assign_instance( - instance, - lk_multiplicity, - &rs2_read)?; + let (_, rs2_abs) = + config + .rs2_signed + .assign_instance(instance, lk_multiplicity, &rs2_read)?; - config.rd_signed.assign_instance( - instance, - lk_multiplicity, - &rd_written)?; + config + .rd_signed + .assign_instance(instance, lk_multiplicity, &rd_written)?; // Extract low limbs value of unsigned product - let unsigned_prod_low = Value::new((rs1_abs * rs2_abs) % (1u64 << BIT_WIDTH), lk_multiplicity); - config.unsigned_prod_low + let unsigned_prod_low = Value::new( + ((rs1_abs * rs2_abs) % (1u64 << BIT_WIDTH)) as u32, + lk_multiplicity, + ); + config + .unsigned_prod_low .assign_limbs(instance, unsigned_prod_low.as_u16_limbs()); Ok(()) @@ -250,29 +271,41 @@ struct Signed { } impl Signed { - pub fn construct_circuit( + pub fn construct_circuit< + E: ExtensionField, + NR: Into + Display + Clone, + N: FnOnce() -> NR, + >( cb: &mut CircuitBuilder, + name_fn: N, val: &UInt, ) -> Result { - // is_lt is set if top limb of val is negative - let is_negative = IsLtConfig::construct_circuit( - cb, + cb.namespace( || "signed", - (1u64 << (LIMB_BITS - 1)).into(), - val.expr().last().unwrap().clone(), - 1, - )?; - let abs_value = cb.create_witin(|| "abs_value witin")?; - cb.require_equal( - || "abs_value", - abs_value.expr(), - (1 - 2 * is_negative.expr()) * (val.value() - (1 << 32) * is_negative.expr()), - )?; - - Ok(Self { - is_negative, - abs_value, - }) + |cb| { + let name = name_fn(); + // is_lt is set if top limb of val is negative + let is_negative = IsLtConfig::construct_circuit( + cb, + || name.clone(), + (1u64 << (LIMB_BITS - 1)).into(), + val.expr().last().unwrap().clone(), + 1, + )?; + let abs_value = cb.create_witin(|| format!("{name} abs_value witin"))?; + cb.require_equal( + || "abs_value", + abs_value.expr(), + (1u64 - 2 * is_negative.expr()) + * (val.value() - (1u64 << 32) * is_negative.expr()), + )?; + + Ok(Self { + is_negative, + abs_value, + }) + }, + ) } pub fn assign_instance( @@ -283,12 +316,8 @@ impl Signed { ) -> Result<(bool, u64), ZKVMError> { let high_limb = *val.limbs.last().unwrap() as u64; let sign_cutoff = 1u64 << (LIMB_BITS - 1); - self.is_negative.assign_instance( - instance, - lkm, - sign_cutoff, - high_limb, - )?; + self.is_negative + .assign_instance(instance, lkm, sign_cutoff, high_limb)?; let is_negative = sign_cutoff < high_limb; let abs_value = { let unsigned = val.as_u64(); @@ -298,11 +327,7 @@ impl Signed { unsigned } }; - set_val!( - instance, - self.abs_value, - abs_value - ); + set_val!(instance, self.abs_value, abs_value); Ok((is_negative, abs_value)) } } @@ -322,12 +347,12 @@ mod test { #[test] fn test_opcode_mulhu() { - verify(2, 11); - verify(u32::MAX, u32::MAX); - verify(u16::MAX as u32, u16::MAX as u32); + verify_mulhu(2, 11); + verify_mulhu(u32::MAX, u32::MAX); + verify_mulhu(u16::MAX as u32, u16::MAX as u32); } - fn verify(rs1: u32, rs2: u32) { + fn verify_mulhu(rs1: u32, rs2: u32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); let config = cb @@ -367,4 +392,62 @@ mod test { MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } + + #[test] + fn test_opcode_mulh() { + let test_cases = vec![ + (2, 11), + (0, -1), + (0, 1), + (1, 0), + (-1, -1), + (i32::MAX, i32::MIN), // TODO handle problem with abs value of min + (i32::MAX, i32::MAX), + (i32::MIN, i32::MIN), + ]; + test_cases + .into_iter() + .for_each(|(rs1, rs2)| verify_mulh(rs1, rs2)); + } + + fn verify_mulh(rs1: i32, rs2: i32) { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace(|| "mulh", |cb| Ok(MulhInstruction::construct_circuit(cb))) + .unwrap() + .unwrap(); + + let signed_prod_high = (rs1 as i64).wrapping_mul(rs2 as i64) >> 32; + + println!("{rs1} {rs2} {signed_prod_high}"); + + // // values assignment + let insn_code = encode_rv32(InsnKind::MULH, 2, 3, 4, 0); + let (raw_witin, lkm) = + MulhInstruction::assign_instances(&config, cb.cs.num_witin as usize, vec![ + StepRecord::new_r_instruction( + 3, + MOCK_PC_START, + insn_code, + rs1 as u32, + rs2 as u32, + Change::new(0, signed_prod_high as u32), + 0, + ), + ]) + .unwrap(); + + // verify value write to register, which is only hi + // let expected_rd_written = UInt::from_const_unchecked(value_mul.as_hi_limb_slice().to_vec()); + let rd_written_expr = cb.get_debug_expr(DebugIndex::RdWrite as usize)[0].clone(); + cb.require_equal( + || "assert_rd_written", + rd_written_expr, + Expression::from(signed_prod_high as u32), + ) + .unwrap(); + + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); + } }