diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index cedf9f96e..e811e8f22 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -6,6 +6,7 @@ pub mod branch; pub mod config; pub mod constants; pub mod divu; +pub mod jump; pub mod logic; pub mod shift_imm; pub mod sltu; @@ -13,6 +14,7 @@ pub mod sltu; mod b_insn; mod i_insn; mod insn_base; +mod j_insn; mod r_insn; #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/j_insn.rs b/ceno_zkvm/src/instructions/riscv/j_insn.rs new file mode 100644 index 000000000..e9bcf3242 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/j_insn.rs @@ -0,0 +1,90 @@ +#![allow(dead_code)] // TODO: remove after BLT, BEQ, … + +use ceno_emul::{InsnKind, StepRecord}; +use ff_ext::ExtensionField; + +use crate::{ + chip_handler::RegisterExpr, + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{ToExpr, WitIn}, + instructions::riscv::insn_base::{StateInOut, WriteRD}, + set_val, + tables::InsnRecord, + witness::LkMultiplicity, +}; +use core::mem::MaybeUninit; + +// Opcode: 1100011 +// Funct3: +// 000 BEQ +// 001 BNE +// 100 BLT +// 101 BGE +// 110 BLTU +// 111 BGEU +// + +/// This config handles the common part of the J-type instruction (JAL): +/// - PC, cycle, fetch +/// - Register access +/// +/// It does not witness the output rd value or next_pc produced by the JAL opcode +#[derive(Debug)] +pub struct JInstructionConfig { + pub vm_state: StateInOut, + pub rd: WriteRD, + pub imm: WitIn, +} + +impl JInstructionConfig { + pub fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + insn_kind: InsnKind, + rd_written: RegisterExpr, + ) -> Result { + // State in and out + let vm_state = StateInOut::construct_circuit(circuit_builder, true)?; + + // Registers + let rd = WriteRD::construct_circuit(circuit_builder, rd_written, vm_state.ts)?; + + // Immediate + let imm = circuit_builder.create_witin(|| "imm")?; + + // Fetch instruction + circuit_builder.lk_fetch(&InsnRecord::new( + vm_state.pc.expr(), + (insn_kind.codes().opcode as usize).into(), + rd.id.expr(), + (insn_kind.codes().func3 as usize).into(), + 0.into(), + 0.into(), + imm.expr(), + ))?; + + Ok(JInstructionConfig { vm_state, rd, imm }) + } + + pub fn assign_instance( + &self, + instance: &mut [MaybeUninit<::BaseField>], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + self.vm_state.assign_instance(instance, step)?; + self.rd.assign_instance(instance, lk_multiplicity, step)?; + + // Immediate + set_val!( + instance, + self.imm, + InsnRecord::imm_or_funct7_field::(&step.insn()) + ); + + // Fetch the instruction. + lk_multiplicity.fetch(step.pc().before.0); + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/jump.rs b/ceno_zkvm/src/instructions/riscv/jump.rs new file mode 100644 index 000000000..5d068c6d1 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/jump.rs @@ -0,0 +1,13 @@ +mod jal; + +use super::RIVInstruction; +use ceno_emul::InsnKind; + +// #[cfg(test)] +// mod test; + +pub struct JalOp; +impl RIVInstruction for JalOp { + const INST_KIND: InsnKind = InsnKind::JAL; +} +pub type JalInstruction = jal::JalInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal.rs b/ceno_zkvm/src/instructions/riscv/jump/jal.rs new file mode 100644 index 000000000..44704b345 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/jump/jal.rs @@ -0,0 +1,73 @@ +use std::{marker::PhantomData, mem::MaybeUninit}; + +use ff_ext::ExtensionField; + +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::ToExpr, + instructions::{ + riscv::{constants::UInt, j_insn::JInstructionConfig, RIVInstruction}, + Instruction, + }, + witness::LkMultiplicity, + Value, +}; +use ceno_emul::PC_STEP_SIZE; + +pub struct JalInstruction(PhantomData<(E, I)>); + +pub struct InstructionConfig { + pub j_insn: JInstructionConfig, + pub rd_written: UInt, +} + +impl Instruction for JalInstruction { + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + type InstructionConfig = InstructionConfig; + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + ) -> Result, ZKVMError> { + let rd_written = UInt::new_unchecked(|| "rd_limbs", circuit_builder)?; + + let j_insn = JInstructionConfig::construct_circuit( + circuit_builder, + I::INST_KIND, + rd_written.register_expr(), + )?; + + // constrain next_pc + let jump_delta = j_insn.imm.expr() * 2.into(); + circuit_builder.require_equal( + || "jump next_pc", + j_insn.vm_state.next_pc.unwrap().expr(), + j_insn.vm_state.pc.expr() + jump_delta, + )?; + + // constrain return address written to rd + let return_addr = j_insn.vm_state.pc.expr() + PC_STEP_SIZE.into(); + circuit_builder.require_equal(|| "jump rd", rd_written.value(), return_addr)?; + + Ok(InstructionConfig { j_insn, rd_written }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [MaybeUninit], + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + config + .j_insn + .assign_instance(instance, lk_multiplicity, step)?; + + let rd = Value::new_unchecked(step.rd().unwrap().value.after); + config.rd_written.assign_limbs(instance, rd.as_u16_limbs()); + + Ok(()) + } +}