Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SRL and SLL #304

Merged
merged 14 commits into from
Oct 8, 2024
5 changes: 5 additions & 0 deletions ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,11 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
self.logic_u8(ROMType::Ltu, a, b, c)
}

// Assert that `2^b = c` and that `b` is a 5-bit unsigned integer.
pub fn lookup_pow2(&mut self, b: Expression<E>, c: Expression<E>) -> Result<(), ZKVMError> {
self.logic_u8(ROMType::Pow, 2.into(), b, c)
}

/// less_than
pub(crate) fn less_than<N, NR>(
&mut self,
Expand Down
72 changes: 72 additions & 0 deletions ceno_zkvm/src/gadgets/div.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
use std::{fmt::Display, mem::MaybeUninit};

use ff_ext::ExtensionField;

use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
instructions::riscv::constants::{UInt, BIT_WIDTH},
witness::LkMultiplicity,
Value,
};

use super::IsLtConfig;

/// divide gadget
#[derive(Debug, Clone)]
pub struct DivConfig<E: ExtensionField> {
pub dividend: UInt<E>,
pub r_lt: IsLtConfig,
pub intermediate_mul: UInt<E>,
}

impl<E: ExtensionField> DivConfig<E> {
/// giving divisor, quotient, and remainder
/// deriving dividend and respective constrains
/// NOTE once divisor is zero, then constrain will always failed
pub fn construct_circuit<NR: Into<String> + Display + Clone, N: FnOnce() -> NR>(
circuit_builder: &mut CircuitBuilder<E>,
name_fn: N,
divisor: &mut UInt<E>,
quotient: &mut UInt<E>,
remainder: &UInt<E>,
) -> Result<Self, ZKVMError> {
circuit_builder.namespace(name_fn, |cb| {
let intermediate_mul =
divisor.mul::<BIT_WIDTH, _, _>(|| "divisor_mul", cb, quotient, true)?;
let dividend = intermediate_mul.add(|| "dividend_add", cb, remainder, true)?;

// remainder range check
let r_lt = cb.less_than(
|| "remainder < divisor",
remainder.value(),
divisor.value(),
Some(true),
UInt::<E>::NUM_CELLS,
)?;
Ok(Self {
dividend,
intermediate_mul,
r_lt,
})
})
}

pub fn assign_instance<'a>(
&self,
instance: &mut [MaybeUninit<E::BaseField>],
lkm: &mut LkMultiplicity,
divisor: &Value<'a, u32>,
quotient: &Value<'a, u32>,
remainder: &Value<'a, u32>,
) -> Result<(), ZKVMError> {
let (dividend, intermediate) = divisor.mul_add(quotient, remainder, lkm, true);

self.r_lt
.assign_instance(instance, lkm, remainder.as_u64(), divisor.as_u64())?;
self.intermediate_mul
.assign_mul_outcome(instance, lkm, &intermediate)?;
self.dividend.assign_add_outcome(instance, &dividend);
Ok(())
}
}
2 changes: 2 additions & 0 deletions ceno_zkvm/src/gadgets/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
mod div;
mod is_lt;
mod is_zero;
pub use div::DivConfig;
pub use is_lt::IsLtConfig;
pub use is_zero::{IsEqualConfig, IsZeroConfig};
1 change: 1 addition & 0 deletions ceno_zkvm/src/instructions/riscv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub mod divu;
pub mod ecall;
pub mod logic;
pub mod mulh;
pub mod shift;
pub mod shift_imm;
pub mod sltu;

Expand Down
297 changes: 297 additions & 0 deletions ceno_zkvm/src/instructions/riscv/shift.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
use std::{marker::PhantomData, mem::MaybeUninit};

use ceno_emul::InsnKind;
use ff_ext::ExtensionField;

use crate::{
expression::{ToExpr, WitIn},
gadgets::DivConfig,
instructions::Instruction,
set_val, Value,
};

use super::{constants::UInt, r_insn::RInstructionConfig, RIVInstruction};

pub struct ShiftConfig<E: ExtensionField> {
r_insn: RInstructionConfig<E>,

rs1_read: UInt<E>,
rs2_read: UInt<E>,
rd_written: UInt<E>,

rs2_high: UInt<E>,
rs2_low5: WitIn,
pow2_rs2_low5: UInt<E>,

// for SRL division arithmetics
remainder: Option<UInt<E>>,
div_config: Option<DivConfig<E>>,
}

