Skip to content

Commit

Permalink
Extend profiling using tracing (#572)
Browse files Browse the repository at this point in the history
Improve profiling efforts by:
- refactoring tracing spans
- addressing a pitfall regarding spawned threads
- changing some subscriber configs
  • Loading branch information
mcalancea authored Nov 12, 2024
1 parent 54c8114 commit 82af85a
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 81 deletions.
34 changes: 25 additions & 9 deletions ceno_zkvm/examples/riscv_opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ use ff_ext::ff::Field;
use goldilocks::GoldilocksExt2;
use itertools::Itertools;
use mpcs::{Basefold, BasefoldRSParams, PolynomialCommitmentScheme};
use sumcheck::{entered_span, exit_span};
use tracing_flame::FlameLayer;
use tracing_subscriber::{EnvFilter, Registry, fmt, layer::SubscriberExt};
use tracing_subscriber::{EnvFilter, Registry, fmt, fmt::format::FmtSpan, layer::SubscriberExt};
use transcript::Transcript;

const PROGRAM_SIZE: usize = 16;
Expand Down Expand Up @@ -92,17 +93,29 @@ fn main() {
.collect(),
);
let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap();
let mut fmt_layer = fmt::layer()
.compact()
.with_span_events(FmtSpan::CLOSE)
.with_thread_ids(false)
.with_thread_names(false);
fmt_layer.set_ansi(false);

// Take filtering directives from RUST_LOG env_var
// Directive syntax: https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#directives
// Example: RUST_LOG="info" cargo run.. to get spans/events at info level; profiling spans are info
// Example: RUST_LOG="[sumcheck]" cargo run.. to get only events under the "sumcheck" span
let filter = EnvFilter::from_default_env();

let subscriber = Registry::default()
.with(
fmt::layer()
.compact()
.with_thread_ids(false)
.with_thread_names(false),
)
.with(EnvFilter::from_default_env())
.with(fmt_layer)
.with(filter)
.with(flame_layer.with_threads_collapsed(true));
tracing::subscriber::set_global_default(subscriber).unwrap();

let top_level = entered_span!("TOPLEVEL");

let keygen = entered_span!("KEYGEN");

// keygen
let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup");
let (pp, vp) = Pcs::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim");
Expand Down Expand Up @@ -138,6 +151,7 @@ fn main() {
.expect("keygen failed");
let vk = pk.get_vk();

exit_span!(keygen);
// proving
let prover = ZKVMProver::new(pk);
let verifier = ZKVMVerifier::new(vk);
Expand Down Expand Up @@ -284,14 +298,15 @@ fn main() {
let timer = Instant::now();

let transcript = Transcript::new(b"riscv");

let mut zkvm_proof = prover
.create_proof(zkvm_witness, pi, transcript)
.expect("create_proof failed");

println!(
"riscv_opcodes::create_proof, instance_num_vars = {}, time = {}",
instance_num_vars,
timer.elapsed().as_secs_f64()
timer.elapsed().as_secs()
);

let transcript = Transcript::new(b"riscv");
Expand Down Expand Up @@ -336,4 +351,5 @@ fn main() {
}
};
}
exit_span!(top_level);
}
76 changes: 47 additions & 29 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
}

/// create proof for zkvm execution
#[tracing::instrument(skip_all, name = "ZKVM_create_proof")]
pub fn create_proof(
&self,
witnesses: ZKVMWitnesses<E>,
Expand Down Expand Up @@ -87,10 +88,11 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
let mut commitments = BTreeMap::new();
let mut wits = BTreeMap::new();

let commit_to_traces_span = entered_span!("commit_to_traces");
// commit to opcode circuits first and then commit to table circuits, sorted by name
for (circuit_name, witness) in witnesses.into_iter_sorted() {
let commit_dur = std::time::Instant::now();
let num_instances = witness.num_instances();
let span = entered_span!("commit to iteration", circuit_name = circuit_name);
let witness = match num_instances {
0 => vec![],
_ => {
Expand All @@ -100,16 +102,13 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
PCS::batch_commit_and_write(&self.pk.pp, &witness, &mut transcript)
.map_err(ZKVMError::PCSError)?,
);
tracing::info!(
"commit to {} traces took {:?}",
circuit_name,
commit_dur.elapsed()
);
witness
}
};
exit_span!(span);
wits.insert(circuit_name, (witness, num_instances));
}
exit_span!(commit_to_traces_span);

// squeeze two challenges from transcript
let challenges = [
Expand All @@ -118,6 +117,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
];
tracing::debug!("challenges in prover: {:?}", challenges);

let main_proofs_span = entered_span!("main_proofs");
let mut transcripts = transcript.fork(self.pk.circuit_pks.len());
for ((circuit_name, pk), (i, transcript)) in self
.pk
Expand Down Expand Up @@ -193,6 +193,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
}
}
}
exit_span!(main_proofs_span);

Ok(vm_proof)
}
Expand All @@ -201,6 +202,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
/// 1: witness layer inferring from input -> output
/// 2: proof (sumcheck reduce) from output to input
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(skip_all, name = "create_opcode_proof", fields(circuit_name=name))]
pub fn create_opcode_proof(
&self,
name: &str,
Expand All @@ -226,8 +228,9 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
.all(|v| { v.evaluations().len() == next_pow2_instances })
);

