From 6f3bfa04cccbd6708baab2685744d78326b80a58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Nicolas?= Date: Wed, 11 Sep 2024 18:06:59 +0200 Subject: [PATCH] hide-fast-decode: rename and privatize specialized instruction form --- ceno_emul/src/rv32im.rs | 22 +++++++++++----------- ceno_emul/src/tracer.rs | 14 +++----------- ceno_emul/src/vm_state.rs | 8 ++------ ceno_emul/tests/test_vm_trace.rs | 5 +---- ceno_zkvm/examples/riscv_add.rs | 2 +- 5 files changed, 18 insertions(+), 33 deletions(-) diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index 64e758f55..7dbba19b7 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -30,10 +30,10 @@ pub trait EmuContext { fn trap(&self, cause: TrapCause) -> Result; // Callback when instructions are decoded - fn on_insn_decoded(&mut self, kind: &Instruction, decoded: &DecodedInstruction); + fn on_insn_decoded(&mut self, _decoded: &DecodedInstruction) {} // Callback when instructions end normally - fn on_normal_end(&mut self, insn: &Instruction, decoded: &DecodedInstruction); + fn on_normal_end(&mut self, _decoded: &DecodedInstruction) {} // Get the program counter fn get_pc(&self) -> ByteAddr; @@ -175,7 +175,7 @@ pub enum InsnKind { } #[derive(Clone, Copy, Debug)] -pub struct Instruction { +struct FastDecodeEntry { pub kind: InsnKind, category: InsnCategory, pub opcode: u32, @@ -183,7 +183,7 @@ pub struct Instruction { pub func7: u32, } -impl Default for Instruction { +impl Default for FastDecodeEntry { fn default() -> Self { insn(InsnKind::INVALID, InsnCategory::Invalid, 0x00, 0x0, 0x00) } @@ -275,8 +275,8 @@ const fn insn( opcode: u32, func3: i32, func7: i32, -) -> Instruction { - Instruction { +) -> FastDecodeEntry { + FastDecodeEntry { kind, category, opcode, @@ -285,7 +285,7 @@ const fn insn( } } -type InstructionTable = [Instruction; 48]; +type InstructionTable = [FastDecodeEntry; 48]; type FastInstructionTable = [u8; 1 << 10]; const RV32IM_ISA: InstructionTable = [ @@ -379,7 +379,7 @@ impl FastDecodeTable { ((op_high << 5) | (func72bits << 3) | func3) as usize } - fn add_insn(table: &mut FastInstructionTable, insn: &Instruction, isa_idx: usize) { + fn add_insn(table: &mut FastInstructionTable, insn: &FastDecodeEntry, isa_idx: usize) { let op_high = insn.opcode >> 2; if (insn.func3 as i32) < 0 { for f3 in 0..8 { @@ -398,7 +398,7 @@ impl FastDecodeTable { } } - fn lookup(&self, decoded: &DecodedInstruction) -> Instruction { + fn lookup(&self, decoded: &DecodedInstruction) -> FastDecodeEntry { let isa_idx = self.table[Self::map10(decoded.opcode, decoded.func3, decoded.func7)]; RV32IM_ISA[isa_idx as usize] } @@ -434,7 +434,7 @@ impl Emulator { let decoded = DecodedInstruction::new(word); let insn = self.table.lookup(&decoded); - ctx.on_insn_decoded(&insn, &decoded); + ctx.on_insn_decoded(&decoded); if match insn.category { InsnCategory::Compute => self.step_compute(ctx, insn.kind, &decoded)?, @@ -443,7 +443,7 @@ impl Emulator { InsnCategory::System => self.step_system(ctx, insn.kind, &decoded)?, InsnCategory::Invalid => ctx.trap(TrapCause::IllegalInstruction(word))?, } { - ctx.on_normal_end(&insn, &decoded); + ctx.on_normal_end(&decoded); }; Ok(()) diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 625b86ef1..d419a2296 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -5,7 +5,6 @@ use crate::{ rv32im::DecodedInstruction, CENO_PLATFORM, }; -use crate::rv32im::Instruction; /// An instruction and its context in an execution trace. That is concrete values of registers and memory. /// @@ -23,7 +22,6 @@ pub struct StepRecord { pub cycle: Cycle, pub pc: Change, pub insn_code: Word, - pub insn: Instruction, pub rs1: Option, pub rs2: Option, @@ -63,18 +61,16 @@ impl StepRecord { self.pc } + /// The instruction as a raw code. pub fn insn_code(&self) -> Word { self.insn_code } - pub fn insn_decoded(&self) -> DecodedInstruction { + /// The instruction as a decoded structure. + pub fn insn(&self) -> DecodedInstruction { DecodedInstruction::new(self.insn_code) } - pub fn insn(&self) -> Instruction { - self.insn - } - pub fn rs1(&self) -> Option { self.rs1.clone() } @@ -147,10 +143,6 @@ impl Tracer { self.record.insn_code = value; } - pub fn store_insn(&mut self, insn: Instruction) { - self.record.insn = insn; - } - pub fn load_register(&mut self, idx: RegIdx, value: Word) { let addr = CENO_PLATFORM.register_vma(idx).into(); diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index ffd449535..3d8c70a03 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -4,7 +4,7 @@ use super::rv32im::EmuContext; use crate::{ addr::{ByteAddr, RegIdx, Word, WordAddr}, platform::Platform, - rv32im::{DecodedInstruction, Emulator, Instruction, TrapCause}, + rv32im::{DecodedInstruction, Emulator, TrapCause}, tracer::{Change, StepRecord, Tracer}, Program, }; @@ -113,11 +113,7 @@ impl EmuContext for VMState { Err(anyhow!("Trap {:?}", cause)) // Crash. } - fn on_insn_decoded(&mut self, insn: &Instruction, _decoded: &DecodedInstruction) { - self.tracer.store_insn(*insn); - } - - fn on_normal_end(&mut self, _kind: &Instruction, _decoded: &DecodedInstruction) { + fn on_normal_end(&mut self, _decoded: &DecodedInstruction) { self.tracer.store_pc(ByteAddr(self.pc)); } diff --git a/ceno_emul/tests/test_vm_trace.rs b/ceno_emul/tests/test_vm_trace.rs index 8f8051230..b3394b53a 100644 --- a/ceno_emul/tests/test_vm_trace.rs +++ b/ceno_emul/tests/test_vm_trace.rs @@ -21,10 +21,7 @@ fn test_vm_trace() -> Result<()> { assert_eq!(ctx.peek_register(2), x2); assert_eq!(ctx.peek_register(3), x3); - let ops: Vec = steps - .iter() - .map(|step| step.insn_decoded().kind().1) - .collect(); + let ops: Vec = steps.iter().map(|step| step.insn().kind().1).collect(); assert_eq!(ops, expected_ops_fibonacci_20()); assert_eq!( diff --git a/ceno_zkvm/examples/riscv_add.rs b/ceno_zkvm/examples/riscv_add.rs index ec8732085..1f9b905b9 100644 --- a/ceno_zkvm/examples/riscv_add.rs +++ b/ceno_zkvm/examples/riscv_add.rs @@ -108,7 +108,7 @@ fn main() { .collect::, _>>() .expect("vm exec failed") .into_iter() - .filter(|record| record.insn().kind == ADD) + .filter(|record| record.insn().kind().1 == ADD) .collect::>(); tracing::info!("tracer generated {} ADD records", records.len());