diff --git a/ceno_zkvm/src/gadgets/lt.rs b/ceno_zkvm/src/gadgets/lt.rs new file mode 100644 index 000000000..bfc892cb4 --- /dev/null +++ b/ceno_zkvm/src/gadgets/lt.rs @@ -0,0 +1,72 @@ +use std::mem::MaybeUninit; + +use ff_ext::ExtensionField; +use goldilocks::SmallField; + +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, ToExpr, WitIn}, + instructions::riscv::constants::UInt, + set_val, + uint::UIntLimbs, + witness::LkMultiplicity, + Value, +}; + +/// Returns `1` when `lhs < rhs`, and returns `0` otherwise. +/// The equation is enforced `lhs - rhs == diff - (lt * range)`. +#[derive(Clone, Debug)] +pub struct LtGadget { + /// `1` when `lhs < rhs`, `0` otherwise. + lt: WitIn, + /// `diff` equals `lhs - rhs` if `lhs >= rhs`,`lhs - rhs + range` otherwise. + diff: UInt, +} + +impl LtGadget { + pub fn construct_circuit( + cb: &mut CircuitBuilder, + lhs: Expression, + rhs: Expression, + ) -> Result { + let lt = cb.create_witin(|| "lt")?; + let diff = UIntLimbs::new(|| "diff", cb)?; + let range = Expression::from(1 << UInt::::M); + + // The equation we require to hold: `lhs - rhs == diff - (lt * range)`. + cb.require_equal( + || "lhs - rhs == diff - (lt ⋅ range)", + lhs - rhs, + diff.value() - (lt.expr() * range), + )?; + + Ok(LtGadget { lt, diff }) + } + + pub(crate) fn expr(&self) -> Expression { + self.lt.expr() + } + + pub(crate) fn assign( + &self, + instance: &mut [MaybeUninit], + lkm: &mut LkMultiplicity, + lhs: E::BaseField, + rhs: E::BaseField, + ) -> Result<(), ZKVMError> { + let lhs = lhs.to_canonical_u64(); + let rhs = rhs.to_canonical_u64(); + + // Set `lt` + let lt = lhs < rhs; + set_val!(instance, self.lt, lt as u64); + + // Set `diff` + let diff = lhs - rhs + (if lt { 1 << UInt::::M } else { 0 }); + self.diff + .assign_limbs(instance, Value::new(diff, lkm).u16_fields()); + + Ok(()) + } +} diff --git a/ceno_zkvm/src/gadgets/mod.rs b/ceno_zkvm/src/gadgets/mod.rs index 6851cb9a6..1dbb10a8b 100644 --- a/ceno_zkvm/src/gadgets/mod.rs +++ b/ceno_zkvm/src/gadgets/mod.rs @@ -1,2 +1,3 @@ mod is_zero; pub use is_zero::{IsEqualConfig, IsZeroConfig}; +pub mod lt; diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index 42e0a752c..5e029ef54 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -6,8 +6,8 @@ use itertools::Itertools; use super::{constants::UInt, r_insn::RInstructionConfig, RIVInstruction}; use crate::{ - circuit_builder::CircuitBuilder, error::ZKVMError, instructions::Instruction, uint::Value, - witness::LkMultiplicity, + circuit_builder::CircuitBuilder, error::ZKVMError, gadgets::lt::LtGadget, + instructions::Instruction, uint::Value, witness::LkMultiplicity, }; use core::mem::MaybeUninit; @@ -41,6 +41,12 @@ impl RIVInstruction for MulOp { } pub type MulInstruction = ArithInstruction; +pub struct SLTUOp; +impl RIVInstruction for SLTUOp { + const INST_KIND: InsnKind = InsnKind::SLTU; +} +pub type SltuInstruction = ArithInstruction; + impl Instruction for ArithInstruction { type InstructionConfig = ArithConfig; @@ -83,6 +89,21 @@ impl Instruction for ArithInstruction { + // If rs1_read < rs2_read, rd_written = 1. Otherwise rd_written = 0 + let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; + let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?; + + let lt = LtGadget::construct_circuit( + circuit_builder, + rs1_read.value(), + rs2_read.value(), + )?; + let rd_written = UInt::new(|| "rd_written", circuit_builder)?; + circuit_builder.require_equal(|| "rd == lt", rd_written.value(), lt.expr())?; + (rs1_read, rs2_read, rd_written) + } + _ => unreachable!("Unsupported instruction kind"), }; diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 817c83ca0..5486549d1 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -213,14 +213,6 @@ impl UIntLimbs { }) } - pub fn lt( - &self, - _circuit_builder: &mut CircuitBuilder, - _rhs: &UIntLimbs, - ) -> Result, ZKVMError> { - Ok(self.expr().remove(0) + 1.into()) - } - pub fn is_equal( &self, circuit_builder: &mut CircuitBuilder,