pub struct ShiftLogicalInstruction<E, I>(PhantomData<(E, I)>);

struct SllOp;
impl RIVInstruction for SllOp {
const INST_KIND: InsnKind = InsnKind::SLL;
}

struct SrlOp;
impl RIVInstruction for SrlOp {
const INST_KIND: InsnKind = InsnKind::SRL;
}

impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ShiftLogicalInstruction<E, I> {
type InstructionConfig = ShiftConfig<E>;

fn name() -> String {
format!("{:?}", I::INST_KIND)
}

fn construct_circuit(
circuit_builder: &mut crate::circuit_builder::CircuitBuilder<E>,
) -> Result<Self::InstructionConfig, crate::error::ZKVMError> {
let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?;
let rs2_low5 = circuit_builder.create_witin(|| "rs2_low5")?;
KimiWu123 marked this conversation as resolved.
Show resolved Hide resolved
// pow2_rs2_low5 is unchecked because it's assignment will be constrained due it's use in lookup_pow2 below
let mut pow2_rs2_low5 = UInt::new_unchecked(|| "pow2_rs2_low5", circuit_builder)?;
// rs2 = rs2_high | rs2_low5
let rs2_high = UInt::new(|| "rs2_high", circuit_builder)?;

let (rs1_read, rd_written, remainder, div_config) = match I::INST_KIND {
InsnKind::SLL => {
let mut rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?;
let rd_written = rs1_read.mul(
|| "rd_written = rs1_read * pow2_rs2_low5",
circuit_builder,
&mut pow2_rs2_low5,
true,
)?;
(rs1_read, rd_written, None, None)
}
InsnKind::SRL => {
let mut rd_written = UInt::new(|| "rd_written", circuit_builder)?;
let remainder = UInt::new(|| "remainder", circuit_builder)?;
let div_config = DivConfig::construct_circuit(
circuit_builder,
|| "srl_div",
&mut pow2_rs2_low5,
&mut rd_written,
&remainder,
)?;
(
div_config.dividend.clone(),
rd_written,
Some(remainder),
Some(div_config),
)
}
_ => unreachable!(),
};

let r_insn = RInstructionConfig::<E>::construct_circuit(
circuit_builder,
I::INST_KIND,
rs1_read.register_expr(),
rs2_read.register_expr(),
rd_written.register_expr(),
)?;

circuit_builder.lookup_pow2(rs2_low5.expr(), pow2_rs2_low5.value())?;
circuit_builder.assert_ux::<_, _, 5>(|| "rs2_low5 in u5", rs2_low5.expr())?;
circuit_builder.require_equal(
|| "rs2 == rs2_high * 2^5 + rs2_low5",
rs2_read.value(),
rs2_high.value() * (1 << 5).into() + rs2_low5.expr(),
)?;

Ok(ShiftConfig {
r_insn,
rs1_read,
rs2_read,
rd_written,
rs2_high,
rs2_low5,
pow2_rs2_low5,
remainder,
div_config,
})
}

fn assign_instance(
zemse marked this conversation as resolved.
Show resolved Hide resolved
config: &Self::InstructionConfig,
instance: &mut [std::mem::MaybeUninit<<E as ExtensionField>::BaseField>],
lk_multiplicity: &mut crate::witness::LkMultiplicity,
step: &ceno_emul::StepRecord,
) -> Result<(), crate::error::ZKVMError> {
let rs2_read = Value::new_unchecked(step.rs2().unwrap().value);
let rs2_low5 = rs2_read.as_u64() & 0b11111;
let pow2_rs2_low5 = Value::new_unchecked((1 << rs2_low5) as u32);
let rs2_high = Value::new(
((rs2_read.as_u64() - rs2_low5) >> 5) as u32,
lk_multiplicity,
);

match I::INST_KIND {
InsnKind::SLL => {
let rs1_read = Value::new_unchecked(step.rs1().unwrap().value);
let rd_written = rs1_read.mul(&pow2_rs2_low5, lk_multiplicity, true);
config.rs1_read.assign_value(instance, rs1_read);
config
.rd_written
.assign_mul_outcome(instance, lk_multiplicity, &rd_written)?;
}
InsnKind::SRL => {
let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity);
let remainder = Value::new(
// rs1 - rd * pow2_rs2_low5
step.rs1()
.unwrap()
.value
.wrapping_sub((rd_written.as_u64() * pow2_rs2_low5.as_u64()) as u32),
lk_multiplicity,
);

config.div_config.as_ref().unwrap().assign_instance(
instance,
lk_multiplicity,
&pow2_rs2_low5,
&rd_written,
&remainder,
)?;

config.rd_written.assign_value(instance, rd_written);
config
.remainder
.as_ref()
.unwrap()
.assign_value(instance, remainder);
}
_ => unreachable!(),
}

config
.r_insn
.assign_instance(instance, lk_multiplicity, step)?;
config.rs2_read.assign_value(instance, rs2_read);
set_val!(instance, config.rs2_low5, rs2_low5);
config.rs2_high.assign_value(instance, rs2_high);
config.pow2_rs2_low5.assign_value(instance, pow2_rs2_low5);

Ok(())
}
}

