Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SRL and SLL #304

Merged
merged 14 commits into from
Oct 8, 2024
5 changes: 5 additions & 0 deletions ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,11 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
self.logic_u8(ROMType::Ltu, a, b, c)
}

// Assert that `2^b = c` and that `b` is a 5-bit unsigned integer.
pub fn lookup_pow2(&mut self, b: Expression<E>, c: Expression<E>) -> Result<(), ZKVMError> {
self.logic_u8(ROMType::Pow, 2.into(), b, c)
}

/// less_than
pub(crate) fn less_than<N, NR>(
&mut self,
Expand Down
1 change: 1 addition & 0 deletions ceno_zkvm/src/instructions/riscv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub mod divu;
pub mod ecall;
pub mod logic;
pub mod mulh;
pub mod shift;
pub mod shift_imm;
pub mod sltu;

Expand Down
303 changes: 303 additions & 0 deletions ceno_zkvm/src/instructions/riscv/shift.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
use std::{marker::PhantomData, mem::MaybeUninit};

use ceno_emul::InsnKind;
use ff_ext::ExtensionField;

use crate::{
expression::{ToExpr, WitIn},
gadgets::IsLtConfig,
instructions::Instruction,
set_val, Value,
};

use super::{constants::UInt, r_insn::RInstructionConfig, RIVInstruction};

pub struct ShiftConfig<E: ExtensionField> {
r_insn: RInstructionConfig<E>,

rs1_read: UInt<E>,
rs2_read: UInt<E>,
rd_written: UInt<E>,

rs2_high: UInt<E>,
rs2_low5: WitIn,
pow2_rs2_low5: UInt<E>,

intermediate: Option<UInt<E>>,
KimiWu123 marked this conversation as resolved.
Show resolved Hide resolved
remainder: Option<UInt<E>>,
lt_config: Option<IsLtConfig>,
}

pub struct ShiftLogicalInstruction<E, I>(PhantomData<(E, I)>);

struct SllOp;
impl RIVInstruction for SllOp {
const INST_KIND: InsnKind = InsnKind::SLL;
}

struct SrlOp;
impl RIVInstruction for SrlOp {
const INST_KIND: InsnKind = InsnKind::SRL;
}

impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ShiftLogicalInstruction<E, I> {
type InstructionConfig = ShiftConfig<E>;

fn name() -> String {
format!("{:?}", I::INST_KIND)
}

fn construct_circuit(
circuit_builder: &mut crate::circuit_builder::CircuitBuilder<E>,
) -> Result<Self::InstructionConfig, crate::error::ZKVMError> {
let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?;
let rs2_low5 = circuit_builder.create_witin(|| "rs2_low5")?;
KimiWu123 marked this conversation as resolved.
Show resolved Hide resolved
// pow2_rs2_low5 is unchecked because it's assignment will be constrained due it's use in lookup_pow2 below
let mut pow2_rs2_low5 = UInt::new_unchecked(|| "pow2_rs2_low5", circuit_builder)?;
// rs2 = rs2_high | rs2_low5
let rs2_high = UInt::new(|| "rs2_high", circuit_builder)?;

let (rs1_read, rd_written, intermediate, remainder, lt_config) =
if I::INST_KIND == InsnKind::SLL {
zemse marked this conversation as resolved.
Show resolved Hide resolved
let mut rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?;
let rd_written = rs1_read.mul(
|| "rd_written = rs1_read * pow2_rs2_low5",
circuit_builder,
&mut pow2_rs2_low5,
true,
)?;
(rs1_read, rd_written, None, None, None)
} else if I::INST_KIND == InsnKind::SRL {
let mut rd_written = UInt::new(|| "rd_written", circuit_builder)?;
let remainder = UInt::new(|| "remainder", circuit_builder)?;
let (rs1_read, intermediate) = rd_written.mul_add(
|| "rs1_read = rd_written * pow2_rs2_low5 + remainder",
circuit_builder,
&mut pow2_rs2_low5,
&remainder,
true,
)?;

let lt_config = circuit_builder.less_than(
|| "remainder < pow2_rs2_low5",
remainder.value(),
pow2_rs2_low5.value(),
Some(true),
2,
)?;

(
rs1_read,
rd_written,
Some(intermediate),
Some(remainder),
Some(lt_config),
)
} else {
unreachable!()
};

let r_insn = RInstructionConfig::<E>::construct_circuit(
circuit_builder,
I::INST_KIND,
rs1_read.register_expr(),
rs2_read.register_expr(),
rd_written.register_expr(),
)?;

circuit_builder.lookup_pow2(rs2_low5.expr(), pow2_rs2_low5.value())?;
circuit_builder.assert_ux::<_, _, 5>(|| "rs2_low5 in u5", rs2_low5.expr())?;
circuit_builder.require_equal(
|| "rs2 == rs2_high * 2^5 + rs2_low5",
rs2_read.value(),
rs2_high.value() * (1 << 5).into() + rs2_low5.expr(),
)?;

Ok(ShiftConfig {
r_insn,
rs1_read,
rs2_read,
rd_written,
rs2_high,
rs2_low5,
pow2_rs2_low5,
intermediate,
remainder,
lt_config,
})
}

fn assign_instance(
zemse marked this conversation as resolved.
Show resolved Hide resolved
config: &Self::InstructionConfig,
instance: &mut [std::mem::MaybeUninit<<E as ExtensionField>::BaseField>],
lk_multiplicity: &mut crate::witness::LkMultiplicity,
step: &ceno_emul::StepRecord,
) -> Result<(), crate::error::ZKVMError> {
let rs2_read = Value::new_unchecked(step.rs2().unwrap().value);
let rs2_low5 = rs2_read.as_u64() & 0b11111;
let pow2_rs2_low5 = Value::new_unchecked((1 << rs2_low5) as u32);
let rs2_high = Value::new(
((rs2_read.as_u64() - rs2_low5) >> 5) as u32,
lk_multiplicity,
);

if I::INST_KIND == InsnKind::SLL {
let rs1_read = Value::new_unchecked(step.rs1().unwrap().value);
let rd_written = rs1_read.mul(&pow2_rs2_low5, lk_multiplicity, true);
config.rs1_read.assign_value(instance, rs1_read);
config
.rd_written
.assign_mul_outcome(instance, lk_multiplicity, &rd_written)?;
} else if I::INST_KIND == InsnKind::SRL {
let rd_written = Value::new_unchecked(step.rd().unwrap().value.after);
let remainder = Value::new(
// rs1 - rd * pow2_rs2_low5
step.rs1()
.unwrap()
.value
.wrapping_sub((rd_written.as_u64() * pow2_rs2_low5.as_u64()) as u32),
lk_multiplicity,
);
let (rs1_read, intermediate) =
rd_written.mul_add(&pow2_rs2_low5, &remainder, lk_multiplicity, true);

config.lt_config.as_ref().unwrap().assign_instance(
instance,
lk_multiplicity,
remainder.as_u64(),
pow2_rs2_low5.as_u64(),
)?;

config.rs1_read.assign_add_outcome(instance, &rs1_read);
config.rd_written.assign_value(instance, rd_written);
config
.remainder
.as_ref()
.unwrap()
.assign_value(instance, remainder);
config.intermediate.as_ref().unwrap().assign_mul_outcome(
instance,
lk_multiplicity,
&intermediate,
)?;
} else {
unreachable!()
};

config
.r_insn
.assign_instance(instance, lk_multiplicity, step)?;
config.rs2_read.assign_value(instance, rs2_read);
set_val!(instance, config.rs2_low5, rs2_low5);
config.rs2_high.assign_value(instance, rs2_high);
config.pow2_rs2_low5.assign_value(instance, pow2_rs2_low5);

Ok(())
}
}

#[cfg(test)]
mod tests {
zemse marked this conversation as resolved.
Show resolved Hide resolved
use ceno_emul::{Change, InsnKind, StepRecord};
use goldilocks::GoldilocksExt2;
use itertools::Itertools;
use multilinear_extensions::mle::IntoMLEs;

use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
instructions::{
riscv::{constants::UInt, RIVInstruction},
Instruction,
},
scheme::mock_prover::{MockProver, MOCK_PC_SLL, MOCK_PC_SRL, MOCK_PROGRAM},
Value,
};

use super::{ShiftLogicalInstruction, SllOp, SrlOp};

#[test]
fn test_opcode_sll_1() {
zemse marked this conversation as resolved.
Show resolved Hide resolved
verify::<SllOp>(32, 3, 32 << 3);
}

#[test]
fn test_opcode_sll_2_overflow() {
// 33 << 33 === 33 << 1
verify::<SllOp>(33, 33, 33 << (33 - 32));
}
zemse marked this conversation as resolved.
Show resolved Hide resolved

#[test]
fn test_opcode_srl_1() {
verify::<SrlOp>(33, 3, 33 >> 3);
}

#[test]
fn test_opcode_srl_2_overflow() {
// 33 >> 33 === 33 >> 1
verify::<SrlOp>(33, 33, 33 >> 1);
}

fn verify<I: RIVInstruction>(rs1_read: u32, rs2_read: u32, expected_rd_written: u32) {
let mut cs = ConstraintSystem::<GoldilocksExt2>::new(|| "riscv");
let mut cb = CircuitBuilder::new(&mut cs);

let (name, mock_pc, mock_program_op) = if I::INST_KIND == InsnKind::SLL {
("SLL", MOCK_PC_SLL, MOCK_PROGRAM[19])
} else if I::INST_KIND == InsnKind::SRL {
("SRL", MOCK_PC_SRL, MOCK_PROGRAM[20])
} else {
unreachable!()
};

let config = cb
.namespace(
|| name,
|cb| {
let config =
ShiftLogicalInstruction::<GoldilocksExt2, I>::construct_circuit(cb);
Ok(config)
},
)
.unwrap()
.unwrap();

config
.rd_written
.require_equal(
|| "assert_rd_written",
&mut cb,
&UInt::from_const_unchecked(
Value::new_unchecked(expected_rd_written)
.as_u16_limbs()
.to_vec(),
),
)
.unwrap();

let (raw_witin, _) = ShiftLogicalInstruction::<GoldilocksExt2, I>::assign_instances(
&config,
cb.cs.num_witin as usize,
vec![StepRecord::new_r_instruction(
3,
mock_pc,
mock_program_op,
rs1_read,
rs2_read,
Change::new(0, expected_rd_written),
0,
)],
)
.unwrap();

MockProver::assert_satisfied(
&cb,
&raw_witin
.de_interleaving()
.into_mles()
.into_iter()
.map(|v| v.into())
.collect_vec(),
None,
);
}
}
13 changes: 10 additions & 3 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use crate::{
expression::{fmt, Expression},
scheme::utils::eval_by_expr_with_fixed,
tables::{
AndTable, LtuTable, OpsTable, OrTable, ProgramTableCircuit, RangeTable, TableCircuit,
U16Table, U5Table, U8Table, XorTable,
AndTable, LtuTable, OpsTable, OrTable, PowTable, ProgramTableCircuit, RangeTable,
TableCircuit, U16Table, U5Table, U8Table, XorTable,
},
};
use ark_std::test_rng;
Expand Down Expand Up @@ -68,7 +68,7 @@ pub const MOCK_PROGRAM: &[u32] = &[
0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b011 << 12 | MOCK_RD << 7 | 0x33,
// addi x4, x2, 3
0x00 << 25 | MOCK_IMM_3 << 20 | MOCK_RS1 << 15 | 0x00 << 12 | MOCK_RD << 7 | 0x13,
// addi x4, x2, -3, correc this below
// addi x4, x2, -3
0b_1_111111 << 25 | MOCK_IMM_NEG3 << 20 | MOCK_RS1 << 15 | 0x00 << 12 | MOCK_RD << 7 | 0x13,
// bltu x2, x3, -8
0b_1_111111 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b_110 << 12 | 0b_1100_1 << 7 | 0x63,
Expand All @@ -78,6 +78,10 @@ pub const MOCK_PROGRAM: &[u32] = &[
0b_1_111111 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b_101 << 12 | 0b_1100_1 << 7 | 0x63,
// mulhu (0x01, 0x00, 0x33)
0x01 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0x3 << 12 | MOCK_RD << 7 | 0x33,
// sll x4, x2, x3
0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b001 << 12 | MOCK_RD << 7 | 0x33,
// srl x4, x2, x3
0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b101 << 12 | MOCK_RD << 7 | 0x33,
];
// Addresses of particular instructions in the mock program.
pub const MOCK_PC_ADD: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start());
Expand All @@ -99,6 +103,8 @@ pub const MOCK_PC_BLTU: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 60);
pub const MOCK_PC_BGEU: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 64);
pub const MOCK_PC_BGE: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 68);
pub const MOCK_PC_MULHU: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 72);
pub const MOCK_PC_SLL: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 76);
pub const MOCK_PC_SRL: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 80);

