Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Implement JAL opcode #305

Merged
merged 16 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions ceno_emul/src/tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,29 @@ impl StepRecord {
}
}

pub fn new_j_instruction(
cycle: Cycle,
pc: Change<ByteAddr>,
insn_code: Word,
rd: Change<Word>,
previous_cycle: Cycle,
) -> StepRecord {
let insn = DecodedInstruction::new(insn_code);
StepRecord {
cycle,
pc,
insn_code,
rs1: None,
rs2: None,
rd: Some(WriteOp {
addr: CENO_PLATFORM.register_vma(insn.rd() as RegIdx).into(),
value: rd,
previous_cycle,
}),
memory_op: None,
}
}

pub fn cycle(&self) -> Cycle {
self.cycle
}
Expand Down
19 changes: 14 additions & 5 deletions ceno_zkvm/examples/riscv_opcodes.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{iter, panic, time::Instant};

use ceno_zkvm::{
instructions::riscv::{arith::AddInstruction, branch::BltuInstruction},
instructions::riscv::{arith::AddInstruction, branch::BltuInstruction, jump::JalInstruction},
scheme::prover::ZKVMProver,
tables::ProgramTableCircuit,
};
Expand All @@ -10,7 +10,7 @@ use const_env::from_env;

use ceno_emul::{
ByteAddr,
InsnKind::{ADD, BLTU, EANY},
InsnKind::{ADD, BLTU, EANY, JAL},
StepRecord, VMState, CENO_PLATFORM,
};
use ceno_zkvm::{
Expand Down Expand Up @@ -39,11 +39,12 @@ const RAYON_NUM_THREADS: usize = 8;
#[allow(clippy::unusual_byte_groupings)]
const ECALL_HALT: u32 = 0b_000000000000_00000_000_00000_1110011;
#[allow(clippy::unusual_byte_groupings)]
const PROGRAM_CODE: [u32; 4] = [
const PROGRAM_CODE: [u32; 5] = [
// func7 rs2 rs1 f3 rd opcode
0b_0000000_00100_00001_000_00100_0110011, // add x4, x4, x1 <=> addi x4, x4, 1
0b_0000000_00011_00010_000_00011_0110011, // add x3, x3, x2 <=> addi x3, x3, -1
0b_1_111111_00011_00000_110_1100_1_1100011, // bltu x0, x3, -8
0b_0_0000000010_0_00000000_00001_1101111, // jal x1, 4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not something introduced in this PR, but:

This style of specifying test data is approximately unreadable. At least for me. I can't (easily) tell whether these binary numbers here make any sense.

We can test the decoding from binary to something symbolic one by one for each instruction, but then when we test multiple instructions together, we should use the symbolic form as an intermediary input.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point, and strong agree that we could use a better way to encode this data -- it's challenging both to read and write instructions like this. I propose leaving the formatting as is in this PR, and follow up with at least a small rework to make these values more legible and less error-prone.

Issue to keep track of the improvement: #331

ECALL_HALT, // ecall halt
];

Expand Down Expand Up @@ -105,6 +106,7 @@ fn main() {
// opcode circuits
let add_config = zkvm_cs.register_opcode_circuit::<AddInstruction<E>>();
let bltu_config = zkvm_cs.register_opcode_circuit::<BltuInstruction>();
let jal_config = zkvm_cs.register_opcode_circuit::<JalInstruction<E>>();
let halt_config = zkvm_cs.register_opcode_circuit::<HaltInstruction<E>>();
// tables
let u16_range_config = zkvm_cs.register_table_circuit::<U16TableCircuit<E>>();
Expand All @@ -121,6 +123,7 @@ fn main() {
let mut zkvm_fixed_traces = ZKVMFixedTraces::default();
zkvm_fixed_traces.register_opcode_circuit::<AddInstruction<E>>(&zkvm_cs);
zkvm_fixed_traces.register_opcode_circuit::<BltuInstruction>(&zkvm_cs);
zkvm_fixed_traces.register_opcode_circuit::<JalInstruction<E>>(&zkvm_cs);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Continued: See, we even already have some symbolic form for AddInstruction and JlaInstruction etc, so might as well make use of them.

zkvm_fixed_traces.register_opcode_circuit::<HaltInstruction<E>>(&zkvm_cs);

zkvm_fixed_traces.register_table_circuit::<U16TableCircuit<E>>(
Expand Down Expand Up @@ -176,12 +179,14 @@ fn main() {
.collect::<Vec<_>>();
let mut add_records = Vec::new();
let mut bltu_records = Vec::new();
let mut jal_records = Vec::new();
let mut halt_records = Vec::new();
all_records.into_iter().for_each(|record| {
let kind = record.insn().kind().1;
match kind {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, matching is better than chains of if-else.

ADD => add_records.push(record),
BLTU => bltu_records.push(record),
JAL => jal_records.push(record),
EANY => {
if record.rs1().unwrap().value == CENO_PLATFORM.ecall_halt() {
halt_records.push(record);
Expand All @@ -196,9 +201,10 @@ fn main() {
let pi = PublicValues::new(exit_code, 0);

tracing::info!(
"tracer generated {} ADD records, {} BLTU records",
"tracer generated {} ADD records, {} BLTU records, {} JAL records",
add_records.len(),
bltu_records.len()
bltu_records.len(),
jal_records.len(),
);

let mut zkvm_witness = ZKVMWitnesses::default();
Expand All @@ -209,6 +215,9 @@ fn main() {
zkvm_witness
.assign_opcode_circuit::<BltuInstruction>(&zkvm_cs, &bltu_config, bltu_records)
.unwrap();
zkvm_witness
.assign_opcode_circuit::<JalInstruction<E>>(&zkvm_cs, &jal_config, jal_records)
.unwrap();
zkvm_witness
.assign_opcode_circuit::<HaltInstruction<E>>(&zkvm_cs, &halt_config, halt_records)
.unwrap();
Expand Down
4 changes: 3 additions & 1 deletion ceno_zkvm/src/instructions/riscv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub mod config;
pub mod constants;
pub mod divu;
pub mod ecall;
pub mod jump;
pub mod logic;
pub mod mulh;
pub mod shift;
Expand All @@ -16,9 +17,10 @@ pub mod sltu;
mod b_insn;
mod i_insn;
mod insn_base;
mod j_insn;
mod r_insn;

mod ecall_insn;
mod r_insn;

#[cfg(test)]
mod test;
Expand Down
1 change: 0 additions & 1 deletion ceno_zkvm/src/instructions/riscv/divu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ mod test {
let mut rng = rand::thread_rng();
let a: u32 = rng.gen();
let b: u32 = rng.gen_range(1..u32::MAX);
println!("random: {} / {} = {}", a, b, a / b);
verify("random", a, b, a / b);
}
}
Expand Down
69 changes: 69 additions & 0 deletions ceno_zkvm/src/instructions/riscv/j_insn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use ceno_emul::{InsnKind, StepRecord};
use ff_ext::ExtensionField;

use crate::{
chip_handler::RegisterExpr,
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::ToExpr,
instructions::riscv::insn_base::{StateInOut, WriteRD},
tables::InsnRecord,
witness::LkMultiplicity,
};
use core::mem::MaybeUninit;

// Opcode: 1101111

/// This config handles the common part of the J-type instruction (JAL):
/// - PC, cycle, fetch
/// - Register access
///
/// It does not witness the output rd value produced by the JAL opcode, but
/// does constrain next_pc = pc + imm using the instruction table lookup
#[derive(Debug)]
pub struct JInstructionConfig<E: ExtensionField> {
pub vm_state: StateInOut<E>,
pub rd: WriteRD<E>,
}

impl<E: ExtensionField> JInstructionConfig<E> {
pub fn construct_circuit(
circuit_builder: &mut CircuitBuilder<E>,
insn_kind: InsnKind,
rd_written: RegisterExpr<E>,
) -> Result<Self, ZKVMError> {
// State in and out
let vm_state = StateInOut::construct_circuit(circuit_builder, true)?;

// Registers
let rd = WriteRD::construct_circuit(circuit_builder, rd_written, vm_state.ts)?;

// Fetch instruction
circuit_builder.lk_fetch(&InsnRecord::new(
vm_state.pc.expr(),
(insn_kind.codes().opcode as usize).into(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit suspicious. Why do we need to cast as usize to just run .into() afterwards? Can't we provide an Into / From from instance that would let us get by without as usize?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point! The into call here relies on the implementation of From<usize> for the Expression type right now, but it looks like it's a little complicated to specify generic behavior over all primitive integer types (as far as I can tell). I'll put an implementation of this into a separate small PR in case folks have opinions on how this should be handled.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See PR: #333

rd.id.expr(),
0.into(),
0.into(),
0.into(),
vm_state.next_pc.unwrap().expr() - vm_state.pc.expr(),
))?;

Ok(JInstructionConfig { vm_state, rd })
}

pub fn assign_instance(
&self,
instance: &mut [MaybeUninit<<E as ExtensionField>::BaseField>],
lk_multiplicity: &mut LkMultiplicity,
step: &StepRecord,
) -> Result<(), ZKVMError> {
self.vm_state.assign_instance(instance, step)?;
self.rd.assign_instance(instance, lk_multiplicity, step)?;

// Fetch the instruction.
lk_multiplicity.fetch(step.pc().before.0);

Ok(())
}
}
14 changes: 14 additions & 0 deletions ceno_zkvm/src/instructions/riscv/jump.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
mod jal;

use super::RIVInstruction;
use ceno_emul::InsnKind;
use jal::JalCircuit;

#[cfg(test)]
mod test;

pub struct JalOp;
impl RIVInstruction for JalOp {
const INST_KIND: InsnKind = InsnKind::JAL;
}
pub type JalInstruction<E> = JalCircuit<E, JalOp>;
78 changes: 78 additions & 0 deletions ceno_zkvm/src/instructions/riscv/jump/jal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use std::{marker::PhantomData, mem::MaybeUninit};

use ff_ext::ExtensionField;

use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::ToExpr,
instructions::{
riscv::{constants::UInt, j_insn::JInstructionConfig, RIVInstruction},
Instruction,
},
witness::LkMultiplicity,
Value,
};
use ceno_emul::PC_STEP_SIZE;

pub struct JalConfig<E: ExtensionField> {
pub j_insn: JInstructionConfig<E>,
pub rd_written: UInt<E>,
}

pub struct JalCircuit<E, I>(PhantomData<(E, I)>);

/// JAL instruction circuit
///
/// Note: does not validate that next_pc is aligned by 4-byte increments, which
/// should be verified by lookup argument of the next execution step against
/// the program table
///
/// Assumption: values for valid initial program counter must lie between
/// 2^20 and 2^32 - 2^20 + 2 inclusive, probably enforced by the static
/// program lookup table. If this assumption does not hold, then resulting
/// value for next_pc may not correctly wrap mod 2^32 because of the use
/// of native WitIn values for address space arithmetic.
impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for JalCircuit<E, I> {
type InstructionConfig = JalConfig<E>;

fn name() -> String {
format!("{:?}", I::INST_KIND)
}

fn construct_circuit(
circuit_builder: &mut CircuitBuilder<E>,
) -> Result<JalConfig<E>, ZKVMError> {
let rd_written = UInt::new(|| "rd_written", circuit_builder)?;

let j_insn = JInstructionConfig::construct_circuit(
circuit_builder,
I::INST_KIND,
rd_written.register_expr(),
)?;

circuit_builder.require_equal(
|| "jal rd_written",
rd_written.value(),
j_insn.vm_state.pc.expr() + PC_STEP_SIZE.into(),
)?;

Ok(JalConfig { j_insn, rd_written })
}

fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<E::BaseField>],
lk_multiplicity: &mut LkMultiplicity,
step: &ceno_emul::StepRecord,
) -> Result<(), ZKVMError> {
config
.j_insn
.assign_instance(instance, lk_multiplicity, step)?;

let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity);
config.rd_written.assign_value(instance, rd_written);

Ok(())
}
}
54 changes: 54 additions & 0 deletions ceno_zkvm/src/instructions/riscv/jump/test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use ceno_emul::{ByteAddr, Change, StepRecord, PC_STEP_SIZE};
use goldilocks::GoldilocksExt2;
use itertools::Itertools;
use multilinear_extensions::mle::IntoMLEs;

use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
instructions::Instruction,
scheme::mock_prover::{MockProver, MOCK_PC_JAL, MOCK_PROGRAM},
};

use super::JalInstruction;

#[test]
fn test_opcode_jal() {
let mut cs = ConstraintSystem::<GoldilocksExt2>::new(|| "riscv");
let mut cb = CircuitBuilder::new(&mut cs);
let config = cb
.namespace(
|| "jal",
|cb| {
let config = JalInstruction::<GoldilocksExt2>::construct_circuit(cb);
Ok(config)
},
)
.unwrap()
.unwrap();

let pc_offset: i32 = -4i32;
let new_pc: ByteAddr = ByteAddr(MOCK_PC_JAL.0.wrapping_add_signed(pc_offset));
let (raw_witin, _lkm) = JalInstruction::<GoldilocksExt2>::assign_instances(
&config,
cb.cs.num_witin as usize,
vec![StepRecord::new_j_instruction(
4,
Change::new(MOCK_PC_JAL, new_pc),
MOCK_PROGRAM[21],
Change::new(0, (MOCK_PC_JAL + PC_STEP_SIZE).into()),
0,
)],
)
.unwrap();

MockProver::assert_satisfied(
&mut cb,
&raw_witin
.de_interleaving()
.into_mles()
.into_iter()
.map(|v| v.into())
.collect_vec(),
None,
);
}
1 change: 0 additions & 1 deletion ceno_zkvm/src/instructions/riscv/sltu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ mod test {
let mut rng = rand::thread_rng();
let a: u32 = rng.gen();
let b: u32 = rng.gen();
println!("random: {}, {}", a, b);
verify("random 1", a, b, (a < b) as u32);
verify("random 2", b, a, !(a < b) as u32);
}
Expand Down
3 changes: 3 additions & 0 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ pub const MOCK_PROGRAM: &[u32] = &[
0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b001 << 12 | MOCK_RD << 7 | 0x33,
// srl x4, x2, x3
0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b101 << 12 | MOCK_RD << 7 | 0x33,
// jal x4, 0xffffe
0b_1_1111111110_1_11111111 << 12 | MOCK_RD << 7 | 0x6f,
];
// Addresses of particular instructions in the mock program.
pub const MOCK_PC_ADD: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start());
Expand All @@ -105,6 +107,7 @@ pub const MOCK_PC_BGE: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 68);
pub const MOCK_PC_MULHU: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 72);
pub const MOCK_PC_SLL: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 76);
pub const MOCK_PC_SRL: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 80);
pub const MOCK_PC_JAL: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 84);

#[allow(clippy::enum_variant_names)]
#[derive(Debug, PartialEq, Clone)]
Expand Down
Loading