Skip to content

Commit

Permalink
feat/guest-example: generalize IO addresses
Browse files Browse the repository at this point in the history
  • Loading branch information
Aurélien Nicolas committed Nov 17, 2024
1 parent 446072f commit 1be1545
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 17 deletions.
33 changes: 29 additions & 4 deletions ceno_zkvm/examples/fibonacci_elf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ use ceno_zkvm::{
},
state::GlobalState,
structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses},
tables::{MemFinalRecord, MemInitRecord, ProgramTableCircuit, initial_registers},
tables::{
MemFinalRecord, MemInitRecord, ProgramTableCircuit, init_public_io, initial_registers,
},
};
use clap::Parser;
use ff_ext::ff::Field;
Expand Down Expand Up @@ -105,14 +107,22 @@ fn main() {

let mut mem_init = chain!(program_addrs, stack_addrs).collect_vec();

address_padder.pad(&mut mem_init, MmuConfig::<E>::static_mem_size());
address_padder.pad_records(&mut mem_init, MmuConfig::<E>::static_mem_size());

mem_init
};

let io_addrs = init_public_io(&[]).iter().map(|v| v.addr).collect_vec();

let reg_init = initial_registers();
config.generate_fixed_traces(&zkvm_cs, &mut zkvm_fixed_traces);
mmu_config.generate_fixed_traces(&zkvm_cs, &mut zkvm_fixed_traces, &reg_init, &mem_init);
mmu_config.generate_fixed_traces(
&zkvm_cs,
&mut zkvm_fixed_traces,
&reg_init,
&mem_init,
&io_addrs,
);
dummy_config.generate_fixed_traces(&zkvm_cs, &mut zkvm_fixed_traces);

let pk = zkvm_cs
Expand Down Expand Up @@ -205,12 +215,27 @@ fn main() {
.collect_vec();
debug_memory_ranges(&vm, &mem_final);

let io_final = io_addrs
.iter()
.map(|&addr| MemFinalRecord {
addr,
value: 0,
cycle: 0, // IO was not used.
})
.collect_vec();

// assign table circuits
config
.assign_table_circuit(&zkvm_cs, &mut zkvm_witness)
.unwrap();
mmu_config
.assign_table_circuit(&zkvm_cs, &mut zkvm_witness, &reg_final, &mem_final, &[])
.assign_table_circuit(
&zkvm_cs,
&mut zkvm_witness,
&reg_final,
&mem_final,
&io_final,
)
.unwrap();
// assign program circuit
zkvm_witness
Expand Down
4 changes: 3 additions & 1 deletion ceno_zkvm/examples/riscv_opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,10 @@ fn main() {

let reg_init = initial_registers();

let io_addrs = init_public_io(&[]).iter().map(|v| v.addr).collect_vec();

config.generate_fixed_traces(&zkvm_cs, &mut zkvm_fixed_traces);
mmu_config.generate_fixed_traces(&zkvm_cs, &mut zkvm_fixed_traces, &reg_init, &[]);
mmu_config.generate_fixed_traces(&zkvm_cs, &mut zkvm_fixed_traces, &reg_init, &[], &io_addrs);

let pk = zkvm_cs
.clone()
Expand Down
39 changes: 36 additions & 3 deletions ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
error::ZKVMError,
structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses},
tables::{
MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit, RegTableCircuit,
MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit, PubIOTable, RegTableCircuit,
StaticMemCircuit, StaticMemTable, TableCircuit,
},
};
Expand Down Expand Up @@ -42,6 +42,7 @@ impl<E: ExtensionField> MmuConfig<E> {
fixed: &mut ZKVMFixedTraces<E>,
reg_init: &[MemInitRecord],
static_mem: &[MemInitRecord],
io_addrs: &[Addr],
) {
fixed.register_table_circuit::<RegTableCircuit<E>>(cs, &self.reg_config, reg_init);

Expand All @@ -51,7 +52,7 @@ impl<E: ExtensionField> MmuConfig<E> {
static_mem,
);

fixed.register_table_circuit::<PubIOCircuit<E>>(cs, &self.public_io_config, &());
fixed.register_table_circuit::<PubIOCircuit<E>>(cs, &self.public_io_config, io_addrs);
}

