Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blake xor table #764

Merged
merged 1 commit into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl<const BATCH_SIZE: usize, E: EvalAtRow> LogupAtRow<BATCH_SIZE, E> {
eval: &mut E,
numerator: E::EF,
values: &[E::F],
lookup_elements: LookupElements,
lookup_elements: &LookupElements,
) {
let shifted_value = shifted_secure_combination(
values,
Expand Down Expand Up @@ -136,6 +136,13 @@ impl LookupElements {
{
shifted_secure_combination(values, EF::from(self.alpha), EF::from(self.z))
}
#[cfg(test)]
pub fn dummy() -> Self {
Self {
z: SecureField::one(),
alpha: SecureField::one(),
}
}
}

// SIMD backend generator for logup interaction trace.
Expand Down
4 changes: 3 additions & 1 deletion crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub trait EvalAtRow {
+ Mul<BaseField, Output = Self::F>
+ Add<SecureField, Output = Self::EF>
+ Mul<SecureField, Output = Self::EF>
+ Neg<Output = Self::F>
+ From<BaseField>;

/// A field type representing the closure of `F` with multiplying by [SecureField]. Constraints
Expand All @@ -55,7 +56,8 @@ pub trait EvalAtRow {
+ Mul<Self::F, Output = Self::EF>
+ Sub<Self::EF, Output = Self::EF>
+ Mul<Self::EF, Output = Self::EF>
+ From<SecureField>;
+ From<SecureField>
+ From<Self::F>;

/// Returns the next mask value for the first interaction at offset 0.
fn next_trace_mask(&mut self) -> Self::F {
Expand Down
1 change: 1 addition & 0 deletions crates/prover/src/examples/blake/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mod xor_table;
52 changes: 52 additions & 0 deletions crates/prover/src/examples/blake/xor_table/constraints.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use super::limb_bits;
use crate::constraint_framework::logup::{LogupAtRow, LookupElements};
use crate::constraint_framework::EvalAtRow;
use crate::core::fields::m31::BaseField;

/// Constraints for the xor table.
pub struct XorTableEval<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32> {
pub eval: E,
pub lookup_elements: &'a LookupElements,
pub logup: LogupAtRow<2, E>,
}
impl<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32>
XorTableEval<'a, E, ELEM_BITS, EXPAND_BITS>
{
pub fn eval(mut self) -> E {
// al, bl are the constant columns for the inputs: All pairs of elements in [0,
// 2^LIMB_BITS).
// cl is the constant column for the xor: al ^ bl.
let [al] = self.eval.next_interaction_mask(2, [0]);
let [bl] = self.eval.next_interaction_mask(2, [0]);
let [cl] = self.eval.next_interaction_mask(2, [0]);
for i in 0..1 << EXPAND_BITS {
for j in 0..1 << EXPAND_BITS {
let multiplicity = self.eval.next_trace_mask();

let a = al
+ E::F::from(BaseField::from_u32_unchecked(
i << limb_bits::<ELEM_BITS, EXPAND_BITS>(),
));
let b = bl
+ E::F::from(BaseField::from_u32_unchecked(
j << limb_bits::<ELEM_BITS, EXPAND_BITS>(),
));
let c = cl
+ E::F::from(BaseField::from_u32_unchecked(
(i ^ j) << limb_bits::<ELEM_BITS, EXPAND_BITS>(),
));

// Add with negative multiplicity. Consumers should lookup with positive
// multiplicity.
self.logup.push_lookup(
&mut self.eval,
(-multiplicity).into(),
&[a, b, c],
self.lookup_elements,
);
}
}
self.logup.finalize(&mut self.eval);
self.eval
}
}
170 changes: 170 additions & 0 deletions crates/prover/src/examples/blake/xor_table/gen.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
use std::simd::u32x16;

use itertools::Itertools;
use tracing::{span, Level};

use super::{column_bits, limb_bits, XorAccumulator};
use crate::constraint_framework::constant_columns::gen_is_first;
use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::SimdBackend;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::ColumnVec;

pub struct XorTableLookupData<const ELEM_BITS: u32, const EXPAND_BITS: u32> {
pub xor_accum: XorAccumulator<ELEM_BITS, EXPAND_BITS>,
}

pub fn generate_trace<const ELEM_BITS: u32, const EXPAND_BITS: u32>(
xor_accum: XorAccumulator<ELEM_BITS, EXPAND_BITS>,
) -> (
ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>,
XorTableLookupData<ELEM_BITS, EXPAND_BITS>,
) {
(
xor_accum
.mults
.iter()
.map(|mult| {
CircleEvaluation::new(
CanonicCoset::new(column_bits::<ELEM_BITS, EXPAND_BITS>()).circle_domain(),
mult.clone(),
)
})
.collect_vec(),
XorTableLookupData { xor_accum },
)
}

/// Generates the interaction trace for the xor table.
/// Returns the interaction trace, the constant trace, and the claimed sum.
#[allow(clippy::type_complexity)]
pub fn generate_interaction_trace<const ELEM_BITS: u32, const EXPAND_BITS: u32>(
lookup_data: XorTableLookupData<ELEM_BITS, EXPAND_BITS>,
lookup_elements: &LookupElements,
) -> (
ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>,
SecureField,
) {
let limb_bits = limb_bits::<ELEM_BITS, EXPAND_BITS>();
let _span = span!(Level::INFO, "Xor interaction trace").entered();
let offsets_vec = u32x16::from_array(std::array::from_fn(|i| i as u32));
let mut logup_gen = LogupTraceGenerator::new(column_bits::<ELEM_BITS, EXPAND_BITS>());

// Iterate each pair of columns, to batch their lookup together.
// There are 2^(2*EXPAND_BITS) column, for each combination of ah, bh.
let mut iter = lookup_data
.xor_accum
.mults
.iter()
.enumerate()
.array_chunks::<2>();
for [(i0, mults0), (i1, mults1)] in &mut iter {
let mut col_gen = logup_gen.new_col();

// Extract ah, bh from column index.
let ah0 = i0 as u32 >> EXPAND_BITS;
let bh0 = i0 as u32 & ((1 << EXPAND_BITS) - 1);
let ah1 = i1 as u32 >> EXPAND_BITS;
let bh1 = i1 as u32 & ((1 << EXPAND_BITS) - 1);

// Each column has 2^(2*LIMB_BITS) rows, packed in N_LANES.
#[allow(clippy::needless_range_loop)]
for vec_row in 0..(1 << (column_bits::<ELEM_BITS, EXPAND_BITS>() - LOG_N_LANES)) {
// vec_row is LIMB_BITS of al and LIMB_BITS - LOG_N_LANES of bl.
// Extract al, blh from vec_row.
let al = vec_row >> (limb_bits - LOG_N_LANES);
let blh = vec_row & ((1 << (limb_bits - LOG_N_LANES)) - 1);

// Construct the 3 vectors a, b, c.
let a0 = u32x16::splat((ah0 << limb_bits) | al);
let a1 = u32x16::splat((ah1 << limb_bits) | al);
// bll is just the consecutive numbers 0 .. N_LANES-1.
let b0 = u32x16::splat((bh0 << limb_bits) | (blh << LOG_N_LANES)) | offsets_vec;
let b1 = u32x16::splat((bh1 << limb_bits) | (blh << LOG_N_LANES)) | offsets_vec;

let c0 = a0 ^ b0;
let c1 = a1 ^ b1;

let p0: PackedSecureField = lookup_elements
.combine(&[a0, b0, c0].map(|x| unsafe { PackedBaseField::from_simd_unchecked(x) }));
let p1: PackedSecureField = lookup_elements
.combine(&[a1, b1, c1].map(|x| unsafe { PackedBaseField::from_simd_unchecked(x) }));

let num = p1 * mults0.data[vec_row as usize] + p0 * mults1.data[vec_row as usize];
let denom = p0 * p1;
col_gen.write_frac(vec_row as usize, -num, denom);
}
col_gen.finalize_col();
}

// If there is an odd number of lookup expressions, handle the last one.
if let Some(rem) = iter.into_remainder() {
if let Some((i, mults)) = rem.collect_vec().pop() {
let mut col_gen = logup_gen.new_col();
let ah = i as u32 >> EXPAND_BITS;
let bh = i as u32 & ((1 << EXPAND_BITS) - 1);

#[allow(clippy::needless_range_loop)]
for vec_row in 0..(1 << (column_bits::<ELEM_BITS, EXPAND_BITS>() - LOG_N_LANES)) {
// vec_row is LIMB_BITS of a, and LIMB_BITS - LOG_N_LANES of b.
let al = vec_row >> (limb_bits - LOG_N_LANES);
let a = u32x16::splat((ah << limb_bits) | al);
let bm = vec_row & ((1 << (limb_bits - LOG_N_LANES)) - 1);
let b = u32x16::splat((bh << limb_bits) | (bm << LOG_N_LANES)) | offsets_vec;

let c = a ^ b;

let p: PackedSecureField = lookup_elements.combine(
&[a, b, c].map(|x| unsafe { PackedBaseField::from_simd_unchecked(x) }),
);

let num = mults.data[vec_row as usize];
let denom = p;
col_gen.write_frac(vec_row as usize, PackedSecureField::from(-num), denom);
}
col_gen.finalize_col();
}
}

let (interaction_trace, claimed_sum) = logup_gen.finalize();
(interaction_trace, claimed_sum)
}

/// Generates the constant trace for the xor table.
/// Returns the constant trace, the constant trace, and the claimed sum.
#[allow(clippy::type_complexity)]
pub fn generate_constant_trace<const ELEM_BITS: u32, const EXPAND_BITS: u32>(
) -> ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
let limb_bits = limb_bits::<ELEM_BITS, EXPAND_BITS>();
let _span = span!(Level::INFO, "Xor constant trace").entered();

// Generate the constant columns. In reality, these should be generated before the proof
// even began.
let a_col: BaseColumn = (0..(1 << (column_bits::<ELEM_BITS, EXPAND_BITS>())))
.map(|i| BaseField::from_u32_unchecked((i >> limb_bits) as u32))
.collect();
let b_col: BaseColumn = (0..(1 << (column_bits::<ELEM_BITS, EXPAND_BITS>())))
.map(|i| BaseField::from_u32_unchecked((i & ((1 << limb_bits) - 1)) as u32))
.collect();
let c_col: BaseColumn = (0..(1 << (column_bits::<ELEM_BITS, EXPAND_BITS>())))
.map(|i| {
BaseField::from_u32_unchecked(((i >> limb_bits) ^ (i & ((1 << limb_bits) - 1))) as u32)
})
.collect();
let mut constant_trace = [a_col, b_col, c_col]
.map(|x| {
CircleEvaluation::new(
CanonicCoset::new(column_bits::<ELEM_BITS, EXPAND_BITS>()).circle_domain(),
x,
)
})
.to_vec();
constant_trace.insert(0, gen_is_first(column_bits::<ELEM_BITS, EXPAND_BITS>()));
constant_trace
}
Loading
Loading