diff --git a/ceno_zkvm/src/instructions/riscv/mul.rs b/ceno_zkvm/src/instructions/riscv/mul.rs index a9dbe2702..50a9615b7 100644 --- a/ceno_zkvm/src/instructions/riscv/mul.rs +++ b/ceno_zkvm/src/instructions/riscv/mul.rs @@ -13,7 +13,7 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, expression::{ToExpr, WitIn}, - instructions::Instruction, + instructions::{riscv::config::ExprLtInput, Instruction}, set_val, uint::UIntValue, witness::LkMultiplicity, @@ -150,7 +150,7 @@ impl Instruction for MulInstruction { ) -> Result<(), ZKVMError> { // TODO use fields from step set_val!(instance, config.pc, 1); - set_val!(instance, config.ts, 2); + set_val!(instance, config.ts, 3); let multiplier_1 = UIntValue::new_unchecked(step.rs1().unwrap().value); let multiplier_2 = UIntValue::new_unchecked(step.rs2().unwrap().value); @@ -170,7 +170,7 @@ impl Instruction for MulInstruction { instance, outcome .into_iter() - .map(|carry| E::BaseField::from(carry as u64)) + .map(|c| E::BaseField::from(c as u64)) .collect_vec(), ); config.outcome.assign_carries( @@ -180,20 +180,37 @@ impl Instruction for MulInstruction { .map(|carry| E::BaseField::from(carry as u64)) .collect_vec(), ); - // TODO #167 - set_val!(instance, config.rs1_id, 2); + + set_val!(instance, config.rs1_id, 1); set_val!(instance, config.rs2_id, 2); - set_val!(instance, config.rd_id, 2); + set_val!(instance, config.rd_id, 3); set_val!(instance, config.prev_rs1_ts, 2); set_val!(instance, config.prev_rs2_ts, 2); set_val!(instance, config.prev_rd_ts, 2); + + ExprLtInput { + lhs: 2, // rs1_ts + rhs: 3, // cur_ts + } + .assign(instance, &config.lt_rs1_ts_cfg); + ExprLtInput { + lhs: 2, // rs2_ts + rhs: 4, // cur_ts + } + .assign(instance, &config.lt_rs2_ts_cfg); + ExprLtInput { + lhs: 2, // rd_ts + rhs: 5, // cur_ts + } + .assign(instance, &config.lt_rd_ts_cfg); + Ok(()) } } #[cfg(test)] mod test { - use ceno_emul::{ReadOp, StepRecord}; + use ceno_emul::{Change, ReadOp, StepRecord, WriteOp}; use goldilocks::GoldilocksExt2; use itertools::Itertools; use multilinear_extensions::mle::IntoMLEs; @@ -208,34 +225,41 @@ mod test { #[test] #[allow(clippy::option_map_unit_fn)] - fn test_opcode_add() { + fn test_opcode_mul() { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); let config = cb - .namespace( - || "add", - |cb| { - let config = MulInstruction::construct_circuit(cb); - Ok(config) - }, - ) + .namespace(|| "mul", |cb| Ok(MulInstruction::construct_circuit(cb))) .unwrap() .unwrap(); + // values assignment + let rs1 = Some(ReadOp { + addr: 0.into(), + value: 11u32, + previous_cycle: 0, + }); + let rs2 = Some(ReadOp { + addr: 0.into(), + value: 2u32, + previous_cycle: 0, + }); + let rd = Some(WriteOp { + addr: 2.into(), + value: Change { + before: 0u32, + after: 22u32, + }, + previous_cycle: 0, + }); + let (raw_witin, _) = MulInstruction::assign_instances( &config, cb.cs.num_witin as usize, vec![StepRecord { - rs1: Some(ReadOp { - addr: 0.into(), - value: 11u32, - previous_cycle: 0, - }), - rs2: Some(ReadOp { - addr: 0.into(), - value: 0xfffffffeu32, - previous_cycle: 0, - }), + rs1, + rs2, + rd, ..Default::default() }], ) @@ -255,34 +279,41 @@ mod test { #[test] #[allow(clippy::option_map_unit_fn)] - fn test_opcode_add_overflow() { + fn test_opcode_mul_overflow() { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); let config = cb - .namespace( - || "add", - |cb| { - let config = MulInstruction::construct_circuit(cb); - Ok(config) - }, - ) + .namespace(|| "mul", |cb| Ok(MulInstruction::construct_circuit(cb))) .unwrap() .unwrap(); + // values assignment + let rs1 = Some(ReadOp { + addr: 0.into(), + value: u32::MAX / 2 + 1, // equals to 2^32 / 2 + previous_cycle: 0, + }); + let rs2 = Some(ReadOp { + addr: 0.into(), + value: 2u32, + previous_cycle: 0, + }); + let rd = Some(WriteOp { + addr: 2.into(), + value: Change { + before: 0u32, + after: 0u32, + }, + previous_cycle: 0, + }); + let (raw_witin, _) = MulInstruction::assign_instances( &config, cb.cs.num_witin as usize, vec![StepRecord { - rs1: Some(ReadOp { - addr: 0.into(), - value: u32::MAX - 1, - previous_cycle: 0, - }), - rs2: Some(ReadOp { - addr: 0.into(), - value: u32::MAX - 1, - previous_cycle: 0, - }), + rs1, + rs2, + rd, ..Default::default() }], ) diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index b16ccef98..2b44a05e6 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -566,11 +566,11 @@ impl + Copy> UIntValue { b_limbs.iter().enumerate().for_each(|(j, b_limb)| { let idx = i + j; if idx < num_limbs { - let (c, overflow) = a_limb.overflowing_mul(*b_limb); - c_limbs[idx] += c; - if overflow { - carries[idx] += 1; - } + let (c, overflow_mul) = a_limb.overflowing_mul(*b_limb); + let (ret, overflow_add) = c_limbs[idx].overflowing_add(c); + + c_limbs[idx] = ret; + carries[idx] += (overflow_add as u16) + (overflow_mul as u16); } }) });