Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Implement JAL opcode #305

Merged
merged 16 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions ceno_emul/src/tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,29 @@ impl StepRecord {
}
}

pub fn new_j_instruction(
cycle: Cycle,
pc: Change<ByteAddr>,
insn_code: Word,
rd: Change<Word>,
previous_cycle: Cycle,
) -> StepRecord {
let insn = DecodedInstruction::new(insn_code);
StepRecord {
cycle,
pc,
insn_code,
rs1: None,
rs2: None,
rd: Some(WriteOp {
addr: CENO_PLATFORM.register_vma(insn.rd() as RegIdx).into(),
value: rd,
previous_cycle,
}),
memory_op: None,
}
}

pub fn cycle(&self) -> Cycle {
self.cycle
}
Expand Down
27 changes: 18 additions & 9 deletions ceno_zkvm/examples/riscv_opcodes.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{iter, time::Instant};

use ceno_zkvm::{
instructions::riscv::{arith::AddInstruction, branch::BltuInstruction},
instructions::riscv::{arith::AddInstruction, branch::BltuInstruction, jump::JalInstruction},
scheme::prover::ZKVMProver,
tables::ProgramTableCircuit,
};
Expand All @@ -10,7 +10,7 @@ use const_env::from_env;

use ceno_emul::{
ByteAddr,
InsnKind::{ADD, BLTU},
InsnKind::{ADD, BLTU, JAL},
StepRecord, VMState, CENO_PLATFORM,
};
use ceno_zkvm::{
Expand All @@ -37,11 +37,12 @@ const RAYON_NUM_THREADS: usize = 8;
#[allow(clippy::unusual_byte_groupings)]
const ECALL_HALT: u32 = 0b_000000000000_00000_000_00000_1110011;
#[allow(clippy::unusual_byte_groupings)]
const PROGRAM_CODE: [u32; 4] = [
const PROGRAM_CODE: [u32; 5] = [
// func7 rs2 rs1 f3 rd opcode
0b_0000000_00100_00001_000_00100_0110011, // add x4, x4, x1 <=> addi x4, x4, 1
0b_0000000_00011_00010_000_00011_0110011, // add x3, x3, x2 <=> addi x3, x3, -1
0b_1_111111_00011_00000_110_1100_1_1100011, // bltu x0, x3, -8
0b_0_0000000010_0_00000000_00001_1101111, // jal x1, 4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not something introduced in this PR, but:

This style of specifying test data is approximately unreadable. At least for me. I can't (easily) tell whether these binary numbers here make any sense.

We can test the decoding from binary to something symbolic one by one for each instruction, but then when we test multiple instructions together, we should use the symbolic form as an intermediary input.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point, and strong agree that we could use a better way to encode this data -- it's challenging both to read and write instructions like this. I propose leaving the formatting as is in this PR, and follow up with at least a small rework to make these values more legible and less error-prone.

Issue to keep track of the improvement: #331

ECALL_HALT, // ecall halt
];

Expand Down Expand Up @@ -103,6 +104,7 @@ fn main() {
// opcode circuits
let add_config = zkvm_cs.register_opcode_circuit::<AddInstruction<E>>();
let bltu_config = zkvm_cs.register_opcode_circuit::<BltuInstruction>();
let jal_config = zkvm_cs.register_opcode_circuit::<JalInstruction<E>>();
// tables
let u16_range_config = zkvm_cs.register_table_circuit::<U16TableCircuit<E>>();
let and_config = zkvm_cs.register_table_circuit::<AndTableCircuit<E>>();
Expand All @@ -118,6 +120,7 @@ fn main() {
let mut zkvm_fixed_traces = ZKVMFixedTraces::default();
zkvm_fixed_traces.register_opcode_circuit::<AddInstruction<E>>(&zkvm_cs);
zkvm_fixed_traces.register_opcode_circuit::<BltuInstruction>(&zkvm_cs);
zkvm_fixed_traces.register_opcode_circuit::<JalInstruction<E>>(&zkvm_cs);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Continued: See, we even already have some symbolic form for AddInstruction and JlaInstruction etc, so might as well make use of them.


zkvm_fixed_traces.register_table_circuit::<U16TableCircuit<E>>(
&zkvm_cs,
Expand Down Expand Up @@ -172,19 +175,22 @@ fn main() {
.collect::<Vec<_>>();
let mut add_records = Vec::new();
let mut bltu_records = Vec::new();
let mut jal_records = Vec::new();
all_records.iter().for_each(|record| {
bgillesp marked this conversation as resolved.
Show resolved Hide resolved
let kind = record.insn().kind().1;
if kind == ADD {
add_records.push(record.clone());
} else if kind == BLTU {
bltu_records.push(record.clone());
match kind {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, matching is better than chains of if-else.

ADD => add_records.push(record.clone()),
BLTU => bltu_records.push(record.clone()),
JAL => jal_records.push(record.clone()),
bgillesp marked this conversation as resolved.
Show resolved Hide resolved
_ => {}
}
});

tracing::info!(
"tracer generated {} ADD records, {} BLTU records",
"tracer generated {} ADD records, {} BLTU records, {} JAL records",
add_records.len(),
bltu_records.len()
bltu_records.len(),
jal_records.len(),
);

let mut zkvm_witness = ZKVMWitnesses::default();
Expand All @@ -195,6 +201,9 @@ fn main() {
zkvm_witness
.assign_opcode_circuit::<BltuInstruction>(&zkvm_cs, &bltu_config, bltu_records)
.unwrap();
zkvm_witness
.assign_opcode_circuit::<JalInstruction<E>>(&zkvm_cs, &jal_config, jal_records)
.unwrap();
zkvm_witness.finalize_lk_multiplicities();
// assign table circuits
zkvm_witness
Expand Down
2 changes: 2 additions & 0 deletions ceno_zkvm/src/instructions/riscv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ pub mod branch;
pub mod config;
pub mod constants;
pub mod divu;
pub mod jump;
pub mod logic;
pub mod shift_imm;
pub mod sltu;

mod b_insn;
mod i_insn;
mod insn_base;
mod j_insn;
mod r_insn;

#[cfg(test)]
Expand Down
80 changes: 80 additions & 0 deletions ceno_zkvm/src/instructions/riscv/j_insn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use ceno_emul::{InsnKind, StepRecord};
use ff_ext::ExtensionField;

use crate::{
chip_handler::RegisterExpr,
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::{ToExpr, WitIn},
instructions::riscv::insn_base::{StateInOut, WriteRD},
set_val,
tables::InsnRecord,
witness::LkMultiplicity,
};
use core::mem::MaybeUninit;

// Opcode: 1101111

/// This config handles the common part of the J-type instruction (JAL):
/// - PC, cycle, fetch
/// - Register access
///
/// It does not witness the output rd value or next_pc produced by the JAL opcode
#[derive(Debug)]
pub struct JInstructionConfig<E: ExtensionField> {
pub vm_state: StateInOut<E>,
pub rd: WriteRD<E>,
pub imm: WitIn,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically the imm witness is redundant since it equals a degree-1 expression (next_pc - pc). You could use the expression directly.

Copy link
Collaborator

@kunxian-xia kunxian-xia Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we can use StateInOut::construct_circuit(circuit_builder, false)?; to initialize vm_state which is more natural.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice find @naure! We can't use StateInOut::construct_circuit(circuit_builder, false) @kunxian-xia because it imposes the constraint next_pc = pc + 4, but we can put this direct check into the J-type instruction gadget since JAL is the only opcode that uses it. Pushing a commit with this change now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This is ready to merge from my side.

}

impl<E: ExtensionField> JInstructionConfig<E> {
pub fn construct_circuit(
circuit_builder: &mut CircuitBuilder<E>,
insn_kind: InsnKind,
rd_written: RegisterExpr<E>,
) -> Result<Self, ZKVMError> {
// State in and out
let vm_state = StateInOut::construct_circuit(circuit_builder, true)?;

// Registers
let rd = WriteRD::construct_circuit(circuit_builder, rd_written, vm_state.ts)?;

// Immediate
let imm = circuit_builder.create_witin(|| "imm")?;

// Fetch instruction
circuit_builder.lk_fetch(&InsnRecord::new(
vm_state.pc.expr(),
(insn_kind.codes().opcode as usize).into(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit suspicious. Why do we need to cast as usize to just run .into() afterwards? Can't we provide an Into / From from instance that would let us get by without as usize?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point! The into call here relies on the implementation of From<usize> for the Expression type right now, but it looks like it's a little complicated to specify generic behavior over all primitive integer types (as far as I can tell). I'll put an implementation of this into a separate small PR in case folks have opinions on how this should be handled.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See PR: #333

rd.id.expr(),
0.into(),
0.into(),
0.into(),
imm.expr(),
))?;

Ok(JInstructionConfig { vm_state, rd, imm })
}

pub fn assign_instance(
&self,
instance: &mut [MaybeUninit<<E as ExtensionField>::BaseField>],
lk_multiplicity: &mut LkMultiplicity,
step: &StepRecord,
) -> Result<(), ZKVMError> {
self.vm_state.assign_instance(instance, step)?;
self.rd.assign_instance(instance, lk_multiplicity, step)?;

// Immediate
set_val!(
instance,
self.imm,
InsnRecord::imm_or_funct7_field::<E::BaseField>(&step.insn())
);

// Fetch the instruction.
lk_multiplicity.fetch(step.pc().before.0);

Ok(())
}
}
14 changes: 14 additions & 0 deletions ceno_zkvm/src/instructions/riscv/jump.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
mod jal;

use super::RIVInstruction;
use ceno_emul::InsnKind;
use jal::JalCircuit;

#[cfg(test)]
mod test;

pub struct JalOp;
impl RIVInstruction for JalOp {
const INST_KIND: InsnKind = InsnKind::JAL;
}
pub type JalInstruction<E> = JalCircuit<E, JalOp>;
76 changes: 76 additions & 0 deletions ceno_zkvm/src/instructions/riscv/jump/jal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
use std::{marker::PhantomData, mem::MaybeUninit};

use ff_ext::ExtensionField;

use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::ToExpr,
instructions::{
riscv::{constants::UInt, j_insn::JInstructionConfig, RIVInstruction},
Instruction,
},
witness::LkMultiplicity,
Value,
};
use ceno_emul::PC_STEP_SIZE;

pub struct JalConfig<E: ExtensionField> {
pub j_insn: JInstructionConfig<E>,
pub rd_written: UInt<E>,
}

pub struct JalCircuit<E, I>(PhantomData<(E, I)>);

/// JAL instruction circuit
/// NOTE: does not validate that next_pc is aligned by 4-byte increments, which
/// should be verified by lookup argument of the next execution step against
/// the program table
impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for JalCircuit<E, I> {
type InstructionConfig = JalConfig<E>;

fn name() -> String {
format!("{:?}", I::INST_KIND)
}

fn construct_circuit(
circuit_builder: &mut CircuitBuilder<E>,
) -> Result<JalConfig<E>, ZKVMError> {
let rd_written = UInt::new_unchecked(|| "rd_limbs", circuit_builder)?;

let j_insn = JInstructionConfig::construct_circuit(
circuit_builder,
I::INST_KIND,
rd_written.register_expr(),
)?;

// constrain next_pc
circuit_builder.require_equal(
|| "jump next_pc",
j_insn.vm_state.next_pc.unwrap().expr(),
j_insn.vm_state.pc.expr() + j_insn.imm.expr(),
)?;

// constrain return address written to rd
let return_addr = j_insn.vm_state.pc.expr() + PC_STEP_SIZE.into();
circuit_builder.require_equal(|| "jump rd", rd_written.value(), return_addr)?;

Ok(JalConfig { j_insn, rd_written })
}

fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<E::BaseField>],
lk_multiplicity: &mut LkMultiplicity,
step: &ceno_emul::StepRecord,
) -> Result<(), ZKVMError> {
config
.j_insn
.assign_instance(instance, lk_multiplicity, step)?;

let rd = Value::new_unchecked(step.rd().unwrap().value.after);
config.rd_written.assign_limbs(instance, rd.as_u16_limbs());

Ok(())
}
}
53 changes: 53 additions & 0 deletions ceno_zkvm/src/instructions/riscv/jump/test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use ceno_emul::{Change, StepRecord, PC_STEP_SIZE};
use goldilocks::GoldilocksExt2;
use itertools::Itertools;
use multilinear_extensions::mle::IntoMLEs;

use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
instructions::Instruction,
scheme::mock_prover::{MockProver, MOCK_PC_JAL, MOCK_PROGRAM},
};

use super::JalInstruction;

#[test]
fn test_opcode_jal() {
let mut cs = ConstraintSystem::<GoldilocksExt2>::new(|| "riscv");
let mut cb = CircuitBuilder::new(&mut cs);
let config = cb
.namespace(
|| "jal",
|cb| {
let config = JalInstruction::<GoldilocksExt2>::construct_circuit(cb);
Ok(config)
},
)
.unwrap()
.unwrap();

let pc_offset: usize = 0x10004;
let (raw_witin, _lkm) = JalInstruction::<GoldilocksExt2>::assign_instances(
&config,
cb.cs.num_witin as usize,
vec![StepRecord::new_j_instruction(
4,
Change::new(MOCK_PC_JAL, MOCK_PC_JAL + pc_offset),
MOCK_PROGRAM[18],
Change::new(0, (MOCK_PC_JAL + PC_STEP_SIZE).into()),
0,
)],
)
.unwrap();

MockProver::assert_satisfied(
&mut cb,
&raw_witin
.de_interleaving()
.into_mles()
.into_iter()
.map(|v| v.into())
.collect_vec(),
None,
);
}
3 changes: 3 additions & 0 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ pub const MOCK_PROGRAM: &[u32] = &[
0b_1_111111 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b_111 << 12 | 0b_1100_1 << 7 | 0x63,
// bge x2, x3, -8
0b_1_111111 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b_101 << 12 | 0b_1100_1 << 7 | 0x63,
// jal x4, 0x10004
0b_0_0000000010_0_00010000 << 12 | MOCK_RD << 7 | 0x6f,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As said above, putting binary numbers into the test data here basically impossible to read.

];
// Addresses of particular instructions in the mock program.
pub const MOCK_PC_ADD: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start());
Expand All @@ -96,6 +98,7 @@ pub const MOCK_PC_ADDI_SUB: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 56);
pub const MOCK_PC_BLTU: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 60);
pub const MOCK_PC_BGEU: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 64);
pub const MOCK_PC_BGE: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 68);
pub const MOCK_PC_JAL: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 72);

#[allow(clippy::enum_variant_names)]
#[derive(Debug, PartialEq, Clone)]
Expand Down
Loading