Skip to content

Commit

Permalink
overflow sll working
Browse files Browse the repository at this point in the history
  • Loading branch information
zemse committed Oct 2, 2024
1 parent 91ad507 commit 5059867
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 11 deletions.
51 changes: 49 additions & 2 deletions ceno_zkvm/src/instructions/riscv/config.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::{fmt::Display, mem::MaybeUninit};
use std::{fmt::Display, marker::PhantomData, mem::MaybeUninit};

use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::{Expression, ToExpr, WitIn},
gadgets::IsLtConfig,
gadgets::{IsLtConfig, IsZeroConfig},
set_val,
utils::i64_to_base,
witness::LkMultiplicity,
Expand Down Expand Up @@ -326,3 +326,50 @@ impl UIntLtSignedConfig {
Ok(())
}
}

/// Gadget to get a boolean expression if an expression is in the range.
pub struct RangeCheckU5<E: ExtensionField> {
pub is_zero: IsZeroConfig,
_phantom: PhantomData<E>,
}

impl<E: ExtensionField> RangeCheckU5<E> {
pub fn bool_expr(&self) -> Expression<E> {
self.is_zero.expr()
}

pub fn construct_circuit<NR: Into<String> + Display + Clone, N: FnOnce() -> NR>(
cb: &mut CircuitBuilder<E>,
name_fn: N,
expr: &Expression<E>,
) -> Result<Self, ZKVMError> {
cb.namespace(
|| "range_check_u5",
|cb| {
let range_check = (0..32)
.map(|i| expr.clone() - i.into())
.reduce(|a, b| (a * b).to_monomial_form())
.unwrap();

let is_zero = IsZeroConfig::construct_circuit(cb, name_fn, range_check)?;
Ok(RangeCheckU5 {
is_zero,
_phantom: PhantomData,
})
},
)
}

pub fn assign_instance(
&self,
instance: &mut [MaybeUninit<E::BaseField>],
value: u64,
) -> Result<(), ZKVMError> {
let range_check = (0..32)
.map(|i| E::BaseField::from(value.wrapping_sub(i)))
.reduce(|a, b| a * b)
.unwrap();
self.is_zero.assign_instance(instance, range_check)?;
Ok(())
}
}
49 changes: 42 additions & 7 deletions ceno_zkvm/src/instructions/riscv/sll.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@ use std::marker::PhantomData;
use ceno_emul::InsnKind;
use ff_ext::ExtensionField;

use crate::{instructions::Instruction, Value};
use crate::{
expression::{ToExpr, WitIn},
instructions::Instruction,
Value,
};

use super::{constants::UInt, r_insn::RInstructionConfig, RIVInstruction};
use super::{config::RangeCheckU5, constants::UInt, r_insn::RInstructionConfig, RIVInstruction};

pub struct ShiftLeftConfig<E: ExtensionField> {
r_insn: RInstructionConfig<E>,
Expand All @@ -14,6 +18,8 @@ pub struct ShiftLeftConfig<E: ExtensionField> {
rs2_read: UInt<E>,
rd_written: UInt<E>,
shift: UInt<E>,
rs2_lt_32: RangeCheckU5<E>,
intermediate: WitIn,
}

pub struct ShiftLeftLogicalInstruction<E>(PhantomData<E>);
Expand All @@ -35,9 +41,24 @@ impl<E: ExtensionField> Instruction<E> for ShiftLeftLogicalInstruction<E> {
// rs1_read * rs2_read = rd_written
let mut rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?;
let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?;

let rs2_lt_32 =
RangeCheckU5::construct_circuit(circuit_builder, || "rs2 < 32", &rs2_read.value())?;
let intermediate = circuit_builder.create_witin(|| "intermediate")?;
let mut shift = UInt::new_unchecked(|| "shift", circuit_builder)?;
circuit_builder.lookup_pow(2.into(), rs2_read.value(), shift.value())?;

circuit_builder.require_equal(
|| "intermediate == rs2_lt_32 * rs2_read",
rs2_lt_32.bool_expr() * rs2_read.value(),
intermediate.expr(),
)?;

// rs2 < 32 then [2, rs2_read, shift] i.e. 2 ** rs2_read == shift
// rse >= 32 then [0, 0, 0] i.e. shift == 0
circuit_builder.lookup_pow(
rs2_lt_32.bool_expr() * 2.into(),
intermediate.expr(), // because degree 1 required
shift.value(),
)?;

let rd_written = rs1_read.mul(|| "rd_written", circuit_builder, &mut shift, true)?;

Expand All @@ -55,6 +76,8 @@ impl<E: ExtensionField> Instruction<E> for ShiftLeftLogicalInstruction<E> {
rs2_read,
rd_written,
shift,
rs2_lt_32,
intermediate,
})
}

Expand All @@ -66,12 +89,24 @@ impl<E: ExtensionField> Instruction<E> for ShiftLeftLogicalInstruction<E> {
) -> Result<(), crate::error::ZKVMError> {
let rs1_read = Value::new_unchecked(step.rs1().unwrap().value);
let rs2_read = Value::new_unchecked(step.rs2().unwrap().value);
let shift = Value::new_unchecked(1u32.wrapping_shl(rs2_read.as_u64() as u32));
let intermediate = (rs2_read.as_u64() < 32) as u64 * rs2_read.as_u64();
let shift = Value::new_unchecked(if rs2_read.as_u64() < 32 {
1u32.wrapping_shl(rs2_read.as_u64() as u32)
} else {
0
});
let rd_written = rs1_read.mul(&shift, lk_multiplicity, true);

config
.intermediate
.assign::<E>(instance, intermediate.into());

config
.r_insn
.assign_instance(instance, lk_multiplicity, step)?;
config
.rs2_lt_32
.assign_instance(instance, rs2_read.as_u64())?;
config.rs1_read.assign_value(instance, rs1_read);
config.rs2_read.assign_value(instance, rs2_read);
config.shift.assign_value(instance, shift);
Expand Down Expand Up @@ -99,7 +134,7 @@ mod tests {
use super::ShiftLeftLogicalInstruction;

#[test]
fn test_opcode_sll() {
fn test_opcode_sll_1() {
let mut cs = ConstraintSystem::<GoldilocksExt2>::new(|| "riscv");
let mut cb = CircuitBuilder::new(&mut cs);
let config = cb
Expand Down Expand Up @@ -142,7 +177,7 @@ mod tests {
}

#[test]
fn test_opcode_sll_overflow() {
fn test_opcode_sll_2_overflow() {
let mut cs = ConstraintSystem::<GoldilocksExt2>::new(|| "riscv");
let mut cb = CircuitBuilder::new(&mut cs);
let config = cb
Expand Down
7 changes: 5 additions & 2 deletions ceno_zkvm/src/tables/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,13 @@ pub struct PowTable;
impl OpsTable for PowTable {
const ROM_TYPE: ROMType = ROMType::Pow;
fn len() -> usize {
1 << 5
(1 << 5) + 1
}

fn content() -> Vec<[u64; 3]> {
(0..Self::len() as u64).map(|b| [2, b, 1 << b]).collect()
(0..Self::len() as u64)
.map(|b| [2, b, 1 << b])
.chain(std::iter::once([0, 0, 0]))
.collect()
}
}

0 comments on commit 5059867

Please sign in to comment.