Skip to content

Commit

Permalink
feat/x0: Support x0 by redirecting writes to RD_NULL (#503)
Browse files Browse the repository at this point in the history
_Issue #245_

---------

Co-authored-by: Aurélien Nicolas <[email protected]>
  • Loading branch information
naure and Aurélien Nicolas authored Oct 31, 2024
1 parent 383950f commit 69d6be7
Show file tree
Hide file tree
Showing 16 changed files with 52 additions and 41 deletions.
8 changes: 2 additions & 6 deletions ceno_emul/src/platform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,6 @@ impl Platform {
(vma >> 8) as RegIdx
}

/// Virtual address of the program counter.
pub const fn pc_vma(&self) -> Addr {
self.register_vma(32)
}

// Startup.

pub const fn pc_base(&self) -> Addr {
Expand Down Expand Up @@ -128,6 +123,7 @@ impl Platform {
#[cfg(test)]
mod tests {
use super::*;
use crate::VMState;

#[test]
fn test_no_overlap() {
Expand All @@ -139,7 +135,7 @@ mod tests {
assert!(!p.is_ram(p.rom_start()));
assert!(!p.is_ram(p.rom_end()));
// Registers do not overlap with ROM or RAM.
for reg in [p.pc_vma(), p.register_vma(0), p.register_vma(31)] {
for reg in [p.register_vma(0), p.register_vma(VMState::REG_COUNT - 1)] {
assert!(!p.is_rom(reg));
assert!(!p.is_ram(reg));
}
Expand Down
21 changes: 10 additions & 11 deletions ceno_emul/src/rv32im.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ pub struct InsnCodes {
}

impl DecodedInstruction {
/// A virtual register which absorbs the writes to x0.
pub const RD_NULL: u32 = 32;

pub fn new(insn: u32) -> Self {
Self {
insn,
Expand All @@ -224,16 +227,12 @@ impl DecodedInstruction {
self.opcode
}

/// Get the rd field, regardless of the instruction format.
pub fn rd(&self) -> u32 {
self.rd
}

/// Get the register destination, or zero if the instruction does not write to a register.
pub fn rd_or_zero(&self) -> u32 {
/// The internal register destination. It is either the regular rd, or an internal RD_NULL if
/// the instruction does not write to a register or writes to x0.
pub fn rd_internal(&self) -> u32 {
match self.codes().format {
R | I | U | J => self.rd,
_ => 0,
R | I | U | J if self.rd != 0 => self.rd,
_ => Self::RD_NULL,
}
}

Expand Down Expand Up @@ -684,7 +683,7 @@ impl Emulator {
if !new_pc.is_aligned() {
return ctx.trap(TrapCause::InstructionAddressMisaligned);
}
ctx.store_register(decoded.rd as usize, out)?;
ctx.store_register(decoded.rd_internal() as usize, out)?;
ctx.set_pc(new_pc);
Ok(true)
}
Expand Down Expand Up @@ -771,7 +770,7 @@ impl Emulator {
}
_ => unreachable!(),
};
ctx.store_register(decoded.rd as usize, out)?;
ctx.store_register(decoded.rd_internal() as usize, out)?;
ctx.set_pc(ctx.get_pc() + WORD_SIZE);
Ok(true)
}
Expand Down
4 changes: 3 additions & 1 deletion ceno_emul/src/tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ impl StepRecord {
previous_cycle,
}),
rd: rd.map(|rd| WriteOp {
addr: CENO_PLATFORM.register_vma(insn.rd() as RegIdx).into(),
addr: CENO_PLATFORM
.register_vma(insn.rd_internal() as RegIdx)
.into(),
value: rd,
previous_cycle,
}),
Expand Down
8 changes: 6 additions & 2 deletions ceno_emul/src/vm_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@ pub struct VMState {
pc: Word,
/// Map a word-address (addr/4) to a word.
memory: HashMap<WordAddr, Word>,
registers: [Word; 32],
registers: [Word; VMState::REG_COUNT],
// Termination.
halted: bool,
tracer: Tracer,
}

impl VMState {
/// The number of registers that the VM uses.
/// 32 architectural registers + 1 register RD_NULL for dark writes to x0.
pub const REG_COUNT: usize = 32 + 1;

pub fn new(platform: Platform, program: Program) -> Self {
let pc = program.entry;
let program = Arc::new(program);
Expand All @@ -34,7 +38,7 @@ impl VMState {
platform,
program: program.clone(),
memory: HashMap::new(),
registers: [0; 32],
registers: [0; VMState::REG_COUNT],
halted: false,
tracer: Tracer::new(),
};
Expand Down
15 changes: 10 additions & 5 deletions ceno_zkvm/examples/riscv_opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use tracing_flame::FlameLayer;
use tracing_subscriber::{EnvFilter, Registry, fmt, layer::SubscriberExt};
use transcript::Transcript;

const PROGRAM_SIZE: usize = 512;
const PROGRAM_SIZE: usize = 16;
// For now, we assume registers
// - x0 is not touched,
// - x1 is initialized to 1,
Expand Down Expand Up @@ -185,10 +185,15 @@ fn main() {
.iter()
.map(|rec| {
let index = rec.addr as usize;
let vma: WordAddr = CENO_PLATFORM.register_vma(index).into();
MemFinalRecord {
value: vm.peek_register(index),
cycle: *final_access.get(&vma).unwrap_or(&0),
if index < VMState::REG_COUNT {
let vma: WordAddr = CENO_PLATFORM.register_vma(index).into();
MemFinalRecord {
value: vm.peek_register(index),
cycle: *final_access.get(&vma).unwrap_or(&0),
}
} else {
// The table is padded beyond the number of registers.
MemFinalRecord { value: 0, cycle: 0 }
}
})
.collect_vec();
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/b_insn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl<E: ExtensionField> BInstructionConfig<E> {
circuit_builder.lk_fetch(&InsnRecord::new(
vm_state.pc.expr(),
insn_kind.codes().opcode.into(),
0.into(),
None,
insn_kind.codes().func3.into(),
rs1.id.expr(),
rs2.id.expr(),
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/ecall_insn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl EcallInstructionConfig {
cb.lk_fetch(&InsnRecord::new(
pc.expr(),
(EANY.codes().opcode as usize).into(),
0.into(),
None,
(EANY.codes().func3 as usize).into(),
0.into(),
0.into(),
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/i_insn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl<E: ExtensionField> IInstructionConfig<E> {
circuit_builder.lk_fetch(&InsnRecord::new(
vm_state.pc.expr(),
insn_kind.codes().opcode.into(),
rd.id.expr(),
Some(rd.id.expr()),
insn_kind.codes().func3.into(),
rs1.id.expr(),
0.into(),
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/im_insn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl<E: ExtensionField> IMInstructionConfig<E> {
circuit_builder.lk_fetch(&InsnRecord::new(
vm_state.pc.expr(),
(insn_kind.codes().opcode as usize).into(),
rd.id.expr(),
Some(rd.id.expr()),
(insn_kind.codes().func3 as usize).into(),
rs1.id.expr(),
0.into(),
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/insn_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ impl<E: ExtensionField> WriteRD<E> {
lk_multiplicity: &mut LkMultiplicity,
step: &StepRecord,
) -> Result<(), ZKVMError> {
set_val!(instance, self.id, step.insn().rd() as u64);
set_val!(instance, self.id, step.insn().rd_internal() as u64);
set_val!(instance, self.prev_ts, step.rd().unwrap().previous_cycle);

// Register state
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/j_insn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl<E: ExtensionField> JInstructionConfig<E> {
circuit_builder.lk_fetch(&InsnRecord::new(
vm_state.pc.expr(),
(insn_kind.codes().opcode as usize).into(),
rd.id.expr(),
Some(rd.id.expr()),
0.into(),
0.into(),
0.into(),
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/r_insn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl<E: ExtensionField> RInstructionConfig<E> {
circuit_builder.lk_fetch(&InsnRecord::new(
vm_state.pc.expr(),
insn_kind.codes().opcode.into(),
rd.id.expr(),
Some(rd.id.expr()),
insn_kind.codes().func3.into(),
rs1.id.expr(),
rs2.id.expr(),
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/s_insn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl<E: ExtensionField> SInstructionConfig<E> {
circuit_builder.lk_fetch(&InsnRecord::new(
vm_state.pc.expr(),
(insn_kind.codes().opcode as usize).into(),
0.into(),
None,
(insn_kind.codes().func3 as usize).into(),
rs1.id.expr(),
rs2.id.expr(),
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/u_insn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl<E: ExtensionField> UInstructionConfig<E> {
circuit_builder.lk_fetch(&InsnRecord::new(
vm_state.pc.expr(),
(insn_kind.codes().opcode as usize).into(),
rd.id.expr(),
Some(rd.id.expr()),
0.into(),
0.into(),
0.into(),
Expand Down
15 changes: 10 additions & 5 deletions ceno_zkvm/src/tables/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ macro_rules! declare_program {
pub struct InsnRecord<T>([T; 7]);

impl<T> InsnRecord<T> {
pub fn new(pc: T, opcode: T, rd: T, funct3: T, rs1: T, rs2: T, imm_or_funct7: T) -> Self {
pub fn new(pc: T, opcode: T, rd: Option<T>, funct3: T, rs1: T, rs2: T, imm_or_funct7: T) -> Self
where
T: From<u32>,
{
let rd = rd.unwrap_or_else(|| T::from(DecodedInstruction::RD_NULL));
InsnRecord([pc, opcode, rd, funct3, rs1, rs2, imm_or_funct7])
}

Expand All @@ -50,7 +54,7 @@ impl<T> InsnRecord<T> {
&self.0[1]
}

pub fn rd_or_zero(&self) -> &T {
pub fn rd_or_null(&self) -> &T {
&self.0[2]
}

Expand Down Expand Up @@ -80,15 +84,15 @@ impl<T> InsnRecord<T> {

impl InsnRecord<u32> {
fn from_decoded(pc: u32, insn: &DecodedInstruction) -> Self {
InsnRecord::new(
InsnRecord([
pc,
insn.opcode(),
insn.rd_or_zero(),
insn.rd_internal(),
insn.funct3_or_zero(),
insn.rs1_or_zero(),
insn.rs2_or_zero(),
insn.imm_or_funct7(),
)
])
}

/// Interpret the immediate or funct7 as unsigned or signed depending on the instruction.
Expand Down Expand Up @@ -184,6 +188,7 @@ impl<E: ExtensionField, const PROGRAM_SIZE: usize> TableCircuit<E>
);
});

Self::padding_zero(&mut fixed, num_fixed).expect("padding error");
fixed
}

Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/tables/ram.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use ceno_emul::{Addr, CENO_PLATFORM, WORD_SIZE, Word};
use ceno_emul::{Addr, CENO_PLATFORM, VMState, WORD_SIZE, Word};
use ram_circuit::RamTableCircuit;

use crate::{instructions::riscv::constants::UINT_LIMBS, structs::RAMType};
Expand Down Expand Up @@ -34,7 +34,7 @@ impl RamTable for RegTable {
const V_LIMBS: usize = UINT_LIMBS; // See `RegisterExpr`.

fn len() -> usize {
32 // register size 32
VMState::REG_COUNT.next_power_of_two()
}

fn addr(entry_index: usize) -> Addr {
Expand Down

0 comments on commit 69d6be7

Please sign in to comment.