diff --git a/ceno_emul/src/addr.rs b/ceno_emul/src/addr.rs index 0ce39f56b..200739ce7 100644 --- a/ceno_emul/src/addr.rs +++ b/ceno_emul/src/addr.rs @@ -17,6 +17,7 @@ use std::{fmt, ops}; pub const WORD_SIZE: usize = 4; +pub const PC_WORD_SIZE: usize = 4; pub const PC_STEP_SIZE: usize = 4; // Type aliases to clarify the code without wrapper types. @@ -26,10 +27,10 @@ pub type Addr = u32; pub type Cycle = u64; pub type RegIdx = usize; -#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct ByteAddr(pub u32); -#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct WordAddr(u32); impl From for WordAddr { diff --git a/ceno_emul/src/elf.rs b/ceno_emul/src/elf.rs index f6a28f743..21d12779f 100644 --- a/ceno_emul/src/elf.rs +++ b/ceno_emul/src/elf.rs @@ -20,21 +20,47 @@ use alloc::collections::BTreeMap; use crate::addr::WORD_SIZE; use anyhow::{Context, Result, anyhow, bail}; -use elf::{ElfBytes, endian::LittleEndian, file::Class}; +use elf::{ + ElfBytes, + abi::{PF_R, PF_W, PF_X}, + endian::LittleEndian, + file::Class, +}; /// A RISC Zero program +#[derive(Clone, Debug)] pub struct Program { /// The entrypoint of the program pub entry: u32, - + /// This is the lowest address of the program's executable code + pub base_address: u32, + /// The instructions of the program + pub instructions: Vec, /// The initial memory image pub image: BTreeMap, } impl Program { + /// Create program + pub fn new( + entry: u32, + base_address: u32, + instructions: Vec, + image: BTreeMap, + ) -> Program { + Self { + entry, + base_address, + instructions, + image, + } + } /// Initialize a RISC Zero Program from an appropriate ELF file pub fn load_elf(input: &[u8], max_mem: u32) -> Result { + let mut instructions: Vec = Vec::new(); let mut image: BTreeMap = BTreeMap::new(); + let mut base_address = None; + let elf = ElfBytes::::minimal_parse(input) .map_err(|err| anyhow!("Elf parse error: {err}"))?; if elf.ehdr.class != Class::ELF32 { @@ -58,7 +84,18 @@ impl Program { if segments.len() > 256 { bail!("Too many program headers"); } - for segment in segments.iter().filter(|x| x.p_type == elf::abi::PT_LOAD) { + for (idx, segment) in segments + .iter() + .filter(|x| x.p_type == elf::abi::PT_LOAD) + .enumerate() + { + tracing::debug!( + "loadable segement {}: PF_R={}, PF_W={}, PF_X={}", + idx, + segment.p_flags & PF_R != 0, + segment.p_flags & PF_W != 0, + segment.p_flags & PF_X != 0, + ); let file_size: u32 = segment .p_filesz .try_into() @@ -77,6 +114,13 @@ impl Program { .p_vaddr .try_into() .map_err(|err| anyhow!("vaddr is larger than 32 bits. {err}"))?; + if (segment.p_flags & PF_X) != 0 { + if base_address.is_none() { + base_address = Some(vaddr); + } else { + return Err(anyhow!("only support one executable segment")); + } + } if vaddr % WORD_SIZE as u32 != 0 { bail!("vaddr {vaddr:08x} is unaligned"); } @@ -104,9 +148,25 @@ impl Program { word |= (*byte as u32) << (j * 8); } image.insert(addr, word); + if (segment.p_flags & PF_X) != 0 { + instructions.push(word); + } } } } - Ok(Program { entry, image }) + + if base_address.is_none() { + return Err(anyhow!("does not have executable segment")); + } + let base_address = base_address.unwrap(); + assert!(entry >= base_address); + assert!((entry - base_address) as usize <= instructions.len() * WORD_SIZE); + + Ok(Program { + entry, + base_address, + image, + instructions, + }) } } diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index 77a725169..e58dd2687 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -5,19 +5,30 @@ use crate::addr::{Addr, RegIdx}; /// - the layout of virtual memory, /// - special addresses, such as the initial PC, /// - codes of environment calls. -pub struct Platform; +#[derive(Copy, Clone)] +pub struct Platform { + pub rom_start: Addr, + pub rom_end: Addr, + pub ram_start: Addr, + pub ram_end: Addr, +} -pub const CENO_PLATFORM: Platform = Platform; +pub const CENO_PLATFORM: Platform = Platform { + rom_start: 0x2000_0000, + rom_end: 0x3000_0000 - 1, + ram_start: 0x8000_0000, + ram_end: 0xFFFF_FFFF, +}; impl Platform { // Virtual memory layout. pub const fn rom_start(&self) -> Addr { - 0x2000_0000 + self.rom_start } pub const fn rom_end(&self) -> Addr { - 0x3000_0000 - 1 + self.rom_end } pub fn is_rom(&self, addr: Addr) -> bool { @@ -43,18 +54,17 @@ impl Platform { } pub const fn ram_start(&self) -> Addr { - let ram_start = 0x8000_0000; if cfg!(feature = "forbid_overflow") { // -1<<11 == 0x800 is the smallest negative 'immediate' // offset we can have in memory instructions. // So if we stay away from it, we are safe. - assert!(ram_start >= 0x800); + assert!(self.ram_start >= 0x800); } - ram_start + self.ram_start } pub const fn ram_end(&self) -> Addr { - 0xFFFF_FFFF + self.ram_end - if cfg!(feature = "forbid_overflow") { // (1<<11) - 1 == 0x7ff is the largest positive 'immediate' // offset we can have in memory instructions. @@ -95,7 +105,7 @@ impl Platform { // Startup. - pub const fn pc_start(&self) -> Addr { + pub const fn pc_base(&self) -> Addr { self.rom_start() } @@ -148,7 +158,7 @@ mod tests { #[test] fn test_no_overlap() { let p = CENO_PLATFORM; - assert!(p.can_execute(p.pc_start())); + assert!(p.can_execute(p.pc_base())); // ROM and RAM do not overlap. assert!(!p.is_rom(p.ram_start())); assert!(!p.is_rom(p.ram_end())); diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index 2abd03cce..2deb4598a 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -88,7 +88,6 @@ pub struct Emulator { } #[derive(Debug)] -#[allow(dead_code)] pub enum TrapCause { InstructionAddressMisaligned, InstructionAccessFault, @@ -557,6 +556,7 @@ impl Emulator { let decoded = DecodedInstruction::new(word); let insn = self.table.lookup(&decoded); ctx.on_insn_decoded(&decoded); + tracing::trace!("pc: {:x}, kind: {:?}", pc.0, insn.kind); if match insn.category { InsnCategory::Compute => self.step_compute(ctx, insn.kind, &decoded)?, @@ -776,6 +776,7 @@ impl Emulator { let addr = ByteAddr(rs1.wrapping_add(decoded.imm_s())); let shift = 8 * (addr.0 & 3); if !ctx.check_data_store(addr) { + tracing::error!("mstore: addr={:x?},rs1={:x}", addr, rs1); return ctx.trap(TrapCause::StoreAccessFault); } let mut data = ctx.peek_memory(addr.waddr()); diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 2e590e0bd..b9423f418 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -115,6 +115,35 @@ impl StepRecord { ) } + pub fn new_im_instruction( + cycle: Cycle, + pc: ByteAddr, + insn_code: u32, + rs1_read: Word, + rd: Change, + mem_op: ReadOp, + prev_cycle: Cycle, + ) -> StepRecord { + let pc = Change::new(pc, pc + PC_STEP_SIZE); + StepRecord::new_insn( + cycle, + pc, + insn_code, + Some(rs1_read), + None, + Some(rd), + Some(WriteOp { + addr: mem_op.addr, + value: Change { + before: mem_op.value, + after: mem_op.value, + }, + previous_cycle: mem_op.previous_cycle, + }), + prev_cycle, + ) + } + pub fn new_u_instruction( cycle: Cycle, pc: ByteAddr, diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index 5d43a3865..bff16fe09 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -9,10 +9,11 @@ use crate::{ tracer::{Change, StepRecord, Tracer}, }; use anyhow::{Result, anyhow}; -use std::iter::from_fn; +use std::{iter::from_fn, ops::Deref, sync::Arc}; /// An implementation of the machine state and of the side-effects of operations. pub struct VMState { + program: Arc, platform: Platform, pc: Word, /// Map a word-address (addr/4) to a word. @@ -24,28 +25,39 @@ pub struct VMState { } impl VMState { - pub fn new(platform: Platform) -> Self { - let pc = platform.pc_start(); - Self { - platform, + pub fn new(platform: Platform, program: Program) -> Self { + let pc = program.entry; + let program = Arc::new(program); + + let mut vm = Self { pc, + platform, + program: program.clone(), memory: HashMap::new(), registers: [0; 32], halted: false, tracer: Tracer::new(), + }; + + // init memory from program.image + for (&addr, &value) in program.image.iter() { + vm.init_memory(ByteAddr(addr).waddr(), value); } + + vm } pub fn new_from_elf(platform: Platform, elf: &[u8]) -> Result { - let mut state = Self::new(platform); - let program = Program::load_elf(elf, u32::MAX).unwrap(); - for (addr, word) in program.image.iter() { - let addr = ByteAddr(*addr).waddr(); - state.init_memory(addr, *word); - } - if program.entry != state.platform.pc_start() { - return Err(anyhow!("Invalid entrypoint {:x}", program.entry)); + let program = Program::load_elf(elf, u32::MAX)?; + let state = Self::new(platform, program); + + if state.program.base_address != state.platform.rom_start() { + return Err(anyhow!( + "Invalid base_address {:x}", + state.program.base_address + )); } + Ok(state) } @@ -57,7 +69,11 @@ impl VMState { &self.tracer } - /// Set a word in memory without side-effects. + pub fn program(&self) -> &Program { + self.program.deref() + } + + /// Set a word in memory without side effects. pub fn init_memory(&mut self, addr: WordAddr, value: Word) { self.memory.insert(addr, value); } diff --git a/ceno_emul/tests/test_vm_trace.rs b/ceno_emul/tests/test_vm_trace.rs index 57069fd75..c01977823 100644 --- a/ceno_emul/tests/test_vm_trace.rs +++ b/ceno_emul/tests/test_vm_trace.rs @@ -1,19 +1,30 @@ #![allow(clippy::unusual_byte_groupings)] use anyhow::Result; -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use ceno_emul::{ - ByteAddr, CENO_PLATFORM, Cycle, EmuContext, InsnKind, StepRecord, Tracer, VMState, WordAddr, + CENO_PLATFORM, Cycle, EmuContext, InsnKind, Program, StepRecord, Tracer, VMState, WORD_SIZE, + WordAddr, }; #[test] fn test_vm_trace() -> Result<()> { - let mut ctx = VMState::new(CENO_PLATFORM); - - let pc_start = ByteAddr(CENO_PLATFORM.pc_start()).waddr(); - for (i, &inst) in PROGRAM_FIBONACCI_20.iter().enumerate() { - ctx.init_memory(pc_start + i as u32, inst); - } + let program = Program::new( + CENO_PLATFORM.pc_base(), + CENO_PLATFORM.pc_base(), + PROGRAM_FIBONACCI_20.to_vec(), + PROGRAM_FIBONACCI_20 + .iter() + .enumerate() + .map(|(insn_idx, &insn)| { + ( + CENO_PLATFORM.pc_base() + (WORD_SIZE * insn_idx) as u32, + insn, + ) + }) + .collect(), + ); + let mut ctx = VMState::new(CENO_PLATFORM, program); let steps = run(&mut ctx)?; @@ -35,7 +46,13 @@ fn test_vm_trace() -> Result<()> { #[test] fn test_empty_program() -> Result<()> { - let mut ctx = VMState::new(CENO_PLATFORM); + let empty_program = Program::new( + CENO_PLATFORM.pc_base(), + CENO_PLATFORM.pc_base(), + vec![], + BTreeMap::new(), + ); + let mut ctx = VMState::new(CENO_PLATFORM, empty_program); let res = run(&mut ctx); assert!(matches!(res, Err(e) if e.to_string().contains("IllegalInstruction(0)"))); Ok(()) diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index 3db8ba758..cd3cb220b 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -13,9 +13,9 @@ use ceno_zkvm::{ use clap::Parser; use ceno_emul::{ - ByteAddr, CENO_PLATFORM, EmuContext, + CENO_PLATFORM, EmuContext, InsnKind::{ADD, BLTU, EANY, JAL, LUI, LW}, - StepRecord, Tracer, VMState, WordAddr, encode_rv32, + PC_WORD_SIZE, Program, StepRecord, Tracer, VMState, WordAddr, encode_rv32, }; use ceno_zkvm::{ scheme::{PublicValues, constants::MAX_NUM_VARIABLES, verifier::ZKVMVerifier}, @@ -78,6 +78,21 @@ fn main() { type E = GoldilocksExt2; type Pcs = Basefold; + let program = Program::new( + CENO_PLATFORM.pc_base(), + CENO_PLATFORM.pc_base(), + PROGRAM_CODE.to_vec(), + PROGRAM_CODE + .iter() + .enumerate() + .map(|(insn_idx, &insn)| { + ( + (insn_idx * PC_WORD_SIZE) as u32 + CENO_PLATFORM.pc_base(), + insn, + ) + }) + .collect(), + ); let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap(); let subscriber = Registry::default() .with( @@ -104,7 +119,7 @@ fn main() { zkvm_fixed_traces.register_table_circuit::>( &zkvm_cs, prog_config.clone(), - &PROGRAM_CODE, + &program, ); let reg_init = initial_registers(); @@ -135,13 +150,8 @@ fn main() { // init vm.x1 = 1, vm.x2 = -1, vm.x3 = step_loop let public_io_init = init_public_io(&[1, u32::MAX, step_loop]); - let mut vm = VMState::new(CENO_PLATFORM); - let pc_start = ByteAddr(CENO_PLATFORM.pc_start()).waddr(); + let mut vm = VMState::new(CENO_PLATFORM, program.clone()); - // init program - for (i, inst) in PROGRAM_CODE.iter().enumerate() { - vm.init_memory(pc_start + i, *inst); - } // init mmio for record in program_data_init.iter().chain(public_io_init.iter()) { vm.init_memory(record.addr.into(), record.value); @@ -251,11 +261,7 @@ fn main() { // assign program circuit zkvm_witness - .assign_table_circuit::>( - &zkvm_cs, - &prog_config, - &PROGRAM_CODE.len(), - ) + .assign_table_circuit::>(&zkvm_cs, &prog_config, &program) .unwrap(); let timer = Instant::now(); diff --git a/ceno_zkvm/src/chip_handler.rs b/ceno_zkvm/src/chip_handler.rs index dc1422a73..8d16f342d 100644 --- a/ceno_zkvm/src/chip_handler.rs +++ b/ceno_zkvm/src/chip_handler.rs @@ -55,7 +55,6 @@ pub type AddressExpr = Expression; pub type MemoryExpr = Expression; pub trait MemoryChipOperations, N: FnOnce() -> NR> { - #[allow(dead_code)] fn memory_read( &mut self, name_fn: N, @@ -66,7 +65,6 @@ pub trait MemoryChipOperations, N: FnOnce() ) -> Result<(Expression, AssertLTConfig), ZKVMError>; #[allow(clippy::too_many_arguments)] - #[allow(dead_code)] fn memory_write( &mut self, name_fn: N, diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index f39f6b301..75d8268c3 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -17,7 +17,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { Self { cs } } - pub fn create_witin(&mut self, name_fn: N) -> Result + pub fn create_witin(&mut self, name_fn: N) -> WitIn where NR: Into, N: FnOnce() -> NR, @@ -148,6 +148,28 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { self.cs.rlc_chip_record(records) } + pub fn create_u8(&mut self, name_fn: N) -> Result + where + NR: Into, + N: FnOnce() -> NR + Clone, + { + let byte = self.cs.create_witin(name_fn.clone()); + self.assert_ux::<_, _, 8>(name_fn, byte.expr())?; + + Ok(byte) + } + + pub fn create_u16(&mut self, name_fn: N) -> Result + where + NR: Into, + N: FnOnce() -> NR + Clone, + { + let limb = self.cs.create_witin(name_fn.clone()); + self.assert_ux::<_, _, 16>(name_fn, limb.expr())?; + + Ok(limb) + } + pub fn require_zero( &mut self, name_fn: N, @@ -376,8 +398,8 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { lhs: Expression, rhs: Expression, ) -> Result<(WitIn, WitIn), ZKVMError> { - let is_eq = self.create_witin(|| "is_eq")?; - let diff_inverse = self.create_witin(|| "diff_inverse")?; + let is_eq = self.create_witin(|| "is_eq"); + let diff_inverse = self.create_witin(|| "diff_inverse"); self.require_zero( || "is equal", diff --git a/ceno_zkvm/src/chip_handler/memory.rs b/ceno_zkvm/src/chip_handler/memory.rs index b89bb4078..9a58c8a04 100644 --- a/ceno_zkvm/src/chip_handler/memory.rs +++ b/ceno_zkvm/src/chip_handler/memory.rs @@ -12,7 +12,6 @@ use ff_ext::ExtensionField; impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> MemoryChipOperations for CircuitBuilder<'a, E> { - #[allow(dead_code)] fn memory_read( &mut self, name_fn: N, diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index 1a29cbb88..cb326d345 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -214,14 +214,11 @@ impl ConstraintSystem { } } - pub fn create_witin, N: FnOnce() -> NR>( - &mut self, - n: N, - ) -> Result { + pub fn create_witin, N: FnOnce() -> NR>(&mut self, n: N) -> WitIn { let wit_in = WitIn { id: { let id = self.num_witin; - self.num_witin += 1; + self.num_witin = self.num_witin.strict_add(1); id }, }; @@ -229,7 +226,7 @@ impl ConstraintSystem { let path = self.ns.compute_path(n().into()); self.witin_namespace_map.push(path); - Ok(wit_in) + wit_in } pub fn create_fixed, N: FnOnce() -> NR>( @@ -443,8 +440,8 @@ impl ConstraintSystem { pub fn namespace, N: FnOnce() -> NR, T>( &mut self, name_fn: N, - cb: impl FnOnce(&mut ConstraintSystem) -> Result, - ) -> Result { + cb: impl FnOnce(&mut ConstraintSystem) -> T, + ) -> T { self.ns.push_namespace(name_fn().into()); let t = cb(self); self.ns.pop_namespace(); diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index 4d8aec48c..efbf08502 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -621,7 +621,7 @@ impl WitIn { || "from_expr", |cb| { let name = name().into(); - let wit = cb.create_witin(|| name.clone())?; + let wit = cb.create_witin(|| name.clone()); if !debug { cb.require_zero(|| name.clone(), wit.expr() - input)?; } @@ -876,7 +876,7 @@ mod tests { type E = GoldilocksExt2; let mut cs = ConstraintSystem::new(|| "test_root"); let mut cb = CircuitBuilder::::new(&mut cs); - let x = cb.create_witin(|| "x").unwrap(); + let x = cb.create_witin(|| "x"); // scaledsum * challenge // 3 * x + 2 @@ -942,9 +942,9 @@ mod tests { type E = GoldilocksExt2; let mut cs = ConstraintSystem::new(|| "test_root"); let mut cb = CircuitBuilder::::new(&mut cs); - let x = cb.create_witin(|| "x").unwrap(); - let y = cb.create_witin(|| "y").unwrap(); - let z = cb.create_witin(|| "z").unwrap(); + let x = cb.create_witin(|| "x"); + let y = cb.create_witin(|| "y"); + let z = cb.create_witin(|| "z"); // scaledsum * challenge // 3 * x + 2 let expr: Expression = @@ -984,8 +984,8 @@ mod tests { type E = GoldilocksExt2; let mut cs = ConstraintSystem::new(|| "test_root"); let mut cb = CircuitBuilder::::new(&mut cs); - let x = cb.create_witin(|| "x").unwrap(); - let y = cb.create_witin(|| "y").unwrap(); + let x = cb.create_witin(|| "x"); + let y = cb.create_witin(|| "y"); // scaledsum * challenge // (x + 1) * (y + 1) let expr: Expression = (Into::>::into(1usize) + x.expr()) diff --git a/ceno_zkvm/src/expression/monomial.rs b/ceno_zkvm/src/expression/monomial.rs index fa030595f..4c73c557b 100644 --- a/ceno_zkvm/src/expression/monomial.rs +++ b/ceno_zkvm/src/expression/monomial.rs @@ -136,11 +136,6 @@ impl PartialOrd for Expression { } } -#[allow(dead_code)] -fn cmp_field(a: &F, b: &F) -> Ordering { - a.to_canonical_u64().cmp(&b.to_canonical_u64()) -} - fn cmp_ext(a: &E, b: &E) -> Ordering { let a = a.as_bases().iter().map(|f| f.to_canonical_u64()); let b = b.as_bases().iter().map(|f| f.to_canonical_u64()); diff --git a/ceno_zkvm/src/gadgets/is_lt.rs b/ceno_zkvm/src/gadgets/is_lt.rs index f8d40cdee..a716ee8d5 100644 --- a/ceno_zkvm/src/gadgets/is_lt.rs +++ b/ceno_zkvm/src/gadgets/is_lt.rs @@ -86,7 +86,7 @@ impl IsLtConfig { || "is_lt", |cb| { let name = name_fn(); - let is_lt = cb.create_witin(|| format!("{name} is_lt witin"))?; + let is_lt = cb.create_witin(|| format!("{name} is_lt witin")); cb.assert_bit(|| "is_lt_bit", is_lt.expr())?; let config = InnerLtConfig::construct_circuit( @@ -153,7 +153,7 @@ impl InnerLtConfig { cb.namespace( || format!("var {var_name}"), |cb| { - let witin = cb.create_witin(|| var_name.to_string())?; + let witin = cb.create_witin(|| var_name.to_string()); cb.assert_ux::<_, _, 16>(|| name.clone(), witin.expr())?; Ok(witin) }, @@ -293,7 +293,7 @@ impl SignedLtConfig { || "is_signed_lt", |cb| { let name = name_fn(); - let is_lt = cb.create_witin(|| format!("{name} is_signed_lt witin"))?; + let is_lt = cb.create_witin(|| format!("{name} is_signed_lt witin")); cb.assert_bit(|| "is_lt_bit", is_lt.expr())?; let config = InnerSignedLtConfig::construct_circuit(cb, name, lhs, rhs, is_lt.expr())?; diff --git a/ceno_zkvm/src/gadgets/is_zero.rs b/ceno_zkvm/src/gadgets/is_zero.rs index 02994e4f7..ff8525afb 100644 --- a/ceno_zkvm/src/gadgets/is_zero.rs +++ b/ceno_zkvm/src/gadgets/is_zero.rs @@ -26,8 +26,8 @@ impl IsZeroConfig { x: Expression, ) -> Result { cb.namespace(name_fn, |cb| { - let is_zero = cb.create_witin(|| "is_zero")?; - let inverse = cb.create_witin(|| "inv")?; + let is_zero = cb.create_witin(|| "is_zero"); + let inverse = cb.create_witin(|| "inv"); // x==0 => is_zero=1 cb.require_one(|| "is_zero_1", is_zero.expr() + x.clone() * inverse.expr())?; diff --git a/ceno_zkvm/src/gadgets/mod.rs b/ceno_zkvm/src/gadgets/mod.rs index 23790b375..60846581e 100644 --- a/ceno_zkvm/src/gadgets/mod.rs +++ b/ceno_zkvm/src/gadgets/mod.rs @@ -1,8 +1,11 @@ mod div; mod is_lt; mod is_zero; +mod signed_ext; + pub use div::DivConfig; pub use is_lt::{ AssertLTConfig, AssertSignedLtConfig, InnerLtConfig, IsLtConfig, SignedLtConfig, cal_lt_diff, }; pub use is_zero::{IsEqualConfig, IsZeroConfig}; +pub use signed_ext::SignedExtendConfig; diff --git a/ceno_zkvm/src/gadgets/signed_ext.rs b/ceno_zkvm/src/gadgets/signed_ext.rs new file mode 100644 index 000000000..96706dc59 --- /dev/null +++ b/ceno_zkvm/src/gadgets/signed_ext.rs @@ -0,0 +1,93 @@ +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, ToExpr, WitIn}, + instructions::riscv::constants::UInt, + set_val, + witness::LkMultiplicity, +}; +use ff_ext::ExtensionField; +use std::mem::MaybeUninit; + +pub struct SignedExtendConfig { + /// most significant bit + msb: WitIn, + /// number of bits contained in the value + n_bits: usize, +} + +impl SignedExtendConfig { + pub fn construct_limb( + cb: &mut CircuitBuilder, + val: Expression, + ) -> Result { + Self::construct_circuit(cb, 16, val) + } + + pub fn construct_byte( + cb: &mut CircuitBuilder, + val: Expression, + ) -> Result { + Self::construct_circuit(cb, 8, val) + } + + fn construct_circuit( + cb: &mut CircuitBuilder, + n_bits: usize, + val: Expression, // it's assumed that val is within [0, 2^N_BITS) + ) -> Result { + assert!(n_bits == 8 || n_bits == 16); + + let msb = cb.create_witin(|| "msb"); + // require msb is boolean + cb.assert_bit(|| "msb is boolean", msb.expr())?; + + // assert 2*val - msb*2^N_BITS is within range [0, 2^N_BITS) + // - if val < 2^(N_BITS-1), then 2*val < 2^N_BITS, msb can only be zero. + // - otherwise, 2*val >= 2^N_BITS, then msb can only be one. + let assert_ux = match n_bits { + 8 => CircuitBuilder::::assert_ux::<_, _, 8>, + 16 => CircuitBuilder::::assert_ux::<_, _, 16>, + _ => unreachable!("unsupported n_bits = {}", n_bits), + }; + assert_ux( + cb, + || "0 <= 2*val - msb*2^N_BITS < 2^N_BITS", + 2 * val - msb.expr() * (1 << n_bits), + )?; + + Ok(SignedExtendConfig { msb, n_bits }) + } + + /// Get the signed extended value + pub fn signed_extended_value(&self, val: Expression) -> UInt { + assert_eq!(UInt::::LIMB_BITS, 16); + + let limb0 = match self.n_bits { + 8 => self.msb.expr() * 0xff00 + val, + 16 => val, + _ => unreachable!("unsupported N_BITS = {}", self.n_bits), + }; + UInt::from_exprs_unchecked(vec![limb0, self.msb.expr() * 0xffff]) + } + + pub fn assign_instance( + &self, + instance: &mut [MaybeUninit], + lk_multiplicity: &mut LkMultiplicity, + val: u64, + ) -> Result<(), ZKVMError> { + let msb = val >> (self.n_bits - 1); + + let assert_ux = match self.n_bits { + 8 => LkMultiplicity::assert_ux::<8>, + 16 => LkMultiplicity::assert_ux::<16>, + _ => unreachable!("unsupported n_bits = {}", self.n_bits), + }; + + assert_ux(lk_multiplicity, 2 * val - (msb << self.n_bits)); + set_val!(instance, self.msb, E::BaseField::from(msb)); + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index 1d0d2146a..0b0a287e6 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -171,8 +171,6 @@ impl Instruction for ArithInstruction Instruction for AddiInstruction { mod test { use ceno_emul::{Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use goldilocks::GoldilocksExt2; - use itertools::Itertools; - use multilinear_extensions::mle::IntoMLEs; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, @@ -124,18 +122,7 @@ mod test { ) .unwrap(); - MockProver::assert_satisfied( - &cb, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), - &[insn_code], - None, - Some(lkm), - ); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } #[test] @@ -168,17 +155,6 @@ mod test { ) .unwrap(); - MockProver::assert_satisfied( - &cb, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), - &[insn_code], - None, - Some(lkm), - ); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } } diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index b7c74543f..2edd1e5b6 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -1,5 +1,3 @@ -#![allow(dead_code)] // TODO: remove after BLT, BEQ, … - use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; @@ -57,7 +55,7 @@ impl BInstructionConfig { let rs2 = ReadRS2::construct_circuit(circuit_builder, rs2_read, vm_state.ts)?; // Immediate - let imm = circuit_builder.create_witin(|| "imm")?; + let imm = circuit_builder.create_witin(|| "imm"); // Fetch instruction circuit_builder.lk_fetch(&InsnRecord::new( diff --git a/ceno_zkvm/src/instructions/riscv/branch/test.rs b/ceno_zkvm/src/instructions/riscv/branch/test.rs index 36746fff3..9847c66d5 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/test.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/test.rs @@ -1,7 +1,5 @@ use ceno_emul::{ByteAddr, Change, PC_STEP_SIZE, StepRecord, Word, encode_rv32}; use goldilocks::GoldilocksExt2; -use itertools::Itertools; -use multilinear_extensions::mle::IntoMLEs; use super::*; use crate::{ @@ -49,18 +47,7 @@ fn impl_opcode_beq(equal: bool) { ]) .unwrap(); - MockProver::assert_satisfied( - &cb, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), - &[insn_code], - None, - Some(lkm), - ); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } #[test] @@ -98,18 +85,7 @@ fn impl_opcode_bne(equal: bool) { ]) .unwrap(); - MockProver::assert_satisfied( - &cb, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), - &[insn_code], - None, - Some(lkm), - ); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } #[test] @@ -150,18 +126,7 @@ fn impl_bltu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { ]) .unwrap(); - MockProver::assert_satisfied( - &circuit_builder, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), - &[insn_code], - None, - Some(lkm), - ); + MockProver::assert_satisfied_raw(&circuit_builder, raw_witin, &[insn_code], None, Some(lkm)); Ok(()) } @@ -202,18 +167,7 @@ fn impl_bgeu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { ]) .unwrap(); - MockProver::assert_satisfied( - &circuit_builder, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), - &[insn_code], - None, - Some(lkm), - ); + MockProver::assert_satisfied_raw(&circuit_builder, raw_witin, &[insn_code], None, Some(lkm)); Ok(()) } @@ -255,18 +209,7 @@ fn impl_blt_circuit(taken: bool, a: i32, b: i32) -> Result<(), ZKVMError> { ]) .unwrap(); - MockProver::assert_satisfied( - &circuit_builder, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), - &[insn_code], - None, - Some(lkm), - ); + MockProver::assert_satisfied_raw(&circuit_builder, raw_witin, &[insn_code], None, Some(lkm)); Ok(()) } @@ -308,17 +251,6 @@ fn impl_bge_circuit(taken: bool, a: i32, b: i32) -> Result<(), ZKVMError> { ]) .unwrap(); - MockProver::assert_satisfied( - &circuit_builder, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), - &[insn_code], - None, - Some(lkm), - ); + MockProver::assert_satisfied_raw(&circuit_builder, raw_witin, &[insn_code], None, Some(lkm)); Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs index 21ba4de0c..711ed51f4 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs @@ -34,7 +34,7 @@ impl Instruction for HaltInstruction { } fn construct_circuit(cb: &mut CircuitBuilder) -> Result { - let prev_x10_ts = cb.create_witin(|| "prev_x10_ts")?; + let prev_x10_ts = cb.create_witin(|| "prev_x10_ts"); let exit_code = { let exit_code = cb.query_exit_code()?; [exit_code[0].expr(), exit_code[1].expr()] diff --git a/ceno_zkvm/src/instructions/riscv/ecall_insn.rs b/ceno_zkvm/src/instructions/riscv/ecall_insn.rs index 49bc1d67a..d409b288a 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall_insn.rs @@ -28,8 +28,8 @@ impl EcallInstructionConfig { syscall_ret_value: Option>, next_pc: Option>, ) -> Result { - let pc = cb.create_witin(|| "pc")?; - let ts = cb.create_witin(|| "cur_ts")?; + let pc = cb.create_witin(|| "pc"); + let ts = cb.create_witin(|| "cur_ts"); cb.state_in(pc.expr(), ts.expr())?; cb.state_out( @@ -47,7 +47,7 @@ impl EcallInstructionConfig { 0.into(), // imm = 0 ))?; - let prev_x5_ts = cb.create_witin(|| "prev_x5_ts")?; + let prev_x5_ts = cb.create_witin(|| "prev_x5_ts"); // read syscall_id from x5 and write return value to x5 let (_, lt_x5_cfg) = cb.register_write( diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 5548786f5..cd799c29a 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -37,14 +37,14 @@ impl StateInOut { circuit_builder: &mut CircuitBuilder, branching: bool, ) -> Result { - let pc = circuit_builder.create_witin(|| "pc")?; + let pc = circuit_builder.create_witin(|| "pc"); let (next_pc_opt, next_pc_expr) = if branching { - let next_pc = circuit_builder.create_witin(|| "next_pc")?; + let next_pc = circuit_builder.create_witin(|| "next_pc"); (Some(next_pc), next_pc.expr()) } else { (None, pc.expr() + PC_STEP_SIZE) }; - let ts = circuit_builder.create_witin(|| "ts")?; + let ts = circuit_builder.create_witin(|| "ts"); let next_ts = ts.expr() + Tracer::SUBCYCLES_PER_INSN; circuit_builder.state_in(pc.expr(), ts.expr())?; circuit_builder.state_out(next_pc_expr, next_ts)?; @@ -87,8 +87,8 @@ impl ReadRS1 { rs1_read: RegisterExpr, cur_ts: WitIn, ) -> Result { - let id = circuit_builder.create_witin(|| "rs1_id")?; - let prev_ts = circuit_builder.create_witin(|| "prev_rs1_ts")?; + let id = circuit_builder.create_witin(|| "rs1_id"); + let prev_ts = circuit_builder.create_witin(|| "prev_rs1_ts"); let (_, lt_cfg) = circuit_builder.register_read( || "read_rs1", id, @@ -142,8 +142,8 @@ impl ReadRS2 { rs2_read: RegisterExpr, cur_ts: WitIn, ) -> Result { - let id = circuit_builder.create_witin(|| "rs2_id")?; - let prev_ts = circuit_builder.create_witin(|| "prev_rs2_ts")?; + let id = circuit_builder.create_witin(|| "rs2_id"); + let prev_ts = circuit_builder.create_witin(|| "prev_rs2_ts"); let (_, lt_cfg) = circuit_builder.register_read( || "read_rs2", id, @@ -197,8 +197,8 @@ impl WriteRD { rd_written: RegisterExpr, cur_ts: WitIn, ) -> Result { - let id = circuit_builder.create_witin(|| "rd_id")?; - let prev_ts = circuit_builder.create_witin(|| "prev_rd_ts")?; + let id = circuit_builder.create_witin(|| "rd_id"); + let prev_ts = circuit_builder.create_witin(|| "prev_rd_ts"); let prev_value = UInt::new_unchecked(|| "prev_rd_value", circuit_builder)?; let (_, lt_cfg) = circuit_builder.register_write( || "write_rd", @@ -258,7 +258,7 @@ impl ReadMEM { mem_read: Expression, cur_ts: WitIn, ) -> Result { - let prev_ts = circuit_builder.create_witin(|| "prev_ts")?; + let prev_ts = circuit_builder.create_witin(|| "prev_ts"); let (_, lt_cfg) = circuit_builder.memory_read( || "read_memory", &mem_addr, @@ -313,7 +313,7 @@ impl WriteMEM { new_value: MemoryExpr, cur_ts: WitIn, ) -> Result { - let prev_ts = circuit_builder.create_witin(|| "prev_ts")?; + let prev_ts = circuit_builder.create_witin(|| "prev_ts"); let (_, lt_cfg) = circuit_builder.memory_write( || "write_memory", @@ -408,7 +408,7 @@ impl MemAddr { // Witness and constrain the non-zero low bits. let low_bits = (n_zeros..Self::N_LOW_BITS) .map(|i| { - let bit = cb.create_witin(|| format!("addr_bit_{}", i))?; + let bit = cb.create_witin(|| format!("addr_bit_{}", i)); cb.assert_bit(|| format!("addr_bit_{}", i), bit.expr())?; Ok(bit) }) diff --git a/ceno_zkvm/src/instructions/riscv/jump/auipc.rs b/ceno_zkvm/src/instructions/riscv/jump/auipc.rs index 6c979ec25..7d20430b7 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/auipc.rs @@ -36,7 +36,7 @@ impl Instruction for AuipcInstruction { fn construct_circuit( circuit_builder: &mut CircuitBuilder, ) -> Result, ZKVMError> { - let imm = circuit_builder.create_witin(|| "imm")?; + let imm = circuit_builder.create_witin(|| "imm"); let rd_written = UInt::new(|| "rd_written", circuit_builder)?; let u_insn = UInstructionConfig::construct_circuit( @@ -46,7 +46,7 @@ impl Instruction for AuipcInstruction { rd_written.register_expr(), )?; - let overflow_bit = circuit_builder.create_witin(|| "overflow_bit")?; + let overflow_bit = circuit_builder.create_witin(|| "overflow_bit"); circuit_builder.assert_bit(|| "is_bit", overflow_bit.expr())?; // assert: imm + pc = rd_written + overflow_bit * 2^32 diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs index 889f3eca8..a80e11c4a 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs @@ -44,7 +44,7 @@ impl Instruction for JalrInstruction { circuit_builder: &mut CircuitBuilder, ) -> Result, ZKVMError> { let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; // unsigned 32-bit value - let imm = circuit_builder.create_witin(|| "imm")?; // signed 12-bit value + let imm = circuit_builder.create_witin(|| "imm"); // signed 12-bit value let rd_written = UInt::new(|| "rd_written", circuit_builder)?; let i_insn = IInstructionConfig::construct_circuit( @@ -63,7 +63,7 @@ impl Instruction for JalrInstruction { // 3. next_pc = next_pc_addr aligned to even value (round down) let next_pc_addr = MemAddr::::construct_unaligned(circuit_builder)?; - let overflow = circuit_builder.create_witin(|| "overflow")?; + let overflow = circuit_builder.create_witin(|| "overflow"); circuit_builder.require_equal( || "rs1+imm = next_pc_unrounded + overflow*2^32", diff --git a/ceno_zkvm/src/instructions/riscv/jump/test.rs b/ceno_zkvm/src/instructions/riscv/jump/test.rs index a1b17e911..887db8da3 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/test.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/test.rs @@ -1,7 +1,5 @@ use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, Word, encode_rv32}; use goldilocks::GoldilocksExt2; -use itertools::Itertools; -use multilinear_extensions::mle::IntoMLEs; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, @@ -45,18 +43,7 @@ fn test_opcode_jal() { ) .unwrap(); - MockProver::assert_satisfied( - &cb, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), - &[insn_code], - None, - Some(lkm), - ); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } #[test] @@ -93,18 +80,7 @@ fn test_opcode_jalr() { ) .unwrap(); - MockProver::assert_satisfied( - &cb, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), - &[insn_code], - None, - Some(lkm), - ); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } #[test] @@ -137,18 +113,7 @@ fn test_opcode_lui() { ) .unwrap(); - MockProver::assert_satisfied( - &cb, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), - &[insn_code], - None, - Some(lkm), - ); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } #[test] @@ -181,16 +146,5 @@ fn test_opcode_auipc() { ) .unwrap(); - MockProver::assert_satisfied( - &cb, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), - &[insn_code], - None, - Some(lkm), - ); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } diff --git a/ceno_zkvm/src/instructions/riscv/logic/test.rs b/ceno_zkvm/src/instructions/riscv/logic/test.rs index 50d73751e..b35c12bcb 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/test.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/test.rs @@ -1,7 +1,5 @@ use ceno_emul::{Change, StepRecord, Word, encode_rv32}; use goldilocks::GoldilocksExt2; -use itertools::Itertools; -use multilinear_extensions::mle::IntoMLEs; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, @@ -52,18 +50,7 @@ fn test_opcode_and() { .require_equal(|| "assert_rd_written", &mut cb, &expected_rd_written) .unwrap(); - MockProver::assert_satisfied( - &cb, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), - &[insn_code], - None, - Some(lkm), - ); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } #[test] @@ -103,18 +90,7 @@ fn test_opcode_or() { .require_equal(|| "assert_rd_written", &mut cb, &expected_rd_written) .unwrap(); - MockProver::assert_satisfied( - &cb, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), - &[insn_code], - None, - Some(lkm), - ); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } #[test] @@ -154,16 +130,5 @@ fn test_opcode_xor() { .require_equal(|| "assert_rd_written", &mut cb, &expected_rd_written) .unwrap(); - MockProver::assert_satisfied( - &cb, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), - &[insn_code], - None, - Some(lkm), - ); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs index 8972766f5..7fb061b5c 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs @@ -126,8 +126,6 @@ impl LogicConfig { mod test { use ceno_emul::{Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use goldilocks::GoldilocksExt2; - use itertools::Itertools; - use multilinear_extensions::mle::IntoMLEs; use crate::{ chip_handler::test::DebugIndex, @@ -205,17 +203,6 @@ mod test { cb.require_equal(|| "assert_rd_written", rd_written_expr, expected.value()) .unwrap(); - MockProver::assert_satisfied( - &cb, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), - &[insn_code], - None, - Some(lkm), - ); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } } diff --git a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs index 0980253db..8c6a36998 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs @@ -33,7 +33,7 @@ impl MemWordChange { -> Result, ZKVMError> { (0..num_bytes) .map(|i| { - let byte = cb.create_witin(|| format!("{}.le_bytes[{}]", anno, i))?; + let byte = cb.create_witin(|| format!("{}.le_bytes[{}]", anno, i)); cb.assert_ux::<_, _, 8>(|| "byte range check", byte.expr())?; Ok(byte) @@ -84,7 +84,7 @@ impl MemWordChange { )?; // alloc a new witIn to cache degree 2 expression - let expected_limb_change = cb.create_witin(|| "expected_limb_change")?; + let expected_limb_change = cb.create_witin(|| "expected_limb_change"); cb.condition_require_equal( || "expected_limb_change = select(low_bits[0], rs2 - prev)", low_bits[0].clone(), @@ -94,7 +94,7 @@ impl MemWordChange { )?; // alloc a new witIn to cache degree 2 expression - let expected_change = cb.create_witin(|| "expected_change")?; + let expected_change = cb.create_witin(|| "expected_change"); cb.condition_require_equal( || "expected_change = select(low_bits[1], limb_change*2^16, limb_change)", low_bits[1].clone(), @@ -117,7 +117,7 @@ impl MemWordChange { let prev_limbs = prev_word.expr(); let rs2_limbs = rs2_word.expr(); - let expected_change = cb.create_witin(|| "expected_change")?; + let expected_change = cb.create_witin(|| "expected_change"); // alloc a new witIn to cache degree 2 expression cb.condition_require_equal( diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index 1ebc7c55b..1bc76f311 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -1,34 +1,75 @@ +// will remove #[allow(dead_code)] when we finished fibonacci integration test use crate::{ Value, circuit_builder::CircuitBuilder, error::ZKVMError, + expression::{Expression, ToExpr, WitIn}, + gadgets::SignedExtendConfig, instructions::{ Instruction, - riscv::{RIVInstruction, constants::UInt, im_insn::IMInstructionConfig}, + riscv::{ + RIVInstruction, constants::UInt, im_insn::IMInstructionConfig, insn_base::MemAddr, + }, }, + set_val, + tables::InsnRecord, witness::LkMultiplicity, }; -use ceno_emul::{InsnKind, StepRecord}; +use ceno_emul::{ByteAddr, InsnKind, StepRecord}; use ff_ext::ExtensionField; +use itertools::izip; use std::{marker::PhantomData, mem::MaybeUninit}; pub struct LoadConfig { im_insn: IMInstructionConfig, rs1_read: UInt, - imm: UInt, + imm: WitIn, + memory_addr: MemAddr, + memory_read: UInt, - memory_addr: UInt, + target_limb: Option, + target_limb_bytes: Option>, + signed_extend_config: Option, } pub struct LoadInstruction(PhantomData<(E, I)>); -pub struct LWOp; +pub struct LwOp; -impl RIVInstruction for LWOp { +impl RIVInstruction for LwOp { const INST_KIND: InsnKind = InsnKind::LW; } -pub type LwInstruction = LoadInstruction; + +pub type LwInstruction = LoadInstruction; + +pub struct LhOp; +impl RIVInstruction for LhOp { + const INST_KIND: InsnKind = InsnKind::LH; +} +#[allow(dead_code)] +pub type LhInstruction = LoadInstruction; + +pub struct LhuOp; +impl RIVInstruction for LhuOp { + const INST_KIND: InsnKind = InsnKind::LHU; +} +#[allow(dead_code)] +pub type LhuInstruction = LoadInstruction; + +pub struct LbOp; +impl RIVInstruction for LbOp { + const INST_KIND: InsnKind = InsnKind::LB; +} +#[allow(dead_code)] +pub type LbInstruction = LoadInstruction; + +pub struct LbuOp; +impl RIVInstruction for LbuOp { + const INST_KIND: InsnKind = InsnKind::LBU; +} +#[allow(dead_code)] +pub type LbuInstruction = LoadInstruction; impl Instruction for LoadInstruction { type InstructionConfig = LoadConfig; @@ -40,14 +81,94 @@ impl Instruction for LoadInstruction, ) -> Result { - let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; - let imm = UInt::new(|| "imm", circuit_builder)?; + let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; // unsigned 32-bit value + let imm = circuit_builder.create_witin(|| "imm"); // signed 12-bit value let memory_read = UInt::new(|| "memory_read", circuit_builder)?; - let (memory_addr, memory_value) = match I::INST_KIND { - InsnKind::LW => ( - rs1_read.add(|| "memory_addr", circuit_builder, &imm, true)?, - memory_read.register_expr(), + let memory_addr = match I::INST_KIND { + InsnKind::LW => MemAddr::construct_align4(circuit_builder), + InsnKind::LH | InsnKind::LHU => MemAddr::construct_align2(circuit_builder), + InsnKind::LB | InsnKind::LBU => MemAddr::construct_unaligned(circuit_builder), + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }?; + + circuit_builder.require_equal( + || "memory_addr = rs1_read + imm", + memory_addr.expr_unaligned(), + rs1_read.value() + imm.expr(), + )?; + + let addr_low_bits = memory_addr.low_bit_exprs(); + let memory_value = memory_read.expr(); + + // get target limb from memory word for load instructions except LW + let target_limb = match I::INST_KIND { + InsnKind::LB | InsnKind::LBU | InsnKind::LH | InsnKind::LHU => { + let target_limb = circuit_builder.create_witin(|| "target_limb"); + circuit_builder.condition_require_equal( + || "target_limb = memory_value[low_bits[1]]", + addr_low_bits[1].clone(), + target_limb.expr(), + memory_value[1].clone(), + memory_value[0].clone(), + )?; + Some(target_limb) + } + _ => None, + }; + + // get target byte from memory word for LB and LBU + let (target_byte_expr, target_limb_bytes) = match I::INST_KIND { + InsnKind::LB | InsnKind::LBU => { + let target_byte = circuit_builder.create_u8(|| "limb.le_bytes[low_bits[0]]")?; + let dummy_byte = circuit_builder.create_u8(|| "limb.le_bytes[1-low_bits[0]]")?; + + circuit_builder.condition_require_equal( + || "target_byte = target_limb[low_bits[0]]", + addr_low_bits[0].clone(), + target_limb.unwrap().expr(), + target_byte.expr() * (1<<8) + dummy_byte.expr(), // target_byte = limb.le_bytes[1] + dummy_byte.expr() * (1<<8) + target_byte.expr(), // target_byte = limb.le_bytes[0] + )?; + + ( + Some(target_byte.expr()), + Some(vec![target_byte, dummy_byte]), + ) + } + _ => (None, None), + }; + let (signed_extend_config, rd_written) = match I::INST_KIND { + InsnKind::LW => (None, memory_read.clone()), + InsnKind::LH => { + let val = target_limb.unwrap(); + let signed_extend_config = + SignedExtendConfig::construct_limb(circuit_builder, val.expr())?; + let rd_written = signed_extend_config.signed_extended_value(val.expr()); + + (Some(signed_extend_config), rd_written) + } + InsnKind::LHU => { + ( + None, + // it's safe to unwrap as `UInt::from_exprs_unchecked` never return error + UInt::from_exprs_unchecked(vec![ + target_limb.as_ref().map(|limb| limb.expr()).unwrap(), + Expression::ZERO, + ]), + ) + } + InsnKind::LB => { + let val = target_byte_expr.unwrap(); + let signed_extend_config = + SignedExtendConfig::construct_byte(circuit_builder, val.clone())?; + let rd_written = signed_extend_config.signed_extended_value(val); + + (Some(signed_extend_config), rd_written) + } + InsnKind::LBU => ( + None, + UInt::from_exprs_unchecked(vec![target_byte_expr.unwrap(), Expression::ZERO]), ), _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), }; @@ -55,19 +176,22 @@ impl Instruction for LoadInstruction::construct_circuit( circuit_builder, I::INST_KIND, - &imm.value(), + &imm.expr(), rs1_read.register_expr(), memory_read.memory_expr(), - memory_addr.address_expr(), - memory_value, + memory_addr.expr_align4(), + rd_written.register_expr(), )?; Ok(LoadConfig { im_insn, rs1_read, - memory_read, imm, memory_addr, + memory_read, + target_limb, + target_limb_bytes, + signed_extend_config, }) } @@ -78,10 +202,22 @@ impl Instruction for LoadInstruction Result<(), ZKVMError> { let rs1 = Value::new_unchecked(step.rs1().unwrap().value); - let memory_read = Value::new(step.memory_op().unwrap().value.before, lk_multiplicity); - let imm = Value::new(step.insn().imm_or_funct7(), lk_multiplicity); - let memory_addr = rs1.add(&imm, lk_multiplicity, true); + let memory_value = step.memory_op().unwrap().value.before; + let memory_read = Value::new(memory_value, lk_multiplicity); + // imm is signed 12-bit value + let imm: E::BaseField = InsnRecord::imm_or_funct7_field(&step.insn()); + let unaligned_addr = ByteAddr::from( + step.rs1() + .unwrap() + .value + .wrapping_add(step.insn().imm_or_funct7()), + ); + let shift = unaligned_addr.shift(); + let addr_low_bits = [shift & 0x01, (shift >> 1) & 0x01]; + let target_limb = memory_read.as_u16_limbs()[addr_low_bits[1] as usize]; + let mut target_limb_bytes = target_limb.to_le_bytes(); + set_val!(instance, config.imm, imm); config .im_insn .assign_instance(instance, lk_multiplicity, step)?; @@ -89,8 +225,32 @@ impl Instruction for LoadInstruction(byte); + set_val!(instance, col, E::BaseField::from(byte)); + } + } + let val = match I::INST_KIND { + InsnKind::LB | InsnKind::LBU => target_limb_bytes[0] as u64, + InsnKind::LH | InsnKind::LHU => target_limb as u64, + _ => 0, + }; + if let Some(signed_ext_config) = config.signed_extend_config.as_ref() { + signed_ext_config.assign_instance::(instance, lk_multiplicity, val)?; + } Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/memory/store.rs b/ceno_zkvm/src/instructions/riscv/memory/store.rs index fc8f0455f..878777de6 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store.rs @@ -38,8 +38,7 @@ impl RIVInstruction for SWOp { const INST_KIND: InsnKind = InsnKind::SW; } -// this is actually used in test -#[allow(dead_code)] +#[cfg(test)] pub type SwInstruction = StoreInstruction; pub struct SHOp; @@ -48,8 +47,7 @@ impl RIVInstruction for SHOp { const INST_KIND: InsnKind = InsnKind::SH; } -// this is actually used in test -#[allow(dead_code)] +#[cfg(test)] pub type ShInstruction = StoreInstruction; pub struct SBOp; @@ -58,8 +56,7 @@ impl RIVInstruction for SBOp { const INST_KIND: InsnKind = InsnKind::SB; } -// this is actually used in test -#[allow(dead_code)] +#[cfg(test)] pub type SbInstruction = StoreInstruction; impl Instruction @@ -77,7 +74,7 @@ impl Instruction let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; // unsigned 32-bit value let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?; let prev_memory_value = UInt::new(|| "prev_memory_value", circuit_builder)?; - let imm = circuit_builder.create_witin(|| "imm")?; // signed 12-bit value + let imm = circuit_builder.create_witin(|| "imm"); // signed 12-bit value let memory_addr = match I::INST_KIND { InsnKind::SW => MemAddr::construct_align4(circuit_builder), diff --git a/ceno_zkvm/src/instructions/riscv/memory/test.rs b/ceno_zkvm/src/instructions/riscv/memory/test.rs index 61bcfd2f4..6243f197b 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/test.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/test.rs @@ -5,18 +5,20 @@ use crate::{ riscv::{ RIVInstruction, memory::{ - SbInstruction, ShInstruction, SwInstruction, + LwInstruction, SbInstruction, ShInstruction, SwInstruction, + load::{ + LbInstruction, LbOp, LbuInstruction, LbuOp, LhInstruction, LhOp, + LhuInstruction, LhuOp, LwOp, + }, store::{SBOp, SHOp, SWOp}, }, }, }, scheme::mock_prover::{MOCK_PC_START, MockProver}, }; -use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, Word, WriteOp, encode_rv32}; +use ceno_emul::{ByteAddr, Change, InsnKind, ReadOp, StepRecord, Word, WriteOp, encode_rv32}; use ff_ext::ExtensionField; use goldilocks::GoldilocksExt2; -use itertools::Itertools; -use multilinear_extensions::mle::IntoMLEs; use std::hash::Hash; fn sb(prev: Word, rs2: Word, shift: u32) -> Word { @@ -43,6 +45,35 @@ fn sw(_prev: Word, rs2: Word) -> Word { rs2 } +fn signed_extend(val: u32, n_bits: u32) -> u32 { + match n_bits { + 8 => (val as i8) as u32, + 16 => (val as i16) as u32, + _ => unreachable!("unsupported n_bits = {}", n_bits), + } +} + +fn load(mem_value: Word, insn: InsnKind, shift: u32) -> Word { + let val = mem_value >> (8 * shift); + match insn { + InsnKind::LB => signed_extend(val & 0xff_u32, 8), + InsnKind::LBU => val & 0xff_u32, + InsnKind::LH => { + assert_eq!(shift & 0x01, 0); + signed_extend(val & 0xffff_u32, 16) + } + InsnKind::LHU => { + assert_eq!(shift & 0x01, 0); + val & 0xffff_u32 + } + InsnKind::LW => { + assert_eq!(shift & 0x03, 0); + mem_value + } + _ => unreachable!(), + } +} + fn impl_opcode_store>(imm: u32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); @@ -88,18 +119,51 @@ fn impl_opcode_store>(imm: u32) { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || Inst::name(), + |cb| { + let config = Inst::construct_circuit(cb); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let insn_code = encode_rv32(I::INST_KIND, 2, 3, 0, imm); + let mem_value = 0x40302010; + let rs1_word = Word::from(0x4000000_u32); + let prev_rd_word = Word::from(0x12345678_u32); + let unaligned_addr = ByteAddr::from(rs1_word.wrapping_add(imm)); + let new_rd_word = load(mem_value, I::INST_KIND, unaligned_addr.shift()); + let rd_change = Change { + before: prev_rd_word, + after: new_rd_word, + }; + let (raw_witin, lkm) = Inst::assign_instances(&config, cb.cs.num_witin as usize, vec![ + StepRecord::new_im_instruction( + 12, + MOCK_PC_START, + insn_code, + rs1_word, + rd_change, + ReadOp { + addr: unaligned_addr.waddr(), + value: mem_value, + previous_cycle: 4, + }, + 8, + ), + ]) + .unwrap(); + + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } fn impl_opcode_sb(imm: u32) { @@ -148,3 +212,64 @@ fn test_sw() { let neg_four = u32::MAX - 3; impl_opcode_sw(neg_four); } + +#[test] +fn test_lb() { + impl_opcode_load::>(0); + impl_opcode_load::>(1); + impl_opcode_load::>(2); + impl_opcode_load::>(3); + + let neg_one = u32::MAX; + // imm = -1, -2, -3 + for i in 0..3 { + impl_opcode_load::>(neg_one - i); + } +} + +#[test] +fn test_lbu() { + impl_opcode_load::>(0); + impl_opcode_load::>(1); + impl_opcode_load::>(2); + impl_opcode_load::>(3); + + let neg_one = u32::MAX; + // imm = -1, -2, -3 + for i in 0..3 { + impl_opcode_load::>(neg_one - i); + } +} + +#[test] +fn test_lh() { + impl_opcode_load::>(0); + impl_opcode_load::>(2); + impl_opcode_load::>(4); + + let neg_two = u32::MAX - 1; + // imm = -2, -4 + for i in [0, 2] { + impl_opcode_load::>(neg_two - i); + } +} + +#[test] +fn test_lhu() { + impl_opcode_load::>(0); + impl_opcode_load::>(2); + impl_opcode_load::>(4); + + let neg_two = u32::MAX - 1; + // imm = -2, -4 + for i in [0, 2] { + impl_opcode_load::>(neg_two - i); + } +} + +#[test] +fn test_lw() { + impl_opcode_load::>(0); + impl_opcode_load::>(4); + impl_opcode_load::>(u32::MAX - 3); // imm = -4 +} diff --git a/ceno_zkvm/src/instructions/riscv/mulh.rs b/ceno_zkvm/src/instructions/riscv/mulh.rs index b852d9e2f..aa9704bdc 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh.rs @@ -119,8 +119,6 @@ impl Instruction for MulhInstruction { // Tables. pub u16_range_config: as TableCircuit>::TableConfig, + pub u14_range_config: as TableCircuit>::TableConfig, pub and_config: as TableCircuit>::TableConfig, pub ltu_config: as TableCircuit>::TableConfig, @@ -51,6 +53,7 @@ impl Rv32imConfig { // tables let u16_range_config = cs.register_table_circuit::>(); + let u14_range_config = cs.register_table_circuit::>(); let and_config = cs.register_table_circuit::>(); let ltu_config = cs.register_table_circuit::>(); @@ -70,6 +73,7 @@ impl Rv32imConfig { lui_config, lw_config, u16_range_config, + u14_range_config, and_config, ltu_config, @@ -95,6 +99,7 @@ impl Rv32imConfig { fixed.register_opcode_circuit::>(cs); fixed.register_table_circuit::>(cs, self.u16_range_config.clone(), &()); + fixed.register_table_circuit::>(cs, self.u14_range_config.clone(), &()); fixed.register_table_circuit::>(cs, self.and_config.clone(), &()); fixed.register_table_circuit::>(cs, self.ltu_config.clone(), &()); @@ -162,6 +167,7 @@ impl Rv32imConfig { public_io_final: &[MemFinalRecord], ) -> Result<(), ZKVMError> { witness.assign_table_circuit::>(cs, &self.u16_range_config, &())?; + witness.assign_table_circuit::>(cs, &self.u14_range_config, &())?; witness.assign_table_circuit::>(cs, &self.and_config, &())?; witness.assign_table_circuit::>(cs, &self.ltu_config, &())?; diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index 189811dd2..f5d0f8e8d 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -31,14 +31,16 @@ pub struct ShiftConfig { pub struct ShiftLogicalInstruction(PhantomData<(E, I)>); -#[allow(dead_code)] +#[cfg(test)] struct SllOp; +#[cfg(test)] impl RIVInstruction for SllOp { const INST_KIND: InsnKind = InsnKind::SLL; } -#[allow(dead_code)] +#[cfg(test)] struct SrlOp; +#[cfg(test)] impl RIVInstruction for SrlOp { const INST_KIND: InsnKind = InsnKind::SRL; } @@ -54,7 +56,7 @@ impl Instruction for ShiftLogicalInstru circuit_builder: &mut crate::circuit_builder::CircuitBuilder, ) -> Result { let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?; - let rs2_low5 = circuit_builder.create_witin(|| "rs2_low5")?; + let rs2_low5 = circuit_builder.create_witin(|| "rs2_low5"); // pow2_rs2_low5 is unchecked because it's assignment will be constrained due it's use in lookup_pow2 below let mut pow2_rs2_low5 = UInt::new_unchecked(|| "pow2_rs2_low5", circuit_builder)?; // rs2 = rs2_high | rs2_low5 @@ -193,8 +195,6 @@ impl Instruction for ShiftLogicalInstru mod tests { use ceno_emul::{Change, InsnKind, StepRecord, encode_rv32}; use goldilocks::GoldilocksExt2; - use itertools::Itertools; - use multilinear_extensions::mle::IntoMLEs; use crate::{ Value, @@ -294,17 +294,6 @@ mod tests { ) .unwrap(); - MockProver::assert_satisfied( - &cb, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), - &[insn_code], - None, - Some(lkm), - ); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } } diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index 166d91e21..1f43da85c 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -57,11 +57,11 @@ impl Instruction for ShiftImmInstructio circuit_builder: &mut CircuitBuilder, ) -> Result { // Note: `imm` wtns is set to 2**imm (upto 32 bit) just for efficient verification. - let imm = circuit_builder.create_witin(|| "imm")?; + let imm = circuit_builder.create_witin(|| "imm"); let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; let rd_written = UInt::new(|| "rd_written", circuit_builder)?; - let outflow = circuit_builder.create_witin(|| "outflow")?; + let outflow = circuit_builder.create_witin(|| "outflow"); let assert_lt_config = AssertLTConfig::construct_circuit( circuit_builder, || "outflow < imm", @@ -89,7 +89,7 @@ impl Instruction for ShiftImmInstructio let is_rs1_neg = IsLtConfig::construct_circuit( circuit_builder, || "lhs_msb", - max_signed_limb_expr.clone(), + max_signed_limb_expr, rs1_read.limbs.iter().last().unwrap().expr(), // msb limb 1, )?; @@ -179,8 +179,6 @@ impl Instruction for ShiftImmInstructio mod test { use ceno_emul::{Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use goldilocks::GoldilocksExt2; - use itertools::Itertools; - use multilinear_extensions::mle::IntoMLEs; use super::{ShiftImmInstruction, SlliOp, SraiOp, SrliOp}; use crate::{ @@ -300,17 +298,6 @@ mod test { ) .unwrap(); - MockProver::assert_satisfied( - &cb, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), - &[insn_code], - None, - Some(lkm), - ); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } } diff --git a/ceno_zkvm/src/instructions/riscv/slt.rs b/ceno_zkvm/src/instructions/riscv/slt.rs index c70f3bdb0..830f571cd 100644 --- a/ceno_zkvm/src/instructions/riscv/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/slt.rs @@ -17,7 +17,7 @@ pub struct SltConfig { rs1_read: UInt, rs2_read: UInt, - #[allow(dead_code)] + #[cfg_attr(not(test), allow(dead_code))] rd_written: UInt, signed_lt: SignedLtConfig, @@ -39,7 +39,7 @@ impl Instruction for SltInstruction { let rs2_read = UInt::new_unchecked(|| "rs2_read", cb)?; let lt = SignedLtConfig::construct_circuit(cb, || "rs1 < rs2", &rs1_read, &rs2_read)?; - let rd_written = UInt::from_exprs_unchecked(vec![lt.expr()])?; + let rd_written = UInt::from_exprs_unchecked(vec![lt.expr()]); let r_insn = RInstructionConfig::::construct_circuit( cb, @@ -90,8 +90,6 @@ mod test { use ceno_emul::{Change, StepRecord, Word, encode_rv32}; use goldilocks::GoldilocksExt2; - use itertools::Itertools; - use multilinear_extensions::mle::IntoMLEs; use rand::Rng; use super::*; @@ -137,18 +135,7 @@ mod test { .require_equal(|| "assert_rd_written", &mut cb, &expected_rd_written) .unwrap(); - MockProver::assert_satisfied( - &cb, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), - &[insn_code], - None, - Some(lkm), - ); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index af31961df..71d0bb26f 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -45,13 +45,13 @@ impl Instruction for SltiInstruction { fn construct_circuit(cb: &mut CircuitBuilder) -> Result { // If rs1_read < imm, rd_written = 1. Otherwise rd_written = 0 let rs1_read = UInt::new_unchecked(|| "rs1_read", cb)?; - let imm = cb.create_witin(|| "imm")?; + let imm = cb.create_witin(|| "imm"); let max_signed_limb_expr: Expression<_> = ((1 << (UInt::::LIMB_BITS - 1)) - 1).into(); let is_rs1_neg = IsLtConfig::construct_circuit( cb, || "lhs_msb", - max_signed_limb_expr.clone(), + max_signed_limb_expr, rs1_read.limbs.iter().last().unwrap().expr(), // msb limb 1, )?; @@ -63,7 +63,7 @@ impl Instruction for SltiInstruction { imm.expr(), UINT_LIMBS, )?; - let rd_written = UInt::from_exprs_unchecked(vec![lt.expr()])?; + let rd_written = UInt::from_exprs_unchecked(vec![lt.expr()]); let i_insn = IInstructionConfig::::construct_circuit( cb, @@ -122,8 +122,6 @@ mod test { use ceno_emul::{Change, PC_STEP_SIZE, StepRecord, Word, encode_rv32}; use goldilocks::GoldilocksExt2; - use itertools::Itertools; - use multilinear_extensions::mle::IntoMLEs; use rand::Rng; use super::*; @@ -168,18 +166,7 @@ mod test { .require_equal(|| "assert_rd_written", &mut cb, &expected_rd_written) .unwrap(); - MockProver::assert_satisfied( - &cb, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), - &[insn_code], - None, - Some(lkm), - ); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/sltu.rs b/ceno_zkvm/src/instructions/riscv/sltu.rs index 64d73869a..26f15b9bc 100644 --- a/ceno_zkvm/src/instructions/riscv/sltu.rs +++ b/ceno_zkvm/src/instructions/riscv/sltu.rs @@ -21,7 +21,7 @@ pub struct ArithConfig { rs1_read: UInt, rs2_read: UInt, - #[allow(dead_code)] + #[cfg_attr(not(test), allow(dead_code))] rd_written: UInt, is_lt: IsLtConfig, @@ -57,7 +57,7 @@ impl Instruction for ArithInstruction::construct_circuit( circuit_builder, @@ -107,8 +107,6 @@ impl Instruction for ArithInstruction::new(cs); - let config = AddInstruction::construct_circuit(&mut circuit_builder); - Ok(config) - }, + |cs| AddInstruction::construct_circuit(&mut CircuitBuilder::::new(cs)), ); let _sub_config = cs.namespace( || "sub", - |cs| { - let mut circuit_builder = CircuitBuilder::::new(cs); - let config = SubInstruction::construct_circuit(&mut circuit_builder); - Ok(config) - }, + |cs| SubInstruction::construct_circuit(&mut CircuitBuilder::::new(cs)), ); let param = Pcs::setup(1 << 10).unwrap(); let (pp, _) = Pcs::trim(¶m, 1 << 10).unwrap(); diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index cb5978754..75b4b377d 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -1,6 +1,7 @@ #![feature(box_patterns)] #![feature(stmt_expr_attributes)] #![feature(variant_count)] +#![feature(strict_overflow_ops)] pub mod error; pub mod instructions; diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index ada3e028a..74fe11006 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -8,19 +8,19 @@ use crate::{ AndTable, LtuTable, OpsTable, OrTable, PowTable, ProgramTableCircuit, RangeTable, TableCircuit, U5Table, U8Table, U14Table, U16Table, XorTable, }, - witness::LkMultiplicity, + witness::{LkMultiplicity, RowMajorMatrix}, }; use ark_std::test_rng; use base64::{Engine, engine::general_purpose::STANDARD_NO_PAD}; -use ceno_emul::{ByteAddr, CENO_PLATFORM}; +use ceno_emul::{ByteAddr, CENO_PLATFORM, PC_WORD_SIZE, Program}; use ff::Field; use ff_ext::ExtensionField; use generic_static::StaticTypeMap; use goldilocks::SmallField; use itertools::{Itertools, izip}; -use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension; +use multilinear_extensions::{mle::IntoMLEs, virtual_poly_v2::ArcMultilinearExtension}; use std::{ - collections::HashSet, + collections::{BTreeMap, HashSet}, fs::File, hash::Hash, io::{BufReader, ErrorKind}, @@ -31,7 +31,7 @@ use std::{ use strum::IntoEnumIterator; const MOCK_PROGRAM_SIZE: usize = 32; -pub const MOCK_PC_START: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start()); +pub const MOCK_PC_START: ByteAddr = ByteAddr(CENO_PLATFORM.pc_base()); #[allow(clippy::enum_variant_names)] #[derive(Debug, Clone)] @@ -389,10 +389,28 @@ impl<'a, E: ExtensionField + Hash> MockProver { lkm: Option, ) -> Result<(), Vec>> { // fix the program table - let mut programs = [0u32; MOCK_PROGRAM_SIZE]; - for (i, &program) in input_programs.iter().enumerate() { - programs[i] = program; - } + let instructions = input_programs + .iter() + .cloned() + .chain(std::iter::repeat(0)) + .take(MOCK_PROGRAM_SIZE) + .collect_vec(); + let image = instructions + .iter() + .enumerate() + .map(|(insn_idx, &insn)| { + ( + CENO_PLATFORM.pc_base() + (insn_idx * PC_WORD_SIZE) as u32, + insn, + ) + }) + .collect::>(); + let program = Program::new( + CENO_PLATFORM.pc_base(), + CENO_PLATFORM.pc_base(), + instructions, + image, + ); // load tables let (challenge, mut table) = if let Some(challenge) = challenge { @@ -401,7 +419,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { load_once_tables(cb) }; let mut prog_table = vec![]; - Self::load_program_table(&mut prog_table, &programs, challenge); + Self::load_program_table(&mut prog_table, &program, challenge); for prog in prog_table { table.insert(prog); } @@ -589,11 +607,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { } } - fn load_program_table( - t_vec: &mut Vec>, - programs: &[u32; MOCK_PROGRAM_SIZE], - challenge: [E; 2], - ) { + fn load_program_table(t_vec: &mut Vec>, program: &Program, challenge: [E; 2]) { let mut cs = ConstraintSystem::::new(|| "mock_program"); let mut cb = CircuitBuilder::new(&mut cs); let config = @@ -601,7 +615,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { let fixed = ProgramTableCircuit::::generate_fixed_traces( &config, cs.num_fixed, - programs, + program, ); for table_expr in &cs.lk_table_expressions { for row in fixed.iter_rows() { @@ -672,6 +686,21 @@ Hints: } } + pub fn assert_satisfied_raw( + cb: &CircuitBuilder, + raw_witin: RowMajorMatrix, + programs: &[u32], + challenge: Option<[E; 2]>, + lkm: Option, + ) { + let wits_in = raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(); + Self::assert_satisfied(cb, &wits_in, programs, challenge, lkm); + } pub fn assert_satisfied( cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>], @@ -698,13 +727,15 @@ mod tests { }; use ff::Field; use goldilocks::{Goldilocks, GoldilocksExt2}; - use multilinear_extensions::mle::{IntoMLE, IntoMLEs}; + use multilinear_extensions::mle::IntoMLE; #[derive(Debug)] - #[allow(dead_code)] struct AssertZeroCircuit { + #[allow(dead_code)] pub a: WitIn, + #[allow(dead_code)] pub b: WitIn, + #[allow(dead_code)] pub c: WitIn, } @@ -712,16 +743,16 @@ mod tests { pub fn construct_circuit( cb: &mut CircuitBuilder, ) -> Result { - let a = cb.create_witin(|| "a")?; - let b = cb.create_witin(|| "b")?; - let c = cb.create_witin(|| "c")?; + let a = cb.create_witin(|| "a"); + let b = cb.create_witin(|| "b"); + let c = cb.create_witin(|| "c"); // degree 1 cb.require_equal(|| "a + 1 == b", b.expr(), a.expr() + 1)?; cb.require_zero(|| "c - 2 == 0", c.expr() - 2)?; // degree > 1 - let d = cb.create_witin(|| "d")?; + let d = cb.create_witin(|| "d"); cb.require_zero( || "d*d - 6*d + 9 == 0", d.expr() * d.expr() - d.expr() * 6 + 9, @@ -766,7 +797,7 @@ mod tests { pub fn construct_circuit( cb: &mut CircuitBuilder, ) -> Result { - let a = cb.create_witin(|| "a")?; + let a = cb.create_witin(|| "a"); cb.assert_ux::<_, _, 5>(|| "assert u5", a.expr())?; Ok(Self { a }) } @@ -832,7 +863,6 @@ mod tests { assert_eq!(err[0].inst_id(), 0); } - #[allow(dead_code)] #[derive(Debug)] struct AssertLtCircuit { pub a: WitIn, @@ -847,8 +877,8 @@ mod tests { impl AssertLtCircuit { fn construct_circuit(cb: &mut CircuitBuilder) -> Result { - let a = cb.create_witin(|| "a")?; - let b = cb.create_witin(|| "b")?; + let a = cb.create_witin(|| "a"); + let b = cb.create_witin(|| "b"); let lt_wtns = AssertLTConfig::construct_circuit(cb, || "lt", a.expr(), b.expr(), 1)?; Ok(Self { a, b, lt_wtns }) } @@ -905,14 +935,9 @@ mod tests { ) .unwrap(); - MockProver::assert_satisfied( + MockProver::assert_satisfied_raw( &builder, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), + raw_witin, &[], Some([1.into(), 1000.into()]), None, @@ -943,14 +968,9 @@ mod tests { ) .unwrap(); - MockProver::assert_satisfied( + MockProver::assert_satisfied_raw( &builder, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), + raw_witin, &[], Some([1.into(), 1000.into()]), None, @@ -971,8 +991,8 @@ mod tests { impl LtCircuit { fn construct_circuit(cb: &mut CircuitBuilder) -> Result { - let a = cb.create_witin(|| "a")?; - let b = cb.create_witin(|| "b")?; + let a = cb.create_witin(|| "a"); + let b = cb.create_witin(|| "b"); let lt_wtns = IsLtConfig::construct_circuit(cb, || "lt", a.expr(), b.expr(), 1)?; Ok(Self { a, b, lt_wtns }) } @@ -1029,14 +1049,9 @@ mod tests { ) .unwrap(); - MockProver::assert_satisfied( + MockProver::assert_satisfied_raw( &builder, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), + raw_witin, &[], Some([1.into(), 1000.into()]), None, @@ -1068,14 +1083,9 @@ mod tests { ) .unwrap(); - MockProver::assert_satisfied( + MockProver::assert_satisfied_raw( &builder, - &raw_witin - .de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(), + raw_witin, &[], Some([1.into(), 1000.into()]), None, diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index b64e34d99..825f93724 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -1,9 +1,9 @@ use std::{marker::PhantomData, mem::MaybeUninit}; use ceno_emul::{ - ByteAddr, CENO_PLATFORM, + CENO_PLATFORM, InsnKind::{ADD, EANY}, - StepRecord, VMState, + PC_WORD_SIZE, Program, StepRecord, VMState, }; use ff::Field; use ff_ext::ExtensionField; @@ -50,7 +50,7 @@ impl Instruction for Test } fn construct_circuit(cb: &mut CircuitBuilder) -> Result { - let reg_id = cb.create_witin(|| "reg_id")?; + let reg_id = cb.create_witin(|| "reg_id"); (0..RW).try_for_each(|_| { let record = cb.rlc_chip_record(vec![ Expression::::Constant(E::BaseField::ONE), @@ -202,6 +202,23 @@ fn test_single_add_instance_e2e() { type E = GoldilocksExt2; type Pcs = Basefold; + // set up program + let program = Program::new( + CENO_PLATFORM.pc_base(), + CENO_PLATFORM.pc_base(), + PROGRAM_CODE.to_vec(), + PROGRAM_CODE + .iter() + .enumerate() + .map(|(insn_idx, &insn)| { + ( + (insn_idx * PC_WORD_SIZE) as u32 + CENO_PLATFORM.pc_base(), + insn, + ) + }) + .collect(), + ); + let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup"); let (pp, vp) = Pcs::trim(&pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); let mut zkvm_cs = ZKVMConstraintSystem::default(); @@ -225,7 +242,7 @@ fn test_single_add_instance_e2e() { zkvm_fixed_traces.register_table_circuit::>( &zkvm_cs, prog_config.clone(), - &PROGRAM_CODE, + &program, ); let pk = zkvm_cs @@ -235,11 +252,7 @@ fn test_single_add_instance_e2e() { let vk = pk.get_vk(); // single instance - let mut vm = VMState::new(CENO_PLATFORM); - let pc_start = ByteAddr(CENO_PLATFORM.pc_start()).waddr(); - for (i, insn) in PROGRAM_CODE.iter().enumerate() { - vm.init_memory(pc_start + i, *insn); - } + let mut vm = VMState::new(CENO_PLATFORM, program.clone()); let all_records = vm .iter_until_halt() .collect::, _>>() @@ -282,7 +295,7 @@ fn test_single_add_instance_e2e() { .assign_table_circuit::>( &zkvm_cs, &prog_config, - &PROGRAM_CODE.len(), + &program, ) .unwrap(); diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 6944eb856..81ebfcf78 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -349,7 +349,7 @@ pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( ) } -#[allow(dead_code)] +#[cfg(test)] pub(crate) fn eval_by_expr( witnesses: &[E], challenges: &[E], @@ -358,7 +358,7 @@ pub(crate) fn eval_by_expr( eval_by_expr_with_fixed(&[], witnesses, challenges, expr) } -#[allow(dead_code)] +#[cfg(test)] pub(crate) fn eval_by_expr_with_fixed( fixed: &[E], witnesses: &[E], @@ -670,9 +670,9 @@ mod tests { type B = goldilocks::Goldilocks; let mut cs = ConstraintSystem::::new(|| "test"); let mut cb = CircuitBuilder::new(&mut cs); - let a = cb.create_witin(|| "a").unwrap(); - let b = cb.create_witin(|| "b").unwrap(); - let c = cb.create_witin(|| "c").unwrap(); + let a = cb.create_witin(|| "a"); + let b = cb.create_witin(|| "b"); + let c = cb.create_witin(|| "c"); let expr: Expression = a.expr() + b.expr() + a.expr() * b.expr() + (c.expr() * 3 + 2); @@ -696,9 +696,9 @@ mod tests { type B = goldilocks::Goldilocks; let mut cs = ConstraintSystem::::new(|| "test"); let mut cb = CircuitBuilder::new(&mut cs); - let a = cb.create_witin(|| "a").unwrap(); - let b = cb.create_witin(|| "b").unwrap(); - let c = cb.create_witin(|| "c").unwrap(); + let a = cb.create_witin(|| "a"); + let b = cb.create_witin(|| "b"); + let c = cb.create_witin(|| "c"); let expr: Expression = a.expr() + b.expr() diff --git a/ceno_zkvm/src/tables/ops/ops_impl.rs b/ceno_zkvm/src/tables/ops/ops_impl.rs index f9bf4d0c8..04939a596 100644 --- a/ceno_zkvm/src/tables/ops/ops_impl.rs +++ b/ceno_zkvm/src/tables/ops/ops_impl.rs @@ -32,7 +32,7 @@ impl OpTableConfig { cb.create_fixed(|| "b")?, cb.create_fixed(|| "c")?, ]; - let mlt = cb.create_witin(|| "mlt")?; + let mlt = cb.create_witin(|| "mlt"); let rlc_record = cb.rlc_chip_record(vec![ (rom_type as usize).into(), diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 3514365c8..d5b629957 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -10,7 +10,7 @@ use crate::{ tables::TableCircuit, witness::RowMajorMatrix, }; -use ceno_emul::{CENO_PLATFORM, DecodedInstruction, PC_STEP_SIZE, WORD_SIZE}; +use ceno_emul::{DecodedInstruction, PC_STEP_SIZE, Program, WORD_SIZE}; use ff_ext::ExtensionField; use goldilocks::SmallField; use itertools::Itertools; @@ -117,8 +117,8 @@ impl TableCircuit for ProgramTableCircuit { type TableConfig = ProgramTableConfig; - type FixedInput = [u32; PROGRAM_SIZE]; - type WitnessInput = usize; + type FixedInput = Program; + type WitnessInput = Program; fn name() -> String { "PROGRAM".into() @@ -135,7 +135,7 @@ impl TableCircuit cb.create_fixed(|| "imm_or_funct7")?, ]); - let mlt = cb.create_witin(|| "mlt")?; + let mlt = cb.create_witin(|| "mlt"); let record_exprs = { let mut fields = vec![E::BaseField::from(ROMType::Instruction as u64).expr()]; @@ -153,9 +153,8 @@ impl TableCircuit num_fixed: usize, program: &Self::FixedInput, ) -> RowMajorMatrix { - // TODO: get bytecode of the program. - let num_instructions = program.len(); - let pc_start = CENO_PLATFORM.pc_start(); + let num_instructions = program.instructions.len(); + let pc_base = program.base_address; let mut fixed = RowMajorMatrix::::new(num_instructions, num_fixed); @@ -164,8 +163,8 @@ impl TableCircuit .with_min_len(MIN_PAR_SIZE) .zip((0..num_instructions).into_par_iter()) .for_each(|(row, i)| { - let pc = pc_start + (i * PC_STEP_SIZE) as u32; - let insn = DecodedInstruction::new(program[i]); + let pc = pc_base + (i * PC_STEP_SIZE) as u32; + let insn = DecodedInstruction::new(program.instructions[i]); let values = InsnRecord::from_decoded(pc, &insn); // Copy all the fields except immediate. @@ -192,13 +191,13 @@ impl TableCircuit config: &Self::TableConfig, num_witin: usize, multiplicity: &[HashMap], - num_instructions: &usize, + program: &Program, ) -> Result, ZKVMError> { let multiplicity = &multiplicity[ROMType::Instruction as usize]; - let mut prog_mlt = vec![0_usize; *num_instructions]; + let mut prog_mlt = vec![0_usize; program.instructions.len()]; for (pc, mlt) in multiplicity { - let i = (*pc as usize - CENO_PLATFORM.pc_start() as usize) / WORD_SIZE; + let i = (*pc as usize - program.base_address as usize) / WORD_SIZE; prog_mlt[i] = *mlt; } diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index 6b11fec59..39b5f043a 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -41,12 +41,12 @@ impl NonVolatileTableConfig, ZKVMError>>()?; let addr = cb.create_fixed(|| "addr")?; - let final_cycle = cb.create_witin(|| "final_cycle")?; + let final_cycle = cb.create_witin(|| "final_cycle"); let final_v = if NVRAM::WRITABLE { Some( (0..NVRAM::V_LIMBS) .map(|i| cb.create_witin(|| format!("final_v_limb_{i}"))) - .collect::, ZKVMError>>()?, + .collect::>(), ) } else { None diff --git a/ceno_zkvm/src/tables/range/range_impl.rs b/ceno_zkvm/src/tables/range/range_impl.rs index 8a14fe236..7e6bb62f6 100644 --- a/ceno_zkvm/src/tables/range/range_impl.rs +++ b/ceno_zkvm/src/tables/range/range_impl.rs @@ -28,7 +28,7 @@ impl RangeTableConfig { table_len: usize, ) -> Result { let fixed = cb.create_fixed(|| "fixed")?; - let mlt = cb.create_witin(|| "mlt")?; + let mlt = cb.create_witin(|| "mlt"); let rlc_record = cb.rlc_chip_record(vec![(rom_type as usize).into(), Expression::Fixed(fixed)]); diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index f243e3769..63f3dc7ac 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -90,7 +90,7 @@ impl UIntLimbs { limbs: UintLimb::WitIn( (0..Self::NUM_LIMBS) .map(|i| { - let w = cb.create_witin(|| format!("limb_{i}"))?; + let w = cb.create_witin(|| format!("limb_{i}")); if is_check { cb.assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), w.expr())?; } @@ -162,7 +162,7 @@ impl UIntLimbs { assert_eq!(expr_limbs.len(), Self::NUM_LIMBS); let limbs = (0..Self::NUM_LIMBS) .map(|i| { - let w = circuit_builder.create_witin(|| "wit for limb").unwrap(); + let w = circuit_builder.create_witin(|| "wit for limb"); circuit_builder .assert_ux::<_, _, C>(|| "range check", w.expr()) .unwrap(); @@ -302,7 +302,7 @@ impl UIntLimbs { .unwrap() }) .collect_vec(); - UIntLimbs::::from_exprs_unchecked(combined_limbs) + Ok(UIntLimbs::::from_exprs_unchecked(combined_limbs)) } pub fn to_u8_limbs( @@ -324,7 +324,7 @@ impl UIntLimbs { .flat_map(|large_limb| { let limbs = (0..k) .map(|_| { - let w = circuit_builder.create_witin(|| "").unwrap(); + let w = circuit_builder.create_witin(|| ""); circuit_builder.assert_byte(|| "", w.expr()).unwrap(); w.expr() }) @@ -345,8 +345,8 @@ impl UIntLimbs { UIntLimbs::::create_witin_from_exprs(circuit_builder, split_limbs) } - pub fn from_exprs_unchecked(expr_limbs: Vec>) -> Result { - let n = Self { + pub fn from_exprs_unchecked(expr_limbs: Vec>) -> Self { + Self { limbs: UintLimb::Expression( expr_limbs .into_iter() @@ -356,8 +356,7 @@ impl UIntLimbs { ), carries: None, carries_auxiliary_lt_config: None, - }; - Ok(n) + } } /// If current limbs are Expression, this function will create witIn and replace the limbs @@ -371,7 +370,7 @@ impl UIntLimbs { self.limbs = UintLimb::WitIn( (0..Self::NUM_LIMBS) .map(|i| { - let w = cb.create_witin(|| format!("limb_{i}"))?; + let w = cb.create_witin(|| format!("limb_{i}")); cb.assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), w.expr())?; Ok(w) }) @@ -401,7 +400,7 @@ impl UIntLimbs { self.carries = Some( (0..carries_len) .map(|i| { - let c = cb.create_witin(|| format!("carry_{i}"))?; + let c = cb.create_witin(|| format!("carry_{i}")); Ok(c) }) .collect::, ZKVMError>>()?, @@ -523,8 +522,8 @@ impl UIntLimbs { let mut self_lo = self.expr(); let self_hi = self_lo.split_off(self_lo.len() / 2); Ok(( - UIntLimbs::from_exprs_unchecked(self_lo)?, - UIntLimbs::from_exprs_unchecked(self_hi)?, + UIntLimbs::from_exprs_unchecked(self_lo), + UIntLimbs::from_exprs_unchecked(self_hi), )) } @@ -638,7 +637,6 @@ impl ValueMul { #[derive(Clone)] pub struct Value<'a, T: Into + From + Copy + Default> { - #[allow(dead_code)] val: T, pub limbs: Cow<'a, [u16]>, } diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 62754aa0d..01413f6a3 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -119,7 +119,7 @@ impl UIntLimbs { }; // with high limb, overall cell will be double let c_limbs: Vec = (0..num_limbs).try_fold(vec![], |mut c_limbs, i| { - let limb = circuit_builder.create_witin(|| format!("limb_{i}"))?; + let limb = circuit_builder.create_witin(|| format!("limb_{i}")); circuit_builder.assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), limb.expr())?; c_limbs.push(limb); Result::, ZKVMError>::Ok(c_limbs) @@ -127,7 +127,7 @@ impl UIntLimbs { let c_carries: Vec = (0..num_limbs).try_fold(vec![], |mut c_carries, i| { // skip last carry if with_overflow == false if i != num_limbs - 1 || with_overflow { - let carry = circuit_builder.create_witin(|| format!("carry_{i}"))?; + let carry = circuit_builder.create_witin(|| format!("carry_{i}")); c_carries.push(carry); } Result::, ZKVMError>::Ok(c_carries) @@ -241,7 +241,7 @@ impl UIntLimbs { mul_hi } else { // lo limb - UIntLimbs::from_exprs_unchecked(mul.expr())? + UIntLimbs::from_exprs_unchecked(mul.expr()) }; let add = cb.namespace( || "add", @@ -302,7 +302,7 @@ impl UIntLimbs { where E: ExtensionField, { - let high_limb_no_msb = circuit_builder.create_witin(|| "high_limb_mask")?; + let high_limb_no_msb = circuit_builder.create_witin(|| "high_limb_mask"); let high_limb = self.limbs[Self::NUM_LIMBS - 1].expr(); circuit_builder.lookup_and_byte( @@ -329,7 +329,7 @@ impl UIntLimbs { let n_bytes = Self::NUM_LIMBS; let indexes: Vec = (0..n_bytes) .map(|_| circuit_builder.create_witin(|| "index")) - .collect::>()?; + .collect(); // indicate the first non-zero byte index i_0 of a[i] - b[i] // from high to low @@ -342,7 +342,7 @@ impl UIntLimbs { // circuit_builder.assert_bit(|| "bit assert", index_sum)?; // equal zero if a==b, otherwise equal (a[i_0]-b[i_0])^{-1} - let byte_diff_inv = circuit_builder.create_witin(|| "byte_diff_inverse")?; + let byte_diff_inv = circuit_builder.create_witin(|| "byte_diff_inverse"); // define accumulated index sum from high to low let si_expr: Vec> = indexes @@ -403,7 +403,7 @@ impl UIntLimbs { - index_ne.expr(), )?; - let is_ltu = circuit_builder.create_witin(|| "is_ltu")?; + let is_ltu = circuit_builder.create_witin(|| "is_ltu"); // now we know the first non-equal byte pairs is (lhs_ne_byte, rhs_ne_byte) circuit_builder.lookup_ltu_byte(lhs_ne_byte.expr(), rhs_ne_byte.expr(), is_ltu.expr())?; Ok(UIntLtuConfig { @@ -421,7 +421,7 @@ impl UIntLimbs { circuit_builder: &mut CircuitBuilder, rhs: &UIntLimbs, ) -> Result { - let is_lt = circuit_builder.create_witin(|| "is_lt")?; + let is_lt = circuit_builder.create_witin(|| "is_lt"); // circuit_builder.assert_bit(|| "assert_bit", is_lt.expr())?; let lhs_msb = self.msb_decompose(circuit_builder)?; diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index e8a5553d6..e9308c667 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -5,13 +5,6 @@ use itertools::Itertools; use multilinear_extensions::util::max_usable_threads; use transcript::Transcript; -/// convert ext field element to u64, assume it is inside the range -#[allow(dead_code)] -pub fn ext_to_u64(x: &E) -> u64 { - let bases = x.as_bases(); - bases[0].to_canonical_u64() -} - pub fn i64_to_base(x: i64) -> F { if x >= 0 { F::from(x as u64) @@ -20,20 +13,6 @@ pub fn i64_to_base(x: i64) -> F { } } -/// This is helper function to convert witness of u8 limb into u16 limb -/// TODO: need a better way to keep consistency of LIMB_BITS -#[allow(dead_code)] -pub fn limb_u8_to_u16(input: &[u8]) -> Vec { - input - .chunks(2) - .map(|chunk| { - let low = chunk[0] as u16; - let high = if chunk.len() > 1 { chunk[1] as u16 } else { 0 }; - high * 256 + low - }) - .collect() -} - pub fn split_to_u8>(value: u32) -> Vec { (0..(u32::BITS / 8)) .scan(value, |acc, _| { @@ -72,15 +51,6 @@ pub(crate) fn add_one_to_big_num(limb_modulo: F, limbs: &[F]) -> Vec(x: i64) -> E::BaseField { - if x >= 0 { - E::BaseField::from(x as u64) - } else { - -E::BaseField::from((-x) as u64) - } -} - /// derive challenge from transcript and return all pows result pub fn get_challenge_pows( size: usize, diff --git a/ceno_zkvm/src/virtual_polys.rs b/ceno_zkvm/src/virtual_polys.rs index 4019f2d22..9aaafdf4f 100644 --- a/ceno_zkvm/src/virtual_polys.rs +++ b/ceno_zkvm/src/virtual_polys.rs @@ -192,8 +192,8 @@ mod tests { fn test_add_mle_list_by_expr() { let mut cs = ConstraintSystem::new(|| "test_root"); let mut cb = CircuitBuilder::::new(&mut cs); - let x = cb.create_witin(|| "x").unwrap(); - let y = cb.create_witin(|| "y").unwrap(); + let x = cb.create_witin(|| "x"); + let y = cb.create_witin(|| "y"); let wits_in: Vec> = (0..cs.num_witin as usize) .map(|_| vec![Goldilocks::from(1)].into_mle().into()) diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index bdffa2ec6..41306a8ed 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -133,7 +133,6 @@ pub struct LkMultiplicity { multiplicity: Arc; mem::variant_count::()]>>>, } -#[allow(dead_code)] impl LkMultiplicity { /// assert within range #[inline(always)] diff --git a/gkr-graph/src/structs.rs b/gkr-graph/src/structs.rs index 00bb4bfb3..4e206ee2a 100644 --- a/gkr-graph/src/structs.rs +++ b/gkr-graph/src/structs.rs @@ -45,7 +45,7 @@ pub enum PredType { #[derive(Clone, Debug)] pub struct CircuitNode { pub(crate) id: usize, - // TODO(Matthias): See whether we can remove this field. + // Note: only for debug output. #[allow(dead_code)] pub(crate) label: &'static str, pub(crate) circuit: Arc>, diff --git a/gkr/src/circuit.rs b/gkr/src/circuit.rs index 5e8c1ec93..d5459f328 100644 --- a/gkr/src/circuit.rs +++ b/gkr/src/circuit.rs @@ -42,14 +42,6 @@ where in_eq_vec: &[E], challenges: &HashMap>, ) -> E; - // TODO(Matthias, by 2024-11-01): review whether we need this function after all. - #[allow(dead_code)] - fn fix_out_variables( - &self, - in_size: usize, - out_eq_vec: &[E], - challenges: &HashMap>, - ) -> Vec; } impl EvaluateGate1In for &[Gate1In>] @@ -68,18 +60,6 @@ where * gate.scalar.eval(challenges) }) } - fn fix_out_variables( - &self, - in_size: usize, - out_eq_vec: &[E], - challenges: &HashMap>, - ) -> Vec { - let mut ans = vec![E::ZERO; in_size]; - for gate in self.iter() { - ans[gate.idx_in[0]] += out_eq_vec[gate.idx_out] * gate.scalar.eval(challenges); - } - ans - } } pub trait EvaluateGate2In diff --git a/gkr/src/circuit/circuit_witness.rs b/gkr/src/circuit/circuit_witness.rs index 1089e8d80..dd88319aa 100644 --- a/gkr/src/circuit/circuit_witness.rs +++ b/gkr/src/circuit/circuit_witness.rs @@ -487,603 +487,3 @@ impl<'a, F: ExtensionField> Debug for CircuitWitness<'a, F> { writeln!(f, "}}") } } - -// #[cfg(test)] -// mod test { -// use std::{collections::HashMap, ops::Neg}; - -// use ff::Field; -// use ff_ext::ExtensionField; -// use goldilocks::GoldilocksExt2; -// use itertools::Itertools; -// use simple_frontend::structs::{ChallengeConst, ChallengeId, CircuitBuilder, ConstantType}; - -// use crate::{ -// structs::{Circuit, CircuitWitness, LayerWitness}, -// utils::i64_to_field, -// }; - -// fn copy_and_paste_circuit() -> Circuit { -// let mut circuit_builder = CircuitBuilder::::new(); -// // Layer 3 -// let (_, input) = circuit_builder.create_witness_in(4); - -// // Layer 2 -// let mul_01 = circuit_builder.create_cell(); -// circuit_builder.mul2(mul_01, input[0], input[1], Ext::BaseField::ONE); - -// // Layer 1 -// let mul_012 = circuit_builder.create_cell(); -// circuit_builder.mul2(mul_012, mul_01, input[2], Ext::BaseField::ONE); - -// // Layer 0 -// let (_, mul_001123) = circuit_builder.create_witness_out(1); -// circuit_builder.mul3( -// mul_001123[0], -// mul_01, -// mul_012, -// input[3], -// Ext::BaseField::ONE, -// ); - -// circuit_builder.configure(); -// let circuit = Circuit::new(&circuit_builder); - -// circuit -// } - -// fn copy_and_paste_witness() -> ( -// Vec>, -// CircuitWitness, -// ) { -// // witness_in, single instance -// let inputs = vec![vec![ -// i64_to_field(5), -// i64_to_field(7), -// i64_to_field(11), -// i64_to_field(13), -// ]]; -// let witness_in = vec![LayerWitness { instances: inputs }]; - -// let layers = vec![ -// LayerWitness { -// instances: vec![vec![i64_to_field(175175)]], -// }, -// LayerWitness { -// instances: vec![vec![ -// i64_to_field(385), -// i64_to_field(35), -// i64_to_field(13), -// i64_to_field(0), // pad -// ]], -// }, -// LayerWitness { -// instances: vec![vec![i64_to_field(35), i64_to_field(11)]], -// }, -// LayerWitness { -// instances: vec![vec![ -// i64_to_field(5), -// i64_to_field(7), -// i64_to_field(11), -// i64_to_field(13), -// ]], -// }, -// ]; - -// let outputs = vec![vec![i64_to_field(175175)]]; -// let witness_out = vec![LayerWitness { instances: outputs }]; - -// ( -// witness_in.clone(), -// CircuitWitness { -// layers, -// witness_in, -// witness_out, -// n_instances: 1, -// challenges: HashMap::new(), -// }, -// ) -// } - -// fn paste_from_wit_in_circuit() -> Circuit { -// let mut circuit_builder = CircuitBuilder::::new(); - -// // Layer 2 -// let (_leaf_id1, leaves1) = circuit_builder.create_witness_in(3); -// let (_leaf_id2, leaves2) = circuit_builder.create_witness_in(3); -// // Unused input elements should also be in the circuit. -// let (_dummy_id, _) = circuit_builder.create_witness_in(3); -// let _ = circuit_builder.create_counter_in(1); -// let _ = circuit_builder.create_constant_in(2, 1); - -// // Layer 1 -// let (_, inners) = circuit_builder.create_witness_out(2); -// circuit_builder.mul2(inners[0], leaves1[0], leaves1[1], Ext::BaseField::ONE); -// circuit_builder.mul2(inners[1], leaves1[2], leaves2[0], Ext::BaseField::ONE); - -// // Layer 0 -// let (_, root) = circuit_builder.create_witness_out(1); -// circuit_builder.mul2(root[0], inners[0], inners[1], Ext::BaseField::ONE); - -// circuit_builder.configure(); -// let circuit = Circuit::new(&circuit_builder); -// circuit -// } - -// fn paste_from_wit_in_witness() -> ( -// Vec>, -// CircuitWitness, -// ) { -// // witness_in, single instance -// let leaves1 = vec![vec![i64_to_field(5), i64_to_field(7), i64_to_field(11)]]; -// let leaves2 = vec![vec![i64_to_field(13), i64_to_field(17), i64_to_field(19)]]; -// let dummy = vec![vec![i64_to_field(13), i64_to_field(17), i64_to_field(19)]]; -// let witness_in = vec![ -// LayerWitness { instances: leaves1 }, -// LayerWitness { instances: leaves2 }, -// LayerWitness { instances: dummy }, -// ]; - -// let layers = vec![ -// LayerWitness { -// instances: vec![vec![ -// i64_to_field(5005), -// i64_to_field(35), -// i64_to_field(143), -// i64_to_field(0), // pad -// ]], -// }, -// LayerWitness { -// instances: vec![vec![i64_to_field(35), i64_to_field(143)]], -// }, -// LayerWitness { -// instances: vec![vec![ -// i64_to_field(5), // leaves1 -// i64_to_field(7), -// i64_to_field(11), -// i64_to_field(13), // leaves2 -// i64_to_field(17), -// i64_to_field(19), -// i64_to_field(13), // dummy -// i64_to_field(17), -// i64_to_field(19), -// i64_to_field(0), // counter -// i64_to_field(1), -// i64_to_field(1), // constant -// i64_to_field(1), -// i64_to_field(0), // pad -// i64_to_field(0), -// i64_to_field(0), -// ]], -// }, -// ]; - -// let outputs1 = vec![vec![i64_to_field(35), i64_to_field(143)]]; -// let outputs2 = vec![vec![i64_to_field(5005)]]; -// let witness_out = vec![ -// LayerWitness { -// instances: outputs1, -// }, -// LayerWitness { -// instances: outputs2, -// }, -// ]; - -// ( -// witness_in.clone(), -// CircuitWitness { -// layers, -// witness_in, -// witness_out, -// n_instances: 1, -// challenges: HashMap::new(), -// }, -// ) -// } - -// fn copy_to_wit_out_circuit() -> Circuit { -// let mut circuit_builder = CircuitBuilder::::new(); -// // Layer 2 -// let (_, leaves) = circuit_builder.create_witness_in(4); - -// // Layer 1 -// let (_inner_id, inners) = circuit_builder.create_witness_out(2); -// circuit_builder.mul2(inners[0], leaves[0], leaves[1], Ext::BaseField::ONE); -// circuit_builder.mul2(inners[1], leaves[2], leaves[3], Ext::BaseField::ONE); - -// // Layer 0 -// let root = circuit_builder.create_cell(); -// circuit_builder.mul2(root, inners[0], inners[1], Ext::BaseField::ONE); -// circuit_builder.assert_const(root, 5005); - -// circuit_builder.configure(); -// let circuit = Circuit::new(&circuit_builder); - -// circuit -// } - -// fn copy_to_wit_out_witness() -> ( -// Vec>, -// CircuitWitness, -// ) { -// // witness_in, single instance -// let leaves = vec![vec![ -// i64_to_field(5), -// i64_to_field(7), -// i64_to_field(11), -// i64_to_field(13), -// ]]; -// let witness_in = vec![LayerWitness { instances: leaves }]; - -// let layers = vec![ -// LayerWitness { -// instances: vec![vec![ -// i64_to_field(5005), -// i64_to_field(35), -// i64_to_field(143), -// i64_to_field(0), // pad -// ]], -// }, -// LayerWitness { -// instances: vec![vec![i64_to_field(35), i64_to_field(143)]], -// }, -// LayerWitness { -// instances: vec![vec![ -// i64_to_field(5), -// i64_to_field(7), -// i64_to_field(11), -// i64_to_field(13), -// ]], -// }, -// ]; - -// let outputs = vec![vec![i64_to_field(35), i64_to_field(143)]]; -// let witness_out = vec![LayerWitness { instances: outputs }]; - -// ( -// witness_in.clone(), -// CircuitWitness { -// layers, -// witness_in, -// witness_out, -// n_instances: 1, -// challenges: HashMap::new(), -// }, -// ) -// } - -// fn copy_to_wit_out_witness_2() -> ( -// Vec>, -// CircuitWitness, -// ) { -// // witness_in, 2 instances -// let leaves = vec![ -// vec![ -// i64_to_field(5), -// i64_to_field(7), -// i64_to_field(11), -// i64_to_field(13), -// ], -// vec![ -// i64_to_field(5), -// i64_to_field(13), -// i64_to_field(11), -// i64_to_field(7), -// ], -// ]; -// let witness_in = vec![LayerWitness { instances: leaves }]; - -// let layers = vec![ -// LayerWitness { -// instances: vec![ -// vec![ -// i64_to_field(5005), -// i64_to_field(35), -// i64_to_field(143), -// i64_to_field(0), // pad -// ], -// vec![ -// i64_to_field(5005), -// i64_to_field(65), -// i64_to_field(77), -// i64_to_field(0), // pad -// ], -// ], -// }, -// LayerWitness { -// instances: vec![ -// vec![i64_to_field(35), i64_to_field(143)], -// vec![i64_to_field(65), i64_to_field(77)], -// ], -// }, -// LayerWitness { -// instances: vec![ -// vec![ -// i64_to_field(5), -// i64_to_field(7), -// i64_to_field(11), -// i64_to_field(13), -// ], -// vec![ -// i64_to_field(5), -// i64_to_field(13), -// i64_to_field(11), -// i64_to_field(7), -// ], -// ], -// }, -// ]; - -// let outputs = vec![ -// vec![i64_to_field(35), i64_to_field(143)], -// vec![i64_to_field(65), i64_to_field(77)], -// ]; -// let witness_out = vec![LayerWitness { instances: outputs }]; - -// ( -// witness_in.clone(), -// CircuitWitness { -// layers, -// witness_in, -// witness_out, -// n_instances: 2, -// challenges: HashMap::new(), -// }, -// ) -// } - -// fn rlc_circuit() -> Circuit { -// let mut circuit_builder = CircuitBuilder::::new(); -// // Layer 2 -// let (_, leaves) = circuit_builder.create_witness_in(4); - -// // Layer 1 -// let inners = circuit_builder.create_ext_cells(2); -// circuit_builder.rlc(&inners[0], &[leaves[0], leaves[1]], 0 as ChallengeId); -// circuit_builder.rlc(&inners[1], &[leaves[2], leaves[3]], 1 as ChallengeId); - -// // Layer 0 -// let (_root_id, roots) = circuit_builder.create_ext_witness_out(1); -// circuit_builder.mul2_ext(&roots[0], &inners[0], &inners[1], Ext::BaseField::ONE); - -// circuit_builder.configure(); -// let circuit = Circuit::new(&circuit_builder); - -// circuit -// } - -// fn rlc_witness_2() -> ( -// Vec>, -// CircuitWitness, -// Vec, -// ) -// where -// Ext: ExtensionField, -// { -// let challenges = vec![ -// Ext::from_bases(&[i64_to_field(31), i64_to_field(37)]), -// Ext::from_bases(&[i64_to_field(97), i64_to_field(23)]), -// ]; -// let challenge_pows = challenges -// .iter() -// .enumerate() -// .map(|(i, x)| { -// (0..3) -// .map(|j| { -// ( -// ChallengeConst { -// challenge: i as u8, -// exp: j as u64, -// }, -// x.pow(&[j as u64]), -// ) -// }) -// .collect_vec() -// }) -// .collect_vec(); - -// // witness_in, double instances -// let leaves = vec![ -// vec![ -// i64_to_field(5), -// i64_to_field(7), -// i64_to_field(11), -// i64_to_field(13), -// ], -// vec![ -// i64_to_field(5), -// i64_to_field(13), -// i64_to_field(11), -// i64_to_field(7), -// ], -// ]; -// let witness_in = vec![LayerWitness { -// instances: leaves.clone(), -// }]; - -// let inner00: Ext = challenge_pows[0][0].1 * (&leaves[0][0]) -// + challenge_pows[0][1].1 * (&leaves[0][1]) -// + challenge_pows[0][2].1; -// let inner01: Ext = challenge_pows[1][0].1 * (&leaves[0][2]) -// + challenge_pows[1][1].1 * (&leaves[0][3]) -// + challenge_pows[1][2].1; -// let inner10: Ext = challenge_pows[0][0].1 * (&leaves[1][0]) -// + challenge_pows[0][1].1 * (&leaves[1][1]) -// + challenge_pows[0][2].1; -// let inner11: Ext = challenge_pows[1][0].1 * (&leaves[1][2]) -// + challenge_pows[1][1].1 * (&leaves[1][3]) -// + challenge_pows[1][2].1; - -// let inners = vec![ -// [ -// inner00.clone().as_bases().to_vec(), -// inner01.clone().as_bases().to_vec(), -// ] -// .concat(), -// [ -// inner10.clone().as_bases().to_vec(), -// inner11.clone().as_bases().to_vec(), -// ] -// .concat(), -// ]; - -// let root_tmp0 = vec![ -// inners[0][0] * inners[0][2], -// inners[0][0] * inners[0][3], -// inners[0][1] * inners[0][2], -// inners[0][1] * inners[0][3], -// ]; -// let root_tmp1 = vec![ -// inners[1][0] * inners[1][2], -// inners[1][0] * inners[1][3], -// inners[1][1] * inners[1][2], -// inners[1][1] * inners[1][3], -// ]; -// let root_tmps = vec![root_tmp0, root_tmp1]; - -// let root0 = inner00 * inner01; -// let root1 = inner10 * inner11; -// let roots = vec![root0.as_bases().to_vec(), root1.as_bases().to_vec()]; - -// let layers = vec![ -// LayerWitness { -// instances: roots.clone(), -// }, -// LayerWitness { -// instances: root_tmps, -// }, -// LayerWitness { instances: inners }, -// LayerWitness { instances: leaves }, -// ]; - -// let outputs = roots; -// let witness_out = vec![LayerWitness { instances: outputs }]; - -// ( -// witness_in.clone(), -// CircuitWitness { -// layers, -// witness_in, -// witness_out, -// n_instances: 2, -// challenges: challenge_pows -// .iter() -// .flatten() -// .cloned() -// .map(|(k, v)| (k, v.as_bases().to_vec())) -// .collect::>(), -// }, -// challenges, -// ) -// } - -// #[test] -// fn test_add_instances() { -// let circuit = copy_and_paste_circuit::(); -// let (wits_in, expect_circuit_wits) = copy_and_paste_witness::(); - -// let mut circuit_wits = CircuitWitness::new(&circuit, vec![]); -// circuit_wits.add_instances(&circuit, wits_in, 1); - -// assert_eq!(circuit_wits, expect_circuit_wits); - -// let circuit = paste_from_wit_in_circuit::(); -// let (wits_in, expect_circuit_wits) = paste_from_wit_in_witness::(); - -// let mut circuit_wits = CircuitWitness::new(&circuit, vec![]); -// circuit_wits.add_instances(&circuit, wits_in, 1); - -// assert_eq!(circuit_wits, expect_circuit_wits); - -// let circuit = copy_to_wit_out_circuit::(); -// let (wits_in, expect_circuit_wits) = copy_to_wit_out_witness::(); - -// let mut circuit_wits = CircuitWitness::new(&circuit, vec![]); -// circuit_wits.add_instances(&circuit, wits_in, 1); - -// assert_eq!(circuit_wits, expect_circuit_wits); - -// let (wits_in, expect_circuit_wits) = copy_to_wit_out_witness_2::(); -// let mut circuit_wits = CircuitWitness::new(&circuit, vec![]); -// circuit_wits.add_instances(&circuit, wits_in, 2); - -// assert_eq!(circuit_wits, expect_circuit_wits); -// } - -// #[test] -// fn test_check_correctness() { -// let circuit = copy_to_wit_out_circuit::(); -// let (_wits_in, expect_circuit_wits) = copy_to_wit_out_witness_2::(); - -// expect_circuit_wits.check_correctness(&circuit); -// } - -// #[test] -// fn test_challenges() { -// let circuit = rlc_circuit::(); -// let (wits_in, expect_circuit_wits, challenges) = rlc_witness_2::(); -// let mut circuit_wits = CircuitWitness::new(&circuit, challenges); -// circuit_wits.add_instances(&circuit, wits_in, 2); - -// assert_eq!(circuit_wits, expect_circuit_wits); -// } - -// #[test] -// fn test_orphan_const_input() { -// // create circuit -// let mut circuit_builder = CircuitBuilder::::new(); - -// let (_, leaves) = circuit_builder.create_witness_in(3); -// let mul_0_1_res = circuit_builder.create_cell(); - -// // 2 * 3 = 6 -// circuit_builder.mul2( -// mul_0_1_res, -// leaves[0], -// leaves[1], -// ::BaseField::ONE, -// ); - -// let (_, out) = circuit_builder.create_witness_out(2); -// // like a bypass gate, passing 6 to output out[0] -// circuit_builder.add( -// out[0], -// mul_0_1_res, -// ::BaseField::ONE, -// ); - -// // assert const 2 -// circuit_builder.assert_const(leaves[2], 5); - -// // 5 + -5 = 0, put in out[1] -// circuit_builder.add( -// out[1], -// leaves[2], -// ::BaseField::ONE, -// ); -// circuit_builder.add_const( -// out[1], -// ::BaseField::from(5).neg(), // -5 -// ); - -// // assert out[1] == 0 -// circuit_builder.assert_const(out[1], 0); - -// circuit_builder.configure(); -// let circuit = Circuit::new(&circuit_builder); - -// let mut circuit_wits = CircuitWitness::new(&circuit, vec![]); -// let witness_in = vec![LayerWitness { -// instances: vec![vec![i64_to_field(2), i64_to_field(3), i64_to_field(5)]], -// }]; -// circuit_wits.add_instances(&circuit, witness_in, 1); - -// println!("circuit_wits {:?}", circuit_wits); -// let output_layer_witness = &circuit_wits.layers[0]; -// for gate in circuit.assert_consts.iter() { -// if let ConstantType::Field(constant) = gate.scalar { -// assert_eq!(output_layer_witness.instances[0][gate.idx_out], constant); -// } -// } -// } -// } diff --git a/multilinear_extensions/src/virtual_poly.rs b/multilinear_extensions/src/virtual_poly.rs index 45e4e9fb7..7d889111e 100644 --- a/multilinear_extensions/src/virtual_poly.rs +++ b/multilinear_extensions/src/virtual_poly.rs @@ -8,7 +8,7 @@ use ark_std::{end_timer, iterable::Iterable, rand::Rng, start_timer}; use ff::{Field, PrimeField}; use ff_ext::ExtensionField; use rayon::{ - iter::{IntoParallelIterator, IntoParallelRefIterator}, + iter::IntoParallelIterator, prelude::{IndexedParallelIterator, ParallelIterator}, slice::ParallelSliceMut, }; @@ -478,37 +478,6 @@ pub fn build_eq_x_r_vec(r: &[E]) -> Vec { } } -/// A helper function to build eq(x, r) via dynamic programing tricks. -/// This function takes 2^num_var iterations, and per iteration with 1 multiplication. -#[allow(dead_code)] -fn build_eq_x_r_helper(r: &[E], buf: &mut [Vec; 2]) { - buf[0][0] = E::ONE; - if r.is_empty() { - buf[0].resize(1, E::ZERO); - return; - } - for (i, r) in r.iter().rev().enumerate() { - let [current, next] = buf; - let (cur_size, next_size) = (1 << i, 1 << (i + 1)); - // suppose at the previous step we processed buf [0..size] - // for the current step we are populating new buf[0..2*size] - // for j travese 0..size - // buf[2*j + 1] = r * buf[j] - // buf[2*j] = (1 - r) * buf[j] - current[0..cur_size] - .par_iter() - .zip_eq(next[0..next_size].par_chunks_mut(2)) - .with_min_len(64) - .for_each(|(prev_val, next_vals)| { - assert!(next_vals.len() == 2); - let tmp = *r * prev_val; - next_vals[1] = tmp; - next_vals[0] = *prev_val - tmp; - }); - buf.swap(0, 1); // swap rolling buffer - } -} - #[cfg(test)] mod tests { use crate::virtual_poly::{build_eq_x_r_vec, build_eq_x_r_vec_sequential}; diff --git a/singer/src/scheme.rs b/singer/src/scheme.rs index d3b5e847d..6c5448e2e 100644 --- a/singer/src/scheme.rs +++ b/singer/src/scheme.rs @@ -1,30 +1,15 @@ use ff_ext::ExtensionField; // TODO: to be changed to a real PCS scheme. -type BatchedPCSProof = Vec>; -type Commitment = Vec; pub mod prover; pub mod verifier; -pub struct CommitPhaseProof { - // TODO(Matthias): Check whether we need this field. - #[allow(dead_code)] - commitments: Vec>, -} - pub type GKRGraphProof = gkr_graph::structs::IOPProof; pub type GKRGraphProverState = gkr_graph::structs::IOPProverState; pub type GKRGraphVerifierState = gkr_graph::structs::IOPVerifierState; -pub struct OpenPhaseProof { - // TODO(Matthias): Check whether we need this field. - #[allow(dead_code)] - pcs_proof: BatchedPCSProof, -} - pub struct SingerProof { - // commitment_phase_proof: CommitPhaseProof, + // TODO: restore and implement `commitment_phase_proof` and `open_phase_proof` gkr_phase_proof: GKRGraphProof, - // open_phase_proof: OpenPhaseProof, } diff --git a/singer/src/utils.rs b/singer/src/utils.rs index 02df9a112..c9cca791c 100644 --- a/singer/src/utils.rs +++ b/singer/src/utils.rs @@ -3,16 +3,6 @@ use ff_ext::ExtensionField; use itertools::izip; use simple_frontend::structs::{CellId, CircuitBuilder}; -// TODO(Matthias): Check whether we need this function. -#[allow(dead_code)] -pub(crate) fn i64_to_base_field(x: i64) -> E::BaseField { - if x >= 0 { - E::BaseField::from(x as u64) - } else { - -E::BaseField::from((-x) as u64) - } -} - pub(crate) fn add_assign_each_cell( circuit_builder: &mut CircuitBuilder, dest: &[CellId], diff --git a/sumcheck/src/structs.rs b/sumcheck/src/structs.rs index 2397a9cb8..a6089722a 100644 --- a/sumcheck/src/structs.rs +++ b/sumcheck/src/structs.rs @@ -14,7 +14,6 @@ pub struct IOPProof { pub proofs: Vec>, } impl IOPProof { - #[allow(dead_code)] pub fn extract_sum(&self) -> E { self.proofs[0].evaluations[0] + self.proofs[0].evaluations[1] }