From 1ab8dea81a9860eb84baf6292b9d9387cbbec44f Mon Sep 17 00:00:00 2001 From: Soham Zemse <22412996+zemse@users.noreply.github.com> Date: Fri, 6 Sep 2024 12:41:23 +0530 Subject: [PATCH] avoid using `lookup_ltu_limb8` instead use ux lookup --- ceno_zkvm/src/chip_handler/general.rs | 28 +++++++++++++-------- ceno_zkvm/src/instructions/riscv/addsub.rs | 29 ++++++++++++++-------- ceno_zkvm/src/scheme/mock_prover.rs | 12 ++++----- 3 files changed, 42 insertions(+), 27 deletions(-) diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 63b9c3f84..ebb856dfd 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -266,34 +266,40 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { Ok(()) } - pub(crate) fn less_than( + pub(crate) fn less_than( &mut self, name_fn: N, lhs: Expression, rhs: Expression, - ) -> Result + ) -> Result<(WitIn, WitIn), ZKVMError> where - NR: Into + Display, + NR: Into + Display + Clone, N: FnOnce() -> NR, { self.namespace( || "less_than", |cb| { let name = name_fn(); - let is_lt = cb.create_witin(|| format!("{name} witin"))?; - // TODO add name_fn to lookup_ltu_limb8, not done yet to avoid merge conflicts - cb.lookup_ltu_limb8(is_lt.expr(), lhs, rhs)?; - Ok(is_lt) + let is_lt = cb.create_witin(|| format!("{name} is_lt witin"))?; + let diff = cb.create_witin(|| format!("{name} diff witin"))?; + let range = Expression::Constant(2u64.pow(C as u32).into()); + cb.require_equal( + || name.clone(), + lhs - rhs, + diff.expr() - is_lt.expr() * range, + )?; + cb.assert_ux::<_, _, C>(|| name, diff.expr())?; + Ok((is_lt, diff)) }, ) } - pub(crate) fn assert_less_than( + pub(crate) fn assert_less_than( &mut self, name_fn: N, lhs: Expression, rhs: Expression, - ) -> Result + ) -> Result<(WitIn, WitIn), ZKVMError> where NR: Into + Clone + Display, N: FnOnce() -> NR, @@ -302,9 +308,9 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { || "assert_less_than", |cb| { let name = name_fn(); - let is_lt = cb.less_than(|| name.clone(), lhs, rhs)?; + let (is_lt, diff) = cb.less_than::<_, _, C>(|| name.clone(), lhs, rhs)?; cb.require_one(|| name, is_lt.expr())?; - Ok(is_lt) + Ok((is_lt, diff)) }, ) } diff --git a/ceno_zkvm/src/instructions/riscv/addsub.rs b/ceno_zkvm/src/instructions/riscv/addsub.rs index a4ac40daf..2992dc738 100644 --- a/ceno_zkvm/src/instructions/riscv/addsub.rs +++ b/ceno_zkvm/src/instructions/riscv/addsub.rs @@ -34,9 +34,9 @@ pub struct InstructionConfig { pub prev_rs1_ts: WitIn, pub prev_rs2_ts: WitIn, pub prev_rd_ts: WitIn, - pub lt_wtns_rs1: WitIn, - pub lt_wtns_rs2: WitIn, - pub lt_wtns_rd: WitIn, + pub lt_wtns_rs1: (WitIn, WitIn), + pub lt_wtns_rs2: (WitIn, WitIn), + pub lt_wtns_rd: (WitIn, WitIn), phantom: PhantomData, } @@ -120,18 +120,21 @@ fn add_sub_gadget( let next_ts = ts + 1.into(); circuit_builder.state_out(next_pc, next_ts)?; - let lt_wtns_rs1 = circuit_builder.assert_less_than( + let lt_wtns_rs1 = circuit_builder.assert_less_than::<_, _, 16>( || "prev_rs1_ts < ts", prev_rs1_ts.expr(), cur_ts.expr(), )?; - let lt_wtns_rs2 = circuit_builder.assert_less_than( + let lt_wtns_rs2 = circuit_builder.assert_less_than::<_, _, 16>( || "prev_rs2_ts < ts", prev_rs2_ts.expr(), cur_ts.expr(), )?; - let lt_wtns_rd = - circuit_builder.assert_less_than(|| "prev_rd_ts < ts", prev_rd_ts.expr(), cur_ts.expr())?; + let lt_wtns_rd = circuit_builder.assert_less_than::<_, _, 16>( + || "prev_rd_ts < ts", + prev_rd_ts.expr(), + cur_ts.expr(), + )?; Ok(InstructionConfig { pc, @@ -195,9 +198,15 @@ impl Instruction for AddInstruction { set_val!(instance, config.prev_rs1_ts, 2); set_val!(instance, config.prev_rs2_ts, 2); set_val!(instance, config.prev_rd_ts, 2); - set_val!(instance, config.lt_wtns_rs1, 1); - set_val!(instance, config.lt_wtns_rs2, 1); - set_val!(instance, config.lt_wtns_rd, 1); + + set_val!(instance, config.lt_wtns_rs1.0, 1); + set_val!(instance, config.lt_wtns_rs1.1, 2u64.pow(16) - 2 + 1); // range - lhs + rhs + + set_val!(instance, config.lt_wtns_rs2.0, 1); + set_val!(instance, config.lt_wtns_rs2.1, 2u64.pow(16) - 3 + 2); // range - lhs + rhs + + set_val!(instance, config.lt_wtns_rd.0, 1); + set_val!(instance, config.lt_wtns_rd.1, 2u64.pow(16) - 3 + 2); // range - lhs + rhs Ok(()) } } diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 24a386e55..eb6504e44 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -325,7 +325,7 @@ pub fn load_u5_table( cb: &CircuitBuilder, challenge: [E; 2], ) { - for i in 0..32 { + for i in 0..(1 << 5) { let rlc_record = cb.rlc_chip_record(vec![ Expression::Constant(E::BaseField::from(ROMType::U5 as u64)), i.into(), @@ -340,8 +340,8 @@ pub fn load_u16_table( cb: &CircuitBuilder, challenge: [E; 2], ) { - t_vec.reserve(u16::MAX as usize); - for i in 0..(u16::MAX as usize) { + t_vec.reserve(1 << 16); + for i in 0..(1 << 16) { let rlc_record = cb.rlc_chip_record(vec![ Expression::Constant(E::BaseField::from(ROMType::U16 as u64)), i.into(), @@ -356,9 +356,9 @@ pub fn load_lt_table( cb: &CircuitBuilder, challenge: [E; 2], ) { - t_vec.reserve(u16::MAX as usize); - for lhs in 0..(u8::MAX as usize) { - for rhs in 0..(u8::MAX as usize) { + t_vec.reserve(1 << 16); + for lhs in 0..(1 << 8) { + for rhs in 0..(1 << 8) { let is_lt = if lhs < rhs { 1 } else { 0 }; let lhs_rhs = lhs * 256 + rhs; let rlc_record = cb.rlc_chip_record(vec![