Skip to content

Commit

Permalink
Feat/#98 riscv mul opcode (#219)
Browse files Browse the repository at this point in the history
close #98
  • Loading branch information
KimiWu123 authored Sep 18, 2024
1 parent 1637f4c commit 713461b
Show file tree
Hide file tree
Showing 10 changed files with 523 additions and 136 deletions.
6 changes: 6 additions & 0 deletions ceno_emul/src/addr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ impl From<WordAddr> for u32 {
}
}

impl From<WordAddr> for u64 {
fn from(addr: WordAddr) -> Self {
addr.baddr().0 as u64
}
}

impl ByteAddr {
pub const fn waddr(self) -> WordAddr {
WordAddr(self.0 / WORD_SIZE as u32)
Expand Down
22 changes: 22 additions & 0 deletions ceno_emul/src/rv32im.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,28 @@ impl DecodedInstruction {
}
}

#[allow(dead_code)]
pub fn from_raw(kind: InsnKind, rs1: u32, rs2: u32, rd: u32) -> Self {
// limit the range of inputs
let rs2 = rs2 & 0x1f; // 5bits mask
let rs1 = rs1 & 0x1f;
let rd = rd & 0x1f;
let func7 = kind.codes().func7;
let func3 = kind.codes().func3;
let opcode = kind.codes().opcode;
let insn = func7 << 25 | rs2 << 20 | rs1 << 15 | func3 << 12 | rd << 7 | opcode;
Self {
insn,
top_bit: func7 | 0x80,
func7,
rs2,
rs1,
func3,
rd,
opcode,
}
}

pub fn encoded(&self) -> u32 {
self.insn
}
Expand Down
7 changes: 4 additions & 3 deletions ceno_emul/src/tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ impl StepRecord {
rs1_read: Word,
rs2_read: Word,
rd: Change<Word>,
previous_cycle: Cycle,
) -> StepRecord {
let insn = DecodedInstruction::new(insn_code);
StepRecord {
Expand All @@ -69,17 +70,17 @@ impl StepRecord {
rs1: Some(ReadOp {
addr: CENO_PLATFORM.register_vma(insn.rs1() as RegIdx).into(),
value: rs1_read,
previous_cycle: 0,
previous_cycle,
}),
rs2: Some(ReadOp {
addr: CENO_PLATFORM.register_vma(insn.rs2() as RegIdx).into(),
value: rs2_read,
previous_cycle: 0,
previous_cycle,
}),
rd: Some(WriteOp {
addr: CENO_PLATFORM.register_vma(insn.rd() as RegIdx).into(),
value: rd,
previous_cycle: 0,
previous_cycle,
}),
memory_op: None,
}
Expand Down
9 changes: 9 additions & 0 deletions ceno_zkvm/src/expression.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
cmp::max,
mem::MaybeUninit,
ops::{Add, Deref, Mul, Neg, Sub},
};

Expand Down Expand Up @@ -426,6 +427,14 @@ impl WitIn {
},
)
}

pub fn assign<E: ExtensionField>(
&self,
instance: &mut [MaybeUninit<E::BaseField>],
value: E::BaseField,
) {
instance[self.id as usize] = MaybeUninit::new(value);
}
}

#[macro_export]
Expand Down
2 changes: 2 additions & 0 deletions ceno_zkvm/src/instructions/riscv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ pub mod addsub;
pub mod blt;
pub mod config;
pub mod constants;
pub mod mul;

mod r_insn;

