Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/#98 riscv mul opcode #219

Merged
merged 10 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ceno_emul/src/addr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ impl From<WordAddr> for u32 {
}
}

impl From<WordAddr> for u64 {
fn from(addr: WordAddr) -> Self {
addr.baddr().0 as u64
}
}

impl ByteAddr {
pub const fn waddr(self) -> WordAddr {
WordAddr(self.0 / WORD_SIZE as u32)
Expand Down
22 changes: 22 additions & 0 deletions ceno_emul/src/rv32im.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,28 @@ impl DecodedInstruction {
}
}

#[allow(dead_code)]
pub fn from_raw(kind: InsnKind, rs1: u32, rs2: u32, rd: u32) -> Self {
// limit the range of inputs
let rs2 = rs2 & 0x1f; // 5bits mask
let rs1 = rs1 & 0x1f;
let rd = rd & 0x1f;
let func7 = kind.codes().func7;
let func3 = kind.codes().func3;
let opcode = kind.codes().opcode;
let insn = func7 << 25 | rs2 << 20 | rs1 << 15 | func3 << 12 | rd << 7 | opcode;
Self {
insn,
top_bit: func7 | 0x80,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not right. See above:

top_bit: (insn & 0x80000000) >> 31,

func7,
rs2,
rs1,
func3,
rd,
opcode,
}
}

pub fn encoded(&self) -> u32 {
self.insn
}
Expand Down
7 changes: 4 additions & 3 deletions ceno_emul/src/tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ impl StepRecord {
rs1_read: Word,
rs2_read: Word,
rd: Change<Word>,
previous_cycle: Cycle,
) -> StepRecord {
let insn = DecodedInstruction::new(insn_code);
StepRecord {
Expand All @@ -69,17 +70,17 @@ impl StepRecord {
rs1: Some(ReadOp {
addr: CENO_PLATFORM.register_vma(insn.rs1() as RegIdx).into(),
value: rs1_read,
previous_cycle: 0,
previous_cycle,
}),
rs2: Some(ReadOp {
addr: CENO_PLATFORM.register_vma(insn.rs2() as RegIdx).into(),
value: rs2_read,
previous_cycle: 0,
previous_cycle,
}),
rd: Some(WriteOp {
addr: CENO_PLATFORM.register_vma(insn.rd() as RegIdx).into(),
value: rd,
previous_cycle: 0,
previous_cycle,
}),
memory_op: None,
}
Expand Down
9 changes: 9 additions & 0 deletions ceno_zkvm/src/expression.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
cmp::max,
mem::MaybeUninit,
ops::{Add, Deref, Mul, Neg, Sub},
};

Expand Down Expand Up @@ -426,6 +427,14 @@ impl WitIn {
},
)
}

pub fn assign<E: ExtensionField>(
KimiWu123 marked this conversation as resolved.
Show resolved Hide resolved
&self,
instance: &mut [MaybeUninit<E::BaseField>],
value: E::BaseField,
) {
instance[self.id as usize] = MaybeUninit::new(value);
}
}

#[macro_export]
Expand Down
2 changes: 2 additions & 0 deletions ceno_zkvm/src/instructions/riscv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ pub mod addsub;
pub mod blt;
pub mod config;
pub mod constants;
pub mod mul;

mod r_insn;

#[cfg(test)]
Expand Down
4 changes: 4 additions & 0 deletions ceno_zkvm/src/instructions/riscv/addsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ mod test {
11,
0xfffffffe,
Change::new(0, 11_u32.wrapping_add(0xfffffffe)),
0,
)],
)
.unwrap();
Expand Down Expand Up @@ -225,6 +226,7 @@ mod test {
u32::MAX - 1,
u32::MAX - 1,
Change::new(0, (u32::MAX - 1).wrapping_add(u32::MAX - 1)),
0,
)],
)
.unwrap();
Expand Down Expand Up @@ -267,6 +269,7 @@ mod test {
11,
2,
Change::new(0, 11_u32.wrapping_sub(2)),
0,
)],
)
.unwrap();
Expand Down Expand Up @@ -309,6 +312,7 @@ mod test {
3,
11,
Change::new(0, 3_u32.wrapping_sub(11)),
0,
)],
)
.unwrap();
Expand Down
218 changes: 218 additions & 0 deletions ceno_zkvm/src/instructions/riscv/mul.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
use ceno_emul::{InsnKind, StepRecord};
use ff_ext::ExtensionField;
use itertools::Itertools;

