Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include mpcs in zkvm prover and verifier #216

Merged
merged 27 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
fa7b6ff
include pcs as type parameter in zkvm types
kunxian-xia Sep 12, 2024
ff456ec
keygen with commit
kunxian-xia Sep 14, 2024
4a7c687
commit to fixed trace during keygen
kunxian-xia Sep 16, 2024
04ac636
Merge remote-tracking branch 'origin/master' into feat/integrate_mpcs
kunxian-xia Sep 16, 2024
91fcf1f
include commit and opening proof in zkvm proof
kunxian-xia Sep 17, 2024
7106b77
Merge branch 'master' into feat/integrate_mpcs
kunxian-xia Sep 17, 2024
48a1c0c
verify pcs opening proof both in prover and verifier
kunxian-xia Sep 17, 2024
119d918
fix clippy errors in mpcs
kunxian-xia Sep 18, 2024
cf1808b
Fix compilation error in test
yczhangsjtu Sep 18, 2024
2904666
fix clippy errors
kunxian-xia Sep 18, 2024
1f59513
augment program_add_loop with ecall_halt since mpcs cannot support mu…
kunxian-xia Sep 18, 2024
0c2a150
replace CommitmentWithData by Commitment
kunxian-xia Sep 18, 2024
4936d54
more
kunxian-xia Sep 18, 2024
a926c07
fmt
kunxian-xia Sep 18, 2024
10aeaf4
add debug level log for pcs::simple_batch_verify
kunxian-xia Sep 18, 2024
0f043cc
Merge remote-tracking branch 'origin/master' into feat/integrate_mpcs
kunxian-xia Sep 19, 2024
a4bc363
trim pcs param only once
kunxian-xia Sep 19, 2024
b0b1df5
fix test
kunxian-xia Sep 19, 2024
29e816a
update Cargo.toml
kunxian-xia Sep 19, 2024
6f5d5d7
fix benches compilation error
kunxian-xia Sep 19, 2024
a7a5b8c
log commit time for each circuit
kunxian-xia Sep 19, 2024
81a2755
Merge remote-tracking branch 'origin/master' into feat/integrate_mpcs
kunxian-xia Sep 19, 2024
ce9fcf5
fix clippy errors in mpcs
kunxian-xia Sep 19, 2024
7559547
resolve conflict
kunxian-xia Sep 19, 2024
dff9bb1
print out mpcs opening proof time
kunxian-xia Sep 19, 2024
9a055a2
Merge branch 'master' into feat/integrate_mpcs
kunxian-xia Sep 19, 2024
1c8a650
Merge branch 'master' into feat/integrate_mpcs
kunxian-xia Sep 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ subtle = "2.2.1"
rand_core = "0.6.0"
rand_xorshift = "0.3"
rayon = "1.8"
rand_chacha = { version = "0.3.1", features = ["serde1"] }

[patch."https://github.com/zhenfeizhang/Goldilocks"]
goldilocks = { git = "https://github.com/hero78119/Goldilocks" }
2 changes: 2 additions & 0 deletions ceno_zkvm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ ff.workspace = true
goldilocks.workspace = true
rayon.workspace = true
serde.workspace = true
rand_chacha.workspace = true

transcript = { path = "../transcript" }
sumcheck = { version = "0.1.0", path = "../sumcheck" }
multilinear_extensions = { version = "0.1.0", path = "../multilinear_extensions" }
ff_ext = { path = "../ff_ext" }
ceno_emul = { path = "../ceno_emul" }
mpcs = { path = "../mpcs" }