#[cfg(test)]
Expand Down
4 changes: 4 additions & 0 deletions ceno_zkvm/src/instructions/riscv/addsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ mod test {
11,
0xfffffffe,
Change::new(0, 11_u32.wrapping_add(0xfffffffe)),
0,
)],
)
.unwrap();
Expand Down Expand Up @@ -225,6 +226,7 @@ mod test {
u32::MAX - 1,
u32::MAX - 1,
Change::new(0, (u32::MAX - 1).wrapping_add(u32::MAX - 1)),
0,
)],
)
.unwrap();
Expand Down Expand Up @@ -267,6 +269,7 @@ mod test {
11,
2,
Change::new(0, 11_u32.wrapping_sub(2)),
0,
)],
)
.unwrap();
Expand Down Expand Up @@ -309,6 +312,7 @@ mod test {
3,
11,
Change::new(0, 3_u32.wrapping_sub(11)),
0,
)],
)
.unwrap();
Expand Down
218 changes: 218 additions & 0 deletions ceno_zkvm/src/instructions/riscv/mul.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
use ceno_emul::{InsnKind, StepRecord};
use ff_ext::ExtensionField;
use itertools::Itertools;

use super::{constants::RegUInt, r_insn::RInstructionConfig, RIVInstruction};
use crate::{
circuit_builder::CircuitBuilder, error::ZKVMError, instructions::Instruction, uint::UIntValue,
witness::LkMultiplicity,
};
use core::mem::MaybeUninit;
use std::marker::PhantomData;

#[derive(Debug)]
pub struct ArithConfig<E: ExtensionField> {
r_insn: RInstructionConfig<E>,

multiplier_1: RegUInt<E>,
multiplier_2: RegUInt<E>,
outcome: RegUInt<E>,
}

pub struct ArithInstruction<E, I>(PhantomData<(E, I)>);

pub struct MulOp;
impl RIVInstruction for MulOp {
const INST_KIND: InsnKind = InsnKind::MUL;
}
pub type MulInstruction<E> = ArithInstruction<E, MulOp>;

impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ArithInstruction<E, I> {
type InstructionConfig = ArithConfig<E>;

fn name() -> String {
format!("{:?}", I::INST_KIND)
}

fn construct_circuit(
circuit_builder: &mut CircuitBuilder<E>,
) -> Result<Self::InstructionConfig, ZKVMError> {
let mut multiplier_1 = RegUInt::new_unchecked(|| "multiplier_1", circuit_builder)?;
let mut multiplier_2 = RegUInt::new_unchecked(|| "multiplier_2", circuit_builder)?;
let outcome = multiplier_1.mul(|| "outcome", circuit_builder, &mut multiplier_2, true)?;

let r_insn = RInstructionConfig::<E>::construct_circuit(
circuit_builder,
I::INST_KIND,
&multiplier_1,
&multiplier_2,
&outcome,
)?;

Ok(ArithConfig {
r_insn,
multiplier_1,
multiplier_2,
outcome,
})
}

fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<E::BaseField>],
lkm: &mut LkMultiplicity,
step: &StepRecord,
) -> Result<(), ZKVMError> {
config.r_insn.assign_instance(instance, lkm, step)?;

let multiplier_1 = UIntValue::new_unchecked(step.rs1().unwrap().value);
let multiplier_2 = UIntValue::new_unchecked(step.rs2().unwrap().value);
let outcome = UIntValue::new_unchecked(step.rd().unwrap().value.after);

config
.multiplier_1
.assign_limbs(instance, multiplier_1.u16_fields());
config
.multiplier_2
.assign_limbs(instance, multiplier_2.u16_fields());
let (_, carries) = multiplier_1.mul(&multiplier_2, lkm, true);

config.outcome.assign_limbs(instance, outcome.u16_fields());
config.outcome.assign_carries(
instance,
carries
.into_iter()
.map(|carry| E::BaseField::from(carry as u64))
.collect_vec(),
);

Ok(())
}
}

#[cfg(test)]
mod test {
use ceno_emul::{Change, StepRecord};
use goldilocks::GoldilocksExt2;
use itertools::Itertools;
use multilinear_extensions::mle::IntoMLEs;

use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
instructions::Instruction,
scheme::mock_prover::{MockProver, MOCK_PC_MUL, MOCK_PROGRAM},
};

use super::MulInstruction;

