diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index 2037b4910..4192607cf 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -84,6 +84,11 @@ impl Platform { 10 } + /// Register containing the 2nd function argument. (x11, a1) + pub const fn reg_arg1(&self) -> RegIdx { + 11 + } + /// The code of ecall HALT. pub const fn ecall_halt(&self) -> u32 { 0 diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index b45470e4b..d4441aec7 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -93,7 +93,10 @@ impl EmuContext for VMState { fn ecall(&mut self) -> Result { let function = self.load_register(self.platform.reg_ecall())?; if function == self.platform.ecall_halt() { - let _exit_code = self.load_register(self.platform.reg_arg0())?; + let exit_code = self.load_register(self.platform.reg_arg0())?; + tracing::debug!("halt with exit_code={}", exit_code); + + self.set_pc(ByteAddr(0)); self.halted = true; Ok(true) } else { diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 5e5a53fb7..b789e72ca 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -126,6 +126,7 @@ fn bench_add(c: &mut Criterion) { &circuit_pk, wits_in.into_iter().map(|mle| mle.into()).collect_vec(), commit, + &[], num_instances, max_threads, &mut transcript, diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index 23a01d604..771357044 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -10,14 +10,16 @@ use const_env::from_env; use ceno_emul::{ ByteAddr, - InsnKind::{ADD, BLTU}, + InsnKind::{ADD, BLTU, EANY}, StepRecord, VMState, CENO_PLATFORM, }; use ceno_zkvm::{ - scheme::{constants::MAX_NUM_VARIABLES, verifier::ZKVMVerifier}, + instructions::riscv::ecall::HaltInstruction, + scheme::{constants::MAX_NUM_VARIABLES, verifier::ZKVMVerifier, PublicValues}, structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, tables::{AndTableCircuit, LtuTableCircuit, U16TableCircuit}, }; +use ff_ext::ff::Field; use goldilocks::GoldilocksExt2; use mpcs::{Basefold, BasefoldRSParams, PolynomialCommitmentScheme}; use rand_chacha::ChaCha8Rng; @@ -103,6 +105,7 @@ fn main() { // opcode circuits let add_config = zkvm_cs.register_opcode_circuit::>(); let bltu_config = zkvm_cs.register_opcode_circuit::(); + let halt_config = zkvm_cs.register_opcode_circuit::>(); // tables let u16_range_config = zkvm_cs.register_table_circuit::>(); let and_config = zkvm_cs.register_table_circuit::>(); @@ -118,6 +121,7 @@ fn main() { let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); zkvm_fixed_traces.register_opcode_circuit::(&zkvm_cs); + zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); zkvm_fixed_traces.register_table_circuit::>( &zkvm_cs, @@ -172,15 +176,25 @@ fn main() { .collect::>(); let mut add_records = Vec::new(); let mut bltu_records = Vec::new(); - all_records.iter().for_each(|record| { + let mut halt_records = Vec::new(); + all_records.into_iter().for_each(|record| { let kind = record.insn().kind().1; - if kind == ADD { - add_records.push(record.clone()); - } else if kind == BLTU { - bltu_records.push(record.clone()); + match kind { + ADD => add_records.push(record), + BLTU => bltu_records.push(record), + EANY => { + if record.rs1().unwrap().value == CENO_PLATFORM.ecall_halt() { + halt_records.push(record); + } + } + _ => {} } }); + assert_eq!(halt_records.len(), 1); + let exit_code = halt_records[0].rs2().unwrap().value; + let pi = PublicValues::new(exit_code, 0); + tracing::info!( "tracer generated {} ADD records, {} BLTU records", add_records.len(), @@ -195,6 +209,10 @@ fn main() { zkvm_witness .assign_opcode_circuit::(&zkvm_cs, &bltu_config, bltu_records) .unwrap(); + zkvm_witness + .assign_opcode_circuit::>(&zkvm_cs, &halt_config, halt_records) + .unwrap(); + zkvm_witness.finalize_lk_multiplicities(); // assign table circuits zkvm_witness @@ -217,8 +235,8 @@ fn main() { let timer = Instant::now(); let transcript = Transcript::new(b"riscv"); - let zkvm_proof = prover - .create_proof(zkvm_witness, max_threads, transcript) + let mut zkvm_proof = prover + .create_proof(zkvm_witness, pi, max_threads, transcript) .expect("create_proof failed"); println!( @@ -230,8 +248,16 @@ fn main() { let transcript = Transcript::new(b"riscv"); assert!( verifier - .verify_proof(zkvm_proof, transcript) + .verify_proof(zkvm_proof.clone(), transcript) .expect("verify proof return with error"), ); + + let transcript = Transcript::new(b"riscv"); + // change public input maliciously should cause verifier to reject proof + zkvm_proof.pv[0] = ::BaseField::ONE; + zkvm_proof.pv[1] = ::BaseField::ONE; + verifier + .verify_proof(zkvm_proof, transcript) + .expect_err("verify proof should return with error"); } } diff --git a/ceno_zkvm/src/chip_handler.rs b/ceno_zkvm/src/chip_handler.rs index 57170be4b..bb27df204 100644 --- a/ceno_zkvm/src/chip_handler.rs +++ b/ceno_zkvm/src/chip_handler.rs @@ -2,7 +2,7 @@ use ff_ext::ExtensionField; use crate::{ error::ZKVMError, - expression::{Expression, WitIn}, + expression::{Expression, ToExpr, WitIn}, gadgets::IsLtConfig, instructions::riscv::constants::UINT_LIMBS, }; @@ -27,7 +27,7 @@ pub trait RegisterChipOperations, N: FnOnce( fn register_read( &mut self, name_fn: N, - register_id: &WitIn, + register_id: impl ToExpr>, prev_ts: Expression, ts: Expression, value: RegisterExpr, @@ -37,7 +37,7 @@ pub trait RegisterChipOperations, N: FnOnce( fn register_write( &mut self, name_fn: N, - register_id: &WitIn, + register_id: impl ToExpr>, prev_ts: Expression, ts: Expression, prev_values: RegisterExpr, diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index bf3116caa..62bf68888 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -5,8 +5,9 @@ use ff_ext::ExtensionField; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, error::ZKVMError, - expression::{Expression, Fixed, ToExpr, WitIn}, + expression::{Expression, Fixed, Instance, ToExpr, WitIn}, gadgets::IsLtConfig, + instructions::riscv::constants::EXIT_CODE_IDX, structs::ROMType, tables::InsnRecord, }; @@ -34,6 +35,14 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { self.cs.create_fixed(name_fn) } + pub fn query_exit_code(&mut self) -> Result<[Instance; 2], ZKVMError> { + Ok([ + self.cs.query_instance(|| "exit_code_low", EXIT_CODE_IDX)?, + self.cs + .query_instance(|| "exit_code_high", EXIT_CODE_IDX + 1)?, + ]) + } + pub fn lk_record( &mut self, name_fn: N, diff --git a/ceno_zkvm/src/chip_handler/register.rs b/ceno_zkvm/src/chip_handler/register.rs index 474ce32bd..a18019405 100644 --- a/ceno_zkvm/src/chip_handler/register.rs +++ b/ceno_zkvm/src/chip_handler/register.rs @@ -3,7 +3,7 @@ use ff_ext::ExtensionField; use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, - expression::{Expression, ToExpr, WitIn}, + expression::{Expression, ToExpr}, gadgets::IsLtConfig, instructions::riscv::constants::UINT_LIMBS, structs::RAMType, @@ -17,7 +17,7 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOpe fn register_read( &mut self, name_fn: N, - register_id: &WitIn, + register_id: impl ToExpr>, prev_ts: Expression, ts: Expression, value: RegisterExpr, @@ -68,12 +68,13 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOpe fn register_write( &mut self, name_fn: N, - register_id: &WitIn, + register_id: impl ToExpr>, prev_ts: Expression, ts: Expression, prev_values: RegisterExpr, value: RegisterExpr, ) -> Result<(Expression, IsLtConfig), ZKVMError> { + assert!(register_id.expr().degree() <= 1); self.namespace(name_fn, |cb| { // READ (a, v, t) let read_record = cb.rlc_chip_record( diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index 773136e66..3967c5ff0 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -1,12 +1,12 @@ use itertools::Itertools; -use std::marker::PhantomData; +use std::{collections::HashMap, marker::PhantomData}; use ff_ext::ExtensionField; use mpcs::PolynomialCommitmentScheme; use crate::{ error::ZKVMError, - expression::{Expression, Fixed, WitIn}, + expression::{Expression, Fixed, Instance, WitIn}, structs::{ProvingKey, VerifyingKey, WitnessId}, witness::RowMajorMatrix, }; @@ -79,6 +79,8 @@ pub struct ConstraintSystem { pub num_fixed: usize, pub fixed_namespace_map: Vec, + pub instance_name_map: HashMap, + pub r_expressions: Vec>, pub r_expressions_namespace_map: Vec, @@ -117,6 +119,7 @@ impl ConstraintSystem { num_fixed: 0, fixed_namespace_map: vec![], ns: NameSpace::new(root_name_fn), + instance_name_map: HashMap::new(), r_expressions: vec![], r_expressions_namespace_map: vec![], w_expressions: vec![], @@ -193,6 +196,19 @@ impl ConstraintSystem { Ok(f) } + pub fn query_instance, N: FnOnce() -> NR>( + &mut self, + n: N, + idx: usize, + ) -> Result { + let i = Instance(idx); + + let name = n().into(); + self.instance_name_map.insert(i, name); + + Ok(i) + } + pub fn lk_record, N: FnOnce() -> NR>( &mut self, name_fn: N, diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index 82dbb1fcd..0a6887600 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -27,6 +27,8 @@ pub enum Expression { WitIn(WitnessId), /// Fixed Fixed(Fixed), + /// Public Values + Instance(Instance), /// Constant poly Constant(E::BaseField), /// This is the sum of two expression @@ -34,6 +36,7 @@ pub enum Expression { /// This is the product of two polynomials Product(Box>, Box>), /// This is x, a, b expr to represent ax + b polynomial + /// and x is one of wit / fixed / instance, a and b are either constant or challenge ScaledSum(Box>, Box>, Box>), Challenge(ChallengeId, usize, E, E), // (challenge_id, power, scalar, offset) } @@ -53,10 +56,11 @@ impl Expression { match self { Expression::Fixed(_) => 1, Expression::WitIn(_) => 1, + Expression::Instance(_) => 0, Expression::Constant(_) => 0, Expression::Sum(a_expr, b_expr) => max(a_expr.degree(), b_expr.degree()), Expression::Product(a_expr, b_expr) => a_expr.degree() + b_expr.degree(), - Expression::ScaledSum(_, _, _) => 1, + Expression::ScaledSum(x, _, _) => x.degree(), Expression::Challenge(_, _, _, _) => 0, } } @@ -71,25 +75,64 @@ impl Expression { sum: &impl Fn(T, T) -> T, product: &impl Fn(T, T) -> T, scaled: &impl Fn(T, T, T) -> T, + ) -> T { + self.evaluate_with_instance( + fixed_in, + wit_in, + &|_| unreachable!(), + constant, + challenge, + sum, + product, + scaled, + ) + } + + #[allow(clippy::too_many_arguments)] + pub fn evaluate_with_instance( + &self, + fixed_in: &impl Fn(&Fixed) -> T, + wit_in: &impl Fn(WitnessId) -> T, // witin id + instance: &impl Fn(Instance) -> T, + constant: &impl Fn(E::BaseField) -> T, + challenge: &impl Fn(ChallengeId, usize, E, E) -> T, + sum: &impl Fn(T, T) -> T, + product: &impl Fn(T, T) -> T, + scaled: &impl Fn(T, T, T) -> T, ) -> T { match self { Expression::Fixed(f) => fixed_in(f), Expression::WitIn(witness_id) => wit_in(*witness_id), + Expression::Instance(i) => instance(*i), Expression::Constant(scalar) => constant(*scalar), Expression::Sum(a, b) => { - let a = a.evaluate(fixed_in, wit_in, constant, challenge, sum, product, scaled); - let b = b.evaluate(fixed_in, wit_in, constant, challenge, sum, product, scaled); + let a = a.evaluate_with_instance( + fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, + ); + let b = b.evaluate_with_instance( + fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, + ); sum(a, b) } Expression::Product(a, b) => { - let a = a.evaluate(fixed_in, wit_in, constant, challenge, sum, product, scaled); - let b = b.evaluate(fixed_in, wit_in, constant, challenge, sum, product, scaled); + let a = a.evaluate_with_instance( + fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, + ); + let b = b.evaluate_with_instance( + fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, + ); product(a, b) } Expression::ScaledSum(x, a, b) => { - let x = x.evaluate(fixed_in, wit_in, constant, challenge, sum, product, scaled); - let a = a.evaluate(fixed_in, wit_in, constant, challenge, sum, product, scaled); - let b = b.evaluate(fixed_in, wit_in, constant, challenge, sum, product, scaled); + let x = x.evaluate_with_instance( + fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, + ); + let a = a.evaluate_with_instance( + fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, + ); + let b = b.evaluate_with_instance( + fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, + ); scaled(x, a, b) } Expression::Challenge(challenge_id, pow, scalar, offset) => { @@ -117,6 +160,7 @@ impl Expression { match expr { Expression::Fixed(_) => false, Expression::WitIn(_) => false, + Expression::Instance(_) => false, Expression::Constant(c) => *c == E::BaseField::ZERO, Expression::Sum(a, b) => Self::is_zero_expr(a) && Self::is_zero_expr(b), Expression::Product(a, b) => Self::is_zero_expr(a) || Self::is_zero_expr(b), @@ -133,7 +177,8 @@ impl Expression { Expression::Fixed(_) | Expression::WitIn(_) | Expression::Challenge(..) - | Expression::Constant(_), + | Expression::Constant(_) + | Expression::Instance(_), _, ) => true, (Expression::Sum(a, b), MonomialState::SumTerm) => { @@ -161,11 +206,13 @@ impl Neg for Expression { type Output = Expression; fn neg(self) -> Self::Output { match self { - Expression::Fixed(_) | Expression::WitIn(_) => Expression::ScaledSum( - Box::new(self), - Box::new(Expression::Constant(E::BaseField::ONE.neg())), - Box::new(Expression::Constant(E::BaseField::ZERO)), - ), + Expression::Fixed(_) | Expression::WitIn(_) | Expression::Instance(_) => { + Expression::ScaledSum( + Box::new(self), + Box::new(Expression::Constant(E::BaseField::ONE.neg())), + Box::new(Expression::Constant(E::BaseField::ZERO)), + ) + } Expression::Constant(c1) => Expression::Constant(c1.neg()), Expression::Sum(a, b) => { Expression::Sum(Box::new(-a.deref().clone()), Box::new(-b.deref().clone())) @@ -189,6 +236,40 @@ impl Add for Expression { type Output = Expression; fn add(self, rhs: Expression) -> Expression { match (&self, &rhs) { + // constant + witness + // constant + fixed + // constant + instance + (Expression::WitIn(_), Expression::Constant(_)) + | (Expression::Fixed(_), Expression::Constant(_)) + | (Expression::Instance(_), Expression::Constant(_)) => Expression::ScaledSum( + Box::new(self), + Box::new(Expression::Constant(E::BaseField::ONE)), + Box::new(rhs), + ), + (Expression::Constant(_), Expression::WitIn(_)) + | (Expression::Constant(_), Expression::Fixed(_)) + | (Expression::Constant(_), Expression::Instance(_)) => Expression::ScaledSum( + Box::new(rhs), + Box::new(Expression::Constant(E::BaseField::ONE)), + Box::new(self), + ), + // challenge + witness + // challenge + fixed + // challenge + instance + (Expression::WitIn(_), Expression::Challenge(..)) + | (Expression::Fixed(_), Expression::Challenge(..)) + | (Expression::Instance(_), Expression::Challenge(..)) => Expression::ScaledSum( + Box::new(self), + Box::new(Expression::Constant(E::BaseField::ONE)), + Box::new(rhs), + ), + (Expression::Challenge(..), Expression::WitIn(_)) + | (Expression::Challenge(..), Expression::Fixed(_)) + | (Expression::Challenge(..), Expression::Instance(_)) => Expression::ScaledSum( + Box::new(rhs), + Box::new(Expression::Constant(E::BaseField::ONE)), + Box::new(self), + ), // constant + challenge ( Expression::Constant(c1), @@ -219,7 +300,7 @@ impl Add for Expression { // constant + constant (Expression::Constant(c1), Expression::Constant(c2)) => Expression::Constant(*c1 + c2), - // constant + scaledsum + // constant + scaled sum (c1 @ Expression::Constant(_), Expression::ScaledSum(x, a, b)) | (Expression::ScaledSum(x, a, b), c1 @ Expression::Constant(_)) => { Expression::ScaledSum( @@ -229,16 +310,6 @@ impl Add for Expression { ) } - // challenge + scaledsum - (c1 @ Expression::Challenge(..), Expression::ScaledSum(x, a, b)) - | (Expression::ScaledSum(x, a, b), c1 @ Expression::Challenge(..)) => { - Expression::ScaledSum( - x.clone(), - a.clone(), - Box::new(b.deref().clone() + c1.clone()), - ) - } - _ => Expression::Sum(Box::new(self), Box::new(rhs)), } } @@ -254,6 +325,50 @@ impl Sub for Expression { type Output = Expression; fn sub(self, rhs: Expression) -> Expression { match (&self, &rhs) { + // witness - constant + // fixed - constant + // instance - constant + (Expression::WitIn(_), Expression::Constant(_)) + | (Expression::Fixed(_), Expression::Constant(_)) + | (Expression::Instance(_), Expression::Constant(_)) => Expression::ScaledSum( + Box::new(self), + Box::new(Expression::Constant(E::BaseField::ONE)), + Box::new(rhs.neg()), + ), + + // constant - witness + // constant - fixed + // constant - instance + (Expression::Constant(_), Expression::WitIn(_)) + | (Expression::Constant(_), Expression::Fixed(_)) + | (Expression::Constant(_), Expression::Instance(_)) => Expression::ScaledSum( + Box::new(rhs), + Box::new(Expression::Constant(E::BaseField::ONE.neg())), + Box::new(self), + ), + + // witness - challenge + // fixed - challenge + // instance - challenge + (Expression::WitIn(_), Expression::Challenge(..)) + | (Expression::Fixed(_), Expression::Challenge(..)) + | (Expression::Instance(_), Expression::Challenge(..)) => Expression::ScaledSum( + Box::new(self), + Box::new(Expression::Constant(E::BaseField::ONE)), + Box::new(rhs.neg()), + ), + + // challenge - witness + // challenge - fixed + // challenge - instance + (Expression::Challenge(..), Expression::WitIn(_)) + | (Expression::Challenge(..), Expression::Fixed(_)) + | (Expression::Challenge(..), Expression::Instance(_)) => Expression::ScaledSum( + Box::new(rhs), + Box::new(Expression::Constant(E::BaseField::ONE.neg())), + Box::new(self), + ), + // constant - challenge ( Expression::Constant(c1), @@ -332,15 +447,31 @@ impl Mul for Expression { fn mul(self, rhs: Expression) -> Expression { match (&self, &rhs) { // constant * witin + // constant * fixed (c @ Expression::Constant(_), w @ Expression::WitIn(..)) - | (w @ Expression::WitIn(..), c @ Expression::Constant(_)) => Expression::ScaledSum( + | (w @ Expression::WitIn(..), c @ Expression::Constant(_)) + | (c @ Expression::Constant(_), w @ Expression::Fixed(..)) + | (w @ Expression::Fixed(..), c @ Expression::Constant(_)) => Expression::ScaledSum( Box::new(w.clone()), Box::new(c.clone()), Box::new(Expression::Constant(E::BaseField::ZERO)), ), // challenge * witin + // challenge * fixed (c @ Expression::Challenge(..), w @ Expression::WitIn(..)) - | (w @ Expression::WitIn(..), c @ Expression::Challenge(..)) => Expression::ScaledSum( + | (w @ Expression::WitIn(..), c @ Expression::Challenge(..)) + | (c @ Expression::Challenge(..), w @ Expression::Fixed(..)) + | (w @ Expression::Fixed(..), c @ Expression::Challenge(..)) => Expression::ScaledSum( + Box::new(w.clone()), + Box::new(c.clone()), + Box::new(Expression::Constant(E::BaseField::ZERO)), + ), + // instance * witin + // instance * fixed + (c @ Expression::Instance(..), w @ Expression::WitIn(..)) + | (w @ Expression::WitIn(..), c @ Expression::Instance(..)) + | (c @ Expression::Instance(..), w @ Expression::Fixed(..)) + | (w @ Expression::Fixed(..), c @ Expression::Instance(..)) => Expression::ScaledSum( Box::new(w.clone()), Box::new(c.clone()), Box::new(Expression::Constant(E::BaseField::ZERO)), @@ -433,9 +564,12 @@ pub struct WitIn { pub id: WitnessId, } -#[derive(Clone, Debug, Ord, PartialOrd, Eq, PartialEq)] +#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq)] pub struct Fixed(pub usize); +#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)] +pub struct Instance(pub usize); + impl WitIn { pub fn from_expr( name: N, @@ -497,6 +631,20 @@ impl ToExpr for WitIn { } } +impl ToExpr for &WitIn { + type Output = Expression; + fn expr(&self) -> Expression { + Expression::WitIn(self.id) + } +} + +impl ToExpr for Instance { + type Output = Expression; + fn expr(&self) -> Expression { + Expression::Instance(*self) + } +} + impl> ToExpr for F { type Output = Expression; fn expr(&self) -> Expression { @@ -551,6 +699,7 @@ pub mod fmt { } Expression::Constant(constant) => base_field::(constant, true).to_string(), Expression::Fixed(fixed) => format!("{:?}", fixed), + Expression::Instance(i) => format!("{:?}", i), Expression::Sum(left, right) => { let s = format!("{} + {}", expr(left, wtns, false), expr(right, wtns, false)); if add_prn_sum { format!("({})", s) } else { s } diff --git a/ceno_zkvm/src/expression/monomial.rs b/ceno_zkvm/src/expression/monomial.rs index c030e0620..f1d436b23 100644 --- a/ceno_zkvm/src/expression/monomial.rs +++ b/ceno_zkvm/src/expression/monomial.rs @@ -19,7 +19,7 @@ impl Expression { }] } - Fixed(_) | WitIn(_) | Challenge(..) => { + Fixed(_) | WitIn(_) | Instance(_) | Challenge(..) => { vec![Term { coeff: Expression::ONE, vars: vec![self.clone()], @@ -101,7 +101,7 @@ impl Ord for Expression { match (self, other) { (Fixed(a), Fixed(b)) => a.cmp(b), (WitIn(a), WitIn(b)) => a.cmp(b), - (Constant(a), Constant(b)) => cmp_field(a, b), + (Instance(a), Instance(b)) => a.cmp(b), (Challenge(a, b, c, d), Challenge(e, f, g, h)) => { let cmp = a.cmp(e); if cmp == Equal { @@ -116,30 +116,16 @@ impl Ord for Expression { cmp } } - (Sum(a, b), Sum(c, d)) => { - let cmp = a.cmp(c); - if cmp == Equal { b.cmp(d) } else { cmp } - } - (Product(a, b), Product(c, d)) => { - let cmp = a.cmp(c); - if cmp == Equal { b.cmp(d) } else { cmp } - } - (ScaledSum(x, a, b), ScaledSum(y, c, d)) => { - let cmp = x.cmp(y); - if cmp == Equal { - let cmp = a.cmp(c); - if cmp == Equal { b.cmp(d) } else { cmp } - } else { - cmp - } - } (Fixed(_), _) => Less, + (Instance(_), Fixed(_)) => Greater, + (Instance(_), _) => Less, + (WitIn(_), Fixed(_)) => Greater, + (WitIn(_), Instance(_)) => Greater, (WitIn(_), _) => Less, - (Constant(_), _) => Less, - (Challenge(..), _) => Less, - (Sum(..), _) => Less, - (Product(..), _) => Less, - (ScaledSum(..), _) => Less, + (Challenge(..), Fixed(_)) => Greater, + (Challenge(..), Instance(_)) => Greater, + (Challenge(..), WitIn(_)) => Greater, + _ => unreachable!(), } } } @@ -150,6 +136,7 @@ impl PartialOrd for Expression { } } +#[allow(dead_code)] fn cmp_field(a: &F, b: &F) -> Ordering { a.to_canonical_u64().cmp(&b.to_canonical_u64()) } diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index cedf9f96e..18a1fd7dc 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -6,6 +6,7 @@ pub mod branch; pub mod config; pub mod constants; pub mod divu; +pub mod ecall; pub mod logic; pub mod shift_imm; pub mod sltu; @@ -13,6 +14,8 @@ pub mod sltu; mod b_insn; mod i_insn; mod insn_base; + +mod ecall_insn; mod r_insn; #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index 041e55405..5289fbd68 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -1,6 +1,9 @@ use crate::uint::UIntLimbs; pub use ceno_emul::PC_STEP_SIZE; +pub const ECALL_HALT_OPCODE: [usize; 2] = [0x00_00, 0x00_00]; +pub const EXIT_PC: usize = 0; +pub const EXIT_CODE_IDX: usize = 0; pub const VALUE_BIT_WIDTH: usize = 16; #[cfg(feature = "riv32")] diff --git a/ceno_zkvm/src/instructions/riscv/ecall.rs b/ceno_zkvm/src/instructions/riscv/ecall.rs new file mode 100644 index 000000000..76c1c04e6 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/ecall.rs @@ -0,0 +1,3 @@ +mod halt; + +pub use halt::HaltInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs new file mode 100644 index 000000000..0f1aa16d5 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs @@ -0,0 +1,103 @@ +use crate::{ + chip_handler::RegisterChipOperations, + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{ToExpr, WitIn}, + gadgets::IsLtConfig, + instructions::{ + riscv::{ + constants::{ECALL_HALT_OPCODE, EXIT_PC}, + ecall_insn::EcallInstructionConfig, + }, + Instruction, + }, + set_val, + witness::LkMultiplicity, +}; +use ceno_emul::StepRecord; +use ff_ext::ExtensionField; +use std::{marker::PhantomData, mem::MaybeUninit}; + +pub struct HaltConfig { + ecall_cfg: EcallInstructionConfig, + prev_x10_ts: WitIn, + lt_x10_cfg: IsLtConfig, +} + +pub struct HaltInstruction(PhantomData); + +impl Instruction for HaltInstruction { + type InstructionConfig = HaltConfig; + + fn name() -> String { + "ECALL_HALT".into() + } + + fn construct_circuit(cb: &mut CircuitBuilder) -> Result { + let prev_x10_ts = cb.create_witin(|| "prev_x10_ts")?; + let exit_code = { + let exit_code = cb.query_exit_code()?; + [exit_code[0].expr(), exit_code[1].expr()] + }; + + let ecall_cfg = EcallInstructionConfig::construct_circuit( + cb, + [ECALL_HALT_OPCODE[0].into(), ECALL_HALT_OPCODE[1].into()], + None, + Some(EXIT_PC.into()), + )?; + + // read exit_code from arg0 (X10 register) + let (_, lt_x10_cfg) = cb.register_read( + || "read x10", + E::BaseField::from(ceno_emul::CENO_PLATFORM.reg_arg0() as u64), + prev_x10_ts.expr(), + ecall_cfg.ts.expr(), + exit_code, + )?; + + Ok(HaltConfig { + ecall_cfg, + prev_x10_ts, + lt_x10_cfg, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [MaybeUninit], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + assert_eq!( + step.rs1().unwrap().value, + (ECALL_HALT_OPCODE[0] + (ECALL_HALT_OPCODE[1] << 16)) as u32 + ); + assert_eq!( + step.pc().after.0, + 0, + "pc after ecall/halt {:x}", + step.pc().after.0 + ); + + // the access of X10 register is stored in rs2() + set_val!( + instance, + config.prev_x10_ts, + step.rs2().unwrap().previous_cycle + ); + + config.lt_x10_cfg.assign_instance( + instance, + lk_multiplicity, + step.rs2().unwrap().previous_cycle, + step.cycle(), + )?; + + config + .ecall_cfg + .assign_instance::(instance, lk_multiplicity, step)?; + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/ecall_insn.rs b/ceno_zkvm/src/instructions/riscv/ecall_insn.rs new file mode 100644 index 000000000..1d33b7443 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/ecall_insn.rs @@ -0,0 +1,97 @@ +use crate::{ + chip_handler::{ + GlobalStateRegisterMachineChipOperations, RegisterChipOperations, RegisterExpr, + }, + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, ToExpr, WitIn}, + gadgets::IsLtConfig, + set_val, + tables::InsnRecord, + witness::LkMultiplicity, +}; +use ceno_emul::{InsnKind::EANY, StepRecord, CENO_PLATFORM, PC_STEP_SIZE}; +use ff_ext::ExtensionField; +use std::mem::MaybeUninit; + +pub struct EcallInstructionConfig { + pub pc: WitIn, + pub ts: WitIn, + prev_x5_ts: WitIn, + lt_x5_cfg: IsLtConfig, +} + +impl EcallInstructionConfig { + pub fn construct_circuit( + cb: &mut CircuitBuilder, + syscall_id: RegisterExpr, + syscall_ret_value: Option>, + next_pc: Option>, + ) -> Result { + let pc = cb.create_witin(|| "pc")?; + let ts = cb.create_witin(|| "cur_ts")?; + + cb.state_in(pc.expr(), ts.expr())?; + + cb.lk_fetch(&InsnRecord::new( + pc.expr(), + (EANY.codes().opcode as usize).into(), + 0.into(), + (EANY.codes().func3 as usize).into(), + 0.into(), + 0.into(), + 0.into(), // imm = 0 + ))?; + + let prev_x5_ts = cb.create_witin(|| "prev_x5_ts")?; + + // read syscall_id from x5 and write return value to x5 + let (_, lt_x5_cfg) = cb.register_write( + || "write x5", + E::BaseField::from(CENO_PLATFORM.reg_ecall() as u64), + prev_x5_ts.expr(), + ts.expr(), + syscall_id.clone(), + syscall_ret_value.map_or(syscall_id, |v| v), + )?; + + cb.state_out( + next_pc.map_or(pc.expr() + PC_STEP_SIZE.into(), |next_pc| next_pc), + ts.expr() + 4.into(), + )?; + + Ok(Self { + pc, + ts, + prev_x5_ts, + lt_x5_cfg, + }) + } + + pub fn assign_instance( + &self, + instance: &mut [MaybeUninit], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + set_val!(instance, self.pc, step.pc().before.0 as u64); + set_val!(instance, self.ts, step.cycle()); + lk_multiplicity.fetch(step.pc().before.0); + + // the access of X5 register is stored in rs1() + set_val!( + instance, + self.prev_x5_ts, + step.rs1().unwrap().previous_cycle + ); + + self.lt_x5_cfg.assign_instance( + instance, + lk_multiplicity, + step.rs1().unwrap().previous_cycle, + step.cycle(), + )?; + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 9f3d7f869..8a3aef28a 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -88,7 +88,7 @@ impl ReadRS1 { let prev_ts = circuit_builder.create_witin(|| "prev_rs1_ts")?; let (_, lt_cfg) = circuit_builder.register_read( || "read_rs1", - &id, + id, prev_ts.expr(), cur_ts.expr() + (Tracer::SUBCYCLE_RS1 as usize).into(), rs1_read, @@ -143,7 +143,7 @@ impl ReadRS2 { let prev_ts = circuit_builder.create_witin(|| "prev_rs2_ts")?; let (_, lt_cfg) = circuit_builder.register_read( || "read_rs2", - &id, + id, prev_ts.expr(), cur_ts.expr() + (Tracer::SUBCYCLE_RS2 as usize).into(), rs2_read, @@ -199,7 +199,7 @@ impl WriteRD { let prev_value = UInt::new_unchecked(|| "prev_rd_value", circuit_builder)?; let (_, lt_cfg) = circuit_builder.register_write( || "write_rd", - &id, + id, prev_ts.expr(), cur_ts.expr() + (Tracer::SUBCYCLE_RD as usize).into(), prev_value.register_expr(), diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 2dfdaf44d..24b6aebd3 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -1,6 +1,6 @@ use ff_ext::ExtensionField; use mpcs::PolynomialCommitmentScheme; -use std::collections::BTreeMap; +use std::{collections::BTreeMap, fmt::Debug}; use sumcheck::structs::IOPProverMessage; use crate::structs::TowerProofs; @@ -66,18 +66,39 @@ pub struct ZKVMTableProof> pub wits_opening_proof: PCS::Proof, } +#[derive(Default, Clone, Debug)] +pub struct PublicValues { + exit_code: T, + end_pc: T, +} + +impl PublicValues { + pub fn new(exit_code: u32, end_pc: u32) -> Self { + Self { exit_code, end_pc } + } + pub fn to_vec(&self) -> Vec { + vec![ + E::BaseField::from((self.exit_code & 0xffff) as u64), + E::BaseField::from(((self.exit_code >> 16) & 0xffff) as u64), + E::BaseField::from(self.end_pc as u64), + ] + } +} + /// Map circuit names to /// - an opcode or table proof, /// - an index unique across both types. #[derive(Clone)] pub struct ZKVMProof> { + pub pv: Vec, opcode_proofs: BTreeMap)>, table_proofs: BTreeMap)>, } impl> ZKVMProof { - pub fn empty() -> Self { + pub fn empty(pv: PublicValues) -> Self { Self { + pv: pv.to_vec::(), opcode_proofs: BTreeMap::new(), table_proofs: BTreeMap::new(), } diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 3c7c88ce8..070be3c21 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -308,19 +308,20 @@ impl<'a, E: ExtensionField + Hash> MockProver { wits_in: &[ArcMultilinearExtension<'a, E>], challenge: [E; 2], ) -> Result<(), Vec>> { - Self::run_maybe_challenge(cb, wits_in, Some(challenge)) + Self::run_maybe_challenge(cb, wits_in, &[], Some(challenge)) } pub fn run( cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>], ) -> Result<(), Vec>> { - Self::run_maybe_challenge(cb, wits_in, None) + Self::run_maybe_challenge(cb, wits_in, &[], None) } fn run_maybe_challenge( cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>], + pi: &[E::BaseField], challenge: Option<[E; 2]>, ) -> Result<(), Vec>> { let table = challenge.map(|challenge| load_tables(cb, challenge)); @@ -344,11 +345,13 @@ impl<'a, E: ExtensionField + Hash> MockProver { .chain(&cb.cs.assert_zero_sumcheck_expressions_namespace_map), ) { - if name.contains("require_equal") { + // require_equal does not always have the form of Expr::Sum as + // the sum of witness and constant is expressed as scaled sum + if name.contains("require_equal") && expr.unpack_sum().is_some() { let (left, right) = expr.unpack_sum().unwrap(); let right = right.neg(); - let left_evaluated = wit_infer_by_expr(&[], wits_in, &challenge, &left); + let left_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, &left); let left_evaluated = left_evaluated .get_ext_field_vec_optn() .map(|v| v.to_vec()) @@ -360,7 +363,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { .collect_vec() }); - let right_evaluated = wit_infer_by_expr(&[], wits_in, &challenge, &right); + let right_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, &right); let right_evaluated = right_evaluated .get_ext_field_vec_optn() .map(|v| v.to_vec()) @@ -389,7 +392,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { } } else { // contains require_zero - let expr_evaluated = wit_infer_by_expr(&[], wits_in, &challenge, &expr); + let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, &expr); let expr_evaluated = expr_evaluated .get_ext_field_vec_optn() .map(|v| v.to_vec()) @@ -421,7 +424,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { .iter() .zip_eq(cb.cs.lk_expressions_namespace_map.iter()) { - let expr_evaluated = wit_infer_by_expr(&[], wits_in, &challenge, expr); + let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr); let expr_evaluated = expr_evaluated.get_ext_field_vec(); // Check each lookup expr exists in t vec @@ -481,6 +484,7 @@ mod tests { gadgets::IsLtConfig, set_val, witness::{LkMultiplicity, RowMajorMatrix}, + ROMType::U5, }; use ff::Field; use goldilocks::{Goldilocks, GoldilocksExt2}; @@ -592,15 +596,20 @@ mod tests { assert_eq!( err, vec![MockProverError::LookupError { - expression: Expression::ScaledSum( - Box::new(Expression::WitIn(0)), - Box::new(Expression::Challenge( - 1, - 1, - // TODO this still uses default challenge in ConstraintSystem, but challengeId - // helps to evaluate the expression correctly. Shoudl challenge be just challengeId? - GoldilocksExt2::ONE, - GoldilocksExt2::ZERO, + expression: Expression::Sum( + Box::new(Expression::ScaledSum( + Box::new(Expression::WitIn(0)), + Box::new(Expression::Challenge( + 1, + 1, + // TODO this still uses default challenge in ConstraintSystem, but challengeId + // helps to evaluate the expression correctly. Shoudl challenge be just challengeId? + GoldilocksExt2::ONE, + GoldilocksExt2::ZERO, + )), + Box::new(Expression::Constant( + ::BaseField::from(U5 as u64) + )), )), Box::new(Expression::Challenge( 0, diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 5803ca115..da896499c 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -35,7 +35,7 @@ use crate::{ virtual_polys::VirtualPolynomials, }; -use super::{ZKVMOpcodeProof, ZKVMProof, ZKVMTableProof}; +use super::{PublicValues, ZKVMOpcodeProof, ZKVMProof, ZKVMTableProof}; pub struct ZKVMProver> { pub pk: ZKVMProvingKey, @@ -50,10 +50,12 @@ impl> ZKVMProver { pub fn create_proof( &self, witnesses: ZKVMWitnesses, + pi: PublicValues, max_threads: usize, mut transcript: Transcript, ) -> Result, ZKVMError> { - let mut vm_proof = ZKVMProof::empty(); + let mut vm_proof = ZKVMProof::empty(pi); + let pi = &vm_proof.pv; // commit to fixed commitment for (_, pk) in self.pk.circuit_pks.iter() { @@ -125,6 +127,7 @@ impl> ZKVMProver { pk, witness.into_iter().map(|w| w.into()).collect_vec(), wits_commit, + pi, num_instances, max_threads, transcript, @@ -145,6 +148,7 @@ impl> ZKVMProver { pk, witness.into_iter().map(|v| v.into()).collect_vec(), wits_commit, + pi, num_instances, max_threads, transcript, @@ -175,6 +179,7 @@ impl> ZKVMProver { circuit_pk: &ProvingKey, witnesses: Vec>, wits_commit: PCS::CommitmentWithData, + pi: &[E::BaseField], num_instances: usize, max_threads: usize, transcript: &mut Transcript, @@ -202,7 +207,7 @@ impl> ZKVMProver { .chain(cs.lk_expressions.par_iter()) .map(|expr| { assert_eq!(expr.degree(), 1); - wit_infer_by_expr(&[], &witnesses, challenges, expr) + wit_infer_by_expr(&[], &witnesses, pi, challenges, expr) }) .collect(); let (r_records_wit, w_lk_records_wit) = records_wit.split_at(cs.r_expressions.len()); @@ -479,7 +484,8 @@ impl> ZKVMProver { { // sanity check in debug build and output != instance index for zero check sumcheck poly if cfg!(debug_assertions) { - let expected_zero_poly = wit_infer_by_expr(&[], &witnesses, challenges, expr); + let expected_zero_poly = + wit_infer_by_expr(&[], &witnesses, pi, challenges, expr); let top_100_errors = expected_zero_poly .get_ext_field_vec() .iter() @@ -609,6 +615,7 @@ impl> ZKVMProver { circuit_pk: &ProvingKey, witnesses: Vec>, wits_commit: PCS::CommitmentWithData, + pi: &[E::BaseField], num_instances: usize, max_threads: usize, transcript: &mut Transcript, @@ -646,7 +653,7 @@ impl> ZKVMProver { ) .map(|expr| { assert_eq!(expr.degree(), 1); - wit_infer_by_expr(&fixed, &witnesses, challenges, expr) + wit_infer_by_expr(&fixed, &witnesses, pi, challenges, expr) }) .collect(); let (lk_d_wit, lk_n_wit) = records_wit.split_at(cs.lk_table_expressions.len()); diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 47de45ed6..30969d2ab 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -1,6 +1,10 @@ use std::{marker::PhantomData, mem::MaybeUninit}; -use ceno_emul::{Change, StepRecord}; +use ceno_emul::{ + ByteAddr, + InsnKind::{ADD, EANY}, + StepRecord, VMState, CENO_PLATFORM, +}; use ff::Field; use ff_ext::ExtensionField; use goldilocks::GoldilocksExt2; @@ -13,7 +17,10 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, expression::{Expression, ToExpr, WitIn}, - instructions::{riscv::arith::AddInstruction, Instruction}, + instructions::{ + riscv::{arith::AddInstruction, ecall::HaltInstruction}, + Instruction, + }, set_val, structs::{PointAndEval, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, tables::{ProgramTableCircuit, U16TableCircuit}, @@ -24,6 +31,7 @@ use super::{ constants::{MAX_NUM_VARIABLES, NUM_FANIN}, prover::ZKVMProver, verifier::ZKVMVerifier, + PublicValues, }; struct TestConfig { @@ -130,6 +138,7 @@ fn test_rw_lk_expression_combination() { prover.pk.circuit_pks.get(&name).unwrap(), wits_in, commit, + &[], num_instances, 1, &mut transcript, @@ -154,6 +163,7 @@ fn test_rw_lk_expression_combination() { &vk.vp, verifier.vk.circuit_vks.get(&name).unwrap(), &proof, + &[], &mut v_transcript, NUM_FANIN, &PointAndEval::default(), @@ -178,6 +188,7 @@ const PROGRAM_CODE: [u32; 4] = [ ECALL_HALT, // ecall halt ECALL_HALT, // ecall halt ]; +#[ignore = "this case is already tested in riscv_example as ecall_halt has only one instance"] #[test] fn test_single_add_instance_e2e() { type E = GoldilocksExt2; @@ -188,6 +199,7 @@ fn test_single_add_instance_e2e() { let mut zkvm_cs = ZKVMConstraintSystem::default(); // opcode circuits let add_config = zkvm_cs.register_opcode_circuit::>(); + let halt_config = zkvm_cs.register_opcode_circuit::>(); let u16_range_config = zkvm_cs.register_table_circuit::>(); let prog_config = zkvm_cs.register_table_circuit::>(); @@ -195,6 +207,7 @@ fn test_single_add_instance_e2e() { let program_code: Vec = PROGRAM_CODE.to_vec(); let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); + zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); zkvm_fixed_traces.register_table_circuit::>( &zkvm_cs, @@ -215,18 +228,34 @@ fn test_single_add_instance_e2e() { let vk = pk.get_vk(); // single instance - let add_records = vec![StepRecord::new_r_instruction( - 4, - 0x20000000.into(), - 4227635, - 1, - 0, - Change { - before: 0, - after: 1, - }, - 3, - )]; + let mut vm = VMState::new(CENO_PLATFORM); + let pc_start = ByteAddr(CENO_PLATFORM.pc_start()).waddr(); + for (i, insn) in PROGRAM_CODE.iter().enumerate() { + vm.init_memory(pc_start + i, *insn); + } + let all_records = vm + .iter_until_halt() + .collect::, _>>() + .expect("vm exec failed") + .into_iter() + .collect::>(); + let mut add_records = vec![]; + let mut halt_records = vec![]; + all_records.into_iter().for_each(|record| { + let kind = record.insn().kind().1; + match kind { + ADD => add_records.push(record), + EANY => { + if record.rs1().unwrap().value == CENO_PLATFORM.ecall_halt() { + halt_records.push(record); + } + } + _ => {} + } + }); + assert_eq!(add_records.len(), 1); + assert_eq!(halt_records.len(), 1); + // proving let prover = ZKVMProver::new(pk); let verifier = ZKVMVerifier::new(vk); @@ -235,6 +264,9 @@ fn test_single_add_instance_e2e() { zkvm_witness .assign_opcode_circuit::>(&zkvm_cs, &add_config, add_records) .unwrap(); + zkvm_witness + .assign_opcode_circuit::>(&zkvm_cs, &halt_config, halt_records) + .unwrap(); zkvm_witness.finalize_lk_multiplicities(); zkvm_witness .assign_table_circuit::>(&zkvm_cs, &u16_range_config, &()) @@ -243,9 +275,10 @@ fn test_single_add_instance_e2e() { .assign_table_circuit::>(&zkvm_cs, &prog_config, &program_code.len()) .unwrap(); + let pi = PublicValues::new(0, 0); let transcript = Transcript::new(b"riscv"); let zkvm_proof = prover - .create_proof(zkvm_witness, 1, transcript) + .create_proof(zkvm_witness, pi, 1, transcript) .expect("create_proof failed"); let transcript = Transcript::new(b"riscv"); diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 1a10c1e92..b788648e5 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -226,12 +226,17 @@ pub(crate) fn infer_tower_product_witness( pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( fixed: &[ArcMultilinearExtension<'a, E>], witnesses: &[ArcMultilinearExtension<'a, E>], + instance: &[E::BaseField], challenges: &[E; N], expr: &Expression, ) -> ArcMultilinearExtension<'a, E> { - expr.evaluate::>( + expr.evaluate_with_instance::>( &|f| fixed[f.0].clone(), &|witness_id| witnesses[witness_id as usize].clone(), + &|i| { + let i = instance[i.0]; + Arc::new(DenseMultilinearExtension::from_evaluations_vec(0, vec![i])) + }, &|scalar| { let scalar: ArcMultilinearExtension = Arc::new( DenseMultilinearExtension::from_evaluations_vec(0, vec![scalar]), @@ -239,7 +244,7 @@ pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( scalar }, &|challenge_id, pow, scalar, offset| { - // TODO cache challenge power to be aquire once for each power + // TODO cache challenge power to be acquired once for each power let challenge = challenges[challenge_id as usize]; let challenge: ArcMultilinearExtension = Arc::new(DenseMultilinearExtension::from_evaluations_ext_vec( @@ -345,6 +350,7 @@ pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( ) } +#[allow(dead_code)] pub(crate) fn eval_by_expr( witnesses: &[E], challenges: &[E], @@ -353,6 +359,7 @@ pub(crate) fn eval_by_expr( eval_by_expr_with_fixed(&[], witnesses, challenges, expr) } +#[allow(dead_code)] pub(crate) fn eval_by_expr_with_fixed( fixed: &[E], witnesses: &[E], @@ -374,6 +381,29 @@ pub(crate) fn eval_by_expr_with_fixed( ) } +pub(crate) fn eval_by_expr_with_instance( + fixed: &[E], + witnesses: &[E], + instance: &[E::BaseField], + challenges: &[E], + expr: &Expression, +) -> E { + expr.evaluate_with_instance::( + &|f| fixed[f.0], + &|witness_id| witnesses[witness_id as usize], + &|i| E::from(instance[i.0]), + &|scalar| scalar.into(), + &|challenge_id, pow, scalar, offset| { + // TODO cache challenge power to be acquired once for each power + let challenge = challenges[challenge_id as usize]; + challenge.pow([pow as u64]) * scalar + offset + }, + &|a, b| a + b, + &|a, b| a * b, + &|x, a, b| a * x + b, + ) +} + #[cfg(test)] mod tests { use ff::Field; diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 0c7a13266..f78ccaf5b 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -15,17 +15,17 @@ use transcript::Transcript; use crate::{ error::ZKVMError, + instructions::{riscv::ecall::HaltInstruction, Instruction}, scheme::{ constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEL_DEGREE}, - utils::eval_by_expr_with_fixed, + utils::eval_by_expr_with_instance, }, structs::{Point, PointAndEval, TowerProofs, VerifyingKey, ZKVMVerifyingKey}, utils::{eq_eval_less_or_equal_than, get_challenge_pows, next_pow2_instance_padding}, }; use super::{ - constants::MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, utils::eval_by_expr, ZKVMOpcodeProof, ZKVMProof, - ZKVMTableProof, + constants::MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, ZKVMOpcodeProof, ZKVMProof, ZKVMTableProof, }; pub struct ZKVMVerifier> { @@ -46,6 +46,22 @@ impl> ZKVMVerifier let mut prod_r = E::ONE; let mut prod_w = E::ONE; let mut logup_sum = E::ZERO; + let pi = &vm_proof.pv; + + // require ecall/halt proof to exist + { + if let Some((_, proof)) = vm_proof.opcode_proofs.get(&HaltInstruction::::name()) { + if proof.num_instances != 1 { + return Err(ZKVMError::VerifyError( + "ecall/halt num_instances != 1".into(), + )); + } + } else { + return Err(ZKVMError::VerifyError( + "ecall/halt proof does not exist".into(), + )); + } + } // write fixed commitment to transcript for (_, vk) in self.vk.circuit_vks.iter() { @@ -89,6 +105,7 @@ impl> ZKVMVerifier &self.vk.vp, circuit_vk, &opcode_proof, + pi, transcript, NUM_FANIN, &point_eval, @@ -127,6 +144,7 @@ impl> ZKVMVerifier &self.vk.vp, circuit_vk, &table_proof, + &vm_proof.pv, transcript, NUM_FANIN_LOGUP, &point_eval, @@ -165,6 +183,7 @@ impl> ZKVMVerifier vp: &PCS::VerifierParam, circuit_vk: &VerifyingKey, proof: &ZKVMOpcodeProof, + pi: &[E::BaseField], transcript: &mut Transcript, num_product_fanin: usize, _out_evals: &PointAndEval, @@ -339,7 +358,14 @@ impl> ZKVMVerifier .zip_eq(alpha_pow_iter) .map(|(expr, alpha)| { // evaluate zero expression by all wits_in_evals because they share the unique input_opening_point opening - *alpha * eval_by_expr(&proof.wits_in_evals, challenges, expr) + *alpha + * eval_by_expr_with_instance( + &[], + &proof.wits_in_evals, + pi, + challenges, + expr, + ) }) .sum::() }, @@ -364,7 +390,8 @@ impl> ZKVMVerifier .chain(proof.lk_records_in_evals[..lk_counts_per_instance].iter()), ) .any(|(expr, expected_evals)| { - eval_by_expr(&proof.wits_in_evals, challenges, expr) != *expected_evals + eval_by_expr_with_instance(&[], &proof.wits_in_evals, pi, challenges, expr) + != *expected_evals }) { return Err(ZKVMError::VerifyError( @@ -373,11 +400,9 @@ impl> ZKVMVerifier } // verify zero expression (degree = 1) statement, thus no sumcheck - if cs - .assert_zero_expressions - .iter() - .any(|expr| eval_by_expr(&proof.wits_in_evals, challenges, expr) != E::ZERO) - { + if cs.assert_zero_expressions.iter().any(|expr| { + eval_by_expr_with_instance(&[], &proof.wits_in_evals, pi, challenges, expr) != E::ZERO + }) { // TODO add me back // return Err(ZKVMError::VerifyError("zero expression != 0")); } @@ -408,6 +433,7 @@ impl> ZKVMVerifier vp: &PCS::VerifierParam, circuit_vk: &VerifyingKey, proof: &ZKVMTableProof, + pi: &[E::BaseField], transcript: &mut Transcript, num_logup_fanin: usize, _out_evals: &PointAndEval, @@ -512,9 +538,10 @@ impl> ZKVMVerifier .chain(proof.lk_n_in_evals[..lk_counts_per_instance].iter()), ) .any(|(expr, expected_evals)| { - eval_by_expr_with_fixed( + eval_by_expr_with_instance( &proof.fixed_in_evals, &proof.wits_in_evals, + pi, challenges, expr, ) != *expected_evals diff --git a/ceno_zkvm/src/tables/ops/ops_impl.rs b/ceno_zkvm/src/tables/ops/ops_impl.rs index 4425a2573..937e3fa2e 100644 --- a/ceno_zkvm/src/tables/ops/ops_impl.rs +++ b/ceno_zkvm/src/tables/ops/ops_impl.rs @@ -35,9 +35,9 @@ impl OpTableConfig { let rlc_record = cb.rlc_chip_record(vec![ (rom_type as usize).into(), - Expression::Fixed(abc[0].clone()), - Expression::Fixed(abc[1].clone()), - Expression::Fixed(abc[2].clone()), + Expression::Fixed(abc[0]), + Expression::Fixed(abc[1]), + Expression::Fixed(abc[2]), ]); cb.lk_table_record(|| "record", rlc_record, mlt.expr())?; diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 8a3b50070..164efad5a 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -123,12 +123,7 @@ impl TableCircuit for ProgramTableCircuit { let record_exprs = { let mut fields = vec![E::BaseField::from(ROMType::Instruction as u64).expr()]; - fields.extend( - record - .as_slice() - .iter() - .map(|f| Expression::Fixed(f.clone())), - ); + fields.extend(record.as_slice().iter().map(|f| Expression::Fixed(*f))); cb.rlc_chip_record(fields) }; diff --git a/ceno_zkvm/src/tables/range/range_impl.rs b/ceno_zkvm/src/tables/range/range_impl.rs index 8344c43f1..f40f5d279 100644 --- a/ceno_zkvm/src/tables/range/range_impl.rs +++ b/ceno_zkvm/src/tables/range/range_impl.rs @@ -29,10 +29,8 @@ impl RangeTableConfig { let fixed = cb.create_fixed(|| "fixed")?; let mlt = cb.create_witin(|| "mlt")?; - let rlc_record = cb.rlc_chip_record(vec![ - (rom_type as usize).into(), - Expression::Fixed(fixed.clone()), - ]); + let rlc_record = + cb.rlc_chip_record(vec![(rom_type as usize).into(), Expression::Fixed(fixed)]); cb.lk_table_record(|| "record", rlc_record, mlt.expr())?;