Skip to content

Commit

Permalink
Feat: Implement JAL opcode (#305)
Browse files Browse the repository at this point in the history
Implement a J-type instruction base type and the `JAL` opcode.

---------

Co-authored-by: Bryan Gillespie <[email protected]>
Co-authored-by: xkx <[email protected]>
  • Loading branch information
3 people authored Oct 11, 2024
1 parent 8e2214d commit cdb771a
Show file tree
Hide file tree
Showing 10 changed files with 258 additions and 8 deletions.
23 changes: 23 additions & 0 deletions ceno_emul/src/tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,29 @@ impl StepRecord {
}
}

pub fn new_j_instruction(
cycle: Cycle,
pc: Change<ByteAddr>,
insn_code: Word,
rd: Change<Word>,
previous_cycle: Cycle,
) -> StepRecord {
let insn = DecodedInstruction::new(insn_code);
StepRecord {
cycle,
pc,
insn_code,
rs1: None,
rs2: None,
rd: Some(WriteOp {
addr: CENO_PLATFORM.register_vma(insn.rd() as RegIdx).into(),
value: rd,
previous_cycle,
}),
memory_op: None,
}
}

pub fn cycle(&self) -> Cycle {
self.cycle
}
Expand Down
19 changes: 14 additions & 5 deletions ceno_zkvm/examples/riscv_opcodes.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{iter, panic, time::Instant};

use ceno_zkvm::{
instructions::riscv::{arith::AddInstruction, branch::BltuInstruction},
instructions::riscv::{arith::AddInstruction, branch::BltuInstruction, jump::JalInstruction},
scheme::prover::ZKVMProver,
tables::ProgramTableCircuit,
};
Expand All @@ -10,7 +10,7 @@ use const_env::from_env;

