Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/impl riscv ADD instruction #85

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 10 additions & 33 deletions gkr-graph/examples/series_connection_alt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,16 @@ use gkr::{
use gkr_graph::{
error::GKRGraphError,
structs::{
CircuitGraphAuxInfo, CircuitGraphBuilder, IOPProverState, IOPVerifierState, NodeOutputType,
PredType, TargetEvaluations,
CircuitGraphAuxInfo, CircuitGraphBuilder, IOPProverState, IOPVerifierState, NodeOutputType, PredType,
TargetEvaluations,
},
};
use goldilocks::{Goldilocks, GoldilocksExt2};
use simple_frontend::structs::{ChallengeId, CircuitBuilder, MixedCell};
use std::sync::Arc;
use transcript::Transcript;

fn construct_input<E: ExtensionField>(
input_size: usize,
challenge: ChallengeId,
) -> Arc<Circuit<E>> {
fn construct_input<E: ExtensionField>(input_size: usize, challenge: ChallengeId) -> Arc<Circuit<E>> {
let mut circuit_builder = CircuitBuilder::<E>::new();
let (_, inputs) = circuit_builder.create_witness_in(input_size);
let (_, lookup_inputs) = circuit_builder.create_ext_witness_out(input_size);
Expand All @@ -34,10 +31,7 @@ fn construct_input<E: ExtensionField>(

/// Construct a selector for n_instances and each instance contains `num`
/// items. `num` must be a power of 2.
pub(crate) fn construct_prefix_selector<E: ExtensionField>(
n_instances: usize,
num: usize,
) -> Arc<Circuit<E>> {
pub(crate) fn construct_prefix_selector<E: ExtensionField>(n_instances: usize, num: usize) -> Arc<Circuit<E>> {
assert_eq!(num, num.next_power_of_two());
let mut circuit_builder = CircuitBuilder::<E>::new();
let _ = circuit_builder.create_constant_in(n_instances * num, 1);
Expand All @@ -59,12 +53,7 @@ pub(crate) fn construct_inv_sum<E: ExtensionField>() -> Arc<Circuit<E>> {
let den_mul = circuit_builder.create_ext_cell();
circuit_builder.mul2_ext(&den_mul, &input[0], &input[1], E::BaseField::ONE);
let tmp = circuit_builder.create_ext_cell();
circuit_builder.sel_mixed_and_ext(
&tmp,
&MixedCell::Constant(E::BaseField::ONE),
&input[0],
cond[0],
);
circuit_builder.sel_mixed_and_ext(&tmp, &MixedCell::Constant(E::BaseField::ONE), &input[0], cond[0]);
circuit_builder.sel_ext(&output[0], &tmp, &den_mul, cond[1]);

// select the numerator 0 or 1 or input[0] + input[1]
Expand Down Expand Up @@ -143,11 +132,7 @@ fn main() -> Result<(), GKRGraphError> {
let mut prover_graph_builder = CircuitGraphBuilder::<GoldilocksExt2>::new();
let mut verifier_graph_builder = CircuitGraphBuilder::<GoldilocksExt2>::new();
let mut prover_transcript = Transcript::<GoldilocksExt2>::new(b"test");
let challenge = vec![
prover_transcript
.get_and_append_challenge(b"lookup challenge")
.elements,
];
let challenge = vec![prover_transcript.get_and_append_challenge(b"lookup challenge").elements];

let mut add_node_and_witness = |label: &'static str,
circuit: &Arc<Circuit<_>>,
Expand Down Expand Up @@ -225,12 +210,8 @@ fn main() -> Result<(), GKRGraphError> {
// Proofs generation
// =================
let output_point = vec![
prover_transcript
.get_and_append_challenge(b"output point")
.elements,
prover_transcript
.get_and_append_challenge(b"output point")
.elements,
prover_transcript.get_and_append_challenge(b"output point").elements,
prover_transcript.get_and_append_challenge(b"output point").elements,
];
let output_eval = circuit_witness
.node_witnesses
Expand Down Expand Up @@ -259,12 +240,8 @@ fn main() -> Result<(), GKRGraphError> {
.elements];

let output_point = vec![
verifier_transcript
.get_and_append_challenge(b"output point")
.elements,
verifier_transcript
.get_and_append_challenge(b"output point")
.elements,
verifier_transcript.get_and_append_challenge(b"output point").elements,
verifier_transcript.get_and_append_challenge(b"output point").elements,
];

IOPVerifierState::verify(
Expand Down
10 changes: 3 additions & 7 deletions gkr-graph/src/circuit_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@ use itertools::Itertools;
use crate::structs::{CircuitGraph, CircuitGraphWitness, NodeOutputType, TargetEvaluations};

impl<E: ExtensionField> CircuitGraph<E> {
pub fn target_evals(
&self,
witness: &CircuitGraphWitness<E::BaseField>,
point: &Point<E>,
) -> TargetEvaluations<E> {
pub fn target_evals(&self, witness: &CircuitGraphWitness<E::BaseField>, point: &Point<E>) -> TargetEvaluations<E> {
// println!("targets: {:?}, point: {:?}", self.targets, point);
let target_evals = self
.targets
Expand All @@ -24,8 +20,8 @@ impl<E: ExtensionField> CircuitGraph<E> {
.instances
.as_slice()
.original_mle(),
NodeOutputType::WireOut(node_id, wit_id) => witness.node_witnesses[*node_id]
.witness_out_ref()[*wit_id as usize]
NodeOutputType::WireOut(node_id, wit_id) => witness.node_witnesses[*node_id].witness_out_ref()
[*wit_id as usize]
.instances
.as_slice()
.original_mle(),
Expand Down
35 changes: 10 additions & 25 deletions gkr-graph/src/circuit_graph_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ use simple_frontend::structs::WitnessId;
use crate::{
error::GKRGraphError,
structs::{
CircuitGraph, CircuitGraphBuilder, CircuitGraphWitness, CircuitNode, NodeInputType,
NodeOutputType, PredType,
CircuitGraph, CircuitGraphBuilder, CircuitGraphWitness, CircuitNode, NodeInputType, NodeOutputType, PredType,
},
};

Expand Down Expand Up @@ -45,16 +44,13 @@ impl<E: ExtensionField> CircuitGraphBuilder<E> {
assert!(num_instances.is_power_of_two());
assert_eq!(sources.len(), circuit.n_witness_in);
assert!(
!sources.iter().any(
|source| source.instances.len() != 0 && source.instances.len() != num_instances
),
!sources
.iter()
.any(|source| source.instances.len() != 0 && source.instances.len() != num_instances),
"node_id: {}, num_instances: {}, sources_num_instances: {:?}",
id,
num_instances,
sources
.iter()
.map(|source| source.instances.len())
.collect_vec()
sources.iter().map(|source| source.instances.len()).collect_vec()
);

let mut witness = CircuitWitness::new(circuit, challenges);
Expand All @@ -65,14 +61,11 @@ impl<E: ExtensionField> CircuitGraphBuilder<E> {
let (id, out) = &match out {
NodeOutputType::OutputLayer(id) => (
*id,
&self.witness.node_witnesses[*id]
.output_layer_witness_ref()
.instances,
&self.witness.node_witnesses[*id].output_layer_witness_ref().instances,
),
NodeOutputType::WireOut(id, wit_id) => (
*id,
&self.witness.node_witnesses[*id].witness_out_ref()[*wit_id as usize]
.instances,
&self.witness.node_witnesses[*id].witness_out_ref()[*wit_id as usize].instances,
),
};
let old_num_instances = self.witness.node_witnesses[*id].n_instances();
Expand All @@ -94,10 +87,7 @@ impl<E: ExtensionField> CircuitGraphBuilder<E> {
out.iter()
.cloned()
.flat_map(|single_instance| {
single_instance
.into_iter()
.cycle()
.take(num_dups * old_size)
single_instance.into_iter().cycle().take(num_dups * old_size)
})
.chunks(old_size)
.into_iter()
Expand Down Expand Up @@ -146,9 +136,7 @@ impl<E: ExtensionField> CircuitGraphBuilder<E> {
}

/// Collect the information of `self.sources` and `self.targets`.
pub fn finalize_graph_and_witness(
mut self,
) -> (CircuitGraph<E>, CircuitGraphWitness<E::BaseField>) {
pub fn finalize_graph_and_witness(mut self) -> (CircuitGraph<E>, CircuitGraphWitness<E::BaseField>) {
// Generate all possible graph output
let outs = self
.graph
Expand Down Expand Up @@ -244,10 +232,7 @@ impl<E: ExtensionField> CircuitGraphBuilder<E> {
},
);

assert_eq!(
expected_target,
targets.iter().cloned().collect::<BTreeSet<_>>()
);
assert_eq!(expected_target, targets.iter().cloned().collect::<BTreeSet<_>>());

self.graph.sources = sources.into_iter().collect();
self.graph.targets = targets.to_vec();
Expand Down
4 changes: 2 additions & 2 deletions gkr-graph/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use transcript::Transcript;
use crate::{
error::GKRGraphError,
structs::{
CircuitGraph, CircuitGraphWitness, GKRProverState, IOPProof, IOPProverState,
NodeOutputType, PredType, TargetEvaluations,
CircuitGraph, CircuitGraphWitness, GKRProverState, IOPProof, IOPProverState, NodeOutputType, PredType,
TargetEvaluations,
},
};

Expand Down
92 changes: 43 additions & 49 deletions gkr-graph/src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use transcript::Transcript;
use crate::{
error::GKRGraphError,
structs::{
CircuitGraph, CircuitGraphAuxInfo, GKRVerifierState, IOPProof, IOPVerifierState,
NodeOutputType, PredType, TargetEvaluations,
CircuitGraph, CircuitGraphAuxInfo, GKRVerifierState, IOPProof, IOPVerifierState, NodeOutputType, PredType,
TargetEvaluations,
},
};

Expand Down Expand Up @@ -50,56 +50,50 @@ impl<E: ExtensionField> IOPVerifierState<E> {

let new_instance_num_vars = aux_info.instance_num_vars[node.id];

izip!(&node.preds, input_claim.point_and_evals).for_each(
|(pred_type, point_and_eval)| {
match pred_type {
PredType::Source => {
// TODO: collect `(proof.point.clone(), *eval)` as `TargetEvaluations`
// for later PCS open?
}
PredType::PredWire(pred_out) | PredType::PredWireDup(pred_out) => {
let point = match pred_type {
PredType::PredWire(_) => point_and_eval.point.clone(),
PredType::PredWireDup(out) => {
let node_id = match out {
NodeOutputType::OutputLayer(id) => *id,
NodeOutputType::WireOut(id, _) => *id,
};
// Suppose the new point is
// [single_instance_slice ||
// new_instance_index_slice]. The old point
// is [single_instance_slices ||
// new_instance_index_slices[(new_instance_num_vars
// - old_instance_num_vars)..]]
let old_instance_num_vars = aux_info.instance_num_vars[node_id];
let num_vars =
point_and_eval.point.len() - new_instance_num_vars;
[
point_and_eval.point[..num_vars].to_vec(),
point_and_eval.point[num_vars
+ (new_instance_num_vars - old_instance_num_vars)..]
.to_vec(),
]
.concat()
}
_ => unreachable!(),
};
match pred_out {
NodeOutputType::OutputLayer(id) => output_evals[*id]
.push(PointAndEval::new_from_ref(&point, &point_and_eval.eval)),
NodeOutputType::WireOut(id, wire_id) => {
let evals = &mut wit_out_evals[*id][*wire_id as usize];
assert!(
evals.point.is_empty() && evals.eval.is_zero_vartime(),
"unimplemented",
);
*evals = PointAndEval::new(point, point_and_eval.eval);
}
izip!(&node.preds, input_claim.point_and_evals).for_each(|(pred_type, point_and_eval)| {
match pred_type {
PredType::Source => {
// TODO: collect `(proof.point.clone(), *eval)` as `TargetEvaluations`
// for later PCS open?
}
PredType::PredWire(pred_out) | PredType::PredWireDup(pred_out) => {
let point = match pred_type {
PredType::PredWire(_) => point_and_eval.point.clone(),
PredType::PredWireDup(out) => {
let node_id = match out {
NodeOutputType::OutputLayer(id) => *id,
NodeOutputType::WireOut(id, _) => *id,
};
// Suppose the new point is
// [single_instance_slice ||
// new_instance_index_slice]. The old point
// is [single_instance_slices ||
// new_instance_index_slices[(new_instance_num_vars
// - old_instance_num_vars)..]]
let old_instance_num_vars = aux_info.instance_num_vars[node_id];
let num_vars = point_and_eval.point.len() - new_instance_num_vars;
[
point_and_eval.point[..num_vars].to_vec(),
point_and_eval.point[num_vars + (new_instance_num_vars - old_instance_num_vars)..]
.to_vec(),
]
.concat()
}
_ => unreachable!(),
};
match pred_out {
NodeOutputType::OutputLayer(id) => {
output_evals[*id].push(PointAndEval::new_from_ref(&point, &point_and_eval.eval))
}
NodeOutputType::WireOut(id, wire_id) => {
let evals = &mut wit_out_evals[*id][*wire_id as usize];
assert!(evals.point.is_empty() && evals.eval.is_zero_vartime(), "unimplemented",);
*evals = PointAndEval::new(point, point_and_eval.eval);
}
}
}
},
);
}
});
}

Ok(())
Expand Down
10 changes: 2 additions & 8 deletions gkr/benches/keccak256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ const NUM_SAMPLES: usize = 10;
const RAYON_NUM_THREADS: usize = 8;

fn bench_keccak256(c: &mut Criterion) {
println!(
"#layers: {}",
keccak256_circuit::<GoldilocksExt2>().layers.len()
);
println!("#layers: {}", keccak256_circuit::<GoldilocksExt2>().layers.len());

let max_thread_id = {
if !is_power_of_2(RAYON_NUM_THREADS) {
Expand Down Expand Up @@ -74,10 +71,7 @@ fn bench_keccak256(c: &mut Criterion) {
BenchmarkId::new("prove_keccak256", format!("keccak256_log2_{}", log2_n)),
|b| {
b.iter(|| {
assert!(
prove_keccak256(log2_n, &circuit, (1 << log2_n).min(max_thread_id),)
.is_some()
);
assert!(prove_keccak256(log2_n, &circuit, (1 << log2_n).min(max_thread_id),).is_some());
});
},
);
Expand Down
25 changes: 5 additions & 20 deletions gkr/examples/keccak256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ use tracing_flame::FlameLayer;
use tracing_subscriber::{fmt, layer::SubscriberExt, EnvFilter, Registry};

fn main() {
println!(
"#layers: {}",
keccak256_circuit::<GoldilocksExt2>().layers.len()
);
println!("#layers: {}", keccak256_circuit::<GoldilocksExt2>().layers.len());

#[allow(unused_mut)]
let mut max_thread_id: usize = env::var("RAYON_NUM_THREADS")
Expand All @@ -29,9 +26,7 @@ fn main() {
if !is_power_of_2(max_thread_id) {
#[cfg(not(feature = "non_pow2_rayon_thread"))]
{
panic!(
"add --features non_pow2_rayon_thread to support non pow of 2 rayon thread pool"
);
panic!("add --features non_pow2_rayon_thread to support non pow of 2 rayon thread pool");
}

#[cfg(feature = "non_pow2_rayon_thread")]
Expand All @@ -57,17 +52,12 @@ fn main() {
witness.add_instance(&circuit, all_zero);
witness.add_instance(&circuit, all_one);

izip!(
&witness.witness_out_ref()[0].instances,
[[0; 25], [u64::MAX; 25]]
)
.for_each(|(wire_out, state)| {
izip!(&witness.witness_out_ref()[0].instances, [[0; 25], [u64::MAX; 25]]).for_each(|(wire_out, state)| {
let output = wire_out[..256]
.chunks_exact(64)
.map(|bits| {
bits.iter().fold(0, |acc, bit| {
(acc << 1)
+ (*bit == <GoldilocksExt2 as ExtensionField>::BaseField::ONE) as u64
(acc << 1) + (*bit == <GoldilocksExt2 as ExtensionField>::BaseField::ONE) as u64
})
})
.collect_vec();
Expand All @@ -82,12 +72,7 @@ fn main() {

let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap();
let subscriber = Registry::default()
.with(
fmt::layer()
.compact()
.with_thread_ids(false)
.with_thread_names(false),
)
.with(fmt::layer().compact().with_thread_ids(false).with_thread_names(false))
.with(EnvFilter::from_default_env())
.with(flame_layer.with_threads_collapsed(true));
tracing::subscriber::set_global_default(subscriber).unwrap();
Expand Down
Loading
Loading