Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blake round #767

Merged
merged 1 commit into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ impl<const BATCH_SIZE: usize, E: EvalAtRow> LogupAtRow<BATCH_SIZE, E> {
is_first,
}
}
pub fn push_lookup(
pub fn push_lookup<const N: usize>(
&mut self,
eval: &mut E,
numerator: E::EF,
values: &[E::F],
lookup_elements: &LookupElements,
lookup_elements: &LookupElements<N>,
) {
let shifted_value = lookup_elements.combine(values);
self.push_frac(eval, numerator, shifted_value);
Expand Down Expand Up @@ -111,24 +111,24 @@ impl<const BATCH_SIZE: usize, E: EvalAtRow> LogupAtRow<BATCH_SIZE, E> {

/// Interaction elements for the logup protocol.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LookupElements {
pub struct LookupElements<const N: usize> {
pub z: SecureField,
pub alpha: SecureField,
alpha_powers: Vec<SecureField>,
alpha_powers: [SecureField; N],
}
impl LookupElements {
pub fn draw(channel: &mut Blake2sChannel, n_powers: usize) -> Self {
impl<const N: usize> LookupElements<N> {
pub fn draw(channel: &mut Blake2sChannel) -> Self {
let [z, alpha] = channel.draw_felts(2).try_into().unwrap();
let mut cur = SecureField::one();
let alpha_powers = std::array::from_fn(|_| {
let res = cur;
cur *= alpha;
res
});
Self {
z,
alpha,
alpha_powers: (0..n_powers)
.scan(SecureField::one(), |acc, _| {
let res = *acc;
*acc *= alpha;
Some(res)
})
.collect(),
alpha_powers,
}
}
pub fn combine<F: Copy, EF>(&self, values: &[F]) -> EF
Expand All @@ -144,12 +144,12 @@ impl LookupElements {
})
- EF::from(self.z)
}
#[cfg(test)]
pub fn dummy(n_powers: usize) -> Self {
// TODO(spapini): Try to remove this.
pub fn dummy() -> Self {
Self {
z: SecureField::one(),
alpha: SecureField::one(),
alpha_powers: vec![SecureField::one(); n_powers],
alpha_powers: [SecureField::one(); N],
}
}
}
Expand Down
4 changes: 1 addition & 3 deletions crates/prover/src/core/pcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ pub mod quotients;
mod utils;
mod verifier;

pub use self::prover::{
CommitmentSchemeProof, CommitmentSchemeProver, CommitmentTreeProver, TreeBuilder,
};
pub use self::prover::{CommitmentSchemeProof, CommitmentSchemeProver, CommitmentTreeProver};
pub use self::utils::TreeVec;
pub use self::verifier::CommitmentSchemeVerifier;

Expand Down
105 changes: 105 additions & 0 deletions crates/prover/src/examples/blake/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,106 @@
//! AIR for blake2s and blake3.
//! See <https://en.wikipedia.org/wiki/BLAKE_(hash_function)>

#![allow(unused)]
use std::fmt::Debug;
use std::ops::{Add, AddAssign, Mul, Sub};
use std::simd::u32x16;

use xor_table::{XorAccumulator, XorElements};

use crate::constraint_framework::logup::LookupElements;
use crate::core::channel::Blake2sChannel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::FieldExpOps;

mod round;
mod xor_table;

#[derive(Default)]
struct XorAccums {
xor12: XorAccumulator<12, 4>,
xor9: XorAccumulator<9, 2>,
xor8: XorAccumulator<8, 2>,
xor7: XorAccumulator<7, 2>,
xor4: XorAccumulator<4, 0>,
}
impl XorAccums {
fn add_input(&mut self, w: u32, a: u32x16, b: u32x16) {
match w {
12 => self.xor12.add_input(a, b),
9 => self.xor9.add_input(a, b),
8 => self.xor8.add_input(a, b),
7 => self.xor7.add_input(a, b),
4 => self.xor4.add_input(a, b),
_ => panic!("Invalid w"),
}
}
}

#[derive(Clone)]
pub struct BlakeXorElements {
xor12: XorElements,
xor9: XorElements,
xor8: XorElements,
xor7: XorElements,
xor4: XorElements,
}
impl BlakeXorElements {
fn draw(channel: &mut Blake2sChannel) -> Self {
Self {
xor12: XorElements::draw(channel),
xor9: XorElements::draw(channel),
xor8: XorElements::draw(channel),
xor7: XorElements::draw(channel),
xor4: XorElements::draw(channel),
}
}
fn dummy() -> Self {
Self {
xor12: XorElements::dummy(),
xor9: XorElements::dummy(),
xor8: XorElements::dummy(),
xor7: XorElements::dummy(),
xor4: XorElements::dummy(),
}
}
fn get(&self, w: u32) -> &XorElements {
match w {
12 => &self.xor12,
9 => &self.xor9,
8 => &self.xor8,
7 => &self.xor7,
4 => &self.xor4,
_ => panic!("Invalid w"),
}
}
}

#[derive(Clone, Copy, Debug)]
struct Fu32<F>
where
F: FieldExpOps
+ Copy
+ Debug
+ AddAssign<F>
+ Add<F, Output = F>
+ Sub<F, Output = F>
+ Mul<BaseField, Output = F>,
{
l: F,
h: F,
}
impl<F> Fu32<F>
where
F: FieldExpOps
+ Copy
+ Debug
+ AddAssign<F>
+ Add<F, Output = F>
+ Sub<F, Output = F>
+ Mul<BaseField, Output = F>,
{
fn to_felts(self) -> [F; 2] {
[self.l, self.h]
}
}
161 changes: 161 additions & 0 deletions crates/prover/src/examples/blake/round/constraints.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
use itertools::{chain, Itertools};
use num_traits::One;

use super::{BlakeXorElements, RoundElements};
use crate::constraint_framework::logup::{LogupAtRow, LookupElements};
use crate::constraint_framework::EvalAtRow;
use crate::core::fields::m31::BaseField;
use crate::examples::blake::Fu32;

const INV16: BaseField = BaseField::from_u32_unchecked(1 << 15);
const TWO: BaseField = BaseField::from_u32_unchecked(2);

pub struct BlakeRoundEval<'a, E: EvalAtRow> {
pub eval: E,
pub xor_lookup_elements: &'a BlakeXorElements,
pub round_lookup_elements: &'a RoundElements,
pub logup: LogupAtRow<2, E>,
}
impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> {
pub fn eval(mut self) -> E {
let mut v: [Fu32<E::F>; 16] = std::array::from_fn(|_| self.next_u32());
let input_v = v;
let m: [Fu32<E::F>; 16] = std::array::from_fn(|_| self.next_u32());

self.g(v.get_many_mut([0, 4, 8, 12]).unwrap(), m[0], m[1]);
self.g(v.get_many_mut([1, 5, 9, 13]).unwrap(), m[2], m[3]);
self.g(v.get_many_mut([2, 6, 10, 14]).unwrap(), m[4], m[5]);
self.g(v.get_many_mut([3, 7, 11, 15]).unwrap(), m[6], m[7]);
self.g(v.get_many_mut([0, 5, 10, 15]).unwrap(), m[8], m[9]);
self.g(v.get_many_mut([1, 6, 11, 12]).unwrap(), m[10], m[11]);
self.g(v.get_many_mut([2, 7, 8, 13]).unwrap(), m[12], m[13]);
self.g(v.get_many_mut([3, 4, 9, 14]).unwrap(), m[14], m[15]);

// Yield `Round(input_v, output_v, message)`.
self.logup.push_lookup(
&mut self.eval,
-E::EF::one(),
&chain![
input_v.iter().copied().flat_map(Fu32::to_felts),
v.iter().copied().flat_map(Fu32::to_felts),
m.iter().copied().flat_map(Fu32::to_felts)
]
.collect_vec(),
self.round_lookup_elements,
);

self.logup.finalize(&mut self.eval);
self.eval
}
fn next_u32(&mut self) -> Fu32<E::F> {
let l = self.eval.next_trace_mask();
let h = self.eval.next_trace_mask();
Fu32 { l, h }
}
fn g(&mut self, v: [&mut Fu32<E::F>; 4], m0: Fu32<E::F>, m1: Fu32<E::F>) {
let [a, b, c, d] = v;

*a = self.add3_u32_unchecked(*a, *b, m0);
*d = self.xor_rotr16_u32(*a, *d);
*c = self.add2_u32_unchecked(*c, *d);
*b = self.xor_rotr_u32(*b, *c, 12);
*a = self.add3_u32_unchecked(*a, *b, m1);
*d = self.xor_rotr_u32(*a, *d, 8);
*c = self.add2_u32_unchecked(*c, *d);
*b = self.xor_rotr_u32(*b, *c, 7);
}

/// Adds two u32s, returning the sum.
/// Assumes a, b are properly range checked.
/// The caller is responsible for checking:
/// res.{l,h} not in [2^16, 2^17) or in [-2^16,0)
fn add2_u32_unchecked(&mut self, a: Fu32<E::F>, b: Fu32<E::F>) -> Fu32<E::F> {
let sl = self.eval.next_trace_mask();
let sh = self.eval.next_trace_mask();

let carry_l = (a.l + b.l - sl) * E::F::from(INV16);
self.eval.add_constraint(carry_l * carry_l - carry_l);

let carry_h = (a.h + b.h + carry_l - sh) * E::F::from(INV16);
self.eval.add_constraint(carry_h * carry_h - carry_h);

Fu32 { l: sl, h: sh }
}

/// Adds three u32s, returning the sum.
/// Assumes a, b, c are properly range checked.
/// Caller is responsible for checking:
/// res.{l,h} not in [2^16, 3*2^16) or in [-2^17,0)
fn add3_u32_unchecked(&mut self, a: Fu32<E::F>, b: Fu32<E::F>, c: Fu32<E::F>) -> Fu32<E::F> {
let sl = self.eval.next_trace_mask();
let sh = self.eval.next_trace_mask();

let carry_l = (a.l + b.l + c.l - sl) * E::F::from(INV16);
self.eval
.add_constraint(carry_l * (carry_l - E::F::one()) * (carry_l - E::F::from(TWO)));

let carry_h = (a.h + b.h + c.h + carry_l - sh) * E::F::from(INV16);
self.eval
.add_constraint(carry_h * (carry_h - E::F::one()) * (carry_h - E::F::from(TWO)));

Fu32 { l: sl, h: sh }
}

/// Splits a felt at r.
/// Caller is responsible for checking that the ranges of h * 2^r and l don't overlap.
fn split_unchecked(&mut self, a: E::F, r: u32) -> (E::F, E::F) {
let h = self.eval.next_trace_mask();
let l = a - h * E::F::from(BaseField::from_u32_unchecked(1 << r));
(l, h)
}

/// Checks that a, b are in range, and computes their xor rotated right by `r` bits.
/// Guarantees that all elements are in range.
fn xor_rotr_u32(&mut self, a: Fu32<E::F>, b: Fu32<E::F>, r: u32) -> Fu32<E::F> {
let (all, alh) = self.split_unchecked(a.l, r);
let (ahl, ahh) = self.split_unchecked(a.h, r);
let (bll, blh) = self.split_unchecked(b.l, r);
let (bhl, bhh) = self.split_unchecked(b.h, r);

// These also guarantee that all elements are in range.
let xorll = self.xor(r, all, bll);
let xorlh = self.xor(16 - r, alh, blh);
let xorhl = self.xor(r, ahl, bhl);
let xorhh = self.xor(16 - r, ahh, bhh);

Fu32 {
l: xorhl * E::F::from(BaseField::from_u32_unchecked(1 << (16 - r))) + xorlh,
h: xorll * E::F::from(BaseField::from_u32_unchecked(1 << (16 - r))) + xorhh,
}
}

/// Checks that a, b are in range, and computes their xor rotated right by 16 bits.
/// Guarantees that all elements are in range.
fn xor_rotr16_u32(&mut self, a: Fu32<E::F>, b: Fu32<E::F>) -> Fu32<E::F> {
let (all, alh) = self.split_unchecked(a.l, 8);
let (ahl, ahh) = self.split_unchecked(a.h, 8);
let (bll, blh) = self.split_unchecked(b.l, 8);
let (bhl, bhh) = self.split_unchecked(b.h, 8);

// These also guarantee that all elements are in range.
let xorll = self.xor(8, all, bll);
let xorlh = self.xor(8, alh, blh);
let xorhl = self.xor(8, ahl, bhl);
let xorhh = self.xor(8, ahh, bhh);

Fu32 {
l: xorhh * E::F::from(BaseField::from_u32_unchecked(1 << 8)) + xorhl,
h: xorlh * E::F::from(BaseField::from_u32_unchecked(1 << 8)) + xorll,
}
}

/// Checks that a, b are in [0, 2^w) and computes their xor.
fn xor(&mut self, w: u32, a: E::F, b: E::F) -> E::F {
// TODO: Separate lookups by w.
let c = self.eval.next_trace_mask();
let lookup_elements = self.xor_lookup_elements.get(w);
self.logup
.push_lookup(&mut self.eval, E::EF::one(), &[a, b, c], lookup_elements);
c
}
}
Loading
Loading