Skip to content

Commit

Permalink
Blake scheduler (#770)
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware authored Aug 6, 2024
1 parent 8344112 commit 925a424
Show file tree
Hide file tree
Showing 8 changed files with 379 additions and 17 deletions.
4 changes: 2 additions & 2 deletions crates/prover/src/core/backend/simd/blake2s.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ const IV: [u32; 8] = [
0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19,
];

const SIGMA: [[u8; 16]; 10] = [
pub const SIGMA: [[u8; 16]; 10] = [
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
[14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3],
[11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4],
Expand Down Expand Up @@ -130,7 +130,7 @@ fn rotate<const N: u32>(x: u32x16) -> u32x16 {

// `inline(always)` can cause code parsing errors for wasm: "locals exceed maximum".
#[cfg_attr(not(target_arch = "wasm32"), inline(always))]
fn round(v: &mut [u32x16; 16], m: [u32x16; 16], r: usize) {
pub fn round(v: &mut [u32x16; 16], m: [u32x16; 16], r: usize) {
v[0] += m[SIGMA[r][0] as usize];
v[1] += m[SIGMA[r][2] as usize];
v[2] += m[SIGMA[r][4] as usize];
Expand Down
17 changes: 17 additions & 0 deletions crates/prover/src/examples/blake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,21 @@ use std::simd::u32x16;
use xor_table::{XorAccumulator, XorElements};

use crate::constraint_framework::logup::LookupElements;
use crate::core::backend::simd::m31::PackedBaseField;
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::FieldExpOps;

mod round;
mod scheduler;
mod xor_table;

const STATE_SIZE: usize = 16;
const MESSAGE_SIZE: usize = 16;
const N_FELTS_IN_U32: usize = 2;
const N_ROUND_INPUT_FELTS: usize = (STATE_SIZE + STATE_SIZE + MESSAGE_SIZE) * N_FELTS_IN_U32;
const N_ROUNDS: usize = 10;

#[derive(Default)]
struct XorAccums {
xor12: XorAccumulator<12, 4>,
Expand Down Expand Up @@ -76,6 +84,7 @@ impl BlakeXorElements {
}
}

/// Utility for representing a u32 as two field elements, for constraint evaluation.
#[derive(Clone, Copy, Debug)]
struct Fu32<F>
where
Expand Down Expand Up @@ -104,3 +113,11 @@ where
[self.l, self.h]
}
}

/// Utility for splitting a u32 into 2 field elements in trace generation.
fn to_felts(x: &u32x16) -> [PackedBaseField; 2] {
[
unsafe { PackedBaseField::from_simd_unchecked(x & u32x16::splat(0xffff)) },
unsafe { PackedBaseField::from_simd_unchecked(x >> 16) },
]
}
6 changes: 3 additions & 3 deletions crates/prover/src/examples/blake/round/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use super::{BlakeXorElements, RoundElements};
use crate::constraint_framework::logup::{LogupAtRow, LookupElements};
use crate::constraint_framework::EvalAtRow;
use crate::core::fields::m31::BaseField;
use crate::examples::blake::Fu32;
use crate::examples::blake::{Fu32, STATE_SIZE};

const INV16: BaseField = BaseField::from_u32_unchecked(1 << 15);
const TWO: BaseField = BaseField::from_u32_unchecked(2);
Expand All @@ -18,9 +18,9 @@ pub struct BlakeRoundEval<'a, E: EvalAtRow> {
}
impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> {
pub fn eval(mut self) -> E {
let mut v: [Fu32<E::F>; 16] = std::array::from_fn(|_| self.next_u32());
let mut v: [Fu32<E::F>; STATE_SIZE] = std::array::from_fn(|_| self.next_u32());
let input_v = v;
let m: [Fu32<E::F>; 16] = std::array::from_fn(|_| self.next_u32());
let m: [Fu32<E::F>; STATE_SIZE] = std::array::from_fn(|_| self.next_u32());

self.g(v.get_many_mut([0, 4, 8, 12]).unwrap(), m[0], m[1]);
self.g(v.get_many_mut([1, 5, 9, 13]).unwrap(), m[2], m[3]);
Expand Down
19 changes: 9 additions & 10 deletions crates/prover/src/examples/blake/round/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,23 @@ use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::ColumnVec;
use crate::examples::blake::round::blake_round_info;
use crate::examples::blake::XorAccums;
use crate::examples::blake::{
to_felts, XorAccums, MESSAGE_SIZE, N_FELTS_IN_U32, N_ROUND_INPUT_FELTS, STATE_SIZE,
};

pub struct BlakeRoundLookupData {
/// A vector of (w, [a_col, b_col, c_col]) for each xor lookup.
/// w is the xor width. c_col is the xor col of a_col and b_col.
xor_lookups: Vec<(u32, [BaseColumn; 3])>,
/// A column of round lookup values (v_in, v_out, m).
round_lookup: [BaseColumn; 16 * 3 * 2],
round_lookup: [BaseColumn; N_ROUND_INPUT_FELTS],
}

pub struct TraceGenerator {
log_size: u32,
trace: Vec<BaseColumn>,
xor_lookups: Vec<(u32, [BaseColumn; 3])>,
round_lookup: [BaseColumn; 16 * 3 * 2],
round_lookup: [BaseColumn; N_ROUND_INPUT_FELTS],
}
impl TraceGenerator {
fn new(log_size: u32) -> Self {
Expand Down Expand Up @@ -98,12 +100,9 @@ impl<'a> TraceGeneratorRow<'a> {
self.g(v.get_many_mut([3, 4, 9, 14]).unwrap(), m[14], m[15]);

chain![input_v.iter(), v.iter(), m.iter()]
.flat_map(|s| [s & u32x16::splat(0xffff), s >> 16])
.flat_map(to_felts)
.enumerate()
.for_each(|(i, val)| {
self.gen.round_lookup[i].data[self.vec_row] =
unsafe { PackedBaseField::from_simd_unchecked(val) }
});
.for_each(|(i, felt)| self.gen.round_lookup[i].data[self.vec_row] = felt);
}

fn g(&mut self, v: [&mut u32x16; 4], m0: u32x16, m1: u32x16) {
Expand Down Expand Up @@ -203,8 +202,8 @@ impl<'a> TraceGeneratorRow<'a> {

#[derive(Copy, Clone, Default)]
pub struct BlakeRoundInput {
pub v: [u32x16; 16],
pub m: [u32x16; 16],
pub v: [u32x16; STATE_SIZE],
pub m: [u32x16; STATE_SIZE],
}

pub fn generate_trace(
Expand Down
5 changes: 3 additions & 2 deletions crates/prover/src/examples/blake/round/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ mod gen;

use constraints::BlakeRoundEval;
use num_traits::Zero;
pub use r#gen::BlakeRoundInput;

use super::BlakeXorElements;
use super::{BlakeXorElements, N_ROUND_INPUT_FELTS, STATE_SIZE};
use crate::constraint_framework::logup::{LogupAtRow, LookupElements};
use crate::constraint_framework::{EvalAtRow, FrameworkComponent, InfoEvaluator};
use crate::core::fields::qm31::SecureField;
Expand All @@ -20,7 +21,7 @@ pub fn blake_round_info() -> InfoEvaluator {
component.evaluate(InfoEvaluator::default())
}

pub type RoundElements = LookupElements<{ 16 * 3 * 2 }>;
pub type RoundElements = LookupElements<N_ROUND_INPUT_FELTS>;
pub struct BlakeRoundComponent {
pub log_size: u32,
pub xor_lookup_elements: BlakeXorElements,
Expand Down
66 changes: 66 additions & 0 deletions crates/prover/src/examples/blake/scheduler/constraints.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use itertools::{chain, Itertools};
use num_traits::One;

use super::BlakeElements;
use crate::constraint_framework::logup::{LogupAtRow, LookupElements};
use crate::constraint_framework::EvalAtRow;
use crate::core::vcs::blake2s_ref::SIGMA;
use crate::examples::blake::round::RoundElements;
use crate::examples::blake::{Fu32, N_ROUNDS, STATE_SIZE};

pub struct BlakeSchedulerEval<'a, E: EvalAtRow> {
pub eval: E,
pub blake_lookup_elements: &'a BlakeElements,
pub round_lookup_elements: &'a RoundElements,
pub logup: LogupAtRow<2, E>,
}
impl<'a, E: EvalAtRow> BlakeSchedulerEval<'a, E> {
pub fn eval(mut self) -> E {
let messages: [Fu32<E::F>; STATE_SIZE] = std::array::from_fn(|_| self.next_u32());
let states: [[Fu32<E::F>; STATE_SIZE]; N_ROUNDS + 1] =
std::array::from_fn(|_| std::array::from_fn(|_| self.next_u32()));

// Schedule.
for i in 0..N_ROUNDS {
let input_state = &states[i];
let output_state = &states[i + 1];
let round_messages = SIGMA[i].map(|j| messages[j as usize]);
// Use triplet in round lookup.
self.logup.push_lookup(
&mut self.eval,
E::EF::one(),
&chain![
input_state.iter().copied().flat_map(Fu32::to_felts),
output_state.iter().copied().flat_map(Fu32::to_felts),
round_messages.iter().copied().flat_map(Fu32::to_felts)
]
.collect_vec(),
self.round_lookup_elements,
)
}

let input_state = &states[0];
let output_state = &states[N_ROUNDS];

// TODO: support multiplicities.
self.logup.push_lookup(
&mut self.eval,
-E::EF::one(),
&chain![
input_state.iter().copied().flat_map(Fu32::to_felts),
output_state.iter().copied().flat_map(Fu32::to_felts),
messages.iter().copied().flat_map(Fu32::to_felts)
]
.collect_vec(),
self.blake_lookup_elements,
);

self.logup.finalize(&mut self.eval);
self.eval
}
fn next_u32(&mut self) -> Fu32<E::F> {
let l = self.eval.next_trace_mask();
let h = self.eval.next_trace_mask();
Fu32 { l, h }
}
}
169 changes: 169 additions & 0 deletions crates/prover/src/examples/blake/scheduler/gen.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
use std::f32::consts::E;
use std::simd::u32x16;

use itertools::{chain, Itertools};
use num_traits::One;
use tracing::{span, Level};

use super::{blake_scheduler_info, BlakeElements};
use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::{blake2s, SimdBackend};
use crate::core::backend::Column;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::ColumnVec;
use crate::examples::blake::round::{BlakeRoundInput, RoundElements};
use crate::examples::blake::{to_felts, N_ROUNDS, N_ROUND_INPUT_FELTS, STATE_SIZE};

#[derive(Copy, Clone, Default)]
pub struct BlakeInput {
pub v: [u32x16; STATE_SIZE],
pub m: [u32x16; STATE_SIZE],
}

pub struct BlakeSchedulerLookupData {
pub round_lookups: [[BaseColumn; N_ROUND_INPUT_FELTS]; N_ROUNDS],
pub blake_lookups: [BaseColumn; N_ROUND_INPUT_FELTS],
}
impl BlakeSchedulerLookupData {
fn new(log_size: u32) -> Self {
Self {
round_lookups: std::array::from_fn(|_| {
std::array::from_fn(|_| unsafe { BaseColumn::uninitialized(1 << log_size) })
}),
blake_lookups: std::array::from_fn(|_| unsafe {
BaseColumn::uninitialized(1 << log_size)
}),
}
}
}

pub fn gen_trace(
log_size: u32,
inputs: &[BlakeInput],
) -> (
ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>,
BlakeSchedulerLookupData,
Vec<BlakeRoundInput>,
) {
let mut lookup_data = BlakeSchedulerLookupData::new(log_size);
let mut round_inputs = Vec::with_capacity(inputs.len() * N_ROUNDS);

let mut trace = (0..blake_scheduler_info().mask_offsets[0].len())
.map(|_| unsafe { BaseColumn::uninitialized(1 << log_size) })
.collect_vec();

for vec_row in 0..(1 << (log_size - LOG_N_LANES)) {
let mut col_index = 0;

let mut write_u32_array = |x: [u32x16; STATE_SIZE], col_index: &mut usize| {
x.iter().for_each(|x| {
to_felts(x).iter().for_each(|x| {
trace[*col_index].data[vec_row] = *x;
*col_index += 1;
});
});
};

let BlakeInput { mut v, m } = inputs.get(vec_row).copied().unwrap_or_default();
let initial_v = v;
write_u32_array(m, &mut col_index);
write_u32_array(v, &mut col_index);

for r in 0..N_ROUNDS {
let prev_v = v;
blake2s::round(&mut v, m, r);
write_u32_array(v, &mut col_index);

let round_m = blake2s::SIGMA[r].map(|i| m[i as usize]);
round_inputs.push(BlakeRoundInput {
v: prev_v,
m: round_m,
});

chain![
prev_v.iter().flat_map(to_felts),
v.iter().flat_map(to_felts),
round_m.iter().flat_map(to_felts)
]
.enumerate()
.for_each(|(i, val)| lookup_data.round_lookups[r][i].data[vec_row] = val);
}

chain![
initial_v.iter().flat_map(to_felts),
v.iter().flat_map(to_felts),
m.iter().flat_map(to_felts)
]
.enumerate()
.for_each(|(i, val)| lookup_data.blake_lookups[i].data[vec_row] = val);
}

let domain = CanonicCoset::new(log_size).circle_domain();
let trace = trace
.into_iter()
.map(|eval| CircleEvaluation::<SimdBackend, _, BitReversedOrder>::new(domain, eval))
.collect_vec();

(trace, lookup_data, round_inputs)
}
pub fn gen_interaction_trace(
log_size: u32,
lookup_data: BlakeSchedulerLookupData,
round_lookup_elements: &RoundElements,
blake_lookup_elements: &BlakeElements,
) -> (
ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>,
SecureField,
) {
let _span = span!(Level::INFO, "Generate scheduler interaction trace").entered();

let mut logup_gen = LogupTraceGenerator::new(log_size);

for [l0, l1] in lookup_data.round_lookups.array_chunks::<2>() {
let mut col_gen = logup_gen.new_col();

#[allow(clippy::needless_range_loop)]
for vec_row in 0..(1 << (log_size - LOG_N_LANES)) {
let p0: PackedSecureField =
round_lookup_elements.combine(&l0.each_ref().map(|l| l.data[vec_row]));
let p1: PackedSecureField =
round_lookup_elements.combine(&l1.each_ref().map(|l| l.data[vec_row]));
#[allow(clippy::eq_op)]
col_gen.write_frac(vec_row, p0 + p1, p0 * p1);
}

col_gen.finalize_col();
}

// Last pair. If the number of round is odd (as in blake3), we combine that last round lookup
// with the entire blake lookup.
let mut col_gen = logup_gen.new_col();
#[allow(clippy::needless_range_loop)]
for vec_row in 0..(1 << (log_size - LOG_N_LANES)) {
let p_blake: PackedSecureField = blake_lookup_elements.combine(
&lookup_data
.blake_lookups
.each_ref()
.map(|l| l.data[vec_row]),
);
if N_ROUNDS % 2 == 1 {
let p_round: PackedSecureField = round_lookup_elements.combine(
&lookup_data.round_lookups[N_ROUNDS - 1]
.each_ref()
.map(|l| l.data[vec_row]),
);
col_gen.write_frac(vec_row, p_blake - p_round, p_round * p_blake);
} else {
col_gen.write_frac(vec_row, -PackedSecureField::one(), p_blake);
}
}
col_gen.finalize_col();

logup_gen.finalize()
}
Loading

0 comments on commit 925a424

Please sign in to comment.