pub fn assign_table_circuit(
Expand Down Expand Up @@ -82,6 +83,10 @@ impl<E: ExtensionField> MmuConfig<E> {
pub fn static_mem_size() -> usize {
<StaticMemTable as NonVolatileTable>::len()
}

pub fn public_io_size() -> usize {
<PubIOTable as NonVolatileTable>::len()
}
}

pub struct AddressPadder {
Expand All @@ -99,7 +104,7 @@ impl AddressPadder {

/// Pad `records` to `new_len` with valid records.
/// No addresses will be used more than once.
pub fn pad(&mut self, records: &mut Vec<MemInitRecord>, new_len: usize) {
pub fn pad_records(&mut self, records: &mut Vec<MemInitRecord>, new_len: usize) {
let old_len = records.len();
assert!(
old_len <= new_len,
Expand All @@ -126,4 +131,32 @@ impl AddressPadder {
"not enough addresses to pad memory records from {old_len} to {new_len}"
);
}

/// Pad `addresses` to `new_len` with valid records.
/// No addresses will be used more than once.
pub fn pad_addresses(&mut self, addresses: &mut Vec<Addr>, new_len: usize) {
let old_len = addresses.len();
assert!(
old_len <= new_len,
"cannot fit {old_len} memory addresses in {new_len} space"
);

// Keep track of addresses that were explicitly used.
self.used_addresses.extend(addresses.iter());

addresses.extend(
// Search for some addresses in the given range.
(&mut self.valid_addresses)
.step_by(WORD_SIZE)
// Exclude addresses already used.
.filter(|addr| !self.used_addresses.contains(addr))
// Create the padding.
.take(new_len - old_len),
);
assert_eq!(
addresses.len(),
new_len,
"not enough addresses to pad from {old_len} to {new_len}"
);
}
}
9 changes: 7 additions & 2 deletions ceno_zkvm/src/tables/ram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,13 @@ impl NonVolatileTable for PubIOTable {
const RAM_TYPE: RAMType = RAMType::Memory;
const V_LIMBS: usize = 1; // See `MemoryExpr`.
const WRITABLE: bool = false;
const OFFSET_ADDR: Addr = CENO_PLATFORM.public_io_start();
const END_ADDR: Addr = CENO_PLATFORM.public_io_end() + 1;
const OFFSET_ADDR: Addr = CENO_PLATFORM.public_io_start(); // TODO: remove.
const END_ADDR: Addr = CENO_PLATFORM.public_io_end() + 1; // TODO: remove.

fn len() -> usize {
// TODO: take as program parameter.
1 << 2 // words
}

fn name() -> &'static str {
"PubIOTable"
Expand Down
6 changes: 3 additions & 3 deletions ceno_zkvm/src/tables/ram/ram_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ impl<E: ExtensionField, NVRAM: NonVolatileTable + Send + Sync + Clone> TableCirc
for PubIORamCircuit<E, NVRAM>
{
type TableConfig = PubIOTableConfig<NVRAM>;
type FixedInput = ();
type FixedInput = [Addr];
type WitnessInput = [MemFinalRecord];

fn name() -> String {
Expand All @@ -120,10 +120,10 @@ impl<E: ExtensionField, NVRAM: NonVolatileTable + Send + Sync + Clone> TableCirc
fn generate_fixed_traces(
config: &Self::TableConfig,
num_fixed: usize,
_init_v: &Self::FixedInput,
io_addrs: &[Addr],
) -> RowMajorMatrix<E::BaseField> {
// assume returned table is well-formed include padding
config.gen_init_state(num_fixed)
config.gen_init_state(num_fixed, io_addrs)
}

fn assign_instances(
Expand Down
13 changes: 9 additions & 4 deletions ceno_zkvm/src/tables/ram/ram_impl.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{collections::HashMap, marker::PhantomData, mem::MaybeUninit};

use ceno_emul::Addr;
use ff_ext::ExtensionField;
use goldilocks::SmallField;
use itertools::Itertools;
Expand Down Expand Up @@ -235,7 +236,11 @@ impl<NVRAM: NonVolatileTable + Send + Sync + Clone> PubIOTableConfig<NVRAM> {
}

/// assign to fixed address
pub fn gen_init_state<F: SmallField>(&self, num_fixed: usize) -> RowMajorMatrix<F> {
pub fn gen_init_state<F: SmallField>(
&self,
num_fixed: usize,
io_addrs: &[Addr],
) -> RowMajorMatrix<F> {
assert!(NVRAM::len().is_power_of_two());

// for ram in memory offline check
Expand All @@ -244,10 +249,10 @@ impl<NVRAM: NonVolatileTable + Send + Sync + Clone> PubIOTableConfig<NVRAM> {

init_table
.par_iter_mut()
.enumerate()
.with_min_len(MIN_PAR_SIZE)
.for_each(|(i, row)| {
set_fixed_val!(row, self.addr, (NVRAM::addr(i) as u64).into());
.zip_eq(io_addrs.into_par_iter())
.for_each(|(row, addr)| {
set_fixed_val!(row, self.addr, (*addr as u64).into());
});
init_table
}
Expand Down

0 comments on commit 1be1545

Please sign in to comment.