let wit_inference_span = entered_span!("wit_inference");
// main constraint: read/write record witness inference
let span = entered_span!("wit_inference::record");
let record_span = entered_span!("record");
let records_wit: Vec<ArcMultilinearExtension<'_, E>> = cs
.r_expressions
.par_iter()
Expand All @@ -240,7 +243,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
.collect();
let (r_records_wit, w_lk_records_wit) = records_wit.split_at(cs.r_expressions.len());
let (w_records_wit, lk_records_wit) = w_lk_records_wit.split_at(cs.w_expressions.len());
exit_span!(span);
exit_span!(record_span);

// product constraint: tower witness inference
let (r_counts_per_instance, w_counts_per_instance, lk_counts_per_instance) = (
Expand All @@ -255,47 +258,48 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
);
// process last layer by interleaving all the read/write record respectively
// as last layer is the output of sel stage
let span = entered_span!("wit_inference::tower_witness_r_last_layer");
let span = entered_span!("tower_witness_r_last_layer");
// TODO optimize last layer to avoid alloc new vector to save memory
let r_records_last_layer =
interleaving_mles_to_mles(r_records_wit, num_instances, NUM_FANIN, E::ONE);
assert_eq!(r_records_last_layer.len(), NUM_FANIN);
exit_span!(span);

// infer all tower witness after last layer
let span = entered_span!("wit_inference::tower_witness_r_layers");
let span = entered_span!("tower_witness_r_layers");
let r_wit_layers = infer_tower_product_witness(
log2_num_instances + log2_r_count,
r_records_last_layer,
NUM_FANIN,
);
exit_span!(span);

let span = entered_span!("wit_inference::tower_witness_w_last_layer");
let span = entered_span!("tower_witness_w_last_layer");
// TODO optimize last layer to avoid alloc new vector to save memory
let w_records_last_layer =
interleaving_mles_to_mles(w_records_wit, num_instances, NUM_FANIN, E::ONE);
assert_eq!(w_records_last_layer.len(), NUM_FANIN);
exit_span!(span);

let span = entered_span!("wit_inference::tower_witness_w_layers");
let span = entered_span!("tower_witness_w_layers");
let w_wit_layers = infer_tower_product_witness(
log2_num_instances + log2_w_count,
w_records_last_layer,
NUM_FANIN,
);
exit_span!(span);

let span = entered_span!("wit_inference::tower_witness_lk_last_layer");
let span = entered_span!("tower_witness_lk_last_layer");
// TODO optimize last layer to avoid alloc new vector to save memory
let lk_records_last_layer =
interleaving_mles_to_mles(lk_records_wit, num_instances, NUM_FANIN, chip_record_alpha);
assert_eq!(lk_records_last_layer.len(), 2);
exit_span!(span);

let span = entered_span!("wit_inference::tower_witness_lk_layers");
let span = entered_span!("tower_witness_lk_layers");
let lk_wit_layers = infer_tower_logup_witness(None, lk_records_last_layer);
exit_span!(span);
exit_span!(wit_inference_span);

if cfg!(test) {
// sanity check
Expand Down Expand Up @@ -326,8 +330,9 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
}));
}

let sumcheck_span = entered_span!("SUMCHECK");
// product constraint tower sumcheck
let span = entered_span!("sumcheck::tower");
let tower_span = entered_span!("tower");
// final evals for verifier
let record_r_out_evals: Vec<E> = r_wit_layers[0]
.iter()
Expand Down Expand Up @@ -365,10 +370,10 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
.max()
.unwrap()
);
exit_span!(span);
exit_span!(tower_span);

// batch sumcheck: selector + main degree > 1 constraints
let span = entered_span!("sumcheck::main_sel");
let main_sel_span = entered_span!("main_sel");
let (rt_r, rt_w, rt_lk, rt_non_lc_sumcheck): (Vec<E>, Vec<E>, Vec<E>, Vec<E>) = (
tower_proof.prod_specs_points[0]
.last()
Expand Down Expand Up @@ -581,7 +586,8 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
);
let input_open_point = main_sel_sumcheck_proofs.point.clone();
assert!(input_open_point.len() == log2_num_instances);
exit_span!(span);
exit_span!(main_sel_span);
exit_span!(sumcheck_span);

let span = entered_span!("witin::evals");
let wits_in_evals: Vec<E> = witnesses
Expand All @@ -590,7 +596,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
.collect();
exit_span!(span);

let span = entered_span!("pcs_open");
let pcs_open_span = entered_span!("pcs_open");
let opening_dur = std::time::Instant::now();
tracing::debug!(
"[opcode {}]: build opening proof for {} polys at {:?}",
Expand All @@ -612,7 +618,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
name,
opening_dur.elapsed(),
);
exit_span!(span);
exit_span!(pcs_open_span);
let wits_commit = PCS::get_pure_commitment(&wits_commit);

