Skip to content

Commit

Permalink
Blake xor table (#764)
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware authored Aug 4, 2024
1 parent 5b9966d commit 17591ea
Show file tree
Hide file tree
Showing 8 changed files with 376 additions and 4 deletions.
9 changes: 8 additions & 1 deletion crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl<const BATCH_SIZE: usize, E: EvalAtRow> LogupAtRow<BATCH_SIZE, E> {
eval: &mut E,
numerator: E::EF,
values: &[E::F],
lookup_elements: LookupElements,
lookup_elements: &LookupElements,
) {
let shifted_value = shifted_secure_combination(
values,
Expand Down Expand Up @@ -136,6 +136,13 @@ impl LookupElements {
{
shifted_secure_combination(values, EF::from(self.alpha), EF::from(self.z))
}
#[cfg(test)]
pub fn dummy() -> Self {
Self {
z: SecureField::one(),
alpha: SecureField::one(),
}
}
}

// SIMD backend generator for logup interaction trace.
Expand Down
4 changes: 3 additions & 1 deletion crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub trait EvalAtRow {
+ Mul<BaseField, Output = Self::F>
+ Add<SecureField, Output = Self::EF>
+ Mul<SecureField, Output = Self::EF>
+ Neg<Output = Self::F>
+ From<BaseField>;

/// A field type representing the closure of `F` with multiplying by [SecureField]. Constraints
Expand All @@ -56,7 +57,8 @@ pub trait EvalAtRow {
+ Mul<Self::F, Output = Self::EF>
+ Sub<Self::EF, Output = Self::EF>
+ Mul<Self::EF, Output = Self::EF>
+ From<SecureField>;
+ From<SecureField>
+ From<Self::F>;

/// Returns the next mask value for the first interaction at offset 0.
fn next_trace_mask(&mut self) -> Self::F {
Expand Down
1 change: 1 addition & 0 deletions crates/prover/src/examples/blake/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mod xor_table;
52 changes: 52 additions & 0 deletions crates/prover/src/examples/blake/xor_table/constraints.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use super::limb_bits;
use crate::constraint_framework::logup::{LogupAtRow, LookupElements};
use crate::constraint_framework::EvalAtRow;
use crate::core::fields::m31::BaseField;

/// Constraints for the xor table.
pub struct XorTableEval<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32> {
pub eval: E,
pub lookup_elements: &'a LookupElements,
pub logup: LogupAtRow<2, E>,
}
impl<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32>
XorTableEval<'a, E, ELEM_BITS, EXPAND_BITS>
{
pub fn eval(mut self) -> E {
// al, bl are the constant columns for the inputs: All pairs of elements in [0,
// 2^LIMB_BITS).
// cl is the constant column for the xor: al ^ bl.
let [al] = self.eval.next_interaction_mask(2, [0]);
let [bl] = self.eval.next_interaction_mask(2, [0]);
let [cl] = self.eval.next_interaction_mask(2, [0]);
for i in 0..1 << EXPAND_BITS {
for j in 0..1 << EXPAND_BITS {
let multiplicity = self.eval.next_trace_mask();

let a = al
+ E::F::from(BaseField::from_u32_unchecked(
i << limb_bits::<ELEM_BITS, EXPAND_BITS>(),
));
let b = bl
+ E::F::from(BaseField::from_u32_unchecked(
j << limb_bits::<ELEM_BITS, EXPAND_BITS>(),
));
let c = cl
+ E::F::from(BaseField::from_u32_unchecked(
(i ^ j) << limb_bits::<ELEM_BITS, EXPAND_BITS>(),
));

// Add with negative multiplicity. Consumers should lookup with positive
// multiplicity.
self.logup.push_lookup(
&mut self.eval,
(-multiplicity).into(),
&[a, b, c],
self.lookup_elements,
);
}
}
self.logup.finalize(&mut self.eval);
self.eval
}
}
170 changes: 170 additions & 0 deletions crates/prover/src/examples/blake/xor_table/gen.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
use std::simd::u32x16;

use itertools::Itertools;
use tracing::{span, Level};

use super::{column_bits, limb_bits, XorAccumulator};
use crate::constraint_framework::constant_columns::gen_is_first;
use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::SimdBackend;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::ColumnVec;

pub struct XorTableLookupData<const ELEM_BITS: u32, const EXPAND_BITS: u32> {
pub xor_accum: XorAccumulator<ELEM_BITS, EXPAND_BITS>,
}

pub fn generate_trace<const ELEM_BITS: u32, const EXPAND_BITS: u32>(
xor_accum: XorAccumulator<ELEM_BITS, EXPAND_BITS>,
) -> (
ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>,
XorTableLookupData<ELEM_BITS, EXPAND_BITS>,
) {
(
xor_accum
.mults
.iter()
.map(|mult| {
CircleEvaluation::new(
CanonicCoset::new(column_bits::<ELEM_BITS, EXPAND_BITS>()).circle_domain(),
mult.clone(),
)
})
.collect_vec(),
XorTableLookupData { xor_accum },
)
}

/// Generates the interaction trace for the xor table.
/// Returns the interaction trace, the constant trace, and the claimed sum.
#[allow(clippy::type_complexity)]
pub fn generate_interaction_trace<const ELEM_BITS: u32, const EXPAND_BITS: u32>(
lookup_data: XorTableLookupData<ELEM_BITS, EXPAND_BITS>,
lookup_elements: &LookupElements,
) -> (
ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>,
SecureField,
) {
let limb_bits = limb_bits::<ELEM_BITS, EXPAND_BITS>();
let _span = span!(Level::INFO, "Xor interaction trace").entered();
let offsets_vec = u32x16::from_array(std::array::from_fn(|i| i as u32));
let mut logup_gen = LogupTraceGenerator::new(column_bits::<ELEM_BITS, EXPAND_BITS>());

// Iterate each pair of columns, to batch their lookup together.
// There are 2^(2*EXPAND_BITS) column, for each combination of ah, bh.
let mut iter = lookup_data
.xor_accum
.mults
.iter()
.enumerate()
.array_chunks::<2>();
for [(i0, mults0), (i1, mults1)] in &mut iter {
let mut col_gen = logup_gen.new_col();

// Extract ah, bh from column index.
let ah0 = i0 as u32 >> EXPAND_BITS;
let bh0 = i0 as u32 & ((1 << EXPAND_BITS) - 1);
let ah1 = i1 as u32 >> EXPAND_BITS;
let bh1 = i1 as u32 & ((1 << EXPAND_BITS) - 1);

// Each column has 2^(2*LIMB_BITS) rows, packed in N_LANES.
#[allow(clippy::needless_range_loop)]
for vec_row in 0..(1 << (column_bits::<ELEM_BITS, EXPAND_BITS>() - LOG_N_LANES)) {
// vec_row is LIMB_BITS of al and LIMB_BITS - LOG_N_LANES of bl.
// Extract al, blh from vec_row.
let al = vec_row >> (limb_bits - LOG_N_LANES);
let blh = vec_row & ((1 << (limb_bits - LOG_N_LANES)) - 1);

// Construct the 3 vectors a, b, c.
let a0 = u32x16::splat((ah0 << limb_bits) | al);
let a1 = u32x16::splat((ah1 << limb_bits) | al);
// bll is just the consecutive numbers 0 .. N_LANES-1.
let b0 = u32x16::splat((bh0 << limb_bits) | (blh << LOG_N_LANES)) | offsets_vec;
let b1 = u32x16::splat((bh1 << limb_bits) | (blh << LOG_N_LANES)) | offsets_vec;

let c0 = a0 ^ b0;
let c1 = a1 ^ b1;

let p0: PackedSecureField = lookup_elements
.combine(&[a0, b0, c0].map(|x| unsafe { PackedBaseField::from_simd_unchecked(x) }));
let p1: PackedSecureField = lookup_elements
.combine(&[a1, b1, c1].map(|x| unsafe { PackedBaseField::from_simd_unchecked(x) }));

let num = p1 * mults0.data[vec_row as usize] + p0 * mults1.data[vec_row as usize];
let denom = p0 * p1;
col_gen.write_frac(vec_row as usize, -num, denom);
}
col_gen.finalize_col();
}

// If there is an odd number of lookup expressions, handle the last one.
if let Some(rem) = iter.into_remainder() {
if let Some((i, mults)) = rem.collect_vec().pop() {
let mut col_gen = logup_gen.new_col();
let ah = i as u32 >> EXPAND_BITS;
let bh = i as u32 & ((1 << EXPAND_BITS) - 1);

#[allow(clippy::needless_range_loop)]
for vec_row in 0..(1 << (column_bits::<ELEM_BITS, EXPAND_BITS>() - LOG_N_LANES)) {
// vec_row is LIMB_BITS of a, and LIMB_BITS - LOG_N_LANES of b.
let al = vec_row >> (limb_bits - LOG_N_LANES);
let a = u32x16::splat((ah << limb_bits) | al);
let bm = vec_row & ((1 << (limb_bits - LOG_N_LANES)) - 1);
let b = u32x16::splat((bh << limb_bits) | (bm << LOG_N_LANES)) | offsets_vec;

let c = a ^ b;

let p: PackedSecureField = lookup_elements.combine(
&[a, b, c].map(|x| unsafe { PackedBaseField::from_simd_unchecked(x) }),
);

let num = mults.data[vec_row as usize];
let denom = p;
col_gen.write_frac(vec_row as usize, PackedSecureField::from(-num), denom);
}
col_gen.finalize_col();
}
}

let (interaction_trace, claimed_sum) = logup_gen.finalize();
(interaction_trace, claimed_sum)
}

/// Generates the constant trace for the xor table.
/// Returns the constant trace, the constant trace, and the claimed sum.
#[allow(clippy::type_complexity)]
pub fn generate_constant_trace<const ELEM_BITS: u32, const EXPAND_BITS: u32>(
) -> ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
let limb_bits = limb_bits::<ELEM_BITS, EXPAND_BITS>();
let _span = span!(Level::INFO, "Xor constant trace").entered();

// Generate the constant columns. In reality, these should be generated before the proof
// even began.
let a_col: BaseColumn = (0..(1 << (column_bits::<ELEM_BITS, EXPAND_BITS>())))
.map(|i| BaseField::from_u32_unchecked((i >> limb_bits) as u32))
.collect();
let b_col: BaseColumn = (0..(1 << (column_bits::<ELEM_BITS, EXPAND_BITS>())))
.map(|i| BaseField::from_u32_unchecked((i & ((1 << limb_bits) - 1)) as u32))
.collect();
let c_col: BaseColumn = (0..(1 << (column_bits::<ELEM_BITS, EXPAND_BITS>())))
.map(|i| {
BaseField::from_u32_unchecked(((i >> limb_bits) ^ (i & ((1 << limb_bits) - 1))) as u32)
})
.collect();
let mut constant_trace = [a_col, b_col, c_col]
.map(|x| {
CircleEvaluation::new(
CanonicCoset::new(column_bits::<ELEM_BITS, EXPAND_BITS>()).circle_domain(),
x,
)
})
.to_vec();
constant_trace.insert(0, gen_is_first(column_bits::<ELEM_BITS, EXPAND_BITS>()));
constant_trace
}
Loading

0 comments on commit 17591ea

Please sign in to comment.