diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index f8c53453c..1e93155c5 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -142,6 +142,29 @@ impl StepRecord { } } + pub fn new_j_instruction( + cycle: Cycle, + pc: Change, + insn_code: Word, + rd: Change, + previous_cycle: Cycle, + ) -> StepRecord { + let insn = DecodedInstruction::new(insn_code); + StepRecord { + cycle, + pc, + insn_code, + rs1: None, + rs2: None, + rd: Some(WriteOp { + addr: CENO_PLATFORM.register_vma(insn.rd() as RegIdx).into(), + value: rd, + previous_cycle, + }), + memory_op: None, + } + } + pub fn cycle(&self) -> Cycle { self.cycle } diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index 9e6997a90..964a1e9d3 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -1,7 +1,7 @@ use std::{iter, panic, time::Instant}; use ceno_zkvm::{ - instructions::riscv::{arith::AddInstruction, branch::BltuInstruction}, + instructions::riscv::{arith::AddInstruction, branch::BltuInstruction, jump::JalInstruction}, scheme::prover::ZKVMProver, tables::ProgramTableCircuit, }; @@ -10,7 +10,7 @@ use const_env::from_env; use ceno_emul::{ ByteAddr, - InsnKind::{ADD, BLTU, EANY}, + InsnKind::{ADD, BLTU, EANY, JAL}, StepRecord, VMState, CENO_PLATFORM, }; use ceno_zkvm::{ @@ -39,11 +39,12 @@ const RAYON_NUM_THREADS: usize = 8; #[allow(clippy::unusual_byte_groupings)] const ECALL_HALT: u32 = 0b_000000000000_00000_000_00000_1110011; #[allow(clippy::unusual_byte_groupings)] -const PROGRAM_CODE: [u32; 4] = [ +const PROGRAM_CODE: [u32; 5] = [ // func7 rs2 rs1 f3 rd opcode 0b_0000000_00100_00001_000_00100_0110011, // add x4, x4, x1 <=> addi x4, x4, 1 0b_0000000_00011_00010_000_00011_0110011, // add x3, x3, x2 <=> addi x3, x3, -1 0b_1_111111_00011_00000_110_1100_1_1100011, // bltu x0, x3, -8 + 0b_0_0000000010_0_00000000_00001_1101111, // jal x1, 4 ECALL_HALT, // ecall halt ]; @@ -105,6 +106,7 @@ fn main() { // opcode circuits let add_config = zkvm_cs.register_opcode_circuit::>(); let bltu_config = zkvm_cs.register_opcode_circuit::(); + let jal_config = zkvm_cs.register_opcode_circuit::>(); let halt_config = zkvm_cs.register_opcode_circuit::>(); // tables let u16_range_config = zkvm_cs.register_table_circuit::>(); @@ -121,6 +123,7 @@ fn main() { let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); zkvm_fixed_traces.register_opcode_circuit::(&zkvm_cs); + zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); zkvm_fixed_traces.register_table_circuit::>( @@ -176,12 +179,14 @@ fn main() { .collect::>(); let mut add_records = Vec::new(); let mut bltu_records = Vec::new(); + let mut jal_records = Vec::new(); let mut halt_records = Vec::new(); all_records.into_iter().for_each(|record| { let kind = record.insn().kind().1; match kind { ADD => add_records.push(record), BLTU => bltu_records.push(record), + JAL => jal_records.push(record), EANY => { if record.rs1().unwrap().value == CENO_PLATFORM.ecall_halt() { halt_records.push(record); @@ -196,9 +201,10 @@ fn main() { let pi = PublicValues::new(exit_code, 0); tracing::info!( - "tracer generated {} ADD records, {} BLTU records", + "tracer generated {} ADD records, {} BLTU records, {} JAL records", add_records.len(), - bltu_records.len() + bltu_records.len(), + jal_records.len(), ); let mut zkvm_witness = ZKVMWitnesses::default(); @@ -209,6 +215,9 @@ fn main() { zkvm_witness .assign_opcode_circuit::(&zkvm_cs, &bltu_config, bltu_records) .unwrap(); + zkvm_witness + .assign_opcode_circuit::>(&zkvm_cs, &jal_config, jal_records) + .unwrap(); zkvm_witness .assign_opcode_circuit::>(&zkvm_cs, &halt_config, halt_records) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index 3fc4d5aa3..5be73f578 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -7,6 +7,7 @@ pub mod config; pub mod constants; pub mod divu; pub mod ecall; +pub mod jump; pub mod logic; pub mod mulh; pub mod shift; @@ -16,9 +17,10 @@ pub mod sltu; mod b_insn; mod i_insn; mod insn_base; +mod j_insn; +mod r_insn; mod ecall_insn; -mod r_insn; #[cfg(test)] mod test; diff --git a/ceno_zkvm/src/instructions/riscv/divu.rs b/ceno_zkvm/src/instructions/riscv/divu.rs index 8b139ff51..2378a0b9b 100644 --- a/ceno_zkvm/src/instructions/riscv/divu.rs +++ b/ceno_zkvm/src/instructions/riscv/divu.rs @@ -241,7 +241,6 @@ mod test { let mut rng = rand::thread_rng(); let a: u32 = rng.gen(); let b: u32 = rng.gen_range(1..u32::MAX); - println!("random: {} / {} = {}", a, b, a / b); verify("random", a, b, a / b); } } diff --git a/ceno_zkvm/src/instructions/riscv/j_insn.rs b/ceno_zkvm/src/instructions/riscv/j_insn.rs new file mode 100644 index 000000000..1ffec5b99 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/j_insn.rs @@ -0,0 +1,69 @@ +use ceno_emul::{InsnKind, StepRecord}; +use ff_ext::ExtensionField; + +use crate::{ + chip_handler::RegisterExpr, + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::ToExpr, + instructions::riscv::insn_base::{StateInOut, WriteRD}, + tables::InsnRecord, + witness::LkMultiplicity, +}; +use core::mem::MaybeUninit; + +// Opcode: 1101111 + +/// This config handles the common part of the J-type instruction (JAL): +/// - PC, cycle, fetch +/// - Register access +/// +/// It does not witness the output rd value produced by the JAL opcode, but +/// does constrain next_pc = pc + imm using the instruction table lookup +#[derive(Debug)] +pub struct JInstructionConfig { + pub vm_state: StateInOut, + pub rd: WriteRD, +} + +impl JInstructionConfig { + pub fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + insn_kind: InsnKind, + rd_written: RegisterExpr, + ) -> Result { + // State in and out + let vm_state = StateInOut::construct_circuit(circuit_builder, true)?; + + // Registers + let rd = WriteRD::construct_circuit(circuit_builder, rd_written, vm_state.ts)?; + + // Fetch instruction + circuit_builder.lk_fetch(&InsnRecord::new( + vm_state.pc.expr(), + (insn_kind.codes().opcode as usize).into(), + rd.id.expr(), + 0.into(), + 0.into(), + 0.into(), + vm_state.next_pc.unwrap().expr() - vm_state.pc.expr(), + ))?; + + Ok(JInstructionConfig { vm_state, rd }) + } + + pub fn assign_instance( + &self, + instance: &mut [MaybeUninit<::BaseField>], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + self.vm_state.assign_instance(instance, step)?; + self.rd.assign_instance(instance, lk_multiplicity, step)?; + + // Fetch the instruction. + lk_multiplicity.fetch(step.pc().before.0); + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/jump.rs b/ceno_zkvm/src/instructions/riscv/jump.rs new file mode 100644 index 000000000..21eb67187 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/jump.rs @@ -0,0 +1,14 @@ +mod jal; + +use super::RIVInstruction; +use ceno_emul::InsnKind; +use jal::JalCircuit; + +#[cfg(test)] +mod test; + +pub struct JalOp; +impl RIVInstruction for JalOp { + const INST_KIND: InsnKind = InsnKind::JAL; +} +pub type JalInstruction = JalCircuit; diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal.rs b/ceno_zkvm/src/instructions/riscv/jump/jal.rs new file mode 100644 index 000000000..2c07206c1 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/jump/jal.rs @@ -0,0 +1,78 @@ +use std::{marker::PhantomData, mem::MaybeUninit}; + +use ff_ext::ExtensionField; + +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::ToExpr, + instructions::{ + riscv::{constants::UInt, j_insn::JInstructionConfig, RIVInstruction}, + Instruction, + }, + witness::LkMultiplicity, + Value, +}; +use ceno_emul::PC_STEP_SIZE; + +pub struct JalConfig { + pub j_insn: JInstructionConfig, + pub rd_written: UInt, +} + +pub struct JalCircuit(PhantomData<(E, I)>); + +/// JAL instruction circuit +/// +/// Note: does not validate that next_pc is aligned by 4-byte increments, which +/// should be verified by lookup argument of the next execution step against +/// the program table +/// +/// Assumption: values for valid initial program counter must lie between +/// 2^20 and 2^32 - 2^20 + 2 inclusive, probably enforced by the static +/// program lookup table. If this assumption does not hold, then resulting +/// value for next_pc may not correctly wrap mod 2^32 because of the use +/// of native WitIn values for address space arithmetic. +impl Instruction for JalCircuit { + type InstructionConfig = JalConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + ) -> Result, ZKVMError> { + let rd_written = UInt::new(|| "rd_written", circuit_builder)?; + + let j_insn = JInstructionConfig::construct_circuit( + circuit_builder, + I::INST_KIND, + rd_written.register_expr(), + )?; + + circuit_builder.require_equal( + || "jal rd_written", + rd_written.value(), + j_insn.vm_state.pc.expr() + PC_STEP_SIZE.into(), + )?; + + Ok(JalConfig { j_insn, rd_written }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [MaybeUninit], + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + config + .j_insn + .assign_instance(instance, lk_multiplicity, step)?; + + let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); + config.rd_written.assign_value(instance, rd_written); + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/jump/test.rs b/ceno_zkvm/src/instructions/riscv/jump/test.rs new file mode 100644 index 000000000..dbc93cccb --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/jump/test.rs @@ -0,0 +1,54 @@ +use ceno_emul::{ByteAddr, Change, StepRecord, PC_STEP_SIZE}; +use goldilocks::GoldilocksExt2; +use itertools::Itertools; +use multilinear_extensions::mle::IntoMLEs; + +use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::Instruction, + scheme::mock_prover::{MockProver, MOCK_PC_JAL, MOCK_PROGRAM}, +}; + +use super::JalInstruction; + +#[test] +fn test_opcode_jal() { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || "jal", + |cb| { + let config = JalInstruction::::construct_circuit(cb); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let pc_offset: i32 = -4i32; + let new_pc: ByteAddr = ByteAddr(MOCK_PC_JAL.0.wrapping_add_signed(pc_offset)); + let (raw_witin, _lkm) = JalInstruction::::assign_instances( + &config, + cb.cs.num_witin as usize, + vec![StepRecord::new_j_instruction( + 4, + Change::new(MOCK_PC_JAL, new_pc), + MOCK_PROGRAM[21], + Change::new(0, (MOCK_PC_JAL + PC_STEP_SIZE).into()), + 0, + )], + ) + .unwrap(); + + MockProver::assert_satisfied( + &mut cb, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + None, + ); +} diff --git a/ceno_zkvm/src/instructions/riscv/sltu.rs b/ceno_zkvm/src/instructions/riscv/sltu.rs index 40113d7f2..7919733c1 100644 --- a/ceno_zkvm/src/instructions/riscv/sltu.rs +++ b/ceno_zkvm/src/instructions/riscv/sltu.rs @@ -185,7 +185,6 @@ mod test { let mut rng = rand::thread_rng(); let a: u32 = rng.gen(); let b: u32 = rng.gen(); - println!("random: {}, {}", a, b); verify("random 1", a, b, (a < b) as u32); verify("random 2", b, a, !(a < b) as u32); } diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 4714563ea..200a0a07b 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -82,6 +82,8 @@ pub const MOCK_PROGRAM: &[u32] = &[ 0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b001 << 12 | MOCK_RD << 7 | 0x33, // srl x4, x2, x3 0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b101 << 12 | MOCK_RD << 7 | 0x33, + // jal x4, 0xffffe + 0b_1_1111111110_1_11111111 << 12 | MOCK_RD << 7 | 0x6f, ]; // Addresses of particular instructions in the mock program. pub const MOCK_PC_ADD: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start()); @@ -105,6 +107,7 @@ pub const MOCK_PC_BGE: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 68); pub const MOCK_PC_MULHU: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 72); pub const MOCK_PC_SLL: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 76); pub const MOCK_PC_SRL: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 80); +pub const MOCK_PC_JAL: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 84); #[allow(clippy::enum_variant_names)] #[derive(Debug, PartialEq, Clone)]