From c7246e2f57cefa8b12d74bb67129b26fe642fd5b Mon Sep 17 00:00:00 2001 From: Shahar Papini <43779613+spapinistarkware@users.noreply.github.com> Date: Wed, 10 Jul 2024 09:10:57 +0300 Subject: [PATCH] Eval framework preperation (#705) --- crates/prover/src/core/backend/simd/m31.rs | 24 +++++++- crates/prover/src/core/backend/simd/qm31.rs | 24 ++++++++ crates/prover/src/core/circle.rs | 8 +++ crates/prover/src/core/prover/mod.rs | 18 +++--- crates/prover/src/core/utils.rs | 61 ++++++++++++++++----- 5 files changed, 111 insertions(+), 24 deletions(-) diff --git a/crates/prover/src/core/backend/simd/m31.rs b/crates/prover/src/core/backend/simd/m31.rs index 654a6ac34..6ea2a7d65 100644 --- a/crates/prover/src/core/backend/simd/m31.rs +++ b/crates/prover/src/core/backend/simd/m31.rs @@ -8,8 +8,10 @@ use bytemuck::{Pod, Zeroable}; use num_traits::{One, Zero}; use rand::distributions::{Distribution, Standard}; +use super::qm31::PackedQM31; use crate::core::backend::simd::utils::{InterleaveEvens, InterleaveOdds}; use crate::core::fields::m31::{pow2147483645, BaseField, M31, P}; +use crate::core::fields::qm31::QM31; use crate::core::fields::FieldExpOps; pub const LOG_N_LANES: u32 = 4; @@ -149,15 +151,33 @@ impl Mul for PackedM31 { } } -impl Mul for PackedM31 { +impl Mul for PackedM31 { type Output = Self; #[inline(always)] - fn mul(self, rhs: BaseField) -> Self::Output { + fn mul(self, rhs: M31) -> Self::Output { self * PackedM31::broadcast(rhs) } } +impl Add for PackedM31 { + type Output = PackedQM31; + + #[inline(always)] + fn add(self, rhs: QM31) -> Self::Output { + PackedQM31::broadcast(rhs) + self + } +} + +impl Mul for PackedM31 { + type Output = PackedQM31; + + #[inline(always)] + fn mul(self, rhs: QM31) -> Self::Output { + PackedQM31::broadcast(rhs) * self + } +} + impl MulAssign for PackedM31 { #[inline(always)] fn mul_assign(&mut self, rhs: Self) { diff --git a/crates/prover/src/core/backend/simd/qm31.rs b/crates/prover/src/core/backend/simd/qm31.rs index 9b0a01ea1..6dc273630 100644 --- a/crates/prover/src/core/backend/simd/qm31.rs +++ b/crates/prover/src/core/backend/simd/qm31.rs @@ -198,6 +198,30 @@ impl Sub for PackedQM31 { } } +impl Add for PackedQM31 { + type Output = Self; + + fn add(self, rhs: QM31) -> Self::Output { + self + PackedQM31::broadcast(rhs) + } +} + +impl Sub for PackedQM31 { + type Output = Self; + + fn sub(self, rhs: QM31) -> Self::Output { + self - PackedQM31::broadcast(rhs) + } +} + +impl Mul for PackedQM31 { + type Output = Self; + + fn mul(self, rhs: QM31) -> Self::Output { + self * PackedQM31::broadcast(rhs) + } +} + impl SubAssign for PackedQM31 { fn sub_assign(&mut self, rhs: Self) { *self = *self - rhs; diff --git a/crates/prover/src/core/circle.rs b/crates/prover/src/core/circle.rs index 4ebf442ac..7aac1203c 100644 --- a/crates/prover/src/core/circle.rs +++ b/crates/prover/src/core/circle.rs @@ -104,6 +104,14 @@ impl CirclePoint { y: self.y.into(), } } + + pub fn mul_signed(&self, off: isize) -> CirclePoint { + if off > 0 { + self.mul(off as u128) + } else { + self.conjugate().mul(-off as u128) + } + } } impl Add for CirclePoint { diff --git a/crates/prover/src/core/prover/mod.rs b/crates/prover/src/core/prover/mod.rs index 1ba46fe7f..a6473d21b 100644 --- a/crates/prover/src/core/prover/mod.rs +++ b/crates/prover/src/core/prover/mod.rs @@ -98,13 +98,6 @@ pub fn generate_proof>( ) -> Result { let component_traces = air.component_traces(&commitment_scheme.trees); let lookup_values = air.lookup_values(&component_traces); - channel.mix_felts( - &lookup_values - .0 - .values() - .map(|v| SecureField::from(*v)) - .collect_vec(), - ); // Evaluate and commit on composition polynomial. let random_coeff = channel.draw_felt(); @@ -190,8 +183,17 @@ pub fn prove>( let (mut commitment_scheme, interaction_elements) = evaluate_and_commit_on_trace(air, channel, &twiddles, trace)?; + let air = air.to_air_prover(); + channel.mix_felts( + &air.lookup_values(&air.component_traces(&commitment_scheme.trees)) + .0 + .values() + .map(|v| SecureField::from(*v)) + .collect_vec(), + ); + generate_proof( - &air.to_air_prover(), + &air, channel, &interaction_elements, &twiddles, diff --git a/crates/prover/src/core/utils.rs b/crates/prover/src/core/utils.rs index 889a8cc8b..2105a4d81 100644 --- a/crates/prover/src/core/utils.rs +++ b/crates/prover/src/core/utils.rs @@ -1,5 +1,5 @@ use std::iter::Peekable; -use std::ops::Add; +use std::ops::{Add, Mul, Sub}; use num_traits::{One, Zero}; @@ -7,7 +7,7 @@ use super::circle::CirclePoint; use super::constraints::point_vanishing; use super::fields::m31::BaseField; use super::fields::qm31::SecureField; -use super::fields::{ExtensionOf, FieldExpOps}; +use super::fields::FieldExpOps; use super::poly::circle::CircleDomain; pub trait IteratorMutExt<'a, T: 'a>: Iterator { @@ -71,14 +71,23 @@ pub(crate) fn previous_bit_reversed_circle_domain_index( domain_log_size: u32, eval_log_size: u32, ) -> usize { - assert!(domain_log_size < eval_log_size); - let step_size = 1 << (eval_log_size - domain_log_size - 1) as usize; + offset_bit_reversed_circle_domain_index(i, domain_log_size, eval_log_size, -1) +} + +pub(crate) fn offset_bit_reversed_circle_domain_index( + i: usize, + domain_log_size: u32, + eval_log_size: u32, + offset: isize, +) -> usize { let mut prev_index = bit_reverse_index(i, eval_log_size); let half_size = 1 << (eval_log_size - 1); + let step_size = offset * (1 << (eval_log_size - domain_log_size - 1)) as isize; if prev_index < half_size { - prev_index = (prev_index + half_size - step_size) % half_size; + prev_index = (prev_index as isize + step_size).rem_euclid(half_size as isize) as usize; } else { - prev_index = ((prev_index + step_size) % half_size) + half_size; + prev_index = + ((prev_index as isize - step_size).rem_euclid(half_size as isize) as usize) + half_size; } bit_reverse_index(prev_index, eval_log_size) } @@ -139,17 +148,13 @@ pub fn generate_secure_powers(felt: SecureField, n_powers: usize) -> Vec>( - values: &[F], - alpha: SecureField, - z: SecureField, -) -> SecureField +pub fn shifted_secure_combination(values: &[F], alpha: EF, z: EF) -> EF where - SecureField: Add, + EF: Copy + Zero + Mul + Add + Sub, { let res = values .iter() - .fold(SecureField::zero(), |acc, &value| acc * alpha + value); + .fold(EF::zero(), |acc, &value| acc * alpha + value); res - z } @@ -173,12 +178,15 @@ mod tests { use itertools::Itertools; use num_traits::One; + use super::{ + offset_bit_reversed_circle_domain_index, previous_bit_reversed_circle_domain_index, + }; use crate::core::backend::cpu::CpuCircleEvaluation; use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; use crate::core::poly::circle::CanonicCoset; use crate::core::poly::NaturalOrder; - use crate::core::utils::{bit_reverse, previous_bit_reversed_circle_domain_index}; + use crate::core::utils::bit_reverse; use crate::{m31, qm31}; #[test] @@ -218,6 +226,31 @@ mod tests { assert_eq!(powers, vec![]); } + #[test] + fn test_offset_bit_reversed_circle_domain_index() { + let domain_log_size = 3; + let eval_log_size = 6; + let initial_index = 5; + + let actual = offset_bit_reversed_circle_domain_index( + initial_index, + domain_log_size, + eval_log_size, + -2, + ); + let expected_prev = previous_bit_reversed_circle_domain_index( + initial_index, + domain_log_size, + eval_log_size, + ); + let expected_prev2 = previous_bit_reversed_circle_domain_index( + expected_prev, + domain_log_size, + eval_log_size, + ); + assert_eq!(actual, expected_prev2); + } + #[test] fn test_previous_bit_reversed_circle_domain_index() { let log_size = 4;