use super::{constants::RegUInt, r_insn::RInstructionConfig, RIVInstruction};
use crate::{
circuit_builder::CircuitBuilder, error::ZKVMError, instructions::Instruction, uint::UIntValue,
witness::LkMultiplicity,
};
use core::mem::MaybeUninit;
use std::marker::PhantomData;

#[derive(Debug)]
pub struct ArithConfig<E: ExtensionField> {
r_insn: RInstructionConfig<E>,

multiplier_1: RegUInt<E>,
multiplier_2: RegUInt<E>,
outcome: RegUInt<E>,
}

pub struct ArithInstruction<E, I>(PhantomData<(E, I)>);

pub struct MulOp;
impl RIVInstruction for MulOp {
const INST_KIND: InsnKind = InsnKind::MUL;
}
pub type MulInstruction<E> = ArithInstruction<E, MulOp>;

impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ArithInstruction<E, I> {
type InstructionConfig = ArithConfig<E>;

fn name() -> String {
format!("{:?}", I::INST_KIND)
}

fn construct_circuit(
circuit_builder: &mut CircuitBuilder<E>,
) -> Result<Self::InstructionConfig, ZKVMError> {
let mut multiplier_1 = RegUInt::new_unchecked(|| "multiplier_1", circuit_builder)?;
let mut multiplier_2 = RegUInt::new_unchecked(|| "multiplier_2", circuit_builder)?;
let outcome = multiplier_1.mul(|| "outcome", circuit_builder, &mut multiplier_2, true)?;

let r_insn = RInstructionConfig::<E>::construct_circuit(
circuit_builder,
I::INST_KIND,
&multiplier_1,
&multiplier_2,
&outcome,
)?;

Ok(ArithConfig {
r_insn,
multiplier_1,
multiplier_2,
outcome,
})
}

fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<E::BaseField>],
lkm: &mut LkMultiplicity,
step: &StepRecord,
) -> Result<(), ZKVMError> {
config.r_insn.assign_instance(instance, lkm, step)?;

let multiplier_1 = UIntValue::new_unchecked(step.rs1().unwrap().value);
let multiplier_2 = UIntValue::new_unchecked(step.rs2().unwrap().value);
let outcome = UIntValue::new_unchecked(step.rd().unwrap().value.after);

config
.multiplier_1
.assign_limbs(instance, multiplier_1.u16_fields());
config
.multiplier_2
.assign_limbs(instance, multiplier_2.u16_fields());
let (_, carries) = multiplier_1.mul(&multiplier_2, lkm, true);

config.outcome.assign_limbs(instance, outcome.u16_fields());
config.outcome.assign_carries(
instance,
carries
.into_iter()
.map(|carry| E::BaseField::from(carry as u64))
.collect_vec(),
);

Ok(())
}
}

#[cfg(test)]
mod test {
use ceno_emul::{Change, StepRecord};
use goldilocks::GoldilocksExt2;
use itertools::Itertools;
use multilinear_extensions::mle::IntoMLEs;

use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
instructions::Instruction,
scheme::mock_prover::{MockProver, MOCK_PC_MUL, MOCK_PROGRAM},
};

use super::MulInstruction;

