Skip to content

Commit

Permalink
Feat: Ecall/Halt (#258)
Browse files Browse the repository at this point in the history
Fixes #125.

- [x] Ecall instruction config
- [x] Ecall/Halt circuit.
- [x] Added public input in the prover framework and allows the circuit
builder to access public inputs.
  • Loading branch information
kunxian-xia authored Oct 7, 2024
1 parent b7eb465 commit d3ea040
Show file tree
Hide file tree
Showing 25 changed files with 665 additions and 139 deletions.
5 changes: 5 additions & 0 deletions ceno_emul/src/platform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ impl Platform {
10
}

/// Register containing the 2nd function argument. (x11, a1)
pub const fn reg_arg1(&self) -> RegIdx {
11
}

/// The code of ecall HALT.
pub const fn ecall_halt(&self) -> u32 {
0
Expand Down
5 changes: 4 additions & 1 deletion ceno_emul/src/vm_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ impl EmuContext for VMState {
fn ecall(&mut self) -> Result<bool> {
let function = self.load_register(self.platform.reg_ecall())?;
if function == self.platform.ecall_halt() {
let _exit_code = self.load_register(self.platform.reg_arg0())?;
let exit_code = self.load_register(self.platform.reg_arg0())?;
tracing::debug!("halt with exit_code={}", exit_code);

self.set_pc(ByteAddr(0));
self.halted = true;
Ok(true)
} else {
Expand Down
1 change: 1 addition & 0 deletions ceno_zkvm/benches/riscv_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ fn bench_add(c: &mut Criterion) {
&circuit_pk,
wits_in.into_iter().map(|mle| mle.into()).collect_vec(),
commit,
&[],
num_instances,
max_threads,
&mut transcript,
Expand Down
46 changes: 36 additions & 10 deletions ceno_zkvm/examples/riscv_opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ use const_env::from_env;

use ceno_emul::{
ByteAddr,
InsnKind::{ADD, BLTU},
InsnKind::{ADD, BLTU, EANY},
StepRecord, VMState, CENO_PLATFORM,
};
use ceno_zkvm::{
scheme::{constants::MAX_NUM_VARIABLES, verifier::ZKVMVerifier},
instructions::riscv::ecall::HaltInstruction,
scheme::{constants::MAX_NUM_VARIABLES, verifier::ZKVMVerifier, PublicValues},
structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses},
tables::{AndTableCircuit, LtuTableCircuit, U16TableCircuit},
};
use ff_ext::ff::Field;
use goldilocks::GoldilocksExt2;
use mpcs::{Basefold, BasefoldRSParams, PolynomialCommitmentScheme};
use rand_chacha::ChaCha8Rng;
Expand Down Expand Up @@ -103,6 +105,7 @@ fn main() {
// opcode circuits
let add_config = zkvm_cs.register_opcode_circuit::<AddInstruction<E>>();
let bltu_config = zkvm_cs.register_opcode_circuit::<BltuInstruction>();
let halt_config = zkvm_cs.register_opcode_circuit::<HaltInstruction<E>>();
// tables
let u16_range_config = zkvm_cs.register_table_circuit::<U16TableCircuit<E>>();
let and_config = zkvm_cs.register_table_circuit::<AndTableCircuit<E>>();
Expand All @@ -118,6 +121,7 @@ fn main() {
let mut zkvm_fixed_traces = ZKVMFixedTraces::default();
zkvm_fixed_traces.register_opcode_circuit::<AddInstruction<E>>(&zkvm_cs);
zkvm_fixed_traces.register_opcode_circuit::<BltuInstruction>(&zkvm_cs);
zkvm_fixed_traces.register_opcode_circuit::<HaltInstruction<E>>(&zkvm_cs);

zkvm_fixed_traces.register_table_circuit::<U16TableCircuit<E>>(
&zkvm_cs,
Expand Down Expand Up @@ -172,15 +176,25 @@ fn main() {
.collect::<Vec<_>>();
let mut add_records = Vec::new();
let mut bltu_records = Vec::new();
all_records.iter().for_each(|record| {
let mut halt_records = Vec::new();
all_records.into_iter().for_each(|record| {
let kind = record.insn().kind().1;
if kind == ADD {
add_records.push(record.clone());
} else if kind == BLTU {
bltu_records.push(record.clone());
match kind {
ADD => add_records.push(record),
BLTU => bltu_records.push(record),
EANY => {
if record.rs1().unwrap().value == CENO_PLATFORM.ecall_halt() {
halt_records.push(record);
}
}
_ => {}
}
});

assert_eq!(halt_records.len(), 1);
let exit_code = halt_records[0].rs2().unwrap().value;
let pi = PublicValues::new(exit_code, 0);

tracing::info!(
"tracer generated {} ADD records, {} BLTU records",
add_records.len(),
Expand All @@ -195,6 +209,10 @@ fn main() {
zkvm_witness
.assign_opcode_circuit::<BltuInstruction>(&zkvm_cs, &bltu_config, bltu_records)
.unwrap();
zkvm_witness
.assign_opcode_circuit::<HaltInstruction<E>>(&zkvm_cs, &halt_config, halt_records)
.unwrap();

zkvm_witness.finalize_lk_multiplicities();
// assign table circuits
zkvm_witness
Expand All @@ -217,8 +235,8 @@ fn main() {
let timer = Instant::now();

let transcript = Transcript::new(b"riscv");
let zkvm_proof = prover
.create_proof(zkvm_witness, max_threads, transcript)
let mut zkvm_proof = prover
.create_proof(zkvm_witness, pi, max_threads, transcript)
.expect("create_proof failed");

println!(
Expand All @@ -230,8 +248,16 @@ fn main() {
let transcript = Transcript::new(b"riscv");
assert!(
verifier
.verify_proof(zkvm_proof, transcript)
.verify_proof(zkvm_proof.clone(), transcript)
.expect("verify proof return with error"),
);

let transcript = Transcript::new(b"riscv");
// change public input maliciously should cause verifier to reject proof
zkvm_proof.pv[0] = <GoldilocksExt2 as ff_ext::ExtensionField>::BaseField::ONE;
zkvm_proof.pv[1] = <GoldilocksExt2 as ff_ext::ExtensionField>::BaseField::ONE;
verifier
.verify_proof(zkvm_proof, transcript)
.expect_err("verify proof should return with error");
}
}
6 changes: 3 additions & 3 deletions ceno_zkvm/src/chip_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use ff_ext::ExtensionField;

use crate::{
error::ZKVMError,
expression::{Expression, WitIn},
expression::{Expression, ToExpr, WitIn},
gadgets::IsLtConfig,
instructions::riscv::constants::UINT_LIMBS,
};
Expand All @@ -27,7 +27,7 @@ pub trait RegisterChipOperations<E: ExtensionField, NR: Into<String>, N: FnOnce(
fn register_read(
&mut self,
name_fn: N,
register_id: &WitIn,
register_id: impl ToExpr<E, Output = Expression<E>>,
prev_ts: Expression<E>,
ts: Expression<E>,
value: RegisterExpr<E>,
Expand All @@ -37,7 +37,7 @@ pub trait RegisterChipOperations<E: ExtensionField, NR: Into<String>, N: FnOnce(
fn register_write(
&mut self,
name_fn: N,
register_id: &WitIn,
register_id: impl ToExpr<E, Output = Expression<E>>,
prev_ts: Expression<E>,
ts: Expression<E>,
prev_values: RegisterExpr<E>,
Expand Down
11 changes: 10 additions & 1 deletion ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ use ff_ext::ExtensionField;
use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
error::ZKVMError,
expression::{Expression, Fixed, ToExpr, WitIn},
expression::{Expression, Fixed, Instance, ToExpr, WitIn},
gadgets::IsLtConfig,
instructions::riscv::constants::EXIT_CODE_IDX,
structs::ROMType,
tables::InsnRecord,
};
Expand Down Expand Up @@ -34,6 +35,14 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
self.cs.create_fixed(name_fn)
}

pub fn query_exit_code(&mut self) -> Result<[Instance; 2], ZKVMError> {
Ok([
self.cs.query_instance(|| "exit_code_low", EXIT_CODE_IDX)?,
self.cs
.query_instance(|| "exit_code_high", EXIT_CODE_IDX + 1)?,
])
}

pub fn lk_record<NR, N>(
&mut self,
name_fn: N,
Expand Down
7 changes: 4 additions & 3 deletions ceno_zkvm/src/chip_handler/register.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use ff_ext::ExtensionField;
use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::{Expression, ToExpr, WitIn},
expression::{Expression, ToExpr},
gadgets::IsLtConfig,
instructions::riscv::constants::UINT_LIMBS,
structs::RAMType,
Expand All @@ -17,7 +17,7 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOpe
fn register_read(
&mut self,
name_fn: N,
register_id: &WitIn,
register_id: impl ToExpr<E, Output = Expression<E>>,
prev_ts: Expression<E>,
ts: Expression<E>,
value: RegisterExpr<E>,
Expand Down Expand Up @@ -68,12 +68,13 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOpe
fn register_write(
&mut self,
name_fn: N,
register_id: &WitIn,
register_id: impl ToExpr<E, Output = Expression<E>>,
prev_ts: Expression<E>,
ts: Expression<E>,
prev_values: RegisterExpr<E>,
value: RegisterExpr<E>,
) -> Result<(Expression<E>, IsLtConfig), ZKVMError> {
assert!(register_id.expr().degree() <= 1);
self.namespace(name_fn, |cb| {
// READ (a, v, t)
let read_record = cb.rlc_chip_record(
Expand Down
20 changes: 18 additions & 2 deletions ceno_zkvm/src/circuit_builder.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use itertools::Itertools;
use std::marker::PhantomData;
use std::{collections::HashMap, marker::PhantomData};

use ff_ext::ExtensionField;
use mpcs::PolynomialCommitmentScheme;

use crate::{
error::ZKVMError,
expression::{Expression, Fixed, WitIn},
expression::{Expression, Fixed, Instance, WitIn},
structs::{ProvingKey, VerifyingKey, WitnessId},
witness::RowMajorMatrix,
};
Expand Down Expand Up @@ -79,6 +79,8 @@ pub struct ConstraintSystem<E: ExtensionField> {
pub num_fixed: usize,
pub fixed_namespace_map: Vec<String>,

pub instance_name_map: HashMap<Instance, String>,

pub r_expressions: Vec<Expression<E>>,
pub r_expressions_namespace_map: Vec<String>,

Expand Down Expand Up @@ -117,6 +119,7 @@ impl<E: ExtensionField> ConstraintSystem<E> {
num_fixed: 0,
fixed_namespace_map: vec![],
ns: NameSpace::new(root_name_fn),
instance_name_map: HashMap::new(),
r_expressions: vec![],
r_expressions_namespace_map: vec![],
w_expressions: vec![],
Expand Down Expand Up @@ -193,6 +196,19 @@ impl<E: ExtensionField> ConstraintSystem<E> {
Ok(f)
}

pub fn query_instance<NR: Into<String>, N: FnOnce() -> NR>(
&mut self,
n: N,
idx: usize,
) -> Result<Instance, ZKVMError> {
let i = Instance(idx);

let name = n().into();
self.instance_name_map.insert(i, name);

Ok(i)
}

pub fn lk_record<NR: Into<String>, N: FnOnce() -> NR>(
&mut self,
name_fn: N,
Expand Down
Loading

0 comments on commit d3ea040

Please sign in to comment.