#[cfg(test)]
mod tests {
zemse marked this conversation as resolved.
Show resolved Hide resolved
use ceno_emul::{Change, InsnKind, StepRecord};
use goldilocks::GoldilocksExt2;
use itertools::Itertools;
use multilinear_extensions::mle::IntoMLEs;

use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
instructions::{
riscv::{constants::UInt, RIVInstruction},
Instruction,
},
scheme::mock_prover::{MockProver, MOCK_PC_SLL, MOCK_PC_SRL, MOCK_PROGRAM},
Value,
};

use super::{ShiftLogicalInstruction, SllOp, SrlOp};

#[test]
fn test_opcode_sll_1() {
zemse marked this conversation as resolved.
Show resolved Hide resolved
verify::<SllOp>(0b_1, 3, 0b_1000);
}

#[test]
fn test_opcode_sll_2_rs2_overflow() {
// 33 << 33 === 33 << 1
verify::<SllOp>(0b_1, 33, 0b_10);
}

#[test]
fn test_opcode_sll_3_bit_loss() {
verify::<SllOp>(1 << 31 | 1, 1, 0b_10);
}
zemse marked this conversation as resolved.
Show resolved Hide resolved

#[test]
fn test_opcode_srl_1() {
verify::<SrlOp>(0b_1000, 3, 0b_1);
}

#[test]
fn test_opcode_srl_2_rs2_overflow() {
// 33 >> 33 === 33 >> 1
verify::<SrlOp>(0b_1010, 33, 0b_101);
}

#[test]
fn test_opcode_srl_3_bit_loss() {
// 33 >> 33 === 33 >> 1
verify::<SrlOp>(0b_1001, 1, 0b_100);
}

fn verify<I: RIVInstruction>(rs1_read: u32, rs2_read: u32, expected_rd_written: u32) {
let mut cs = ConstraintSystem::<GoldilocksExt2>::new(|| "riscv");
let mut cb = CircuitBuilder::new(&mut cs);

let (name, mock_pc, mock_program_op) = match I::INST_KIND {
InsnKind::SLL => ("SLL", MOCK_PC_SLL, MOCK_PROGRAM[19]),
InsnKind::SRL => ("SRL", MOCK_PC_SRL, MOCK_PROGRAM[20]),
_ => unreachable!(),
};

let config = cb
.namespace(
|| name,
|cb| {
let config =
ShiftLogicalInstruction::<GoldilocksExt2, I>::construct_circuit(cb);
Ok(config)
},
)
.unwrap()
.unwrap();

config
.rd_written
.require_equal(
|| "assert_rd_written",
&mut cb,
&UInt::from_const_unchecked(
Value::new_unchecked(expected_rd_written)
.as_u16_limbs()
.to_vec(),
),
)
.unwrap();

let (raw_witin, _) = ShiftLogicalInstruction::<GoldilocksExt2, I>::assign_instances(
&config,
cb.cs.num_witin as usize,
vec![StepRecord::new_r_instruction(
3,
mock_pc,
mock_program_op,
rs1_read,
rs2_read,
Change::new(0, expected_rd_written),
0,
)],
)
.unwrap();

MockProver::assert_satisfied(
&cb,
&raw_witin
.de_interleaving()
.into_mles()
.into_iter()
.map(|v| v.into())
.collect_vec(),
None,
);
}
}
Loading
Loading