#[test]
fn test_opcode_mul() {
let mut cs = ConstraintSystem::<GoldilocksExt2>::new(|| "riscv");
let mut cb = CircuitBuilder::new(&mut cs);
let config = cb
.namespace(|| "mul", |cb| Ok(MulInstruction::construct_circuit(cb)))
.unwrap()
.unwrap();

// values assignment
let (raw_witin, _) = MulInstruction::assign_instances(
&config,
cb.cs.num_witin as usize,
vec![StepRecord::new_r_instruction(
3,
MOCK_PC_MUL,
MOCK_PROGRAM[2],
11,
2,
Change::new(0, 22),
0,
)],
)
.unwrap();

MockProver::assert_satisfied(
&mut cb,
&raw_witin
.de_interleaving()
.into_mles()
.into_iter()
.map(|v| v.into())
.collect_vec(),
None,
);
}

#[test]
fn test_opcode_mul_overflow() {
hero78119 marked this conversation as resolved.
Show resolved Hide resolved
let mut cs = ConstraintSystem::<GoldilocksExt2>::new(|| "riscv");
let mut cb = CircuitBuilder::new(&mut cs);
let config = cb
.namespace(|| "mul", |cb| Ok(MulInstruction::construct_circuit(cb)))
.unwrap()
.unwrap();

// values assignment
let (raw_witin, _) = MulInstruction::assign_instances(
&config,
cb.cs.num_witin as usize,
vec![StepRecord::new_r_instruction(
3,
MOCK_PC_MUL,
MOCK_PROGRAM[2],
u32::MAX / 2 + 1,
2,
Change::new(0, 0),
0,
)],
)
.unwrap();

MockProver::assert_satisfied(
&mut cb,
&raw_witin
.de_interleaving()
.into_mles()
.into_iter()
.map(|v| v.into())
.collect_vec(),
None,
);
}

#[test]
fn test_opcode_mul_overflow2() {
let mut cs = ConstraintSystem::<GoldilocksExt2>::new(|| "riscv");
let mut cb = CircuitBuilder::new(&mut cs);
let config = cb
.namespace(|| "mul", |cb| Ok(MulInstruction::construct_circuit(cb)))
.unwrap()
.unwrap();

// values assignment
let (raw_witin, _) = MulInstruction::assign_instances(
&config,
cb.cs.num_witin as usize,
vec![StepRecord::new_r_instruction(
3,
MOCK_PC_MUL,
MOCK_PROGRAM[2],
4294901760,
4294901760,
Change::new(0, 0),
0,
)],
)
.unwrap();

MockProver::assert_satisfied(
&mut cb,
&raw_witin
.de_interleaving()
.into_mles()
.into_iter()
.map(|v| v.into())
.collect_vec(),
None,
);
}
}
13 changes: 11 additions & 2 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,26 @@ use itertools::Itertools;
use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension;
use std::{collections::HashSet, hash::Hash, marker::PhantomData, ops::Neg, sync::OnceLock};

pub const MOCK_RS1: u32 = 2;
pub const MOCK_RS2: u32 = 3;
pub const MOCK_RD: u32 = 4;
/// The program baked in the MockProver.
/// TODO: Make this a parameter?
pub const MOCK_PROGRAM: &[u32] = &[
// R-Type
// funct7 | rs2 | rs1 | funct3 | rd | opcode
// -----------------------------------------
// add x4, x2, x3
0x00 << 25 | 3 << 20 | 2 << 15 | 4 << 7 | 0x33,
0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0x00 << 12 | MOCK_RD << 7 | 0x33,
// sub x4, x2, x3
0x20 << 25 | 3 << 20 | 2 << 15 | 4 << 7 | 0x33,
0x20 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0x00 << 12 | MOCK_RD << 7 | 0x33,
// mul (0x01, 0x00, 0x33)
0x01 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0x00 << 12 | MOCK_RD << 7 | 0x33,
];
// Addresses of particular instructions in the mock program.
pub const MOCK_PC_ADD: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start());
pub const MOCK_PC_SUB: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 4);
pub const MOCK_PC_MUL: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 8);

#[allow(clippy::enum_variant_names)]
#[derive(Debug, PartialEq, Clone)]
Expand Down
Loading
Loading