Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] non-uniform memory init/finalize PIOP design #430

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions ceno_zkvm/src/scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,21 @@ pub struct ZKVMTableProof<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>>
pub wits_opening_proof: PCS::Proof,
}

#[derive(Clone, Serialize, Deserialize)]
pub struct ZKVMMemProof<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> {
// tower evaluation at layer 1
pub r_out_evals: [E; 2],
pub w_out_evals: [E; 2],

pub tower_proof: TowerProofs<E>,

pub fixed_in_evals: Vec<E>,
pub fixed_opening_proof: PCS::Proof,
pub wits_commit: PCS::Commitment,
pub wits_in_evals: Vec<E>,
pub wits_opening_proof: PCS::Proof,
}

#[derive(Default, Clone, Debug)]
pub struct PublicValues<T: Default + Clone + Debug> {
exit_code: T,
Expand Down
214 changes: 213 additions & 1 deletion ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ use crate::{
structs::{
Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses,
},
tables::MemTableConfig,
utils::{get_challenge_pows, next_pow2_instance_padding, proper_num_threads},
virtual_polys::VirtualPolynomials,
};

use super::{PublicValues, ZKVMOpcodeProof, ZKVMProof, ZKVMTableProof};
use super::{PublicValues, ZKVMMemProof, ZKVMOpcodeProof, ZKVMProof, ZKVMTableProof};

pub struct ZKVMProver<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> {
pub pk: ZKVMProvingKey<E, PCS>,
Expand Down Expand Up @@ -1035,6 +1036,217 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
wits_opening_proof,
})
}

