Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize witness cloning #89

Merged
merged 5 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/lints.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ jobs:
uses: actions-rs/cargo@v1
with:
command: make
args: fmt-check
args: fmt-check-selected-packages

- name: Run clippy
uses: actions-rs/cargo@v1
with:
command: make
args: clippy
args: clippy-check-selected-packages

27 changes: 27 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ members = [
"singer-utils",
"sumcheck",
"transcript",
"ceno_zkvm"
]

[workspace.package]
Expand Down
10 changes: 9 additions & 1 deletion Makefile.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ RAYON_NUM_THREADS = "${CORE}"

[tasks.tests]
command = "cargo"
args = ["test", "--lib", "--release", "--all"]
args = ["test", "--lib", "--release", "--workspace", "--exclude", "singer-pro"]

[tasks.fmt-check]
command = "cargo"
Expand All @@ -18,3 +18,11 @@ args = ["fmt", "--all"]
[tasks.clippy]
command = "cargo"
args = ["clippy", "--all-features", "--all-targets", "--", "-D", "warnings"]

[tasks.fmt-check-selected-packages]
command = "cargo"
args = ["fmt", "-p", "ceno_zkvm", "--", "--check"]

[tasks.clippy-check-selected-packages]
command = "cargo"
args = ["clippy", "-p", "ceno_zkvm", "--", "-D", "warnings"]
42 changes: 42 additions & 0 deletions ceno_zkvm/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
[package]
name = "ceno_zkvm"
version.workspace = true
edition.workspace = true
license.workspace = true

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
ark-std.workspace = true
ff.workspace = true
goldilocks.workspace = true
rayon.workspace = true
serde.workspace = true

transcript = { path = "../transcript" }
sumcheck = { version = "0.1.0", path = "../sumcheck" }
multilinear_extensions = { version = "0.1.0", path = "../multilinear_extensions" }
ff_ext = { path = "../ff_ext" }

itertools = "0.12.0"
strum = "0.25.0"
strum_macros = "0.25.3"
paste = "1.0.14"
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
tracing-flame = "0.2.0"
tracing = "0.1.40"

[dev-dependencies]
pprof = { version = "0.13", features = ["flamegraph"]}
criterion = { version = "0.5", features = ["html_reports"] }
cfg-if = "1.0.0"
const_env = "0.1.2"

[features]

[profile.bench]
opt-level = 0

[[bench]]
name = "riscv_add"
harness = false
127 changes: 127 additions & 0 deletions ceno_zkvm/benches/riscv_add.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#![allow(clippy::manual_memcpy)]
#![allow(clippy::needless_range_loop)]

use std::time::{Duration, Instant};

use ark_std::test_rng;
use ceno_zkvm::{
circuit_builder::CircuitBuilder,
instructions::{riscv::addsub::AddInstruction, Instruction},
scheme::prover::ZKVMProver,
};
use const_env::from_env;
use criterion::*;

use ff_ext::ff::Field;
use goldilocks::{Goldilocks, GoldilocksExt2};
use itertools::Itertools;
use multilinear_extensions::mle::IntoMLE;
use transcript::Transcript;

cfg_if::cfg_if! {
if #[cfg(feature = "flamegraph")] {
criterion_group! {
name = op_add;
config = Criterion::default().warm_up_time(Duration::from_millis(3000)).with_profiler(pprof::criterion::PProfProfiler::new(100, pprof::criterion::Output::Flamegraph(None)));
targets = bench_add
}
} else {
criterion_group! {
name = op_add;
config = Criterion::default().warm_up_time(Duration::from_millis(3000));
targets = bench_add
}
}
}

criterion_main!(op_add);

const NUM_SAMPLES: usize = 10;
#[from_env]
const RAYON_NUM_THREADS: usize = 8;

pub fn is_power_of_2(x: usize) -> bool {
(x != 0) && ((x & (x - 1)) == 0)
}

fn bench_add(c: &mut Criterion) {
let max_threads = {
if !is_power_of_2(RAYON_NUM_THREADS) {
#[cfg(not(feature = "non_pow2_rayon_thread"))]
{
panic!(
"add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool"
);
}

#[cfg(feature = "non_pow2_rayon_thread")]
{
use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2};
let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS);
create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true);
max_thread_id
}
} else {
RAYON_NUM_THREADS
}
};
let mut circuit_builder = CircuitBuilder::<GoldilocksExt2>::new();
let _ = AddInstruction::construct_circuit(&mut circuit_builder);
let circuit = circuit_builder.finalize_circuit();
let num_witin = circuit.num_witin;

let prover = ZKVMProver::new(circuit); // circuit clone due to verifier alos need circuit reference
let mut transcript = Transcript::new(b"riscv");

for instance_num_vars in 20..22 {
// expand more input size once runtime is acceptable
let mut group = c.benchmark_group(format!("add_op_{}", instance_num_vars));
group.sample_size(NUM_SAMPLES);

// Benchmark the proving time
group.bench_function(
BenchmarkId::new("prove_add", format!("prove_add_log2_{}", instance_num_vars)),
|b| {
b.iter_with_setup(
|| {
let mut rng = test_rng();
let real_challenges = [E::random(&mut rng), E::random(&mut rng)];
(rng, real_challenges)
},
|(mut rng, real_challenges)| {
// generate mock witness
let num_instances = 1 << instance_num_vars;
let wits_in = (0..num_witin as usize)
.map(|_| {
(0..num_instances)
.map(|_| Goldilocks::random(&mut rng))
.collect::<Vec<Goldilocks>>()
.into_mle()
.into()
})
.collect_vec();
let timer = Instant::now();
let _ = prover
.create_proof(
wits_in,
num_instances,
max_threads,
&mut transcript,
&real_challenges,
)
.expect("create_proof failed");
println!(
"AddInstruction::create_proof, instance_num_vars = {}, time = {}",
instance_num_vars,
timer.elapsed().as_secs_f64()
);
},
);
},
);

group.finish();
}

type E = GoldilocksExt2;
}
36 changes: 36 additions & 0 deletions ceno_zkvm/src/chip_handler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use ff_ext::ExtensionField;

use crate::{
error::ZKVMError,
expression::WitIn,
structs::{PCUInt, TSUInt, UInt64},
};

pub mod general;
pub mod global_state;
pub mod register;

pub trait GlobalStateRegisterMachineChipOperations<E: ExtensionField> {
fn state_in(&mut self, pc: &PCUInt, ts: &TSUInt) -> Result<(), ZKVMError>;

fn state_out(&mut self, pc: &PCUInt, ts: &TSUInt) -> Result<(), ZKVMError>;
}

pub trait RegisterChipOperations<E: ExtensionField> {
fn register_read(
&mut self,
register_id: &WitIn,
prev_ts: &mut TSUInt,
ts: &mut TSUInt,
values: &UInt64,
) -> Result<TSUInt, ZKVMError>;

fn register_write(
&mut self,
register_id: &WitIn,
prev_ts: &mut TSUInt,
ts: &mut TSUInt,
prev_values: &UInt64,
values: &UInt64,
) -> Result<TSUInt, ZKVMError>;
}
Loading
Loading