Skip to content

Commit

Permalink
impl lt gadget and SLTU opcode
Browse files Browse the repository at this point in the history
  • Loading branch information
KimiWu123 committed Sep 26, 2024
1 parent d0d4e66 commit d732adb
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 10 deletions.
72 changes: 72 additions & 0 deletions ceno_zkvm/src/gadgets/lt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
use std::mem::MaybeUninit;

use ff_ext::ExtensionField;
use goldilocks::SmallField;

use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::{Expression, ToExpr, WitIn},
instructions::riscv::constants::UInt,
set_val,
uint::UIntLimbs,
witness::LkMultiplicity,
Value,
};

/// Returns `1` when `lhs < rhs`, and returns `0` otherwise.
/// The equation is enforced `lhs - rhs == diff - (lt * range)`.
#[derive(Clone, Debug)]
pub struct LtGadget<E: ExtensionField> {
/// `1` when `lhs < rhs`, `0` otherwise.
lt: WitIn,
/// `diff` equals `lhs - rhs` if `lhs >= rhs`,`lhs - rhs + range` otherwise.
diff: UInt<E>,
}

impl<E: ExtensionField> LtGadget<E> {
pub fn construct_circuit(
cb: &mut CircuitBuilder<E>,
lhs: Expression<E>,
rhs: Expression<E>,
) -> Result<Self, ZKVMError> {
let lt = cb.create_witin(|| "lt")?;
let diff = UIntLimbs::new(|| "diff", cb)?;
let range = Expression::from(1 << UInt::<E>::M);

// The equation we require to hold: `lhs - rhs == diff - (lt * range)`.
cb.require_equal(
|| "lhs - rhs == diff - (lt ⋅ range)",
lhs - rhs,
diff.value() - (lt.expr() * range),
)?;

Ok(LtGadget { lt, diff })
}

pub(crate) fn expr(&self) -> Expression<E> {
self.lt.expr()
}

pub(crate) fn assign(
&self,
instance: &mut [MaybeUninit<E::BaseField>],
lkm: &mut LkMultiplicity,
lhs: E::BaseField,
rhs: E::BaseField,
) -> Result<(), ZKVMError> {
let lhs = lhs.to_canonical_u64();
let rhs = rhs.to_canonical_u64();

// Set `lt`
let lt = lhs < rhs;
set_val!(instance, self.lt, lt as u64);

// Set `diff`
let diff = lhs - rhs + (if lt { 1 << UInt::<E>::M } else { 0 });
self.diff
.assign_limbs(instance, Value::new(diff, lkm).u16_fields());

Ok(())
}
}
1 change: 1 addition & 0 deletions ceno_zkvm/src/gadgets/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
mod is_zero;
pub use is_zero::{IsEqualConfig, IsZeroConfig};
pub mod lt;
25 changes: 23 additions & 2 deletions ceno_zkvm/src/instructions/riscv/arith.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use itertools::Itertools;

use super::{constants::UInt, r_insn::RInstructionConfig, RIVInstruction};
use crate::{
circuit_builder::CircuitBuilder, error::ZKVMError, instructions::Instruction, uint::Value,
witness::LkMultiplicity,
circuit_builder::CircuitBuilder, error::ZKVMError, gadgets::lt::LtGadget,
instructions::Instruction, uint::Value, witness::LkMultiplicity,
};
use core::mem::MaybeUninit;

Expand Down Expand Up @@ -41,6 +41,12 @@ impl RIVInstruction for MulOp {
}
pub type MulInstruction<E> = ArithInstruction<E, MulOp>;

pub struct SLTUOp;
impl RIVInstruction for SLTUOp {
const INST_KIND: InsnKind = InsnKind::SLTU;
}
pub type SltuInstruction<E> = ArithInstruction<E, SLTUOp>;

impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ArithInstruction<E, I> {
type InstructionConfig = ArithConfig<E>;

Expand Down Expand Up @@ -83,6 +89,21 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ArithInstruction<E
(rs1_read, rs2_read, rd_written)
}

InsnKind::SLTU => {
// If rs1_read < rs2_read, rd_written = 1. Otherwise rd_written = 0
let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?;
let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?;

let lt = LtGadget::construct_circuit(
circuit_builder,
rs1_read.value(),
rs2_read.value(),
)?;
let rd_written = UInt::new(|| "rd_written", circuit_builder)?;
circuit_builder.require_equal(|| "rd == lt", rd_written.value(), lt.expr())?;
(rs1_read, rs2_read, rd_written)
}

_ => unreachable!("Unsupported instruction kind"),
};

Expand Down
8 changes: 0 additions & 8 deletions ceno_zkvm/src/uint/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,6 @@ impl<const M: usize, const C: usize, E: ExtensionField> UIntLimbs<M, C, E> {
})
}

pub fn lt(
&self,
_circuit_builder: &mut CircuitBuilder<E>,
_rhs: &UIntLimbs<M, C, E>,
) -> Result<Expression<E>, ZKVMError> {
Ok(self.expr().remove(0) + 1.into())
}

pub fn is_equal(
&self,
circuit_builder: &mut CircuitBuilder<E>,
Expand Down

0 comments on commit d732adb

Please sign in to comment.