Skip to content

Commit

Permalink
update riscv_add example
Browse files Browse the repository at this point in the history
  • Loading branch information
kunxian-xia committed Sep 5, 2024
1 parent e714621 commit 7084e09
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 43 deletions.
108 changes: 77 additions & 31 deletions ceno_zkvm/examples/riscv_add.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::time::Instant;
use std::{collections::BTreeMap, time::Instant};

use ark_std::test_rng;
use ceno_zkvm::{
Expand All @@ -8,6 +8,12 @@ use ceno_zkvm::{
};
use const_env::from_env;

use ceno_emul::StepRecord;
use ceno_zkvm::{
circuit_builder::{ZKVMConstraintSystem, ZKVMVerifyingKey},
scheme::verifier::ZKVMVerifier,
tables::{RangeTableCircuit, TableCircuit},
};
use ff_ext::ff::Field;
use goldilocks::{Goldilocks, GoldilocksExt2};
use itertools::Itertools;
Expand All @@ -21,6 +27,8 @@ use transcript::Transcript;
const RAYON_NUM_THREADS: usize = 8;

fn main() {
type E = GoldilocksExt2;

let max_threads = {
if !is_power_of_2(RAYON_NUM_THREADS) {
#[cfg(not(feature = "non_pow2_rayon_thread"))]
Expand All @@ -41,16 +49,6 @@ fn main() {
RAYON_NUM_THREADS
}
};
let mut cs = ConstraintSystem::new(|| "risv_add");
let mut circuit_builder = CircuitBuilder::<GoldilocksExt2>::new(&mut cs);
let _ = AddInstruction::construct_circuit(&mut circuit_builder);
let pk = cs.key_gen(None);
let num_witin = pk.get_cs().num_witin;

let prover = ZKVMProver::new(pk);
let mut transcript = Transcript::new(b"riscv");
let mut rng = test_rng();
let real_challenges = [E::random(&mut rng), E::random(&mut rng)];

let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap();
let subscriber = Registry::default()
Expand All @@ -64,34 +62,82 @@ fn main() {
.with(flame_layer.with_threads_collapsed(true));
tracing::subscriber::set_global_default(subscriber).unwrap();

// keygen
let mut zkvm_fixed_traces = BTreeMap::default();
let mut zkvm_cs = ZKVMConstraintSystem::default();

let (add_cs, add_config) = {
let mut cs = ConstraintSystem::new(|| "riscv_add");
let mut circuit_builder = CircuitBuilder::<E>::new(&mut cs);
let config = AddInstruction::construct_circuit(&mut circuit_builder).unwrap();
zkvm_cs.add_cs(AddInstruction::<E>::name(), cs.clone());
zkvm_fixed_traces.insert(AddInstruction::<E>::name(), None);
(cs, config)
};
let (range_cs, range_config) = {
let mut cs = ConstraintSystem::new(|| "riscv_range");
let mut circuit_builder = CircuitBuilder::<E>::new(&mut cs);
let config = RangeTableCircuit::construct_circuit(&mut circuit_builder).unwrap();
zkvm_cs.add_cs(
<RangeTableCircuit<E> as TableCircuit<E>>::name(),
cs.clone(),
);
zkvm_fixed_traces.insert(
<RangeTableCircuit<E> as TableCircuit<E>>::name(),
Some(RangeTableCircuit::<E>::generate_fixed_traces(
&config,
cs.num_fixed,
)),
);
(cs, config)
};
let pk = zkvm_cs.key_gen(zkvm_fixed_traces);
let vk = pk.get_vk();

// proving
let prover = ZKVMProver::new(pk);
let verifier = ZKVMVerifier::new(vk);

for instance_num_vars in 20..22 {
// generate mock witness
// TODO: witness generation from step records emitted by tracer
let num_instances = 1 << instance_num_vars;
let wits_in = (0..num_witin as usize)
.map(|_| {
(0..num_instances)
.map(|_| Goldilocks::random(&mut rng))
.collect::<Vec<Goldilocks>>()
.into_mle()
.into()
})
.collect_vec();
let mut zkvm_witness = BTreeMap::default();
let add_witness = AddInstruction::assign_instances(
&add_config,
add_cs.num_witin as usize,
vec![StepRecord::default(); num_instances],
)
.unwrap();
let range_witness = RangeTableCircuit::<E>::assign_instances(
&range_config,
range_cs.num_witin as usize,
&[],
)
.unwrap();

zkvm_witness.insert(AddInstruction::<E>::name(), add_witness);
zkvm_witness.insert(RangeTableCircuit::<E>::name(), range_witness);

let timer = Instant::now();
let _ = prover
.create_opcode_proof(
wits_in,
num_instances,
max_threads,
&mut transcript,
&real_challenges,
)

let mut transcript = Transcript::new(b"riscv");
let mut rng = test_rng();
let real_challenges = [E::random(&mut rng), E::random(&mut rng)];

let zkvm_proof = prover
.create_proof(zkvm_witness, max_threads, &mut transcript, &real_challenges)
.expect("create_proof failed");

assert!(
verifier
.verify_proof(zkvm_proof, &mut transcript, &real_challenges,)
.expect("verify proof return with error"),
);

println!(
"AddInstruction::create_proof, instance_num_vars = {}, time = {}",
instance_num_vars,
timer.elapsed().as_secs_f64()
);
}

type E = GoldilocksExt2;
}
6 changes: 6 additions & 0 deletions ceno_zkvm/src/circuit_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,12 @@ pub struct ZKVMConstraintSystem<E: ExtensionField> {
pub circuit_css: BTreeMap<String, ConstraintSystem<E>>,
}

impl<E: ExtensionField> ZKVMConstraintSystem<E> {
pub fn add_cs(&mut self, name: String, cs: ConstraintSystem<E>) {
assert!(self.circuit_css.insert(name, cs).is_none());
}
}

#[derive(Default)]
pub struct ZKVMProvingKey<E: ExtensionField> {
// pk for opcode and table circuits
Expand Down
12 changes: 6 additions & 6 deletions ceno_zkvm/src/instructions/riscv/addsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ use crate::{
};
use core::mem::MaybeUninit;

pub struct AddInstruction;
pub struct SubInstruction;
pub struct AddInstruction<E>(PhantomData<E>);
pub struct SubInstruction<E>(PhantomData<E>);

#[derive(Debug)]
pub struct InstructionConfig<E: ExtensionField> {
Expand All @@ -37,11 +37,11 @@ pub struct InstructionConfig<E: ExtensionField> {
phantom: PhantomData<E>,
}

impl<E: ExtensionField> RIVInstruction<E> for AddInstruction {
impl<E: ExtensionField> RIVInstruction<E> for AddInstruction<E> {
const OPCODE_TYPE: OpcodeType = OpcodeType::RType(OPType::Op, 0x000, 0x0000000);
}

impl<E: ExtensionField> RIVInstruction<E> for SubInstruction {
impl<E: ExtensionField> RIVInstruction<E> for SubInstruction<E> {
const OPCODE_TYPE: OpcodeType = OpcodeType::RType(OPType::Op, 0x000, 0x0100000);
}

Expand Down Expand Up @@ -134,7 +134,7 @@ fn add_sub_gadget<E: ExtensionField, const IS_ADD: bool>(
})
}

impl<E: ExtensionField> Instruction<E> for AddInstruction {
impl<E: ExtensionField> Instruction<E> for AddInstruction<E> {
// const NAME: &'static str = "ADD";
fn name() -> String {
"ADD".into()
Expand Down Expand Up @@ -183,7 +183,7 @@ impl<E: ExtensionField> Instruction<E> for AddInstruction {
}
}

impl<E: ExtensionField> Instruction<E> for SubInstruction {
impl<E: ExtensionField> Instruction<E> for SubInstruction<E> {
// const NAME: &'static str = "ADD";
fn name() -> String {
"SUB".into()
Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use ff_ext::ExtensionField;
use std::{
collections::{BTreeSet, HashMap},
collections::{BTreeMap, BTreeSet, HashMap},

Check failure on line 3 in ceno_zkvm/src/scheme/prover.rs

View workflow job for this annotation

GitHub Actions / Various lints

unused import: `HashMap`
sync::Arc,
};

Expand Down Expand Up @@ -48,7 +48,7 @@ impl<E: ExtensionField> ZKVMProver<E> {
/// create proof for zkvm execution
pub fn create_proof(
&self,
mut witnesses: HashMap<String, RowMajorMatrix<E::BaseField>>,
mut witnesses: BTreeMap<String, RowMajorMatrix<E::BaseField>>,
max_threads: usize,
transcript: &mut Transcript<E>,
challenges: &[E; 2],
Expand Down
1 change: 1 addition & 0 deletions ceno_zkvm/src/tables/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, witness::RowMajor
use ff_ext::ExtensionField;

mod range;
pub use range::RangeTableCircuit;

pub trait TableCircuit<E: ExtensionField> {
type TableConfig: Send + Sync;
Expand Down
16 changes: 12 additions & 4 deletions ceno_zkvm/src/tables/range.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
use std::mem::MaybeUninit;
use std::{marker::PhantomData, mem::MaybeUninit};

use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, expression::{Expression, Fixed, ToExpr, WitIn}, set_fixed_val, set_val, structs::ROMType, tables::TableCircuit, witness::RowMajorMatrix};
use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::{Expression, Fixed, ToExpr, WitIn},
set_fixed_val, set_val,
structs::ROMType,
tables::TableCircuit,
witness::RowMajorMatrix,
};
use ff_ext::ExtensionField;
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};

Expand All @@ -10,9 +18,9 @@ pub struct RangeTableConfig {
u16_mlt: WitIn,
}

pub struct RangeTableCircuit;
pub struct RangeTableCircuit<E>(PhantomData<E>);

impl<E: ExtensionField> TableCircuit<E> for RangeTableCircuit {
impl<E: ExtensionField> TableCircuit<E> for RangeTableCircuit<E> {
type TableConfig = RangeTableConfig;
type Input = usize;

Expand Down

0 comments on commit 7084e09

Please sign in to comment.