#[allow(clippy::too_many_arguments)]
/// create memory proof for initialize/finalize
pub fn create_mem_proof(
&self,
name: &str,
pp: &PCS::ProverParam,
circuit_pk: &ProvingKey<E, PCS>,
witnesses: Vec<ArcMultilinearExtension<'_, E>>,
wits_commit: PCS::CommitmentWithData,
// pi: &[E::BaseField],
max_threads: usize,
transcript: &mut Transcript<E>,
challenges: &[E; 2],
) -> Result<ZKVMMemProof<E, PCS>, ZKVMError> {
let cs = circuit_pk.get_cs();
let fixed = circuit_pk
.fixed_traces
.as_ref()
.expect("pk.fixed_traces must not be none for table circuit")
.iter()
.map(|f| -> ArcMultilinearExtension<E> { Arc::new(f.get_ranged_mle(1, 0)) })
.collect::<Vec<ArcMultilinearExtension<E>>>();

// sanity check
assert_eq!(witnesses.len(), cs.num_witin as usize);
assert_eq!(witnesses.len(), 1);

let num_vars = witnesses[0].num_vars();

assert_eq!(fixed.len(), cs.num_fixed);
// check all witness size are power of 2
assert!(
witnesses
.iter()
.all(|v| { v.evaluations().len().is_power_of_two() })
);
assert!(!cs.r_table_expressions.is_empty() || !cs.w_table_expressions.is_empty());
assert!(
cs.r_table_expressions
.iter()
.zip_eq(cs.w_table_expressions.iter())
.all(|(r, w)| r.table_len == w.table_len)
);

assert!(
num_vars < MemTableConfig::ADDR_RANGE[1] && num_vars >= MemTableConfig::ADDR_RANGE[0]
);
let fixed_addr_index = num_vars - MemTableConfig::ADDR_RANGE[0];

// non-uniform PIOP by selecting expression via auxiliary input addr_index
let w_table_expr = &cs.w_table_expressions[fixed_addr_index];
assert_eq!(w_table_expr.values.degree(), 1);
let r_table_expr = &cs.r_table_expressions[fixed_addr_index];
assert_eq!(r_table_expr.values.degree(), 1);

let span = entered_span!("wit_inference::record");
let (w_set_wit, r_set_wit) = rayon::join(
|| wit_infer_by_expr(&fixed, &witnesses, &[], challenges, &w_table_expr.values),
|| wit_infer_by_expr(&fixed, &witnesses, &[], challenges, &r_table_expr.values),
);
exit_span!(span);

// infer all tower witness after last layer
let span = entered_span!("wit_inference::tower_witness_lk_last_layer");
let r_set_last_layer = {
let (first, second) = r_set_wit
.get_ext_field_vec()
.split_at(r_set_wit.evaluations().len() / 2);
let res = vec![
first.to_vec().into_mle().into(),
second.to_vec().into_mle().into(),
];
assert_eq!(res.len(), NUM_FANIN_LOGUP);
res
};
let w_set_last_layer = {
let (first, second) = w_set_wit
.get_ext_field_vec()
.split_at(r_set_wit.evaluations().len() / 2);
let res = vec![
first.to_vec().into_mle().into(),
second.to_vec().into_mle().into(),
];
assert_eq!(res.len(), NUM_FANIN_LOGUP);
res
};
exit_span!(span);

let span = entered_span!("wit_inference::tower_witness_rw_layers");
let r_wit_layers =
infer_tower_product_witness(r_set_wit.num_vars(), r_set_last_layer, NUM_FANIN);
let w_wit_layers =
infer_tower_product_witness(w_set_wit.num_vars(), w_set_last_layer, NUM_FANIN);
exit_span!(span);

if cfg!(test) {
// sanity check
assert_eq!(r_wit_layers.len(), r_set_wit.num_vars());
assert!(r_wit_layers.iter().enumerate().all(|(i, w)| {
let expected_size = 1 << i;
w[0].evaluations().len() == expected_size
&& w[1].evaluations().len() == expected_size
}));
assert_eq!(w_wit_layers.len(), w_set_wit.num_vars());
assert!(w_wit_layers.iter().enumerate().all(|(i, w)| {
let expected_size = 1 << i;
w[0].evaluations().len() == expected_size
&& w[1].evaluations().len() == expected_size
}));
}

// product constraint tower sumcheck
let span = entered_span!("sumcheck::tower");
// final evals for verifier
let r_out_evals: [E; 2] = {
[
r_wit_layers[0][0].get_ext_field_vec()[0],
r_wit_layers[0][1].get_ext_field_vec()[0],
]
};
let w_out_evals: [E; 2] = {
[
w_wit_layers[0][0].get_ext_field_vec()[0],
w_wit_layers[0][1].get_ext_field_vec()[0],
]
};

let (rt_tower, tower_proof) = TowerProver::create_proof(
max_threads,
// pattern [r1, w1] same pair are chain together
vec![
TowerProverSpec {
witness: r_wit_layers,
},
TowerProverSpec {
witness: w_wit_layers,
},
],
vec![],
NUM_FANIN_LOGUP,
transcript,
);
assert_eq!(
rt_tower.len(), // num var length should equal to max_num_instance
num_vars
);
exit_span!(span);

let input_open_point = rt_tower;

let span = entered_span!("fixed::evals + witin::evals");
let mut evals = witnesses
.par_iter()
.chain(fixed.par_iter())
.map(|poly| poly.evaluate(&input_open_point[..poly.num_vars()]))
.collect::<Vec<_>>();
let fixed_in_evals = evals.split_off(witnesses.len());
let wits_in_evals = evals;
exit_span!(span);

let span = entered_span!("pcs_opening");
let fixed_opening_proof = PCS::simple_batch_open(
pp,
&fixed,
circuit_pk.fixed_commit_wd.as_ref().unwrap(),
&input_open_point,
fixed_in_evals.as_slice(),
transcript,
)
.map_err(ZKVMError::PCSError)?;
let fixed_commit = PCS::get_pure_commitment(circuit_pk.fixed_commit_wd.as_ref().unwrap());
tracing::debug!(
"[table {}] build opening proof for {} fixed polys at {:?}: values = {:?}, commit = {:?}",
name,
fixed.len(),
input_open_point,
fixed_in_evals,
fixed_commit,
);
let wits_opening_proof = PCS::simple_batch_open(
pp,
&witnesses,
&wits_commit,
&input_open_point,
wits_in_evals.as_slice(),
transcript,
)
.map_err(ZKVMError::PCSError)?;
exit_span!(span);
let wits_commit = PCS::get_pure_commitment(&wits_commit);
tracing::debug!(
"[table {}] build opening proof for {} polys at {:?}: values = {:?}, commit = {:?}",
name,
witnesses.len(),
input_open_point,
wits_in_evals,
wits_commit,
);

Ok(ZKVMMemProof {
r_out_evals,
w_out_evals,
tower_proof,
fixed_in_evals,
fixed_opening_proof,
wits_in_evals,
wits_commit,
wits_opening_proof,
})
}
}

/// TowerProofs
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ impl<E: ExtensionField> ZKVMFixedTraces<E> {
assert!(self.circuit_fixed_traces.insert(OC::name(), None).is_none());
}

pub fn register_table_circuit<TC: TableCircuit<E>>(
pub fn register_table_circuit<TC: TableCircuit<E, FixedOutput = RowMajorMatrix<E::BaseField>>>(
&mut self,
cs: &ZKVMConstraintSystem<E>,
config: TC::TableConfig,
Expand Down
Loading
Loading