Skip to content

Commit

Permalink
moving out sltu opcode and adding testing
Browse files Browse the repository at this point in the history
  • Loading branch information
KimiWu123 committed Sep 26, 2024
1 parent d732adb commit 23644b7
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 33 deletions.
25 changes: 17 additions & 8 deletions ceno_zkvm/src/gadgets/lt.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::mem::MaybeUninit;
use std::{
mem::MaybeUninit,
ops::{Add, Sub},
};

use ff_ext::ExtensionField;
use goldilocks::SmallField;
Expand Down Expand Up @@ -55,17 +58,23 @@ impl<E: ExtensionField> LtGadget<E> {
lhs: E::BaseField,
rhs: E::BaseField,
) -> Result<(), ZKVMError> {
let lhs = lhs.to_canonical_u64();
let rhs = rhs.to_canonical_u64();

// Set `lt`
let lt = lhs < rhs;
let lt = lhs.to_canonical_u64() < rhs.to_canonical_u64();
set_val!(instance, self.lt, lt as u64);

// Set `diff`
let diff = lhs - rhs + (if lt { 1 << UInt::<E>::M } else { 0 });
self.diff
.assign_limbs(instance, Value::new(diff, lkm).u16_fields());
let diff = lhs.sub(rhs).add(if lt {
E::BaseField::from(1 << UInt::<E>::M)
} else {
E::BaseField::from(0)
});
self.diff.assign_limbs(
instance,
#[cfg(feature = "riv32")]
Value::new(diff.to_canonical_u64() as u32, lkm).u16_fields(),
#[cfg(feature = "riv64")]
Value::new(diff.to_canonical_u64(), lkm).u16_fields(),
);

Ok(())
}
Expand Down
25 changes: 2 additions & 23 deletions ceno_zkvm/src/instructions/riscv/arith.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use itertools::Itertools;

use super::{constants::UInt, r_insn::RInstructionConfig, RIVInstruction};
use crate::{
circuit_builder::CircuitBuilder, error::ZKVMError, gadgets::lt::LtGadget,
instructions::Instruction, uint::Value, witness::LkMultiplicity,
circuit_builder::CircuitBuilder, error::ZKVMError, instructions::Instruction, uint::Value,
witness::LkMultiplicity,
};
use core::mem::MaybeUninit;

Expand Down Expand Up @@ -41,12 +41,6 @@ impl RIVInstruction for MulOp {
}
pub type MulInstruction<E> = ArithInstruction<E, MulOp>;

pub struct SLTUOp;
impl RIVInstruction for SLTUOp {
const INST_KIND: InsnKind = InsnKind::SLTU;
}
pub type SltuInstruction<E> = ArithInstruction<E, SLTUOp>;

impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ArithInstruction<E, I> {
type InstructionConfig = ArithConfig<E>;

Expand Down Expand Up @@ -89,21 +83,6 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ArithInstruction<E
(rs1_read, rs2_read, rd_written)
}

InsnKind::SLTU => {
// If rs1_read < rs2_read, rd_written = 1. Otherwise rd_written = 0
let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?;
let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?;

let lt = LtGadget::construct_circuit(
circuit_builder,
rs1_read.value(),
rs2_read.value(),
)?;
let rd_written = UInt::new(|| "rd_written", circuit_builder)?;
circuit_builder.require_equal(|| "rd == lt", rd_written.value(), lt.expr())?;
(rs1_read, rs2_read, rd_written)
}

_ => unreachable!("Unsupported instruction kind"),
};

Expand Down
185 changes: 185 additions & 0 deletions ceno_zkvm/src/instructions/riscv/sltu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
use std::marker::PhantomData;

use ceno_emul::{InsnKind, StepRecord};
use ff_ext::ExtensionField;

use super::{constants::UInt, r_insn::RInstructionConfig, RIVInstruction};
use crate::{
circuit_builder::CircuitBuilder, error::ZKVMError, gadgets::lt::LtGadget,
instructions::Instruction, uint::Value, witness::LkMultiplicity,
};
use core::mem::MaybeUninit;

/// This config handles R-Instructions that represent registers values as 2 * u16.
#[derive(Debug)]
pub struct ArithConfig<E: ExtensionField> {
r_insn: RInstructionConfig<E>,

rs1_read: UInt<E>,
rs2_read: UInt<E>,
rd_written: UInt<E>,

lt_gadget: LtGadget<E>,
}

pub struct ArithInstruction<E, I>(PhantomData<(E, I)>);

pub struct SLTUOp;
impl RIVInstruction for SLTUOp {
const INST_KIND: InsnKind = InsnKind::SLTU;
}
pub type SltuInstruction<E> = ArithInstruction<E, SLTUOp>;

impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ArithInstruction<E, I> {
type InstructionConfig = ArithConfig<E>;

fn name() -> String {
format!("{:?}", I::INST_KIND)
}

fn construct_circuit(
circuit_builder: &mut CircuitBuilder<E>,
) -> Result<Self::InstructionConfig, ZKVMError> {
// If rs1_read < rs2_read, rd_written = 1. Otherwise rd_written = 0
let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?;
let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?;

let lt = LtGadget::construct_circuit(circuit_builder, rs1_read.value(), rs2_read.value())?;
let rd_written = UInt::new(|| "rd_written", circuit_builder)?;
circuit_builder.require_equal(|| "rd == lt", rd_written.value(), lt.expr())?;

let r_insn = RInstructionConfig::<E>::construct_circuit(
circuit_builder,
I::INST_KIND,
rs1_read.register_expr(),
rs2_read.register_expr(),
rd_written.register_expr(),
)?;

Ok(ArithConfig {
r_insn,
rs1_read,
rs2_read,
rd_written,
lt_gadget: lt,
})
}

fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<<E as ExtensionField>::BaseField>],
lkm: &mut LkMultiplicity,
step: &StepRecord,
) -> Result<(), ZKVMError> {
config.r_insn.assign_instance(instance, lkm, step)?;

let rs1 = step.rs1().unwrap().value;
let rs2 = step.rs2().unwrap().value;

let rs1_read = Value::new_unchecked(rs1);
let rs2_read = Value::new_unchecked(rs2);
config
.rs1_read
.assign_limbs(instance, rs1_read.u16_fields());
config
.rs2_read
.assign_limbs(instance, rs2_read.u16_fields());
config.lt_gadget.assign(
instance,
lkm,
E::BaseField::from(rs1.into()),
E::BaseField::from(rs2.into()),
)?;

let lt = if rs1 < rs2 {
Value::new_unchecked(1u32)
} else {
Value::new_unchecked(0u32)
};
config.rd_written.assign_limbs(instance, lt.u16_fields());

Ok(())
}
}

#[cfg(test)]
mod test {
use std::u32;

use ceno_emul::{Change, StepRecord, Word, CENO_PLATFORM};
use goldilocks::GoldilocksExt2;
use itertools::Itertools;
use multilinear_extensions::mle::IntoMLEs;
use rand::Rng;

use super::*;
use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
instructions::Instruction,
scheme::mock_prover::{MockProver, MOCK_PC_SLTU, MOCK_PROGRAM},
};

fn verify(name: &'static str, rs1: Word, rs2: Word, rd: Word) {
let mut cs = ConstraintSystem::<GoldilocksExt2>::new(|| "riscv");
let mut cb = CircuitBuilder::new(&mut cs);
let config = cb
.namespace(
|| format!("SLTU/{name}"),
|cb| {
let config = SltuInstruction::construct_circuit(cb);
Ok(config)
},
)
.unwrap()
.unwrap();

let idx = (MOCK_PC_SLTU.0 - CENO_PLATFORM.pc_start()) / 4;
let (raw_witin, _) = SltuInstruction::assign_instances(
&config,
cb.cs.num_witin as usize,
vec![StepRecord::new_r_instruction(
3,
MOCK_PC_SLTU,
MOCK_PROGRAM[idx as usize],
rs1,
rs2,
Change::new(0, rd),
0,
)],
)
.unwrap();

MockProver::assert_satisfied(
&mut cb,
&raw_witin
.de_interleaving()
.into_mles()
.into_iter()
.map(|v| v.into())
.collect_vec(),
None,
);
}

#[test]
fn test_sltu_simple() {
verify("lt = true, 0 < 1", 0, 1, 1);
verify("lt = true, 1 < 2", 1, 2, 1);
verify("lt = true, 0 < u32::MAX", 0, u32::MAX, 1);
verify("lt = true, u32::MAX - 1", u32::MAX - 1, u32::MAX, 1);
verify("lt = false, u32::MAX", u32::MAX, u32::MAX, 0);
verify("lt = false, u32::MAX - 1", u32::MAX, u32::MAX - 1, 0);
verify("lt = false, u32::MAX > 0", u32::MAX, 0, 0);
verify("lt = false, 2 > 1", 2, 1, 0);
}

#[test]
fn test_sltu_random() {
let mut rng = rand::thread_rng();
let a: u32 = rng.gen();
let b: u32 = rng.gen();
println!("random: {}, {}", a, b);
verify("random 1", a, b, (a < b) as u32);
verify("random 2", b, a, !(a < b) as u32);
}
}
6 changes: 4 additions & 2 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,13 @@ pub const MOCK_PROGRAM: &[u32] = &[
// blt x2, x3, -8
0b_1_111111 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b_100 << 12 | 0b_1100_1 << 7 | 0x63,
// divu (0x01, 0x05, 0x33)
0x01 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0x05 << 12 | MOCK_RD << 7 | 0x33,
0x01 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b101 << 12 | MOCK_RD << 7 | 0x33,
// srli x4, x2, 3
0x00 << 25 | MOCK_IMM_3 << 20 | MOCK_RS1 << 15 | 0x05 << 12 | MOCK_RD << 7 | 0x13,
// srli x4, x2, 31
0x00 << 25 | MOCK_IMM_31 << 20 | MOCK_RS1 << 15 | 0x05 << 12 | MOCK_RD << 7 | 0x13,
// sltu (0x00, 0x03, 0x33)
0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b011 << 12 | MOCK_RD << 7 | 0x33,
];
// Addresses of particular instructions in the mock program.
pub const MOCK_PC_ADD: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start());
Expand All @@ -83,8 +85,8 @@ pub const MOCK_PC_BLT: Change<ByteAddr> = Change {
pub const MOCK_PC_DIVU: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 36);
pub const MOCK_PC_SRLI: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 40);
pub const MOCK_PC_SRLI_31: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 44);
pub const MOCK_PC_SLTU: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 48);

#[allow(clippy::enum_variant_names)]
#[derive(Debug, PartialEq, Clone)]
pub(crate) enum MockProverError<E: ExtensionField> {
AssertZeroError {
Expand Down

0 comments on commit 23644b7

Please sign in to comment.