Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement (non-)volatile table modular circuits and e2e public io #457

Merged
merged 24 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions ceno_emul/src/platform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,24 @@ impl Platform {
(self.rom_start()..=self.rom_end()).contains(&addr)
}

// TODO figure out proper region for program_data
pub const fn program_data_start(&self) -> Addr {
0x3000_0000
matthiasgoergens marked this conversation as resolved.
Show resolved Hide resolved
}

pub const fn program_data_end(&self) -> Addr {
0x3000_1000 - 1
matthiasgoergens marked this conversation as resolved.
Show resolved Hide resolved
}

// TODO figure out a proper region for public io
matthiasgoergens marked this conversation as resolved.
Show resolved Hide resolved
pub const fn public_io_start(&self) -> Addr {
0x3000_1000
}

pub const fn public_io_end(&self) -> Addr {
naure marked this conversation as resolved.
Show resolved Hide resolved
0x3000_2000 - 1
}

pub const fn ram_start(&self) -> Addr {
let ram_start = 0x8000_0000;
if cfg!(feature = "forbid_overflow") {
Expand Down
65 changes: 51 additions & 14 deletions ceno_zkvm/examples/riscv_opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use ceno_zkvm::{
instructions::riscv::{Rv32imConfig, constants::EXIT_PC},
scheme::prover::ZKVMProver,
state::GlobalState,
tables::{MemFinalRecord, ProgramTableCircuit, initial_memory, initial_registers},
tables::{
DynVolatileRamTable, MemFinalRecord, MemTable, ProgramTableCircuit, init_program_data,
initial_registers,
},
};
use clap::Parser;
use const_env::from_env;
Expand Down Expand Up @@ -45,11 +48,11 @@ const PROGRAM_CODE: [u32; PROGRAM_SIZE] = {
let mut program: [u32; PROGRAM_SIZE] = [ECALL_HALT; PROGRAM_SIZE];
declare_program!(
program,
// Load parameters from initial RAM.
encode_rv32(LUI, 0, 0, 10, CENO_PLATFORM.ram_start()), // lui x10, program_data
encode_rv32(LW, 10, 0, 1, 0), // lw x1, 0(x10)
encode_rv32(LW, 10, 0, 2, 4), // lw x2, 4(x10)
encode_rv32(LW, 10, 0, 3, 8), // lw x3, 8(x10)
// TODO load data from public io
encode_rv32(LUI, 0, 0, 10, CENO_PLATFORM.program_data_start()), // lui x10, program_data
encode_rv32(LW, 10, 0, 1, 0), // lw x1, 0(x10)
encode_rv32(LW, 10, 0, 2, 4), // lw x2, 4(x10)
encode_rv32(LW, 10, 0, 3, 8), // lw x3, 8(x10)
// Main loop.
encode_rv32(ADD, 1, 4, 4, 0), // add x4, x1, x4
encode_rv32(ADD, 2, 3, 3, 0), // add x3, x2, x3
Expand Down Expand Up @@ -126,6 +129,7 @@ fn main() {
let step_loop = 1 << (instance_num_vars - 1); // 1 step in loop contribute to 2 add instance

// init vm.x1 = 1, vm.x2 = -1, vm.x3 = step_loop
// TODO replace with public io
let program_data: &[u32] = &[1, u32::MAX, step_loop];

let mut zkvm_fixed_traces = ZKVMFixedTraces::default();
Expand All @@ -137,9 +141,14 @@ fn main() {
);

let reg_init = initial_registers();
let mem_init = initial_memory(program_data);
let program_data_init = init_program_data(program_data);

config.generate_fixed_traces(&zkvm_cs, &mut zkvm_fixed_traces, &reg_init, &mem_init);
config.generate_fixed_traces(
&zkvm_cs,
&mut zkvm_fixed_traces,
&reg_init,
&program_data_init,
);

let pk = zkvm_cs
.clone()
Expand All @@ -154,12 +163,15 @@ fn main() {
let mut vm = VMState::new(CENO_PLATFORM);
let pc_start = ByteAddr(CENO_PLATFORM.pc_start()).waddr();

// init program
for (i, inst) in PROGRAM_CODE.iter().enumerate() {
vm.init_memory(pc_start + i, *inst);
}
for record in &mem_init {
// init program data
for record in &program_data_init {
vm.init_memory(record.addr.into(), record.value);
}
// TODO init public i/o mem

let all_records = vm
.iter_until_halt()
Expand All @@ -185,6 +197,8 @@ fn main() {
Tracer::SUBCYCLES_PER_INSN as u32,
EXIT_PC as u32,
end_cycle,
// TODO use correct public_io
vec![1, 2],
);

let mut zkvm_witness = ZKVMWitnesses::default();
Expand All @@ -207,11 +221,28 @@ fn main() {
})
.collect_vec();

// Find the final memory values and cycles.
let mem_final = mem_init
// Find the final program_data cycles.
let program_data_final = program_data_init
.iter()
.map(|rec| {
let vma: WordAddr = rec.addr.into();
MemFinalRecord {
value: rec.value,
cycle: *final_access.get(&vma).unwrap_or(&0),
}
})
.collect_vec();

// TODO Find the final public io cycles.

// Find the final mem data and cycles.
// TODO retrieve max address access property and avoid scan whole address space
// as we already support non-uniform proving of memory
let mem_start = MemTable::OFFSET_ADDR;
let mem_end = MemTable::END_ADDR;
let mem_final = (mem_start..mem_end)
.map(|addr| {
let vma = ByteAddr::from(addr).waddr();
MemFinalRecord {
value: vm.peek_memory(vma),
cycle: *final_access.get(&vma).unwrap_or(&0),
Expand All @@ -221,7 +252,13 @@ fn main() {

// assign table circuits
config
.assign_table_circuit(&zkvm_cs, &mut zkvm_witness, &reg_final, &mem_final)
.assign_table_circuit(
&zkvm_cs,
&mut zkvm_witness,
&reg_final,
&mem_final,
&program_data_final,
)
.unwrap();

// assign program circuit
Expand Down Expand Up @@ -255,8 +292,8 @@ fn main() {

let transcript = Transcript::new(b"riscv");
// change public input maliciously should cause verifier to reject proof
zkvm_proof.pv[0] = <GoldilocksExt2 as ff_ext::ExtensionField>::BaseField::ONE;
zkvm_proof.pv[1] = <GoldilocksExt2 as ff_ext::ExtensionField>::BaseField::ONE;
zkvm_proof.raw_pi[0] = vec![<GoldilocksExt2 as ff_ext::ExtensionField>::BaseField::ONE];
zkvm_proof.raw_pi[1] = vec![<GoldilocksExt2 as ff_ext::ExtensionField>::BaseField::ONE];

// capture panic message, if have
let default_hook = panic::take_hook();
Expand Down
19 changes: 12 additions & 7 deletions ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use ff_ext::ExtensionField;

use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
circuit_builder::{CircuitBuilder, ConstraintSystem, SetTableSpec},
error::ZKVMError,
expression::{Expression, Fixed, Instance, ToExpr, WitIn},
instructions::riscv::constants::{
END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, INIT_CYCLE_IDX, INIT_PC_IDX,
END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, PUBLIC_IO_IDX,
UINT_LIMBS,
},
structs::ROMType,
tables::InsnRecord,
Expand All @@ -32,7 +33,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
self.cs.create_fixed(name_fn)
}

pub fn query_exit_code(&mut self) -> Result<[Instance; 2], ZKVMError> {
pub fn query_exit_code(&mut self) -> Result<[Instance; UINT_LIMBS], ZKVMError> {
Ok([
self.cs.query_instance(|| "exit_code_low", EXIT_CODE_IDX)?,
self.cs
Expand All @@ -56,6 +57,10 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
self.cs.query_instance(|| "end_cycle", END_CYCLE_IDX)
}

pub fn query_public_io(&mut self) -> Result<Instance, ZKVMError> {
self.cs.query_instance(|| "public_io", PUBLIC_IO_IDX)
}

pub fn lk_record<NR, N>(
&mut self,
name_fn: N,
Expand Down Expand Up @@ -87,27 +92,27 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
pub fn r_table_record<NR, N>(
&mut self,
name_fn: N,
table_len: usize,
table_spec: SetTableSpec,
rlc_record: Expression<E>,
) -> Result<(), ZKVMError>
where
NR: Into<String>,
N: FnOnce() -> NR,
{
self.cs.r_table_record(name_fn, table_len, rlc_record)
self.cs.r_table_record(name_fn, table_spec, rlc_record)
}

pub fn w_table_record<NR, N>(
&mut self,
name_fn: N,
table_len: usize,
table_spec: SetTableSpec,
rlc_record: Expression<E>,
) -> Result<(), ZKVMError>
where
NR: Into<String>,
N: FnOnce() -> NR,
{
self.cs.w_table_record(name_fn, table_len, rlc_record)
self.cs.w_table_record(name_fn, table_spec, rlc_record)
}

/// Fetch an instruction at a given PC from the Program table.
Expand Down
27 changes: 22 additions & 5 deletions ceno_zkvm/src/circuit_builder.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use ceno_emul::Addr;
use itertools::Itertools;
use std::{collections::HashMap, marker::PhantomData};

Expand Down Expand Up @@ -72,10 +73,26 @@ pub struct LogupTableExpression<E: ExtensionField> {
pub table_len: usize,
}

#[derive(Clone, Debug)]
pub enum SetTableAddrType {
FixedAddr,
DynamicAddr,
}

#[derive(Clone, Debug)]
pub struct SetTableSpec {
pub addr_type: SetTableAddrType,
pub offset: Addr,
matthiasgoergens marked this conversation as resolved.
Show resolved Hide resolved
pub len: usize,
pub rw: bool,
}

#[derive(Clone, Debug)]
pub struct SetTableExpression<E: ExtensionField> {
pub values: Expression<E>,
pub table_len: usize,

// TODO diffentiate enum/struct, for which option is more friendly to be processed by ConstrainSystem + recursive verifier
pub table_spec: SetTableSpec,
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -307,7 +324,7 @@ impl<E: ExtensionField> ConstraintSystem<E> {
pub fn r_table_record<NR, N>(
&mut self,
name_fn: N,
table_len: usize,
table_spec: SetTableSpec,
rlc_record: Expression<E>,
) -> Result<(), ZKVMError>
where
Expand All @@ -322,7 +339,7 @@ impl<E: ExtensionField> ConstraintSystem<E> {
);
self.r_table_expressions.push(SetTableExpression {
values: rlc_record,
table_len,
table_spec,
});
let path = self.ns.compute_path(name_fn().into());
self.r_table_expressions_namespace_map.push(path);
Expand All @@ -333,7 +350,7 @@ impl<E: ExtensionField> ConstraintSystem<E> {
pub fn w_table_record<NR, N>(
&mut self,
name_fn: N,
table_len: usize,
table_spec: SetTableSpec,
rlc_record: Expression<E>,
) -> Result<(), ZKVMError>
where
Expand All @@ -348,7 +365,7 @@ impl<E: ExtensionField> ConstraintSystem<E> {
);
self.w_table_expressions.push(SetTableExpression {
values: rlc_record,
table_len,
table_spec,
});
let path = self.ns.compute_path(name_fn().into());
self.w_table_expressions_namespace_map.push(path);
Expand Down
1 change: 1 addition & 0 deletions ceno_zkvm/src/instructions/riscv/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub const INIT_PC_IDX: usize = 2;
pub const INIT_CYCLE_IDX: usize = 3;
pub const END_PC_IDX: usize = 4;
pub const END_CYCLE_IDX: usize = 5;
pub const PUBLIC_IO_IDX: usize = 6;
matthiasgoergens marked this conversation as resolved.
Show resolved Hide resolved

pub const LIMB_BITS: usize = 16;
pub const LIMB_MASK: u32 = 0xFFFF;
Expand Down
30 changes: 23 additions & 7 deletions ceno_zkvm/src/instructions/riscv/rv32im.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use crate::{
instructions::Instruction,
structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses},
tables::{
AndTableCircuit, LtuTableCircuit, MemFinalRecord, MemInitRecord, MemTableCircuit,
RegTableCircuit, TableCircuit, U16TableCircuit,
AndTableCircuit, LtuTableCircuit, MemCircuit, MemFinalRecord, MemInitRecord,
ProgramDataCircuit, RegTableCircuit, TableCircuit, U16TableCircuit,
},
};
use ceno_emul::{CENO_PLATFORM, InsnKind, StepRecord};
Expand Down Expand Up @@ -34,7 +34,8 @@ pub struct Rv32imConfig<E: ExtensionField> {

// RW tables.
pub reg_config: <RegTableCircuit<E> as TableCircuit<E>>::TableConfig,
pub mem_config: <MemTableCircuit<E> as TableCircuit<E>>::TableConfig,
pub mem_config: <MemCircuit<E> as TableCircuit<E>>::TableConfig,
pub program_data_config: <ProgramDataCircuit<E> as TableCircuit<E>>::TableConfig,
}

impl<E: ExtensionField> Rv32imConfig<E> {
Expand All @@ -54,7 +55,8 @@ impl<E: ExtensionField> Rv32imConfig<E> {

// RW tables
let reg_config = cs.register_table_circuit::<RegTableCircuit<E>>();
let mem_config = cs.register_table_circuit::<MemTableCircuit<E>>();
let mem_config = cs.register_table_circuit::<MemCircuit<E>>();
let program_data_config = cs.register_table_circuit::<ProgramDataCircuit<E>>();

Self {
add_config,
Expand All @@ -69,6 +71,7 @@ impl<E: ExtensionField> Rv32imConfig<E> {

reg_config,
mem_config,
program_data_config,
}
}

Expand All @@ -77,7 +80,7 @@ impl<E: ExtensionField> Rv32imConfig<E> {
cs: &ZKVMConstraintSystem<E>,
fixed: &mut ZKVMFixedTraces<E>,
reg_init: &[MemInitRecord],
mem_init: &[MemInitRecord],
program_data_init: &[MemInitRecord],
) {
fixed.register_opcode_circuit::<AddInstruction<E>>(cs);
fixed.register_opcode_circuit::<BltuInstruction>(cs);
Expand All @@ -91,7 +94,11 @@ impl<E: ExtensionField> Rv32imConfig<E> {
fixed.register_table_circuit::<LtuTableCircuit<E>>(cs, self.ltu_config.clone(), &());

fixed.register_table_circuit::<RegTableCircuit<E>>(cs, self.reg_config.clone(), reg_init);
fixed.register_table_circuit::<MemTableCircuit<E>>(cs, self.mem_config.clone(), mem_init);
fixed.register_table_circuit::<ProgramDataCircuit<E>>(
cs,
self.program_data_config.clone(),
program_data_init,
);
}

pub fn assign_opcode_circuit(
Expand Down Expand Up @@ -145,6 +152,7 @@ impl<E: ExtensionField> Rv32imConfig<E> {
witness: &mut ZKVMWitnesses<E>,
reg_final: &[MemFinalRecord],
mem_final: &[MemFinalRecord],
program_data_final: &[MemFinalRecord],
) -> Result<(), ZKVMError> {
witness.assign_table_circuit::<U16TableCircuit<E>>(cs, &self.u16_range_config, &())?;
witness.assign_table_circuit::<AndTableCircuit<E>>(cs, &self.and_config, &())?;
Expand All @@ -156,7 +164,15 @@ impl<E: ExtensionField> Rv32imConfig<E> {
.unwrap();
// assign memory finalization.
witness
.assign_table_circuit::<MemTableCircuit<E>>(cs, &self.mem_config, mem_final)
.assign_table_circuit::<MemCircuit<E>>(cs, &self.mem_config, mem_final)
.unwrap();
// assign program_data finalization.
witness
.assign_table_circuit::<ProgramDataCircuit<E>>(
cs,
&self.program_data_config,
program_data_final,
)
.unwrap();
Ok(())
}
Expand Down
Loading
Loading