#[allow(clippy::enum_variant_names)]
#[derive(Debug, PartialEq, Clone)]
Expand Down Expand Up @@ -256,6 +262,7 @@ fn load_tables<E: ExtensionField>(cb: &CircuitBuilder<E>, challenge: [E; 2]) ->
load_op_table::<OrTable, _>(&mut table_vec, cb, challenge);
load_op_table::<XorTable, _>(&mut table_vec, cb, challenge);
load_op_table::<LtuTable, _>(&mut table_vec, cb, challenge);
load_op_table::<PowTable, _>(&mut table_vec, cb, challenge);
load_program_table(&mut table_vec, cb, challenge);
HashSet::from_iter(table_vec)
}
Expand Down
1 change: 1 addition & 0 deletions ceno_zkvm/src/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pub enum ROMType {
Or, // a | b where a, b are bytes
Xor, // a ^ b where a, b are bytes
Ltu, // a <(usign) b where a, b are bytes and the result is 0/1.
Pow, // a ** b where a is 2 and b is 5-bit value
Instruction, // Decoded instruction from the fixed program.
}

Expand Down
12 changes: 12 additions & 0 deletions ceno_zkvm/src/tables/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,15 @@ impl OpsTable for LtuTable {
}
}
pub type LtuTableCircuit<E> = OpsTableCircuit<E, LtuTable>;

pub struct PowTable;
impl OpsTable for PowTable {
const ROM_TYPE: ROMType = ROMType::Pow;
fn len() -> usize {
1 << 5
}

fn content() -> Vec<[u64; 3]> {
(0..Self::len() as u64).map(|b| [2, b, 1 << b]).collect()
}
}
Loading