diff --git a/ceno_zkvm/src/gadgets/is_lt.rs b/ceno_zkvm/src/gadgets/is_lt.rs index 0f77053fa..f8d40cdee 100644 --- a/ceno_zkvm/src/gadgets/is_lt.rs +++ b/ceno_zkvm/src/gadgets/is_lt.rs @@ -271,7 +271,7 @@ impl AssertSignedLtConfig { #[derive(Debug)] pub struct SignedLtConfig { is_lt: WitIn, - pub config: InnerSignedLtConfig, + config: InnerSignedLtConfig, } impl SignedLtConfig { @@ -318,9 +318,9 @@ impl SignedLtConfig { } #[derive(Debug)] -pub struct InnerSignedLtConfig { - pub is_lhs_neg: IsLtConfig, - pub is_rhs_neg: IsLtConfig, +struct InnerSignedLtConfig { + is_lhs_neg: IsLtConfig, + is_rhs_neg: IsLtConfig, config: InnerLtConfig, } diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index b8b2eba65..96b192d60 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -18,7 +18,6 @@ pub mod shift; pub mod shift_imm; pub mod slt; pub mod slti; -pub mod slti2; pub mod sltu; mod b_insn; @@ -35,6 +34,7 @@ mod memory; mod s_insn; #[cfg(test)] mod test; +#[cfg(test)] mod test_utils; pub trait RIVInstruction { diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index 013c3615a..af31961df 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -1,14 +1,17 @@ use std::marker::PhantomData; -use ceno_emul::{InsnKind, SWord, StepRecord}; +use ceno_emul::{InsnKind, SWord, StepRecord, Word}; use ff_ext::ExtensionField; -use super::{constants::UInt, i_insn::IInstructionConfig}; +use super::{ + constants::{UINT_LIMBS, UInt}, + i_insn::IInstructionConfig, +}; use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, - expression::{ToExpr, WitIn}, - gadgets::SignedLtConfig, + expression::{Expression, ToExpr, WitIn}, + gadgets::IsLtConfig, instructions::Instruction, set_val, tables::InsnRecord, @@ -20,18 +23,14 @@ use core::mem::MaybeUninit; #[derive(Debug)] pub struct InstructionConfig { i_insn: IInstructionConfig, - rs1_read: UInt, - // `imm` data is a field element (which is a u64 data since we're using Goldilock) - // and `imm` is used as an lookup argument in the instruction lookup. - // However, our current gadgets (IsLtConfig or SignedLtConfig) don't support a field element comparison. - // That's the reason why we add `imm_uint` to compare `rs1_read` in SignedLtConfig. + rs1_read: UInt, imm: WitIn, - imm_uint: UInt, #[allow(dead_code)] rd_written: UInt, - signed_lt: SignedLtConfig, + is_rs1_neg: IsLtConfig, + lt: IsLtConfig, } pub struct SltiInstruction(PhantomData); @@ -46,15 +45,25 @@ impl Instruction for SltiInstruction { fn construct_circuit(cb: &mut CircuitBuilder) -> Result { // If rs1_read < imm, rd_written = 1. Otherwise rd_written = 0 let rs1_read = UInt::new_unchecked(|| "rs1_read", cb)?; - let imm_uint = UInt::new_unchecked(|| "imm_uint", cb)?; let imm = cb.create_witin(|| "imm")?; - let lt = SignedLtConfig::construct_circuit(cb, || "rs1 < imm_uint", &rs1_read, &imm_uint)?; - let rd_written = UInt::from_exprs_unchecked(vec![lt.expr()])?; + let max_signed_limb_expr: Expression<_> = ((1 << (UInt::::LIMB_BITS - 1)) - 1).into(); + let is_rs1_neg = IsLtConfig::construct_circuit( + cb, + || "lhs_msb", + max_signed_limb_expr.clone(), + rs1_read.limbs.iter().last().unwrap().expr(), // msb limb + 1, + )?; - // Constrain imm == imm_uint by converting imm_uint to a field element - let imm_field_expr = imm_uint.to_field_expr(lt.config.is_rhs_neg.expr()); - cb.require_equal(|| "imm_uint == imm", imm_field_expr, imm.expr())?; + let lt = IsLtConfig::construct_circuit( + cb, + || "rs1 < imm", + rs1_read.to_field_expr(is_rs1_neg.expr()), + imm.expr(), + UINT_LIMBS, + )?; + let rd_written = UInt::from_exprs_unchecked(vec![lt.expr()])?; let i_insn = IInstructionConfig::::construct_circuit( cb, @@ -68,10 +77,10 @@ impl Instruction for SltiInstruction { Ok(InstructionConfig { i_insn, rs1_read, - imm_uint, imm, rd_written, - signed_lt: lt, + is_rs1_neg, + lt, }) } @@ -84,20 +93,25 @@ impl Instruction for SltiInstruction { config.i_insn.assign_instance(instance, lkm, step)?; let rs1 = step.rs1().unwrap().value; + let max_signed_limb = (1u64 << (UInt::::LIMB_BITS - 1)) - 1; + let rs1_value = Value::new_unchecked(rs1 as Word); config .rs1_read .assign_value(instance, Value::new_unchecked(rs1)); + config.is_rs1_neg.assign_instance( + instance, + lkm, + max_signed_limb, + *rs1_value.limbs.last().unwrap() as u64, + )?; let imm = step.insn().imm_or_funct7(); let imm_field = InsnRecord::imm_or_funct7_field::(&step.insn()); set_val!(instance, config.imm, imm_field); - config - .imm_uint - .assign_value(instance, Value::new_unchecked(imm)); config - .signed_lt - .assign_instance::(instance, lkm, rs1 as SWord, imm as SWord)?; + .lt + .assign_instance_signed(instance, lkm, rs1 as SWord, imm as SWord)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/slti2.rs b/ceno_zkvm/src/instructions/riscv/slti2.rs deleted file mode 100644 index 3f67b6107..000000000 --- a/ceno_zkvm/src/instructions/riscv/slti2.rs +++ /dev/null @@ -1,191 +0,0 @@ -use std::marker::PhantomData; - -use ceno_emul::{InsnKind, SWord, StepRecord, Word}; -use ff_ext::ExtensionField; - -use super::{ - constants::{UINT_LIMBS, UInt}, - i_insn::IInstructionConfig, -}; -use crate::{ - circuit_builder::CircuitBuilder, - error::ZKVMError, - expression::{Expression, ToExpr, WitIn}, - gadgets::IsLtConfig, - instructions::Instruction, - set_val, - tables::InsnRecord, - uint::Value, - witness::LkMultiplicity, -}; -use core::mem::MaybeUninit; - -#[derive(Debug)] -pub struct InstructionConfig { - i_insn: IInstructionConfig, - rs1_read: UInt, - - imm: WitIn, - #[allow(dead_code)] - rd_written: UInt, - - is_rs1_neg: IsLtConfig, - lt: IsLtConfig, - // signed_lt: SignedLtConfig, -} - -pub struct SltiInstruction2(PhantomData); - -impl Instruction for SltiInstruction2 { - type InstructionConfig = InstructionConfig; - - fn name() -> String { - format!("{:?}", InsnKind::SLTI) - } - - fn construct_circuit(cb: &mut CircuitBuilder) -> Result { - // If rs1_read < imm, rd_written = 1. Otherwise rd_written = 0 - let rs1_read = UInt::new_unchecked(|| "rs1_read", cb)?; - let imm = cb.create_witin(|| "imm")?; - - let max_signed_limb_expr: Expression<_> = ((1 << (UInt::::LIMB_BITS - 1)) - 1).into(); - let is_rs1_neg = IsLtConfig::construct_circuit( - cb, - || "lhs_msb", - max_signed_limb_expr.clone(), - rs1_read.limbs.iter().last().unwrap().expr(), // msb limb - 1, - )?; - - let lt = IsLtConfig::construct_circuit( - cb, - || "rs1 < imm", - rs1_read.to_field_expr(is_rs1_neg.expr()), - imm.expr(), - UINT_LIMBS, - )?; - let rd_written = UInt::from_exprs_unchecked(vec![lt.expr()])?; - - let i_insn = IInstructionConfig::::construct_circuit( - cb, - InsnKind::SLTI, - &imm.expr(), - rs1_read.register_expr(), - rd_written.register_expr(), - false, - )?; - - Ok(InstructionConfig { - i_insn, - rs1_read, - imm, - rd_written, - is_rs1_neg, - lt, - }) - } - - fn assign_instance( - config: &Self::InstructionConfig, - instance: &mut [MaybeUninit], - lkm: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - config.i_insn.assign_instance(instance, lkm, step)?; - - let rs1 = step.rs1().unwrap().value; - let max_signed_limb = (1u64 << (UInt::::LIMB_BITS - 1)) - 1; - let rs1_value = Value::new_unchecked(rs1 as Word); - config - .rs1_read - .assign_value(instance, Value::new_unchecked(rs1)); - config.is_rs1_neg.assign_instance( - instance, - lkm, - max_signed_limb, - *rs1_value.limbs.last().unwrap() as u64, - )?; - - let imm = step.insn().imm_or_funct7(); - let imm_field = InsnRecord::imm_or_funct7_field::(&step.insn()); - set_val!(instance, config.imm, imm_field); - - config - .lt - .assign_instance_signed(instance, lkm, rs1 as SWord, imm as SWord)?; - - Ok(()) - } -} - -#[cfg(test)] -mod test { - use ceno_emul::{Change, PC_STEP_SIZE, StepRecord, Word, encode_rv32}; - use goldilocks::GoldilocksExt2; - - use itertools::Itertools; - use multilinear_extensions::mle::IntoMLEs; - - use super::*; - use crate::{ - circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::{Instruction, riscv::test_utils::imm_i}, - scheme::mock_prover::{MOCK_PC_START, MockProver}, - }; - - fn verify(name: &'static str, rs1: i32, imm: i32, rd: Word) { - let mut cs = ConstraintSystem::::new(|| "riscv"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = cb - .namespace( - || format!("SLTI/{name}"), - |cb| { - let config = SltiInstruction2::construct_circuit(cb); - Ok(config) - }, - ) - .unwrap() - .unwrap(); - - let insn_code = encode_rv32(InsnKind::SLTI, 2, 0, 4, imm_i(imm)); - let (raw_witin, lkm) = - SltiInstruction2::assign_instances(&config, cb.cs.num_witin as usize, vec![ - StepRecord::new_i_instruction( - 3, - Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), - insn_code, - rs1 as Word, - Change::new(0, rd), - 0, - ), - ]) - .unwrap(); - - let expected_rd_written = - UInt::from_const_unchecked(Value::new_unchecked(rd).as_u16_limbs().to_vec()); - config - .rd_written - .require_equal(|| "assert_rd_written", &mut cb, &expected_rd_written) - .unwrap(); - - MockProver::assert_satisfied( - &cb, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), - &[insn_code], - None, - Some(lkm), - ); - } - - #[test] - fn test_slti_failed() { - verify("lt = true, -1 < 0", -1, 0, 1); - verify("lt = true, -1 < 1", -1, 1, 1); - verify("lt = true, -2 < -1", -2, -1, 1); - } -} diff --git a/ceno_zkvm/src/instructions/riscv/test_utils.rs b/ceno_zkvm/src/instructions/riscv/test_utils.rs index ab3795494..416ce628c 100644 --- a/ceno_zkvm/src/instructions/riscv/test_utils.rs +++ b/ceno_zkvm/src/instructions/riscv/test_utils.rs @@ -1,27 +1,22 @@ -#[cfg(test)] pub fn imm_b(imm: i32) -> u32 { // imm is 13 bits in B-type imm_with_max_valid_bits(imm, 13) } -#[cfg(test)] pub fn imm_i(imm: i32) -> u32 { // imm is 12 bits in I-type imm_with_max_valid_bits(imm, 12) } -#[cfg(test)] pub fn imm_j(imm: i32) -> u32 { // imm is 21 bits in J-type imm_with_max_valid_bits(imm, 21) } -#[cfg(test)] fn imm_with_max_valid_bits(imm: i32, bits: u32) -> u32 { imm as u32 & !(u32::MAX << bits) } -#[cfg(test)] pub fn imm_u(imm: u32) -> u32 { // valid imm is imm[12:31] in U-type imm << 12