diff --git a/ceno_emul/src/addr.rs b/ceno_emul/src/addr.rs index 0ce39f56b..200739ce7 100644 --- a/ceno_emul/src/addr.rs +++ b/ceno_emul/src/addr.rs @@ -17,6 +17,7 @@ use std::{fmt, ops}; pub const WORD_SIZE: usize = 4; +pub const PC_WORD_SIZE: usize = 4; pub const PC_STEP_SIZE: usize = 4; // Type aliases to clarify the code without wrapper types. @@ -26,10 +27,10 @@ pub type Addr = u32; pub type Cycle = u64; pub type RegIdx = usize; -#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct ByteAddr(pub u32); -#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct WordAddr(u32); impl From for WordAddr { diff --git a/ceno_emul/src/elf.rs b/ceno_emul/src/elf.rs index f6a28f743..21d12779f 100644 --- a/ceno_emul/src/elf.rs +++ b/ceno_emul/src/elf.rs @@ -20,21 +20,47 @@ use alloc::collections::BTreeMap; use crate::addr::WORD_SIZE; use anyhow::{Context, Result, anyhow, bail}; -use elf::{ElfBytes, endian::LittleEndian, file::Class}; +use elf::{ + ElfBytes, + abi::{PF_R, PF_W, PF_X}, + endian::LittleEndian, + file::Class, +}; /// A RISC Zero program +#[derive(Clone, Debug)] pub struct Program { /// The entrypoint of the program pub entry: u32, - + /// This is the lowest address of the program's executable code + pub base_address: u32, + /// The instructions of the program + pub instructions: Vec, /// The initial memory image pub image: BTreeMap, } impl Program { + /// Create program + pub fn new( + entry: u32, + base_address: u32, + instructions: Vec, + image: BTreeMap, + ) -> Program { + Self { + entry, + base_address, + instructions, + image, + } + } /// Initialize a RISC Zero Program from an appropriate ELF file pub fn load_elf(input: &[u8], max_mem: u32) -> Result { + let mut instructions: Vec = Vec::new(); let mut image: BTreeMap = BTreeMap::new(); + let mut base_address = None; + let elf = ElfBytes::::minimal_parse(input) .map_err(|err| anyhow!("Elf parse error: {err}"))?; if elf.ehdr.class != Class::ELF32 { @@ -58,7 +84,18 @@ impl Program { if segments.len() > 256 { bail!("Too many program headers"); } - for segment in segments.iter().filter(|x| x.p_type == elf::abi::PT_LOAD) { + for (idx, segment) in segments + .iter() + .filter(|x| x.p_type == elf::abi::PT_LOAD) + .enumerate() + { + tracing::debug!( + "loadable segement {}: PF_R={}, PF_W={}, PF_X={}", + idx, + segment.p_flags & PF_R != 0, + segment.p_flags & PF_W != 0, + segment.p_flags & PF_X != 0, + ); let file_size: u32 = segment .p_filesz .try_into() @@ -77,6 +114,13 @@ impl Program { .p_vaddr .try_into() .map_err(|err| anyhow!("vaddr is larger than 32 bits. {err}"))?; + if (segment.p_flags & PF_X) != 0 { + if base_address.is_none() { + base_address = Some(vaddr); + } else { + return Err(anyhow!("only support one executable segment")); + } + } if vaddr % WORD_SIZE as u32 != 0 { bail!("vaddr {vaddr:08x} is unaligned"); } @@ -104,9 +148,25 @@ impl Program { word |= (*byte as u32) << (j * 8); } image.insert(addr, word); + if (segment.p_flags & PF_X) != 0 { + instructions.push(word); + } } } } - Ok(Program { entry, image }) + + if base_address.is_none() { + return Err(anyhow!("does not have executable segment")); + } + let base_address = base_address.unwrap(); + assert!(entry >= base_address); + assert!((entry - base_address) as usize <= instructions.len() * WORD_SIZE); + + Ok(Program { + entry, + base_address, + image, + instructions, + }) } } diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index 264d810f5..163757057 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -5,19 +5,30 @@ use crate::addr::{Addr, RegIdx}; /// - the layout of virtual memory, /// - special addresses, such as the initial PC, /// - codes of environment calls. -pub struct Platform; +#[derive(Copy, Clone)] +pub struct Platform { + pub rom_start: Addr, + pub rom_end: Addr, + pub ram_start: Addr, + pub ram_end: Addr, +} -pub const CENO_PLATFORM: Platform = Platform; +pub const CENO_PLATFORM: Platform = Platform { + rom_start: 0x2000_0000, + rom_end: 0x3000_0000 - 1, + ram_start: 0x8000_0000, + ram_end: 0xFFFF_FFFF, +}; impl Platform { // Virtual memory layout. pub const fn rom_start(&self) -> Addr { - 0x2000_0000 + self.rom_start } pub const fn rom_end(&self) -> Addr { - 0x3000_0000 - 1 + self.rom_end } pub fn is_rom(&self, addr: Addr) -> bool { @@ -25,18 +36,17 @@ impl Platform { } pub const fn ram_start(&self) -> Addr { - let ram_start = 0x8000_0000; if cfg!(feature = "forbid_overflow") { // -1<<11 == 0x800 is the smallest negative 'immediate' // offset we can have in memory instructions. // So if we stay away from it, we are safe. - assert!(ram_start >= 0x800); + assert!(self.ram_start >= 0x800); } - ram_start + self.ram_start } pub const fn ram_end(&self) -> Addr { - 0xFFFF_FFFF + self.ram_end - if cfg!(feature = "forbid_overflow") { // (1<<11) - 1 == 0x7ff is the largest positive 'immediate' // offset we can have in memory instructions. @@ -69,7 +79,7 @@ impl Platform { // Startup. - pub const fn pc_start(&self) -> Addr { + pub const fn pc_base(&self) -> Addr { self.rom_start() } @@ -122,7 +132,7 @@ mod tests { #[test] fn test_no_overlap() { let p = CENO_PLATFORM; - assert!(p.can_execute(p.pc_start())); + assert!(p.can_execute(p.pc_base())); // ROM and RAM do not overlap. assert!(!p.is_rom(p.ram_start())); assert!(!p.is_rom(p.ram_end())); diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index 25a13f22b..2deb4598a 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -556,6 +556,7 @@ impl Emulator { let decoded = DecodedInstruction::new(word); let insn = self.table.lookup(&decoded); ctx.on_insn_decoded(&decoded); + tracing::trace!("pc: {:x}, kind: {:?}", pc.0, insn.kind); if match insn.category { InsnCategory::Compute => self.step_compute(ctx, insn.kind, &decoded)?, @@ -775,6 +776,7 @@ impl Emulator { let addr = ByteAddr(rs1.wrapping_add(decoded.imm_s())); let shift = 8 * (addr.0 & 3); if !ctx.check_data_store(addr) { + tracing::error!("mstore: addr={:x?},rs1={:x}", addr, rs1); return ctx.trap(TrapCause::StoreAccessFault); } let mut data = ctx.peek_memory(addr.waddr()); diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index 5d43a3865..bff16fe09 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -9,10 +9,11 @@ use crate::{ tracer::{Change, StepRecord, Tracer}, }; use anyhow::{Result, anyhow}; -use std::iter::from_fn; +use std::{iter::from_fn, ops::Deref, sync::Arc}; /// An implementation of the machine state and of the side-effects of operations. pub struct VMState { + program: Arc, platform: Platform, pc: Word, /// Map a word-address (addr/4) to a word. @@ -24,28 +25,39 @@ pub struct VMState { } impl VMState { - pub fn new(platform: Platform) -> Self { - let pc = platform.pc_start(); - Self { - platform, + pub fn new(platform: Platform, program: Program) -> Self { + let pc = program.entry; + let program = Arc::new(program); + + let mut vm = Self { pc, + platform, + program: program.clone(), memory: HashMap::new(), registers: [0; 32], halted: false, tracer: Tracer::new(), + }; + + // init memory from program.image + for (&addr, &value) in program.image.iter() { + vm.init_memory(ByteAddr(addr).waddr(), value); } + + vm } pub fn new_from_elf(platform: Platform, elf: &[u8]) -> Result { - let mut state = Self::new(platform); - let program = Program::load_elf(elf, u32::MAX).unwrap(); - for (addr, word) in program.image.iter() { - let addr = ByteAddr(*addr).waddr(); - state.init_memory(addr, *word); - } - if program.entry != state.platform.pc_start() { - return Err(anyhow!("Invalid entrypoint {:x}", program.entry)); + let program = Program::load_elf(elf, u32::MAX)?; + let state = Self::new(platform, program); + + if state.program.base_address != state.platform.rom_start() { + return Err(anyhow!( + "Invalid base_address {:x}", + state.program.base_address + )); } + Ok(state) } @@ -57,7 +69,11 @@ impl VMState { &self.tracer } - /// Set a word in memory without side-effects. + pub fn program(&self) -> &Program { + self.program.deref() + } + + /// Set a word in memory without side effects. pub fn init_memory(&mut self, addr: WordAddr, value: Word) { self.memory.insert(addr, value); } diff --git a/ceno_emul/tests/test_vm_trace.rs b/ceno_emul/tests/test_vm_trace.rs index 57069fd75..c01977823 100644 --- a/ceno_emul/tests/test_vm_trace.rs +++ b/ceno_emul/tests/test_vm_trace.rs @@ -1,19 +1,30 @@ #![allow(clippy::unusual_byte_groupings)] use anyhow::Result; -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use ceno_emul::{ - ByteAddr, CENO_PLATFORM, Cycle, EmuContext, InsnKind, StepRecord, Tracer, VMState, WordAddr, + CENO_PLATFORM, Cycle, EmuContext, InsnKind, Program, StepRecord, Tracer, VMState, WORD_SIZE, + WordAddr, }; #[test] fn test_vm_trace() -> Result<()> { - let mut ctx = VMState::new(CENO_PLATFORM); - - let pc_start = ByteAddr(CENO_PLATFORM.pc_start()).waddr(); - for (i, &inst) in PROGRAM_FIBONACCI_20.iter().enumerate() { - ctx.init_memory(pc_start + i as u32, inst); - } + let program = Program::new( + CENO_PLATFORM.pc_base(), + CENO_PLATFORM.pc_base(), + PROGRAM_FIBONACCI_20.to_vec(), + PROGRAM_FIBONACCI_20 + .iter() + .enumerate() + .map(|(insn_idx, &insn)| { + ( + CENO_PLATFORM.pc_base() + (WORD_SIZE * insn_idx) as u32, + insn, + ) + }) + .collect(), + ); + let mut ctx = VMState::new(CENO_PLATFORM, program); let steps = run(&mut ctx)?; @@ -35,7 +46,13 @@ fn test_vm_trace() -> Result<()> { #[test] fn test_empty_program() -> Result<()> { - let mut ctx = VMState::new(CENO_PLATFORM); + let empty_program = Program::new( + CENO_PLATFORM.pc_base(), + CENO_PLATFORM.pc_base(), + vec![], + BTreeMap::new(), + ); + let mut ctx = VMState::new(CENO_PLATFORM, empty_program); let res = run(&mut ctx); assert!(matches!(res, Err(e) if e.to_string().contains("IllegalInstruction(0)"))); Ok(()) diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index b979975d2..2ba51773d 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -10,9 +10,9 @@ use ceno_zkvm::{ use clap::Parser; use ceno_emul::{ - ByteAddr, CENO_PLATFORM, EmuContext, + CENO_PLATFORM, EmuContext, InsnKind::{ADD, BLTU, EANY, JAL, LUI, LW}, - StepRecord, Tracer, VMState, WordAddr, encode_rv32, + PC_WORD_SIZE, Program, StepRecord, Tracer, VMState, WordAddr, encode_rv32, }; use ceno_zkvm::{ scheme::{PublicValues, constants::MAX_NUM_VARIABLES, verifier::ZKVMVerifier}, @@ -76,6 +76,21 @@ fn main() { type E = GoldilocksExt2; type Pcs = Basefold; + let program = Program::new( + CENO_PLATFORM.pc_base(), + CENO_PLATFORM.pc_base(), + PROGRAM_CODE.to_vec(), + PROGRAM_CODE + .iter() + .enumerate() + .map(|(insn_idx, &insn)| { + ( + (insn_idx * PC_WORD_SIZE) as u32 + CENO_PLATFORM.pc_base(), + insn, + ) + }) + .collect(), + ); let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap(); let subscriber = Registry::default() .with( @@ -108,7 +123,7 @@ fn main() { zkvm_fixed_traces.register_table_circuit::>( &zkvm_cs, prog_config.clone(), - &PROGRAM_CODE, + &program, ); let reg_init = initial_registers(); @@ -126,12 +141,8 @@ fn main() { let prover = ZKVMProver::new(pk); let verifier = ZKVMVerifier::new(vk); - let mut vm = VMState::new(CENO_PLATFORM); - let pc_start = ByteAddr(CENO_PLATFORM.pc_start()).waddr(); + let mut vm = VMState::new(CENO_PLATFORM, program.clone()); - for (i, inst) in PROGRAM_CODE.iter().enumerate() { - vm.init_memory(pc_start + i, *inst); - } for record in &mem_init { vm.init_memory(record.addr.into(), record.value); } @@ -201,11 +212,7 @@ fn main() { // assign program circuit zkvm_witness - .assign_table_circuit::>( - &zkvm_cs, - &prog_config, - &PROGRAM_CODE.len(), - ) + .assign_table_circuit::>(&zkvm_cs, &prog_config, &program) .unwrap(); let timer = Instant::now(); diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 429bba66d..2d6f8ba39 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -16,7 +16,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { Self { cs } } - pub fn create_witin(&mut self, name_fn: N) -> Result + pub fn create_witin(&mut self, name_fn: N) -> WitIn where NR: Into, N: FnOnce() -> NR, @@ -148,7 +148,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { NR: Into, N: FnOnce() -> NR + Clone, { - let byte = self.cs.create_witin(name_fn.clone())?; + let byte = self.cs.create_witin(name_fn.clone()); self.assert_ux::<_, _, 8>(name_fn, byte.expr())?; Ok(byte) @@ -159,7 +159,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { NR: Into, N: FnOnce() -> NR + Clone, { - let limb = self.cs.create_witin(name_fn.clone())?; + let limb = self.cs.create_witin(name_fn.clone()); self.assert_ux::<_, _, 16>(name_fn, limb.expr())?; Ok(limb) @@ -393,8 +393,8 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { lhs: Expression, rhs: Expression, ) -> Result<(WitIn, WitIn), ZKVMError> { - let is_eq = self.create_witin(|| "is_eq")?; - let diff_inverse = self.create_witin(|| "diff_inverse")?; + let is_eq = self.create_witin(|| "is_eq"); + let diff_inverse = self.create_witin(|| "diff_inverse"); self.require_zero( || "is equal", diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index f5ef37825..b5a8cd13b 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -194,14 +194,11 @@ impl ConstraintSystem { } } - pub fn create_witin, N: FnOnce() -> NR>( - &mut self, - n: N, - ) -> Result { + pub fn create_witin, N: FnOnce() -> NR>(&mut self, n: N) -> WitIn { let wit_in = WitIn { id: { let id = self.num_witin; - self.num_witin += 1; + self.num_witin = self.num_witin.strict_add(1); id }, }; @@ -209,7 +206,7 @@ impl ConstraintSystem { let path = self.ns.compute_path(n().into()); self.witin_namespace_map.push(path); - Ok(wit_in) + wit_in } pub fn create_fixed, N: FnOnce() -> NR>( diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index e59a3985d..a8db8f294 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -621,7 +621,7 @@ impl WitIn { || "from_expr", |cb| { let name = name().into(); - let wit = cb.create_witin(|| name.clone())?; + let wit = cb.create_witin(|| name.clone()); if !debug { cb.require_zero(|| name.clone(), wit.expr() - input)?; } @@ -876,7 +876,7 @@ mod tests { type E = GoldilocksExt2; let mut cs = ConstraintSystem::new(|| "test_root"); let mut cb = CircuitBuilder::::new(&mut cs); - let x = cb.create_witin(|| "x").unwrap(); + let x = cb.create_witin(|| "x"); // scaledsum * challenge // 3 * x + 2 @@ -942,9 +942,9 @@ mod tests { type E = GoldilocksExt2; let mut cs = ConstraintSystem::new(|| "test_root"); let mut cb = CircuitBuilder::::new(&mut cs); - let x = cb.create_witin(|| "x").unwrap(); - let y = cb.create_witin(|| "y").unwrap(); - let z = cb.create_witin(|| "z").unwrap(); + let x = cb.create_witin(|| "x"); + let y = cb.create_witin(|| "y"); + let z = cb.create_witin(|| "z"); // scaledsum * challenge // 3 * x + 2 let expr: Expression = @@ -984,8 +984,8 @@ mod tests { type E = GoldilocksExt2; let mut cs = ConstraintSystem::new(|| "test_root"); let mut cb = CircuitBuilder::::new(&mut cs); - let x = cb.create_witin(|| "x").unwrap(); - let y = cb.create_witin(|| "y").unwrap(); + let x = cb.create_witin(|| "x"); + let y = cb.create_witin(|| "y"); // scaledsum * challenge // (x + 1) * (y + 1) let expr: Expression = (Into::>::into(1usize) + x.expr()) diff --git a/ceno_zkvm/src/gadgets/is_lt.rs b/ceno_zkvm/src/gadgets/is_lt.rs index f8d40cdee..a716ee8d5 100644 --- a/ceno_zkvm/src/gadgets/is_lt.rs +++ b/ceno_zkvm/src/gadgets/is_lt.rs @@ -86,7 +86,7 @@ impl IsLtConfig { || "is_lt", |cb| { let name = name_fn(); - let is_lt = cb.create_witin(|| format!("{name} is_lt witin"))?; + let is_lt = cb.create_witin(|| format!("{name} is_lt witin")); cb.assert_bit(|| "is_lt_bit", is_lt.expr())?; let config = InnerLtConfig::construct_circuit( @@ -153,7 +153,7 @@ impl InnerLtConfig { cb.namespace( || format!("var {var_name}"), |cb| { - let witin = cb.create_witin(|| var_name.to_string())?; + let witin = cb.create_witin(|| var_name.to_string()); cb.assert_ux::<_, _, 16>(|| name.clone(), witin.expr())?; Ok(witin) }, @@ -293,7 +293,7 @@ impl SignedLtConfig { || "is_signed_lt", |cb| { let name = name_fn(); - let is_lt = cb.create_witin(|| format!("{name} is_signed_lt witin"))?; + let is_lt = cb.create_witin(|| format!("{name} is_signed_lt witin")); cb.assert_bit(|| "is_lt_bit", is_lt.expr())?; let config = InnerSignedLtConfig::construct_circuit(cb, name, lhs, rhs, is_lt.expr())?; diff --git a/ceno_zkvm/src/gadgets/is_zero.rs b/ceno_zkvm/src/gadgets/is_zero.rs index 02994e4f7..ff8525afb 100644 --- a/ceno_zkvm/src/gadgets/is_zero.rs +++ b/ceno_zkvm/src/gadgets/is_zero.rs @@ -26,8 +26,8 @@ impl IsZeroConfig { x: Expression, ) -> Result { cb.namespace(name_fn, |cb| { - let is_zero = cb.create_witin(|| "is_zero")?; - let inverse = cb.create_witin(|| "inv")?; + let is_zero = cb.create_witin(|| "is_zero"); + let inverse = cb.create_witin(|| "inv"); // x==0 => is_zero=1 cb.require_one(|| "is_zero_1", is_zero.expr() + x.clone() * inverse.expr())?; diff --git a/ceno_zkvm/src/gadgets/signed_ext.rs b/ceno_zkvm/src/gadgets/signed_ext.rs index b81c009f7..96706dc59 100644 --- a/ceno_zkvm/src/gadgets/signed_ext.rs +++ b/ceno_zkvm/src/gadgets/signed_ext.rs @@ -38,7 +38,7 @@ impl SignedExtendConfig { ) -> Result { assert!(n_bits == 8 || n_bits == 16); - let msb = cb.create_witin(|| "msb")?; + let msb = cb.create_witin(|| "msb"); // require msb is boolean cb.assert_bit(|| "msb is boolean", msb.expr())?; diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index 399798398..2edd1e5b6 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -55,7 +55,7 @@ impl BInstructionConfig { let rs2 = ReadRS2::construct_circuit(circuit_builder, rs2_read, vm_state.ts)?; // Immediate - let imm = circuit_builder.create_witin(|| "imm")?; + let imm = circuit_builder.create_witin(|| "imm"); // Fetch instruction circuit_builder.lk_fetch(&InsnRecord::new( diff --git a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs index 21ba4de0c..711ed51f4 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs @@ -34,7 +34,7 @@ impl Instruction for HaltInstruction { } fn construct_circuit(cb: &mut CircuitBuilder) -> Result { - let prev_x10_ts = cb.create_witin(|| "prev_x10_ts")?; + let prev_x10_ts = cb.create_witin(|| "prev_x10_ts"); let exit_code = { let exit_code = cb.query_exit_code()?; [exit_code[0].expr(), exit_code[1].expr()] diff --git a/ceno_zkvm/src/instructions/riscv/ecall_insn.rs b/ceno_zkvm/src/instructions/riscv/ecall_insn.rs index 49bc1d67a..d409b288a 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall_insn.rs @@ -28,8 +28,8 @@ impl EcallInstructionConfig { syscall_ret_value: Option>, next_pc: Option>, ) -> Result { - let pc = cb.create_witin(|| "pc")?; - let ts = cb.create_witin(|| "cur_ts")?; + let pc = cb.create_witin(|| "pc"); + let ts = cb.create_witin(|| "cur_ts"); cb.state_in(pc.expr(), ts.expr())?; cb.state_out( @@ -47,7 +47,7 @@ impl EcallInstructionConfig { 0.into(), // imm = 0 ))?; - let prev_x5_ts = cb.create_witin(|| "prev_x5_ts")?; + let prev_x5_ts = cb.create_witin(|| "prev_x5_ts"); // read syscall_id from x5 and write return value to x5 let (_, lt_x5_cfg) = cb.register_write( diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 5548786f5..cd799c29a 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -37,14 +37,14 @@ impl StateInOut { circuit_builder: &mut CircuitBuilder, branching: bool, ) -> Result { - let pc = circuit_builder.create_witin(|| "pc")?; + let pc = circuit_builder.create_witin(|| "pc"); let (next_pc_opt, next_pc_expr) = if branching { - let next_pc = circuit_builder.create_witin(|| "next_pc")?; + let next_pc = circuit_builder.create_witin(|| "next_pc"); (Some(next_pc), next_pc.expr()) } else { (None, pc.expr() + PC_STEP_SIZE) }; - let ts = circuit_builder.create_witin(|| "ts")?; + let ts = circuit_builder.create_witin(|| "ts"); let next_ts = ts.expr() + Tracer::SUBCYCLES_PER_INSN; circuit_builder.state_in(pc.expr(), ts.expr())?; circuit_builder.state_out(next_pc_expr, next_ts)?; @@ -87,8 +87,8 @@ impl ReadRS1 { rs1_read: RegisterExpr, cur_ts: WitIn, ) -> Result { - let id = circuit_builder.create_witin(|| "rs1_id")?; - let prev_ts = circuit_builder.create_witin(|| "prev_rs1_ts")?; + let id = circuit_builder.create_witin(|| "rs1_id"); + let prev_ts = circuit_builder.create_witin(|| "prev_rs1_ts"); let (_, lt_cfg) = circuit_builder.register_read( || "read_rs1", id, @@ -142,8 +142,8 @@ impl ReadRS2 { rs2_read: RegisterExpr, cur_ts: WitIn, ) -> Result { - let id = circuit_builder.create_witin(|| "rs2_id")?; - let prev_ts = circuit_builder.create_witin(|| "prev_rs2_ts")?; + let id = circuit_builder.create_witin(|| "rs2_id"); + let prev_ts = circuit_builder.create_witin(|| "prev_rs2_ts"); let (_, lt_cfg) = circuit_builder.register_read( || "read_rs2", id, @@ -197,8 +197,8 @@ impl WriteRD { rd_written: RegisterExpr, cur_ts: WitIn, ) -> Result { - let id = circuit_builder.create_witin(|| "rd_id")?; - let prev_ts = circuit_builder.create_witin(|| "prev_rd_ts")?; + let id = circuit_builder.create_witin(|| "rd_id"); + let prev_ts = circuit_builder.create_witin(|| "prev_rd_ts"); let prev_value = UInt::new_unchecked(|| "prev_rd_value", circuit_builder)?; let (_, lt_cfg) = circuit_builder.register_write( || "write_rd", @@ -258,7 +258,7 @@ impl ReadMEM { mem_read: Expression, cur_ts: WitIn, ) -> Result { - let prev_ts = circuit_builder.create_witin(|| "prev_ts")?; + let prev_ts = circuit_builder.create_witin(|| "prev_ts"); let (_, lt_cfg) = circuit_builder.memory_read( || "read_memory", &mem_addr, @@ -313,7 +313,7 @@ impl WriteMEM { new_value: MemoryExpr, cur_ts: WitIn, ) -> Result { - let prev_ts = circuit_builder.create_witin(|| "prev_ts")?; + let prev_ts = circuit_builder.create_witin(|| "prev_ts"); let (_, lt_cfg) = circuit_builder.memory_write( || "write_memory", @@ -408,7 +408,7 @@ impl MemAddr { // Witness and constrain the non-zero low bits. let low_bits = (n_zeros..Self::N_LOW_BITS) .map(|i| { - let bit = cb.create_witin(|| format!("addr_bit_{}", i))?; + let bit = cb.create_witin(|| format!("addr_bit_{}", i)); cb.assert_bit(|| format!("addr_bit_{}", i), bit.expr())?; Ok(bit) }) diff --git a/ceno_zkvm/src/instructions/riscv/jump/auipc.rs b/ceno_zkvm/src/instructions/riscv/jump/auipc.rs index 6c979ec25..7d20430b7 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/auipc.rs @@ -36,7 +36,7 @@ impl Instruction for AuipcInstruction { fn construct_circuit( circuit_builder: &mut CircuitBuilder, ) -> Result, ZKVMError> { - let imm = circuit_builder.create_witin(|| "imm")?; + let imm = circuit_builder.create_witin(|| "imm"); let rd_written = UInt::new(|| "rd_written", circuit_builder)?; let u_insn = UInstructionConfig::construct_circuit( @@ -46,7 +46,7 @@ impl Instruction for AuipcInstruction { rd_written.register_expr(), )?; - let overflow_bit = circuit_builder.create_witin(|| "overflow_bit")?; + let overflow_bit = circuit_builder.create_witin(|| "overflow_bit"); circuit_builder.assert_bit(|| "is_bit", overflow_bit.expr())?; // assert: imm + pc = rd_written + overflow_bit * 2^32 diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs index 889f3eca8..a80e11c4a 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs @@ -44,7 +44,7 @@ impl Instruction for JalrInstruction { circuit_builder: &mut CircuitBuilder, ) -> Result, ZKVMError> { let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; // unsigned 32-bit value - let imm = circuit_builder.create_witin(|| "imm")?; // signed 12-bit value + let imm = circuit_builder.create_witin(|| "imm"); // signed 12-bit value let rd_written = UInt::new(|| "rd_written", circuit_builder)?; let i_insn = IInstructionConfig::construct_circuit( @@ -63,7 +63,7 @@ impl Instruction for JalrInstruction { // 3. next_pc = next_pc_addr aligned to even value (round down) let next_pc_addr = MemAddr::::construct_unaligned(circuit_builder)?; - let overflow = circuit_builder.create_witin(|| "overflow")?; + let overflow = circuit_builder.create_witin(|| "overflow"); circuit_builder.require_equal( || "rs1+imm = next_pc_unrounded + overflow*2^32", diff --git a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs index 0980253db..8c6a36998 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs @@ -33,7 +33,7 @@ impl MemWordChange { -> Result, ZKVMError> { (0..num_bytes) .map(|i| { - let byte = cb.create_witin(|| format!("{}.le_bytes[{}]", anno, i))?; + let byte = cb.create_witin(|| format!("{}.le_bytes[{}]", anno, i)); cb.assert_ux::<_, _, 8>(|| "byte range check", byte.expr())?; Ok(byte) @@ -84,7 +84,7 @@ impl MemWordChange { )?; // alloc a new witIn to cache degree 2 expression - let expected_limb_change = cb.create_witin(|| "expected_limb_change")?; + let expected_limb_change = cb.create_witin(|| "expected_limb_change"); cb.condition_require_equal( || "expected_limb_change = select(low_bits[0], rs2 - prev)", low_bits[0].clone(), @@ -94,7 +94,7 @@ impl MemWordChange { )?; // alloc a new witIn to cache degree 2 expression - let expected_change = cb.create_witin(|| "expected_change")?; + let expected_change = cb.create_witin(|| "expected_change"); cb.condition_require_equal( || "expected_change = select(low_bits[1], limb_change*2^16, limb_change)", low_bits[1].clone(), @@ -117,7 +117,7 @@ impl MemWordChange { let prev_limbs = prev_word.expr(); let rs2_limbs = rs2_word.expr(); - let expected_change = cb.create_witin(|| "expected_change")?; + let expected_change = cb.create_witin(|| "expected_change"); // alloc a new witIn to cache degree 2 expression cb.condition_require_equal( diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index e7d7e20b5..1bc76f311 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -82,7 +82,7 @@ impl Instruction for LoadInstruction, ) -> Result { let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; // unsigned 32-bit value - let imm = circuit_builder.create_witin(|| "imm")?; // signed 12-bit value + let imm = circuit_builder.create_witin(|| "imm"); // signed 12-bit value let memory_read = UInt::new(|| "memory_read", circuit_builder)?; let memory_addr = match I::INST_KIND { @@ -104,7 +104,7 @@ impl Instruction for LoadInstruction { - let target_limb = circuit_builder.create_witin(|| "target_limb")?; + let target_limb = circuit_builder.create_witin(|| "target_limb"); circuit_builder.condition_require_equal( || "target_limb = memory_value[low_bits[1]]", addr_low_bits[1].clone(), diff --git a/ceno_zkvm/src/instructions/riscv/memory/store.rs b/ceno_zkvm/src/instructions/riscv/memory/store.rs index ddf76d0fd..878777de6 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store.rs @@ -74,7 +74,7 @@ impl Instruction let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; // unsigned 32-bit value let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?; let prev_memory_value = UInt::new(|| "prev_memory_value", circuit_builder)?; - let imm = circuit_builder.create_witin(|| "imm")?; // signed 12-bit value + let imm = circuit_builder.create_witin(|| "imm"); // signed 12-bit value let memory_addr = match I::INST_KIND { InsnKind::SW => MemAddr::construct_align4(circuit_builder), diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index 41d39a352..f5d0f8e8d 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -56,7 +56,7 @@ impl Instruction for ShiftLogicalInstru circuit_builder: &mut crate::circuit_builder::CircuitBuilder, ) -> Result { let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?; - let rs2_low5 = circuit_builder.create_witin(|| "rs2_low5")?; + let rs2_low5 = circuit_builder.create_witin(|| "rs2_low5"); // pow2_rs2_low5 is unchecked because it's assignment will be constrained due it's use in lookup_pow2 below let mut pow2_rs2_low5 = UInt::new_unchecked(|| "pow2_rs2_low5", circuit_builder)?; // rs2 = rs2_high | rs2_low5 diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index 81688089e..1f43da85c 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -57,11 +57,11 @@ impl Instruction for ShiftImmInstructio circuit_builder: &mut CircuitBuilder, ) -> Result { // Note: `imm` wtns is set to 2**imm (upto 32 bit) just for efficient verification. - let imm = circuit_builder.create_witin(|| "imm")?; + let imm = circuit_builder.create_witin(|| "imm"); let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; let rd_written = UInt::new(|| "rd_written", circuit_builder)?; - let outflow = circuit_builder.create_witin(|| "outflow")?; + let outflow = circuit_builder.create_witin(|| "outflow"); let assert_lt_config = AssertLTConfig::construct_circuit( circuit_builder, || "outflow < imm", diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index 4ba2ced3a..71d0bb26f 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -45,7 +45,7 @@ impl Instruction for SltiInstruction { fn construct_circuit(cb: &mut CircuitBuilder) -> Result { // If rs1_read < imm, rd_written = 1. Otherwise rd_written = 0 let rs1_read = UInt::new_unchecked(|| "rs1_read", cb)?; - let imm = cb.create_witin(|| "imm")?; + let imm = cb.create_witin(|| "imm"); let max_signed_limb_expr: Expression<_> = ((1 << (UInt::::LIMB_BITS - 1)) - 1).into(); let is_rs1_neg = IsLtConfig::construct_circuit( diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index cb5978754..75b4b377d 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -1,6 +1,7 @@ #![feature(box_patterns)] #![feature(stmt_expr_attributes)] #![feature(variant_count)] +#![feature(strict_overflow_ops)] pub mod error; pub mod instructions; diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 74cc15de8..541236d28 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -12,7 +12,7 @@ use crate::{ }; use ark_std::test_rng; use base64::{Engine, engine::general_purpose::STANDARD_NO_PAD}; -use ceno_emul::{ByteAddr, CENO_PLATFORM}; +use ceno_emul::{ByteAddr, CENO_PLATFORM, PC_WORD_SIZE, Program}; use ff::Field; use ff_ext::ExtensionField; use generic_static::StaticTypeMap; @@ -20,7 +20,7 @@ use goldilocks::SmallField; use itertools::{Itertools, izip}; use multilinear_extensions::{mle::IntoMLEs, virtual_poly_v2::ArcMultilinearExtension}; use std::{ - collections::HashSet, + collections::{BTreeMap, HashSet}, fs::File, hash::Hash, io::{BufReader, ErrorKind}, @@ -31,7 +31,7 @@ use std::{ use strum::IntoEnumIterator; const MOCK_PROGRAM_SIZE: usize = 32; -pub const MOCK_PC_START: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start()); +pub const MOCK_PC_START: ByteAddr = ByteAddr(CENO_PLATFORM.pc_base()); #[allow(clippy::enum_variant_names)] #[derive(Debug, Clone)] @@ -389,10 +389,28 @@ impl<'a, E: ExtensionField + Hash> MockProver { lkm: Option, ) -> Result<(), Vec>> { // fix the program table - let mut programs = [0u32; MOCK_PROGRAM_SIZE]; - for (i, &program) in input_programs.iter().enumerate() { - programs[i] = program; - } + let instructions = input_programs + .iter() + .cloned() + .chain(std::iter::repeat(0)) + .take(MOCK_PROGRAM_SIZE) + .collect_vec(); + let image = instructions + .iter() + .enumerate() + .map(|(insn_idx, &insn)| { + ( + CENO_PLATFORM.pc_base() + (insn_idx * PC_WORD_SIZE) as u32, + insn, + ) + }) + .collect::>(); + let program = Program::new( + CENO_PLATFORM.pc_base(), + CENO_PLATFORM.pc_base(), + instructions, + image, + ); // load tables let (challenge, mut table) = if let Some(challenge) = challenge { @@ -401,7 +419,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { load_once_tables(cb) }; let mut prog_table = vec![]; - Self::load_program_table(&mut prog_table, &programs, challenge); + Self::load_program_table(&mut prog_table, &program, challenge); for prog in prog_table { table.insert(prog); } @@ -589,11 +607,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { } } - fn load_program_table( - t_vec: &mut Vec>, - programs: &[u32; MOCK_PROGRAM_SIZE], - challenge: [E; 2], - ) { + fn load_program_table(t_vec: &mut Vec>, program: &Program, challenge: [E; 2]) { let mut cs = ConstraintSystem::::new(|| "mock_program"); let mut cb = CircuitBuilder::new(&mut cs); let config = @@ -601,7 +615,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { let fixed = ProgramTableCircuit::::generate_fixed_traces( &config, cs.num_fixed, - programs, + program, ); for table_expr in &cs.lk_table_expressions { for row in fixed.iter_rows() { @@ -729,16 +743,16 @@ mod tests { pub fn construct_circuit( cb: &mut CircuitBuilder, ) -> Result { - let a = cb.create_witin(|| "a")?; - let b = cb.create_witin(|| "b")?; - let c = cb.create_witin(|| "c")?; + let a = cb.create_witin(|| "a"); + let b = cb.create_witin(|| "b"); + let c = cb.create_witin(|| "c"); // degree 1 cb.require_equal(|| "a + 1 == b", b.expr(), a.expr() + 1)?; cb.require_zero(|| "c - 2 == 0", c.expr() - 2)?; // degree > 1 - let d = cb.create_witin(|| "d")?; + let d = cb.create_witin(|| "d"); cb.require_zero( || "d*d - 6*d + 9 == 0", d.expr() * d.expr() - d.expr() * 6 + 9, @@ -783,7 +797,7 @@ mod tests { pub fn construct_circuit( cb: &mut CircuitBuilder, ) -> Result { - let a = cb.create_witin(|| "a")?; + let a = cb.create_witin(|| "a"); cb.assert_ux::<_, _, 5>(|| "assert u5", a.expr())?; Ok(Self { a }) } @@ -863,8 +877,8 @@ mod tests { impl AssertLtCircuit { fn construct_circuit(cb: &mut CircuitBuilder) -> Result { - let a = cb.create_witin(|| "a")?; - let b = cb.create_witin(|| "b")?; + let a = cb.create_witin(|| "a"); + let b = cb.create_witin(|| "b"); let lt_wtns = AssertLTConfig::construct_circuit(cb, || "lt", a.expr(), b.expr(), 1)?; Ok(Self { a, b, lt_wtns }) } @@ -977,8 +991,8 @@ mod tests { impl LtCircuit { fn construct_circuit(cb: &mut CircuitBuilder) -> Result { - let a = cb.create_witin(|| "a")?; - let b = cb.create_witin(|| "b")?; + let a = cb.create_witin(|| "a"); + let b = cb.create_witin(|| "b"); let lt_wtns = IsLtConfig::construct_circuit(cb, || "lt", a.expr(), b.expr(), 1)?; Ok(Self { a, b, lt_wtns }) } diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index f928a8918..18fe04973 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -1,9 +1,9 @@ use std::{marker::PhantomData, mem::MaybeUninit}; use ceno_emul::{ - ByteAddr, CENO_PLATFORM, + CENO_PLATFORM, InsnKind::{ADD, EANY}, - StepRecord, VMState, + PC_WORD_SIZE, Program, StepRecord, VMState, }; use ff::Field; use ff_ext::ExtensionField; @@ -50,7 +50,7 @@ impl Instruction for Test } fn construct_circuit(cb: &mut CircuitBuilder) -> Result { - let reg_id = cb.create_witin(|| "reg_id")?; + let reg_id = cb.create_witin(|| "reg_id"); (0..RW).try_for_each(|_| { let record = cb.rlc_chip_record(vec![ Expression::::Constant(E::BaseField::ONE), @@ -202,6 +202,23 @@ fn test_single_add_instance_e2e() { type E = GoldilocksExt2; type Pcs = Basefold; + // set up program + let program = Program::new( + CENO_PLATFORM.pc_base(), + CENO_PLATFORM.pc_base(), + PROGRAM_CODE.to_vec(), + PROGRAM_CODE + .iter() + .enumerate() + .map(|(insn_idx, &insn)| { + ( + (insn_idx * PC_WORD_SIZE) as u32 + CENO_PLATFORM.pc_base(), + insn, + ) + }) + .collect(), + ); + let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup"); let (pp, vp) = Pcs::trim(&pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); let mut zkvm_cs = ZKVMConstraintSystem::default(); @@ -225,7 +242,7 @@ fn test_single_add_instance_e2e() { zkvm_fixed_traces.register_table_circuit::>( &zkvm_cs, prog_config.clone(), - &PROGRAM_CODE, + &program, ); let pk = zkvm_cs @@ -235,11 +252,7 @@ fn test_single_add_instance_e2e() { let vk = pk.get_vk(); // single instance - let mut vm = VMState::new(CENO_PLATFORM); - let pc_start = ByteAddr(CENO_PLATFORM.pc_start()).waddr(); - for (i, insn) in PROGRAM_CODE.iter().enumerate() { - vm.init_memory(pc_start + i, *insn); - } + let mut vm = VMState::new(CENO_PLATFORM, program.clone()); let all_records = vm .iter_until_halt() .collect::, _>>() @@ -282,7 +295,7 @@ fn test_single_add_instance_e2e() { .assign_table_circuit::>( &zkvm_cs, &prog_config, - &PROGRAM_CODE.len(), + &program, ) .unwrap(); diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index fc4b7a444..8e605bcf9 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -673,9 +673,9 @@ mod tests { type B = goldilocks::Goldilocks; let mut cs = ConstraintSystem::::new(|| "test"); let mut cb = CircuitBuilder::new(&mut cs); - let a = cb.create_witin(|| "a").unwrap(); - let b = cb.create_witin(|| "b").unwrap(); - let c = cb.create_witin(|| "c").unwrap(); + let a = cb.create_witin(|| "a"); + let b = cb.create_witin(|| "b"); + let c = cb.create_witin(|| "c"); let expr: Expression = a.expr() + b.expr() + a.expr() * b.expr() + (c.expr() * 3 + 2); @@ -699,9 +699,9 @@ mod tests { type B = goldilocks::Goldilocks; let mut cs = ConstraintSystem::::new(|| "test"); let mut cb = CircuitBuilder::new(&mut cs); - let a = cb.create_witin(|| "a").unwrap(); - let b = cb.create_witin(|| "b").unwrap(); - let c = cb.create_witin(|| "c").unwrap(); + let a = cb.create_witin(|| "a"); + let b = cb.create_witin(|| "b"); + let c = cb.create_witin(|| "c"); let expr: Expression = a.expr() + b.expr() diff --git a/ceno_zkvm/src/tables/ops/ops_impl.rs b/ceno_zkvm/src/tables/ops/ops_impl.rs index f9bf4d0c8..04939a596 100644 --- a/ceno_zkvm/src/tables/ops/ops_impl.rs +++ b/ceno_zkvm/src/tables/ops/ops_impl.rs @@ -32,7 +32,7 @@ impl OpTableConfig { cb.create_fixed(|| "b")?, cb.create_fixed(|| "c")?, ]; - let mlt = cb.create_witin(|| "mlt")?; + let mlt = cb.create_witin(|| "mlt"); let rlc_record = cb.rlc_chip_record(vec![ (rom_type as usize).into(), diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 3514365c8..d5b629957 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -10,7 +10,7 @@ use crate::{ tables::TableCircuit, witness::RowMajorMatrix, }; -use ceno_emul::{CENO_PLATFORM, DecodedInstruction, PC_STEP_SIZE, WORD_SIZE}; +use ceno_emul::{DecodedInstruction, PC_STEP_SIZE, Program, WORD_SIZE}; use ff_ext::ExtensionField; use goldilocks::SmallField; use itertools::Itertools; @@ -117,8 +117,8 @@ impl TableCircuit for ProgramTableCircuit { type TableConfig = ProgramTableConfig; - type FixedInput = [u32; PROGRAM_SIZE]; - type WitnessInput = usize; + type FixedInput = Program; + type WitnessInput = Program; fn name() -> String { "PROGRAM".into() @@ -135,7 +135,7 @@ impl TableCircuit cb.create_fixed(|| "imm_or_funct7")?, ]); - let mlt = cb.create_witin(|| "mlt")?; + let mlt = cb.create_witin(|| "mlt"); let record_exprs = { let mut fields = vec![E::BaseField::from(ROMType::Instruction as u64).expr()]; @@ -153,9 +153,8 @@ impl TableCircuit num_fixed: usize, program: &Self::FixedInput, ) -> RowMajorMatrix { - // TODO: get bytecode of the program. - let num_instructions = program.len(); - let pc_start = CENO_PLATFORM.pc_start(); + let num_instructions = program.instructions.len(); + let pc_base = program.base_address; let mut fixed = RowMajorMatrix::::new(num_instructions, num_fixed); @@ -164,8 +163,8 @@ impl TableCircuit .with_min_len(MIN_PAR_SIZE) .zip((0..num_instructions).into_par_iter()) .for_each(|(row, i)| { - let pc = pc_start + (i * PC_STEP_SIZE) as u32; - let insn = DecodedInstruction::new(program[i]); + let pc = pc_base + (i * PC_STEP_SIZE) as u32; + let insn = DecodedInstruction::new(program.instructions[i]); let values = InsnRecord::from_decoded(pc, &insn); // Copy all the fields except immediate. @@ -192,13 +191,13 @@ impl TableCircuit config: &Self::TableConfig, num_witin: usize, multiplicity: &[HashMap], - num_instructions: &usize, + program: &Program, ) -> Result, ZKVMError> { let multiplicity = &multiplicity[ROMType::Instruction as usize]; - let mut prog_mlt = vec![0_usize; *num_instructions]; + let mut prog_mlt = vec![0_usize; program.instructions.len()]; for (pc, mlt) in multiplicity { - let i = (*pc as usize - CENO_PLATFORM.pc_start() as usize) / WORD_SIZE; + let i = (*pc as usize - program.base_address as usize) / WORD_SIZE; prog_mlt[i] = *mlt; } diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index a4d263123..99d50b0ae 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -40,10 +40,10 @@ impl RamTableConfig { .collect::, ZKVMError>>()?; let addr = cb.create_fixed(|| "addr")?; - let final_v = (0..RAM::V_LIMBS) + let final_v: Vec<_> = (0..RAM::V_LIMBS) .map(|i| cb.create_witin(|| format!("final_v_limb_{i}"))) - .collect::, ZKVMError>>()?; - let final_cycle = cb.create_witin(|| "final_cycle")?; + .collect(); + let final_cycle = cb.create_witin(|| "final_cycle"); let init_table = cb.rlc_chip_record( [ diff --git a/ceno_zkvm/src/tables/range/range_impl.rs b/ceno_zkvm/src/tables/range/range_impl.rs index 8a14fe236..7e6bb62f6 100644 --- a/ceno_zkvm/src/tables/range/range_impl.rs +++ b/ceno_zkvm/src/tables/range/range_impl.rs @@ -28,7 +28,7 @@ impl RangeTableConfig { table_len: usize, ) -> Result { let fixed = cb.create_fixed(|| "fixed")?; - let mlt = cb.create_witin(|| "mlt")?; + let mlt = cb.create_witin(|| "mlt"); let rlc_record = cb.rlc_chip_record(vec![(rom_type as usize).into(), Expression::Fixed(fixed)]); diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index d8f53f200..63f3dc7ac 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -90,7 +90,7 @@ impl UIntLimbs { limbs: UintLimb::WitIn( (0..Self::NUM_LIMBS) .map(|i| { - let w = cb.create_witin(|| format!("limb_{i}"))?; + let w = cb.create_witin(|| format!("limb_{i}")); if is_check { cb.assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), w.expr())?; } @@ -162,7 +162,7 @@ impl UIntLimbs { assert_eq!(expr_limbs.len(), Self::NUM_LIMBS); let limbs = (0..Self::NUM_LIMBS) .map(|i| { - let w = circuit_builder.create_witin(|| "wit for limb").unwrap(); + let w = circuit_builder.create_witin(|| "wit for limb"); circuit_builder .assert_ux::<_, _, C>(|| "range check", w.expr()) .unwrap(); @@ -324,7 +324,7 @@ impl UIntLimbs { .flat_map(|large_limb| { let limbs = (0..k) .map(|_| { - let w = circuit_builder.create_witin(|| "").unwrap(); + let w = circuit_builder.create_witin(|| ""); circuit_builder.assert_byte(|| "", w.expr()).unwrap(); w.expr() }) @@ -370,7 +370,7 @@ impl UIntLimbs { self.limbs = UintLimb::WitIn( (0..Self::NUM_LIMBS) .map(|i| { - let w = cb.create_witin(|| format!("limb_{i}"))?; + let w = cb.create_witin(|| format!("limb_{i}")); cb.assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), w.expr())?; Ok(w) }) @@ -400,7 +400,7 @@ impl UIntLimbs { self.carries = Some( (0..carries_len) .map(|i| { - let c = cb.create_witin(|| format!("carry_{i}"))?; + let c = cb.create_witin(|| format!("carry_{i}")); Ok(c) }) .collect::, ZKVMError>>()?, diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index c93bb73ae..01413f6a3 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -119,7 +119,7 @@ impl UIntLimbs { }; // with high limb, overall cell will be double let c_limbs: Vec = (0..num_limbs).try_fold(vec![], |mut c_limbs, i| { - let limb = circuit_builder.create_witin(|| format!("limb_{i}"))?; + let limb = circuit_builder.create_witin(|| format!("limb_{i}")); circuit_builder.assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), limb.expr())?; c_limbs.push(limb); Result::, ZKVMError>::Ok(c_limbs) @@ -127,7 +127,7 @@ impl UIntLimbs { let c_carries: Vec = (0..num_limbs).try_fold(vec![], |mut c_carries, i| { // skip last carry if with_overflow == false if i != num_limbs - 1 || with_overflow { - let carry = circuit_builder.create_witin(|| format!("carry_{i}"))?; + let carry = circuit_builder.create_witin(|| format!("carry_{i}")); c_carries.push(carry); } Result::, ZKVMError>::Ok(c_carries) @@ -302,7 +302,7 @@ impl UIntLimbs { where E: ExtensionField, { - let high_limb_no_msb = circuit_builder.create_witin(|| "high_limb_mask")?; + let high_limb_no_msb = circuit_builder.create_witin(|| "high_limb_mask"); let high_limb = self.limbs[Self::NUM_LIMBS - 1].expr(); circuit_builder.lookup_and_byte( @@ -329,7 +329,7 @@ impl UIntLimbs { let n_bytes = Self::NUM_LIMBS; let indexes: Vec = (0..n_bytes) .map(|_| circuit_builder.create_witin(|| "index")) - .collect::>()?; + .collect(); // indicate the first non-zero byte index i_0 of a[i] - b[i] // from high to low @@ -342,7 +342,7 @@ impl UIntLimbs { // circuit_builder.assert_bit(|| "bit assert", index_sum)?; // equal zero if a==b, otherwise equal (a[i_0]-b[i_0])^{-1} - let byte_diff_inv = circuit_builder.create_witin(|| "byte_diff_inverse")?; + let byte_diff_inv = circuit_builder.create_witin(|| "byte_diff_inverse"); // define accumulated index sum from high to low let si_expr: Vec> = indexes @@ -403,7 +403,7 @@ impl UIntLimbs { - index_ne.expr(), )?; - let is_ltu = circuit_builder.create_witin(|| "is_ltu")?; + let is_ltu = circuit_builder.create_witin(|| "is_ltu"); // now we know the first non-equal byte pairs is (lhs_ne_byte, rhs_ne_byte) circuit_builder.lookup_ltu_byte(lhs_ne_byte.expr(), rhs_ne_byte.expr(), is_ltu.expr())?; Ok(UIntLtuConfig { @@ -421,7 +421,7 @@ impl UIntLimbs { circuit_builder: &mut CircuitBuilder, rhs: &UIntLimbs, ) -> Result { - let is_lt = circuit_builder.create_witin(|| "is_lt")?; + let is_lt = circuit_builder.create_witin(|| "is_lt"); // circuit_builder.assert_bit(|| "assert_bit", is_lt.expr())?; let lhs_msb = self.msb_decompose(circuit_builder)?; diff --git a/ceno_zkvm/src/virtual_polys.rs b/ceno_zkvm/src/virtual_polys.rs index 4019f2d22..9aaafdf4f 100644 --- a/ceno_zkvm/src/virtual_polys.rs +++ b/ceno_zkvm/src/virtual_polys.rs @@ -192,8 +192,8 @@ mod tests { fn test_add_mle_list_by_expr() { let mut cs = ConstraintSystem::new(|| "test_root"); let mut cb = CircuitBuilder::::new(&mut cs); - let x = cb.create_witin(|| "x").unwrap(); - let y = cb.create_witin(|| "y").unwrap(); + let x = cb.create_witin(|| "x"); + let y = cb.create_witin(|| "y"); let wits_in: Vec> = (0..cs.num_witin as usize) .map(|_| vec![Goldilocks::from(1)].into_mle().into())