diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 21791f426..e3ae37145 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -1,12 +1,9 @@ -use std::fmt::Display; - use ff_ext::ExtensionField; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, error::ZKVMError, expression::{Expression, Fixed, Instance, ToExpr, WitIn}, - gadgets::IsLtConfig, instructions::riscv::constants::EXIT_CODE_IDX, structs::ROMType, tables::InsnRecord, @@ -328,22 +325,6 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { self.logic_u8(ROMType::Pow, 2.into(), b, c) } - /// less_than - pub(crate) fn less_than( - &mut self, - name_fn: N, - lhs: Expression, - rhs: Expression, - assert_less_than: Option, - max_num_u16_limbs: usize, - ) -> Result - where - NR: Into + Display + Clone, - N: FnOnce() -> NR, - { - IsLtConfig::construct_circuit(self, name_fn, lhs, rhs, assert_less_than, max_num_u16_limbs) - } - pub(crate) fn is_equal( &mut self, lhs: Expression, diff --git a/ceno_zkvm/src/chip_handler/memory.rs b/ceno_zkvm/src/chip_handler/memory.rs index 2900d48ff..cd56f6254 100644 --- a/ceno_zkvm/src/chip_handler/memory.rs +++ b/ceno_zkvm/src/chip_handler/memory.rs @@ -50,7 +50,8 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> MemoryChipOpera cb.write_record(|| "write_record", write_record)?; // assert prev_ts < current_ts - let lt_cfg = cb.less_than( + let lt_cfg = IsLtConfig::construct_circuit( + cb, || "prev_ts < ts", prev_ts, ts.clone(), @@ -102,7 +103,8 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> MemoryChipOpera cb.read_record(|| "read_record", read_record)?; cb.write_record(|| "write_record", write_record)?; - let lt_cfg = cb.less_than( + let lt_cfg = IsLtConfig::construct_circuit( + cb, || "prev_ts < ts", prev_ts, ts.clone(), diff --git a/ceno_zkvm/src/chip_handler/register.rs b/ceno_zkvm/src/chip_handler/register.rs index e8ecf865a..25e58d130 100644 --- a/ceno_zkvm/src/chip_handler/register.rs +++ b/ceno_zkvm/src/chip_handler/register.rs @@ -51,7 +51,8 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOpe cb.write_record(|| "write_record", write_record)?; // assert prev_ts < current_ts - let lt_cfg = cb.less_than( + let lt_cfg = IsLtConfig::construct_circuit( + cb, || "prev_ts < ts", prev_ts, ts.clone(), @@ -103,7 +104,8 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOpe cb.read_record(|| "read_record", read_record)?; cb.write_record(|| "write_record", write_record)?; - let lt_cfg = cb.less_than( + let lt_cfg = IsLtConfig::construct_circuit( + cb, || "prev_ts < ts", prev_ts, ts.clone(), diff --git a/ceno_zkvm/src/gadgets/div.rs b/ceno_zkvm/src/gadgets/div.rs index 5160eef04..b336ad44d 100644 --- a/ceno_zkvm/src/gadgets/div.rs +++ b/ceno_zkvm/src/gadgets/div.rs @@ -5,7 +5,7 @@ use ff_ext::ExtensionField; use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, - instructions::riscv::constants::{UInt, BIT_WIDTH}, + instructions::riscv::constants::{UInt, UINT_LIMBS}, witness::LkMultiplicity, Value, }; @@ -32,18 +32,18 @@ impl DivConfig { remainder: &UInt, ) -> Result { circuit_builder.namespace(name_fn, |cb| { - let intermediate_mul = - divisor.mul::(|| "divisor_mul", cb, quotient, true)?; - let dividend = intermediate_mul.add(|| "dividend_add", cb, remainder, true)?; + let (dividend, intermediate_mul) = + divisor.mul_add(|| "divisor * outcome + r", cb, quotient, remainder, true)?; - // remainder range check - let r_lt = cb.less_than( + let r_lt = IsLtConfig::construct_circuit( + cb, || "remainder < divisor", remainder.value(), divisor.value(), Some(true), - UInt::::NUM_CELLS, + UINT_LIMBS, )?; + Ok(Self { dividend, intermediate_mul, @@ -61,7 +61,6 @@ impl DivConfig { remainder: &Value<'a, u32>, ) -> Result<(), ZKVMError> { let (dividend, intermediate) = divisor.mul_add(quotient, remainder, lkm, true); - self.r_lt .assign_instance(instance, lkm, remainder.as_u64(), divisor.as_u64())?; self.intermediate_mul diff --git a/ceno_zkvm/src/gadgets/is_lt.rs b/ceno_zkvm/src/gadgets/is_lt.rs index 60cb38a74..70d4f7dd2 100644 --- a/ceno_zkvm/src/gadgets/is_lt.rs +++ b/ceno_zkvm/src/gadgets/is_lt.rs @@ -39,7 +39,7 @@ impl IsLtConfig { ) -> Result { assert!(max_num_u16_limbs >= 1); cb.namespace( - || "less_than", + || "is_lt", |cb| { let name = name_fn(); let (is_lt, is_lt_expr) = if let Some(lt) = assert_less_than { diff --git a/ceno_zkvm/src/instructions/riscv/divu.rs b/ceno_zkvm/src/instructions/riscv/divu.rs index 517135cb9..8b139ff51 100644 --- a/ceno_zkvm/src/instructions/riscv/divu.rs +++ b/ceno_zkvm/src/instructions/riscv/divu.rs @@ -1,10 +1,19 @@ use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; -use super::{constants::UInt, r_insn::RInstructionConfig, RIVInstruction}; +use super::{ + constants::{UInt, UINT_LIMBS}, + r_insn::RInstructionConfig, + RIVInstruction, +}; use crate::{ - circuit_builder::CircuitBuilder, error::ZKVMError, gadgets::IsZeroConfig, - instructions::Instruction, uint::Value, witness::LkMultiplicity, + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::Expression, + gadgets::{IsLtConfig, IsZeroConfig}, + instructions::Instruction, + uint::Value, + witness::LkMultiplicity, }; use core::mem::MaybeUninit; use std::marker::PhantomData; @@ -19,6 +28,7 @@ pub struct ArithConfig { remainder: UInt, inter_mul_value: UInt, is_zero: IsZeroConfig, + pub remainder_lt: IsLtConfig, } pub struct ArithInstruction(PhantomData<(E, I)>); @@ -36,37 +46,46 @@ impl Instruction for ArithInstruction, - ) -> Result { + fn construct_circuit(cb: &mut CircuitBuilder) -> Result { // outcome = dividend / divisor + remainder => dividend = divisor * outcome + r - let mut divisor = UInt::new_unchecked(|| "divisor", circuit_builder)?; - let mut outcome = UInt::new(|| "outcome", circuit_builder)?; - let r = UInt::new(|| "remainder", circuit_builder)?; - + let mut divisor = UInt::new_unchecked(|| "divisor", cb)?; + let mut outcome = UInt::new(|| "outcome", cb)?; + let r = UInt::new(|| "remainder", cb)?; let (dividend, inter_mul_value) = - divisor.mul_add(|| "dividend", circuit_builder, &mut outcome, &r, true)?; + divisor.mul_add(|| "divisor * outcome + r", cb, &mut outcome, &r, true)?; // div by zero check - let is_zero = IsZeroConfig::construct_circuit( - circuit_builder, - || "divisor_zero_check", + let is_zero = + IsZeroConfig::construct_circuit(cb, || "divisor_zero_check", divisor.value())?; + let outcome_value = outcome.value(); + cb.condition_require_equal( + || "outcome_is_zero", + is_zero.expr(), + outcome_value.clone(), + ((1u64 << UInt::::M) - 1).into(), + outcome_value, + )?; + + // remainder should be less than divisor if divisor != 0. + let lt = IsLtConfig::construct_circuit( + cb, + || "remainder < divisor?", + r.value(), divisor.value(), + None, + UINT_LIMBS, )?; - let outcome_value = outcome.value(); - circuit_builder - .condition_require_equal( - || "outcome_is_zero", - is_zero.expr(), - outcome_value.clone(), - ((1u64 << UInt::::M) - 1).into(), - outcome_value, - ) - .unwrap(); + // When divisor is zero, remainder is -1 implies "remainder > divisor" aka. lt.expr() == 0 + // otherwise lt.expr() == 1 + cb.require_equal( + || "remainder < divisor when non-zero divisor", + is_zero.expr() + lt.expr(), + Expression::ONE, + )?; let r_insn = RInstructionConfig::::construct_circuit( - circuit_builder, + cb, I::INST_KIND, dividend.register_expr(), divisor.register_expr(), @@ -81,6 +100,7 @@ impl Instruction for ArithInstruction Instruction for ArithInstruction::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); let config = cb .namespace( - || format!("divu_{name}"), + || format!("divu_({name})"), |cb| Ok(DivUInstruction::construct_circuit(cb)), ) .unwrap() .unwrap(); + let outcome = if divisor == 0 { + u32::MAX + } else { + dividend / divisor + }; // values assignment let (raw_witin, _) = DivUInstruction::assign_instances( &config, @@ -179,8 +204,9 @@ mod test { ) .unwrap(); - let expected_rd_written = - UInt::from_const_unchecked(Value::new_unchecked(outcome).as_u16_limbs().to_vec()); + let expected_rd_written = UInt::from_const_unchecked( + Value::new_unchecked(exp_outcome).as_u16_limbs().to_vec(), + ); config .outcome @@ -206,8 +232,8 @@ mod test { verify("u32::MAX", u32::MAX, u32::MAX, 1); verify("div u32::MAX", 3, u32::MAX, 0); verify("u32::MAX div by 2", u32::MAX, 2, u32::MAX / 2); + verify("mul with carries", 1202729773, 171818539, 7); verify("div by zero", 10, 0, u32::MAX); - verify("mul carry", 1202729773, 171818539, 7); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index 4ebfdd161..3b9692411 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -112,8 +112,8 @@ impl Instruction for ShiftLogicalInstru rs2_high, rs2_low5, pow2_rs2_low5, - remainder, div_config, + remainder, }) } @@ -201,51 +201,46 @@ mod tests { use super::{ShiftLogicalInstruction, SllOp, SrlOp}; #[test] - fn test_opcode_sll_1() { - verify::(0b_1, 3, 0b_1000); - } - - #[test] - fn test_opcode_sll_2_rs2_overflow() { + fn test_opcode_sll() { + verify::("basic", 0b_0001, 3, 0b_1000); // 33 << 33 === 33 << 1 - verify::(0b_1, 33, 0b_10); - } - - #[test] - fn test_opcode_sll_3_bit_loss() { - verify::(1 << 31 | 1, 1, 0b_10); - } - - #[test] - fn test_opcode_srl_1() { - verify::(0b_1000, 3, 0b_1); - } - - #[test] - fn test_opcode_srl_2_rs2_overflow() { - // 33 >> 33 === 33 >> 1 - verify::(0b_1010, 33, 0b_101); + verify::("rs2 over 5-bits", 0b_0001, 33, 0b_0010); + verify::("bit loss", 1 << 31 | 1, 1, 0b_0010); + verify::("zero shift", 0b_0001, 0, 0b_0001); + verify::("all zeros", 0b_0000, 0, 0b_0000); + verify::("base is zero", 0b_0000, 1, 0b_0000); } #[test] - fn test_opcode_srl_3_bit_loss() { + fn test_opcode_srl() { + verify::("basic", 0b_1000, 3, 0b_0001); // 33 >> 33 === 33 >> 1 - verify::(0b_1001, 1, 0b_100); + verify::("rs2 over 5-bits", 0b_1010, 33, 0b_0101); + verify::("bit loss", 0b_1001, 1, 0b_0100); + verify::("zero shift", 0b_1000, 0, 0b_1000); + verify::("all zeros", 0b_0000, 0, 0b_0000); + verify::("base is zero", 0b_0000, 1, 0b_0000); } - fn verify(rs1_read: u32, rs2_read: u32, expected_rd_written: u32) { + fn verify( + name: &'static str, + rs1_read: u32, + rs2_read: u32, + expected_rd_written: u32, + ) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let (name, mock_pc, mock_program_op) = match I::INST_KIND { - InsnKind::SLL => ("SLL", MOCK_PC_SLL, MOCK_PROGRAM[19]), - InsnKind::SRL => ("SRL", MOCK_PC_SRL, MOCK_PROGRAM[20]), + let shift = rs2_read & 0b11111; + let (prefix, mock_pc, mock_program_op, rd_written) = match I::INST_KIND { + InsnKind::SLL => ("SLL", MOCK_PC_SLL, MOCK_PROGRAM[19], rs1_read << shift), + InsnKind::SRL => ("SRL", MOCK_PC_SRL, MOCK_PROGRAM[20], rs1_read >> shift), _ => unreachable!(), }; let config = cb .namespace( - || name, + || format!("{prefix}_({name})"), |cb| { let config = ShiftLogicalInstruction::::construct_circuit(cb); @@ -258,7 +253,7 @@ mod tests { config .rd_written .require_equal( - || "assert_rd_written", + || format!("{prefix}_({name})_assert_rd_written"), &mut cb, &UInt::from_const_unchecked( Value::new_unchecked(expected_rd_written) @@ -277,7 +272,7 @@ mod tests { mock_program_op, rs1_read, rs2_read, - Change::new(0, expected_rd_written), + Change::new(0, rd_written), 0, )], ) diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index e0e0b9e8a..4714563ea 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -17,7 +17,7 @@ use itertools::Itertools; use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension; use std::{ collections::HashSet, - fs::{self, File}, + fs::File, hash::Hash, io::{BufReader, ErrorKind}, marker::PhantomData, @@ -669,7 +669,8 @@ mod tests { fn construct_circuit(cb: &mut CircuitBuilder) -> Result { let a = cb.create_witin(|| "a")?; let b = cb.create_witin(|| "b")?; - let lt_wtns = cb.less_than(|| "lt", a.expr(), b.expr(), Some(true), 1)?; + let lt_wtns = + IsLtConfig::construct_circuit(cb, || "lt", a.expr(), b.expr(), Some(true), 1)?; Ok(Self { a, b, lt_wtns }) } @@ -789,7 +790,7 @@ mod tests { fn construct_circuit(cb: &mut CircuitBuilder) -> Result { let a = cb.create_witin(|| "a")?; let b = cb.create_witin(|| "b")?; - let lt_wtns = cb.less_than(|| "lt", a.expr(), b.expr(), None, 1)?; + let lt_wtns = IsLtConfig::construct_circuit(cb, || "lt", a.expr(), b.expr(), None, 1)?; Ok(Self { a, b, lt_wtns }) }