-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
87a136c
commit 1ba9ec1
Showing
4 changed files
with
657 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,95 @@ | ||
#![allow(unused)] | ||
use std::fmt::Debug; | ||
use std::ops::{Add, AddAssign, Mul, Sub}; | ||
use std::simd::u32x16; | ||
|
||
use xor_table::XorAccumulator; | ||
|
||
use crate::constraint_framework::logup::LookupElements; | ||
use crate::core::channel::Blake2sChannel; | ||
use crate::core::fields::m31::BaseField; | ||
use crate::core::fields::FieldExpOps; | ||
|
||
mod round; | ||
mod xor_table; | ||
|
||
#[derive(Default)] | ||
struct XorAccums { | ||
xor12: XorAccumulator<12, 4>, | ||
xor9: XorAccumulator<9, 2>, | ||
xor8: XorAccumulator<8, 2>, | ||
xor7: XorAccumulator<7, 2>, | ||
xor4: XorAccumulator<4, 0>, | ||
} | ||
impl XorAccums { | ||
fn add_input(&mut self, w: u32, a: u32x16, b: u32x16) { | ||
match w { | ||
12 => self.xor12.add_input(a, b), | ||
9 => self.xor9.add_input(a, b), | ||
8 => self.xor8.add_input(a, b), | ||
7 => self.xor7.add_input(a, b), | ||
4 => self.xor4.add_input(a, b), | ||
_ => panic!("Invalid w"), | ||
} | ||
} | ||
} | ||
|
||
#[derive(Clone)] | ||
pub struct XorLookupElements { | ||
xor12: LookupElements, | ||
xor9: LookupElements, | ||
xor8: LookupElements, | ||
xor7: LookupElements, | ||
xor4: LookupElements, | ||
} | ||
impl XorLookupElements { | ||
fn draw(channel: &mut Blake2sChannel) -> Self { | ||
Self { | ||
xor12: LookupElements::draw(channel, 3), | ||
xor9: LookupElements::draw(channel, 3), | ||
xor8: LookupElements::draw(channel, 3), | ||
xor7: LookupElements::draw(channel, 3), | ||
xor4: LookupElements::draw(channel, 3), | ||
} | ||
} | ||
|
||
fn get(&self, w: u32) -> &LookupElements { | ||
match w { | ||
12 => &self.xor12, | ||
9 => &self.xor9, | ||
8 => &self.xor8, | ||
7 => &self.xor7, | ||
4 => &self.xor4, | ||
_ => panic!("Invalid w"), | ||
} | ||
} | ||
} | ||
|
||
#[derive(Clone, Copy, Debug)] | ||
struct Fu32<F> | ||
where | ||
F: FieldExpOps | ||
+ Copy | ||
+ Debug | ||
+ AddAssign<F> | ||
+ Add<F, Output = F> | ||
+ Sub<F, Output = F> | ||
+ Mul<BaseField, Output = F>, | ||
{ | ||
l: F, | ||
h: F, | ||
} | ||
impl<F> Fu32<F> | ||
where | ||
F: FieldExpOps | ||
+ Copy | ||
+ Debug | ||
+ AddAssign<F> | ||
+ Add<F, Output = F> | ||
+ Sub<F, Output = F> | ||
+ Mul<BaseField, Output = F>, | ||
{ | ||
fn to_felts(self) -> [F; 2] { | ||
[self.l, self.h] | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
use itertools::{chain, Itertools}; | ||
use num_traits::One; | ||
|
||
use super::XorLookupElements; | ||
use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; | ||
use crate::constraint_framework::EvalAtRow; | ||
use crate::core::fields::m31::BaseField; | ||
use crate::examples::blake::Fu32; | ||
|
||
const I16: BaseField = BaseField::from_u32_unchecked(1 << 15); | ||
|
||
pub struct BlakeRoundEval<'a, E: EvalAtRow> { | ||
pub eval: E, | ||
pub xor_lookup_elements: &'a XorLookupElements, | ||
pub round_lookup_elements: &'a LookupElements, | ||
pub logup: LogupAtRow<2, E>, | ||
} | ||
impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { | ||
pub fn eval(mut self) -> E { | ||
let mut v: [Fu32<E::F>; 16] = std::array::from_fn(|_| self.next_u32()); | ||
let input_v = v; | ||
let m: [Fu32<E::F>; 16] = std::array::from_fn(|_| self.next_u32()); | ||
|
||
self.g(0, v.get_many_mut([0, 4, 8, 12]).unwrap(), m[0], m[1]); | ||
self.g(1, v.get_many_mut([1, 5, 9, 13]).unwrap(), m[2], m[3]); | ||
self.g(2, v.get_many_mut([2, 6, 10, 14]).unwrap(), m[4], m[5]); | ||
self.g(3, v.get_many_mut([3, 7, 11, 15]).unwrap(), m[6], m[7]); | ||
self.g(4, v.get_many_mut([0, 5, 10, 15]).unwrap(), m[8], m[9]); | ||
self.g(5, v.get_many_mut([1, 6, 11, 12]).unwrap(), m[10], m[11]); | ||
self.g(6, v.get_many_mut([2, 7, 8, 13]).unwrap(), m[12], m[13]); | ||
self.g(7, v.get_many_mut([3, 4, 9, 14]).unwrap(), m[14], m[15]); | ||
|
||
self.logup.push_lookup( | ||
&mut self.eval, | ||
-E::EF::one(), | ||
&chain![ | ||
input_v.iter().copied().flat_map(Fu32::to_felts), | ||
v.iter().copied().flat_map(Fu32::to_felts), | ||
m.iter().copied().flat_map(Fu32::to_felts) | ||
] | ||
.collect_vec(), | ||
self.round_lookup_elements, | ||
); | ||
|
||
self.logup.finalize(&mut self.eval); | ||
self.eval | ||
} | ||
fn next_u32(&mut self) -> Fu32<E::F> { | ||
let l = self.eval.next_trace_mask(); | ||
let h = self.eval.next_trace_mask(); | ||
Fu32 { l, h } | ||
} | ||
fn g(&mut self, _round: u32, v: [&mut Fu32<E::F>; 4], m0: Fu32<E::F>, m1: Fu32<E::F>) { | ||
let [a, b, c, d] = v; | ||
|
||
*a = self.add3_u32_unchecked(*a, *b, m0); | ||
*d = self.xor_rotr16_u32(*a, *d); | ||
*c = self.add2_u32_unchecked(*c, *d); | ||
*b = self.xor_rotr_u32(*b, *c, 12); | ||
*a = self.add3_u32_unchecked(*a, *b, m1); | ||
*d = self.xor_rotr_u32(*a, *d, 8); | ||
*c = self.add2_u32_unchecked(*c, *d); | ||
*b = self.xor_rotr_u32(*b, *c, 7); | ||
} | ||
|
||
/// Adds two u32s, returning the sum. | ||
/// Assumes a, b are properly range checked. | ||
/// The caller is responsible for checking: | ||
/// res.{l,h} not in [2^16, 2^17) or in [-2^16,0) | ||
fn add2_u32_unchecked(&mut self, a: Fu32<E::F>, b: Fu32<E::F>) -> Fu32<E::F> { | ||
let sl = self.eval.next_trace_mask(); | ||
let sh = self.eval.next_trace_mask(); | ||
|
||
let carry_l = (a.l + b.l - sl) * E::F::from(I16); | ||
self.eval.add_constraint(carry_l * carry_l - carry_l); | ||
|
||
let carry_h = (a.h + b.h + carry_l - sh) * E::F::from(I16); | ||
self.eval.add_constraint(carry_h * carry_h - carry_h); | ||
|
||
Fu32 { l: sl, h: sh } | ||
} | ||
|
||
/// Adds three u32s, returning the sum. | ||
/// Assumes a, b, c are properly range checked. | ||
/// Caller is responsible for checking: | ||
/// res.{l,h} not in [2^16, 3*2^16) or in [-2^17,0) | ||
fn add3_u32_unchecked(&mut self, a: Fu32<E::F>, b: Fu32<E::F>, c: Fu32<E::F>) -> Fu32<E::F> { | ||
let sl = self.eval.next_trace_mask(); | ||
let sh = self.eval.next_trace_mask(); | ||
|
||
let carry_l = (a.l + b.l + c.l - sl) * E::F::from(I16); | ||
self.eval.add_constraint( | ||
carry_l | ||
* (carry_l - E::F::from(BaseField::from_u32_unchecked(1 << 0))) | ||
* (carry_l - E::F::from(BaseField::from_u32_unchecked(1 << 1))), | ||
); | ||
|
||
let carry_h = (a.h + b.h + c.h + carry_l - sh) * E::F::from(I16); | ||
self.eval.add_constraint( | ||
carry_h | ||
* (carry_h - E::F::from(BaseField::from_u32_unchecked(1 << 0))) | ||
* (carry_h - E::F::from(BaseField::from_u32_unchecked(1 << 1))), | ||
); | ||
|
||
Fu32 { l: sl, h: sh } | ||
} | ||
|
||
/// Splits a felt at r. | ||
/// Caller is responsible for checking that the ranges of h * 2^r and l don't overlap. | ||
fn split_unchecked(&mut self, a: E::F, r: u32) -> (E::F, E::F) { | ||
let h = self.eval.next_trace_mask(); | ||
let l = a - h * E::F::from(BaseField::from_u32_unchecked(1 << r)); | ||
(l, h) | ||
} | ||
|
||
/// Checks that a, b are in range, and computes their xor rotate. | ||
fn xor_rotr_u32(&mut self, a: Fu32<E::F>, b: Fu32<E::F>, r: u32) -> Fu32<E::F> { | ||
let (all, alh) = self.split_unchecked(a.l, r); | ||
let (ahl, ahh) = self.split_unchecked(a.h, r); | ||
let (bll, blh) = self.split_unchecked(b.l, r); | ||
let (bhl, bhh) = self.split_unchecked(b.h, r); | ||
|
||
// These also guarantee that all elements are in range. | ||
let xorll = self.xor(r, all, bll); | ||
let xorlh = self.xor(16 - r, alh, blh); | ||
let xorhl = self.xor(r, ahl, bhl); | ||
let xorhh = self.xor(16 - r, ahh, bhh); | ||
|
||
Fu32 { | ||
l: xorhl * E::F::from(BaseField::from_u32_unchecked(1 << (16 - r))) + xorlh, | ||
h: xorll * E::F::from(BaseField::from_u32_unchecked(1 << (16 - r))) + xorhh, | ||
} | ||
} | ||
|
||
/// Checks that a, b are in range, and computes their xor 16 rotate. | ||
fn xor_rotr16_u32(&mut self, a: Fu32<E::F>, b: Fu32<E::F>) -> Fu32<E::F> { | ||
let (all, alh) = self.split_unchecked(a.l, 8); | ||
let (ahl, ahh) = self.split_unchecked(a.h, 8); | ||
let (bll, blh) = self.split_unchecked(b.l, 8); | ||
let (bhl, bhh) = self.split_unchecked(b.h, 8); | ||
|
||
// These also guarantee that all elements are in range. | ||
let xorll = self.xor(8, all, bll); | ||
let xorlh = self.xor(8, alh, blh); | ||
let xorhl = self.xor(8, ahl, bhl); | ||
let xorhh = self.xor(8, ahh, bhh); | ||
|
||
Fu32 { | ||
l: xorhh * E::F::from(BaseField::from_u32_unchecked(1 << 8)) + xorhl, | ||
h: xorlh * E::F::from(BaseField::from_u32_unchecked(1 << 8)) + xorll, | ||
} | ||
} | ||
|
||
/// Checks that a,b in in [0,2^w) and computes their xor. | ||
fn xor(&mut self, w: u32, a: E::F, b: E::F) -> E::F { | ||
// TODO: Separate lookups by w. | ||
let c = self.eval.next_trace_mask(); | ||
let lookup_elements = self.xor_lookup_elements.get(w); | ||
self.logup | ||
.push_lookup(&mut self.eval, E::EF::one(), &[a, b, c], lookup_elements); | ||
c | ||
} | ||
} |
Oops, something went wrong.