use ceno_emul::{
ByteAddr,
InsnKind::{ADD, BLTU, EANY},
InsnKind::{ADD, BLTU, EANY, JAL},
StepRecord, VMState, CENO_PLATFORM,
};
use ceno_zkvm::{
Expand Down Expand Up @@ -39,11 +39,12 @@ const RAYON_NUM_THREADS: usize = 8;
#[allow(clippy::unusual_byte_groupings)]
const ECALL_HALT: u32 = 0b_000000000000_00000_000_00000_1110011;
#[allow(clippy::unusual_byte_groupings)]
const PROGRAM_CODE: [u32; 4] = [
const PROGRAM_CODE: [u32; 5] = [
// func7 rs2 rs1 f3 rd opcode
0b_0000000_00100_00001_000_00100_0110011, // add x4, x4, x1 <=> addi x4, x4, 1
0b_0000000_00011_00010_000_00011_0110011, // add x3, x3, x2 <=> addi x3, x3, -1
0b_1_111111_00011_00000_110_1100_1_1100011, // bltu x0, x3, -8
0b_0_0000000010_0_00000000_00001_1101111, // jal x1, 4
ECALL_HALT, // ecall halt
];

Expand Down Expand Up @@ -105,6 +106,7 @@ fn main() {
// opcode circuits
let add_config = zkvm_cs.register_opcode_circuit::<AddInstruction<E>>();
let bltu_config = zkvm_cs.register_opcode_circuit::<BltuInstruction>();
let jal_config = zkvm_cs.register_opcode_circuit::<JalInstruction<E>>();
let halt_config = zkvm_cs.register_opcode_circuit::<HaltInstruction<E>>();
// tables
let u16_range_config = zkvm_cs.register_table_circuit::<U16TableCircuit<E>>();
Expand All @@ -121,6 +123,7 @@ fn main() {
let mut zkvm_fixed_traces = ZKVMFixedTraces::default();
zkvm_fixed_traces.register_opcode_circuit::<AddInstruction<E>>(&zkvm_cs);
zkvm_fixed_traces.register_opcode_circuit::<BltuInstruction>(&zkvm_cs);
zkvm_fixed_traces.register_opcode_circuit::<JalInstruction<E>>(&zkvm_cs);
zkvm_fixed_traces.register_opcode_circuit::<HaltInstruction<E>>(&zkvm_cs);

zkvm_fixed_traces.register_table_circuit::<U16TableCircuit<E>>(
Expand Down Expand Up @@ -176,12 +179,14 @@ fn main() {
.collect::<Vec<_>>();
let mut add_records = Vec::new();
let mut bltu_records = Vec::new();
let mut jal_records = Vec::new();
let mut halt_records = Vec::new();
all_records.into_iter().for_each(|record| {
let kind = record.insn().kind().1;
match kind {
ADD => add_records.push(record),
BLTU => bltu_records.push(record),
JAL => jal_records.push(record),
EANY => {
if record.rs1().unwrap().value == CENO_PLATFORM.ecall_halt() {
halt_records.push(record);
Expand All @@ -196,9 +201,10 @@ fn main() {
let pi = PublicValues::new(exit_code, 0);

tracing::info!(
"tracer generated {} ADD records, {} BLTU records",
"tracer generated {} ADD records, {} BLTU records, {} JAL records",
add_records.len(),
bltu_records.len()
bltu_records.len(),
jal_records.len(),
);

let mut zkvm_witness = ZKVMWitnesses::default();
Expand All @@ -209,6 +215,9 @@ fn main() {
zkvm_witness
.assign_opcode_circuit::<BltuInstruction>(&zkvm_cs, &bltu_config, bltu_records)
.unwrap();
zkvm_witness
.assign_opcode_circuit::<JalInstruction<E>>(&zkvm_cs, &jal_config, jal_records)
.unwrap();
zkvm_witness
.assign_opcode_circuit::<HaltInstruction<E>>(&zkvm_cs, &halt_config, halt_records)
.unwrap();
Expand Down
4 changes: 3 additions & 1 deletion ceno_zkvm/src/instructions/riscv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub mod config;
pub mod constants;
pub mod divu;
pub mod ecall;
pub mod jump;
pub mod logic;
pub mod mulh;
pub mod shift;
Expand All @@ -16,9 +17,10 @@ pub mod sltu;
mod b_insn;
mod i_insn;
mod insn_base;
mod j_insn;
mod r_insn;

mod ecall_insn;
mod r_insn;

#[cfg(test)]
mod test;
Expand Down
1 change: 0 additions & 1 deletion ceno_zkvm/src/instructions/riscv/divu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ mod test {
let mut rng = rand::thread_rng();
let a: u32 = rng.gen();
let b: u32 = rng.gen_range(1..u32::MAX);
println!("random: {} / {} = {}", a, b, a / b);
verify("random", a, b, a / b);
}
}
Expand Down
69 changes: 69 additions & 0 deletions ceno_zkvm/src/instructions/riscv/j_insn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use ceno_emul::{InsnKind, StepRecord};
use ff_ext::ExtensionField;

use crate::{
chip_handler::RegisterExpr,
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::ToExpr,
instructions::riscv::insn_base::{StateInOut, WriteRD},
tables::InsnRecord,
witness::LkMultiplicity,
};
use core::mem::MaybeUninit;

// Opcode: 1101111

/// This config handles the common part of the J-type instruction (JAL):
/// - PC, cycle, fetch
/// - Register access
///
/// It does not witness the output rd value produced by the JAL opcode, but
/// does constrain next_pc = pc + imm using the instruction table lookup
#[derive(Debug)]
pub struct JInstructionConfig<E: ExtensionField> {
pub vm_state: StateInOut<E>,
pub rd: WriteRD<E>,
}

impl<E: ExtensionField> JInstructionConfig<E> {
pub fn construct_circuit(
circuit_builder: &mut CircuitBuilder<E>,
insn_kind: InsnKind,
rd_written: RegisterExpr<E>,
) -> Result<Self, ZKVMError> {
// State in and out
let vm_state = StateInOut::construct_circuit(circuit_builder, true)?;

// Registers
let rd = WriteRD::construct_circuit(circuit_builder, rd_written, vm_state.ts)?;

// Fetch instruction
circuit_builder.lk_fetch(&InsnRecord::new(
vm_state.pc.expr(),
(insn_kind.codes().opcode as usize).into(),
rd.id.expr(),
0.into(),
0.into(),
0.into(),
vm_state.next_pc.unwrap().expr() - vm_state.pc.expr(),
))?;

Ok(JInstructionConfig { vm_state, rd })
}

pub fn assign_instance(
&self,
instance: &mut [MaybeUninit<<E as ExtensionField>::BaseField>],
lk_multiplicity: &mut LkMultiplicity,
step: &StepRecord,
) -> Result<(), ZKVMError> {
self.vm_state.assign_instance(instance, step)?;
self.rd.assign_instance(instance, lk_multiplicity, step)?;

// Fetch the instruction.
lk_multiplicity.fetch(step.pc().before.0);

Ok(())
}
}
14 changes: 14 additions & 0 deletions ceno_zkvm/src/instructions/riscv/jump.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
mod jal;

use super::RIVInstruction;
use ceno_emul::InsnKind;
use jal::JalCircuit;

#[cfg(test)]
mod test;

pub struct JalOp;
impl RIVInstruction for JalOp {
const INST_KIND: InsnKind = InsnKind::JAL;
}
pub type JalInstruction<E> = JalCircuit<E, JalOp>;
78 changes: 78 additions & 0 deletions ceno_zkvm/src/instructions/riscv/jump/jal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use std::{marker::PhantomData, mem::MaybeUninit};

use ff_ext::ExtensionField;

use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::ToExpr,
instructions::{
riscv::{constants::UInt, j_insn::JInstructionConfig, RIVInstruction},
Instruction,
},
witness::LkMultiplicity,
Value,
};
use ceno_emul::PC_STEP_SIZE;

pub struct JalConfig<E: ExtensionField> {
pub j_insn: JInstructionConfig<E>,
pub rd_written: UInt<E>,
}

pub struct JalCircuit<E, I>(PhantomData<(E, I)>);

/// JAL instruction circuit
///
/// Note: does not validate that next_pc is aligned by 4-byte increments, which
/// should be verified by lookup argument of the next execution step against
/// the program table
///
/// Assumption: values for valid initial program counter must lie between
/// 2^20 and 2^32 - 2^20 + 2 inclusive, probably enforced by the static
/// program lookup table. If this assumption does not hold, then resulting
/// value for next_pc may not correctly wrap mod 2^32 because of the use
/// of native WitIn values for address space arithmetic.
impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for JalCircuit<E, I> {
type InstructionConfig = JalConfig<E>;

fn name() -> String {
format!("{:?}", I::INST_KIND)
}

fn construct_circuit(
circuit_builder: &mut CircuitBuilder<E>,
) -> Result<JalConfig<E>, ZKVMError> {
let rd_written = UInt::new(|| "rd_written", circuit_builder)?;

let j_insn = JInstructionConfig::construct_circuit(
circuit_builder,
I::INST_KIND,
rd_written.register_expr(),
)?;

circuit_builder.require_equal(
|| "jal rd_written",
rd_written.value(),
j_insn.vm_state.pc.expr() + PC_STEP_SIZE.into(),
)?;

Ok(JalConfig { j_insn, rd_written })
}

fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<E::BaseField>],
lk_multiplicity: &mut LkMultiplicity,
step: &ceno_emul::StepRecord,
) -> Result<(), ZKVMError> {
config
.j_insn
.assign_instance(instance, lk_multiplicity, step)?;

let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity);
config.rd_written.assign_value(instance, rd_written);

Ok(())
}
}
54 changes: 54 additions & 0 deletions ceno_zkvm/src/instructions/riscv/jump/test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use ceno_emul::{ByteAddr, Change, StepRecord, PC_STEP_SIZE};
use goldilocks::GoldilocksExt2;
use itertools::Itertools;
use multilinear_extensions::mle::IntoMLEs;

use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
instructions::Instruction,
scheme::mock_prover::{MockProver, MOCK_PC_JAL, MOCK_PROGRAM},
};

use super::JalInstruction;

#[test]
fn test_opcode_jal() {
let mut cs = ConstraintSystem::<GoldilocksExt2>::new(|| "riscv");
let mut cb = CircuitBuilder::new(&mut cs);
let config = cb
.namespace(
|| "jal",
|cb| {
let config = JalInstruction::<GoldilocksExt2>::construct_circuit(cb);
Ok(config)
},
)
.unwrap()
.unwrap();

let pc_offset: i32 = -4i32;
let new_pc: ByteAddr = ByteAddr(MOCK_PC_JAL.0.wrapping_add_signed(pc_offset));
let (raw_witin, _lkm) = JalInstruction::<GoldilocksExt2>::assign_instances(
&config,
cb.cs.num_witin as usize,
vec![StepRecord::new_j_instruction(
4,
Change::new(MOCK_PC_JAL, new_pc),
MOCK_PROGRAM[21],
Change::new(0, (MOCK_PC_JAL + PC_STEP_SIZE).into()),
0,
)],
)
.unwrap();

MockProver::assert_satisfied(
&mut cb,
&raw_witin
.de_interleaving()
.into_mles()
.into_iter()
.map(|v| v.into())
.collect_vec(),
None,
);
}
1 change: 0 additions & 1 deletion ceno_zkvm/src/instructions/riscv/sltu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ mod test {
let mut rng = rand::thread_rng();
let a: u32 = rng.gen();
let b: u32 = rng.gen();
println!("random: {}, {}", a, b);
verify("random 1", a, b, (a < b) as u32);
verify("random 2", b, a, !(a < b) as u32);
}
Expand Down
3 changes: 3 additions & 0 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ pub const MOCK_PROGRAM: &[u32] = &[
0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b001 << 12 | MOCK_RD << 7 | 0x33,
// srl x4, x2, x3
0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b101 << 12 | MOCK_RD << 7 | 0x33,
// jal x4, 0xffffe
0b_1_1111111110_1_11111111 << 12 | MOCK_RD << 7 | 0x6f,
];
// Addresses of particular instructions in the mock program.
pub const MOCK_PC_ADD: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start());
Expand All @@ -105,6 +107,7 @@ pub const MOCK_PC_BGE: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 68);
pub const MOCK_PC_MULHU: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 72);
pub const MOCK_PC_SLL: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 76);
pub const MOCK_PC_SRL: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 80);
pub const MOCK_PC_JAL: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 84);

#[allow(clippy::enum_variant_names)]
#[derive(Debug, PartialEq, Clone)]
Expand Down

0 comments on commit cdb771a

Please sign in to comment.