From 11b272aeeaa330854fcd6ad86d97ccaa1a34484d Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Wed, 7 Aug 2024 11:13:06 +0300 Subject: [PATCH] Blake statement --- crates/prover/src/examples/blake/air.rs | 380 ++++++++++++------ .../examples/blake/scheduler/constraints.rs | 7 +- .../src/examples/blake/scheduler/gen.rs | 8 +- .../src/examples/blake/xor_table/mod.rs | 15 +- 4 files changed, 274 insertions(+), 136 deletions(-) diff --git a/crates/prover/src/examples/blake/air.rs b/crates/prover/src/examples/blake/air.rs index 0185fb20d..1e89a03f2 100644 --- a/crates/prover/src/examples/blake/air.rs +++ b/crates/prover/src/examples/blake/air.rs @@ -1,38 +1,168 @@ use std::simd::u32x16; use itertools::{chain, multiunzip, Itertools}; +use num_traits::Zero; +use serde::Serialize; use tracing::{span, Level}; -use super::round::BlakeRoundComponent; +use super::round::{blake_round_info, BlakeRoundComponent}; use super::scheduler::BlakeSchedulerComponent; use super::xor_table::XorTableComponent; use crate::constraint_framework::constant_columns::gen_is_first; -use crate::core::air::{Air, AirProver, Component, ComponentProver}; +use crate::core::air::{Component, ComponentProver}; use crate::core::backend::simd::m31::LOG_N_LANES; use crate::core::backend::simd::SimdBackend; use crate::core::channel::Channel; -use crate::core::pcs::CommitmentSchemeProver; +use crate::core::fields::qm31::SecureField; +use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, TreeVec}; use crate::core::poly::circle::{CanonicCoset, PolyOps}; -use crate::core::prover::{prove, StarkProof, LOG_BLOWUP_FACTOR}; +use crate::core::prover::{prove, verify, StarkProof, VerificationError, LOG_BLOWUP_FACTOR}; use crate::core::vcs::ops::{MerkleHasher, MerkleOps}; use crate::core::InteractionElements; use crate::examples::blake::round::RoundElements; -use crate::examples::blake::scheduler::{self, BlakeElements, BlakeInput}; +use crate::examples::blake::scheduler::{self, blake_scheduler_info, BlakeElements, BlakeInput}; use crate::examples::blake::{ round, xor_table, BlakeXorElements, XorAccums, N_ROUNDS, ROUND_LOG_SPLIT, }; -pub struct BlakeAir { - pub scheduler_component: BlakeSchedulerComponent, - pub round_components: Vec, - pub xor12: XorTableComponent<12, 4>, - pub xor9: XorTableComponent<9, 2>, - pub xor8: XorTableComponent<8, 2>, - pub xor7: XorTableComponent<7, 2>, - pub xor4: XorTableComponent<4, 0>, +#[derive(Serialize)] +pub struct BlakeStatement0 { + log_size: u32, +} +impl BlakeStatement0 { + fn log_sizes(&self) -> TreeVec> { + let mut sizes = vec![]; + sizes.push( + blake_scheduler_info() + .mask_offsets + .as_cols_ref() + .map_cols(|_| self.log_size), + ); + for l in ROUND_LOG_SPLIT { + sizes.push( + blake_round_info() + .mask_offsets + .as_cols_ref() + .map_cols(|_| self.log_size + l), + ); + } + sizes.push(xor_table::trace_sizes::<12, 4>()); + sizes.push(xor_table::trace_sizes::<9, 2>()); + sizes.push(xor_table::trace_sizes::<8, 2>()); + sizes.push(xor_table::trace_sizes::<7, 2>()); + sizes.push(xor_table::trace_sizes::<4, 0>()); + + TreeVec::new( + (0..=2) + .map(|i| sizes.iter().flat_map(|x| x[i].clone()).collect()) + .collect(), + ) + } + fn mix_into(&self, channel: &mut impl Channel) { + // TODO(spapini): Do this better. + channel.mix_nonce(self.log_size as u64); + } +} + +pub struct AllElements { + blake_elements: BlakeElements, + round_elements: RoundElements, + xor_elements: BlakeXorElements, +} +impl AllElements { + pub fn draw(channel: &mut impl Channel) -> Self { + Self { + blake_elements: BlakeElements::draw(channel), + round_elements: RoundElements::draw(channel), + xor_elements: BlakeXorElements::draw(channel), + } + } +} + +pub struct BlakeStatement1 { + scheduler_claimed_sum: SecureField, + round_claimed_sums: Vec, + xor12_claimed_sum: SecureField, + xor9_claimed_sum: SecureField, + xor8_claimed_sum: SecureField, + xor7_claimed_sum: SecureField, + xor4_claimed_sum: SecureField, +} +impl BlakeStatement1 { + fn mix_into(&self, channel: &mut impl Channel) { + channel.mix_felts( + &chain![ + [ + self.scheduler_claimed_sum, + self.xor12_claimed_sum, + self.xor9_claimed_sum, + self.xor8_claimed_sum, + self.xor7_claimed_sum, + self.xor4_claimed_sum + ], + self.round_claimed_sums.clone() + ] + .collect_vec(), + ) + } } -impl Air for BlakeAir { +pub struct BlakeProof { + stmt0: BlakeStatement0, + stmt1: BlakeStatement1, + stark_proof: StarkProof, +} + +pub struct BlakeComponents { + scheduler_component: BlakeSchedulerComponent, + round_components: Vec, + xor12: XorTableComponent<12, 4>, + xor9: XorTableComponent<9, 2>, + xor8: XorTableComponent<8, 2>, + xor7: XorTableComponent<7, 2>, + xor4: XorTableComponent<4, 0>, +} +impl BlakeComponents { + fn new(stmt0: &BlakeStatement0, all_elements: &AllElements, stmt1: &BlakeStatement1) -> Self { + Self { + scheduler_component: BlakeSchedulerComponent { + log_size: stmt0.log_size, + blake_lookup_elements: all_elements.blake_elements.clone(), + round_lookup_elements: all_elements.round_elements.clone(), + claimed_sum: stmt1.scheduler_claimed_sum, + }, + round_components: ROUND_LOG_SPLIT + .iter() + .zip(stmt1.round_claimed_sums.clone()) + .map(|(l, claimed_sum)| BlakeRoundComponent { + log_size: stmt0.log_size + l, + xor_lookup_elements: all_elements.xor_elements.clone(), + round_lookup_elements: all_elements.round_elements.clone(), + claimed_sum, + }) + .collect(), + xor12: XorTableComponent { + lookup_elements: all_elements.xor_elements.xor12.clone(), + claimed_sum: stmt1.xor12_claimed_sum, + }, + xor9: XorTableComponent { + lookup_elements: all_elements.xor_elements.xor9.clone(), + claimed_sum: stmt1.xor9_claimed_sum, + }, + xor8: XorTableComponent { + lookup_elements: all_elements.xor_elements.xor8.clone(), + claimed_sum: stmt1.xor8_claimed_sum, + }, + xor7: XorTableComponent { + lookup_elements: all_elements.xor_elements.xor7.clone(), + claimed_sum: stmt1.xor7_claimed_sum, + }, + xor4: XorTableComponent { + lookup_elements: all_elements.xor_elements.xor4.clone(), + claimed_sum: stmt1.xor4_claimed_sum, + }, + } + } fn components(&self) -> Vec<&dyn Component> { chain![ [&self.scheduler_component as &dyn Component], @@ -47,9 +177,7 @@ impl Air for BlakeAir { ] .collect() } -} -impl AirProver for BlakeAir { fn component_provers(&self) -> Vec<&dyn ComponentProver> { chain![ [&self.scheduler_component as &dyn ComponentProver], @@ -69,7 +197,7 @@ impl AirProver for BlakeAir { } #[allow(unused)] -pub fn prove_blake(log_size: u32) -> (BlakeAir, StarkProof) +pub fn prove_blake(log_size: u32) -> (BlakeProof) where SimdBackend: MerkleOps, C: Channel, @@ -130,6 +258,10 @@ where let (xor_trace7, xor_lookup_data7) = xor_table::generate_trace(xor_accums.xor7); let (xor_trace4, xor_lookup_data4) = xor_table::generate_trace(xor_accums.xor4); + // Statement0. + let stmt0 = BlakeStatement0 { log_size }; + stmt0.mix_into(channel); + // Trace commitment. let mut tree_builder = commitment_scheme.tree_builder(); tree_builder.extend_evals( @@ -148,17 +280,15 @@ where span.exit(); // Draw lookup element. - let blake_lookup_elements = BlakeElements::draw(channel); - let round_lookup_elements = RoundElements::draw(channel); - let xor_lookup_elements = BlakeXorElements::draw(channel); + let all_elements = AllElements::draw(channel); // Interaction trace. let span = span!(Level::INFO, "Interaction").entered(); let (scheduler_trace, scheduler_claimed_sum) = scheduler::gen_interaction_trace( log_size, scheduler_lookup_data, - &round_lookup_elements, - &blake_lookup_elements, + &all_elements.round_elements, + &all_elements.blake_elements, ); let (round_traces, round_claimed_sums): (Vec<_>, Vec<_>) = multiunzip( @@ -169,22 +299,22 @@ where round::generate_interaction_trace( log_size + l, lookup_data, - &xor_lookup_elements, - &round_lookup_elements, + &all_elements.xor_elements, + &all_elements.round_elements, ) }), ); - let (xor_trace12, xor_claimed_sum12) = - xor_table::generate_interaction_trace(xor_lookup_data12, &xor_lookup_elements.xor12); - let (xor_trace9, xor_claimed_sum9) = - xor_table::generate_interaction_trace(xor_lookup_data9, &xor_lookup_elements.xor9); - let (xor_trace8, xor_claimed_sum8) = - xor_table::generate_interaction_trace(xor_lookup_data8, &xor_lookup_elements.xor8); - let (xor_trace7, xor_claimed_sum7) = - xor_table::generate_interaction_trace(xor_lookup_data7, &xor_lookup_elements.xor7); - let (xor_trace4, xor_claimed_sum4) = - xor_table::generate_interaction_trace(xor_lookup_data4, &xor_lookup_elements.xor4); + let (xor_trace12, xor12_claimed_sum) = + xor_table::generate_interaction_trace(xor_lookup_data12, &all_elements.xor_elements.xor12); + let (xor_trace9, xor9_claimed_sum) = + xor_table::generate_interaction_trace(xor_lookup_data9, &all_elements.xor_elements.xor9); + let (xor_trace8, xor8_claimed_sum) = + xor_table::generate_interaction_trace(xor_lookup_data8, &all_elements.xor_elements.xor8); + let (xor_trace7, xor7_claimed_sum) = + xor_table::generate_interaction_trace(xor_lookup_data7, &all_elements.xor_elements.xor7); + let (xor_trace4, xor4_claimed_sum) = + xor_table::generate_interaction_trace(xor_lookup_data4, &all_elements.xor_elements.xor4); let mut tree_builder = commitment_scheme.tree_builder(); tree_builder.extend_evals( @@ -199,6 +329,18 @@ where ] .collect_vec(), ); + + // Statement1. + let stmt1 = BlakeStatement1 { + scheduler_claimed_sum, + round_claimed_sums, + xor12_claimed_sum, + xor9_claimed_sum, + xor8_claimed_sum, + xor7_claimed_sum, + xor4_claimed_sum, + }; + stmt1.mix_into(channel); tree_builder.commit(channel); span.exit(); @@ -220,77 +362,93 @@ where tree_builder.commit(channel); span.exit(); + assert_eq!( + commitment_scheme + .polynomials() + .as_cols_ref() + .map_cols(|c| c.log_size()) + .0, + stmt0.log_sizes().0 + ); + // Prove constraints. - let scheduler_component = BlakeSchedulerComponent { - log_size, - blake_lookup_elements, - round_lookup_elements: round_lookup_elements.clone(), - claimed_sum: scheduler_claimed_sum, - }; - let round_components = round_claimed_sums - .into_iter() - .zip(ROUND_LOG_SPLIT) - .map(|(claimed_sum, l)| BlakeRoundComponent { - log_size: log_size + l, - xor_lookup_elements: xor_lookup_elements.clone(), - round_lookup_elements: round_lookup_elements.clone(), - claimed_sum, - }) - .collect(); - let xor12 = XorTableComponent::<12, 4> { - lookup_elements: xor_lookup_elements.xor12, - claimed_sum: xor_claimed_sum12, - }; - let xor9 = XorTableComponent::<9, 2> { - lookup_elements: xor_lookup_elements.xor9, - claimed_sum: xor_claimed_sum9, - }; - let xor8 = XorTableComponent::<8, 2> { - lookup_elements: xor_lookup_elements.xor8, - claimed_sum: xor_claimed_sum8, - }; - let xor7 = XorTableComponent::<7, 2> { - lookup_elements: xor_lookup_elements.xor7, - claimed_sum: xor_claimed_sum7, - }; - let xor4 = XorTableComponent::<4, 0> { - lookup_elements: xor_lookup_elements.xor4, - claimed_sum: xor_claimed_sum4, - }; - let air = BlakeAir { - scheduler_component, - round_components, - xor12, - xor9, - xor8, - xor7, - xor4, - }; - let proof = prove::( - &air.component_provers(), + let components = BlakeComponents::new(&stmt0, &all_elements, &stmt1); + let stark_proof = prove::( + &components.component_provers(), channel, &InteractionElements::default(), commitment_scheme, ) .unwrap(); - (air, proof) + BlakeProof { + stmt0, + stmt1, + stark_proof, + } +} + +#[allow(unused)] +pub fn verify_blake( + BlakeProof { + stmt0, + stmt1, + stark_proof, + }: BlakeProof, +) -> Result<(), VerificationError> +where + C: Channel, + H: MerkleHasher, +{ + let channel = &mut C::new(C::Digest::default()); + let commitment_scheme = &mut CommitmentSchemeVerifier::new(); + + let log_sizes = stmt0.log_sizes(); + + // Trace. + stmt0.mix_into(channel); + commitment_scheme.commit(stark_proof.commitments[0], &log_sizes[0], channel); + + // Draw interaction elements. + let all_elements = AllElements::draw(channel); + + // Interaction trace. + stmt1.mix_into(channel); + commitment_scheme.commit(stark_proof.commitments[1], &log_sizes[1], channel); + + // Constant trace. + commitment_scheme.commit(stark_proof.commitments[2], &log_sizes[2], channel); + + let components = BlakeComponents::new(&stmt0, &all_elements, &stmt1); + + // Check that all sums are correct. + let total_sum = stmt1.scheduler_claimed_sum + + stmt1.round_claimed_sums.iter().sum::() + + stmt1.xor12_claimed_sum + + stmt1.xor9_claimed_sum + + stmt1.xor8_claimed_sum + + stmt1.xor7_claimed_sum + + stmt1.xor4_claimed_sum; + + // TODO(spapini): Add inputs to sum, and constraint them. + assert_eq!(total_sum, SecureField::zero()); + + verify( + &components.components(), + channel, + &InteractionElements::default(), // Not in use. + commitment_scheme, + stark_proof, + ) } #[cfg(test)] mod tests { use std::env; - use crate::core::air::{Air, Components}; - use crate::core::channel::{Blake2sChannel, Channel}; - use crate::core::pcs::CommitmentSchemeVerifier; - use crate::core::prover::verify; - use crate::core::vcs::blake2_hash::Blake2sHash; + use crate::core::channel::Blake2sChannel; use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher; - use crate::core::InteractionElements; - use crate::examples::blake::air::prove_blake; - use crate::examples::blake::round::RoundElements; - use crate::examples::blake::xor_table::XorElements; + use crate::examples::blake::air::{prove_blake, verify_blake}; // Note: this test is slow. Only run in release. #[ignore] @@ -308,45 +466,9 @@ mod tests { .unwrap(); // Prove. - let (air, proof) = prove_blake::(log_n_instances); + let proof = prove_blake::(log_n_instances); // Verify. - // TODO: Create Air instance independently. - let channel = &mut Blake2sChannel::new(Blake2sHash::default()); - let commitment_scheme = &mut CommitmentSchemeVerifier::new(); - - // Decommit. - let sizes = Components(air.components()).column_log_sizes(); - - // Trace columns. - commitment_scheme.commit(proof.commitments[0], &sizes[0], channel); - // Draw lookup element. - let blake_lookup_elements = RoundElements::draw(channel); - let round_lookup_elements = RoundElements::draw(channel); - let xor_lookup_elements = XorElements::draw(channel); - assert_eq!( - blake_lookup_elements, - air.scheduler_component.blake_lookup_elements - ); - assert_eq!( - round_lookup_elements, - air.scheduler_component.round_lookup_elements - ); - assert_eq!(xor_lookup_elements, air.xor12.lookup_elements); - - // TODO(spapini): Check claimed sum against first and last instances. - // Interaction columns. - commitment_scheme.commit(proof.commitments[1], &sizes[1], channel); - // Constant columns. - commitment_scheme.commit(proof.commitments[2], &sizes[2], channel); - - verify( - &air.components(), - channel, - &InteractionElements::default(), // Not in use. - commitment_scheme, - proof, - ) - .unwrap(); + verify_blake::(proof).unwrap(); } } diff --git a/crates/prover/src/examples/blake/scheduler/constraints.rs b/crates/prover/src/examples/blake/scheduler/constraints.rs index 395dac6f8..9d5057fed 100644 --- a/crates/prover/src/examples/blake/scheduler/constraints.rs +++ b/crates/prover/src/examples/blake/scheduler/constraints.rs @@ -1,5 +1,5 @@ use itertools::{chain, Itertools}; -use num_traits::One; +use num_traits::{One, Zero}; use super::BlakeElements; use crate::constraint_framework::logup::LogupAtRow; @@ -42,10 +42,11 @@ impl<'a, E: EvalAtRow> BlakeSchedulerEval<'a, E> { let input_state = &states[0]; let output_state = &states[N_ROUNDS]; - // TODO: support multiplicities. + // TODO(spapini): Support multiplicities. + // TODO(spapini): Change to -1. self.logup.push_lookup( &mut self.eval, - -E::EF::one(), + E::EF::zero(), &chain![ input_state.iter().copied().flat_map(Fu32::to_felts), output_state.iter().copied().flat_map(Fu32::to_felts), diff --git a/crates/prover/src/examples/blake/scheduler/gen.rs b/crates/prover/src/examples/blake/scheduler/gen.rs index 3b4437920..0581b2fe1 100644 --- a/crates/prover/src/examples/blake/scheduler/gen.rs +++ b/crates/prover/src/examples/blake/scheduler/gen.rs @@ -1,7 +1,7 @@ use std::simd::u32x16; use itertools::{chain, Itertools}; -use num_traits::One; +use num_traits::Zero; use tracing::{span, Level}; use super::{blake_scheduler_info, BlakeElements}; @@ -158,9 +158,11 @@ pub fn gen_interaction_trace( .each_ref() .map(|l| l.data[vec_row]), ); - col_gen.write_frac(vec_row, p_blake - p_round, p_round * p_blake); + // TODO(spapini): Change blake numerator to p_blake - p_round. + col_gen.write_frac(vec_row, p_blake, p_round * p_blake); } else { - col_gen.write_frac(vec_row, -PackedSecureField::one(), p_blake); + // TODO(spapini): Change numerator to -1. + col_gen.write_frac(vec_row, PackedSecureField::zero(), p_blake); } } col_gen.finalize_col(); diff --git a/crates/prover/src/examples/blake/xor_table/mod.rs b/crates/prover/src/examples/blake/xor_table/mod.rs index 5d2781c13..21c417cfd 100644 --- a/crates/prover/src/examples/blake/xor_table/mod.rs +++ b/crates/prover/src/examples/blake/xor_table/mod.rs @@ -17,13 +17,26 @@ use std::simd::u32x16; use constraints::XorTableEval; use itertools::Itertools; +use num_traits::Zero; pub use r#gen::{generate_constant_trace, generate_interaction_trace, generate_trace}; use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; -use crate::constraint_framework::{EvalAtRow, FrameworkComponent}; +use crate::constraint_framework::{EvalAtRow, FrameworkComponent, InfoEvaluator}; use crate::core::backend::simd::column::BaseColumn; use crate::core::backend::Column; use crate::core::fields::qm31::SecureField; +use crate::core::pcs::TreeVec; + +pub fn trace_sizes() -> TreeVec> { + let component = XorTableComponent:: { + lookup_elements: LookupElements::<3>::dummy(), + claimed_sum: SecureField::zero(), + }; + let info = component.evaluate(InfoEvaluator::default()); + info.mask_offsets + .as_cols_ref() + .map_cols(|_| column_bits::()) +} const fn limb_bits() -> u32 { ELEM_BITS - EXPAND_BITS