#[test]
fn test_opcode_mul() {
let mut cs = ConstraintSystem::<GoldilocksExt2>::new(|| "riscv");
let mut cb = CircuitBuilder::new(&mut cs);
let config = cb
.namespace(|| "mul", |cb| Ok(MulInstruction::construct_circuit(cb)))
.unwrap()
.unwrap();

// values assignment
let (raw_witin, _) = MulInstruction::assign_instances(
&config,
cb.cs.num_witin as usize,
vec![StepRecord::new_r_instruction(
3,
MOCK_PC_MUL,
MOCK_PROGRAM[2],
11,
2,
Change::new(0, 22),
0,
)],
)
.unwrap();

MockProver::assert_satisfied(
&mut cb,
&raw_witin
.de_interleaving()
.into_mles()
.into_iter()
.map(|v| v.into())
.collect_vec(),
None,
);
}

#[test]
fn test_opcode_mul_overflow() {
let mut cs = ConstraintSystem::<GoldilocksExt2>::new(|| "riscv");
let mut cb = CircuitBuilder::new(&mut cs);
let config = cb
.namespace(|| "mul", |cb| Ok(MulInstruction::construct_circuit(cb)))
.unwrap()
.unwrap();

// values assignment
let (raw_witin, _) = MulInstruction::assign_instances(
&config,
cb.cs.num_witin as usize,
vec![StepRecord::new_r_instruction(
3,
MOCK_PC_MUL,
MOCK_PROGRAM[2],
u32::MAX / 2 + 1,
2,
Change::new(0, 0),
0,
)],
)
.unwrap();

MockProver::assert_satisfied(
&mut cb,
&raw_witin
.de_interleaving()
.into_mles()
.into_iter()
.map(|v| v.into())
.collect_vec(),
None,
);
}

#[test]
fn test_opcode_mul_overflow2() {
let mut cs = ConstraintSystem::<GoldilocksExt2>::new(|| "riscv");
let mut cb = CircuitBuilder::new(&mut cs);
let config = cb
.namespace(|| "mul", |cb| Ok(MulInstruction::construct_circuit(cb)))
.unwrap()
.unwrap();

// values assignment
let (raw_witin, _) = MulInstruction::assign_instances(
&config,
cb.cs.num_witin as usize,
vec![StepRecord::new_r_instruction(
3,
MOCK_PC_MUL,
MOCK_PROGRAM[2],
4294901760,
4294901760,
Change::new(0, 0),
0,
)],
)
.unwrap();

MockProver::assert_satisfied(
&mut cb,
&raw_witin
.de_interleaving()
.into_mles()
.into_iter()
.map(|v| v.into())
.collect_vec(),
None,
);
}
}
13 changes: 11 additions & 2 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,26 @@ use itertools::Itertools;
use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension;
use std::{collections::HashSet, hash::Hash, marker::PhantomData, ops::Neg, sync::OnceLock};

pub const MOCK_RS1: u32 = 2;
pub const MOCK_RS2: u32 = 3;
pub const MOCK_RD: u32 = 4;
/// The program baked in the MockProver.
/// TODO: Make this a parameter?
pub const MOCK_PROGRAM: &[u32] = &[
// R-Type
// funct7 | rs2 | rs1 | funct3 | rd | opcode
// -----------------------------------------
// add x4, x2, x3
0x00 << 25 | 3 << 20 | 2 << 15 | 4 << 7 | 0x33,
0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0x00 << 12 | MOCK_RD << 7 | 0x33,
// sub x4, x2, x3
0x20 << 25 | 3 << 20 | 2 << 15 | 4 << 7 | 0x33,
0x20 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0x00 << 12 | MOCK_RD << 7 | 0x33,
// mul (0x01, 0x00, 0x33)
0x01 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0x00 << 12 | MOCK_RD << 7 | 0x33,
];
// Addresses of particular instructions in the mock program.
pub const MOCK_PC_ADD: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start());
pub const MOCK_PC_SUB: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 4);
pub const MOCK_PC_MUL: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 8);

#[allow(clippy::enum_variant_names)]
#[derive(Debug, PartialEq, Clone)]
Expand Down
Loading

0 comments on commit 713461b

Please sign in to comment.