itertools = "0.12.0"
strum = "0.25.0"
Expand Down
39 changes: 27 additions & 12 deletions ceno_zkvm/benches/riscv_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ use ceno_zkvm::{
use const_env::from_env;
use criterion::*;

use ceno_zkvm::scheme::constants::MAX_NUM_VARIABLES;
use ff_ext::ff::Field;
use goldilocks::{Goldilocks, GoldilocksExt2};
use itertools::Itertools;
use mpcs::{BasefoldDefault, PolynomialCommitmentScheme};
use multilinear_extensions::mle::IntoMLE;
use transcript::Transcript;

Expand Down Expand Up @@ -43,6 +45,7 @@ pub fn is_power_of_2(x: usize) -> bool {
}

fn bench_add(c: &mut Criterion) {
type Pcs = BasefoldDefault<E>;
let max_threads = {
if !is_power_of_2(RAYON_NUM_THREADS) {
#[cfg(not(feature = "non_pow2_rayon_thread"))]
Expand All @@ -68,9 +71,12 @@ fn bench_add(c: &mut Criterion) {
let mut zkvm_fixed_traces = ZKVMFixedTraces::default();
zkvm_fixed_traces.register_opcode_circuit::<AddInstruction<E>>(&zkvm_cs);

let param = Pcs::setup(1 << MAX_NUM_VARIABLES).unwrap();
let (pp, vp) = Pcs::trim(&param, 1 << MAX_NUM_VARIABLES).unwrap();

let pk = zkvm_cs
.clone()
.key_gen(zkvm_fixed_traces)
.key_gen::<Pcs>(pp, vp, zkvm_fixed_traces)
.expect("keygen failed");

let circuit_pk = pk
Expand All @@ -81,7 +87,6 @@ fn bench_add(c: &mut Criterion) {
let num_witin = circuit_pk.get_cs().num_witin;

let prover = ZKVMProver::new(pk);
let mut transcript = Transcript::new(b"riscv");

for instance_num_vars in 20..22 {
// expand more input size once runtime is acceptable
Expand All @@ -94,31 +99,41 @@ fn bench_add(c: &mut Criterion) {
|b| {
b.iter_with_setup(
|| {
let mut rng = test_rng();
let real_challenges = [E::random(&mut rng), E::random(&mut rng)];
(rng, real_challenges)
},
|(mut rng, real_challenges)| {
// generate mock witness
let mut rng = test_rng();
let num_instances = 1 << instance_num_vars;
let wits_in = (0..num_witin as usize)
(0..num_witin as usize)
.map(|_| {
(0..num_instances)
.map(|_| Goldilocks::random(&mut rng))
.collect::<Vec<Goldilocks>>()
.into_mle()
.into()
})
.collect_vec();
.collect_vec()
},
|wits_in| {
let timer = Instant::now();
let num_instances = 1 << instance_num_vars;
let mut transcript = Transcript::new(b"riscv");
let commit =
Pcs::batch_commit_and_write(&prover.pk.pp, &wits_in, &mut transcript)
.unwrap();
let challenges = [
transcript.read_challenge().elements,
transcript.read_challenge().elements,
];

let _ = prover
.create_opcode_proof(
"ADD",
&prover.pk.pp,
&circuit_pk,
wits_in,
wits_in.into_iter().map(|mle| mle.into()).collect_vec(),
commit,
num_instances,
max_threads,
&mut transcript,
&real_challenges,
&challenges,
)
.expect("create_proof failed");
println!(
Expand Down
36 changes: 22 additions & 14 deletions ceno_zkvm/examples/riscv_add.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::time::Instant;
use std::{iter, time::Instant};

use ark_std::test_rng;
use ceno_zkvm::{
instructions::riscv::arith::AddInstruction, scheme::prover::ZKVMProver,
tables::ProgramTableCircuit,
Expand All @@ -10,12 +9,13 @@

use ceno_emul::{ByteAddr, InsnKind::ADD, StepRecord, VMState, CENO_PLATFORM};
use ceno_zkvm::{
scheme::verifier::ZKVMVerifier,
scheme::{constants::MAX_NUM_VARIABLES, verifier::ZKVMVerifier},
structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses},
tables::U16TableCircuit,
};
use ff_ext::ff::Field;
use goldilocks::GoldilocksExt2;
use mpcs::{Basefold, BasefoldRSParams, PolynomialCommitmentScheme};
use rand_chacha::ChaCha8Rng;
use sumcheck::util::is_power_of_2;
use tracing_flame::FlameLayer;
use tracing_subscriber::{fmt, layer::SubscriberExt, EnvFilter, Registry};
Expand All @@ -31,12 +31,14 @@
// - x3 is initialized to loop bound.
// we use x4 to hold the acc_sum.
#[allow(clippy::unusual_byte_groupings)]
const ECALL_HALT: u32 = 0b_000000000000_00000_000_00000_1110011;
#[allow(clippy::unusual_byte_groupings)]
const PROGRAM_ADD_LOOP: [u32; 4] = [
// func7 rs2 rs1 f3 rd opcode
0b_0000000_00100_00001_000_00100_0110011, // add x4, x4, x1 <=> addi x4, x4, 1
0b_0000000_00011_00010_000_00011_0110011, // add x3, x3, x2 <=> addi x3, x3, -1
0b_1_111111_00000_00011_001_1100_1_1100011, // bne x3, x0, -8
0b_000000000000_00000_000_00000_1110011, // ecall halt
ECALL_HALT, // ecall halt
];

/// Simple program to greet a person
Expand All @@ -55,17 +57,18 @@
fn main() {
let args = Args::parse();
type E = GoldilocksExt2;
type Pcs = Basefold<GoldilocksExt2, BasefoldRSParams, ChaCha8Rng>;

let max_threads = {
if !is_power_of_2(RAYON_NUM_THREADS) {
#[cfg(not(feature = "non_pow2_rayon_thread"))]

Check warning on line 64 in ceno_zkvm/examples/riscv_add.rs

View workflow job for this annotation

GitHub Actions / Various lints (x86_64-unknown-linux-gnu)

unexpected `cfg` condition value: `non_pow2_rayon_thread`
{
panic!(
"add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool"
);
}

#[cfg(feature = "non_pow2_rayon_thread")]

Check warning on line 71 in ceno_zkvm/examples/riscv_add.rs

View workflow job for this annotation

GitHub Actions / Various lints (x86_64-unknown-linux-gnu)

unexpected `cfg` condition value: `non_pow2_rayon_thread`
{
use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2};
let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS);
Expand All @@ -90,11 +93,19 @@
tracing::subscriber::set_global_default(subscriber).unwrap();

// keygen
let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup");
let (pp, vp) = Pcs::trim(&pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim");
let mut zkvm_cs = ZKVMConstraintSystem::default();
let add_config = zkvm_cs.register_opcode_circuit::<AddInstruction<E>>();
let range_config = zkvm_cs.register_table_circuit::<U16TableCircuit<E>>();
let prog_config = zkvm_cs.register_table_circuit::<ProgramTableCircuit<E>>();

let program_add_loop: Vec<u32> = PROGRAM_ADD_LOOP
.iter()
.cloned()
.chain(iter::repeat(ECALL_HALT))
.take(512)
.collect();
let mut zkvm_fixed_traces = ZKVMFixedTraces::default();
zkvm_fixed_traces.register_opcode_circuit::<AddInstruction<E>>(&zkvm_cs);
zkvm_fixed_traces.register_table_circuit::<U16TableCircuit<E>>(
Expand All @@ -105,12 +116,12 @@
zkvm_fixed_traces.register_table_circuit::<ProgramTableCircuit<E>>(
&zkvm_cs,
prog_config.clone(),
&PROGRAM_ADD_LOOP,
&program_add_loop,
);

let pk = zkvm_cs
.clone()
.key_gen(zkvm_fixed_traces)
.key_gen::<Pcs>(pp, vp, zkvm_fixed_traces)
.expect("keygen failed");
let vk = pk.get_vk();

Expand All @@ -128,7 +139,7 @@
vm.init_register_unsafe(1usize, 1);
vm.init_register_unsafe(2usize, u32::MAX); // -1 in two's complement
vm.init_register_unsafe(3usize, step_loop as u32);
for (i, inst) in PROGRAM_ADD_LOOP.iter().enumerate() {
for (i, inst) in program_add_loop.iter().enumerate() {
vm.init_memory(pc_start + i, *inst);
}
let records = vm
Expand All @@ -154,18 +165,15 @@
.assign_table_circuit::<ProgramTableCircuit<E>>(
&zkvm_cs,
&prog_config,
&PROGRAM_ADD_LOOP.len(),
&program_add_loop.len(),
)
.unwrap();

let timer = Instant::now();

let transcript = Transcript::new(b"riscv");
let mut rng = test_rng();
let real_challenges = [E::random(&mut rng), E::random(&mut rng)];

let zkvm_proof = prover
.create_proof(zkvm_witness, max_threads, transcript, &real_challenges)
.create_proof(zkvm_witness, max_threads, transcript)
.expect("create_proof failed");

println!(
Expand All @@ -177,7 +185,7 @@
let transcript = Transcript::new(b"riscv");
assert!(
verifier
.verify_proof(zkvm_proof, transcript, &real_challenges)
.verify_proof(zkvm_proof, transcript)
.expect("verify proof return with error"),
);
}
Expand Down
22 changes: 18 additions & 4 deletions ceno_zkvm/src/circuit_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use itertools::Itertools;
use std::marker::PhantomData;

use ff_ext::ExtensionField;
use mpcs::PolynomialCommitmentScheme;
use multilinear_extensions::mle::IntoMLEs;

use crate::{
Expand Down Expand Up @@ -137,16 +138,29 @@ impl<E: ExtensionField> ConstraintSystem<E> {
}
}

pub fn key_gen(self, fixed_traces: Option<RowMajorMatrix<E::BaseField>>) -> ProvingKey<E> {
// TODO: commit to fixed_traces

pub fn key_gen<PCS: PolynomialCommitmentScheme<E>>(
self,
pp: &PCS::ProverParam,
fixed_traces: Option<RowMajorMatrix<E::BaseField>>,
) -> ProvingKey<E, PCS> {
// transpose from row-major to column-major
let fixed_traces =
fixed_traces.map(|t| t.de_interleaving().into_mles().into_iter().collect_vec());

let fixed_commit_wd = fixed_traces
.as_ref()
.map(|traces| PCS::batch_commit(pp, traces).unwrap());
let fixed_commit = fixed_commit_wd
.as_ref()
.map(|commit_wd| PCS::get_pure_commitment(commit_wd));

ProvingKey {
fixed_traces,
vk: VerifyingKey { cs: self },
fixed_commit_wd,
vk: VerifyingKey {
cs: self,
fixed_commit,
kunxian-xia marked this conversation as resolved.
Show resolved Hide resolved
},
}
}

Expand Down
3 changes: 3 additions & 0 deletions ceno_zkvm/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use mpcs::Error;

#[derive(Debug)]
pub enum UtilError {
UIntError(String),
Expand All @@ -11,6 +13,7 @@ pub enum ZKVMError {
VKNotFound(String),
FixedTraceNotFound(String),
VerifyError(String),
PCSError(Error),
}

impl From<UtilError> for ZKVMError {
Expand Down
12 changes: 9 additions & 3 deletions ceno_zkvm/src/instructions/riscv/test.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use goldilocks::GoldilocksExt2;
use mpcs::{BasefoldDefault, PolynomialCommitmentScheme};

use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
Expand All @@ -9,22 +10,27 @@ use super::arith::{AddInstruction, SubInstruction};

#[test]
fn test_multiple_opcode() {
type E = GoldilocksExt2;
type PCS = BasefoldDefault<E>;

let mut cs = ConstraintSystem::new(|| "riscv");
let _add_config = cs.namespace(
|| "add",
|cs| {
let mut circuit_builder = CircuitBuilder::<GoldilocksExt2>::new(cs);
let mut circuit_builder = CircuitBuilder::<E>::new(cs);
let config = AddInstruction::construct_circuit(&mut circuit_builder);
Ok(config)
},
);
let _sub_config = cs.namespace(
|| "sub",
|cs| {
let mut circuit_builder = CircuitBuilder::<GoldilocksExt2>::new(cs);
let mut circuit_builder = CircuitBuilder::<E>::new(cs);
let config = SubInstruction::construct_circuit(&mut circuit_builder);
Ok(config)
},
);
cs.key_gen(None);
let param = PCS::setup(1 << 10).unwrap();
let (pp, _) = PCS::trim(&param, 1 << 10).unwrap();
cs.key_gen::<PCS>(&pp, None);
}
11 changes: 7 additions & 4 deletions ceno_zkvm/src/keygen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,24 @@ use crate::{
structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMProvingKey},
};
use ff_ext::ExtensionField;
use mpcs::PolynomialCommitmentScheme;

impl<E: ExtensionField> ZKVMConstraintSystem<E> {
pub fn key_gen(
pub fn key_gen<PCS: PolynomialCommitmentScheme<E>>(
self,
pp: PCS::ProverParam,
vp: PCS::VerifierParam,
mut vm_fixed_traces: ZKVMFixedTraces<E>,
) -> Result<ZKVMProvingKey<E>, ZKVMError> {
let mut vm_pk = ZKVMProvingKey::default();
) -> Result<ZKVMProvingKey<E, PCS>, ZKVMError> {
let mut vm_pk = ZKVMProvingKey::new(pp, vp);

for (c_name, cs) in self.circuit_css.into_iter() {
let fixed_traces = vm_fixed_traces
.circuit_fixed_traces
.remove(&c_name)
.ok_or(ZKVMError::FixedTraceNotFound(c_name.clone()))?;

let circuit_pk = cs.key_gen(fixed_traces);
let circuit_pk = cs.key_gen(&vm_pk.pp, fixed_traces);
assert!(vm_pk.circuit_pks.insert(c_name, circuit_pk).is_none());
}

Expand Down
Loading
Loading