Skip to content

Commit

Permalink
fix clippy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
kunxian-xia committed Sep 18, 2024
1 parent cf1808b commit 2904666
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 46 deletions.
41 changes: 28 additions & 13 deletions ceno_zkvm/benches/riscv_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@ use std::time::{Duration, Instant};

use ark_std::test_rng;
use ceno_zkvm::{
self,
instructions::{riscv::addsub::AddInstruction, Instruction},
scheme::prover::ZKVMProver,
structs::{ZKVMConstraintSystem, ZKVMFixedTraces},
};
use const_env::from_env;
use criterion::*;

use ceno_zkvm::scheme::constants::MAX_NUM_VARIABLES;
use ff_ext::ff::Field;
use goldilocks::{Goldilocks, GoldilocksExt2};
use itertools::Itertools;
use mpcs::{BasefoldDefault, PolynomialCommitmentScheme};
use multilinear_extensions::mle::IntoMLE;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use transcript::Transcript;

cfg_if::cfg_if! {
Expand Down Expand Up @@ -43,6 +46,7 @@ pub fn is_power_of_2(x: usize) -> bool {
}

fn bench_add(c: &mut Criterion) {
type Pcs = BasefoldDefault<E>;
let max_threads = {
if !is_power_of_2(RAYON_NUM_THREADS) {
#[cfg(not(feature = "non_pow2_rayon_thread"))]
Expand All @@ -68,9 +72,13 @@ fn bench_add(c: &mut Criterion) {
let mut zkvm_fixed_traces = ZKVMFixedTraces::default();
zkvm_fixed_traces.register_opcode_circuit::<AddInstruction<E>>(&zkvm_cs);

let rng = ChaCha8Rng::from_seed([0u8; 32]);
let param = Pcs::setup(1 << MAX_NUM_VARIABLES, &rng).unwrap();
let (pp, _) = Pcs::trim(&param, 1 << MAX_NUM_VARIABLES).unwrap();

let pk = zkvm_cs
.clone()
.key_gen(zkvm_fixed_traces)
.key_gen::<Pcs>(param, zkvm_fixed_traces)
.expect("keygen failed");

let circuit_pk = pk
Expand All @@ -81,7 +89,6 @@ fn bench_add(c: &mut Criterion) {
let num_witin = circuit_pk.get_cs().num_witin;

let prover = ZKVMProver::new(pk);
let mut transcript = Transcript::new(b"riscv");

for instance_num_vars in 20..22 {
// expand more input size once runtime is acceptable
Expand All @@ -94,31 +101,39 @@ fn bench_add(c: &mut Criterion) {
|b| {
b.iter_with_setup(
|| {
let mut rng = test_rng();
let real_challenges = [E::random(&mut rng), E::random(&mut rng)];
(rng, real_challenges)
},
|(mut rng, real_challenges)| {
// generate mock witness
let mut rng = test_rng();
let num_instances = 1 << instance_num_vars;
let wits_in = (0..num_witin as usize)
(0..num_witin as usize)
.map(|_| {
(0..num_instances)
.map(|_| Goldilocks::random(&mut rng))
.collect::<Vec<Goldilocks>>()
.into_mle()
.into()
})
.collect_vec();
.collect_vec()
},
|wits_in| {
let timer = Instant::now();
let num_instances = 1 << instance_num_vars;
let mut transcript = Transcript::new(b"riscv");
let commit =
Pcs::batch_commit_and_write(&pp, &wits_in, &mut transcript).unwrap();
let challenges = [
transcript.read_challenge().elements,
transcript.read_challenge().elements,
];

let _ = prover
.create_opcode_proof(
&pp,
&circuit_pk,
wits_in,
wits_in.into_iter().map(|mle| mle.into()).collect_vec(),
commit,
num_instances,
max_threads,
&mut transcript,
&real_challenges,
&challenges,
)
.expect("create_proof failed");
println!(
Expand Down
15 changes: 12 additions & 3 deletions ceno_zkvm/src/instructions/riscv/test.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use goldilocks::GoldilocksExt2;
use mpcs::{BasefoldDefault, PolynomialCommitmentScheme};
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;

use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
Expand All @@ -9,22 +12,28 @@ use super::addsub::{AddInstruction, SubInstruction};

#[test]
fn test_multiple_opcode() {
type E = GoldilocksExt2;
type PCS = BasefoldDefault<E>;

let mut cs = ConstraintSystem::new(|| "riscv");
let _add_config = cs.namespace(
|| "add",
|cs| {
let mut circuit_builder = CircuitBuilder::<GoldilocksExt2>::new(cs);
let mut circuit_builder = CircuitBuilder::<E>::new(cs);
let config = AddInstruction::construct_circuit(&mut circuit_builder);
Ok(config)
},
);
let _sub_config = cs.namespace(
|| "sub",
|cs| {
let mut circuit_builder = CircuitBuilder::<GoldilocksExt2>::new(cs);
let mut circuit_builder = CircuitBuilder::<E>::new(cs);
let config = SubInstruction::construct_circuit(&mut circuit_builder);
Ok(config)
},
);
cs.key_gen(None);
let rng = ChaCha8Rng::from_seed([0u8; 32]);
let param = PCS::setup(1 << 10, &rng).unwrap();
let (pp, _) = PCS::trim(&param, 1 << 10).unwrap();
cs.key_gen::<PCS>(&pp, None);
}
3 changes: 1 addition & 2 deletions ceno_zkvm/src/keygen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ impl<E: ExtensionField> ZKVMConstraintSystem<E> {
mut vm_fixed_traces: ZKVMFixedTraces<E>,
) -> Result<ZKVMProvingKey<E, PCS>, ZKVMError> {
let mut vm_pk = ZKVMProvingKey::new(pp);
let (pp, _) =
PCS::trim(&vm_pk.pp, 1 << MAX_NUM_VARIABLES).map_err(|err| ZKVMError::PCSError(err))?;
let (pp, _) = PCS::trim(&vm_pk.pp, 1 << MAX_NUM_VARIABLES).map_err(ZKVMError::PCSError)?;

for (c_name, cs) in self.circuit_css.into_iter() {
let fixed_traces = vm_fixed_traces
Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,9 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
wits_in: &[ArcMultilinearExtension<'a, E>],
challenge: Option<[E; 2]>,
) -> Result<(), Vec<MockProverError<E>>> {
let table = challenge.map(|challenge| load_tables(cb, challenge));
let (challenge, table) = if let Some(challenge) = challenge {
(challenge, &load_tables(cb, challenge))
(challenge, table.as_ref().unwrap())
} else {
load_once_tables(cb)
};
Expand Down Expand Up @@ -521,7 +522,6 @@ mod tests {

use super::*;
use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
error::ZKVMError,
expression::{ToExpr, WitIn},
instructions::riscv::config::{ExprLtConfig, ExprLtInput},
Expand Down
20 changes: 13 additions & 7 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ use transcript::Transcript;
use crate::{
error::ZKVMError,
scheme::{
constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, NUM_FANIN, NUM_FANIN_LOGUP},
constants::{
MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, MAX_NUM_VARIABLES, NUM_FANIN, NUM_FANIN_LOGUP,
},
utils::{
infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles,
wit_infer_by_expr,
Expand Down Expand Up @@ -58,7 +60,8 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
// TODO: commit to fixed commitment

// commit to main traces
let (pp, _) = PCS::trim(&self.pk.pp, 1 << 24).map_err(|e| ZKVMError::PCSError(e))?;
let (pp, _) =
PCS::trim(&self.pk.pp, 1 << MAX_NUM_VARIABLES).map_err(ZKVMError::PCSError)?;
let mut commitments = BTreeMap::new();
let mut wits = BTreeMap::new();
// sort by circuit name, and we rely on an assumption that
Expand All @@ -69,7 +72,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
commitments.insert(
circuit_name.clone(),
PCS::batch_commit_and_write(&pp, &witness, &mut transcript)
.map_err(|e| ZKVMError::PCSError(e))?,
.map_err(ZKVMError::PCSError)?,
);
wits.insert(circuit_name, (witness, num_instances));
}
Expand All @@ -79,6 +82,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
transcript.read_challenge().elements,
transcript.read_challenge().elements,
];
tracing::debug!("challenges in prover: {:?}", challenges);

let mut transcripts = transcript.fork(self.pk.circuit_pks.len());
for ((circuit_name, pk), (i, transcript)) in self
Expand Down Expand Up @@ -108,8 +112,8 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
tracing::debug!("opcode circuit {}: {}", circuit_name, lk_s);
}
let opcode_proof = self.create_opcode_proof(
pk,
&pp,
pk,
witness.into_iter().map(|w| w.into()).collect_vec(),
wits_commit,
num_instances,
Expand Down Expand Up @@ -153,10 +157,11 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
/// major flow break down into
/// 1: witness layer inferring from input -> output
/// 2: proof (sumcheck reduce) from output to input
#[allow(clippy::too_many_arguments)]
pub fn create_opcode_proof(
&self,
circuit_pk: &ProvingKey<E, PCS>,
pp: &PCS::ProverParam,
circuit_pk: &ProvingKey<E, PCS>,
witnesses: Vec<ArcMultilinearExtension<'_, E>>,
wits_commit: PCS::CommitmentWithData,
num_instances: usize,
Expand Down Expand Up @@ -533,7 +538,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
wits_in_evals.as_slice(),
transcript,
)
.map_err(|e| ZKVMError::PCSError(e))?;
.map_err(ZKVMError::PCSError)?;
exit_span!(span);

Ok(ZKVMOpcodeProof {
Expand All @@ -555,6 +560,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
})
}

#[allow(clippy::too_many_arguments)]
pub fn create_table_proof(
&self,
circuit_pk: &ProvingKey<E, PCS>,
Expand Down Expand Up @@ -745,7 +751,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
wits_in_evals.as_slice(),
transcript,
)
.map_err(|e| ZKVMError::PCSError(e))?;
.map_err(ZKVMError::PCSError)?;
exit_span!(span);

Ok(ZKVMTableProof {
Expand Down
49 changes: 34 additions & 15 deletions ceno_zkvm/src/scheme/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ use ff::Field;
use ff_ext::ExtensionField;
use goldilocks::GoldilocksExt2;
use itertools::Itertools;
use multilinear_extensions::mle::IntoMLEs;
use rand::rngs::ThreadRng;
use mpcs::{BasefoldDefault, PolynomialCommitmentScheme};
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use transcript::Transcript;

use crate::{
Expand Down Expand Up @@ -73,14 +74,21 @@ impl<E: ExtensionField, const L: usize, const RW: usize> Instruction<E> for Test
fn test_rw_lk_expression_combination() {
fn test_rw_lk_expression_combination_inner<const L: usize, const RW: usize>() {
type E = GoldilocksExt2;
type Pcs = BasefoldDefault<E>;
let rng = ChaCha8Rng::from_seed([0u8; 32]);
let param = Pcs::setup(1 << 20, &rng).unwrap();
let (pp, vp) = Pcs::trim(&param, 1 << 20).unwrap();
let name = TestCircuit::<E, RW, L>::name();
let mut zkvm_cs = ZKVMConstraintSystem::default();
let config = zkvm_cs.register_opcode_circuit::<TestCircuit<E, RW, L>>();

let mut zkvm_fixed_traces = ZKVMFixedTraces::default();
zkvm_fixed_traces.register_opcode_circuit::<TestCircuit<E, RW, L>>(&zkvm_cs);

let pk = zkvm_cs.clone().key_gen(zkvm_fixed_traces).unwrap();
let pk = zkvm_cs
.clone()
.key_gen::<Pcs>(param, zkvm_fixed_traces)
.unwrap();
let vk = pk.get_vk();

// generate mock witness
Expand All @@ -97,22 +105,21 @@ fn test_rw_lk_expression_combination() {
// get proof
let prover = ZKVMProver::new(pk);
let mut transcript = Transcript::new(b"test");
let mut rng = ThreadRng::default();
let challenges = [E::random(&mut rng), E::random(&mut rng)];

let wits_in = zkvm_witness
.witnesses
.remove(&name)
.unwrap()
.de_interleaving()
.into_mles()
.into_iter()
.map(|v| v.into())
.collect_vec();
let wits_in = zkvm_witness.witnesses.remove(&name).unwrap().into_mles();
// commit to main traces
let commit = Pcs::batch_commit_and_write(&pp, &wits_in, &mut transcript).unwrap();
let wits_in = wits_in.into_iter().map(|v| v.into()).collect_vec();
let challenges = [
transcript.read_challenge().elements,
transcript.read_challenge().elements,
];

let proof = prover
.create_opcode_proof(
&pp,
prover.pk.circuit_pks.get(&name).unwrap(),
wits_in,
commit,
num_instances,
1,
&mut transcript,
Expand All @@ -122,8 +129,20 @@ fn test_rw_lk_expression_combination() {

let verifier = ZKVMVerifier::new(vk.clone());
let mut v_transcript = Transcript::new(b"test");
// write commitment into transcript and derive challenges from it
Pcs::write_commitment(
&Pcs::get_pure_commitment(&proof.wits_commit),
&mut v_transcript,
)
.unwrap();
let challenges = [
transcript.read_challenge().elements,
transcript.read_challenge().elements,
];

let _rt_input = verifier
.verify_opcode_proof(
&vp,
verifier.vk.circuit_vks.get(&name).unwrap(),
&proof,
&mut v_transcript,
Expand Down
Loading

0 comments on commit 2904666

Please sign in to comment.