Skip to content

Commit

Permalink
SRAI (#463)
Browse files Browse the repository at this point in the history
This PR adds SRAI to shift_imm, also changes logic for SLLI and SRLI.

Closes #365
  • Loading branch information
zemse authored Oct 28, 2024
1 parent ef5caa9 commit 47f5572
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 91 deletions.
223 changes: 132 additions & 91 deletions ceno_zkvm/src/instructions/riscv/shift_imm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ use crate::{
Value,
circuit_builder::CircuitBuilder,
error::ZKVMError,
gadgets::DivConfig,
expression::{Expression, ToExpr, WitIn},
gadgets::{AssertLTConfig, IsLtConfig},
instructions::{
Instruction,
riscv::{constants::UInt, i_insn::IInstructionConfig},
},
set_val,
witness::LkMultiplicity,
};
use ceno_emul::{InsnKind, StepRecord};
Expand All @@ -17,13 +19,14 @@ use std::{marker::PhantomData, mem::MaybeUninit};
pub struct ShiftImmConfig<E: ExtensionField> {
i_insn: IInstructionConfig<E>,

imm: WitIn,
rs1_read: UInt<E>,
imm: UInt<E>,
rd_written: UInt<E>,
outflow: WitIn,
assert_lt_config: AssertLTConfig,

// for SRLI division arithmetics
remainder: Option<UInt<E>>,
div_config: Option<DivConfig<E>>,
// SRAI
is_lt_config: Option<IsLtConfig>,
}

pub struct ShiftImmInstruction<E, I>(PhantomData<(E, I)>);
Expand All @@ -33,9 +36,14 @@ impl RIVInstruction for SlliOp {
const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SLLI;
}

pub struct SraiOp;
impl RIVInstruction for SraiOp {
const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SRAI;
}

pub struct SrliOp;
impl RIVInstruction for SrliOp {
const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SRLI;
const INST_KIND: ceno_emul::InsnKind = InsnKind::SRLI;
}

impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ShiftImmInstruction<E, I> {
Expand All @@ -48,47 +56,64 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ShiftImmInstructio
fn construct_circuit(
circuit_builder: &mut CircuitBuilder<E>,
) -> Result<Self::InstructionConfig, ZKVMError> {
let mut imm = UInt::new(|| "imm", circuit_builder)?;
// Note: `imm` wtns is set to 2**imm (upto 32 bit) just for efficient verification.
let imm = circuit_builder.create_witin(|| "imm")?;
let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?;
let rd_written = UInt::new(|| "rd_written", circuit_builder)?;

let outflow = circuit_builder.create_witin(|| "outflow")?;
let assert_lt_config = AssertLTConfig::construct_circuit(
circuit_builder,
|| "outflow < imm",
outflow.expr(),
imm.expr(),
2,
)?;

let two_pow_total_bits: Expression<_> = (1u64 << UInt::<E>::TOTAL_BITS).into();

// Note: `imm` is set to 2**imm (upto 32 bit) just for efficient verification
// Goal is to constrain:
// rs1 == rd_written * imm + remainder
let (rs1_read, rd_written, remainder, div_config) = match I::INST_KIND {
let is_lt_config = match I::INST_KIND {
InsnKind::SLLI => {
let mut rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?;
let rd_written = rs1_read.mul(
|| "rd_written = rs1_read * imm",
circuit_builder,
&mut imm,
true,
circuit_builder.require_equal(
|| "shift check",
rs1_read.value() * imm.expr(), // inflow is zero for this case
outflow.expr() * two_pow_total_bits + rd_written.value(),
)?;

(rs1_read, rd_written, None, None)
None
}
InsnKind::SRLI => {
let mut rd_written = UInt::new(|| "rd_written", circuit_builder)?;
let remainder = UInt::new(|| "remainder", circuit_builder)?;
let div_config = DivConfig::construct_circuit(
circuit_builder,
|| "srli_div",
&mut imm,
&mut rd_written,
&remainder,
InsnKind::SRAI | InsnKind::SRLI => {
let (inflow, is_lt_config) = match I::INST_KIND {
InsnKind::SRAI => {
let max_signed_limb_expr: Expression<_> =
((1 << (UInt::<E>::LIMB_BITS - 1)) - 1).into();
let is_rs1_neg = IsLtConfig::construct_circuit(
circuit_builder,
|| "lhs_msb",
max_signed_limb_expr.clone(),
rs1_read.limbs.iter().last().unwrap().expr(), // msb limb
1,
)?;
let msb_expr: Expression<E> = is_rs1_neg.is_lt.expr();
let ones = imm.expr() - Expression::ONE;
(msb_expr * ones, Some(is_rs1_neg))
}
InsnKind::SRLI => (Expression::ZERO, None),
_ => unreachable!(),
};
circuit_builder.require_equal(
|| "shift check",
rd_written.value() * imm.expr() + outflow.expr(),
inflow * two_pow_total_bits + rs1_read.value(),
)?;
(
div_config.dividend.clone(),
rd_written,
Some(remainder),
Some(div_config),
)
is_lt_config
}
_ => unreachable!(),
_ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND),
};

let i_insn = IInstructionConfig::<E>::construct_circuit(
circuit_builder,
I::INST_KIND,
&imm.value(),
&imm.expr(),
rs1_read.register_expr(),
rd_written.register_expr(),
false,
Expand All @@ -97,10 +122,11 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ShiftImmInstructio
Ok(ShiftImmConfig {
i_insn,
imm,
rd_written,
rs1_read,
remainder,
div_config,
rd_written,
outflow,
assert_lt_config,
is_lt_config,
})
}

Expand All @@ -110,38 +136,36 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ShiftImmInstructio
lk_multiplicity: &mut LkMultiplicity,
step: &StepRecord,
) -> Result<(), ZKVMError> {
let imm = Value::new(step.insn().imm_or_funct7(), lk_multiplicity);
match I::INST_KIND {
InsnKind::SLLI => {
let rs1_read = Value::new_unchecked(step.rs1().unwrap().value);
let rd_written = rs1_read.mul(&imm, lk_multiplicity, true);
config.rs1_read.assign_value(instance, rs1_read);
config
.rd_written
.assign_mul_outcome(instance, lk_multiplicity, &rd_written)?;
}
InsnKind::SRLI => {
let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity);
let rs1_read = step.rs1().unwrap().value;
let remainder = Value::new(rs1_read % imm.as_u32(), lk_multiplicity);
config.div_config.as_ref().unwrap().assign_instance(
instance,
lk_multiplicity,
&imm,
&rd_written,
&remainder,
)?;
config
.remainder
.as_ref()
.unwrap()
.assign_value(instance, remainder);
config.rd_written.assign_value(instance, rd_written);
let imm = step.insn().imm_or_funct7();
let rs1_read = Value::new_unchecked(step.rs1().unwrap().value);
let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity);

set_val!(instance, config.imm, imm as u64);
config.rs1_read.assign_value(instance, rs1_read.clone());
config.rd_written.assign_value(instance, rd_written);

let outflow = match I::INST_KIND {
InsnKind::SLLI => (rs1_read.as_u64() * imm as u64) >> UInt::<E>::TOTAL_BITS,
InsnKind::SRAI | InsnKind::SRLI => {
if I::INST_KIND == InsnKind::SRAI {
let max_signed_limb_expr = (1 << (UInt::<E>::LIMB_BITS - 1)) - 1;
config.is_lt_config.as_ref().unwrap().assign_instance(
instance,
lk_multiplicity,
max_signed_limb_expr,
rs1_read.as_u64() >> UInt::<E>::LIMB_BITS,
)?;
}

rs1_read.as_u64() & (imm as u64 - 1)
}
_ => unreachable!(),
_ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND),
};

config.imm.assign_value(instance, imm);
set_val!(instance, config.outflow, outflow);
config
.assert_lt_config
.assign_instance(instance, lk_multiplicity, outflow, imm as u64)?;

config
.i_insn
Expand All @@ -158,6 +182,7 @@ mod test {
use itertools::Itertools;
use multilinear_extensions::mle::IntoMLEs;

use super::{ShiftImmInstruction, SlliOp, SraiOp, SrliOp};
use crate::{
Value,
circuit_builder::{CircuitBuilder, ConstraintSystem},
Expand All @@ -168,30 +193,51 @@ mod test {
scheme::mock_prover::{MOCK_PC_START, MockProver},
};

use super::{ShiftImmInstruction, SlliOp, SrliOp};

#[test]
fn test_opcode_slli() {
verify::<SlliOp>("imm = 3, rs1 = 32", 3, 32, 32 << 3);
verify::<SlliOp>("imm = 3, rs1 = 33", 3, 33, 33 << 3);
// imm = 3
verify::<SlliOp>("32 << 3", 32, 3, 32 << 3);
verify::<SlliOp>("33 << 3", 33, 3, 33 << 3);
// imm = 31
verify::<SlliOp>("32 << 31", 32, 31, 32 << 31);
verify::<SlliOp>("33 << 31", 33, 31, 33 << 31);
}

verify::<SlliOp>("imm = 31, rs1 = 32", 31, 32, 32 << 31);
verify::<SlliOp>("imm = 31, rs1 = 33", 31, 33, 33 << 31);
#[test]
fn test_opcode_srai() {
// positive rs1
// imm = 3
verify::<SraiOp>("32 >> 3", 32, 3, 32 >> 3);
verify::<SraiOp>("33 >> 3", 33, 3, 33 >> 3);
// imm = 31
verify::<SraiOp>("32 >> 31", 32, 31, 32 >> 31);
verify::<SraiOp>("33 >> 31", 33, 31, 33 >> 31);

// negative rs1
// imm = 3
verify::<SraiOp>("-32 >> 3", (-32_i32) as u32, 3, (-32_i32 >> 3) as u32);
verify::<SraiOp>("-33 >> 3", (-33_i32) as u32, 3, (-33_i32 >> 3) as u32);
// imm = 31
verify::<SraiOp>("-32 >> 31", (-32_i32) as u32, 31, (-32_i32 >> 31) as u32);
verify::<SraiOp>("-33 >> 31", (-33_i32) as u32, 31, (-33_i32 >> 31) as u32);
}

#[test]
fn test_opcode_srli() {
verify::<SrliOp>("imm = 3, rs1 = 32", 3, 32, 32 >> 3);
verify::<SrliOp>("imm = 3, rs1 = 33", 3, 33, 33 >> 3);

verify::<SrliOp>("imm = 31, rs1 = 32", 31, 32, 32 >> 31);
verify::<SrliOp>("imm = 31, rs1 = 33", 31, 33, 33 >> 31);
// imm = 3
verify::<SrliOp>("32 >> 3", 32, 3, 32 >> 3);
verify::<SrliOp>("33 >> 3", 33, 3, 33 >> 3);
// imm = 31
verify::<SrliOp>("32 >> 31", 32, 31, 32 >> 31);
verify::<SrliOp>("33 >> 31", 33, 31, 33 >> 31);
// rs1 top bit is 1
verify::<SrliOp>("-32 >> 3", (-32_i32) as u32, 3, (-32_i32) as u32 >> 3);
}

fn verify<I: RIVInstruction>(
name: &'static str,
imm: u32,
rs1_read: u32,
imm: u32,
expected_rd_written: u32,
) {
let mut cs = ConstraintSystem::<GoldilocksExt2>::new(|| "riscv");
Expand All @@ -203,6 +249,11 @@ mod test {
encode_rv32(InsnKind::SLLI, 2, 0, 4, imm),
rs1_read << imm,
),
InsnKind::SRAI => (
"SRAI",
encode_rv32(InsnKind::SRAI, 2, 0, 4, imm),
(rs1_read as i32 >> imm as i32) as u32,
),
InsnKind::SRLI => (
"SRLI",
encode_rv32(InsnKind::SRLI, 2, 0, 4, imm),
Expand All @@ -225,7 +276,7 @@ mod test {
config
.rd_written
.require_equal(
|| "assert_rd_written",
|| format!("{prefix}_({name})_assert_rd_written"),
&mut cb,
&UInt::from_const_unchecked(
Value::new_unchecked(expected_rd_written)
Expand All @@ -249,16 +300,6 @@ mod test {
)
.unwrap();

let expected_rd_written = UInt::from_const_unchecked(
Value::new_unchecked(expected_rd_written)
.as_u16_limbs()
.to_vec(),
);
config
.rd_written
.require_equal(|| "assert_rd_written", &mut cb, &expected_rd_written)
.unwrap();

MockProver::assert_satisfied(
&cb,
&raw_witin
Expand Down
1 change: 1 addition & 0 deletions ceno_zkvm/src/uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@ impl ValueMul {
}
}

#[derive(Clone)]
pub struct Value<'a, T: Into<u64> + From<u32> + Copy + Default> {
#[allow(dead_code)]
val: T,
Expand Down

0 comments on commit 47f5572

Please sign in to comment.