diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index 1d9c4e180..413ee00c1 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -49,12 +49,12 @@ impl LogupAtRow { is_first, } } - pub fn push_lookup( + pub fn push_lookup( &mut self, eval: &mut E, numerator: E::EF, values: &[E::F], - lookup_elements: &LookupElements, + lookup_elements: &LookupElements, ) { let shifted_value = lookup_elements.combine(values); self.push_frac(eval, numerator, shifted_value); @@ -111,24 +111,24 @@ impl LogupAtRow { /// Interaction elements for the logup protocol. #[derive(Clone, Debug, PartialEq, Eq)] -pub struct LookupElements { +pub struct LookupElements { pub z: SecureField, pub alpha: SecureField, - alpha_powers: Vec, + alpha_powers: [SecureField; N], } -impl LookupElements { - pub fn draw(channel: &mut Blake2sChannel, n_powers: usize) -> Self { +impl LookupElements { + pub fn draw(channel: &mut Blake2sChannel) -> Self { let [z, alpha] = channel.draw_felts(2).try_into().unwrap(); + let mut cur = SecureField::one(); + let alpha_powers = std::array::from_fn(|_| { + let res = cur; + cur *= alpha; + res + }); Self { z, alpha, - alpha_powers: (0..n_powers) - .scan(SecureField::one(), |acc, _| { - let res = *acc; - *acc *= alpha; - Some(res) - }) - .collect(), + alpha_powers, } } pub fn combine(&self, values: &[F]) -> EF @@ -144,12 +144,12 @@ impl LookupElements { }) - EF::from(self.z) } - #[cfg(test)] - pub fn dummy(n_powers: usize) -> Self { + // TODO(spapini): Try to remove this. + pub fn dummy() -> Self { Self { z: SecureField::one(), alpha: SecureField::one(), - alpha_powers: vec![SecureField::one(); n_powers], + alpha_powers: [SecureField::one(); N], } } } diff --git a/crates/prover/src/core/pcs/mod.rs b/crates/prover/src/core/pcs/mod.rs index ec4e4ba4c..0494cc7d1 100644 --- a/crates/prover/src/core/pcs/mod.rs +++ b/crates/prover/src/core/pcs/mod.rs @@ -11,9 +11,7 @@ pub mod quotients; mod utils; mod verifier; -pub use self::prover::{ - CommitmentSchemeProof, CommitmentSchemeProver, CommitmentTreeProver, TreeBuilder, -}; +pub use self::prover::{CommitmentSchemeProof, CommitmentSchemeProver, CommitmentTreeProver}; pub use self::utils::TreeVec; pub use self::verifier::CommitmentSchemeVerifier; diff --git a/crates/prover/src/examples/blake/mod.rs b/crates/prover/src/examples/blake/mod.rs index aa69ad098..2bb98efe2 100644 --- a/crates/prover/src/examples/blake/mod.rs +++ b/crates/prover/src/examples/blake/mod.rs @@ -1 +1,106 @@ +//! AIR for blake2s and blake3. +//! See + +#![allow(unused)] +use std::fmt::Debug; +use std::ops::{Add, AddAssign, Mul, Sub}; +use std::simd::u32x16; + +use xor_table::{XorAccumulator, XorElements}; + +use crate::constraint_framework::logup::LookupElements; +use crate::core::channel::Blake2sChannel; +use crate::core::fields::m31::BaseField; +use crate::core::fields::FieldExpOps; + +mod round; mod xor_table; + +#[derive(Default)] +struct XorAccums { + xor12: XorAccumulator<12, 4>, + xor9: XorAccumulator<9, 2>, + xor8: XorAccumulator<8, 2>, + xor7: XorAccumulator<7, 2>, + xor4: XorAccumulator<4, 0>, +} +impl XorAccums { + fn add_input(&mut self, w: u32, a: u32x16, b: u32x16) { + match w { + 12 => self.xor12.add_input(a, b), + 9 => self.xor9.add_input(a, b), + 8 => self.xor8.add_input(a, b), + 7 => self.xor7.add_input(a, b), + 4 => self.xor4.add_input(a, b), + _ => panic!("Invalid w"), + } + } +} + +#[derive(Clone)] +pub struct BlakeXorElements { + xor12: XorElements, + xor9: XorElements, + xor8: XorElements, + xor7: XorElements, + xor4: XorElements, +} +impl BlakeXorElements { + fn draw(channel: &mut Blake2sChannel) -> Self { + Self { + xor12: XorElements::draw(channel), + xor9: XorElements::draw(channel), + xor8: XorElements::draw(channel), + xor7: XorElements::draw(channel), + xor4: XorElements::draw(channel), + } + } + fn dummy() -> Self { + Self { + xor12: XorElements::dummy(), + xor9: XorElements::dummy(), + xor8: XorElements::dummy(), + xor7: XorElements::dummy(), + xor4: XorElements::dummy(), + } + } + fn get(&self, w: u32) -> &XorElements { + match w { + 12 => &self.xor12, + 9 => &self.xor9, + 8 => &self.xor8, + 7 => &self.xor7, + 4 => &self.xor4, + _ => panic!("Invalid w"), + } + } +} + +#[derive(Clone, Copy, Debug)] +struct Fu32 +where + F: FieldExpOps + + Copy + + Debug + + AddAssign + + Add + + Sub + + Mul, +{ + l: F, + h: F, +} +impl Fu32 +where + F: FieldExpOps + + Copy + + Debug + + AddAssign + + Add + + Sub + + Mul, +{ + fn to_felts(self) -> [F; 2] { + [self.l, self.h] + } +} diff --git a/crates/prover/src/examples/blake/round/constraints.rs b/crates/prover/src/examples/blake/round/constraints.rs new file mode 100644 index 000000000..03b80d124 --- /dev/null +++ b/crates/prover/src/examples/blake/round/constraints.rs @@ -0,0 +1,161 @@ +use itertools::{chain, Itertools}; +use num_traits::One; + +use super::{BlakeXorElements, RoundElements}; +use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; +use crate::constraint_framework::EvalAtRow; +use crate::core::fields::m31::BaseField; +use crate::examples::blake::Fu32; + +const INV16: BaseField = BaseField::from_u32_unchecked(1 << 15); +const TWO: BaseField = BaseField::from_u32_unchecked(2); + +pub struct BlakeRoundEval<'a, E: EvalAtRow> { + pub eval: E, + pub xor_lookup_elements: &'a BlakeXorElements, + pub round_lookup_elements: &'a RoundElements, + pub logup: LogupAtRow<2, E>, +} +impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { + pub fn eval(mut self) -> E { + let mut v: [Fu32; 16] = std::array::from_fn(|_| self.next_u32()); + let input_v = v; + let m: [Fu32; 16] = std::array::from_fn(|_| self.next_u32()); + + self.g(v.get_many_mut([0, 4, 8, 12]).unwrap(), m[0], m[1]); + self.g(v.get_many_mut([1, 5, 9, 13]).unwrap(), m[2], m[3]); + self.g(v.get_many_mut([2, 6, 10, 14]).unwrap(), m[4], m[5]); + self.g(v.get_many_mut([3, 7, 11, 15]).unwrap(), m[6], m[7]); + self.g(v.get_many_mut([0, 5, 10, 15]).unwrap(), m[8], m[9]); + self.g(v.get_many_mut([1, 6, 11, 12]).unwrap(), m[10], m[11]); + self.g(v.get_many_mut([2, 7, 8, 13]).unwrap(), m[12], m[13]); + self.g(v.get_many_mut([3, 4, 9, 14]).unwrap(), m[14], m[15]); + + // Yield `Round(input_v, output_v, message)`. + self.logup.push_lookup( + &mut self.eval, + -E::EF::one(), + &chain![ + input_v.iter().copied().flat_map(Fu32::to_felts), + v.iter().copied().flat_map(Fu32::to_felts), + m.iter().copied().flat_map(Fu32::to_felts) + ] + .collect_vec(), + self.round_lookup_elements, + ); + + self.logup.finalize(&mut self.eval); + self.eval + } + fn next_u32(&mut self) -> Fu32 { + let l = self.eval.next_trace_mask(); + let h = self.eval.next_trace_mask(); + Fu32 { l, h } + } + fn g(&mut self, v: [&mut Fu32; 4], m0: Fu32, m1: Fu32) { + let [a, b, c, d] = v; + + *a = self.add3_u32_unchecked(*a, *b, m0); + *d = self.xor_rotr16_u32(*a, *d); + *c = self.add2_u32_unchecked(*c, *d); + *b = self.xor_rotr_u32(*b, *c, 12); + *a = self.add3_u32_unchecked(*a, *b, m1); + *d = self.xor_rotr_u32(*a, *d, 8); + *c = self.add2_u32_unchecked(*c, *d); + *b = self.xor_rotr_u32(*b, *c, 7); + } + + /// Adds two u32s, returning the sum. + /// Assumes a, b are properly range checked. + /// The caller is responsible for checking: + /// res.{l,h} not in [2^16, 2^17) or in [-2^16,0) + fn add2_u32_unchecked(&mut self, a: Fu32, b: Fu32) -> Fu32 { + let sl = self.eval.next_trace_mask(); + let sh = self.eval.next_trace_mask(); + + let carry_l = (a.l + b.l - sl) * E::F::from(INV16); + self.eval.add_constraint(carry_l * carry_l - carry_l); + + let carry_h = (a.h + b.h + carry_l - sh) * E::F::from(INV16); + self.eval.add_constraint(carry_h * carry_h - carry_h); + + Fu32 { l: sl, h: sh } + } + + /// Adds three u32s, returning the sum. + /// Assumes a, b, c are properly range checked. + /// Caller is responsible for checking: + /// res.{l,h} not in [2^16, 3*2^16) or in [-2^17,0) + fn add3_u32_unchecked(&mut self, a: Fu32, b: Fu32, c: Fu32) -> Fu32 { + let sl = self.eval.next_trace_mask(); + let sh = self.eval.next_trace_mask(); + + let carry_l = (a.l + b.l + c.l - sl) * E::F::from(INV16); + self.eval + .add_constraint(carry_l * (carry_l - E::F::one()) * (carry_l - E::F::from(TWO))); + + let carry_h = (a.h + b.h + c.h + carry_l - sh) * E::F::from(INV16); + self.eval + .add_constraint(carry_h * (carry_h - E::F::one()) * (carry_h - E::F::from(TWO))); + + Fu32 { l: sl, h: sh } + } + + /// Splits a felt at r. + /// Caller is responsible for checking that the ranges of h * 2^r and l don't overlap. + fn split_unchecked(&mut self, a: E::F, r: u32) -> (E::F, E::F) { + let h = self.eval.next_trace_mask(); + let l = a - h * E::F::from(BaseField::from_u32_unchecked(1 << r)); + (l, h) + } + + /// Checks that a, b are in range, and computes their xor rotated right by `r` bits. + /// Guarantees that all elements are in range. + fn xor_rotr_u32(&mut self, a: Fu32, b: Fu32, r: u32) -> Fu32 { + let (all, alh) = self.split_unchecked(a.l, r); + let (ahl, ahh) = self.split_unchecked(a.h, r); + let (bll, blh) = self.split_unchecked(b.l, r); + let (bhl, bhh) = self.split_unchecked(b.h, r); + + // These also guarantee that all elements are in range. + let xorll = self.xor(r, all, bll); + let xorlh = self.xor(16 - r, alh, blh); + let xorhl = self.xor(r, ahl, bhl); + let xorhh = self.xor(16 - r, ahh, bhh); + + Fu32 { + l: xorhl * E::F::from(BaseField::from_u32_unchecked(1 << (16 - r))) + xorlh, + h: xorll * E::F::from(BaseField::from_u32_unchecked(1 << (16 - r))) + xorhh, + } + } + + /// Checks that a, b are in range, and computes their xor rotated right by 16 bits. + /// Guarantees that all elements are in range. + fn xor_rotr16_u32(&mut self, a: Fu32, b: Fu32) -> Fu32 { + let (all, alh) = self.split_unchecked(a.l, 8); + let (ahl, ahh) = self.split_unchecked(a.h, 8); + let (bll, blh) = self.split_unchecked(b.l, 8); + let (bhl, bhh) = self.split_unchecked(b.h, 8); + + // These also guarantee that all elements are in range. + let xorll = self.xor(8, all, bll); + let xorlh = self.xor(8, alh, blh); + let xorhl = self.xor(8, ahl, bhl); + let xorhh = self.xor(8, ahh, bhh); + + Fu32 { + l: xorhh * E::F::from(BaseField::from_u32_unchecked(1 << 8)) + xorhl, + h: xorlh * E::F::from(BaseField::from_u32_unchecked(1 << 8)) + xorll, + } + } + + /// Checks that a, b are in [0, 2^w) and computes their xor. + fn xor(&mut self, w: u32, a: E::F, b: E::F) -> E::F { + // TODO: Separate lookups by w. + let c = self.eval.next_trace_mask(); + let lookup_elements = self.xor_lookup_elements.get(w); + self.logup + .push_lookup(&mut self.eval, E::EF::one(), &[a, b, c], lookup_elements); + c + } +} diff --git a/crates/prover/src/examples/blake/round/gen.rs b/crates/prover/src/examples/blake/round/gen.rs new file mode 100644 index 000000000..267574fbd --- /dev/null +++ b/crates/prover/src/examples/blake/round/gen.rs @@ -0,0 +1,283 @@ +use std::simd::u32x16; +use std::vec; + +use itertools::{chain, Itertools}; +use num_traits::One; +use tracing::{span, Level}; + +use super::{BlakeXorElements, RoundElements}; +use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements}; +use crate::core::backend::simd::column::BaseColumn; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::{Col, Column}; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use crate::core::poly::BitReversedOrder; +use crate::core::ColumnVec; +use crate::examples::blake::round::blake_round_info; +use crate::examples::blake::XorAccums; + +pub struct BlakeRoundLookupData { + /// A vector of (w, [a_col, b_col, c_col]) for each xor lookup. + /// w is the xor width. c_col is the xor col of a_col and b_col. + xor_lookups: Vec<(u32, [BaseColumn; 3])>, + /// A column of round lookup values (v_in, v_out, m). + round_lookup: [BaseColumn; 16 * 3 * 2], +} + +pub struct TraceGenerator { + log_size: u32, + trace: Vec, + xor_lookups: Vec<(u32, [BaseColumn; 3])>, + round_lookup: [BaseColumn; 16 * 3 * 2], +} +impl TraceGenerator { + fn new(log_size: u32) -> Self { + assert!(log_size >= LOG_N_LANES); + let trace = (0..blake_round_info().mask_offsets[0].len()) + .map(|_| unsafe { Col::::uninitialized(1 << log_size) }) + .collect_vec(); + Self { + log_size, + trace, + xor_lookups: vec![], + round_lookup: std::array::from_fn(|_| unsafe { + BaseColumn::uninitialized(1 << log_size) + }), + } + } + + fn gen_row(&mut self, vec_row: usize) -> TraceGeneratorRow<'_> { + TraceGeneratorRow { + gen: self, + col_index: 0, + vec_row, + xor_lookups_index: 0, + } + } +} + +/// Trace generator for the constraints defined at [`super::constraints::BlakeRoundEval`] +struct TraceGeneratorRow<'a> { + gen: &'a mut TraceGenerator, + col_index: usize, + vec_row: usize, + xor_lookups_index: usize, +} +impl<'a> TraceGeneratorRow<'a> { + fn append_felt(&mut self, val: u32x16) { + self.gen.trace[self.col_index].data[self.vec_row] = + unsafe { PackedBaseField::from_simd_unchecked(val) }; + self.col_index += 1; + } + + fn append_u32(&mut self, val: u32x16) { + self.append_felt(val & u32x16::splat(0xffff)); + self.append_felt(val >> 16); + } + + fn generate(&mut self, mut v: [u32x16; 16], m: [u32x16; 16]) { + let input_v = v; + v.iter().for_each(|s| { + self.append_u32(*s); + }); + m.iter().for_each(|s| { + self.append_u32(*s); + }); + + self.g(v.get_many_mut([0, 4, 8, 12]).unwrap(), m[0], m[1]); + self.g(v.get_many_mut([1, 5, 9, 13]).unwrap(), m[2], m[3]); + self.g(v.get_many_mut([2, 6, 10, 14]).unwrap(), m[4], m[5]); + self.g(v.get_many_mut([3, 7, 11, 15]).unwrap(), m[6], m[7]); + self.g(v.get_many_mut([0, 5, 10, 15]).unwrap(), m[8], m[9]); + self.g(v.get_many_mut([1, 6, 11, 12]).unwrap(), m[10], m[11]); + self.g(v.get_many_mut([2, 7, 8, 13]).unwrap(), m[12], m[13]); + self.g(v.get_many_mut([3, 4, 9, 14]).unwrap(), m[14], m[15]); + + chain![input_v.iter(), v.iter(), m.iter()] + .flat_map(|s| [s & u32x16::splat(0xffff), s >> 16]) + .enumerate() + .for_each(|(i, val)| { + self.gen.round_lookup[i].data[self.vec_row] = + unsafe { PackedBaseField::from_simd_unchecked(val) } + }); + } + + fn g(&mut self, v: [&mut u32x16; 4], m0: u32x16, m1: u32x16) { + let [a, b, c, d] = v; + + *a = self.add3_u32s(*a, *b, m0); + *d = self.xor_rotr16_u32(*a, *d); + *c = self.add2_u32s(*c, *d); + *b = self.xor_rotr_u32(*b, *c, 12); + *a = self.add3_u32s(*a, *b, m1); + *d = self.xor_rotr_u32(*a, *d, 8); + *c = self.add2_u32s(*c, *d); + *b = self.xor_rotr_u32(*b, *c, 7); + } + + /// Adds two u32s, returning the sum. + fn add2_u32s(&mut self, a: u32x16, b: u32x16) -> u32x16 { + let s = a + b; + self.append_u32(s); + s + } + + /// Adds three u32s, returning the sum. + fn add3_u32s(&mut self, a: u32x16, b: u32x16, c: u32x16) -> u32x16 { + let s = a + b + c; + self.append_u32(s); + s + } + + /// Splits a felt at r. + fn split(&mut self, a: u32x16, r: u32) -> (u32x16, u32x16) { + let h = a >> r; + let l = a & u32x16::splat((1 << r) - 1); + self.append_felt(h); + (l, h) + } + + /// Checks that a, b are in range, and computes their xor rotated right by `r` bits. + fn xor_rotr_u32(&mut self, a: u32x16, b: u32x16, r: u32) -> u32x16 { + let c = a ^ b; + let cr = (c >> r) | (c << (32 - r)); + + let (all, alh) = self.split(a & u32x16::splat(0xffff), r); + let (ahl, ahh) = self.split(a >> 16, r); + let (bll, blh) = self.split(b & u32x16::splat(0xffff), r); + let (bhl, bhh) = self.split(b >> 16, r); + + self.xor(r, all, bll); + self.xor(16 - r, alh, blh); + self.xor(r, ahl, bhl); + self.xor(16 - r, ahh, bhh); + + cr + } + + /// Checks that a, b are in range, and computes their xor rotated right by 16 bits. + fn xor_rotr16_u32(&mut self, a: u32x16, b: u32x16) -> u32x16 { + let c = a ^ b; + let cr = (c >> 16) | (c << 16); + + let (all, alh) = self.split(a & u32x16::splat(0xffff), 8); + let (ahl, ahh) = self.split(a >> 16, 8); + let (bll, blh) = self.split(b & u32x16::splat(0xffff), 8); + let (bhl, bhh) = self.split(b >> 16, 8); + + self.xor(8, all, bll); + self.xor(8, alh, blh); + self.xor(8, ahl, bhl); + self.xor(8, ahh, bhh); + + cr + } + + /// Checks that a, b are in [0, 2^w) and computes their xor. + /// a,b,a^b are assumed to fit in a single felt. + fn xor(&mut self, w: u32, a: u32x16, b: u32x16) -> u32x16 { + let c = a ^ b; + self.append_felt(c); + if self.gen.xor_lookups.len() <= self.xor_lookups_index { + self.gen.xor_lookups.push(( + w, + std::array::from_fn(|_| unsafe { + BaseColumn::uninitialized(1 << self.gen.log_size) + }), + )); + } + self.gen.xor_lookups[self.xor_lookups_index].1[0].data[self.vec_row] = + unsafe { PackedBaseField::from_simd_unchecked(a) }; + self.gen.xor_lookups[self.xor_lookups_index].1[1].data[self.vec_row] = + unsafe { PackedBaseField::from_simd_unchecked(b) }; + self.gen.xor_lookups[self.xor_lookups_index].1[2].data[self.vec_row] = + unsafe { PackedBaseField::from_simd_unchecked(c) }; + self.xor_lookups_index += 1; + c + } +} + +#[derive(Copy, Clone, Default)] +pub struct BlakeRoundInput { + pub v: [u32x16; 16], + pub m: [u32x16; 16], +} + +pub fn generate_trace( + log_size: u32, + inputs: &[BlakeRoundInput], + xor_accum: &mut XorAccums, +) -> ( + ColumnVec>, + BlakeRoundLookupData, +) { + let mut generator = TraceGenerator::new(log_size); + + for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + let mut row_gen = generator.gen_row(vec_row); + let BlakeRoundInput { v, m } = inputs.get(vec_row).copied().unwrap_or_default(); + row_gen.generate(v, m); + for (w, [a, b, _c]) in &generator.xor_lookups { + let a = a.data[vec_row].into_simd(); + let b = b.data[vec_row].into_simd(); + xor_accum.add_input(*w, a, b); + } + } + let domain = CanonicCoset::new(log_size).circle_domain(); + ( + generator + .trace + .into_iter() + .map(|eval| CircleEvaluation::::new(domain, eval)) + .collect_vec(), + BlakeRoundLookupData { + xor_lookups: generator.xor_lookups, + round_lookup: generator.round_lookup, + }, + ) +} + +pub fn generate_interaction_trace( + log_size: u32, + lookup_data: BlakeRoundLookupData, + xor_lookup_elements: &BlakeXorElements, + round_lookup_elements: &RoundElements, +) -> ( + ColumnVec>, + SecureField, +) { + let _span = span!(Level::INFO, "Generate round interaction trace").entered(); + let mut logup_gen = LogupTraceGenerator::new(log_size); + + for [(w0, l0), (w1, l1)] in lookup_data.xor_lookups.array_chunks::<2>() { + let mut col_gen = logup_gen.new_col(); + + #[allow(clippy::needless_range_loop)] + for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + let p0: PackedSecureField = xor_lookup_elements + .get(*w0) + .combine(&l0.each_ref().map(|l| l.data[vec_row])); + let p1: PackedSecureField = xor_lookup_elements + .get(*w1) + .combine(&l1.each_ref().map(|l| l.data[vec_row])); + col_gen.write_frac(vec_row, p0 + p1, p0 * p1); + } + + col_gen.finalize_col(); + } + + let mut col_gen = logup_gen.new_col(); + #[allow(clippy::needless_range_loop)] + for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + let p = round_lookup_elements + .combine(&lookup_data.round_lookup.each_ref().map(|l| l.data[vec_row])); + col_gen.write_frac(vec_row, -PackedSecureField::one(), p); + } + col_gen.finalize_col(); + + logup_gen.finalize() +} diff --git a/crates/prover/src/examples/blake/round/mod.rs b/crates/prover/src/examples/blake/round/mod.rs new file mode 100644 index 000000000..fec7d3976 --- /dev/null +++ b/crates/prover/src/examples/blake/round/mod.rs @@ -0,0 +1,113 @@ +mod constraints; +mod gen; + +use constraints::BlakeRoundEval; +use num_traits::Zero; + +use super::BlakeXorElements; +use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; +use crate::constraint_framework::{EvalAtRow, FrameworkComponent, InfoEvaluator}; +use crate::core::fields::qm31::SecureField; +use crate::examples::blake::XorAccums; + +pub fn blake_round_info() -> InfoEvaluator { + let component = BlakeRoundComponent { + log_size: 1, + xor_lookup_elements: BlakeXorElements::dummy(), + round_lookup_elements: RoundElements::dummy(), + claimed_sum: SecureField::zero(), + }; + component.evaluate(InfoEvaluator::default()) +} + +pub type RoundElements = LookupElements<{ 16 * 3 * 2 }>; +pub struct BlakeRoundComponent { + pub log_size: u32, + pub xor_lookup_elements: BlakeXorElements, + pub round_lookup_elements: RoundElements, + pub claimed_sum: SecureField, +} + +impl FrameworkComponent for BlakeRoundComponent { + fn log_size(&self) -> u32 { + self.log_size + } + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size + 1 + } + fn evaluate(&self, mut eval: E) -> E { + let [is_first] = eval.next_interaction_mask(2, [0]); + let blake_eval = BlakeRoundEval { + eval, + xor_lookup_elements: &self.xor_lookup_elements, + round_lookup_elements: &self.round_lookup_elements, + logup: LogupAtRow::new(1, self.claimed_sum, is_first), + }; + blake_eval.eval() + } +} + +#[cfg(test)] +mod tests { + use std::simd::Simd; + + use itertools::Itertools; + + use crate::constraint_framework::constant_columns::gen_is_first; + use crate::constraint_framework::logup::LookupElements; + use crate::constraint_framework::FrameworkComponent; + use crate::core::backend::simd::SimdBackend; + use crate::core::fields::m31::BaseField; + use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; + use crate::core::poly::BitReversedOrder; + use crate::examples::blake::round::r#gen::{ + generate_interaction_trace, generate_trace, BlakeRoundInput, + }; + use crate::examples::blake::round::{BlakeRoundComponent, RoundElements}; + use crate::examples::blake::{BlakeXorElements, XorAccums}; + + #[test] + fn test_blake_round() { + use crate::core::pcs::TreeVec; + + const LOG_SIZE: u32 = 10; + + let mut xor_accum = XorAccums::default(); + let (trace, lookup_data) = generate_trace( + LOG_SIZE, + &(0..(1 << LOG_SIZE)) + .map(|i| BlakeRoundInput { + v: std::array::from_fn(|i| Simd::splat(i as u32)), + m: std::array::from_fn(|i| Simd::splat((i + 1) as u32)), + }) + .collect_vec(), + &mut xor_accum, + ); + + let xor_lookup_elements = BlakeXorElements::dummy(); + let round_lookup_elements = RoundElements::dummy(); + let (interaction_trace, claimed_sum) = generate_interaction_trace( + LOG_SIZE, + lookup_data, + &xor_lookup_elements, + &round_lookup_elements, + ); + + let trace = TreeVec::new(vec![trace, interaction_trace, vec![gen_is_first(LOG_SIZE)]]); + let trace_polys = trace.map_cols(|c| c.interpolate()); + + let component = BlakeRoundComponent { + log_size: LOG_SIZE, + xor_lookup_elements, + round_lookup_elements, + claimed_sum, + }; + crate::constraint_framework::assert_constraints( + &trace_polys, + CanonicCoset::new(LOG_SIZE), + |eval| { + component.evaluate(eval); + }, + ) + } +} diff --git a/crates/prover/src/examples/blake/xor_table/constraints.rs b/crates/prover/src/examples/blake/xor_table/constraints.rs index f29379a5c..00a658311 100644 --- a/crates/prover/src/examples/blake/xor_table/constraints.rs +++ b/crates/prover/src/examples/blake/xor_table/constraints.rs @@ -1,4 +1,4 @@ -use super::limb_bits; +use super::{limb_bits, XorElements}; use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; use crate::constraint_framework::EvalAtRow; use crate::core::fields::m31::BaseField; @@ -6,7 +6,7 @@ use crate::core::fields::m31::BaseField; /// Constraints for the xor table. pub struct XorTableEval<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32> { pub eval: E, - pub lookup_elements: &'a LookupElements, + pub lookup_elements: &'a XorElements, pub logup: LogupAtRow<2, E>, } impl<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32> diff --git a/crates/prover/src/examples/blake/xor_table/gen.rs b/crates/prover/src/examples/blake/xor_table/gen.rs index ee26f0e0a..0158b4b1e 100644 --- a/crates/prover/src/examples/blake/xor_table/gen.rs +++ b/crates/prover/src/examples/blake/xor_table/gen.rs @@ -3,7 +3,7 @@ use std::simd::u32x16; use itertools::Itertools; use tracing::{span, Level}; -use super::{column_bits, limb_bits, XorAccumulator}; +use super::{column_bits, limb_bits, XorAccumulator, XorElements}; use crate::constraint_framework::constant_columns::gen_is_first; use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements}; use crate::core::backend::simd::column::BaseColumn; @@ -46,7 +46,7 @@ pub fn generate_trace( #[allow(clippy::type_complexity)] pub fn generate_interaction_trace( lookup_data: XorTableLookupData, - lookup_elements: &LookupElements, + lookup_elements: &XorElements, ) -> ( ColumnVec>, SecureField, diff --git a/crates/prover/src/examples/blake/xor_table/mod.rs b/crates/prover/src/examples/blake/xor_table/mod.rs index 1a55a4611..3a09454fa 100644 --- a/crates/prover/src/examples/blake/xor_table/mod.rs +++ b/crates/prover/src/examples/blake/xor_table/mod.rs @@ -69,8 +69,9 @@ impl XorAccumulator; pub struct XorTableComponent { - pub lookup_elements: LookupElements, + pub lookup_elements: XorElements, pub claimed_sum: SecureField, } impl FrameworkComponent @@ -116,7 +117,7 @@ mod tests { xor_accum.add_input(u32x16::splat(1), u32x16::splat(2)); let (trace, lookup_data) = generate_trace(xor_accum); - let lookup_elements = LookupElements::dummy(3); + let lookup_elements = crate::examples::blake::xor_table::XorElements::dummy(); let (interaction_trace, claimed_sum) = generate_interaction_trace(lookup_data, &lookup_elements); let constant_trace = generate_constant_trace::(); diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 7435b2c88..a37705fc7 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -53,10 +53,11 @@ impl Air for PoseidonAir { } } +pub type PoseidonElements = LookupElements<{ N_STATE * 2 }>; #[derive(Clone)] pub struct PoseidonComponent { pub log_n_rows: u32, - pub lookup_elements: LookupElements, + pub lookup_elements: PoseidonElements, pub claimed_sum: SecureField, } impl FrameworkComponent for PoseidonComponent { @@ -149,7 +150,7 @@ fn pow5(x: F) -> F { struct PoseidonEval<'a, E: EvalAtRow> { eval: E, logup: LogupAtRow<2, E>, - lookup_elements: &'a LookupElements, + lookup_elements: &'a PoseidonElements, } impl<'a, E: EvalAtRow> PoseidonEval<'a, E> { @@ -311,7 +312,7 @@ pub fn gen_trace( pub fn gen_interaction_trace( log_size: u32, lookup_data: LookupData, - lookup_elements: &LookupElements, + lookup_elements: &PoseidonElements, ) -> ( ColumnVec>, SecureField, @@ -369,7 +370,7 @@ pub fn prove_poseidon(log_n_instances: u32) -> (PoseidonAir, StarkProof