Ok(ZKVMOpcodeProof {
Expand All @@ -638,6 +644,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
/// support batch prove for logup + product arguments each with different num_vars()
/// side effect: concurrency will be determine based on min(thread, num_vars()),
/// so suggest dont batch too small table (size < threads) with large table together
#[tracing::instrument(skip_all, name = "create_table_proof", fields(table_name=name))]
pub fn create_table_proof(
&self,
name: &str,
Expand Down Expand Up @@ -681,8 +688,9 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
.all(|(r, w)| r.table_spec.len == w.table_spec.len)
);

let wit_inference_span = entered_span!("wit_inference");
// main constraint: lookup denominator and numerator record witness inference
let span = entered_span!("wit_inference::record");
let record_span = entered_span!("record");
let mut records_wit: Vec<ArcMultilinearExtension<'_, E>> = cs
.r_table_expressions
.par_iter()
Expand All @@ -707,10 +715,10 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
let (lk_d_wit, _empty) = remains.split_at_mut(cs.lk_table_expressions.len());
assert!(_empty.is_empty());

exit_span!(span);
exit_span!(record_span);

// infer all tower witness after last layer
let span = entered_span!("wit_inference::tower_witness_lk_last_layer");
let span = entered_span!("tower_witness_lk_last_layer");
let mut r_set_last_layer = r_set_wit
.iter()
.chain(w_set_wit.iter())
Expand Down Expand Up @@ -758,7 +766,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
.collect::<Vec<_>>();
exit_span!(span);

let span = entered_span!("wit_inference::tower_witness_lk_layers");
let span = entered_span!("tower_witness_lk_layers");
let r_wit_layers = r_set_last_layer
.into_iter()
.zip(r_set_wit.iter())
Expand All @@ -779,6 +787,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
.map(|(lk_n, lk_d)| infer_tower_logup_witness(Some(lk_n), lk_d))
.collect_vec();
exit_span!(span);
exit_span!(wit_inference_span);

if cfg!(test) {
// sanity check
Expand Down Expand Up @@ -831,8 +840,9 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
}));
}

let sumcheck_span = entered_span!("sumcheck");
// product constraint tower sumcheck
let span = entered_span!("sumcheck::tower");
let tower_span = entered_span!("tower");
// final evals for verifier
let r_out_evals = r_wit_layers
.iter()
Expand Down Expand Up @@ -889,7 +899,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
rt_tower.len(), // num var length should equal to max_num_instance
max_log2_num_instance
);
exit_span!(span);
exit_span!(tower_span);

// same point sumcheck is optional when all witin + fixed are in same num_vars
let is_skip_same_point_sumcheck = witnesses
Expand All @@ -904,7 +914,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
} else {
// one sumcheck to make them opening on same point r (with different prefix)
// If all table length are the same, we can skip this sumcheck
let span = entered_span!("sumcheck::opening_same_point");
let span = entered_span!("opening_same_point");
// NOTE: max concurrency will be dominated by smallest table since it will blo
let num_threads = optimal_sumcheck_threads(min_log2_num_instance);
let alpha_pow = get_challenge_pows(
Expand Down Expand Up @@ -993,6 +1003,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
)
};

exit_span!(sumcheck_span);
let span = entered_span!("fixed::evals + witin::evals");
let mut evals = witnesses
.par_iter()
Expand Down Expand Up @@ -1025,7 +1036,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
.collect_vec();
// TODO implement mechanism to skip commitment

let span = entered_span!("pcs_opening");
let pcs_opening = entered_span!("pcs_opening");
let (fixed_opening_proof, fixed_commit) = if !fixed.is_empty() {
(
Some(
Expand Down Expand Up @@ -1064,7 +1075,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
transcript,
)
.map_err(ZKVMError::PCSError)?;
exit_span!(span);
exit_span!(pcs_opening);
let wits_commit = PCS::get_pure_commitment(&wits_commit);
tracing::debug!(
"[table {}] build opening proof for {} polys at {:?}: values = {:?}, commit = {:?}",
Expand Down Expand Up @@ -1132,6 +1143,7 @@ impl<E: ExtensionField> TowerProofs<E> {

/// Tower Prover
impl TowerProver {
#[tracing::instrument(skip_all, name = "tower_prover_create_proof")]
pub fn create_proof<'a, E: ExtensionField>(
prod_specs: Vec<TowerProverSpec<'a, E>>,
logup_specs: Vec<TowerProverSpec<'a, E>>,
Expand Down Expand Up @@ -1226,11 +1238,17 @@ impl TowerProver {
}
}

let wrap_batch_span = entered_span!("wrap_batch");
// NOTE: at the time of adding this span, visualizing it with the flamegraph layer
// shows it to be (inexplicably) much more time-consuming than the call to `prove_batch_polys`
// This is likely a bug in the tracing-flame crate.
let (sumcheck_proofs, state) = IOPProverStateV2::prove_batch_polys(
num_threads,
virtual_polys.get_batched_polys(),
transcript,
);
exit_span!(wrap_batch_span);

proofs.push_sumcheck_proofs(sumcheck_proofs.proofs);

// rt' = r_merge || rt
Expand Down
Loading

0 comments on commit 82af85a

Please sign in to comment.