From 4d64300dc178d10d2eb7e9c43ac69cb73a634bb6 Mon Sep 17 00:00:00 2001 From: Alon-Ti <54235977+Alon-Ti@users.noreply.github.com> Date: Tue, 26 Nov 2024 10:27:31 +0200 Subject: [PATCH 01/69] Exported shared code to temp variables in ExprEvaluator. (#893) --- .../prover/src/constraint_framework/expr.rs | 240 +++++------------- .../prover/src/examples/state_machine/mod.rs | 62 +++-- 2 files changed, 91 insertions(+), 211 deletions(-) diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index 92526a6b6..5d8013402 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -64,7 +64,7 @@ impl Expr { Expr::Sub(a, b) => format!("{} - ({})", a.format_expr(), b.format_expr()), Expr::Mul(a, b) => format!("({}) * ({})", a.format_expr(), b.format_expr()), Expr::Neg(a) => format!("-({})", a.format_expr()), - Expr::Inv(a) => format!("1/({})", a.format_expr()), + Expr::Inv(a) => format!("1 / ({})", a.format_expr()), } } } @@ -248,6 +248,7 @@ pub struct ExprEvaluator { pub cur_var_index: usize, pub constraints: Vec, pub logup: FormalLogupAtRow, + pub intermediates: Vec<(String, Expr)>, } impl ExprEvaluator { @@ -257,8 +258,35 @@ impl ExprEvaluator { cur_var_index: Default::default(), constraints: Default::default(), logup: FormalLogupAtRow::new(INTERACTION_TRACE_IDX, has_partial_sum, log_size), + intermediates: vec![], } } + + pub fn add_intermediate(&mut self, expr: Expr) -> Expr { + let name = format!("intermediate{}", self.intermediates.len()); + let intermediate = Expr::Param(name.clone()); + self.intermediates.push((name, expr)); + intermediate + } + + pub fn format_constraints(&self) -> String { + let lets_string = self + .intermediates + .iter() + .map(|(name, expr)| format!("let {} = {};", name, expr.format_expr())) + .collect::>() + .join("\n"); + + let constraints_str = self + .constraints + .iter() + .enumerate() + .map(|(i, c)| format!("let constraint_{i} = ") + &c.format_expr() + ";") + .collect::>() + .join("\n\n"); + + lets_string + "\n\n" + &constraints_str + } } impl EvalAtRow for ExprEvaluator { @@ -286,7 +314,15 @@ impl EvalAtRow for ExprEvaluator { where Self::EF: std::ops::Mul, { - self.constraints.push(Expr::one() * constraint); + match Expr::one() * constraint { + Expr::Mul(one, constraint) => { + assert_eq!(*one, Expr::one()); + self.constraints.push(*constraint); + } + _ => { + unreachable!(); + } + } } fn combine_ef(values: [Self::F; 4]) -> Self::EF { @@ -310,7 +346,8 @@ impl EvalAtRow for ExprEvaluator { multiplicity, values, }| { - Fraction::new(multiplicity.clone(), combine_formal(*relation, values)) + let intermediate = self.add_intermediate(combine_formal(*relation, values)); + Fraction::new(multiplicity.clone(), intermediate) }, ) .collect(); @@ -324,187 +361,34 @@ impl EvalAtRow for ExprEvaluator { mod tests { use num_traits::One; - use crate::constraint_framework::expr::{ColumnExpr, Expr, ExprEvaluator}; - use crate::constraint_framework::{ - relation, EvalAtRow, FrameworkEval, RelationEntry, ORIGINAL_TRACE_IDX, - }; - use crate::core::fields::m31::M31; + use crate::constraint_framework::expr::ExprEvaluator; + use crate::constraint_framework::{relation, EvalAtRow, FrameworkEval, RelationEntry}; use crate::core::fields::FieldExpOps; - #[test] - fn test_expr_eval() { - let test_struct = TestStruct {}; - let eval = test_struct.evaluate(ExprEvaluator::new(16, false)); - assert_eq!(eval.constraints.len(), 2); - assert_eq!( - eval.constraints[0], - Expr::Mul( - Box::new(Expr::one()), - Box::new(Expr::Mul( - Box::new(Expr::Mul( - Box::new(Expr::Mul( - Box::new(Expr::Col(ColumnExpr { - interaction: ORIGINAL_TRACE_IDX, - idx: 0, - offset: 0 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: ORIGINAL_TRACE_IDX, - idx: 1, - offset: 0 - })) - )), - Box::new(Expr::Col(ColumnExpr { - interaction: ORIGINAL_TRACE_IDX, - idx: 2, - offset: 0 - })) - )), - Box::new(Expr::Inv(Box::new(Expr::Add( - Box::new(Expr::Col(ColumnExpr { - interaction: ORIGINAL_TRACE_IDX, - idx: 0, - offset: 0 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: ORIGINAL_TRACE_IDX, - idx: 1, - offset: 0 - })) - )))) - )) - ) - ); - - assert_eq!( - eval.constraints[1], - Expr::Mul( - Box::new(Expr::Const(M31(1))), - Box::new(Expr::Sub( - Box::new(Expr::Mul( - Box::new(Expr::Sub( - Box::new(Expr::Sub( - Box::new(Expr::SecureCol([ - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 4, - offset: 0 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 6, - offset: 0 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 8, - offset: 0 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 10, - offset: 0 - })) - ])), - Box::new(Expr::Sub( - Box::new(Expr::SecureCol([ - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 5, - offset: -1 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 7, - offset: -1 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 9, - offset: -1 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 11, - offset: -1 - })) - ])), - Box::new(Expr::Mul( - Box::new(Expr::Col(ColumnExpr { - interaction: 0, - idx: 3, - offset: 0 - })), - Box::new(Expr::Param("total_sum".into())) - )) - )) - )), - Box::new(Expr::Const(M31(0))) - )), - Box::new(Expr::Sub( - Box::new(Expr::Add( - Box::new(Expr::Add( - Box::new(Expr::Add( - Box::new(Expr::Const(M31(0))), - Box::new(Expr::Mul( - Box::new(Expr::Param( - "TestRelation_alpha0".to_string() - )), - Box::new(Expr::Col(ColumnExpr { - interaction: 1, - idx: 0, - offset: 0 - })) - )) - )), - Box::new(Expr::Mul( - Box::new(Expr::Param("TestRelation_alpha1".to_string())), - Box::new(Expr::Col(ColumnExpr { - interaction: 1, - idx: 1, - offset: 0 - })) - )) - )), - Box::new(Expr::Mul( - Box::new(Expr::Param("TestRelation_alpha2".to_string())), - Box::new(Expr::Col(ColumnExpr { - interaction: 1, - idx: 2, - offset: 0 - })) - )) - )), - Box::new(Expr::Param("TestRelation_z".to_string())) - )) - )), - Box::new(Expr::Const(M31(1))) - )) - ) - ); - } - #[test] fn test_format_expr() { let test_struct = TestStruct {}; let eval = test_struct.evaluate(ExprEvaluator::new(16, false)); - let constraint0_str = "(1) * ((((col_1_0[0]) * (col_1_1[0])) * (col_1_2[0])) * (1/(col_1_0[0] + col_1_1[0])))"; - assert_eq!(eval.constraints[0].format_expr(), constraint0_str); - let constraint1_str = "(1) \ - * ((SecureCol(col_2_4[0], col_2_6[0], col_2_8[0], col_2_10[0]) \ - - (SecureCol(\ - col_2_5[-1], \ - col_2_7[-1], \ - col_2_9[-1], \ - col_2_11[-1]\ - ) - ((col_0_3[0]) * (total_sum))) \ - - (0)) \ - * (0 + (TestRelation_alpha0) * (col_1_0[0]) \ - + (TestRelation_alpha1) * (col_1_1[0]) \ - + (TestRelation_alpha2) * (col_1_2[0]) \ - - (TestRelation_z)) \ - - (1))"; - assert_eq!(eval.constraints[1].format_expr(), constraint1_str); + let expected = "let intermediate0 = 0 \ + + (TestRelation_alpha0) * (col_1_0[0]) \ + + (TestRelation_alpha1) * (col_1_1[0]) \ + + (TestRelation_alpha2) * (col_1_2[0]) \ + - (TestRelation_z); + +\ + let constraint_0 = \ + (((col_1_0[0]) * (col_1_1[0])) * (col_1_2[0])) * (1 / (col_1_0[0] + col_1_1[0])); + +\ + let constraint_1 = (SecureCol(col_2_4[0], col_2_6[0], col_2_8[0], col_2_10[0]) \ + - (SecureCol(col_2_5[-1], col_2_7[-1], col_2_9[-1], col_2_11[-1]) \ + - ((col_0_3[0]) * (total_sum))) \ + - (0)) \ + * (intermediate0) \ + - (1);" + .to_string(); + + assert_eq!(eval.format_constraints(), expected); } relation!(TestRelation, 3); diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index b17258d20..787394dd3 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -300,38 +300,34 @@ mod tests { ); let eval = component.evaluate(ExprEvaluator::new(log_n_rows, true)); - - assert_eq!(eval.constraints.len(), 2); - let constraint0_str = "(1) \ - * ((SecureCol(\ - col_2_5[claimed_sum_offset], \ - col_2_8[claimed_sum_offset], \ - col_2_11[claimed_sum_offset], \ - col_2_14[claimed_sum_offset]\ - ) - (claimed_sum)) \ - * (col_0_2[0]))"; - assert_eq!(eval.constraints[0].format_expr(), constraint0_str); - let constraint1_str = "(1) \ - * ((SecureCol(col_2_3[0], col_2_6[0], col_2_9[0], col_2_12[0]) \ - - (SecureCol(col_2_4[-1], col_2_7[-1], col_2_10[-1], col_2_13[-1]) \ - - ((col_0_2[0]) * (total_sum))) \ - - (0)) \ - * ((0 \ - + (StateMachineElements_alpha0) * (col_1_0[0]) \ - + (StateMachineElements_alpha1) * (col_1_1[0]) \ - - (StateMachineElements_z)) \ - * (0 + (StateMachineElements_alpha0) * (col_1_0[0] + 1) \ - + (StateMachineElements_alpha1) * (col_1_1[0]) \ - - (StateMachineElements_z))) \ - - ((0 \ - + (StateMachineElements_alpha0) * (col_1_0[0] + 1) \ - + (StateMachineElements_alpha1) * (col_1_1[0]) \ - - (StateMachineElements_z)) \ - * (1) \ - + (0 + (StateMachineElements_alpha0) * (col_1_0[0]) \ - + (StateMachineElements_alpha1) * (col_1_1[0]) \ - - (StateMachineElements_z)) \ - * (-(1))))"; - assert_eq!(eval.constraints[1].format_expr(), constraint1_str); + let expected = "let intermediate0 = 0 \ + + (StateMachineElements_alpha0) * (col_1_0[0]) \ + + (StateMachineElements_alpha1) * (col_1_1[0]) \ + - (StateMachineElements_z); +\ + let intermediate1 = 0 \ + + (StateMachineElements_alpha0) * (col_1_0[0] + 1) \ + + (StateMachineElements_alpha1) * (col_1_1[0]) \ + - (StateMachineElements_z); + +\ + let constraint_0 = (SecureCol(\ + col_2_5[claimed_sum_offset], \ + col_2_8[claimed_sum_offset], \ + col_2_11[claimed_sum_offset], \ + col_2_14[claimed_sum_offset]\ + ) - (claimed_sum)) \ + * (col_0_2[0]); + +\ + let constraint_1 = (SecureCol(col_2_3[0], col_2_6[0], col_2_9[0], col_2_12[0]) \ + - (SecureCol(col_2_4[-1], col_2_7[-1], col_2_10[-1], col_2_13[-1]) \ + - ((col_0_2[0]) * (total_sum))) \ + - (0)) \ + * ((intermediate0) * (intermediate1)) \ + - ((intermediate1) * (1) + (intermediate0) * (-(1)));" + .to_string(); + + assert_eq!(eval.format_constraints(), expected); } } From 9b2573809734c7ee067241bae46da174898a24de Mon Sep 17 00:00:00 2001 From: ilyalesokhin-starkware Date: Thu, 28 Nov 2024 09:36:16 +0200 Subject: [PATCH 02/69] Remove outdated comment. (#906) --- crates/prover/src/core/vcs/verifier.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/crates/prover/src/core/vcs/verifier.rs b/crates/prover/src/core/vcs/verifier.rs index 163fed2f1..9c1b0b39a 100644 --- a/crates/prover/src/core/vcs/verifier.rs +++ b/crates/prover/src/core/vcs/verifier.rs @@ -42,11 +42,6 @@ impl MerkleVerifier { /// * The column values are too short (missing values). /// * The computed root does not match the expected root. /// - /// # Panics - /// - /// This function will panic if the `values` vector is not sorted in descending order based on - /// the `log_size` of the columns. - /// /// # Returns /// /// Returns `Ok(())` if the decommitment is successfully verified. From 260bcb29c7097bfbdaa9c89faa58f68d5d936d4e Mon Sep 17 00:00:00 2001 From: ilyalesokhin-starkware Date: Thu, 28 Nov 2024 10:04:07 +0200 Subject: [PATCH 03/69] FRI: Enforce one column per size. (#905) --- crates/prover/src/core/fri.rs | 9 ++++++--- crates/prover/src/lib.rs | 1 + 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/crates/prover/src/core/fri.rs b/crates/prover/src/core/fri.rs index 03dac47da..0f0a89b9b 100644 --- a/crates/prover/src/core/fri.rs +++ b/crates/prover/src/core/fri.rs @@ -136,7 +136,7 @@ pub struct FriProver<'a, B: FriOps + MerkleOps, MC: MerkleChannel> { impl<'a, B: FriOps + MerkleOps, MC: MerkleChannel> FriProver<'a, B, MC> { /// Commits to multiple circle polynomials. /// - /// `columns` must be provided in descending order by size. + /// `columns` must be provided in descending order by size with at most one column per size. /// /// This is a batched commitment that handles multiple mixed-degree polynomials, each /// evaluated over domains of varying sizes. Instead of combining these evaluations into @@ -147,7 +147,7 @@ impl<'a, B: FriOps + MerkleOps, MC: MerkleChannel> FriProver<'a, B, MC> { /// # Panics /// /// Panics if: - /// * `columns` is empty or not sorted in ascending order by domain size. + /// * `columns` is empty or not sorted in descending order by domain size. /// * An evaluation is not from a sufficiently low degree circle polynomial. /// * An evaluation's domain is smaller than the last layer. /// * An evaluation's domain is not a canonic circle domain. @@ -159,8 +159,11 @@ impl<'a, B: FriOps + MerkleOps, MC: MerkleChannel> FriProver<'a, B, MC> { twiddles: &TwiddleTree, ) -> Self { assert!(!columns.is_empty(), "no columns"); - assert!(columns.is_sorted_by_key(|e| Reverse(e.len())), "not sorted"); assert!(columns.iter().all(|e| e.domain.is_canonic()), "not canonic"); + assert!( + columns.array_windows().all(|[a, b]| a.len() > b.len()), + "column sizes not decreasing" + ); let first_layer = Self::commit_first_layer(channel, columns); let (inner_layers, last_layer_evaluation) = diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index 49adff0f3..34a5c6701 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -3,6 +3,7 @@ array_chunks, array_methods, array_try_from_fn, + array_windows, assert_matches, exact_size_is_empty, generic_const_exprs, From c2ef3ac1be9689fa78be9574e175df41082b7644 Mon Sep 17 00:00:00 2001 From: Thomas Coratger <60488569+tcoratger@users.noreply.github.com> Date: Thu, 28 Nov 2024 14:35:59 +0100 Subject: [PATCH 04/69] clippy: apply missing_const_for_fn rule (#873) --- crates/prover/Cargo.toml | 3 +++ crates/prover/src/core/air/accumulation.rs | 2 +- crates/prover/src/core/backend/simd/cm31.rs | 4 ++-- crates/prover/src/core/backend/simd/column.rs | 2 +- crates/prover/src/core/backend/simd/fft/mod.rs | 2 +- crates/prover/src/core/backend/simd/m31.rs | 6 +++--- crates/prover/src/core/backend/simd/qm31.rs | 8 ++++---- crates/prover/src/core/backend/simd/utils.rs | 4 ++-- crates/prover/src/core/channel/blake2s.rs | 2 +- crates/prover/src/core/channel/poseidon252.rs | 2 +- crates/prover/src/core/circle.rs | 16 ++++++++-------- crates/prover/src/core/fields/cm31.rs | 2 +- crates/prover/src/core/fields/m31.rs | 8 ++++---- crates/prover/src/core/fields/qm31.rs | 6 +++--- crates/prover/src/core/fri.rs | 8 ++++---- crates/prover/src/core/lookups/utils.rs | 4 ++-- crates/prover/src/core/poly/circle/canonic.rs | 12 ++++++------ crates/prover/src/core/poly/circle/domain.rs | 6 +++--- crates/prover/src/core/poly/circle/poly.rs | 2 +- crates/prover/src/core/poly/line.rs | 10 +++++----- crates/prover/src/core/utils.rs | 8 ++++---- crates/prover/src/core/vcs/blake2s_ref.rs | 12 ++++++------ .../src/examples/xor/gkr_lookups/mle_eval.rs | 2 +- 23 files changed, 67 insertions(+), 64 deletions(-) diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index 508c86e18..a9b80e9e4 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -56,6 +56,9 @@ nonstandard-style = "deny" rust-2018-idioms = "deny" unused = "deny" +[lints.clippy] +missing_const_for_fn = "warn" + [package.metadata.cargo-machete] ignored = ["downcast-rs"] diff --git a/crates/prover/src/core/air/accumulation.rs b/crates/prover/src/core/air/accumulation.rs index 2e7010ae6..f958f2029 100644 --- a/crates/prover/src/core/air/accumulation.rs +++ b/crates/prover/src/core/air/accumulation.rs @@ -39,7 +39,7 @@ impl PointEvaluationAccumulator { self.accumulation = self.accumulation * self.random_coeff + evaluation; } - pub fn finalize(self) -> SecureField { + pub const fn finalize(self) -> SecureField { self.accumulation } } diff --git a/crates/prover/src/core/backend/simd/cm31.rs b/crates/prover/src/core/backend/simd/cm31.rs index 31aba0a44..2155e8ff1 100644 --- a/crates/prover/src/core/backend/simd/cm31.rs +++ b/crates/prover/src/core/backend/simd/cm31.rs @@ -19,12 +19,12 @@ impl PackedCM31 { } /// Returns all `a` values such that each vector element is represented as `a + bi`. - pub fn a(&self) -> PackedM31 { + pub const fn a(&self) -> PackedM31 { self.0[0] } /// Returns all `b` values such that each vector element is represented as `a + bi`. - pub fn b(&self) -> PackedM31 { + pub const fn b(&self) -> PackedM31 { self.0[1] } diff --git a/crates/prover/src/core/backend/simd/column.rs b/crates/prover/src/core/backend/simd/column.rs index 07d7463ee..dd5578c0e 100644 --- a/crates/prover/src/core/backend/simd/column.rs +++ b/crates/prover/src/core/backend/simd/column.rs @@ -463,7 +463,7 @@ impl VeryPackedBaseColumn { /// # Safety /// /// The resulting pointer does not update the underlying `data`'s length. - pub unsafe fn transform_under_ref(value: &BaseColumn) -> &Self { + pub const unsafe fn transform_under_ref(value: &BaseColumn) -> &Self { &*(std::ptr::addr_of!(*value) as *const VeryPackedBaseColumn) } diff --git a/crates/prover/src/core/backend/simd/fft/mod.rs b/crates/prover/src/core/backend/simd/fft/mod.rs index ba091b145..b3ea4d700 100644 --- a/crates/prover/src/core/backend/simd/fft/mod.rs +++ b/crates/prover/src/core/backend/simd/fft/mod.rs @@ -97,7 +97,7 @@ pub fn compute_first_twiddles(twiddle1_dbl: u32x8) -> (u32x16, u32x16) { } #[inline] -unsafe fn load(mem_addr: *const u32) -> u32x16 { +const unsafe fn load(mem_addr: *const u32) -> u32x16 { std::ptr::read(mem_addr as *const u32x16) } diff --git a/crates/prover/src/core/backend/simd/m31.rs b/crates/prover/src/core/backend/simd/m31.rs index f6291626b..dbeec152f 100644 --- a/crates/prover/src/core/backend/simd/m31.rs +++ b/crates/prover/src/core/backend/simd/m31.rs @@ -78,14 +78,14 @@ impl PackedM31 { self + self } - pub fn into_simd(self) -> Simd { + pub const fn into_simd(self) -> Simd { self.0 } /// # Safety /// /// Vector elements must be in the range `[0, P]`. - pub unsafe fn from_simd_unchecked(v: Simd) -> Self { + pub const unsafe fn from_simd_unchecked(v: Simd) -> Self { Self(v) } @@ -93,7 +93,7 @@ impl PackedM31 { /// /// Behavior is undefined if the pointer does not have the same alignment as /// [`PackedM31`]. The loaded `u32` values must be in the range `[0, P]`. - pub unsafe fn load(mem_addr: *const u32) -> Self { + pub const unsafe fn load(mem_addr: *const u32) -> Self { Self(ptr::read(mem_addr as *const u32x16)) } diff --git a/crates/prover/src/core/backend/simd/qm31.rs b/crates/prover/src/core/backend/simd/qm31.rs index 078f6ef56..ce7231d0a 100644 --- a/crates/prover/src/core/backend/simd/qm31.rs +++ b/crates/prover/src/core/backend/simd/qm31.rs @@ -28,12 +28,12 @@ impl PackedQM31 { } /// Returns all `a` values such that each vector element is represented as `a + bu`. - pub fn a(&self) -> PackedCM31 { + pub const fn a(&self) -> PackedCM31 { self.0[0] } /// Returns all `b` values such that each vector element is represented as `a + bu`. - pub fn b(&self) -> PackedCM31 { + pub const fn b(&self) -> PackedCM31 { self.0[1] } @@ -80,14 +80,14 @@ impl PackedQM31 { /// Returns vectors `a, b, c, d` such that element `i` is represented as /// `QM31(a_i, b_i, c_i, d_i)`. - pub fn into_packed_m31s(self) -> [PackedM31; 4] { + pub const fn into_packed_m31s(self) -> [PackedM31; 4] { let Self([PackedCM31([a, b]), PackedCM31([c, d])]) = self; [a, b, c, d] } /// Creates an instance from vectors `a, b, c, d` such that element `i` /// is represented as `QM31(a_i, b_i, c_i, d_i)`. - pub fn from_packed_m31s([a, b, c, d]: [PackedM31; 4]) -> Self { + pub const fn from_packed_m31s([a, b, c, d]: [PackedM31; 4]) -> Self { Self([PackedCM31([a, b]), PackedCM31([c, d])]) } } diff --git a/crates/prover/src/core/backend/simd/utils.rs b/crates/prover/src/core/backend/simd/utils.rs index b5cb9e986..d5f53a22b 100644 --- a/crates/prover/src/core/backend/simd/utils.rs +++ b/crates/prover/src/core/backend/simd/utils.rs @@ -30,7 +30,7 @@ impl UnsafeMut { /// # Safety /// /// Returns a raw mutable pointer. - pub unsafe fn get(&self) -> *mut T { + pub const unsafe fn get(&self) -> *mut T { self.0 } } @@ -43,7 +43,7 @@ impl UnsafeConst { /// # Safety /// /// Returns a raw constant pointer. - pub unsafe fn get(&self) -> *const T { + pub const unsafe fn get(&self) -> *const T { self.0 } } diff --git a/crates/prover/src/core/channel/blake2s.rs b/crates/prover/src/core/channel/blake2s.rs index 86565658b..160d4754e 100644 --- a/crates/prover/src/core/channel/blake2s.rs +++ b/crates/prover/src/core/channel/blake2s.rs @@ -19,7 +19,7 @@ pub struct Blake2sChannel { } impl Blake2sChannel { - pub fn digest(&self) -> Blake2sHash { + pub const fn digest(&self) -> Blake2sHash { self.digest } pub fn update_digest(&mut self, new_digest: Blake2sHash) { diff --git a/crates/prover/src/core/channel/poseidon252.rs b/crates/prover/src/core/channel/poseidon252.rs index c0960fc3e..a02a82b1d 100644 --- a/crates/prover/src/core/channel/poseidon252.rs +++ b/crates/prover/src/core/channel/poseidon252.rs @@ -19,7 +19,7 @@ pub struct Poseidon252Channel { } impl Poseidon252Channel { - pub fn digest(&self) -> FieldElement252 { + pub const fn digest(&self) -> FieldElement252 { self.digest } pub fn update_digest(&mut self, new_digest: FieldElement252) { diff --git a/crates/prover/src/core/circle.rs b/crates/prover/src/core/circle.rs index 8cfe48ab8..9f7c99d82 100644 --- a/crates/prover/src/core/circle.rs +++ b/crates/prover/src/core/circle.rs @@ -223,15 +223,15 @@ pub const SECURE_FIELD_CIRCLE_ORDER: u128 = P4 - 1; pub struct CirclePointIndex(pub usize); impl CirclePointIndex { - pub fn zero() -> Self { + pub const fn zero() -> Self { Self(0) } - pub fn generator() -> Self { + pub const fn generator() -> Self { Self(1) } - pub fn reduce(self) -> Self { + pub const fn reduce(self) -> Self { Self(self.0 & ((1 << M31_CIRCLE_LOG_ORDER) - 1)) } @@ -343,16 +343,16 @@ impl Coset { } /// Returns the size of the coset. - pub fn size(&self) -> usize { + pub const fn size(&self) -> usize { 1 << self.log_size() } /// Returns the log size of the coset. - pub fn log_size(&self) -> u32 { + pub const fn log_size(&self) -> u32 { self.log_size } - pub fn iter(&self) -> CosetIterator> { + pub const fn iter(&self) -> CosetIterator> { CosetIterator { cur: self.initial, step: self.step, @@ -360,7 +360,7 @@ impl Coset { } } - pub fn iter_indices(&self) -> CosetIterator { + pub const fn iter_indices(&self) -> CosetIterator { CosetIterator { cur: self.initial_index, step: self.step_size, @@ -389,7 +389,7 @@ impl Coset { && *self == other.repeated_double(other.log_size - self.log_size) } - pub fn initial(&self) -> CirclePoint { + pub const fn initial(&self) -> CirclePoint { self.initial } diff --git a/crates/prover/src/core/fields/cm31.rs b/crates/prover/src/core/fields/cm31.rs index 6f1b6c2ef..e7f92dba7 100644 --- a/crates/prover/src/core/fields/cm31.rs +++ b/crates/prover/src/core/fields/cm31.rs @@ -24,7 +24,7 @@ impl CM31 { Self(M31::from_u32_unchecked(a), M31::from_u32_unchecked(b)) } - pub fn from_m31(a: M31, b: M31) -> CM31 { + pub const fn from_m31(a: M31, b: M31) -> CM31 { Self(a, b) } } diff --git a/crates/prover/src/core/fields/m31.rs b/crates/prover/src/core/fields/m31.rs index a7c3c57a2..7c28bf33a 100644 --- a/crates/prover/src/core/fields/m31.rs +++ b/crates/prover/src/core/fields/m31.rs @@ -55,7 +55,7 @@ impl M31 { /// let val = (P as u64).pow(2) - 19; /// assert_eq!(M31::reduce(val), M31::from(P - 19)); /// ``` - pub fn reduce(val: u64) -> Self { + pub const fn reduce(val: u64) -> Self { Self((((((val >> MODULUS_BITS) + val + 1) >> MODULUS_BITS) + val) & (P as u64)) as u32) } @@ -211,15 +211,15 @@ mod tests { use super::{M31, P}; use crate::core::fields::IntoSlice; - fn mul_p(a: u32, b: u32) -> u32 { + const fn mul_p(a: u32, b: u32) -> u32 { ((a as u64 * b as u64) % P as u64) as u32 } - fn add_p(a: u32, b: u32) -> u32 { + const fn add_p(a: u32, b: u32) -> u32 { (a + b) % P } - fn neg_p(a: u32) -> u32 { + const fn neg_p(a: u32) -> u32 { if a == 0 { 0 } else { diff --git a/crates/prover/src/core/fields/qm31.rs b/crates/prover/src/core/fields/qm31.rs index 6da19a3c0..41342ade6 100644 --- a/crates/prover/src/core/fields/qm31.rs +++ b/crates/prover/src/core/fields/qm31.rs @@ -32,15 +32,15 @@ impl QM31 { ) } - pub fn from_m31(a: M31, b: M31, c: M31, d: M31) -> Self { + pub const fn from_m31(a: M31, b: M31, c: M31, d: M31) -> Self { Self(CM31::from_m31(a, b), CM31::from_m31(c, d)) } - pub fn from_m31_array(array: [M31; SECURE_EXTENSION_DEGREE]) -> Self { + pub const fn from_m31_array(array: [M31; SECURE_EXTENSION_DEGREE]) -> Self { Self::from_m31(array[0], array[1], array[2], array[3]) } - pub fn to_m31_array(self) -> [M31; SECURE_EXTENSION_DEGREE] { + pub const fn to_m31_array(self) -> [M31; SECURE_EXTENSION_DEGREE] { [self.0 .0, self.0 .1, self.1 .0, self.1 .1] } diff --git a/crates/prover/src/core/fri.rs b/crates/prover/src/core/fri.rs index 0f0a89b9b..46eabeedb 100644 --- a/crates/prover/src/core/fri.rs +++ b/crates/prover/src/core/fri.rs @@ -70,7 +70,7 @@ impl FriConfig { } } - fn last_layer_domain_size(&self) -> usize { + const fn last_layer_domain_size(&self) -> usize { 1 << (self.log_last_layer_degree_bound + self.log_blowup_factor) } } @@ -597,13 +597,13 @@ pub struct CirclePolyDegreeBound { } impl CirclePolyDegreeBound { - pub fn new(log_degree_bound: u32) -> Self { + pub const fn new(log_degree_bound: u32) -> Self { Self { log_degree_bound } } /// Maps a circle polynomial's degree bound to the degree bound of the univariate (line) /// polynomial it gets folded into. - fn fold_to_line(&self) -> LinePolyDegreeBound { + const fn fold_to_line(&self) -> LinePolyDegreeBound { LinePolyDegreeBound { log_degree_bound: self.log_degree_bound - CIRCLE_TO_LINE_FOLD_STEP, } @@ -629,7 +629,7 @@ struct LinePolyDegreeBound { impl LinePolyDegreeBound { /// Returns [None] if the unfolded degree bound is smaller than the folding factor. - fn fold(self, n_folds: u32) -> Option { + const fn fold(self, n_folds: u32) -> Option { if self.log_degree_bound < n_folds { return None; } diff --git a/crates/prover/src/core/lookups/utils.rs b/crates/prover/src/core/lookups/utils.rs index ed67477f7..d66bd93f0 100644 --- a/crates/prover/src/core/lookups/utils.rs +++ b/crates/prover/src/core/lookups/utils.rs @@ -202,7 +202,7 @@ pub struct Fraction { } impl Fraction { - pub fn new(numerator: N, denominator: D) -> Self { + pub const fn new(numerator: N, denominator: D) -> Self { Self { numerator, denominator, @@ -256,7 +256,7 @@ pub struct Reciprocal { } impl Reciprocal { - pub fn new(x: T) -> Self { + pub const fn new(x: T) -> Self { Self { x } } } diff --git a/crates/prover/src/core/poly/circle/canonic.rs b/crates/prover/src/core/poly/circle/canonic.rs index 837e648d9..cda0fcc8c 100644 --- a/crates/prover/src/core/poly/circle/canonic.rs +++ b/crates/prover/src/core/poly/circle/canonic.rs @@ -31,7 +31,7 @@ impl CanonicCoset { } /// Gets the full coset represented G_{2n} + . - pub fn coset(&self) -> Coset { + pub const fn coset(&self) -> Coset { self.coset } @@ -46,24 +46,24 @@ impl CanonicCoset { } /// Returns the log size of the coset. - pub fn log_size(&self) -> u32 { + pub const fn log_size(&self) -> u32 { self.coset.log_size } /// Returns the size of the coset. - pub fn size(&self) -> usize { + pub const fn size(&self) -> usize { self.coset.size() } - pub fn initial_index(&self) -> CirclePointIndex { + pub const fn initial_index(&self) -> CirclePointIndex { self.coset.initial_index } - pub fn step_size(&self) -> CirclePointIndex { + pub const fn step_size(&self) -> CirclePointIndex { self.coset.step_size } - pub fn step(&self) -> CirclePoint { + pub const fn step(&self) -> CirclePoint { self.coset.step } diff --git a/crates/prover/src/core/poly/circle/domain.rs b/crates/prover/src/core/poly/circle/domain.rs index fba2bc3fb..2bffac773 100644 --- a/crates/prover/src/core/poly/circle/domain.rs +++ b/crates/prover/src/core/poly/circle/domain.rs @@ -20,7 +20,7 @@ pub struct CircleDomain { impl CircleDomain { /// Given a coset C + , constructs the circle domain +-C + (i.e., /// this coset and its conjugate). - pub fn new(half_coset: Coset) -> Self { + pub const fn new(half_coset: Coset) -> Self { Self { half_coset } } @@ -38,12 +38,12 @@ impl CircleDomain { } /// Returns the size of the domain. - pub fn size(&self) -> usize { + pub const fn size(&self) -> usize { 1 << self.log_size() } /// Returns the log size of the domain. - pub fn log_size(&self) -> u32 { + pub const fn log_size(&self) -> u32 { self.half_coset.log_size + 1 } diff --git a/crates/prover/src/core/poly/circle/poly.rs b/crates/prover/src/core/poly/circle/poly.rs index c10fc5e7a..6744a0c80 100644 --- a/crates/prover/src/core/poly/circle/poly.rs +++ b/crates/prover/src/core/poly/circle/poly.rs @@ -34,7 +34,7 @@ impl CirclePoly { Self { log_size, coeffs } } - pub fn log_size(&self) -> u32 { + pub const fn log_size(&self) -> u32 { self.log_size } diff --git a/crates/prover/src/core/poly/line.rs b/crates/prover/src/core/poly/line.rs index 2bf640c63..2b58e01fb 100644 --- a/crates/prover/src/core/poly/line.rs +++ b/crates/prover/src/core/poly/line.rs @@ -58,12 +58,12 @@ impl LineDomain { } /// Returns the size of the domain. - pub fn size(&self) -> usize { + pub const fn size(&self) -> usize { self.coset.size() } /// Returns the log size of the domain. - pub fn log_size(&self) -> u32 { + pub const fn log_size(&self) -> u32 { self.coset.log_size() } @@ -80,7 +80,7 @@ impl LineDomain { } /// Returns the domain's underlying coset. - pub fn coset(&self) -> Coset { + pub const fn coset(&self) -> Coset { self.coset } } @@ -209,11 +209,11 @@ impl> LineEvaluation { /// Returns the number of evaluations. #[allow(clippy::len_without_is_empty)] - pub fn len(&self) -> usize { + pub const fn len(&self) -> usize { 1 << self.domain.log_size() } - pub fn domain(&self) -> LineDomain { + pub const fn domain(&self) -> LineDomain { self.domain } diff --git a/crates/prover/src/core/utils.rs b/crates/prover/src/core/utils.rs index 05842fb2d..330e49ed2 100644 --- a/crates/prover/src/core/utils.rs +++ b/crates/prover/src/core/utils.rs @@ -54,7 +54,7 @@ impl<'a, I: Iterator> PeekableExt<'a, I> for Peekable { } /// Returns the bit reversed index of `i` which is represented by `log_size` bits. -pub fn bit_reverse_index(i: usize, log_size: u32) -> usize { +pub const fn bit_reverse_index(i: usize, log_size: u32) -> usize { if log_size == 0 { return i; } @@ -64,7 +64,7 @@ pub fn bit_reverse_index(i: usize, log_size: u32) -> usize { /// Returns the index of the previous element in a bit reversed /// [super::poly::circle::CircleEvaluation] of log size `eval_log_size` relative to a smaller domain /// of size `domain_log_size`. -pub fn previous_bit_reversed_circle_domain_index( +pub const fn previous_bit_reversed_circle_domain_index( i: usize, domain_log_size: u32, eval_log_size: u32, @@ -75,7 +75,7 @@ pub fn previous_bit_reversed_circle_domain_index( /// Returns the index of the offset element in a bit reversed /// [super::poly::circle::CircleEvaluation] of log size `eval_log_size` relative to a smaller domain /// of size `domain_log_size`. -pub fn offset_bit_reversed_circle_domain_index( +pub const fn offset_bit_reversed_circle_domain_index( i: usize, domain_log_size: u32, eval_log_size: u32, @@ -122,7 +122,7 @@ pub(crate) fn coset_order_to_circle_domain_order(values: &[F]) -> Vec< /// /// [`CircleDomain`]: crate::core::poly::circle::CircleDomain /// [`Coset`]: crate::core::circle::Coset -pub fn coset_index_to_circle_domain_index(coset_index: usize, log_domain_size: u32) -> usize { +pub const fn coset_index_to_circle_domain_index(coset_index: usize, log_domain_size: u32) -> usize { if coset_index % 2 == 0 { coset_index / 2 } else { diff --git a/crates/prover/src/core/vcs/blake2s_ref.rs b/crates/prover/src/core/vcs/blake2s_ref.rs index 9b982bba6..95665597c 100644 --- a/crates/prover/src/core/vcs/blake2s_ref.rs +++ b/crates/prover/src/core/vcs/blake2s_ref.rs @@ -19,32 +19,32 @@ pub const SIGMA: [[u8; 16]; 10] = [ ]; #[inline(always)] -fn add(a: u32, b: u32) -> u32 { +const fn add(a: u32, b: u32) -> u32 { a.wrapping_add(b) } #[inline(always)] -fn xor(a: u32, b: u32) -> u32 { +const fn xor(a: u32, b: u32) -> u32 { a ^ b } #[inline(always)] -fn rot16(x: u32) -> u32 { +const fn rot16(x: u32) -> u32 { (x >> 16) | (x << (32 - 16)) } #[inline(always)] -fn rot12(x: u32) -> u32 { +const fn rot12(x: u32) -> u32 { (x >> 12) | (x << (32 - 12)) } #[inline(always)] -fn rot8(x: u32) -> u32 { +const fn rot8(x: u32) -> u32 { (x >> 8) | (x << (32 - 8)) } #[inline(always)] -fn rot7(x: u32) -> u32 { +const fn rot7(x: u32) -> u32 { (x >> 7) | (x << (32 - 7)) } diff --git a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs index 69acedccb..cc7ebdaec 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs @@ -1213,7 +1213,7 @@ mod tests { } impl MleCoeffColumnEval { - pub fn new(interaction: usize, n_variables: usize) -> Self { + pub const fn new(interaction: usize, n_variables: usize) -> Self { Self { interaction, n_variables, From a0cce5c12ff31c214b1fd2685fa1ef8df211fc9c Mon Sep 17 00:00:00 2001 From: Shahar Samocha Date: Thu, 28 Nov 2024 16:49:29 +0200 Subject: [PATCH 05/69] Fix clippy const function error --- crates/prover/src/constraint_framework/component.rs | 2 +- crates/prover/src/constraint_framework/mod.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index 078e84bb8..23981cc06 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -92,7 +92,7 @@ impl TraceLocationAllocator { } } - pub fn preprocessed_columns(&self) -> &HashMap { + pub const fn preprocessed_columns(&self) -> &HashMap { &self.preprocessed_columns } diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 9bbf05402..0870e149d 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -234,7 +234,7 @@ pub struct RelationEntry<'a, F: Clone, EF: RelationEFTraitBound, R: Relation< values: &'a [F], } impl<'a, F: Clone, EF: RelationEFTraitBound, R: Relation> RelationEntry<'a, F, EF, R> { - pub fn new(relation: &'a R, multiplicity: EF, values: &'a [F]) -> Self { + pub const fn new(relation: &'a R, multiplicity: EF, values: &'a [F]) -> Self { Self { relation, multiplicity, From b6ea51249e664527358d819b14f99cc32e913d0c Mon Sep 17 00:00:00 2001 From: Alon-Ti <54235977+Alon-Ti@users.noreply.github.com> Date: Thu, 28 Nov 2024 17:24:43 +0200 Subject: [PATCH 06/69] Simplify expressions. (#894) --- .../prover/src/constraint_framework/expr.rs | 100 ++++++++++++++++-- .../prover/src/examples/state_machine/mod.rs | 12 +-- 2 files changed, 97 insertions(+), 15 deletions(-) diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index 5d8013402..304c9f0e6 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -3,7 +3,7 @@ use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub}; use num_traits::{One, Zero}; use super::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX}; -use crate::core::fields::m31::{self, BaseField}; +use crate::core::fields::m31::{self, BaseField, M31}; use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; use crate::core::lookups::utils::Fraction; @@ -36,7 +36,6 @@ pub enum Expr { } impl Expr { - #[allow(dead_code)] pub fn format_expr(&self) -> String { match self { Expr::Col(ColumnExpr { @@ -67,6 +66,10 @@ impl Expr { Expr::Inv(a) => format!("1 / ({})", a.format_expr()), } } + + pub fn simplify_and_format(&self) -> String { + simplify(self.clone()).format_expr() + } } impl From for Expr { @@ -190,6 +193,88 @@ impl AddAssign for Expr { } } +const ZERO: M31 = M31(0); +const ONE: M31 = M31(1); +const MINUS_ONE: M31 = M31(m31::P - 1); + +// TODO(alont) Add random point assignment test. +pub fn simplify(expr: Expr) -> Expr { + match expr { + Expr::Add(a, b) => { + let a = simplify(*a); + let b = simplify(*b); + match (a.clone(), b.clone()) { + (Expr::Const(a), Expr::Const(b)) => Expr::Const(a + b), + (Expr::Const(ZERO), _) => b, // 0 + b = b + (_, Expr::Const(ZERO)) => a, // a + 0 = a + // (-a + -b) = -(a + b) + (Expr::Neg(minus_a), Expr::Neg(minus_b)) => -(*minus_a + *minus_b), + (Expr::Neg(minus_a), _) => b - *minus_a, // -a + b = b - a + (_, Expr::Neg(minus_b)) => a - *minus_b, // a + -b = a - b + _ => Expr::Add(Box::new(a), Box::new(b)), + } + } + Expr::Sub(a, b) => { + let a = simplify(*a); + let b = simplify(*b); + match (a.clone(), b.clone()) { + (Expr::Const(a), Expr::Const(b)) => Expr::Const(a - b), + (Expr::Const(ZERO), _) => -b, // 0 - b = -b + (_, Expr::Const(ZERO)) => a, // a - 0 = a + // (-a - -b) = b - a + (Expr::Neg(minus_a), Expr::Neg(minus_b)) => *minus_b - *minus_a, + (Expr::Neg(minus_a), _) => -(*minus_a + b), // -a - b = -(a + b) + (_, Expr::Neg(minus_b)) => a + *minus_b, // a + -b = a - b + _ => Expr::Sub(Box::new(a), Box::new(b)), + } + } + Expr::Mul(a, b) => { + let a = simplify(*a); + let b = simplify(*b); + match (a.clone(), b.clone()) { + (Expr::Const(a), Expr::Const(b)) => Expr::Const(a * b), + (Expr::Const(ZERO), _) => Expr::zero(), // 0 * b = 0 + (_, Expr::Const(ZERO)) => Expr::zero(), // a * 0 = 0 + (Expr::Const(ONE), _) => b, // 1 * b = b + (_, Expr::Const(ONE)) => a, // a * 1 = a + // (-a) * (-b) = a * b + (Expr::Neg(minus_a), Expr::Neg(minus_b)) => *minus_a * *minus_b, + (Expr::Neg(minus_a), _) => -(*minus_a * b), // (-a) * b = -(a * b) + (_, Expr::Neg(minus_b)) => -(a * *minus_b), // a * (-b) = -(a * b) + (Expr::Const(MINUS_ONE), _) => -b, // -1 * b = -b + (_, Expr::Const(MINUS_ONE)) => -a, // a * -1 = -a + _ => Expr::Mul(Box::new(a), Box::new(b)), + } + } + Expr::Col(colexpr) => Expr::Col(colexpr), + Expr::SecureCol([a, b, c, d]) => Expr::SecureCol([ + Box::new(simplify(*a)), + Box::new(simplify(*b)), + Box::new(simplify(*c)), + Box::new(simplify(*d)), + ]), + Expr::Const(c) => Expr::Const(c), + Expr::Param(x) => Expr::Param(x), + Expr::Neg(a) => { + let a = simplify(*a); + match a { + Expr::Const(c) => Expr::Const(-c), + Expr::Neg(minus_a) => *minus_a, // -(-a) = a + Expr::Sub(a, b) => Expr::Sub(b, a), // -(a - b) = b - a + _ => Expr::Neg(Box::new(a)), + } + } + Expr::Inv(a) => { + let a = simplify(*a); + match a { + Expr::Inv(inv_a) => *inv_a, // 1 / (1 / a) = a + Expr::Const(c) => Expr::Const(c.inverse()), + _ => Expr::Inv(Box::new(a)), + } + } + } +} + /// Returns the expression /// `value[0] * _alpha0 + value[1] * _alpha1 + ... - _z.` fn combine_formal>(relation: &R, values: &[Expr]) -> Expr { @@ -273,7 +358,7 @@ impl ExprEvaluator { let lets_string = self .intermediates .iter() - .map(|(name, expr)| format!("let {} = {};", name, expr.format_expr())) + .map(|(name, expr)| format!("let {} = {};", name, expr.simplify_and_format())) .collect::>() .join("\n"); @@ -281,7 +366,7 @@ impl ExprEvaluator { .constraints .iter() .enumerate() - .map(|(i, c)| format!("let constraint_{i} = ") + &c.format_expr() + ";") + .map(|(i, c)| format!("let constraint_{i} = ") + &c.simplify_and_format() + ";") .collect::>() .join("\n\n"); @@ -369,8 +454,7 @@ mod tests { fn test_format_expr() { let test_struct = TestStruct {}; let eval = test_struct.evaluate(ExprEvaluator::new(16, false)); - let expected = "let intermediate0 = 0 \ - + (TestRelation_alpha0) * (col_1_0[0]) \ + let expected = "let intermediate0 = (TestRelation_alpha0) * (col_1_0[0]) \ + (TestRelation_alpha1) * (col_1_1[0]) \ + (TestRelation_alpha2) * (col_1_2[0]) \ - (TestRelation_z); @@ -382,8 +466,8 @@ mod tests { \ let constraint_1 = (SecureCol(col_2_4[0], col_2_6[0], col_2_8[0], col_2_10[0]) \ - (SecureCol(col_2_5[-1], col_2_7[-1], col_2_9[-1], col_2_11[-1]) \ - - ((col_0_3[0]) * (total_sum))) \ - - (0)) \ + - ((col_0_3[0]) * (total_sum)))\ + ) \ * (intermediate0) \ - (1);" .to_string(); diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index 787394dd3..2cf8bc2f2 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -300,13 +300,11 @@ mod tests { ); let eval = component.evaluate(ExprEvaluator::new(log_n_rows, true)); - let expected = "let intermediate0 = 0 \ - + (StateMachineElements_alpha0) * (col_1_0[0]) \ + let expected = "let intermediate0 = (StateMachineElements_alpha0) * (col_1_0[0]) \ + (StateMachineElements_alpha1) * (col_1_1[0]) \ - (StateMachineElements_z); \ - let intermediate1 = 0 \ - + (StateMachineElements_alpha0) * (col_1_0[0] + 1) \ + let intermediate1 = (StateMachineElements_alpha0) * (col_1_0[0] + 1) \ + (StateMachineElements_alpha1) * (col_1_1[0]) \ - (StateMachineElements_z); @@ -322,10 +320,10 @@ mod tests { \ let constraint_1 = (SecureCol(col_2_3[0], col_2_6[0], col_2_9[0], col_2_12[0]) \ - (SecureCol(col_2_4[-1], col_2_7[-1], col_2_10[-1], col_2_13[-1]) \ - - ((col_0_2[0]) * (total_sum))) \ - - (0)) \ + - ((col_0_2[0]) * (total_sum)))\ + ) \ * ((intermediate0) * (intermediate1)) \ - - ((intermediate1) * (1) + (intermediate0) * (-(1)));" + - (intermediate1 - (intermediate0));" .to_string(); assert_eq!(eval.format_constraints(), expected); From e0d930ed20f996b52ab246ac54cbf89ddbb3a352 Mon Sep 17 00:00:00 2001 From: Alon-Ti <54235977+Alon-Ti@users.noreply.github.com> Date: Thu, 28 Nov 2024 17:29:31 +0200 Subject: [PATCH 07/69] Added `From` in `add_constraint`. (#901) --- crates/prover/src/constraint_framework/assert.rs | 12 ++++++++---- crates/prover/src/constraint_framework/cpu_domain.rs | 2 +- crates/prover/src/constraint_framework/expr.rs | 12 ++---------- crates/prover/src/constraint_framework/mod.rs | 2 +- .../prover/src/constraint_framework/simd_domain.rs | 2 +- 5 files changed, 13 insertions(+), 17 deletions(-) diff --git a/crates/prover/src/constraint_framework/assert.rs b/crates/prover/src/constraint_framework/assert.rs index ce37a0cb3..34ab6fdec 100644 --- a/crates/prover/src/constraint_framework/assert.rs +++ b/crates/prover/src/constraint_framework/assert.rs @@ -1,4 +1,4 @@ -use num_traits::{One, Zero}; +use num_traits::Zero; use super::logup::{LogupAtRow, LogupSums}; use super::{EvalAtRow, INTERACTION_TRACE_IDX}; @@ -54,13 +54,17 @@ impl<'a> EvalAtRow for AssertEvaluator<'a> { fn add_constraint(&mut self, constraint: G) where - Self::EF: std::ops::Mul, + Self::EF: std::ops::Mul + From, { // Cast to SecureField. - let res = SecureField::one() * constraint; // The constraint should be zero at the given row, since we are evaluating on the trace // domain. - assert_eq!(res, SecureField::zero(), "row: {}", self.row); + assert_eq!( + Self::EF::from(constraint), + SecureField::zero(), + "row: {}", + self.row + ); } fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF { diff --git a/crates/prover/src/constraint_framework/cpu_domain.rs b/crates/prover/src/constraint_framework/cpu_domain.rs index 72d285aeb..8c0f4beb9 100644 --- a/crates/prover/src/constraint_framework/cpu_domain.rs +++ b/crates/prover/src/constraint_framework/cpu_domain.rs @@ -85,7 +85,7 @@ impl<'a> EvalAtRow for CpuDomainEvaluator<'a> { fn add_constraint(&mut self, constraint: G) where - Self::EF: Mul, + Self::EF: Mul + From, { self.row_res += self.random_coeff_powers[self.constraint_index] * constraint; self.constraint_index += 1; diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index 304c9f0e6..cb1f860aa 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -397,17 +397,9 @@ impl EvalAtRow for ExprEvaluator { fn add_constraint(&mut self, constraint: G) where - Self::EF: std::ops::Mul, + Self::EF: From, { - match Expr::one() * constraint { - Expr::Mul(one, constraint) => { - assert_eq!(*one, Expr::one()); - self.constraints.push(*constraint); - } - _ => { - unreachable!(); - } - } + self.constraints.push(constraint.into()); } fn combine_ef(values: [Self::F; 4]) -> Self::EF { diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 0870e149d..044f05ac8 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -108,7 +108,7 @@ pub trait EvalAtRow { /// Adds a constraint to the component. fn add_constraint(&mut self, constraint: G) where - Self::EF: Mul; + Self::EF: Mul + From; /// Combines 4 base field values into a single extension field value. fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF; diff --git a/crates/prover/src/constraint_framework/simd_domain.rs b/crates/prover/src/constraint_framework/simd_domain.rs index b18e8e265..c85942228 100644 --- a/crates/prover/src/constraint_framework/simd_domain.rs +++ b/crates/prover/src/constraint_framework/simd_domain.rs @@ -98,7 +98,7 @@ impl<'a> EvalAtRow for SimdDomainEvaluator<'a> { } fn add_constraint(&mut self, constraint: G) where - Self::EF: Mul, + Self::EF: Mul + From, { self.row_res += VeryPackedSecureField::broadcast(self.random_coeff_powers[self.constraint_index]) From d84b24a6e6af5aa9612099359adc2e658f0629dd Mon Sep 17 00:00:00 2001 From: Alon-Ti <54235977+Alon-Ti@users.noreply.github.com> Date: Sun, 1 Dec 2024 15:48:49 +0200 Subject: [PATCH 08/69] Separated Base and Extension field expressions. (#903) --- .../prover/src/constraint_framework/expr.rs | 579 +++++++++++++----- .../prover/src/examples/state_machine/mod.rs | 2 +- 2 files changed, 410 insertions(+), 171 deletions(-) diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index cb1f860aa..ab32829cf 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -3,8 +3,9 @@ use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub}; use num_traits::{One, Zero}; use super::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX}; -use crate::core::fields::m31::{self, BaseField, M31}; -use crate::core::fields::qm31::SecureField; +use crate::core::fields::cm31::CM31; +use crate::core::fields::m31::{self, BaseField}; +use crate::core::fields::qm31::{SecureField, QM31}; use crate::core::fields::FieldExpOps; use crate::core::lookups::utils::Fraction; @@ -16,122 +17,373 @@ pub struct ColumnExpr { offset: isize, } +/// An expression representing a base field value. Can be either: +/// * A column indexed by a `ColumnExpr`. +/// * A base field constant. +/// * A formal parameter to the AIR. +/// * A sum, difference, or product of two base field expressions. +/// * A negation or inverse of a base field expression. +/// +/// This type is meant to be used as an F associated type for EvalAtRow and interacts with +/// `ExtExpr`, `BaseField` and `SecureField` as expected. #[derive(Clone, Debug, PartialEq)] -pub enum Expr { +pub enum BaseExpr { Col(ColumnExpr), + Const(BaseField), + /// Formal parameter to the AIR, for example the interaction elements of a relation. + Param(String), + Add(Box, Box), + Sub(Box, Box), + Mul(Box, Box), + Neg(Box), + Inv(Box), +} + +/// An expression representing a secure field value. Can be either: +/// * A secure column constructed from 4 base field expressions. +/// * A secure field constant. +/// * A formal parameter to the AIR. +/// * A sum, difference, or product of two secure field expressions. +/// * A negation of a secure field expression. +/// +/// This type is meant to be used as an EF associated type for EvalAtRow and interacts with +/// `BaseExpr`, `BaseField` and `SecureField` as expected. +#[derive(Clone, Debug, PartialEq)] +pub enum ExtExpr { /// An atomic secure column constructed from 4 expressions. /// Expressions on the secure column are not reduced, i.e, /// if `a = SecureCol(a0, a1, a2, a3)`, `b = SecureCol(b0, b1, b2, b3)` then /// `a + b` evaluates to `Add(a, b)` rather than /// `SecureCol(Add(a0, b0), Add(a1, b1), Add(a2, b2), Add(a3, b3))` - SecureCol([Box; 4]), - Const(BaseField), + SecureCol([Box; 4]), + Const(SecureField), /// Formal parameter to the AIR, for example the interaction elements of a relation. Param(String), - Add(Box, Box), - Sub(Box, Box), - Mul(Box, Box), - Neg(Box), - Inv(Box), + Add(Box, Box), + Sub(Box, Box), + Mul(Box, Box), + Neg(Box), +} + +/// Applies simplifications to arithmetic expressions that can be used both for `BaseExpr` and for +/// `ExtExpr`. +macro_rules! simplify_arithmetic { + ($self:tt) => { + match $self.clone() { + Self::Add(a, b) => { + let a = a.simplify(); + let b = b.simplify(); + match (a.clone(), b.clone()) { + // Simplify constants. + (Self::Const(a), Self::Const(b)) => Self::Const(a + b), + (Self::Const(a_val), _) if a_val.is_zero() => b, // 0 + b = b + (_, Self::Const(b_val)) if b_val.is_zero() => a, // a + 0 = a + // Simplify Negs. + // (-a + -b) = -(a + b) + (Self::Neg(minus_a), Self::Neg(minus_b)) => -(*minus_a + *minus_b), + (Self::Neg(minus_a), _) => b - *minus_a, // -a + b = b - a + (_, Self::Neg(minus_b)) => a - *minus_b, // a + -b = a - b + // No simplification. + _ => a + b, + } + } + Self::Sub(a, b) => { + let a = a.simplify(); + let b = b.simplify(); + match (a.clone(), b.clone()) { + // Simplify constants. + (Self::Const(a), Self::Const(b)) => Self::Const(a - b), // Simplify consts. + (Self::Const(a_val), _) if a_val.is_zero() => -b, // 0 - b = -b + (_, Self::Const(b_val)) if b_val.is_zero() => a, // a - 0 = a + // Simplify Negs. + // (-a - -b) = b - a + (Self::Neg(minus_a), Self::Neg(minus_b)) => *minus_b - *minus_a, + (Self::Neg(minus_a), _) => -(*minus_a + b), // -a - b = -(a + b) + (_, Self::Neg(minus_b)) => a + *minus_b, // a + -b = a - b + // No Simplification. + _ => a - b, + } + } + Self::Mul(a, b) => { + let a = a.simplify(); + let b = b.simplify(); + match (a.clone(), b.clone()) { + // Simplify consts. + (Self::Const(a), Self::Const(b)) => Self::Const(a * b), + (Self::Const(a_val), _) if a_val.is_zero() => Self::zero(), // 0 * b = 0 + (_, Self::Const(b_val)) if b_val.is_zero() => Self::zero(), // a * 0 = 0 + (Self::Const(a_val), _) if a_val == One::one() => b, // 1 * b = b + (_, Self::Const(b_val)) if b_val == One::one() => a, // a * 1 = a + (Self::Const(a_val), _) if -a_val == One::one() => -b, // -1 * b = -b + (_, Self::Const(b_val)) if -b_val == One::one() => -a, // a * -1 = -a + // Simplify Negs. + // (-a) * (-b) = a * b + (Self::Neg(minus_a), Self::Neg(minus_b)) => *minus_a * *minus_b, + (Self::Neg(minus_a), _) => -(*minus_a * b), // (-a) * b = -(a * b) + (_, Self::Neg(minus_b)) => -(a * *minus_b), // a * (-b) = -(a * b) + // No simplification. + _ => a * b, + } + } + Self::Neg(a) => { + let a = a.simplify(); + match a { + Self::Const(c) => Self::Const(-c), + Self::Neg(minus_a) => *minus_a, // -(-a) = a + Self::Sub(a, b) => Self::Sub(b, a), // -(a - b) = b - a + _ => -a, // No simplification. + } + } + other => other, // No simplification. + } + }; } -impl Expr { +impl BaseExpr { pub fn format_expr(&self) -> String { match self { - Expr::Col(ColumnExpr { + BaseExpr::Col(ColumnExpr { interaction, idx, offset, }) => { - let offset_str = if *offset == CLAIMED_SUM_DUMMY_OFFSET.try_into().unwrap() { + let offset_str = if *offset == CLAIMED_SUM_DUMMY_OFFSET as isize { "claimed_sum_offset".to_string() } else { offset.to_string() }; format!("col_{interaction}_{idx}[{offset_str}]") } - Expr::SecureCol([a, b, c, d]) => format!( - "SecureCol({}, {}, {}, {})", - a.format_expr(), - b.format_expr(), - c.format_expr(), - d.format_expr() - ), - Expr::Const(c) => c.0.to_string(), - Expr::Param(v) => v.to_string(), - Expr::Add(a, b) => format!("{} + {}", a.format_expr(), b.format_expr()), - Expr::Sub(a, b) => format!("{} - ({})", a.format_expr(), b.format_expr()), - Expr::Mul(a, b) => format!("({}) * ({})", a.format_expr(), b.format_expr()), - Expr::Neg(a) => format!("-({})", a.format_expr()), - Expr::Inv(a) => format!("1 / ({})", a.format_expr()), + BaseExpr::Const(c) => c.to_string(), + BaseExpr::Param(v) => v.to_string(), + BaseExpr::Add(a, b) => format!("{} + {}", a.format_expr(), b.format_expr()), + BaseExpr::Sub(a, b) => format!("{} - ({})", a.format_expr(), b.format_expr()), + BaseExpr::Mul(a, b) => format!("({}) * ({})", a.format_expr(), b.format_expr()), + BaseExpr::Neg(a) => format!("-({})", a.format_expr()), + BaseExpr::Inv(a) => format!("1 / ({})", a.format_expr()), + } + } + + pub fn simplify(&self) -> Self { + let simple = simplify_arithmetic!(self); + match simple { + Self::Inv(a) => { + let a = a.simplify(); + match a { + Self::Inv(inv_a) => *inv_a, // 1 / (1 / a) = a + Self::Const(c) => Self::Const(c.inverse()), + _ => Self::Inv(Box::new(a)), + } + } + other => other, + } + } + + pub fn simplify_and_format(&self) -> String { + self.simplify().format_expr() + } +} + +impl ExtExpr { + pub fn format_expr(&self) -> String { + match self { + ExtExpr::SecureCol([a, b, c, d]) => { + // If the expression's non-base components are all constant zeroes, return the base + // field representation of its first part. + if **b == BaseExpr::zero() && **c == BaseExpr::zero() && **d == BaseExpr::zero() { + a.format_expr() + } else { + format!( + "SecureCol({}, {}, {}, {})", + a.format_expr(), + b.format_expr(), + c.format_expr(), + d.format_expr() + ) + } + } + ExtExpr::Const(c) => { + if c.0 .1.is_zero() && c.1 .0.is_zero() && c.1 .1.is_zero() { + // If the constant is in the base field, display it as such. + c.0 .0.to_string() + } else { + c.to_string() + } + } + ExtExpr::Param(v) => v.to_string(), + ExtExpr::Add(a, b) => format!("{} + {}", a.format_expr(), b.format_expr()), + ExtExpr::Sub(a, b) => format!("{} - ({})", a.format_expr(), b.format_expr()), + ExtExpr::Mul(a, b) => format!("({}) * ({})", a.format_expr(), b.format_expr()), + ExtExpr::Neg(a) => format!("-({})", a.format_expr()), + } + } + + pub fn simplify(&self) -> Self { + let simple = simplify_arithmetic!(self); + match simple { + Self::SecureCol([a, b, c, d]) => { + let a = a.simplify(); + let b = b.simplify(); + let c = c.simplify(); + let d = d.simplify(); + match (a.clone(), b.clone(), c.clone(), d.clone()) { + ( + BaseExpr::Const(a_val), + BaseExpr::Const(b_val), + BaseExpr::Const(c_val), + BaseExpr::Const(d_val), + ) => ExtExpr::Const(SecureField::from_m31_array([a_val, b_val, c_val, d_val])), + _ => Self::SecureCol([Box::new(a), Box::new(b), Box::new(c), Box::new(d)]), + } + } + other => other, } } pub fn simplify_and_format(&self) -> String { - simplify(self.clone()).format_expr() + self.simplify().format_expr() } } -impl From for Expr { +impl From for BaseExpr { fn from(val: BaseField) -> Self { - Expr::Const(val) + BaseExpr::Const(val) } } -impl From for Expr { - fn from(val: SecureField) -> Self { - Expr::SecureCol([ - Box::new(val.0 .0.into()), - Box::new(val.0 .1.into()), - Box::new(val.1 .0.into()), - Box::new(val.1 .1.into()), +impl From for ExtExpr { + fn from(val: BaseField) -> Self { + ExtExpr::SecureCol([ + Box::new(BaseExpr::from(val)), + Box::new(BaseExpr::zero()), + Box::new(BaseExpr::zero()), + Box::new(BaseExpr::zero()), + ]) + } +} + +impl From for ExtExpr { + fn from(QM31(CM31(a, b), CM31(c, d)): SecureField) -> Self { + ExtExpr::SecureCol([ + Box::new(BaseExpr::from(a)), + Box::new(BaseExpr::from(b)), + Box::new(BaseExpr::from(c)), + Box::new(BaseExpr::from(d)), + ]) + } +} + +impl From for ExtExpr { + fn from(expr: BaseExpr) -> Self { + ExtExpr::SecureCol([ + Box::new(expr.clone()), + Box::new(BaseExpr::zero()), + Box::new(BaseExpr::zero()), + Box::new(BaseExpr::zero()), ]) } } -impl Add for Expr { +impl Add for BaseExpr { type Output = Self; fn add(self, rhs: Self) -> Self { - Expr::Add(Box::new(self), Box::new(rhs)) + BaseExpr::Add(Box::new(self), Box::new(rhs)) } } -impl Sub for Expr { +impl Sub for BaseExpr { type Output = Self; fn sub(self, rhs: Self) -> Self { - Expr::Sub(Box::new(self), Box::new(rhs)) + BaseExpr::Sub(Box::new(self), Box::new(rhs)) } } -impl Mul for Expr { +impl Mul for BaseExpr { type Output = Self; fn mul(self, rhs: Self) -> Self { - Expr::Mul(Box::new(self), Box::new(rhs)) + BaseExpr::Mul(Box::new(self), Box::new(rhs)) } } -impl AddAssign for Expr { +impl AddAssign for BaseExpr { fn add_assign(&mut self, rhs: Self) { *self = self.clone() + rhs } } -impl MulAssign for Expr { +impl MulAssign for BaseExpr { fn mul_assign(&mut self, rhs: Self) { *self = self.clone() * rhs } } -impl Neg for Expr { +impl Neg for BaseExpr { type Output = Self; fn neg(self) -> Self { - Expr::Neg(Box::new(self)) + BaseExpr::Neg(Box::new(self)) + } +} + +impl Add for ExtExpr { + type Output = Self; + fn add(self, rhs: Self) -> Self { + ExtExpr::Add(Box::new(self), Box::new(rhs)) + } +} + +impl Sub for ExtExpr { + type Output = Self; + fn sub(self, rhs: Self) -> Self { + ExtExpr::Sub(Box::new(self), Box::new(rhs)) + } +} + +impl Mul for ExtExpr { + type Output = Self; + fn mul(self, rhs: Self) -> Self { + ExtExpr::Mul(Box::new(self), Box::new(rhs)) + } +} + +impl AddAssign for ExtExpr { + fn add_assign(&mut self, rhs: Self) { + *self = self.clone() + rhs + } +} + +impl MulAssign for ExtExpr { + fn mul_assign(&mut self, rhs: Self) { + *self = self.clone() * rhs } } -impl Zero for Expr { +impl Neg for ExtExpr { + type Output = Self; + fn neg(self) -> Self { + ExtExpr::Neg(Box::new(self)) + } +} + +impl Zero for BaseExpr { + fn zero() -> Self { + BaseExpr::from(BaseField::zero()) + } + fn is_zero(&self) -> bool { + // TODO(alont): consider replacing `Zero` in the trait bound with a custom trait + // that only has `zero()`. + panic!("Can't check if an expression is zero."); + } +} + +impl One for BaseExpr { + fn one() -> Self { + BaseExpr::from(BaseField::one()) + } +} + +impl Zero for ExtExpr { fn zero() -> Self { - Expr::Const(BaseField::zero()) + ExtExpr::from(BaseField::zero()) } fn is_zero(&self) -> bool { // TODO(alont): consider replacing `Zero` in the trait bound with a custom trait @@ -140,154 +392,141 @@ impl Zero for Expr { } } -impl One for Expr { +impl One for ExtExpr { fn one() -> Self { - Expr::Const(BaseField::one()) + ExtExpr::from(BaseField::one()) } } -impl FieldExpOps for Expr { +impl FieldExpOps for BaseExpr { fn inverse(&self) -> Self { - Expr::Inv(Box::new(self.clone())) + BaseExpr::Inv(Box::new(self.clone())) + } +} + +impl Add for BaseExpr { + type Output = Self; + fn add(self, rhs: BaseField) -> Self { + self + BaseExpr::from(rhs) + } +} + +impl AddAssign for BaseExpr { + fn add_assign(&mut self, rhs: BaseField) { + *self = self.clone() + BaseExpr::from(rhs) + } +} + +impl Mul for BaseExpr { + type Output = Self; + fn mul(self, rhs: BaseField) -> Self { + self * BaseExpr::from(rhs) + } +} + +impl Mul for BaseExpr { + type Output = ExtExpr; + fn mul(self, rhs: SecureField) -> ExtExpr { + ExtExpr::from(self) * ExtExpr::from(rhs) } } -impl Add for Expr { +impl Add for BaseExpr { + type Output = ExtExpr; + fn add(self, rhs: SecureField) -> ExtExpr { + ExtExpr::from(self) + ExtExpr::from(rhs) + } +} + +impl Sub for BaseExpr { + type Output = ExtExpr; + fn sub(self, rhs: SecureField) -> ExtExpr { + ExtExpr::from(self) - ExtExpr::from(rhs) + } +} + +impl Add for ExtExpr { type Output = Self; fn add(self, rhs: BaseField) -> Self { - self + Expr::from(rhs) + self + ExtExpr::from(rhs) } } -impl Mul for Expr { +impl AddAssign for ExtExpr { + fn add_assign(&mut self, rhs: BaseField) { + *self = self.clone() + ExtExpr::from(rhs) + } +} + +impl Mul for ExtExpr { type Output = Self; fn mul(self, rhs: BaseField) -> Self { - self * Expr::from(rhs) + self * ExtExpr::from(rhs) } } -impl Mul for Expr { +impl Mul for ExtExpr { type Output = Self; fn mul(self, rhs: SecureField) -> Self { - self * Expr::from(rhs) + self * ExtExpr::from(rhs) } } -impl Add for Expr { +impl Add for ExtExpr { type Output = Self; fn add(self, rhs: SecureField) -> Self { - self + Expr::from(rhs) + self + ExtExpr::from(rhs) } } -impl Sub for Expr { +impl Sub for ExtExpr { type Output = Self; fn sub(self, rhs: SecureField) -> Self { - self - Expr::from(rhs) + self - ExtExpr::from(rhs) } } -impl AddAssign for Expr { - fn add_assign(&mut self, rhs: BaseField) { - *self = self.clone() + Expr::from(rhs) - } -} - -const ZERO: M31 = M31(0); -const ONE: M31 = M31(1); -const MINUS_ONE: M31 = M31(m31::P - 1); - -// TODO(alont) Add random point assignment test. -pub fn simplify(expr: Expr) -> Expr { - match expr { - Expr::Add(a, b) => { - let a = simplify(*a); - let b = simplify(*b); - match (a.clone(), b.clone()) { - (Expr::Const(a), Expr::Const(b)) => Expr::Const(a + b), - (Expr::Const(ZERO), _) => b, // 0 + b = b - (_, Expr::Const(ZERO)) => a, // a + 0 = a - // (-a + -b) = -(a + b) - (Expr::Neg(minus_a), Expr::Neg(minus_b)) => -(*minus_a + *minus_b), - (Expr::Neg(minus_a), _) => b - *minus_a, // -a + b = b - a - (_, Expr::Neg(minus_b)) => a - *minus_b, // a + -b = a - b - _ => Expr::Add(Box::new(a), Box::new(b)), - } - } - Expr::Sub(a, b) => { - let a = simplify(*a); - let b = simplify(*b); - match (a.clone(), b.clone()) { - (Expr::Const(a), Expr::Const(b)) => Expr::Const(a - b), - (Expr::Const(ZERO), _) => -b, // 0 - b = -b - (_, Expr::Const(ZERO)) => a, // a - 0 = a - // (-a - -b) = b - a - (Expr::Neg(minus_a), Expr::Neg(minus_b)) => *minus_b - *minus_a, - (Expr::Neg(minus_a), _) => -(*minus_a + b), // -a - b = -(a + b) - (_, Expr::Neg(minus_b)) => a + *minus_b, // a + -b = a - b - _ => Expr::Sub(Box::new(a), Box::new(b)), - } - } - Expr::Mul(a, b) => { - let a = simplify(*a); - let b = simplify(*b); - match (a.clone(), b.clone()) { - (Expr::Const(a), Expr::Const(b)) => Expr::Const(a * b), - (Expr::Const(ZERO), _) => Expr::zero(), // 0 * b = 0 - (_, Expr::Const(ZERO)) => Expr::zero(), // a * 0 = 0 - (Expr::Const(ONE), _) => b, // 1 * b = b - (_, Expr::Const(ONE)) => a, // a * 1 = a - // (-a) * (-b) = a * b - (Expr::Neg(minus_a), Expr::Neg(minus_b)) => *minus_a * *minus_b, - (Expr::Neg(minus_a), _) => -(*minus_a * b), // (-a) * b = -(a * b) - (_, Expr::Neg(minus_b)) => -(a * *minus_b), // a * (-b) = -(a * b) - (Expr::Const(MINUS_ONE), _) => -b, // -1 * b = -b - (_, Expr::Const(MINUS_ONE)) => -a, // a * -1 = -a - _ => Expr::Mul(Box::new(a), Box::new(b)), - } - } - Expr::Col(colexpr) => Expr::Col(colexpr), - Expr::SecureCol([a, b, c, d]) => Expr::SecureCol([ - Box::new(simplify(*a)), - Box::new(simplify(*b)), - Box::new(simplify(*c)), - Box::new(simplify(*d)), - ]), - Expr::Const(c) => Expr::Const(c), - Expr::Param(x) => Expr::Param(x), - Expr::Neg(a) => { - let a = simplify(*a); - match a { - Expr::Const(c) => Expr::Const(-c), - Expr::Neg(minus_a) => *minus_a, // -(-a) = a - Expr::Sub(a, b) => Expr::Sub(b, a), // -(a - b) = b - a - _ => Expr::Neg(Box::new(a)), - } - } - Expr::Inv(a) => { - let a = simplify(*a); - match a { - Expr::Inv(inv_a) => *inv_a, // 1 / (1 / a) = a - Expr::Const(c) => Expr::Const(c.inverse()), - _ => Expr::Inv(Box::new(a)), - } - } +impl Add for ExtExpr { + type Output = Self; + fn add(self, rhs: BaseExpr) -> Self { + self + ExtExpr::from(rhs) + } +} + +impl Mul for ExtExpr { + type Output = Self; + fn mul(self, rhs: BaseExpr) -> Self { + self * ExtExpr::from(rhs) + } +} + +impl Mul for BaseExpr { + type Output = ExtExpr; + fn mul(self, rhs: ExtExpr) -> ExtExpr { + rhs * self + } +} + +impl Sub for ExtExpr { + type Output = Self; + fn sub(self, rhs: BaseExpr) -> Self { + self - ExtExpr::from(rhs) } } /// Returns the expression /// `value[0] * _alpha0 + value[1] * _alpha1 + ... - _z.` -fn combine_formal>(relation: &R, values: &[Expr]) -> Expr { +fn combine_formal>(relation: &R, values: &[BaseExpr]) -> ExtExpr { const Z_SUFFIX: &str = "_z"; const ALPHA_SUFFIX: &str = "_alpha"; - let z = Expr::Param(relation.get_name().to_owned() + Z_SUFFIX); + let z = ExtExpr::Param(relation.get_name().to_owned() + Z_SUFFIX); let alpha_powers = (0..relation.get_size()) - .map(|i| Expr::Param(relation.get_name().to_owned() + ALPHA_SUFFIX + &i.to_string())); + .map(|i| ExtExpr::Param(relation.get_name().to_owned() + ALPHA_SUFFIX + &i.to_string())); values .iter() .zip(alpha_powers) - .fold(Expr::zero(), |acc, (value, power)| { + .fold(ExtExpr::zero(), |acc, (value, power)| { acc + power * value.clone() }) - z @@ -295,12 +534,12 @@ fn combine_formal>(relation: &R, values: &[Expr]) -> Exp pub struct FormalLogupAtRow { pub interaction: usize, - pub total_sum: Expr, - pub claimed_sum: Option<(Expr, usize)>, - pub prev_col_cumsum: Expr, - pub cur_frac: Option>, + pub total_sum: ExtExpr, + pub claimed_sum: Option<(ExtExpr, usize)>, + pub prev_col_cumsum: ExtExpr, + pub cur_frac: Option>, pub is_finalized: bool, - pub is_first: Expr, + pub is_first: BaseExpr, pub log_size: u32, } @@ -316,13 +555,13 @@ impl FormalLogupAtRow { Self { interaction, // TODO(alont): Should these be Expr::SecureField? - total_sum: Expr::Param(total_sum_name), + total_sum: ExtExpr::Param(total_sum_name), claimed_sum: has_partial_sum - .then_some((Expr::Param(claimed_sum_name), CLAIMED_SUM_DUMMY_OFFSET)), - prev_col_cumsum: Expr::zero(), + .then_some((ExtExpr::Param(claimed_sum_name), CLAIMED_SUM_DUMMY_OFFSET)), + prev_col_cumsum: ExtExpr::zero(), cur_frac: None, is_finalized: true, - is_first: Expr::zero(), + is_first: BaseExpr::zero(), log_size, } } @@ -331,9 +570,9 @@ impl FormalLogupAtRow { /// An Evaluator that saves all constraint expressions. pub struct ExprEvaluator { pub cur_var_index: usize, - pub constraints: Vec, + pub constraints: Vec, pub logup: FormalLogupAtRow, - pub intermediates: Vec<(String, Expr)>, + pub intermediates: Vec<(String, ExtExpr)>, } impl ExprEvaluator { @@ -347,9 +586,9 @@ impl ExprEvaluator { } } - pub fn add_intermediate(&mut self, expr: Expr) -> Expr { + pub fn add_intermediate(&mut self, expr: ExtExpr) -> ExtExpr { let name = format!("intermediate{}", self.intermediates.len()); - let intermediate = Expr::Param(name.clone()); + let intermediate = ExtExpr::Param(name.clone()); self.intermediates.push((name, expr)); intermediate } @@ -376,8 +615,8 @@ impl ExprEvaluator { impl EvalAtRow for ExprEvaluator { // TODO(alont): Should there be a version of this that disallows Secure fields for F? - type F = Expr; - type EF = Expr; + type F = BaseExpr; + type EF = ExtExpr; fn next_interaction_mask( &mut self, @@ -391,7 +630,7 @@ impl EvalAtRow for ExprEvaluator { offset: offsets[i], }; self.cur_var_index += 1; - Expr::Col(col) + BaseExpr::Col(col) }) } @@ -403,7 +642,7 @@ impl EvalAtRow for ExprEvaluator { } fn combine_ef(values: [Self::F; 4]) -> Self::EF { - Expr::SecureCol([ + ExtExpr::SecureCol([ Box::new(values[0].clone()), Box::new(values[1].clone()), Box::new(values[2].clone()), @@ -458,7 +697,7 @@ mod tests { \ let constraint_1 = (SecureCol(col_2_4[0], col_2_6[0], col_2_8[0], col_2_10[0]) \ - (SecureCol(col_2_5[-1], col_2_7[-1], col_2_9[-1], col_2_11[-1]) \ - - ((col_0_3[0]) * (total_sum)))\ + - ((total_sum) * (col_0_3[0])))\ ) \ * (intermediate0) \ - (1);" diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index 2cf8bc2f2..8dbe3a068 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -320,7 +320,7 @@ mod tests { \ let constraint_1 = (SecureCol(col_2_3[0], col_2_6[0], col_2_9[0], col_2_12[0]) \ - (SecureCol(col_2_4[-1], col_2_7[-1], col_2_10[-1], col_2_13[-1]) \ - - ((col_0_2[0]) * (total_sum)))\ + - ((total_sum) * (col_0_2[0])))\ ) \ * ((intermediate0) * (intermediate1)) \ - (intermediate1 - (intermediate0));" From 4af2d445f35d9d0106cc6b74cf266c1cc57040c2 Mon Sep 17 00:00:00 2001 From: Alon-Ti <54235977+Alon-Ti@users.noreply.github.com> Date: Sun, 1 Dec 2024 17:37:40 +0200 Subject: [PATCH 09/69] Added eval on exprs. (#904) --- .../prover/src/constraint_framework/expr.rs | 133 +++++++++++++++++- 1 file changed, 126 insertions(+), 7 deletions(-) diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index ab32829cf..c459cee27 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -1,4 +1,4 @@ -use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub}; +use std::ops::{Add, AddAssign, Index, Mul, MulAssign, Neg, Sub}; use num_traits::{One, Zero}; @@ -17,6 +17,16 @@ pub struct ColumnExpr { offset: isize, } +impl From<(usize, usize, isize)> for ColumnExpr { + fn from((interaction, idx, offset): (usize, usize, isize)) -> Self { + Self { + interaction, + idx, + offset, + } + } +} + /// An expression representing a base field value. Can be either: /// * A column indexed by a `ColumnExpr`. /// * A base field constant. @@ -182,6 +192,34 @@ impl BaseExpr { pub fn simplify_and_format(&self) -> String { self.simplify().format_expr() } + + /// Evaluates a base field expression. + /// Takes: + /// * `columns`: A mapping from triplets (interaction, idx, offset) to base field values. + /// * `vars`: A mapping from variable names to base field values. + pub fn eval_expr(&self, columns: &C, vars: &V) -> E::F + where + C: for<'a> Index<&'a (usize, usize, isize), Output = E::F>, + V: for<'a> Index<&'a String, Output = E::F>, + E: EvalAtRow, + { + match self { + Self::Col(col) => columns[&(col.interaction, col.idx, col.offset)].clone(), + Self::Const(c) => E::F::from(*c), + Self::Param(var) => vars[&var.to_string()].clone(), + Self::Add(a, b) => { + a.eval_expr::(columns, vars) + b.eval_expr::(columns, vars) + } + Self::Sub(a, b) => { + a.eval_expr::(columns, vars) - b.eval_expr::(columns, vars) + } + Self::Mul(a, b) => { + a.eval_expr::(columns, vars) * b.eval_expr::(columns, vars) + } + Self::Neg(a) => -a.eval_expr::(columns, vars), + Self::Inv(a) => a.eval_expr::(columns, vars).inverse(), + } + } } impl ExtExpr { @@ -243,6 +281,44 @@ impl ExtExpr { pub fn simplify_and_format(&self) -> String { self.simplify().format_expr() } + + /// Evaluates an extension field expression. + /// Takes: + /// * `columns`: A mapping from triplets (interaction, idx, offset) to base field values. + /// * `vars`: A mapping from variable names to base field values. + /// * `ext_vars`: A mapping from variable names to extension field values. + pub fn eval_expr(&self, columns: &C, vars: &V, ext_vars: &EV) -> E::EF + where + C: for<'a> Index<&'a (usize, usize, isize), Output = E::F>, + V: for<'a> Index<&'a String, Output = E::F>, + EV: for<'a> Index<&'a String, Output = E::EF>, + E: EvalAtRow, + { + match self { + Self::SecureCol([a, b, c, d]) => { + let a = a.eval_expr::(columns, vars); + let b = b.eval_expr::(columns, vars); + let c = c.eval_expr::(columns, vars); + let d = d.eval_expr::(columns, vars); + E::combine_ef([a, b, c, d]) + } + Self::Const(c) => E::EF::from(*c), + Self::Param(var) => ext_vars[&var.to_string()].clone(), + Self::Add(a, b) => { + a.eval_expr::(columns, vars, ext_vars) + + b.eval_expr::(columns, vars, ext_vars) + } + Self::Sub(a, b) => { + a.eval_expr::(columns, vars, ext_vars) + - b.eval_expr::(columns, vars, ext_vars) + } + Self::Mul(a, b) => { + a.eval_expr::(columns, vars, ext_vars) + * b.eval_expr::(columns, vars, ext_vars) + } + Self::Neg(a) => -a.eval_expr::(columns, vars, ext_vars), + } + } } impl From for BaseExpr { @@ -624,11 +700,7 @@ impl EvalAtRow for ExprEvaluator { offsets: [isize; N], ) -> [Self::F; N] { std::array::from_fn(|i| { - let col = ColumnExpr { - interaction, - idx: self.cur_var_index, - offset: offsets[i], - }; + let col = ColumnExpr::from((interaction, self.cur_var_index, offsets[i])); self.cur_var_index += 1; BaseExpr::Col(col) }) @@ -675,12 +747,59 @@ impl EvalAtRow for ExprEvaluator { #[cfg(test)] mod tests { + use std::collections::HashMap; + use num_traits::One; + use super::{BaseExpr, ExtExpr}; use crate::constraint_framework::expr::ExprEvaluator; - use crate::constraint_framework::{relation, EvalAtRow, FrameworkEval, RelationEntry}; + use crate::constraint_framework::{ + relation, AssertEvaluator, EvalAtRow, FrameworkEval, RelationEntry, + }; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; + #[test] + fn test_eval_expr() { + let col_1_0_0 = BaseField::from(12); + let col_1_1_0 = BaseField::from(5); + let var_a = BaseField::from(3); + let var_b = BaseField::from(4); + let var_c = SecureField::from_m31_array([ + BaseField::from(1), + BaseField::from(2), + BaseField::from(3), + BaseField::from(4), + ]); + + let columns: HashMap<(usize, usize, isize), BaseField> = + HashMap::from([((1, 0, 0), col_1_0_0), ((1, 1, 0), col_1_1_0)]); + let vars = HashMap::from([("a".to_string(), var_a), ("b".to_string(), var_b)]); + let ext_vars = HashMap::from([("c".to_string(), var_c)]); + + let expr = ExtExpr::SecureCol([ + Box::new(BaseExpr::Col((1, 0, 0).into()) - BaseExpr::Col((1, 1, 0).into())), + Box::new(BaseExpr::Col((1, 1, 0).into()) * (-BaseExpr::Param("a".to_string()))), + Box::new(BaseExpr::Param("a".to_string()) + BaseExpr::Param("a".to_string()).inverse()), + Box::new(BaseExpr::Param("b".to_string()) * BaseExpr::Const(BaseField::from(7))), + ]) + ExtExpr::Param("c".to_string()) * ExtExpr::Param("c".to_string()) + - ExtExpr::Const(SecureField::one()); + + let expected = SecureField::from_m31_array([ + col_1_0_0 - col_1_1_0, + col_1_1_0 * (-var_a), + var_a + var_a.inverse(), + var_b * BaseField::from(7), + ]) + var_c * var_c + - SecureField::one(); + + assert_eq!( + expr.eval_expr::, _, _, _>(&columns, &vars, &ext_vars), + expected + ); + } + #[test] fn test_format_expr() { let test_struct = TestStruct {}; From a459c485ae9e6da79c1d52f4793b459eed529e4e Mon Sep 17 00:00:00 2001 From: Alon-Ti <54235977+Alon-Ti@users.noreply.github.com> Date: Mon, 2 Dec 2024 10:12:54 +0200 Subject: [PATCH 10/69] Added test for expression simplifier. (#908) --- .../prover/src/constraint_framework/expr.rs | 119 ++++++++++++++++-- 1 file changed, 111 insertions(+), 8 deletions(-) diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index c459cee27..95f9f3342 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -750,16 +750,64 @@ mod tests { use std::collections::HashMap; use num_traits::One; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; use super::{BaseExpr, ExtExpr}; use crate::constraint_framework::expr::ExprEvaluator; use crate::constraint_framework::{ relation, AssertEvaluator, EvalAtRow, FrameworkEval, RelationEntry, }; - use crate::core::fields::m31::BaseField; + use crate::core::fields::m31::{self, BaseField}; use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; + macro_rules! secure_col { + ($a:expr, $b:expr, $c:expr, $d:expr) => { + ExtExpr::SecureCol([ + Box::new($a.into()), + Box::new($b.into()), + Box::new($c.into()), + Box::new($d.into()), + ]) + }; + } + + macro_rules! col { + ($interaction:expr, $idx:expr, $offset:expr) => { + BaseExpr::Col(($interaction, $idx, $offset).into()) + }; + } + + macro_rules! var { + ($var:expr) => { + BaseExpr::Param($var.to_string()) + }; + } + + macro_rules! qvar { + ($var:expr) => { + ExtExpr::Param($var.to_string()) + }; + } + + macro_rules! felt { + ($val:expr) => { + BaseExpr::Const($val.into()) + }; + } + + macro_rules! qfelt { + ($a:expr, $b:expr, $c:expr, $d:expr) => { + ExtExpr::Const(SecureField::from_m31_array([ + $a.into(), + $b.into(), + $c.into(), + $d.into(), + ])) + }; + } + #[test] fn test_eval_expr() { let col_1_0_0 = BaseField::from(12); @@ -778,13 +826,13 @@ mod tests { let vars = HashMap::from([("a".to_string(), var_a), ("b".to_string(), var_b)]); let ext_vars = HashMap::from([("c".to_string(), var_c)]); - let expr = ExtExpr::SecureCol([ - Box::new(BaseExpr::Col((1, 0, 0).into()) - BaseExpr::Col((1, 1, 0).into())), - Box::new(BaseExpr::Col((1, 1, 0).into()) * (-BaseExpr::Param("a".to_string()))), - Box::new(BaseExpr::Param("a".to_string()) + BaseExpr::Param("a".to_string()).inverse()), - Box::new(BaseExpr::Param("b".to_string()) * BaseExpr::Const(BaseField::from(7))), - ]) + ExtExpr::Param("c".to_string()) * ExtExpr::Param("c".to_string()) - - ExtExpr::Const(SecureField::one()); + let expr = secure_col!( + col!(1, 0, 0) - col!(1, 1, 0), + col!(1, 1, 0) * (-var!("a")), + var!("a") + var!("a").inverse(), + var!("b") * felt!(7) + ) + qvar!("c") * qvar!("c") + - qfelt!(1, 0, 0, 0); let expected = SecureField::from_m31_array([ col_1_0_0 - col_1_1_0, @@ -800,6 +848,61 @@ mod tests { ); } + #[test] + fn test_simplify_expr() { + let c0 = col!(1, 0, 0); + let c1 = col!(1, 1, 0); + let a = var!("a"); + let b = qvar!("b"); + let zero = felt!(0); + let qzero = qfelt!(0, 0, 0, 0); + let one = felt!(1); + let qone = qfelt!(1, 0, 0, 0); + let minus_one = felt!(m31::P - 1); + let qminus_one = qfelt!(m31::P - 1, 0, 0, 0); + + let mut rng = SmallRng::seed_from_u64(0); + let columns: HashMap<(usize, usize, isize), BaseField> = + HashMap::from([((1, 0, 0), rng.gen()), ((1, 1, 0), rng.gen())]); + let vars: HashMap = HashMap::from([("a".to_string(), rng.gen())]); + let ext_vars: HashMap = HashMap::from([("b".to_string(), rng.gen())]); + + let base_expr = (((zero.clone() + c0.clone()) + (a.clone() + zero.clone())) + * ((-c1.clone()) + (-c0.clone())) + + (-(-(a.clone() + a.clone() + c0.clone()))) + - zero.clone()) + + (a.clone() - zero.clone()) + + (-c1.clone() - (a.clone() * a.clone())) + + (a.clone() * zero.clone()) + - (zero.clone() * c1.clone()) + + one.clone() + * a.clone() + * one.clone() + * c1.clone() + * (-a.clone()) + * c1.clone() + * (minus_one.clone() * c0.clone()); + + let expr = (qzero.clone() + + secure_col!( + base_expr.clone(), + base_expr.clone(), + zero.clone(), + one.clone() + ) + - qzero.clone()) + * qone.clone() + * b.clone() + * qminus_one.clone(); + + let full_eval = expr.eval_expr::, _, _, _>(&columns, &vars, &ext_vars); + let simplified_eval = expr + .simplify() + .eval_expr::, _, _, _>(&columns, &vars, &ext_vars); + + assert_eq!(full_eval, simplified_eval); + } + #[test] fn test_format_expr() { let test_struct = TestStruct {}; From e21f74f957b7263b15149dae1695469aa1f44cfa Mon Sep 17 00:00:00 2001 From: Ohad <137686240+ohad-starkware@users.noreply.github.com> Date: Mon, 25 Nov 2024 13:08:16 +0200 Subject: [PATCH 11/69] relation tracker eval --- crates/prover/src/constraint_framework/mod.rs | 1 + .../constraint_framework/relation_tracker.rs | 189 ++++++++++++++++++ crates/prover/src/core/pcs/utils.rs | 7 + 3 files changed, 197 insertions(+) create mode 100644 crates/prover/src/constraint_framework/relation_tracker.rs diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 044f05ac8..341e24511 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -7,6 +7,7 @@ mod info; pub mod logup; mod point; pub mod preprocessed_columns; +pub mod relation_tracker; mod simd_domain; use std::array; diff --git a/crates/prover/src/constraint_framework/relation_tracker.rs b/crates/prover/src/constraint_framework/relation_tracker.rs new file mode 100644 index 000000000..df3996d63 --- /dev/null +++ b/crates/prover/src/constraint_framework/relation_tracker.rs @@ -0,0 +1,189 @@ +use std::fmt::Debug; + +use itertools::Itertools; +use num_traits::Zero; + +use super::logup::LogupSums; +use super::{ + EvalAtRow, FrameworkEval, InfoEvaluator, Relation, RelationEntry, TraceLocationAllocator, + INTERACTION_TRACE_IDX, +}; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::very_packed_m31::LOG_N_VERY_PACKED_ELEMS; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::fields::m31::{BaseField, M31}; +use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; +use crate::core::lookups::utils::Fraction; +use crate::core::pcs::{TreeSubspan, TreeVec}; +use crate::core::poly::circle::CircleEvaluation; +use crate::core::poly::BitReversedOrder; +use crate::core::utils::{ + bit_reverse_index, coset_index_to_circle_domain_index, offset_bit_reversed_circle_domain_index, +}; + +#[derive(Debug)] +pub struct RelationTrackerEntry { + pub relation: String, + pub mult: M31, + pub values: Vec, +} + +pub struct RelationTrackerComponent { + eval: E, + trace_locations: TreeVec, + n_rows: usize, +} +impl RelationTrackerComponent { + pub fn new(location_allocator: &mut TraceLocationAllocator, eval: E, n_rows: usize) -> Self { + let info = eval.evaluate(InfoEvaluator::new( + eval.log_size(), + vec![], + LogupSums::default(), + )); + let mut mask_offsets = info.mask_offsets; + mask_offsets.drain(INTERACTION_TRACE_IDX..); + let trace_locations = location_allocator.next_for_structure(&mask_offsets); + Self { + eval, + trace_locations, + n_rows, + } + } + + pub fn entries( + self, + trace: &TreeVec>>, + ) -> Vec { + let log_size = self.eval.log_size(); + + // Deref the sub-tree. Only copies the references. + let sub_tree = trace + .sub_tree(&self.trace_locations) + .map(|vec| vec.into_iter().copied().collect_vec()); + let mut entries = vec![]; + + for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + let evaluator = + RelationTrackerEvaluator::new(&sub_tree, vec_row, log_size, self.n_rows); + entries.extend(self.eval.evaluate(evaluator).entries()); + } + entries + } +} + +/// Aggregates relation entries. +// TODO(Ohad): write a summarize function, test. +pub struct RelationTrackerEvaluator<'a> { + entries: Vec, + pub trace_eval: + &'a TreeVec>>, + pub column_index_per_interaction: Vec, + pub vec_row: usize, + pub domain_log_size: u32, + pub n_rows: usize, +} +impl<'a> RelationTrackerEvaluator<'a> { + pub fn new( + trace_eval: &'a TreeVec>>, + vec_row: usize, + domain_log_size: u32, + n_rows: usize, + ) -> Self { + Self { + entries: vec![], + trace_eval, + column_index_per_interaction: vec![0; trace_eval.len()], + vec_row, + domain_log_size, + n_rows, + } + } + + pub fn entries(self) -> Vec { + self.entries + } +} +impl<'a> EvalAtRow for RelationTrackerEvaluator<'a> { + type F = PackedBaseField; + type EF = PackedSecureField; + + // TODO(Ohad): Add debug boundary checks. + fn next_interaction_mask( + &mut self, + interaction: usize, + offsets: [isize; N], + ) -> [Self::F; N] { + assert_ne!(interaction, INTERACTION_TRACE_IDX); + let col_index = self.column_index_per_interaction[interaction]; + self.column_index_per_interaction[interaction] += 1; + offsets.map(|off| { + // If the offset is 0, we can just return the value directly from this row. + if off == 0 { + unsafe { + let col = &self + .trace_eval + .get_unchecked(interaction) + .get_unchecked(col_index) + .values; + return *col.data.get_unchecked(self.vec_row); + }; + } + // Otherwise, we need to look up the value at the offset. + // Since the domain is bit-reversed circle domain ordered, we need to look up the value + // at the bit-reversed natural order index at an offset. + PackedBaseField::from_array(std::array::from_fn(|i| { + let row_index = offset_bit_reversed_circle_domain_index( + (self.vec_row << (LOG_N_LANES + LOG_N_VERY_PACKED_ELEMS)) + i, + self.domain_log_size, + self.domain_log_size, + off, + ); + self.trace_eval[interaction][col_index].at(row_index) + })) + }) + } + fn add_constraint(&mut self, _constraint: G) {} + + fn combine_ef(_values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF { + PackedSecureField::zero() + } + + fn write_logup_frac(&mut self, _fraction: Fraction) {} + + fn finalize_logup(&mut self) {} + + fn add_to_relation>( + &mut self, + entries: &[RelationEntry<'_, Self::F, Self::EF, R>], + ) { + for entry in entries { + let relation = entry.relation.get_name().to_owned(); + let values = entry.values.iter().map(|v| v.to_array()).collect_vec(); + let mult = entry.multiplicity.to_array(); + + // Unpack SIMD. + for j in 0..N_LANES { + // Skip padded values. + let cannonical_index = bit_reverse_index( + coset_index_to_circle_domain_index( + (self.vec_row << LOG_N_LANES) + j, + self.domain_log_size, + ), + self.domain_log_size, + ); + if cannonical_index >= self.n_rows { + continue; + } + let values = values.iter().map(|v| v[j]).collect_vec(); + let mult = mult[j].to_m31_array()[0]; + self.entries.push(RelationTrackerEntry { + relation: relation.clone(), + mult, + values, + }); + } + } + } +} diff --git a/crates/prover/src/core/pcs/utils.rs b/crates/prover/src/core/pcs/utils.rs index 36ef3a198..73c624f81 100644 --- a/crates/prover/src/core/pcs/utils.rs +++ b/crates/prover/src/core/pcs/utils.rs @@ -41,6 +41,13 @@ impl<'a, T> From<&'a TreeVec> for TreeVec<&'a T> { } } +/// Converts `&TreeVec<&Vec>` to `TreeVec>`. +impl<'a, T> From<&'a TreeVec<&'a Vec>> for TreeVec> { + fn from(val: &'a TreeVec<&'a Vec>) -> Self { + TreeVec(val.iter().map(|vec| vec.iter().collect()).collect()) + } +} + impl Deref for TreeVec { type Target = Vec; fn deref(&self) -> &Self::Target { From f7de6145106ede80091325aa8c32fde543ef7cc8 Mon Sep 17 00:00:00 2001 From: Ohad Agadi Date: Thu, 28 Nov 2024 14:43:35 +0200 Subject: [PATCH 12/69] relation summary --- .../constraint_framework/relation_tracker.rs | 45 ++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/crates/prover/src/constraint_framework/relation_tracker.rs b/crates/prover/src/constraint_framework/relation_tracker.rs index df3996d63..b5220130a 100644 --- a/crates/prover/src/constraint_framework/relation_tracker.rs +++ b/crates/prover/src/constraint_framework/relation_tracker.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::fmt::Debug; use itertools::Itertools; @@ -74,7 +75,7 @@ impl RelationTrackerComponent { } /// Aggregates relation entries. -// TODO(Ohad): write a summarize function, test. +// TODO(Ohad): test. pub struct RelationTrackerEvaluator<'a> { entries: Vec, pub trace_eval: @@ -187,3 +188,45 @@ impl<'a> EvalAtRow for RelationTrackerEvaluator<'a> { } } } + +type RelationInfo = (String, Vec<(Vec, M31)>); +pub struct RelationSummary(Vec); +impl RelationSummary { + /// Returns the sum of every entry's yields and uses. + /// The result is a map from relation name to a list of values(M31 vectors) and their sum. + pub fn summarize_relations(entries: &[RelationTrackerEntry]) -> Self { + let mut summary = vec![]; + let relations = entries.iter().group_by(|entry| entry.relation.clone()); + for (relation, entries) in &relations { + let mut relation_sums: HashMap, M31> = HashMap::new(); + for entry in entries { + let mult = relation_sums + .entry(entry.values.clone()) + .or_insert(M31::zero()); + *mult += entry.mult; + } + let relation_sums = relation_sums.into_iter().collect_vec(); + summary.push((relation.clone(), relation_sums)); + } + Self(summary) + } + + pub fn get_relation_info(&self, relation: &str) -> Option<&[(Vec, M31)]> { + self.0 + .iter() + .find(|(name, _)| name == relation) + .map(|(_, entries)| entries.as_slice()) + } +} +impl Debug for RelationSummary { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for (relation, entries) in &self.0 { + writeln!(f, "{}:", relation)?; + for (vector, sum) in entries { + let vector = vector.iter().map(|v| v.0).collect_vec(); + writeln!(f, " {:?} -> {}", vector, sum)?; + } + } + Ok(()) + } +} From 5f5a02f4b166560fb4846e7310b30e034e820e75 Mon Sep 17 00:00:00 2001 From: Ohad <137686240+ohad-starkware@users.noreply.github.com> Date: Mon, 25 Nov 2024 13:08:16 +0200 Subject: [PATCH 13/69] relation tracker --- .../constraint_framework/relation_tracker.rs | 1 - .../src/examples/state_machine/components.rs | 48 +++++++++++- .../prover/src/examples/state_machine/mod.rs | 76 ++++++++++++++++--- 3 files changed, 111 insertions(+), 14 deletions(-) diff --git a/crates/prover/src/constraint_framework/relation_tracker.rs b/crates/prover/src/constraint_framework/relation_tracker.rs index b5220130a..3866df39a 100644 --- a/crates/prover/src/constraint_framework/relation_tracker.rs +++ b/crates/prover/src/constraint_framework/relation_tracker.rs @@ -75,7 +75,6 @@ impl RelationTrackerComponent { } /// Aggregates relation entries. -// TODO(Ohad): test. pub struct RelationTrackerEvaluator<'a> { entries: Vec, pub trace_eval: diff --git a/crates/prover/src/examples/state_machine/components.rs b/crates/prover/src/examples/state_machine/components.rs index 2451eef23..4600a3cf0 100644 --- a/crates/prover/src/examples/state_machine/components.rs +++ b/crates/prover/src/examples/state_machine/components.rs @@ -1,16 +1,21 @@ use num_traits::{One, Zero}; use crate::constraint_framework::logup::ClaimedPrefixSum; +use crate::constraint_framework::relation_tracker::{ + RelationTrackerComponent, RelationTrackerEntry, +}; use crate::constraint_framework::{ relation, EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator, RelationEntry, - PREPROCESSED_TRACE_IDX, + TraceLocationAllocator, PREPROCESSED_TRACE_IDX, }; use crate::core::air::{Component, ComponentProver}; use crate::core::backend::simd::SimdBackend; use crate::core::channel::Channel; -use crate::core::fields::m31::M31; +use crate::core::fields::m31::{BaseField, M31}; use crate::core::fields::qm31::{SecureField, QM31}; use crate::core::pcs::TreeVec; +use crate::core::poly::circle::CircleEvaluation; +use crate::core::poly::BitReversedOrder; use crate::core::prover::StarkProof; use crate::core::vcs::ops::MerkleHasher; @@ -124,6 +129,45 @@ impl StateMachineComponents { } } +pub fn track_state_machine_relations( + trace: &TreeVec<&Vec>>, + x_axis_log_n_rows: u32, + y_axis_log_n_rows: u32, + n_rows_x: u32, + n_rows_y: u32, +) -> Vec { + let tree_span_provider = &mut TraceLocationAllocator::default(); + let mut entries = vec![]; + entries.extend( + RelationTrackerComponent::new( + tree_span_provider, + StateTransitionEval::<0> { + log_n_rows: x_axis_log_n_rows, + lookup_elements: StateMachineElements::dummy(), + total_sum: QM31::zero(), + claimed_sum: (QM31::zero(), 0), + }, + n_rows_x as usize, + ) + .entries(&trace.into()), + ); + entries.extend( + RelationTrackerComponent::new( + tree_span_provider, + StateTransitionEval::<1> { + log_n_rows: y_axis_log_n_rows, + lookup_elements: StateMachineElements::dummy(), + total_sum: QM31::zero(), + claimed_sum: (QM31::zero(), 0), + }, + n_rows_y as usize, + ) + .entries(&trace.into()), + ); + + entries +} + pub struct StateMachineProof { pub public_input: [State; 2], // Initial and final state. pub stmt0: StateMachineStatement0, diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index 8dbe3a068..23973a960 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -1,11 +1,12 @@ +use crate::constraint_framework::relation_tracker::RelationSummary; use crate::constraint_framework::Relation; pub mod components; pub mod gen; use components::{ - State, StateMachineComponents, StateMachineElements, StateMachineOp0Component, - StateMachineOp1Component, StateMachineProof, StateMachineStatement0, StateMachineStatement1, - StateTransitionEval, + track_state_machine_relations, State, StateMachineComponents, StateMachineElements, + StateMachineOp0Component, StateMachineOp1Component, StateMachineProof, StateMachineStatement0, + StateMachineStatement1, StateTransitionEval, }; use gen::{gen_interaction_trace, gen_trace}; use itertools::{chain, Itertools}; @@ -19,7 +20,7 @@ use crate::core::backend::simd::SimdBackend; use crate::core::channel::Blake2sChannel; use crate::core::fields::m31::M31; use crate::core::fields::qm31::QM31; -use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig}; +use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig, TreeVec}; use crate::core::poly::circle::{CanonicCoset, PolyOps}; use crate::core::prover::{prove, verify, VerificationError}; use crate::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher}; @@ -30,9 +31,11 @@ pub fn prove_state_machine( initial_state: State, config: PcsConfig, channel: &mut Blake2sChannel, + track_relations: bool, ) -> ( StateMachineComponents, StateMachineProof, + Option, ) { let (x_axis_log_rows, y_axis_log_rows) = (log_n_rows, log_n_rows - 1); let (x_row, y_row) = (34, 56); @@ -62,14 +65,32 @@ pub fn prove_state_machine( ]; // Preprocessed trace. - let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(gen_preprocessed_columns(preprocessed_columns.iter())); - tree_builder.commit(channel); + let preprocessed_trace = gen_preprocessed_columns(preprocessed_columns.iter()); // Trace. let trace_op0 = gen_trace(x_axis_log_rows, initial_state, 0); let trace_op1 = gen_trace(y_axis_log_rows, intermediate_state, 1); + let trace = chain![trace_op0.clone(), trace_op1.clone()].collect_vec(); + + let relation_summary = match track_relations { + false => None, + true => Some(RelationSummary::summarize_relations( + &track_state_machine_relations( + &TreeVec(vec![&preprocessed_trace, &trace]), + x_axis_log_rows, + y_axis_log_rows, + x_row, + y_row, + ), + )), + }; + + // Commitments. + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(preprocessed_trace); + tree_builder.commit(channel); + let stmt0 = StateMachineStatement0 { n: x_axis_log_rows, m: y_axis_log_rows, @@ -135,7 +156,7 @@ pub fn prove_state_machine( stmt1, stark_proof, }; - (components, proof) + (components, proof, relation_summary) } pub fn verify_state_machine( @@ -250,7 +271,8 @@ mod tests { // Setup protocol. let channel = &mut Blake2sChannel::default(); - let (component, _) = prove_state_machine(log_n_rows, initial_state, config, channel); + let (component, ..) = + prove_state_machine(log_n_rows, initial_state, config, channel, false); let interaction_elements = component.component0.lookup_elements.clone(); let initial_state_comb: QM31 = interaction_elements.combine(&initial_state); @@ -262,6 +284,38 @@ mod tests { ); } + #[test] + fn test_relation_tracker() { + let log_n_rows = 8; + let config = PcsConfig::default(); + let initial_state = [M31::zero(); STATE_SIZE]; + let final_state = [M31::from_u32_unchecked(34), M31::from_u32_unchecked(56)]; + + // Summarize `StateMachineElements`. + let (_, _, summary) = prove_state_machine( + log_n_rows, + initial_state, + config, + &mut Blake2sChannel::default(), + true, + ); + let summary = summary.unwrap(); + let relation_info = summary.get_relation_info("StateMachineElements").unwrap(); + + // Check the final state inferred from the summary. + let mut curr_state = initial_state; + for entry in relation_info { + let x_step = entry.0[0]; + let y_step = entry.0[1]; + let mult = entry.1; + let next_state = [curr_state[0] - x_step * mult, curr_state[1] - y_step * mult]; + + curr_state = next_state; + } + + assert_eq!(curr_state, final_state); + } + #[test] fn test_state_machine_prove() { let log_n_rows = 8; @@ -270,8 +324,8 @@ mod tests { let prover_channel = &mut Blake2sChannel::default(); let verifier_channel = &mut Blake2sChannel::default(); - let (components, proof) = - prove_state_machine(log_n_rows, initial_state, config, prover_channel); + let (components, proof, _) = + prove_state_machine(log_n_rows, initial_state, config, prover_channel, false); verify_state_machine(config, verifier_channel, components, proof).unwrap(); } From df4a6427679370164fe54d395fd174a596f954b9 Mon Sep 17 00:00:00 2001 From: Ohad <137686240+ohad-starkware@users.noreply.github.com> Date: Mon, 2 Dec 2024 13:31:51 +0200 Subject: [PATCH 14/69] logup sum (#917) --- crates/prover/src/constraint_framework/logup.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index 650607c1d..d6af96e46 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -27,6 +27,16 @@ pub type ClaimedPrefixSum = (SecureField, usize); // (total_sum, claimed_sum) pub type LogupSums = (SecureField, Option); +pub trait LogupSumsExt { + fn value(&self) -> SecureField; +} + +impl LogupSumsExt for LogupSums { + fn value(&self) -> SecureField { + self.1.map(|(claimed_sum, _)| claimed_sum).unwrap_or(self.0) + } +} + /// Evaluates constraints for batched logups. /// These constraint enforce the sum of multiplicity_i / (z + sum_j alpha^j * x_j) = claimed_sum. pub struct LogupAtRow { From bf21504fc0690ebd6911888230dfe70cfb64baad Mon Sep 17 00:00:00 2001 From: Alon Titelman Date: Sun, 1 Dec 2024 19:10:20 +0200 Subject: [PATCH 15/69] Add safe simplify for expressions that compares random assignments before and after. --- .../prover/src/constraint_framework/expr.rs | 161 +++++++++++++++++- 1 file changed, 152 insertions(+), 9 deletions(-) diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index 95f9f3342..e01a03818 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -1,8 +1,12 @@ +use std::collections::{HashMap, HashSet}; use std::ops::{Add, AddAssign, Index, Mul, MulAssign, Neg, Sub}; +use itertools::sorted; use num_traits::{One, Zero}; +use rand::rngs::SmallRng; +use rand::{Rng, SeedableRng}; -use super::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX}; +use super::{AssertEvaluator, EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX}; use crate::core::fields::cm31::CM31; use crate::core::fields::m31::{self, BaseField}; use crate::core::fields::qm31::{SecureField, QM31}; @@ -10,7 +14,7 @@ use crate::core::fields::FieldExpOps; use crate::core::lookups::utils::Fraction; /// A single base field column at index `idx` of interaction `interaction`, at mask offset `offset`. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct ColumnExpr { interaction: usize, idx: usize, @@ -174,11 +178,14 @@ impl BaseExpr { } } - pub fn simplify(&self) -> Self { + /// Helper function, use [`simplify`] instead. + /// + /// Simplifies an expression by applying basic arithmetic rules. + fn unchecked_simplify(&self) -> Self { let simple = simplify_arithmetic!(self); match simple { Self::Inv(a) => { - let a = a.simplify(); + let a = a.unchecked_simplify(); match a { Self::Inv(inv_a) => *inv_a, // 1 / (1 / a) = a Self::Const(c) => Self::Const(c.inverse()), @@ -189,6 +196,14 @@ impl BaseExpr { } } + /// Simplifies an expression by applying basic arithmetic rules and ensures that the result is + /// equivalent to the original expression by assigning random values. + pub fn simplify(&self) -> Self { + let simplified = self.unchecked_simplify(); + assert_eq!(self.random_eval(), simplified.random_eval()); + simplified + } + pub fn simplify_and_format(&self) -> String { self.simplify().format_expr() } @@ -220,6 +235,25 @@ impl BaseExpr { Self::Inv(a) => a.eval_expr::(columns, vars).inverse(), } } + + pub fn collect_variables(&self) -> ExprVariables { + match self { + BaseExpr::Col(col) => ExprVariables::col(col.clone()), + BaseExpr::Const(_) => ExprVariables::default(), + BaseExpr::Param(param) => ExprVariables::param(param.to_string()), + BaseExpr::Add(a, b) => a.collect_variables() + b.collect_variables(), + BaseExpr::Sub(a, b) => a.collect_variables() + b.collect_variables(), + BaseExpr::Mul(a, b) => a.collect_variables() + b.collect_variables(), + BaseExpr::Neg(a) => a.collect_variables(), + BaseExpr::Inv(a) => a.collect_variables(), + } + } + + pub fn random_eval(&self) -> BaseField { + let assignment = self.collect_variables().random_assignment(); + assert!(assignment.2.is_empty()); + self.eval_expr::, _, _>(&assignment.0, &assignment.1) + } } impl ExtExpr { @@ -256,14 +290,17 @@ impl ExtExpr { } } - pub fn simplify(&self) -> Self { + /// Helper function, use [`simplify`] instead. + /// + /// Simplifies an expression by applying basic arithmetic rules. + fn unchecked_simplify(&self) -> Self { let simple = simplify_arithmetic!(self); match simple { Self::SecureCol([a, b, c, d]) => { - let a = a.simplify(); - let b = b.simplify(); - let c = c.simplify(); - let d = d.simplify(); + let a = a.unchecked_simplify(); + let b = b.unchecked_simplify(); + let c = c.unchecked_simplify(); + let d = d.unchecked_simplify(); match (a.clone(), b.clone(), c.clone(), d.clone()) { ( BaseExpr::Const(a_val), @@ -278,6 +315,14 @@ impl ExtExpr { } } + /// Simplifies an expression by applying basic arithmetic rules and ensures that the result is + /// equivalent to the original expression by assigning random values. + pub fn simplify(&self) -> Self { + let simplified = self.unchecked_simplify(); + assert_eq!(self.random_eval(), simplified.random_eval()); + simplified + } + pub fn simplify_and_format(&self) -> String { self.simplify().format_expr() } @@ -319,6 +364,104 @@ impl ExtExpr { Self::Neg(a) => -a.eval_expr::(columns, vars, ext_vars), } } + + pub fn collect_variables(&self) -> ExprVariables { + match self { + ExtExpr::SecureCol([a, b, c, d]) => { + a.collect_variables() + + b.collect_variables() + + c.collect_variables() + + d.collect_variables() + } + ExtExpr::Const(_) => ExprVariables::default(), + ExtExpr::Param(param) => ExprVariables::ext_param(param.to_string()), + ExtExpr::Add(a, b) => a.collect_variables() + b.collect_variables(), + ExtExpr::Sub(a, b) => a.collect_variables() + b.collect_variables(), + ExtExpr::Mul(a, b) => a.collect_variables() + b.collect_variables(), + ExtExpr::Neg(a) => a.collect_variables(), + } + } + + pub fn random_eval(&self) -> SecureField { + let assignment = self.collect_variables().random_assignment(); + self.eval_expr::, _, _, _>(&assignment.0, &assignment.1, &assignment.2) + } +} + +/// An assignment to the variables that may appear in an expression. +pub type ExprVarAssignment = ( + HashMap<(usize, usize, isize), BaseField>, + HashMap, + HashMap, +); + +/// Three sets representing all the variables that can appear in an expression: +/// * `cols`: The columns of the AIR. +/// * `params`: The formal parameters to the AIR. +/// * `ext_params`: The extension field parameters to the AIR. +#[derive(Default)] +pub struct ExprVariables { + pub cols: HashSet, + pub params: HashSet, + pub ext_params: HashSet, +} + +impl ExprVariables { + pub fn col(col: ColumnExpr) -> Self { + Self { + cols: vec![col].into_iter().collect(), + params: HashSet::new(), + ext_params: HashSet::new(), + } + } + + pub fn param(param: String) -> Self { + Self { + cols: HashSet::new(), + params: vec![param].into_iter().collect(), + ext_params: HashSet::new(), + } + } + + pub fn ext_param(param: String) -> Self { + Self { + cols: HashSet::new(), + params: HashSet::new(), + ext_params: vec![param].into_iter().collect(), + } + } + + /// Generates a random assignment to the variables. + /// Note that the assignment is deterministic in the sets of variables (disregarding their + /// order), and this is required. + pub fn random_assignment(&self) -> ExprVarAssignment { + let mut rng = SmallRng::seed_from_u64(0); + + let cols = sorted(self.cols.iter()) + .map(|col| ((col.interaction, col.idx, col.offset), rng.gen())) + .collect(); + + let params = sorted(self.params.iter()) + .map(|param| (param.clone(), rng.gen())) + .collect(); + + let ext_params = sorted(self.ext_params.iter()) + .map(|param| (param.clone(), rng.gen())) + .collect(); + + (cols, params, ext_params) + } +} + +impl Add for ExprVariables { + type Output = Self; + fn add(self, rhs: Self) -> Self { + Self { + cols: self.cols.union(&rhs.cols).cloned().collect(), + params: self.params.union(&rhs.params).cloned().collect(), + ext_params: self.ext_params.union(&rhs.ext_params).cloned().collect(), + } + } } impl From for BaseExpr { From 6e7d2aa687071cd049c3ddac0cea778dfe755497 Mon Sep 17 00:00:00 2001 From: ilyalesokhin-starkware Date: Tue, 3 Dec 2024 13:36:00 +0200 Subject: [PATCH 16/69] Rearrange queried_values_by_layer for merkle. (#902) --- crates/prover/src/core/fri.rs | 28 +++--- crates/prover/src/core/pcs/prover.rs | 4 +- crates/prover/src/core/pcs/quotients.rs | 46 +++++----- crates/prover/src/core/pcs/verifier.rs | 18 ++-- crates/prover/src/core/vcs/blake2_merkle.rs | 10 +-- .../prover/src/core/vcs/poseidon252_merkle.rs | 14 +-- crates/prover/src/core/vcs/prover.rs | 51 ++--------- crates/prover/src/core/vcs/test_utils.rs | 7 +- crates/prover/src/core/vcs/verifier.rs | 90 +++++++++---------- 9 files changed, 113 insertions(+), 155 deletions(-) diff --git a/crates/prover/src/core/fri.rs b/crates/prover/src/core/fri.rs index 46eabeedb..607dfaf41 100644 --- a/crates/prover/src/core/fri.rs +++ b/crates/prover/src/core/fri.rs @@ -699,9 +699,9 @@ impl FriFirstLayerVerifier { let mut fri_witness = self.proof.fri_witness.iter().copied(); let mut decommitment_positions_by_log_size = BTreeMap::new(); - let mut all_column_decommitment_values = Vec::new(); let mut folded_evals_by_column = Vec::new(); + let mut decommitmented_values = vec![]; for (&column_domain, column_query_evals) in zip_eq(&self.column_commitment_domains, query_evals_by_column) { @@ -722,15 +722,13 @@ impl FriFirstLayerVerifier { decommitment_positions_by_log_size .insert(column_domain.log_size(), column_decommitment_positions); - // Prepare values in the structure needed for merkle decommitment. - let column_decommitment_values: SecureColumnByCoords = sparse_evaluation - .subset_evals - .iter() - .flatten() - .copied() - .collect(); - - all_column_decommitment_values.extend(column_decommitment_values.columns); + decommitmented_values.extend( + sparse_evaluation + .subset_evals + .iter() + .flatten() + .flat_map(|qm31| qm31.to_m31_array()), + ); let folded_evals = sparse_evaluation.fold_circle(self.folding_alpha, column_domain); folded_evals_by_column.push(folded_evals); @@ -752,7 +750,7 @@ impl FriFirstLayerVerifier { merkle_verifier .verify( &decommitment_positions_by_log_size, - all_column_decommitment_values, + decommitmented_values, self.proof.decommitment.clone(), ) .map_err(|error| FriVerificationError::FirstLayerCommitmentInvalid { error })?; @@ -814,12 +812,12 @@ impl FriInnerLayerVerifier { }); } - let decommitment_values: SecureColumnByCoords = sparse_evaluation + let decommitmented_values = sparse_evaluation .subset_evals .iter() .flatten() - .copied() - .collect(); + .flat_map(|qm31| qm31.to_m31_array()) + .collect_vec(); let merkle_verifier = MerkleVerifier::new( self.proof.commitment, @@ -829,7 +827,7 @@ impl FriInnerLayerVerifier { merkle_verifier .verify( &BTreeMap::from_iter([(self.domain.log_size(), decommitment_positions)]), - decommitment_values.columns.to_vec(), + decommitmented_values, self.proof.decommitment.clone(), ) .map_err(|e| FriVerificationError::InnerLayerCommitmentInvalid { diff --git a/crates/prover/src/core/pcs/prover.rs b/crates/prover/src/core/pcs/prover.rs index ef27f706d..59a2e8e8d 100644 --- a/crates/prover/src/core/pcs/prover.rs +++ b/crates/prover/src/core/pcs/prover.rs @@ -152,7 +152,7 @@ pub struct CommitmentSchemeProof { pub commitments: TreeVec, pub sampled_values: TreeVec>>, pub decommitments: TreeVec>, - pub queried_values: TreeVec>>, + pub queried_values: TreeVec>, pub proof_of_work: u64, pub fri_proof: FriProof, } @@ -231,7 +231,7 @@ impl, MC: MerkleChannel> CommitmentTreeProver { fn decommit( &self, queries: &BTreeMap>, - ) -> (ColumnVec>, MerkleDecommitment) { + ) -> (Vec, MerkleDecommitment) { let eval_vec = self .evaluations .iter() diff --git a/crates/prover/src/core/pcs/quotients.rs b/crates/prover/src/core/pcs/quotients.rs index aca0901c6..d41f17670 100644 --- a/crates/prover/src/core/pcs/quotients.rs +++ b/crates/prover/src/core/pcs/quotients.rs @@ -5,6 +5,7 @@ use std::iter::zip; use itertools::{izip, multiunzip, Itertools}; use tracing::{span, Level}; +use super::TreeVec; use crate::core::backend::cpu::quotients::{accumulate_row_quotients, quotient_constants}; use crate::core::circle::CirclePoint; use crate::core::fields::m31::BaseField; @@ -100,25 +101,30 @@ pub fn compute_fri_quotients( } pub fn fri_answers( - column_log_sizes: Vec, - samples: &[Vec], + column_log_sizes: TreeVec>, + samples: TreeVec>>, random_coeff: SecureField, query_positions_per_log_size: &BTreeMap>, - queried_values_per_column: &[Vec], + queried_values: TreeVec>, + n_columns_per_log_size: TreeVec<&BTreeMap>, ) -> Result>, VerificationError> { - izip!(column_log_sizes, samples, queried_values_per_column) + let mut queried_values = queried_values.map(|values| values.into_iter()); + + izip!(column_log_sizes.flatten(), samples.flatten().iter()) .sorted_by_key(|(log_size, ..)| Reverse(*log_size)) .group_by(|(log_size, ..)| *log_size) .into_iter() .map(|(log_size, tuples)| { - let (_, samples, queried_values_per_column): (Vec<_>, Vec<_>, Vec<_>) = - multiunzip(tuples); + let (_, samples): (Vec<_>, Vec<_>) = multiunzip(tuples); fri_answers_for_log_size( log_size, &samples, random_coeff, &query_positions_per_log_size[&log_size], - &queried_values_per_column, + &mut queried_values, + n_columns_per_log_size + .as_ref() + .map(|colums_log_sizes| *colums_log_sizes.get(&log_size).unwrap_or(&0)), ) }) .collect() @@ -129,27 +135,23 @@ pub fn fri_answers_for_log_size( samples: &[&Vec], random_coeff: SecureField, query_positions: &[usize], - queried_values_per_column: &[&Vec], + queried_values: &mut TreeVec>, + n_columns: TreeVec, ) -> Result, VerificationError> { - for queried_values in queried_values_per_column { - if queried_values.len() != query_positions.len() { - return Err(VerificationError::InvalidStructure( - "Insufficient number of queried values".to_string(), - )); - } - } - let sample_batches = ColumnSampleBatch::new_vec(samples); let quotient_constants = quotient_constants(&sample_batches, random_coeff); let commitment_domain = CanonicCoset::new(log_size).circle_domain(); - let mut quotient_evals_at_queries = Vec::new(); - for (row, &query_position) in query_positions.iter().enumerate() { + let mut quotient_evals_at_queries = Vec::new(); + for &query_position in query_positions { let domain_point = commitment_domain.at(bit_reverse_index(query_position, log_size)); - let queried_values_at_row = queried_values_per_column - .iter() - .map(|col| col[row]) - .collect_vec(); + + let queried_values_at_row = queried_values + .as_mut() + .zip_eq(n_columns.as_ref()) + .map(|(queried_values, n_columns)| queried_values.take(*n_columns).collect()) + .flatten(); + quotient_evals_at_queries.push(accumulate_row_quotients( &sample_batches, &queried_values_at_row, diff --git a/crates/prover/src/core/pcs/verifier.rs b/crates/prover/src/core/pcs/verifier.rs index 200fe98d5..db7bea3d3 100644 --- a/crates/prover/src/core/pcs/verifier.rs +++ b/crates/prover/src/core/pcs/verifier.rs @@ -99,21 +99,23 @@ impl CommitmentSchemeVerifier { .collect::>()?; // Answer FRI queries. - let samples = sampled_points - .zip_cols(proof.sampled_values) - .map_cols(|(sampled_points, sampled_values)| { + let samples = sampled_points.zip_cols(proof.sampled_values).map_cols( + |(sampled_points, sampled_values)| { zip(sampled_points, sampled_values) .map(|(point, value)| PointSample { point, value }) .collect_vec() - }) - .flatten(); + }, + ); + + let n_columns_per_log_size = self.trees.as_ref().map(|tree| &tree.n_columns_per_log_size); let fri_answers = fri_answers( - self.column_log_sizes().flatten().into_iter().collect(), - &samples, + self.column_log_sizes(), + samples, random_coeff, &query_positions_per_log_size, - &proof.queried_values.flatten(), + proof.queried_values, + n_columns_per_log_size, )?; fri_verifier.decommit(fri_answers)?; diff --git a/crates/prover/src/core/vcs/blake2_merkle.rs b/crates/prover/src/core/vcs/blake2_merkle.rs index 8401716fa..3664ea147 100644 --- a/crates/prover/src/core/vcs/blake2_merkle.rs +++ b/crates/prover/src/core/vcs/blake2_merkle.rs @@ -86,7 +86,7 @@ mod tests { #[test] fn test_merkle_invalid_value() { let (queries, decommitment, mut values, verifier) = prepare_merkle::(); - values[3][2] = BaseField::zero(); + values[6] = BaseField::zero(); assert_eq!( verifier.verify(&queries, values, decommitment).unwrap_err(), @@ -119,22 +119,22 @@ mod tests { #[test] fn test_merkle_column_values_too_long() { let (queries, decommitment, mut values, verifier) = prepare_merkle::(); - values[3].push(BaseField::zero()); + values.insert(3, BaseField::zero()); assert_eq!( verifier.verify(&queries, values, decommitment).unwrap_err(), - MerkleVerificationError::ColumnValuesTooLong + MerkleVerificationError::TooManyQueriedValues ); } #[test] fn test_merkle_column_values_too_short() { let (queries, decommitment, mut values, verifier) = prepare_merkle::(); - values[3].pop(); + values.remove(3); assert_eq!( verifier.verify(&queries, values, decommitment).unwrap_err(), - MerkleVerificationError::ColumnValuesTooShort + MerkleVerificationError::TooFewQueriedValues ); } diff --git a/crates/prover/src/core/vcs/poseidon252_merkle.rs b/crates/prover/src/core/vcs/poseidon252_merkle.rs index 5ffba1ea6..f39a2c62d 100644 --- a/crates/prover/src/core/vcs/poseidon252_merkle.rs +++ b/crates/prover/src/core/vcs/poseidon252_merkle.rs @@ -114,7 +114,7 @@ mod tests { fn test_merkle_invalid_value() { let (queries, decommitment, mut values, verifier) = prepare_merkle::(); - values[3][2] = BaseField::zero(); + values[6] = BaseField::zero(); assert_eq!( verifier.verify(&queries, values, decommitment).unwrap_err(), @@ -147,26 +147,26 @@ mod tests { } #[test] - fn test_merkle_column_values_too_long() { + fn test_merkle_values_too_long() { let (queries, decommitment, mut values, verifier) = prepare_merkle::(); - values[3].push(BaseField::zero()); + values.insert(3, BaseField::zero()); assert_eq!( verifier.verify(&queries, values, decommitment).unwrap_err(), - MerkleVerificationError::ColumnValuesTooLong + MerkleVerificationError::TooManyQueriedValues ); } #[test] - fn test_merkle_column_values_too_short() { + fn test_merkle_values_too_short() { let (queries, decommitment, mut values, verifier) = prepare_merkle::(); - values[3].pop(); + values.remove(3); assert_eq!( verifier.verify(&queries, values, decommitment).unwrap_err(), - MerkleVerificationError::ColumnValuesTooShort + MerkleVerificationError::TooFewQueriedValues ); } } diff --git a/crates/prover/src/core/vcs/prover.rs b/crates/prover/src/core/vcs/prover.rs index bc788e51f..da4695d3f 100644 --- a/crates/prover/src/core/vcs/prover.rs +++ b/crates/prover/src/core/vcs/prover.rs @@ -9,7 +9,6 @@ use super::utils::{next_decommitment_node, option_flatten_peekable}; use crate::core::backend::{Col, Column}; use crate::core::fields::m31::BaseField; use crate::core::utils::PeekableExt; -use crate::core::ColumnVec; pub struct MerkleProver, H: MerkleHasher> { /// Layers of the Merkle tree. @@ -48,6 +47,7 @@ impl, H: MerkleHasher> MerkleProver { .into_iter() .sorted_by_key(|c| Reverse(c.len())) .peekable(); + let mut layers: Vec> = Vec::new(); let max_log_size = columns.peek().unwrap().len().ilog2(); @@ -75,15 +75,16 @@ impl, H: MerkleHasher> MerkleProver { /// # Returns /// /// A tuple containing: - /// * A vector of vectors of queried values for each column, in the order of the input columns. + /// * A vector queried values sorted by the order they were queried from the largest layer to + /// the smallest. /// * A `MerkleDecommitment` containing the hash and column witnesses. pub fn decommit( &self, queries_per_log_size: &BTreeMap>, columns: Vec<&Col>, - ) -> (ColumnVec>, MerkleDecommitment) { + ) -> (Vec, MerkleDecommitment) { // Prepare output buffers. - let mut queried_values_by_layer = vec![]; + let mut queried_values = vec![]; let mut decommitment = MerkleDecommitment::empty(); // Sort columns by layer. @@ -94,9 +95,6 @@ impl, H: MerkleHasher> MerkleProver { let mut last_layer_queries = vec![]; for layer_log_size in (0..self.layers.len() as u32).rev() { - // Prepare write buffer for queried values to the current layer. - let mut layer_queried_values = vec![]; - // Prepare write buffer for queries to the current layer. This will propagate to the // next layer. let mut layer_total_queries = vec![]; @@ -140,7 +138,7 @@ impl, H: MerkleHasher> MerkleProver { // If the column values were queried, return them. let node_values = layer_columns.iter().map(|c| c.at(node_index)); if layer_column_queries.next_if_eq(&node_index).is_some() { - layer_queried_values.push(node_values.collect_vec()); + queried_values.extend(node_values); } else { // Otherwise, add them to the witness. decommitment.column_witness.extend(node_values); @@ -149,50 +147,13 @@ impl, H: MerkleHasher> MerkleProver { layer_total_queries.push(node_index); } - queried_values_by_layer.push(layer_queried_values); - // Propagate queries to the next layer. last_layer_queries = layer_total_queries; } - queried_values_by_layer.reverse(); - - // Rearrange returned queried values according to input, and not by layer. - let queried_values = Self::rearrange_queried_values(queried_values_by_layer, columns); (queried_values, decommitment) } - /// Given queried values by layer, rearranges in the order of input columns. - fn rearrange_queried_values( - queried_values_by_layer: Vec>>, - columns: Vec<&Col>, - ) -> Vec> { - // Turn each column queried values into an iterator. - let mut queried_values_by_layer = queried_values_by_layer - .into_iter() - .map(|layer_results| { - layer_results - .into_iter() - .map(|x| x.into_iter()) - .collect_vec() - }) - .collect_vec(); - - // For each input column, fetch the queried values from the corresponding layer. - let queried_values = columns - .iter() - .map(|column| { - queried_values_by_layer - .get_mut(column.len().ilog2() as usize) - .unwrap() - .iter_mut() - .map(|x| x.next().unwrap()) - .collect_vec() - }) - .collect_vec(); - queried_values - } - pub fn root(&self) -> H::Hash { self.layers.first().unwrap().at(0) } diff --git a/crates/prover/src/core/vcs/test_utils.rs b/crates/prover/src/core/vcs/test_utils.rs index b92f9e971..c906f05d0 100644 --- a/crates/prover/src/core/vcs/test_utils.rs +++ b/crates/prover/src/core/vcs/test_utils.rs @@ -14,7 +14,7 @@ use crate::core::vcs::prover::MerkleProver; pub type TestData = ( BTreeMap>, MerkleDecommitment, - Vec>, + Vec, MerkleVerifier, ); @@ -52,9 +52,6 @@ where let (values, decommitment) = merkle.decommit(&queries, cols.iter().collect_vec()); - let verifier = MerkleVerifier { - root: merkle.root(), - column_log_sizes: log_sizes, - }; + let verifier = MerkleVerifier::new(merkle.root(), log_sizes); (queries, decommitment, values, verifier) } diff --git a/crates/prover/src/core/vcs/verifier.rs b/crates/prover/src/core/vcs/verifier.rs index 9c1b0b39a..fcd0453a3 100644 --- a/crates/prover/src/core/vcs/verifier.rs +++ b/crates/prover/src/core/vcs/verifier.rs @@ -1,4 +1,3 @@ -use std::cmp::Reverse; use std::collections::BTreeMap; use itertools::Itertools; @@ -9,27 +8,35 @@ use super::prover::MerkleDecommitment; use super::utils::{next_decommitment_node, option_flatten_peekable}; use crate::core::fields::m31::BaseField; use crate::core::utils::PeekableExt; -use crate::core::ColumnVec; pub struct MerkleVerifier { pub root: H::Hash, pub column_log_sizes: Vec, + pub n_columns_per_log_size: BTreeMap, } impl MerkleVerifier { pub fn new(root: H::Hash, column_log_sizes: Vec) -> Self { + let mut n_columns_per_log_size = BTreeMap::new(); + for log_size in &column_log_sizes { + *n_columns_per_log_size.entry(*log_size).or_insert(0) += 1; + } + Self { root, column_log_sizes, + n_columns_per_log_size, } } /// Verifies the decommitment of the columns. /// + /// Returns `Ok(())` if the decommitment is successfully verified. + /// /// # Arguments /// /// * `queries_per_log_size` - A map from log_size to a vector of queries for columns of that /// log_size. - /// * `queried_values` - A vector of vectors of queried values. For each column, there is a - /// vector of queried values to that column. + /// * `queried_values` - A vector of queried values according to the order in + /// [`MerkleProver::decommit()`]. /// * `decommitment` - The decommitment object containing the witness and column values. /// /// # Errors @@ -38,45 +45,35 @@ impl MerkleVerifier { /// /// * The witness is too long (not fully consumed). /// * The witness is too short (missing values). - /// * The column values are too long (not fully consumed). - /// * The column values are too short (missing values). + /// * Too many queried values (not fully consumed). + /// * Too few queried values (missing values). /// * The computed root does not match the expected root. /// - /// # Returns - /// - /// Returns `Ok(())` if the decommitment is successfully verified. + /// [`MerkleProver::decommit()`]: crate::core::...::MerkleProver::decommit + pub fn verify( &self, queries_per_log_size: &BTreeMap>, - queried_values: ColumnVec>, + queried_values: Vec, decommitment: MerkleDecommitment, ) -> Result<(), MerkleVerificationError> { let Some(max_log_size) = self.column_log_sizes.iter().max() else { return Ok(()); }; + let mut queried_values = queried_values.into_iter(); + // Prepare read buffers. - let mut queried_values_by_layer = self - .column_log_sizes - .iter() - .copied() - .zip( - queried_values - .into_iter() - .map(|column_values| column_values.into_iter()), - ) - .sorted_by_key(|(log_size, _)| Reverse(*log_size)) - .peekable(); + let mut hash_witness = decommitment.hash_witness.into_iter(); let mut column_witness = decommitment.column_witness.into_iter(); let mut last_layer_hashes: Option> = None; for layer_log_size in (0..=*max_log_size).rev() { - // Prepare read buffer for queried values to the current layer. - let mut layer_queried_values = queried_values_by_layer - .peek_take_while(|(log_size, _)| *log_size == layer_log_size) - .collect_vec(); - let n_columns_in_layer = layer_queried_values.len(); + let n_columns_in_layer = *self + .n_columns_per_log_size + .get(&layer_log_size) + .unwrap_or(&0); // Prepare write buffer for queries to the current layer. This will propagate to the // next layer. @@ -132,29 +129,26 @@ impl MerkleVerifier { .transpose()?; // If the column values were queried, read them from `queried_value`. - let node_values = if layer_column_queries.next_if_eq(&node_index).is_some() { - layer_queried_values - .iter_mut() - .map(|(_, ref mut column_queries)| { - column_queries - .next() - .ok_or(MerkleVerificationError::ColumnValuesTooShort) - }) - .collect::, _>>()? - } else { + let (err, node_values_iter) = match layer_column_queries.next_if_eq(&node_index) { + Some(_) => ( + MerkleVerificationError::TooFewQueriedValues, + &mut queried_values, + ), // Otherwise, read them from the witness. - (&mut column_witness).take(n_columns_in_layer).collect_vec() + None => ( + MerkleVerificationError::WitnessTooShort, + &mut column_witness, + ), }; + + let node_values = node_values_iter.take(n_columns_in_layer).collect_vec(); if node_values.len() != n_columns_in_layer { - return Err(MerkleVerificationError::WitnessTooShort); + return Err(err); } layer_total_queries.push((node_index, H::hash_node(node_hashes, &node_values))); } - if !layer_queried_values.iter().all(|(_, c)| c.is_empty()) { - return Err(MerkleVerificationError::ColumnValuesTooLong); - } last_layer_hashes = Some(layer_total_queries); } @@ -162,6 +156,9 @@ impl MerkleVerifier { if !hash_witness.is_empty() { return Err(MerkleVerificationError::WitnessTooLong); } + if !queried_values.is_empty() { + return Err(MerkleVerificationError::TooManyQueriedValues); + } if !column_witness.is_empty() { return Err(MerkleVerificationError::WitnessTooLong); } @@ -175,16 +172,17 @@ impl MerkleVerifier { } } +// TODO(ilya): Make error messages consistent. #[derive(Clone, Copy, Debug, Error, PartialEq, Eq)] pub enum MerkleVerificationError { - #[error("Witness is too short.")] + #[error("Witness is too short")] WitnessTooShort, #[error("Witness is too long.")] WitnessTooLong, - #[error("Column values are too long.")] - ColumnValuesTooLong, - #[error("Column values are too short.")] - ColumnValuesTooShort, + #[error("too many Queried values")] + TooManyQueriedValues, + #[error("too few queried values")] + TooFewQueriedValues, #[error("Root mismatch.")] RootMismatch, } From f519bb6844f4d88b4a77065b2c109daf2e6508b8 Mon Sep 17 00:00:00 2001 From: Alon-Ti <54235977+Alon-Ti@users.noreply.github.com> Date: Tue, 3 Dec 2024 15:19:04 +0200 Subject: [PATCH 17/69] Added `add_intermediate` and `add_secure_intermediate` to eval API. (#919) --- .../prover/src/constraint_framework/expr.rs | 69 +++++++++++++------ crates/prover/src/constraint_framework/mod.rs | 12 ++++ .../prover/src/examples/state_machine/mod.rs | 1 + 3 files changed, 61 insertions(+), 21 deletions(-) diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index e01a03818..3098dcc56 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -791,34 +791,35 @@ pub struct ExprEvaluator { pub cur_var_index: usize, pub constraints: Vec, pub logup: FormalLogupAtRow, - pub intermediates: Vec<(String, ExtExpr)>, + pub intermediates: Vec<(String, BaseExpr)>, + pub ext_intermediates: Vec<(String, ExtExpr)>, } impl ExprEvaluator { - #[allow(dead_code)] pub fn new(log_size: u32, has_partial_sum: bool) -> Self { Self { cur_var_index: Default::default(), constraints: Default::default(), logup: FormalLogupAtRow::new(INTERACTION_TRACE_IDX, has_partial_sum, log_size), intermediates: vec![], + ext_intermediates: vec![], } } - pub fn add_intermediate(&mut self, expr: ExtExpr) -> ExtExpr { - let name = format!("intermediate{}", self.intermediates.len()); - let intermediate = ExtExpr::Param(name.clone()); - self.intermediates.push((name, expr)); - intermediate - } - pub fn format_constraints(&self) -> String { let lets_string = self .intermediates .iter() .map(|(name, expr)| format!("let {} = {};", name, expr.simplify_and_format())) .collect::>() - .join("\n"); + .join("\n\n"); + + let secure_lets_string = self + .ext_intermediates + .iter() + .map(|(name, expr)| format!("let {} = {};", name, expr.simplify_and_format())) + .collect::>() + .join("\n\n"); let constraints_str = self .constraints @@ -828,7 +829,12 @@ impl ExprEvaluator { .collect::>() .join("\n\n"); - lets_string + "\n\n" + &constraints_str + [lets_string, secure_lets_string, constraints_str] + .iter() + .filter(|x| !x.is_empty()) + .cloned() + .collect::>() + .join("\n\n") } } @@ -877,7 +883,8 @@ impl EvalAtRow for ExprEvaluator { multiplicity, values, }| { - let intermediate = self.add_intermediate(combine_formal(*relation, values)); + let intermediate = + self.add_extension_intermediate(combine_formal(*relation, values)); Fraction::new(multiplicity.clone(), intermediate) }, ) @@ -885,6 +892,26 @@ impl EvalAtRow for ExprEvaluator { self.write_logup_frac(fracs.into_iter().sum()); } + fn add_intermediate(&mut self, expr: Self::F) -> Self::F { + let name = format!( + "intermediate{}", + self.intermediates.len() + self.ext_intermediates.len() + ); + let intermediate = BaseExpr::Param(name.clone()); + self.intermediates.push((name, expr)); + intermediate + } + + fn add_extension_intermediate(&mut self, expr: Self::EF) -> Self::EF { + let name = format!( + "intermediate{}", + self.intermediates.len() + self.ext_intermediates.len() + ); + let intermediate = ExtExpr::Param(name.clone()); + self.ext_intermediates.push((name, expr)); + intermediate + } + super::logup_proxy!(); } @@ -1050,21 +1077,22 @@ mod tests { fn test_format_expr() { let test_struct = TestStruct {}; let eval = test_struct.evaluate(ExprEvaluator::new(16, false)); - let expected = "let intermediate0 = (TestRelation_alpha0) * (col_1_0[0]) \ + let expected = "let intermediate0 = (col_1_1[0]) * (col_1_2[0]); + +\ + let intermediate1 = (TestRelation_alpha0) * (col_1_0[0]) \ + (TestRelation_alpha1) * (col_1_1[0]) \ + (TestRelation_alpha2) * (col_1_2[0]) \ - (TestRelation_z); \ - let constraint_0 = \ - (((col_1_0[0]) * (col_1_1[0])) * (col_1_2[0])) * (1 / (col_1_0[0] + col_1_1[0])); + let constraint_0 = ((col_1_0[0]) * (intermediate0)) * (1 / (col_1_0[0] + col_1_1[0])); \ let constraint_1 = (SecureCol(col_2_4[0], col_2_6[0], col_2_8[0], col_2_10[0]) \ - (SecureCol(col_2_5[-1], col_2_7[-1], col_2_9[-1], col_2_11[-1]) \ - - ((total_sum) * (col_0_3[0])))\ - ) \ - * (intermediate0) \ + - ((total_sum) * (col_0_3[0])))) \ + * (intermediate1) \ - (1);" .to_string(); @@ -1085,9 +1113,8 @@ mod tests { let x0 = eval.next_trace_mask(); let x1 = eval.next_trace_mask(); let x2 = eval.next_trace_mask(); - eval.add_constraint( - x0.clone() * x1.clone() * x2.clone() * (x0.clone() + x1.clone()).inverse(), - ); + let intermediate = eval.add_intermediate(x1.clone() * x2.clone()); + eval.add_constraint(x0.clone() * intermediate * (x0.clone() + x1.clone()).inverse()); eval.add_to_relation(&[RelationEntry::new( &TestRelation::dummy(), E::EF::one(), diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 341e24511..b03c08ce3 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -111,6 +111,18 @@ pub trait EvalAtRow { where Self::EF: Mul + From; + /// Adds an intermediate value in the base field to the component and returns its value. + /// Does nothing by default. + fn add_intermediate(&mut self, val: Self::F) -> Self::F { + val + } + + /// Adds an intermediate value in the extension field to the component and returns its value. + /// Does nothing by default. + fn add_extension_intermediate(&mut self, val: Self::EF) -> Self::EF { + val + } + /// Combines 4 base field values into a single extension field value. fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF; diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index 23973a960..bdb265fff 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -357,6 +357,7 @@ mod tests { let expected = "let intermediate0 = (StateMachineElements_alpha0) * (col_1_0[0]) \ + (StateMachineElements_alpha1) * (col_1_1[0]) \ - (StateMachineElements_z); + \ let intermediate1 = (StateMachineElements_alpha0) * (col_1_0[0] + 1) \ + (StateMachineElements_alpha1) * (col_1_1[0]) \ From 76af3c6c7afc5b6e582ec5d16a03e74efe5a11ee Mon Sep 17 00:00:00 2001 From: Ohad <137686240+ohad-starkware@users.noreply.github.com> Date: Wed, 4 Dec 2024 11:28:06 +0200 Subject: [PATCH 18/69] relation tracker bug fix (#921) --- .../constraint_framework/relation_tracker.rs | 38 ++++++++++++++++--- .../prover/src/examples/state_machine/mod.rs | 8 +++- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/crates/prover/src/constraint_framework/relation_tracker.rs b/crates/prover/src/constraint_framework/relation_tracker.rs index 3866df39a..b804d488e 100644 --- a/crates/prover/src/constraint_framework/relation_tracker.rs +++ b/crates/prover/src/constraint_framework/relation_tracker.rs @@ -194,14 +194,24 @@ impl RelationSummary { /// Returns the sum of every entry's yields and uses. /// The result is a map from relation name to a list of values(M31 vectors) and their sum. pub fn summarize_relations(entries: &[RelationTrackerEntry]) -> Self { + let mut entry_by_relation = HashMap::new(); + for entry in entries { + entry_by_relation + .entry(entry.relation.clone()) + .or_insert_with(Vec::new) + .push(entry); + } let mut summary = vec![]; - let relations = entries.iter().group_by(|entry| entry.relation.clone()); - for (relation, entries) in &relations { + for (relation, entries) in entry_by_relation { let mut relation_sums: HashMap, M31> = HashMap::new(); for entry in entries { - let mult = relation_sums - .entry(entry.values.clone()) - .or_insert(M31::zero()); + let mut values = entry.values.clone(); + + // Trailing zeroes do not affect the sum, remove for correct aggregation. + while values.last().is_some_and(|v| v.is_zero()) { + values.pop(); + } + let mult = relation_sums.entry(values).or_insert(M31::zero()); *mult += entry.mult; } let relation_sums = relation_sums.into_iter().collect_vec(); @@ -216,6 +226,24 @@ impl RelationSummary { .find(|(name, _)| name == relation) .map(|(_, entries)| entries.as_slice()) } + + /// Cleans up the summary by removing zero-sum entries, only keeping the non-zero ones. + /// Used for debugging. + pub fn cleaned(self) -> Self { + let mut cleaned = vec![]; + for (relation, entries) in self.0 { + let mut cleaned_entries = vec![]; + for (vector, sum) in entries { + if !sum.is_zero() { + cleaned_entries.push((vector, sum)); + } + } + if !cleaned_entries.is_empty() { + cleaned.push((relation, cleaned_entries)); + } + } + Self(cleaned) + } } impl Debug for RelationSummary { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index bdb265fff..84b64617c 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -305,8 +305,12 @@ mod tests { // Check the final state inferred from the summary. let mut curr_state = initial_state; for entry in relation_info { - let x_step = entry.0[0]; - let y_step = entry.0[1]; + let (x_step, y_step) = match entry.0.len() { + 2 => (entry.0[0], entry.0[1]), + 1 => (entry.0[0], M31::zero()), + 0 => (M31::zero(), M31::zero()), + _ => unreachable!(), + }; let mult = entry.1; let next_state = [curr_state[0] - x_step * mult, curr_state[1] - y_step * mult]; From 09860e0eb9482d5305081c8b2946c975408afd05 Mon Sep 17 00:00:00 2001 From: Shahar Samocha Date: Thu, 28 Nov 2024 16:56:48 +0200 Subject: [PATCH 19/69] Fix machete version --- .github/workflows/ci.yaml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index b9b72d33d..29f73ae76 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -196,8 +196,14 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly-2024-01-04 + - uses: Swatinem/rust-cache@v2 + - name: Install Machete + run: cargo +nightly-2024-01-04 install --locked cargo-machete - name: Run Machete (detect unused dependencies) - uses: bnjbvr/cargo-machete@main + run: cargo +nightly-2024-01-04 machete all-tests: runs-on: ubuntu-latest From a201e77bf10be532f31f26dde6c22b4c2bf48ecb Mon Sep 17 00:00:00 2001 From: Alon-Ti <54235977+Alon-Ti@users.noreply.github.com> Date: Sun, 8 Dec 2024 16:08:28 +0200 Subject: [PATCH 20/69] Formal preprocessed columns in expressions. (#923) --- crates/prover/src/constraint_framework/expr.rs | 11 ++++++++--- .../constraint_framework/preprocessed_columns.rs | 10 ++++++++++ crates/prover/src/examples/state_machine/mod.rs | 16 ++++++++-------- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index 3098dcc56..f5d61bd5c 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -6,6 +6,7 @@ use num_traits::{One, Zero}; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; +use super::preprocessed_columns::PreprocessedColumn; use super::{AssertEvaluator, EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX}; use crate::core::fields::cm31::CM31; use crate::core::fields::m31::{self, BaseField}; @@ -912,6 +913,10 @@ impl EvalAtRow for ExprEvaluator { intermediate } + fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F { + BaseExpr::Param(column.name().to_string()) + } + super::logup_proxy!(); } @@ -1089,9 +1094,9 @@ mod tests { let constraint_0 = ((col_1_0[0]) * (intermediate0)) * (1 / (col_1_0[0] + col_1_1[0])); \ - let constraint_1 = (SecureCol(col_2_4[0], col_2_6[0], col_2_8[0], col_2_10[0]) \ - - (SecureCol(col_2_5[-1], col_2_7[-1], col_2_9[-1], col_2_11[-1]) \ - - ((total_sum) * (col_0_3[0])))) \ + let constraint_1 = (SecureCol(col_2_3[0], col_2_5[0], col_2_7[0], col_2_9[0]) \ + - (SecureCol(col_2_4[-1], col_2_6[-1], col_2_8[-1], col_2_10[-1]) \ + - ((total_sum) * (preprocessed.is_first)))) \ * (intermediate1) \ - (1);" .to_string(); diff --git a/crates/prover/src/constraint_framework/preprocessed_columns.rs b/crates/prover/src/constraint_framework/preprocessed_columns.rs index bd7b4c9c9..f196567dd 100644 --- a/crates/prover/src/constraint_framework/preprocessed_columns.rs +++ b/crates/prover/src/constraint_framework/preprocessed_columns.rs @@ -14,6 +14,16 @@ pub enum PreprocessedColumn { Plonk(usize), } +impl PreprocessedColumn { + pub const fn name(&self) -> &'static str { + match self { + PreprocessedColumn::XorTable(..) => "preprocessed.xor_table", + PreprocessedColumn::IsFirst(_) => "preprocessed.is_first", + PreprocessedColumn::Plonk(_) => "preprocessed.plonk", + } + } +} + /// Generates a column with a single one at the first position, and zeros elsewhere. pub fn gen_is_first(log_size: u32) -> CircleEvaluation { let mut col = Col::::zeros(1 << log_size); diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index 84b64617c..043cce416 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -369,17 +369,17 @@ mod tests { \ let constraint_0 = (SecureCol(\ - col_2_5[claimed_sum_offset], \ - col_2_8[claimed_sum_offset], \ - col_2_11[claimed_sum_offset], \ - col_2_14[claimed_sum_offset]\ + col_2_4[claimed_sum_offset], \ + col_2_7[claimed_sum_offset], \ + col_2_10[claimed_sum_offset], \ + col_2_13[claimed_sum_offset]\ ) - (claimed_sum)) \ - * (col_0_2[0]); + * (preprocessed.is_first); \ - let constraint_1 = (SecureCol(col_2_3[0], col_2_6[0], col_2_9[0], col_2_12[0]) \ - - (SecureCol(col_2_4[-1], col_2_7[-1], col_2_10[-1], col_2_13[-1]) \ - - ((total_sum) * (col_0_2[0])))\ + let constraint_1 = (SecureCol(col_2_2[0], col_2_5[0], col_2_8[0], col_2_11[0]) \ + - (SecureCol(col_2_3[-1], col_2_6[-1], col_2_9[-1], col_2_12[-1]) \ + - ((total_sum) * (preprocessed.is_first)))\ ) \ * ((intermediate0) * (intermediate1)) \ - (intermediate1 - (intermediate0));" From 27980202c2b0e4cda99c30a09a03ebb8b833b1da Mon Sep 17 00:00:00 2001 From: Alon-Ti <54235977+Alon-Ti@users.noreply.github.com> Date: Sun, 8 Dec 2024 16:13:27 +0200 Subject: [PATCH 21/69] Fix expression bugs. (#925) --- .../prover/src/constraint_framework/expr.rs | 50 ++++++++++++------- .../prover/src/examples/state_machine/mod.rs | 10 ++-- 2 files changed, 38 insertions(+), 22 deletions(-) diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index f5d61bd5c..34f400f37 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -1,10 +1,9 @@ use std::collections::{HashMap, HashSet}; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::ops::{Add, AddAssign, Index, Mul, MulAssign, Neg, Sub}; use itertools::sorted; use num_traits::{One, Zero}; -use rand::rngs::SmallRng; -use rand::{Rng, SeedableRng}; use super::preprocessed_columns::PreprocessedColumn; use super::{AssertEvaluator, EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX}; @@ -251,7 +250,7 @@ impl BaseExpr { } pub fn random_eval(&self) -> BaseField { - let assignment = self.collect_variables().random_assignment(); + let assignment = self.collect_variables().random_assignment(0); assert!(assignment.2.is_empty()); self.eval_expr::, _, _>(&assignment.0, &assignment.1) } @@ -384,7 +383,7 @@ impl ExtExpr { } pub fn random_eval(&self) -> SecureField { - let assignment = self.collect_variables().random_assignment(); + let assignment = self.collect_variables().random_assignment(0); self.eval_expr::, _, _, _>(&assignment.0, &assignment.1, &assignment.2) } } @@ -433,21 +432,37 @@ impl ExprVariables { } /// Generates a random assignment to the variables. - /// Note that the assignment is deterministic in the sets of variables (disregarding their - /// order), and this is required. - pub fn random_assignment(&self) -> ExprVarAssignment { - let mut rng = SmallRng::seed_from_u64(0); - + /// Note that the assignment is deterministically dependent on every variable and that this is + /// required. + pub fn random_assignment(&self, salt: usize) -> ExprVarAssignment { let cols = sorted(self.cols.iter()) - .map(|col| ((col.interaction, col.idx, col.offset), rng.gen())) + .map(|col| { + ((col.interaction, col.idx, col.offset), { + let mut hasher = DefaultHasher::new(); + (salt, col).hash(&mut hasher); + (hasher.finish() as u32).into() + }) + }) .collect(); let params = sorted(self.params.iter()) - .map(|param| (param.clone(), rng.gen())) + .map(|param| { + (param.clone(), { + let mut hasher = DefaultHasher::new(); + (salt, param).hash(&mut hasher); + (hasher.finish() as u32).into() + }) + }) .collect(); let ext_params = sorted(self.ext_params.iter()) - .map(|param| (param.clone(), rng.gen())) + .map(|param| { + (param.clone(), { + let mut hasher = DefaultHasher::new(); + (salt, param).hash(&mut hasher); + (hasher.finish() as u32).into() + }) + }) .collect(); (cols, params, ext_params) @@ -849,11 +864,12 @@ impl EvalAtRow for ExprEvaluator { interaction: usize, offsets: [isize; N], ) -> [Self::F; N] { - std::array::from_fn(|i| { + let res = std::array::from_fn(|i| { let col = ColumnExpr::from((interaction, self.cur_var_index, offsets[i])); - self.cur_var_index += 1; BaseExpr::Col(col) - }) + }); + self.cur_var_index += 1; + res } fn add_constraint(&mut self, constraint: G) @@ -1094,8 +1110,8 @@ mod tests { let constraint_0 = ((col_1_0[0]) * (intermediate0)) * (1 / (col_1_0[0] + col_1_1[0])); \ - let constraint_1 = (SecureCol(col_2_3[0], col_2_5[0], col_2_7[0], col_2_9[0]) \ - - (SecureCol(col_2_4[-1], col_2_6[-1], col_2_8[-1], col_2_10[-1]) \ + let constraint_1 = (SecureCol(col_2_3[0], col_2_4[0], col_2_5[0], col_2_6[0]) \ + - (SecureCol(col_2_3[-1], col_2_4[-1], col_2_5[-1], col_2_6[-1]) \ - ((total_sum) * (preprocessed.is_first)))) \ * (intermediate1) \ - (1);" diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index 043cce416..684f04f76 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -369,16 +369,16 @@ mod tests { \ let constraint_0 = (SecureCol(\ + col_2_2[claimed_sum_offset], \ + col_2_3[claimed_sum_offset], \ col_2_4[claimed_sum_offset], \ - col_2_7[claimed_sum_offset], \ - col_2_10[claimed_sum_offset], \ - col_2_13[claimed_sum_offset]\ + col_2_5[claimed_sum_offset]\ ) - (claimed_sum)) \ * (preprocessed.is_first); \ - let constraint_1 = (SecureCol(col_2_2[0], col_2_5[0], col_2_8[0], col_2_11[0]) \ - - (SecureCol(col_2_3[-1], col_2_6[-1], col_2_9[-1], col_2_12[-1]) \ + let constraint_1 = (SecureCol(col_2_2[0], col_2_3[0], col_2_4[0], col_2_5[0]) \ + - (SecureCol(col_2_2[-1], col_2_3[-1], col_2_4[-1], col_2_5[-1]) \ - ((total_sum) * (preprocessed.is_first)))\ ) \ * ((intermediate0) * (intermediate1)) \ From c2780f7f4ccb69338121ae710c0c79247dbbb118 Mon Sep 17 00:00:00 2001 From: Alon-Ti <54235977+Alon-Ti@users.noreply.github.com> Date: Mon, 9 Dec 2024 18:38:37 +0200 Subject: [PATCH 22/69] Decoupled batching from `add_to_relation`. (#922) --- .../prover/src/constraint_framework/expr.rs | 31 ++---- .../prover/src/constraint_framework/logup.rs | 9 +- crates/prover/src/constraint_framework/mod.rs | 100 ++++++++++++++---- .../constraint_framework/relation_tracker.rs | 52 ++++----- crates/prover/src/examples/blake/mod.rs | 40 +++---- .../src/examples/blake/round/constraints.rs | 6 +- .../examples/blake/scheduler/constraints.rs | 20 ++-- .../examples/blake/xor_table/constraints.rs | 56 ++++------ crates/prover/src/examples/plonk/mod.rs | 20 ++-- crates/prover/src/examples/poseidon/mod.rs | 12 ++- .../src/examples/state_machine/components.rs | 18 ++-- 11 files changed, 207 insertions(+), 157 deletions(-) diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index 34f400f37..12c841b51 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -771,8 +771,7 @@ pub struct FormalLogupAtRow { pub interaction: usize, pub total_sum: ExtExpr, pub claimed_sum: Option<(ExtExpr, usize)>, - pub prev_col_cumsum: ExtExpr, - pub cur_frac: Option>, + pub fracs: Vec>, pub is_finalized: bool, pub is_first: BaseExpr, pub log_size: u32, @@ -793,8 +792,7 @@ impl FormalLogupAtRow { total_sum: ExtExpr::Param(total_sum_name), claimed_sum: has_partial_sum .then_some((ExtExpr::Param(claimed_sum_name), CLAIMED_SUM_DUMMY_OFFSET)), - prev_col_cumsum: ExtExpr::zero(), - cur_frac: None, + fracs: vec![], is_finalized: true, is_first: BaseExpr::zero(), log_size, @@ -890,23 +888,12 @@ impl EvalAtRow for ExprEvaluator { fn add_to_relation>( &mut self, - entries: &[RelationEntry<'_, Self::F, Self::EF, R>], + entry: RelationEntry<'_, Self::F, Self::EF, R>, ) { - let fracs: Vec> = entries - .iter() - .map( - |RelationEntry { - relation, - multiplicity, - values, - }| { - let intermediate = - self.add_extension_intermediate(combine_formal(*relation, values)); - Fraction::new(multiplicity.clone(), intermediate) - }, - ) - .collect(); - self.write_logup_frac(fracs.into_iter().sum()); + let intermediate = + self.add_extension_intermediate(combine_formal(entry.relation, entry.values)); + let frac = Fraction::new(entry.multiplicity.clone(), intermediate); + self.write_logup_frac(frac); } fn add_intermediate(&mut self, expr: Self::F) -> Self::F { @@ -1136,11 +1123,11 @@ mod tests { let x2 = eval.next_trace_mask(); let intermediate = eval.add_intermediate(x1.clone() * x2.clone()); eval.add_constraint(x0.clone() * intermediate * (x0.clone() + x1.clone()).inverse()); - eval.add_to_relation(&[RelationEntry::new( + eval.add_to_relation(RelationEntry::new( &TestRelation::dummy(), E::EF::one(), &[x0, x1, x2], - )]); + )); eval.finalize_logup(); eval } diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index d6af96e46..bb05c6b5c 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -49,8 +49,7 @@ pub struct LogupAtRow { /// None if the claimed_sum is the total_sum. pub claimed_sum: Option, /// The evaluation of the last cumulative sum column. - pub prev_col_cumsum: E::EF, - pub cur_frac: Option>, + pub fracs: Vec>, pub is_finalized: bool, /// The value of the `is_first` constant column at current row. /// See [`super::preprocessed_columns::gen_is_first()`]. @@ -74,8 +73,7 @@ impl LogupAtRow { interaction, total_sum, claimed_sum, - prev_col_cumsum: E::EF::zero(), - cur_frac: None, + fracs: vec![], is_finalized: true, is_first: E::F::zero(), log_size, @@ -88,8 +86,7 @@ impl LogupAtRow { interaction: 100, total_sum: SecureField::one(), claimed_sum: None, - prev_col_cumsum: E::EF::zero(), - cur_frac: None, + fracs: vec![], is_finalized: true, is_first: E::F::zero(), log_size: 10, diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index b03c08ce3..bc188eb57 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -32,6 +32,13 @@ pub const PREPROCESSED_TRACE_IDX: usize = 0; pub const ORIGINAL_TRACE_IDX: usize = 1; pub const INTERACTION_TRACE_IDX: usize = 2; +/// A vector that describes the batching of logup entries. +/// Each vector member corresponds to a logup entry, and contains the batch number to which the +/// entry should be added. +/// Note that the batch numbers should be consecutive and start from 0, and that the vector's +/// length should be equal to the number of logup entries. +type Batching = Vec; + /// A trait for evaluating expressions at some point or row. pub trait EvalAtRow { // TODO(Ohad): Use a better trait for these, like 'Algebra' or something. @@ -132,25 +139,30 @@ pub trait EvalAtRow { /// multiplied. fn add_to_relation>( &mut self, - entries: &[RelationEntry<'_, Self::F, Self::EF, R>], + entry: RelationEntry<'_, Self::F, Self::EF, R>, ) { - let fracs = entries.iter().map( - |RelationEntry { - relation, - multiplicity, - values, - }| { Fraction::new(multiplicity.clone(), relation.combine(values)) }, + let frac = Fraction::new( + entry.multiplicity.clone(), + entry.relation.combine(entry.values), ); - self.write_logup_frac(fracs.sum()); + self.write_logup_frac(frac); } // TODO(alont): Remove these once LogupAtRow is no longer used. fn write_logup_frac(&mut self, _fraction: Fraction) { unimplemented!() } - fn finalize_logup(&mut self) { + fn finalize_logup_batched(&mut self, _batching: &Batching) { unimplemented!() } + + fn finalize_logup(&mut self) { + unimplemented!(); + } + + fn finalize_logup_in_pairs(&mut self) { + unimplemented!(); + } } /// Default implementation for evaluators that have an element called "logup" that works like a @@ -159,26 +171,59 @@ pub trait EvalAtRow { macro_rules! logup_proxy { () => { fn write_logup_frac(&mut self, fraction: Fraction) { - // Add a constraint that num / denom = diff. - if let Some(cur_frac) = self.logup.cur_frac.clone() { - let [cur_cumsum] = - self.next_extension_interaction_mask(self.logup.interaction, [0]); - let diff = cur_cumsum.clone() - self.logup.prev_col_cumsum.clone(); - self.logup.prev_col_cumsum = cur_cumsum; - self.add_constraint(diff * cur_frac.denominator - cur_frac.numerator); - } else { + if self.logup.fracs.is_empty() { self.logup.is_first = self.get_preprocessed_column( super::preprocessed_columns::PreprocessedColumn::IsFirst(self.logup.log_size), ); self.logup.is_finalized = false; } - self.logup.cur_frac = Some(fraction); + self.logup.fracs.push(fraction.clone()); } - fn finalize_logup(&mut self) { + /// Finalize the logup by adding the constraints for the fractions, batched by + /// the given `batching`. + /// `batching` should contain the batch into which every logup entry should be inserted. + fn finalize_logup_batched(&mut self, batching: &super::Batching) { assert!(!self.logup.is_finalized, "LogupAtRow was already finalized"); + assert_eq!( + batching.len(), + self.logup.fracs.len(), + "Batching must be of the same length as the number of entries" + ); + + let last_batch = *batching.iter().max().unwrap(); + + let mut fracs_by_batch = + std::collections::HashMap::>>::new(); + + for (batch, frac) in batching.iter().zip(self.logup.fracs.iter()) { + fracs_by_batch + .entry(*batch) + .or_insert_with(Vec::new) + .push(frac.clone()); + } + + let keys_set: std::collections::HashSet<_> = fracs_by_batch.keys().cloned().collect(); + let all_batches_set: std::collections::HashSet<_> = (0..last_batch + 1).collect(); - let frac = self.logup.cur_frac.clone().unwrap(); + assert_eq!( + keys_set, all_batches_set, + "Batching must contain all consecutive batches" + ); + + let mut prev_col_cumsum = ::zero(); + + // All batches except the last are cumulatively summed in new interaction columns. + for batch_id in (0..last_batch) { + let cur_frac: Fraction<_, _> = fracs_by_batch[&batch_id].iter().cloned().sum(); + let [cur_cumsum] = + self.next_extension_interaction_mask(self.logup.interaction, [0]); + let diff = cur_cumsum.clone() - prev_col_cumsum.clone(); + prev_col_cumsum = cur_cumsum; + self.add_constraint(diff * cur_frac.denominator - cur_frac.numerator); + } + + let frac: Fraction<_, _> = fracs_by_batch[&last_batch].clone().into_iter().sum(); // TODO(ShaharS): remove `claimed_row_index` interaction value and get the shifted // offset from the is_first column when constant columns are supported. @@ -205,12 +250,25 @@ macro_rules! logup_proxy { // Fix `prev_row_cumsum` by subtracting `total_sum` if this is the first row. let fixed_prev_row_cumsum = prev_row_cumsum - self.logup.is_first.clone() * self.logup.total_sum.clone(); - let diff = cur_cumsum - fixed_prev_row_cumsum - self.logup.prev_col_cumsum.clone(); + let diff = cur_cumsum - fixed_prev_row_cumsum - prev_col_cumsum.clone(); self.add_constraint(diff * frac.denominator - frac.numerator); self.logup.is_finalized = true; } + + /// Finalizes the row's logup in the default way. Currently, this means no batching. + fn finalize_logup(&mut self) { + let batches = (0..self.logup.fracs.len()).collect(); + self.finalize_logup_batched(&batches) + } + + /// Finalizes the row's logup, batched in pairs. + /// TODO(alont) Remove this once a better batching mechanism is implemented. + fn finalize_logup_in_pairs(&mut self) { + let batches = (0..self.logup.fracs.len()).map(|n| n / 2).collect(); + self.finalize_logup_batched(&batches) + } }; } pub(crate) use logup_proxy; diff --git a/crates/prover/src/constraint_framework/relation_tracker.rs b/crates/prover/src/constraint_framework/relation_tracker.rs index b804d488e..8311209d1 100644 --- a/crates/prover/src/constraint_framework/relation_tracker.rs +++ b/crates/prover/src/constraint_framework/relation_tracker.rs @@ -6,8 +6,8 @@ use num_traits::Zero; use super::logup::LogupSums; use super::{ - EvalAtRow, FrameworkEval, InfoEvaluator, Relation, RelationEntry, TraceLocationAllocator, - INTERACTION_TRACE_IDX, + Batching, EvalAtRow, FrameworkEval, InfoEvaluator, Relation, RelationEntry, + TraceLocationAllocator, INTERACTION_TRACE_IDX, }; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; use crate::core::backend::simd::qm31::PackedSecureField; @@ -152,38 +152,38 @@ impl<'a> EvalAtRow for RelationTrackerEvaluator<'a> { fn write_logup_frac(&mut self, _fraction: Fraction) {} + fn finalize_logup_batched(&mut self, _batching: &Batching) {} fn finalize_logup(&mut self) {} + fn finalize_logup_in_pairs(&mut self) {} fn add_to_relation>( &mut self, - entries: &[RelationEntry<'_, Self::F, Self::EF, R>], + entry: RelationEntry<'_, Self::F, Self::EF, R>, ) { - for entry in entries { - let relation = entry.relation.get_name().to_owned(); - let values = entry.values.iter().map(|v| v.to_array()).collect_vec(); - let mult = entry.multiplicity.to_array(); + let relation = entry.relation.get_name().to_owned(); + let values = entry.values.iter().map(|v| v.to_array()).collect_vec(); + let mult = entry.multiplicity.to_array(); - // Unpack SIMD. - for j in 0..N_LANES { - // Skip padded values. - let cannonical_index = bit_reverse_index( - coset_index_to_circle_domain_index( - (self.vec_row << LOG_N_LANES) + j, - self.domain_log_size, - ), + // Unpack SIMD. + for j in 0..N_LANES { + // Skip padded values. + let cannonical_index = bit_reverse_index( + coset_index_to_circle_domain_index( + (self.vec_row << LOG_N_LANES) + j, self.domain_log_size, - ); - if cannonical_index >= self.n_rows { - continue; - } - let values = values.iter().map(|v| v[j]).collect_vec(); - let mult = mult[j].to_m31_array()[0]; - self.entries.push(RelationTrackerEntry { - relation: relation.clone(), - mult, - values, - }); + ), + self.domain_log_size, + ); + if cannonical_index >= self.n_rows { + continue; } + let values = values.iter().map(|v| v[j]).collect_vec(); + let mult = mult[j].to_m31_array()[0]; + self.entries.push(RelationTrackerEntry { + relation: relation.clone(), + mult, + values, + }); } } } diff --git a/crates/prover/src/examples/blake/mod.rs b/crates/prover/src/examples/blake/mod.rs index ff62f9f7d..76feb7f8b 100644 --- a/crates/prover/src/examples/blake/mod.rs +++ b/crates/prover/src/examples/blake/mod.rs @@ -88,26 +88,26 @@ impl BlakeXorElements { // TODO(alont): Generalize this to variable sizes batches if ever used. fn use_relation(&self, eval: &mut E, w: u32, values: [&[E::F]; 2]) { match w { - 12 => eval.add_to_relation(&[ - RelationEntry::new(&self.xor12, E::EF::one(), values[0]), - RelationEntry::new(&self.xor12, E::EF::one(), values[1]), - ]), - 9 => eval.add_to_relation(&[ - RelationEntry::new(&self.xor9, E::EF::one(), values[0]), - RelationEntry::new(&self.xor9, E::EF::one(), values[1]), - ]), - 8 => eval.add_to_relation(&[ - RelationEntry::new(&self.xor8, E::EF::one(), values[0]), - RelationEntry::new(&self.xor8, E::EF::one(), values[1]), - ]), - 7 => eval.add_to_relation(&[ - RelationEntry::new(&self.xor7, E::EF::one(), values[0]), - RelationEntry::new(&self.xor7, E::EF::one(), values[1]), - ]), - 4 => eval.add_to_relation(&[ - RelationEntry::new(&self.xor4, E::EF::one(), values[0]), - RelationEntry::new(&self.xor4, E::EF::one(), values[1]), - ]), + 12 => { + eval.add_to_relation(RelationEntry::new(&self.xor12, E::EF::one(), values[0])); + eval.add_to_relation(RelationEntry::new(&self.xor12, E::EF::one(), values[1])); + } + 9 => { + eval.add_to_relation(RelationEntry::new(&self.xor9, E::EF::one(), values[0])); + eval.add_to_relation(RelationEntry::new(&self.xor9, E::EF::one(), values[1])); + } + 8 => { + eval.add_to_relation(RelationEntry::new(&self.xor8, E::EF::one(), values[0])); + eval.add_to_relation(RelationEntry::new(&self.xor8, E::EF::one(), values[1])); + } + 7 => { + eval.add_to_relation(RelationEntry::new(&self.xor7, E::EF::one(), values[0])); + eval.add_to_relation(RelationEntry::new(&self.xor7, E::EF::one(), values[1])); + } + 4 => { + eval.add_to_relation(RelationEntry::new(&self.xor4, E::EF::one(), values[0])); + eval.add_to_relation(RelationEntry::new(&self.xor4, E::EF::one(), values[1])); + } _ => panic!("Invalid w"), }; } diff --git a/crates/prover/src/examples/blake/round/constraints.rs b/crates/prover/src/examples/blake/round/constraints.rs index ada5fb287..e15a225df 100644 --- a/crates/prover/src/examples/blake/round/constraints.rs +++ b/crates/prover/src/examples/blake/round/constraints.rs @@ -65,7 +65,7 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { ); // Yield `Round(input_v, output_v, message)`. - self.eval.add_to_relation(&[RelationEntry::new( + self.eval.add_to_relation(RelationEntry::new( self.round_lookup_elements, -E::EF::one(), &chain![ @@ -74,9 +74,9 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { m.iter().cloned().flat_map(Fu32::into_felts) ] .collect_vec(), - )]); + )); - self.eval.finalize_logup(); + self.eval.finalize_logup_in_pairs(); self.eval } fn next_u32(&mut self) -> Fu32 { diff --git a/crates/prover/src/examples/blake/scheduler/constraints.rs b/crates/prover/src/examples/blake/scheduler/constraints.rs index 1bf93d1aa..aceece2e8 100644 --- a/crates/prover/src/examples/blake/scheduler/constraints.rs +++ b/crates/prover/src/examples/blake/scheduler/constraints.rs @@ -30,17 +30,23 @@ pub fn eval_blake_scheduler_constraints( ] .collect_vec() }); - eval.add_to_relation(&[ - RelationEntry::new(round_lookup_elements, E::EF::one(), &elems_i), - RelationEntry::new(round_lookup_elements, E::EF::one(), &elems_j), - ]); + eval.add_to_relation(RelationEntry::new( + round_lookup_elements, + E::EF::one(), + &elems_i, + )); + eval.add_to_relation(RelationEntry::new( + round_lookup_elements, + E::EF::one(), + &elems_j, + )); } let input_state = &states[0]; let output_state = &states[N_ROUNDS]; // TODO(alont): Remove blake interaction. - eval.add_to_relation(&[RelationEntry::new( + eval.add_to_relation(RelationEntry::new( blake_lookup_elements, E::EF::zero(), &chain![ @@ -49,9 +55,9 @@ pub fn eval_blake_scheduler_constraints( messages.iter().cloned().flat_map(Fu32::into_felts) ] .collect_vec(), - )]); + )); - eval.finalize_logup(); + eval.finalize_logup_in_pairs(); } fn eval_next_u32(eval: &mut E) -> Fu32 { diff --git a/crates/prover/src/examples/blake/xor_table/constraints.rs b/crates/prover/src/examples/blake/xor_table/constraints.rs index 4df0a6c63..60fef8bfe 100644 --- a/crates/prover/src/examples/blake/xor_table/constraints.rs +++ b/crates/prover/src/examples/blake/xor_table/constraints.rs @@ -40,43 +40,31 @@ macro_rules! xor_table_eval { 2, )); - let entry_chunks = (0..(1 << (2 * EXPAND_BITS))) - .map(|i| { - let (i, j) = ((i >> EXPAND_BITS) as u32, (i % (1 << EXPAND_BITS)) as u32); - let multiplicity = self.eval.next_trace_mask(); + for i in (0..(1 << (2 * EXPAND_BITS))) { + let (i, j) = ((i >> EXPAND_BITS) as u32, (i % (1 << EXPAND_BITS)) as u32); + let multiplicity = self.eval.next_trace_mask(); - let a = al.clone() - + E::F::from(BaseField::from_u32_unchecked( - i << limb_bits::(), - )); - let b = bl.clone() - + E::F::from(BaseField::from_u32_unchecked( - j << limb_bits::(), - )); - let c = cl.clone() - + E::F::from(BaseField::from_u32_unchecked( - (i ^ j) << limb_bits::(), - )); + let a = al.clone() + + E::F::from(BaseField::from_u32_unchecked( + i << limb_bits::(), + )); + let b = bl.clone() + + E::F::from(BaseField::from_u32_unchecked( + j << limb_bits::(), + )); + let c = cl.clone() + + E::F::from(BaseField::from_u32_unchecked( + (i ^ j) << limb_bits::(), + )); - (self.lookup_elements, -multiplicity, [a, b, c]) - }) - .collect_vec(); - - for entry_chunk in entry_chunks.chunks(2) { - self.eval.add_to_relation( - &entry_chunk - .iter() - .map(|(lookup, multiplicity, values)| { - RelationEntry::new( - *lookup, - E::EF::from(multiplicity.clone()), - values, - ) - }) - .collect_vec(), - ); + self.eval.add_to_relation(RelationEntry::new( + self.lookup_elements, + -E::EF::from(multiplicity), + &[a, b, c], + )); } - self.eval.finalize_logup(); + + self.eval.finalize_logup_in_pairs(); self.eval } } diff --git a/crates/prover/src/examples/plonk/mod.rs b/crates/prover/src/examples/plonk/mod.rs index 49da86f8a..a1e0362c9 100644 --- a/crates/prover/src/examples/plonk/mod.rs +++ b/crates/prover/src/examples/plonk/mod.rs @@ -66,18 +66,24 @@ impl FrameworkEval for PlonkEval { + (E::F::one() - op) * a_val.clone() * b_val.clone(), ); - eval.add_to_relation(&[ - RelationEntry::new(&self.lookup_elements, E::EF::one(), &[a_wire, a_val]), - RelationEntry::new(&self.lookup_elements, E::EF::one(), &[b_wire, b_val]), - ]); + eval.add_to_relation(RelationEntry::new( + &self.lookup_elements, + E::EF::one(), + &[a_wire, a_val], + )); + eval.add_to_relation(RelationEntry::new( + &self.lookup_elements, + E::EF::one(), + &[b_wire, b_val], + )); - eval.add_to_relation(&[RelationEntry::new( + eval.add_to_relation(RelationEntry::new( &self.lookup_elements, (-mult).into(), &[c_wire, c_val], - )]); + )); - eval.finalize_logup(); + eval.finalize_logup_in_pairs(); eval } } diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 51b671580..808dcc74d 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -186,13 +186,15 @@ pub fn eval_poseidon_constraints(eval: &mut E, lookup_elements: &P }); // Provide state lookups. - eval.add_to_relation(&[ - RelationEntry::new(lookup_elements, E::EF::one(), &initial_state), - RelationEntry::new(lookup_elements, -E::EF::one(), &state), - ]) + eval.add_to_relation(RelationEntry::new( + lookup_elements, + E::EF::one(), + &initial_state, + )); + eval.add_to_relation(RelationEntry::new(lookup_elements, -E::EF::one(), &state)); } - eval.finalize_logup(); + eval.finalize_logup_in_pairs(); } pub struct LookupData { diff --git a/crates/prover/src/examples/state_machine/components.rs b/crates/prover/src/examples/state_machine/components.rs index 4600a3cf0..23bcf2977 100644 --- a/crates/prover/src/examples/state_machine/components.rs +++ b/crates/prover/src/examples/state_machine/components.rs @@ -52,12 +52,18 @@ impl FrameworkEval for StateTransitionEval let mut output_state = input_state.clone(); output_state[COORDINATE] += E::F::one(); - eval.add_to_relation(&[ - RelationEntry::new(&self.lookup_elements, E::EF::one(), &input_state), - RelationEntry::new(&self.lookup_elements, -E::EF::one(), &output_state), - ]); - - eval.finalize_logup(); + eval.add_to_relation(RelationEntry::new( + &self.lookup_elements, + E::EF::one(), + &input_state, + )); + eval.add_to_relation(RelationEntry::new( + &self.lookup_elements, + -E::EF::one(), + &output_state, + )); + + eval.finalize_logup_in_pairs(); eval } } From 10a3f69d4eaf0433268a23f1b5e94ce1e71e3500 Mon Sep 17 00:00:00 2001 From: Gali Michlevich Date: Mon, 9 Dec 2024 17:54:42 +0200 Subject: [PATCH 23/69] Delete shifted_secure_combination Function --- crates/prover/src/core/utils.rs | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/crates/prover/src/core/utils.rs b/crates/prover/src/core/utils.rs index 330e49ed2..4561d54ff 100644 --- a/crates/prover/src/core/utils.rs +++ b/crates/prover/src/core/utils.rs @@ -1,7 +1,6 @@ use std::iter::Peekable; -use std::ops::{Add, Mul, Sub}; -use num_traits::{One, Zero}; +use num_traits::One; use super::fields::m31::BaseField; use super::fields::qm31::SecureField; @@ -175,18 +174,6 @@ pub fn generate_secure_powers(felt: SecureField, n_powers: usize) -> Vec(values: &[F], alpha: EF, z: EF) -> EF -where - EF: Copy + Zero + Mul + Add + Sub, -{ - let res = values - .iter() - .fold(EF::zero(), |acc, &value| acc * alpha + value); - res - z -} - #[cfg(test)] mod tests { use itertools::Itertools; From 060f0e4e2b2d32bfad8eabc17ef66292d883e220 Mon Sep 17 00:00:00 2001 From: Gali Michlevich Date: Tue, 10 Dec 2024 10:21:30 +0200 Subject: [PATCH 24/69] Move bit_reverse function to backend/cpu --- crates/prover/benches/bit_rev.rs | 2 +- .../src/constraint_framework/component.rs | 5 +-- crates/prover/src/core/backend/cpu/circle.rs | 3 +- crates/prover/src/core/backend/cpu/mod.rs | 34 ++++++++++++++++++- .../src/core/backend/simd/bit_reverse.rs | 5 +-- crates/prover/src/core/backend/simd/domain.rs | 2 +- .../prover/src/core/backend/simd/fft/ifft.rs | 2 +- .../prover/src/core/backend/simd/fft/rfft.rs | 2 +- .../src/core/backend/simd/prefix_sum.rs | 5 ++- .../prover/src/core/backend/simd/quotients.rs | 2 +- crates/prover/src/core/poly/line.rs | 2 +- crates/prover/src/core/queries.rs | 2 +- crates/prover/src/core/utils.rs | 33 ------------------ .../src/examples/xor/gkr_lookups/mle_eval.rs | 8 +++-- 14 files changed, 55 insertions(+), 52 deletions(-) diff --git a/crates/prover/benches/bit_rev.rs b/crates/prover/benches/bit_rev.rs index 6e287e60f..b15a172c6 100644 --- a/crates/prover/benches/bit_rev.rs +++ b/crates/prover/benches/bit_rev.rs @@ -5,7 +5,7 @@ use itertools::Itertools; use stwo_prover::core::fields::m31::BaseField; pub fn cpu_bit_rev(c: &mut Criterion) { - use stwo_prover::core::utils::bit_reverse; + use stwo_prover::core::backend::cpu::bit_reverse; // TODO(andrew): Consider using same size for all. const SIZE: usize = 1 << 24; let data = (0..SIZE).map(BaseField::from).collect_vec(); diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index 23981cc06..8f082f5f7 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -17,6 +17,7 @@ use super::{ }; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use crate::core::air::{Component, ComponentProver, Trace}; +use crate::core::backend::cpu::bit_reverse; use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords; use crate::core::backend::simd::m31::LOG_N_LANES; use crate::core::backend::simd::very_packed_m31::{VeryPackedBaseField, LOG_N_VERY_PACKED_ELEMS}; @@ -30,7 +31,7 @@ use crate::core::fields::FieldExpOps; use crate::core::pcs::{TreeSubspan, TreeVec}; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; -use crate::core::{utils, ColumnVec}; +use crate::core::ColumnVec; const CHUNK_SIZE: usize = 1; @@ -292,7 +293,7 @@ impl ComponentProver for FrameworkComponen let mut denom_inv = (0..1 << log_expand) .map(|i| coset_vanishing(trace_domain.coset(), eval_domain.at(i)).inverse()) .collect_vec(); - utils::bit_reverse(&mut denom_inv); + bit_reverse(&mut denom_inv); // Accumulator. let [mut accum] = diff --git a/crates/prover/src/core/backend/cpu/circle.rs b/crates/prover/src/core/backend/cpu/circle.rs index 6e94078cd..21351d164 100644 --- a/crates/prover/src/core/backend/cpu/circle.rs +++ b/crates/prover/src/core/backend/cpu/circle.rs @@ -1,6 +1,7 @@ use num_traits::Zero; use super::CpuBackend; +use crate::core::backend::cpu::bit_reverse; use crate::core::backend::{Col, ColumnOps}; use crate::core::circle::{CirclePoint, Coset}; use crate::core::fft::{butterfly, ibutterfly}; @@ -13,7 +14,7 @@ use crate::core::poly::circle::{ use crate::core::poly::twiddles::TwiddleTree; use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold}; use crate::core::poly::BitReversedOrder; -use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order}; +use crate::core::utils::coset_order_to_circle_domain_order; impl PolyOps for CpuBackend { type Twiddles = Vec; diff --git a/crates/prover/src/core/backend/cpu/mod.rs b/crates/prover/src/core/backend/cpu/mod.rs index cfa514e4c..ea6e49c07 100644 --- a/crates/prover/src/core/backend/cpu/mod.rs +++ b/crates/prover/src/core/backend/cpu/mod.rs @@ -16,7 +16,7 @@ use super::{Backend, BackendForChannel, Column, ColumnOps, FieldOps}; use crate::core::fields::Field; use crate::core::lookups::mle::Mle; use crate::core::poly::circle::{CircleEvaluation, CirclePoly}; -use crate::core::utils::bit_reverse; +use crate::core::utils::bit_reverse_index; use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; #[cfg(not(target_arch = "wasm32"))] use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleChannel; @@ -29,6 +29,23 @@ impl BackendForChannel for CpuBackend {} #[cfg(not(target_arch = "wasm32"))] impl BackendForChannel for CpuBackend {} +/// Performs a naive bit-reversal permutation inplace. +/// +/// # Panics +/// +/// Panics if the length of the slice is not a power of two. +pub fn bit_reverse(v: &mut [T]) { + let n = v.len(); + assert!(n.is_power_of_two()); + let log_n = n.ilog2(); + for i in 0..n { + let j = bit_reverse_index(i, log_n); + if j > i { + v.swap(i, j); + } + } +} + impl ColumnOps for CpuBackend { type Column = Vec; @@ -79,10 +96,25 @@ mod tests { use rand::prelude::*; use rand::rngs::SmallRng; + use crate::core::backend::cpu::bit_reverse; use crate::core::backend::{Column, CpuBackend, FieldOps}; use crate::core::fields::qm31::QM31; use crate::core::fields::FieldExpOps; + #[test] + fn bit_reverse_works() { + let mut data = [0, 1, 2, 3, 4, 5, 6, 7]; + bit_reverse(&mut data); + assert_eq!(data, [0, 4, 2, 6, 1, 5, 3, 7]); + } + + #[test] + #[should_panic] + fn bit_reverse_non_power_of_two_size_fails() { + let mut data = [0, 1, 2, 3, 4, 5]; + bit_reverse(&mut data); + } + #[test] fn batch_inverse_test() { let mut rng = SmallRng::seed_from_u64(0); diff --git a/crates/prover/src/core/backend/simd/bit_reverse.rs b/crates/prover/src/core/backend/simd/bit_reverse.rs index 1c2418d7d..cc2a55c98 100644 --- a/crates/prover/src/core/backend/simd/bit_reverse.rs +++ b/crates/prover/src/core/backend/simd/bit_reverse.rs @@ -6,11 +6,12 @@ use rayon::prelude::*; use super::column::{BaseColumn, SecureColumn}; use super::m31::PackedBaseField; use super::SimdBackend; +use crate::core::backend::cpu::bit_reverse as cpu_bit_reverse; use crate::core::backend::simd::utils::UnsafeMut; use crate::core::backend::ColumnOps; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; -use crate::core::utils::{bit_reverse as cpu_bit_reverse, bit_reverse_index}; +use crate::core::utils::bit_reverse_index; use crate::parallel_iter; const VEC_BITS: u32 = 4; @@ -150,12 +151,12 @@ mod tests { use itertools::Itertools; use super::{bit_reverse16, bit_reverse_m31, MIN_LOG_SIZE}; + use crate::core::backend::cpu::bit_reverse as cpu_bit_reverse; use crate::core::backend::simd::column::BaseColumn; use crate::core::backend::simd::m31::{PackedM31, N_LANES}; use crate::core::backend::simd::SimdBackend; use crate::core::backend::{Column, ColumnOps}; use crate::core::fields::m31::BaseField; - use crate::core::utils::bit_reverse as cpu_bit_reverse; #[test] fn test_bit_reverse16() { diff --git a/crates/prover/src/core/backend/simd/domain.rs b/crates/prover/src/core/backend/simd/domain.rs index 209314175..d27cf2396 100644 --- a/crates/prover/src/core/backend/simd/domain.rs +++ b/crates/prover/src/core/backend/simd/domain.rs @@ -73,7 +73,7 @@ fn test_circle_domain_bit_rev_iterator() { 5, )); let mut expected = domain.iter().collect::>(); - crate::core::utils::bit_reverse(&mut expected); + crate::core::backend::cpu::bit_reverse(&mut expected); let actual = CircleDomainBitRevIterator::new(domain) .flat_map(|c| -> [_; 16] { std::array::from_fn(|i| CirclePoint { diff --git a/crates/prover/src/core/backend/simd/fft/ifft.rs b/crates/prover/src/core/backend/simd/fft/ifft.rs index 77b096d9c..a6abb48e0 100644 --- a/crates/prover/src/core/backend/simd/fft/ifft.rs +++ b/crates/prover/src/core/backend/simd/fft/ifft.rs @@ -9,11 +9,11 @@ use rayon::prelude::*; use super::{ compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE, }; +use crate::core::backend::cpu::bit_reverse; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::backend::simd::utils::UnsafeMut; use crate::core::circle::Coset; use crate::core::fields::FieldExpOps; -use crate::core::utils::bit_reverse; use crate::parallel_iter; /// Performs an Inverse Circle Fast Fourier Transform (ICFFT) on the given values. diff --git a/crates/prover/src/core/backend/simd/fft/rfft.rs b/crates/prover/src/core/backend/simd/fft/rfft.rs index d28c8a00d..1249b11e4 100644 --- a/crates/prover/src/core/backend/simd/fft/rfft.rs +++ b/crates/prover/src/core/backend/simd/fft/rfft.rs @@ -10,10 +10,10 @@ use rayon::prelude::*; use super::{ compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE, }; +use crate::core::backend::cpu::bit_reverse; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::backend::simd::utils::{UnsafeConst, UnsafeMut}; use crate::core::circle::Coset; -use crate::core::utils::bit_reverse; use crate::parallel_iter; /// Performs a Circle Fast Fourier Transform (CFFT) on the given values. diff --git a/crates/prover/src/core/backend/simd/prefix_sum.rs b/crates/prover/src/core/backend/simd/prefix_sum.rs index 652b484a1..8e7f07cdf 100644 --- a/crates/prover/src/core/backend/simd/prefix_sum.rs +++ b/crates/prover/src/core/backend/simd/prefix_sum.rs @@ -4,13 +4,12 @@ use std::ops::{AddAssign, Sub}; use itertools::{izip, Itertools}; use num_traits::Zero; +use crate::core::backend::cpu::bit_reverse; use crate::core::backend::simd::m31::{PackedBaseField, N_LANES}; use crate::core::backend::simd::SimdBackend; use crate::core::backend::{Col, Column}; use crate::core::fields::m31::BaseField; -use crate::core::utils::{ - bit_reverse, circle_domain_order_to_coset_order, coset_order_to_circle_domain_order, -}; +use crate::core::utils::{circle_domain_order_to_coset_order, coset_order_to_circle_domain_order}; /// Performs a inclusive prefix sum on values in `Coset` order when provided /// with evaluations in bit-reversed `CircleDomain` order. diff --git a/crates/prover/src/core/backend/simd/quotients.rs b/crates/prover/src/core/backend/simd/quotients.rs index 553e540a5..bac374292 100644 --- a/crates/prover/src/core/backend/simd/quotients.rs +++ b/crates/prover/src/core/backend/simd/quotients.rs @@ -8,6 +8,7 @@ use super::domain::CircleDomainBitRevIterator; use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; use super::qm31::PackedSecureField; use super::SimdBackend; +use crate::core::backend::cpu::bit_reverse; use crate::core::backend::cpu::quotients::{batch_random_coeffs, column_line_coeffs}; use crate::core::backend::{Column, CpuBackend}; use crate::core::fields::m31::BaseField; @@ -17,7 +18,6 @@ use crate::core::fields::FieldExpOps; use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps}; use crate::core::poly::circle::{CircleDomain, CircleEvaluation, PolyOps, SecureEvaluation}; use crate::core::poly::BitReversedOrder; -use crate::core::utils::bit_reverse; pub struct QuotientConstants { pub line_coeffs: Vec>, diff --git a/crates/prover/src/core/poly/line.rs b/crates/prover/src/core/poly/line.rs index 2b58e01fb..9a8a4cf6d 100644 --- a/crates/prover/src/core/poly/line.rs +++ b/crates/prover/src/core/poly/line.rs @@ -9,6 +9,7 @@ use serde::{Deserialize, Serialize}; use super::circle::CircleDomain; use super::utils::fold; +use crate::core::backend::cpu::bit_reverse; use crate::core::backend::{ColumnOps, CpuBackend}; use crate::core::circle::{CirclePoint, Coset, CosetIterator}; use crate::core::fft::ibutterfly; @@ -16,7 +17,6 @@ use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SecureColumnByCoords; use crate::core::fields::{ExtensionOf, FieldExpOps, FieldOps}; -use crate::core::utils::bit_reverse; /// Domain comprising of the x-coordinates of points in a [Coset]. /// diff --git a/crates/prover/src/core/queries.rs b/crates/prover/src/core/queries.rs index cd9546e6a..946380000 100644 --- a/crates/prover/src/core/queries.rs +++ b/crates/prover/src/core/queries.rs @@ -70,10 +70,10 @@ impl Deref for Queries { #[cfg(test)] mod tests { + use crate::core::backend::cpu::bit_reverse; use crate::core::channel::Blake2sChannel; use crate::core::poly::circle::CanonicCoset; use crate::core::queries::Queries; - use crate::core::utils::bit_reverse; #[test] fn test_generate_queries() { diff --git a/crates/prover/src/core/utils.rs b/crates/prover/src/core/utils.rs index 4561d54ff..745168c38 100644 --- a/crates/prover/src/core/utils.rs +++ b/crates/prover/src/core/utils.rs @@ -129,24 +129,6 @@ pub const fn coset_index_to_circle_domain_index(coset_index: usize, log_domain_s } } -/// Performs a naive bit-reversal permutation inplace. -/// -/// # Panics -/// -/// Panics if the length of the slice is not a power of two. -// TODO(alont): Move this to the cpu backend. -pub fn bit_reverse(v: &mut [T]) { - let n = v.len(); - assert!(n.is_power_of_two()); - let log_n = n.ilog2(); - for i in 0..n { - let j = bit_reverse_index(i, log_n); - if j > i { - v.swap(i, j); - } - } -} - /// Performs a coset-natural-order to circle-domain-bit-reversed-order permutation in-place. /// /// # Panics @@ -187,23 +169,8 @@ mod tests { use crate::core::fields::FieldExpOps; use crate::core::poly::circle::CanonicCoset; use crate::core::poly::NaturalOrder; - use crate::core::utils::bit_reverse; use crate::{m31, qm31}; - #[test] - fn bit_reverse_works() { - let mut data = [0, 1, 2, 3, 4, 5, 6, 7]; - bit_reverse(&mut data); - assert_eq!(data, [0, 4, 2, 6, 1, 5, 3, 7]); - } - - #[test] - #[should_panic] - fn bit_reverse_non_power_of_two_size_fails() { - let mut data = [0, 1, 2, 3, 4, 5]; - bit_reverse(&mut data); - } - #[test] fn generate_secure_powers_works() { let felt = qm31!(1, 2, 3, 4); diff --git a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs index cc7ebdaec..79efe5998 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs @@ -14,6 +14,7 @@ use crate::constraint_framework::{ }; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use crate::core::air::{Component, ComponentProver, Trace}; +use crate::core::backend::cpu::bit_reverse; use crate::core::backend::simd::column::{SecureColumn, VeryPackedSecureColumnByCoords}; use crate::core::backend::simd::m31::LOG_N_LANES; use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum; @@ -36,7 +37,7 @@ use crate::core::poly::circle::{ }; use crate::core::poly::twiddles::TwiddleTree; use crate::core::poly::BitReversedOrder; -use crate::core::utils::{self, bit_reverse_index, coset_index_to_circle_domain_index}; +use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; use crate::core::ColumnVec; /// Prover component that carries out a univariate IOP for multilinear eval at point. @@ -231,7 +232,7 @@ impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> ComponentProver let mut denom_inv = (0..1 << log_expand) .map(|i| coset_vanishing(trace_domain.coset(), eval_domain.at(i)).inverse()) .collect_vec(); - utils::bit_reverse(&mut denom_inv); + bit_reverse(&mut denom_inv); // Accumulator. let [mut acc] = accumulator.columns([(eval_domain.log_size(), self.n_constraints())]); @@ -752,6 +753,7 @@ mod tests { }; use crate::constraint_framework::{assert_constraints, EvalAtRow, TraceLocationAllocator}; use crate::core::air::{Component, ComponentProver, Components}; + use crate::core::backend::cpu::bit_reverse; use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum; use crate::core::backend::simd::qm31::PackedSecureField; use crate::core::backend::simd::SimdBackend; @@ -765,7 +767,7 @@ mod tests { use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; use crate::core::prover::{prove, verify, VerificationError}; - use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order}; + use crate::core::utils::coset_order_to_circle_domain_order; use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; use crate::examples::xor::gkr_lookups::accumulation::MIN_LOG_BLOWUP_FACTOR; use crate::examples::xor::gkr_lookups::mle_eval::eval_step_selector_with_offset; From 8550b7d55d69e75d5fce5977011d97dadd9564e2 Mon Sep 17 00:00:00 2001 From: Alon-Ti <54235977+Alon-Ti@users.noreply.github.com> Date: Tue, 10 Dec 2024 13:34:58 +0200 Subject: [PATCH 25/69] =?UTF-8?q?Expression=20evaluator:=20Ordnung=20mu?= =?UTF-8?q?=C3=9F=20sein.=20(#927)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../prover/src/constraint_framework/expr.rs | 1135 ----------------- .../constraint_framework/expr/assignment.rs | 267 ++++ .../constraint_framework/expr/evaluator.rs | 244 ++++ .../src/constraint_framework/expr/format.rs | 64 + .../src/constraint_framework/expr/mod.rs | 351 +++++ .../src/constraint_framework/expr/simplify.rs | 217 ++++ .../src/constraint_framework/expr/utils.rs | 65 + crates/prover/src/constraint_framework/mod.rs | 6 +- 8 files changed, 1212 insertions(+), 1137 deletions(-) delete mode 100644 crates/prover/src/constraint_framework/expr.rs create mode 100644 crates/prover/src/constraint_framework/expr/assignment.rs create mode 100644 crates/prover/src/constraint_framework/expr/evaluator.rs create mode 100644 crates/prover/src/constraint_framework/expr/format.rs create mode 100644 crates/prover/src/constraint_framework/expr/mod.rs create mode 100644 crates/prover/src/constraint_framework/expr/simplify.rs create mode 100644 crates/prover/src/constraint_framework/expr/utils.rs diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs deleted file mode 100644 index 12c841b51..000000000 --- a/crates/prover/src/constraint_framework/expr.rs +++ /dev/null @@ -1,1135 +0,0 @@ -use std::collections::{HashMap, HashSet}; -use std::hash::{DefaultHasher, Hash, Hasher}; -use std::ops::{Add, AddAssign, Index, Mul, MulAssign, Neg, Sub}; - -use itertools::sorted; -use num_traits::{One, Zero}; - -use super::preprocessed_columns::PreprocessedColumn; -use super::{AssertEvaluator, EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX}; -use crate::core::fields::cm31::CM31; -use crate::core::fields::m31::{self, BaseField}; -use crate::core::fields::qm31::{SecureField, QM31}; -use crate::core::fields::FieldExpOps; -use crate::core::lookups::utils::Fraction; - -/// A single base field column at index `idx` of interaction `interaction`, at mask offset `offset`. -#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct ColumnExpr { - interaction: usize, - idx: usize, - offset: isize, -} - -impl From<(usize, usize, isize)> for ColumnExpr { - fn from((interaction, idx, offset): (usize, usize, isize)) -> Self { - Self { - interaction, - idx, - offset, - } - } -} - -/// An expression representing a base field value. Can be either: -/// * A column indexed by a `ColumnExpr`. -/// * A base field constant. -/// * A formal parameter to the AIR. -/// * A sum, difference, or product of two base field expressions. -/// * A negation or inverse of a base field expression. -/// -/// This type is meant to be used as an F associated type for EvalAtRow and interacts with -/// `ExtExpr`, `BaseField` and `SecureField` as expected. -#[derive(Clone, Debug, PartialEq)] -pub enum BaseExpr { - Col(ColumnExpr), - Const(BaseField), - /// Formal parameter to the AIR, for example the interaction elements of a relation. - Param(String), - Add(Box, Box), - Sub(Box, Box), - Mul(Box, Box), - Neg(Box), - Inv(Box), -} - -/// An expression representing a secure field value. Can be either: -/// * A secure column constructed from 4 base field expressions. -/// * A secure field constant. -/// * A formal parameter to the AIR. -/// * A sum, difference, or product of two secure field expressions. -/// * A negation of a secure field expression. -/// -/// This type is meant to be used as an EF associated type for EvalAtRow and interacts with -/// `BaseExpr`, `BaseField` and `SecureField` as expected. -#[derive(Clone, Debug, PartialEq)] -pub enum ExtExpr { - /// An atomic secure column constructed from 4 expressions. - /// Expressions on the secure column are not reduced, i.e, - /// if `a = SecureCol(a0, a1, a2, a3)`, `b = SecureCol(b0, b1, b2, b3)` then - /// `a + b` evaluates to `Add(a, b)` rather than - /// `SecureCol(Add(a0, b0), Add(a1, b1), Add(a2, b2), Add(a3, b3))` - SecureCol([Box; 4]), - Const(SecureField), - /// Formal parameter to the AIR, for example the interaction elements of a relation. - Param(String), - Add(Box, Box), - Sub(Box, Box), - Mul(Box, Box), - Neg(Box), -} - -/// Applies simplifications to arithmetic expressions that can be used both for `BaseExpr` and for -/// `ExtExpr`. -macro_rules! simplify_arithmetic { - ($self:tt) => { - match $self.clone() { - Self::Add(a, b) => { - let a = a.simplify(); - let b = b.simplify(); - match (a.clone(), b.clone()) { - // Simplify constants. - (Self::Const(a), Self::Const(b)) => Self::Const(a + b), - (Self::Const(a_val), _) if a_val.is_zero() => b, // 0 + b = b - (_, Self::Const(b_val)) if b_val.is_zero() => a, // a + 0 = a - // Simplify Negs. - // (-a + -b) = -(a + b) - (Self::Neg(minus_a), Self::Neg(minus_b)) => -(*minus_a + *minus_b), - (Self::Neg(minus_a), _) => b - *minus_a, // -a + b = b - a - (_, Self::Neg(minus_b)) => a - *minus_b, // a + -b = a - b - // No simplification. - _ => a + b, - } - } - Self::Sub(a, b) => { - let a = a.simplify(); - let b = b.simplify(); - match (a.clone(), b.clone()) { - // Simplify constants. - (Self::Const(a), Self::Const(b)) => Self::Const(a - b), // Simplify consts. - (Self::Const(a_val), _) if a_val.is_zero() => -b, // 0 - b = -b - (_, Self::Const(b_val)) if b_val.is_zero() => a, // a - 0 = a - // Simplify Negs. - // (-a - -b) = b - a - (Self::Neg(minus_a), Self::Neg(minus_b)) => *minus_b - *minus_a, - (Self::Neg(minus_a), _) => -(*minus_a + b), // -a - b = -(a + b) - (_, Self::Neg(minus_b)) => a + *minus_b, // a + -b = a - b - // No Simplification. - _ => a - b, - } - } - Self::Mul(a, b) => { - let a = a.simplify(); - let b = b.simplify(); - match (a.clone(), b.clone()) { - // Simplify consts. - (Self::Const(a), Self::Const(b)) => Self::Const(a * b), - (Self::Const(a_val), _) if a_val.is_zero() => Self::zero(), // 0 * b = 0 - (_, Self::Const(b_val)) if b_val.is_zero() => Self::zero(), // a * 0 = 0 - (Self::Const(a_val), _) if a_val == One::one() => b, // 1 * b = b - (_, Self::Const(b_val)) if b_val == One::one() => a, // a * 1 = a - (Self::Const(a_val), _) if -a_val == One::one() => -b, // -1 * b = -b - (_, Self::Const(b_val)) if -b_val == One::one() => -a, // a * -1 = -a - // Simplify Negs. - // (-a) * (-b) = a * b - (Self::Neg(minus_a), Self::Neg(minus_b)) => *minus_a * *minus_b, - (Self::Neg(minus_a), _) => -(*minus_a * b), // (-a) * b = -(a * b) - (_, Self::Neg(minus_b)) => -(a * *minus_b), // a * (-b) = -(a * b) - // No simplification. - _ => a * b, - } - } - Self::Neg(a) => { - let a = a.simplify(); - match a { - Self::Const(c) => Self::Const(-c), - Self::Neg(minus_a) => *minus_a, // -(-a) = a - Self::Sub(a, b) => Self::Sub(b, a), // -(a - b) = b - a - _ => -a, // No simplification. - } - } - other => other, // No simplification. - } - }; -} - -impl BaseExpr { - pub fn format_expr(&self) -> String { - match self { - BaseExpr::Col(ColumnExpr { - interaction, - idx, - offset, - }) => { - let offset_str = if *offset == CLAIMED_SUM_DUMMY_OFFSET as isize { - "claimed_sum_offset".to_string() - } else { - offset.to_string() - }; - format!("col_{interaction}_{idx}[{offset_str}]") - } - BaseExpr::Const(c) => c.to_string(), - BaseExpr::Param(v) => v.to_string(), - BaseExpr::Add(a, b) => format!("{} + {}", a.format_expr(), b.format_expr()), - BaseExpr::Sub(a, b) => format!("{} - ({})", a.format_expr(), b.format_expr()), - BaseExpr::Mul(a, b) => format!("({}) * ({})", a.format_expr(), b.format_expr()), - BaseExpr::Neg(a) => format!("-({})", a.format_expr()), - BaseExpr::Inv(a) => format!("1 / ({})", a.format_expr()), - } - } - - /// Helper function, use [`simplify`] instead. - /// - /// Simplifies an expression by applying basic arithmetic rules. - fn unchecked_simplify(&self) -> Self { - let simple = simplify_arithmetic!(self); - match simple { - Self::Inv(a) => { - let a = a.unchecked_simplify(); - match a { - Self::Inv(inv_a) => *inv_a, // 1 / (1 / a) = a - Self::Const(c) => Self::Const(c.inverse()), - _ => Self::Inv(Box::new(a)), - } - } - other => other, - } - } - - /// Simplifies an expression by applying basic arithmetic rules and ensures that the result is - /// equivalent to the original expression by assigning random values. - pub fn simplify(&self) -> Self { - let simplified = self.unchecked_simplify(); - assert_eq!(self.random_eval(), simplified.random_eval()); - simplified - } - - pub fn simplify_and_format(&self) -> String { - self.simplify().format_expr() - } - - /// Evaluates a base field expression. - /// Takes: - /// * `columns`: A mapping from triplets (interaction, idx, offset) to base field values. - /// * `vars`: A mapping from variable names to base field values. - pub fn eval_expr(&self, columns: &C, vars: &V) -> E::F - where - C: for<'a> Index<&'a (usize, usize, isize), Output = E::F>, - V: for<'a> Index<&'a String, Output = E::F>, - E: EvalAtRow, - { - match self { - Self::Col(col) => columns[&(col.interaction, col.idx, col.offset)].clone(), - Self::Const(c) => E::F::from(*c), - Self::Param(var) => vars[&var.to_string()].clone(), - Self::Add(a, b) => { - a.eval_expr::(columns, vars) + b.eval_expr::(columns, vars) - } - Self::Sub(a, b) => { - a.eval_expr::(columns, vars) - b.eval_expr::(columns, vars) - } - Self::Mul(a, b) => { - a.eval_expr::(columns, vars) * b.eval_expr::(columns, vars) - } - Self::Neg(a) => -a.eval_expr::(columns, vars), - Self::Inv(a) => a.eval_expr::(columns, vars).inverse(), - } - } - - pub fn collect_variables(&self) -> ExprVariables { - match self { - BaseExpr::Col(col) => ExprVariables::col(col.clone()), - BaseExpr::Const(_) => ExprVariables::default(), - BaseExpr::Param(param) => ExprVariables::param(param.to_string()), - BaseExpr::Add(a, b) => a.collect_variables() + b.collect_variables(), - BaseExpr::Sub(a, b) => a.collect_variables() + b.collect_variables(), - BaseExpr::Mul(a, b) => a.collect_variables() + b.collect_variables(), - BaseExpr::Neg(a) => a.collect_variables(), - BaseExpr::Inv(a) => a.collect_variables(), - } - } - - pub fn random_eval(&self) -> BaseField { - let assignment = self.collect_variables().random_assignment(0); - assert!(assignment.2.is_empty()); - self.eval_expr::, _, _>(&assignment.0, &assignment.1) - } -} - -impl ExtExpr { - pub fn format_expr(&self) -> String { - match self { - ExtExpr::SecureCol([a, b, c, d]) => { - // If the expression's non-base components are all constant zeroes, return the base - // field representation of its first part. - if **b == BaseExpr::zero() && **c == BaseExpr::zero() && **d == BaseExpr::zero() { - a.format_expr() - } else { - format!( - "SecureCol({}, {}, {}, {})", - a.format_expr(), - b.format_expr(), - c.format_expr(), - d.format_expr() - ) - } - } - ExtExpr::Const(c) => { - if c.0 .1.is_zero() && c.1 .0.is_zero() && c.1 .1.is_zero() { - // If the constant is in the base field, display it as such. - c.0 .0.to_string() - } else { - c.to_string() - } - } - ExtExpr::Param(v) => v.to_string(), - ExtExpr::Add(a, b) => format!("{} + {}", a.format_expr(), b.format_expr()), - ExtExpr::Sub(a, b) => format!("{} - ({})", a.format_expr(), b.format_expr()), - ExtExpr::Mul(a, b) => format!("({}) * ({})", a.format_expr(), b.format_expr()), - ExtExpr::Neg(a) => format!("-({})", a.format_expr()), - } - } - - /// Helper function, use [`simplify`] instead. - /// - /// Simplifies an expression by applying basic arithmetic rules. - fn unchecked_simplify(&self) -> Self { - let simple = simplify_arithmetic!(self); - match simple { - Self::SecureCol([a, b, c, d]) => { - let a = a.unchecked_simplify(); - let b = b.unchecked_simplify(); - let c = c.unchecked_simplify(); - let d = d.unchecked_simplify(); - match (a.clone(), b.clone(), c.clone(), d.clone()) { - ( - BaseExpr::Const(a_val), - BaseExpr::Const(b_val), - BaseExpr::Const(c_val), - BaseExpr::Const(d_val), - ) => ExtExpr::Const(SecureField::from_m31_array([a_val, b_val, c_val, d_val])), - _ => Self::SecureCol([Box::new(a), Box::new(b), Box::new(c), Box::new(d)]), - } - } - other => other, - } - } - - /// Simplifies an expression by applying basic arithmetic rules and ensures that the result is - /// equivalent to the original expression by assigning random values. - pub fn simplify(&self) -> Self { - let simplified = self.unchecked_simplify(); - assert_eq!(self.random_eval(), simplified.random_eval()); - simplified - } - - pub fn simplify_and_format(&self) -> String { - self.simplify().format_expr() - } - - /// Evaluates an extension field expression. - /// Takes: - /// * `columns`: A mapping from triplets (interaction, idx, offset) to base field values. - /// * `vars`: A mapping from variable names to base field values. - /// * `ext_vars`: A mapping from variable names to extension field values. - pub fn eval_expr(&self, columns: &C, vars: &V, ext_vars: &EV) -> E::EF - where - C: for<'a> Index<&'a (usize, usize, isize), Output = E::F>, - V: for<'a> Index<&'a String, Output = E::F>, - EV: for<'a> Index<&'a String, Output = E::EF>, - E: EvalAtRow, - { - match self { - Self::SecureCol([a, b, c, d]) => { - let a = a.eval_expr::(columns, vars); - let b = b.eval_expr::(columns, vars); - let c = c.eval_expr::(columns, vars); - let d = d.eval_expr::(columns, vars); - E::combine_ef([a, b, c, d]) - } - Self::Const(c) => E::EF::from(*c), - Self::Param(var) => ext_vars[&var.to_string()].clone(), - Self::Add(a, b) => { - a.eval_expr::(columns, vars, ext_vars) - + b.eval_expr::(columns, vars, ext_vars) - } - Self::Sub(a, b) => { - a.eval_expr::(columns, vars, ext_vars) - - b.eval_expr::(columns, vars, ext_vars) - } - Self::Mul(a, b) => { - a.eval_expr::(columns, vars, ext_vars) - * b.eval_expr::(columns, vars, ext_vars) - } - Self::Neg(a) => -a.eval_expr::(columns, vars, ext_vars), - } - } - - pub fn collect_variables(&self) -> ExprVariables { - match self { - ExtExpr::SecureCol([a, b, c, d]) => { - a.collect_variables() - + b.collect_variables() - + c.collect_variables() - + d.collect_variables() - } - ExtExpr::Const(_) => ExprVariables::default(), - ExtExpr::Param(param) => ExprVariables::ext_param(param.to_string()), - ExtExpr::Add(a, b) => a.collect_variables() + b.collect_variables(), - ExtExpr::Sub(a, b) => a.collect_variables() + b.collect_variables(), - ExtExpr::Mul(a, b) => a.collect_variables() + b.collect_variables(), - ExtExpr::Neg(a) => a.collect_variables(), - } - } - - pub fn random_eval(&self) -> SecureField { - let assignment = self.collect_variables().random_assignment(0); - self.eval_expr::, _, _, _>(&assignment.0, &assignment.1, &assignment.2) - } -} - -/// An assignment to the variables that may appear in an expression. -pub type ExprVarAssignment = ( - HashMap<(usize, usize, isize), BaseField>, - HashMap, - HashMap, -); - -/// Three sets representing all the variables that can appear in an expression: -/// * `cols`: The columns of the AIR. -/// * `params`: The formal parameters to the AIR. -/// * `ext_params`: The extension field parameters to the AIR. -#[derive(Default)] -pub struct ExprVariables { - pub cols: HashSet, - pub params: HashSet, - pub ext_params: HashSet, -} - -impl ExprVariables { - pub fn col(col: ColumnExpr) -> Self { - Self { - cols: vec![col].into_iter().collect(), - params: HashSet::new(), - ext_params: HashSet::new(), - } - } - - pub fn param(param: String) -> Self { - Self { - cols: HashSet::new(), - params: vec![param].into_iter().collect(), - ext_params: HashSet::new(), - } - } - - pub fn ext_param(param: String) -> Self { - Self { - cols: HashSet::new(), - params: HashSet::new(), - ext_params: vec![param].into_iter().collect(), - } - } - - /// Generates a random assignment to the variables. - /// Note that the assignment is deterministically dependent on every variable and that this is - /// required. - pub fn random_assignment(&self, salt: usize) -> ExprVarAssignment { - let cols = sorted(self.cols.iter()) - .map(|col| { - ((col.interaction, col.idx, col.offset), { - let mut hasher = DefaultHasher::new(); - (salt, col).hash(&mut hasher); - (hasher.finish() as u32).into() - }) - }) - .collect(); - - let params = sorted(self.params.iter()) - .map(|param| { - (param.clone(), { - let mut hasher = DefaultHasher::new(); - (salt, param).hash(&mut hasher); - (hasher.finish() as u32).into() - }) - }) - .collect(); - - let ext_params = sorted(self.ext_params.iter()) - .map(|param| { - (param.clone(), { - let mut hasher = DefaultHasher::new(); - (salt, param).hash(&mut hasher); - (hasher.finish() as u32).into() - }) - }) - .collect(); - - (cols, params, ext_params) - } -} - -impl Add for ExprVariables { - type Output = Self; - fn add(self, rhs: Self) -> Self { - Self { - cols: self.cols.union(&rhs.cols).cloned().collect(), - params: self.params.union(&rhs.params).cloned().collect(), - ext_params: self.ext_params.union(&rhs.ext_params).cloned().collect(), - } - } -} - -impl From for BaseExpr { - fn from(val: BaseField) -> Self { - BaseExpr::Const(val) - } -} - -impl From for ExtExpr { - fn from(val: BaseField) -> Self { - ExtExpr::SecureCol([ - Box::new(BaseExpr::from(val)), - Box::new(BaseExpr::zero()), - Box::new(BaseExpr::zero()), - Box::new(BaseExpr::zero()), - ]) - } -} - -impl From for ExtExpr { - fn from(QM31(CM31(a, b), CM31(c, d)): SecureField) -> Self { - ExtExpr::SecureCol([ - Box::new(BaseExpr::from(a)), - Box::new(BaseExpr::from(b)), - Box::new(BaseExpr::from(c)), - Box::new(BaseExpr::from(d)), - ]) - } -} - -impl From for ExtExpr { - fn from(expr: BaseExpr) -> Self { - ExtExpr::SecureCol([ - Box::new(expr.clone()), - Box::new(BaseExpr::zero()), - Box::new(BaseExpr::zero()), - Box::new(BaseExpr::zero()), - ]) - } -} - -impl Add for BaseExpr { - type Output = Self; - fn add(self, rhs: Self) -> Self { - BaseExpr::Add(Box::new(self), Box::new(rhs)) - } -} - -impl Sub for BaseExpr { - type Output = Self; - fn sub(self, rhs: Self) -> Self { - BaseExpr::Sub(Box::new(self), Box::new(rhs)) - } -} - -impl Mul for BaseExpr { - type Output = Self; - fn mul(self, rhs: Self) -> Self { - BaseExpr::Mul(Box::new(self), Box::new(rhs)) - } -} - -impl AddAssign for BaseExpr { - fn add_assign(&mut self, rhs: Self) { - *self = self.clone() + rhs - } -} - -impl MulAssign for BaseExpr { - fn mul_assign(&mut self, rhs: Self) { - *self = self.clone() * rhs - } -} - -impl Neg for BaseExpr { - type Output = Self; - fn neg(self) -> Self { - BaseExpr::Neg(Box::new(self)) - } -} - -impl Add for ExtExpr { - type Output = Self; - fn add(self, rhs: Self) -> Self { - ExtExpr::Add(Box::new(self), Box::new(rhs)) - } -} - -impl Sub for ExtExpr { - type Output = Self; - fn sub(self, rhs: Self) -> Self { - ExtExpr::Sub(Box::new(self), Box::new(rhs)) - } -} - -impl Mul for ExtExpr { - type Output = Self; - fn mul(self, rhs: Self) -> Self { - ExtExpr::Mul(Box::new(self), Box::new(rhs)) - } -} - -impl AddAssign for ExtExpr { - fn add_assign(&mut self, rhs: Self) { - *self = self.clone() + rhs - } -} - -impl MulAssign for ExtExpr { - fn mul_assign(&mut self, rhs: Self) { - *self = self.clone() * rhs - } -} - -impl Neg for ExtExpr { - type Output = Self; - fn neg(self) -> Self { - ExtExpr::Neg(Box::new(self)) - } -} - -impl Zero for BaseExpr { - fn zero() -> Self { - BaseExpr::from(BaseField::zero()) - } - fn is_zero(&self) -> bool { - // TODO(alont): consider replacing `Zero` in the trait bound with a custom trait - // that only has `zero()`. - panic!("Can't check if an expression is zero."); - } -} - -impl One for BaseExpr { - fn one() -> Self { - BaseExpr::from(BaseField::one()) - } -} - -impl Zero for ExtExpr { - fn zero() -> Self { - ExtExpr::from(BaseField::zero()) - } - fn is_zero(&self) -> bool { - // TODO(alont): consider replacing `Zero` in the trait bound with a custom trait - // that only has `zero()`. - panic!("Can't check if an expression is zero."); - } -} - -impl One for ExtExpr { - fn one() -> Self { - ExtExpr::from(BaseField::one()) - } -} - -impl FieldExpOps for BaseExpr { - fn inverse(&self) -> Self { - BaseExpr::Inv(Box::new(self.clone())) - } -} - -impl Add for BaseExpr { - type Output = Self; - fn add(self, rhs: BaseField) -> Self { - self + BaseExpr::from(rhs) - } -} - -impl AddAssign for BaseExpr { - fn add_assign(&mut self, rhs: BaseField) { - *self = self.clone() + BaseExpr::from(rhs) - } -} - -impl Mul for BaseExpr { - type Output = Self; - fn mul(self, rhs: BaseField) -> Self { - self * BaseExpr::from(rhs) - } -} - -impl Mul for BaseExpr { - type Output = ExtExpr; - fn mul(self, rhs: SecureField) -> ExtExpr { - ExtExpr::from(self) * ExtExpr::from(rhs) - } -} - -impl Add for BaseExpr { - type Output = ExtExpr; - fn add(self, rhs: SecureField) -> ExtExpr { - ExtExpr::from(self) + ExtExpr::from(rhs) - } -} - -impl Sub for BaseExpr { - type Output = ExtExpr; - fn sub(self, rhs: SecureField) -> ExtExpr { - ExtExpr::from(self) - ExtExpr::from(rhs) - } -} - -impl Add for ExtExpr { - type Output = Self; - fn add(self, rhs: BaseField) -> Self { - self + ExtExpr::from(rhs) - } -} - -impl AddAssign for ExtExpr { - fn add_assign(&mut self, rhs: BaseField) { - *self = self.clone() + ExtExpr::from(rhs) - } -} - -impl Mul for ExtExpr { - type Output = Self; - fn mul(self, rhs: BaseField) -> Self { - self * ExtExpr::from(rhs) - } -} - -impl Mul for ExtExpr { - type Output = Self; - fn mul(self, rhs: SecureField) -> Self { - self * ExtExpr::from(rhs) - } -} - -impl Add for ExtExpr { - type Output = Self; - fn add(self, rhs: SecureField) -> Self { - self + ExtExpr::from(rhs) - } -} - -impl Sub for ExtExpr { - type Output = Self; - fn sub(self, rhs: SecureField) -> Self { - self - ExtExpr::from(rhs) - } -} - -impl Add for ExtExpr { - type Output = Self; - fn add(self, rhs: BaseExpr) -> Self { - self + ExtExpr::from(rhs) - } -} - -impl Mul for ExtExpr { - type Output = Self; - fn mul(self, rhs: BaseExpr) -> Self { - self * ExtExpr::from(rhs) - } -} - -impl Mul for BaseExpr { - type Output = ExtExpr; - fn mul(self, rhs: ExtExpr) -> ExtExpr { - rhs * self - } -} - -impl Sub for ExtExpr { - type Output = Self; - fn sub(self, rhs: BaseExpr) -> Self { - self - ExtExpr::from(rhs) - } -} - -/// Returns the expression -/// `value[0] * _alpha0 + value[1] * _alpha1 + ... - _z.` -fn combine_formal>(relation: &R, values: &[BaseExpr]) -> ExtExpr { - const Z_SUFFIX: &str = "_z"; - const ALPHA_SUFFIX: &str = "_alpha"; - - let z = ExtExpr::Param(relation.get_name().to_owned() + Z_SUFFIX); - let alpha_powers = (0..relation.get_size()) - .map(|i| ExtExpr::Param(relation.get_name().to_owned() + ALPHA_SUFFIX + &i.to_string())); - values - .iter() - .zip(alpha_powers) - .fold(ExtExpr::zero(), |acc, (value, power)| { - acc + power * value.clone() - }) - - z -} - -pub struct FormalLogupAtRow { - pub interaction: usize, - pub total_sum: ExtExpr, - pub claimed_sum: Option<(ExtExpr, usize)>, - pub fracs: Vec>, - pub is_finalized: bool, - pub is_first: BaseExpr, - pub log_size: u32, -} - -// P is an offset no column can reach, it signifies the variable -// offset, which is an input to the verifier. -const CLAIMED_SUM_DUMMY_OFFSET: usize = m31::P as usize; - -impl FormalLogupAtRow { - pub fn new(interaction: usize, has_partial_sum: bool, log_size: u32) -> Self { - let total_sum_name = "total_sum".to_string(); - let claimed_sum_name = "claimed_sum".to_string(); - - Self { - interaction, - // TODO(alont): Should these be Expr::SecureField? - total_sum: ExtExpr::Param(total_sum_name), - claimed_sum: has_partial_sum - .then_some((ExtExpr::Param(claimed_sum_name), CLAIMED_SUM_DUMMY_OFFSET)), - fracs: vec![], - is_finalized: true, - is_first: BaseExpr::zero(), - log_size, - } - } -} - -/// An Evaluator that saves all constraint expressions. -pub struct ExprEvaluator { - pub cur_var_index: usize, - pub constraints: Vec, - pub logup: FormalLogupAtRow, - pub intermediates: Vec<(String, BaseExpr)>, - pub ext_intermediates: Vec<(String, ExtExpr)>, -} - -impl ExprEvaluator { - pub fn new(log_size: u32, has_partial_sum: bool) -> Self { - Self { - cur_var_index: Default::default(), - constraints: Default::default(), - logup: FormalLogupAtRow::new(INTERACTION_TRACE_IDX, has_partial_sum, log_size), - intermediates: vec![], - ext_intermediates: vec![], - } - } - - pub fn format_constraints(&self) -> String { - let lets_string = self - .intermediates - .iter() - .map(|(name, expr)| format!("let {} = {};", name, expr.simplify_and_format())) - .collect::>() - .join("\n\n"); - - let secure_lets_string = self - .ext_intermediates - .iter() - .map(|(name, expr)| format!("let {} = {};", name, expr.simplify_and_format())) - .collect::>() - .join("\n\n"); - - let constraints_str = self - .constraints - .iter() - .enumerate() - .map(|(i, c)| format!("let constraint_{i} = ") + &c.simplify_and_format() + ";") - .collect::>() - .join("\n\n"); - - [lets_string, secure_lets_string, constraints_str] - .iter() - .filter(|x| !x.is_empty()) - .cloned() - .collect::>() - .join("\n\n") - } -} - -impl EvalAtRow for ExprEvaluator { - // TODO(alont): Should there be a version of this that disallows Secure fields for F? - type F = BaseExpr; - type EF = ExtExpr; - - fn next_interaction_mask( - &mut self, - interaction: usize, - offsets: [isize; N], - ) -> [Self::F; N] { - let res = std::array::from_fn(|i| { - let col = ColumnExpr::from((interaction, self.cur_var_index, offsets[i])); - BaseExpr::Col(col) - }); - self.cur_var_index += 1; - res - } - - fn add_constraint(&mut self, constraint: G) - where - Self::EF: From, - { - self.constraints.push(constraint.into()); - } - - fn combine_ef(values: [Self::F; 4]) -> Self::EF { - ExtExpr::SecureCol([ - Box::new(values[0].clone()), - Box::new(values[1].clone()), - Box::new(values[2].clone()), - Box::new(values[3].clone()), - ]) - } - - fn add_to_relation>( - &mut self, - entry: RelationEntry<'_, Self::F, Self::EF, R>, - ) { - let intermediate = - self.add_extension_intermediate(combine_formal(entry.relation, entry.values)); - let frac = Fraction::new(entry.multiplicity.clone(), intermediate); - self.write_logup_frac(frac); - } - - fn add_intermediate(&mut self, expr: Self::F) -> Self::F { - let name = format!( - "intermediate{}", - self.intermediates.len() + self.ext_intermediates.len() - ); - let intermediate = BaseExpr::Param(name.clone()); - self.intermediates.push((name, expr)); - intermediate - } - - fn add_extension_intermediate(&mut self, expr: Self::EF) -> Self::EF { - let name = format!( - "intermediate{}", - self.intermediates.len() + self.ext_intermediates.len() - ); - let intermediate = ExtExpr::Param(name.clone()); - self.ext_intermediates.push((name, expr)); - intermediate - } - - fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F { - BaseExpr::Param(column.name().to_string()) - } - - super::logup_proxy!(); -} - -#[cfg(test)] -mod tests { - use std::collections::HashMap; - - use num_traits::One; - use rand::rngs::SmallRng; - use rand::{Rng, SeedableRng}; - - use super::{BaseExpr, ExtExpr}; - use crate::constraint_framework::expr::ExprEvaluator; - use crate::constraint_framework::{ - relation, AssertEvaluator, EvalAtRow, FrameworkEval, RelationEntry, - }; - use crate::core::fields::m31::{self, BaseField}; - use crate::core::fields::qm31::SecureField; - use crate::core::fields::FieldExpOps; - - macro_rules! secure_col { - ($a:expr, $b:expr, $c:expr, $d:expr) => { - ExtExpr::SecureCol([ - Box::new($a.into()), - Box::new($b.into()), - Box::new($c.into()), - Box::new($d.into()), - ]) - }; - } - - macro_rules! col { - ($interaction:expr, $idx:expr, $offset:expr) => { - BaseExpr::Col(($interaction, $idx, $offset).into()) - }; - } - - macro_rules! var { - ($var:expr) => { - BaseExpr::Param($var.to_string()) - }; - } - - macro_rules! qvar { - ($var:expr) => { - ExtExpr::Param($var.to_string()) - }; - } - - macro_rules! felt { - ($val:expr) => { - BaseExpr::Const($val.into()) - }; - } - - macro_rules! qfelt { - ($a:expr, $b:expr, $c:expr, $d:expr) => { - ExtExpr::Const(SecureField::from_m31_array([ - $a.into(), - $b.into(), - $c.into(), - $d.into(), - ])) - }; - } - - #[test] - fn test_eval_expr() { - let col_1_0_0 = BaseField::from(12); - let col_1_1_0 = BaseField::from(5); - let var_a = BaseField::from(3); - let var_b = BaseField::from(4); - let var_c = SecureField::from_m31_array([ - BaseField::from(1), - BaseField::from(2), - BaseField::from(3), - BaseField::from(4), - ]); - - let columns: HashMap<(usize, usize, isize), BaseField> = - HashMap::from([((1, 0, 0), col_1_0_0), ((1, 1, 0), col_1_1_0)]); - let vars = HashMap::from([("a".to_string(), var_a), ("b".to_string(), var_b)]); - let ext_vars = HashMap::from([("c".to_string(), var_c)]); - - let expr = secure_col!( - col!(1, 0, 0) - col!(1, 1, 0), - col!(1, 1, 0) * (-var!("a")), - var!("a") + var!("a").inverse(), - var!("b") * felt!(7) - ) + qvar!("c") * qvar!("c") - - qfelt!(1, 0, 0, 0); - - let expected = SecureField::from_m31_array([ - col_1_0_0 - col_1_1_0, - col_1_1_0 * (-var_a), - var_a + var_a.inverse(), - var_b * BaseField::from(7), - ]) + var_c * var_c - - SecureField::one(); - - assert_eq!( - expr.eval_expr::, _, _, _>(&columns, &vars, &ext_vars), - expected - ); - } - - #[test] - fn test_simplify_expr() { - let c0 = col!(1, 0, 0); - let c1 = col!(1, 1, 0); - let a = var!("a"); - let b = qvar!("b"); - let zero = felt!(0); - let qzero = qfelt!(0, 0, 0, 0); - let one = felt!(1); - let qone = qfelt!(1, 0, 0, 0); - let minus_one = felt!(m31::P - 1); - let qminus_one = qfelt!(m31::P - 1, 0, 0, 0); - - let mut rng = SmallRng::seed_from_u64(0); - let columns: HashMap<(usize, usize, isize), BaseField> = - HashMap::from([((1, 0, 0), rng.gen()), ((1, 1, 0), rng.gen())]); - let vars: HashMap = HashMap::from([("a".to_string(), rng.gen())]); - let ext_vars: HashMap = HashMap::from([("b".to_string(), rng.gen())]); - - let base_expr = (((zero.clone() + c0.clone()) + (a.clone() + zero.clone())) - * ((-c1.clone()) + (-c0.clone())) - + (-(-(a.clone() + a.clone() + c0.clone()))) - - zero.clone()) - + (a.clone() - zero.clone()) - + (-c1.clone() - (a.clone() * a.clone())) - + (a.clone() * zero.clone()) - - (zero.clone() * c1.clone()) - + one.clone() - * a.clone() - * one.clone() - * c1.clone() - * (-a.clone()) - * c1.clone() - * (minus_one.clone() * c0.clone()); - - let expr = (qzero.clone() - + secure_col!( - base_expr.clone(), - base_expr.clone(), - zero.clone(), - one.clone() - ) - - qzero.clone()) - * qone.clone() - * b.clone() - * qminus_one.clone(); - - let full_eval = expr.eval_expr::, _, _, _>(&columns, &vars, &ext_vars); - let simplified_eval = expr - .simplify() - .eval_expr::, _, _, _>(&columns, &vars, &ext_vars); - - assert_eq!(full_eval, simplified_eval); - } - - #[test] - fn test_format_expr() { - let test_struct = TestStruct {}; - let eval = test_struct.evaluate(ExprEvaluator::new(16, false)); - let expected = "let intermediate0 = (col_1_1[0]) * (col_1_2[0]); - -\ - let intermediate1 = (TestRelation_alpha0) * (col_1_0[0]) \ - + (TestRelation_alpha1) * (col_1_1[0]) \ - + (TestRelation_alpha2) * (col_1_2[0]) \ - - (TestRelation_z); - -\ - let constraint_0 = ((col_1_0[0]) * (intermediate0)) * (1 / (col_1_0[0] + col_1_1[0])); - -\ - let constraint_1 = (SecureCol(col_2_3[0], col_2_4[0], col_2_5[0], col_2_6[0]) \ - - (SecureCol(col_2_3[-1], col_2_4[-1], col_2_5[-1], col_2_6[-1]) \ - - ((total_sum) * (preprocessed.is_first)))) \ - * (intermediate1) \ - - (1);" - .to_string(); - - assert_eq!(eval.format_constraints(), expected); - } - - relation!(TestRelation, 3); - - struct TestStruct {} - impl FrameworkEval for TestStruct { - fn log_size(&self) -> u32 { - 0 - } - fn max_constraint_log_degree_bound(&self) -> u32 { - 0 - } - fn evaluate(&self, mut eval: E) -> E { - let x0 = eval.next_trace_mask(); - let x1 = eval.next_trace_mask(); - let x2 = eval.next_trace_mask(); - let intermediate = eval.add_intermediate(x1.clone() * x2.clone()); - eval.add_constraint(x0.clone() * intermediate * (x0.clone() + x1.clone()).inverse()); - eval.add_to_relation(RelationEntry::new( - &TestRelation::dummy(), - E::EF::one(), - &[x0, x1, x2], - )); - eval.finalize_logup(); - eval - } - } -} diff --git a/crates/prover/src/constraint_framework/expr/assignment.rs b/crates/prover/src/constraint_framework/expr/assignment.rs new file mode 100644 index 000000000..1ba834139 --- /dev/null +++ b/crates/prover/src/constraint_framework/expr/assignment.rs @@ -0,0 +1,267 @@ +use std::collections::{HashMap, HashSet}; +use std::hash::{DefaultHasher, Hash, Hasher}; +use std::ops::{Add, Index}; + +use itertools::sorted; + +use super::{BaseExpr, ColumnExpr, ExtExpr}; +use crate::constraint_framework::{AssertEvaluator, EvalAtRow}; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::FieldExpOps; + +/// An assignment to the variables that may appear in an expression. +pub type ExprVarAssignment = ( + HashMap<(usize, usize, isize), BaseField>, + HashMap, + HashMap, +); + +/// Three sets representing all the variables that can appear in an expression: +/// * `cols`: The columns of the AIR. +/// * `params`: The formal parameters to the AIR. +/// * `ext_params`: The extension field parameters to the AIR. +#[derive(Default)] +pub struct ExprVariables { + pub cols: HashSet, + pub params: HashSet, + pub ext_params: HashSet, +} + +impl ExprVariables { + pub fn col(col: ColumnExpr) -> Self { + Self { + cols: vec![col].into_iter().collect(), + params: HashSet::new(), + ext_params: HashSet::new(), + } + } + + pub fn param(param: String) -> Self { + Self { + cols: HashSet::new(), + params: vec![param].into_iter().collect(), + ext_params: HashSet::new(), + } + } + + pub fn ext_param(param: String) -> Self { + Self { + cols: HashSet::new(), + params: HashSet::new(), + ext_params: vec![param].into_iter().collect(), + } + } + + /// Generates a random assignment to the variables. + /// Note that the assignment is deterministic in the sets of variables (disregarding their + /// order), and this is required. + pub fn random_assignment(&self, salt: usize) -> ExprVarAssignment { + let cols = sorted(self.cols.iter()) + .map(|col| { + ((col.interaction, col.idx, col.offset), { + let mut hasher = DefaultHasher::new(); + (salt, col).hash(&mut hasher); + (hasher.finish() as u32).into() + }) + }) + .collect(); + + let params = sorted(self.params.iter()) + .map(|param| { + (param.clone(), { + let mut hasher = DefaultHasher::new(); + (salt, param).hash(&mut hasher); + (hasher.finish() as u32).into() + }) + }) + .collect(); + + let ext_params = sorted(self.ext_params.iter()) + .map(|param| { + (param.clone(), { + let mut hasher = DefaultHasher::new(); + (salt, param).hash(&mut hasher); + (hasher.finish() as u32).into() + }) + }) + .collect(); + + (cols, params, ext_params) + } +} + +impl Add for ExprVariables { + type Output = Self; + fn add(self, rhs: Self) -> Self { + Self { + cols: self.cols.union(&rhs.cols).cloned().collect(), + params: self.params.union(&rhs.params).cloned().collect(), + ext_params: self.ext_params.union(&rhs.ext_params).cloned().collect(), + } + } +} + +impl BaseExpr { + /// Evaluates a base field expression. + /// Takes: + /// * `columns`: A mapping from triplets (interaction, idx, offset) to base field values. + /// * `vars`: A mapping from variable names to base field values. + pub fn eval_expr(&self, columns: &C, vars: &V) -> E::F + where + C: for<'a> Index<&'a (usize, usize, isize), Output = E::F>, + V: for<'a> Index<&'a String, Output = E::F>, + E: EvalAtRow, + { + match self { + Self::Col(col) => columns[&(col.interaction, col.idx, col.offset)].clone(), + Self::Const(c) => E::F::from(*c), + Self::Param(var) => vars[&var.to_string()].clone(), + Self::Add(a, b) => { + a.eval_expr::(columns, vars) + b.eval_expr::(columns, vars) + } + Self::Sub(a, b) => { + a.eval_expr::(columns, vars) - b.eval_expr::(columns, vars) + } + Self::Mul(a, b) => { + a.eval_expr::(columns, vars) * b.eval_expr::(columns, vars) + } + Self::Neg(a) => -a.eval_expr::(columns, vars), + Self::Inv(a) => a.eval_expr::(columns, vars).inverse(), + } + } + + pub fn collect_variables(&self) -> ExprVariables { + match self { + BaseExpr::Col(col) => ExprVariables::col(col.clone()), + BaseExpr::Const(_) => ExprVariables::default(), + BaseExpr::Param(param) => ExprVariables::param(param.to_string()), + BaseExpr::Add(a, b) => a.collect_variables() + b.collect_variables(), + BaseExpr::Sub(a, b) => a.collect_variables() + b.collect_variables(), + BaseExpr::Mul(a, b) => a.collect_variables() + b.collect_variables(), + BaseExpr::Neg(a) => a.collect_variables(), + BaseExpr::Inv(a) => a.collect_variables(), + } + } + + pub fn random_eval(&self) -> BaseField { + let assignment = self.collect_variables().random_assignment(0); + assert!(assignment.2.is_empty()); + self.eval_expr::, _, _>(&assignment.0, &assignment.1) + } +} + +impl ExtExpr { + /// Evaluates an extension field expression. + /// Takes: + /// * `columns`: A mapping from triplets (interaction, idx, offset) to base field values. + /// * `vars`: A mapping from variable names to base field values. + /// * `ext_vars`: A mapping from variable names to extension field values. + pub fn eval_expr(&self, columns: &C, vars: &V, ext_vars: &EV) -> E::EF + where + C: for<'a> Index<&'a (usize, usize, isize), Output = E::F>, + V: for<'a> Index<&'a String, Output = E::F>, + EV: for<'a> Index<&'a String, Output = E::EF>, + E: EvalAtRow, + { + match self { + Self::SecureCol([a, b, c, d]) => { + let a = a.eval_expr::(columns, vars); + let b = b.eval_expr::(columns, vars); + let c = c.eval_expr::(columns, vars); + let d = d.eval_expr::(columns, vars); + E::combine_ef([a, b, c, d]) + } + Self::Const(c) => E::EF::from(*c), + Self::Param(var) => ext_vars[&var.to_string()].clone(), + Self::Add(a, b) => { + a.eval_expr::(columns, vars, ext_vars) + + b.eval_expr::(columns, vars, ext_vars) + } + Self::Sub(a, b) => { + a.eval_expr::(columns, vars, ext_vars) + - b.eval_expr::(columns, vars, ext_vars) + } + Self::Mul(a, b) => { + a.eval_expr::(columns, vars, ext_vars) + * b.eval_expr::(columns, vars, ext_vars) + } + Self::Neg(a) => -a.eval_expr::(columns, vars, ext_vars), + } + } + + pub fn collect_variables(&self) -> ExprVariables { + match self { + ExtExpr::SecureCol([a, b, c, d]) => { + a.collect_variables() + + b.collect_variables() + + c.collect_variables() + + d.collect_variables() + } + ExtExpr::Const(_) => ExprVariables::default(), + ExtExpr::Param(param) => ExprVariables::ext_param(param.to_string()), + ExtExpr::Add(a, b) => a.collect_variables() + b.collect_variables(), + ExtExpr::Sub(a, b) => a.collect_variables() + b.collect_variables(), + ExtExpr::Mul(a, b) => a.collect_variables() + b.collect_variables(), + ExtExpr::Neg(a) => a.collect_variables(), + } + } + + pub fn random_eval(&self) -> SecureField { + let assignment = self.collect_variables().random_assignment(0); + self.eval_expr::, _, _, _>(&assignment.0, &assignment.1, &assignment.2) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use num_traits::One; + + use crate::constraint_framework::expr::utils::*; + use crate::constraint_framework::AssertEvaluator; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + use crate::core::fields::FieldExpOps; + + #[test] + fn test_eval_expr() { + let col_1_0_0 = BaseField::from(12); + let col_1_1_0 = BaseField::from(5); + let var_a = BaseField::from(3); + let var_b = BaseField::from(4); + let var_c = SecureField::from_m31_array([ + BaseField::from(1), + BaseField::from(2), + BaseField::from(3), + BaseField::from(4), + ]); + + let columns: HashMap<(usize, usize, isize), BaseField> = + HashMap::from([((1, 0, 0), col_1_0_0), ((1, 1, 0), col_1_1_0)]); + let vars = HashMap::from([("a".to_string(), var_a), ("b".to_string(), var_b)]); + let ext_vars = HashMap::from([("c".to_string(), var_c)]); + + let expr = secure_col!( + col!(1, 0, 0) - col!(1, 1, 0), + col!(1, 1, 0) * (-var!("a")), + var!("a") + var!("a").inverse(), + var!("b") * felt!(7) + ) + qvar!("c") * qvar!("c") + - qfelt!(1, 0, 0, 0); + + let expected = SecureField::from_m31_array([ + col_1_0_0 - col_1_1_0, + col_1_1_0 * (-var_a), + var_a + var_a.inverse(), + var_b * BaseField::from(7), + ]) + var_c * var_c + - SecureField::one(); + + assert_eq!( + expr.eval_expr::, _, _, _>(&columns, &vars, &ext_vars), + expected + ); + } +} diff --git a/crates/prover/src/constraint_framework/expr/evaluator.rs b/crates/prover/src/constraint_framework/expr/evaluator.rs new file mode 100644 index 000000000..7fef20254 --- /dev/null +++ b/crates/prover/src/constraint_framework/expr/evaluator.rs @@ -0,0 +1,244 @@ +use num_traits::Zero; + +use super::{BaseExpr, ExtExpr}; +use crate::constraint_framework::expr::ColumnExpr; +use crate::constraint_framework::preprocessed_columns::PreprocessedColumn; +use crate::constraint_framework::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX}; +use crate::core::fields::m31; +use crate::core::lookups::utils::Fraction; + +pub struct FormalLogupAtRow { + pub interaction: usize, + pub total_sum: ExtExpr, + pub claimed_sum: Option<(ExtExpr, usize)>, + pub fracs: Vec>, + pub is_finalized: bool, + pub is_first: BaseExpr, + pub log_size: u32, +} + +// P is an offset no column can reach, it signifies the variable +// offset, which is an input to the verifier. +pub const CLAIMED_SUM_DUMMY_OFFSET: usize = m31::P as usize; + +impl FormalLogupAtRow { + pub fn new(interaction: usize, has_partial_sum: bool, log_size: u32) -> Self { + let total_sum_name = "total_sum".to_string(); + let claimed_sum_name = "claimed_sum".to_string(); + + Self { + interaction, + // TODO(alont): Should these be Expr::SecureField? + total_sum: ExtExpr::Param(total_sum_name), + claimed_sum: has_partial_sum + .then_some((ExtExpr::Param(claimed_sum_name), CLAIMED_SUM_DUMMY_OFFSET)), + fracs: vec![], + is_finalized: true, + is_first: BaseExpr::zero(), + log_size, + } + } +} + +/// Returns the expression +/// `value[0] * _alpha0 + value[1] * _alpha1 + ... - _z.` +fn combine_formal>(relation: &R, values: &[BaseExpr]) -> ExtExpr { + const Z_SUFFIX: &str = "_z"; + const ALPHA_SUFFIX: &str = "_alpha"; + + let z = ExtExpr::Param(relation.get_name().to_owned() + Z_SUFFIX); + let alpha_powers = (0..relation.get_size()) + .map(|i| ExtExpr::Param(relation.get_name().to_owned() + ALPHA_SUFFIX + &i.to_string())); + values + .iter() + .zip(alpha_powers) + .fold(ExtExpr::zero(), |acc, (value, power)| { + acc + power * value.clone() + }) + - z +} + +/// An Evaluator that saves all constraint expressions. +pub struct ExprEvaluator { + pub cur_var_index: usize, + pub constraints: Vec, + pub logup: FormalLogupAtRow, + pub intermediates: Vec<(String, BaseExpr)>, + pub ext_intermediates: Vec<(String, ExtExpr)>, +} + +impl ExprEvaluator { + pub fn new(log_size: u32, has_partial_sum: bool) -> Self { + Self { + cur_var_index: Default::default(), + constraints: Default::default(), + logup: FormalLogupAtRow::new(INTERACTION_TRACE_IDX, has_partial_sum, log_size), + intermediates: vec![], + ext_intermediates: vec![], + } + } + + pub fn format_constraints(&self) -> String { + let lets_string = self + .intermediates + .iter() + .map(|(name, expr)| format!("let {} = {};", name, expr.simplify_and_format())) + .collect::>() + .join("\n\n"); + + let secure_lets_string = self + .ext_intermediates + .iter() + .map(|(name, expr)| format!("let {} = {};", name, expr.simplify_and_format())) + .collect::>() + .join("\n\n"); + + let constraints_str = self + .constraints + .iter() + .enumerate() + .map(|(i, c)| format!("let constraint_{i} = ") + &c.simplify_and_format() + ";") + .collect::>() + .join("\n\n"); + + [lets_string, secure_lets_string, constraints_str] + .iter() + .filter(|x| !x.is_empty()) + .cloned() + .collect::>() + .join("\n\n") + } +} + +impl EvalAtRow for ExprEvaluator { + // TODO(alont): Should there be a version of this that disallows Secure fields for F? + type F = BaseExpr; + type EF = ExtExpr; + + fn next_interaction_mask( + &mut self, + interaction: usize, + offsets: [isize; N], + ) -> [Self::F; N] { + let res = std::array::from_fn(|i| { + let col = ColumnExpr::from((interaction, self.cur_var_index, offsets[i])); + BaseExpr::Col(col) + }); + self.cur_var_index += 1; + res + } + + fn add_constraint(&mut self, constraint: G) + where + Self::EF: From, + { + self.constraints.push(constraint.into()); + } + + fn combine_ef(values: [Self::F; 4]) -> Self::EF { + ExtExpr::SecureCol([ + Box::new(values[0].clone()), + Box::new(values[1].clone()), + Box::new(values[2].clone()), + Box::new(values[3].clone()), + ]) + } + + fn add_to_relation>( + &mut self, + entry: RelationEntry<'_, Self::F, Self::EF, R>, + ) { + let intermediate = + self.add_extension_intermediate(combine_formal(entry.relation, entry.values)); + let frac = Fraction::new(entry.multiplicity.clone(), intermediate); + self.write_logup_frac(frac); + } + + fn add_intermediate(&mut self, expr: Self::F) -> Self::F { + let name = format!( + "intermediate{}", + self.intermediates.len() + self.ext_intermediates.len() + ); + let intermediate = BaseExpr::Param(name.clone()); + self.intermediates.push((name, expr)); + intermediate + } + + fn add_extension_intermediate(&mut self, expr: Self::EF) -> Self::EF { + let name = format!( + "intermediate{}", + self.intermediates.len() + self.ext_intermediates.len() + ); + let intermediate = ExtExpr::Param(name.clone()); + self.ext_intermediates.push((name, expr)); + intermediate + } + + fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F { + BaseExpr::Param(column.name().to_string()) + } + + crate::constraint_framework::logup_proxy!(); +} + +#[cfg(test)] +mod tests { + use num_traits::One; + + use crate::constraint_framework::expr::ExprEvaluator; + use crate::constraint_framework::{EvalAtRow, FrameworkEval, RelationEntry}; + use crate::core::fields::FieldExpOps; + use crate::relation; + + #[test] + fn test_expr_evaluator() { + let test_struct = TestStruct {}; + let eval = test_struct.evaluate(ExprEvaluator::new(16, false)); + let expected = "let intermediate0 = (col_1_1[0]) * (col_1_2[0]); + +\ + let intermediate1 = (TestRelation_alpha0) * (col_1_0[0]) \ + + (TestRelation_alpha1) * (col_1_1[0]) \ + + (TestRelation_alpha2) * (col_1_2[0]) \ + - (TestRelation_z); + +\ + let constraint_0 = ((col_1_0[0]) * (intermediate0)) * (1 / (col_1_0[0] + col_1_1[0])); + +\ + let constraint_1 = (SecureCol(col_2_3[0], col_2_4[0], col_2_5[0], col_2_6[0]) \ + - (SecureCol(col_2_3[-1], col_2_4[-1], col_2_5[-1], col_2_6[-1]) \ + - ((total_sum) * (preprocessed.is_first)))) \ + * (intermediate1) \ + - (1);" + .to_string(); + + assert_eq!(eval.format_constraints(), expected); + } + + relation!(TestRelation, 3); + + struct TestStruct {} + impl FrameworkEval for TestStruct { + fn log_size(&self) -> u32 { + 0 + } + fn max_constraint_log_degree_bound(&self) -> u32 { + 0 + } + fn evaluate(&self, mut eval: E) -> E { + let x0 = eval.next_trace_mask(); + let x1 = eval.next_trace_mask(); + let x2 = eval.next_trace_mask(); + let intermediate = eval.add_intermediate(x1.clone() * x2.clone()); + eval.add_constraint(x0.clone() * intermediate * (x0.clone() + x1.clone()).inverse()); + eval.add_to_relation(RelationEntry::new( + &TestRelation::dummy(), + E::EF::one(), + &[x0, x1, x2], + )); + eval.finalize_logup(); + eval + } + } +} diff --git a/crates/prover/src/constraint_framework/expr/format.rs b/crates/prover/src/constraint_framework/expr/format.rs new file mode 100644 index 000000000..f6d30163d --- /dev/null +++ b/crates/prover/src/constraint_framework/expr/format.rs @@ -0,0 +1,64 @@ +use num_traits::Zero; + +use super::{BaseExpr, ColumnExpr, ExtExpr, CLAIMED_SUM_DUMMY_OFFSET}; + +impl BaseExpr { + pub fn format_expr(&self) -> String { + match self { + BaseExpr::Col(ColumnExpr { + interaction, + idx, + offset, + }) => { + let offset_str = if *offset == CLAIMED_SUM_DUMMY_OFFSET as isize { + "claimed_sum_offset".to_string() + } else { + offset.to_string() + }; + format!("col_{interaction}_{idx}[{offset_str}]") + } + BaseExpr::Const(c) => c.to_string(), + BaseExpr::Param(v) => v.to_string(), + BaseExpr::Add(a, b) => format!("{} + {}", a.format_expr(), b.format_expr()), + BaseExpr::Sub(a, b) => format!("{} - ({})", a.format_expr(), b.format_expr()), + BaseExpr::Mul(a, b) => format!("({}) * ({})", a.format_expr(), b.format_expr()), + BaseExpr::Neg(a) => format!("-({})", a.format_expr()), + BaseExpr::Inv(a) => format!("1 / ({})", a.format_expr()), + } + } +} + +impl ExtExpr { + pub fn format_expr(&self) -> String { + match self { + ExtExpr::SecureCol([a, b, c, d]) => { + // If the expression's non-base components are all constant zeroes, return the base + // field representation of its first part. + if **b == BaseExpr::zero() && **c == BaseExpr::zero() && **d == BaseExpr::zero() { + a.format_expr() + } else { + format!( + "SecureCol({}, {}, {}, {})", + a.format_expr(), + b.format_expr(), + c.format_expr(), + d.format_expr() + ) + } + } + ExtExpr::Const(c) => { + if c.0 .1.is_zero() && c.1 .0.is_zero() && c.1 .1.is_zero() { + // If the constant is in the base field, display it as such. + c.0 .0.to_string() + } else { + c.to_string() + } + } + ExtExpr::Param(v) => v.to_string(), + ExtExpr::Add(a, b) => format!("{} + {}", a.format_expr(), b.format_expr()), + ExtExpr::Sub(a, b) => format!("{} - ({})", a.format_expr(), b.format_expr()), + ExtExpr::Mul(a, b) => format!("({}) * ({})", a.format_expr(), b.format_expr()), + ExtExpr::Neg(a) => format!("-({})", a.format_expr()), + } + } +} diff --git a/crates/prover/src/constraint_framework/expr/mod.rs b/crates/prover/src/constraint_framework/expr/mod.rs new file mode 100644 index 000000000..14668c6e2 --- /dev/null +++ b/crates/prover/src/constraint_framework/expr/mod.rs @@ -0,0 +1,351 @@ +pub mod assignment; +pub mod evaluator; +pub mod format; +pub mod simplify; +pub mod utils; + +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub}; + +pub use evaluator::ExprEvaluator; +use num_traits::{One, Zero}; + +use crate::constraint_framework::expr::evaluator::CLAIMED_SUM_DUMMY_OFFSET; +use crate::core::fields::cm31::CM31; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::{SecureField, QM31}; +use crate::core::fields::FieldExpOps; + +/// A single base field column at index `idx` of interaction `interaction`, at mask offset `offset`. +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct ColumnExpr { + interaction: usize, + idx: usize, + offset: isize, +} + +impl From<(usize, usize, isize)> for ColumnExpr { + fn from((interaction, idx, offset): (usize, usize, isize)) -> Self { + Self { + interaction, + idx, + offset, + } + } +} + +/// An expression representing a base field value. Can be either: +/// * A column indexed by a `ColumnExpr`. +/// * A base field constant. +/// * A formal parameter to the AIR. +/// * A sum, difference, or product of two base field expressions. +/// * A negation or inverse of a base field expression. +/// +/// This type is meant to be used as an F associated type for EvalAtRow and interacts with +/// `ExtExpr`, `BaseField` and `SecureField` as expected. +#[derive(Clone, Debug, PartialEq)] +pub enum BaseExpr { + Col(ColumnExpr), + Const(BaseField), + /// Formal parameter to the AIR, for example the interaction elements of a relation. + Param(String), + Add(Box, Box), + Sub(Box, Box), + Mul(Box, Box), + Neg(Box), + Inv(Box), +} + +/// An expression representing a secure field value. Can be either: +/// * A secure column constructed from 4 base field expressions. +/// * A secure field constant. +/// * A formal parameter to the AIR. +/// * A sum, difference, or product of two secure field expressions. +/// * A negation of a secure field expression. +/// +/// This type is meant to be used as an EF associated type for EvalAtRow and interacts with +/// `BaseExpr`, `BaseField` and `SecureField` as expected. +#[derive(Clone, Debug, PartialEq)] +pub enum ExtExpr { + /// An atomic secure column constructed from 4 expressions. + /// Expressions on the secure column are not reduced, i.e, + /// if `a = SecureCol(a0, a1, a2, a3)`, `b = SecureCol(b0, b1, b2, b3)` then + /// `a + b` evaluates to `Add(a, b)` rather than + /// `SecureCol(Add(a0, b0), Add(a1, b1), Add(a2, b2), Add(a3, b3))` + SecureCol([Box; 4]), + Const(SecureField), + /// Formal parameter to the AIR, for example the interaction elements of a relation. + Param(String), + Add(Box, Box), + Sub(Box, Box), + Mul(Box, Box), + Neg(Box), +} + +impl From for BaseExpr { + fn from(val: BaseField) -> Self { + BaseExpr::Const(val) + } +} + +impl From for ExtExpr { + fn from(val: BaseField) -> Self { + ExtExpr::SecureCol([ + Box::new(BaseExpr::from(val)), + Box::new(BaseExpr::zero()), + Box::new(BaseExpr::zero()), + Box::new(BaseExpr::zero()), + ]) + } +} + +impl From for ExtExpr { + fn from(QM31(CM31(a, b), CM31(c, d)): SecureField) -> Self { + ExtExpr::SecureCol([ + Box::new(BaseExpr::from(a)), + Box::new(BaseExpr::from(b)), + Box::new(BaseExpr::from(c)), + Box::new(BaseExpr::from(d)), + ]) + } +} + +impl From for ExtExpr { + fn from(expr: BaseExpr) -> Self { + ExtExpr::SecureCol([ + Box::new(expr.clone()), + Box::new(BaseExpr::zero()), + Box::new(BaseExpr::zero()), + Box::new(BaseExpr::zero()), + ]) + } +} + +impl Add for BaseExpr { + type Output = Self; + fn add(self, rhs: Self) -> Self { + BaseExpr::Add(Box::new(self), Box::new(rhs)) + } +} + +impl Sub for BaseExpr { + type Output = Self; + fn sub(self, rhs: Self) -> Self { + BaseExpr::Sub(Box::new(self), Box::new(rhs)) + } +} + +impl Mul for BaseExpr { + type Output = Self; + fn mul(self, rhs: Self) -> Self { + BaseExpr::Mul(Box::new(self), Box::new(rhs)) + } +} + +impl AddAssign for BaseExpr { + fn add_assign(&mut self, rhs: Self) { + *self = self.clone() + rhs + } +} + +impl MulAssign for BaseExpr { + fn mul_assign(&mut self, rhs: Self) { + *self = self.clone() * rhs + } +} + +impl Neg for BaseExpr { + type Output = Self; + fn neg(self) -> Self { + BaseExpr::Neg(Box::new(self)) + } +} + +impl Add for ExtExpr { + type Output = Self; + fn add(self, rhs: Self) -> Self { + ExtExpr::Add(Box::new(self), Box::new(rhs)) + } +} + +impl Sub for ExtExpr { + type Output = Self; + fn sub(self, rhs: Self) -> Self { + ExtExpr::Sub(Box::new(self), Box::new(rhs)) + } +} + +impl Mul for ExtExpr { + type Output = Self; + fn mul(self, rhs: Self) -> Self { + ExtExpr::Mul(Box::new(self), Box::new(rhs)) + } +} + +impl AddAssign for ExtExpr { + fn add_assign(&mut self, rhs: Self) { + *self = self.clone() + rhs + } +} + +impl MulAssign for ExtExpr { + fn mul_assign(&mut self, rhs: Self) { + *self = self.clone() * rhs + } +} + +impl Neg for ExtExpr { + type Output = Self; + fn neg(self) -> Self { + ExtExpr::Neg(Box::new(self)) + } +} + +impl Zero for BaseExpr { + fn zero() -> Self { + BaseExpr::from(BaseField::zero()) + } + fn is_zero(&self) -> bool { + // TODO(alont): consider replacing `Zero` in the trait bound with a custom trait + // that only has `zero()`. + panic!("Can't check if an expression is zero."); + } +} + +impl One for BaseExpr { + fn one() -> Self { + BaseExpr::from(BaseField::one()) + } +} + +impl Zero for ExtExpr { + fn zero() -> Self { + ExtExpr::from(BaseField::zero()) + } + fn is_zero(&self) -> bool { + // TODO(alont): consider replacing `Zero` in the trait bound with a custom trait + // that only has `zero()`. + panic!("Can't check if an expression is zero."); + } +} + +impl One for ExtExpr { + fn one() -> Self { + ExtExpr::from(BaseField::one()) + } +} + +impl FieldExpOps for BaseExpr { + fn inverse(&self) -> Self { + BaseExpr::Inv(Box::new(self.clone())) + } +} + +impl Add for BaseExpr { + type Output = Self; + fn add(self, rhs: BaseField) -> Self { + self + BaseExpr::from(rhs) + } +} + +impl AddAssign for BaseExpr { + fn add_assign(&mut self, rhs: BaseField) { + *self = self.clone() + BaseExpr::from(rhs) + } +} + +impl Mul for BaseExpr { + type Output = Self; + fn mul(self, rhs: BaseField) -> Self { + self * BaseExpr::from(rhs) + } +} + +impl Mul for BaseExpr { + type Output = ExtExpr; + fn mul(self, rhs: SecureField) -> ExtExpr { + ExtExpr::from(self) * ExtExpr::from(rhs) + } +} + +impl Add for BaseExpr { + type Output = ExtExpr; + fn add(self, rhs: SecureField) -> ExtExpr { + ExtExpr::from(self) + ExtExpr::from(rhs) + } +} + +impl Sub for BaseExpr { + type Output = ExtExpr; + fn sub(self, rhs: SecureField) -> ExtExpr { + ExtExpr::from(self) - ExtExpr::from(rhs) + } +} + +impl Add for ExtExpr { + type Output = Self; + fn add(self, rhs: BaseField) -> Self { + self + ExtExpr::from(rhs) + } +} + +impl AddAssign for ExtExpr { + fn add_assign(&mut self, rhs: BaseField) { + *self = self.clone() + ExtExpr::from(rhs) + } +} + +impl Mul for ExtExpr { + type Output = Self; + fn mul(self, rhs: BaseField) -> Self { + self * ExtExpr::from(rhs) + } +} + +impl Mul for ExtExpr { + type Output = Self; + fn mul(self, rhs: SecureField) -> Self { + self * ExtExpr::from(rhs) + } +} + +impl Add for ExtExpr { + type Output = Self; + fn add(self, rhs: SecureField) -> Self { + self + ExtExpr::from(rhs) + } +} + +impl Sub for ExtExpr { + type Output = Self; + fn sub(self, rhs: SecureField) -> Self { + self - ExtExpr::from(rhs) + } +} + +impl Add for ExtExpr { + type Output = Self; + fn add(self, rhs: BaseExpr) -> Self { + self + ExtExpr::from(rhs) + } +} + +impl Mul for ExtExpr { + type Output = Self; + fn mul(self, rhs: BaseExpr) -> Self { + self * ExtExpr::from(rhs) + } +} + +impl Mul for BaseExpr { + type Output = ExtExpr; + fn mul(self, rhs: ExtExpr) -> ExtExpr { + rhs * self + } +} + +impl Sub for ExtExpr { + type Output = Self; + fn sub(self, rhs: BaseExpr) -> Self { + self - ExtExpr::from(rhs) + } +} diff --git a/crates/prover/src/constraint_framework/expr/simplify.rs b/crates/prover/src/constraint_framework/expr/simplify.rs new file mode 100644 index 000000000..528b23627 --- /dev/null +++ b/crates/prover/src/constraint_framework/expr/simplify.rs @@ -0,0 +1,217 @@ +use num_traits::{One, Zero}; + +use super::{BaseExpr, ExtExpr}; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::FieldExpOps; + +/// Applies simplifications to arithmetic expressions that can be used both for `BaseExpr` and for +/// `ExtExpr`. +macro_rules! simplify_arithmetic { + ($self:tt) => { + match $self.clone() { + Self::Add(a, b) => { + let a = a.simplify(); + let b = b.simplify(); + match (a.clone(), b.clone()) { + // Simplify constants. + (Self::Const(a), Self::Const(b)) => Self::Const(a + b), + (Self::Const(a_val), _) if a_val.is_zero() => b, // 0 + b = b + (_, Self::Const(b_val)) if b_val.is_zero() => a, // a + 0 = a + // Simplify Negs. + // (-a + -b) = -(a + b) + (Self::Neg(minus_a), Self::Neg(minus_b)) => -(*minus_a + *minus_b), + (Self::Neg(minus_a), _) => b - *minus_a, // -a + b = b - a + (_, Self::Neg(minus_b)) => a - *minus_b, // a + -b = a - b + // No simplification. + _ => a + b, + } + } + Self::Sub(a, b) => { + let a = a.simplify(); + let b = b.simplify(); + match (a.clone(), b.clone()) { + // Simplify constants. + (Self::Const(a), Self::Const(b)) => Self::Const(a - b), // Simplify consts. + (Self::Const(a_val), _) if a_val.is_zero() => -b, // 0 - b = -b + (_, Self::Const(b_val)) if b_val.is_zero() => a, // a - 0 = a + // Simplify Negs. + // (-a - -b) = b - a + (Self::Neg(minus_a), Self::Neg(minus_b)) => *minus_b - *minus_a, + (Self::Neg(minus_a), _) => -(*minus_a + b), // -a - b = -(a + b) + (_, Self::Neg(minus_b)) => a + *minus_b, // a + -b = a - b + // No Simplification. + _ => a - b, + } + } + Self::Mul(a, b) => { + let a = a.simplify(); + let b = b.simplify(); + match (a.clone(), b.clone()) { + // Simplify consts. + (Self::Const(a), Self::Const(b)) => Self::Const(a * b), + (Self::Const(a_val), _) if a_val.is_zero() => Self::zero(), // 0 * b = 0 + (_, Self::Const(b_val)) if b_val.is_zero() => Self::zero(), // a * 0 = 0 + (Self::Const(a_val), _) if a_val == One::one() => b, // 1 * b = b + (_, Self::Const(b_val)) if b_val == One::one() => a, // a * 1 = a + (Self::Const(a_val), _) if -a_val == One::one() => -b, // -1 * b = -b + (_, Self::Const(b_val)) if -b_val == One::one() => -a, // a * -1 = -a + // Simplify Negs. + // (-a) * (-b) = a * b + (Self::Neg(minus_a), Self::Neg(minus_b)) => *minus_a * *minus_b, + (Self::Neg(minus_a), _) => -(*minus_a * b), // (-a) * b = -(a * b) + (_, Self::Neg(minus_b)) => -(a * *minus_b), // a * (-b) = -(a * b) + // No simplification. + _ => a * b, + } + } + Self::Neg(a) => { + let a = a.simplify(); + match a { + Self::Const(c) => Self::Const(-c), + Self::Neg(minus_a) => *minus_a, // -(-a) = a + Self::Sub(a, b) => Self::Sub(b, a), // -(a - b) = b - a + _ => -a, // No simplification. + } + } + other => other, // No simplification. + } + }; +} + +impl BaseExpr { + /// Helper function, use [`simplify`] instead. + /// + /// Simplifies an expression by applying basic arithmetic rules. + fn unchecked_simplify(&self) -> Self { + let simple = simplify_arithmetic!(self); + match simple { + Self::Inv(a) => { + let a = a.unchecked_simplify(); + match a { + Self::Inv(inv_a) => *inv_a, // 1 / (1 / a) = a + Self::Const(c) => Self::Const(c.inverse()), + _ => Self::Inv(Box::new(a)), + } + } + other => other, + } + } + + /// Simplifies an expression by applying basic arithmetic rules and ensures that the result is + /// equivalent to the original expression by assigning random values. + pub fn simplify(&self) -> Self { + let simplified = self.unchecked_simplify(); + assert_eq!(self.random_eval(), simplified.random_eval()); + simplified + } + + pub fn simplify_and_format(&self) -> String { + self.simplify().format_expr() + } +} + +impl ExtExpr { + /// Helper function, use [`simplify`] instead. + /// + /// Simplifies an expression by applying basic arithmetic rules. + fn unchecked_simplify(&self) -> Self { + let simple = simplify_arithmetic!(self); + match simple { + Self::SecureCol([a, b, c, d]) => { + let a = a.unchecked_simplify(); + let b = b.unchecked_simplify(); + let c = c.unchecked_simplify(); + let d = d.unchecked_simplify(); + match (a.clone(), b.clone(), c.clone(), d.clone()) { + ( + BaseExpr::Const(a_val), + BaseExpr::Const(b_val), + BaseExpr::Const(c_val), + BaseExpr::Const(d_val), + ) => ExtExpr::Const(SecureField::from_m31_array([a_val, b_val, c_val, d_val])), + _ => Self::SecureCol([Box::new(a), Box::new(b), Box::new(c), Box::new(d)]), + } + } + other => other, + } + } + + /// Simplifies an expression by applying basic arithmetic rules and ensures that the result is + /// equivalent to the original expression by assigning random values. + pub fn simplify(&self) -> Self { + let simplified = self.unchecked_simplify(); + assert_eq!(self.random_eval(), simplified.random_eval()); + simplified + } + + pub fn simplify_and_format(&self) -> String { + self.simplify().format_expr() + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use crate::constraint_framework::expr::utils::*; + use crate::constraint_framework::AssertEvaluator; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + #[test] + fn test_simplify_expr() { + let c0 = col!(1, 0, 0); + let c1 = col!(1, 1, 0); + let a = var!("a"); + let b = qvar!("b"); + let zero = felt!(0); + let qzero = qfelt!(0, 0, 0, 0); + let one = felt!(1); + let qone = qfelt!(1, 0, 0, 0); + let minus_one = felt!(crate::core::fields::m31::P - 1); + let qminus_one = qfelt!(crate::core::fields::m31::P - 1, 0, 0, 0); + + let mut rng = SmallRng::seed_from_u64(0); + let columns: HashMap<(usize, usize, isize), BaseField> = + HashMap::from([((1, 0, 0), rng.gen()), ((1, 1, 0), rng.gen())]); + let vars: HashMap = HashMap::from([("a".to_string(), rng.gen())]); + let ext_vars: HashMap = HashMap::from([("b".to_string(), rng.gen())]); + + let base_expr = (((zero.clone() + c0.clone()) + (a.clone() + zero.clone())) + * ((-c1.clone()) + (-c0.clone())) + + (-(-(a.clone() + a.clone() + c0.clone()))) + - zero.clone()) + + (a.clone() - zero.clone()) + + (-c1.clone() - (a.clone() * a.clone())) + + (a.clone() * zero.clone()) + - (zero.clone() * c1.clone()) + + one.clone() + * a.clone() + * one.clone() + * c1.clone() + * (-a.clone()) + * c1.clone() + * (minus_one.clone() * c0.clone()); + + let expr = (qzero.clone() + + secure_col!( + base_expr.clone(), + base_expr.clone(), + zero.clone(), + one.clone() + ) + - qzero.clone()) + * qone.clone() + * b.clone() + * qminus_one.clone(); + + let full_eval = expr.eval_expr::, _, _, _>(&columns, &vars, &ext_vars); + let simplified_eval = expr + .simplify() + .eval_expr::, _, _, _>(&columns, &vars, &ext_vars); + + assert_eq!(full_eval, simplified_eval); + } +} diff --git a/crates/prover/src/constraint_framework/expr/utils.rs b/crates/prover/src/constraint_framework/expr/utils.rs new file mode 100644 index 000000000..724840a06 --- /dev/null +++ b/crates/prover/src/constraint_framework/expr/utils.rs @@ -0,0 +1,65 @@ +#[cfg(test)] +macro_rules! secure_col { + ($a:expr, $b:expr, $c:expr, $d:expr) => { + crate::constraint_framework::expr::ExtExpr::SecureCol([ + Box::new($a.into()), + Box::new($b.into()), + Box::new($c.into()), + Box::new($d.into()), + ]) + }; +} +#[cfg(test)] +pub(crate) use secure_col; + +#[cfg(test)] +macro_rules! col { + ($interaction:expr, $idx:expr, $offset:expr) => { + crate::constraint_framework::expr::BaseExpr::Col(($interaction, $idx, $offset).into()) + }; +} +#[cfg(test)] +pub(crate) use col; + +#[cfg(test)] +macro_rules! var { + ($var:expr) => { + crate::constraint_framework::expr::BaseExpr::Param($var.to_string()) + }; +} +#[cfg(test)] +pub(crate) use var; + +#[cfg(test)] +macro_rules! qvar { + ($var:expr) => { + crate::constraint_framework::expr::ExtExpr::Param($var.to_string()) + }; +} +#[cfg(test)] +pub(crate) use qvar; + +#[cfg(test)] +macro_rules! felt { + ($val:expr) => { + crate::constraint_framework::expr::BaseExpr::Const($val.into()) + }; +} +#[cfg(test)] +pub(crate) use felt; + +#[cfg(test)] +macro_rules! qfelt { + ($a:expr, $b:expr, $c:expr, $d:expr) => { + crate::constraint_framework::expr::ExtExpr::Const( + crate::core::fields::qm31::SecureField::from_m31_array([ + $a.into(), + $b.into(), + $c.into(), + $d.into(), + ]), + ) + }; +} +#[cfg(test)] +pub(crate) use qfelt; diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index bc188eb57..37baa6167 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -173,7 +173,9 @@ macro_rules! logup_proxy { fn write_logup_frac(&mut self, fraction: Fraction) { if self.logup.fracs.is_empty() { self.logup.is_first = self.get_preprocessed_column( - super::preprocessed_columns::PreprocessedColumn::IsFirst(self.logup.log_size), + crate::constraint_framework::preprocessed_columns::PreprocessedColumn::IsFirst( + self.logup.log_size, + ), ); self.logup.is_finalized = false; } @@ -183,7 +185,7 @@ macro_rules! logup_proxy { /// Finalize the logup by adding the constraints for the fractions, batched by /// the given `batching`. /// `batching` should contain the batch into which every logup entry should be inserted. - fn finalize_logup_batched(&mut self, batching: &super::Batching) { + fn finalize_logup_batched(&mut self, batching: &crate::constraint_framework::Batching) { assert!(!self.logup.is_finalized, "LogupAtRow was already finalized"); assert_eq!( batching.len(), From bcb4ec35ef9828d5fdd8c1b186c530a60f78e626 Mon Sep 17 00:00:00 2001 From: alon-dotan-starkware Date: Sun, 15 Dec 2024 15:40:50 +0200 Subject: [PATCH 26/69] chore: fix benchmarks (#926) --- .github/workflows/benchmarks-pages.yaml | 2 +- .github/workflows/ci.yaml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/benchmarks-pages.yaml b/.github/workflows/benchmarks-pages.yaml index 3672b6c2c..4625dd88d 100644 --- a/.github/workflows/benchmarks-pages.yaml +++ b/.github/workflows/benchmarks-pages.yaml @@ -25,7 +25,7 @@ jobs: uses: actions/cache@v4 with: path: ./cache - key: ${{ runner.os }}-benchmark + key: ${{ runner.os }}-${{github.event.pull_request.base.ref}}-benchmark - name: Store benchmark result uses: benchmark-action/github-action-benchmark@v1 with: diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 29f73ae76..5fb3258e6 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -123,7 +123,7 @@ jobs: uses: actions/cache@v4 with: path: ./cache - key: ${{ runner.os }}-benchmark + key: ${{ runner.os }}-${{github.event.pull_request.base.ref}}-benchmark - name: Store benchmark result uses: benchmark-action/github-action-benchmark@v1 with: @@ -149,7 +149,7 @@ jobs: uses: actions/cache@v4 with: path: ./cache - key: ${{ runner.os }}-benchmark + key: ${{ runner.os }}-${{github.event.pull_request.base.ref}}-benchmark - name: Store benchmark result uses: benchmark-action/github-action-benchmark@v1 with: From bb82d1a149978e9374b8deac7df745e212701e8a Mon Sep 17 00:00:00 2001 From: VitaliiH Date: Mon, 16 Dec 2024 12:48:34 +0400 Subject: [PATCH 27/69] wip - utils draft, wip fold, fibonacci switch to icicle --- Cargo.lock | 8 +- crates/prover/Cargo.toml | 10 +- crates/prover/src/core/backend/icicle/mod.rs | 2 + .../prover/src/core/backend/icicle/utils.rs | 113 ++++++++++++++++++ crates/prover/src/examples/mod.rs | 1 + .../prover/src/examples/state_machine/mod.rs | 10 +- crates/prover/src/examples/utils.rs | 8 ++ .../prover/src/examples/wide_fibonacci/mod.rs | 22 ++-- 8 files changed, 150 insertions(+), 24 deletions(-) create mode 100644 crates/prover/src/core/backend/icicle/utils.rs create mode 100644 crates/prover/src/examples/utils.rs diff --git a/Cargo.lock b/Cargo.lock index dd0806e96..9b5a5061e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -760,7 +760,7 @@ dependencies = [ [[package]] name = "icicle-core" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=c33902295916c5b9bebdbc4ec32e0f531502b141#c33902295916c5b9bebdbc4ec32e0f531502b141" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0#eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0" dependencies = [ "criterion 0.3.6", "hex", @@ -771,7 +771,7 @@ dependencies = [ [[package]] name = "icicle-cuda-runtime" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=c33902295916c5b9bebdbc4ec32e0f531502b141#c33902295916c5b9bebdbc4ec32e0f531502b141" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0#eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0" dependencies = [ "bindgen", "bitflags 1.3.2", @@ -780,7 +780,7 @@ dependencies = [ [[package]] name = "icicle-hash" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=c33902295916c5b9bebdbc4ec32e0f531502b141#c33902295916c5b9bebdbc4ec32e0f531502b141" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0#eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0" dependencies = [ "cmake", "icicle-core", @@ -790,7 +790,7 @@ dependencies = [ [[package]] name = "icicle-m31" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=c33902295916c5b9bebdbc4ec32e0f531502b141#c33902295916c5b9bebdbc4ec32e0f531502b141" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0#eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0" dependencies = [ "cmake", "criterion 0.3.6", diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index 284c548a8..61999031c 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -4,7 +4,7 @@ version.workspace = true edition.workspace = true [features] -default = ["icicle"] +default = ["icicle", "parallel"] parallel = ["rayon"] slow-tests = [] icicle = ["icicle-cuda-runtime", "icicle-core", "icicle-m31", "icicle-hash", "nvtx"] @@ -29,10 +29,10 @@ tracing.workspace = true rayon = { version = "1.10.0", optional = true } serde = { version = "1.0", features = ["derive"] } -icicle-cuda-runtime = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="c33902295916c5b9bebdbc4ec32e0f531502b141"} -icicle-core = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="c33902295916c5b9bebdbc4ec32e0f531502b141"} -icicle-m31 = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="c33902295916c5b9bebdbc4ec32e0f531502b141"} -icicle-hash = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="c33902295916c5b9bebdbc4ec32e0f531502b141"} +icicle-cuda-runtime = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0"} +icicle-core = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0"} +icicle-m31 = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0"} +icicle-hash = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0"} nvtx = { version = "*", optional = true } diff --git a/crates/prover/src/core/backend/icicle/mod.rs b/crates/prover/src/core/backend/icicle/mod.rs index 2bbf555f3..51519f82c 100644 --- a/crates/prover/src/core/backend/icicle/mod.rs +++ b/crates/prover/src/core/backend/icicle/mod.rs @@ -45,6 +45,8 @@ use crate::core::ColumnVec; #[derive(Copy, Clone, Debug, Deserialize, Serialize, Default)] pub struct IcicleBackend; +pub mod utils; + impl Backend for IcicleBackend {} // stwo/crates/prover/src/core/backend/cpu/lookups/gkr.rs diff --git a/crates/prover/src/core/backend/icicle/utils.rs b/crates/prover/src/core/backend/icicle/utils.rs new file mode 100644 index 000000000..887c831aa --- /dev/null +++ b/crates/prover/src/core/backend/icicle/utils.rs @@ -0,0 +1,113 @@ +use crate::core::fields::{ExtensionOf, Field}; + +/// Folds values recursively in `O(n)` by a hierarchical application of folding factors. +/// +/// i.e. folding `n = 8` values with `folding_factors = [x, y, z]`: +/// +/// ```text +/// n2=n1+x*n2 +/// / \ +/// n1=n3+y*n4 n2=n5+y*n6 +/// / \ / \ +/// n3=a+z*b n4=c+z*d n5=e+z*f n6=g+z*h +/// / \ / \ / \ / \ +/// a b c d e f g h +/// ``` +/// +/// # Panics +/// +/// Panics if the number of values is not a power of two or if an incorrect number of of folding +/// factors is provided. +// TODO(Andrew): Can be made to run >10x faster by unrolling lower layers of recursion +pub fn fold>(values: &[F], folding_factors: &[E]) -> E { + let n = values.len(); + assert_eq!(n, 1 << folding_factors.len()); + if n == 1 { + let res: E = values[0].into(); + return res; + } + let (lhs_values, rhs_values) = values.split_at(n / 2); + let (folding_factor, folding_factors) = folding_factors.split_first().unwrap(); + let lhs_val = fold(lhs_values, folding_factors); + let rhs_val = fold(rhs_values, folding_factors); + // println!( + // "n={:?} lhs_val{:?} + rhs_val{:?} x folding_factor: {:?}", + // n, lhs_val, rhs_val, *folding_factor + // ); + let res = lhs_val + rhs_val * *folding_factor; + // println!("res = {:?}; ", res); + res +} + +pub fn fold_gpu>(values: &[F], folding_factors: &[E]) -> E { + let n = values.len(); + assert_eq!(n, 1 << folding_factors.len()); + if n == 1 { + return values[0].into(); + } + let (lhs_values, rhs_values) = values.split_at(n / 2); + let (folding_factor, folding_factors) = folding_factors.split_first().unwrap(); + let lhs_val = fold(lhs_values, folding_factors); + let rhs_val = fold(rhs_values, folding_factors); + lhs_val + rhs_val * *folding_factor +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::fields::m31::M31; + use crate::core::fields::qm31::QM31; + #[test] + fn test_fold_works() { + // Example input: power-of-two values and appropriate folding factors + let values = vec![ + M31(1), + M31(2), + M31(3), + M31(4), + M31(5), + M31(6), + M31(7), + M31(8), + ]; + let folding_factors = vec![ + QM31::from_u32_unchecked(2, 0, 0, 0), + QM31::from_u32_unchecked(3, 0, 0, 0), + QM31::from_u32_unchecked(4, 0, 0, 0), + ]; + let result = fold(&values, &folding_factors); + + let expected = QM31::from_u32_unchecked(358, 0, 0, 0); + assert_eq!(result, expected, "Result for simple folding is incorrect"); + + // Set the desired length for folding_factors + let folding_factors_length = 20; // Example length + let values_length = 1 << folding_factors_length; // 2^folding_factors_length + + // Initialize the `values` vector + let mut values: Vec = Vec::with_capacity(values_length); + use rayon::iter::IntoParallelIterator; + use rayon::prelude::*; + + let values: Vec = (1..=values_length) + .into_par_iter() + .map(|i| M31(i as u32)) + .collect(); + + // Initialize the `folding_factors` vector + let mut folding_factors = Vec::with_capacity(folding_factors_length); + for i in 2..(2 + folding_factors_length) { + folding_factors.push(QM31::from_u32_unchecked(i as u32, 0, 0, 0)); + } + let time = std::time::Instant::now(); + let result = fold(&values, &folding_factors); + let elapsed = time.elapsed(); + println!( + "Elapsed time for 2^{}: {:?}", + folding_factors_length, elapsed + ); + + let expected = QM31::from_u32_unchecked(223550878, 0, 0, 0); + assert_eq!(result, expected, "Result for large folding is incorrect"); + } +} diff --git a/crates/prover/src/examples/mod.rs b/crates/prover/src/examples/mod.rs index 4a3511b51..0ca3346e7 100644 --- a/crates/prover/src/examples/mod.rs +++ b/crates/prover/src/examples/mod.rs @@ -4,3 +4,4 @@ pub mod poseidon; pub mod state_machine; pub mod wide_fibonacci; pub mod xor; +pub mod utils; diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index 57697f2d6..558cb3b58 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -493,6 +493,7 @@ mod tests { use crate::core::fields::FieldExpOps; use crate::core::pcs::{PcsConfig, TreeVec}; use crate::core::poly::circle::CanonicCoset; + use crate::examples::utils::get_env_var; #[test] fn test_state_machine_constraints() { @@ -559,14 +560,7 @@ mod tests { #[test] fn test_state_machine_prove() { - fn get_env_log2(key: &str, default: u32) -> u32 { - std::env::var(key) - .unwrap_or_else(|_| default.to_string()) - .parse() - .unwrap_or(default) - } - - let log_n_rows = get_env_log2("TSMP_LOG2", 8); + let log_n_rows = get_env_var("TSMP_LOG2", 8u32); let config = PcsConfig::default(); let initial_state = [M31::zero(); STATE_SIZE]; diff --git a/crates/prover/src/examples/utils.rs b/crates/prover/src/examples/utils.rs new file mode 100644 index 000000000..2620d212a --- /dev/null +++ b/crates/prover/src/examples/utils.rs @@ -0,0 +1,8 @@ +use std::{fmt::Display, str::FromStr}; + +pub fn get_env_var(key: &str, default: T) -> T { + std::env::var(key) + .unwrap_or_else(|_| default.to_string()) + .parse() + .unwrap_or(default) +} \ No newline at end of file diff --git a/crates/prover/src/examples/wide_fibonacci/mod.rs b/crates/prover/src/examples/wide_fibonacci/mod.rs index 115a1cc39..dee4bcc3e 100644 --- a/crates/prover/src/examples/wide_fibonacci/mod.rs +++ b/crates/prover/src/examples/wide_fibonacci/mod.rs @@ -229,16 +229,18 @@ mod tests { #[test] #[cfg(feature = "icicle")] fn test_wide_fib_prove_with_blake_icicle() { - // use crate::core::backend::icicle::IcicleBackend; - use crate::core::backend::CpuBackend; + use crate::core::backend::icicle::IcicleBackend; + // use crate::core::backend::CpuBackend; use crate::core::fields::m31::M31; - // type TheBackend = IcicleBackend; + use crate::examples::utils::get_env_var; + type TheBackend = IcicleBackend; + // type TheBackend = CpuBackend; - type TheBackend = CpuBackend; + let min_log = get_env_var("MIN_FIB_LOG", 2u32); + let max_log = get_env_var("MAX_FIB_LOG", 6u32); - // for log_n_instances in 2..=6 { - let log_n_instances = 6; - { + for log_n_instances in min_log..=max_log { + println!("proving for 2^{:?}...", log_n_instances); let config = PcsConfig::default(); // Precompute twiddles. let twiddles = TheBackend::precompute_twiddles( @@ -277,12 +279,18 @@ mod tests { (SecureField::zero(), None), ); + let start = std::time::Instant::now(); let proof = prove::( &[&component], prover_channel, commitment_scheme, ) .unwrap(); + println!( + "proving for 2^{:?} took {:?} ms", + log_n_instances, + start.elapsed().as_millis() + ); // Verify. let verifier_channel = &mut Blake2sChannel::default(); From be2dcd2d5d1aba940947d72c15745ee50fde9e4a Mon Sep 17 00:00:00 2001 From: VitaliiH Date: Mon, 16 Dec 2024 14:03:48 +0400 Subject: [PATCH 28/69] wip new domain --- Cargo.lock | 9 +- crates/prover/Cargo.toml | 8 +- crates/prover/src/core/backend/icicle/mod.rs | 157 ++++++++++--------- 3 files changed, 89 insertions(+), 85 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9b5a5061e..3b357b2c9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -760,7 +760,7 @@ dependencies = [ [[package]] name = "icicle-core" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0#eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=eb97ee08ef26d4de6e0695df3484abeddd83c08c#eb97ee08ef26d4de6e0695df3484abeddd83c08c" dependencies = [ "criterion 0.3.6", "hex", @@ -771,7 +771,7 @@ dependencies = [ [[package]] name = "icicle-cuda-runtime" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0#eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=eb97ee08ef26d4de6e0695df3484abeddd83c08c#eb97ee08ef26d4de6e0695df3484abeddd83c08c" dependencies = [ "bindgen", "bitflags 1.3.2", @@ -780,7 +780,7 @@ dependencies = [ [[package]] name = "icicle-hash" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0#eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=eb97ee08ef26d4de6e0695df3484abeddd83c08c#eb97ee08ef26d4de6e0695df3484abeddd83c08c" dependencies = [ "cmake", "icicle-core", @@ -790,12 +790,13 @@ dependencies = [ [[package]] name = "icicle-m31" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0#eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=eb97ee08ef26d4de6e0695df3484abeddd83c08c#eb97ee08ef26d4de6e0695df3484abeddd83c08c" dependencies = [ "cmake", "criterion 0.3.6", "icicle-core", "icicle-cuda-runtime", + "rayon", ] [[package]] diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index 61999031c..ca5930ddc 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -29,10 +29,10 @@ tracing.workspace = true rayon = { version = "1.10.0", optional = true } serde = { version = "1.0", features = ["derive"] } -icicle-cuda-runtime = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0"} -icicle-core = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0"} -icicle-m31 = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0"} -icicle-hash = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0"} +icicle-cuda-runtime = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="eb97ee08ef26d4de6e0695df3484abeddd83c08c"} +icicle-core = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="eb97ee08ef26d4de6e0695df3484abeddd83c08c"} +icicle-m31 = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="eb97ee08ef26d4de6e0695df3484abeddd83c08c"} +icicle-hash = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="eb97ee08ef26d4de6e0695df3484abeddd83c08c"} nvtx = { version = "*", optional = true } diff --git a/crates/prover/src/core/backend/icicle/mod.rs b/crates/prover/src/core/backend/icicle/mod.rs index 51519f82c..4af7695ca 100644 --- a/crates/prover/src/core/backend/icicle/mod.rs +++ b/crates/prover/src/core/backend/icicle/mod.rs @@ -12,7 +12,7 @@ use icicle_core::vec_ops::{accumulate_scalars, VecOpsConfig}; use icicle_core::Matrix; use icicle_hash::blake2s::build_blake2s_mmcs; use icicle_m31::dcct::{evaluate, get_dcct_root_of_unity, initialize_dcct_domain, interpolate}; -use icicle_m31::fri::{self, fold_circle_into_line, FriConfig}; +use icicle_m31::fri::{self, fold_circle_into_line, fold_circle_into_line_new, FriConfig}; use icicle_m31::quotient; use itertools::Itertools; use serde::{Deserialize, Serialize}; @@ -509,30 +509,30 @@ impl FriOps for IcicleBackend { let length = src.values.len(); let dom_vals_len = length / 2; - let _domain_log_size = domain.log_size(); + let domain_log_size = domain.log_size(); + + // let mut domain_rev = Vec::new(); + // for i in 0..dom_vals_len { + // // TODO: on-device batch + // // TODO(andrew): Inefficient. Update when domain twiddles get stored in a buffer. + // let p = domain.at(bit_reverse_index( + // i << CIRCLE_TO_LINE_FOLD_STEP, + // domain.log_size(), + // )); + // let p = p.y.inverse(); + // domain_rev.push(p); + // } - let mut domain_rev = Vec::new(); - for i in 0..dom_vals_len { - // TODO: on-device batch - // TODO(andrew): Inefficient. Update when domain twiddles get stored in a buffer. - let p = domain.at(bit_reverse_index( - i << CIRCLE_TO_LINE_FOLD_STEP, - domain.log_size(), - )); - let p = p.y.inverse(); - domain_rev.push(p); - } + // let domain_vals = (0..dom_vals_len) + // .map(|i| { + // let p = domain_rev[i]; + // ScalarField::from_u32(p.0) + // }) + // .collect::>(); - let domain_vals = (0..dom_vals_len) - .map(|i| { - let p = domain_rev[i]; - ScalarField::from_u32(p.0) - }) - .collect::>(); - - let domain_icicle_host = HostSlice::from_slice(domain_vals.as_slice()); - let mut d_domain_icicle = DeviceVec::::cuda_malloc(dom_vals_len).unwrap(); - d_domain_icicle.copy_from_host(domain_icicle_host).unwrap(); + // let domain_icicle_host = HostSlice::from_slice(domain_vals.as_slice()); + // let mut d_domain_icicle = DeviceVec::::cuda_malloc(dom_vals_len).unwrap(); + // d_domain_icicle.copy_from_host(domain_icicle_host).unwrap(); let mut d_evals_icicle = DeviceVec::::cuda_malloc(length).unwrap(); SecureColumnByCoords::convert_to_icicle(&src.values, &mut d_evals_icicle); @@ -546,12 +546,14 @@ impl FriOps for IcicleBackend { let cfg = FriConfig::default(); let icicle_alpha = unsafe { transmute(alpha) }; - let _ = fold_circle_into_line( + let _ = fold_circle_into_line_new( &d_evals_icicle[..], - &d_domain_icicle[..], + domain.half_coset.initial_index.0 as _, + domain.half_coset.log_size, &mut d_folded_eval[..], icicle_alpha, &cfg, + ) .unwrap(); @@ -606,59 +608,60 @@ impl QuotientOps for IcicleBackend { log_blowup_factor: u32, ) -> SecureEvaluation { - // unsafe { - // transmute(CpuBackend::accumulate_quotients( - // domain, - // unsafe { transmute(columns) }, - // random_coeff, - // sample_batches, - // log_blowup_factor, - // )) - // } + // TODO: the fn accumulate_quotients( fix seems doesn't work for this branch https://github.com/ingonyama-zk/icicle/commit/eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0 + unsafe { + transmute(CpuBackend::accumulate_quotients( + domain, + unsafe { transmute(columns) }, + random_coeff, + sample_batches, + log_blowup_factor, + )) + } - let icicle_columns_raw = columns - .iter() - .flat_map(|x| x.iter().map(|&y| unsafe { transmute(y) })) - .collect_vec(); - let icicle_columns = HostSlice::from_slice(&icicle_columns_raw); - let icicle_sample_batches = sample_batches - .into_iter() - .map(|sample| { - let (columns, values) = sample - .columns_and_values - .iter() - .map(|(index, value)| { - ((*index) as u32, unsafe { - transmute::(*value) - }) - }) - .unzip(); - - quotient::ColumnSampleBatch { - point: unsafe { transmute(sample.point) }, - columns, - values, - } - }) - .collect_vec(); - let mut icicle_result_raw = vec![QuarticExtensionField::zero(); domain.size()]; - let icicle_result = HostSlice::from_mut_slice(icicle_result_raw.as_mut_slice()); - let cfg = quotient::QuotientConfig::default(); - - quotient::accumulate_quotients_wrapped( - domain.half_coset.initial_index.0 as u32, - domain.half_coset.step_size.0 as u32, - domain.log_size() as u32, - icicle_columns, - unsafe { transmute(random_coeff) }, - &icicle_sample_batches, - icicle_result, - &cfg, - ); - // TODO: make it on cuda side - let mut result = unsafe { SecureColumnByCoords::uninitialized(domain.size()) }; - (0..domain.size()).for_each(|i| result.set(i, unsafe { transmute(icicle_result_raw[i]) })); - SecureEvaluation::new(domain, result) + // let icicle_columns_raw = columns + // .iter() + // .flat_map(|x| x.iter().map(|&y| unsafe { transmute(y) })) + // .collect_vec(); + // let icicle_columns = HostSlice::from_slice(&icicle_columns_raw); + // let icicle_sample_batches = sample_batches + // .into_iter() + // .map(|sample| { + // let (columns, values) = sample + // .columns_and_values + // .iter() + // .map(|(index, value)| { + // ((*index) as u32, unsafe { + // transmute::(*value) + // }) + // }) + // .unzip(); + + // quotient::ColumnSampleBatch { + // point: unsafe { transmute(sample.point) }, + // columns, + // values, + // } + // }) + // .collect_vec(); + // let mut icicle_result_raw = vec![QuarticExtensionField::zero(); domain.size()]; + // let icicle_result = HostSlice::from_mut_slice(icicle_result_raw.as_mut_slice()); + // let cfg = quotient::QuotientConfig::default(); + + // quotient::accumulate_quotients_wrapped( + // // domain.half_coset.initial_index.0 as u32, + // // domain.half_coset.step_size.0 as u32, + // domain.log_size() as u32, + // icicle_columns, + // unsafe { transmute(random_coeff) }, + // &icicle_sample_batches, + // icicle_result, + // &cfg, + // ); + // // TODO: make it on cuda side + // let mut result = unsafe { SecureColumnByCoords::uninitialized(domain.size()) }; + // (0..domain.size()).for_each(|i| result.set(i, unsafe { transmute(icicle_result_raw[i]) })); + // SecureEvaluation::new(domain, result) } } From 2c773848e9976afbf24c3623215e44660d7923ec Mon Sep 17 00:00:00 2001 From: Alon-Ti <54235977+Alon-Ti@users.noreply.github.com> Date: Wed, 18 Dec 2024 12:01:04 +0200 Subject: [PATCH 29/69] Added expression degrees. (#931) --- .../src/constraint_framework/expr/degree.rs | 100 ++++++++++++++++++ .../src/constraint_framework/expr/mod.rs | 1 + 2 files changed, 101 insertions(+) create mode 100644 crates/prover/src/constraint_framework/expr/degree.rs diff --git a/crates/prover/src/constraint_framework/expr/degree.rs b/crates/prover/src/constraint_framework/expr/degree.rs new file mode 100644 index 000000000..27c848336 --- /dev/null +++ b/crates/prover/src/constraint_framework/expr/degree.rs @@ -0,0 +1,100 @@ +/// Finds a degree bound for an expressions. The degree is given with respect to columns as +/// variables. +/// Computes the actual degree with the following caveats: +/// 1. The constant expression 0 receives degree 0 like all other constants rather than the +/// mathematically correcy -infinity. This means, for example, that expresisons of the +/// type 0 * expr will return degree deg expr. This should be mitigated by +/// simplification. +/// 2. If expressions p and q cancel out under some operation, this will not be accounted +/// for, so that (x^2 + 1) - (x^2 + x) will return degree 2. +use std::collections::HashMap; + +use super::{BaseExpr, ExtExpr}; + +type Degree = usize; + +/// A struct of named expressions that can be searched when determining the degree bound for an +/// expression that contains parameters. +/// Required because expressions that contain parameters that are actually intermediates have to +/// account for the degree of the intermediate. +pub struct NamedExprs { + exprs: HashMap, + ext_exprs: HashMap, +} + +impl NamedExprs { + pub fn degree_bound(&self, name: String) -> Degree { + if let Some(expr) = self.exprs.get(&name) { + expr.degree_bound(self) + } else if let Some(expr) = self.ext_exprs.get(&name) { + expr.degree_bound(self) + } else if name.starts_with("preprocessed.") { + // TODO(alont): Fix this hack. + 1 + } else { + // If expression isn't found assume it's an external variable, effectively a const. + 0 + } + } +} + +impl BaseExpr { + pub fn degree_bound(&self, named_exprs: &NamedExprs) -> Degree { + match self { + BaseExpr::Col(_) => 1, + BaseExpr::Const(_) => 0, + BaseExpr::Param(name) => named_exprs.degree_bound(name.clone()), + BaseExpr::Add(a, b) => a.degree_bound(named_exprs).max(b.degree_bound(named_exprs)), + BaseExpr::Sub(a, b) => a.degree_bound(named_exprs).max(b.degree_bound(named_exprs)), + BaseExpr::Mul(a, b) => a.degree_bound(named_exprs) + b.degree_bound(named_exprs), + BaseExpr::Neg(a) => a.degree_bound(named_exprs), + // TODO(alont): Consider handling this in the type system. + BaseExpr::Inv(_) => panic!("Cannot compute the degree of an inverse."), + } + } +} + +impl ExtExpr { + pub fn degree_bound(&self, named_exprs: &NamedExprs) -> Degree { + match self { + ExtExpr::SecureCol(coefs) => coefs + .iter() + .cloned() + .map(|coef| coef.degree_bound(named_exprs)) + .max() + .unwrap(), + ExtExpr::Const(_) => 0, + ExtExpr::Param(name) => named_exprs.degree_bound(name.clone()), + ExtExpr::Add(a, b) => a.degree_bound(named_exprs).max(b.degree_bound(named_exprs)), + ExtExpr::Sub(a, b) => a.degree_bound(named_exprs).max(b.degree_bound(named_exprs)), + ExtExpr::Mul(a, b) => a.degree_bound(named_exprs) + b.degree_bound(named_exprs), + ExtExpr::Neg(a) => a.degree_bound(named_exprs), + } + } +} + +#[cfg(test)] +mod tests { + use crate::constraint_framework::expr::degree::NamedExprs; + use crate::constraint_framework::expr::utils::*; + + #[test] + fn test_degree_bound() { + let intermediate = (felt!(12) + col!(1, 1, 0)) * var!("a") * col!(1, 0, 0); + let qintermediate = secure_col!(intermediate.clone(), felt!(12), var!("b"), felt!(0)); + + let named_exprs = NamedExprs { + exprs: [("intermediate".to_string(), intermediate.clone())].into(), + ext_exprs: [("qintermediate".to_string(), qintermediate.clone())].into(), + }; + + let expr = var!("intermediate") * col!(2, 1, 0); + let qexpr = + var!("qintermediate") * secure_col!(col!(2, 1, 0), expr.clone(), felt!(0), felt!(1)); + + assert_eq!(intermediate.degree_bound(&named_exprs), 2); + assert_eq!(qintermediate.degree_bound(&named_exprs), 2); + assert_eq!(expr.degree_bound(&named_exprs), 3); + assert_eq!(qexpr.degree_bound(&named_exprs), 5); + } +} diff --git a/crates/prover/src/constraint_framework/expr/mod.rs b/crates/prover/src/constraint_framework/expr/mod.rs index 14668c6e2..7b17fc73e 100644 --- a/crates/prover/src/constraint_framework/expr/mod.rs +++ b/crates/prover/src/constraint_framework/expr/mod.rs @@ -1,4 +1,5 @@ pub mod assignment; +pub mod degree; pub mod evaluator; pub mod format; pub mod simplify; From f0f28c5d275d465963f93a137ab000c106a89941 Mon Sep 17 00:00:00 2001 From: Gali Michlevich Date: Tue, 10 Dec 2024 13:16:46 +0200 Subject: [PATCH 30/69] Add secure powers generation for simd --- crates/prover/src/core/backend/simd/utils.rs | 49 +++++++++++++++++++- crates/prover/src/core/utils.rs | 1 + 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/crates/prover/src/core/backend/simd/utils.rs b/crates/prover/src/core/backend/simd/utils.rs index d5f53a22b..64931145a 100644 --- a/crates/prover/src/core/backend/simd/utils.rs +++ b/crates/prover/src/core/backend/simd/utils.rs @@ -1,5 +1,12 @@ use std::simd::Swizzle; +use itertools::Itertools; + +use crate::core::backend::simd::m31::N_LANES; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::fields::qm31::SecureField; +use crate::core::utils::generate_secure_powers; + /// Used with [`Swizzle::concat_swizzle`] to interleave the even values of two vectors. pub struct InterleaveEvens; @@ -51,11 +58,35 @@ impl UnsafeConst { unsafe impl Send for UnsafeConst {} unsafe impl Sync for UnsafeConst {} +// TODO(Gali): Remove #[allow(dead_code)]. +#[allow(dead_code)] +/// Generates the first `n_powers` powers of `felt` using SIMD. +/// Refer to [`generate_secure_powers`] for the scalar implementation. +pub fn generate_secure_powers_simd(felt: SecureField, n_powers: usize) -> Vec { + let base_arr = generate_secure_powers(felt, N_LANES).try_into().unwrap(); + let base = PackedSecureField::from_array(base_arr); + let step = PackedSecureField::broadcast(base_arr[N_LANES - 1] * felt); + let size = n_powers.div_ceil(N_LANES); + + // Collects the next N_LANES powers of `felt` in each iteration. + (0..size) + .scan(base, |acc, _| { + let res = *acc; + *acc *= step; + Some(res) + }) + .flat_map(|x| x.to_array()) + .take(n_powers) + .collect_vec() +} + #[cfg(test)] mod tests { use std::simd::{u32x4, Swizzle}; - use super::{InterleaveEvens, InterleaveOdds}; + use super::{generate_secure_powers_simd, InterleaveEvens, InterleaveOdds}; + use crate::core::utils::generate_secure_powers; + use crate::qm31; #[test] fn interleave_evens() { @@ -76,4 +107,20 @@ mod tests { assert_eq!(res, u32x4::from_array([1, 5, 3, 7])); } + + #[test] + fn test_generate_secure_powers_simd() { + let felt = qm31!(1, 2, 3, 4); + let n_powers_vec = [0, 16, 100]; + + n_powers_vec.iter().for_each(|&n_powers| { + let expected = generate_secure_powers(felt, n_powers); + let actual = generate_secure_powers_simd(felt, n_powers); + assert_eq!( + expected, actual, + "Error generating secure powers in n_powers = {}.", + n_powers + ); + }); + } } diff --git a/crates/prover/src/core/utils.rs b/crates/prover/src/core/utils.rs index 745168c38..1babd20af 100644 --- a/crates/prover/src/core/utils.rs +++ b/crates/prover/src/core/utils.rs @@ -146,6 +146,7 @@ pub fn bit_reverse_coset_to_circle_domain_order(v: &mut [T]) { } } +/// Generates the first `n_powers` powers of `felt`. pub fn generate_secure_powers(felt: SecureField, n_powers: usize) -> Vec { (0..n_powers) .scan(SecureField::one(), |acc, _| { From cf1eca6439158e88cd503a7e31ade6551ba4dd18 Mon Sep 17 00:00:00 2001 From: Gali Michlevich Date: Mon, 16 Dec 2024 12:42:22 +0200 Subject: [PATCH 31/69] generic backend secure power generation --- crates/prover/src/core/air/accumulation.rs | 6 ++- .../src/core/backend/cpu/accumulation.rs | 48 ++++++++++++++++- crates/prover/src/core/backend/cpu/mod.rs | 2 +- .../src/core/backend/simd/accumulation.rs | 54 ++++++++++++++++++- crates/prover/src/core/backend/simd/utils.rs | 49 +---------------- crates/prover/src/core/utils.rs | 42 +-------------- .../examples/xor/gkr_lookups/accumulation.rs | 5 +- 7 files changed, 110 insertions(+), 96 deletions(-) diff --git a/crates/prover/src/core/air/accumulation.rs b/crates/prover/src/core/air/accumulation.rs index f958f2029..d02235a5f 100644 --- a/crates/prover/src/core/air/accumulation.rs +++ b/crates/prover/src/core/air/accumulation.rs @@ -13,7 +13,6 @@ use crate::core::fields::secure_column::SecureColumnByCoords; use crate::core::fields::FieldOps; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, CirclePoly, SecureCirclePoly}; use crate::core::poly::BitReversedOrder; -use crate::core::utils::generate_secure_powers; /// Accumulates N evaluations of u_i(P0) at a single point. /// Computes f(P0), the combined polynomial at that point. @@ -63,7 +62,7 @@ impl DomainEvaluationAccumulator { pub fn new(random_coeff: SecureField, max_log_size: u32, total_columns: usize) -> Self { let max_log_size = max_log_size as usize; Self { - random_coeff_powers: generate_secure_powers(random_coeff, total_columns), + random_coeff_powers: B::generate_secure_powers(random_coeff, total_columns), sub_accumulations: (0..(max_log_size + 1)).map(|_| None).collect(), } } @@ -106,6 +105,9 @@ pub trait AccumulationOps: FieldOps + Sized { /// Accumulates other into column: /// column = column + other. fn accumulate(column: &mut SecureColumnByCoords, other: &SecureColumnByCoords); + + /// Generates the first `n_powers` powers of `felt`. + fn generate_secure_powers(felt: SecureField, n_powers: usize) -> Vec; } impl DomainEvaluationAccumulator { diff --git a/crates/prover/src/core/backend/cpu/accumulation.rs b/crates/prover/src/core/backend/cpu/accumulation.rs index 63a49bf15..5d3e8895c 100644 --- a/crates/prover/src/core/backend/cpu/accumulation.rs +++ b/crates/prover/src/core/backend/cpu/accumulation.rs @@ -1,5 +1,8 @@ -use super::CpuBackend; +use num_traits::One; + use crate::core::air::accumulation::AccumulationOps; +use crate::core::backend::cpu::CpuBackend; +use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SecureColumnByCoords; impl AccumulationOps for CpuBackend { @@ -9,4 +12,47 @@ impl AccumulationOps for CpuBackend { column.set(i, res_coeff); } } + + fn generate_secure_powers(felt: SecureField, n_powers: usize) -> Vec { + (0..n_powers) + .scan(SecureField::one(), |acc, _| { + let res = *acc; + *acc *= felt; + Some(res) + }) + .collect() + } +} + +#[cfg(test)] +mod tests { + use num_traits::One; + + use crate::core::air::accumulation::AccumulationOps; + use crate::core::backend::CpuBackend; + use crate::core::fields::qm31::SecureField; + use crate::core::fields::FieldExpOps; + use crate::qm31; + #[test] + fn generate_secure_powers_works() { + let felt = qm31!(1, 2, 3, 4); + let n_powers = 10; + + let powers = ::generate_secure_powers(felt, n_powers); + + assert_eq!(powers.len(), n_powers); + assert_eq!(powers[0], SecureField::one()); + assert_eq!(powers[1], felt); + assert_eq!(powers[7], felt.pow(7)); + } + + #[test] + fn generate_empty_secure_powers_works() { + let felt = qm31!(1, 2, 3, 4); + let max_log_size = 0; + + let powers = ::generate_secure_powers(felt, max_log_size); + + assert_eq!(powers, vec![]); + } } diff --git a/crates/prover/src/core/backend/cpu/mod.rs b/crates/prover/src/core/backend/cpu/mod.rs index ea6e49c07..457ca9d2a 100644 --- a/crates/prover/src/core/backend/cpu/mod.rs +++ b/crates/prover/src/core/backend/cpu/mod.rs @@ -1,4 +1,4 @@ -mod accumulation; +pub mod accumulation; mod blake2s; pub mod circle; mod fri; diff --git a/crates/prover/src/core/backend/simd/accumulation.rs b/crates/prover/src/core/backend/simd/accumulation.rs index c9705df6b..f2b59d1e4 100644 --- a/crates/prover/src/core/backend/simd/accumulation.rs +++ b/crates/prover/src/core/backend/simd/accumulation.rs @@ -1,5 +1,11 @@ -use super::SimdBackend; +use itertools::Itertools; + use crate::core::air::accumulation::AccumulationOps; +use crate::core::backend::simd::m31::N_LANES; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::CpuBackend; +use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SecureColumnByCoords; impl AccumulationOps for SimdBackend { @@ -9,4 +15,50 @@ impl AccumulationOps for SimdBackend { unsafe { column.set_packed(i, res_coeff) }; } } + + /// Generates the first `n_powers` powers of `felt` using SIMD. + /// Refer to `CpuBackend::generate_secure_powers` for the scalar CPU implementation. + fn generate_secure_powers(felt: SecureField, n_powers: usize) -> Vec { + let base_arr = ::generate_secure_powers(felt, N_LANES) + .try_into() + .unwrap(); + let base = PackedSecureField::from_array(base_arr); + let step = PackedSecureField::broadcast(base_arr[N_LANES - 1] * felt); + let size = n_powers.div_ceil(N_LANES); + + // Collects the next N_LANES powers of `felt` in each iteration. + (0..size) + .scan(base, |acc, _| { + let res = *acc; + *acc *= step; + Some(res) + }) + .flat_map(|x| x.to_array()) + .take(n_powers) + .collect_vec() + } +} + +#[cfg(test)] +mod tests { + use crate::core::air::accumulation::AccumulationOps; + use crate::core::backend::cpu::CpuBackend; + use crate::core::backend::simd::SimdBackend; + use crate::qm31; + + #[test] + fn test_generate_secure_powers_simd() { + let felt = qm31!(1, 2, 3, 4); + let n_powers_vec = [0, 16, 100]; + + n_powers_vec.iter().for_each(|&n_powers| { + let expected = ::generate_secure_powers(felt, n_powers); + let actual = ::generate_secure_powers(felt, n_powers); + assert_eq!( + expected, actual, + "Error generating secure powers in n_powers = {}.", + n_powers + ); + }); + } } diff --git a/crates/prover/src/core/backend/simd/utils.rs b/crates/prover/src/core/backend/simd/utils.rs index 64931145a..d5f53a22b 100644 --- a/crates/prover/src/core/backend/simd/utils.rs +++ b/crates/prover/src/core/backend/simd/utils.rs @@ -1,12 +1,5 @@ use std::simd::Swizzle; -use itertools::Itertools; - -use crate::core::backend::simd::m31::N_LANES; -use crate::core::backend::simd::qm31::PackedSecureField; -use crate::core::fields::qm31::SecureField; -use crate::core::utils::generate_secure_powers; - /// Used with [`Swizzle::concat_swizzle`] to interleave the even values of two vectors. pub struct InterleaveEvens; @@ -58,35 +51,11 @@ impl UnsafeConst { unsafe impl Send for UnsafeConst {} unsafe impl Sync for UnsafeConst {} -// TODO(Gali): Remove #[allow(dead_code)]. -#[allow(dead_code)] -/// Generates the first `n_powers` powers of `felt` using SIMD. -/// Refer to [`generate_secure_powers`] for the scalar implementation. -pub fn generate_secure_powers_simd(felt: SecureField, n_powers: usize) -> Vec { - let base_arr = generate_secure_powers(felt, N_LANES).try_into().unwrap(); - let base = PackedSecureField::from_array(base_arr); - let step = PackedSecureField::broadcast(base_arr[N_LANES - 1] * felt); - let size = n_powers.div_ceil(N_LANES); - - // Collects the next N_LANES powers of `felt` in each iteration. - (0..size) - .scan(base, |acc, _| { - let res = *acc; - *acc *= step; - Some(res) - }) - .flat_map(|x| x.to_array()) - .take(n_powers) - .collect_vec() -} - #[cfg(test)] mod tests { use std::simd::{u32x4, Swizzle}; - use super::{generate_secure_powers_simd, InterleaveEvens, InterleaveOdds}; - use crate::core::utils::generate_secure_powers; - use crate::qm31; + use super::{InterleaveEvens, InterleaveOdds}; #[test] fn interleave_evens() { @@ -107,20 +76,4 @@ mod tests { assert_eq!(res, u32x4::from_array([1, 5, 3, 7])); } - - #[test] - fn test_generate_secure_powers_simd() { - let felt = qm31!(1, 2, 3, 4); - let n_powers_vec = [0, 16, 100]; - - n_powers_vec.iter().for_each(|&n_powers| { - let expected = generate_secure_powers(felt, n_powers); - let actual = generate_secure_powers_simd(felt, n_powers); - assert_eq!( - expected, actual, - "Error generating secure powers in n_powers = {}.", - n_powers - ); - }); - } } diff --git a/crates/prover/src/core/utils.rs b/crates/prover/src/core/utils.rs index 1babd20af..df7c77e4c 100644 --- a/crates/prover/src/core/utils.rs +++ b/crates/prover/src/core/utils.rs @@ -1,9 +1,6 @@ use std::iter::Peekable; -use num_traits::One; - use super::fields::m31::BaseField; -use super::fields::qm31::SecureField; use super::fields::Field; pub trait IteratorMutExt<'a, T: 'a>: Iterator { @@ -146,54 +143,17 @@ pub fn bit_reverse_coset_to_circle_domain_order(v: &mut [T]) { } } -/// Generates the first `n_powers` powers of `felt`. -pub fn generate_secure_powers(felt: SecureField, n_powers: usize) -> Vec { - (0..n_powers) - .scan(SecureField::one(), |acc, _| { - let res = *acc; - *acc *= felt; - Some(res) - }) - .collect() -} - #[cfg(test)] mod tests { use itertools::Itertools; - use num_traits::One; use super::{ offset_bit_reversed_circle_domain_index, previous_bit_reversed_circle_domain_index, }; use crate::core::backend::cpu::CpuCircleEvaluation; - use crate::core::fields::qm31::SecureField; - use crate::core::fields::FieldExpOps; use crate::core::poly::circle::CanonicCoset; use crate::core::poly::NaturalOrder; - use crate::{m31, qm31}; - - #[test] - fn generate_secure_powers_works() { - let felt = qm31!(1, 2, 3, 4); - let n_powers = 10; - - let powers = super::generate_secure_powers(felt, n_powers); - - assert_eq!(powers.len(), n_powers); - assert_eq!(powers[0], SecureField::one()); - assert_eq!(powers[1], felt); - assert_eq!(powers[7], felt.pow(7)); - } - - #[test] - fn generate_empty_secure_powers_works() { - let felt = qm31!(1, 2, 3, 4); - let max_log_size = 0; - - let powers = super::generate_secure_powers(felt, max_log_size); - - assert_eq!(powers, vec![]); - } + use crate::m31; #[test] fn test_offset_bit_reversed_circle_domain_index() { diff --git a/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs b/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs index 986289572..8e0ae2d74 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs @@ -4,13 +4,13 @@ use std::ops::{AddAssign, Mul}; use educe::Educe; use num_traits::One; +use crate::core::air::accumulation::AccumulationOps; use crate::core::backend::simd::SimdBackend; use crate::core::backend::Backend; use crate::core::circle::M31_CIRCLE_LOG_ORDER; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::lookups::mle::Mle; -use crate::core::utils::generate_secure_powers; pub const MIN_LOG_BLOWUP_FACTOR: u32 = 1; @@ -59,7 +59,8 @@ fn mle_random_linear_combination( assert!(!mles.is_empty()); let n_variables = mles[0].n_variables(); assert!(mles.iter().all(|mle| mle.n_variables() == n_variables)); - let coeff_powers = generate_secure_powers(random_coeff, mles.len()); + let coeff_powers = + ::generate_secure_powers(random_coeff, mles.len()); let mut mle_and_coeff = zip(mles, coeff_powers.into_iter().rev()); // The last value can initialize the accumulator. From af0b35fab244e6c03e796c267e1a05bb45b92a0c Mon Sep 17 00:00:00 2001 From: Gali Michlevich Date: Tue, 17 Dec 2024 10:21:17 +0200 Subject: [PATCH 32/69] Unite impls of DomainEvaluationAccumulator --- crates/prover/src/core/air/accumulation.rs | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/crates/prover/src/core/air/accumulation.rs b/crates/prover/src/core/air/accumulation.rs index d02235a5f..c01828fd4 100644 --- a/crates/prover/src/core/air/accumulation.rs +++ b/crates/prover/src/core/air/accumulation.rs @@ -99,18 +99,7 @@ impl DomainEvaluationAccumulator { pub fn log_size(&self) -> u32 { (self.sub_accumulations.len() - 1) as u32 } -} - -pub trait AccumulationOps: FieldOps + Sized { - /// Accumulates other into column: - /// column = column + other. - fn accumulate(column: &mut SecureColumnByCoords, other: &SecureColumnByCoords); - - /// Generates the first `n_powers` powers of `felt`. - fn generate_secure_powers(felt: SecureField, n_powers: usize) -> Vec; -} -impl DomainEvaluationAccumulator { /// Computes f(P) as coefficients. pub fn finalize(self) -> SecureCirclePoly { assert_eq!( @@ -159,6 +148,15 @@ impl DomainEvaluationAccumulator { } } +pub trait AccumulationOps: FieldOps + Sized { + /// Accumulates other into column: + /// column = column + other. + fn accumulate(column: &mut SecureColumnByCoords, other: &SecureColumnByCoords); + + /// Generates the first `n_powers` powers of `felt`. + fn generate_secure_powers(felt: SecureField, n_powers: usize) -> Vec; +} + /// A domain accumulator for polynomials of a single size. pub struct ColumnAccumulator<'a, B: Backend> { pub random_coeff_powers: Vec, From 0e106eca86d7005f1ba074e894d38c026cdba921 Mon Sep 17 00:00:00 2001 From: Ohad Agadi Date: Wed, 18 Dec 2024 13:35:58 +0200 Subject: [PATCH 33/69] opened issue to rustc --- crates/prover/src/core/circle.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/prover/src/core/circle.rs b/crates/prover/src/core/circle.rs index 9f7c99d82..a20ee3451 100644 --- a/crates/prover/src/core/circle.rs +++ b/crates/prover/src/core/circle.rs @@ -126,7 +126,8 @@ impl + FieldExpOps + Sub + Neg type Output = Self; fn add(self, rhs: Self) -> Self::Output { - let x = self.x.clone() * rhs.x.clone() - self.y.clone() * rhs.y.clone(); + // TODO(ShaharS): Revert once Rust solves compiler [issue](https://github.com/rust-lang/rust/issues/134457). + let x = self.x.clone() * rhs.x.clone() + (-self.y.clone() * rhs.y.clone()); let y = self.x * rhs.y + self.y * rhs.x; Self { x, y } } From 172cd35bf86de4312b74a429d5fa12383aee5372 Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Sat, 21 Sep 2024 14:09:26 -1000 Subject: [PATCH 34/69] Update toolchain --- .github/workflows/benchmarks-pages.yaml | 6 +- .github/workflows/ci.yaml | 40 +- .github/workflows/coverage.yaml | 4 +- Cargo.lock | 475 ++++++++------ crates/prover/Cargo.toml | 4 - crates/prover/benches/fft.rs | 14 +- crates/prover/benches/merkle.rs | 2 +- .../prover/src/constraint_framework/assert.rs | 2 +- .../src/constraint_framework/component.rs | 5 +- .../src/constraint_framework/cpu_domain.rs | 2 +- .../prover/src/constraint_framework/logup.rs | 2 +- .../prover/src/constraint_framework/point.rs | 2 +- .../constraint_framework/relation_tracker.rs | 2 +- .../src/constraint_framework/simd_domain.rs | 2 +- crates/prover/src/core/air/accumulation.rs | 3 +- crates/prover/src/core/air/components.rs | 4 +- crates/prover/src/core/air/mod.rs | 8 +- .../src/core/backend/cpu/lookups/gkr.rs | 2 +- .../prover/src/core/backend/cpu/quotients.rs | 13 +- .../src/core/backend/simd/bit_reverse.rs | 2 +- .../prover/src/core/backend/simd/blake2s.rs | 8 +- crates/prover/src/core/backend/simd/circle.rs | 11 +- crates/prover/src/core/backend/simd/column.rs | 6 +- .../prover/src/core/backend/simd/fft/ifft.rs | 13 +- .../prover/src/core/backend/simd/fft/mod.rs | 16 +- .../prover/src/core/backend/simd/fft/rfft.rs | 19 +- crates/prover/src/core/backend/simd/fri.rs | 15 +- .../src/core/backend/simd/lookups/gkr.rs | 2 +- .../src/core/backend/simd/lookups/mle.rs | 3 +- crates/prover/src/core/backend/simd/m31.rs | 578 +++++++++--------- .../prover/src/core/backend/simd/quotients.rs | 4 +- crates/prover/src/core/backend/simd/utils.rs | 91 +-- crates/prover/src/core/channel/blake2s.rs | 2 +- crates/prover/src/core/constraints.rs | 10 +- crates/prover/src/core/fri.rs | 7 +- crates/prover/src/core/lookups/gkr_prover.rs | 4 +- .../prover/src/core/lookups/gkr_verifier.rs | 4 +- crates/prover/src/core/pcs/mod.rs | 1 + crates/prover/src/core/pcs/prover.rs | 2 +- crates/prover/src/core/pcs/utils.rs | 2 +- crates/prover/src/core/pcs/verifier.rs | 2 +- crates/prover/src/core/poly/circle/canonic.rs | 14 +- crates/prover/src/core/poly/circle/domain.rs | 5 +- .../prover/src/core/poly/circle/evaluation.rs | 4 +- crates/prover/src/core/poly/twiddles.rs | 1 + crates/prover/src/core/utils.rs | 2 +- crates/prover/src/core/vcs/blake2_merkle.rs | 13 +- crates/prover/src/core/vcs/blake2s_ref.rs | 8 +- crates/prover/src/core/vcs/ops.rs | 11 +- crates/prover/src/core/vcs/prover.rs | 5 +- crates/prover/src/core/vcs/verifier.rs | 3 +- .../src/examples/blake/round/constraints.rs | 7 +- crates/prover/src/examples/blake/round/gen.rs | 2 +- crates/prover/src/examples/blake/round/mod.rs | 4 +- .../src/examples/blake/scheduler/mod.rs | 2 + crates/prover/src/examples/poseidon/mod.rs | 2 +- .../src/examples/xor/gkr_lookups/mle_eval.rs | 10 +- crates/prover/src/lib.rs | 13 +- rust-toolchain.toml | 2 +- scripts/clippy.sh | 2 +- scripts/rust_fmt.sh | 2 +- scripts/test_avx.sh | 2 +- 62 files changed, 814 insertions(+), 694 deletions(-) diff --git a/.github/workflows/benchmarks-pages.yaml b/.github/workflows/benchmarks-pages.yaml index 4625dd88d..5e2f877b3 100644 --- a/.github/workflows/benchmarks-pages.yaml +++ b/.github/workflows/benchmarks-pages.yaml @@ -1,4 +1,4 @@ -name: +name: on: push: @@ -18,7 +18,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - name: Run benchmark run: ./scripts/bench.sh -- --output-format bencher | tee output.txt - name: Download previous benchmark data @@ -29,7 +29,7 @@ jobs: - name: Store benchmark result uses: benchmark-action/github-action-benchmark@v1 with: - tool: 'cargo' + tool: "cargo" output-file-path: output.txt github-token: ${{ secrets.GITHUB_TOKEN }} auto-push: true diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 5fb3258e6..8a00b5467 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -25,7 +25,7 @@ jobs: - uses: dtolnay/rust-toolchain@master with: components: rustfmt - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - uses: Swatinem/rust-cache@v2 - run: scripts/rust_fmt.sh --check @@ -36,7 +36,7 @@ jobs: - uses: dtolnay/rust-toolchain@master with: components: clippy - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - uses: Swatinem/rust-cache@v2 - run: scripts/clippy.sh @@ -46,9 +46,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 doc + - run: cargo +nightly-2024-11-06 doc run-wasm32-wasi-tests: runs-on: ubuntu-latest @@ -56,7 +56,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 targets: wasm32-wasi - uses: taiki-e/install-action@v2 with: @@ -73,7 +73,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 targets: wasm32-unknown-unknown - uses: Swatinem/rust-cache@v2 - uses: jetli/wasm-pack-action@v0.4.0 @@ -89,9 +89,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 test + - run: cargo +nightly-2024-11-06 test env: RUSTFLAGS: -C target-feature=+neon @@ -104,9 +104,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 test + - run: cargo +nightly-2024-11-06 test env: RUSTFLAGS: -C target-cpu=native -C target-feature=+${{ matrix.target-feature }} @@ -116,7 +116,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - name: Run benchmark run: ./scripts/bench.sh -- --output-format bencher | tee output.txt - name: Download previous benchmark data @@ -142,7 +142,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - name: Run benchmark run: ./scripts/bench.sh --features="parallel" -- --output-format bencher | tee output.txt - name: Download previous benchmark data @@ -168,9 +168,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 test + - run: cargo +nightly-2024-11-06 test run-slow-tests: runs-on: ubuntu-latest @@ -178,9 +178,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 test --release --features="slow-tests" + - run: cargo +nightly-2024-11-06 test --release --features="slow-tests" run-tests-parallel: runs-on: ubuntu-latest @@ -188,9 +188,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 test --features="parallel" + - run: cargo +nightly-2024-11-06 test --features="parallel" machete: runs-on: ubuntu-latest @@ -201,9 +201,9 @@ jobs: toolchain: nightly-2024-01-04 - uses: Swatinem/rust-cache@v2 - name: Install Machete - run: cargo +nightly-2024-01-04 install --locked cargo-machete + run: cargo +nightly-2024-11-06 install --locked cargo-machete - name: Run Machete (detect unused dependencies) - run: cargo +nightly-2024-01-04 machete + run: cargo +nightly-2024-11-06 machete all-tests: runs-on: ubuntu-latest diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml index 504cd67bb..508e0f11b 100644 --- a/.github/workflows/coverage.yaml +++ b/.github/workflows/coverage.yaml @@ -12,14 +12,14 @@ jobs: - uses: dtolnay/rust-toolchain@master with: components: rustfmt - toolchain: nightly-2024-01-04 + toolchain: nightly-2024-11-06 - uses: Swatinem/rust-cache@v2 - name: Install cargo-llvm-cov uses: taiki-e/install-action@cargo-llvm-cov # TODO: Merge coverage reports for tests on different architectures. # - name: Generate code coverage - run: cargo +nightly-2024-01-04 llvm-cov --codecov --output-path codecov.json + run: cargo +nightly-2024-11-06 llvm-cov --codecov --output-path codecov.json env: RUSTFLAGS: "-C target-feature=+avx512f" - name: Upload coverage to Codecov diff --git a/Cargo.lock b/Cargo.lock index c14183025..a7e009352 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "aho-corasick" @@ -26,11 +26,54 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" +[[package]] +name = "anstream" +version = "0.6.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + [[package]] name = "anstyle" -version = "1.0.6" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" + +[[package]] +name = "anstyle-parse" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" +dependencies = [ + "anstyle", + "windows-sys 0.59.0", +] [[package]] name = "ark-ff" @@ -98,15 +141,15 @@ dependencies = [ [[package]] name = "arrayref" -version = "0.3.7" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b4930d2cb77ce62f89ee5d5289b4ac049559b1c45539271f5ed4fdc7db34545" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" [[package]] name = "arrayvec" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "as-slice" @@ -119,9 +162,9 @@ dependencies = [ [[package]] name = "autocfg" -version = "1.2.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "bigdecimal" @@ -146,9 +189,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.5.1" +version = "1.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30cca6d3674597c30ddf2c587bf8d9d65c9a84d2326d941cc79c9842dfe0ef52" +checksum = "b8ee0c1824c4dea5b5f81736aff91bae041d2c07ee1192bec91054e10e3e601e" dependencies = [ "arrayref", "arrayvec", @@ -174,24 +217,30 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.15.0" +version = "1.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d6d68c57235a3a081186990eca2867354726650f42f7516ca50c28d6281fd15" +checksum = "8b37c88a63ffd85d15b406896cc343916d7cf57838a847b3a6f2ca5d39a5695a" dependencies = [ "bytemuck_derive", ] [[package]] name = "bytemuck_derive" -version = "1.6.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4da9a32f3fed317401fa3c862968128267c3106685286e15d5aaa3d7389c2f60" +checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.90", ] +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "cast" version = "0.3.0" @@ -200,9 +249,12 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.0.95" +version = "1.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d32a725bc159af97c3e629873bb9f88fb8cf8a4867175f76dc987815ea07c83b" +checksum = "9157bbaa6b165880c27a4293a474c91cdcf265cc68cc829bf10be0964a391caf" +dependencies = [ + "shlex", +] [[package]] name = "cfg-if" @@ -239,18 +291,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.4" +version = "4.5.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90bc066a67923782aa8515dbaea16946c5bcc5addbd668bb80af688e53e548a0" +checksum = "3135e7ec2ef7b10c6ed8950f0f792ed96ee093fa088608f1c76e569722700c84" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.2" +version = "4.5.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae129e2e766ae0ec03484e609954119f123cc1fe650337e155d03b022f24f7b4" +checksum = "30582fc632330df2bd26877bde0c1f4470d57c582bbc070376afcd04d8cb4838" dependencies = [ "anstyle", "clap_lex", @@ -258,31 +310,27 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.0" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" +checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" [[package]] -name = "console_error_panic_hook" -version = "0.1.7" +name = "colorchoice" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a06aeb73f470f66dcdbf7223caeebb85984942f22f1adb2a088cf9668146bbbc" -dependencies = [ - "cfg-if", - "wasm-bindgen", -] +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" [[package]] name = "constant_time_eq" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7144d30dcf0fafbce74250a3963025d8d52177934239851c917d29f1df280c2" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" [[package]] name = "cpufeatures" -version = "0.2.12" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" dependencies = [ "libc", ] @@ -325,9 +373,9 @@ dependencies = [ [[package]] name = "crossbeam-deque" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" dependencies = [ "crossbeam-epoch", "crossbeam-utils", @@ -344,9 +392,9 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.19" +version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crunchy" @@ -397,12 +445,6 @@ dependencies = [ "subtle", ] -[[package]] -name = "downcast-rs" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2" - [[package]] name = "educe" version = "0.5.11" @@ -412,14 +454,14 @@ dependencies = [ "enum-ordinalize", "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.90", ] [[package]] name = "either" -version = "1.11.0" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] name = "enum-ordinalize" @@ -438,24 +480,26 @@ checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.90", ] [[package]] name = "env_filter" -version = "0.1.0" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a009aa4810eb158359dda09d0c87378e4bbb89b5a801f016885a4707ba24f7ea" +checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab" dependencies = [ "log", ] [[package]] name = "env_logger" -version = "0.11.3" +version = "0.11.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38b35839ba51819680ba087cd351788c9a3c476841207e0b8cee0b04722343b9" +checksum = "e13fa619b91fb2381732789fc5de83b45675e882f66623b7d8cb4f643017018d" dependencies = [ + "anstream", + "anstyle", "env_filter", "log", ] @@ -495,9 +539,9 @@ dependencies = [ [[package]] name = "hermit-abi" -version = "0.3.9" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" +checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" [[package]] name = "hex" @@ -516,15 +560,21 @@ dependencies = [ [[package]] name = "is-terminal" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b" +checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b" dependencies = [ "hermit-abi", "libc", - "windows-sys", + "windows-sys 0.52.0", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" + [[package]] name = "itertools" version = "0.10.5" @@ -545,36 +595,37 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.11" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" [[package]] name = "js-sys" -version = "0.3.70" +version = "0.3.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" +checksum = "6717b6b5b077764fb5966237269cb3c64edddde4b14ce42647430a78ced9e7b7" dependencies = [ + "once_cell", "wasm-bindgen", ] [[package]] name = "lazy_static" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.155" +version = "0.2.168" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +checksum = "5aaeb2981e0606ca11d79718f8bb01164f1d6ed75080182d3abf017e6d244b6d" [[package]] name = "log" -version = "0.4.21" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "matchers" @@ -587,15 +638,15 @@ dependencies = [ [[package]] name = "memchr" -version = "2.7.2" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "minicov" -version = "0.3.5" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c71e683cd655513b99affab7d317deb690528255a0d5f717f1024093c12b169" +checksum = "f27fe9f1cc3c22e1687f9446c2083c4c5fc7f0bcf1c7a86bdbded14985895b4b" dependencies = [ "cc", "walkdir", @@ -613,9 +664,9 @@ dependencies = [ [[package]] name = "num-bigint" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c165a9ab64cf766f73521c0dd2cfdff64f488b8f0b3e621face3462d3db536d7" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ "num-integer", "num-traits", @@ -632,24 +683,24 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.18" +version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", ] [[package]] name = "once_cell" -version = "1.19.0" +version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" [[package]] name = "oorandom" -version = "11.1.3" +version = "11.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" +checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" [[package]] name = "overload" @@ -665,15 +716,15 @@ checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" [[package]] name = "pin-project-lite" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" [[package]] name = "plotters" -version = "0.3.5" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" dependencies = [ "num-traits", "plotters-backend", @@ -684,39 +735,42 @@ dependencies = [ [[package]] name = "plotters-backend" -version = "0.3.5" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" [[package]] name = "plotters-svg" -version = "0.3.5" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" dependencies = [ "plotters-backend", ] [[package]] name = "ppv-lite86" -version = "0.2.17" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] [[package]] name = "proc-macro2" -version = "1.0.81" +version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba" +checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.36" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ "proc-macro2", ] @@ -769,14 +823,14 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.4" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.6", - "regex-syntax 0.8.3", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", ] [[package]] @@ -790,13 +844,13 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.6" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.3", + "regex-syntax 0.8.5", ] [[package]] @@ -807,9 +861,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "regex-syntax" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "rfc6979" @@ -823,18 +877,18 @@ dependencies = [ [[package]] name = "rustc_version" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" dependencies = [ "semver", ] [[package]] name = "ryu" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" [[package]] name = "same-file" @@ -853,37 +907,38 @@ checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" [[package]] name = "semver" -version = "1.0.23" +version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" +checksum = "3cb6eb87a131f756572d7fb904f6e7b68633f09cca868c5df1c4b8d1a694bbba" [[package]] name = "serde" -version = "1.0.198" +version = "1.0.216" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9846a40c979031340571da2545a4e5b7c4163bdae79b301d5f86d03979451fcc" +checksum = "0b9781016e935a97e8beecf0c933758c97a5520d32930e460142b4cd80c6338e" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.198" +version = "1.0.216" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9" +checksum = "46f859dbbf73865c6627ed570e78961cd3ac92407a2d117204c49232485da55e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.90", ] [[package]] name = "serde_json" -version = "1.0.116" +version = "1.0.133" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813" +checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" dependencies = [ "itoa", + "memchr", "ryu", "serde", ] @@ -908,6 +963,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "smallvec" version = "1.13.2" @@ -948,7 +1009,7 @@ checksum = "bbc159a1934c7be9761c237333a57febe060ace2bc9e3b337a59a37af206d19f" dependencies = [ "starknet-curve", "starknet-ff", - "syn 2.0.60", + "syn 2.0.90", ] [[package]] @@ -984,7 +1045,6 @@ dependencies = [ "bytemuck", "cfg-if", "criterion", - "downcast-rs", "educe", "hex", "itertools 0.12.1", @@ -1003,9 +1063,9 @@ dependencies = [ [[package]] name = "subtle" -version = "2.5.0" +version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" @@ -1020,9 +1080,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.60" +version = "2.0.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3" +checksum = "919d3b74a5dd0ccd15aeb8f93e7006bd9e14c295087c9896a110f490752bcf31" dependencies = [ "proc-macro2", "quote", @@ -1031,9 +1091,9 @@ dependencies = [ [[package]] name = "test-log" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b319995299c65d522680decf80f2c108d85b861d81dfe340a10d16cee29d9e6" +checksum = "3dffced63c2b5c7be278154d76b479f9f9920ed34e7574201407f0b14e2bbb93" dependencies = [ "env_logger", "test-log-macros", @@ -1042,33 +1102,33 @@ dependencies = [ [[package]] name = "test-log-macros" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8f546451eaa38373f549093fe9fd05e7d2bade739e2ddf834b9968621d60107" +checksum = "5999e24eaa32083191ba4e425deb75cdf25efefabe5aaccb7446dd0d4122a3f5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.90", ] [[package]] name = "thiserror" -version = "1.0.59" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0126ad08bff79f29fc3ae6a55cc72352056dfff61e3ff8bb7129476d44b23aa" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.59" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1cd413b5d558b4c5bf3680e324a6fa5014e7b7c067a51e69dbdf47eb7148b66" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.90", ] [[package]] @@ -1093,9 +1153,9 @@ dependencies = [ [[package]] name = "tracing" -version = "0.1.40" +version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ "pin-project-lite", "tracing-attributes", @@ -1104,20 +1164,20 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.27" +version = "0.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.90", ] [[package]] name = "tracing-core" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" dependencies = [ "once_cell", "valuable", @@ -1136,9 +1196,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.18" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" dependencies = [ "matchers", "nu-ansi-term", @@ -1160,9 +1220,15 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicode-ident" -version = "1.0.12" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" + +[[package]] +name = "utf8parse" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "valuable" @@ -1172,9 +1238,9 @@ checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" [[package]] name = "version_check" -version = "0.9.4" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" [[package]] name = "walkdir" @@ -1194,9 +1260,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.93" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" +checksum = "a474f6281d1d70c17ae7aa6a613c87fce69a127e2624002df63dcb39d6cf6396" dependencies = [ "cfg-if", "once_cell", @@ -1205,36 +1271,36 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.93" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" +checksum = "5f89bb38646b4f81674e8f5c3fb81b562be1fd936d84320f3264486418519c79" dependencies = [ "bumpalo", "log", - "once_cell", "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.90", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.43" +version = "0.4.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" +checksum = "38176d9b44ea84e9184eff0bc34cc167ed044f816accfe5922e54d84cf48eca2" dependencies = [ "cfg-if", "js-sys", + "once_cell", "wasm-bindgen", "web-sys", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.93" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" +checksum = "2cc6181fd9a7492eef6fef1f33961e3695e4579b9872a6f7c83aee556666d4fe" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -1242,30 +1308,29 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.93" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" +checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.90", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.93" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" +checksum = "943aab3fdaaa029a6e0271b35ea10b72b943135afe9bffca82384098ad0e06a6" [[package]] name = "wasm-bindgen-test" -version = "0.3.43" +version = "0.3.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68497a05fb21143a08a7d24fc81763384a3072ee43c44e86aad1744d6adef9d9" +checksum = "c61d44563646eb934577f2772656c7ad5e9c90fac78aa8013d776fcdaf24625d" dependencies = [ - "console_error_panic_hook", "js-sys", "minicov", "scoped-tls", @@ -1276,20 +1341,20 @@ dependencies = [ [[package]] name = "wasm-bindgen-test-macro" -version = "0.3.43" +version = "0.3.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b8220be1fa9e4c889b30fd207d4906657e7e90b12e0e6b0c8b8d8709f5de021" +checksum = "54171416ce73aa0b9c377b51cc3cb542becee1cd678204812e8392e5b0e4a031" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.90", ] [[package]] name = "web-sys" -version = "0.3.69" +version = "0.3.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +checksum = "04dd7223427d52553d3702c004d3b2fe07c148165faa56313cb00211e31c12bc" dependencies = [ "js-sys", "wasm-bindgen", @@ -1313,11 +1378,11 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] name = "winapi-util" -version = "0.1.6" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "winapi", + "windows-sys 0.59.0", ] [[package]] @@ -1335,11 +1400,20 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets", +] + [[package]] name = "windows-targets" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ "windows_aarch64_gnullvm", "windows_aarch64_msvc", @@ -1353,51 +1427,72 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "byteorder", + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.90", +] [[package]] name = "zeroize" @@ -1416,5 +1511,5 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.90", ] diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index a9b80e9e4..03e99ae28 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -14,7 +14,6 @@ blake2.workspace = true blake3.workspace = true bytemuck = { workspace = true, features = ["derive", "extern_crate_alloc"] } cfg-if = "1.0.0" -downcast-rs = "1.2" educe.workspace = true hex.workspace = true itertools.workspace = true @@ -59,9 +58,6 @@ unused = "deny" [lints.clippy] missing_const_for_fn = "warn" -[package.metadata.cargo-machete] -ignored = ["downcast-rs"] - [[bench]] harness = false name = "bit_rev" diff --git a/crates/prover/benches/fft.rs b/crates/prover/benches/fft.rs index 35841d7e8..cbb0c9e80 100644 --- a/crates/prover/benches/fft.rs +++ b/crates/prover/benches/fft.rs @@ -29,7 +29,7 @@ pub fn simd_ifft(c: &mut Criterion) { || values.clone().data, |mut data| unsafe { ifft( - transmute(data.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(data.as_mut_ptr()), black_box(&twiddle_dbls_refs), black_box(log_size as usize), ); @@ -58,7 +58,7 @@ pub fn simd_ifft_parts(c: &mut Criterion) { || values.clone().data, |mut values| unsafe { ifft_vecwise_loop( - transmute(values.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(values.as_mut_ptr()), black_box(&twiddle_dbls_refs), black_box(9), black_box(0), @@ -72,7 +72,7 @@ pub fn simd_ifft_parts(c: &mut Criterion) { || values.clone().data, |mut values| unsafe { ifft3_loop( - transmute(values.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(values.as_mut_ptr()), black_box(&twiddle_dbls_refs[3..]), black_box(7), black_box(4), @@ -91,7 +91,7 @@ pub fn simd_ifft_parts(c: &mut Criterion) { || transpose_values.clone().data, |mut values| unsafe { transpose_vecs( - transmute(values.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(values.as_mut_ptr()), black_box(TRANSPOSE_LOG_SIZE as usize - 4), ) }, @@ -115,8 +115,10 @@ pub fn simd_rfft(c: &mut Criterion) { target.set_len(values.data.len()); fft( - black_box(transmute(values.data.as_ptr())), - transmute(target.as_mut_ptr()), + black_box(transmute::<*const PackedBaseField, *const u32>( + values.data.as_ptr(), + )), + transmute::<*mut PackedBaseField, *mut u32>(target.as_mut_ptr()), black_box(&twiddle_dbls_refs), black_box(LOG_SIZE as usize), ) diff --git a/crates/prover/benches/merkle.rs b/crates/prover/benches/merkle.rs index c039be77e..9a63a3c38 100644 --- a/crates/prover/benches/merkle.rs +++ b/crates/prover/benches/merkle.rs @@ -21,7 +21,7 @@ fn bench_blake2s_merkle>(c: &mut Criterion, id let n_elements = 1 << (LOG_N_COLS + LOG_N_ROWS); group.throughput(Throughput::Elements(n_elements)); group.throughput(Throughput::Bytes(N_BYTES_FELT as u64 * n_elements)); - group.bench_function(&format!("{id} merkle"), |b| { + group.bench_function(format!("{id} merkle"), |b| { b.iter_with_large_drop(|| B::commit_on_layer(LOG_N_ROWS, None, &col_refs)) }); } diff --git a/crates/prover/src/constraint_framework/assert.rs b/crates/prover/src/constraint_framework/assert.rs index 34ab6fdec..376ff80b1 100644 --- a/crates/prover/src/constraint_framework/assert.rs +++ b/crates/prover/src/constraint_framework/assert.rs @@ -33,7 +33,7 @@ impl<'a> AssertEvaluator<'a> { } } } -impl<'a> EvalAtRow for AssertEvaluator<'a> { +impl EvalAtRow for AssertEvaluator<'_> { type F = BaseField; type EF = SecureField; diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index 8f082f5f7..86f00609c 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -110,9 +110,10 @@ impl TraceLocationAllocator { } /// A component defined solely in means of the constraints framework. +/// /// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for -/// the SIMD backend. -/// Note that the constraint framework only support components with columns of the same size. +/// the SIMD backend. Note that the constraint framework only supports components with columns of +/// the same size. pub trait FrameworkEval { fn log_size(&self) -> u32; diff --git a/crates/prover/src/constraint_framework/cpu_domain.rs b/crates/prover/src/constraint_framework/cpu_domain.rs index 8c0f4beb9..03089bd17 100644 --- a/crates/prover/src/constraint_framework/cpu_domain.rs +++ b/crates/prover/src/constraint_framework/cpu_domain.rs @@ -52,7 +52,7 @@ impl<'a> CpuDomainEvaluator<'a> { } } -impl<'a> EvalAtRow for CpuDomainEvaluator<'a> { +impl EvalAtRow for CpuDomainEvaluator<'_> { type F = BaseField; type EF = SecureField; diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index bb05c6b5c..370987e4c 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -238,7 +238,7 @@ pub struct LogupColGenerator<'a> { /// Numerator expressions (i.e. multiplicities) being generated for the current lookup. numerator: SecureColumnByCoords, } -impl<'a> LogupColGenerator<'a> { +impl LogupColGenerator<'_> { /// Write a fraction to the column at a row. pub fn write_frac( &mut self, diff --git a/crates/prover/src/constraint_framework/point.rs b/crates/prover/src/constraint_framework/point.rs index 3fc2ad510..ea01c647d 100644 --- a/crates/prover/src/constraint_framework/point.rs +++ b/crates/prover/src/constraint_framework/point.rs @@ -35,7 +35,7 @@ impl<'a> PointEvaluator<'a> { } } } -impl<'a> EvalAtRow for PointEvaluator<'a> { +impl EvalAtRow for PointEvaluator<'_> { type F = SecureField; type EF = SecureField; diff --git a/crates/prover/src/constraint_framework/relation_tracker.rs b/crates/prover/src/constraint_framework/relation_tracker.rs index 8311209d1..8b522b615 100644 --- a/crates/prover/src/constraint_framework/relation_tracker.rs +++ b/crates/prover/src/constraint_framework/relation_tracker.rs @@ -105,7 +105,7 @@ impl<'a> RelationTrackerEvaluator<'a> { self.entries } } -impl<'a> EvalAtRow for RelationTrackerEvaluator<'a> { +impl EvalAtRow for RelationTrackerEvaluator<'_> { type F = PackedBaseField; type EF = PackedSecureField; diff --git a/crates/prover/src/constraint_framework/simd_domain.rs b/crates/prover/src/constraint_framework/simd_domain.rs index c85942228..65c52708c 100644 --- a/crates/prover/src/constraint_framework/simd_domain.rs +++ b/crates/prover/src/constraint_framework/simd_domain.rs @@ -57,7 +57,7 @@ impl<'a> SimdDomainEvaluator<'a> { } } } -impl<'a> EvalAtRow for SimdDomainEvaluator<'a> { +impl EvalAtRow for SimdDomainEvaluator<'_> { type F = VeryPackedBaseField; type EF = VeryPackedSecureField; diff --git a/crates/prover/src/core/air/accumulation.rs b/crates/prover/src/core/air/accumulation.rs index c01828fd4..f53519929 100644 --- a/crates/prover/src/core/air/accumulation.rs +++ b/crates/prover/src/core/air/accumulation.rs @@ -1,4 +1,5 @@ //! Accumulators for a random linear combination of circle polynomials. +//! //! Given N polynomials, u_0(P), ... u_{N-1}(P), and a random alpha, the combined polynomial is //! defined as //! f(p) = sum_i alpha^{N-1-i} u_i(P). @@ -162,7 +163,7 @@ pub struct ColumnAccumulator<'a, B: Backend> { pub random_coeff_powers: Vec, pub col: &'a mut SecureColumnByCoords, } -impl<'a> ColumnAccumulator<'a, CpuBackend> { +impl ColumnAccumulator<'_, CpuBackend> { pub fn accumulate(&mut self, index: usize, evaluation: SecureField) { let val = self.col.at(index) + evaluation; self.col.set(index, val); diff --git a/crates/prover/src/core/air/components.rs b/crates/prover/src/core/air/components.rs index 3f9bf78ad..1008f0bfb 100644 --- a/crates/prover/src/core/air/components.rs +++ b/crates/prover/src/core/air/components.rs @@ -17,7 +17,7 @@ pub struct Components<'a> { pub n_preprocessed_columns: usize, } -impl<'a> Components<'a> { +impl Components<'_> { pub fn composition_log_degree_bound(&self) -> u32 { self.components .iter() @@ -108,7 +108,7 @@ pub struct ComponentProvers<'a, B: Backend> { pub n_preprocessed_columns: usize, } -impl<'a, B: Backend> ComponentProvers<'a, B> { +impl ComponentProvers<'_, B> { pub fn components(&self) -> Components<'_> { Components { components: self diff --git a/crates/prover/src/core/air/mod.rs b/crates/prover/src/core/air/mod.rs index 671d05048..fbcf6c736 100644 --- a/crates/prover/src/core/air/mod.rs +++ b/crates/prover/src/core/air/mod.rs @@ -15,10 +15,10 @@ mod components; pub mod mask; /// Arithmetic Intermediate Representation (AIR). -/// An Air instance is assumed to already contain all the information needed to -/// evaluate the constraints. -/// For instance, all interaction elements are assumed to be present in it. -/// Therefore, an AIR is generated only after the initial trace commitment phase. +/// +/// An Air instance is assumed to already contain all the information needed to evaluate the +/// constraints. For instance, all interaction elements are assumed to be present in it. Therefore, +/// an AIR is generated only after the initial trace commitment phase. pub trait Air { fn components(&self) -> Vec<&dyn Component>; } diff --git a/crates/prover/src/core/backend/cpu/lookups/gkr.rs b/crates/prover/src/core/backend/cpu/lookups/gkr.rs index ae9ab6b65..9c1d60093 100644 --- a/crates/prover/src/core/backend/cpu/lookups/gkr.rs +++ b/crates/prover/src/core/backend/cpu/lookups/gkr.rs @@ -265,7 +265,7 @@ enum MleExpr<'a, F: Field> { Mle(&'a Mle), } -impl<'a, F: Field> Index for MleExpr<'a, F> { +impl Index for MleExpr<'_, F> { type Output = F; fn index(&self, index: usize) -> &F { diff --git a/crates/prover/src/core/backend/cpu/quotients.rs b/crates/prover/src/core/backend/cpu/quotients.rs index f157b76ca..16f0647b6 100644 --- a/crates/prover/src/core/backend/cpu/quotients.rs +++ b/crates/prover/src/core/backend/cpu/quotients.rs @@ -73,10 +73,10 @@ pub fn accumulate_row_quotients( row_accumulator } -/// Precompute the complex conjugate line coefficients for each column in each sample batch. -/// Specifically, for the i-th (in a sample batch) column's numerator term -/// `alpha^i * (c * F(p) - (a * p.y + b))`, we precompute and return the constants: -/// (`alpha^i * a`, `alpha^i * b`, `alpha^i * c`). +/// Precomputes the complex conjugate line coefficients for each column in each sample batch. +/// +/// For the `i`-th (in a sample batch) column's numerator term `alpha^i * (c * F(p) - (a * p.y + +/// b))`, we precompute and return the constants: (`alpha^i * a`, `alpha^i * b`, `alpha^i * c`). pub fn column_line_coeffs( sample_batches: &[ColumnSampleBatch], random_coeff: SecureField, @@ -101,8 +101,9 @@ pub fn column_line_coeffs( .collect() } -/// Precompute the random coefficients used to linearly combine the batched quotients. -/// Specifically, for each sample batch we compute random_coeff^(number of columns in the batch), +/// Precomputes the random coefficients used to linearly combine the batched quotients. +/// +/// For each sample batch we compute random_coeff^(number of columns in the batch), /// which is used to linearly combine the batch with the next one. pub fn batch_random_coeffs( sample_batches: &[ColumnSampleBatch], diff --git a/crates/prover/src/core/backend/simd/bit_reverse.rs b/crates/prover/src/core/backend/simd/bit_reverse.rs index cc2a55c98..fe27a149c 100644 --- a/crates/prover/src/core/backend/simd/bit_reverse.rs +++ b/crates/prover/src/core/backend/simd/bit_reverse.rs @@ -166,7 +166,7 @@ mod tests { let res = bit_reverse16(values.data.try_into().unwrap()); - assert_eq!(res.map(PackedM31::to_array).flatten(), expected); + assert_eq!(res.map(PackedM31::to_array).as_flattened(), expected); } #[test] diff --git a/crates/prover/src/core/backend/simd/blake2s.rs b/crates/prover/src/core/backend/simd/blake2s.rs index 4f2297d19..3f4d46b8f 100644 --- a/crates/prover/src/core/backend/simd/blake2s.rs +++ b/crates/prover/src/core/backend/simd/blake2s.rs @@ -364,8 +364,12 @@ mod tests { let res_vectorized: [[u32; 8]; 16] = unsafe { transmute(untranspose_states(compress16( - transpose_states(transmute(states)), - transpose_msgs(transmute(msgs)), + transpose_states(transmute::, [u32x16; 8]>( + states, + )), + transpose_msgs(transmute::, [u32x16; 16]>( + msgs, + )), u32x16::splat(count_low), u32x16::splat(count_high), u32x16::splat(lastblock), diff --git a/crates/prover/src/core/backend/simd/circle.rs b/crates/prover/src/core/backend/simd/circle.rs index a20721a4f..61588ffe3 100644 --- a/crates/prover/src/core/backend/simd/circle.rs +++ b/crates/prover/src/core/backend/simd/circle.rs @@ -89,10 +89,7 @@ impl SimdBackend { // Generates twiddle steps for efficiently computing the twiddles. // steps[i] = t_i/(t_0*t_1*...*t_i-1). - fn twiddle_steps(mappings: &[F]) -> Vec - where - F: FieldExpOps, - { + fn twiddle_steps(mappings: &[F]) -> Vec { let mut denominators: Vec = vec![mappings[0]]; for i in 1..mappings.len() { @@ -159,7 +156,7 @@ impl PolyOps for SimdBackend { // Safe because [PackedBaseField] is aligned on 64 bytes. unsafe { ifft::ifft( - transmute(values.data.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(values.data.as_mut_ptr()), &twiddles, log_size as usize, ); @@ -269,8 +266,8 @@ impl PolyOps for SimdBackend { // FFT from the coefficients buffer to the values chunk. unsafe { rfft::fft( - transmute(poly.coeffs.data.as_ptr()), - transmute( + transmute::<*const PackedBaseField, *const u32>(poly.coeffs.data.as_ptr()), + transmute::<*mut PackedBaseField, *mut u32>( values[i << (fft_log_size - LOG_N_LANES) ..(i + 1) << (fft_log_size - LOG_N_LANES)] .as_mut_ptr(), diff --git a/crates/prover/src/core/backend/simd/column.rs b/crates/prover/src/core/backend/simd/column.rs index dd5578c0e..29a6e58c5 100644 --- a/crates/prover/src/core/backend/simd/column.rs +++ b/crates/prover/src/core/backend/simd/column.rs @@ -207,7 +207,7 @@ impl FromIterator for CM31Column { /// A mutable slice of a BaseColumn. pub struct BaseColumnMutSlice<'a>(pub &'a mut [PackedBaseField]); -impl<'a> BaseColumnMutSlice<'a> { +impl BaseColumnMutSlice<'_> { pub fn at(&self, index: usize) -> BaseField { self.0[index / N_LANES].to_array()[index % N_LANES] } @@ -323,7 +323,7 @@ impl FromIterator for SecureColumn { /// A mutable slice of a SecureColumnByCoords. pub struct SecureColumnByCoordsMutSlice<'a>(pub [BaseColumnMutSlice<'a>; SECURE_EXTENSION_DEGREE]); -impl<'a> SecureColumnByCoordsMutSlice<'a> { +impl SecureColumnByCoordsMutSlice<'_> { /// # Safety /// /// `vec_index` must be a valid index. @@ -357,7 +357,7 @@ pub struct VeryPackedSecureColumnByCoordsMutSlice<'a>( pub [VeryPackedBaseColumnMutSlice<'a>; SECURE_EXTENSION_DEGREE], ); -impl<'a> VeryPackedSecureColumnByCoordsMutSlice<'a> { +impl VeryPackedSecureColumnByCoordsMutSlice<'_> { /// # Safety /// /// `vec_index` must be a valid index. diff --git a/crates/prover/src/core/backend/simd/fft/ifft.rs b/crates/prover/src/core/backend/simd/fft/ifft.rs index a6abb48e0..feab2ab54 100644 --- a/crates/prover/src/core/backend/simd/fft/ifft.rs +++ b/crates/prover/src/core/backend/simd/fft/ifft.rs @@ -598,7 +598,7 @@ mod tests { let mut res = values; unsafe { ifft3( - transmute(res.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.as_mut_ptr()), 0, LOG_N_LANES as usize, twiddles0_dbl, @@ -664,7 +664,7 @@ mod tests { [val0.to_array(), val1.to_array()].concat() }; - assert_eq!(res, ground_truth_ifft(domain, values.flatten())); + assert_eq!(res, ground_truth_ifft(domain, values.as_flattened())); } #[test] @@ -678,7 +678,7 @@ mod tests { let mut res = values.iter().copied().collect::(); unsafe { ifft_lower_with_vecwise( - transmute(res.data.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), log_size as usize, log_size as usize, @@ -700,11 +700,14 @@ mod tests { let mut res = values.iter().copied().collect::(); unsafe { ifft( - transmute(res.data.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), log_size as usize, ); - transpose_vecs(transmute(res.data.as_mut_ptr()), log_size as usize - 4); + transpose_vecs( + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), + log_size as usize - 4, + ); } assert_eq!(res.to_cpu(), ground_truth_ifft(domain, &values)); diff --git a/crates/prover/src/core/backend/simd/fft/mod.rs b/crates/prover/src/core/backend/simd/fft/mod.rs index b3ea4d700..78624d9e0 100644 --- a/crates/prover/src/core/backend/simd/fft/mod.rs +++ b/crates/prover/src/core/backend/simd/fft/mod.rs @@ -102,7 +102,7 @@ const unsafe fn load(mem_addr: *const u32) -> u32x16 { } #[inline] -unsafe fn store(mem_addr: *mut u32, a: u32x16) { +const unsafe fn store(mem_addr: *mut u32, a: u32x16) { std::ptr::write(mem_addr as *mut u32x16, a); } @@ -111,19 +111,19 @@ fn mul_twiddle(v: PackedBaseField, twiddle_dbl: u32x16) -> PackedBaseField { // TODO: Come up with a better approach than `cfg`ing on target_feature. // TODO: Ensure all these branches get tested in the CI. cfg_if::cfg_if! { - if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] { + if #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] { // TODO: For architectures that when multiplying require doubling then the twiddles // should be precomputed as double. For other architectures, the twiddle should be // precomputed without doubling. - crate::core::backend::simd::m31::_mul_doubled_neon(v, twiddle_dbl) - } else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] { - crate::core::backend::simd::m31::_mul_doubled_wasm(v, twiddle_dbl) + crate::core::backend::simd::m31::mul_doubled_neon(v, twiddle_dbl) + } else if #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))] { + crate::core::backend::simd::m31::mul_doubled_wasm(v, twiddle_dbl) } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] { - crate::core::backend::simd::m31::_mul_doubled_avx512(v, twiddle_dbl) + crate::core::backend::simd::m31::mul_doubled_avx512(v, twiddle_dbl) } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] { - crate::core::backend::simd::m31::_mul_doubled_avx2(v, twiddle_dbl) + crate::core::backend::simd::m31::mul_doubled_avx2(v, twiddle_dbl) } else { - crate::core::backend::simd::m31::_mul_doubled_simd(v, twiddle_dbl) + crate::core::backend::simd::m31::mul_doubled_simd(v, twiddle_dbl) } } } diff --git a/crates/prover/src/core/backend/simd/fft/rfft.rs b/crates/prover/src/core/backend/simd/fft/rfft.rs index 1249b11e4..4500f64ef 100644 --- a/crates/prover/src/core/backend/simd/fft/rfft.rs +++ b/crates/prover/src/core/backend/simd/fft/rfft.rs @@ -624,8 +624,8 @@ mod tests { let mut res = values; unsafe { fft3( - transmute(res.as_ptr()), - transmute(res.as_mut_ptr()), + transmute::<*const PackedBaseField, *const u32>(res.as_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.as_mut_ptr()), 0, LOG_N_LANES as usize, twiddles0_dbl, @@ -695,7 +695,7 @@ mod tests { [val0.to_array(), val1.to_array()].concat() }; - assert_eq!(res, ground_truth_fft(domain, values.flatten())); + assert_eq!(res, ground_truth_fft(domain, values.as_flattened())); } #[test] @@ -709,8 +709,8 @@ mod tests { let mut res = values.iter().copied().collect::(); unsafe { fft_lower_with_vecwise( - transmute(res.data.as_ptr()), - transmute(res.data.as_mut_ptr()), + transmute::<*const PackedBaseField, *const u32>(res.data.as_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), log_size as usize, log_size as usize, @@ -731,10 +731,13 @@ mod tests { let mut res = values.iter().copied().collect::(); unsafe { - transpose_vecs(transmute(res.data.as_mut_ptr()), log_size as usize - 4); + transpose_vecs( + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), + log_size as usize - 4, + ); fft( - transmute(res.data.as_ptr()), - transmute(res.data.as_mut_ptr()), + transmute::<*const PackedBaseField, *const u32>(res.data.as_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), log_size as usize, ); diff --git a/crates/prover/src/core/backend/simd/fri.rs b/crates/prover/src/core/backend/simd/fri.rs index 8804a7015..3ced43459 100644 --- a/crates/prover/src/core/backend/simd/fri.rs +++ b/crates/prover/src/core/backend/simd/fri.rs @@ -1,5 +1,5 @@ use std::array; -use std::simd::u32x8; +use std::simd::{u32x16, u32x8}; use num_traits::Zero; @@ -38,14 +38,15 @@ impl FriOps for SimdBackend { let mut folded_values = SecureColumnByCoords::::zeros(1 << (log_size - 1)); for vec_index in 0..(1 << (log_size - 1 - LOG_N_LANES)) { - let value = unsafe { - let twiddle_dbl: [u32; 16] = - array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 16 + i)); - let val0 = eval.values.packed_at(vec_index * 2).into_packed_m31s(); - let val1 = eval.values.packed_at(vec_index * 2 + 1).into_packed_m31s(); + let value = { + let twiddle_dbl = u32x16::from_array(array::from_fn(|i| unsafe { + *itwiddles.get_unchecked(vec_index * 16 + i) + })); + let val0 = unsafe { eval.values.packed_at(vec_index * 2) }.into_packed_m31s(); + let val1 = unsafe { eval.values.packed_at(vec_index * 2 + 1) }.into_packed_m31s(); let pairs: [_; 4] = array::from_fn(|i| { let (a, b) = val0[i].deinterleave(val1[i]); - simd_ibutterfly(a, b, std::mem::transmute(twiddle_dbl)) + simd_ibutterfly(a, b, twiddle_dbl) }); let val0 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].0)); let val1 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].1)); diff --git a/crates/prover/src/core/backend/simd/lookups/gkr.rs b/crates/prover/src/core/backend/simd/lookups/gkr.rs index 017948dee..74d7f7c43 100644 --- a/crates/prover/src/core/backend/simd/lookups/gkr.rs +++ b/crates/prover/src/core/backend/simd/lookups/gkr.rs @@ -25,7 +25,7 @@ impl GkrOps for SimdBackend { } // Start DP with CPU backend to avoid dealing with instances smaller than a SIMD vector. - let (y_last_chunk, y_rem) = y.split_last_chunk::<{ LOG_N_LANES as usize }>().unwrap(); + let (y_rem, y_last_chunk) = y.split_last_chunk::<{ LOG_N_LANES as usize }>().unwrap(); let initial = SecureColumn::from_iter(cpu_gen_eq_evals(y_last_chunk, v)); assert_eq!(initial.len(), N_LANES); diff --git a/crates/prover/src/core/backend/simd/lookups/mle.rs b/crates/prover/src/core/backend/simd/lookups/mle.rs index 0e2fe73f7..07f175bbc 100644 --- a/crates/prover/src/core/backend/simd/lookups/mle.rs +++ b/crates/prover/src/core/backend/simd/lookups/mle.rs @@ -30,9 +30,8 @@ impl MleOps for SimdBackend { let (evals_at_0x, evals_at_1x) = mle.data.split_at(packed_midpoint); let res = zip(evals_at_0x, evals_at_1x) - .enumerate() // MLE at points `({0, 1}, rev(bits(i)), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. - .map(|(_i, (&packed_eval_at_0iv, &packed_eval_at_1iv))| { + .map(|(&packed_eval_at_0iv, &packed_eval_at_1iv)| { fold_packed_mle_evals(packed_assignment, packed_eval_at_0iv, packed_eval_at_1iv) }) .collect(); diff --git a/crates/prover/src/core/backend/simd/m31.rs b/crates/prover/src/core/backend/simd/m31.rs index dbeec152f..3d10be8c0 100644 --- a/crates/prover/src/core/backend/simd/m31.rs +++ b/crates/prover/src/core/backend/simd/m31.rs @@ -3,14 +3,13 @@ use std::mem::transmute; use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use std::ptr; use std::simd::cmp::SimdOrd; -use std::simd::{u32x16, Simd, Swizzle}; +use std::simd::{u32x16, Simd}; use bytemuck::{Pod, Zeroable}; use num_traits::{One, Zero}; use rand::distributions::{Distribution, Standard}; use super::qm31::PackedQM31; -use crate::core::backend::simd::utils::{InterleaveEvens, InterleaveOdds}; use crate::core::fields::m31::{pow2147483645, BaseField, M31, P}; use crate::core::fields::qm31::QM31; use crate::core::fields::FieldExpOps; @@ -101,7 +100,7 @@ impl PackedM31 { /// /// Behavior is undefined if the pointer does not have the same alignment as /// [`PackedM31`]. - pub unsafe fn store(self, dst: *mut u32) { + pub const unsafe fn store(self, dst: *mut u32) { ptr::write(dst as *mut u32x16, self.0) } } @@ -142,16 +141,16 @@ impl Mul for PackedM31 { // TODO: Come up with a better approach than `cfg`ing on target_feature. // TODO: Ensure all these branches get tested in the CI. cfg_if::cfg_if! { - if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] { - _mul_neon(self, rhs) - } else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] { - _mul_wasm(self, rhs) + if #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] { + mul_neon(self, rhs) + } else if #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))] { + mul_wasm(self, rhs) } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] { - _mul_avx512(self, rhs) + mul_avx512(self, rhs) } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] { - _mul_avx2(self, rhs) + mul_avx2(self, rhs) } else { - _mul_simd(self, rhs) + mul_simd(self, rhs) } } } @@ -286,290 +285,299 @@ impl Sum for PackedM31 { } } -/// Returns `a * b`. -#[cfg(target_arch = "aarch64")] -pub(crate) fn _mul_neon(a: PackedM31, b: PackedM31) -> PackedM31 { - use core::arch::aarch64::{int32x2_t, vqdmull_s32}; - use std::simd::u32x4; - - let [a0, a1, a2, a3, a4, a5, a6, a7]: [int32x2_t; 8] = unsafe { transmute(a) }; - let [b0, b1, b2, b3, b4, b5, b6, b7]: [int32x2_t; 8] = unsafe { transmute(b) }; - - // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| - let c0: u32x4 = unsafe { transmute(vqdmull_s32(a0, b0)) }; - let c1: u32x4 = unsafe { transmute(vqdmull_s32(a1, b1)) }; - let c2: u32x4 = unsafe { transmute(vqdmull_s32(a2, b2)) }; - let c3: u32x4 = unsafe { transmute(vqdmull_s32(a3, b3)) }; - let c4: u32x4 = unsafe { transmute(vqdmull_s32(a4, b4)) }; - let c5: u32x4 = unsafe { transmute(vqdmull_s32(a5, b5)) }; - let c6: u32x4 = unsafe { transmute(vqdmull_s32(a6, b6)) }; - let c7: u32x4 = unsafe { transmute(vqdmull_s32(a7, b7)) }; - - // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. - // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. - let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); - let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); - let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); - let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); - - // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. - c0_c1_lo >>= 1; - c2_c3_lo >>= 1; - c4_c5_lo >>= 1; - c6_c7_lo >>= 1; - - let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; - let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; - - lo + hi -} +cfg_if::cfg_if! { + if #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] { + use core::arch::aarch64::{uint32x2_t, vmull_u32, int32x2_t, vqdmull_s32}; + use std::simd::u32x4; + + /// Returns `a * b`. + pub(crate) fn mul_neon(a: PackedM31, b: PackedM31) -> PackedM31 { + let [a0, a1, a2, a3, a4, a5, a6, a7]: [int32x2_t; 8] = unsafe { transmute(a) }; + let [b0, b1, b2, b3, b4, b5, b6, b7]: [int32x2_t; 8] = unsafe { transmute(b) }; + + // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| + let c0: u32x4 = unsafe { transmute(vqdmull_s32(a0, b0)) }; + let c1: u32x4 = unsafe { transmute(vqdmull_s32(a1, b1)) }; + let c2: u32x4 = unsafe { transmute(vqdmull_s32(a2, b2)) }; + let c3: u32x4 = unsafe { transmute(vqdmull_s32(a3, b3)) }; + let c4: u32x4 = unsafe { transmute(vqdmull_s32(a4, b4)) }; + let c5: u32x4 = unsafe { transmute(vqdmull_s32(a5, b5)) }; + let c6: u32x4 = unsafe { transmute(vqdmull_s32(a6, b6)) }; + let c7: u32x4 = unsafe { transmute(vqdmull_s32(a7, b7)) }; + + // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. + // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. + let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); + let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); + let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); + let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); + + // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. + c0_c1_lo >>= 1; + c2_c3_lo >>= 1; + c4_c5_lo >>= 1; + c6_c7_lo >>= 1; + + let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; + let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; + + lo + hi + } -/// Returns `a * b`. -/// -/// `b_double` should be in the range `[0, 2P]`. -#[cfg(target_arch = "aarch64")] -pub(crate) fn _mul_doubled_neon(a: PackedM31, b_double: u32x16) -> PackedM31 { - use core::arch::aarch64::{uint32x2_t, vmull_u32}; - use std::simd::u32x4; - - let [a0, a1, a2, a3, a4, a5, a6, a7]: [uint32x2_t; 8] = unsafe { transmute(a) }; - let [b0, b1, b2, b3, b4, b5, b6, b7]: [uint32x2_t; 8] = unsafe { transmute(b_double) }; - - // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| - let c0: u32x4 = unsafe { transmute(vmull_u32(a0, b0)) }; - let c1: u32x4 = unsafe { transmute(vmull_u32(a1, b1)) }; - let c2: u32x4 = unsafe { transmute(vmull_u32(a2, b2)) }; - let c3: u32x4 = unsafe { transmute(vmull_u32(a3, b3)) }; - let c4: u32x4 = unsafe { transmute(vmull_u32(a4, b4)) }; - let c5: u32x4 = unsafe { transmute(vmull_u32(a5, b5)) }; - let c6: u32x4 = unsafe { transmute(vmull_u32(a6, b6)) }; - let c7: u32x4 = unsafe { transmute(vmull_u32(a7, b7)) }; - - // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. - // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. - let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); - let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); - let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); - let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); - - // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. - c0_c1_lo >>= 1; - c2_c3_lo >>= 1; - c4_c5_lo >>= 1; - c6_c7_lo >>= 1; - - let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; - let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; - - lo + hi -} + /// Returns `a * b`. + /// + /// `b_double` should be in the range `[0, 2P]`. + pub(crate) fn mul_doubled_neon(a: PackedM31, b_double: u32x16) -> PackedM31 { + let [a0, a1, a2, a3, a4, a5, a6, a7]: [uint32x2_t; 8] = unsafe { transmute(a) }; + let [b0, b1, b2, b3, b4, b5, b6, b7]: [uint32x2_t; 8] = unsafe { transmute(b_double) }; + + // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| + let c0: u32x4 = unsafe { transmute(vmull_u32(a0, b0)) }; + let c1: u32x4 = unsafe { transmute(vmull_u32(a1, b1)) }; + let c2: u32x4 = unsafe { transmute(vmull_u32(a2, b2)) }; + let c3: u32x4 = unsafe { transmute(vmull_u32(a3, b3)) }; + let c4: u32x4 = unsafe { transmute(vmull_u32(a4, b4)) }; + let c5: u32x4 = unsafe { transmute(vmull_u32(a5, b5)) }; + let c6: u32x4 = unsafe { transmute(vmull_u32(a6, b6)) }; + let c7: u32x4 = unsafe { transmute(vmull_u32(a7, b7)) }; + + // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. + // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. + let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); + let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); + let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); + let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); + + // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. + c0_c1_lo >>= 1; + c2_c3_lo >>= 1; + c4_c5_lo >>= 1; + c6_c7_lo >>= 1; + + let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; + let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; + + lo + hi + } + } else if #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))] { + use core::arch::wasm32::{i64x2_extmul_high_u32x4, i64x2_extmul_low_u32x4, v128}; + use std::simd::u32x4; -/// Returns `a * b`. -#[cfg(target_arch = "wasm32")] -pub(crate) fn _mul_wasm(a: PackedM31, b: PackedM31) -> PackedM31 { - _mul_doubled_wasm(a, b.0 + b.0) -} + /// Returns `a * b`. + pub(crate) fn mul_wasm(a: PackedM31, b: PackedM31) -> PackedM31 { + mul_doubled_wasm(a, b.0 + b.0) + } -/// Returns `a * b`. -/// -/// `b_double` should be in the range `[0, 2P]`. -#[cfg(target_arch = "wasm32")] -pub(crate) fn _mul_doubled_wasm(a: PackedM31, b_double: u32x16) -> PackedM31 { - use core::arch::wasm32::{i64x2_extmul_high_u32x4, i64x2_extmul_low_u32x4, v128}; - use std::simd::u32x4; - - let [a0, a1, a2, a3]: [v128; 4] = unsafe { transmute(a) }; - let [b_double0, b_double1, b_double2, b_double3]: [v128; 4] = unsafe { transmute(b_double) }; - - let c0_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a0, b_double0)) }; - let c0_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a0, b_double0)) }; - let c1_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a1, b_double1)) }; - let c1_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a1, b_double1)) }; - let c2_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a2, b_double2)) }; - let c2_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a2, b_double2)) }; - let c3_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a3, b_double3)) }; - let c3_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a3, b_double3)) }; - - let (mut c0_even, c0_odd) = c0_lo.deinterleave(c0_hi); - let (mut c1_even, c1_odd) = c1_lo.deinterleave(c1_hi); - let (mut c2_even, c2_odd) = c2_lo.deinterleave(c2_hi); - let (mut c3_even, c3_odd) = c3_lo.deinterleave(c3_hi); - c0_even >>= 1; - c1_even >>= 1; - c2_even >>= 1; - c3_even >>= 1; - - let even: PackedM31 = unsafe { transmute([c0_even, c1_even, c2_even, c3_even]) }; - let odd: PackedM31 = unsafe { transmute([c0_odd, c1_odd, c2_odd, c3_odd]) }; - - even + odd -} + /// Returns `a * b`. + /// + /// `b_double` should be in the range `[0, 2P]`. + pub(crate) fn mul_doubled_wasm(a: PackedM31, b_double: u32x16) -> PackedM31 { + let [a0, a1, a2, a3]: [v128; 4] = unsafe { transmute(a) }; + let [b_double0, b_double1, b_double2, b_double3]: [v128; 4] = unsafe { transmute(b_double) }; + + let c0_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a0, b_double0)) }; + let c0_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a0, b_double0)) }; + let c1_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a1, b_double1)) }; + let c1_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a1, b_double1)) }; + let c2_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a2, b_double2)) }; + let c2_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a2, b_double2)) }; + let c3_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a3, b_double3)) }; + let c3_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a3, b_double3)) }; + + let (mut c0_even, c0_odd) = c0_lo.deinterleave(c0_hi); + let (mut c1_even, c1_odd) = c1_lo.deinterleave(c1_hi); + let (mut c2_even, c2_odd) = c2_lo.deinterleave(c2_hi); + let (mut c3_even, c3_odd) = c3_lo.deinterleave(c3_hi); + c0_even >>= 1; + c1_even >>= 1; + c2_even >>= 1; + c3_even >>= 1; + + let even: PackedM31 = unsafe { transmute([c0_even, c1_even, c2_even, c3_even]) }; + let odd: PackedM31 = unsafe { transmute([c0_odd, c1_odd, c2_odd, c3_odd]) }; + + even + odd + } + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] { + use std::arch::x86_64::{__m512i, _mm512_mul_epu32, _mm512_srli_epi64}; + use std::simd::Swizzle; -/// Returns `a * b`. -#[cfg(target_arch = "x86_64")] -pub(crate) fn _mul_avx512(a: PackedM31, b: PackedM31) -> PackedM31 { - _mul_doubled_avx512(a, b.0 + b.0) -} + use crate::core::backend::simd::utils::swizzle::{InterleaveEvens, InterleaveOdds}; -/// Returns `a * b`. -/// -/// `b_double` should be in the range `[0, 2P]`. -#[cfg(target_arch = "x86_64")] -pub(crate) fn _mul_doubled_avx512(a: PackedM31, b_double: u32x16) -> PackedM31 { - use std::arch::x86_64::{__m512i, _mm512_mul_epu32, _mm512_srli_epi64}; - - let a: __m512i = unsafe { transmute(a) }; - let b_double: __m512i = unsafe { transmute(b_double) }; - - // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of - // the first operand. - let a_e = a; - // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of - // the first operand. - let a_o = unsafe { _mm512_srli_epi64(a, 32) }; - - let b_dbl_e = b_double; - let b_dbl_o = unsafe { _mm512_srli_epi64(b_double, 32) }; - - // To compute prod = a * b start by multiplying a_e/odd by b_dbl_e/odd. - let prod_dbl_e: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_e, b_dbl_e)) }; - let prod_dbl_o: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_o, b_dbl_o)) }; - - // The result of a multiplication holds a*b in as 64-bits. - // Each 64b-bit word looks like this: - // 1 31 31 1 - // prod_dbl_e - |0|prod_e_h|prod_e_l|0| - // prod_dbl_o - |0|prod_o_h|prod_o_l|0| - - // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: - let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); - // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| - // Divide by 2: - prod_lo >>= 1; - // prod_lo - |0|prod_o_l|0|prod_e_l| - - // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: - let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); - // prod_hi - |0|prod_o_h|0|prod_e_h| - - PackedM31(prod_lo) + PackedM31(prod_hi) -} + /// Returns `a * b`. + pub(crate) fn mul_avx512(a: PackedM31, b: PackedM31) -> PackedM31 { + mul_doubled_avx512(a, b.0 + b.0) + } -/// Returns `a * b`. -#[cfg(target_arch = "x86_64")] -pub(crate) fn _mul_avx2(a: PackedM31, b: PackedM31) -> PackedM31 { - _mul_doubled_avx2(a, b.0 + b.0) -} + /// Returns `a * b`. + /// + /// `b_double` should be in the range `[0, 2P]`. + pub(crate) fn mul_doubled_avx512(a: PackedM31, b_double: u32x16) -> PackedM31 { + let a: __m512i = unsafe { transmute(a) }; + let b_double: __m512i = unsafe { transmute(b_double) }; + + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a_e = a; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a_o = unsafe { _mm512_srli_epi64(a, 32) }; + + let b_dbl_e = b_double; + let b_dbl_o = unsafe { _mm512_srli_epi64(b_double, 32) }; + + // To compute prod = a * b start by multiplying a_e/odd by b_dbl_e/odd. + let prod_dbl_e: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_e, b_dbl_e)) }; + let prod_dbl_o: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_o, b_dbl_o)) }; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_dbl_e - |0|prod_e_h|prod_e_l|0| + // prod_dbl_o - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: + let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| + // Divide by 2: + prod_lo >>= 1; + // prod_lo - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: + let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_hi - |0|prod_o_h|0|prod_e_h| + + PackedM31(prod_lo) + PackedM31(prod_hi) + } + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] { + use std::arch::x86_64::{__m256i, _mm256_mul_epu32, _mm256_srli_epi64}; + use std::simd::Swizzle; -/// Returns `a * b`. -/// -/// `b_double` should be in the range `[0, 2P]`. -#[cfg(target_arch = "x86_64")] -pub(crate) fn _mul_doubled_avx2(a: PackedM31, b_double: u32x16) -> PackedM31 { - use std::arch::x86_64::{__m256i, _mm256_mul_epu32, _mm256_srli_epi64}; - - let [a0, a1]: [__m256i; 2] = unsafe { transmute(a) }; - let [b0_dbl, b1_dbl]: [__m256i; 2] = unsafe { transmute(b_double) }; - - // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of - // the first operand. - let a0_e = a0; - let a1_e = a1; - // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of - // the first operand. - let a0_o = unsafe { _mm256_srli_epi64(a0, 32) }; - let a1_o = unsafe { _mm256_srli_epi64(a1, 32) }; - - let b0_dbl_e = b0_dbl; - let b1_dbl_e = b1_dbl; - let b0_dbl_o = unsafe { _mm256_srli_epi64(b0_dbl, 32) }; - let b1_dbl_o = unsafe { _mm256_srli_epi64(b1_dbl, 32) }; - - // To compute prod = a * b start by multiplying a0/1_e/odd by b0/1_e/odd. - let prod0_dbl_e = unsafe { _mm256_mul_epu32(a0_e, b0_dbl_e) }; - let prod0_dbl_o = unsafe { _mm256_mul_epu32(a0_o, b0_dbl_o) }; - let prod1_dbl_e = unsafe { _mm256_mul_epu32(a1_e, b1_dbl_e) }; - let prod1_dbl_o = unsafe { _mm256_mul_epu32(a1_o, b1_dbl_o) }; - - let prod_dbl_e: u32x16 = unsafe { transmute([prod0_dbl_e, prod1_dbl_e]) }; - let prod_dbl_o: u32x16 = unsafe { transmute([prod0_dbl_o, prod1_dbl_o]) }; - - // The result of a multiplication holds a*b in as 64-bits. - // Each 64b-bit word looks like this: - // 1 31 31 1 - // prod_dbl_e - |0|prod_e_h|prod_e_l|0| - // prod_dbl_o - |0|prod_o_h|prod_o_l|0| - - // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: - let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); - // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| - // Divide by 2: - prod_lo >>= 1; - // prod_lo - |0|prod_o_l|0|prod_e_l| - - // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: - let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); - // prod_hi - |0|prod_o_h|0|prod_e_h| - - PackedM31(prod_lo) + PackedM31(prod_hi) -} + use crate::core::backend::simd::utils::swizzle::{InterleaveEvens, InterleaveOdds}; -/// Returns `a * b`. -/// -/// Should only be used in the absence of a platform specific implementation. -pub(crate) fn _mul_simd(a: PackedM31, b: PackedM31) -> PackedM31 { - _mul_doubled_simd(a, b.0 + b.0) -} + /// Returns `a * b`. + pub(crate) fn mul_avx2(a: PackedM31, b: PackedM31) -> PackedM31 { + mul_doubled_avx2(a, b.0 + b.0) + } -/// Returns `a * b`. -/// -/// Should only be used in the absence of a platform specific implementation. -/// -/// `b_double` should be in the range `[0, 2P]`. -pub(crate) fn _mul_doubled_simd(a: PackedM31, b_double: u32x16) -> PackedM31 { - const MASK_EVENS: Simd = Simd::from_array([0xFFFFFFFF; { N_LANES / 2 }]); - - // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of - // the first operand. - let a_e = unsafe { transmute::<_, Simd>(a.0) & MASK_EVENS }; - // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of - // the first operand. - let a_o = unsafe { transmute::<_, Simd>(a) >> 32 }; - - let b_dbl_e = unsafe { transmute::<_, Simd>(b_double) & MASK_EVENS }; - let b_dbl_o = unsafe { transmute::<_, Simd>(b_double) >> 32 }; - - // To compute prod = a * b start by multiplying - // a_e/o by b_dbl_e/o. - let prod_e_dbl = a_e * b_dbl_e; - let prod_o_dbl = a_o * b_dbl_o; - - // The result of a multiplication holds a*b in as 64-bits. - // Each 64b-bit word looks like this: - // 1 31 31 1 - // prod_e_dbl - |0|prod_e_h|prod_e_l|0| - // prod_o_dbl - |0|prod_o_h|prod_o_l|0| - - // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: - // let prod_lows = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, - // prod_o_dbl); - // prod_ls - |prod_o_l|0|prod_e_l|0| - let mut prod_lows = InterleaveEvens::concat_swizzle( - unsafe { transmute::<_, Simd>(prod_e_dbl) }, - unsafe { transmute::<_, Simd>(prod_o_dbl) }, - ); - // Divide by 2: - prod_lows >>= 1; - // prod_ls - |0|prod_o_l|0|prod_e_l| - - // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: - let prod_highs = InterleaveOdds::concat_swizzle( - unsafe { transmute::<_, Simd>(prod_e_dbl) }, - unsafe { transmute::<_, Simd>(prod_o_dbl) }, - ); - - // prod_hs - |0|prod_o_h|0|prod_e_h| - PackedM31(prod_lows) + PackedM31(prod_highs) + /// Returns `a * b`. + /// + /// `b_double` should be in the range `[0, 2P]`. + pub(crate) fn mul_doubled_avx2(a: PackedM31, b_double: u32x16) -> PackedM31 { + let [a0, a1]: [__m256i; 2] = unsafe { transmute::(a) }; + let [b0_dbl, b1_dbl]: [__m256i; 2] = unsafe { transmute::(b_double) }; + + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a0_e = a0; + let a1_e = a1; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a0_o = unsafe { _mm256_srli_epi64(a0, 32) }; + let a1_o = unsafe { _mm256_srli_epi64(a1, 32) }; + + let b0_dbl_e = b0_dbl; + let b1_dbl_e = b1_dbl; + let b0_dbl_o = unsafe { _mm256_srli_epi64(b0_dbl, 32) }; + let b1_dbl_o = unsafe { _mm256_srli_epi64(b1_dbl, 32) }; + + // To compute prod = a * b start by multiplying a0/1_e/odd by b0/1_e/odd. + let prod0_dbl_e = unsafe { _mm256_mul_epu32(a0_e, b0_dbl_e) }; + let prod0_dbl_o = unsafe { _mm256_mul_epu32(a0_o, b0_dbl_o) }; + let prod1_dbl_e = unsafe { _mm256_mul_epu32(a1_e, b1_dbl_e) }; + let prod1_dbl_o = unsafe { _mm256_mul_epu32(a1_o, b1_dbl_o) }; + + let prod_dbl_e: u32x16 = + unsafe { transmute::<[__m256i; 2], u32x16>([prod0_dbl_e, prod1_dbl_e]) }; + let prod_dbl_o: u32x16 = + unsafe { transmute::<[__m256i; 2], u32x16>([prod0_dbl_o, prod1_dbl_o]) }; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_dbl_e - |0|prod_e_h|prod_e_l|0| + // prod_dbl_o - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: + let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| + // Divide by 2: + prod_lo >>= 1; + // prod_lo - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: + let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_hi - |0|prod_o_h|0|prod_e_h| + + PackedM31(prod_lo) + PackedM31(prod_hi) + } + } else { + use std::simd::Swizzle; + + use crate::core::backend::simd::utils::swizzle::{InterleaveEvens, InterleaveOdds}; + + /// Returns `a * b`. + /// + /// Should only be used in the absence of a platform specific implementation. + pub(crate) fn mul_simd(a: PackedM31, b: PackedM31) -> PackedM31 { + mul_doubled_simd(a, b.0 + b.0) + } + + /// Returns `a * b`. + /// + /// Should only be used in the absence of a platform specific implementation. + /// + /// `b_double` should be in the range `[0, 2P]`. + pub(crate) fn mul_doubled_simd(a: PackedM31, b_double: u32x16) -> PackedM31 { + const MASK_EVENS: Simd = Simd::from_array([0xFFFFFFFF; { N_LANES / 2 }]); + + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a_e = + unsafe { transmute::, Simd>(a.0) & MASK_EVENS }; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a_o = unsafe { transmute::>(a) >> 32 }; + + let b_dbl_e = unsafe { + transmute::, Simd>(b_double) & MASK_EVENS + }; + let b_dbl_o = + unsafe { transmute::, Simd>(b_double) >> 32 }; + + // To compute prod = a * b start by multiplying + // a_e/o by b_dbl_e/o. + let prod_e_dbl = a_e * b_dbl_e; + let prod_o_dbl = a_o * b_dbl_o; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_e_dbl - |0|prod_e_h|prod_e_l|0| + // prod_o_dbl - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: + // let prod_lows = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, + // prod_o_dbl); + // prod_ls - |prod_o_l|0|prod_e_l|0| + let mut prod_lows = InterleaveEvens::concat_swizzle( + unsafe { transmute::, Simd>(prod_e_dbl) }, + unsafe { transmute::, Simd>(prod_o_dbl) }, + ); + // Divide by 2: + prod_lows >>= 1; + // prod_ls - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: + let prod_highs = InterleaveOdds::concat_swizzle( + unsafe { transmute::, Simd>(prod_e_dbl) }, + unsafe { transmute::, Simd>(prod_o_dbl) }, + ); + + // prod_hs - |0|prod_o_h|0|prod_e_h| + PackedM31(prod_lows) + PackedM31(prod_highs) + } + } } #[cfg(test)] diff --git a/crates/prover/src/core/backend/simd/quotients.rs b/crates/prover/src/core/backend/simd/quotients.rs index bac374292..9dd30f7fd 100644 --- a/crates/prover/src/core/backend/simd/quotients.rs +++ b/crates/prover/src/core/backend/simd/quotients.rs @@ -286,13 +286,13 @@ mod tests { let e1: BaseColumn = (0..small_domain.size()) .map(|i| BaseField::from(2 * i)) .collect(); - let polys = vec![ + let polys = [ CircleEvaluation::::new(small_domain, e0) .interpolate(), CircleEvaluation::::new(small_domain, e1) .interpolate(), ]; - let columns = vec![polys[0].evaluate(domain), polys[1].evaluate(domain)]; + let columns = [polys[0].evaluate(domain), polys[1].evaluate(domain)]; let random_coeff = qm31!(1, 2, 3, 4); let a = polys[0].eval_at_point(SECURE_FIELD_CIRCLE_GEN); let b = polys[1].eval_at_point(SECURE_FIELD_CIRCLE_GEN); diff --git a/crates/prover/src/core/backend/simd/utils.rs b/crates/prover/src/core/backend/simd/utils.rs index d5f53a22b..a3d1b614c 100644 --- a/crates/prover/src/core/backend/simd/utils.rs +++ b/crates/prover/src/core/backend/simd/utils.rs @@ -1,29 +1,3 @@ -use std::simd::Swizzle; - -/// Used with [`Swizzle::concat_swizzle`] to interleave the even values of two vectors. -pub struct InterleaveEvens; - -impl Swizzle for InterleaveEvens { - const INDEX: [usize; N] = parity_interleave(false); -} - -/// Used with [`Swizzle::concat_swizzle`] to interleave the odd values of two vectors. -pub struct InterleaveOdds; - -impl Swizzle for InterleaveOdds { - const INDEX: [usize; N] = parity_interleave(true); -} - -const fn parity_interleave(odd: bool) -> [usize; N] { - let mut res = [0; N]; - let mut i = 0; - while i < N { - res[i] = (i % 2) * N + (i / 2) * 2 + if odd { 1 } else { 0 }; - i += 1; - } - res -} - // TODO(andrew): Examine usage of unsafe in SIMD FFT. pub struct UnsafeMut(pub *mut T); impl UnsafeMut { @@ -51,29 +25,60 @@ impl UnsafeConst { unsafe impl Send for UnsafeConst {} unsafe impl Sync for UnsafeConst {} -#[cfg(test)] -mod tests { - use std::simd::{u32x4, Swizzle}; - - use super::{InterleaveEvens, InterleaveOdds}; +#[cfg(not(any( + all(target_arch = "aarch64", target_feature = "neon"), + all(target_arch = "wasm32", target_feature = "simd128") +)))] +pub mod swizzle { + use std::simd::Swizzle; + + /// Used with [`Swizzle::concat_swizzle`] to interleave the even values of two vectors. + pub struct InterleaveEvens; + impl Swizzle for InterleaveEvens { + const INDEX: [usize; N] = parity_interleave(false); + } - #[test] - fn interleave_evens() { - let lo = u32x4::from_array([0, 1, 2, 3]); - let hi = u32x4::from_array([4, 5, 6, 7]); + /// Used with [`Swizzle::concat_swizzle`] to interleave the odd values of two vectors. + pub struct InterleaveOdds; - let res = InterleaveEvens::concat_swizzle(lo, hi); + impl Swizzle for InterleaveOdds { + const INDEX: [usize; N] = parity_interleave(true); + } - assert_eq!(res, u32x4::from_array([0, 4, 2, 6])); + const fn parity_interleave(odd: bool) -> [usize; N] { + let mut res = [0; N]; + let mut i = 0; + while i < N { + res[i] = (i % 2) * N + (i / 2) * 2 + if odd { 1 } else { 0 }; + i += 1; + } + res } - #[test] - fn interleave_odds() { - let lo = u32x4::from_array([0, 1, 2, 3]); - let hi = u32x4::from_array([4, 5, 6, 7]); + #[cfg(test)] + mod tests { + use std::simd::{u32x4, Swizzle}; + + use super::{InterleaveEvens, InterleaveOdds}; + + #[test] + fn interleave_evens() { + let lo = u32x4::from_array([0, 1, 2, 3]); + let hi = u32x4::from_array([4, 5, 6, 7]); + + let res = InterleaveEvens::concat_swizzle(lo, hi); + + assert_eq!(res, u32x4::from_array([0, 4, 2, 6])); + } + + #[test] + fn interleave_odds() { + let lo = u32x4::from_array([0, 1, 2, 3]); + let hi = u32x4::from_array([4, 5, 6, 7]); - let res = InterleaveOdds::concat_swizzle(lo, hi); + let res = InterleaveOdds::concat_swizzle(lo, hi); - assert_eq!(res, u32x4::from_array([1, 5, 3, 7])); + assert_eq!(res, u32x4::from_array([1, 5, 3, 7])); + } } } diff --git a/crates/prover/src/core/channel/blake2s.rs b/crates/prover/src/core/channel/blake2s.rs index 160d4754e..62218b5ba 100644 --- a/crates/prover/src/core/channel/blake2s.rs +++ b/crates/prover/src/core/channel/blake2s.rs @@ -75,7 +75,7 @@ impl Channel for Blake2sChannel { let res = compress(std::array::from_fn(|i| digest[i]), msg, 0, 0, 0, 0); // TODO(shahars) Channel should always finalize hash. - self.update_digest(unsafe { std::mem::transmute(res) }); + self.update_digest(unsafe { std::mem::transmute::<[u32; 8], Blake2sHash>(res) }); } fn draw_felt(&mut self) -> SecureField { diff --git a/crates/prover/src/core/constraints.rs b/crates/prover/src/core/constraints.rs index 31711d98e..f66c8d93d 100644 --- a/crates/prover/src/core/constraints.rs +++ b/crates/prover/src/core/constraints.rs @@ -90,11 +90,11 @@ pub fn complex_conjugate_line( / (point.complex_conjugate().y - point.y) } -/// Evaluates the coefficients of a line between a point and its complex conjugate. Specifically, -/// `a, b, and c, s.t. a*x + b -c*y = 0` for (x,y) being (sample.y, sample.value) and -/// (conj(sample.y), conj(sample.value)). -/// Relies on the fact that every polynomial F over the base -/// field holds: F(p*) == F(p)* (* being the complex conjugate). +/// Evaluates the coefficients of a line between a point and its complex conjugate. +/// +/// Specifically, `a, b, and c, s.t. a*x + b -c*y = 0` for (x,y) being (sample.y, sample.value) and +/// (conj(sample.y), conj(sample.value)). Relies on the fact that every polynomial F over the base +/// field holds: `F(p*) == F(p)*` (`*` being the complex conjugate). pub fn complex_conjugate_line_coeffs( sample: &PointSample, alpha: SecureField, diff --git a/crates/prover/src/core/fri.rs b/crates/prover/src/core/fri.rs index 607dfaf41..d3684c1a8 100644 --- a/crates/prover/src/core/fri.rs +++ b/crates/prover/src/core/fri.rs @@ -98,8 +98,7 @@ pub trait FriOps: FieldOps + PolyOps + Sized + FieldOps /// Let `src` be the evaluation of a circle polynomial `f` on a /// [`CircleDomain`] `E`. This function computes evaluations of `f' = f0 /// + alpha * f1` on the x-coordinates of `E` such that `2f(p) = f0(px) + py * f1(px)`. The - /// evaluations of `f'` are accumulated into `dst` by the formula `dst = dst * alpha^2 + - /// f'`. + /// evaluations of `f'` are accumulated into `dst` by the formula `dst = dst * alpha^2 + f'`. /// /// # Panics /// @@ -979,7 +978,7 @@ fn compute_decommitment_positions_and_witness_evals( let mut witness_evals = Vec::new(); // Group queries by the folding coset they reside in. - for subset_queries in query_positions.group_by(|a, b| a >> fold_step == b >> fold_step) { + for subset_queries in query_positions.chunk_by(|a, b| a >> fold_step == b >> fold_step) { let subset_start = (subset_queries[0] >> fold_step) << fold_step; let subset_decommitment_positions = subset_start..subset_start + (1 << fold_step); let mut subset_queries_iter = subset_queries.iter().peekable(); @@ -1020,7 +1019,7 @@ fn compute_decommitment_positions_and_rebuild_evals( let mut subset_domain_index_initials = Vec::new(); // Group queries by the subset they reside in. - for subset_queries in queries.group_by(|a, b| a >> fold_step == b >> fold_step) { + for subset_queries in queries.chunk_by(|a, b| a >> fold_step == b >> fold_step) { let subset_start = (subset_queries[0] >> fold_step) << fold_step; let subset_decommitment_positions = subset_start..subset_start + (1 << fold_step); decommitment_positions.extend(subset_decommitment_positions.clone()); diff --git a/crates/prover/src/core/lookups/gkr_prover.rs b/crates/prover/src/core/lookups/gkr_prover.rs index 6e6ed2586..c2d6df1bd 100644 --- a/crates/prover/src/core/lookups/gkr_prover.rs +++ b/crates/prover/src/core/lookups/gkr_prover.rs @@ -299,7 +299,7 @@ pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> { pub lambda: SecureField, } -impl<'a, B: GkrOps> MultivariatePolyOracle for GkrMultivariatePolyOracle<'a, B> { +impl MultivariatePolyOracle for GkrMultivariatePolyOracle<'_, B> { fn n_variables(&self) -> usize { self.input_layer.n_variables() - 1 } @@ -470,7 +470,7 @@ pub fn prove_batch( // Seed the channel with the layer masks. for (&instance, mask) in zip(&sumcheck_instances, &masks) { - channel.mix_felts(mask.columns().flatten()); + channel.mix_felts(mask.columns().as_flattened()); layer_masks_by_instance[instance].push(mask.clone()); } diff --git a/crates/prover/src/core/lookups/gkr_verifier.rs b/crates/prover/src/core/lookups/gkr_verifier.rs index b65ceb162..f7ffefc9d 100644 --- a/crates/prover/src/core/lookups/gkr_verifier.rs +++ b/crates/prover/src/core/lookups/gkr_verifier.rs @@ -120,7 +120,7 @@ pub fn partially_verify_batch( for &instance in &sumcheck_instances { let n_unused = n_layers - instance_n_layers(instance); let mask = &layer_masks_by_instance[instance][layer - n_unused]; - channel.mix_felts(mask.columns().flatten()); + channel.mix_felts(mask.columns().as_flattened()); } // Set the OOD evaluation point for layer above. @@ -223,7 +223,7 @@ pub struct GkrMask { } impl GkrMask { - pub fn new(columns: Vec<[SecureField; 2]>) -> Self { + pub const fn new(columns: Vec<[SecureField; 2]>) -> Self { Self { columns } } diff --git a/crates/prover/src/core/pcs/mod.rs b/crates/prover/src/core/pcs/mod.rs index d9acf524b..1a551d1eb 100644 --- a/crates/prover/src/core/pcs/mod.rs +++ b/crates/prover/src/core/pcs/mod.rs @@ -1,4 +1,5 @@ //! Implements a FRI polynomial commitment scheme. +//! //! This is a protocol where the prover can commit on a set of polynomials and then prove their //! opening on a set of points. //! Note: This implementation is not really a polynomial commitment scheme, because we are not in diff --git a/crates/prover/src/core/pcs/prover.rs b/crates/prover/src/core/pcs/prover.rs index 59a2e8e8d..ae017c0e5 100644 --- a/crates/prover/src/core/pcs/prover.rs +++ b/crates/prover/src/core/pcs/prover.rs @@ -162,7 +162,7 @@ pub struct TreeBuilder<'a, 'b, B: BackendForChannel, MC: MerkleChannel> { commitment_scheme: &'a mut CommitmentSchemeProver<'b, B, MC>, polys: ColumnVec>, } -impl<'a, 'b, B: BackendForChannel, MC: MerkleChannel> TreeBuilder<'a, 'b, B, MC> { +impl, MC: MerkleChannel> TreeBuilder<'_, '_, B, MC> { pub fn extend_evals( &mut self, columns: impl IntoIterator>, diff --git a/crates/prover/src/core/pcs/utils.rs b/crates/prover/src/core/pcs/utils.rs index 73c624f81..1a5ce7ccb 100644 --- a/crates/prover/src/core/pcs/utils.rs +++ b/crates/prover/src/core/pcs/utils.rs @@ -12,7 +12,7 @@ use crate::core::ColumnVec; pub struct TreeVec(pub Vec); impl TreeVec { - pub fn new(vec: Vec) -> TreeVec { + pub const fn new(vec: Vec) -> TreeVec { TreeVec(vec) } pub fn map U>(self, f: F) -> TreeVec { diff --git a/crates/prover/src/core/pcs/verifier.rs b/crates/prover/src/core/pcs/verifier.rs index db7bea3d3..d6ecd334e 100644 --- a/crates/prover/src/core/pcs/verifier.rs +++ b/crates/prover/src/core/pcs/verifier.rs @@ -96,7 +96,7 @@ impl CommitmentSchemeVerifier { }) .0 .into_iter() - .collect::>()?; + .collect::>()?; // Answer FRI queries. let samples = sampled_points.zip_cols(proof.sampled_values).map_cols( diff --git a/crates/prover/src/core/poly/circle/canonic.rs b/crates/prover/src/core/poly/circle/canonic.rs index cda0fcc8c..1a559fd3c 100644 --- a/crates/prover/src/core/poly/circle/canonic.rs +++ b/crates/prover/src/core/poly/circle/canonic.rs @@ -2,12 +2,14 @@ use super::CircleDomain; use crate::core::circle::{CirclePoint, CirclePointIndex, Coset}; use crate::core::fields::m31::BaseField; -/// A coset of the form G_{2n} + , where G_n is the generator of the -/// subgroup of order n. The ordering on this coset is G_2n + i * G_n. -/// These cosets can be used as a [CircleDomain], and be interpolated on. -/// Note that this changes the ordering on the coset to be like [CircleDomain], -/// which is G_2n + i * G_n/2 and then -G_2n -i * G_n/2. -/// For example, the Xs below are a canonic coset with n=8. +/// A coset of the form `G_{2n} + `, where `G_n` is the generator of the subgroup of order `n`. +/// +/// The ordering on this coset is `G_2n + i * G_n`. +/// These cosets can be used as a [`CircleDomain`], and be interpolated on. +/// Note that this changes the ordering on the coset to be like [`CircleDomain`], +/// which is `G_{2n} + i * G_{n/2}` and then `-G_{2n} -i * G_{n/2}`. +/// For example, the `X`s below are a canonic coset with `n=8`. +/// /// ```text /// X O X /// O O diff --git a/crates/prover/src/core/poly/circle/domain.rs b/crates/prover/src/core/poly/circle/domain.rs index 2bffac773..83765e6d8 100644 --- a/crates/prover/src/core/poly/circle/domain.rs +++ b/crates/prover/src/core/poly/circle/domain.rs @@ -10,8 +10,9 @@ use crate::core::fields::m31::BaseField; pub const MAX_CIRCLE_DOMAIN_LOG_SIZE: u32 = M31_CIRCLE_LOG_ORDER - 1; /// A valid domain for circle polynomial interpolation and evaluation. -/// Valid domains are a disjoint union of two conjugate cosets: +-C + . -/// The ordering defined on this domain is C + iG_n, and then -C - iG_n. +/// +/// Valid domains are a disjoint union of two conjugate cosets: `+-C + `. +/// The ordering defined on this domain is `C + iG_n`, and then `-C - iG_n`. #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct CircleDomain { pub half_coset: Coset, diff --git a/crates/prover/src/core/poly/circle/evaluation.rs b/crates/prover/src/core/poly/circle/evaluation.rs index faa2b7284..6094df399 100644 --- a/crates/prover/src/core/poly/circle/evaluation.rs +++ b/crates/prover/src/core/poly/circle/evaluation.rs @@ -146,7 +146,7 @@ impl<'a, F: ExtensionOf> CosetSubEvaluation<'a, F> { } } -impl<'a, F: ExtensionOf> Index for CosetSubEvaluation<'a, F> { +impl> Index for CosetSubEvaluation<'_, F> { type Output = F; fn index(&self, index: isize) -> &Self::Output { @@ -156,7 +156,7 @@ impl<'a, F: ExtensionOf> Index for CosetSubEvaluation<'a, F> { } } -impl<'a, F: ExtensionOf> Index for CosetSubEvaluation<'a, F> { +impl> Index for CosetSubEvaluation<'_, F> { type Output = F; fn index(&self, index: usize) -> &Self::Output { diff --git a/crates/prover/src/core/poly/twiddles.rs b/crates/prover/src/core/poly/twiddles.rs index f3b186376..2e172a2cf 100644 --- a/crates/prover/src/core/poly/twiddles.rs +++ b/crates/prover/src/core/poly/twiddles.rs @@ -2,6 +2,7 @@ use super::circle::PolyOps; use crate::core::circle::Coset; /// Precomputed twiddles for a specific coset tower. +/// /// A coset tower is every repeated doubling of a `root_coset`. /// The largest CircleDomain that can be ffted using these twiddles is one with `root_coset` as /// its `half_coset`. diff --git a/crates/prover/src/core/utils.rs b/crates/prover/src/core/utils.rs index df7c77e4c..e19e68f5e 100644 --- a/crates/prover/src/core/utils.rs +++ b/crates/prover/src/core/utils.rs @@ -20,7 +20,7 @@ pub struct PeekTakeWhile<'a, I: Iterator, P: FnMut(&I::Item) -> bool> { iter: &'a mut Peekable, predicate: P, } -impl<'a, I: Iterator, P: FnMut(&I::Item) -> bool> Iterator for PeekTakeWhile<'a, I, P> { +impl bool> Iterator for PeekTakeWhile<'_, I, P> { type Item = I::Item; fn next(&mut self) -> Option { diff --git a/crates/prover/src/core/vcs/blake2_merkle.rs b/crates/prover/src/core/vcs/blake2_merkle.rs index 3664ea147..15e043723 100644 --- a/crates/prover/src/core/vcs/blake2_merkle.rs +++ b/crates/prover/src/core/vcs/blake2_merkle.rs @@ -20,7 +20,7 @@ impl MerkleHasher for Blake2sMerkleHasher { if let Some((left, right)) = children_hashes { state = compress( state, - unsafe { std::mem::transmute([left, right]) }, + unsafe { std::mem::transmute::<[Blake2sHash; 2], [u32; 16]>([left, right]) }, 0, 0, 0, @@ -33,9 +33,16 @@ impl MerkleHasher for Blake2sMerkleHasher { .copied() .chain(std::iter::repeat(BaseField::zero()).take(rem)); for chunk in padded_values.array_chunks::<16>() { - state = compress(state, unsafe { std::mem::transmute(chunk) }, 0, 0, 0, 0); + state = compress( + state, + unsafe { std::mem::transmute::<[BaseField; 16], [u32; 16]>(chunk) }, + 0, + 0, + 0, + 0, + ); } - state.map(|x| x.to_le_bytes()).flatten().into() + state.map(|x| x.to_le_bytes()).as_flattened().into() } } diff --git a/crates/prover/src/core/vcs/blake2s_ref.rs b/crates/prover/src/core/vcs/blake2s_ref.rs index 95665597c..3776a8830 100644 --- a/crates/prover/src/core/vcs/blake2s_ref.rs +++ b/crates/prover/src/core/vcs/blake2s_ref.rs @@ -30,22 +30,22 @@ const fn xor(a: u32, b: u32) -> u32 { #[inline(always)] const fn rot16(x: u32) -> u32 { - (x >> 16) | (x << (32 - 16)) + x.rotate_right(16) } #[inline(always)] const fn rot12(x: u32) -> u32 { - (x >> 12) | (x << (32 - 12)) + x.rotate_right(12) } #[inline(always)] const fn rot8(x: u32) -> u32 { - (x >> 8) | (x << (32 - 8)) + x.rotate_right(8) } #[inline(always)] const fn rot7(x: u32) -> u32 { - (x >> 7) | (x << (32 - 7)) + x.rotate_right(7) } #[inline(always)] diff --git a/crates/prover/src/core/vcs/ops.rs b/crates/prover/src/core/vcs/ops.rs index 14093e536..b40a91bef 100644 --- a/crates/prover/src/core/vcs/ops.rs +++ b/crates/prover/src/core/vcs/ops.rs @@ -6,13 +6,12 @@ use crate::core::backend::{Col, ColumnOps}; use crate::core::fields::m31::BaseField; use crate::core::vcs::hash::Hash; -/// A Merkle node hash is a hash of: -/// [left_child_hash, right_child_hash], column0_value, column1_value, ... -/// "[]" denotes optional values. +/// A Merkle node hash is a hash of: `[left_child_hash, right_child_hash], column0_value, +/// column1_value, ...` where `[]` denotes optional values. +/// /// The largest Merkle layer has no left and right child hashes. The rest of the layers have -/// children hashes. -/// At each layer, the tree may have multiple columns of the same length as the layer. -/// Each node in that layer contains one value from each column. +/// children hashes. At each layer, the tree may have multiple columns of the same length as the +/// layer. Each node in that layer contains one value from each column. pub trait MerkleHasher: Debug + Default + Clone { type Hash: Hash; /// Hashes a single Merkle node. See [MerkleHasher] for more details. diff --git a/crates/prover/src/core/vcs/prover.rs b/crates/prover/src/core/vcs/prover.rs index da4695d3f..2d0983837 100644 --- a/crates/prover/src/core/vcs/prover.rs +++ b/crates/prover/src/core/vcs/prover.rs @@ -68,8 +68,7 @@ impl, H: MerkleHasher> MerkleProver { /// /// # Arguments /// - /// * `queries_per_log_size` - A map from log_size to a vector of queries for columns of that - /// log_size. + /// * `queries_per_log_size` - Maps a log_size to a vector of queries for columns of that size. /// * `columns` - A vector of references to columns. /// /// # Returns @@ -171,7 +170,7 @@ pub struct MerkleDecommitment { pub column_witness: Vec, } impl MerkleDecommitment { - fn empty() -> Self { + const fn empty() -> Self { Self { hash_witness: Vec::new(), column_witness: Vec::new(), diff --git a/crates/prover/src/core/vcs/verifier.rs b/crates/prover/src/core/vcs/verifier.rs index fcd0453a3..bb4969f5d 100644 --- a/crates/prover/src/core/vcs/verifier.rs +++ b/crates/prover/src/core/vcs/verifier.rs @@ -34,7 +34,7 @@ impl MerkleVerifier { /// # Arguments /// /// * `queries_per_log_size` - A map from log_size to a vector of queries for columns of that - /// log_size. + /// log_size. /// * `queried_values` - A vector of queried values according to the order in /// [`MerkleProver::decommit()`]. /// * `decommitment` - The decommitment object containing the witness and column values. @@ -50,7 +50,6 @@ impl MerkleVerifier { /// * The computed root does not match the expected root. /// /// [`MerkleProver::decommit()`]: crate::core::...::MerkleProver::decommit - pub fn verify( &self, queries_per_log_size: &BTreeMap>, diff --git a/crates/prover/src/examples/blake/round/constraints.rs b/crates/prover/src/examples/blake/round/constraints.rs index e15a225df..f291f44e1 100644 --- a/crates/prover/src/examples/blake/round/constraints.rs +++ b/crates/prover/src/examples/blake/round/constraints.rs @@ -14,10 +14,11 @@ pub struct BlakeRoundEval<'a, E: EvalAtRow> { pub eval: E, pub xor_lookup_elements: &'a BlakeXorElements, pub round_lookup_elements: &'a RoundElements, - pub total_sum: SecureField, - pub log_size: u32, + // TODO(first): validate logup. + pub _total_sum: SecureField, + pub _log_size: u32, } -impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { +impl BlakeRoundEval<'_, E> { pub fn eval(mut self) -> E { let mut v: [Fu32; STATE_SIZE] = std::array::from_fn(|_| self.next_u32()); let input_v = v.clone(); diff --git a/crates/prover/src/examples/blake/round/gen.rs b/crates/prover/src/examples/blake/round/gen.rs index 3b9d0c853..6f4f11a9b 100644 --- a/crates/prover/src/examples/blake/round/gen.rs +++ b/crates/prover/src/examples/blake/round/gen.rs @@ -68,7 +68,7 @@ struct TraceGeneratorRow<'a> { vec_row: usize, xor_lookups_index: usize, } -impl<'a> TraceGeneratorRow<'a> { +impl TraceGeneratorRow<'_> { fn append_felt(&mut self, val: u32x16) { self.gen.trace[self.col_index].data[self.vec_row] = unsafe { PackedBaseField::from_simd_unchecked(val) }; diff --git a/crates/prover/src/examples/blake/round/mod.rs b/crates/prover/src/examples/blake/round/mod.rs index 8fa238b26..4926fe218 100644 --- a/crates/prover/src/examples/blake/round/mod.rs +++ b/crates/prover/src/examples/blake/round/mod.rs @@ -33,8 +33,8 @@ impl FrameworkEval for BlakeRoundEval { eval, xor_lookup_elements: &self.xor_lookup_elements, round_lookup_elements: &self.round_lookup_elements, - total_sum: self.total_sum, - log_size: self.log_size, + _total_sum: self.total_sum, + _log_size: self.log_size, }; blake_eval.eval() } diff --git a/crates/prover/src/examples/blake/scheduler/mod.rs b/crates/prover/src/examples/blake/scheduler/mod.rs index b69318ce4..c998ed61b 100644 --- a/crates/prover/src/examples/blake/scheduler/mod.rs +++ b/crates/prover/src/examples/blake/scheduler/mod.rs @@ -16,10 +16,12 @@ pub type BlakeSchedulerComponent = FrameworkComponent; relation!(BlakeElements, N_ROUND_INPUT_FELTS); +#[allow(dead_code)] pub struct BlakeSchedulerEval { pub log_size: u32, pub blake_lookup_elements: BlakeElements, pub round_lookup_elements: RoundElements, + // TODO(first): validate logup. pub total_sum: SecureField, } impl FrameworkEval for BlakeSchedulerEval { diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 808dcc74d..481d30fd6 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -454,7 +454,7 @@ mod tests { for i in 0..16 { internal_matrix[i][i] += BaseField::from_u32_unchecked(1 << (i + 1)); } - let matrix = RowMajorMatrix::::new(internal_matrix.flatten().to_vec()); + let matrix = RowMajorMatrix::::new(internal_matrix.as_flattened().to_vec()); let expected_state = matrix.mul(state); apply_internal_round_matrix(&mut state); diff --git a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs index 79efe5998..3c0c17d0c 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs @@ -115,9 +115,7 @@ impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> MleEvalProverComponent<'twiddl } } -impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> Component - for MleEvalProverComponent<'twiddles, 'oracle, O> -{ +impl Component for MleEvalProverComponent<'_, '_, O> { fn n_constraints(&self) -> usize { self.eval_info().n_constraints } @@ -191,9 +189,7 @@ impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> Component } } -impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> ComponentProver - for MleEvalProverComponent<'twiddles, 'oracle, O> -{ +impl ComponentProver for MleEvalProverComponent<'_, '_, O> { fn evaluate_constraint_quotients_on_domain( &self, trace: &Trace<'_, SimdBackend>, @@ -330,7 +326,7 @@ impl<'oracle, O: MleCoeffColumnOracle> MleEvalVerifierComponent<'oracle, O> { } } -impl<'oracle, O: MleCoeffColumnOracle> Component for MleEvalVerifierComponent<'oracle, O> { +impl Component for MleEvalVerifierComponent<'_, O> { fn n_constraints(&self) -> usize { self.eval_info().n_constraints } diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index 34a5c6701..43f932f50 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -1,23 +1,20 @@ #![allow(incomplete_features)] +#![cfg_attr( + all(target_arch = "x86_64", target_feature = "avx512f"), + feature(stdarch_x86_avx512) +)] #![feature( array_chunks, - array_methods, array_try_from_fn, array_windows, assert_matches, exact_size_is_empty, - generic_const_exprs, get_many_mut, int_roundings, - is_sorted, iter_array_chunks, - new_uninit, portable_simd, - slice_first_last_chunk, - slice_flatten, - slice_group_by, slice_ptr_get, - stdsimd + trait_upcasting )] pub mod constraint_framework; pub mod core; diff --git a/rust-toolchain.toml b/rust-toolchain.toml index a0f1a930e..85b121b33 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "nightly-2024-01-04" +channel = "nightly-2024-11-06" diff --git a/scripts/clippy.sh b/scripts/clippy.sh index 8361cd25d..43f11957e 100755 --- a/scripts/clippy.sh +++ b/scripts/clippy.sh @@ -1,3 +1,3 @@ #!/bin/bash -cargo +nightly-2024-01-04 clippy "$@" --all-targets --all-features -- -D warnings -D future-incompatible \ +cargo +nightly-2024-11-06 clippy "$@" --all-targets --all-features -- -D warnings -D future-incompatible \ -D nonstandard-style -D rust-2018-idioms -D unused diff --git a/scripts/rust_fmt.sh b/scripts/rust_fmt.sh index e4223f999..9f95485b0 100755 --- a/scripts/rust_fmt.sh +++ b/scripts/rust_fmt.sh @@ -1,3 +1,3 @@ #!/bin/bash -cargo +nightly-2024-01-04 fmt --all -- "$@" +cargo +nightly-2024-11-06 fmt --all -- "$@" diff --git a/scripts/test_avx.sh b/scripts/test_avx.sh index d911a2479..f7755f61d 100755 --- a/scripts/test_avx.sh +++ b/scripts/test_avx.sh @@ -1,4 +1,4 @@ #!/bin/bash # Can be used as a drop in replacement for `cargo test` with avx512f flag on. # For example, `./scripts/test_avx.sh` will run all tests(not only avx). -RUSTFLAGS="-Awarnings -C target-cpu=native -C target-feature=+avx512f -C opt-level=2" cargo +nightly-2024-01-04 test "$@" +RUSTFLAGS="-Awarnings -C target-cpu=native -C target-feature=+avx512f -C opt-level=2" cargo +nightly-2024-11-06 test "$@" From f24cde6f8de3f1db75d04f87807f052aa889b077 Mon Sep 17 00:00:00 2001 From: shaharsamocha7 <70577611+shaharsamocha7@users.noreply.github.com> Date: Thu, 19 Dec 2024 08:21:10 +0200 Subject: [PATCH 35/69] Update toolchain version (#940) --- .github/workflows/benchmarks-pages.yaml | 2 +- .github/workflows/ci.yaml | 50 ++++++++++++------------- .github/workflows/coverage.yaml | 4 +- rust-toolchain.toml | 2 +- scripts/clippy.sh | 2 +- scripts/rust_fmt.sh | 2 +- scripts/test_avx.sh | 2 +- 7 files changed, 32 insertions(+), 32 deletions(-) diff --git a/.github/workflows/benchmarks-pages.yaml b/.github/workflows/benchmarks-pages.yaml index 5e2f877b3..ef8269911 100644 --- a/.github/workflows/benchmarks-pages.yaml +++ b/.github/workflows/benchmarks-pages.yaml @@ -18,7 +18,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-11-06 + toolchain: nightly-2024-12-17 - name: Run benchmark run: ./scripts/bench.sh -- --output-format bencher | tee output.txt - name: Download previous benchmark data diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8a00b5467..a37679771 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -25,7 +25,7 @@ jobs: - uses: dtolnay/rust-toolchain@master with: components: rustfmt - toolchain: nightly-2024-11-06 + toolchain: nightly-2024-12-17 - uses: Swatinem/rust-cache@v2 - run: scripts/rust_fmt.sh --check @@ -36,7 +36,7 @@ jobs: - uses: dtolnay/rust-toolchain@master with: components: clippy - toolchain: nightly-2024-11-06 + toolchain: nightly-2024-12-17 - uses: Swatinem/rust-cache@v2 - run: scripts/clippy.sh @@ -46,25 +46,25 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-11-06 + toolchain: nightly-2024-12-17 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-11-06 doc + - run: cargo +nightly-2024-12-17 doc - run-wasm32-wasi-tests: + run-wasm32-wasip1-tests: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-11-06 - targets: wasm32-wasi + toolchain: nightly-2024-12-17 + targets: wasm32-wasip1 - uses: taiki-e/install-action@v2 with: tool: wasmtime - uses: Swatinem/rust-cache@v2 - - run: cargo test --target wasm32-wasi + - run: cargo test --target wasm32-wasip1 env: - CARGO_TARGET_WASM32_WASI_RUNNER: "wasmtime run --" + CARGO_TARGET_WASM32_WASIP1_RUNNER: "wasmtime run --" RUSTFLAGS: -C target-feature=+simd128 run-wasm32-unknown-tests: @@ -73,7 +73,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-11-06 + toolchain: nightly-2024-12-17 targets: wasm32-unknown-unknown - uses: Swatinem/rust-cache@v2 - uses: jetli/wasm-pack-action@v0.4.0 @@ -89,9 +89,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-11-06 + toolchain: nightly-2024-12-17 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-11-06 test + - run: cargo +nightly-2024-12-17 test env: RUSTFLAGS: -C target-feature=+neon @@ -104,9 +104,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-11-06 + toolchain: nightly-2024-12-17 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-11-06 test + - run: cargo +nightly-2024-12-17 test env: RUSTFLAGS: -C target-cpu=native -C target-feature=+${{ matrix.target-feature }} @@ -116,7 +116,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-11-06 + toolchain: nightly-2024-12-17 - name: Run benchmark run: ./scripts/bench.sh -- --output-format bencher | tee output.txt - name: Download previous benchmark data @@ -142,7 +142,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-11-06 + toolchain: nightly-2024-12-17 - name: Run benchmark run: ./scripts/bench.sh --features="parallel" -- --output-format bencher | tee output.txt - name: Download previous benchmark data @@ -168,9 +168,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-11-06 + toolchain: nightly-2024-12-17 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-11-06 test + - run: cargo +nightly-2024-12-17 test run-slow-tests: runs-on: ubuntu-latest @@ -178,9 +178,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-11-06 + toolchain: nightly-2024-12-17 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-11-06 test --release --features="slow-tests" + - run: cargo +nightly-2024-12-17 test --release --features="slow-tests" run-tests-parallel: runs-on: ubuntu-latest @@ -188,9 +188,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-11-06 + toolchain: nightly-2024-12-17 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-11-06 test --features="parallel" + - run: cargo +nightly-2024-12-17 test --features="parallel" machete: runs-on: ubuntu-latest @@ -201,9 +201,9 @@ jobs: toolchain: nightly-2024-01-04 - uses: Swatinem/rust-cache@v2 - name: Install Machete - run: cargo +nightly-2024-11-06 install --locked cargo-machete + run: cargo +nightly-2024-12-17 install --locked cargo-machete - name: Run Machete (detect unused dependencies) - run: cargo +nightly-2024-11-06 machete + run: cargo +nightly-2024-12-17 machete all-tests: runs-on: ubuntu-latest @@ -213,7 +213,7 @@ jobs: - run-tests - run-avx-tests - run-neon-tests - - run-wasm32-wasi-tests + - run-wasm32-wasip1-tests - run-slow-tests - run-tests-parallel - machete diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml index 508e0f11b..34b92be13 100644 --- a/.github/workflows/coverage.yaml +++ b/.github/workflows/coverage.yaml @@ -12,14 +12,14 @@ jobs: - uses: dtolnay/rust-toolchain@master with: components: rustfmt - toolchain: nightly-2024-11-06 + toolchain: nightly-2024-12-17 - uses: Swatinem/rust-cache@v2 - name: Install cargo-llvm-cov uses: taiki-e/install-action@cargo-llvm-cov # TODO: Merge coverage reports for tests on different architectures. # - name: Generate code coverage - run: cargo +nightly-2024-11-06 llvm-cov --codecov --output-path codecov.json + run: cargo +nightly-2024-12-17 llvm-cov --codecov --output-path codecov.json env: RUSTFLAGS: "-C target-feature=+avx512f" - name: Upload coverage to Codecov diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 85b121b33..690b698f9 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "nightly-2024-11-06" +channel = "nightly-2024-12-17" diff --git a/scripts/clippy.sh b/scripts/clippy.sh index 43f11957e..a3f74f4b8 100755 --- a/scripts/clippy.sh +++ b/scripts/clippy.sh @@ -1,3 +1,3 @@ #!/bin/bash -cargo +nightly-2024-11-06 clippy "$@" --all-targets --all-features -- -D warnings -D future-incompatible \ +cargo +nightly-2024-12-17 clippy "$@" --all-targets --all-features -- -D warnings -D future-incompatible \ -D nonstandard-style -D rust-2018-idioms -D unused diff --git a/scripts/rust_fmt.sh b/scripts/rust_fmt.sh index 9f95485b0..ae4a9f7f8 100755 --- a/scripts/rust_fmt.sh +++ b/scripts/rust_fmt.sh @@ -1,3 +1,3 @@ #!/bin/bash -cargo +nightly-2024-11-06 fmt --all -- "$@" +cargo +nightly-2024-12-17 fmt --all -- "$@" diff --git a/scripts/test_avx.sh b/scripts/test_avx.sh index f7755f61d..cb0ac2445 100755 --- a/scripts/test_avx.sh +++ b/scripts/test_avx.sh @@ -1,4 +1,4 @@ #!/bin/bash # Can be used as a drop in replacement for `cargo test` with avx512f flag on. # For example, `./scripts/test_avx.sh` will run all tests(not only avx). -RUSTFLAGS="-Awarnings -C target-cpu=native -C target-feature=+avx512f -C opt-level=2" cargo +nightly-2024-11-06 test "$@" +RUSTFLAGS="-Awarnings -C target-cpu=native -C target-feature=+avx512f -C opt-level=2" cargo +nightly-2024-12-17 test "$@" From 5c53025ba2e08044872782966c46a63a702bbdcc Mon Sep 17 00:00:00 2001 From: Gali Michlevich Date: Tue, 24 Dec 2024 17:03:20 +0200 Subject: [PATCH 36/69] Add from_simd_vec() to BaseColumn --- crates/prover/src/core/backend/simd/column.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/crates/prover/src/core/backend/simd/column.rs b/crates/prover/src/core/backend/simd/column.rs index 29a6e58c5..819f79cbb 100644 --- a/crates/prover/src/core/backend/simd/column.rs +++ b/crates/prover/src/core/backend/simd/column.rs @@ -64,6 +64,13 @@ impl BaseColumn { values.into_iter().collect() } + pub fn from_simd(values: Vec) -> Self { + Self { + length: values.len() * N_LANES, + data: values, + } + } + /// Returns a vector of `BaseColumnMutSlice`s, each mutably owning /// `chunk_size` `PackedBaseField`s (i.e, `chuck_size` * `N_LANES` elements). pub fn chunks_mut(&mut self, chunk_size: usize) -> Vec> { From 571e25600a3bbdcc4ccf72c1ad4e0b78e09f0974 Mon Sep 17 00:00:00 2001 From: Ohad Agadi Date: Mon, 23 Dec 2024 11:48:07 +0200 Subject: [PATCH 37/69] iterable trace --- Cargo.lock | 9 ++ Cargo.toml | 2 +- crates/air_utils/Cargo.toml | 12 ++ crates/air_utils/src/lib.rs | 2 + crates/air_utils/src/trace/component_trace.rs | 127 ++++++++++++++++++ crates/air_utils/src/trace/mod.rs | 1 + 6 files changed, 152 insertions(+), 1 deletion(-) create mode 100644 crates/air_utils/Cargo.toml create mode 100644 crates/air_utils/src/lib.rs create mode 100644 crates/air_utils/src/trace/component_trace.rs create mode 100644 crates/air_utils/src/trace/mod.rs diff --git a/Cargo.lock b/Cargo.lock index a7e009352..54d4d4ff7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1035,6 +1035,15 @@ dependencies = [ "serde", ] +[[package]] +name = "stwo-air-utils" +version = "0.1.1" +dependencies = [ + "bytemuck", + "itertools 0.12.1", + "stwo-prover", +] + [[package]] name = "stwo-prover" version = "0.1.1" diff --git a/Cargo.toml b/Cargo.toml index 0f314a496..fadd620de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["crates/prover"] +members = ["crates/prover", "crates/air_utils"] resolver = "2" [workspace.package] diff --git a/crates/air_utils/Cargo.toml b/crates/air_utils/Cargo.toml new file mode 100644 index 000000000..c021b76cb --- /dev/null +++ b/crates/air_utils/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "stwo-air-utils" +version.workspace = true +edition.workspace = true + +[dependencies] +bytemuck.workspace = true +itertools.workspace = true +stwo-prover = { path = "../prover" } + +[lib] +bench = false diff --git a/crates/air_utils/src/lib.rs b/crates/air_utils/src/lib.rs new file mode 100644 index 000000000..8603c2cee --- /dev/null +++ b/crates/air_utils/src/lib.rs @@ -0,0 +1,2 @@ +#![feature(exact_size_is_empty, raw_slice_split, portable_simd)] +pub mod trace; diff --git a/crates/air_utils/src/trace/component_trace.rs b/crates/air_utils/src/trace/component_trace.rs new file mode 100644 index 000000000..1fe6928f0 --- /dev/null +++ b/crates/air_utils/src/trace/component_trace.rs @@ -0,0 +1,127 @@ +use std::marker::PhantomData; + +use bytemuck::Zeroable; +use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES}; +use stwo_prover::core::backend::simd::SimdBackend; +use stwo_prover::core::fields::m31::M31; +use stwo_prover::core::poly::circle::CircleEvaluation; +use stwo_prover::core::poly::BitReversedOrder; + +/// A 2D Matrix of [`PackedM31`] values. +/// Used for generating the witness of 'Stwo' proofs. +/// Stored as an array of `N` columns, each column is a vector of [`PackedM31`] values. +/// Exposes an iterator over mutable references to the rows of the matrix. +/// +/// # Example: +/// +/// ```text +/// Computation trace of a^2 + (a + 1)^2 for a in 0..256 +/// ``` +/// ``` +/// use stwo_air_utils::trace::component_trace::ComponentTrace; +/// use itertools::Itertools; +/// use stwo_prover::core::backend::simd::m31::{PackedM31, N_LANES}; +/// use stwo_prover::core::fields::m31::M31; +/// use stwo_prover::core::fields::FieldExpOps; +/// +/// const N_COLUMNS: usize = 3; +/// const LOG_SIZE: u32 = 8; +/// let mut trace = ComponentTrace::::zeroed(LOG_SIZE); +/// let example_input = (0..1 << LOG_SIZE).map(M31::from).collect_vec(); // 0..256 +/// trace +/// .iter_mut() +/// .zip(example_input.chunks(N_LANES)) +/// .chunks(4) +/// .into_iter() +/// .for_each(|chunk| { +/// chunk.into_iter().for_each(|(row, input)| { +/// *row[0] = PackedM31::from_array(input.try_into().unwrap()); +/// *row[1] = *row[0] + PackedM31::broadcast(M31(1)); +/// *row[2] = row[0].square() + row[1].square(); +/// }) +/// }); +/// +/// let first_3_rows = (0..N_COLUMNS).map(|i| trace.row_at(i)).collect::>(); +/// assert_eq!(first_3_rows, [[0,1,1], [1,2,5], [2,3,13]].map(|row| row.map(M31::from))); +/// ``` +#[derive(Debug)] +pub struct ComponentTrace { + data: [Vec; N], + + /// Log number of non-packed rows in each column. + log_size: u32, +} + +impl ComponentTrace { + pub fn zeroed(log_size: u32) -> Self { + let n_simd_elems = 1 << (log_size - LOG_N_LANES); + let data = [(); N].map(|_| vec![PackedM31::zeroed(); n_simd_elems]); + Self { data, log_size } + } + + /// # Safety + /// The caller must ensure that the column is populated before being used. + #[allow(clippy::uninit_vec)] + pub unsafe fn uninitialized(_log_size: u32) -> Self { + todo!() + } + + pub fn log_size(&self) -> u32 { + self.log_size + } + + pub fn iter_mut(&mut self) -> RowIterMut<'_, N> { + RowIterMut::new(self.data.each_mut().map(|column| column.as_mut_slice())) + } + + pub fn to_evals(self) -> [CircleEvaluation; N] { + todo!() + } + + pub fn row_at(&self, row: usize) -> [M31; N] { + assert!(row < 1 << self.log_size); + let packed_row = row / N_LANES; + let idx_in_simd_vector = row % N_LANES; + self.data + .each_ref() + .map(|column| column[packed_row].to_array()[idx_in_simd_vector]) + } +} + +pub type MutRow<'trace, const N: usize> = [&'trace mut PackedM31; N]; + +/// An iterator over mutable references to the rows of a [`ComponentTrace`]. +pub struct RowIterMut<'trace, const N: usize> { + v: [*mut [PackedM31]; N], + phantom: PhantomData<&'trace ()>, +} +impl<'trace, const N: usize> RowIterMut<'trace, N> { + pub fn new(slice: [&'trace mut [PackedM31]; N]) -> Self { + Self { + v: slice.map(|s| s as *mut _), + phantom: PhantomData, + } + } +} +impl<'trace, const N: usize> Iterator for RowIterMut<'trace, N> { + type Item = MutRow<'trace, N>; + + fn next(&mut self) -> Option { + if self.v[0].is_empty() { + return None; + } + let item = std::array::from_fn(|i| unsafe { + // SAFETY: The self.v contract ensures that any split_at_mut is valid. + let (head, tail) = self.v[i].split_at_mut(1); + self.v[i] = tail; + &mut (*head)[0] + }); + Some(item) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.v[0].len(); + (len, Some(len)) + } +} +impl ExactSizeIterator for RowIterMut<'_, N> {} diff --git a/crates/air_utils/src/trace/mod.rs b/crates/air_utils/src/trace/mod.rs new file mode 100644 index 000000000..03a022de5 --- /dev/null +++ b/crates/air_utils/src/trace/mod.rs @@ -0,0 +1 @@ +pub mod component_trace; From 95b66e90cd331b42554c77d00bfa3a063c1f0ddc Mon Sep 17 00:00:00 2001 From: Ohad Agadi Date: Mon, 23 Dec 2024 11:58:26 +0200 Subject: [PATCH 38/69] par trace --- Cargo.lock | 1 + crates/air_utils/Cargo.toml | 1 + crates/air_utils/src/trace/component_trace.rs | 86 +++++++----- crates/air_utils/src/trace/mod.rs | 1 + crates/air_utils/src/trace/row_iterator.rs | 126 ++++++++++++++++++ 5 files changed, 181 insertions(+), 34 deletions(-) create mode 100644 crates/air_utils/src/trace/row_iterator.rs diff --git a/Cargo.lock b/Cargo.lock index 54d4d4ff7..c87fc6628 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1041,6 +1041,7 @@ version = "0.1.1" dependencies = [ "bytemuck", "itertools 0.12.1", + "rayon", "stwo-prover", ] diff --git a/crates/air_utils/Cargo.toml b/crates/air_utils/Cargo.toml index c021b76cb..7d09a7eaf 100644 --- a/crates/air_utils/Cargo.toml +++ b/crates/air_utils/Cargo.toml @@ -6,6 +6,7 @@ edition.workspace = true [dependencies] bytemuck.workspace = true itertools.workspace = true +rayon = { version = "1.10.0", optional = false } stwo-prover = { path = "../prover" } [lib] diff --git a/crates/air_utils/src/trace/component_trace.rs b/crates/air_utils/src/trace/component_trace.rs index 1fe6928f0..20a96579d 100644 --- a/crates/air_utils/src/trace/component_trace.rs +++ b/crates/air_utils/src/trace/component_trace.rs @@ -1,5 +1,3 @@ -use std::marker::PhantomData; - use bytemuck::Zeroable; use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES}; use stwo_prover::core::backend::simd::SimdBackend; @@ -7,6 +5,8 @@ use stwo_prover::core::fields::m31::M31; use stwo_prover::core::poly::circle::CircleEvaluation; use stwo_prover::core::poly::BitReversedOrder; +use super::row_iterator::{ParRowIterMut, RowIterMut}; + /// A 2D Matrix of [`PackedM31`] values. /// Used for generating the witness of 'Stwo' proofs. /// Stored as an array of `N` columns, each column is a vector of [`PackedM31`] values. @@ -74,6 +74,10 @@ impl ComponentTrace { RowIterMut::new(self.data.each_mut().map(|column| column.as_mut_slice())) } + pub fn par_iter_mut(&mut self) -> ParRowIterMut<'_, N> { + ParRowIterMut::new(self.data.each_mut().map(|column| column.as_mut_slice())) + } + pub fn to_evals(self) -> [CircleEvaluation; N] { todo!() } @@ -88,40 +92,54 @@ impl ComponentTrace { } } -pub type MutRow<'trace, const N: usize> = [&'trace mut PackedM31; N]; +#[cfg(test)] +mod tests { + use itertools::Itertools; + use stwo_prover::core::backend::simd::m31::{PackedM31, N_LANES}; + use stwo_prover::core::fields::m31::M31; + use stwo_prover::core::fields::FieldExpOps; -/// An iterator over mutable references to the rows of a [`ComponentTrace`]. -pub struct RowIterMut<'trace, const N: usize> { - v: [*mut [PackedM31]; N], - phantom: PhantomData<&'trace ()>, -} -impl<'trace, const N: usize> RowIterMut<'trace, N> { - pub fn new(slice: [&'trace mut [PackedM31]; N]) -> Self { - Self { - v: slice.map(|s| s as *mut _), - phantom: PhantomData, - } - } -} -impl<'trace, const N: usize> Iterator for RowIterMut<'trace, N> { - type Item = MutRow<'trace, N>; + #[test] + fn test_parallel_trace() { + use rayon::iter::{IndexedParallelIterator, ParallelIterator}; + use rayon::slice::ParallelSlice; - fn next(&mut self) -> Option { - if self.v[0].is_empty() { - return None; - } - let item = std::array::from_fn(|i| unsafe { - // SAFETY: The self.v contract ensures that any split_at_mut is valid. - let (head, tail) = self.v[i].split_at_mut(1); - self.v[i] = tail; - &mut (*head)[0] - }); - Some(item) - } + const N_COLUMNS: usize = 3; + const LOG_SIZE: u32 = 8; + const CHUNK_SIZE: usize = 4; + let mut trace = super::ComponentTrace::::zeroed(LOG_SIZE); + let arr = (0..1 << LOG_SIZE).map(M31::from).collect_vec(); + let expected = arr + .iter() + .map(|&a| { + let b = a + M31::from(1); + let c = a.square() + b.square(); + (a, b, c) + }) + .multiunzip(); + + trace + .par_iter_mut() + .zip(arr.par_chunks(N_LANES)) + .chunks(CHUNK_SIZE) + .for_each(|chunk| { + chunk.into_iter().for_each(|(row, input)| { + *row[0] = PackedM31::from_array(input.try_into().unwrap()); + *row[1] = *row[0] + PackedM31::broadcast(M31(1)); + *row[2] = row[0].square() + row[1].square(); + }); + }); + let actual = trace + .data + .map(|c| { + c.into_iter() + .flat_map(|packed| packed.to_array()) + .collect_vec() + }) + .into_iter() + .next_tuple() + .unwrap(); - fn size_hint(&self) -> (usize, Option) { - let len = self.v[0].len(); - (len, Some(len)) + assert_eq!(expected, actual); } } -impl ExactSizeIterator for RowIterMut<'_, N> {} diff --git a/crates/air_utils/src/trace/mod.rs b/crates/air_utils/src/trace/mod.rs index 03a022de5..6e44c9033 100644 --- a/crates/air_utils/src/trace/mod.rs +++ b/crates/air_utils/src/trace/mod.rs @@ -1 +1,2 @@ pub mod component_trace; +mod row_iterator; diff --git a/crates/air_utils/src/trace/row_iterator.rs b/crates/air_utils/src/trace/row_iterator.rs new file mode 100644 index 000000000..78d03ebea --- /dev/null +++ b/crates/air_utils/src/trace/row_iterator.rs @@ -0,0 +1,126 @@ +use std::marker::PhantomData; + +use rayon::iter::plumbing::{bridge, Consumer, Producer, ProducerCallback, UnindexedConsumer}; +use rayon::prelude::*; +use stwo_prover::core::backend::simd::m31::PackedM31; + +pub type MutRow<'trace, const N: usize> = [&'trace mut PackedM31; N]; + +/// An iterator over mutable references to the rows of a [`super::component_trace::ComponentTrace`]. +// TODO(Ohad): Iterating over single rows is not optimal, figure out optimal chunk size when using +// this iterator. +pub struct RowIterMut<'trace, const N: usize> { + v: [*mut [PackedM31]; N], + phantom: PhantomData<&'trace ()>, +} +impl<'trace, const N: usize> RowIterMut<'trace, N> { + pub fn new(slice: [&'trace mut [PackedM31]; N]) -> Self { + Self { + v: slice.map(|s| s as *mut _), + phantom: PhantomData, + } + } +} +impl<'trace, const N: usize> Iterator for RowIterMut<'trace, N> { + type Item = MutRow<'trace, N>; + + fn next(&mut self) -> Option { + if self.v[0].is_empty() { + return None; + } + let item = std::array::from_fn(|i| unsafe { + // SAFETY: The self.v contract ensures that any split_at_mut is valid. + let (head, tail) = self.v[i].split_at_mut(1); + self.v[i] = tail; + &mut (*head)[0] + }); + Some(item) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.v[0].len(); + (len, Some(len)) + } +} +impl ExactSizeIterator for RowIterMut<'_, N> {} +impl DoubleEndedIterator for RowIterMut<'_, N> { + fn next_back(&mut self) -> Option { + if self.v[0].is_empty() { + return None; + } + let item = std::array::from_fn(|i| unsafe { + // SAFETY: The self.v contract ensures that any split_at_mut is valid. + let (head, tail) = self.v[i].split_at_mut(self.v[i].len() - 1); + self.v[i] = head; + &mut (*tail)[0] + }); + Some(item) + } +} + +struct RowProducer<'trace, const N: usize> { + data: [&'trace mut [PackedM31]; N], +} +impl<'trace, const N: usize> Producer for RowProducer<'trace, N> { + type Item = MutRow<'trace, N>; + + fn split_at(self, index: usize) -> (Self, Self) { + let mut left: [_; N] = unsafe { std::mem::zeroed() }; + let mut right: [_; N] = unsafe { std::mem::zeroed() }; + for (i, slice) in self.data.into_iter().enumerate() { + let (lhs, rhs) = slice.split_at_mut(index); + left[i] = lhs; + right[i] = rhs; + } + (RowProducer { data: left }, RowProducer { data: right }) + } + + type IntoIter = RowIterMut<'trace, N>; + + fn into_iter(self) -> Self::IntoIter { + RowIterMut { + v: self.data.map(|s| s as *mut _), + phantom: PhantomData, + } + } +} + +/// A parallel iterator over mutable references to the rows of a +/// [`super::component_trace::ComponentTrace`]. [`super::component_trace::ComponentTrace`] is an +/// array of columns, hence iterating over rows is not trivial. Iteration is done by iterating over +/// `N` columns in parallel. +pub struct ParRowIterMut<'trace, const N: usize> { + data: [&'trace mut [PackedM31]; N], +} +impl<'trace, const N: usize> ParRowIterMut<'trace, N> { + pub(super) fn new(data: [&'trace mut [PackedM31]; N]) -> Self { + Self { data } + } +} +impl<'trace, const N: usize> ParallelIterator for ParRowIterMut<'trace, N> { + type Item = MutRow<'trace, N>; + + fn drive_unindexed(self, consumer: D) -> D::Result + where + D: UnindexedConsumer, + { + bridge(self, consumer) + } + + fn opt_len(&self) -> Option { + Some(self.len()) + } +} +impl IndexedParallelIterator for ParRowIterMut<'_, N> { + fn len(&self) -> usize { + self.data[0].len() + } + + fn drive>(self, consumer: D) -> D::Result { + bridge(self, consumer) + } + + fn with_producer>(self, callback: CB) -> CB::Output { + callback.callback(RowProducer { data: self.data }) + } +} From cc4641daf542d6e53b46cbaa804bddb857a7b3f9 Mon Sep 17 00:00:00 2001 From: VitaliiH Date: Thu, 26 Dec 2024 19:43:01 +0100 Subject: [PATCH 39/69] wip eval_at_point - correctness up to 2^8 --- crates/prover/src/core/backend/icicle/mod.rs | 122 ++++++++-------- .../prover/src/core/backend/icicle/utils.rs | 80 +++++++---- .../prover/src/examples/wide_fibonacci/mod.rs | 132 +++++++++--------- 3 files changed, 187 insertions(+), 147 deletions(-) diff --git a/crates/prover/src/core/backend/icicle/mod.rs b/crates/prover/src/core/backend/icicle/mod.rs index 4af7695ca..ccee18192 100644 --- a/crates/prover/src/core/backend/icicle/mod.rs +++ b/crates/prover/src/core/backend/icicle/mod.rs @@ -271,7 +271,20 @@ impl PolyOps for IcicleBackend { fn eval_at_point(poly: &CirclePoly, point: CirclePoint) -> SecureField { // todo!() - unsafe { CpuBackend::eval_at_point(transmute(poly), point) } + // unsafe { CpuBackend::eval_at_point(transmute(poly), point) } + if poly.log_size() == 0 { + return poly.coeffs[0].into(); + } + // TODO: to gpu after correctness fix + let mut mappings = vec![point.y]; + let mut x = point.x; + for _ in 1..poly.log_size() { + mappings.push(x); + x = CirclePoint::double_x(x); + } + mappings.reverse(); + + crate::core::backend::icicle::utils::fold(&poly.coeffs, &mappings) } fn extend(poly: &CirclePoly, log_size: u32) -> CirclePoly { @@ -553,7 +566,6 @@ impl FriOps for IcicleBackend { &mut d_folded_eval[..], icicle_alpha, &cfg, - ) .unwrap(); @@ -607,61 +619,61 @@ impl QuotientOps for IcicleBackend { sample_batches: &[ColumnSampleBatch], log_blowup_factor: u32, ) -> SecureEvaluation { - // TODO: the fn accumulate_quotients( fix seems doesn't work for this branch https://github.com/ingonyama-zk/icicle/commit/eb82fbe20d116829eebf63d9b77e9a2eb2b0b0b0 - unsafe { - transmute(CpuBackend::accumulate_quotients( - domain, - unsafe { transmute(columns) }, - random_coeff, - sample_batches, - log_blowup_factor, - )) - } - // let icicle_columns_raw = columns - // .iter() - // .flat_map(|x| x.iter().map(|&y| unsafe { transmute(y) })) - // .collect_vec(); - // let icicle_columns = HostSlice::from_slice(&icicle_columns_raw); - // let icicle_sample_batches = sample_batches - // .into_iter() - // .map(|sample| { - // let (columns, values) = sample - // .columns_and_values - // .iter() - // .map(|(index, value)| { - // ((*index) as u32, unsafe { - // transmute::(*value) - // }) - // }) - // .unzip(); - - // quotient::ColumnSampleBatch { - // point: unsafe { transmute(sample.point) }, - // columns, - // values, - // } - // }) - // .collect_vec(); - // let mut icicle_result_raw = vec![QuarticExtensionField::zero(); domain.size()]; - // let icicle_result = HostSlice::from_mut_slice(icicle_result_raw.as_mut_slice()); - // let cfg = quotient::QuotientConfig::default(); - - // quotient::accumulate_quotients_wrapped( - // // domain.half_coset.initial_index.0 as u32, - // // domain.half_coset.step_size.0 as u32, - // domain.log_size() as u32, - // icicle_columns, - // unsafe { transmute(random_coeff) }, - // &icicle_sample_batches, - // icicle_result, - // &cfg, - // ); - // // TODO: make it on cuda side - // let mut result = unsafe { SecureColumnByCoords::uninitialized(domain.size()) }; - // (0..domain.size()).for_each(|i| result.set(i, unsafe { transmute(icicle_result_raw[i]) })); - // SecureEvaluation::new(domain, result) + // unsafe { + // transmute(CpuBackend::accumulate_quotients( + // domain, + // unsafe { transmute(columns) }, + // random_coeff, + // sample_batches, + // log_blowup_factor, + // )) + // } + + let icicle_columns_raw = columns + .iter() + .flat_map(|x| x.iter().map(|&y| unsafe { transmute(y) })) + .collect_vec(); + let icicle_columns = HostSlice::from_slice(&icicle_columns_raw); + let icicle_sample_batches = sample_batches + .into_iter() + .map(|sample| { + let (columns, values) = sample + .columns_and_values + .iter() + .map(|(index, value)| { + ((*index) as u32, unsafe { + transmute::(*value) + }) + }) + .unzip(); + + quotient::ColumnSampleBatch { + point: unsafe { transmute(sample.point) }, + columns, + values, + } + }) + .collect_vec(); + let mut icicle_result_raw = vec![QuarticExtensionField::zero(); domain.size()]; + let icicle_result = HostSlice::from_mut_slice(icicle_result_raw.as_mut_slice()); + let cfg = quotient::QuotientConfig::default(); + + quotient::accumulate_quotients_wrapped( + // domain.half_coset.initial_index.0 as u32, + // domain.half_coset.step_size.0 as u32, + domain.log_size() as u32, + icicle_columns, + unsafe { transmute(random_coeff) }, + &icicle_sample_batches, + icicle_result, + &cfg, + ); + // TODO: make it on cuda side + let mut result = unsafe { SecureColumnByCoords::uninitialized(domain.size()) }; + (0..domain.size()).for_each(|i| result.set(i, unsafe { transmute(icicle_result_raw[i]) })); + SecureEvaluation::new(domain, result) } } diff --git a/crates/prover/src/core/backend/icicle/utils.rs b/crates/prover/src/core/backend/icicle/utils.rs index 887c831aa..7de84d099 100644 --- a/crates/prover/src/core/backend/icicle/utils.rs +++ b/crates/prover/src/core/backend/icicle/utils.rs @@ -1,5 +1,29 @@ +use std::mem::transmute; + +use icicle_core::ntt::FieldImpl; +use icicle_core::vec_ops::{fold_scalars, VecOps, VecOpsConfig}; +use icicle_cuda_runtime::memory::HostSlice; +use icicle_m31::field::{ComplexExtensionField, QuarticExtensionField, ScalarField}; + +use crate::core::fields::m31::M31; +use crate::core::fields::qm31::QM31; use crate::core::fields::{ExtensionOf, Field}; +macro_rules! select_result_type { + (1) => { + ScalarField + }; + (2) => { + ComplexExtensionField + }; + (4) => { + QuarticExtensionField + }; + ($other:expr) => { + compile_error!("Unsupported limbs count") + }; +} + /// Folds values recursively in `O(n)` by a hierarchical application of folding factors. /// /// i.e. folding `n = 8` values with `folding_factors = [x, y, z]`: @@ -19,37 +43,35 @@ use crate::core::fields::{ExtensionOf, Field}; /// Panics if the number of values is not a power of two or if an incorrect number of of folding /// factors is provided. // TODO(Andrew): Can be made to run >10x faster by unrolling lower layers of recursion -pub fn fold>(values: &[F], folding_factors: &[E]) -> E { - let n = values.len(); - assert_eq!(n, 1 << folding_factors.len()); - if n == 1 { - let res: E = values[0].into(); - return res; - } - let (lhs_values, rhs_values) = values.split_at(n / 2); - let (folding_factor, folding_factors) = folding_factors.split_first().unwrap(); - let lhs_val = fold(lhs_values, folding_factors); - let rhs_val = fold(rhs_values, folding_factors); - // println!( - // "n={:?} lhs_val{:?} + rhs_val{:?} x folding_factor: {:?}", - // n, lhs_val, rhs_val, *folding_factor - // ); - let res = lhs_val + rhs_val * *folding_factor; - // println!("res = {:?}; ", res); - res -} +pub fn fold<'a, F: Field, E: ExtensionOf + Sized>( + values: &'a [F], + folding_factors: &'a [E], +) -> E { + assert!(values.len().is_power_of_two()); -pub fn fold_gpu>(values: &[F], folding_factors: &[E]) -> E { - let n = values.len(); - assert_eq!(n, 1 << folding_factors.len()); - if n == 1 { - return values[0].into(); + let a = HostSlice::from_slice(unsafe { transmute(values) }); + let b = HostSlice::from_slice(unsafe { transmute(folding_factors) }); + let mut result = vec![QuarticExtensionField::zero()]; + let res = HostSlice::from_mut_slice(&mut result); + + let cfg = VecOpsConfig::default(); + + // TODO: generic macro for selecting appropriate result type + // let limbs_count: usize = std::mem::size_of::() / 4; + // type EE = select_result_type!(limbs_count); + // let limbs_count: usize = std::mem::size_of::() / 4; + // type FF = select_result_type!(limbs_count); + + fold_scalars::(a, b, res, &cfg).unwrap(); + + unsafe { + let vec: Vec = transmute(result); + if let Some(first) = vec.first() { + *first + } else { + panic!("Fold result empty."); + } } - let (lhs_values, rhs_values) = values.split_at(n / 2); - let (folding_factor, folding_factors) = folding_factors.split_first().unwrap(); - let lhs_val = fold(lhs_values, folding_factors); - let rhs_val = fold(rhs_values, folding_factors); - lhs_val + rhs_val * *folding_factor } #[cfg(test)] diff --git a/crates/prover/src/examples/wide_fibonacci/mod.rs b/crates/prover/src/examples/wide_fibonacci/mod.rs index dee4bcc3e..1251f9335 100644 --- a/crates/prover/src/examples/wide_fibonacci/mod.rs +++ b/crates/prover/src/examples/wide_fibonacci/mod.rs @@ -237,71 +237,77 @@ mod tests { // type TheBackend = CpuBackend; let min_log = get_env_var("MIN_FIB_LOG", 2u32); - let max_log = get_env_var("MAX_FIB_LOG", 6u32); + let max_log = get_env_var("MAX_FIB_LOG", 18u32); for log_n_instances in min_log..=max_log { - println!("proving for 2^{:?}...", log_n_instances); - let config = PcsConfig::default(); - // Precompute twiddles. - let twiddles = TheBackend::precompute_twiddles( - CanonicCoset::new(log_n_instances + 1 + config.fri_config.log_blowup_factor) - .circle_domain() - .half_coset, - ); - - // Setup protocol. - let prover_channel = &mut Blake2sChannel::default(); - let mut commitment_scheme = - CommitmentSchemeProver::::new(config, &twiddles); - - // Preprocessed trace - let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals([]); - tree_builder.commit(prover_channel); - - // Trace. - let trace: Vec> = - generate_test_trace(log_n_instances) - .iter() - .map(|c| unsafe { std::mem::transmute(c.to_cpu()) }) - .collect_vec(); - - let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(trace); - tree_builder.commit(prover_channel); - - // Prove constraints. - let component = WideFibonacciComponent::new( - &mut TraceLocationAllocator::default(), - WideFibonacciEval:: { - log_n_rows: log_n_instances, - }, - (SecureField::zero(), None), - ); - - let start = std::time::Instant::now(); - let proof = prove::( - &[&component], - prover_channel, - commitment_scheme, - ) - .unwrap(); - println!( - "proving for 2^{:?} took {:?} ms", - log_n_instances, - start.elapsed().as_millis() - ); - - // Verify. - let verifier_channel = &mut Blake2sChannel::default(); - let commitment_scheme = - &mut CommitmentSchemeVerifier::::new(config); - - // Retrieve the expected column sizes in each commitment interaction, from the AIR. - let sizes = component.trace_log_degree_bounds(); - commitment_scheme.commit(proof.commitments[0], &sizes[0], verifier_channel); - commitment_scheme.commit(proof.commitments[1], &sizes[1], verifier_channel); - verify(&[&component], verifier_channel, commitment_scheme, proof).unwrap(); + for _ in 0..1 { + println!("proving for 2^{:?}...", log_n_instances); + let config = PcsConfig::default(); + // Precompute twiddles. + let twiddles = TheBackend::precompute_twiddles( + CanonicCoset::new(log_n_instances + 1 + config.fri_config.log_blowup_factor) + .circle_domain() + .half_coset, + ); + + // Setup protocol. + let prover_channel = &mut Blake2sChannel::default(); + let mut commitment_scheme = CommitmentSchemeProver::< + TheBackend, + Blake2sMerkleChannel, + >::new(config, &twiddles); + + // Preprocessed trace + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals([]); + tree_builder.commit(prover_channel); + + // Trace. + let trace: Vec> = + generate_test_trace(log_n_instances) + .iter() + .map(|c| unsafe { std::mem::transmute(c.to_cpu()) }) + .collect_vec(); + + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(trace); + tree_builder.commit(prover_channel); + + // Prove constraints. + let component = WideFibonacciComponent::new( + &mut TraceLocationAllocator::default(), + WideFibonacciEval:: { + log_n_rows: log_n_instances, + }, + (SecureField::zero(), None), + ); + + let start = std::time::Instant::now(); + let proof = prove::( + &[&component], + prover_channel, + commitment_scheme, + ) + .unwrap(); + println!( + "proving for 2^{:?} took {:?} ms", + log_n_instances, + start.elapsed().as_millis() + ); + + // Verify. + let verifier_channel = &mut Blake2sChannel::default(); + let commitment_scheme = + &mut CommitmentSchemeVerifier::::new(config); + + // Retrieve the expected column sizes in each commitment interaction, from the AIR. + let sizes = component.trace_log_degree_bounds(); + commitment_scheme.commit(proof.commitments[0], &sizes[0], verifier_channel); + commitment_scheme.commit(proof.commitments[1], &sizes[1], verifier_channel); + verify(&[&component], verifier_channel, commitment_scheme, proof).unwrap_or_else(|err| { + println!("verify failed for {} with: {}", log_n_instances, err); + }); + } } } From a865b2b2ff3a7fb72c0b99b25bbafefb3b33f927 Mon Sep 17 00:00:00 2001 From: Ohad <137686240+ohad-starkware@users.noreply.github.com> Date: Sun, 29 Dec 2024 16:15:24 +0200 Subject: [PATCH 40/69] move inverse behind impl (#954) --- crates/prover/src/constraint_framework/expr/simplify.rs | 1 - crates/prover/src/core/backend/simd/fft/ifft.rs | 1 - crates/prover/src/core/fields/m31.rs | 8 ++++++-- crates/prover/src/core/poly/line.rs | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/crates/prover/src/constraint_framework/expr/simplify.rs b/crates/prover/src/constraint_framework/expr/simplify.rs index 528b23627..3632c09e2 100644 --- a/crates/prover/src/constraint_framework/expr/simplify.rs +++ b/crates/prover/src/constraint_framework/expr/simplify.rs @@ -2,7 +2,6 @@ use num_traits::{One, Zero}; use super::{BaseExpr, ExtExpr}; use crate::core::fields::qm31::SecureField; -use crate::core::fields::FieldExpOps; /// Applies simplifications to arithmetic expressions that can be used both for `BaseExpr` and for /// `ExtExpr`. diff --git a/crates/prover/src/core/backend/simd/fft/ifft.rs b/crates/prover/src/core/backend/simd/fft/ifft.rs index feab2ab54..b41cc70de 100644 --- a/crates/prover/src/core/backend/simd/fft/ifft.rs +++ b/crates/prover/src/core/backend/simd/fft/ifft.rs @@ -13,7 +13,6 @@ use crate::core::backend::cpu::bit_reverse; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::backend::simd::utils::UnsafeMut; use crate::core::circle::Coset; -use crate::core::fields::FieldExpOps; use crate::parallel_iter; /// Performs an Inverse Circle Fast Fourier Transform (ICFFT) on the given values. diff --git a/crates/prover/src/core/fields/m31.rs b/crates/prover/src/core/fields/m31.rs index 7c28bf33a..9ef981ecf 100644 --- a/crates/prover/src/core/fields/m31.rs +++ b/crates/prover/src/core/fields/m31.rs @@ -62,6 +62,11 @@ impl M31 { pub const fn from_u32_unchecked(arg: u32) -> Self { Self(arg) } + + pub fn inverse(&self) -> Self { + assert!(!self.is_zero(), "0 has no inverse"); + pow2147483645(*self) + } } impl Display for M31 { @@ -112,8 +117,7 @@ impl FieldExpOps for M31 { /// assert_eq!(v.inverse() * v, BaseField::one()); /// ``` fn inverse(&self) -> Self { - assert!(!self.is_zero(), "0 has no inverse"); - pow2147483645(*self) + self.inverse() } } diff --git a/crates/prover/src/core/poly/line.rs b/crates/prover/src/core/poly/line.rs index 9a8a4cf6d..d684d0c58 100644 --- a/crates/prover/src/core/poly/line.rs +++ b/crates/prover/src/core/poly/line.rs @@ -16,7 +16,7 @@ use crate::core::fft::ibutterfly; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SecureColumnByCoords; -use crate::core::fields::{ExtensionOf, FieldExpOps, FieldOps}; +use crate::core::fields::{ExtensionOf, FieldOps}; /// Domain comprising of the x-coordinates of points in a [Coset]. /// From 949c053102559a01adefd4a0ecf28db9d948293a Mon Sep 17 00:00:00 2001 From: Ohad Agadi Date: Mon, 23 Dec 2024 11:58:26 +0200 Subject: [PATCH 41/69] trace to evals --- crates/air_utils/src/trace/component_trace.rs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/crates/air_utils/src/trace/component_trace.rs b/crates/air_utils/src/trace/component_trace.rs index 20a96579d..7252eb901 100644 --- a/crates/air_utils/src/trace/component_trace.rs +++ b/crates/air_utils/src/trace/component_trace.rs @@ -1,8 +1,9 @@ use bytemuck::Zeroable; +use stwo_prover::core::backend::simd::column::BaseColumn; use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES}; use stwo_prover::core::backend::simd::SimdBackend; use stwo_prover::core::fields::m31::M31; -use stwo_prover::core::poly::circle::CircleEvaluation; +use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; use stwo_prover::core::poly::BitReversedOrder; use super::row_iterator::{ParRowIterMut, RowIterMut}; @@ -10,6 +11,7 @@ use super::row_iterator::{ParRowIterMut, RowIterMut}; /// A 2D Matrix of [`PackedM31`] values. /// Used for generating the witness of 'Stwo' proofs. /// Stored as an array of `N` columns, each column is a vector of [`PackedM31`] values. +/// All columns are of the same length. /// Exposes an iterator over mutable references to the rows of the matrix. /// /// # Example: @@ -46,6 +48,7 @@ use super::row_iterator::{ParRowIterMut, RowIterMut}; /// ``` #[derive(Debug)] pub struct ComponentTrace { + /// Columns are assumed to be of the same length. data: [Vec; N], /// Log number of non-packed rows in each column. @@ -79,7 +82,13 @@ impl ComponentTrace { } pub fn to_evals(self) -> [CircleEvaluation; N] { - todo!() + let domain = CanonicCoset::new(self.log_size).circle_domain(); + self.data.map(|column| { + CircleEvaluation::::new( + domain, + BaseColumn::from_simd(column), + ) + }) } pub fn row_at(&self, row: usize) -> [M31; N] { From c5869945cf780f381a83fde74195e97659d33286 Mon Sep 17 00:00:00 2001 From: Ohad Agadi Date: Mon, 23 Dec 2024 13:08:31 +0200 Subject: [PATCH 42/69] uninitialized trace --- crates/air_utils/src/trace/component_trace.rs | 41 ++++++++++++++++++- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/crates/air_utils/src/trace/component_trace.rs b/crates/air_utils/src/trace/component_trace.rs index 7252eb901..ba59964ea 100644 --- a/crates/air_utils/src/trace/component_trace.rs +++ b/crates/air_utils/src/trace/component_trace.rs @@ -56,17 +56,39 @@ pub struct ComponentTrace { } impl ComponentTrace { + /// Creates a new `ComponentTrace` with all values initialized to zero. + /// The number of rows in each column is `2^log_size`. + /// # Panics: + /// if log_size < 4. pub fn zeroed(log_size: u32) -> Self { + assert!( + log_size >= LOG_N_LANES, + "log_size < LOG_N_LANES not supported!" + ); let n_simd_elems = 1 << (log_size - LOG_N_LANES); let data = [(); N].map(|_| vec![PackedM31::zeroed(); n_simd_elems]); Self { data, log_size } } + /// Creates a new `ComponentTrace` with all values uninitialized. /// # Safety /// The caller must ensure that the column is populated before being used. + /// The number of rows in each column is `2^log_size`. + /// # Panics: + /// if `log_size` < 4. #[allow(clippy::uninit_vec)] - pub unsafe fn uninitialized(_log_size: u32) -> Self { - todo!() + pub unsafe fn uninitialized(log_size: u32) -> Self { + assert!( + log_size >= LOG_N_LANES, + "log_size < LOG_N_LANES not supported!" + ); + let n_simd_elems = 1 << (log_size - LOG_N_LANES); + let data = [(); N].map(|_| { + let mut vec = Vec::with_capacity(n_simd_elems); + vec.set_len(n_simd_elems); + vec + }); + Self { data, log_size } } pub fn log_size(&self) -> u32 { @@ -151,4 +173,19 @@ mod tests { assert_eq!(expected, actual); } + + #[test] + fn test_component_trace_uninitialized_success() { + const N_COLUMNS: usize = 3; + const LOG_SIZE: u32 = 4; + unsafe { super::ComponentTrace::::uninitialized(LOG_SIZE) }; + } + + #[should_panic = "log_size < LOG_N_LANES not supported!"] + #[test] + fn test_component_trace_uninitialized_fails() { + const N_COLUMNS: usize = 3; + const LOG_SIZE: u32 = 3; + unsafe { super::ComponentTrace::::uninitialized(LOG_SIZE) }; + } } From 506794b37269eeacd9cb8384e46d8b2fb8d363e1 Mon Sep 17 00:00:00 2001 From: Gali Michlevich Date: Thu, 26 Dec 2024 11:33:48 +0200 Subject: [PATCH 43/69] Add Seq type and packed_at func to PreprocessedColumn --- .../preprocessed_columns.rs | 125 ++++++++++++++++-- 1 file changed, 111 insertions(+), 14 deletions(-) diff --git a/crates/prover/src/constraint_framework/preprocessed_columns.rs b/crates/prover/src/constraint_framework/preprocessed_columns.rs index f196567dd..32a234a38 100644 --- a/crates/prover/src/constraint_framework/preprocessed_columns.rs +++ b/crates/prover/src/constraint_framework/preprocessed_columns.rs @@ -1,25 +1,80 @@ -use num_traits::One; +use std::simd::Simd; +use num_traits::{One, Zero}; + +use crate::core::backend::simd::m31::{PackedM31, N_LANES}; use crate::core::backend::{Backend, Col, Column}; -use crate::core::fields::m31::BaseField; +use crate::core::fields::m31::{BaseField, M31}; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; use crate::core::poly::BitReversedOrder; use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; +const SIMD_ENUMERATION_0: PackedM31 = unsafe { + PackedM31::from_simd_unchecked(Simd::from_array([ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + ])) +}; + // TODO(ilya): Where should this enum be placed? +// TODO(Gali): Consider making it a trait, add documentation for the rest of the variants. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum PreprocessedColumn { - XorTable(u32, u32, usize), + /// A column with `1` at the first position, and `0` elsewhere. IsFirst(u32), Plonk(usize), + /// A column with the numbers [0..2^log_size-1]. + Seq(u32), + XorTable(u32, u32, usize), } impl PreprocessedColumn { pub const fn name(&self) -> &'static str { match self { - PreprocessedColumn::XorTable(..) => "preprocessed.xor_table", PreprocessedColumn::IsFirst(_) => "preprocessed.is_first", PreprocessedColumn::Plonk(_) => "preprocessed.plonk", + PreprocessedColumn::Seq(_) => "preprocessed.seq", + PreprocessedColumn::XorTable(..) => "preprocessed.xor_table", + } + } + + /// Returns the values of the column at the given row. + pub fn packed_at(&self, vec_row: usize) -> PackedM31 { + match self { + PreprocessedColumn::IsFirst(log_size) => { + assert!(vec_row < (1 << log_size) / N_LANES); + if vec_row == 0 { + unsafe { + PackedM31::from_simd_unchecked(Simd::from_array(std::array::from_fn(|i| { + if i == 0 { + 1 + } else { + 0 + } + }))) + } + } else { + PackedM31::zero() + } + } + PreprocessedColumn::Seq(log_size) => { + assert!(vec_row < (1 << log_size) / N_LANES); + PackedM31::broadcast(M31::from(vec_row * N_LANES)) + SIMD_ENUMERATION_0 + } + + _ => unimplemented!(), + } + } + + /// Generates a column according to the preprocessed column chosen. + pub fn gen_preprocessed_column( + preprocessed_column: &PreprocessedColumn, + ) -> CircleEvaluation { + match preprocessed_column { + PreprocessedColumn::IsFirst(log_size) => gen_is_first(*log_size), + PreprocessedColumn::Plonk(_) | PreprocessedColumn::XorTable(..) => { + unimplemented!("eval_preprocessed_column: Plonk and XorTable are not supported.") + } + PreprocessedColumn::Seq(log_size) => gen_seq(*log_size), } } } @@ -54,19 +109,61 @@ pub fn gen_is_step_with_offset( CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col) } -pub fn gen_preprocessed_column( - preprocessed_column: &PreprocessedColumn, -) -> CircleEvaluation { - match preprocessed_column { - PreprocessedColumn::IsFirst(log_size) => gen_is_first(*log_size), - PreprocessedColumn::Plonk(_) | PreprocessedColumn::XorTable(..) => { - unimplemented!("eval_preprocessed_column: Plonk and XorTable are not supported.") - } - } +/// Generates a column with sequence of numbers from 0 to 2^log_size - 1. +pub fn gen_seq(log_size: u32) -> CircleEvaluation { + let col = Col::::from_iter((0..(1 << log_size)).map(BaseField::from)); + CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col) } pub fn gen_preprocessed_columns<'a, B: Backend>( columns: impl Iterator, ) -> Vec> { - columns.map(gen_preprocessed_column).collect() + columns + .map(PreprocessedColumn::gen_preprocessed_column) + .collect() +} + +#[cfg(test)] +mod tests { + use crate::core::backend::simd::m31::N_LANES; + use crate::core::backend::simd::SimdBackend; + use crate::core::backend::Column; + use crate::core::fields::m31::{BaseField, M31}; + const LOG_SIZE: u32 = 8; + + #[test] + fn test_gen_seq() { + let seq = super::gen_seq::(LOG_SIZE); + + for i in 0..(1 << LOG_SIZE) { + assert_eq!(seq.at(i), BaseField::from_u32_unchecked(i as u32)); + } + } + + // TODO(Gali): Add packed_at tests for xor_table and plonk. + #[test] + fn test_packed_at_is_first() { + let is_first = super::PreprocessedColumn::IsFirst(LOG_SIZE); + let expected_is_first = super::gen_is_first::(LOG_SIZE).to_cpu(); + + for i in 0..(1 << LOG_SIZE) / N_LANES { + assert_eq!( + is_first.packed_at(i).to_array(), + expected_is_first[i * N_LANES..(i + 1) * N_LANES] + ); + } + } + + #[test] + fn test_packed_at_seq() { + let seq = super::PreprocessedColumn::Seq(LOG_SIZE); + let expected_seq: [_; 1 << LOG_SIZE] = std::array::from_fn(|i| M31::from(i as u32)); + + let packed_seq = std::array::from_fn::<_, { (1 << LOG_SIZE) / N_LANES }, _>(|i| { + seq.packed_at(i).to_array() + }) + .concat(); + + assert_eq!(packed_seq, expected_seq); + } } From 81154fc68caa679ce65c1201b39121459b99b5b5 Mon Sep 17 00:00:00 2001 From: VitaliiH Date: Mon, 30 Dec 2024 13:58:56 +0200 Subject: [PATCH 44/69] no default parallel feature --- crates/prover/Cargo.toml | 2 +- crates/prover/src/core/backend/icicle/utils.rs | 9 +++++++++ crates/prover/src/examples/wide_fibonacci/mod.rs | 2 +- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index a771ebcbb..50c01a7b6 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -4,7 +4,7 @@ version.workspace = true edition.workspace = true [features] -default = ["icicle", "parallel"] +default = ["icicle"] parallel = ["rayon"] slow-tests = [] icicle = ["icicle-cuda-runtime", "icicle-core", "icicle-m31", "icicle-hash", "nvtx"] diff --git a/crates/prover/src/core/backend/icicle/utils.rs b/crates/prover/src/core/backend/icicle/utils.rs index 7de84d099..bdb28c63d 100644 --- a/crates/prover/src/core/backend/icicle/utils.rs +++ b/crates/prover/src/core/backend/icicle/utils.rs @@ -108,14 +108,23 @@ mod tests { // Initialize the `values` vector let mut values: Vec = Vec::with_capacity(values_length); + #[cfg(feature = "parallel")] use rayon::iter::IntoParallelIterator; + #[cfg(feature = "parallel")] use rayon::prelude::*; + #[cfg(feature = "parallel")] let values: Vec = (1..=values_length) .into_par_iter() .map(|i| M31(i as u32)) .collect(); + #[cfg(not(feature = "parallel"))] + let values: Vec = (1..=values_length) + .into_iter() + .map(|i| M31(i as u32)) + .collect(); + // Initialize the `folding_factors` vector let mut folding_factors = Vec::with_capacity(folding_factors_length); for i in 2..(2 + folding_factors_length) { diff --git a/crates/prover/src/examples/wide_fibonacci/mod.rs b/crates/prover/src/examples/wide_fibonacci/mod.rs index 1251f9335..f015afa39 100644 --- a/crates/prover/src/examples/wide_fibonacci/mod.rs +++ b/crates/prover/src/examples/wide_fibonacci/mod.rs @@ -237,7 +237,7 @@ mod tests { // type TheBackend = CpuBackend; let min_log = get_env_var("MIN_FIB_LOG", 2u32); - let max_log = get_env_var("MAX_FIB_LOG", 18u32); + let max_log = get_env_var("MAX_FIB_LOG", 23u32); for log_n_instances in min_log..=max_log { for _ in 0..1 { From f5e54997214ed5265852eff5c2b3ea1dc7fa9c6b Mon Sep 17 00:00:00 2001 From: Ohad <137686240+ohad-starkware@users.noreply.github.com> Date: Mon, 30 Dec 2024 16:42:50 +0200 Subject: [PATCH 45/69] derive uninit (#952) --- Cargo.lock | 20 +++ Cargo.toml | 2 +- crates/air_utils/Cargo.toml | 1 + crates/air_utils/src/lib.rs | 3 +- crates/air_utils/src/lookup_data/mod.rs | 33 ++++ crates/air_utils_derive/Cargo.toml | 13 ++ crates/air_utils_derive/src/allocation.rs | 30 ++++ crates/air_utils_derive/src/iterable_field.rs | 143 ++++++++++++++++++ crates/air_utils_derive/src/lib.rs | 17 +++ 9 files changed, 260 insertions(+), 2 deletions(-) create mode 100644 crates/air_utils/src/lookup_data/mod.rs create mode 100644 crates/air_utils_derive/Cargo.toml create mode 100644 crates/air_utils_derive/src/allocation.rs create mode 100644 crates/air_utils_derive/src/iterable_field.rs create mode 100644 crates/air_utils_derive/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index c87fc6628..5eecc10f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -593,6 +593,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.14" @@ -1042,9 +1051,20 @@ dependencies = [ "bytemuck", "itertools 0.12.1", "rayon", + "stwo-air-utils-derive", "stwo-prover", ] +[[package]] +name = "stwo-air-utils-derive" +version = "0.1.0" +dependencies = [ + "itertools 0.13.0", + "proc-macro2", + "quote", + "syn 2.0.90", +] + [[package]] name = "stwo-prover" version = "0.1.1" diff --git a/Cargo.toml b/Cargo.toml index fadd620de..d4bb782ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["crates/prover", "crates/air_utils"] +members = ["crates/prover", "crates/air_utils", "crates/air_utils_derive"] resolver = "2" [workspace.package] diff --git a/crates/air_utils/Cargo.toml b/crates/air_utils/Cargo.toml index 7d09a7eaf..4463faf8b 100644 --- a/crates/air_utils/Cargo.toml +++ b/crates/air_utils/Cargo.toml @@ -8,6 +8,7 @@ bytemuck.workspace = true itertools.workspace = true rayon = { version = "1.10.0", optional = false } stwo-prover = { path = "../prover" } +stwo-air-utils-derive = { path = "../air_utils_derive" } [lib] bench = false diff --git a/crates/air_utils/src/lib.rs b/crates/air_utils/src/lib.rs index 8603c2cee..813e60110 100644 --- a/crates/air_utils/src/lib.rs +++ b/crates/air_utils/src/lib.rs @@ -1,2 +1,3 @@ -#![feature(exact_size_is_empty, raw_slice_split, portable_simd)] +#![feature(exact_size_is_empty, raw_slice_split, portable_simd, array_chunks)] +pub mod lookup_data; pub mod trace; diff --git a/crates/air_utils/src/lookup_data/mod.rs b/crates/air_utils/src/lookup_data/mod.rs new file mode 100644 index 000000000..e7de1e864 --- /dev/null +++ b/crates/air_utils/src/lookup_data/mod.rs @@ -0,0 +1,33 @@ +#[cfg(test)] +mod tests { + use itertools::all; + use stwo_air_utils_derive::Uninitialized; + use stwo_prover::core::backend::simd::m31::PackedM31; + + /// Lookup data for the example ComponentTrace. + /// Vectors are assumed to be of the same length. + #[derive(Uninitialized)] + struct LookupData { + field0: Vec, + field1: Vec<[PackedM31; 2]>, + field2: [Vec<[PackedM31; 2]>; 2], + } + + #[test] + fn test_derived_lookup_data() { + const LOG_SIZE: u32 = 6; + let LookupData { + field0, + field1, + field2, + } = unsafe { LookupData::uninitialized(LOG_SIZE) }; + + let lengths = [ + [field0.len()].as_slice(), + [field1.len()].as_slice(), + field2.map(|v| v.len()).as_slice(), + ] + .concat(); + assert!(all(lengths, |len| len == 1 << LOG_SIZE)); + } +} diff --git a/crates/air_utils_derive/Cargo.toml b/crates/air_utils_derive/Cargo.toml new file mode 100644 index 000000000..0f36c43af --- /dev/null +++ b/crates/air_utils_derive/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "stwo-air-utils-derive" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +syn = "2.0.90" +quote = "1.0.37" +itertools = "0.13.0" +proc-macro2 = "1.0.92" diff --git a/crates/air_utils_derive/src/allocation.rs b/crates/air_utils_derive/src/allocation.rs new file mode 100644 index 000000000..bff4897a6 --- /dev/null +++ b/crates/air_utils_derive/src/allocation.rs @@ -0,0 +1,30 @@ +use proc_macro2::TokenStream; +use quote::quote; +use syn::Ident; + +use crate::iterable_field::IterableField; + +/// Implements an "Uninitialized" function for the struct. +/// Allocates 2^`log_size` slots for every Vector. +pub fn expand_uninitialized_impl( + struct_name: &Ident, + iterable_fields: &[IterableField], +) -> TokenStream { + let (field_names, allocations): (Vec<_>, Vec<_>) = iterable_fields + .iter() + .map(|f| (f.name(), f.uninitialized_field_allocation())) + .unzip(); + quote! { + impl #struct_name { + /// # Safety + /// The caller must ensure that the trace is populated before being used. + #[automatically_derived] + pub unsafe fn uninitialized(log_size: u32) -> Self { + let len = 1 << log_size; + #(#allocations)* + Self { + #(#field_names,)* + } + } + }} +} diff --git a/crates/air_utils_derive/src/iterable_field.rs b/crates/air_utils_derive/src/iterable_field.rs new file mode 100644 index 000000000..49b51d8c3 --- /dev/null +++ b/crates/air_utils_derive/src/iterable_field.rs @@ -0,0 +1,143 @@ +use proc_macro2::TokenStream; +use quote::quote; +use syn::{Data, DeriveInput, Expr, Field, Fields, Ident, Type}; + +/// Each variant represents a field that can be iterated over. +/// Used to derive implementations of `Uninitialized`, `MutIter`, and `ParMutIter`. +/// Currently supports `Vec` and `[Vec; N]` fields only. +pub(super) enum IterableField { + /// A single Vec field, e.g. `Vec`, `Vec<[u32; K]>`. + PlainVec(PlainVec), + /// An array of Vec fields, e.g. `[Vec; N]`, `[Vec<[u32; K]>; N]`. + ArrayOfVecs(ArrayOfVecs), +} + +pub(super) struct PlainVec { + name: Ident, + _ty: Type, +} +pub(super) struct ArrayOfVecs { + name: Ident, + _inner_type: Type, + outer_array_size: Expr, +} + +impl IterableField { + pub fn from_field(field: &Field) -> Result { + // Check if the field is a vector or array of vectors. + match field.ty { + // Case that type is [Vec; N]. + Type::Array(ref outer_array) => { + let inner_type = match outer_array.elem.as_ref() { + Type::Path(ref type_path) => parse_inner_type(type_path)?, + _ => Err(syn::Error::new_spanned( + outer_array.elem.clone(), + "Expected Vec type", + ))?, + }; + Ok(Self::ArrayOfVecs(ArrayOfVecs { + name: field.ident.clone().unwrap(), + outer_array_size: outer_array.len.clone(), + _inner_type: inner_type.clone(), + })) + } + // Case that type is Vec. + Type::Path(ref type_path) => { + let _ty = parse_inner_type(type_path)?; + Ok(Self::PlainVec(PlainVec { + name: field.ident.clone().unwrap(), + _ty, + })) + } + _ => Err(syn::Error::new_spanned( + field, + "Expected vector or array of vectors", + )), + } + } + + pub fn name(&self) -> &Ident { + match self { + IterableField::PlainVec(plain_vec) => &plain_vec.name, + IterableField::ArrayOfVecs(array_of_vecs) => &array_of_vecs.name, + } + } + + /// Generate the uninitialized allocation for the field. + /// E.g. `Vec::with_capacity(len); vec.set_len(len);` for a `Vec` field. + /// E.g. `[(); N].map(|_| { Vec::with_capacity(len); vec.set_len(len); })` for `[Vec; N]`. + pub fn uninitialized_field_allocation(&self) -> TokenStream { + match self { + IterableField::PlainVec(plain_vec) => { + let name = &plain_vec.name; + quote! { + let mut #name= Vec::with_capacity(len); + #name.set_len(len); + } + } + IterableField::ArrayOfVecs(array_of_vecs) => { + let name = &array_of_vecs.name; + let outer_array_size = &array_of_vecs.outer_array_size; + quote! { + let #name = [(); #outer_array_size].map(|_| { + let mut vec = Vec::with_capacity(len); + vec.set_len(len); + vec + }); + } + } + } + } +} + +// Extract the inner vector type from a path. +// Returns an error if the path is not of the form ::Vec. +fn parse_inner_type(type_path: &syn::TypePath) -> Result { + match type_path.path.segments.last() { + Some(last_segment) => { + if last_segment.ident != "Vec" { + return Err(syn::Error::new_spanned( + last_segment.ident.clone(), + "Expected Vec type", + )); + } + match &last_segment.arguments { + syn::PathArguments::AngleBracketed(args) => match args.args.first() { + Some(syn::GenericArgument::Type(inner_type)) => Ok(inner_type.clone()), + _ => Err(syn::Error::new_spanned( + args.args.first().unwrap(), + "Expected exactly one generic argument: Vec", + )), + }, + _ => Err(syn::Error::new_spanned( + last_segment.arguments.clone(), + "Expected angle-bracketed arguments: Vec<..>", + )), + } + } + _ => Err(syn::Error::new_spanned( + type_path.path.clone(), + "Expected last segment", + )), + } +} + +pub(super) fn to_iterable_fields(input: DeriveInput) -> Result, syn::Error> { + let struct_name = &input.ident; + let input = match input.data { + Data::Struct(data_struct) => Ok(data_struct), + _ => Err(syn::Error::new_spanned(struct_name, "Expected a struct")), + }?; + + match input.fields { + Fields::Named(fields) => Ok(fields + .named + .iter() + .map(IterableField::from_field) + .collect::>()?), + _ => Err(syn::Error::new_spanned( + input.fields, + "Expected named fields", + )), + } +} diff --git a/crates/air_utils_derive/src/lib.rs b/crates/air_utils_derive/src/lib.rs new file mode 100644 index 000000000..a43b1dc9a --- /dev/null +++ b/crates/air_utils_derive/src/lib.rs @@ -0,0 +1,17 @@ +mod allocation; +mod iterable_field; +use iterable_field::to_iterable_fields; +use syn::{parse_macro_input, DeriveInput}; + +#[proc_macro_derive(Uninitialized)] +pub fn derive_uninitialized(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let struct_name = input.ident.clone(); + + let iterable_fields = match to_iterable_fields(input) { + Ok(iterable_fields) => iterable_fields, + Err(err) => return err.into_compile_error().into(), + }; + + allocation::expand_uninitialized_impl(&struct_name, &iterable_fields).into() +} From c67301628b7f16b96533d544b96b2e651d2b220b Mon Sep 17 00:00:00 2001 From: shaharsamocha7 <70577611+shaharsamocha7@users.noreply.github.com> Date: Tue, 31 Dec 2024 09:57:21 +0200 Subject: [PATCH 46/69] Relation tracker - preprocessed columns (#956) --- crates/prover/src/constraint_framework/relation_tracker.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/crates/prover/src/constraint_framework/relation_tracker.rs b/crates/prover/src/constraint_framework/relation_tracker.rs index 8b522b615..e4606e7bf 100644 --- a/crates/prover/src/constraint_framework/relation_tracker.rs +++ b/crates/prover/src/constraint_framework/relation_tracker.rs @@ -5,6 +5,7 @@ use itertools::Itertools; use num_traits::Zero; use super::logup::LogupSums; +use super::preprocessed_columns::PreprocessedColumn; use super::{ Batching, EvalAtRow, FrameworkEval, InfoEvaluator, Relation, RelationEntry, TraceLocationAllocator, INTERACTION_TRACE_IDX, @@ -144,6 +145,11 @@ impl EvalAtRow for RelationTrackerEvaluator<'_> { })) }) } + + fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F { + column.packed_at(self.vec_row) + } + fn add_constraint(&mut self, _constraint: G) {} fn combine_ef(_values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF { From 6153b25b97d8f3d5c14a3e75c1e50d4fc164de3c Mon Sep 17 00:00:00 2001 From: Ohad Agadi Date: Mon, 23 Dec 2024 11:58:26 +0200 Subject: [PATCH 47/69] iter mut --- crates/air_utils/src/lookup_data/mod.rs | 76 ++++++- crates/air_utils_derive/src/iter_mut.rs | 167 ++++++++++++++ crates/air_utils_derive/src/iterable_field.rs | 203 +++++++++++++++++- crates/air_utils_derive/src/lib.rs | 14 ++ 4 files changed, 449 insertions(+), 11 deletions(-) create mode 100644 crates/air_utils_derive/src/iter_mut.rs diff --git a/crates/air_utils/src/lookup_data/mod.rs b/crates/air_utils/src/lookup_data/mod.rs index e7de1e864..2453994fe 100644 --- a/crates/air_utils/src/lookup_data/mod.rs +++ b/crates/air_utils/src/lookup_data/mod.rs @@ -1,12 +1,15 @@ #[cfg(test)] mod tests { - use itertools::all; - use stwo_air_utils_derive::Uninitialized; - use stwo_prover::core::backend::simd::m31::PackedM31; + use itertools::{all, Itertools}; + use stwo_air_utils_derive::{IterMut, Uninitialized}; + use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES}; + use stwo_prover::core::fields::m31::M31; + + use crate::trace::component_trace::ComponentTrace; /// Lookup data for the example ComponentTrace. /// Vectors are assumed to be of the same length. - #[derive(Uninitialized)] + #[derive(Uninitialized, IterMut)] struct LookupData { field0: Vec, field1: Vec<[PackedM31; 2]>, @@ -30,4 +33,69 @@ mod tests { .concat(); assert!(all(lengths, |len| len == 1 << LOG_SIZE)); } + + #[test] + fn test_derived_lookup_data_iter_mut() { + const N_COLUMNS: usize = 5; + const LOG_N_ROWS: u32 = 8; + let mut trace = ComponentTrace::::zeroed(LOG_N_ROWS); + let arr = (0..1 << LOG_N_ROWS).map(M31::from).collect_vec(); + let mut lookup_data = unsafe { LookupData::uninitialized(LOG_N_ROWS - LOG_N_LANES) }; + let (expected_field0, expected_field1, expected_field2): ( + Vec<_>, + Vec<_>, + (Vec<_>, Vec<_>), + ) = arr + .array_chunks::() + .map(|x| { + let x = PackedM31::from_array(*x); + let x1 = x + PackedM31::broadcast(M31(1)); + let x2 = x + x1; + let x3 = x + x1 + x2; + let x4 = x + x1 + x2 + x3; + ( + x4, + [x1, x1.double()], + ([x2, x2.double()], [x3, x3.double()]), + ) + }) + .multiunzip(); + + trace + .iter_mut() + .zip(arr.chunks(N_LANES)) + .zip(lookup_data.iter_mut()) + .for_each(|((row, input), lookup_data)| { + *row[0] = PackedM31::from_array(input.try_into().unwrap()); + *row[1] = *row[0] + PackedM31::broadcast(M31(1)); + *row[2] = *row[0] + *row[1]; + *row[3] = *row[0] + *row[1] + *row[2]; + *row[4] = *row[0] + *row[1] + *row[2] + *row[3]; + *lookup_data.field0 = *row[4]; + *lookup_data.field1 = [*row[1], row[1].double()]; + *lookup_data.field2[0] = [*row[2], row[2].double()]; + *lookup_data.field2[1] = [*row[3], row[3].double()]; + }); + let (actual0, actual1, actual2) = ( + lookup_data.field0, + lookup_data.field1, + (lookup_data.field2[0].clone(), lookup_data.field2[1].clone()), + ); + + assert_eq!( + format!("{expected_field0:?}"), + format!("{actual0:?}"), + "Failed on Vec" + ); + assert_eq!( + format!("{expected_field1:?}"), + format!("{actual1:?}"), + "Failed on Vec<[PackedM31; 2]>" + ); + assert_eq!( + format!("{expected_field2:?}"), + format!("{actual2:?}"), + "Failed on [Vec<[PackedM31; 2]>; 2]" + ); + } } diff --git a/crates/air_utils_derive/src/iter_mut.rs b/crates/air_utils_derive/src/iter_mut.rs new file mode 100644 index 000000000..0909d7341 --- /dev/null +++ b/crates/air_utils_derive/src/iter_mut.rs @@ -0,0 +1,167 @@ +use itertools::Itertools; +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote}; +use syn::{Ident, Lifetime}; + +use crate::iterable_field::IterableField; + +pub fn expand_iter_mut_structs( + struct_name: &Ident, + iterable_fields: &[IterableField], +) -> TokenStream { + let impl_struct_name = expand_impl_struct_name(struct_name, iterable_fields); + let mut_chunk_struct = expand_mut_chunk_struct(struct_name, iterable_fields); + let iter_mut_struct = expand_iter_mut_struct(struct_name, iterable_fields); + let iterator_impl = expand_iterator_impl(struct_name, iterable_fields); + let exact_size_iterator = expand_exact_size_iterator(struct_name); + let double_ended_iterator = expand_double_ended_iterator(struct_name, iterable_fields); + + quote! { + #impl_struct_name + #mut_chunk_struct + #iter_mut_struct + #iterator_impl + #exact_size_iterator + #double_ended_iterator + } +} + +fn expand_impl_struct_name(struct_name: &Ident, iterable_fields: &[IterableField]) -> TokenStream { + let iter_mut_name = format_ident!("{}IterMut", struct_name); + let as_mut_slice = iterable_fields + .iter() + .map(|f| f.as_mut_slice()) + .collect_vec(); + quote! { + impl #struct_name { + pub fn iter_mut(&mut self) -> #iter_mut_name<'_> { + #iter_mut_name::new( + #(self.#as_mut_slice,)* + ) + } + } + } +} + +fn expand_mut_chunk_struct(struct_name: &Ident, iterable_fields: &[IterableField]) -> TokenStream { + let lifetime = Lifetime::new("'a", Span::call_site()); + let mut_chunk_name = format_ident!("{}MutChunk", struct_name); + let (field_names, mut_chunk_types): (Vec<_>, Vec<_>) = iterable_fields + .iter() + .map(|f| (f.name(), f.mut_chunk_type(&lifetime))) + .unzip(); + + quote! { + pub struct #mut_chunk_name<#lifetime> { + #(#field_names: #mut_chunk_types,)* + } + } +} + +fn expand_iter_mut_struct(struct_name: &Ident, iterable_fields: &[IterableField]) -> TokenStream { + let lifetime = Lifetime::new("'a", Span::call_site()); + let iter_mut_name = format_ident!("{}IterMut", struct_name); + let (field_names, mut_slice_types, mut_ptr_types, as_mut_ptr): ( + Vec<_>, + Vec<_>, + Vec<_>, + Vec<_>, + ) = iterable_fields + .iter() + .map(|f| { + ( + f.name(), + f.mut_slice_type(&lifetime), + f.mut_slice_ptr_type(), + f.as_mut_ptr(), + ) + }) + .multiunzip(); + + quote! { + pub struct #iter_mut_name<#lifetime> { + #(#field_names: #mut_ptr_types,)* + phantom: std::marker::PhantomData<&#lifetime ()>, + } + impl<#lifetime> #iter_mut_name<#lifetime> { + pub fn new( + #(#field_names: #mut_slice_types,)* + ) -> Self { + Self { + #(#field_names: #as_mut_ptr,)* + phantom: std::marker::PhantomData, + } + } + } + } +} + +fn expand_iterator_impl(struct_name: &Ident, iterable_fields: &[IterableField]) -> TokenStream { + let lifetime = Lifetime::new("'a", Span::call_site()); + let iter_mut_name = format_ident!("{}IterMut", struct_name); + let mut_chunk_name = format_ident!("{}MutChunk", struct_name); + let (field_names, split_first): (Vec<_>, Vec<_>) = iterable_fields + .iter() + .map(|f| (f.name(), f.split_first())) + .unzip(); + let get_length = iterable_fields.first().unwrap().get_len(); + + quote! { + impl<#lifetime> Iterator for #iter_mut_name<#lifetime> { + type Item = #mut_chunk_name<#lifetime>; + fn next(&mut self) -> Option { + if self.#get_length == 0 { + return None; + } + let item = unsafe { + #(#split_first)* + #mut_chunk_name { + #(#field_names,)* + } + }; + Some(item) + } + fn size_hint(&self) -> (usize, Option) { + let len = self.#get_length; + (len, Some(len)) + } + } + } +} + +fn expand_exact_size_iterator(struct_name: &Ident) -> TokenStream { + let iter_mut_name = format_ident!("{}IterMut", struct_name); + quote! { + impl ExactSizeIterator for #iter_mut_name<'_> {} + } +} + +fn expand_double_ended_iterator( + struct_name: &Ident, + iterable_fields: &[IterableField], +) -> TokenStream { + let iter_mut_name = format_ident!("{}IterMut", struct_name); + let mut_chunk_name = format_ident!("{}MutChunk", struct_name); + let (field_names, split_last): (Vec<_>, Vec<_>) = iterable_fields + .iter() + .map(|f| (f.name(), f.split_last(&format_ident!("len")))) + .unzip(); + let get_length = iterable_fields.first().unwrap().get_len(); + quote! { + impl DoubleEndedIterator for #iter_mut_name<'_> { + fn next_back(&mut self) -> Option { + let len = self.#get_length; + if len == 0 { + return None; + } + let item = unsafe { + #(#split_last)* + #mut_chunk_name { + #(#field_names,)* + } + }; + Some(item) + } + } + } +} diff --git a/crates/air_utils_derive/src/iterable_field.rs b/crates/air_utils_derive/src/iterable_field.rs index 49b51d8c3..ecfe3f92a 100644 --- a/crates/air_utils_derive/src/iterable_field.rs +++ b/crates/air_utils_derive/src/iterable_field.rs @@ -1,6 +1,6 @@ use proc_macro2::TokenStream; -use quote::quote; -use syn::{Data, DeriveInput, Expr, Field, Fields, Ident, Type}; +use quote::{format_ident, quote}; +use syn::{Data, DeriveInput, Expr, Field, Fields, Ident, Lifetime, Type}; /// Each variant represents a field that can be iterated over. /// Used to derive implementations of `Uninitialized`, `MutIter`, and `ParMutIter`. @@ -14,11 +14,11 @@ pub(super) enum IterableField { pub(super) struct PlainVec { name: Ident, - _ty: Type, + ty: Type, } pub(super) struct ArrayOfVecs { name: Ident, - _inner_type: Type, + inner_type: Type, outer_array_size: Expr, } @@ -38,15 +38,15 @@ impl IterableField { Ok(Self::ArrayOfVecs(ArrayOfVecs { name: field.ident.clone().unwrap(), outer_array_size: outer_array.len.clone(), - _inner_type: inner_type.clone(), + inner_type: inner_type.clone(), })) } // Case that type is Vec. Type::Path(ref type_path) => { - let _ty = parse_inner_type(type_path)?; + let ty = parse_inner_type(type_path)?; Ok(Self::PlainVec(PlainVec { name: field.ident.clone().unwrap(), - _ty, + ty, })) } _ => Err(syn::Error::new_spanned( @@ -63,9 +63,76 @@ impl IterableField { } } + /// Generate the type of a mutable slice of the field. + /// E.g. `&'a mut [u32]` for a `Vec` field. + /// E.g. [`&'a mut [u32]; N]` for a `[Vec; N]` field. + /// Used in the `IterMut` struct. + pub fn mut_slice_type(&self, lifetime: &Lifetime) -> TokenStream { + match self { + IterableField::PlainVec(plain_vec) => { + let ty = &plain_vec.ty; + quote! { + &#lifetime mut [#ty] + } + } + IterableField::ArrayOfVecs(array_of_vecs) => { + let inner_type = &array_of_vecs.inner_type; + let outer_array_size = &array_of_vecs.outer_array_size; + quote! { + [&#lifetime mut [#inner_type]; #outer_array_size] + } + } + } + } + + /// Generate the type of a mutable chunk of the field. + /// E.g. `&'a mut u32` for a `Vec` field. + /// E.g. [`&'a mut u32; N]` for a `[Vec; N]` field. + /// Used in the `MutChunk` struct. + pub fn mut_chunk_type(&self, lifetime: &Lifetime) -> TokenStream { + match self { + IterableField::PlainVec(plain_vec) => { + let ty = &plain_vec.ty; + quote! { + &#lifetime mut #ty + } + } + IterableField::ArrayOfVecs(array_of_vecs) => { + let inner_type = &array_of_vecs.inner_type; + let array_size = &array_of_vecs.outer_array_size; + quote! { + [&#lifetime mut #inner_type; #array_size] + } + } + } + } + + /// Generate the type of a mutable slice pointer to the field. + /// E.g. `*mut [u32]` for a `Vec` field. + /// E.g. [`*mut [u32]; N]` for a `[Vec; N]` field. + /// Used in the `IterMut` struct. + pub fn mut_slice_ptr_type(&self) -> TokenStream { + match self { + IterableField::PlainVec(plain_vec) => { + let ty = &plain_vec.ty; + quote! { + *mut [#ty] + } + } + IterableField::ArrayOfVecs(array_of_vecs) => { + let inner_type = &array_of_vecs.inner_type; + let outer_array_size = &array_of_vecs.outer_array_size; + quote! { + [*mut [#inner_type]; #outer_array_size] + } + } + } + } + /// Generate the uninitialized allocation for the field. /// E.g. `Vec::with_capacity(len); vec.set_len(len);` for a `Vec` field. /// E.g. `[(); N].map(|_| { Vec::with_capacity(len); vec.set_len(len); })` for `[Vec; N]`. + /// Used to generate the `uninitialized` function. pub fn uninitialized_field_allocation(&self) -> TokenStream { match self { IterableField::PlainVec(plain_vec) => { @@ -88,6 +155,128 @@ impl IterableField { } } } + + /// Generate the code to split the first element from the field. + /// E.g. `let (head, tail) = self.field.split_at_mut(1); + /// self.field = tail; let field = &mut (*head)[0];` + /// Used for the `next` function in the iterator struct. + pub fn split_first(&self) -> TokenStream { + match self { + IterableField::PlainVec(plain_vec) => { + let name = &plain_vec.name; + let head = format_ident!("{}_head", name); + let tail = format_ident!("{}_tail", name); + quote! { + let (#head, #tail) = self.#name.split_at_mut(1); + self.#name = #tail; + let #name = &mut (*(#head))[0]; + } + } + IterableField::ArrayOfVecs(array_of_vecs) => { + let name = &array_of_vecs.name; + quote! { + let #name = self.#name.each_mut().map(|v| { + let (head, tail) = v.split_at_mut(1); + *v = tail; + &mut (*head)[0] + }); + } + } + } + } + + /// Generate the code to split the last element from the field. + /// E.g. `let (head, tail) = self.field.split_at_mut(len - 1); + /// Used for the `next_back` function in the DoubleEnded impl. + pub fn split_last(&self, length: &Ident) -> TokenStream { + match self { + IterableField::PlainVec(plain_vec) => { + let name = &plain_vec.name; + let head = format_ident!("{}_head", name); + let tail = format_ident!("{}_tail", name); + quote! { + let ( + #head, + #tail, + ) = self.#name.split_at_mut(#length - 1); + self.#name = #head; + let #name = &mut (*#tail)[0]; + } + } + IterableField::ArrayOfVecs(array_of_vecs) => { + let name = &array_of_vecs.name; + quote! { + let #name = self.#name.each_mut().map(|v| { + let (head, tail) = v.split_at_mut(#length - 1); + *v = head; + &mut (*tail)[0] + }); + } + } + } + } + + /// Generate the code to get a mutable slice of the field. + /// E.g. `self.field.as_mut_slice()` + /// E.g. `self.field.each_mut().map(|v| v.as_mut_slice())` + /// Used to generate the arguments for the IterMut 'new' function call. + pub fn as_mut_slice(&self) -> TokenStream { + match self { + IterableField::PlainVec(plain_vec) => { + let name = &plain_vec.name; + quote! { + #name.as_mut_slice() + } + } + IterableField::ArrayOfVecs(array_of_vecs) => { + let name = &array_of_vecs.name; + quote! { + #name.each_mut().map(|v| v.as_mut_slice()) + } + } + } + } + + /// Generate the code to get a mutable pointer a mutable slice of the field. + /// E.g. `'a mut [u32]` -> `*mut [u32]`. Achieved by casting: `as *mut _`. + /// Used for the `IterMut` struct. + pub fn as_mut_ptr(&self) -> TokenStream { + match self { + IterableField::PlainVec(plain_vec) => { + let name = &plain_vec.name; + quote! { + #name as *mut _ + } + } + IterableField::ArrayOfVecs(array_of_vecs) => { + let name = &array_of_vecs.name; + quote! { + #name.map(|v| v as *mut _) + } + } + } + } + + /// Generate the code to get the length of the field. + /// Length is assumed to be the same for all fields on every coordinate. + /// E.g. `self.field.len()` + /// E.g. `self.field[0].len()` + pub fn get_len(&self) -> TokenStream { + match self { + IterableField::PlainVec(plain_vec) => { + let name = &plain_vec.name; + quote! { + #name.len() + } + } + IterableField::ArrayOfVecs(array_of_vecs) => { + let name = &array_of_vecs.name; + quote! { + #name[0].len() + } + } + } + } } // Extract the inner vector type from a path. diff --git a/crates/air_utils_derive/src/lib.rs b/crates/air_utils_derive/src/lib.rs index a43b1dc9a..d8ee8178a 100644 --- a/crates/air_utils_derive/src/lib.rs +++ b/crates/air_utils_derive/src/lib.rs @@ -1,4 +1,5 @@ mod allocation; +mod iter_mut; mod iterable_field; use iterable_field::to_iterable_fields; use syn::{parse_macro_input, DeriveInput}; @@ -15,3 +16,16 @@ pub fn derive_uninitialized(input: proc_macro::TokenStream) -> proc_macro::Token allocation::expand_uninitialized_impl(&struct_name, &iterable_fields).into() } + +#[proc_macro_derive(IterMut)] +pub fn derive_mut_iter(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let struct_name = input.ident.clone(); + + let iterable_fields = match to_iterable_fields(input) { + Ok(iterable_fields) => iterable_fields, + Err(err) => return err.into_compile_error().into(), + }; + + iter_mut::expand_iter_mut_structs(&struct_name, &iterable_fields).into() +} From 3b018eab5f24ec653d4a9d6f09bd59eb3f32f1e8 Mon Sep 17 00:00:00 2001 From: Ohad <137686240+ohad-starkware@users.noreply.github.com> Date: Tue, 31 Dec 2024 13:45:25 +0200 Subject: [PATCH 48/69] derive par iter (#955) --- crates/air_utils/src/lookup_data/mod.rs | 73 +++++++- crates/air_utils_derive/src/iterable_field.rs | 37 ++++ crates/air_utils_derive/src/lib.rs | 14 ++ crates/air_utils_derive/src/par_iter.rs | 163 ++++++++++++++++++ 4 files changed, 284 insertions(+), 3 deletions(-) create mode 100644 crates/air_utils_derive/src/par_iter.rs diff --git a/crates/air_utils/src/lookup_data/mod.rs b/crates/air_utils/src/lookup_data/mod.rs index 2453994fe..a8be8a495 100644 --- a/crates/air_utils/src/lookup_data/mod.rs +++ b/crates/air_utils/src/lookup_data/mod.rs @@ -1,7 +1,9 @@ #[cfg(test)] mod tests { use itertools::{all, Itertools}; - use stwo_air_utils_derive::{IterMut, Uninitialized}; + use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; + use rayon::slice::ParallelSlice; + use stwo_air_utils_derive::{IterMut, ParMutIter, Uninitialized}; use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES}; use stwo_prover::core::fields::m31::M31; @@ -9,7 +11,7 @@ mod tests { /// Lookup data for the example ComponentTrace. /// Vectors are assumed to be of the same length. - #[derive(Uninitialized, IterMut)] + #[derive(Uninitialized, IterMut, ParMutIter)] struct LookupData { field0: Vec, field1: Vec<[PackedM31; 2]>, @@ -35,7 +37,7 @@ mod tests { } #[test] - fn test_derived_lookup_data_iter_mut() { + fn test_derived_lookup_data_iter() { const N_COLUMNS: usize = 5; const LOG_N_ROWS: u32 = 8; let mut trace = ComponentTrace::::zeroed(LOG_N_ROWS); @@ -98,4 +100,69 @@ mod tests { "Failed on [Vec<[PackedM31; 2]>; 2]" ); } + + #[test] + fn test_derived_lookup_data_par_iter() { + const N_COLUMNS: usize = 5; + const LOG_N_ROWS: u32 = 8; + let mut trace = ComponentTrace::::zeroed(LOG_N_ROWS); + let arr = (0..1 << LOG_N_ROWS).map(M31::from).collect_vec(); + let mut lookup_data = unsafe { LookupData::uninitialized(LOG_N_ROWS - LOG_N_LANES) }; + let (expected_field0, expected_field1, expected_field2): ( + Vec<_>, + Vec<_>, + (Vec<_>, Vec<_>), + ) = arr + .array_chunks::() + .map(|x| { + let x = PackedM31::from_array(*x); + let x1 = x + PackedM31::broadcast(M31(1)); + let x2 = x + x1; + let x3 = x + x1 + x2; + let x4 = x + x1 + x2 + x3; + ( + x4, + [x1, x1.double()], + ([x2, x2.double()], [x3, x3.double()]), + ) + }) + .multiunzip(); + + trace + .par_iter_mut() + .zip(arr.par_chunks(N_LANES).into_par_iter()) + .zip(lookup_data.par_iter_mut()) + .for_each(|((row, input), lookup_data)| { + *row[0] = PackedM31::from_array(input.try_into().unwrap()); + *row[1] = *row[0] + PackedM31::broadcast(M31(1)); + *row[2] = *row[0] + *row[1]; + *row[3] = *row[0] + *row[1] + *row[2]; + *row[4] = *row[0] + *row[1] + *row[2] + *row[3]; + *lookup_data.field0 = *row[4]; + *lookup_data.field1 = [*row[1], row[1].double()]; + *lookup_data.field2[0] = [*row[2], row[2].double()]; + *lookup_data.field2[1] = [*row[3], row[3].double()]; + }); + let (actual0, actual1, actual2) = ( + lookup_data.field0, + lookup_data.field1, + (lookup_data.field2[0].clone(), lookup_data.field2[1].clone()), + ); + + assert_eq!( + format!("{expected_field0:?}"), + format!("{actual0:?}"), + "Failed on Vec" + ); + assert_eq!( + format!("{expected_field1:?}"), + format!("{actual1:?}"), + "Failed on Vec<[PackedM31; 2]>" + ); + assert_eq!( + format!("{expected_field2:?}"), + format!("{actual2:?}"), + "Failed on [Vec<[PackedM31; 2]>; 2]" + ); + } } diff --git a/crates/air_utils_derive/src/iterable_field.rs b/crates/air_utils_derive/src/iterable_field.rs index ecfe3f92a..e78af08b7 100644 --- a/crates/air_utils_derive/src/iterable_field.rs +++ b/crates/air_utils_derive/src/iterable_field.rs @@ -216,6 +216,43 @@ impl IterableField { } } + /// Generate the code to split the field at a given index. + /// E.g. `let (head, tail) = self.field.split_at_mut(index);` + /// E.g. `let (head, tail) = self.field.each_mut().map(|v| v.split_at_mut(index));` + /// Used for the `split_at` function in the Producer impl. + pub fn split_at(&self, index: &Ident) -> TokenStream { + match self { + IterableField::PlainVec(plain_vec) => { + let name = &plain_vec.name; + let head = format_ident!("{}_head", name); + let tail = format_ident!("{}_tail", name); + quote! { + let ( + #head, + #tail + ) = self.#name.split_at_mut(#index); + } + } + IterableField::ArrayOfVecs(array_of_vecs) => { + let name = &array_of_vecs.name; + let head = format_ident!("{}_head", name); + let tail = format_ident!("{}_tail", name); + let array_size = &array_of_vecs.outer_array_size; + quote! { + let ( + mut #head, + mut #tail + ):([_; #array_size],[_; #array_size]) = unsafe { (std::mem::zeroed(), std::mem::zeroed()) }; + self.#name.into_iter().enumerate().for_each(|(i, v)| { + let (head, tail) = v.split_at_mut(#index); + #head[i] = head; + #tail[i] = tail; + }); + } + } + } + } + /// Generate the code to get a mutable slice of the field. /// E.g. `self.field.as_mut_slice()` /// E.g. `self.field.each_mut().map(|v| v.as_mut_slice())` diff --git a/crates/air_utils_derive/src/lib.rs b/crates/air_utils_derive/src/lib.rs index d8ee8178a..081b389e6 100644 --- a/crates/air_utils_derive/src/lib.rs +++ b/crates/air_utils_derive/src/lib.rs @@ -1,6 +1,7 @@ mod allocation; mod iter_mut; mod iterable_field; +mod par_iter; use iterable_field::to_iterable_fields; use syn::{parse_macro_input, DeriveInput}; @@ -29,3 +30,16 @@ pub fn derive_mut_iter(input: proc_macro::TokenStream) -> proc_macro::TokenStrea iter_mut::expand_iter_mut_structs(&struct_name, &iterable_fields).into() } + +#[proc_macro_derive(ParMutIter)] +pub fn derive_par_mut_iter(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let struct_name = input.ident.clone(); + + let iterable_fields = match to_iterable_fields(input) { + Ok(iterable_fields) => iterable_fields, + Err(err) => return err.into_compile_error().into(), + }; + + par_iter::expand_par_iter_mut_structs(&struct_name, &iterable_fields).into() +} diff --git a/crates/air_utils_derive/src/par_iter.rs b/crates/air_utils_derive/src/par_iter.rs new file mode 100644 index 000000000..e7e2b8947 --- /dev/null +++ b/crates/air_utils_derive/src/par_iter.rs @@ -0,0 +1,163 @@ +use itertools::Itertools; +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote}; +use syn::{Ident, Lifetime}; + +use crate::iterable_field::IterableField; + +pub fn expand_par_iter_mut_structs( + struct_name: &Ident, + iterable_fields: &[IterableField], +) -> TokenStream { + let lifetime = Lifetime::new("'a", Span::call_site()); + let split_index = format_ident!("index"); + + let struct_code = generate_struct_impl(struct_name, iterable_fields); + let producer_code = + generate_row_producer(struct_name, iterable_fields, &lifetime, &split_index); + let oar_iter_struct = generate_par_iter_struct(struct_name, iterable_fields, &lifetime); + let impl_par_iter = generate_parallel_iterator_impls(struct_name, iterable_fields, &lifetime); + + quote! { + #struct_code + #producer_code + #oar_iter_struct + #impl_par_iter + } +} + +fn generate_struct_impl(struct_name: &Ident, iterable_fields: &[IterableField]) -> TokenStream { + let par_iter_mut_name = format_ident!("{}ParIterMut", struct_name); + let as_mut_slice = iterable_fields.iter().map(|f| f.as_mut_slice()); + quote! { + impl #struct_name { + pub fn par_iter_mut(&mut self) -> #par_iter_mut_name<'_> { + #par_iter_mut_name::new( + #(self.#as_mut_slice,)* + ) + } + } + } +} + +fn generate_row_producer( + struct_name: &Ident, + iterable_fields: &[IterableField], + lifetime: &Lifetime, + split_index: &Ident, +) -> TokenStream { + let row_producer_name = format_ident!("{}RowProducer", struct_name); + let mut_chunk_name = format_ident!("{}MutChunk", struct_name); + let iter_mut_name = format_ident!("{}IterMut", struct_name); + let (field_names, mut_slice_types, split_at): (Vec<_>, Vec<_>, Vec<_>) = iterable_fields + .iter() + .map(|f| { + ( + f.name(), + f.mut_slice_type(lifetime), + f.split_at(split_index), + ) + }) + .multiunzip(); + let field_names_head = field_names.iter().map(|f| format_ident!("{}_head", f)); + let field_names_tail = field_names.iter().map(|f| format_ident!("{}_tail", f)); + quote! { + pub struct #row_producer_name<#lifetime> { + #(#field_names: #mut_slice_types,)* + } + impl<#lifetime> rayon::iter::plumbing::Producer for #row_producer_name<#lifetime> { + type Item = #mut_chunk_name<#lifetime>; + type IntoIter = #iter_mut_name<#lifetime>; + + #[allow(invalid_value)] + fn split_at(self, index: usize) -> (Self, Self) { + #(#split_at)* + ( + #row_producer_name { + #(#field_names: #field_names_head,)* + }, + #row_producer_name { + #(#field_names: #field_names_tail,)* + } + ) + } + + fn into_iter(self) -> Self::IntoIter { + #iter_mut_name::new(#(self.#field_names),*) + } + } + } +} + +fn generate_par_iter_struct( + struct_name: &Ident, + iterable_fields: &[IterableField], + lifetime: &Lifetime, +) -> TokenStream { + let par_iter_mut_name = format_ident!("{struct_name}ParIterMut"); + let (field_names, mut_slice_types): (Vec<_>, Vec<_>) = iterable_fields + .iter() + .map(|f| (f.name(), f.mut_slice_type(lifetime))) + .unzip(); + quote! { + pub struct #par_iter_mut_name<#lifetime> { + #(#field_names: #mut_slice_types,)* + } + + impl<#lifetime> #par_iter_mut_name<#lifetime> { + pub fn new( + #(#field_names: #mut_slice_types,)* + ) -> Self { + Self { + #(#field_names,)* + } + } + } + } +} + +fn generate_parallel_iterator_impls( + struct_name: &Ident, + iterable_fields: &[IterableField], + lifetime: &Lifetime, +) -> TokenStream { + let par_iter_mut_name = format_ident!("{}ParIterMut", struct_name); + let mut_chunk_name = format_ident!("{}MutChunk", struct_name); + let row_producer_name = format_ident!("{}RowProducer", struct_name); + let field_names = iterable_fields.iter().map(|f| f.name()); + let get_length = iterable_fields.first().unwrap().get_len(); + quote! { + impl<#lifetime> rayon::prelude::ParallelIterator for #par_iter_mut_name<#lifetime> { + type Item = #mut_chunk_name<#lifetime>; + + fn drive_unindexed(self, consumer: D) -> D::Result + where + D: rayon::iter::plumbing::UnindexedConsumer, + { + rayon::iter::plumbing::bridge(self, consumer) + } + + fn opt_len(&self) -> Option { + Some(self.len()) + } + } + + impl rayon::iter::IndexedParallelIterator for #par_iter_mut_name<'_> { + fn len(&self) -> usize { + self.#get_length + } + + fn drive>(self, consumer: D) -> D::Result { + rayon::iter::plumbing::bridge(self, consumer) + } + + fn with_producer>(self, callback: CB) -> CB::Output { + callback.callback( + #row_producer_name { + #(#field_names : self.#field_names,)* + } + ) + } + } + } +} From c5c6eda01db70a43bcfa04b358c246522150c567 Mon Sep 17 00:00:00 2001 From: Ohad <137686240+ohad-starkware@users.noreply.github.com> Date: Tue, 31 Dec 2024 15:42:10 +0200 Subject: [PATCH 49/69] rename par iter (#957) --- crates/air_utils/src/lookup_data/mod.rs | 4 ++-- crates/air_utils_derive/src/iterable_field.rs | 2 +- crates/air_utils_derive/src/lib.rs | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/air_utils/src/lookup_data/mod.rs b/crates/air_utils/src/lookup_data/mod.rs index a8be8a495..234e4be44 100644 --- a/crates/air_utils/src/lookup_data/mod.rs +++ b/crates/air_utils/src/lookup_data/mod.rs @@ -3,7 +3,7 @@ mod tests { use itertools::{all, Itertools}; use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; use rayon::slice::ParallelSlice; - use stwo_air_utils_derive::{IterMut, ParMutIter, Uninitialized}; + use stwo_air_utils_derive::{IterMut, ParIterMut, Uninitialized}; use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES}; use stwo_prover::core::fields::m31::M31; @@ -11,7 +11,7 @@ mod tests { /// Lookup data for the example ComponentTrace. /// Vectors are assumed to be of the same length. - #[derive(Uninitialized, IterMut, ParMutIter)] + #[derive(Uninitialized, IterMut, ParIterMut)] struct LookupData { field0: Vec, field1: Vec<[PackedM31; 2]>, diff --git a/crates/air_utils_derive/src/iterable_field.rs b/crates/air_utils_derive/src/iterable_field.rs index e78af08b7..6cb80ea15 100644 --- a/crates/air_utils_derive/src/iterable_field.rs +++ b/crates/air_utils_derive/src/iterable_field.rs @@ -3,7 +3,7 @@ use quote::{format_ident, quote}; use syn::{Data, DeriveInput, Expr, Field, Fields, Ident, Lifetime, Type}; /// Each variant represents a field that can be iterated over. -/// Used to derive implementations of `Uninitialized`, `MutIter`, and `ParMutIter`. +/// Used to derive implementations of `Uninitialized`, `MutIter`, and `ParIterMut`. /// Currently supports `Vec` and `[Vec; N]` fields only. pub(super) enum IterableField { /// A single Vec field, e.g. `Vec`, `Vec<[u32; K]>`. diff --git a/crates/air_utils_derive/src/lib.rs b/crates/air_utils_derive/src/lib.rs index 081b389e6..33bd91bf7 100644 --- a/crates/air_utils_derive/src/lib.rs +++ b/crates/air_utils_derive/src/lib.rs @@ -31,7 +31,7 @@ pub fn derive_mut_iter(input: proc_macro::TokenStream) -> proc_macro::TokenStrea iter_mut::expand_iter_mut_structs(&struct_name, &iterable_fields).into() } -#[proc_macro_derive(ParMutIter)] +#[proc_macro_derive(ParIterMut)] pub fn derive_par_mut_iter(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(input as DeriveInput); let struct_name = input.ident.clone(); From 9d3afa43d724577ae5d50da90992f55382755626 Mon Sep 17 00:00:00 2001 From: shaharsamocha7 <70577611+shaharsamocha7@users.noreply.github.com> Date: Tue, 31 Dec 2024 16:11:20 +0200 Subject: [PATCH 50/69] log_size to preprocessed columns (#958) --- .../src/constraint_framework/preprocessed_columns.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/crates/prover/src/constraint_framework/preprocessed_columns.rs b/crates/prover/src/constraint_framework/preprocessed_columns.rs index 32a234a38..5c2a52df7 100644 --- a/crates/prover/src/constraint_framework/preprocessed_columns.rs +++ b/crates/prover/src/constraint_framework/preprocessed_columns.rs @@ -37,6 +37,15 @@ impl PreprocessedColumn { } } + pub fn log_size(&self) -> u32 { + match self { + PreprocessedColumn::IsFirst(log_size) => *log_size, + PreprocessedColumn::Seq(log_size) => *log_size, + PreprocessedColumn::XorTable(log_size, ..) => *log_size, + PreprocessedColumn::Plonk(_) => unimplemented!(), + } + } + /// Returns the values of the column at the given row. pub fn packed_at(&self, vec_row: usize) -> PackedM31 { match self { From bede666c0aa248b2f6b375d82c38f7251da0d591 Mon Sep 17 00:00:00 2001 From: Ohad <137686240+ohad-starkware@users.noreply.github.com> Date: Thu, 2 Jan 2025 13:00:24 +0200 Subject: [PATCH 51/69] update toolchain 20250102 (#959) --- .github/workflows/benchmarks-pages.yaml | 2 +- .github/workflows/ci.yaml | 40 ++++++++++++------------- .github/workflows/coverage.yaml | 4 +-- rust-toolchain.toml | 2 +- scripts/clippy.sh | 2 +- scripts/rust_fmt.sh | 2 +- scripts/test_avx.sh | 2 +- 7 files changed, 27 insertions(+), 27 deletions(-) diff --git a/.github/workflows/benchmarks-pages.yaml b/.github/workflows/benchmarks-pages.yaml index ef8269911..fc87b2d94 100644 --- a/.github/workflows/benchmarks-pages.yaml +++ b/.github/workflows/benchmarks-pages.yaml @@ -18,7 +18,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-12-17 + toolchain: nightly-2025-01-02 - name: Run benchmark run: ./scripts/bench.sh -- --output-format bencher | tee output.txt - name: Download previous benchmark data diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index a37679771..4d2bb45fa 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -25,7 +25,7 @@ jobs: - uses: dtolnay/rust-toolchain@master with: components: rustfmt - toolchain: nightly-2024-12-17 + toolchain: nightly-2025-01-02 - uses: Swatinem/rust-cache@v2 - run: scripts/rust_fmt.sh --check @@ -36,7 +36,7 @@ jobs: - uses: dtolnay/rust-toolchain@master with: components: clippy - toolchain: nightly-2024-12-17 + toolchain: nightly-2025-01-02 - uses: Swatinem/rust-cache@v2 - run: scripts/clippy.sh @@ -46,9 +46,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-12-17 + toolchain: nightly-2025-01-02 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-12-17 doc + - run: cargo +nightly-2025-01-02 doc run-wasm32-wasip1-tests: runs-on: ubuntu-latest @@ -56,7 +56,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-12-17 + toolchain: nightly-2025-01-02 targets: wasm32-wasip1 - uses: taiki-e/install-action@v2 with: @@ -73,7 +73,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-12-17 + toolchain: nightly-2025-01-02 targets: wasm32-unknown-unknown - uses: Swatinem/rust-cache@v2 - uses: jetli/wasm-pack-action@v0.4.0 @@ -89,9 +89,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-12-17 + toolchain: nightly-2025-01-02 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-12-17 test + - run: cargo +nightly-2025-01-02 test env: RUSTFLAGS: -C target-feature=+neon @@ -104,9 +104,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-12-17 + toolchain: nightly-2025-01-02 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-12-17 test + - run: cargo +nightly-2025-01-02 test env: RUSTFLAGS: -C target-cpu=native -C target-feature=+${{ matrix.target-feature }} @@ -116,7 +116,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-12-17 + toolchain: nightly-2025-01-02 - name: Run benchmark run: ./scripts/bench.sh -- --output-format bencher | tee output.txt - name: Download previous benchmark data @@ -142,7 +142,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-12-17 + toolchain: nightly-2025-01-02 - name: Run benchmark run: ./scripts/bench.sh --features="parallel" -- --output-format bencher | tee output.txt - name: Download previous benchmark data @@ -168,9 +168,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-12-17 + toolchain: nightly-2025-01-02 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-12-17 test + - run: cargo +nightly-2025-01-02 test run-slow-tests: runs-on: ubuntu-latest @@ -178,9 +178,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-12-17 + toolchain: nightly-2025-01-02 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-12-17 test --release --features="slow-tests" + - run: cargo +nightly-2025-01-02 test --release --features="slow-tests" run-tests-parallel: runs-on: ubuntu-latest @@ -188,9 +188,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-12-17 + toolchain: nightly-2025-01-02 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-12-17 test --features="parallel" + - run: cargo +nightly-2025-01-02 test --features="parallel" machete: runs-on: ubuntu-latest @@ -201,9 +201,9 @@ jobs: toolchain: nightly-2024-01-04 - uses: Swatinem/rust-cache@v2 - name: Install Machete - run: cargo +nightly-2024-12-17 install --locked cargo-machete + run: cargo +nightly-2025-01-02 install --locked cargo-machete - name: Run Machete (detect unused dependencies) - run: cargo +nightly-2024-12-17 machete + run: cargo +nightly-2025-01-02 machete all-tests: runs-on: ubuntu-latest diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml index 34b92be13..05a1482cb 100644 --- a/.github/workflows/coverage.yaml +++ b/.github/workflows/coverage.yaml @@ -12,14 +12,14 @@ jobs: - uses: dtolnay/rust-toolchain@master with: components: rustfmt - toolchain: nightly-2024-12-17 + toolchain: nightly-2025-01-02 - uses: Swatinem/rust-cache@v2 - name: Install cargo-llvm-cov uses: taiki-e/install-action@cargo-llvm-cov # TODO: Merge coverage reports for tests on different architectures. # - name: Generate code coverage - run: cargo +nightly-2024-12-17 llvm-cov --codecov --output-path codecov.json + run: cargo +nightly-2025-01-02 llvm-cov --codecov --output-path codecov.json env: RUSTFLAGS: "-C target-feature=+avx512f" - name: Upload coverage to Codecov diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 690b698f9..27381425f 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "nightly-2024-12-17" +channel = "nightly-2025-01-02" diff --git a/scripts/clippy.sh b/scripts/clippy.sh index a3f74f4b8..9198c8648 100755 --- a/scripts/clippy.sh +++ b/scripts/clippy.sh @@ -1,3 +1,3 @@ #!/bin/bash -cargo +nightly-2024-12-17 clippy "$@" --all-targets --all-features -- -D warnings -D future-incompatible \ +cargo +nightly-2025-01-02 clippy "$@" --all-targets --all-features -- -D warnings -D future-incompatible \ -D nonstandard-style -D rust-2018-idioms -D unused diff --git a/scripts/rust_fmt.sh b/scripts/rust_fmt.sh index ae4a9f7f8..9f80b191c 100755 --- a/scripts/rust_fmt.sh +++ b/scripts/rust_fmt.sh @@ -1,3 +1,3 @@ #!/bin/bash -cargo +nightly-2024-12-17 fmt --all -- "$@" +cargo +nightly-2025-01-02 fmt --all -- "$@" diff --git a/scripts/test_avx.sh b/scripts/test_avx.sh index cb0ac2445..eb4429d3a 100755 --- a/scripts/test_avx.sh +++ b/scripts/test_avx.sh @@ -1,4 +1,4 @@ #!/bin/bash # Can be used as a drop in replacement for `cargo test` with avx512f flag on. # For example, `./scripts/test_avx.sh` will run all tests(not only avx). -RUSTFLAGS="-Awarnings -C target-cpu=native -C target-feature=+avx512f -C opt-level=2" cargo +nightly-2024-12-17 test "$@" +RUSTFLAGS="-Awarnings -C target-cpu=native -C target-feature=+avx512f -C opt-level=2" cargo +nightly-2025-01-02 test "$@" From 36b17ac3b3d66106dbe5268305a5812bcd490e5c Mon Sep 17 00:00:00 2001 From: ilyalesokhin-starkware Date: Thu, 2 Jan 2025 17:15:40 +0200 Subject: [PATCH 52/69] Fix mask so that -1 appears before 0. (#960) --- crates/prover/src/constraint_framework/mod.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 37baa6167..22809152d 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -231,10 +231,10 @@ macro_rules! logup_proxy { // offset from the is_first column when constant columns are supported. let (cur_cumsum, prev_row_cumsum) = match self.logup.claimed_sum.clone() { Some((claimed_sum, claimed_row_index)) => { - let [cur_cumsum, prev_row_cumsum, claimed_cumsum] = self + let [prev_row_cumsum, cur_cumsum, claimed_cumsum] = self .next_extension_interaction_mask( self.logup.interaction, - [0, -1, claimed_row_index as isize], + [-1, 0, claimed_row_index as isize], ); // Constrain that the claimed_sum in case that it is not equal to the total_sum. @@ -244,8 +244,8 @@ macro_rules! logup_proxy { (cur_cumsum, prev_row_cumsum) } None => { - let [cur_cumsum, prev_row_cumsum] = - self.next_extension_interaction_mask(self.logup.interaction, [0, -1]); + let [prev_row_cumsum, cur_cumsum] = + self.next_extension_interaction_mask(self.logup.interaction, [-1, 0]); (cur_cumsum, prev_row_cumsum) } }; From 95def20b6bd9fc72606b842a38ca5b737371ac45 Mon Sep 17 00:00:00 2001 From: Ohad <137686240+ohad-starkware@users.noreply.github.com> Date: Mon, 6 Jan 2025 11:55:53 +0200 Subject: [PATCH 53/69] clone for stark proofs (#961) --- crates/prover/src/core/fri.rs | 4 ++-- crates/prover/src/core/pcs/prover.rs | 2 +- crates/prover/src/core/prover/mod.rs | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/prover/src/core/fri.rs b/crates/prover/src/core/fri.rs index d3684c1a8..6bcb3cb2a 100644 --- a/crates/prover/src/core/fri.rs +++ b/crates/prover/src/core/fri.rs @@ -639,7 +639,7 @@ impl LinePolyDegreeBound { } /// A FRI proof. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct FriProof { pub first_layer: FriLayerProof, pub inner_layers: Vec>, @@ -654,7 +654,7 @@ pub const FOLD_STEP: u32 = 1; pub const CIRCLE_TO_LINE_FOLD_STEP: u32 = 1; /// Proof of an individual FRI layer. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct FriLayerProof { /// Values that the verifier needs but cannot deduce from previous computations, in the /// order they are needed. This complements the values that were queried. These must be diff --git a/crates/prover/src/core/pcs/prover.rs b/crates/prover/src/core/pcs/prover.rs index ae017c0e5..d60c82bb1 100644 --- a/crates/prover/src/core/pcs/prover.rs +++ b/crates/prover/src/core/pcs/prover.rs @@ -147,7 +147,7 @@ impl<'a, B: BackendForChannel, MC: MerkleChannel> CommitmentSchemeProver<'a, } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct CommitmentSchemeProof { pub commitments: TreeVec, pub sampled_values: TreeVec>>, diff --git a/crates/prover/src/core/prover/mod.rs b/crates/prover/src/core/prover/mod.rs index 20e70c533..8a04432b1 100644 --- a/crates/prover/src/core/prover/mod.rs +++ b/crates/prover/src/core/prover/mod.rs @@ -155,7 +155,7 @@ pub enum VerificationError { ProofOfWork, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct StarkProof(pub CommitmentSchemeProof); impl StarkProof { From af5475cb946a29de7d04dfa8c77a17034f888df4 Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Tue, 7 Jan 2025 00:49:11 -0500 Subject: [PATCH 54/69] Update formatted expressions (#949) --- .../constraint_framework/expr/evaluator.rs | 18 ++++++------ .../src/constraint_framework/expr/format.rs | 23 +++++++-------- .../preprocessed_columns.rs | 8 +++--- .../prover/src/examples/state_machine/mod.rs | 28 +++++++++---------- 4 files changed, 39 insertions(+), 38 deletions(-) diff --git a/crates/prover/src/constraint_framework/expr/evaluator.rs b/crates/prover/src/constraint_framework/expr/evaluator.rs index 7fef20254..6b9238d23 100644 --- a/crates/prover/src/constraint_framework/expr/evaluator.rs +++ b/crates/prover/src/constraint_framework/expr/evaluator.rs @@ -194,23 +194,23 @@ mod tests { fn test_expr_evaluator() { let test_struct = TestStruct {}; let eval = test_struct.evaluate(ExprEvaluator::new(16, false)); - let expected = "let intermediate0 = (col_1_1[0]) * (col_1_2[0]); + let expected = "let intermediate0 = (trace_1_column_1_offset_0) * (trace_1_column_2_offset_0); \ - let intermediate1 = (TestRelation_alpha0) * (col_1_0[0]) \ - + (TestRelation_alpha1) * (col_1_1[0]) \ - + (TestRelation_alpha2) * (col_1_2[0]) \ + let intermediate1 = (TestRelation_alpha0) * (trace_1_column_0_offset_0) \ + + (TestRelation_alpha1) * (trace_1_column_1_offset_0) \ + + (TestRelation_alpha2) * (trace_1_column_2_offset_0) \ - (TestRelation_z); \ - let constraint_0 = ((col_1_0[0]) * (intermediate0)) * (1 / (col_1_0[0] + col_1_1[0])); + let constraint_0 = ((trace_1_column_0_offset_0) * (intermediate0)) * (1 / (trace_1_column_0_offset_0 + trace_1_column_1_offset_0)); \ - let constraint_1 = (SecureCol(col_2_3[0], col_2_4[0], col_2_5[0], col_2_6[0]) \ - - (SecureCol(col_2_3[-1], col_2_4[-1], col_2_5[-1], col_2_6[-1]) \ - - ((total_sum) * (preprocessed.is_first)))) \ + let constraint_1 = (QM31Impl::from_partial_evals([trace_2_column_3_offset_0, trace_2_column_4_offset_0, trace_2_column_5_offset_0, trace_2_column_6_offset_0]) \ + - (QM31Impl::from_partial_evals([trace_2_column_3_offset_neg_1, trace_2_column_4_offset_neg_1, trace_2_column_5_offset_neg_1, trace_2_column_6_offset_neg_1]) \ + - ((total_sum) * (preprocessed_is_first)))) \ * (intermediate1) \ - - (1);" + - (qm31(1, 0, 0, 0));" .to_string(); assert_eq!(eval.format_constraints(), expected); diff --git a/crates/prover/src/constraint_framework/expr/format.rs b/crates/prover/src/constraint_framework/expr/format.rs index f6d30163d..f286f1135 100644 --- a/crates/prover/src/constraint_framework/expr/format.rs +++ b/crates/prover/src/constraint_framework/expr/format.rs @@ -11,13 +11,18 @@ impl BaseExpr { offset, }) => { let offset_str = if *offset == CLAIMED_SUM_DUMMY_OFFSET as isize { - "claimed_sum_offset".to_string() + "claimed_sum".to_string() } else { - offset.to_string() + let offset_abs = offset.abs(); + if *offset >= 0 { + offset.to_string() + } else { + format!("neg_{offset_abs}") + } }; - format!("col_{interaction}_{idx}[{offset_str}]") + format!("trace_{interaction}_column_{idx}_offset_{offset_str}") } - BaseExpr::Const(c) => c.to_string(), + BaseExpr::Const(c) => format!("m31({c}).into()"), BaseExpr::Param(v) => v.to_string(), BaseExpr::Add(a, b) => format!("{} + {}", a.format_expr(), b.format_expr()), BaseExpr::Sub(a, b) => format!("{} - ({})", a.format_expr(), b.format_expr()), @@ -38,7 +43,7 @@ impl ExtExpr { a.format_expr() } else { format!( - "SecureCol({}, {}, {}, {})", + "QM31Impl::from_partial_evals([{}, {}, {}, {}])", a.format_expr(), b.format_expr(), c.format_expr(), @@ -47,12 +52,8 @@ impl ExtExpr { } } ExtExpr::Const(c) => { - if c.0 .1.is_zero() && c.1 .0.is_zero() && c.1 .1.is_zero() { - // If the constant is in the base field, display it as such. - c.0 .0.to_string() - } else { - c.to_string() - } + let [v0, v1, v2, v3] = c.to_m31_array(); + format!("qm31({v0}, {v1}, {v2}, {v3})") } ExtExpr::Param(v) => v.to_string(), ExtExpr::Add(a, b) => format!("{} + {}", a.format_expr(), b.format_expr()), diff --git a/crates/prover/src/constraint_framework/preprocessed_columns.rs b/crates/prover/src/constraint_framework/preprocessed_columns.rs index 5c2a52df7..a54ebc734 100644 --- a/crates/prover/src/constraint_framework/preprocessed_columns.rs +++ b/crates/prover/src/constraint_framework/preprocessed_columns.rs @@ -30,10 +30,10 @@ pub enum PreprocessedColumn { impl PreprocessedColumn { pub const fn name(&self) -> &'static str { match self { - PreprocessedColumn::IsFirst(_) => "preprocessed.is_first", - PreprocessedColumn::Plonk(_) => "preprocessed.plonk", - PreprocessedColumn::Seq(_) => "preprocessed.seq", - PreprocessedColumn::XorTable(..) => "preprocessed.xor_table", + PreprocessedColumn::IsFirst(_) => "preprocessed_is_first", + PreprocessedColumn::Plonk(_) => "preprocessed_plonk", + PreprocessedColumn::Seq(_) => "preprocessed_seq", + PreprocessedColumn::XorTable(..) => "preprocessed_xor_table", } } diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index 684f04f76..5de104188 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -358,28 +358,28 @@ mod tests { ); let eval = component.evaluate(ExprEvaluator::new(log_n_rows, true)); - let expected = "let intermediate0 = (StateMachineElements_alpha0) * (col_1_0[0]) \ - + (StateMachineElements_alpha1) * (col_1_1[0]) \ + let expected = "let intermediate0 = (StateMachineElements_alpha0) * (trace_1_column_0_offset_0) \ + + (StateMachineElements_alpha1) * (trace_1_column_1_offset_0) \ - (StateMachineElements_z); \ - let intermediate1 = (StateMachineElements_alpha0) * (col_1_0[0] + 1) \ - + (StateMachineElements_alpha1) * (col_1_1[0]) \ + let intermediate1 = (StateMachineElements_alpha0) * (trace_1_column_0_offset_0 + m31(1).into()) \ + + (StateMachineElements_alpha1) * (trace_1_column_1_offset_0) \ - (StateMachineElements_z); \ - let constraint_0 = (SecureCol(\ - col_2_2[claimed_sum_offset], \ - col_2_3[claimed_sum_offset], \ - col_2_4[claimed_sum_offset], \ - col_2_5[claimed_sum_offset]\ - ) - (claimed_sum)) \ - * (preprocessed.is_first); + let constraint_0 = (QM31Impl::from_partial_evals([\ + trace_2_column_2_offset_claimed_sum, \ + trace_2_column_3_offset_claimed_sum, \ + trace_2_column_4_offset_claimed_sum, \ + trace_2_column_5_offset_claimed_sum\ + ]) - (claimed_sum)) \ + * (preprocessed_is_first); \ - let constraint_1 = (SecureCol(col_2_2[0], col_2_3[0], col_2_4[0], col_2_5[0]) \ - - (SecureCol(col_2_2[-1], col_2_3[-1], col_2_4[-1], col_2_5[-1]) \ - - ((total_sum) * (preprocessed.is_first)))\ + let constraint_1 = (QM31Impl::from_partial_evals([trace_2_column_2_offset_0, trace_2_column_3_offset_0, trace_2_column_4_offset_0, trace_2_column_5_offset_0]) \ + - (QM31Impl::from_partial_evals([trace_2_column_2_offset_neg_1, trace_2_column_3_offset_neg_1, trace_2_column_4_offset_neg_1, trace_2_column_5_offset_neg_1]) \ + - ((total_sum) * (preprocessed_is_first)))\ ) \ * ((intermediate0) * (intermediate1)) \ - (intermediate1 - (intermediate0));" From b393d23a4f9d785792b4ec48affb41b21bff8c97 Mon Sep 17 00:00:00 2001 From: VitaliiH Date: Tue, 7 Jan 2025 14:27:44 +0100 Subject: [PATCH 55/69] merged latest icicle-backend in - full correctness --- Cargo.lock | 8 +- crates/prover/Cargo.toml | 14 +- .../src/constraint_framework/component.rs | 16 ++ crates/prover/src/core/air/components.rs | 12 +- crates/prover/src/core/backend/icicle/mod.rs | 138 +++++++++++++---- .../prover/src/core/backend/icicle/utils.rs | 9 ++ .../src/core/backend/simd/accumulation.rs | 2 + crates/prover/src/core/backend/simd/circle.rs | 33 +++- crates/prover/src/core/backend/simd/fri.rs | 24 ++- .../prover/src/core/backend/simd/quotients.rs | 14 +- crates/prover/src/core/fri.rs | 18 ++- crates/prover/src/core/pcs/prover.rs | 28 ++-- crates/prover/src/core/pcs/quotients.rs | 4 - crates/prover/src/core/prover/mod.rs | 14 +- .../prover/src/examples/wide_fibonacci/mod.rs | 142 +++++++++--------- 15 files changed, 338 insertions(+), 138 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3b357b2c9..19c067fda 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -760,7 +760,7 @@ dependencies = [ [[package]] name = "icicle-core" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=eb97ee08ef26d4de6e0695df3484abeddd83c08c#eb97ee08ef26d4de6e0695df3484abeddd83c08c" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=cc4fc660ea331bab1fe725cd8a93b371bb59c62a#cc4fc660ea331bab1fe725cd8a93b371bb59c62a" dependencies = [ "criterion 0.3.6", "hex", @@ -771,7 +771,7 @@ dependencies = [ [[package]] name = "icicle-cuda-runtime" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=eb97ee08ef26d4de6e0695df3484abeddd83c08c#eb97ee08ef26d4de6e0695df3484abeddd83c08c" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=cc4fc660ea331bab1fe725cd8a93b371bb59c62a#cc4fc660ea331bab1fe725cd8a93b371bb59c62a" dependencies = [ "bindgen", "bitflags 1.3.2", @@ -780,7 +780,7 @@ dependencies = [ [[package]] name = "icicle-hash" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=eb97ee08ef26d4de6e0695df3484abeddd83c08c#eb97ee08ef26d4de6e0695df3484abeddd83c08c" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=cc4fc660ea331bab1fe725cd8a93b371bb59c62a#cc4fc660ea331bab1fe725cd8a93b371bb59c62a" dependencies = [ "cmake", "icicle-core", @@ -790,7 +790,7 @@ dependencies = [ [[package]] name = "icicle-m31" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=eb97ee08ef26d4de6e0695df3484abeddd83c08c#eb97ee08ef26d4de6e0695df3484abeddd83c08c" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=cc4fc660ea331bab1fe725cd8a93b371bb59c62a#cc4fc660ea331bab1fe725cd8a93b371bb59c62a" dependencies = [ "cmake", "criterion 0.3.6", diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index ca5930ddc..a3bd229d1 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -29,10 +29,16 @@ tracing.workspace = true rayon = { version = "1.10.0", optional = true } serde = { version = "1.0", features = ["derive"] } -icicle-cuda-runtime = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="eb97ee08ef26d4de6e0695df3484abeddd83c08c"} -icicle-core = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="eb97ee08ef26d4de6e0695df3484abeddd83c08c"} -icicle-m31 = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="eb97ee08ef26d4de6e0695df3484abeddd83c08c"} -icicle-hash = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="eb97ee08ef26d4de6e0695df3484abeddd83c08c"} +icicle-cuda-runtime = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="cc4fc660ea331bab1fe725cd8a93b371bb59c62a"} +icicle-core = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="cc4fc660ea331bab1fe725cd8a93b371bb59c62a"} +icicle-m31 = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="cc4fc660ea331bab1fe725cd8a93b371bb59c62a"} +icicle-hash = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="cc4fc660ea331bab1fe725cd8a93b371bb59c62a"} + +# icicle-cuda-runtime = { path = "/home/vhnat/repo/icicle/wrappers/rust/icicle-cuda-runtime", optional = true} +# icicle-core = { path = "/home/vhnat/repo/icicle/wrappers/rust/icicle-core", optional = true} +# icicle-m31 = { path = "/home/vhnat/repo/icicle/wrappers/rust/icicle-fields/icicle-m31", optional = true} +# icicle-hash = { path = "/home/vhnat/repo/icicle/wrappers/rust/icicle-hash", optional = true} + nvtx = { version = "*", optional = true } diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index a1724830c..ab957fdac 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -527,26 +527,35 @@ impl ComponentProver for FrameworkCompon return; } + nvtx::range_push!("create eval domain"); let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain(); + nvtx::range_pop!(); + nvtx::range_push!("create trace domain"); let trace_domain = CanonicCoset::new(self.eval.log_size()); + nvtx::range_pop!(); + nvtx::range_push!("component_polys"); let mut component_polys = trace.polys.sub_tree(&self.trace_locations); component_polys[PREPROCESSED_TRACE_IDX] = self .preprocessed_column_indices .iter() .map(|idx| &trace.polys[PREPROCESSED_TRACE_IDX][*idx]) .collect(); + nvtx::range_pop!(); + nvtx::range_push!("component_evals"); let mut component_evals = trace.evals.sub_tree(&self.trace_locations); component_evals[PREPROCESSED_TRACE_IDX] = self .preprocessed_column_indices .iter() .map(|idx| &trace.evals[PREPROCESSED_TRACE_IDX][*idx]) .collect(); + nvtx::range_pop!(); // Extend trace if necessary. // TODO: Don't extend when eval_size < committed_size. Instead, pick a good // subdomain. (For larger blowup factors). + nvtx::range_push!("extend trace"); let need_to_extend = component_evals .iter() .flatten() @@ -562,22 +571,28 @@ impl ComponentProver for FrameworkCompon } else { component_evals.clone().map_cols(|c| Cow::Borrowed(*c)) }; + nvtx::range_pop!(); // Denom inverses. + nvtx::range_push!("denom inverses"); let log_expand = eval_domain.log_size() - trace_domain.log_size(); let mut denom_inv = (0..1 << log_expand) .map(|i| coset_vanishing(trace_domain.coset(), eval_domain.at(i)).inverse()) .collect_vec(); utils::bit_reverse(&mut denom_inv); + nvtx::range_pop!(); // Accumulator. + nvtx::range_push!("accum"); let [mut accum] = evaluation_accumulator.columns([(eval_domain.log_size(), self.n_constraints())]); accum.random_coeff_powers.reverse(); + nvtx::range_pop!(); let _span = span!(Level::INFO, "Constraint point-wise eval").entered(); let col = accum.col; + nvtx::range_push!("eval constr at row loop"); for row in 0..(1 << eval_domain.log_size()) { let trace_cols = trace.as_cols_ref().map_cols(|cow| match cow { Cow::Borrowed(borrowed) => *borrowed, @@ -600,6 +615,7 @@ impl ComponentProver for FrameworkCompon let denom_inv = denom_inv[row >> trace_domain.log_size()]; col.set(row, col.at(row) + row_res * denom_inv) } + nvtx::range_pop!(); accum.col = col; return; } diff --git a/crates/prover/src/core/air/components.rs b/crates/prover/src/core/air/components.rs index 3f9bf78ad..244bc0aac 100644 --- a/crates/prover/src/core/air/components.rs +++ b/crates/prover/src/core/air/components.rs @@ -124,15 +124,25 @@ impl<'a, B: Backend> ComponentProvers<'a, B> { random_coeff: SecureField, trace: &Trace<'_, B>, ) -> SecureCirclePoly { + nvtx::range_push!("total_constraints"); let total_constraints: usize = self.components.iter().map(|c| c.n_constraints()).sum(); + nvtx::range_pop!(); + nvtx::range_push!("accumulator"); let mut accumulator = DomainEvaluationAccumulator::new( random_coeff, self.components().composition_log_degree_bound(), total_constraints, ); + nvtx::range_pop!(); + nvtx::range_push!("eval_constr_quot_on_domain"); for component in &self.components { component.evaluate_constraint_quotients_on_domain(trace, &mut accumulator) } - accumulator.finalize() + nvtx::range_pop!(); + nvtx::range_push!("accum.finalize"); + let res = accumulator.finalize(); + nvtx::range_pop!(); + + res } } diff --git a/crates/prover/src/core/backend/icicle/mod.rs b/crates/prover/src/core/backend/icicle/mod.rs index ccee18192..af984c4a2 100644 --- a/crates/prover/src/core/backend/icicle/mod.rs +++ b/crates/prover/src/core/backend/icicle/mod.rs @@ -9,11 +9,12 @@ use std::mem::{size_of_val, transmute}; use icicle_core::tree::{merkle_tree_digests_len, TreeBuilderConfig}; use icicle_core::vec_ops::{accumulate_scalars, VecOpsConfig}; -use icicle_core::Matrix; +use icicle_core::{field::Field as IcicleField, Matrix}; use icicle_hash::blake2s::build_blake2s_mmcs; use icicle_m31::dcct::{evaluate, get_dcct_root_of_unity, initialize_dcct_domain, interpolate}; use icicle_m31::fri::{self, fold_circle_into_line, fold_circle_into_line_new, FriConfig}; use icicle_m31::quotient; +use icicle_m31::field::ScalarCfg; use itertools::Itertools; use serde::{Deserialize, Serialize}; use twiddles::TwiddleTree; @@ -124,9 +125,11 @@ impl AccumulationOps for IcicleBackend { let is_a_on_host = get_device_from_pointer(a_ptr).unwrap() == 18446744073709551614; let mut col_a; if is_a_on_host { + nvtx::range_push!("[ICICLE] convert + move"); col_a = DeviceVec::::cuda_malloc(n).unwrap(); d_a_slice = &mut col_a[..]; SecureColumnByCoords::convert_to_icicle(column, d_a_slice); + nvtx::range_pop!(); } else { let mut v_ptr = a_ptr as *mut QuarticExtensionField; let rr = unsafe { slice::from_raw_parts_mut(v_ptr, n) }; @@ -136,21 +139,27 @@ impl AccumulationOps for IcicleBackend { let mut d_b_slice; let mut col_b; if get_device_from_pointer(b_ptr).unwrap() == 18446744073709551614 { + nvtx::range_push!("[ICICLE] convert + move"); col_b = DeviceVec::::cuda_malloc(n).unwrap(); d_b_slice = &mut col_b[..]; SecureColumnByCoords::convert_to_icicle(other, d_b_slice); + nvtx::range_pop!(); } else { let mut v_ptr = b_ptr as *mut QuarticExtensionField; let rr = unsafe { slice::from_raw_parts_mut(v_ptr, n) }; d_b_slice = DeviceSlice::from_mut_slice(rr); } - + + nvtx::range_push!("[ICICLE] accum scalars"); accumulate_scalars(d_a_slice, d_b_slice, &cfg); + nvtx::range_pop!(); + nvtx::range_push!("[ICICLE] convert + move to SecureColumnByCoords"); let mut v_ptr = d_a_slice.as_mut_ptr() as *mut _; let d_slice = unsafe { slice::from_raw_parts_mut(v_ptr, secure_degree * n) }; let d_a_slice = DeviceSlice::from_mut_slice(d_slice); SecureColumnByCoords::convert_from_icicle(column, d_a_slice); + nvtx::range_pop!(); } } } @@ -167,6 +176,7 @@ impl MerkleOps for IcicleBackend { config.digest_elements = 32; config.sort_inputs = false; + nvtx::range_push!("[ICICLE] log_max"); let log_max = columns .iter() .sorted_by_key(|c| Reverse(c.len())) @@ -174,21 +184,29 @@ impl MerkleOps for IcicleBackend { .unwrap() .len() .ilog2(); + nvtx::range_pop!(); let mut matrices = vec![]; + nvtx::range_push!("[ICICLE] create matrix"); for col in columns.into_iter().sorted_by_key(|c| Reverse(c.len())) { matrices.push(Matrix::from_slice(col, 4, col.len())); } + nvtx::range_pop!(); + nvtx::range_push!("[ICICLE] merkle_tree_digests_len"); let digests_len = merkle_tree_digests_len(log_max as u32, 2, 32); + nvtx::range_pop!(); let mut digests = vec![0u8; digests_len]; let digests_slice = HostSlice::from_mut_slice(&mut digests); - + + nvtx::range_push!("[ICICLE] build_blake2s_mmcs"); build_blake2s_mmcs(&matrices, digests_slice, &config).unwrap(); + nvtx::range_pop!(); let mut digests: &[::Hash] = unsafe { std::mem::transmute(digests.as_mut_slice()) }; // Transmute digests into stwo format let mut layers = vec![]; let mut offset = 0usize; + nvtx::range_push!("[ICICLE] convert to CPU layer"); for log in 0..=log_max { let inv_log = log_max - log; let number_of_rows = 1 << inv_log; @@ -203,6 +221,7 @@ impl MerkleOps for IcicleBackend { } layers.reverse(); + nvtx::range_pop!(); layers } @@ -251,19 +270,27 @@ impl PolyOps for IcicleBackend { } let values = eval.values; + nvtx::range_push!("[ICICLE] get_dcct_root_of_unity"); let rou = get_dcct_root_of_unity(eval.domain.size() as _); + nvtx::range_pop!(); + + nvtx::range_push!("[ICICLE] initialize_dcct_domain"); initialize_dcct_domain(eval.domain.log_size(), rou, &DeviceContext::default()).unwrap(); + nvtx::range_pop!(); let mut evaluations = vec![ScalarField::zero(); values.len()]; let values: Vec = unsafe { transmute(values) }; let mut cfg = NTTConfig::default(); cfg.ordering = Ordering::kMN; + nvtx::range_push!("[ICICLE] interpolate"); interpolate( HostSlice::from_slice(&values), &cfg, HostSlice::from_mut_slice(&mut evaluations), ) .unwrap(); + nvtx::range_pop!(); + let values: Vec = unsafe { transmute(evaluations) }; CirclePoly::new(values) @@ -276,6 +303,7 @@ impl PolyOps for IcicleBackend { return poly.coeffs[0].into(); } // TODO: to gpu after correctness fix + nvtx::range_push!("[ICICLE] create mappings"); let mut mappings = vec![point.y]; let mut x = point.x; for _ in 1..poly.log_size() { @@ -283,8 +311,12 @@ impl PolyOps for IcicleBackend { x = CirclePoint::double_x(x); } mappings.reverse(); - - crate::core::backend::icicle::utils::fold(&poly.coeffs, &mappings) + nvtx::range_pop!(); + + nvtx::range_push!("[ICICLE] fold"); + let folded = crate::core::backend::icicle::utils::fold(&poly.coeffs, &mappings); + nvtx::range_pop!(); + folded } fn extend(poly: &CirclePoly, log_size: u32) -> CirclePoly { @@ -309,20 +341,25 @@ impl PolyOps for IcicleBackend { } let values = poly.extend(domain.log_size()).coeffs; - + nvtx::range_push!("[ICICLE] get_dcct_root_of_unity"); let rou = get_dcct_root_of_unity(domain.size() as _); + nvtx::range_pop!(); + nvtx::range_push!("[ICICLE] initialize_dcct_domain"); initialize_dcct_domain(domain.log_size(), rou, &DeviceContext::default()).unwrap(); + nvtx::range_pop!(); let mut evaluations = vec![ScalarField::zero(); values.len()]; let values: Vec = unsafe { transmute(values) }; let mut cfg = NTTConfig::default(); cfg.ordering = Ordering::kNM; + nvtx::range_push!("[ICICLE] evaluate"); evaluate( HostSlice::from_slice(&values), &cfg, HostSlice::from_mut_slice(&mut evaluations), ) .unwrap(); + nvtx::range_pop!(); unsafe { transmute(IcicleCircleEvaluation::::new( domain, @@ -467,6 +504,7 @@ impl FriOps for IcicleBackend { let mut domain_vals = Vec::new(); let line_domain_log_size = domain.log_size(); + nvtx::range_push!("[ICICLE] calc domain values"); for i in 0..dom_vals_len { // TODO: on-device batch // TODO(andrew): Inefficient. Update when domain twiddles get stored in a buffer. @@ -477,21 +515,27 @@ impl FriOps for IcicleBackend { .0, )); } + nvtx::range_pop!(); + nvtx::range_push!("[ICICLE] domain values to device"); let domain_icicle_host = HostSlice::from_slice(domain_vals.as_slice()); let mut d_domain_icicle = DeviceVec::::cuda_malloc(dom_vals_len).unwrap(); d_domain_icicle.copy_from_host(domain_icicle_host).unwrap(); - + nvtx::range_pop!(); + + nvtx::range_push!("[ICICLE] domain evals convert + move"); let mut d_evals_icicle = DeviceVec::::cuda_malloc(length).unwrap(); SecureColumnByCoords::::convert_to_icicle( unsafe { transmute(&eval.values) }, &mut d_evals_icicle, ); + nvtx::range_pop!(); let mut d_folded_eval = DeviceVec::::cuda_malloc(dom_vals_len).unwrap(); let cfg = FriConfig::default(); let icicle_alpha = unsafe { transmute(alpha) }; + nvtx::range_push!("[ICICLE] fold_line"); let _ = fri::fold_line( &d_evals_icicle[..], &d_domain_icicle[..], @@ -500,14 +544,19 @@ impl FriOps for IcicleBackend { &cfg, ) .unwrap(); + nvtx::range_pop!(); + nvtx::range_push!("[ICICLE] convert to SecureColumnByCoords"); let mut folded_values = unsafe { SecureColumnByCoords::uninitialized(dom_vals_len) }; SecureColumnByCoords::::convert_from_icicle_q31( &mut folded_values, &mut d_folded_eval[..], ); - LineEvaluation::new(domain.double(), folded_values) + let line_eval = LineEvaluation::new(domain.double(), folded_values); + nvtx::range_pop!(); + + line_eval } fn fold_circle_into_line( @@ -549,9 +598,13 @@ impl FriOps for IcicleBackend { let mut d_evals_icicle = DeviceVec::::cuda_malloc(length).unwrap(); SecureColumnByCoords::convert_to_icicle(&src.values, &mut d_evals_icicle); + nvtx::range_pop!(); + + nvtx::range_push!("[ICICLE] d_folded_evals"); let mut d_folded_eval = - DeviceVec::::cuda_malloc(dom_vals_len).unwrap(); + DeviceVec::::cuda_malloc(dom_vals_len).unwrap(); SecureColumnByCoords::convert_to_icicle(&dst.values, &mut d_folded_eval); + nvtx::range_pop!(); let mut folded_eval_raw = vec![QuarticExtensionField::zero(); dom_vals_len]; let folded_eval = HostSlice::from_mut_slice(folded_eval_raw.as_mut_slice()); @@ -559,6 +612,7 @@ impl FriOps for IcicleBackend { let cfg = FriConfig::default(); let icicle_alpha = unsafe { transmute(alpha) }; + nvtx::range_push!("[ICICLE] fold circle"); let _ = fold_circle_into_line_new( &d_evals_icicle[..], domain.half_coset.initial_index.0 as _, @@ -568,10 +622,13 @@ impl FriOps for IcicleBackend { &cfg, ) .unwrap(); + nvtx::range_pop!(); d_folded_eval.copy_to_host(folded_eval).unwrap(); + nvtx::range_push!("[ICICLE] convert to SecureColumnByCoords"); SecureColumnByCoords::convert_from_icicle_q31(&mut dst.values, &mut d_folded_eval[..]); + nvtx::range_pop!(); } fn decompose( @@ -631,49 +688,66 @@ impl QuotientOps for IcicleBackend { // )) // } - let icicle_columns_raw = columns - .iter() - .flat_map(|x| x.iter().map(|&y| unsafe { transmute(y) })) - .collect_vec(); - let icicle_columns = HostSlice::from_slice(&icicle_columns_raw); + let total_columns_size = columns.iter().fold(0, |acc, column| acc + column.values.len()); + let mut icicle_device_columns = DeviceVec::cuda_malloc(total_columns_size).unwrap(); + let mut start = 0; + nvtx::range_push!("[ICICLE] columns to device"); + columns.iter().for_each(|column| { + let end = start + column.values.len(); + let device_slice = &mut icicle_device_columns[start..end]; + let transmuted: Vec> = unsafe { transmute(column.values.clone()) }; + device_slice.copy_from_host(&HostSlice::from_slice(&transmuted)); + start += column.values.len(); + }); + nvtx::range_pop!(); + + nvtx::range_push!("[ICICLE] column sample batch"); let icicle_sample_batches = sample_batches .into_iter() .map(|sample| { let (columns, values) = sample - .columns_and_values - .iter() - .map(|(index, value)| { - ((*index) as u32, unsafe { - transmute::(*value) + .columns_and_values + .iter() + .map(|(index, value)| { + ((*index) as u32, unsafe { + transmute::(*value) + }) }) - }) - .unzip(); - - quotient::ColumnSampleBatch { - point: unsafe { transmute(sample.point) }, - columns, - values, - } - }) - .collect_vec(); + .unzip(); + + quotient::ColumnSampleBatch { + point: unsafe { transmute(sample.point) }, + columns, + values, + } + }) + .collect_vec(); + nvtx::range_pop!(); + let mut icicle_result_raw = vec![QuarticExtensionField::zero(); domain.size()]; let icicle_result = HostSlice::from_mut_slice(icicle_result_raw.as_mut_slice()); let cfg = quotient::QuotientConfig::default(); - + + nvtx::range_push!("[ICICLE] accumulate_quotients_wrapped"); quotient::accumulate_quotients_wrapped( // domain.half_coset.initial_index.0 as u32, // domain.half_coset.step_size.0 as u32, domain.log_size() as u32, - icicle_columns, + &icicle_device_columns[..], unsafe { transmute(random_coeff) }, &icicle_sample_batches, icicle_result, &cfg, ); + nvtx::range_pop!(); // TODO: make it on cuda side + nvtx::range_push!("[ICICLE] res to SecureEvaluation"); let mut result = unsafe { SecureColumnByCoords::uninitialized(domain.size()) }; (0..domain.size()).for_each(|i| result.set(i, unsafe { transmute(icicle_result_raw[i]) })); - SecureEvaluation::new(domain, result) + let ret = SecureEvaluation::new(domain, result); + nvtx::range_pop!(); + + ret } } diff --git a/crates/prover/src/core/backend/icicle/utils.rs b/crates/prover/src/core/backend/icicle/utils.rs index 7de84d099..bdb28c63d 100644 --- a/crates/prover/src/core/backend/icicle/utils.rs +++ b/crates/prover/src/core/backend/icicle/utils.rs @@ -108,14 +108,23 @@ mod tests { // Initialize the `values` vector let mut values: Vec = Vec::with_capacity(values_length); + #[cfg(feature = "parallel")] use rayon::iter::IntoParallelIterator; + #[cfg(feature = "parallel")] use rayon::prelude::*; + #[cfg(feature = "parallel")] let values: Vec = (1..=values_length) .into_par_iter() .map(|i| M31(i as u32)) .collect(); + #[cfg(not(feature = "parallel"))] + let values: Vec = (1..=values_length) + .into_iter() + .map(|i| M31(i as u32)) + .collect(); + // Initialize the `folding_factors` vector let mut folding_factors = Vec::with_capacity(folding_factors_length); for i in 2..(2 + folding_factors_length) { diff --git a/crates/prover/src/core/backend/simd/accumulation.rs b/crates/prover/src/core/backend/simd/accumulation.rs index c9705df6b..bea476b56 100644 --- a/crates/prover/src/core/backend/simd/accumulation.rs +++ b/crates/prover/src/core/backend/simd/accumulation.rs @@ -4,9 +4,11 @@ use crate::core::fields::secure_column::SecureColumnByCoords; impl AccumulationOps for SimdBackend { fn accumulate(column: &mut SecureColumnByCoords, other: &SecureColumnByCoords) { + nvtx::range_push!("[SIMD] loop pack"); for i in 0..column.packed_len() { let res_coeff = unsafe { column.packed_at(i) + other.packed_at(i) }; unsafe { column.set_packed(i, res_coeff) }; } + nvtx::range_pop!(); } } diff --git a/crates/prover/src/core/backend/simd/circle.rs b/crates/prover/src/core/backend/simd/circle.rs index a20721a4f..5963c5e8d 100644 --- a/crates/prover/src/core/backend/simd/circle.rs +++ b/crates/prover/src/core/backend/simd/circle.rs @@ -154,9 +154,12 @@ impl PolyOps for SimdBackend { } let mut values = eval.values; + nvtx::range_push!("[SIMD] domain_line_twiddles_from_tree"); let twiddles = domain_line_twiddles_from_tree(eval.domain, &twiddles.itwiddles); - + nvtx::range_pop!(); + // Safe because [PackedBaseField] is aligned on 64 bytes. + nvtx::range_push!("[SIMD] ifft"); unsafe { ifft::ifft( transmute(values.data.as_mut_ptr()), @@ -164,10 +167,13 @@ impl PolyOps for SimdBackend { log_size as usize, ); } - + nvtx::range_pop!(); + // TODO(alont): Cache this inversion. + nvtx::range_push!("[SIMD] invert"); let inv = PackedBaseField::broadcast(BaseField::from(eval.domain.size()).inverse()); values.data.iter_mut().for_each(|x| *x *= inv); + nvtx::range_pop!(); CirclePoly::new(values) } @@ -179,18 +185,26 @@ impl PolyOps for SimdBackend { return slow_eval_at_point(poly, point); } + nvtx::range_push!("[SIMD] generate mappings"); let mappings = Self::generate_evaluation_mappings(point, poly.log_size()); + nvtx::range_pop!(); // 8 lowest mappings produce the first 2^8 twiddles. Separate to optimize each calculation. + nvtx::range_push!("[SIMD] twiddle_lows"); let (map_low, map_high) = mappings.split_at(4); let twiddle_lows = PackedSecureField::from_array(std::array::from_fn(|i| Self::twiddle_at(map_low, i))); + nvtx::range_pop!(); + nvtx::range_push!("[SIMD] twiddle_mid"); let (map_mid, map_high) = map_high.split_at(4); let twiddle_mids = PackedSecureField::from_array(std::array::from_fn(|i| Self::twiddle_at(map_mid, i))); + nvtx::range_pop!(); // Compute the high twiddle steps. + nvtx::range_push!("[SIMD] twiddle_steps"); let twiddle_steps = Self::twiddle_steps(map_high); + nvtx::range_pop!(); // Every twiddle is a product of mappings that correspond to '1's in the bit representation // of the current index. For every 2^n alligned chunk of 2^n elements, the twiddle @@ -212,10 +226,16 @@ impl PolyOps for SimdBackend { } // Advance twiddle high. + nvtx::range_push!("[SIMD] advance_twiddle"); twiddle_high = Self::advance_twiddle(twiddle_high, &twiddle_steps, i); + nvtx::range_pop!(); } - (sum * twiddle_lows).pointwise_sum() + nvtx::range_push!("[SIMD] pointwise_sum"); + let pointwise_sum = (sum * twiddle_lows).pointwise_sum(); + nvtx::range_pop!(); + + pointwise_sum } fn extend(poly: &CirclePoly, log_size: u32) -> CirclePoly { @@ -245,7 +265,9 @@ impl PolyOps for SimdBackend { ); } + nvtx::range_push!("[SIMD] domain_line_twiddles_from_tree"); let twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles); + nvtx::range_pop!(); // Evaluate on a big domains by evaluating on several subdomains. let log_subdomains = log_size - fft_log_size; @@ -259,14 +281,16 @@ impl PolyOps for SimdBackend { for i in 0..(1 << log_subdomains) { // The subdomain twiddles are a slice of the large domain twiddles. + nvtx::range_push!("[SIMD] calc subdomain twiddles"); let subdomain_twiddles = (0..(fft_log_size - 1)) .map(|layer_i| { &twiddles[layer_i as usize] [i << (fft_log_size - 2 - layer_i)..(i + 1) << (fft_log_size - 2 - layer_i)] }) .collect::>(); - + nvtx::range_pop!(); // FFT from the coefficients buffer to the values chunk. + nvtx::range_push!("[SIMD] fft"); unsafe { rfft::fft( transmute(poly.coeffs.data.as_ptr()), @@ -279,6 +303,7 @@ impl PolyOps for SimdBackend { fft_log_size as usize, ); } + nvtx::range_pop!(); } CircleEvaluation::new( diff --git a/crates/prover/src/core/backend/simd/fri.rs b/crates/prover/src/core/backend/simd/fri.rs index 4e9f4271d..8aba27e4b 100644 --- a/crates/prover/src/core/backend/simd/fri.rs +++ b/crates/prover/src/core/backend/simd/fri.rs @@ -33,7 +33,9 @@ impl FriOps for SimdBackend { } let domain = eval.domain(); + nvtx::range_push!("[SIMD] domain_line_twiddles_from_tree"); let itwiddles = domain_line_twiddles_from_tree(domain, &twiddles.itwiddles)[0]; + nvtx::range_pop!(); let mut folded_values = SecureColumnByCoords::::zeros(1 << (log_size - 1)); @@ -44,14 +46,22 @@ impl FriOps for SimdBackend { let val0 = eval.values.packed_at(vec_index * 2).into_packed_m31s(); let val1 = eval.values.packed_at(vec_index * 2 + 1).into_packed_m31s(); let pairs: [_; 4] = array::from_fn(|i| { + nvtx::range_push!("[SIMD] deinterleave"); let (a, b) = val0[i].deinterleave(val1[i]); - simd_ibutterfly(a, b, std::mem::transmute(twiddle_dbl)) + nvtx::range_pop!(); + nvtx::range_push!("[SIMD] simd_ibutterfly"); + let butterfly = simd_ibutterfly(a, b, std::mem::transmute(twiddle_dbl)); + nvtx::range_pop!(); + + butterfly }); let val0 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].0)); let val1 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].1)); val0 + PackedSecureField::broadcast(alpha) * val1 }; + nvtx::range_push!("[SIMD] simd_ibutterfly"); unsafe { folded_values.set_packed(vec_index, value) }; + nvtx::range_pop!(); } LineEvaluation::new(domain.double(), folded_values) @@ -77,7 +87,9 @@ impl FriOps for SimdBackend { let domain = src.domain; let alpha_sq = alpha * alpha; + nvtx::range_push!("[SIMD] domain_line_twiddles_from_tree"); let itwiddles = domain_line_twiddles_from_tree(domain, &twiddles.itwiddles)[0]; + nvtx::range_pop!(); for vec_index in 0..(1 << (log_size - 1 - LOG_N_LANES)) { let value = unsafe { @@ -90,13 +102,20 @@ impl FriOps for SimdBackend { let val0 = src.values.packed_at(vec_index * 2).into_packed_m31s(); let val1 = src.values.packed_at(vec_index * 2 + 1).into_packed_m31s(); let pairs: [_; 4] = array::from_fn(|i| { + nvtx::range_push!("[SIMD] deinterleave"); let (a, b) = val0[i].deinterleave(val1[i]); - simd_ibutterfly(a, b, t0) + nvtx::range_pop!(); + nvtx::range_push!("[SIMD] simd_ibutterfly"); + let butter = simd_ibutterfly(a, b, t0); + nvtx::range_pop!(); + + butter }); let val0 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].0)); let val1 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].1)); val0 + PackedSecureField::broadcast(alpha) * val1 }; + nvtx::range_push!("[SIMD] set packed"); unsafe { dst.values.set_packed( vec_index, @@ -104,6 +123,7 @@ impl FriOps for SimdBackend { + value, ) }; + nvtx::range_pop!(); } } diff --git a/crates/prover/src/core/backend/simd/quotients.rs b/crates/prover/src/core/backend/simd/quotients.rs index 553e540a5..2620b5b17 100644 --- a/crates/prover/src/core/backend/simd/quotients.rs +++ b/crates/prover/src/core/backend/simd/quotients.rs @@ -67,8 +67,11 @@ impl QuotientOps for SimdBackend { // b2 b3 b4 b5 is indeed a circle domain, with a bigger jump. // Traversing the domain in bit-reversed order, after we finish with b5, b4, b3, b2, // we need to change b1 and then b0. This is the bit reverse of the shift b0 b1. + nvtx::range_push!("[SIMD] bit_reverse"); bit_reverse(&mut subdomain_shifts); + nvtx::range_pop!(); + nvtx::range_push!("[SIMD] accumulate_quotients_on_subdomain"); let (span, mut extended_eval, subeval_polys) = accumulate_quotients_on_subdomain( subdomain, sample_batches, @@ -76,12 +79,14 @@ impl QuotientOps for SimdBackend { columns, domain, ); + nvtx::range_pop!(); // Extend the evaluation to the full domain. // TODO(Ohad): Try to optimize out all these copies. + nvtx::range_push!("[SIMD] extend to full domain"); for (ci, &c) in subdomain_shifts.iter().enumerate() { let subdomain = subdomain.shift(c); - + let twiddles = SimdBackend::precompute_twiddles(subdomain.half_coset); #[allow(clippy::needless_range_loop)] for i in 0..SECURE_EXTENSION_DEGREE { @@ -92,8 +97,11 @@ impl QuotientOps for SimdBackend { } } span.exit(); - - SecureEvaluation::new(domain, extended_eval) + + let ret = SecureEvaluation::new(domain, extended_eval); + nvtx::range_pop!(); + + ret } } diff --git a/crates/prover/src/core/fri.rs b/crates/prover/src/core/fri.rs index 2e9d17c5e..12a3b06df 100644 --- a/crates/prover/src/core/fri.rs +++ b/crates/prover/src/core/fri.rs @@ -160,17 +160,19 @@ impl<'a, B: FriOps + MerkleOps, MC: MerkleChannel> FriProver<'a, B, MC> { columns: &'a [SecureEvaluation], twiddles: &TwiddleTree, ) -> Self { - #[cfg(feature = "icicle")] - nvtx::range_push!("fn FriPorover::commit("); assert!(!columns.is_empty(), "no columns"); assert!(columns.is_sorted_by_key(|e| Reverse(e.len())), "not sorted"); assert!(columns.iter().all(|e| e.domain.is_canonic()), "not canonic"); + nvtx::range_push!("commit_first_layer"); let first_layer = Self::commit_first_layer(channel, columns); + nvtx::range_pop!(); + nvtx::range_push!("commit_inner_layer"); let (inner_layers, last_layer_evaluation) = Self::commit_inner_layers(channel, config, columns, twiddles); + nvtx::range_pop!(); + nvtx::range_push!("commit_last_layer"); let last_layer_poly = Self::commit_last_layer(channel, config, last_layer_evaluation); - #[cfg(feature = "icicle")] nvtx::range_pop!(); Self { config, @@ -223,18 +225,22 @@ impl<'a, B: FriOps + MerkleOps, MC: MerkleChannel> FriProver<'a, B, MC> { while layer_evaluation.len() > config.last_layer_domain_size() { // Check for circle polys in the first layer that should be combined in this layer. while let Some(column) = columns.next_if(|c| folded_size(c) == layer_evaluation.len()) { + nvtx::range_push!("fold circle"); B::fold_circle_into_line( &mut layer_evaluation, column, circle_poly_folding_alpha, twiddles, ); + nvtx::range_pop!(); } let layer = FriInnerLayerProver::new(layer_evaluation); MC::mix_root(channel, layer.merkle_tree.root()); let folding_alpha = channel.draw_felt(); + nvtx::range_push!("fold line"); let folded_layer_evaluation = B::fold_line(&layer.evaluation, folding_alpha, twiddles); + nvtx::range_pop!(); layer_evaluation = folded_layer_evaluation; layers.push(layer); @@ -856,8 +862,12 @@ struct FriFirstLayerProver<'a, B: FriOps + MerkleOps, H: MerkleHasher> { impl<'a, B: FriOps + MerkleOps, H: MerkleHasher> FriFirstLayerProver<'a, B, H> { fn new(columns: &'a [SecureEvaluation]) -> Self { + nvtx::range_push!("extract columns"); let coordinate_columns = extract_coordinate_columns(columns); + nvtx::range_pop!(); + nvtx::range_push!("Merkle commit"); let merkle_tree = MerkleProver::commit(coordinate_columns); + nvtx::range_pop!(); FriFirstLayerProver { columns, @@ -941,7 +951,9 @@ struct FriInnerLayerProver, H: MerkleHasher> { impl, H: MerkleHasher> FriInnerLayerProver { fn new(evaluation: LineEvaluation) -> Self { + nvtx::range_push!("Merkle commit"); let merkle_tree = MerkleProver::commit(evaluation.values.columns.iter().collect_vec()); + nvtx::range_pop!(); FriInnerLayerProver { evaluation, merkle_tree, diff --git a/crates/prover/src/core/pcs/prover.rs b/crates/prover/src/core/pcs/prover.rs index 9d51f28b0..199e01abd 100644 --- a/crates/prover/src/core/pcs/prover.rs +++ b/crates/prover/src/core/pcs/prover.rs @@ -85,11 +85,9 @@ impl<'a, B: BackendForChannel, MC: MerkleChannel> CommitmentSchemeProver<'a, sampled_points: TreeVec>>>, channel: &mut MC::C, ) -> CommitmentSchemeProof { - #[cfg(feature = "icicle")] - nvtx::range_push!("fn prove_values("); - // Evaluate polynomials on open points. let span = span!(Level::INFO, "Evaluate columns out of domain").entered(); + nvtx::range_push!("sample points w/ values"); let samples = self .polynomials() .zip_cols(&sampled_points) @@ -102,40 +100,53 @@ impl<'a, B: BackendForChannel, MC: MerkleChannel> CommitmentSchemeProver<'a, }) .collect_vec() }); + nvtx::range_pop!(); span.exit(); + nvtx::range_push!("mix values"); let sampled_values = samples .as_cols_ref() .map_cols(|x| x.iter().map(|o| o.value).collect()); channel.mix_felts(&sampled_values.clone().flatten_cols()); + nvtx::range_pop!(); // Compute oods quotients for boundary constraints on the sampled points. let columns = self.evaluations().flatten(); + nvtx::range_push!("fri_quotients"); let quotients = compute_fri_quotients( &columns, &samples.flatten(), channel.draw_felt(), self.config.fri_config.log_blowup_factor, ); + nvtx::range_pop!(); // Run FRI commitment phase on the oods quotients. + nvtx::range_push!("FRI commit"); let fri_prover = FriProver::::commit(channel, self.config.fri_config, "ients, self.twiddles); - + nvtx::range_pop!(); + // Proof of work. + nvtx::range_push!("Proof of work"); let span1 = span!(Level::INFO, "Grind").entered(); let proof_of_work = B::grind(channel, self.config.pow_bits); span1.exit(); channel.mix_u64(proof_of_work); - + nvtx::range_pop!(); + // FRI decommitment phase. + nvtx::range_push!("FRI Decommit"); let (fri_proof, query_positions_per_log_size) = fri_prover.decommit(channel); - + nvtx::range_pop!(); + // Decommit the FRI queries on the merkle trees. + nvtx::range_push!("Tree Decommit"); let decommitment_results = self .trees .as_ref() .map(|tree| tree.decommit(&query_positions_per_log_size)); - + nvtx::range_pop!(); + let queried_values = decommitment_results.as_ref().map(|(v, _)| v.clone()); let decommitments = decommitment_results.map(|(_, d)| d); @@ -147,9 +158,6 @@ impl<'a, B: BackendForChannel, MC: MerkleChannel> CommitmentSchemeProver<'a, proof_of_work, fri_proof, }; - - #[cfg(feature = "icicle")] - nvtx::range_pop!(); result } } diff --git a/crates/prover/src/core/pcs/quotients.rs b/crates/prover/src/core/pcs/quotients.rs index 04f3fe532..9e9092e18 100644 --- a/crates/prover/src/core/pcs/quotients.rs +++ b/crates/prover/src/core/pcs/quotients.rs @@ -79,7 +79,6 @@ pub fn compute_fri_quotients( log_blowup_factor: u32, ) -> Vec> { #[cfg(feature = "icicle")] - nvtx::range_push!("fn compute_fri_quotients("); let _span = span!(Level::INFO, "Compute FRI quotients").entered(); let result = zip(columns, samples) .sorted_by_key(|(c, _)| Reverse(c.domain.log_size())) @@ -99,9 +98,6 @@ pub fn compute_fri_quotients( ) }) .collect(); - - #[cfg(feature = "icicle")] - nvtx::range_pop!(); return result; } diff --git a/crates/prover/src/core/prover/mod.rs b/crates/prover/src/core/prover/mod.rs index 3e045da56..5bef4def1 100644 --- a/crates/prover/src/core/prover/mod.rs +++ b/crates/prover/src/core/prover/mod.rs @@ -32,7 +32,7 @@ pub fn prove, MC: MerkleChannel>( #[cfg(feature = "icicle")] nvtx::name_thread!("stark_prover"); #[cfg(feature = "icicle")] - nvtx::range_push!("fn prove("); + nvtx::range_push!("fn prove"); let n_preprocessed_columns = commitment_scheme.trees[PREPROCESSED_TRACE_IDX] .polynomials @@ -48,25 +48,37 @@ pub fn prove, MC: MerkleChannel>( let span = span!(Level::INFO, "Composition").entered(); let span1 = span!(Level::INFO, "Generation").entered(); + nvtx::range_push!("fn compute_composition_polynomial"); let composition_poly = component_provers.compute_composition_polynomial(random_coeff, &trace); + nvtx::range_pop!(); span1.exit(); + nvtx::range_push!("tree builder + commit"); let mut tree_builder = commitment_scheme.tree_builder(); tree_builder.extend_polys(composition_poly.into_coordinate_polys()); + nvtx::range_push!("commit"); tree_builder.commit(channel); + nvtx::range_pop!(); + nvtx::range_pop!(); span.exit(); // Draw OODS point. + nvtx::range_push!("Draw OODS point"); let oods_point = CirclePoint::::get_random_point(channel); + nvtx::range_pop!(); // Get mask sample points relative to oods point. + nvtx::range_push!("mask sample points"); let mut sample_points = component_provers.components().mask_points(oods_point); + nvtx::range_pop!(); // Add the composition polynomial mask points. sample_points.push(vec![vec![oods_point]; SECURE_EXTENSION_DEGREE]); // Prove the trace and composition OODS values, and retrieve them. + nvtx::range_push!("fn prove_values"); let commitment_scheme_proof = commitment_scheme.prove_values(sample_points, channel); + nvtx::range_pop!(); let proof = StarkProof(commitment_scheme_proof); info!(proof_size_estimate = proof.size_estimate()); diff --git a/crates/prover/src/examples/wide_fibonacci/mod.rs b/crates/prover/src/examples/wide_fibonacci/mod.rs index 1251f9335..ff656d3a1 100644 --- a/crates/prover/src/examples/wide_fibonacci/mod.rs +++ b/crates/prover/src/examples/wide_fibonacci/mod.rs @@ -171,8 +171,13 @@ mod tests { } #[test_log::test] - fn test_wide_fib_prove_with_blake() { - for log_n_instances in 2..=6 { + fn test_wide_fib_prove_with_blake_simd() { + use crate::examples::utils::get_env_var; + + let min_log = get_env_var("MIN_FIB_LOG", 2u32); + let max_log = get_env_var("MAX_FIB_LOG", 18u32); + + for log_n_instances in min_log..=max_log { let config = PcsConfig::default(); // Precompute twiddles. let twiddles = SimdBackend::precompute_twiddles( @@ -237,77 +242,74 @@ mod tests { // type TheBackend = CpuBackend; let min_log = get_env_var("MIN_FIB_LOG", 2u32); - let max_log = get_env_var("MAX_FIB_LOG", 18u32); + let max_log = get_env_var("MAX_FIB_LOG", 23u32); for log_n_instances in min_log..=max_log { - for _ in 0..1 { - println!("proving for 2^{:?}...", log_n_instances); - let config = PcsConfig::default(); - // Precompute twiddles. - let twiddles = TheBackend::precompute_twiddles( - CanonicCoset::new(log_n_instances + 1 + config.fri_config.log_blowup_factor) - .circle_domain() - .half_coset, - ); - - // Setup protocol. - let prover_channel = &mut Blake2sChannel::default(); - let mut commitment_scheme = CommitmentSchemeProver::< - TheBackend, - Blake2sMerkleChannel, - >::new(config, &twiddles); - - // Preprocessed trace - let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals([]); - tree_builder.commit(prover_channel); - - // Trace. - let trace: Vec> = - generate_test_trace(log_n_instances) - .iter() - .map(|c| unsafe { std::mem::transmute(c.to_cpu()) }) - .collect_vec(); - - let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(trace); - tree_builder.commit(prover_channel); - - // Prove constraints. - let component = WideFibonacciComponent::new( - &mut TraceLocationAllocator::default(), - WideFibonacciEval:: { - log_n_rows: log_n_instances, - }, - (SecureField::zero(), None), - ); - - let start = std::time::Instant::now(); - let proof = prove::( - &[&component], - prover_channel, - commitment_scheme, - ) - .unwrap(); - println!( - "proving for 2^{:?} took {:?} ms", - log_n_instances, - start.elapsed().as_millis() - ); - - // Verify. - let verifier_channel = &mut Blake2sChannel::default(); - let commitment_scheme = - &mut CommitmentSchemeVerifier::::new(config); - - // Retrieve the expected column sizes in each commitment interaction, from the AIR. - let sizes = component.trace_log_degree_bounds(); - commitment_scheme.commit(proof.commitments[0], &sizes[0], verifier_channel); - commitment_scheme.commit(proof.commitments[1], &sizes[1], verifier_channel); - verify(&[&component], verifier_channel, commitment_scheme, proof).unwrap_or_else(|err| { + let config = PcsConfig::default(); + // Precompute twiddles. + let twiddles = TheBackend::precompute_twiddles( + CanonicCoset::new(log_n_instances + 1 + config.fri_config.log_blowup_factor) + .circle_domain() + .half_coset, + ); + + // Setup protocol. + let prover_channel = &mut Blake2sChannel::default(); + let mut commitment_scheme = + CommitmentSchemeProver::::new(config, &twiddles); + + // Preprocessed trace + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals([]); + tree_builder.commit(prover_channel); + + // Trace. + let trace: Vec> = + generate_test_trace(log_n_instances) + .iter() + .map(|c| unsafe { std::mem::transmute(c.to_cpu()) }) + .collect_vec(); + + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(trace); + tree_builder.commit(prover_channel); + + // Prove constraints. + let component = WideFibonacciComponent::new( + &mut TraceLocationAllocator::default(), + WideFibonacciEval:: { + log_n_rows: log_n_instances, + }, + (SecureField::zero(), None), + ); + + let start = std::time::Instant::now(); + let proof = prove::( + &[&component], + prover_channel, + commitment_scheme, + ) + .unwrap(); + println!( + "proving for 2^{:?} took {:?} ms", + log_n_instances, + start.elapsed().as_millis() + ); + + // Verify. + let verifier_channel = &mut Blake2sChannel::default(); + let commitment_scheme = + &mut CommitmentSchemeVerifier::::new(config); + + // Retrieve the expected column sizes in each commitment interaction, from the AIR. + let sizes = component.trace_log_degree_bounds(); + commitment_scheme.commit(proof.commitments[0], &sizes[0], verifier_channel); + commitment_scheme.commit(proof.commitments[1], &sizes[1], verifier_channel); + verify(&[&component], verifier_channel, commitment_scheme, proof).unwrap_or_else( + |err| { println!("verify failed for {} with: {}", log_n_instances, err); - }); - } + }, + ); } } From c5af6763a70b371cacfb7156ca562429dc4c350c Mon Sep 17 00:00:00 2001 From: Ohad <137686240+ohad-starkware@users.noreply.github.com> Date: Tue, 7 Jan 2025 15:55:40 +0200 Subject: [PATCH 56/69] par fri accumulation (#963) --- .../prover/src/core/backend/simd/quotients.rs | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/crates/prover/src/core/backend/simd/quotients.rs b/crates/prover/src/core/backend/simd/quotients.rs index 9dd30f7fd..f0155ebb1 100644 --- a/crates/prover/src/core/backend/simd/quotients.rs +++ b/crates/prover/src/core/backend/simd/quotients.rs @@ -1,5 +1,7 @@ use itertools::{izip, zip_eq, Itertools}; use num_traits::Zero; +#[cfg(feature = "parallel")] +use rayon::prelude::*; use tracing::{span, Level}; use super::cm31::PackedCM31; @@ -114,10 +116,17 @@ fn accumulate_quotients_on_subdomain( let quotient_constants = quotient_constants(sample_batches, random_coeff, subdomain); let span = span!(Level::INFO, "Quotient accumulation").entered(); - for (quad_row, points) in CircleDomainBitRevIterator::new(subdomain) + let quad_rows = CircleDomainBitRevIterator::new(subdomain) .array_chunks::<4>() - .enumerate() - { + .collect_vec(); + + #[cfg(not(feature = "parallel"))] + let iter = quad_rows.iter().zip(values.chunks_mut(4)).enumerate(); + + #[cfg(feature = "parallel")] + let iter = quad_rows.par_iter().zip(values.chunks_mut(4)).enumerate(); + + iter.for_each(|(quad_row, (points, mut values_dst))| { // TODO(andrew): Spapini said: Use optimized domain iteration. Is there a better way to do // this? let (y01, _) = points[0].y.deinterleave(points[1].y); @@ -130,11 +139,13 @@ fn accumulate_quotients_on_subdomain( quad_row, spaced_ys, ); - #[allow(clippy::needless_range_loop)] - for i in 0..4 { - unsafe { values.set_packed((quad_row << 2) + i, row_accumulator[i]) }; + unsafe { + values_dst.set_packed(0, row_accumulator[0]); + values_dst.set_packed(1, row_accumulator[1]); + values_dst.set_packed(2, row_accumulator[2]); + values_dst.set_packed(3, row_accumulator[3]); } - } + }); span.exit(); let span = span!(Level::INFO, "Quotient extension").entered(); From cd72092c1d7358dcb6ac3c95ecf8412114ba0913 Mon Sep 17 00:00:00 2001 From: ilyalesokhin-starkware Date: Wed, 8 Jan 2025 10:51:46 +0200 Subject: [PATCH 57/69] Add soundness todo. (#962) --- crates/prover/src/core/pcs/quotients.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/prover/src/core/pcs/quotients.rs b/crates/prover/src/core/pcs/quotients.rs index d41f17670..742a78f3a 100644 --- a/crates/prover/src/core/pcs/quotients.rs +++ b/crates/prover/src/core/pcs/quotients.rs @@ -139,6 +139,7 @@ pub fn fri_answers_for_log_size( n_columns: TreeVec, ) -> Result, VerificationError> { let sample_batches = ColumnSampleBatch::new_vec(samples); + // TODO(ilya): Is it ok to use the same `random_coeff` for all log sizes. let quotient_constants = quotient_constants(&sample_batches, random_coeff); let commitment_domain = CanonicCoset::new(log_size).circle_domain(); From 4a205b5b25a7613bfc9a144562729e8018b8887d Mon Sep 17 00:00:00 2001 From: VitaliiH Date: Wed, 8 Jan 2025 12:12:49 +0100 Subject: [PATCH 58/69] domaininit icicle --- Cargo.lock | 8 +- crates/prover/Cargo.toml | 14 +-- crates/prover/src/core/backend/icicle/mod.rs | 91 ++++++++----------- .../prover/src/examples/wide_fibonacci/mod.rs | 2 + 4 files changed, 48 insertions(+), 67 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 19c067fda..3279ac177 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -760,7 +760,7 @@ dependencies = [ [[package]] name = "icicle-core" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=cc4fc660ea331bab1fe725cd8a93b371bb59c62a#cc4fc660ea331bab1fe725cd8a93b371bb59c62a" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=7e7c1c8d96af1963df94c9ab6e7fdc37176e9543#7e7c1c8d96af1963df94c9ab6e7fdc37176e9543" dependencies = [ "criterion 0.3.6", "hex", @@ -771,7 +771,7 @@ dependencies = [ [[package]] name = "icicle-cuda-runtime" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=cc4fc660ea331bab1fe725cd8a93b371bb59c62a#cc4fc660ea331bab1fe725cd8a93b371bb59c62a" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=7e7c1c8d96af1963df94c9ab6e7fdc37176e9543#7e7c1c8d96af1963df94c9ab6e7fdc37176e9543" dependencies = [ "bindgen", "bitflags 1.3.2", @@ -780,7 +780,7 @@ dependencies = [ [[package]] name = "icicle-hash" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=cc4fc660ea331bab1fe725cd8a93b371bb59c62a#cc4fc660ea331bab1fe725cd8a93b371bb59c62a" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=7e7c1c8d96af1963df94c9ab6e7fdc37176e9543#7e7c1c8d96af1963df94c9ab6e7fdc37176e9543" dependencies = [ "cmake", "icicle-core", @@ -790,7 +790,7 @@ dependencies = [ [[package]] name = "icicle-m31" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=cc4fc660ea331bab1fe725cd8a93b371bb59c62a#cc4fc660ea331bab1fe725cd8a93b371bb59c62a" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=7e7c1c8d96af1963df94c9ab6e7fdc37176e9543#7e7c1c8d96af1963df94c9ab6e7fdc37176e9543" dependencies = [ "cmake", "criterion 0.3.6", diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index a3bd229d1..2df0cf7a0 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -29,19 +29,13 @@ tracing.workspace = true rayon = { version = "1.10.0", optional = true } serde = { version = "1.0", features = ["derive"] } -icicle-cuda-runtime = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="cc4fc660ea331bab1fe725cd8a93b371bb59c62a"} -icicle-core = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="cc4fc660ea331bab1fe725cd8a93b371bb59c62a"} -icicle-m31 = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="cc4fc660ea331bab1fe725cd8a93b371bb59c62a"} -icicle-hash = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="cc4fc660ea331bab1fe725cd8a93b371bb59c62a"} - -# icicle-cuda-runtime = { path = "/home/vhnat/repo/icicle/wrappers/rust/icicle-cuda-runtime", optional = true} -# icicle-core = { path = "/home/vhnat/repo/icicle/wrappers/rust/icicle-core", optional = true} -# icicle-m31 = { path = "/home/vhnat/repo/icicle/wrappers/rust/icicle-fields/icicle-m31", optional = true} -# icicle-hash = { path = "/home/vhnat/repo/icicle/wrappers/rust/icicle-hash", optional = true} +icicle-cuda-runtime = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="7e7c1c8d96af1963df94c9ab6e7fdc37176e9543"} +icicle-core = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="7e7c1c8d96af1963df94c9ab6e7fdc37176e9543"} +icicle-m31 = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="7e7c1c8d96af1963df94c9ab6e7fdc37176e9543"} +icicle-hash = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="7e7c1c8d96af1963df94c9ab6e7fdc37176e9543"} nvtx = { version = "*", optional = true } - [dev-dependencies] aligned = "0.4.2" test-log = { version = "0.2.15", features = ["trace"] } diff --git a/crates/prover/src/core/backend/icicle/mod.rs b/crates/prover/src/core/backend/icicle/mod.rs index af984c4a2..bc8fed885 100644 --- a/crates/prover/src/core/backend/icicle/mod.rs +++ b/crates/prover/src/core/backend/icicle/mod.rs @@ -7,14 +7,15 @@ use std::ffi::c_void; use std::iter::zip; use std::mem::{size_of_val, transmute}; +use icicle_core::field::Field as IcicleField; use icicle_core::tree::{merkle_tree_digests_len, TreeBuilderConfig}; use icicle_core::vec_ops::{accumulate_scalars, VecOpsConfig}; -use icicle_core::{field::Field as IcicleField, Matrix}; +use icicle_core::Matrix; use icicle_hash::blake2s::build_blake2s_mmcs; use icicle_m31::dcct::{evaluate, get_dcct_root_of_unity, initialize_dcct_domain, interpolate}; +use icicle_m31::field::ScalarCfg; use icicle_m31::fri::{self, fold_circle_into_line, fold_circle_into_line_new, FriConfig}; use icicle_m31::quotient; -use icicle_m31::field::ScalarCfg; use itertools::Itertools; use serde::{Deserialize, Serialize}; use twiddles::TwiddleTree; @@ -149,7 +150,7 @@ impl AccumulationOps for IcicleBackend { let rr = unsafe { slice::from_raw_parts_mut(v_ptr, n) }; d_b_slice = DeviceSlice::from_mut_slice(rr); } - + nvtx::range_push!("[ICICLE] accum scalars"); accumulate_scalars(d_a_slice, d_b_slice, &cfg); nvtx::range_pop!(); @@ -196,7 +197,7 @@ impl MerkleOps for IcicleBackend { nvtx::range_pop!(); let mut digests = vec![0u8; digests_len]; let digests_slice = HostSlice::from_mut_slice(&mut digests); - + nvtx::range_push!("[ICICLE] build_blake2s_mmcs"); build_blake2s_mmcs(&matrices, digests_slice, &config).unwrap(); nvtx::range_pop!(); @@ -273,7 +274,7 @@ impl PolyOps for IcicleBackend { nvtx::range_push!("[ICICLE] get_dcct_root_of_unity"); let rou = get_dcct_root_of_unity(eval.domain.size() as _); nvtx::range_pop!(); - + nvtx::range_push!("[ICICLE] initialize_dcct_domain"); initialize_dcct_domain(eval.domain.log_size(), rou, &DeviceContext::default()).unwrap(); nvtx::range_pop!(); @@ -290,7 +291,7 @@ impl PolyOps for IcicleBackend { ) .unwrap(); nvtx::range_pop!(); - + let values: Vec = unsafe { transmute(evaluations) }; CirclePoly::new(values) @@ -312,7 +313,7 @@ impl PolyOps for IcicleBackend { } mappings.reverse(); nvtx::range_pop!(); - + nvtx::range_push!("[ICICLE] fold"); let folded = crate::core::backend::icicle::utils::fold(&poly.coeffs, &mappings); nvtx::range_pop!(); @@ -522,7 +523,7 @@ impl FriOps for IcicleBackend { let mut d_domain_icicle = DeviceVec::::cuda_malloc(dom_vals_len).unwrap(); d_domain_icicle.copy_from_host(domain_icicle_host).unwrap(); nvtx::range_pop!(); - + nvtx::range_push!("[ICICLE] domain evals convert + move"); let mut d_evals_icicle = DeviceVec::::cuda_malloc(length).unwrap(); SecureColumnByCoords::::convert_to_icicle( @@ -573,36 +574,13 @@ impl FriOps for IcicleBackend { let dom_vals_len = length / 2; let domain_log_size = domain.log_size(); - // let mut domain_rev = Vec::new(); - // for i in 0..dom_vals_len { - // // TODO: on-device batch - // // TODO(andrew): Inefficient. Update when domain twiddles get stored in a buffer. - // let p = domain.at(bit_reverse_index( - // i << CIRCLE_TO_LINE_FOLD_STEP, - // domain.log_size(), - // )); - // let p = p.y.inverse(); - // domain_rev.push(p); - // } - - // let domain_vals = (0..dom_vals_len) - // .map(|i| { - // let p = domain_rev[i]; - // ScalarField::from_u32(p.0) - // }) - // .collect::>(); - - // let domain_icicle_host = HostSlice::from_slice(domain_vals.as_slice()); - // let mut d_domain_icicle = DeviceVec::::cuda_malloc(dom_vals_len).unwrap(); - // d_domain_icicle.copy_from_host(domain_icicle_host).unwrap(); - let mut d_evals_icicle = DeviceVec::::cuda_malloc(length).unwrap(); SecureColumnByCoords::convert_to_icicle(&src.values, &mut d_evals_icicle); nvtx::range_pop!(); - + nvtx::range_push!("[ICICLE] d_folded_evals"); let mut d_folded_eval = - DeviceVec::::cuda_malloc(dom_vals_len).unwrap(); + DeviceVec::::cuda_malloc(dom_vals_len).unwrap(); SecureColumnByCoords::convert_to_icicle(&dst.values, &mut d_folded_eval); nvtx::range_pop!(); @@ -613,6 +591,10 @@ impl FriOps for IcicleBackend { let icicle_alpha = unsafe { transmute(alpha) }; nvtx::range_push!("[ICICLE] fold circle"); + println!( + "index {}, half_log_size {}", + domain.half_coset.initial_index.0, domain.half_coset.log_size + ); let _ = fold_circle_into_line_new( &d_evals_icicle[..], domain.half_coset.initial_index.0 as _, @@ -688,46 +670,49 @@ impl QuotientOps for IcicleBackend { // )) // } - let total_columns_size = columns.iter().fold(0, |acc, column| acc + column.values.len()); + let total_columns_size = columns + .iter() + .fold(0, |acc, column| acc + column.values.len()); let mut icicle_device_columns = DeviceVec::cuda_malloc(total_columns_size).unwrap(); let mut start = 0; nvtx::range_push!("[ICICLE] columns to device"); columns.iter().for_each(|column| { let end = start + column.values.len(); let device_slice = &mut icicle_device_columns[start..end]; - let transmuted: Vec> = unsafe { transmute(column.values.clone()) }; + let transmuted: Vec> = + unsafe { transmute(column.values.clone()) }; device_slice.copy_from_host(&HostSlice::from_slice(&transmuted)); start += column.values.len(); }); nvtx::range_pop!(); - + nvtx::range_push!("[ICICLE] column sample batch"); let icicle_sample_batches = sample_batches .into_iter() .map(|sample| { let (columns, values) = sample - .columns_and_values - .iter() - .map(|(index, value)| { - ((*index) as u32, unsafe { - transmute::(*value) - }) + .columns_and_values + .iter() + .map(|(index, value)| { + ((*index) as u32, unsafe { + transmute::(*value) }) - .unzip(); - - quotient::ColumnSampleBatch { - point: unsafe { transmute(sample.point) }, - columns, - values, - } - }) - .collect_vec(); + }) + .unzip(); + + quotient::ColumnSampleBatch { + point: unsafe { transmute(sample.point) }, + columns, + values, + } + }) + .collect_vec(); nvtx::range_pop!(); let mut icicle_result_raw = vec![QuarticExtensionField::zero(); domain.size()]; let icicle_result = HostSlice::from_mut_slice(icicle_result_raw.as_mut_slice()); let cfg = quotient::QuotientConfig::default(); - + nvtx::range_push!("[ICICLE] accumulate_quotients_wrapped"); quotient::accumulate_quotients_wrapped( // domain.half_coset.initial_index.0 as u32, @@ -746,7 +731,7 @@ impl QuotientOps for IcicleBackend { (0..domain.size()).for_each(|i| result.set(i, unsafe { transmute(icicle_result_raw[i]) })); let ret = SecureEvaluation::new(domain, result); nvtx::range_pop!(); - + ret } } diff --git a/crates/prover/src/examples/wide_fibonacci/mod.rs b/crates/prover/src/examples/wide_fibonacci/mod.rs index ff656d3a1..de703c584 100644 --- a/crates/prover/src/examples/wide_fibonacci/mod.rs +++ b/crates/prover/src/examples/wide_fibonacci/mod.rs @@ -283,6 +283,8 @@ mod tests { (SecureField::zero(), None), ); + icicle_m31::fri::precompute_fri_twiddles(log_n_instances).unwrap(); + println!("++++++++ proving for 2^{:?}", log_n_instances); let start = std::time::Instant::now(); let proof = prove::( &[&component], From ed2076e76a3eef023fb33a71dd7c1fd40a5c81d7 Mon Sep 17 00:00:00 2001 From: VitaliiH Date: Wed, 8 Jan 2025 12:13:25 +0100 Subject: [PATCH 59/69] -O3 by default --- crates/prover/Cargo.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index 2df0cf7a0..cb86db769 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -119,3 +119,6 @@ name = "pcs" [[bench]] harness = false name = "accumulate" + +[profile.release] +opt-level = 3 \ No newline at end of file From 31e8dbcc4752240b596774743946c561ab5b9cd1 Mon Sep 17 00:00:00 2001 From: Ohad <137686240+ohad-starkware@users.noreply.github.com> Date: Thu, 9 Jan 2025 09:55:21 +0200 Subject: [PATCH 60/69] improved component trace docs (#970) --- crates/air_utils/src/trace/component_trace.rs | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/crates/air_utils/src/trace/component_trace.rs b/crates/air_utils/src/trace/component_trace.rs index ba59964ea..cde923672 100644 --- a/crates/air_utils/src/trace/component_trace.rs +++ b/crates/air_utils/src/trace/component_trace.rs @@ -9,12 +9,13 @@ use stwo_prover::core::poly::BitReversedOrder; use super::row_iterator::{ParRowIterMut, RowIterMut}; /// A 2D Matrix of [`PackedM31`] values. -/// Used for generating the witness of 'Stwo' proofs. +/// +/// Used for generating the witness of 'Stwo' proofs.\ /// Stored as an array of `N` columns, each column is a vector of [`PackedM31`] values. -/// All columns are of the same length. +/// All columns are of the same length.\ /// Exposes an iterator over mutable references to the rows of the matrix. /// -/// # Example: +/// # Example /// /// ```text /// Computation trace of a^2 + (a + 1)^2 for a in 0..256 @@ -58,7 +59,9 @@ pub struct ComponentTrace { impl ComponentTrace { /// Creates a new `ComponentTrace` with all values initialized to zero. /// The number of rows in each column is `2^log_size`. - /// # Panics: + /// + /// # Panics + /// /// if log_size < 4. pub fn zeroed(log_size: u32) -> Self { assert!( @@ -71,10 +74,14 @@ impl ComponentTrace { } /// Creates a new `ComponentTrace` with all values uninitialized. + /// /// # Safety + /// /// The caller must ensure that the column is populated before being used. /// The number of rows in each column is `2^log_size`. - /// # Panics: + /// + /// # Panics + /// /// if `log_size` < 4. #[allow(clippy::uninit_vec)] pub unsafe fn uninitialized(log_size: u32) -> Self { From 4c55de2d382f1d4662085e8f37af29fdb854d83f Mon Sep 17 00:00:00 2001 From: VitaliiH Date: Thu, 9 Jan 2025 10:51:31 +0100 Subject: [PATCH 61/69] merge in latest dev cd72092c --- .github/workflows/benchmarks-pages.yaml | 8 +- .github/workflows/ci.yaml | 58 +- .github/workflows/coverage.yaml | 4 +- Cargo.lock | 73 ++- Cargo.toml | 2 +- crates/air_utils/Cargo.toml | 14 + crates/air_utils/src/lib.rs | 3 + crates/air_utils/src/lookup_data/mod.rs | 168 +++++ crates/air_utils/src/trace/component_trace.rs | 191 ++++++ crates/air_utils/src/trace/mod.rs | 2 + crates/air_utils/src/trace/row_iterator.rs | 126 ++++ crates/air_utils_derive/Cargo.toml | 13 + crates/air_utils_derive/src/allocation.rs | 30 + crates/air_utils_derive/src/iter_mut.rs | 167 +++++ crates/air_utils_derive/src/iterable_field.rs | 369 +++++++++++ crates/air_utils_derive/src/lib.rs | 45 ++ crates/air_utils_derive/src/par_iter.rs | 163 +++++ crates/prover/Cargo.toml | 5 +- crates/prover/benches/bit_rev.rs | 5 +- crates/prover/benches/fft.rs | 14 +- crates/prover/benches/merkle.rs | 2 +- .../prover/src/constraint_framework/assert.rs | 14 +- .../src/constraint_framework/component.rs | 16 +- .../src/constraint_framework/cpu_domain.rs | 4 +- .../prover/src/constraint_framework/expr.rs | 536 ---------------- .../constraint_framework/expr/assignment.rs | 267 ++++++++ .../src/constraint_framework/expr/degree.rs | 100 +++ .../constraint_framework/expr/evaluator.rs | 244 ++++++++ .../src/constraint_framework/expr/format.rs | 65 ++ .../src/constraint_framework/expr/mod.rs | 352 +++++++++++ .../src/constraint_framework/expr/simplify.rs | 216 +++++++ .../src/constraint_framework/expr/utils.rs | 65 ++ .../prover/src/constraint_framework/logup.rs | 21 +- crates/prover/src/constraint_framework/mod.rs | 129 +++- .../prover/src/constraint_framework/point.rs | 2 +- .../preprocessed_columns.rs | 142 ++++- .../constraint_framework/relation_tracker.rs | 265 ++++++++ .../src/constraint_framework/simd_domain.rs | 4 +- crates/prover/src/core/air/accumulation.rs | 25 +- crates/prover/src/core/air/components.rs | 4 +- crates/prover/src/core/air/mod.rs | 8 +- .../src/core/backend/cpu/accumulation.rs | 48 +- crates/prover/src/core/backend/cpu/circle.rs | 3 +- .../src/core/backend/cpu/lookups/gkr.rs | 2 +- crates/prover/src/core/backend/cpu/mod.rs | 80 ++- .../prover/src/core/backend/cpu/quotients.rs | 13 +- crates/prover/src/core/backend/icicle/mod.rs | 10 +- .../src/core/backend/simd/accumulation.rs | 54 +- .../src/core/backend/simd/bit_reverse.rs | 7 +- .../prover/src/core/backend/simd/blake2s.rs | 8 +- crates/prover/src/core/backend/simd/circle.rs | 11 +- crates/prover/src/core/backend/simd/cm31.rs | 4 +- crates/prover/src/core/backend/simd/column.rs | 15 +- crates/prover/src/core/backend/simd/domain.rs | 2 +- .../prover/src/core/backend/simd/fft/ifft.rs | 16 +- .../prover/src/core/backend/simd/fft/mod.rs | 18 +- .../prover/src/core/backend/simd/fft/rfft.rs | 21 +- crates/prover/src/core/backend/simd/fri.rs | 15 +- .../src/core/backend/simd/lookups/gkr.rs | 2 +- .../src/core/backend/simd/lookups/mle.rs | 3 +- crates/prover/src/core/backend/simd/m31.rs | 584 +++++++++--------- .../src/core/backend/simd/prefix_sum.rs | 5 +- crates/prover/src/core/backend/simd/qm31.rs | 8 +- .../prover/src/core/backend/simd/quotients.rs | 31 +- crates/prover/src/core/backend/simd/utils.rs | 95 +-- crates/prover/src/core/channel/blake2s.rs | 4 +- crates/prover/src/core/channel/poseidon252.rs | 2 +- crates/prover/src/core/circle.rs | 19 +- crates/prover/src/core/constraints.rs | 10 +- crates/prover/src/core/fields/cm31.rs | 2 +- crates/prover/src/core/fields/m31.rs | 16 +- crates/prover/src/core/fields/qm31.rs | 6 +- crates/prover/src/core/fri.rs | 56 +- crates/prover/src/core/lookups/gkr_prover.rs | 4 +- .../prover/src/core/lookups/gkr_verifier.rs | 4 +- crates/prover/src/core/lookups/utils.rs | 4 +- crates/prover/src/core/pcs/mod.rs | 1 + crates/prover/src/core/pcs/prover.rs | 8 +- crates/prover/src/core/pcs/quotients.rs | 47 +- crates/prover/src/core/pcs/utils.rs | 9 +- crates/prover/src/core/pcs/verifier.rs | 20 +- crates/prover/src/core/poly/circle/canonic.rs | 26 +- crates/prover/src/core/poly/circle/domain.rs | 11 +- .../prover/src/core/poly/circle/evaluation.rs | 4 +- crates/prover/src/core/poly/circle/poly.rs | 2 +- crates/prover/src/core/poly/line.rs | 14 +- crates/prover/src/core/poly/twiddles.rs | 1 + crates/prover/src/core/prover/mod.rs | 2 +- crates/prover/src/core/queries.rs | 2 +- crates/prover/src/core/utils.rs | 137 +--- crates/prover/src/core/vcs/blake2_merkle.rs | 23 +- crates/prover/src/core/vcs/blake2s_ref.rs | 20 +- crates/prover/src/core/vcs/ops.rs | 11 +- .../prover/src/core/vcs/poseidon252_merkle.rs | 14 +- crates/prover/src/core/vcs/prover.rs | 55 +- crates/prover/src/core/vcs/test_utils.rs | 7 +- crates/prover/src/core/vcs/verifier.rs | 96 ++- crates/prover/src/examples/blake/mod.rs | 40 +- .../src/examples/blake/round/constraints.rs | 13 +- crates/prover/src/examples/blake/round/gen.rs | 2 +- crates/prover/src/examples/blake/round/mod.rs | 4 +- .../examples/blake/scheduler/constraints.rs | 20 +- .../src/examples/blake/scheduler/mod.rs | 2 + .../examples/blake/xor_table/constraints.rs | 56 +- crates/prover/src/examples/plonk/mod.rs | 20 +- crates/prover/src/examples/poseidon/mod.rs | 14 +- .../src/examples/state_machine/components.rs | 66 +- .../prover/src/examples/state_machine/mod.rs | 137 ++-- .../examples/xor/gkr_lookups/accumulation.rs | 5 +- .../src/examples/xor/gkr_lookups/mle_eval.rs | 20 +- crates/prover/src/lib.rs | 14 +- rust-toolchain.toml | 2 +- scripts/clippy.sh | 2 +- scripts/rust_fmt.sh | 2 +- scripts/test_avx.sh | 2 +- 115 files changed, 4394 insertions(+), 1660 deletions(-) create mode 100644 crates/air_utils/Cargo.toml create mode 100644 crates/air_utils/src/lib.rs create mode 100644 crates/air_utils/src/lookup_data/mod.rs create mode 100644 crates/air_utils/src/trace/component_trace.rs create mode 100644 crates/air_utils/src/trace/mod.rs create mode 100644 crates/air_utils/src/trace/row_iterator.rs create mode 100644 crates/air_utils_derive/Cargo.toml create mode 100644 crates/air_utils_derive/src/allocation.rs create mode 100644 crates/air_utils_derive/src/iter_mut.rs create mode 100644 crates/air_utils_derive/src/iterable_field.rs create mode 100644 crates/air_utils_derive/src/lib.rs create mode 100644 crates/air_utils_derive/src/par_iter.rs delete mode 100644 crates/prover/src/constraint_framework/expr.rs create mode 100644 crates/prover/src/constraint_framework/expr/assignment.rs create mode 100644 crates/prover/src/constraint_framework/expr/degree.rs create mode 100644 crates/prover/src/constraint_framework/expr/evaluator.rs create mode 100644 crates/prover/src/constraint_framework/expr/format.rs create mode 100644 crates/prover/src/constraint_framework/expr/mod.rs create mode 100644 crates/prover/src/constraint_framework/expr/simplify.rs create mode 100644 crates/prover/src/constraint_framework/expr/utils.rs create mode 100644 crates/prover/src/constraint_framework/relation_tracker.rs diff --git a/.github/workflows/benchmarks-pages.yaml b/.github/workflows/benchmarks-pages.yaml index 3672b6c2c..fc87b2d94 100644 --- a/.github/workflows/benchmarks-pages.yaml +++ b/.github/workflows/benchmarks-pages.yaml @@ -1,4 +1,4 @@ -name: +name: on: push: @@ -18,18 +18,18 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2025-01-02 - name: Run benchmark run: ./scripts/bench.sh -- --output-format bencher | tee output.txt - name: Download previous benchmark data uses: actions/cache@v4 with: path: ./cache - key: ${{ runner.os }}-benchmark + key: ${{ runner.os }}-${{github.event.pull_request.base.ref}}-benchmark - name: Store benchmark result uses: benchmark-action/github-action-benchmark@v1 with: - tool: 'cargo' + tool: "cargo" output-file-path: output.txt github-token: ${{ secrets.GITHUB_TOKEN }} auto-push: true diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index b9b72d33d..4d2bb45fa 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -25,7 +25,7 @@ jobs: - uses: dtolnay/rust-toolchain@master with: components: rustfmt - toolchain: nightly-2024-01-04 + toolchain: nightly-2025-01-02 - uses: Swatinem/rust-cache@v2 - run: scripts/rust_fmt.sh --check @@ -36,7 +36,7 @@ jobs: - uses: dtolnay/rust-toolchain@master with: components: clippy - toolchain: nightly-2024-01-04 + toolchain: nightly-2025-01-02 - uses: Swatinem/rust-cache@v2 - run: scripts/clippy.sh @@ -46,25 +46,25 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2025-01-02 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 doc + - run: cargo +nightly-2025-01-02 doc - run-wasm32-wasi-tests: + run-wasm32-wasip1-tests: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 - targets: wasm32-wasi + toolchain: nightly-2025-01-02 + targets: wasm32-wasip1 - uses: taiki-e/install-action@v2 with: tool: wasmtime - uses: Swatinem/rust-cache@v2 - - run: cargo test --target wasm32-wasi + - run: cargo test --target wasm32-wasip1 env: - CARGO_TARGET_WASM32_WASI_RUNNER: "wasmtime run --" + CARGO_TARGET_WASM32_WASIP1_RUNNER: "wasmtime run --" RUSTFLAGS: -C target-feature=+simd128 run-wasm32-unknown-tests: @@ -73,7 +73,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2025-01-02 targets: wasm32-unknown-unknown - uses: Swatinem/rust-cache@v2 - uses: jetli/wasm-pack-action@v0.4.0 @@ -89,9 +89,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2025-01-02 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 test + - run: cargo +nightly-2025-01-02 test env: RUSTFLAGS: -C target-feature=+neon @@ -104,9 +104,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2025-01-02 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 test + - run: cargo +nightly-2025-01-02 test env: RUSTFLAGS: -C target-cpu=native -C target-feature=+${{ matrix.target-feature }} @@ -116,14 +116,14 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2025-01-02 - name: Run benchmark run: ./scripts/bench.sh -- --output-format bencher | tee output.txt - name: Download previous benchmark data uses: actions/cache@v4 with: path: ./cache - key: ${{ runner.os }}-benchmark + key: ${{ runner.os }}-${{github.event.pull_request.base.ref}}-benchmark - name: Store benchmark result uses: benchmark-action/github-action-benchmark@v1 with: @@ -142,14 +142,14 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2025-01-02 - name: Run benchmark run: ./scripts/bench.sh --features="parallel" -- --output-format bencher | tee output.txt - name: Download previous benchmark data uses: actions/cache@v4 with: path: ./cache - key: ${{ runner.os }}-benchmark + key: ${{ runner.os }}-${{github.event.pull_request.base.ref}}-benchmark - name: Store benchmark result uses: benchmark-action/github-action-benchmark@v1 with: @@ -168,9 +168,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2025-01-02 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 test + - run: cargo +nightly-2025-01-02 test run-slow-tests: runs-on: ubuntu-latest @@ -178,9 +178,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2025-01-02 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 test --release --features="slow-tests" + - run: cargo +nightly-2025-01-02 test --release --features="slow-tests" run-tests-parallel: runs-on: ubuntu-latest @@ -188,16 +188,22 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-01-04 + toolchain: nightly-2025-01-02 - uses: Swatinem/rust-cache@v2 - - run: cargo +nightly-2024-01-04 test --features="parallel" + - run: cargo +nightly-2025-01-02 test --features="parallel" machete: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly-2024-01-04 + - uses: Swatinem/rust-cache@v2 + - name: Install Machete + run: cargo +nightly-2025-01-02 install --locked cargo-machete - name: Run Machete (detect unused dependencies) - uses: bnjbvr/cargo-machete@main + run: cargo +nightly-2025-01-02 machete all-tests: runs-on: ubuntu-latest @@ -207,7 +213,7 @@ jobs: - run-tests - run-avx-tests - run-neon-tests - - run-wasm32-wasi-tests + - run-wasm32-wasip1-tests - run-slow-tests - run-tests-parallel - machete diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml index 504cd67bb..05a1482cb 100644 --- a/.github/workflows/coverage.yaml +++ b/.github/workflows/coverage.yaml @@ -12,14 +12,14 @@ jobs: - uses: dtolnay/rust-toolchain@master with: components: rustfmt - toolchain: nightly-2024-01-04 + toolchain: nightly-2025-01-02 - uses: Swatinem/rust-cache@v2 - name: Install cargo-llvm-cov uses: taiki-e/install-action@cargo-llvm-cov # TODO: Merge coverage reports for tests on different architectures. # - name: Generate code coverage - run: cargo +nightly-2024-01-04 llvm-cov --codecov --output-path codecov.json + run: cargo +nightly-2025-01-02 llvm-cov --codecov --output-path codecov.json env: RUSTFLAGS: "-C target-feature=+avx512f" - name: Upload coverage to Codecov diff --git a/Cargo.lock b/Cargo.lock index 3279ac177..4e715e06b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "aho-corasick" @@ -208,7 +208,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.89", + "syn 2.0.95", "which", ] @@ -278,7 +278,7 @@ checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.95", ] [[package]] @@ -598,12 +598,6 @@ dependencies = [ "subtle", ] -[[package]] -name = "downcast-rs" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2" - [[package]] name = "educe" version = "0.5.11" @@ -613,7 +607,7 @@ dependencies = [ "enum-ordinalize", "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.95", ] [[package]] @@ -639,7 +633,7 @@ checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.95", ] [[package]] @@ -834,6 +828,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.14" @@ -1051,7 +1054,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "64d1ec885c64d0457d564db4ec299b2dae3f9c02808b8ad9c3a089c591b18033" dependencies = [ "proc-macro2", - "syn 2.0.89", + "syn 2.0.95", ] [[package]] @@ -1254,7 +1257,7 @@ checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.95", ] [[package]] @@ -1335,7 +1338,7 @@ checksum = "bbc159a1934c7be9761c237333a57febe060ace2bc9e3b337a59a37af206d19f" dependencies = [ "starknet-curve", "starknet-ff", - "syn 2.0.89", + "syn 2.0.95", ] [[package]] @@ -1361,6 +1364,27 @@ dependencies = [ "serde", ] +[[package]] +name = "stwo-air-utils" +version = "0.1.1" +dependencies = [ + "bytemuck", + "itertools 0.12.1", + "rayon", + "stwo-air-utils-derive", + "stwo-prover", +] + +[[package]] +name = "stwo-air-utils-derive" +version = "0.1.0" +dependencies = [ + "itertools 0.13.0", + "proc-macro2", + "quote", + "syn 2.0.95", +] + [[package]] name = "stwo-prover" version = "0.1.1" @@ -1371,7 +1395,6 @@ dependencies = [ "bytemuck", "cfg-if", "criterion 0.5.1", - "downcast-rs", "educe", "hex", "icicle-core", @@ -1412,9 +1435,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.89" +version = "2.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44d46482f1c1c87acd84dea20c1bf5ebff4c757009ed6bf19cfd36fb10e92c4e" +checksum = "46f71c0377baf4ef1cc3e3402ded576dccc315800fbc62dfc7fe04b009773b4a" dependencies = [ "proc-macro2", "quote", @@ -1440,7 +1463,7 @@ checksum = "5999e24eaa32083191ba4e425deb75cdf25efefabe5aaccb7446dd0d4122a3f5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.95", ] [[package]] @@ -1469,7 +1492,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.95", ] [[package]] @@ -1511,7 +1534,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.95", ] [[package]] @@ -1627,7 +1650,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.95", "wasm-bindgen-shared", ] @@ -1661,7 +1684,7 @@ checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.95", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -1695,7 +1718,7 @@ checksum = "c97b2ef2c8d627381e51c071c2ab328eac606d3f69dd82bcbca20a9e389d95f0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.95", ] [[package]] @@ -1851,7 +1874,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.95", ] [[package]] @@ -1871,5 +1894,5 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.95", ] diff --git a/Cargo.toml b/Cargo.toml index 0f314a496..d4bb782ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["crates/prover"] +members = ["crates/prover", "crates/air_utils", "crates/air_utils_derive"] resolver = "2" [workspace.package] diff --git a/crates/air_utils/Cargo.toml b/crates/air_utils/Cargo.toml new file mode 100644 index 000000000..4463faf8b --- /dev/null +++ b/crates/air_utils/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "stwo-air-utils" +version.workspace = true +edition.workspace = true + +[dependencies] +bytemuck.workspace = true +itertools.workspace = true +rayon = { version = "1.10.0", optional = false } +stwo-prover = { path = "../prover" } +stwo-air-utils-derive = { path = "../air_utils_derive" } + +[lib] +bench = false diff --git a/crates/air_utils/src/lib.rs b/crates/air_utils/src/lib.rs new file mode 100644 index 000000000..813e60110 --- /dev/null +++ b/crates/air_utils/src/lib.rs @@ -0,0 +1,3 @@ +#![feature(exact_size_is_empty, raw_slice_split, portable_simd, array_chunks)] +pub mod lookup_data; +pub mod trace; diff --git a/crates/air_utils/src/lookup_data/mod.rs b/crates/air_utils/src/lookup_data/mod.rs new file mode 100644 index 000000000..234e4be44 --- /dev/null +++ b/crates/air_utils/src/lookup_data/mod.rs @@ -0,0 +1,168 @@ +#[cfg(test)] +mod tests { + use itertools::{all, Itertools}; + use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; + use rayon::slice::ParallelSlice; + use stwo_air_utils_derive::{IterMut, ParIterMut, Uninitialized}; + use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES}; + use stwo_prover::core::fields::m31::M31; + + use crate::trace::component_trace::ComponentTrace; + + /// Lookup data for the example ComponentTrace. + /// Vectors are assumed to be of the same length. + #[derive(Uninitialized, IterMut, ParIterMut)] + struct LookupData { + field0: Vec, + field1: Vec<[PackedM31; 2]>, + field2: [Vec<[PackedM31; 2]>; 2], + } + + #[test] + fn test_derived_lookup_data() { + const LOG_SIZE: u32 = 6; + let LookupData { + field0, + field1, + field2, + } = unsafe { LookupData::uninitialized(LOG_SIZE) }; + + let lengths = [ + [field0.len()].as_slice(), + [field1.len()].as_slice(), + field2.map(|v| v.len()).as_slice(), + ] + .concat(); + assert!(all(lengths, |len| len == 1 << LOG_SIZE)); + } + + #[test] + fn test_derived_lookup_data_iter() { + const N_COLUMNS: usize = 5; + const LOG_N_ROWS: u32 = 8; + let mut trace = ComponentTrace::::zeroed(LOG_N_ROWS); + let arr = (0..1 << LOG_N_ROWS).map(M31::from).collect_vec(); + let mut lookup_data = unsafe { LookupData::uninitialized(LOG_N_ROWS - LOG_N_LANES) }; + let (expected_field0, expected_field1, expected_field2): ( + Vec<_>, + Vec<_>, + (Vec<_>, Vec<_>), + ) = arr + .array_chunks::() + .map(|x| { + let x = PackedM31::from_array(*x); + let x1 = x + PackedM31::broadcast(M31(1)); + let x2 = x + x1; + let x3 = x + x1 + x2; + let x4 = x + x1 + x2 + x3; + ( + x4, + [x1, x1.double()], + ([x2, x2.double()], [x3, x3.double()]), + ) + }) + .multiunzip(); + + trace + .iter_mut() + .zip(arr.chunks(N_LANES)) + .zip(lookup_data.iter_mut()) + .for_each(|((row, input), lookup_data)| { + *row[0] = PackedM31::from_array(input.try_into().unwrap()); + *row[1] = *row[0] + PackedM31::broadcast(M31(1)); + *row[2] = *row[0] + *row[1]; + *row[3] = *row[0] + *row[1] + *row[2]; + *row[4] = *row[0] + *row[1] + *row[2] + *row[3]; + *lookup_data.field0 = *row[4]; + *lookup_data.field1 = [*row[1], row[1].double()]; + *lookup_data.field2[0] = [*row[2], row[2].double()]; + *lookup_data.field2[1] = [*row[3], row[3].double()]; + }); + let (actual0, actual1, actual2) = ( + lookup_data.field0, + lookup_data.field1, + (lookup_data.field2[0].clone(), lookup_data.field2[1].clone()), + ); + + assert_eq!( + format!("{expected_field0:?}"), + format!("{actual0:?}"), + "Failed on Vec" + ); + assert_eq!( + format!("{expected_field1:?}"), + format!("{actual1:?}"), + "Failed on Vec<[PackedM31; 2]>" + ); + assert_eq!( + format!("{expected_field2:?}"), + format!("{actual2:?}"), + "Failed on [Vec<[PackedM31; 2]>; 2]" + ); + } + + #[test] + fn test_derived_lookup_data_par_iter() { + const N_COLUMNS: usize = 5; + const LOG_N_ROWS: u32 = 8; + let mut trace = ComponentTrace::::zeroed(LOG_N_ROWS); + let arr = (0..1 << LOG_N_ROWS).map(M31::from).collect_vec(); + let mut lookup_data = unsafe { LookupData::uninitialized(LOG_N_ROWS - LOG_N_LANES) }; + let (expected_field0, expected_field1, expected_field2): ( + Vec<_>, + Vec<_>, + (Vec<_>, Vec<_>), + ) = arr + .array_chunks::() + .map(|x| { + let x = PackedM31::from_array(*x); + let x1 = x + PackedM31::broadcast(M31(1)); + let x2 = x + x1; + let x3 = x + x1 + x2; + let x4 = x + x1 + x2 + x3; + ( + x4, + [x1, x1.double()], + ([x2, x2.double()], [x3, x3.double()]), + ) + }) + .multiunzip(); + + trace + .par_iter_mut() + .zip(arr.par_chunks(N_LANES).into_par_iter()) + .zip(lookup_data.par_iter_mut()) + .for_each(|((row, input), lookup_data)| { + *row[0] = PackedM31::from_array(input.try_into().unwrap()); + *row[1] = *row[0] + PackedM31::broadcast(M31(1)); + *row[2] = *row[0] + *row[1]; + *row[3] = *row[0] + *row[1] + *row[2]; + *row[4] = *row[0] + *row[1] + *row[2] + *row[3]; + *lookup_data.field0 = *row[4]; + *lookup_data.field1 = [*row[1], row[1].double()]; + *lookup_data.field2[0] = [*row[2], row[2].double()]; + *lookup_data.field2[1] = [*row[3], row[3].double()]; + }); + let (actual0, actual1, actual2) = ( + lookup_data.field0, + lookup_data.field1, + (lookup_data.field2[0].clone(), lookup_data.field2[1].clone()), + ); + + assert_eq!( + format!("{expected_field0:?}"), + format!("{actual0:?}"), + "Failed on Vec" + ); + assert_eq!( + format!("{expected_field1:?}"), + format!("{actual1:?}"), + "Failed on Vec<[PackedM31; 2]>" + ); + assert_eq!( + format!("{expected_field2:?}"), + format!("{actual2:?}"), + "Failed on [Vec<[PackedM31; 2]>; 2]" + ); + } +} diff --git a/crates/air_utils/src/trace/component_trace.rs b/crates/air_utils/src/trace/component_trace.rs new file mode 100644 index 000000000..ba59964ea --- /dev/null +++ b/crates/air_utils/src/trace/component_trace.rs @@ -0,0 +1,191 @@ +use bytemuck::Zeroable; +use stwo_prover::core::backend::simd::column::BaseColumn; +use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES}; +use stwo_prover::core::backend::simd::SimdBackend; +use stwo_prover::core::fields::m31::M31; +use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use stwo_prover::core::poly::BitReversedOrder; + +use super::row_iterator::{ParRowIterMut, RowIterMut}; + +/// A 2D Matrix of [`PackedM31`] values. +/// Used for generating the witness of 'Stwo' proofs. +/// Stored as an array of `N` columns, each column is a vector of [`PackedM31`] values. +/// All columns are of the same length. +/// Exposes an iterator over mutable references to the rows of the matrix. +/// +/// # Example: +/// +/// ```text +/// Computation trace of a^2 + (a + 1)^2 for a in 0..256 +/// ``` +/// ``` +/// use stwo_air_utils::trace::component_trace::ComponentTrace; +/// use itertools::Itertools; +/// use stwo_prover::core::backend::simd::m31::{PackedM31, N_LANES}; +/// use stwo_prover::core::fields::m31::M31; +/// use stwo_prover::core::fields::FieldExpOps; +/// +/// const N_COLUMNS: usize = 3; +/// const LOG_SIZE: u32 = 8; +/// let mut trace = ComponentTrace::::zeroed(LOG_SIZE); +/// let example_input = (0..1 << LOG_SIZE).map(M31::from).collect_vec(); // 0..256 +/// trace +/// .iter_mut() +/// .zip(example_input.chunks(N_LANES)) +/// .chunks(4) +/// .into_iter() +/// .for_each(|chunk| { +/// chunk.into_iter().for_each(|(row, input)| { +/// *row[0] = PackedM31::from_array(input.try_into().unwrap()); +/// *row[1] = *row[0] + PackedM31::broadcast(M31(1)); +/// *row[2] = row[0].square() + row[1].square(); +/// }) +/// }); +/// +/// let first_3_rows = (0..N_COLUMNS).map(|i| trace.row_at(i)).collect::>(); +/// assert_eq!(first_3_rows, [[0,1,1], [1,2,5], [2,3,13]].map(|row| row.map(M31::from))); +/// ``` +#[derive(Debug)] +pub struct ComponentTrace { + /// Columns are assumed to be of the same length. + data: [Vec; N], + + /// Log number of non-packed rows in each column. + log_size: u32, +} + +impl ComponentTrace { + /// Creates a new `ComponentTrace` with all values initialized to zero. + /// The number of rows in each column is `2^log_size`. + /// # Panics: + /// if log_size < 4. + pub fn zeroed(log_size: u32) -> Self { + assert!( + log_size >= LOG_N_LANES, + "log_size < LOG_N_LANES not supported!" + ); + let n_simd_elems = 1 << (log_size - LOG_N_LANES); + let data = [(); N].map(|_| vec![PackedM31::zeroed(); n_simd_elems]); + Self { data, log_size } + } + + /// Creates a new `ComponentTrace` with all values uninitialized. + /// # Safety + /// The caller must ensure that the column is populated before being used. + /// The number of rows in each column is `2^log_size`. + /// # Panics: + /// if `log_size` < 4. + #[allow(clippy::uninit_vec)] + pub unsafe fn uninitialized(log_size: u32) -> Self { + assert!( + log_size >= LOG_N_LANES, + "log_size < LOG_N_LANES not supported!" + ); + let n_simd_elems = 1 << (log_size - LOG_N_LANES); + let data = [(); N].map(|_| { + let mut vec = Vec::with_capacity(n_simd_elems); + vec.set_len(n_simd_elems); + vec + }); + Self { data, log_size } + } + + pub fn log_size(&self) -> u32 { + self.log_size + } + + pub fn iter_mut(&mut self) -> RowIterMut<'_, N> { + RowIterMut::new(self.data.each_mut().map(|column| column.as_mut_slice())) + } + + pub fn par_iter_mut(&mut self) -> ParRowIterMut<'_, N> { + ParRowIterMut::new(self.data.each_mut().map(|column| column.as_mut_slice())) + } + + pub fn to_evals(self) -> [CircleEvaluation; N] { + let domain = CanonicCoset::new(self.log_size).circle_domain(); + self.data.map(|column| { + CircleEvaluation::::new( + domain, + BaseColumn::from_simd(column), + ) + }) + } + + pub fn row_at(&self, row: usize) -> [M31; N] { + assert!(row < 1 << self.log_size); + let packed_row = row / N_LANES; + let idx_in_simd_vector = row % N_LANES; + self.data + .each_ref() + .map(|column| column[packed_row].to_array()[idx_in_simd_vector]) + } +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + use stwo_prover::core::backend::simd::m31::{PackedM31, N_LANES}; + use stwo_prover::core::fields::m31::M31; + use stwo_prover::core::fields::FieldExpOps; + + #[test] + fn test_parallel_trace() { + use rayon::iter::{IndexedParallelIterator, ParallelIterator}; + use rayon::slice::ParallelSlice; + + const N_COLUMNS: usize = 3; + const LOG_SIZE: u32 = 8; + const CHUNK_SIZE: usize = 4; + let mut trace = super::ComponentTrace::::zeroed(LOG_SIZE); + let arr = (0..1 << LOG_SIZE).map(M31::from).collect_vec(); + let expected = arr + .iter() + .map(|&a| { + let b = a + M31::from(1); + let c = a.square() + b.square(); + (a, b, c) + }) + .multiunzip(); + + trace + .par_iter_mut() + .zip(arr.par_chunks(N_LANES)) + .chunks(CHUNK_SIZE) + .for_each(|chunk| { + chunk.into_iter().for_each(|(row, input)| { + *row[0] = PackedM31::from_array(input.try_into().unwrap()); + *row[1] = *row[0] + PackedM31::broadcast(M31(1)); + *row[2] = row[0].square() + row[1].square(); + }); + }); + let actual = trace + .data + .map(|c| { + c.into_iter() + .flat_map(|packed| packed.to_array()) + .collect_vec() + }) + .into_iter() + .next_tuple() + .unwrap(); + + assert_eq!(expected, actual); + } + + #[test] + fn test_component_trace_uninitialized_success() { + const N_COLUMNS: usize = 3; + const LOG_SIZE: u32 = 4; + unsafe { super::ComponentTrace::::uninitialized(LOG_SIZE) }; + } + + #[should_panic = "log_size < LOG_N_LANES not supported!"] + #[test] + fn test_component_trace_uninitialized_fails() { + const N_COLUMNS: usize = 3; + const LOG_SIZE: u32 = 3; + unsafe { super::ComponentTrace::::uninitialized(LOG_SIZE) }; + } +} diff --git a/crates/air_utils/src/trace/mod.rs b/crates/air_utils/src/trace/mod.rs new file mode 100644 index 000000000..6e44c9033 --- /dev/null +++ b/crates/air_utils/src/trace/mod.rs @@ -0,0 +1,2 @@ +pub mod component_trace; +mod row_iterator; diff --git a/crates/air_utils/src/trace/row_iterator.rs b/crates/air_utils/src/trace/row_iterator.rs new file mode 100644 index 000000000..78d03ebea --- /dev/null +++ b/crates/air_utils/src/trace/row_iterator.rs @@ -0,0 +1,126 @@ +use std::marker::PhantomData; + +use rayon::iter::plumbing::{bridge, Consumer, Producer, ProducerCallback, UnindexedConsumer}; +use rayon::prelude::*; +use stwo_prover::core::backend::simd::m31::PackedM31; + +pub type MutRow<'trace, const N: usize> = [&'trace mut PackedM31; N]; + +/// An iterator over mutable references to the rows of a [`super::component_trace::ComponentTrace`]. +// TODO(Ohad): Iterating over single rows is not optimal, figure out optimal chunk size when using +// this iterator. +pub struct RowIterMut<'trace, const N: usize> { + v: [*mut [PackedM31]; N], + phantom: PhantomData<&'trace ()>, +} +impl<'trace, const N: usize> RowIterMut<'trace, N> { + pub fn new(slice: [&'trace mut [PackedM31]; N]) -> Self { + Self { + v: slice.map(|s| s as *mut _), + phantom: PhantomData, + } + } +} +impl<'trace, const N: usize> Iterator for RowIterMut<'trace, N> { + type Item = MutRow<'trace, N>; + + fn next(&mut self) -> Option { + if self.v[0].is_empty() { + return None; + } + let item = std::array::from_fn(|i| unsafe { + // SAFETY: The self.v contract ensures that any split_at_mut is valid. + let (head, tail) = self.v[i].split_at_mut(1); + self.v[i] = tail; + &mut (*head)[0] + }); + Some(item) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.v[0].len(); + (len, Some(len)) + } +} +impl ExactSizeIterator for RowIterMut<'_, N> {} +impl DoubleEndedIterator for RowIterMut<'_, N> { + fn next_back(&mut self) -> Option { + if self.v[0].is_empty() { + return None; + } + let item = std::array::from_fn(|i| unsafe { + // SAFETY: The self.v contract ensures that any split_at_mut is valid. + let (head, tail) = self.v[i].split_at_mut(self.v[i].len() - 1); + self.v[i] = head; + &mut (*tail)[0] + }); + Some(item) + } +} + +struct RowProducer<'trace, const N: usize> { + data: [&'trace mut [PackedM31]; N], +} +impl<'trace, const N: usize> Producer for RowProducer<'trace, N> { + type Item = MutRow<'trace, N>; + + fn split_at(self, index: usize) -> (Self, Self) { + let mut left: [_; N] = unsafe { std::mem::zeroed() }; + let mut right: [_; N] = unsafe { std::mem::zeroed() }; + for (i, slice) in self.data.into_iter().enumerate() { + let (lhs, rhs) = slice.split_at_mut(index); + left[i] = lhs; + right[i] = rhs; + } + (RowProducer { data: left }, RowProducer { data: right }) + } + + type IntoIter = RowIterMut<'trace, N>; + + fn into_iter(self) -> Self::IntoIter { + RowIterMut { + v: self.data.map(|s| s as *mut _), + phantom: PhantomData, + } + } +} + +/// A parallel iterator over mutable references to the rows of a +/// [`super::component_trace::ComponentTrace`]. [`super::component_trace::ComponentTrace`] is an +/// array of columns, hence iterating over rows is not trivial. Iteration is done by iterating over +/// `N` columns in parallel. +pub struct ParRowIterMut<'trace, const N: usize> { + data: [&'trace mut [PackedM31]; N], +} +impl<'trace, const N: usize> ParRowIterMut<'trace, N> { + pub(super) fn new(data: [&'trace mut [PackedM31]; N]) -> Self { + Self { data } + } +} +impl<'trace, const N: usize> ParallelIterator for ParRowIterMut<'trace, N> { + type Item = MutRow<'trace, N>; + + fn drive_unindexed(self, consumer: D) -> D::Result + where + D: UnindexedConsumer, + { + bridge(self, consumer) + } + + fn opt_len(&self) -> Option { + Some(self.len()) + } +} +impl IndexedParallelIterator for ParRowIterMut<'_, N> { + fn len(&self) -> usize { + self.data[0].len() + } + + fn drive>(self, consumer: D) -> D::Result { + bridge(self, consumer) + } + + fn with_producer>(self, callback: CB) -> CB::Output { + callback.callback(RowProducer { data: self.data }) + } +} diff --git a/crates/air_utils_derive/Cargo.toml b/crates/air_utils_derive/Cargo.toml new file mode 100644 index 000000000..0f36c43af --- /dev/null +++ b/crates/air_utils_derive/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "stwo-air-utils-derive" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +syn = "2.0.90" +quote = "1.0.37" +itertools = "0.13.0" +proc-macro2 = "1.0.92" diff --git a/crates/air_utils_derive/src/allocation.rs b/crates/air_utils_derive/src/allocation.rs new file mode 100644 index 000000000..bff4897a6 --- /dev/null +++ b/crates/air_utils_derive/src/allocation.rs @@ -0,0 +1,30 @@ +use proc_macro2::TokenStream; +use quote::quote; +use syn::Ident; + +use crate::iterable_field::IterableField; + +/// Implements an "Uninitialized" function for the struct. +/// Allocates 2^`log_size` slots for every Vector. +pub fn expand_uninitialized_impl( + struct_name: &Ident, + iterable_fields: &[IterableField], +) -> TokenStream { + let (field_names, allocations): (Vec<_>, Vec<_>) = iterable_fields + .iter() + .map(|f| (f.name(), f.uninitialized_field_allocation())) + .unzip(); + quote! { + impl #struct_name { + /// # Safety + /// The caller must ensure that the trace is populated before being used. + #[automatically_derived] + pub unsafe fn uninitialized(log_size: u32) -> Self { + let len = 1 << log_size; + #(#allocations)* + Self { + #(#field_names,)* + } + } + }} +} diff --git a/crates/air_utils_derive/src/iter_mut.rs b/crates/air_utils_derive/src/iter_mut.rs new file mode 100644 index 000000000..0909d7341 --- /dev/null +++ b/crates/air_utils_derive/src/iter_mut.rs @@ -0,0 +1,167 @@ +use itertools::Itertools; +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote}; +use syn::{Ident, Lifetime}; + +use crate::iterable_field::IterableField; + +pub fn expand_iter_mut_structs( + struct_name: &Ident, + iterable_fields: &[IterableField], +) -> TokenStream { + let impl_struct_name = expand_impl_struct_name(struct_name, iterable_fields); + let mut_chunk_struct = expand_mut_chunk_struct(struct_name, iterable_fields); + let iter_mut_struct = expand_iter_mut_struct(struct_name, iterable_fields); + let iterator_impl = expand_iterator_impl(struct_name, iterable_fields); + let exact_size_iterator = expand_exact_size_iterator(struct_name); + let double_ended_iterator = expand_double_ended_iterator(struct_name, iterable_fields); + + quote! { + #impl_struct_name + #mut_chunk_struct + #iter_mut_struct + #iterator_impl + #exact_size_iterator + #double_ended_iterator + } +} + +fn expand_impl_struct_name(struct_name: &Ident, iterable_fields: &[IterableField]) -> TokenStream { + let iter_mut_name = format_ident!("{}IterMut", struct_name); + let as_mut_slice = iterable_fields + .iter() + .map(|f| f.as_mut_slice()) + .collect_vec(); + quote! { + impl #struct_name { + pub fn iter_mut(&mut self) -> #iter_mut_name<'_> { + #iter_mut_name::new( + #(self.#as_mut_slice,)* + ) + } + } + } +} + +fn expand_mut_chunk_struct(struct_name: &Ident, iterable_fields: &[IterableField]) -> TokenStream { + let lifetime = Lifetime::new("'a", Span::call_site()); + let mut_chunk_name = format_ident!("{}MutChunk", struct_name); + let (field_names, mut_chunk_types): (Vec<_>, Vec<_>) = iterable_fields + .iter() + .map(|f| (f.name(), f.mut_chunk_type(&lifetime))) + .unzip(); + + quote! { + pub struct #mut_chunk_name<#lifetime> { + #(#field_names: #mut_chunk_types,)* + } + } +} + +fn expand_iter_mut_struct(struct_name: &Ident, iterable_fields: &[IterableField]) -> TokenStream { + let lifetime = Lifetime::new("'a", Span::call_site()); + let iter_mut_name = format_ident!("{}IterMut", struct_name); + let (field_names, mut_slice_types, mut_ptr_types, as_mut_ptr): ( + Vec<_>, + Vec<_>, + Vec<_>, + Vec<_>, + ) = iterable_fields + .iter() + .map(|f| { + ( + f.name(), + f.mut_slice_type(&lifetime), + f.mut_slice_ptr_type(), + f.as_mut_ptr(), + ) + }) + .multiunzip(); + + quote! { + pub struct #iter_mut_name<#lifetime> { + #(#field_names: #mut_ptr_types,)* + phantom: std::marker::PhantomData<&#lifetime ()>, + } + impl<#lifetime> #iter_mut_name<#lifetime> { + pub fn new( + #(#field_names: #mut_slice_types,)* + ) -> Self { + Self { + #(#field_names: #as_mut_ptr,)* + phantom: std::marker::PhantomData, + } + } + } + } +} + +fn expand_iterator_impl(struct_name: &Ident, iterable_fields: &[IterableField]) -> TokenStream { + let lifetime = Lifetime::new("'a", Span::call_site()); + let iter_mut_name = format_ident!("{}IterMut", struct_name); + let mut_chunk_name = format_ident!("{}MutChunk", struct_name); + let (field_names, split_first): (Vec<_>, Vec<_>) = iterable_fields + .iter() + .map(|f| (f.name(), f.split_first())) + .unzip(); + let get_length = iterable_fields.first().unwrap().get_len(); + + quote! { + impl<#lifetime> Iterator for #iter_mut_name<#lifetime> { + type Item = #mut_chunk_name<#lifetime>; + fn next(&mut self) -> Option { + if self.#get_length == 0 { + return None; + } + let item = unsafe { + #(#split_first)* + #mut_chunk_name { + #(#field_names,)* + } + }; + Some(item) + } + fn size_hint(&self) -> (usize, Option) { + let len = self.#get_length; + (len, Some(len)) + } + } + } +} + +fn expand_exact_size_iterator(struct_name: &Ident) -> TokenStream { + let iter_mut_name = format_ident!("{}IterMut", struct_name); + quote! { + impl ExactSizeIterator for #iter_mut_name<'_> {} + } +} + +fn expand_double_ended_iterator( + struct_name: &Ident, + iterable_fields: &[IterableField], +) -> TokenStream { + let iter_mut_name = format_ident!("{}IterMut", struct_name); + let mut_chunk_name = format_ident!("{}MutChunk", struct_name); + let (field_names, split_last): (Vec<_>, Vec<_>) = iterable_fields + .iter() + .map(|f| (f.name(), f.split_last(&format_ident!("len")))) + .unzip(); + let get_length = iterable_fields.first().unwrap().get_len(); + quote! { + impl DoubleEndedIterator for #iter_mut_name<'_> { + fn next_back(&mut self) -> Option { + let len = self.#get_length; + if len == 0 { + return None; + } + let item = unsafe { + #(#split_last)* + #mut_chunk_name { + #(#field_names,)* + } + }; + Some(item) + } + } + } +} diff --git a/crates/air_utils_derive/src/iterable_field.rs b/crates/air_utils_derive/src/iterable_field.rs new file mode 100644 index 000000000..6cb80ea15 --- /dev/null +++ b/crates/air_utils_derive/src/iterable_field.rs @@ -0,0 +1,369 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{Data, DeriveInput, Expr, Field, Fields, Ident, Lifetime, Type}; + +/// Each variant represents a field that can be iterated over. +/// Used to derive implementations of `Uninitialized`, `MutIter`, and `ParIterMut`. +/// Currently supports `Vec` and `[Vec; N]` fields only. +pub(super) enum IterableField { + /// A single Vec field, e.g. `Vec`, `Vec<[u32; K]>`. + PlainVec(PlainVec), + /// An array of Vec fields, e.g. `[Vec; N]`, `[Vec<[u32; K]>; N]`. + ArrayOfVecs(ArrayOfVecs), +} + +pub(super) struct PlainVec { + name: Ident, + ty: Type, +} +pub(super) struct ArrayOfVecs { + name: Ident, + inner_type: Type, + outer_array_size: Expr, +} + +impl IterableField { + pub fn from_field(field: &Field) -> Result { + // Check if the field is a vector or array of vectors. + match field.ty { + // Case that type is [Vec; N]. + Type::Array(ref outer_array) => { + let inner_type = match outer_array.elem.as_ref() { + Type::Path(ref type_path) => parse_inner_type(type_path)?, + _ => Err(syn::Error::new_spanned( + outer_array.elem.clone(), + "Expected Vec type", + ))?, + }; + Ok(Self::ArrayOfVecs(ArrayOfVecs { + name: field.ident.clone().unwrap(), + outer_array_size: outer_array.len.clone(), + inner_type: inner_type.clone(), + })) + } + // Case that type is Vec. + Type::Path(ref type_path) => { + let ty = parse_inner_type(type_path)?; + Ok(Self::PlainVec(PlainVec { + name: field.ident.clone().unwrap(), + ty, + })) + } + _ => Err(syn::Error::new_spanned( + field, + "Expected vector or array of vectors", + )), + } + } + + pub fn name(&self) -> &Ident { + match self { + IterableField::PlainVec(plain_vec) => &plain_vec.name, + IterableField::ArrayOfVecs(array_of_vecs) => &array_of_vecs.name, + } + } + + /// Generate the type of a mutable slice of the field. + /// E.g. `&'a mut [u32]` for a `Vec` field. + /// E.g. [`&'a mut [u32]; N]` for a `[Vec; N]` field. + /// Used in the `IterMut` struct. + pub fn mut_slice_type(&self, lifetime: &Lifetime) -> TokenStream { + match self { + IterableField::PlainVec(plain_vec) => { + let ty = &plain_vec.ty; + quote! { + &#lifetime mut [#ty] + } + } + IterableField::ArrayOfVecs(array_of_vecs) => { + let inner_type = &array_of_vecs.inner_type; + let outer_array_size = &array_of_vecs.outer_array_size; + quote! { + [&#lifetime mut [#inner_type]; #outer_array_size] + } + } + } + } + + /// Generate the type of a mutable chunk of the field. + /// E.g. `&'a mut u32` for a `Vec` field. + /// E.g. [`&'a mut u32; N]` for a `[Vec; N]` field. + /// Used in the `MutChunk` struct. + pub fn mut_chunk_type(&self, lifetime: &Lifetime) -> TokenStream { + match self { + IterableField::PlainVec(plain_vec) => { + let ty = &plain_vec.ty; + quote! { + &#lifetime mut #ty + } + } + IterableField::ArrayOfVecs(array_of_vecs) => { + let inner_type = &array_of_vecs.inner_type; + let array_size = &array_of_vecs.outer_array_size; + quote! { + [&#lifetime mut #inner_type; #array_size] + } + } + } + } + + /// Generate the type of a mutable slice pointer to the field. + /// E.g. `*mut [u32]` for a `Vec` field. + /// E.g. [`*mut [u32]; N]` for a `[Vec; N]` field. + /// Used in the `IterMut` struct. + pub fn mut_slice_ptr_type(&self) -> TokenStream { + match self { + IterableField::PlainVec(plain_vec) => { + let ty = &plain_vec.ty; + quote! { + *mut [#ty] + } + } + IterableField::ArrayOfVecs(array_of_vecs) => { + let inner_type = &array_of_vecs.inner_type; + let outer_array_size = &array_of_vecs.outer_array_size; + quote! { + [*mut [#inner_type]; #outer_array_size] + } + } + } + } + + /// Generate the uninitialized allocation for the field. + /// E.g. `Vec::with_capacity(len); vec.set_len(len);` for a `Vec` field. + /// E.g. `[(); N].map(|_| { Vec::with_capacity(len); vec.set_len(len); })` for `[Vec; N]`. + /// Used to generate the `uninitialized` function. + pub fn uninitialized_field_allocation(&self) -> TokenStream { + match self { + IterableField::PlainVec(plain_vec) => { + let name = &plain_vec.name; + quote! { + let mut #name= Vec::with_capacity(len); + #name.set_len(len); + } + } + IterableField::ArrayOfVecs(array_of_vecs) => { + let name = &array_of_vecs.name; + let outer_array_size = &array_of_vecs.outer_array_size; + quote! { + let #name = [(); #outer_array_size].map(|_| { + let mut vec = Vec::with_capacity(len); + vec.set_len(len); + vec + }); + } + } + } + } + + /// Generate the code to split the first element from the field. + /// E.g. `let (head, tail) = self.field.split_at_mut(1); + /// self.field = tail; let field = &mut (*head)[0];` + /// Used for the `next` function in the iterator struct. + pub fn split_first(&self) -> TokenStream { + match self { + IterableField::PlainVec(plain_vec) => { + let name = &plain_vec.name; + let head = format_ident!("{}_head", name); + let tail = format_ident!("{}_tail", name); + quote! { + let (#head, #tail) = self.#name.split_at_mut(1); + self.#name = #tail; + let #name = &mut (*(#head))[0]; + } + } + IterableField::ArrayOfVecs(array_of_vecs) => { + let name = &array_of_vecs.name; + quote! { + let #name = self.#name.each_mut().map(|v| { + let (head, tail) = v.split_at_mut(1); + *v = tail; + &mut (*head)[0] + }); + } + } + } + } + + /// Generate the code to split the last element from the field. + /// E.g. `let (head, tail) = self.field.split_at_mut(len - 1); + /// Used for the `next_back` function in the DoubleEnded impl. + pub fn split_last(&self, length: &Ident) -> TokenStream { + match self { + IterableField::PlainVec(plain_vec) => { + let name = &plain_vec.name; + let head = format_ident!("{}_head", name); + let tail = format_ident!("{}_tail", name); + quote! { + let ( + #head, + #tail, + ) = self.#name.split_at_mut(#length - 1); + self.#name = #head; + let #name = &mut (*#tail)[0]; + } + } + IterableField::ArrayOfVecs(array_of_vecs) => { + let name = &array_of_vecs.name; + quote! { + let #name = self.#name.each_mut().map(|v| { + let (head, tail) = v.split_at_mut(#length - 1); + *v = head; + &mut (*tail)[0] + }); + } + } + } + } + + /// Generate the code to split the field at a given index. + /// E.g. `let (head, tail) = self.field.split_at_mut(index);` + /// E.g. `let (head, tail) = self.field.each_mut().map(|v| v.split_at_mut(index));` + /// Used for the `split_at` function in the Producer impl. + pub fn split_at(&self, index: &Ident) -> TokenStream { + match self { + IterableField::PlainVec(plain_vec) => { + let name = &plain_vec.name; + let head = format_ident!("{}_head", name); + let tail = format_ident!("{}_tail", name); + quote! { + let ( + #head, + #tail + ) = self.#name.split_at_mut(#index); + } + } + IterableField::ArrayOfVecs(array_of_vecs) => { + let name = &array_of_vecs.name; + let head = format_ident!("{}_head", name); + let tail = format_ident!("{}_tail", name); + let array_size = &array_of_vecs.outer_array_size; + quote! { + let ( + mut #head, + mut #tail + ):([_; #array_size],[_; #array_size]) = unsafe { (std::mem::zeroed(), std::mem::zeroed()) }; + self.#name.into_iter().enumerate().for_each(|(i, v)| { + let (head, tail) = v.split_at_mut(#index); + #head[i] = head; + #tail[i] = tail; + }); + } + } + } + } + + /// Generate the code to get a mutable slice of the field. + /// E.g. `self.field.as_mut_slice()` + /// E.g. `self.field.each_mut().map(|v| v.as_mut_slice())` + /// Used to generate the arguments for the IterMut 'new' function call. + pub fn as_mut_slice(&self) -> TokenStream { + match self { + IterableField::PlainVec(plain_vec) => { + let name = &plain_vec.name; + quote! { + #name.as_mut_slice() + } + } + IterableField::ArrayOfVecs(array_of_vecs) => { + let name = &array_of_vecs.name; + quote! { + #name.each_mut().map(|v| v.as_mut_slice()) + } + } + } + } + + /// Generate the code to get a mutable pointer a mutable slice of the field. + /// E.g. `'a mut [u32]` -> `*mut [u32]`. Achieved by casting: `as *mut _`. + /// Used for the `IterMut` struct. + pub fn as_mut_ptr(&self) -> TokenStream { + match self { + IterableField::PlainVec(plain_vec) => { + let name = &plain_vec.name; + quote! { + #name as *mut _ + } + } + IterableField::ArrayOfVecs(array_of_vecs) => { + let name = &array_of_vecs.name; + quote! { + #name.map(|v| v as *mut _) + } + } + } + } + + /// Generate the code to get the length of the field. + /// Length is assumed to be the same for all fields on every coordinate. + /// E.g. `self.field.len()` + /// E.g. `self.field[0].len()` + pub fn get_len(&self) -> TokenStream { + match self { + IterableField::PlainVec(plain_vec) => { + let name = &plain_vec.name; + quote! { + #name.len() + } + } + IterableField::ArrayOfVecs(array_of_vecs) => { + let name = &array_of_vecs.name; + quote! { + #name[0].len() + } + } + } + } +} + +// Extract the inner vector type from a path. +// Returns an error if the path is not of the form ::Vec. +fn parse_inner_type(type_path: &syn::TypePath) -> Result { + match type_path.path.segments.last() { + Some(last_segment) => { + if last_segment.ident != "Vec" { + return Err(syn::Error::new_spanned( + last_segment.ident.clone(), + "Expected Vec type", + )); + } + match &last_segment.arguments { + syn::PathArguments::AngleBracketed(args) => match args.args.first() { + Some(syn::GenericArgument::Type(inner_type)) => Ok(inner_type.clone()), + _ => Err(syn::Error::new_spanned( + args.args.first().unwrap(), + "Expected exactly one generic argument: Vec", + )), + }, + _ => Err(syn::Error::new_spanned( + last_segment.arguments.clone(), + "Expected angle-bracketed arguments: Vec<..>", + )), + } + } + _ => Err(syn::Error::new_spanned( + type_path.path.clone(), + "Expected last segment", + )), + } +} + +pub(super) fn to_iterable_fields(input: DeriveInput) -> Result, syn::Error> { + let struct_name = &input.ident; + let input = match input.data { + Data::Struct(data_struct) => Ok(data_struct), + _ => Err(syn::Error::new_spanned(struct_name, "Expected a struct")), + }?; + + match input.fields { + Fields::Named(fields) => Ok(fields + .named + .iter() + .map(IterableField::from_field) + .collect::>()?), + _ => Err(syn::Error::new_spanned( + input.fields, + "Expected named fields", + )), + } +} diff --git a/crates/air_utils_derive/src/lib.rs b/crates/air_utils_derive/src/lib.rs new file mode 100644 index 000000000..33bd91bf7 --- /dev/null +++ b/crates/air_utils_derive/src/lib.rs @@ -0,0 +1,45 @@ +mod allocation; +mod iter_mut; +mod iterable_field; +mod par_iter; +use iterable_field::to_iterable_fields; +use syn::{parse_macro_input, DeriveInput}; + +#[proc_macro_derive(Uninitialized)] +pub fn derive_uninitialized(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let struct_name = input.ident.clone(); + + let iterable_fields = match to_iterable_fields(input) { + Ok(iterable_fields) => iterable_fields, + Err(err) => return err.into_compile_error().into(), + }; + + allocation::expand_uninitialized_impl(&struct_name, &iterable_fields).into() +} + +#[proc_macro_derive(IterMut)] +pub fn derive_mut_iter(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let struct_name = input.ident.clone(); + + let iterable_fields = match to_iterable_fields(input) { + Ok(iterable_fields) => iterable_fields, + Err(err) => return err.into_compile_error().into(), + }; + + iter_mut::expand_iter_mut_structs(&struct_name, &iterable_fields).into() +} + +#[proc_macro_derive(ParIterMut)] +pub fn derive_par_mut_iter(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let struct_name = input.ident.clone(); + + let iterable_fields = match to_iterable_fields(input) { + Ok(iterable_fields) => iterable_fields, + Err(err) => return err.into_compile_error().into(), + }; + + par_iter::expand_par_iter_mut_structs(&struct_name, &iterable_fields).into() +} diff --git a/crates/air_utils_derive/src/par_iter.rs b/crates/air_utils_derive/src/par_iter.rs new file mode 100644 index 000000000..e7e2b8947 --- /dev/null +++ b/crates/air_utils_derive/src/par_iter.rs @@ -0,0 +1,163 @@ +use itertools::Itertools; +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote}; +use syn::{Ident, Lifetime}; + +use crate::iterable_field::IterableField; + +pub fn expand_par_iter_mut_structs( + struct_name: &Ident, + iterable_fields: &[IterableField], +) -> TokenStream { + let lifetime = Lifetime::new("'a", Span::call_site()); + let split_index = format_ident!("index"); + + let struct_code = generate_struct_impl(struct_name, iterable_fields); + let producer_code = + generate_row_producer(struct_name, iterable_fields, &lifetime, &split_index); + let oar_iter_struct = generate_par_iter_struct(struct_name, iterable_fields, &lifetime); + let impl_par_iter = generate_parallel_iterator_impls(struct_name, iterable_fields, &lifetime); + + quote! { + #struct_code + #producer_code + #oar_iter_struct + #impl_par_iter + } +} + +fn generate_struct_impl(struct_name: &Ident, iterable_fields: &[IterableField]) -> TokenStream { + let par_iter_mut_name = format_ident!("{}ParIterMut", struct_name); + let as_mut_slice = iterable_fields.iter().map(|f| f.as_mut_slice()); + quote! { + impl #struct_name { + pub fn par_iter_mut(&mut self) -> #par_iter_mut_name<'_> { + #par_iter_mut_name::new( + #(self.#as_mut_slice,)* + ) + } + } + } +} + +fn generate_row_producer( + struct_name: &Ident, + iterable_fields: &[IterableField], + lifetime: &Lifetime, + split_index: &Ident, +) -> TokenStream { + let row_producer_name = format_ident!("{}RowProducer", struct_name); + let mut_chunk_name = format_ident!("{}MutChunk", struct_name); + let iter_mut_name = format_ident!("{}IterMut", struct_name); + let (field_names, mut_slice_types, split_at): (Vec<_>, Vec<_>, Vec<_>) = iterable_fields + .iter() + .map(|f| { + ( + f.name(), + f.mut_slice_type(lifetime), + f.split_at(split_index), + ) + }) + .multiunzip(); + let field_names_head = field_names.iter().map(|f| format_ident!("{}_head", f)); + let field_names_tail = field_names.iter().map(|f| format_ident!("{}_tail", f)); + quote! { + pub struct #row_producer_name<#lifetime> { + #(#field_names: #mut_slice_types,)* + } + impl<#lifetime> rayon::iter::plumbing::Producer for #row_producer_name<#lifetime> { + type Item = #mut_chunk_name<#lifetime>; + type IntoIter = #iter_mut_name<#lifetime>; + + #[allow(invalid_value)] + fn split_at(self, index: usize) -> (Self, Self) { + #(#split_at)* + ( + #row_producer_name { + #(#field_names: #field_names_head,)* + }, + #row_producer_name { + #(#field_names: #field_names_tail,)* + } + ) + } + + fn into_iter(self) -> Self::IntoIter { + #iter_mut_name::new(#(self.#field_names),*) + } + } + } +} + +fn generate_par_iter_struct( + struct_name: &Ident, + iterable_fields: &[IterableField], + lifetime: &Lifetime, +) -> TokenStream { + let par_iter_mut_name = format_ident!("{struct_name}ParIterMut"); + let (field_names, mut_slice_types): (Vec<_>, Vec<_>) = iterable_fields + .iter() + .map(|f| (f.name(), f.mut_slice_type(lifetime))) + .unzip(); + quote! { + pub struct #par_iter_mut_name<#lifetime> { + #(#field_names: #mut_slice_types,)* + } + + impl<#lifetime> #par_iter_mut_name<#lifetime> { + pub fn new( + #(#field_names: #mut_slice_types,)* + ) -> Self { + Self { + #(#field_names,)* + } + } + } + } +} + +fn generate_parallel_iterator_impls( + struct_name: &Ident, + iterable_fields: &[IterableField], + lifetime: &Lifetime, +) -> TokenStream { + let par_iter_mut_name = format_ident!("{}ParIterMut", struct_name); + let mut_chunk_name = format_ident!("{}MutChunk", struct_name); + let row_producer_name = format_ident!("{}RowProducer", struct_name); + let field_names = iterable_fields.iter().map(|f| f.name()); + let get_length = iterable_fields.first().unwrap().get_len(); + quote! { + impl<#lifetime> rayon::prelude::ParallelIterator for #par_iter_mut_name<#lifetime> { + type Item = #mut_chunk_name<#lifetime>; + + fn drive_unindexed(self, consumer: D) -> D::Result + where + D: rayon::iter::plumbing::UnindexedConsumer, + { + rayon::iter::plumbing::bridge(self, consumer) + } + + fn opt_len(&self) -> Option { + Some(self.len()) + } + } + + impl rayon::iter::IndexedParallelIterator for #par_iter_mut_name<'_> { + fn len(&self) -> usize { + self.#get_length + } + + fn drive>(self, consumer: D) -> D::Result { + rayon::iter::plumbing::bridge(self, consumer) + } + + fn with_producer>(self, callback: CB) -> CB::Output { + callback.callback( + #row_producer_name { + #(#field_names : self.#field_names,)* + } + ) + } + } + } +} diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index 1aa546e5f..994cdcd76 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -16,7 +16,6 @@ blake2.workspace = true blake3.workspace = true bytemuck = { workspace = true, features = ["derive", "extern_crate_alloc"] } cfg-if = "1.0.0" -downcast-rs = "1.2" educe.workspace = true hex.workspace = true itertools.workspace = true @@ -65,8 +64,8 @@ nonstandard-style = "deny" rust-2018-idioms = "deny" unused = "deny" -[package.metadata.cargo-machete] -ignored = ["downcast-rs"] +[lints.clippy] +missing_const_for_fn = "warn" [[bench]] harness = false diff --git a/crates/prover/benches/bit_rev.rs b/crates/prover/benches/bit_rev.rs index 7e8865c43..fe0a80f1d 100644 --- a/crates/prover/benches/bit_rev.rs +++ b/crates/prover/benches/bit_rev.rs @@ -8,9 +8,10 @@ const LOG_SIZE: usize = 28; const SIZE: usize = 1 << LOG_SIZE; pub fn cpu_bit_rev(c: &mut Criterion) { - use stwo_prover::core::utils::bit_reverse; - #[cfg(not(feature = "icicle"))] + use stwo_prover::core::backend::cpu::bit_reverse; + // TODO(andrew): Consider using same size for all. + let data = (0..SIZE).map(BaseField::from).collect_vec(); #[cfg(feature = "icicle")] diff --git a/crates/prover/benches/fft.rs b/crates/prover/benches/fft.rs index 35841d7e8..cbb0c9e80 100644 --- a/crates/prover/benches/fft.rs +++ b/crates/prover/benches/fft.rs @@ -29,7 +29,7 @@ pub fn simd_ifft(c: &mut Criterion) { || values.clone().data, |mut data| unsafe { ifft( - transmute(data.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(data.as_mut_ptr()), black_box(&twiddle_dbls_refs), black_box(log_size as usize), ); @@ -58,7 +58,7 @@ pub fn simd_ifft_parts(c: &mut Criterion) { || values.clone().data, |mut values| unsafe { ifft_vecwise_loop( - transmute(values.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(values.as_mut_ptr()), black_box(&twiddle_dbls_refs), black_box(9), black_box(0), @@ -72,7 +72,7 @@ pub fn simd_ifft_parts(c: &mut Criterion) { || values.clone().data, |mut values| unsafe { ifft3_loop( - transmute(values.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(values.as_mut_ptr()), black_box(&twiddle_dbls_refs[3..]), black_box(7), black_box(4), @@ -91,7 +91,7 @@ pub fn simd_ifft_parts(c: &mut Criterion) { || transpose_values.clone().data, |mut values| unsafe { transpose_vecs( - transmute(values.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(values.as_mut_ptr()), black_box(TRANSPOSE_LOG_SIZE as usize - 4), ) }, @@ -115,8 +115,10 @@ pub fn simd_rfft(c: &mut Criterion) { target.set_len(values.data.len()); fft( - black_box(transmute(values.data.as_ptr())), - transmute(target.as_mut_ptr()), + black_box(transmute::<*const PackedBaseField, *const u32>( + values.data.as_ptr(), + )), + transmute::<*mut PackedBaseField, *mut u32>(target.as_mut_ptr()), black_box(&twiddle_dbls_refs), black_box(LOG_SIZE as usize), ) diff --git a/crates/prover/benches/merkle.rs b/crates/prover/benches/merkle.rs index c039be77e..9a63a3c38 100644 --- a/crates/prover/benches/merkle.rs +++ b/crates/prover/benches/merkle.rs @@ -21,7 +21,7 @@ fn bench_blake2s_merkle>(c: &mut Criterion, id let n_elements = 1 << (LOG_N_COLS + LOG_N_ROWS); group.throughput(Throughput::Elements(n_elements)); group.throughput(Throughput::Bytes(N_BYTES_FELT as u64 * n_elements)); - group.bench_function(&format!("{id} merkle"), |b| { + group.bench_function(format!("{id} merkle"), |b| { b.iter_with_large_drop(|| B::commit_on_layer(LOG_N_ROWS, None, &col_refs)) }); } diff --git a/crates/prover/src/constraint_framework/assert.rs b/crates/prover/src/constraint_framework/assert.rs index ce37a0cb3..376ff80b1 100644 --- a/crates/prover/src/constraint_framework/assert.rs +++ b/crates/prover/src/constraint_framework/assert.rs @@ -1,4 +1,4 @@ -use num_traits::{One, Zero}; +use num_traits::Zero; use super::logup::{LogupAtRow, LogupSums}; use super::{EvalAtRow, INTERACTION_TRACE_IDX}; @@ -33,7 +33,7 @@ impl<'a> AssertEvaluator<'a> { } } } -impl<'a> EvalAtRow for AssertEvaluator<'a> { +impl EvalAtRow for AssertEvaluator<'_> { type F = BaseField; type EF = SecureField; @@ -54,13 +54,17 @@ impl<'a> EvalAtRow for AssertEvaluator<'a> { fn add_constraint(&mut self, constraint: G) where - Self::EF: std::ops::Mul, + Self::EF: std::ops::Mul + From, { // Cast to SecureField. - let res = SecureField::one() * constraint; // The constraint should be zero at the given row, since we are evaluating on the trace // domain. - assert_eq!(res, SecureField::zero(), "row: {}", self.row); + assert_eq!( + Self::EF::from(constraint), + SecureField::zero(), + "row: {}", + self.row + ); } fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF { diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index ab957fdac..caae6fec6 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -17,6 +17,7 @@ use super::{ }; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use crate::core::air::{Component, ComponentProver, Trace}; +use crate::core::backend::cpu::bit_reverse; use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords; use crate::core::backend::simd::m31::LOG_N_LANES; use crate::core::backend::simd::very_packed_m31::{VeryPackedBaseField, LOG_N_VERY_PACKED_ELEMS}; @@ -31,7 +32,7 @@ use crate::core::fields::FieldExpOps; use crate::core::pcs::{TreeSubspan, TreeVec}; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; -use crate::core::{utils, ColumnVec}; +use crate::core::ColumnVec; const CHUNK_SIZE: usize = 1; @@ -93,7 +94,7 @@ impl TraceLocationAllocator { } } - pub fn preprocessed_columns(&self) -> &HashMap { + pub const fn preprocessed_columns(&self) -> &HashMap { &self.preprocessed_columns } @@ -110,9 +111,10 @@ impl TraceLocationAllocator { } /// A component defined solely in means of the constraints framework. +/// /// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for -/// the SIMD backend. -/// Note that the constraint framework only support components with columns of the same size. +/// the SIMD backend. Note that the constraint framework only supports components with columns of +/// the same size. pub trait FrameworkEval { fn log_size(&self) -> u32; @@ -293,7 +295,7 @@ impl ComponentProver for FrameworkComponen let mut denom_inv = (0..1 << log_expand) .map(|i| coset_vanishing(trace_domain.coset(), eval_domain.at(i)).inverse()) .collect_vec(); - utils::bit_reverse(&mut denom_inv); + bit_reverse(&mut denom_inv); // Accumulator. let [mut accum] = @@ -471,7 +473,7 @@ impl ComponentProver for FrameworkComponent let mut denom_inv = (0..1 << log_expand) .map(|i| coset_vanishing(trace_domain.coset(), eval_domain.at(i)).inverse()) .collect_vec(); - utils::bit_reverse(&mut denom_inv); + bit_reverse(&mut denom_inv); // Accumulator. let [mut accum] = @@ -579,7 +581,7 @@ impl ComponentProver for FrameworkCompon let mut denom_inv = (0..1 << log_expand) .map(|i| coset_vanishing(trace_domain.coset(), eval_domain.at(i)).inverse()) .collect_vec(); - utils::bit_reverse(&mut denom_inv); + bit_reverse(&mut denom_inv); nvtx::range_pop!(); // Accumulator. diff --git a/crates/prover/src/constraint_framework/cpu_domain.rs b/crates/prover/src/constraint_framework/cpu_domain.rs index 72d285aeb..03089bd17 100644 --- a/crates/prover/src/constraint_framework/cpu_domain.rs +++ b/crates/prover/src/constraint_framework/cpu_domain.rs @@ -52,7 +52,7 @@ impl<'a> CpuDomainEvaluator<'a> { } } -impl<'a> EvalAtRow for CpuDomainEvaluator<'a> { +impl EvalAtRow for CpuDomainEvaluator<'_> { type F = BaseField; type EF = SecureField; @@ -85,7 +85,7 @@ impl<'a> EvalAtRow for CpuDomainEvaluator<'a> { fn add_constraint(&mut self, constraint: G) where - Self::EF: Mul, + Self::EF: Mul + From, { self.row_res += self.random_coeff_powers[self.constraint_index] * constraint; self.constraint_index += 1; diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs deleted file mode 100644 index 9e2e7bed5..000000000 --- a/crates/prover/src/constraint_framework/expr.rs +++ /dev/null @@ -1,536 +0,0 @@ -use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub}; - -use num_traits::{One, Zero}; - -use super::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX}; -use crate::core::fields::m31::{self, BaseField}; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::FieldExpOps; -use crate::core::lookups::utils::Fraction; - -/// A single base field column at index `idx` of interaction `interaction`, at mask offset `offset`. -#[derive(Clone, Debug, PartialEq)] -pub struct ColumnExpr { - interaction: usize, - idx: usize, - offset: isize, -} - -#[derive(Clone, Debug, PartialEq)] -pub enum Expr { - Col(ColumnExpr), - /// An atomic secure column constructed from 4 expressions. - /// Expressions on the secure column are not reduced, i.e, - /// if `a = SecureCol(a0, a1, a2, a3)`, `b = SecureCol(b0, b1, b2, b3)` then - /// `a + b` evaluates to `Add(a, b)` rather than - /// `SecureCol(Add(a0, b0), Add(a1, b1), Add(a2, b2), Add(a3, b3))` - SecureCol([Box; 4]), - Const(BaseField), - /// Formal parameter to the AIR, for example the interaction elements of a relation. - Param(String), - Add(Box, Box), - Sub(Box, Box), - Mul(Box, Box), - Neg(Box), - Inv(Box), -} - -impl Expr { - #[allow(dead_code)] - pub fn format_expr(&self) -> String { - match self { - Expr::Col(ColumnExpr { - interaction, - idx, - offset, - }) => { - let offset_str = if *offset == CLAIMED_SUM_DUMMY_OFFSET as isize { - "claimed_sum_offset".to_string() - } else { - offset.to_string() - }; - format!("col_{interaction}_{idx}[{offset_str}]") - } - Expr::SecureCol([a, b, c, d]) => format!( - "SecureCol({}, {}, {}, {})", - a.format_expr(), - b.format_expr(), - c.format_expr(), - d.format_expr() - ), - Expr::Const(c) => c.0.to_string(), - Expr::Param(v) => v.to_string(), - Expr::Add(a, b) => format!("{} + {}", a.format_expr(), b.format_expr()), - Expr::Sub(a, b) => format!("{} - ({})", a.format_expr(), b.format_expr()), - Expr::Mul(a, b) => format!("({}) * ({})", a.format_expr(), b.format_expr()), - Expr::Neg(a) => format!("-({})", a.format_expr()), - Expr::Inv(a) => format!("1/({})", a.format_expr()), - } - } -} - -impl From for Expr { - fn from(val: BaseField) -> Self { - Expr::Const(val) - } -} - -impl From for Expr { - fn from(val: SecureField) -> Self { - Expr::SecureCol([ - Box::new(val.0 .0.into()), - Box::new(val.0 .1.into()), - Box::new(val.1 .0.into()), - Box::new(val.1 .1.into()), - ]) - } -} - -impl Add for Expr { - type Output = Self; - fn add(self, rhs: Self) -> Self { - Expr::Add(Box::new(self), Box::new(rhs)) - } -} - -impl Sub for Expr { - type Output = Self; - fn sub(self, rhs: Self) -> Self { - Expr::Sub(Box::new(self), Box::new(rhs)) - } -} - -impl Mul for Expr { - type Output = Self; - fn mul(self, rhs: Self) -> Self { - Expr::Mul(Box::new(self), Box::new(rhs)) - } -} - -impl AddAssign for Expr { - fn add_assign(&mut self, rhs: Self) { - *self = self.clone() + rhs - } -} - -impl MulAssign for Expr { - fn mul_assign(&mut self, rhs: Self) { - *self = self.clone() * rhs - } -} - -impl Neg for Expr { - type Output = Self; - fn neg(self) -> Self { - Expr::Neg(Box::new(self)) - } -} - -impl Zero for Expr { - fn zero() -> Self { - Expr::Const(BaseField::zero()) - } - fn is_zero(&self) -> bool { - // TODO(alont): consider replacing `Zero` in the trait bound with a custom trait - // that only has `zero()`. - panic!("Can't check if an expression is zero."); - } -} - -impl One for Expr { - fn one() -> Self { - Expr::Const(BaseField::one()) - } -} - -impl FieldExpOps for Expr { - fn inverse(&self) -> Self { - Expr::Inv(Box::new(self.clone())) - } -} - -impl Add for Expr { - type Output = Self; - fn add(self, rhs: BaseField) -> Self { - self + Expr::from(rhs) - } -} - -impl Mul for Expr { - type Output = Self; - fn mul(self, rhs: BaseField) -> Self { - self * Expr::from(rhs) - } -} - -impl Mul for Expr { - type Output = Self; - fn mul(self, rhs: SecureField) -> Self { - self * Expr::from(rhs) - } -} - -impl Add for Expr { - type Output = Self; - fn add(self, rhs: SecureField) -> Self { - self + Expr::from(rhs) - } -} - -impl Sub for Expr { - type Output = Self; - fn sub(self, rhs: SecureField) -> Self { - self - Expr::from(rhs) - } -} - -impl AddAssign for Expr { - fn add_assign(&mut self, rhs: BaseField) { - *self = self.clone() + Expr::from(rhs) - } -} - -/// Returns the expression -/// `value[0] * _alpha0 + value[1] * _alpha1 + ... - _z.` -fn combine_formal>(relation: &R, values: &[Expr]) -> Expr { - const Z_SUFFIX: &str = "_z"; - const ALPHA_SUFFIX: &str = "_alpha"; - - let z = Expr::Param(relation.get_name().to_owned() + Z_SUFFIX); - let alpha_powers = (0..relation.get_size()) - .map(|i| Expr::Param(relation.get_name().to_owned() + ALPHA_SUFFIX + &i.to_string())); - values - .iter() - .zip(alpha_powers) - .fold(Expr::zero(), |acc, (value, power)| { - acc + power * value.clone() - }) - - z -} - -pub struct FormalLogupAtRow { - pub interaction: usize, - pub total_sum: Expr, - pub claimed_sum: Option<(Expr, usize)>, - pub prev_col_cumsum: Expr, - pub cur_frac: Option>, - pub is_finalized: bool, - pub is_first: Expr, - pub log_size: u32, -} - -// P is an offset no column can reach, it signifies the variable -// offset, which is an input to the verifier. -const CLAIMED_SUM_DUMMY_OFFSET: usize = m31::P as usize; - -impl FormalLogupAtRow { - pub fn new(interaction: usize, has_partial_sum: bool, log_size: u32) -> Self { - let total_sum_name = "total_sum".to_string(); - let claimed_sum_name = "claimed_sum".to_string(); - - Self { - interaction, - // TODO(alont): Should these be Expr::SecureField? - total_sum: Expr::Param(total_sum_name), - claimed_sum: has_partial_sum - .then_some((Expr::Param(claimed_sum_name), CLAIMED_SUM_DUMMY_OFFSET)), - prev_col_cumsum: Expr::zero(), - cur_frac: None, - is_finalized: true, - is_first: Expr::zero(), - log_size, - } - } -} - -/// An Evaluator that saves all constraint expressions. -pub struct ExprEvaluator { - pub cur_var_index: usize, - pub constraints: Vec, - pub logup: FormalLogupAtRow, -} - -impl ExprEvaluator { - #[allow(dead_code)] - pub fn new(log_size: u32, has_partial_sum: bool) -> Self { - Self { - cur_var_index: Default::default(), - constraints: Default::default(), - logup: FormalLogupAtRow::new(INTERACTION_TRACE_IDX, has_partial_sum, log_size), - } - } -} - -impl EvalAtRow for ExprEvaluator { - // TODO(alont): Should there be a version of this that disallows Secure fields for F? - type F = Expr; - type EF = Expr; - - fn next_interaction_mask( - &mut self, - interaction: usize, - offsets: [isize; N], - ) -> [Self::F; N] { - std::array::from_fn(|i| { - let col = ColumnExpr { - interaction, - idx: self.cur_var_index, - offset: offsets[i], - }; - self.cur_var_index += 1; - Expr::Col(col) - }) - } - - fn add_constraint(&mut self, constraint: G) - where - Self::EF: std::ops::Mul, - { - self.constraints.push(Expr::one() * constraint); - } - - fn combine_ef(values: [Self::F; 4]) -> Self::EF { - Expr::SecureCol([ - Box::new(values[0].clone()), - Box::new(values[1].clone()), - Box::new(values[2].clone()), - Box::new(values[3].clone()), - ]) - } - - fn add_to_relation>( - &mut self, - entries: &[RelationEntry<'_, Self::F, Self::EF, R>], - ) { - let fracs: Vec> = entries - .iter() - .map( - |RelationEntry { - relation, - multiplicity, - values, - }| { - Fraction::new(multiplicity.clone(), combine_formal(*relation, values)) - }, - ) - .collect(); - self.write_logup_frac(fracs.into_iter().sum()); - } - - super::logup_proxy!(); -} - -#[cfg(test)] -mod tests { - use num_traits::One; - - use crate::constraint_framework::expr::{ColumnExpr, Expr, ExprEvaluator}; - use crate::constraint_framework::{ - relation, EvalAtRow, FrameworkEval, RelationEntry, ORIGINAL_TRACE_IDX, - }; - use crate::core::fields::m31::M31; - use crate::core::fields::FieldExpOps; - - #[test] - fn test_expr_eval() { - let test_struct = TestStruct {}; - let eval = test_struct.evaluate(ExprEvaluator::new(16, false)); - assert_eq!(eval.constraints.len(), 2); - assert_eq!( - eval.constraints[0], - Expr::Mul( - Box::new(Expr::one()), - Box::new(Expr::Mul( - Box::new(Expr::Mul( - Box::new(Expr::Mul( - Box::new(Expr::Col(ColumnExpr { - interaction: ORIGINAL_TRACE_IDX, - idx: 0, - offset: 0 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: ORIGINAL_TRACE_IDX, - idx: 1, - offset: 0 - })) - )), - Box::new(Expr::Col(ColumnExpr { - interaction: ORIGINAL_TRACE_IDX, - idx: 2, - offset: 0 - })) - )), - Box::new(Expr::Inv(Box::new(Expr::Add( - Box::new(Expr::Col(ColumnExpr { - interaction: ORIGINAL_TRACE_IDX, - idx: 0, - offset: 0 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: ORIGINAL_TRACE_IDX, - idx: 1, - offset: 0 - })) - )))) - )) - ) - ); - - assert_eq!( - eval.constraints[1], - Expr::Mul( - Box::new(Expr::Const(M31(1))), - Box::new(Expr::Sub( - Box::new(Expr::Mul( - Box::new(Expr::Sub( - Box::new(Expr::Sub( - Box::new(Expr::SecureCol([ - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 4, - offset: 0 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 6, - offset: 0 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 8, - offset: 0 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 10, - offset: 0 - })) - ])), - Box::new(Expr::Sub( - Box::new(Expr::SecureCol([ - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 5, - offset: -1 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 7, - offset: -1 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 9, - offset: -1 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 11, - offset: -1 - })) - ])), - Box::new(Expr::Mul( - Box::new(Expr::Col(ColumnExpr { - interaction: 0, - idx: 3, - offset: 0 - })), - Box::new(Expr::Param("total_sum".into())) - )) - )) - )), - Box::new(Expr::Const(M31(0))) - )), - Box::new(Expr::Sub( - Box::new(Expr::Add( - Box::new(Expr::Add( - Box::new(Expr::Add( - Box::new(Expr::Const(M31(0))), - Box::new(Expr::Mul( - Box::new(Expr::Param( - "TestRelation_alpha0".to_string() - )), - Box::new(Expr::Col(ColumnExpr { - interaction: 1, - idx: 0, - offset: 0 - })) - )) - )), - Box::new(Expr::Mul( - Box::new(Expr::Param("TestRelation_alpha1".to_string())), - Box::new(Expr::Col(ColumnExpr { - interaction: 1, - idx: 1, - offset: 0 - })) - )) - )), - Box::new(Expr::Mul( - Box::new(Expr::Param("TestRelation_alpha2".to_string())), - Box::new(Expr::Col(ColumnExpr { - interaction: 1, - idx: 2, - offset: 0 - })) - )) - )), - Box::new(Expr::Param("TestRelation_z".to_string())) - )) - )), - Box::new(Expr::Const(M31(1))) - )) - ) - ); - } - - #[test] - fn test_format_expr() { - let test_struct = TestStruct {}; - let eval = test_struct.evaluate(ExprEvaluator::new(16, false)); - let constraint0_str = "(1) * ((((col_1_0[0]) * (col_1_1[0])) * (col_1_2[0])) * (1/(col_1_0[0] + col_1_1[0])))"; - assert_eq!(eval.constraints[0].format_expr(), constraint0_str); - let constraint1_str = "(1) \ - * ((SecureCol(col_2_4[0], col_2_6[0], col_2_8[0], col_2_10[0]) \ - - (SecureCol(\ - col_2_5[-1], \ - col_2_7[-1], \ - col_2_9[-1], \ - col_2_11[-1]\ - ) - ((col_0_3[0]) * (total_sum))) \ - - (0)) \ - * (0 + (TestRelation_alpha0) * (col_1_0[0]) \ - + (TestRelation_alpha1) * (col_1_1[0]) \ - + (TestRelation_alpha2) * (col_1_2[0]) \ - - (TestRelation_z)) \ - - (1))"; - assert_eq!(eval.constraints[1].format_expr(), constraint1_str); - } - - relation!(TestRelation, 3); - - struct TestStruct {} - impl FrameworkEval for TestStruct { - fn log_size(&self) -> u32 { - 0 - } - fn max_constraint_log_degree_bound(&self) -> u32 { - 0 - } - fn evaluate(&self, mut eval: E) -> E { - let x0 = eval.next_trace_mask(); - let x1 = eval.next_trace_mask(); - let x2 = eval.next_trace_mask(); - eval.add_constraint( - x0.clone() * x1.clone() * x2.clone() * (x0.clone() + x1.clone()).inverse(), - ); - eval.add_to_relation(&[RelationEntry::new( - &TestRelation::dummy(), - E::EF::one(), - &[x0, x1, x2], - )]); - eval.finalize_logup(); - eval - } - } -} diff --git a/crates/prover/src/constraint_framework/expr/assignment.rs b/crates/prover/src/constraint_framework/expr/assignment.rs new file mode 100644 index 000000000..1ba834139 --- /dev/null +++ b/crates/prover/src/constraint_framework/expr/assignment.rs @@ -0,0 +1,267 @@ +use std::collections::{HashMap, HashSet}; +use std::hash::{DefaultHasher, Hash, Hasher}; +use std::ops::{Add, Index}; + +use itertools::sorted; + +use super::{BaseExpr, ColumnExpr, ExtExpr}; +use crate::constraint_framework::{AssertEvaluator, EvalAtRow}; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::FieldExpOps; + +/// An assignment to the variables that may appear in an expression. +pub type ExprVarAssignment = ( + HashMap<(usize, usize, isize), BaseField>, + HashMap, + HashMap, +); + +/// Three sets representing all the variables that can appear in an expression: +/// * `cols`: The columns of the AIR. +/// * `params`: The formal parameters to the AIR. +/// * `ext_params`: The extension field parameters to the AIR. +#[derive(Default)] +pub struct ExprVariables { + pub cols: HashSet, + pub params: HashSet, + pub ext_params: HashSet, +} + +impl ExprVariables { + pub fn col(col: ColumnExpr) -> Self { + Self { + cols: vec![col].into_iter().collect(), + params: HashSet::new(), + ext_params: HashSet::new(), + } + } + + pub fn param(param: String) -> Self { + Self { + cols: HashSet::new(), + params: vec![param].into_iter().collect(), + ext_params: HashSet::new(), + } + } + + pub fn ext_param(param: String) -> Self { + Self { + cols: HashSet::new(), + params: HashSet::new(), + ext_params: vec![param].into_iter().collect(), + } + } + + /// Generates a random assignment to the variables. + /// Note that the assignment is deterministic in the sets of variables (disregarding their + /// order), and this is required. + pub fn random_assignment(&self, salt: usize) -> ExprVarAssignment { + let cols = sorted(self.cols.iter()) + .map(|col| { + ((col.interaction, col.idx, col.offset), { + let mut hasher = DefaultHasher::new(); + (salt, col).hash(&mut hasher); + (hasher.finish() as u32).into() + }) + }) + .collect(); + + let params = sorted(self.params.iter()) + .map(|param| { + (param.clone(), { + let mut hasher = DefaultHasher::new(); + (salt, param).hash(&mut hasher); + (hasher.finish() as u32).into() + }) + }) + .collect(); + + let ext_params = sorted(self.ext_params.iter()) + .map(|param| { + (param.clone(), { + let mut hasher = DefaultHasher::new(); + (salt, param).hash(&mut hasher); + (hasher.finish() as u32).into() + }) + }) + .collect(); + + (cols, params, ext_params) + } +} + +impl Add for ExprVariables { + type Output = Self; + fn add(self, rhs: Self) -> Self { + Self { + cols: self.cols.union(&rhs.cols).cloned().collect(), + params: self.params.union(&rhs.params).cloned().collect(), + ext_params: self.ext_params.union(&rhs.ext_params).cloned().collect(), + } + } +} + +impl BaseExpr { + /// Evaluates a base field expression. + /// Takes: + /// * `columns`: A mapping from triplets (interaction, idx, offset) to base field values. + /// * `vars`: A mapping from variable names to base field values. + pub fn eval_expr(&self, columns: &C, vars: &V) -> E::F + where + C: for<'a> Index<&'a (usize, usize, isize), Output = E::F>, + V: for<'a> Index<&'a String, Output = E::F>, + E: EvalAtRow, + { + match self { + Self::Col(col) => columns[&(col.interaction, col.idx, col.offset)].clone(), + Self::Const(c) => E::F::from(*c), + Self::Param(var) => vars[&var.to_string()].clone(), + Self::Add(a, b) => { + a.eval_expr::(columns, vars) + b.eval_expr::(columns, vars) + } + Self::Sub(a, b) => { + a.eval_expr::(columns, vars) - b.eval_expr::(columns, vars) + } + Self::Mul(a, b) => { + a.eval_expr::(columns, vars) * b.eval_expr::(columns, vars) + } + Self::Neg(a) => -a.eval_expr::(columns, vars), + Self::Inv(a) => a.eval_expr::(columns, vars).inverse(), + } + } + + pub fn collect_variables(&self) -> ExprVariables { + match self { + BaseExpr::Col(col) => ExprVariables::col(col.clone()), + BaseExpr::Const(_) => ExprVariables::default(), + BaseExpr::Param(param) => ExprVariables::param(param.to_string()), + BaseExpr::Add(a, b) => a.collect_variables() + b.collect_variables(), + BaseExpr::Sub(a, b) => a.collect_variables() + b.collect_variables(), + BaseExpr::Mul(a, b) => a.collect_variables() + b.collect_variables(), + BaseExpr::Neg(a) => a.collect_variables(), + BaseExpr::Inv(a) => a.collect_variables(), + } + } + + pub fn random_eval(&self) -> BaseField { + let assignment = self.collect_variables().random_assignment(0); + assert!(assignment.2.is_empty()); + self.eval_expr::, _, _>(&assignment.0, &assignment.1) + } +} + +impl ExtExpr { + /// Evaluates an extension field expression. + /// Takes: + /// * `columns`: A mapping from triplets (interaction, idx, offset) to base field values. + /// * `vars`: A mapping from variable names to base field values. + /// * `ext_vars`: A mapping from variable names to extension field values. + pub fn eval_expr(&self, columns: &C, vars: &V, ext_vars: &EV) -> E::EF + where + C: for<'a> Index<&'a (usize, usize, isize), Output = E::F>, + V: for<'a> Index<&'a String, Output = E::F>, + EV: for<'a> Index<&'a String, Output = E::EF>, + E: EvalAtRow, + { + match self { + Self::SecureCol([a, b, c, d]) => { + let a = a.eval_expr::(columns, vars); + let b = b.eval_expr::(columns, vars); + let c = c.eval_expr::(columns, vars); + let d = d.eval_expr::(columns, vars); + E::combine_ef([a, b, c, d]) + } + Self::Const(c) => E::EF::from(*c), + Self::Param(var) => ext_vars[&var.to_string()].clone(), + Self::Add(a, b) => { + a.eval_expr::(columns, vars, ext_vars) + + b.eval_expr::(columns, vars, ext_vars) + } + Self::Sub(a, b) => { + a.eval_expr::(columns, vars, ext_vars) + - b.eval_expr::(columns, vars, ext_vars) + } + Self::Mul(a, b) => { + a.eval_expr::(columns, vars, ext_vars) + * b.eval_expr::(columns, vars, ext_vars) + } + Self::Neg(a) => -a.eval_expr::(columns, vars, ext_vars), + } + } + + pub fn collect_variables(&self) -> ExprVariables { + match self { + ExtExpr::SecureCol([a, b, c, d]) => { + a.collect_variables() + + b.collect_variables() + + c.collect_variables() + + d.collect_variables() + } + ExtExpr::Const(_) => ExprVariables::default(), + ExtExpr::Param(param) => ExprVariables::ext_param(param.to_string()), + ExtExpr::Add(a, b) => a.collect_variables() + b.collect_variables(), + ExtExpr::Sub(a, b) => a.collect_variables() + b.collect_variables(), + ExtExpr::Mul(a, b) => a.collect_variables() + b.collect_variables(), + ExtExpr::Neg(a) => a.collect_variables(), + } + } + + pub fn random_eval(&self) -> SecureField { + let assignment = self.collect_variables().random_assignment(0); + self.eval_expr::, _, _, _>(&assignment.0, &assignment.1, &assignment.2) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use num_traits::One; + + use crate::constraint_framework::expr::utils::*; + use crate::constraint_framework::AssertEvaluator; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + use crate::core::fields::FieldExpOps; + + #[test] + fn test_eval_expr() { + let col_1_0_0 = BaseField::from(12); + let col_1_1_0 = BaseField::from(5); + let var_a = BaseField::from(3); + let var_b = BaseField::from(4); + let var_c = SecureField::from_m31_array([ + BaseField::from(1), + BaseField::from(2), + BaseField::from(3), + BaseField::from(4), + ]); + + let columns: HashMap<(usize, usize, isize), BaseField> = + HashMap::from([((1, 0, 0), col_1_0_0), ((1, 1, 0), col_1_1_0)]); + let vars = HashMap::from([("a".to_string(), var_a), ("b".to_string(), var_b)]); + let ext_vars = HashMap::from([("c".to_string(), var_c)]); + + let expr = secure_col!( + col!(1, 0, 0) - col!(1, 1, 0), + col!(1, 1, 0) * (-var!("a")), + var!("a") + var!("a").inverse(), + var!("b") * felt!(7) + ) + qvar!("c") * qvar!("c") + - qfelt!(1, 0, 0, 0); + + let expected = SecureField::from_m31_array([ + col_1_0_0 - col_1_1_0, + col_1_1_0 * (-var_a), + var_a + var_a.inverse(), + var_b * BaseField::from(7), + ]) + var_c * var_c + - SecureField::one(); + + assert_eq!( + expr.eval_expr::, _, _, _>(&columns, &vars, &ext_vars), + expected + ); + } +} diff --git a/crates/prover/src/constraint_framework/expr/degree.rs b/crates/prover/src/constraint_framework/expr/degree.rs new file mode 100644 index 000000000..27c848336 --- /dev/null +++ b/crates/prover/src/constraint_framework/expr/degree.rs @@ -0,0 +1,100 @@ +/// Finds a degree bound for an expressions. The degree is given with respect to columns as +/// variables. +/// Computes the actual degree with the following caveats: +/// 1. The constant expression 0 receives degree 0 like all other constants rather than the +/// mathematically correcy -infinity. This means, for example, that expresisons of the +/// type 0 * expr will return degree deg expr. This should be mitigated by +/// simplification. +/// 2. If expressions p and q cancel out under some operation, this will not be accounted +/// for, so that (x^2 + 1) - (x^2 + x) will return degree 2. +use std::collections::HashMap; + +use super::{BaseExpr, ExtExpr}; + +type Degree = usize; + +/// A struct of named expressions that can be searched when determining the degree bound for an +/// expression that contains parameters. +/// Required because expressions that contain parameters that are actually intermediates have to +/// account for the degree of the intermediate. +pub struct NamedExprs { + exprs: HashMap, + ext_exprs: HashMap, +} + +impl NamedExprs { + pub fn degree_bound(&self, name: String) -> Degree { + if let Some(expr) = self.exprs.get(&name) { + expr.degree_bound(self) + } else if let Some(expr) = self.ext_exprs.get(&name) { + expr.degree_bound(self) + } else if name.starts_with("preprocessed.") { + // TODO(alont): Fix this hack. + 1 + } else { + // If expression isn't found assume it's an external variable, effectively a const. + 0 + } + } +} + +impl BaseExpr { + pub fn degree_bound(&self, named_exprs: &NamedExprs) -> Degree { + match self { + BaseExpr::Col(_) => 1, + BaseExpr::Const(_) => 0, + BaseExpr::Param(name) => named_exprs.degree_bound(name.clone()), + BaseExpr::Add(a, b) => a.degree_bound(named_exprs).max(b.degree_bound(named_exprs)), + BaseExpr::Sub(a, b) => a.degree_bound(named_exprs).max(b.degree_bound(named_exprs)), + BaseExpr::Mul(a, b) => a.degree_bound(named_exprs) + b.degree_bound(named_exprs), + BaseExpr::Neg(a) => a.degree_bound(named_exprs), + // TODO(alont): Consider handling this in the type system. + BaseExpr::Inv(_) => panic!("Cannot compute the degree of an inverse."), + } + } +} + +impl ExtExpr { + pub fn degree_bound(&self, named_exprs: &NamedExprs) -> Degree { + match self { + ExtExpr::SecureCol(coefs) => coefs + .iter() + .cloned() + .map(|coef| coef.degree_bound(named_exprs)) + .max() + .unwrap(), + ExtExpr::Const(_) => 0, + ExtExpr::Param(name) => named_exprs.degree_bound(name.clone()), + ExtExpr::Add(a, b) => a.degree_bound(named_exprs).max(b.degree_bound(named_exprs)), + ExtExpr::Sub(a, b) => a.degree_bound(named_exprs).max(b.degree_bound(named_exprs)), + ExtExpr::Mul(a, b) => a.degree_bound(named_exprs) + b.degree_bound(named_exprs), + ExtExpr::Neg(a) => a.degree_bound(named_exprs), + } + } +} + +#[cfg(test)] +mod tests { + use crate::constraint_framework::expr::degree::NamedExprs; + use crate::constraint_framework::expr::utils::*; + + #[test] + fn test_degree_bound() { + let intermediate = (felt!(12) + col!(1, 1, 0)) * var!("a") * col!(1, 0, 0); + let qintermediate = secure_col!(intermediate.clone(), felt!(12), var!("b"), felt!(0)); + + let named_exprs = NamedExprs { + exprs: [("intermediate".to_string(), intermediate.clone())].into(), + ext_exprs: [("qintermediate".to_string(), qintermediate.clone())].into(), + }; + + let expr = var!("intermediate") * col!(2, 1, 0); + let qexpr = + var!("qintermediate") * secure_col!(col!(2, 1, 0), expr.clone(), felt!(0), felt!(1)); + + assert_eq!(intermediate.degree_bound(&named_exprs), 2); + assert_eq!(qintermediate.degree_bound(&named_exprs), 2); + assert_eq!(expr.degree_bound(&named_exprs), 3); + assert_eq!(qexpr.degree_bound(&named_exprs), 5); + } +} diff --git a/crates/prover/src/constraint_framework/expr/evaluator.rs b/crates/prover/src/constraint_framework/expr/evaluator.rs new file mode 100644 index 000000000..6b9238d23 --- /dev/null +++ b/crates/prover/src/constraint_framework/expr/evaluator.rs @@ -0,0 +1,244 @@ +use num_traits::Zero; + +use super::{BaseExpr, ExtExpr}; +use crate::constraint_framework::expr::ColumnExpr; +use crate::constraint_framework::preprocessed_columns::PreprocessedColumn; +use crate::constraint_framework::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX}; +use crate::core::fields::m31; +use crate::core::lookups::utils::Fraction; + +pub struct FormalLogupAtRow { + pub interaction: usize, + pub total_sum: ExtExpr, + pub claimed_sum: Option<(ExtExpr, usize)>, + pub fracs: Vec>, + pub is_finalized: bool, + pub is_first: BaseExpr, + pub log_size: u32, +} + +// P is an offset no column can reach, it signifies the variable +// offset, which is an input to the verifier. +pub const CLAIMED_SUM_DUMMY_OFFSET: usize = m31::P as usize; + +impl FormalLogupAtRow { + pub fn new(interaction: usize, has_partial_sum: bool, log_size: u32) -> Self { + let total_sum_name = "total_sum".to_string(); + let claimed_sum_name = "claimed_sum".to_string(); + + Self { + interaction, + // TODO(alont): Should these be Expr::SecureField? + total_sum: ExtExpr::Param(total_sum_name), + claimed_sum: has_partial_sum + .then_some((ExtExpr::Param(claimed_sum_name), CLAIMED_SUM_DUMMY_OFFSET)), + fracs: vec![], + is_finalized: true, + is_first: BaseExpr::zero(), + log_size, + } + } +} + +/// Returns the expression +/// `value[0] * _alpha0 + value[1] * _alpha1 + ... - _z.` +fn combine_formal>(relation: &R, values: &[BaseExpr]) -> ExtExpr { + const Z_SUFFIX: &str = "_z"; + const ALPHA_SUFFIX: &str = "_alpha"; + + let z = ExtExpr::Param(relation.get_name().to_owned() + Z_SUFFIX); + let alpha_powers = (0..relation.get_size()) + .map(|i| ExtExpr::Param(relation.get_name().to_owned() + ALPHA_SUFFIX + &i.to_string())); + values + .iter() + .zip(alpha_powers) + .fold(ExtExpr::zero(), |acc, (value, power)| { + acc + power * value.clone() + }) + - z +} + +/// An Evaluator that saves all constraint expressions. +pub struct ExprEvaluator { + pub cur_var_index: usize, + pub constraints: Vec, + pub logup: FormalLogupAtRow, + pub intermediates: Vec<(String, BaseExpr)>, + pub ext_intermediates: Vec<(String, ExtExpr)>, +} + +impl ExprEvaluator { + pub fn new(log_size: u32, has_partial_sum: bool) -> Self { + Self { + cur_var_index: Default::default(), + constraints: Default::default(), + logup: FormalLogupAtRow::new(INTERACTION_TRACE_IDX, has_partial_sum, log_size), + intermediates: vec![], + ext_intermediates: vec![], + } + } + + pub fn format_constraints(&self) -> String { + let lets_string = self + .intermediates + .iter() + .map(|(name, expr)| format!("let {} = {};", name, expr.simplify_and_format())) + .collect::>() + .join("\n\n"); + + let secure_lets_string = self + .ext_intermediates + .iter() + .map(|(name, expr)| format!("let {} = {};", name, expr.simplify_and_format())) + .collect::>() + .join("\n\n"); + + let constraints_str = self + .constraints + .iter() + .enumerate() + .map(|(i, c)| format!("let constraint_{i} = ") + &c.simplify_and_format() + ";") + .collect::>() + .join("\n\n"); + + [lets_string, secure_lets_string, constraints_str] + .iter() + .filter(|x| !x.is_empty()) + .cloned() + .collect::>() + .join("\n\n") + } +} + +impl EvalAtRow for ExprEvaluator { + // TODO(alont): Should there be a version of this that disallows Secure fields for F? + type F = BaseExpr; + type EF = ExtExpr; + + fn next_interaction_mask( + &mut self, + interaction: usize, + offsets: [isize; N], + ) -> [Self::F; N] { + let res = std::array::from_fn(|i| { + let col = ColumnExpr::from((interaction, self.cur_var_index, offsets[i])); + BaseExpr::Col(col) + }); + self.cur_var_index += 1; + res + } + + fn add_constraint(&mut self, constraint: G) + where + Self::EF: From, + { + self.constraints.push(constraint.into()); + } + + fn combine_ef(values: [Self::F; 4]) -> Self::EF { + ExtExpr::SecureCol([ + Box::new(values[0].clone()), + Box::new(values[1].clone()), + Box::new(values[2].clone()), + Box::new(values[3].clone()), + ]) + } + + fn add_to_relation>( + &mut self, + entry: RelationEntry<'_, Self::F, Self::EF, R>, + ) { + let intermediate = + self.add_extension_intermediate(combine_formal(entry.relation, entry.values)); + let frac = Fraction::new(entry.multiplicity.clone(), intermediate); + self.write_logup_frac(frac); + } + + fn add_intermediate(&mut self, expr: Self::F) -> Self::F { + let name = format!( + "intermediate{}", + self.intermediates.len() + self.ext_intermediates.len() + ); + let intermediate = BaseExpr::Param(name.clone()); + self.intermediates.push((name, expr)); + intermediate + } + + fn add_extension_intermediate(&mut self, expr: Self::EF) -> Self::EF { + let name = format!( + "intermediate{}", + self.intermediates.len() + self.ext_intermediates.len() + ); + let intermediate = ExtExpr::Param(name.clone()); + self.ext_intermediates.push((name, expr)); + intermediate + } + + fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F { + BaseExpr::Param(column.name().to_string()) + } + + crate::constraint_framework::logup_proxy!(); +} + +#[cfg(test)] +mod tests { + use num_traits::One; + + use crate::constraint_framework::expr::ExprEvaluator; + use crate::constraint_framework::{EvalAtRow, FrameworkEval, RelationEntry}; + use crate::core::fields::FieldExpOps; + use crate::relation; + + #[test] + fn test_expr_evaluator() { + let test_struct = TestStruct {}; + let eval = test_struct.evaluate(ExprEvaluator::new(16, false)); + let expected = "let intermediate0 = (trace_1_column_1_offset_0) * (trace_1_column_2_offset_0); + +\ + let intermediate1 = (TestRelation_alpha0) * (trace_1_column_0_offset_0) \ + + (TestRelation_alpha1) * (trace_1_column_1_offset_0) \ + + (TestRelation_alpha2) * (trace_1_column_2_offset_0) \ + - (TestRelation_z); + +\ + let constraint_0 = ((trace_1_column_0_offset_0) * (intermediate0)) * (1 / (trace_1_column_0_offset_0 + trace_1_column_1_offset_0)); + +\ + let constraint_1 = (QM31Impl::from_partial_evals([trace_2_column_3_offset_0, trace_2_column_4_offset_0, trace_2_column_5_offset_0, trace_2_column_6_offset_0]) \ + - (QM31Impl::from_partial_evals([trace_2_column_3_offset_neg_1, trace_2_column_4_offset_neg_1, trace_2_column_5_offset_neg_1, trace_2_column_6_offset_neg_1]) \ + - ((total_sum) * (preprocessed_is_first)))) \ + * (intermediate1) \ + - (qm31(1, 0, 0, 0));" + .to_string(); + + assert_eq!(eval.format_constraints(), expected); + } + + relation!(TestRelation, 3); + + struct TestStruct {} + impl FrameworkEval for TestStruct { + fn log_size(&self) -> u32 { + 0 + } + fn max_constraint_log_degree_bound(&self) -> u32 { + 0 + } + fn evaluate(&self, mut eval: E) -> E { + let x0 = eval.next_trace_mask(); + let x1 = eval.next_trace_mask(); + let x2 = eval.next_trace_mask(); + let intermediate = eval.add_intermediate(x1.clone() * x2.clone()); + eval.add_constraint(x0.clone() * intermediate * (x0.clone() + x1.clone()).inverse()); + eval.add_to_relation(RelationEntry::new( + &TestRelation::dummy(), + E::EF::one(), + &[x0, x1, x2], + )); + eval.finalize_logup(); + eval + } + } +} diff --git a/crates/prover/src/constraint_framework/expr/format.rs b/crates/prover/src/constraint_framework/expr/format.rs new file mode 100644 index 000000000..f286f1135 --- /dev/null +++ b/crates/prover/src/constraint_framework/expr/format.rs @@ -0,0 +1,65 @@ +use num_traits::Zero; + +use super::{BaseExpr, ColumnExpr, ExtExpr, CLAIMED_SUM_DUMMY_OFFSET}; + +impl BaseExpr { + pub fn format_expr(&self) -> String { + match self { + BaseExpr::Col(ColumnExpr { + interaction, + idx, + offset, + }) => { + let offset_str = if *offset == CLAIMED_SUM_DUMMY_OFFSET as isize { + "claimed_sum".to_string() + } else { + let offset_abs = offset.abs(); + if *offset >= 0 { + offset.to_string() + } else { + format!("neg_{offset_abs}") + } + }; + format!("trace_{interaction}_column_{idx}_offset_{offset_str}") + } + BaseExpr::Const(c) => format!("m31({c}).into()"), + BaseExpr::Param(v) => v.to_string(), + BaseExpr::Add(a, b) => format!("{} + {}", a.format_expr(), b.format_expr()), + BaseExpr::Sub(a, b) => format!("{} - ({})", a.format_expr(), b.format_expr()), + BaseExpr::Mul(a, b) => format!("({}) * ({})", a.format_expr(), b.format_expr()), + BaseExpr::Neg(a) => format!("-({})", a.format_expr()), + BaseExpr::Inv(a) => format!("1 / ({})", a.format_expr()), + } + } +} + +impl ExtExpr { + pub fn format_expr(&self) -> String { + match self { + ExtExpr::SecureCol([a, b, c, d]) => { + // If the expression's non-base components are all constant zeroes, return the base + // field representation of its first part. + if **b == BaseExpr::zero() && **c == BaseExpr::zero() && **d == BaseExpr::zero() { + a.format_expr() + } else { + format!( + "QM31Impl::from_partial_evals([{}, {}, {}, {}])", + a.format_expr(), + b.format_expr(), + c.format_expr(), + d.format_expr() + ) + } + } + ExtExpr::Const(c) => { + let [v0, v1, v2, v3] = c.to_m31_array(); + format!("qm31({v0}, {v1}, {v2}, {v3})") + } + ExtExpr::Param(v) => v.to_string(), + ExtExpr::Add(a, b) => format!("{} + {}", a.format_expr(), b.format_expr()), + ExtExpr::Sub(a, b) => format!("{} - ({})", a.format_expr(), b.format_expr()), + ExtExpr::Mul(a, b) => format!("({}) * ({})", a.format_expr(), b.format_expr()), + ExtExpr::Neg(a) => format!("-({})", a.format_expr()), + } + } +} diff --git a/crates/prover/src/constraint_framework/expr/mod.rs b/crates/prover/src/constraint_framework/expr/mod.rs new file mode 100644 index 000000000..7b17fc73e --- /dev/null +++ b/crates/prover/src/constraint_framework/expr/mod.rs @@ -0,0 +1,352 @@ +pub mod assignment; +pub mod degree; +pub mod evaluator; +pub mod format; +pub mod simplify; +pub mod utils; + +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub}; + +pub use evaluator::ExprEvaluator; +use num_traits::{One, Zero}; + +use crate::constraint_framework::expr::evaluator::CLAIMED_SUM_DUMMY_OFFSET; +use crate::core::fields::cm31::CM31; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::{SecureField, QM31}; +use crate::core::fields::FieldExpOps; + +/// A single base field column at index `idx` of interaction `interaction`, at mask offset `offset`. +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct ColumnExpr { + interaction: usize, + idx: usize, + offset: isize, +} + +impl From<(usize, usize, isize)> for ColumnExpr { + fn from((interaction, idx, offset): (usize, usize, isize)) -> Self { + Self { + interaction, + idx, + offset, + } + } +} + +/// An expression representing a base field value. Can be either: +/// * A column indexed by a `ColumnExpr`. +/// * A base field constant. +/// * A formal parameter to the AIR. +/// * A sum, difference, or product of two base field expressions. +/// * A negation or inverse of a base field expression. +/// +/// This type is meant to be used as an F associated type for EvalAtRow and interacts with +/// `ExtExpr`, `BaseField` and `SecureField` as expected. +#[derive(Clone, Debug, PartialEq)] +pub enum BaseExpr { + Col(ColumnExpr), + Const(BaseField), + /// Formal parameter to the AIR, for example the interaction elements of a relation. + Param(String), + Add(Box, Box), + Sub(Box, Box), + Mul(Box, Box), + Neg(Box), + Inv(Box), +} + +/// An expression representing a secure field value. Can be either: +/// * A secure column constructed from 4 base field expressions. +/// * A secure field constant. +/// * A formal parameter to the AIR. +/// * A sum, difference, or product of two secure field expressions. +/// * A negation of a secure field expression. +/// +/// This type is meant to be used as an EF associated type for EvalAtRow and interacts with +/// `BaseExpr`, `BaseField` and `SecureField` as expected. +#[derive(Clone, Debug, PartialEq)] +pub enum ExtExpr { + /// An atomic secure column constructed from 4 expressions. + /// Expressions on the secure column are not reduced, i.e, + /// if `a = SecureCol(a0, a1, a2, a3)`, `b = SecureCol(b0, b1, b2, b3)` then + /// `a + b` evaluates to `Add(a, b)` rather than + /// `SecureCol(Add(a0, b0), Add(a1, b1), Add(a2, b2), Add(a3, b3))` + SecureCol([Box; 4]), + Const(SecureField), + /// Formal parameter to the AIR, for example the interaction elements of a relation. + Param(String), + Add(Box, Box), + Sub(Box, Box), + Mul(Box, Box), + Neg(Box), +} + +impl From for BaseExpr { + fn from(val: BaseField) -> Self { + BaseExpr::Const(val) + } +} + +impl From for ExtExpr { + fn from(val: BaseField) -> Self { + ExtExpr::SecureCol([ + Box::new(BaseExpr::from(val)), + Box::new(BaseExpr::zero()), + Box::new(BaseExpr::zero()), + Box::new(BaseExpr::zero()), + ]) + } +} + +impl From for ExtExpr { + fn from(QM31(CM31(a, b), CM31(c, d)): SecureField) -> Self { + ExtExpr::SecureCol([ + Box::new(BaseExpr::from(a)), + Box::new(BaseExpr::from(b)), + Box::new(BaseExpr::from(c)), + Box::new(BaseExpr::from(d)), + ]) + } +} + +impl From for ExtExpr { + fn from(expr: BaseExpr) -> Self { + ExtExpr::SecureCol([ + Box::new(expr.clone()), + Box::new(BaseExpr::zero()), + Box::new(BaseExpr::zero()), + Box::new(BaseExpr::zero()), + ]) + } +} + +impl Add for BaseExpr { + type Output = Self; + fn add(self, rhs: Self) -> Self { + BaseExpr::Add(Box::new(self), Box::new(rhs)) + } +} + +impl Sub for BaseExpr { + type Output = Self; + fn sub(self, rhs: Self) -> Self { + BaseExpr::Sub(Box::new(self), Box::new(rhs)) + } +} + +impl Mul for BaseExpr { + type Output = Self; + fn mul(self, rhs: Self) -> Self { + BaseExpr::Mul(Box::new(self), Box::new(rhs)) + } +} + +impl AddAssign for BaseExpr { + fn add_assign(&mut self, rhs: Self) { + *self = self.clone() + rhs + } +} + +impl MulAssign for BaseExpr { + fn mul_assign(&mut self, rhs: Self) { + *self = self.clone() * rhs + } +} + +impl Neg for BaseExpr { + type Output = Self; + fn neg(self) -> Self { + BaseExpr::Neg(Box::new(self)) + } +} + +impl Add for ExtExpr { + type Output = Self; + fn add(self, rhs: Self) -> Self { + ExtExpr::Add(Box::new(self), Box::new(rhs)) + } +} + +impl Sub for ExtExpr { + type Output = Self; + fn sub(self, rhs: Self) -> Self { + ExtExpr::Sub(Box::new(self), Box::new(rhs)) + } +} + +impl Mul for ExtExpr { + type Output = Self; + fn mul(self, rhs: Self) -> Self { + ExtExpr::Mul(Box::new(self), Box::new(rhs)) + } +} + +impl AddAssign for ExtExpr { + fn add_assign(&mut self, rhs: Self) { + *self = self.clone() + rhs + } +} + +impl MulAssign for ExtExpr { + fn mul_assign(&mut self, rhs: Self) { + *self = self.clone() * rhs + } +} + +impl Neg for ExtExpr { + type Output = Self; + fn neg(self) -> Self { + ExtExpr::Neg(Box::new(self)) + } +} + +impl Zero for BaseExpr { + fn zero() -> Self { + BaseExpr::from(BaseField::zero()) + } + fn is_zero(&self) -> bool { + // TODO(alont): consider replacing `Zero` in the trait bound with a custom trait + // that only has `zero()`. + panic!("Can't check if an expression is zero."); + } +} + +impl One for BaseExpr { + fn one() -> Self { + BaseExpr::from(BaseField::one()) + } +} + +impl Zero for ExtExpr { + fn zero() -> Self { + ExtExpr::from(BaseField::zero()) + } + fn is_zero(&self) -> bool { + // TODO(alont): consider replacing `Zero` in the trait bound with a custom trait + // that only has `zero()`. + panic!("Can't check if an expression is zero."); + } +} + +impl One for ExtExpr { + fn one() -> Self { + ExtExpr::from(BaseField::one()) + } +} + +impl FieldExpOps for BaseExpr { + fn inverse(&self) -> Self { + BaseExpr::Inv(Box::new(self.clone())) + } +} + +impl Add for BaseExpr { + type Output = Self; + fn add(self, rhs: BaseField) -> Self { + self + BaseExpr::from(rhs) + } +} + +impl AddAssign for BaseExpr { + fn add_assign(&mut self, rhs: BaseField) { + *self = self.clone() + BaseExpr::from(rhs) + } +} + +impl Mul for BaseExpr { + type Output = Self; + fn mul(self, rhs: BaseField) -> Self { + self * BaseExpr::from(rhs) + } +} + +impl Mul for BaseExpr { + type Output = ExtExpr; + fn mul(self, rhs: SecureField) -> ExtExpr { + ExtExpr::from(self) * ExtExpr::from(rhs) + } +} + +impl Add for BaseExpr { + type Output = ExtExpr; + fn add(self, rhs: SecureField) -> ExtExpr { + ExtExpr::from(self) + ExtExpr::from(rhs) + } +} + +impl Sub for BaseExpr { + type Output = ExtExpr; + fn sub(self, rhs: SecureField) -> ExtExpr { + ExtExpr::from(self) - ExtExpr::from(rhs) + } +} + +impl Add for ExtExpr { + type Output = Self; + fn add(self, rhs: BaseField) -> Self { + self + ExtExpr::from(rhs) + } +} + +impl AddAssign for ExtExpr { + fn add_assign(&mut self, rhs: BaseField) { + *self = self.clone() + ExtExpr::from(rhs) + } +} + +impl Mul for ExtExpr { + type Output = Self; + fn mul(self, rhs: BaseField) -> Self { + self * ExtExpr::from(rhs) + } +} + +impl Mul for ExtExpr { + type Output = Self; + fn mul(self, rhs: SecureField) -> Self { + self * ExtExpr::from(rhs) + } +} + +impl Add for ExtExpr { + type Output = Self; + fn add(self, rhs: SecureField) -> Self { + self + ExtExpr::from(rhs) + } +} + +impl Sub for ExtExpr { + type Output = Self; + fn sub(self, rhs: SecureField) -> Self { + self - ExtExpr::from(rhs) + } +} + +impl Add for ExtExpr { + type Output = Self; + fn add(self, rhs: BaseExpr) -> Self { + self + ExtExpr::from(rhs) + } +} + +impl Mul for ExtExpr { + type Output = Self; + fn mul(self, rhs: BaseExpr) -> Self { + self * ExtExpr::from(rhs) + } +} + +impl Mul for BaseExpr { + type Output = ExtExpr; + fn mul(self, rhs: ExtExpr) -> ExtExpr { + rhs * self + } +} + +impl Sub for ExtExpr { + type Output = Self; + fn sub(self, rhs: BaseExpr) -> Self { + self - ExtExpr::from(rhs) + } +} diff --git a/crates/prover/src/constraint_framework/expr/simplify.rs b/crates/prover/src/constraint_framework/expr/simplify.rs new file mode 100644 index 000000000..3632c09e2 --- /dev/null +++ b/crates/prover/src/constraint_framework/expr/simplify.rs @@ -0,0 +1,216 @@ +use num_traits::{One, Zero}; + +use super::{BaseExpr, ExtExpr}; +use crate::core::fields::qm31::SecureField; + +/// Applies simplifications to arithmetic expressions that can be used both for `BaseExpr` and for +/// `ExtExpr`. +macro_rules! simplify_arithmetic { + ($self:tt) => { + match $self.clone() { + Self::Add(a, b) => { + let a = a.simplify(); + let b = b.simplify(); + match (a.clone(), b.clone()) { + // Simplify constants. + (Self::Const(a), Self::Const(b)) => Self::Const(a + b), + (Self::Const(a_val), _) if a_val.is_zero() => b, // 0 + b = b + (_, Self::Const(b_val)) if b_val.is_zero() => a, // a + 0 = a + // Simplify Negs. + // (-a + -b) = -(a + b) + (Self::Neg(minus_a), Self::Neg(minus_b)) => -(*minus_a + *minus_b), + (Self::Neg(minus_a), _) => b - *minus_a, // -a + b = b - a + (_, Self::Neg(minus_b)) => a - *minus_b, // a + -b = a - b + // No simplification. + _ => a + b, + } + } + Self::Sub(a, b) => { + let a = a.simplify(); + let b = b.simplify(); + match (a.clone(), b.clone()) { + // Simplify constants. + (Self::Const(a), Self::Const(b)) => Self::Const(a - b), // Simplify consts. + (Self::Const(a_val), _) if a_val.is_zero() => -b, // 0 - b = -b + (_, Self::Const(b_val)) if b_val.is_zero() => a, // a - 0 = a + // Simplify Negs. + // (-a - -b) = b - a + (Self::Neg(minus_a), Self::Neg(minus_b)) => *minus_b - *minus_a, + (Self::Neg(minus_a), _) => -(*minus_a + b), // -a - b = -(a + b) + (_, Self::Neg(minus_b)) => a + *minus_b, // a + -b = a - b + // No Simplification. + _ => a - b, + } + } + Self::Mul(a, b) => { + let a = a.simplify(); + let b = b.simplify(); + match (a.clone(), b.clone()) { + // Simplify consts. + (Self::Const(a), Self::Const(b)) => Self::Const(a * b), + (Self::Const(a_val), _) if a_val.is_zero() => Self::zero(), // 0 * b = 0 + (_, Self::Const(b_val)) if b_val.is_zero() => Self::zero(), // a * 0 = 0 + (Self::Const(a_val), _) if a_val == One::one() => b, // 1 * b = b + (_, Self::Const(b_val)) if b_val == One::one() => a, // a * 1 = a + (Self::Const(a_val), _) if -a_val == One::one() => -b, // -1 * b = -b + (_, Self::Const(b_val)) if -b_val == One::one() => -a, // a * -1 = -a + // Simplify Negs. + // (-a) * (-b) = a * b + (Self::Neg(minus_a), Self::Neg(minus_b)) => *minus_a * *minus_b, + (Self::Neg(minus_a), _) => -(*minus_a * b), // (-a) * b = -(a * b) + (_, Self::Neg(minus_b)) => -(a * *minus_b), // a * (-b) = -(a * b) + // No simplification. + _ => a * b, + } + } + Self::Neg(a) => { + let a = a.simplify(); + match a { + Self::Const(c) => Self::Const(-c), + Self::Neg(minus_a) => *minus_a, // -(-a) = a + Self::Sub(a, b) => Self::Sub(b, a), // -(a - b) = b - a + _ => -a, // No simplification. + } + } + other => other, // No simplification. + } + }; +} + +impl BaseExpr { + /// Helper function, use [`simplify`] instead. + /// + /// Simplifies an expression by applying basic arithmetic rules. + fn unchecked_simplify(&self) -> Self { + let simple = simplify_arithmetic!(self); + match simple { + Self::Inv(a) => { + let a = a.unchecked_simplify(); + match a { + Self::Inv(inv_a) => *inv_a, // 1 / (1 / a) = a + Self::Const(c) => Self::Const(c.inverse()), + _ => Self::Inv(Box::new(a)), + } + } + other => other, + } + } + + /// Simplifies an expression by applying basic arithmetic rules and ensures that the result is + /// equivalent to the original expression by assigning random values. + pub fn simplify(&self) -> Self { + let simplified = self.unchecked_simplify(); + assert_eq!(self.random_eval(), simplified.random_eval()); + simplified + } + + pub fn simplify_and_format(&self) -> String { + self.simplify().format_expr() + } +} + +impl ExtExpr { + /// Helper function, use [`simplify`] instead. + /// + /// Simplifies an expression by applying basic arithmetic rules. + fn unchecked_simplify(&self) -> Self { + let simple = simplify_arithmetic!(self); + match simple { + Self::SecureCol([a, b, c, d]) => { + let a = a.unchecked_simplify(); + let b = b.unchecked_simplify(); + let c = c.unchecked_simplify(); + let d = d.unchecked_simplify(); + match (a.clone(), b.clone(), c.clone(), d.clone()) { + ( + BaseExpr::Const(a_val), + BaseExpr::Const(b_val), + BaseExpr::Const(c_val), + BaseExpr::Const(d_val), + ) => ExtExpr::Const(SecureField::from_m31_array([a_val, b_val, c_val, d_val])), + _ => Self::SecureCol([Box::new(a), Box::new(b), Box::new(c), Box::new(d)]), + } + } + other => other, + } + } + + /// Simplifies an expression by applying basic arithmetic rules and ensures that the result is + /// equivalent to the original expression by assigning random values. + pub fn simplify(&self) -> Self { + let simplified = self.unchecked_simplify(); + assert_eq!(self.random_eval(), simplified.random_eval()); + simplified + } + + pub fn simplify_and_format(&self) -> String { + self.simplify().format_expr() + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use crate::constraint_framework::expr::utils::*; + use crate::constraint_framework::AssertEvaluator; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + #[test] + fn test_simplify_expr() { + let c0 = col!(1, 0, 0); + let c1 = col!(1, 1, 0); + let a = var!("a"); + let b = qvar!("b"); + let zero = felt!(0); + let qzero = qfelt!(0, 0, 0, 0); + let one = felt!(1); + let qone = qfelt!(1, 0, 0, 0); + let minus_one = felt!(crate::core::fields::m31::P - 1); + let qminus_one = qfelt!(crate::core::fields::m31::P - 1, 0, 0, 0); + + let mut rng = SmallRng::seed_from_u64(0); + let columns: HashMap<(usize, usize, isize), BaseField> = + HashMap::from([((1, 0, 0), rng.gen()), ((1, 1, 0), rng.gen())]); + let vars: HashMap = HashMap::from([("a".to_string(), rng.gen())]); + let ext_vars: HashMap = HashMap::from([("b".to_string(), rng.gen())]); + + let base_expr = (((zero.clone() + c0.clone()) + (a.clone() + zero.clone())) + * ((-c1.clone()) + (-c0.clone())) + + (-(-(a.clone() + a.clone() + c0.clone()))) + - zero.clone()) + + (a.clone() - zero.clone()) + + (-c1.clone() - (a.clone() * a.clone())) + + (a.clone() * zero.clone()) + - (zero.clone() * c1.clone()) + + one.clone() + * a.clone() + * one.clone() + * c1.clone() + * (-a.clone()) + * c1.clone() + * (minus_one.clone() * c0.clone()); + + let expr = (qzero.clone() + + secure_col!( + base_expr.clone(), + base_expr.clone(), + zero.clone(), + one.clone() + ) + - qzero.clone()) + * qone.clone() + * b.clone() + * qminus_one.clone(); + + let full_eval = expr.eval_expr::, _, _, _>(&columns, &vars, &ext_vars); + let simplified_eval = expr + .simplify() + .eval_expr::, _, _, _>(&columns, &vars, &ext_vars); + + assert_eq!(full_eval, simplified_eval); + } +} diff --git a/crates/prover/src/constraint_framework/expr/utils.rs b/crates/prover/src/constraint_framework/expr/utils.rs new file mode 100644 index 000000000..724840a06 --- /dev/null +++ b/crates/prover/src/constraint_framework/expr/utils.rs @@ -0,0 +1,65 @@ +#[cfg(test)] +macro_rules! secure_col { + ($a:expr, $b:expr, $c:expr, $d:expr) => { + crate::constraint_framework::expr::ExtExpr::SecureCol([ + Box::new($a.into()), + Box::new($b.into()), + Box::new($c.into()), + Box::new($d.into()), + ]) + }; +} +#[cfg(test)] +pub(crate) use secure_col; + +#[cfg(test)] +macro_rules! col { + ($interaction:expr, $idx:expr, $offset:expr) => { + crate::constraint_framework::expr::BaseExpr::Col(($interaction, $idx, $offset).into()) + }; +} +#[cfg(test)] +pub(crate) use col; + +#[cfg(test)] +macro_rules! var { + ($var:expr) => { + crate::constraint_framework::expr::BaseExpr::Param($var.to_string()) + }; +} +#[cfg(test)] +pub(crate) use var; + +#[cfg(test)] +macro_rules! qvar { + ($var:expr) => { + crate::constraint_framework::expr::ExtExpr::Param($var.to_string()) + }; +} +#[cfg(test)] +pub(crate) use qvar; + +#[cfg(test)] +macro_rules! felt { + ($val:expr) => { + crate::constraint_framework::expr::BaseExpr::Const($val.into()) + }; +} +#[cfg(test)] +pub(crate) use felt; + +#[cfg(test)] +macro_rules! qfelt { + ($a:expr, $b:expr, $c:expr, $d:expr) => { + crate::constraint_framework::expr::ExtExpr::Const( + crate::core::fields::qm31::SecureField::from_m31_array([ + $a.into(), + $b.into(), + $c.into(), + $d.into(), + ]), + ) + }; +} +#[cfg(test)] +pub(crate) use qfelt; diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index 650607c1d..370987e4c 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -27,6 +27,16 @@ pub type ClaimedPrefixSum = (SecureField, usize); // (total_sum, claimed_sum) pub type LogupSums = (SecureField, Option); +pub trait LogupSumsExt { + fn value(&self) -> SecureField; +} + +impl LogupSumsExt for LogupSums { + fn value(&self) -> SecureField { + self.1.map(|(claimed_sum, _)| claimed_sum).unwrap_or(self.0) + } +} + /// Evaluates constraints for batched logups. /// These constraint enforce the sum of multiplicity_i / (z + sum_j alpha^j * x_j) = claimed_sum. pub struct LogupAtRow { @@ -39,8 +49,7 @@ pub struct LogupAtRow { /// None if the claimed_sum is the total_sum. pub claimed_sum: Option, /// The evaluation of the last cumulative sum column. - pub prev_col_cumsum: E::EF, - pub cur_frac: Option>, + pub fracs: Vec>, pub is_finalized: bool, /// The value of the `is_first` constant column at current row. /// See [`super::preprocessed_columns::gen_is_first()`]. @@ -64,8 +73,7 @@ impl LogupAtRow { interaction, total_sum, claimed_sum, - prev_col_cumsum: E::EF::zero(), - cur_frac: None, + fracs: vec![], is_finalized: true, is_first: E::F::zero(), log_size, @@ -78,8 +86,7 @@ impl LogupAtRow { interaction: 100, total_sum: SecureField::one(), claimed_sum: None, - prev_col_cumsum: E::EF::zero(), - cur_frac: None, + fracs: vec![], is_finalized: true, is_first: E::F::zero(), log_size: 10, @@ -231,7 +238,7 @@ pub struct LogupColGenerator<'a> { /// Numerator expressions (i.e. multiplicities) being generated for the current lookup. numerator: SecureColumnByCoords, } -impl<'a> LogupColGenerator<'a> { +impl LogupColGenerator<'_> { /// Write a fraction to the column at a row. pub fn write_frac( &mut self, diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 9bbf05402..22809152d 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -7,6 +7,7 @@ mod info; pub mod logup; mod point; pub mod preprocessed_columns; +pub mod relation_tracker; mod simd_domain; use std::array; @@ -31,6 +32,13 @@ pub const PREPROCESSED_TRACE_IDX: usize = 0; pub const ORIGINAL_TRACE_IDX: usize = 1; pub const INTERACTION_TRACE_IDX: usize = 2; +/// A vector that describes the batching of logup entries. +/// Each vector member corresponds to a logup entry, and contains the batch number to which the +/// entry should be added. +/// Note that the batch numbers should be consecutive and start from 0, and that the vector's +/// length should be equal to the number of logup entries. +type Batching = Vec; + /// A trait for evaluating expressions at some point or row. pub trait EvalAtRow { // TODO(Ohad): Use a better trait for these, like 'Algebra' or something. @@ -108,7 +116,19 @@ pub trait EvalAtRow { /// Adds a constraint to the component. fn add_constraint(&mut self, constraint: G) where - Self::EF: Mul; + Self::EF: Mul + From; + + /// Adds an intermediate value in the base field to the component and returns its value. + /// Does nothing by default. + fn add_intermediate(&mut self, val: Self::F) -> Self::F { + val + } + + /// Adds an intermediate value in the extension field to the component and returns its value. + /// Does nothing by default. + fn add_extension_intermediate(&mut self, val: Self::EF) -> Self::EF { + val + } /// Combines 4 base field values into a single extension field value. fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF; @@ -119,25 +139,30 @@ pub trait EvalAtRow { /// multiplied. fn add_to_relation>( &mut self, - entries: &[RelationEntry<'_, Self::F, Self::EF, R>], + entry: RelationEntry<'_, Self::F, Self::EF, R>, ) { - let fracs = entries.iter().map( - |RelationEntry { - relation, - multiplicity, - values, - }| { Fraction::new(multiplicity.clone(), relation.combine(values)) }, + let frac = Fraction::new( + entry.multiplicity.clone(), + entry.relation.combine(entry.values), ); - self.write_logup_frac(fracs.sum()); + self.write_logup_frac(frac); } // TODO(alont): Remove these once LogupAtRow is no longer used. fn write_logup_frac(&mut self, _fraction: Fraction) { unimplemented!() } - fn finalize_logup(&mut self) { + fn finalize_logup_batched(&mut self, _batching: &Batching) { unimplemented!() } + + fn finalize_logup(&mut self) { + unimplemented!(); + } + + fn finalize_logup_in_pairs(&mut self) { + unimplemented!(); + } } /// Default implementation for evaluators that have an element called "logup" that works like a @@ -146,35 +171,70 @@ pub trait EvalAtRow { macro_rules! logup_proxy { () => { fn write_logup_frac(&mut self, fraction: Fraction) { - // Add a constraint that num / denom = diff. - if let Some(cur_frac) = self.logup.cur_frac.clone() { - let [cur_cumsum] = - self.next_extension_interaction_mask(self.logup.interaction, [0]); - let diff = cur_cumsum.clone() - self.logup.prev_col_cumsum.clone(); - self.logup.prev_col_cumsum = cur_cumsum; - self.add_constraint(diff * cur_frac.denominator - cur_frac.numerator); - } else { + if self.logup.fracs.is_empty() { self.logup.is_first = self.get_preprocessed_column( - super::preprocessed_columns::PreprocessedColumn::IsFirst(self.logup.log_size), + crate::constraint_framework::preprocessed_columns::PreprocessedColumn::IsFirst( + self.logup.log_size, + ), ); self.logup.is_finalized = false; } - self.logup.cur_frac = Some(fraction); + self.logup.fracs.push(fraction.clone()); } - fn finalize_logup(&mut self) { + /// Finalize the logup by adding the constraints for the fractions, batched by + /// the given `batching`. + /// `batching` should contain the batch into which every logup entry should be inserted. + fn finalize_logup_batched(&mut self, batching: &crate::constraint_framework::Batching) { assert!(!self.logup.is_finalized, "LogupAtRow was already finalized"); + assert_eq!( + batching.len(), + self.logup.fracs.len(), + "Batching must be of the same length as the number of entries" + ); + + let last_batch = *batching.iter().max().unwrap(); - let frac = self.logup.cur_frac.clone().unwrap(); + let mut fracs_by_batch = + std::collections::HashMap::>>::new(); + + for (batch, frac) in batching.iter().zip(self.logup.fracs.iter()) { + fracs_by_batch + .entry(*batch) + .or_insert_with(Vec::new) + .push(frac.clone()); + } + + let keys_set: std::collections::HashSet<_> = fracs_by_batch.keys().cloned().collect(); + let all_batches_set: std::collections::HashSet<_> = (0..last_batch + 1).collect(); + + assert_eq!( + keys_set, all_batches_set, + "Batching must contain all consecutive batches" + ); + + let mut prev_col_cumsum = ::zero(); + + // All batches except the last are cumulatively summed in new interaction columns. + for batch_id in (0..last_batch) { + let cur_frac: Fraction<_, _> = fracs_by_batch[&batch_id].iter().cloned().sum(); + let [cur_cumsum] = + self.next_extension_interaction_mask(self.logup.interaction, [0]); + let diff = cur_cumsum.clone() - prev_col_cumsum.clone(); + prev_col_cumsum = cur_cumsum; + self.add_constraint(diff * cur_frac.denominator - cur_frac.numerator); + } + + let frac: Fraction<_, _> = fracs_by_batch[&last_batch].clone().into_iter().sum(); // TODO(ShaharS): remove `claimed_row_index` interaction value and get the shifted // offset from the is_first column when constant columns are supported. let (cur_cumsum, prev_row_cumsum) = match self.logup.claimed_sum.clone() { Some((claimed_sum, claimed_row_index)) => { - let [cur_cumsum, prev_row_cumsum, claimed_cumsum] = self + let [prev_row_cumsum, cur_cumsum, claimed_cumsum] = self .next_extension_interaction_mask( self.logup.interaction, - [0, -1, claimed_row_index as isize], + [-1, 0, claimed_row_index as isize], ); // Constrain that the claimed_sum in case that it is not equal to the total_sum. @@ -184,20 +244,33 @@ macro_rules! logup_proxy { (cur_cumsum, prev_row_cumsum) } None => { - let [cur_cumsum, prev_row_cumsum] = - self.next_extension_interaction_mask(self.logup.interaction, [0, -1]); + let [prev_row_cumsum, cur_cumsum] = + self.next_extension_interaction_mask(self.logup.interaction, [-1, 0]); (cur_cumsum, prev_row_cumsum) } }; // Fix `prev_row_cumsum` by subtracting `total_sum` if this is the first row. let fixed_prev_row_cumsum = prev_row_cumsum - self.logup.is_first.clone() * self.logup.total_sum.clone(); - let diff = cur_cumsum - fixed_prev_row_cumsum - self.logup.prev_col_cumsum.clone(); + let diff = cur_cumsum - fixed_prev_row_cumsum - prev_col_cumsum.clone(); self.add_constraint(diff * frac.denominator - frac.numerator); self.logup.is_finalized = true; } + + /// Finalizes the row's logup in the default way. Currently, this means no batching. + fn finalize_logup(&mut self) { + let batches = (0..self.logup.fracs.len()).collect(); + self.finalize_logup_batched(&batches) + } + + /// Finalizes the row's logup, batched in pairs. + /// TODO(alont) Remove this once a better batching mechanism is implemented. + fn finalize_logup_in_pairs(&mut self) { + let batches = (0..self.logup.fracs.len()).map(|n| n / 2).collect(); + self.finalize_logup_batched(&batches) + } }; } pub(crate) use logup_proxy; @@ -234,7 +307,7 @@ pub struct RelationEntry<'a, F: Clone, EF: RelationEFTraitBound, R: Relation< values: &'a [F], } impl<'a, F: Clone, EF: RelationEFTraitBound, R: Relation> RelationEntry<'a, F, EF, R> { - pub fn new(relation: &'a R, multiplicity: EF, values: &'a [F]) -> Self { + pub const fn new(relation: &'a R, multiplicity: EF, values: &'a [F]) -> Self { Self { relation, multiplicity, diff --git a/crates/prover/src/constraint_framework/point.rs b/crates/prover/src/constraint_framework/point.rs index 3fc2ad510..ea01c647d 100644 --- a/crates/prover/src/constraint_framework/point.rs +++ b/crates/prover/src/constraint_framework/point.rs @@ -35,7 +35,7 @@ impl<'a> PointEvaluator<'a> { } } } -impl<'a> EvalAtRow for PointEvaluator<'a> { +impl EvalAtRow for PointEvaluator<'_> { type F = SecureField; type EF = SecureField; diff --git a/crates/prover/src/constraint_framework/preprocessed_columns.rs b/crates/prover/src/constraint_framework/preprocessed_columns.rs index bd7b4c9c9..a54ebc734 100644 --- a/crates/prover/src/constraint_framework/preprocessed_columns.rs +++ b/crates/prover/src/constraint_framework/preprocessed_columns.rs @@ -1,17 +1,91 @@ -use num_traits::One; +use std::simd::Simd; +use num_traits::{One, Zero}; + +use crate::core::backend::simd::m31::{PackedM31, N_LANES}; use crate::core::backend::{Backend, Col, Column}; -use crate::core::fields::m31::BaseField; +use crate::core::fields::m31::{BaseField, M31}; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; use crate::core::poly::BitReversedOrder; use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; +const SIMD_ENUMERATION_0: PackedM31 = unsafe { + PackedM31::from_simd_unchecked(Simd::from_array([ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + ])) +}; + // TODO(ilya): Where should this enum be placed? +// TODO(Gali): Consider making it a trait, add documentation for the rest of the variants. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum PreprocessedColumn { - XorTable(u32, u32, usize), + /// A column with `1` at the first position, and `0` elsewhere. IsFirst(u32), Plonk(usize), + /// A column with the numbers [0..2^log_size-1]. + Seq(u32), + XorTable(u32, u32, usize), +} + +impl PreprocessedColumn { + pub const fn name(&self) -> &'static str { + match self { + PreprocessedColumn::IsFirst(_) => "preprocessed_is_first", + PreprocessedColumn::Plonk(_) => "preprocessed_plonk", + PreprocessedColumn::Seq(_) => "preprocessed_seq", + PreprocessedColumn::XorTable(..) => "preprocessed_xor_table", + } + } + + pub fn log_size(&self) -> u32 { + match self { + PreprocessedColumn::IsFirst(log_size) => *log_size, + PreprocessedColumn::Seq(log_size) => *log_size, + PreprocessedColumn::XorTable(log_size, ..) => *log_size, + PreprocessedColumn::Plonk(_) => unimplemented!(), + } + } + + /// Returns the values of the column at the given row. + pub fn packed_at(&self, vec_row: usize) -> PackedM31 { + match self { + PreprocessedColumn::IsFirst(log_size) => { + assert!(vec_row < (1 << log_size) / N_LANES); + if vec_row == 0 { + unsafe { + PackedM31::from_simd_unchecked(Simd::from_array(std::array::from_fn(|i| { + if i == 0 { + 1 + } else { + 0 + } + }))) + } + } else { + PackedM31::zero() + } + } + PreprocessedColumn::Seq(log_size) => { + assert!(vec_row < (1 << log_size) / N_LANES); + PackedM31::broadcast(M31::from(vec_row * N_LANES)) + SIMD_ENUMERATION_0 + } + + _ => unimplemented!(), + } + } + + /// Generates a column according to the preprocessed column chosen. + pub fn gen_preprocessed_column( + preprocessed_column: &PreprocessedColumn, + ) -> CircleEvaluation { + match preprocessed_column { + PreprocessedColumn::IsFirst(log_size) => gen_is_first(*log_size), + PreprocessedColumn::Plonk(_) | PreprocessedColumn::XorTable(..) => { + unimplemented!("eval_preprocessed_column: Plonk and XorTable are not supported.") + } + PreprocessedColumn::Seq(log_size) => gen_seq(*log_size), + } + } } /// Generates a column with a single one at the first position, and zeros elsewhere. @@ -44,19 +118,61 @@ pub fn gen_is_step_with_offset( CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col) } -pub fn gen_preprocessed_column( - preprocessed_column: &PreprocessedColumn, -) -> CircleEvaluation { - match preprocessed_column { - PreprocessedColumn::IsFirst(log_size) => gen_is_first(*log_size), - PreprocessedColumn::Plonk(_) | PreprocessedColumn::XorTable(..) => { - unimplemented!("eval_preprocessed_column: Plonk and XorTable are not supported.") - } - } +/// Generates a column with sequence of numbers from 0 to 2^log_size - 1. +pub fn gen_seq(log_size: u32) -> CircleEvaluation { + let col = Col::::from_iter((0..(1 << log_size)).map(BaseField::from)); + CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col) } pub fn gen_preprocessed_columns<'a, B: Backend>( columns: impl Iterator, ) -> Vec> { - columns.map(gen_preprocessed_column).collect() + columns + .map(PreprocessedColumn::gen_preprocessed_column) + .collect() +} + +#[cfg(test)] +mod tests { + use crate::core::backend::simd::m31::N_LANES; + use crate::core::backend::simd::SimdBackend; + use crate::core::backend::Column; + use crate::core::fields::m31::{BaseField, M31}; + const LOG_SIZE: u32 = 8; + + #[test] + fn test_gen_seq() { + let seq = super::gen_seq::(LOG_SIZE); + + for i in 0..(1 << LOG_SIZE) { + assert_eq!(seq.at(i), BaseField::from_u32_unchecked(i as u32)); + } + } + + // TODO(Gali): Add packed_at tests for xor_table and plonk. + #[test] + fn test_packed_at_is_first() { + let is_first = super::PreprocessedColumn::IsFirst(LOG_SIZE); + let expected_is_first = super::gen_is_first::(LOG_SIZE).to_cpu(); + + for i in 0..(1 << LOG_SIZE) / N_LANES { + assert_eq!( + is_first.packed_at(i).to_array(), + expected_is_first[i * N_LANES..(i + 1) * N_LANES] + ); + } + } + + #[test] + fn test_packed_at_seq() { + let seq = super::PreprocessedColumn::Seq(LOG_SIZE); + let expected_seq: [_; 1 << LOG_SIZE] = std::array::from_fn(|i| M31::from(i as u32)); + + let packed_seq = std::array::from_fn::<_, { (1 << LOG_SIZE) / N_LANES }, _>(|i| { + seq.packed_at(i).to_array() + }) + .concat(); + + assert_eq!(packed_seq, expected_seq); + } } diff --git a/crates/prover/src/constraint_framework/relation_tracker.rs b/crates/prover/src/constraint_framework/relation_tracker.rs new file mode 100644 index 000000000..e4606e7bf --- /dev/null +++ b/crates/prover/src/constraint_framework/relation_tracker.rs @@ -0,0 +1,265 @@ +use std::collections::HashMap; +use std::fmt::Debug; + +use itertools::Itertools; +use num_traits::Zero; + +use super::logup::LogupSums; +use super::preprocessed_columns::PreprocessedColumn; +use super::{ + Batching, EvalAtRow, FrameworkEval, InfoEvaluator, Relation, RelationEntry, + TraceLocationAllocator, INTERACTION_TRACE_IDX, +}; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::very_packed_m31::LOG_N_VERY_PACKED_ELEMS; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::fields::m31::{BaseField, M31}; +use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; +use crate::core::lookups::utils::Fraction; +use crate::core::pcs::{TreeSubspan, TreeVec}; +use crate::core::poly::circle::CircleEvaluation; +use crate::core::poly::BitReversedOrder; +use crate::core::utils::{ + bit_reverse_index, coset_index_to_circle_domain_index, offset_bit_reversed_circle_domain_index, +}; + +#[derive(Debug)] +pub struct RelationTrackerEntry { + pub relation: String, + pub mult: M31, + pub values: Vec, +} + +pub struct RelationTrackerComponent { + eval: E, + trace_locations: TreeVec, + n_rows: usize, +} +impl RelationTrackerComponent { + pub fn new(location_allocator: &mut TraceLocationAllocator, eval: E, n_rows: usize) -> Self { + let info = eval.evaluate(InfoEvaluator::new( + eval.log_size(), + vec![], + LogupSums::default(), + )); + let mut mask_offsets = info.mask_offsets; + mask_offsets.drain(INTERACTION_TRACE_IDX..); + let trace_locations = location_allocator.next_for_structure(&mask_offsets); + Self { + eval, + trace_locations, + n_rows, + } + } + + pub fn entries( + self, + trace: &TreeVec>>, + ) -> Vec { + let log_size = self.eval.log_size(); + + // Deref the sub-tree. Only copies the references. + let sub_tree = trace + .sub_tree(&self.trace_locations) + .map(|vec| vec.into_iter().copied().collect_vec()); + let mut entries = vec![]; + + for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + let evaluator = + RelationTrackerEvaluator::new(&sub_tree, vec_row, log_size, self.n_rows); + entries.extend(self.eval.evaluate(evaluator).entries()); + } + entries + } +} + +/// Aggregates relation entries. +pub struct RelationTrackerEvaluator<'a> { + entries: Vec, + pub trace_eval: + &'a TreeVec>>, + pub column_index_per_interaction: Vec, + pub vec_row: usize, + pub domain_log_size: u32, + pub n_rows: usize, +} +impl<'a> RelationTrackerEvaluator<'a> { + pub fn new( + trace_eval: &'a TreeVec>>, + vec_row: usize, + domain_log_size: u32, + n_rows: usize, + ) -> Self { + Self { + entries: vec![], + trace_eval, + column_index_per_interaction: vec![0; trace_eval.len()], + vec_row, + domain_log_size, + n_rows, + } + } + + pub fn entries(self) -> Vec { + self.entries + } +} +impl EvalAtRow for RelationTrackerEvaluator<'_> { + type F = PackedBaseField; + type EF = PackedSecureField; + + // TODO(Ohad): Add debug boundary checks. + fn next_interaction_mask( + &mut self, + interaction: usize, + offsets: [isize; N], + ) -> [Self::F; N] { + assert_ne!(interaction, INTERACTION_TRACE_IDX); + let col_index = self.column_index_per_interaction[interaction]; + self.column_index_per_interaction[interaction] += 1; + offsets.map(|off| { + // If the offset is 0, we can just return the value directly from this row. + if off == 0 { + unsafe { + let col = &self + .trace_eval + .get_unchecked(interaction) + .get_unchecked(col_index) + .values; + return *col.data.get_unchecked(self.vec_row); + }; + } + // Otherwise, we need to look up the value at the offset. + // Since the domain is bit-reversed circle domain ordered, we need to look up the value + // at the bit-reversed natural order index at an offset. + PackedBaseField::from_array(std::array::from_fn(|i| { + let row_index = offset_bit_reversed_circle_domain_index( + (self.vec_row << (LOG_N_LANES + LOG_N_VERY_PACKED_ELEMS)) + i, + self.domain_log_size, + self.domain_log_size, + off, + ); + self.trace_eval[interaction][col_index].at(row_index) + })) + }) + } + + fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F { + column.packed_at(self.vec_row) + } + + fn add_constraint(&mut self, _constraint: G) {} + + fn combine_ef(_values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF { + PackedSecureField::zero() + } + + fn write_logup_frac(&mut self, _fraction: Fraction) {} + + fn finalize_logup_batched(&mut self, _batching: &Batching) {} + fn finalize_logup(&mut self) {} + fn finalize_logup_in_pairs(&mut self) {} + + fn add_to_relation>( + &mut self, + entry: RelationEntry<'_, Self::F, Self::EF, R>, + ) { + let relation = entry.relation.get_name().to_owned(); + let values = entry.values.iter().map(|v| v.to_array()).collect_vec(); + let mult = entry.multiplicity.to_array(); + + // Unpack SIMD. + for j in 0..N_LANES { + // Skip padded values. + let cannonical_index = bit_reverse_index( + coset_index_to_circle_domain_index( + (self.vec_row << LOG_N_LANES) + j, + self.domain_log_size, + ), + self.domain_log_size, + ); + if cannonical_index >= self.n_rows { + continue; + } + let values = values.iter().map(|v| v[j]).collect_vec(); + let mult = mult[j].to_m31_array()[0]; + self.entries.push(RelationTrackerEntry { + relation: relation.clone(), + mult, + values, + }); + } + } +} + +type RelationInfo = (String, Vec<(Vec, M31)>); +pub struct RelationSummary(Vec); +impl RelationSummary { + /// Returns the sum of every entry's yields and uses. + /// The result is a map from relation name to a list of values(M31 vectors) and their sum. + pub fn summarize_relations(entries: &[RelationTrackerEntry]) -> Self { + let mut entry_by_relation = HashMap::new(); + for entry in entries { + entry_by_relation + .entry(entry.relation.clone()) + .or_insert_with(Vec::new) + .push(entry); + } + let mut summary = vec![]; + for (relation, entries) in entry_by_relation { + let mut relation_sums: HashMap, M31> = HashMap::new(); + for entry in entries { + let mut values = entry.values.clone(); + + // Trailing zeroes do not affect the sum, remove for correct aggregation. + while values.last().is_some_and(|v| v.is_zero()) { + values.pop(); + } + let mult = relation_sums.entry(values).or_insert(M31::zero()); + *mult += entry.mult; + } + let relation_sums = relation_sums.into_iter().collect_vec(); + summary.push((relation.clone(), relation_sums)); + } + Self(summary) + } + + pub fn get_relation_info(&self, relation: &str) -> Option<&[(Vec, M31)]> { + self.0 + .iter() + .find(|(name, _)| name == relation) + .map(|(_, entries)| entries.as_slice()) + } + + /// Cleans up the summary by removing zero-sum entries, only keeping the non-zero ones. + /// Used for debugging. + pub fn cleaned(self) -> Self { + let mut cleaned = vec![]; + for (relation, entries) in self.0 { + let mut cleaned_entries = vec![]; + for (vector, sum) in entries { + if !sum.is_zero() { + cleaned_entries.push((vector, sum)); + } + } + if !cleaned_entries.is_empty() { + cleaned.push((relation, cleaned_entries)); + } + } + Self(cleaned) + } +} +impl Debug for RelationSummary { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for (relation, entries) in &self.0 { + writeln!(f, "{}:", relation)?; + for (vector, sum) in entries { + let vector = vector.iter().map(|v| v.0).collect_vec(); + writeln!(f, " {:?} -> {}", vector, sum)?; + } + } + Ok(()) + } +} diff --git a/crates/prover/src/constraint_framework/simd_domain.rs b/crates/prover/src/constraint_framework/simd_domain.rs index b18e8e265..65c52708c 100644 --- a/crates/prover/src/constraint_framework/simd_domain.rs +++ b/crates/prover/src/constraint_framework/simd_domain.rs @@ -57,7 +57,7 @@ impl<'a> SimdDomainEvaluator<'a> { } } } -impl<'a> EvalAtRow for SimdDomainEvaluator<'a> { +impl EvalAtRow for SimdDomainEvaluator<'_> { type F = VeryPackedBaseField; type EF = VeryPackedSecureField; @@ -98,7 +98,7 @@ impl<'a> EvalAtRow for SimdDomainEvaluator<'a> { } fn add_constraint(&mut self, constraint: G) where - Self::EF: Mul, + Self::EF: Mul + From, { self.row_res += VeryPackedSecureField::broadcast(self.random_coeff_powers[self.constraint_index]) diff --git a/crates/prover/src/core/air/accumulation.rs b/crates/prover/src/core/air/accumulation.rs index eed58de8d..6297f97f9 100644 --- a/crates/prover/src/core/air/accumulation.rs +++ b/crates/prover/src/core/air/accumulation.rs @@ -1,4 +1,5 @@ //! Accumulators for a random linear combination of circle polynomials. +//! //! Given N polynomials, u_0(P), ... u_{N-1}(P), and a random alpha, the combined polynomial is //! defined as //! f(p) = sum_i alpha^{N-1-i} u_i(P). @@ -13,7 +14,6 @@ use crate::core::fields::secure_column::SecureColumnByCoords; use crate::core::fields::FieldOps; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, CirclePoly, SecureCirclePoly}; use crate::core::poly::BitReversedOrder; -use crate::core::utils::generate_secure_powers; /// Accumulates N evaluations of u_i(P0) at a single point. /// Computes f(P0), the combined polynomial at that point. @@ -39,7 +39,7 @@ impl PointEvaluationAccumulator { self.accumulation = self.accumulation * self.random_coeff + evaluation; } - pub fn finalize(self) -> SecureField { + pub const fn finalize(self) -> SecureField { self.accumulation } } @@ -63,7 +63,7 @@ impl DomainEvaluationAccumulator { pub fn new(random_coeff: SecureField, max_log_size: u32, total_columns: usize) -> Self { let max_log_size = max_log_size as usize; Self { - random_coeff_powers: generate_secure_powers(random_coeff, total_columns), + random_coeff_powers: B::generate_secure_powers(random_coeff, total_columns), sub_accumulations: (0..(max_log_size + 1)).map(|_| None).collect(), } } @@ -100,15 +100,7 @@ impl DomainEvaluationAccumulator { pub fn log_size(&self) -> u32 { (self.sub_accumulations.len() - 1) as u32 } -} - -pub trait AccumulationOps: FieldOps + Sized { - /// Accumulates other into column: - /// column = column + other. - fn accumulate(column: &mut SecureColumnByCoords, other: &SecureColumnByCoords); -} -impl DomainEvaluationAccumulator { /// Computes f(P) as coefficients. pub fn finalize(self) -> SecureCirclePoly { assert_eq!( @@ -157,12 +149,21 @@ impl DomainEvaluationAccumulator { } } +pub trait AccumulationOps: FieldOps + Sized { + /// Accumulates other into column: + /// column = column + other. + fn accumulate(column: &mut SecureColumnByCoords, other: &SecureColumnByCoords); + + /// Generates the first `n_powers` powers of `felt`. + fn generate_secure_powers(felt: SecureField, n_powers: usize) -> Vec; +} + /// A domain accumulator for polynomials of a single size. pub struct ColumnAccumulator<'a, B: Backend> { pub random_coeff_powers: Vec, pub col: &'a mut SecureColumnByCoords, } -impl<'a> ColumnAccumulator<'a, CpuBackend> { +impl ColumnAccumulator<'_, CpuBackend> { pub fn accumulate(&mut self, index: usize, evaluation: SecureField) { let val = self.col.at(index) + evaluation; self.col.set(index, val); diff --git a/crates/prover/src/core/air/components.rs b/crates/prover/src/core/air/components.rs index 244bc0aac..4835a6918 100644 --- a/crates/prover/src/core/air/components.rs +++ b/crates/prover/src/core/air/components.rs @@ -17,7 +17,7 @@ pub struct Components<'a> { pub n_preprocessed_columns: usize, } -impl<'a> Components<'a> { +impl Components<'_> { pub fn composition_log_degree_bound(&self) -> u32 { self.components .iter() @@ -108,7 +108,7 @@ pub struct ComponentProvers<'a, B: Backend> { pub n_preprocessed_columns: usize, } -impl<'a, B: Backend> ComponentProvers<'a, B> { +impl ComponentProvers<'_, B> { pub fn components(&self) -> Components<'_> { Components { components: self diff --git a/crates/prover/src/core/air/mod.rs b/crates/prover/src/core/air/mod.rs index 671d05048..fbcf6c736 100644 --- a/crates/prover/src/core/air/mod.rs +++ b/crates/prover/src/core/air/mod.rs @@ -15,10 +15,10 @@ mod components; pub mod mask; /// Arithmetic Intermediate Representation (AIR). -/// An Air instance is assumed to already contain all the information needed to -/// evaluate the constraints. -/// For instance, all interaction elements are assumed to be present in it. -/// Therefore, an AIR is generated only after the initial trace commitment phase. +/// +/// An Air instance is assumed to already contain all the information needed to evaluate the +/// constraints. For instance, all interaction elements are assumed to be present in it. Therefore, +/// an AIR is generated only after the initial trace commitment phase. pub trait Air { fn components(&self) -> Vec<&dyn Component>; } diff --git a/crates/prover/src/core/backend/cpu/accumulation.rs b/crates/prover/src/core/backend/cpu/accumulation.rs index eac3def2e..394efdf81 100644 --- a/crates/prover/src/core/backend/cpu/accumulation.rs +++ b/crates/prover/src/core/backend/cpu/accumulation.rs @@ -1,8 +1,11 @@ use std::mem::{size_of_val, transmute}; use std::os::raw::c_void; -use super::CpuBackend; +use num_traits::One; + use crate::core::air::accumulation::AccumulationOps; +use crate::core::backend::cpu::CpuBackend; +use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SecureColumnByCoords; impl AccumulationOps for CpuBackend { @@ -12,4 +15,47 @@ impl AccumulationOps for CpuBackend { column.set(i, res_coeff); } } + + fn generate_secure_powers(felt: SecureField, n_powers: usize) -> Vec { + (0..n_powers) + .scan(SecureField::one(), |acc, _| { + let res = *acc; + *acc *= felt; + Some(res) + }) + .collect() + } +} + +#[cfg(test)] +mod tests { + use num_traits::One; + + use crate::core::air::accumulation::AccumulationOps; + use crate::core::backend::CpuBackend; + use crate::core::fields::qm31::SecureField; + use crate::core::fields::FieldExpOps; + use crate::qm31; + #[test] + fn generate_secure_powers_works() { + let felt = qm31!(1, 2, 3, 4); + let n_powers = 10; + + let powers = ::generate_secure_powers(felt, n_powers); + + assert_eq!(powers.len(), n_powers); + assert_eq!(powers[0], SecureField::one()); + assert_eq!(powers[1], felt); + assert_eq!(powers[7], felt.pow(7)); + } + + #[test] + fn generate_empty_secure_powers_works() { + let felt = qm31!(1, 2, 3, 4); + let max_log_size = 0; + + let powers = ::generate_secure_powers(felt, max_log_size); + + assert_eq!(powers, vec![]); + } } diff --git a/crates/prover/src/core/backend/cpu/circle.rs b/crates/prover/src/core/backend/cpu/circle.rs index 9383af219..f0ddb4a6b 100644 --- a/crates/prover/src/core/backend/cpu/circle.rs +++ b/crates/prover/src/core/backend/cpu/circle.rs @@ -3,6 +3,7 @@ use std::mem::transmute; use num_traits::Zero; use super::CpuBackend; +use crate::core::backend::cpu::bit_reverse; use crate::core::backend::{Col, ColumnOps}; use crate::core::circle::{CirclePoint, Coset}; use crate::core::fft::{butterfly, ibutterfly}; @@ -15,7 +16,7 @@ use crate::core::poly::circle::{ use crate::core::poly::twiddles::TwiddleTree; use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold}; use crate::core::poly::BitReversedOrder; -use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order}; +use crate::core::utils::coset_order_to_circle_domain_order; impl PolyOps for CpuBackend { type Twiddles = Vec; diff --git a/crates/prover/src/core/backend/cpu/lookups/gkr.rs b/crates/prover/src/core/backend/cpu/lookups/gkr.rs index ae9ab6b65..9c1d60093 100644 --- a/crates/prover/src/core/backend/cpu/lookups/gkr.rs +++ b/crates/prover/src/core/backend/cpu/lookups/gkr.rs @@ -265,7 +265,7 @@ enum MleExpr<'a, F: Field> { Mle(&'a Mle), } -impl<'a, F: Field> Index for MleExpr<'a, F> { +impl Index for MleExpr<'_, F> { type Output = F; fn index(&self, index: usize) -> &F { diff --git a/crates/prover/src/core/backend/cpu/mod.rs b/crates/prover/src/core/backend/cpu/mod.rs index 4090b1571..5e96518de 100644 --- a/crates/prover/src/core/backend/cpu/mod.rs +++ b/crates/prover/src/core/backend/cpu/mod.rs @@ -1,4 +1,4 @@ -mod accumulation; +pub mod accumulation; mod blake2s; pub mod circle; mod fri; @@ -16,7 +16,7 @@ use super::{Backend, BackendForChannel, Column, ColumnOps, FieldOps}; use crate::core::fields::Field; use crate::core::lookups::mle::Mle; use crate::core::poly::circle::{CircleEvaluation, CirclePoly}; -use crate::core::utils::bit_reverse; +use crate::core::utils::bit_reverse_index; use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; #[cfg(not(target_arch = "wasm32"))] use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleChannel; @@ -29,6 +29,67 @@ impl BackendForChannel for CpuBackend {} #[cfg(not(target_arch = "wasm32"))] impl BackendForChannel for CpuBackend {} +/// Performs a naive bit-reversal permutation inplace. +/// +/// # Panics +/// +/// Panics if the length of the slice is not a power of two. +// TODO(alont): Move this to the cpu backend. +pub fn bit_reverse(v: &mut [T]) { + let n = v.len(); + assert!(n.is_power_of_two()); + #[cfg(not(feature = "icicle"))] + { + let log_n = n.ilog2(); + for i in 0..n { + let j = bit_reverse_index(i, log_n); + if j > i { + v.swap(i, j); + } + } + } + + #[cfg(feature = "icicle")] + unsafe { + let limbs_count: usize = size_of_val(&v[0]) / 4; + use std::slice; + + use icicle_core::traits::FieldImpl; + use icicle_core::vec_ops::{bit_reverse_inplace, BitReverseConfig, VecOps}; + use icicle_cuda_runtime::device::get_device_from_pointer; + use icicle_cuda_runtime::memory::{DeviceSlice, HostSlice}; + use icicle_m31::field::{ComplexExtensionField, QuarticExtensionField, ScalarField}; + + fn bit_rev_generic(v: &mut [T], n: usize) + where + F: FieldImpl, + ::Config: VecOps, + { + let cfg = BitReverseConfig::default(); + + // Check if v is a DeviceSlice or some other slice type + let mut v_ptr = v.as_mut_ptr() as *mut F; + let rr = unsafe { slice::from_raw_parts_mut(v_ptr, n) }; + + // means data already on device (some finite device id, instead of huge number for host + // pointer) + if get_device_from_pointer(v_ptr as _).unwrap() <= 1024 { + bit_reverse_inplace(unsafe { DeviceSlice::from_mut_slice(rr) }, &cfg).unwrap(); + } else { + bit_reverse_inplace(HostSlice::from_mut_slice(rr), &cfg).unwrap(); + } + } + + if limbs_count == 1 { + bit_rev_generic::(v, n); + } else if limbs_count == 2 { + bit_rev_generic::(v, n); + } else if limbs_count == 4 { + bit_rev_generic::(v, n); + } + } +} + impl ColumnOps for CpuBackend { type Column = Vec; @@ -79,10 +140,25 @@ mod tests { use rand::prelude::*; use rand::rngs::SmallRng; + use crate::core::backend::cpu::bit_reverse; use crate::core::backend::{Column, CpuBackend, FieldOps}; use crate::core::fields::qm31::QM31; use crate::core::fields::FieldExpOps; + #[test] + fn bit_reverse_works() { + let mut data = [0, 1, 2, 3, 4, 5, 6, 7]; + bit_reverse(&mut data); + assert_eq!(data, [0, 4, 2, 6, 1, 5, 3, 7]); + } + + #[test] + #[should_panic] + fn bit_reverse_non_power_of_two_size_fails() { + let mut data = [0, 1, 2, 3, 4, 5]; + bit_reverse(&mut data); + } + #[test] fn batch_inverse_test() { let mut rng = SmallRng::seed_from_u64(0); diff --git a/crates/prover/src/core/backend/cpu/quotients.rs b/crates/prover/src/core/backend/cpu/quotients.rs index f157b76ca..16f0647b6 100644 --- a/crates/prover/src/core/backend/cpu/quotients.rs +++ b/crates/prover/src/core/backend/cpu/quotients.rs @@ -73,10 +73,10 @@ pub fn accumulate_row_quotients( row_accumulator } -/// Precompute the complex conjugate line coefficients for each column in each sample batch. -/// Specifically, for the i-th (in a sample batch) column's numerator term -/// `alpha^i * (c * F(p) - (a * p.y + b))`, we precompute and return the constants: -/// (`alpha^i * a`, `alpha^i * b`, `alpha^i * c`). +/// Precomputes the complex conjugate line coefficients for each column in each sample batch. +/// +/// For the `i`-th (in a sample batch) column's numerator term `alpha^i * (c * F(p) - (a * p.y + +/// b))`, we precompute and return the constants: (`alpha^i * a`, `alpha^i * b`, `alpha^i * c`). pub fn column_line_coeffs( sample_batches: &[ColumnSampleBatch], random_coeff: SecureField, @@ -101,8 +101,9 @@ pub fn column_line_coeffs( .collect() } -/// Precompute the random coefficients used to linearly combine the batched quotients. -/// Specifically, for each sample batch we compute random_coeff^(number of columns in the batch), +/// Precomputes the random coefficients used to linearly combine the batched quotients. +/// +/// For each sample batch we compute random_coeff^(number of columns in the batch), /// which is used to linearly combine the batch with the next one. pub fn batch_random_coeffs( sample_batches: &[ColumnSampleBatch], diff --git a/crates/prover/src/core/backend/icicle/mod.rs b/crates/prover/src/core/backend/icicle/mod.rs index e988252e5..3ed265f04 100644 --- a/crates/prover/src/core/backend/icicle/mod.rs +++ b/crates/prover/src/core/backend/icicle/mod.rs @@ -162,6 +162,11 @@ impl AccumulationOps for IcicleBackend { nvtx::range_pop!(); } } + + fn generate_secure_powers(felt: SecureField, n_powers: usize) -> Vec { + //todo!() + CpuBackend::generate_secure_powers(felt, n_powers) + } } // stwo/crates/prover/src/core/backend/cpu/blake2s.rs @@ -1089,10 +1094,7 @@ mod tests { let (values, decommitment) = merkle.decommit(&queries, cols.iter().collect_vec()); - let verifier = MerkleVerifier { - root: merkle.root(), - column_log_sizes: log_sizes, - }; + let verifier = MerkleVerifier::new(merkle.root(), log_sizes); verifier.verify(&queries, values, decommitment).unwrap(); } diff --git a/crates/prover/src/core/backend/simd/accumulation.rs b/crates/prover/src/core/backend/simd/accumulation.rs index bea476b56..09cb30a28 100644 --- a/crates/prover/src/core/backend/simd/accumulation.rs +++ b/crates/prover/src/core/backend/simd/accumulation.rs @@ -1,5 +1,11 @@ -use super::SimdBackend; +use itertools::Itertools; + use crate::core::air::accumulation::AccumulationOps; +use crate::core::backend::simd::m31::N_LANES; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::CpuBackend; +use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SecureColumnByCoords; impl AccumulationOps for SimdBackend { @@ -11,4 +17,50 @@ impl AccumulationOps for SimdBackend { } nvtx::range_pop!(); } + + /// Generates the first `n_powers` powers of `felt` using SIMD. + /// Refer to `CpuBackend::generate_secure_powers` for the scalar CPU implementation. + fn generate_secure_powers(felt: SecureField, n_powers: usize) -> Vec { + let base_arr = ::generate_secure_powers(felt, N_LANES) + .try_into() + .unwrap(); + let base = PackedSecureField::from_array(base_arr); + let step = PackedSecureField::broadcast(base_arr[N_LANES - 1] * felt); + let size = n_powers.div_ceil(N_LANES); + + // Collects the next N_LANES powers of `felt` in each iteration. + (0..size) + .scan(base, |acc, _| { + let res = *acc; + *acc *= step; + Some(res) + }) + .flat_map(|x| x.to_array()) + .take(n_powers) + .collect_vec() + } +} + +#[cfg(test)] +mod tests { + use crate::core::air::accumulation::AccumulationOps; + use crate::core::backend::cpu::CpuBackend; + use crate::core::backend::simd::SimdBackend; + use crate::qm31; + + #[test] + fn test_generate_secure_powers_simd() { + let felt = qm31!(1, 2, 3, 4); + let n_powers_vec = [0, 16, 100]; + + n_powers_vec.iter().for_each(|&n_powers| { + let expected = ::generate_secure_powers(felt, n_powers); + let actual = ::generate_secure_powers(felt, n_powers); + assert_eq!( + expected, actual, + "Error generating secure powers in n_powers = {}.", + n_powers + ); + }); + } } diff --git a/crates/prover/src/core/backend/simd/bit_reverse.rs b/crates/prover/src/core/backend/simd/bit_reverse.rs index 3efb88979..67ec28f35 100644 --- a/crates/prover/src/core/backend/simd/bit_reverse.rs +++ b/crates/prover/src/core/backend/simd/bit_reverse.rs @@ -6,11 +6,12 @@ use rayon::prelude::*; use super::column::{BaseColumn, SecureColumn}; use super::m31::PackedBaseField; use super::SimdBackend; +use crate::core::backend::cpu::bit_reverse as cpu_bit_reverse; use crate::core::backend::simd::utils::UnsafeMut; use crate::core::backend::ColumnOps; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; -use crate::core::utils::{bit_reverse as cpu_bit_reverse, bit_reverse_index}; +use crate::core::utils::bit_reverse_index; use crate::parallel_iter; const VEC_BITS: u32 = 4; @@ -170,12 +171,12 @@ mod tests { use itertools::Itertools; use super::{bit_reverse16, bit_reverse_m31, MIN_LOG_SIZE}; + use crate::core::backend::cpu::bit_reverse as cpu_bit_reverse; use crate::core::backend::simd::column::BaseColumn; use crate::core::backend::simd::m31::{PackedM31, N_LANES}; use crate::core::backend::simd::SimdBackend; use crate::core::backend::{Column, ColumnOps}; use crate::core::fields::m31::BaseField; - use crate::core::utils::bit_reverse as cpu_bit_reverse; #[test] fn test_bit_reverse16() { @@ -185,7 +186,7 @@ mod tests { let res = bit_reverse16(values.data.try_into().unwrap()); - assert_eq!(res.map(PackedM31::to_array).flatten(), expected); + assert_eq!(res.map(PackedM31::to_array).as_flattened(), expected); } #[test] diff --git a/crates/prover/src/core/backend/simd/blake2s.rs b/crates/prover/src/core/backend/simd/blake2s.rs index 590326068..b499efa8f 100644 --- a/crates/prover/src/core/backend/simd/blake2s.rs +++ b/crates/prover/src/core/backend/simd/blake2s.rs @@ -366,8 +366,12 @@ mod tests { let res_vectorized: [[u32; 8]; 16] = unsafe { transmute(untranspose_states(compress16( - transpose_states(transmute(states)), - transpose_msgs(transmute(msgs)), + transpose_states(transmute::, [u32x16; 8]>( + states, + )), + transpose_msgs(transmute::, [u32x16; 16]>( + msgs, + )), u32x16::splat(count_low), u32x16::splat(count_high), u32x16::splat(lastblock), diff --git a/crates/prover/src/core/backend/simd/circle.rs b/crates/prover/src/core/backend/simd/circle.rs index 5963c5e8d..ab772cd63 100644 --- a/crates/prover/src/core/backend/simd/circle.rs +++ b/crates/prover/src/core/backend/simd/circle.rs @@ -89,10 +89,7 @@ impl SimdBackend { // Generates twiddle steps for efficiently computing the twiddles. // steps[i] = t_i/(t_0*t_1*...*t_i-1). - fn twiddle_steps(mappings: &[F]) -> Vec - where - F: FieldExpOps, - { + fn twiddle_steps(mappings: &[F]) -> Vec { let mut denominators: Vec = vec![mappings[0]]; for i in 1..mappings.len() { @@ -162,7 +159,7 @@ impl PolyOps for SimdBackend { nvtx::range_push!("[SIMD] ifft"); unsafe { ifft::ifft( - transmute(values.data.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(values.data.as_mut_ptr()), &twiddles, log_size as usize, ); @@ -293,8 +290,8 @@ impl PolyOps for SimdBackend { nvtx::range_push!("[SIMD] fft"); unsafe { rfft::fft( - transmute(poly.coeffs.data.as_ptr()), - transmute( + transmute::<*const PackedBaseField, *const u32>(poly.coeffs.data.as_ptr()), + transmute::<*mut PackedBaseField, *mut u32>( values[i << (fft_log_size - LOG_N_LANES) ..(i + 1) << (fft_log_size - LOG_N_LANES)] .as_mut_ptr(), diff --git a/crates/prover/src/core/backend/simd/cm31.rs b/crates/prover/src/core/backend/simd/cm31.rs index 31aba0a44..2155e8ff1 100644 --- a/crates/prover/src/core/backend/simd/cm31.rs +++ b/crates/prover/src/core/backend/simd/cm31.rs @@ -19,12 +19,12 @@ impl PackedCM31 { } /// Returns all `a` values such that each vector element is represented as `a + bi`. - pub fn a(&self) -> PackedM31 { + pub const fn a(&self) -> PackedM31 { self.0[0] } /// Returns all `b` values such that each vector element is represented as `a + bi`. - pub fn b(&self) -> PackedM31 { + pub const fn b(&self) -> PackedM31 { self.0[1] } diff --git a/crates/prover/src/core/backend/simd/column.rs b/crates/prover/src/core/backend/simd/column.rs index 3659a96cf..d63c55f33 100644 --- a/crates/prover/src/core/backend/simd/column.rs +++ b/crates/prover/src/core/backend/simd/column.rs @@ -64,6 +64,13 @@ impl BaseColumn { values.into_iter().collect() } + pub fn from_simd(values: Vec) -> Self { + Self { + length: values.len() * N_LANES, + data: values, + } + } + /// Returns a vector of `BaseColumnMutSlice`s, each mutably owning /// `chunk_size` `PackedBaseField`s (i.e, `chuck_size` * `N_LANES` elements). pub fn chunks_mut(&mut self, chunk_size: usize) -> Vec> { @@ -207,7 +214,7 @@ impl FromIterator for CM31Column { /// A mutable slice of a BaseColumn. pub struct BaseColumnMutSlice<'a>(pub &'a mut [PackedBaseField]); -impl<'a> BaseColumnMutSlice<'a> { +impl BaseColumnMutSlice<'_> { pub fn at(&self, index: usize) -> BaseField { self.0[index / N_LANES].to_array()[index % N_LANES] } @@ -323,7 +330,7 @@ impl FromIterator for SecureColumn { /// A mutable slice of a SecureColumnByCoords. pub struct SecureColumnByCoordsMutSlice<'a>(pub [BaseColumnMutSlice<'a>; SECURE_EXTENSION_DEGREE]); -impl<'a> SecureColumnByCoordsMutSlice<'a> { +impl SecureColumnByCoordsMutSlice<'_> { /// # Safety /// /// `vec_index` must be a valid index. @@ -357,7 +364,7 @@ pub struct VeryPackedSecureColumnByCoordsMutSlice<'a>( pub [VeryPackedBaseColumnMutSlice<'a>; SECURE_EXTENSION_DEGREE], ); -impl<'a> VeryPackedSecureColumnByCoordsMutSlice<'a> { +impl VeryPackedSecureColumnByCoordsMutSlice<'_> { /// # Safety /// /// `vec_index` must be a valid index. @@ -463,7 +470,7 @@ impl VeryPackedBaseColumn { /// # Safety /// /// The resulting pointer does not update the underlying `data`'s length. - pub unsafe fn transform_under_ref(value: &BaseColumn) -> &Self { + pub const unsafe fn transform_under_ref(value: &BaseColumn) -> &Self { &*(std::ptr::addr_of!(*value) as *const VeryPackedBaseColumn) } diff --git a/crates/prover/src/core/backend/simd/domain.rs b/crates/prover/src/core/backend/simd/domain.rs index 209314175..d27cf2396 100644 --- a/crates/prover/src/core/backend/simd/domain.rs +++ b/crates/prover/src/core/backend/simd/domain.rs @@ -73,7 +73,7 @@ fn test_circle_domain_bit_rev_iterator() { 5, )); let mut expected = domain.iter().collect::>(); - crate::core::utils::bit_reverse(&mut expected); + crate::core::backend::cpu::bit_reverse(&mut expected); let actual = CircleDomainBitRevIterator::new(domain) .flat_map(|c| -> [_; 16] { std::array::from_fn(|i| CirclePoint { diff --git a/crates/prover/src/core/backend/simd/fft/ifft.rs b/crates/prover/src/core/backend/simd/fft/ifft.rs index 77b096d9c..b41cc70de 100644 --- a/crates/prover/src/core/backend/simd/fft/ifft.rs +++ b/crates/prover/src/core/backend/simd/fft/ifft.rs @@ -9,11 +9,10 @@ use rayon::prelude::*; use super::{ compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE, }; +use crate::core::backend::cpu::bit_reverse; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::backend::simd::utils::UnsafeMut; use crate::core::circle::Coset; -use crate::core::fields::FieldExpOps; -use crate::core::utils::bit_reverse; use crate::parallel_iter; /// Performs an Inverse Circle Fast Fourier Transform (ICFFT) on the given values. @@ -598,7 +597,7 @@ mod tests { let mut res = values; unsafe { ifft3( - transmute(res.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.as_mut_ptr()), 0, LOG_N_LANES as usize, twiddles0_dbl, @@ -664,7 +663,7 @@ mod tests { [val0.to_array(), val1.to_array()].concat() }; - assert_eq!(res, ground_truth_ifft(domain, values.flatten())); + assert_eq!(res, ground_truth_ifft(domain, values.as_flattened())); } #[test] @@ -678,7 +677,7 @@ mod tests { let mut res = values.iter().copied().collect::(); unsafe { ifft_lower_with_vecwise( - transmute(res.data.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), log_size as usize, log_size as usize, @@ -700,11 +699,14 @@ mod tests { let mut res = values.iter().copied().collect::(); unsafe { ifft( - transmute(res.data.as_mut_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), log_size as usize, ); - transpose_vecs(transmute(res.data.as_mut_ptr()), log_size as usize - 4); + transpose_vecs( + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), + log_size as usize - 4, + ); } assert_eq!(res.to_cpu(), ground_truth_ifft(domain, &values)); diff --git a/crates/prover/src/core/backend/simd/fft/mod.rs b/crates/prover/src/core/backend/simd/fft/mod.rs index ba091b145..78624d9e0 100644 --- a/crates/prover/src/core/backend/simd/fft/mod.rs +++ b/crates/prover/src/core/backend/simd/fft/mod.rs @@ -97,12 +97,12 @@ pub fn compute_first_twiddles(twiddle1_dbl: u32x8) -> (u32x16, u32x16) { } #[inline] -unsafe fn load(mem_addr: *const u32) -> u32x16 { +const unsafe fn load(mem_addr: *const u32) -> u32x16 { std::ptr::read(mem_addr as *const u32x16) } #[inline] -unsafe fn store(mem_addr: *mut u32, a: u32x16) { +const unsafe fn store(mem_addr: *mut u32, a: u32x16) { std::ptr::write(mem_addr as *mut u32x16, a); } @@ -111,19 +111,19 @@ fn mul_twiddle(v: PackedBaseField, twiddle_dbl: u32x16) -> PackedBaseField { // TODO: Come up with a better approach than `cfg`ing on target_feature. // TODO: Ensure all these branches get tested in the CI. cfg_if::cfg_if! { - if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] { + if #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] { // TODO: For architectures that when multiplying require doubling then the twiddles // should be precomputed as double. For other architectures, the twiddle should be // precomputed without doubling. - crate::core::backend::simd::m31::_mul_doubled_neon(v, twiddle_dbl) - } else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] { - crate::core::backend::simd::m31::_mul_doubled_wasm(v, twiddle_dbl) + crate::core::backend::simd::m31::mul_doubled_neon(v, twiddle_dbl) + } else if #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))] { + crate::core::backend::simd::m31::mul_doubled_wasm(v, twiddle_dbl) } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] { - crate::core::backend::simd::m31::_mul_doubled_avx512(v, twiddle_dbl) + crate::core::backend::simd::m31::mul_doubled_avx512(v, twiddle_dbl) } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] { - crate::core::backend::simd::m31::_mul_doubled_avx2(v, twiddle_dbl) + crate::core::backend::simd::m31::mul_doubled_avx2(v, twiddle_dbl) } else { - crate::core::backend::simd::m31::_mul_doubled_simd(v, twiddle_dbl) + crate::core::backend::simd::m31::mul_doubled_simd(v, twiddle_dbl) } } } diff --git a/crates/prover/src/core/backend/simd/fft/rfft.rs b/crates/prover/src/core/backend/simd/fft/rfft.rs index d28c8a00d..4500f64ef 100644 --- a/crates/prover/src/core/backend/simd/fft/rfft.rs +++ b/crates/prover/src/core/backend/simd/fft/rfft.rs @@ -10,10 +10,10 @@ use rayon::prelude::*; use super::{ compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE, }; +use crate::core::backend::cpu::bit_reverse; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::backend::simd::utils::{UnsafeConst, UnsafeMut}; use crate::core::circle::Coset; -use crate::core::utils::bit_reverse; use crate::parallel_iter; /// Performs a Circle Fast Fourier Transform (CFFT) on the given values. @@ -624,8 +624,8 @@ mod tests { let mut res = values; unsafe { fft3( - transmute(res.as_ptr()), - transmute(res.as_mut_ptr()), + transmute::<*const PackedBaseField, *const u32>(res.as_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.as_mut_ptr()), 0, LOG_N_LANES as usize, twiddles0_dbl, @@ -695,7 +695,7 @@ mod tests { [val0.to_array(), val1.to_array()].concat() }; - assert_eq!(res, ground_truth_fft(domain, values.flatten())); + assert_eq!(res, ground_truth_fft(domain, values.as_flattened())); } #[test] @@ -709,8 +709,8 @@ mod tests { let mut res = values.iter().copied().collect::(); unsafe { fft_lower_with_vecwise( - transmute(res.data.as_ptr()), - transmute(res.data.as_mut_ptr()), + transmute::<*const PackedBaseField, *const u32>(res.data.as_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), log_size as usize, log_size as usize, @@ -731,10 +731,13 @@ mod tests { let mut res = values.iter().copied().collect::(); unsafe { - transpose_vecs(transmute(res.data.as_mut_ptr()), log_size as usize - 4); + transpose_vecs( + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), + log_size as usize - 4, + ); fft( - transmute(res.data.as_ptr()), - transmute(res.data.as_mut_ptr()), + transmute::<*const PackedBaseField, *const u32>(res.data.as_ptr()), + transmute::<*mut PackedBaseField, *mut u32>(res.data.as_mut_ptr()), &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), log_size as usize, ); diff --git a/crates/prover/src/core/backend/simd/fri.rs b/crates/prover/src/core/backend/simd/fri.rs index 8aba27e4b..958b13372 100644 --- a/crates/prover/src/core/backend/simd/fri.rs +++ b/crates/prover/src/core/backend/simd/fri.rs @@ -1,5 +1,5 @@ use std::array; -use std::simd::u32x8; +use std::simd::{u32x16, u32x8}; use num_traits::Zero; @@ -40,17 +40,18 @@ impl FriOps for SimdBackend { let mut folded_values = SecureColumnByCoords::::zeros(1 << (log_size - 1)); for vec_index in 0..(1 << (log_size - 1 - LOG_N_LANES)) { - let value = unsafe { - let twiddle_dbl: [u32; 16] = - array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 16 + i)); - let val0 = eval.values.packed_at(vec_index * 2).into_packed_m31s(); - let val1 = eval.values.packed_at(vec_index * 2 + 1).into_packed_m31s(); + let value = { + let twiddle_dbl = u32x16::from_array(array::from_fn(|i| unsafe { + *itwiddles.get_unchecked(vec_index * 16 + i) + })); + let val0 = unsafe { eval.values.packed_at(vec_index * 2) }.into_packed_m31s(); + let val1 = unsafe { eval.values.packed_at(vec_index * 2 + 1) }.into_packed_m31s(); let pairs: [_; 4] = array::from_fn(|i| { nvtx::range_push!("[SIMD] deinterleave"); let (a, b) = val0[i].deinterleave(val1[i]); nvtx::range_pop!(); nvtx::range_push!("[SIMD] simd_ibutterfly"); - let butterfly = simd_ibutterfly(a, b, std::mem::transmute(twiddle_dbl)); + let butterfly = simd_ibutterfly(a, b, unsafe { std::mem::transmute(twiddle_dbl)}); nvtx::range_pop!(); butterfly diff --git a/crates/prover/src/core/backend/simd/lookups/gkr.rs b/crates/prover/src/core/backend/simd/lookups/gkr.rs index 017948dee..74d7f7c43 100644 --- a/crates/prover/src/core/backend/simd/lookups/gkr.rs +++ b/crates/prover/src/core/backend/simd/lookups/gkr.rs @@ -25,7 +25,7 @@ impl GkrOps for SimdBackend { } // Start DP with CPU backend to avoid dealing with instances smaller than a SIMD vector. - let (y_last_chunk, y_rem) = y.split_last_chunk::<{ LOG_N_LANES as usize }>().unwrap(); + let (y_rem, y_last_chunk) = y.split_last_chunk::<{ LOG_N_LANES as usize }>().unwrap(); let initial = SecureColumn::from_iter(cpu_gen_eq_evals(y_last_chunk, v)); assert_eq!(initial.len(), N_LANES); diff --git a/crates/prover/src/core/backend/simd/lookups/mle.rs b/crates/prover/src/core/backend/simd/lookups/mle.rs index 0e2fe73f7..07f175bbc 100644 --- a/crates/prover/src/core/backend/simd/lookups/mle.rs +++ b/crates/prover/src/core/backend/simd/lookups/mle.rs @@ -30,9 +30,8 @@ impl MleOps for SimdBackend { let (evals_at_0x, evals_at_1x) = mle.data.split_at(packed_midpoint); let res = zip(evals_at_0x, evals_at_1x) - .enumerate() // MLE at points `({0, 1}, rev(bits(i)), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. - .map(|(_i, (&packed_eval_at_0iv, &packed_eval_at_1iv))| { + .map(|(&packed_eval_at_0iv, &packed_eval_at_1iv)| { fold_packed_mle_evals(packed_assignment, packed_eval_at_0iv, packed_eval_at_1iv) }) .collect(); diff --git a/crates/prover/src/core/backend/simd/m31.rs b/crates/prover/src/core/backend/simd/m31.rs index f6291626b..3d10be8c0 100644 --- a/crates/prover/src/core/backend/simd/m31.rs +++ b/crates/prover/src/core/backend/simd/m31.rs @@ -3,14 +3,13 @@ use std::mem::transmute; use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use std::ptr; use std::simd::cmp::SimdOrd; -use std::simd::{u32x16, Simd, Swizzle}; +use std::simd::{u32x16, Simd}; use bytemuck::{Pod, Zeroable}; use num_traits::{One, Zero}; use rand::distributions::{Distribution, Standard}; use super::qm31::PackedQM31; -use crate::core::backend::simd::utils::{InterleaveEvens, InterleaveOdds}; use crate::core::fields::m31::{pow2147483645, BaseField, M31, P}; use crate::core::fields::qm31::QM31; use crate::core::fields::FieldExpOps; @@ -78,14 +77,14 @@ impl PackedM31 { self + self } - pub fn into_simd(self) -> Simd { + pub const fn into_simd(self) -> Simd { self.0 } /// # Safety /// /// Vector elements must be in the range `[0, P]`. - pub unsafe fn from_simd_unchecked(v: Simd) -> Self { + pub const unsafe fn from_simd_unchecked(v: Simd) -> Self { Self(v) } @@ -93,7 +92,7 @@ impl PackedM31 { /// /// Behavior is undefined if the pointer does not have the same alignment as /// [`PackedM31`]. The loaded `u32` values must be in the range `[0, P]`. - pub unsafe fn load(mem_addr: *const u32) -> Self { + pub const unsafe fn load(mem_addr: *const u32) -> Self { Self(ptr::read(mem_addr as *const u32x16)) } @@ -101,7 +100,7 @@ impl PackedM31 { /// /// Behavior is undefined if the pointer does not have the same alignment as /// [`PackedM31`]. - pub unsafe fn store(self, dst: *mut u32) { + pub const unsafe fn store(self, dst: *mut u32) { ptr::write(dst as *mut u32x16, self.0) } } @@ -142,16 +141,16 @@ impl Mul for PackedM31 { // TODO: Come up with a better approach than `cfg`ing on target_feature. // TODO: Ensure all these branches get tested in the CI. cfg_if::cfg_if! { - if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] { - _mul_neon(self, rhs) - } else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] { - _mul_wasm(self, rhs) + if #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] { + mul_neon(self, rhs) + } else if #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))] { + mul_wasm(self, rhs) } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] { - _mul_avx512(self, rhs) + mul_avx512(self, rhs) } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] { - _mul_avx2(self, rhs) + mul_avx2(self, rhs) } else { - _mul_simd(self, rhs) + mul_simd(self, rhs) } } } @@ -286,290 +285,299 @@ impl Sum for PackedM31 { } } -/// Returns `a * b`. -#[cfg(target_arch = "aarch64")] -pub(crate) fn _mul_neon(a: PackedM31, b: PackedM31) -> PackedM31 { - use core::arch::aarch64::{int32x2_t, vqdmull_s32}; - use std::simd::u32x4; - - let [a0, a1, a2, a3, a4, a5, a6, a7]: [int32x2_t; 8] = unsafe { transmute(a) }; - let [b0, b1, b2, b3, b4, b5, b6, b7]: [int32x2_t; 8] = unsafe { transmute(b) }; - - // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| - let c0: u32x4 = unsafe { transmute(vqdmull_s32(a0, b0)) }; - let c1: u32x4 = unsafe { transmute(vqdmull_s32(a1, b1)) }; - let c2: u32x4 = unsafe { transmute(vqdmull_s32(a2, b2)) }; - let c3: u32x4 = unsafe { transmute(vqdmull_s32(a3, b3)) }; - let c4: u32x4 = unsafe { transmute(vqdmull_s32(a4, b4)) }; - let c5: u32x4 = unsafe { transmute(vqdmull_s32(a5, b5)) }; - let c6: u32x4 = unsafe { transmute(vqdmull_s32(a6, b6)) }; - let c7: u32x4 = unsafe { transmute(vqdmull_s32(a7, b7)) }; - - // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. - // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. - let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); - let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); - let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); - let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); - - // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. - c0_c1_lo >>= 1; - c2_c3_lo >>= 1; - c4_c5_lo >>= 1; - c6_c7_lo >>= 1; - - let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; - let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; - - lo + hi -} +cfg_if::cfg_if! { + if #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] { + use core::arch::aarch64::{uint32x2_t, vmull_u32, int32x2_t, vqdmull_s32}; + use std::simd::u32x4; + + /// Returns `a * b`. + pub(crate) fn mul_neon(a: PackedM31, b: PackedM31) -> PackedM31 { + let [a0, a1, a2, a3, a4, a5, a6, a7]: [int32x2_t; 8] = unsafe { transmute(a) }; + let [b0, b1, b2, b3, b4, b5, b6, b7]: [int32x2_t; 8] = unsafe { transmute(b) }; + + // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| + let c0: u32x4 = unsafe { transmute(vqdmull_s32(a0, b0)) }; + let c1: u32x4 = unsafe { transmute(vqdmull_s32(a1, b1)) }; + let c2: u32x4 = unsafe { transmute(vqdmull_s32(a2, b2)) }; + let c3: u32x4 = unsafe { transmute(vqdmull_s32(a3, b3)) }; + let c4: u32x4 = unsafe { transmute(vqdmull_s32(a4, b4)) }; + let c5: u32x4 = unsafe { transmute(vqdmull_s32(a5, b5)) }; + let c6: u32x4 = unsafe { transmute(vqdmull_s32(a6, b6)) }; + let c7: u32x4 = unsafe { transmute(vqdmull_s32(a7, b7)) }; + + // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. + // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. + let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); + let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); + let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); + let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); + + // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. + c0_c1_lo >>= 1; + c2_c3_lo >>= 1; + c4_c5_lo >>= 1; + c6_c7_lo >>= 1; + + let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; + let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; + + lo + hi + } -/// Returns `a * b`. -/// -/// `b_double` should be in the range `[0, 2P]`. -#[cfg(target_arch = "aarch64")] -pub(crate) fn _mul_doubled_neon(a: PackedM31, b_double: u32x16) -> PackedM31 { - use core::arch::aarch64::{uint32x2_t, vmull_u32}; - use std::simd::u32x4; - - let [a0, a1, a2, a3, a4, a5, a6, a7]: [uint32x2_t; 8] = unsafe { transmute(a) }; - let [b0, b1, b2, b3, b4, b5, b6, b7]: [uint32x2_t; 8] = unsafe { transmute(b_double) }; - - // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| - let c0: u32x4 = unsafe { transmute(vmull_u32(a0, b0)) }; - let c1: u32x4 = unsafe { transmute(vmull_u32(a1, b1)) }; - let c2: u32x4 = unsafe { transmute(vmull_u32(a2, b2)) }; - let c3: u32x4 = unsafe { transmute(vmull_u32(a3, b3)) }; - let c4: u32x4 = unsafe { transmute(vmull_u32(a4, b4)) }; - let c5: u32x4 = unsafe { transmute(vmull_u32(a5, b5)) }; - let c6: u32x4 = unsafe { transmute(vmull_u32(a6, b6)) }; - let c7: u32x4 = unsafe { transmute(vmull_u32(a7, b7)) }; - - // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. - // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. - let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); - let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); - let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); - let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); - - // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. - c0_c1_lo >>= 1; - c2_c3_lo >>= 1; - c4_c5_lo >>= 1; - c6_c7_lo >>= 1; - - let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; - let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; - - lo + hi -} + /// Returns `a * b`. + /// + /// `b_double` should be in the range `[0, 2P]`. + pub(crate) fn mul_doubled_neon(a: PackedM31, b_double: u32x16) -> PackedM31 { + let [a0, a1, a2, a3, a4, a5, a6, a7]: [uint32x2_t; 8] = unsafe { transmute(a) }; + let [b0, b1, b2, b3, b4, b5, b6, b7]: [uint32x2_t; 8] = unsafe { transmute(b_double) }; + + // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| + let c0: u32x4 = unsafe { transmute(vmull_u32(a0, b0)) }; + let c1: u32x4 = unsafe { transmute(vmull_u32(a1, b1)) }; + let c2: u32x4 = unsafe { transmute(vmull_u32(a2, b2)) }; + let c3: u32x4 = unsafe { transmute(vmull_u32(a3, b3)) }; + let c4: u32x4 = unsafe { transmute(vmull_u32(a4, b4)) }; + let c5: u32x4 = unsafe { transmute(vmull_u32(a5, b5)) }; + let c6: u32x4 = unsafe { transmute(vmull_u32(a6, b6)) }; + let c7: u32x4 = unsafe { transmute(vmull_u32(a7, b7)) }; + + // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. + // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. + let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); + let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); + let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); + let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); + + // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. + c0_c1_lo >>= 1; + c2_c3_lo >>= 1; + c4_c5_lo >>= 1; + c6_c7_lo >>= 1; + + let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; + let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; + + lo + hi + } + } else if #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))] { + use core::arch::wasm32::{i64x2_extmul_high_u32x4, i64x2_extmul_low_u32x4, v128}; + use std::simd::u32x4; -/// Returns `a * b`. -#[cfg(target_arch = "wasm32")] -pub(crate) fn _mul_wasm(a: PackedM31, b: PackedM31) -> PackedM31 { - _mul_doubled_wasm(a, b.0 + b.0) -} + /// Returns `a * b`. + pub(crate) fn mul_wasm(a: PackedM31, b: PackedM31) -> PackedM31 { + mul_doubled_wasm(a, b.0 + b.0) + } -/// Returns `a * b`. -/// -/// `b_double` should be in the range `[0, 2P]`. -#[cfg(target_arch = "wasm32")] -pub(crate) fn _mul_doubled_wasm(a: PackedM31, b_double: u32x16) -> PackedM31 { - use core::arch::wasm32::{i64x2_extmul_high_u32x4, i64x2_extmul_low_u32x4, v128}; - use std::simd::u32x4; - - let [a0, a1, a2, a3]: [v128; 4] = unsafe { transmute(a) }; - let [b_double0, b_double1, b_double2, b_double3]: [v128; 4] = unsafe { transmute(b_double) }; - - let c0_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a0, b_double0)) }; - let c0_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a0, b_double0)) }; - let c1_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a1, b_double1)) }; - let c1_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a1, b_double1)) }; - let c2_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a2, b_double2)) }; - let c2_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a2, b_double2)) }; - let c3_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a3, b_double3)) }; - let c3_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a3, b_double3)) }; - - let (mut c0_even, c0_odd) = c0_lo.deinterleave(c0_hi); - let (mut c1_even, c1_odd) = c1_lo.deinterleave(c1_hi); - let (mut c2_even, c2_odd) = c2_lo.deinterleave(c2_hi); - let (mut c3_even, c3_odd) = c3_lo.deinterleave(c3_hi); - c0_even >>= 1; - c1_even >>= 1; - c2_even >>= 1; - c3_even >>= 1; - - let even: PackedM31 = unsafe { transmute([c0_even, c1_even, c2_even, c3_even]) }; - let odd: PackedM31 = unsafe { transmute([c0_odd, c1_odd, c2_odd, c3_odd]) }; - - even + odd -} + /// Returns `a * b`. + /// + /// `b_double` should be in the range `[0, 2P]`. + pub(crate) fn mul_doubled_wasm(a: PackedM31, b_double: u32x16) -> PackedM31 { + let [a0, a1, a2, a3]: [v128; 4] = unsafe { transmute(a) }; + let [b_double0, b_double1, b_double2, b_double3]: [v128; 4] = unsafe { transmute(b_double) }; + + let c0_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a0, b_double0)) }; + let c0_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a0, b_double0)) }; + let c1_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a1, b_double1)) }; + let c1_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a1, b_double1)) }; + let c2_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a2, b_double2)) }; + let c2_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a2, b_double2)) }; + let c3_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a3, b_double3)) }; + let c3_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a3, b_double3)) }; + + let (mut c0_even, c0_odd) = c0_lo.deinterleave(c0_hi); + let (mut c1_even, c1_odd) = c1_lo.deinterleave(c1_hi); + let (mut c2_even, c2_odd) = c2_lo.deinterleave(c2_hi); + let (mut c3_even, c3_odd) = c3_lo.deinterleave(c3_hi); + c0_even >>= 1; + c1_even >>= 1; + c2_even >>= 1; + c3_even >>= 1; + + let even: PackedM31 = unsafe { transmute([c0_even, c1_even, c2_even, c3_even]) }; + let odd: PackedM31 = unsafe { transmute([c0_odd, c1_odd, c2_odd, c3_odd]) }; + + even + odd + } + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] { + use std::arch::x86_64::{__m512i, _mm512_mul_epu32, _mm512_srli_epi64}; + use std::simd::Swizzle; -/// Returns `a * b`. -#[cfg(target_arch = "x86_64")] -pub(crate) fn _mul_avx512(a: PackedM31, b: PackedM31) -> PackedM31 { - _mul_doubled_avx512(a, b.0 + b.0) -} + use crate::core::backend::simd::utils::swizzle::{InterleaveEvens, InterleaveOdds}; -/// Returns `a * b`. -/// -/// `b_double` should be in the range `[0, 2P]`. -#[cfg(target_arch = "x86_64")] -pub(crate) fn _mul_doubled_avx512(a: PackedM31, b_double: u32x16) -> PackedM31 { - use std::arch::x86_64::{__m512i, _mm512_mul_epu32, _mm512_srli_epi64}; - - let a: __m512i = unsafe { transmute(a) }; - let b_double: __m512i = unsafe { transmute(b_double) }; - - // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of - // the first operand. - let a_e = a; - // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of - // the first operand. - let a_o = unsafe { _mm512_srli_epi64(a, 32) }; - - let b_dbl_e = b_double; - let b_dbl_o = unsafe { _mm512_srli_epi64(b_double, 32) }; - - // To compute prod = a * b start by multiplying a_e/odd by b_dbl_e/odd. - let prod_dbl_e: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_e, b_dbl_e)) }; - let prod_dbl_o: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_o, b_dbl_o)) }; - - // The result of a multiplication holds a*b in as 64-bits. - // Each 64b-bit word looks like this: - // 1 31 31 1 - // prod_dbl_e - |0|prod_e_h|prod_e_l|0| - // prod_dbl_o - |0|prod_o_h|prod_o_l|0| - - // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: - let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); - // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| - // Divide by 2: - prod_lo >>= 1; - // prod_lo - |0|prod_o_l|0|prod_e_l| - - // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: - let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); - // prod_hi - |0|prod_o_h|0|prod_e_h| - - PackedM31(prod_lo) + PackedM31(prod_hi) -} + /// Returns `a * b`. + pub(crate) fn mul_avx512(a: PackedM31, b: PackedM31) -> PackedM31 { + mul_doubled_avx512(a, b.0 + b.0) + } -/// Returns `a * b`. -#[cfg(target_arch = "x86_64")] -pub(crate) fn _mul_avx2(a: PackedM31, b: PackedM31) -> PackedM31 { - _mul_doubled_avx2(a, b.0 + b.0) -} + /// Returns `a * b`. + /// + /// `b_double` should be in the range `[0, 2P]`. + pub(crate) fn mul_doubled_avx512(a: PackedM31, b_double: u32x16) -> PackedM31 { + let a: __m512i = unsafe { transmute(a) }; + let b_double: __m512i = unsafe { transmute(b_double) }; + + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a_e = a; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a_o = unsafe { _mm512_srli_epi64(a, 32) }; + + let b_dbl_e = b_double; + let b_dbl_o = unsafe { _mm512_srli_epi64(b_double, 32) }; + + // To compute prod = a * b start by multiplying a_e/odd by b_dbl_e/odd. + let prod_dbl_e: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_e, b_dbl_e)) }; + let prod_dbl_o: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_o, b_dbl_o)) }; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_dbl_e - |0|prod_e_h|prod_e_l|0| + // prod_dbl_o - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: + let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| + // Divide by 2: + prod_lo >>= 1; + // prod_lo - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: + let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_hi - |0|prod_o_h|0|prod_e_h| + + PackedM31(prod_lo) + PackedM31(prod_hi) + } + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] { + use std::arch::x86_64::{__m256i, _mm256_mul_epu32, _mm256_srli_epi64}; + use std::simd::Swizzle; -/// Returns `a * b`. -/// -/// `b_double` should be in the range `[0, 2P]`. -#[cfg(target_arch = "x86_64")] -pub(crate) fn _mul_doubled_avx2(a: PackedM31, b_double: u32x16) -> PackedM31 { - use std::arch::x86_64::{__m256i, _mm256_mul_epu32, _mm256_srli_epi64}; - - let [a0, a1]: [__m256i; 2] = unsafe { transmute(a) }; - let [b0_dbl, b1_dbl]: [__m256i; 2] = unsafe { transmute(b_double) }; - - // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of - // the first operand. - let a0_e = a0; - let a1_e = a1; - // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of - // the first operand. - let a0_o = unsafe { _mm256_srli_epi64(a0, 32) }; - let a1_o = unsafe { _mm256_srli_epi64(a1, 32) }; - - let b0_dbl_e = b0_dbl; - let b1_dbl_e = b1_dbl; - let b0_dbl_o = unsafe { _mm256_srli_epi64(b0_dbl, 32) }; - let b1_dbl_o = unsafe { _mm256_srli_epi64(b1_dbl, 32) }; - - // To compute prod = a * b start by multiplying a0/1_e/odd by b0/1_e/odd. - let prod0_dbl_e = unsafe { _mm256_mul_epu32(a0_e, b0_dbl_e) }; - let prod0_dbl_o = unsafe { _mm256_mul_epu32(a0_o, b0_dbl_o) }; - let prod1_dbl_e = unsafe { _mm256_mul_epu32(a1_e, b1_dbl_e) }; - let prod1_dbl_o = unsafe { _mm256_mul_epu32(a1_o, b1_dbl_o) }; - - let prod_dbl_e: u32x16 = unsafe { transmute([prod0_dbl_e, prod1_dbl_e]) }; - let prod_dbl_o: u32x16 = unsafe { transmute([prod0_dbl_o, prod1_dbl_o]) }; - - // The result of a multiplication holds a*b in as 64-bits. - // Each 64b-bit word looks like this: - // 1 31 31 1 - // prod_dbl_e - |0|prod_e_h|prod_e_l|0| - // prod_dbl_o - |0|prod_o_h|prod_o_l|0| - - // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: - let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); - // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| - // Divide by 2: - prod_lo >>= 1; - // prod_lo - |0|prod_o_l|0|prod_e_l| - - // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: - let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); - // prod_hi - |0|prod_o_h|0|prod_e_h| - - PackedM31(prod_lo) + PackedM31(prod_hi) -} + use crate::core::backend::simd::utils::swizzle::{InterleaveEvens, InterleaveOdds}; -/// Returns `a * b`. -/// -/// Should only be used in the absence of a platform specific implementation. -pub(crate) fn _mul_simd(a: PackedM31, b: PackedM31) -> PackedM31 { - _mul_doubled_simd(a, b.0 + b.0) -} + /// Returns `a * b`. + pub(crate) fn mul_avx2(a: PackedM31, b: PackedM31) -> PackedM31 { + mul_doubled_avx2(a, b.0 + b.0) + } -/// Returns `a * b`. -/// -/// Should only be used in the absence of a platform specific implementation. -/// -/// `b_double` should be in the range `[0, 2P]`. -pub(crate) fn _mul_doubled_simd(a: PackedM31, b_double: u32x16) -> PackedM31 { - const MASK_EVENS: Simd = Simd::from_array([0xFFFFFFFF; { N_LANES / 2 }]); - - // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of - // the first operand. - let a_e = unsafe { transmute::<_, Simd>(a.0) & MASK_EVENS }; - // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of - // the first operand. - let a_o = unsafe { transmute::<_, Simd>(a) >> 32 }; - - let b_dbl_e = unsafe { transmute::<_, Simd>(b_double) & MASK_EVENS }; - let b_dbl_o = unsafe { transmute::<_, Simd>(b_double) >> 32 }; - - // To compute prod = a * b start by multiplying - // a_e/o by b_dbl_e/o. - let prod_e_dbl = a_e * b_dbl_e; - let prod_o_dbl = a_o * b_dbl_o; - - // The result of a multiplication holds a*b in as 64-bits. - // Each 64b-bit word looks like this: - // 1 31 31 1 - // prod_e_dbl - |0|prod_e_h|prod_e_l|0| - // prod_o_dbl - |0|prod_o_h|prod_o_l|0| - - // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: - // let prod_lows = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, - // prod_o_dbl); - // prod_ls - |prod_o_l|0|prod_e_l|0| - let mut prod_lows = InterleaveEvens::concat_swizzle( - unsafe { transmute::<_, Simd>(prod_e_dbl) }, - unsafe { transmute::<_, Simd>(prod_o_dbl) }, - ); - // Divide by 2: - prod_lows >>= 1; - // prod_ls - |0|prod_o_l|0|prod_e_l| - - // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: - let prod_highs = InterleaveOdds::concat_swizzle( - unsafe { transmute::<_, Simd>(prod_e_dbl) }, - unsafe { transmute::<_, Simd>(prod_o_dbl) }, - ); - - // prod_hs - |0|prod_o_h|0|prod_e_h| - PackedM31(prod_lows) + PackedM31(prod_highs) + /// Returns `a * b`. + /// + /// `b_double` should be in the range `[0, 2P]`. + pub(crate) fn mul_doubled_avx2(a: PackedM31, b_double: u32x16) -> PackedM31 { + let [a0, a1]: [__m256i; 2] = unsafe { transmute::(a) }; + let [b0_dbl, b1_dbl]: [__m256i; 2] = unsafe { transmute::(b_double) }; + + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a0_e = a0; + let a1_e = a1; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a0_o = unsafe { _mm256_srli_epi64(a0, 32) }; + let a1_o = unsafe { _mm256_srli_epi64(a1, 32) }; + + let b0_dbl_e = b0_dbl; + let b1_dbl_e = b1_dbl; + let b0_dbl_o = unsafe { _mm256_srli_epi64(b0_dbl, 32) }; + let b1_dbl_o = unsafe { _mm256_srli_epi64(b1_dbl, 32) }; + + // To compute prod = a * b start by multiplying a0/1_e/odd by b0/1_e/odd. + let prod0_dbl_e = unsafe { _mm256_mul_epu32(a0_e, b0_dbl_e) }; + let prod0_dbl_o = unsafe { _mm256_mul_epu32(a0_o, b0_dbl_o) }; + let prod1_dbl_e = unsafe { _mm256_mul_epu32(a1_e, b1_dbl_e) }; + let prod1_dbl_o = unsafe { _mm256_mul_epu32(a1_o, b1_dbl_o) }; + + let prod_dbl_e: u32x16 = + unsafe { transmute::<[__m256i; 2], u32x16>([prod0_dbl_e, prod1_dbl_e]) }; + let prod_dbl_o: u32x16 = + unsafe { transmute::<[__m256i; 2], u32x16>([prod0_dbl_o, prod1_dbl_o]) }; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_dbl_e - |0|prod_e_h|prod_e_l|0| + // prod_dbl_o - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: + let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| + // Divide by 2: + prod_lo >>= 1; + // prod_lo - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: + let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_hi - |0|prod_o_h|0|prod_e_h| + + PackedM31(prod_lo) + PackedM31(prod_hi) + } + } else { + use std::simd::Swizzle; + + use crate::core::backend::simd::utils::swizzle::{InterleaveEvens, InterleaveOdds}; + + /// Returns `a * b`. + /// + /// Should only be used in the absence of a platform specific implementation. + pub(crate) fn mul_simd(a: PackedM31, b: PackedM31) -> PackedM31 { + mul_doubled_simd(a, b.0 + b.0) + } + + /// Returns `a * b`. + /// + /// Should only be used in the absence of a platform specific implementation. + /// + /// `b_double` should be in the range `[0, 2P]`. + pub(crate) fn mul_doubled_simd(a: PackedM31, b_double: u32x16) -> PackedM31 { + const MASK_EVENS: Simd = Simd::from_array([0xFFFFFFFF; { N_LANES / 2 }]); + + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a_e = + unsafe { transmute::, Simd>(a.0) & MASK_EVENS }; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a_o = unsafe { transmute::>(a) >> 32 }; + + let b_dbl_e = unsafe { + transmute::, Simd>(b_double) & MASK_EVENS + }; + let b_dbl_o = + unsafe { transmute::, Simd>(b_double) >> 32 }; + + // To compute prod = a * b start by multiplying + // a_e/o by b_dbl_e/o. + let prod_e_dbl = a_e * b_dbl_e; + let prod_o_dbl = a_o * b_dbl_o; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_e_dbl - |0|prod_e_h|prod_e_l|0| + // prod_o_dbl - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: + // let prod_lows = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, + // prod_o_dbl); + // prod_ls - |prod_o_l|0|prod_e_l|0| + let mut prod_lows = InterleaveEvens::concat_swizzle( + unsafe { transmute::, Simd>(prod_e_dbl) }, + unsafe { transmute::, Simd>(prod_o_dbl) }, + ); + // Divide by 2: + prod_lows >>= 1; + // prod_ls - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: + let prod_highs = InterleaveOdds::concat_swizzle( + unsafe { transmute::, Simd>(prod_e_dbl) }, + unsafe { transmute::, Simd>(prod_o_dbl) }, + ); + + // prod_hs - |0|prod_o_h|0|prod_e_h| + PackedM31(prod_lows) + PackedM31(prod_highs) + } + } } #[cfg(test)] diff --git a/crates/prover/src/core/backend/simd/prefix_sum.rs b/crates/prover/src/core/backend/simd/prefix_sum.rs index 652b484a1..8e7f07cdf 100644 --- a/crates/prover/src/core/backend/simd/prefix_sum.rs +++ b/crates/prover/src/core/backend/simd/prefix_sum.rs @@ -4,13 +4,12 @@ use std::ops::{AddAssign, Sub}; use itertools::{izip, Itertools}; use num_traits::Zero; +use crate::core::backend::cpu::bit_reverse; use crate::core::backend::simd::m31::{PackedBaseField, N_LANES}; use crate::core::backend::simd::SimdBackend; use crate::core::backend::{Col, Column}; use crate::core::fields::m31::BaseField; -use crate::core::utils::{ - bit_reverse, circle_domain_order_to_coset_order, coset_order_to_circle_domain_order, -}; +use crate::core::utils::{circle_domain_order_to_coset_order, coset_order_to_circle_domain_order}; /// Performs a inclusive prefix sum on values in `Coset` order when provided /// with evaluations in bit-reversed `CircleDomain` order. diff --git a/crates/prover/src/core/backend/simd/qm31.rs b/crates/prover/src/core/backend/simd/qm31.rs index 078f6ef56..ce7231d0a 100644 --- a/crates/prover/src/core/backend/simd/qm31.rs +++ b/crates/prover/src/core/backend/simd/qm31.rs @@ -28,12 +28,12 @@ impl PackedQM31 { } /// Returns all `a` values such that each vector element is represented as `a + bu`. - pub fn a(&self) -> PackedCM31 { + pub const fn a(&self) -> PackedCM31 { self.0[0] } /// Returns all `b` values such that each vector element is represented as `a + bu`. - pub fn b(&self) -> PackedCM31 { + pub const fn b(&self) -> PackedCM31 { self.0[1] } @@ -80,14 +80,14 @@ impl PackedQM31 { /// Returns vectors `a, b, c, d` such that element `i` is represented as /// `QM31(a_i, b_i, c_i, d_i)`. - pub fn into_packed_m31s(self) -> [PackedM31; 4] { + pub const fn into_packed_m31s(self) -> [PackedM31; 4] { let Self([PackedCM31([a, b]), PackedCM31([c, d])]) = self; [a, b, c, d] } /// Creates an instance from vectors `a, b, c, d` such that element `i` /// is represented as `QM31(a_i, b_i, c_i, d_i)`. - pub fn from_packed_m31s([a, b, c, d]: [PackedM31; 4]) -> Self { + pub const fn from_packed_m31s([a, b, c, d]: [PackedM31; 4]) -> Self { Self([PackedCM31([a, b]), PackedCM31([c, d])]) } } diff --git a/crates/prover/src/core/backend/simd/quotients.rs b/crates/prover/src/core/backend/simd/quotients.rs index 2620b5b17..68a89f299 100644 --- a/crates/prover/src/core/backend/simd/quotients.rs +++ b/crates/prover/src/core/backend/simd/quotients.rs @@ -1,5 +1,7 @@ use itertools::{izip, zip_eq, Itertools}; use num_traits::Zero; +#[cfg(feature = "parallel")] +use rayon::prelude::*; use tracing::{span, Level}; use super::cm31::PackedCM31; @@ -8,6 +10,7 @@ use super::domain::CircleDomainBitRevIterator; use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; use super::qm31::PackedSecureField; use super::SimdBackend; +use crate::core::backend::cpu::bit_reverse; use crate::core::backend::cpu::quotients::{batch_random_coeffs, column_line_coeffs}; use crate::core::backend::{Column, CpuBackend}; use crate::core::fields::m31::BaseField; @@ -17,7 +20,6 @@ use crate::core::fields::FieldExpOps; use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps}; use crate::core::poly::circle::{CircleDomain, CircleEvaluation, PolyOps, SecureEvaluation}; use crate::core::poly::BitReversedOrder; -use crate::core::utils::bit_reverse; pub struct QuotientConstants { pub line_coeffs: Vec>, @@ -122,10 +124,17 @@ fn accumulate_quotients_on_subdomain( let quotient_constants = quotient_constants(sample_batches, random_coeff, subdomain); let span = span!(Level::INFO, "Quotient accumulation").entered(); - for (quad_row, points) in CircleDomainBitRevIterator::new(subdomain) + let quad_rows = CircleDomainBitRevIterator::new(subdomain) .array_chunks::<4>() - .enumerate() - { + .collect_vec(); + + #[cfg(not(feature = "parallel"))] + let iter = quad_rows.iter().zip(values.chunks_mut(4)).enumerate(); + + #[cfg(feature = "parallel")] + let iter = quad_rows.par_iter().zip(values.chunks_mut(4)).enumerate(); + + iter.for_each(|(quad_row, (points, mut values_dst))| { // TODO(andrew): Spapini said: Use optimized domain iteration. Is there a better way to do // this? let (y01, _) = points[0].y.deinterleave(points[1].y); @@ -138,11 +147,13 @@ fn accumulate_quotients_on_subdomain( quad_row, spaced_ys, ); - #[allow(clippy::needless_range_loop)] - for i in 0..4 { - unsafe { values.set_packed((quad_row << 2) + i, row_accumulator[i]) }; + unsafe { + values_dst.set_packed(0, row_accumulator[0]); + values_dst.set_packed(1, row_accumulator[1]); + values_dst.set_packed(2, row_accumulator[2]); + values_dst.set_packed(3, row_accumulator[3]); } - } + }); span.exit(); let span = span!(Level::INFO, "Quotient extension").entered(); @@ -294,13 +305,13 @@ mod tests { let e1: BaseColumn = (0..small_domain.size()) .map(|i| BaseField::from(2 * i)) .collect(); - let polys = vec![ + let polys = [ CircleEvaluation::::new(small_domain, e0) .interpolate(), CircleEvaluation::::new(small_domain, e1) .interpolate(), ]; - let columns = vec![polys[0].evaluate(domain), polys[1].evaluate(domain)]; + let columns = [polys[0].evaluate(domain), polys[1].evaluate(domain)]; let random_coeff = qm31!(1, 2, 3, 4); let a = polys[0].eval_at_point(SECURE_FIELD_CIRCLE_GEN); let b = polys[1].eval_at_point(SECURE_FIELD_CIRCLE_GEN); diff --git a/crates/prover/src/core/backend/simd/utils.rs b/crates/prover/src/core/backend/simd/utils.rs index b5cb9e986..a3d1b614c 100644 --- a/crates/prover/src/core/backend/simd/utils.rs +++ b/crates/prover/src/core/backend/simd/utils.rs @@ -1,36 +1,10 @@ -use std::simd::Swizzle; - -/// Used with [`Swizzle::concat_swizzle`] to interleave the even values of two vectors. -pub struct InterleaveEvens; - -impl Swizzle for InterleaveEvens { - const INDEX: [usize; N] = parity_interleave(false); -} - -/// Used with [`Swizzle::concat_swizzle`] to interleave the odd values of two vectors. -pub struct InterleaveOdds; - -impl Swizzle for InterleaveOdds { - const INDEX: [usize; N] = parity_interleave(true); -} - -const fn parity_interleave(odd: bool) -> [usize; N] { - let mut res = [0; N]; - let mut i = 0; - while i < N { - res[i] = (i % 2) * N + (i / 2) * 2 + if odd { 1 } else { 0 }; - i += 1; - } - res -} - // TODO(andrew): Examine usage of unsafe in SIMD FFT. pub struct UnsafeMut(pub *mut T); impl UnsafeMut { /// # Safety /// /// Returns a raw mutable pointer. - pub unsafe fn get(&self) -> *mut T { + pub const unsafe fn get(&self) -> *mut T { self.0 } } @@ -43,7 +17,7 @@ impl UnsafeConst { /// # Safety /// /// Returns a raw constant pointer. - pub unsafe fn get(&self) -> *const T { + pub const unsafe fn get(&self) -> *const T { self.0 } } @@ -51,29 +25,60 @@ impl UnsafeConst { unsafe impl Send for UnsafeConst {} unsafe impl Sync for UnsafeConst {} -#[cfg(test)] -mod tests { - use std::simd::{u32x4, Swizzle}; - - use super::{InterleaveEvens, InterleaveOdds}; +#[cfg(not(any( + all(target_arch = "aarch64", target_feature = "neon"), + all(target_arch = "wasm32", target_feature = "simd128") +)))] +pub mod swizzle { + use std::simd::Swizzle; + + /// Used with [`Swizzle::concat_swizzle`] to interleave the even values of two vectors. + pub struct InterleaveEvens; + impl Swizzle for InterleaveEvens { + const INDEX: [usize; N] = parity_interleave(false); + } - #[test] - fn interleave_evens() { - let lo = u32x4::from_array([0, 1, 2, 3]); - let hi = u32x4::from_array([4, 5, 6, 7]); + /// Used with [`Swizzle::concat_swizzle`] to interleave the odd values of two vectors. + pub struct InterleaveOdds; - let res = InterleaveEvens::concat_swizzle(lo, hi); + impl Swizzle for InterleaveOdds { + const INDEX: [usize; N] = parity_interleave(true); + } - assert_eq!(res, u32x4::from_array([0, 4, 2, 6])); + const fn parity_interleave(odd: bool) -> [usize; N] { + let mut res = [0; N]; + let mut i = 0; + while i < N { + res[i] = (i % 2) * N + (i / 2) * 2 + if odd { 1 } else { 0 }; + i += 1; + } + res } - #[test] - fn interleave_odds() { - let lo = u32x4::from_array([0, 1, 2, 3]); - let hi = u32x4::from_array([4, 5, 6, 7]); + #[cfg(test)] + mod tests { + use std::simd::{u32x4, Swizzle}; + + use super::{InterleaveEvens, InterleaveOdds}; + + #[test] + fn interleave_evens() { + let lo = u32x4::from_array([0, 1, 2, 3]); + let hi = u32x4::from_array([4, 5, 6, 7]); + + let res = InterleaveEvens::concat_swizzle(lo, hi); + + assert_eq!(res, u32x4::from_array([0, 4, 2, 6])); + } + + #[test] + fn interleave_odds() { + let lo = u32x4::from_array([0, 1, 2, 3]); + let hi = u32x4::from_array([4, 5, 6, 7]); - let res = InterleaveOdds::concat_swizzle(lo, hi); + let res = InterleaveOdds::concat_swizzle(lo, hi); - assert_eq!(res, u32x4::from_array([1, 5, 3, 7])); + assert_eq!(res, u32x4::from_array([1, 5, 3, 7])); + } } } diff --git a/crates/prover/src/core/channel/blake2s.rs b/crates/prover/src/core/channel/blake2s.rs index 86565658b..62218b5ba 100644 --- a/crates/prover/src/core/channel/blake2s.rs +++ b/crates/prover/src/core/channel/blake2s.rs @@ -19,7 +19,7 @@ pub struct Blake2sChannel { } impl Blake2sChannel { - pub fn digest(&self) -> Blake2sHash { + pub const fn digest(&self) -> Blake2sHash { self.digest } pub fn update_digest(&mut self, new_digest: Blake2sHash) { @@ -75,7 +75,7 @@ impl Channel for Blake2sChannel { let res = compress(std::array::from_fn(|i| digest[i]), msg, 0, 0, 0, 0); // TODO(shahars) Channel should always finalize hash. - self.update_digest(unsafe { std::mem::transmute(res) }); + self.update_digest(unsafe { std::mem::transmute::<[u32; 8], Blake2sHash>(res) }); } fn draw_felt(&mut self) -> SecureField { diff --git a/crates/prover/src/core/channel/poseidon252.rs b/crates/prover/src/core/channel/poseidon252.rs index c0960fc3e..a02a82b1d 100644 --- a/crates/prover/src/core/channel/poseidon252.rs +++ b/crates/prover/src/core/channel/poseidon252.rs @@ -19,7 +19,7 @@ pub struct Poseidon252Channel { } impl Poseidon252Channel { - pub fn digest(&self) -> FieldElement252 { + pub const fn digest(&self) -> FieldElement252 { self.digest } pub fn update_digest(&mut self, new_digest: FieldElement252) { diff --git a/crates/prover/src/core/circle.rs b/crates/prover/src/core/circle.rs index 8cfe48ab8..a20ee3451 100644 --- a/crates/prover/src/core/circle.rs +++ b/crates/prover/src/core/circle.rs @@ -126,7 +126,8 @@ impl + FieldExpOps + Sub + Neg type Output = Self; fn add(self, rhs: Self) -> Self::Output { - let x = self.x.clone() * rhs.x.clone() - self.y.clone() * rhs.y.clone(); + // TODO(ShaharS): Revert once Rust solves compiler [issue](https://github.com/rust-lang/rust/issues/134457). + let x = self.x.clone() * rhs.x.clone() + (-self.y.clone() * rhs.y.clone()); let y = self.x * rhs.y + self.y * rhs.x; Self { x, y } } @@ -223,15 +224,15 @@ pub const SECURE_FIELD_CIRCLE_ORDER: u128 = P4 - 1; pub struct CirclePointIndex(pub usize); impl CirclePointIndex { - pub fn zero() -> Self { + pub const fn zero() -> Self { Self(0) } - pub fn generator() -> Self { + pub const fn generator() -> Self { Self(1) } - pub fn reduce(self) -> Self { + pub const fn reduce(self) -> Self { Self(self.0 & ((1 << M31_CIRCLE_LOG_ORDER) - 1)) } @@ -343,16 +344,16 @@ impl Coset { } /// Returns the size of the coset. - pub fn size(&self) -> usize { + pub const fn size(&self) -> usize { 1 << self.log_size() } /// Returns the log size of the coset. - pub fn log_size(&self) -> u32 { + pub const fn log_size(&self) -> u32 { self.log_size } - pub fn iter(&self) -> CosetIterator> { + pub const fn iter(&self) -> CosetIterator> { CosetIterator { cur: self.initial, step: self.step, @@ -360,7 +361,7 @@ impl Coset { } } - pub fn iter_indices(&self) -> CosetIterator { + pub const fn iter_indices(&self) -> CosetIterator { CosetIterator { cur: self.initial_index, step: self.step_size, @@ -389,7 +390,7 @@ impl Coset { && *self == other.repeated_double(other.log_size - self.log_size) } - pub fn initial(&self) -> CirclePoint { + pub const fn initial(&self) -> CirclePoint { self.initial } diff --git a/crates/prover/src/core/constraints.rs b/crates/prover/src/core/constraints.rs index 31711d98e..f66c8d93d 100644 --- a/crates/prover/src/core/constraints.rs +++ b/crates/prover/src/core/constraints.rs @@ -90,11 +90,11 @@ pub fn complex_conjugate_line( / (point.complex_conjugate().y - point.y) } -/// Evaluates the coefficients of a line between a point and its complex conjugate. Specifically, -/// `a, b, and c, s.t. a*x + b -c*y = 0` for (x,y) being (sample.y, sample.value) and -/// (conj(sample.y), conj(sample.value)). -/// Relies on the fact that every polynomial F over the base -/// field holds: F(p*) == F(p)* (* being the complex conjugate). +/// Evaluates the coefficients of a line between a point and its complex conjugate. +/// +/// Specifically, `a, b, and c, s.t. a*x + b -c*y = 0` for (x,y) being (sample.y, sample.value) and +/// (conj(sample.y), conj(sample.value)). Relies on the fact that every polynomial F over the base +/// field holds: `F(p*) == F(p)*` (`*` being the complex conjugate). pub fn complex_conjugate_line_coeffs( sample: &PointSample, alpha: SecureField, diff --git a/crates/prover/src/core/fields/cm31.rs b/crates/prover/src/core/fields/cm31.rs index 6f1b6c2ef..e7f92dba7 100644 --- a/crates/prover/src/core/fields/cm31.rs +++ b/crates/prover/src/core/fields/cm31.rs @@ -24,7 +24,7 @@ impl CM31 { Self(M31::from_u32_unchecked(a), M31::from_u32_unchecked(b)) } - pub fn from_m31(a: M31, b: M31) -> CM31 { + pub const fn from_m31(a: M31, b: M31) -> CM31 { Self(a, b) } } diff --git a/crates/prover/src/core/fields/m31.rs b/crates/prover/src/core/fields/m31.rs index a7c3c57a2..9ef981ecf 100644 --- a/crates/prover/src/core/fields/m31.rs +++ b/crates/prover/src/core/fields/m31.rs @@ -55,13 +55,18 @@ impl M31 { /// let val = (P as u64).pow(2) - 19; /// assert_eq!(M31::reduce(val), M31::from(P - 19)); /// ``` - pub fn reduce(val: u64) -> Self { + pub const fn reduce(val: u64) -> Self { Self((((((val >> MODULUS_BITS) + val + 1) >> MODULUS_BITS) + val) & (P as u64)) as u32) } pub const fn from_u32_unchecked(arg: u32) -> Self { Self(arg) } + + pub fn inverse(&self) -> Self { + assert!(!self.is_zero(), "0 has no inverse"); + pow2147483645(*self) + } } impl Display for M31 { @@ -112,8 +117,7 @@ impl FieldExpOps for M31 { /// assert_eq!(v.inverse() * v, BaseField::one()); /// ``` fn inverse(&self) -> Self { - assert!(!self.is_zero(), "0 has no inverse"); - pow2147483645(*self) + self.inverse() } } @@ -211,15 +215,15 @@ mod tests { use super::{M31, P}; use crate::core::fields::IntoSlice; - fn mul_p(a: u32, b: u32) -> u32 { + const fn mul_p(a: u32, b: u32) -> u32 { ((a as u64 * b as u64) % P as u64) as u32 } - fn add_p(a: u32, b: u32) -> u32 { + const fn add_p(a: u32, b: u32) -> u32 { (a + b) % P } - fn neg_p(a: u32) -> u32 { + const fn neg_p(a: u32) -> u32 { if a == 0 { 0 } else { diff --git a/crates/prover/src/core/fields/qm31.rs b/crates/prover/src/core/fields/qm31.rs index 6da19a3c0..41342ade6 100644 --- a/crates/prover/src/core/fields/qm31.rs +++ b/crates/prover/src/core/fields/qm31.rs @@ -32,15 +32,15 @@ impl QM31 { ) } - pub fn from_m31(a: M31, b: M31, c: M31, d: M31) -> Self { + pub const fn from_m31(a: M31, b: M31, c: M31, d: M31) -> Self { Self(CM31::from_m31(a, b), CM31::from_m31(c, d)) } - pub fn from_m31_array(array: [M31; SECURE_EXTENSION_DEGREE]) -> Self { + pub const fn from_m31_array(array: [M31; SECURE_EXTENSION_DEGREE]) -> Self { Self::from_m31(array[0], array[1], array[2], array[3]) } - pub fn to_m31_array(self) -> [M31; SECURE_EXTENSION_DEGREE] { + pub const fn to_m31_array(self) -> [M31; SECURE_EXTENSION_DEGREE] { [self.0 .0, self.0 .1, self.1 .0, self.1 .1] } diff --git a/crates/prover/src/core/fri.rs b/crates/prover/src/core/fri.rs index 12a3b06df..13504cad5 100644 --- a/crates/prover/src/core/fri.rs +++ b/crates/prover/src/core/fri.rs @@ -72,7 +72,7 @@ impl FriConfig { } } - fn last_layer_domain_size(&self) -> usize { + const fn last_layer_domain_size(&self) -> usize { 1 << (self.log_last_layer_degree_bound + self.log_blowup_factor) } } @@ -100,8 +100,7 @@ pub trait FriOps: FieldOps + PolyOps + Sized + FieldOps /// Let `src` be the evaluation of a circle polynomial `f` on a /// [`CircleDomain`] `E`. This function computes evaluations of `f' = f0 /// + alpha * f1` on the x-coordinates of `E` such that `2f(p) = f0(px) + py * f1(px)`. The - /// evaluations of `f'` are accumulated into `dst` by the formula `dst = dst * alpha^2 + - /// f'`. + /// evaluations of `f'` are accumulated into `dst` by the formula `dst = dst * alpha^2 + f'`. /// /// # Panics /// @@ -138,7 +137,7 @@ pub struct FriProver<'a, B: FriOps + MerkleOps, MC: MerkleChannel> { impl<'a, B: FriOps + MerkleOps, MC: MerkleChannel> FriProver<'a, B, MC> { /// Commits to multiple circle polynomials. /// - /// `columns` must be provided in descending order by size. + /// `columns` must be provided in descending order by size with at most one column per size. /// /// This is a batched commitment that handles multiple mixed-degree polynomials, each /// evaluated over domains of varying sizes. Instead of combining these evaluations into @@ -149,7 +148,7 @@ impl<'a, B: FriOps + MerkleOps, MC: MerkleChannel> FriProver<'a, B, MC> { /// # Panics /// /// Panics if: - /// * `columns` is empty or not sorted in ascending order by domain size. + /// * `columns` is empty or not sorted in descending order by domain size. /// * An evaluation is not from a sufficiently low degree circle polynomial. /// * An evaluation's domain is smaller than the last layer. /// * An evaluation's domain is not a canonic circle domain. @@ -161,8 +160,11 @@ impl<'a, B: FriOps + MerkleOps, MC: MerkleChannel> FriProver<'a, B, MC> { twiddles: &TwiddleTree, ) -> Self { assert!(!columns.is_empty(), "no columns"); - assert!(columns.is_sorted_by_key(|e| Reverse(e.len())), "not sorted"); assert!(columns.iter().all(|e| e.domain.is_canonic()), "not canonic"); + assert!( + columns.array_windows().all(|[a, b]| a.len() > b.len()), + "column sizes not decreasing" + ); nvtx::range_push!("commit_first_layer"); let first_layer = Self::commit_first_layer(channel, columns); @@ -605,13 +607,13 @@ pub struct CirclePolyDegreeBound { } impl CirclePolyDegreeBound { - pub fn new(log_degree_bound: u32) -> Self { + pub const fn new(log_degree_bound: u32) -> Self { Self { log_degree_bound } } /// Maps a circle polynomial's degree bound to the degree bound of the univariate (line) /// polynomial it gets folded into. - fn fold_to_line(&self) -> LinePolyDegreeBound { + const fn fold_to_line(&self) -> LinePolyDegreeBound { LinePolyDegreeBound { log_degree_bound: self.log_degree_bound - CIRCLE_TO_LINE_FOLD_STEP, } @@ -637,7 +639,7 @@ struct LinePolyDegreeBound { impl LinePolyDegreeBound { /// Returns [None] if the unfolded degree bound is smaller than the folding factor. - fn fold(self, n_folds: u32) -> Option { + const fn fold(self, n_folds: u32) -> Option { if self.log_degree_bound < n_folds { return None; } @@ -648,7 +650,7 @@ impl LinePolyDegreeBound { } /// A FRI proof. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct FriProof { pub first_layer: FriLayerProof, pub inner_layers: Vec>, @@ -663,7 +665,7 @@ pub const FOLD_STEP: u32 = 1; pub const CIRCLE_TO_LINE_FOLD_STEP: u32 = 1; /// Proof of an individual FRI layer. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct FriLayerProof { /// Values that the verifier needs but cannot deduce from previous computations, in the /// order they are needed. This complements the values that were queried. These must be @@ -707,9 +709,9 @@ impl FriFirstLayerVerifier { let mut fri_witness = self.proof.fri_witness.iter().copied(); let mut decommitment_positions_by_log_size = BTreeMap::new(); - let mut all_column_decommitment_values = Vec::new(); let mut folded_evals_by_column = Vec::new(); + let mut decommitmented_values = vec![]; for (&column_domain, column_query_evals) in zip_eq(&self.column_commitment_domains, query_evals_by_column) { @@ -730,15 +732,13 @@ impl FriFirstLayerVerifier { decommitment_positions_by_log_size .insert(column_domain.log_size(), column_decommitment_positions); - // Prepare values in the structure needed for merkle decommitment. - let column_decommitment_values: SecureColumnByCoords = sparse_evaluation - .subset_evals - .iter() - .flatten() - .copied() - .collect(); - - all_column_decommitment_values.extend(column_decommitment_values.columns); + decommitmented_values.extend( + sparse_evaluation + .subset_evals + .iter() + .flatten() + .flat_map(|qm31| qm31.to_m31_array()), + ); let folded_evals = sparse_evaluation.fold_circle(self.folding_alpha, column_domain); folded_evals_by_column.push(folded_evals); @@ -760,7 +760,7 @@ impl FriFirstLayerVerifier { merkle_verifier .verify( &decommitment_positions_by_log_size, - all_column_decommitment_values, + decommitmented_values, self.proof.decommitment.clone(), ) .map_err(|error| FriVerificationError::FirstLayerCommitmentInvalid { error })?; @@ -822,12 +822,12 @@ impl FriInnerLayerVerifier { }); } - let decommitment_values: SecureColumnByCoords = sparse_evaluation + let decommitmented_values = sparse_evaluation .subset_evals .iter() .flatten() - .copied() - .collect(); + .flat_map(|qm31| qm31.to_m31_array()) + .collect_vec(); let merkle_verifier = MerkleVerifier::new( self.proof.commitment, @@ -837,7 +837,7 @@ impl FriInnerLayerVerifier { merkle_verifier .verify( &BTreeMap::from_iter([(self.domain.log_size(), decommitment_positions)]), - decommitment_values.columns.to_vec(), + decommitmented_values, self.proof.decommitment.clone(), ) .map_err(|e| FriVerificationError::InnerLayerCommitmentInvalid { @@ -995,7 +995,7 @@ fn compute_decommitment_positions_and_witness_evals( let mut witness_evals = Vec::new(); // Group queries by the folding coset they reside in. - for subset_queries in query_positions.group_by(|a, b| a >> fold_step == b >> fold_step) { + for subset_queries in query_positions.chunk_by(|a, b| a >> fold_step == b >> fold_step) { let subset_start = (subset_queries[0] >> fold_step) << fold_step; let subset_decommitment_positions = subset_start..subset_start + (1 << fold_step); let mut subset_queries_iter = subset_queries.iter().peekable(); @@ -1036,7 +1036,7 @@ fn compute_decommitment_positions_and_rebuild_evals( let mut subset_domain_index_initials = Vec::new(); // Group queries by the subset they reside in. - for subset_queries in queries.group_by(|a, b| a >> fold_step == b >> fold_step) { + for subset_queries in queries.chunk_by(|a, b| a >> fold_step == b >> fold_step) { let subset_start = (subset_queries[0] >> fold_step) << fold_step; let subset_decommitment_positions = subset_start..subset_start + (1 << fold_step); decommitment_positions.extend(subset_decommitment_positions.clone()); diff --git a/crates/prover/src/core/lookups/gkr_prover.rs b/crates/prover/src/core/lookups/gkr_prover.rs index 6e6ed2586..c2d6df1bd 100644 --- a/crates/prover/src/core/lookups/gkr_prover.rs +++ b/crates/prover/src/core/lookups/gkr_prover.rs @@ -299,7 +299,7 @@ pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> { pub lambda: SecureField, } -impl<'a, B: GkrOps> MultivariatePolyOracle for GkrMultivariatePolyOracle<'a, B> { +impl MultivariatePolyOracle for GkrMultivariatePolyOracle<'_, B> { fn n_variables(&self) -> usize { self.input_layer.n_variables() - 1 } @@ -470,7 +470,7 @@ pub fn prove_batch( // Seed the channel with the layer masks. for (&instance, mask) in zip(&sumcheck_instances, &masks) { - channel.mix_felts(mask.columns().flatten()); + channel.mix_felts(mask.columns().as_flattened()); layer_masks_by_instance[instance].push(mask.clone()); } diff --git a/crates/prover/src/core/lookups/gkr_verifier.rs b/crates/prover/src/core/lookups/gkr_verifier.rs index b65ceb162..f7ffefc9d 100644 --- a/crates/prover/src/core/lookups/gkr_verifier.rs +++ b/crates/prover/src/core/lookups/gkr_verifier.rs @@ -120,7 +120,7 @@ pub fn partially_verify_batch( for &instance in &sumcheck_instances { let n_unused = n_layers - instance_n_layers(instance); let mask = &layer_masks_by_instance[instance][layer - n_unused]; - channel.mix_felts(mask.columns().flatten()); + channel.mix_felts(mask.columns().as_flattened()); } // Set the OOD evaluation point for layer above. @@ -223,7 +223,7 @@ pub struct GkrMask { } impl GkrMask { - pub fn new(columns: Vec<[SecureField; 2]>) -> Self { + pub const fn new(columns: Vec<[SecureField; 2]>) -> Self { Self { columns } } diff --git a/crates/prover/src/core/lookups/utils.rs b/crates/prover/src/core/lookups/utils.rs index ed67477f7..d66bd93f0 100644 --- a/crates/prover/src/core/lookups/utils.rs +++ b/crates/prover/src/core/lookups/utils.rs @@ -202,7 +202,7 @@ pub struct Fraction { } impl Fraction { - pub fn new(numerator: N, denominator: D) -> Self { + pub const fn new(numerator: N, denominator: D) -> Self { Self { numerator, denominator, @@ -256,7 +256,7 @@ pub struct Reciprocal { } impl Reciprocal { - pub fn new(x: T) -> Self { + pub const fn new(x: T) -> Self { Self { x } } } diff --git a/crates/prover/src/core/pcs/mod.rs b/crates/prover/src/core/pcs/mod.rs index d9acf524b..1a551d1eb 100644 --- a/crates/prover/src/core/pcs/mod.rs +++ b/crates/prover/src/core/pcs/mod.rs @@ -1,4 +1,5 @@ //! Implements a FRI polynomial commitment scheme. +//! //! This is a protocol where the prover can commit on a set of polynomials and then prove their //! opening on a set of points. //! Note: This implementation is not really a polynomial commitment scheme, because we are not in diff --git a/crates/prover/src/core/pcs/prover.rs b/crates/prover/src/core/pcs/prover.rs index 199e01abd..3f98b77a3 100644 --- a/crates/prover/src/core/pcs/prover.rs +++ b/crates/prover/src/core/pcs/prover.rs @@ -162,12 +162,12 @@ impl<'a, B: BackendForChannel, MC: MerkleChannel> CommitmentSchemeProver<'a, } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct CommitmentSchemeProof { pub commitments: TreeVec, pub sampled_values: TreeVec>>, pub decommitments: TreeVec>, - pub queried_values: TreeVec>>, + pub queried_values: TreeVec>, pub proof_of_work: u64, pub fri_proof: FriProof, } @@ -177,7 +177,7 @@ pub struct TreeBuilder<'a, 'b, B: BackendForChannel, MC: MerkleChannel> { commitment_scheme: &'a mut CommitmentSchemeProver<'b, B, MC>, polys: ColumnVec>, } -impl<'a, 'b, B: BackendForChannel, MC: MerkleChannel> TreeBuilder<'a, 'b, B, MC> { +impl, MC: MerkleChannel> TreeBuilder<'_, '_, B, MC> { pub fn extend_evals( &mut self, columns: impl IntoIterator>, @@ -246,7 +246,7 @@ impl, MC: MerkleChannel> CommitmentTreeProver { fn decommit( &self, queries: &BTreeMap>, - ) -> (ColumnVec>, MerkleDecommitment) { + ) -> (Vec, MerkleDecommitment) { let eval_vec = self .evaluations .iter() diff --git a/crates/prover/src/core/pcs/quotients.rs b/crates/prover/src/core/pcs/quotients.rs index 9e9092e18..ce253d166 100644 --- a/crates/prover/src/core/pcs/quotients.rs +++ b/crates/prover/src/core/pcs/quotients.rs @@ -5,6 +5,7 @@ use std::iter::zip; use itertools::{izip, multiunzip, Itertools}; use tracing::{span, Level}; +use super::TreeVec; use crate::core::backend::cpu::quotients::{accumulate_row_quotients, quotient_constants}; use crate::core::circle::CirclePoint; use crate::core::fields::m31::BaseField; @@ -102,25 +103,30 @@ pub fn compute_fri_quotients( } pub fn fri_answers( - column_log_sizes: Vec, - samples: &[Vec], + column_log_sizes: TreeVec>, + samples: TreeVec>>, random_coeff: SecureField, query_positions_per_log_size: &BTreeMap>, - queried_values_per_column: &[Vec], + queried_values: TreeVec>, + n_columns_per_log_size: TreeVec<&BTreeMap>, ) -> Result>, VerificationError> { - izip!(column_log_sizes, samples, queried_values_per_column) + let mut queried_values = queried_values.map(|values| values.into_iter()); + + izip!(column_log_sizes.flatten(), samples.flatten().iter()) .sorted_by_key(|(log_size, ..)| Reverse(*log_size)) .group_by(|(log_size, ..)| *log_size) .into_iter() .map(|(log_size, tuples)| { - let (_, samples, queried_values_per_column): (Vec<_>, Vec<_>, Vec<_>) = - multiunzip(tuples); + let (_, samples): (Vec<_>, Vec<_>) = multiunzip(tuples); fri_answers_for_log_size( log_size, &samples, random_coeff, &query_positions_per_log_size[&log_size], - &queried_values_per_column, + &mut queried_values, + n_columns_per_log_size + .as_ref() + .map(|colums_log_sizes| *colums_log_sizes.get(&log_size).unwrap_or(&0)), ) }) .collect() @@ -131,27 +137,24 @@ pub fn fri_answers_for_log_size( samples: &[&Vec], random_coeff: SecureField, query_positions: &[usize], - queried_values_per_column: &[&Vec], + queried_values: &mut TreeVec>, + n_columns: TreeVec, ) -> Result, VerificationError> { - for queried_values in queried_values_per_column { - if queried_values.len() != query_positions.len() { - return Err(VerificationError::InvalidStructure( - "Insufficient number of queried values".to_string(), - )); - } - } - let sample_batches = ColumnSampleBatch::new_vec(samples); + // TODO(ilya): Is it ok to use the same `random_coeff` for all log sizes. let quotient_constants = quotient_constants(&sample_batches, random_coeff); let commitment_domain = CanonicCoset::new(log_size).circle_domain(); - let mut quotient_evals_at_queries = Vec::new(); - for (row, &query_position) in query_positions.iter().enumerate() { + let mut quotient_evals_at_queries = Vec::new(); + for &query_position in query_positions { let domain_point = commitment_domain.at(bit_reverse_index(query_position, log_size)); - let queried_values_at_row = queried_values_per_column - .iter() - .map(|col| col[row]) - .collect_vec(); + + let queried_values_at_row = queried_values + .as_mut() + .zip_eq(n_columns.as_ref()) + .map(|(queried_values, n_columns)| queried_values.take(*n_columns).collect()) + .flatten(); + quotient_evals_at_queries.push(accumulate_row_quotients( &sample_batches, &queried_values_at_row, diff --git a/crates/prover/src/core/pcs/utils.rs b/crates/prover/src/core/pcs/utils.rs index 36ef3a198..1a5ce7ccb 100644 --- a/crates/prover/src/core/pcs/utils.rs +++ b/crates/prover/src/core/pcs/utils.rs @@ -12,7 +12,7 @@ use crate::core::ColumnVec; pub struct TreeVec(pub Vec); impl TreeVec { - pub fn new(vec: Vec) -> TreeVec { + pub const fn new(vec: Vec) -> TreeVec { TreeVec(vec) } pub fn map U>(self, f: F) -> TreeVec { @@ -41,6 +41,13 @@ impl<'a, T> From<&'a TreeVec> for TreeVec<&'a T> { } } +/// Converts `&TreeVec<&Vec>` to `TreeVec>`. +impl<'a, T> From<&'a TreeVec<&'a Vec>> for TreeVec> { + fn from(val: &'a TreeVec<&'a Vec>) -> Self { + TreeVec(val.iter().map(|vec| vec.iter().collect()).collect()) + } +} + impl Deref for TreeVec { type Target = Vec; fn deref(&self) -> &Self::Target { diff --git a/crates/prover/src/core/pcs/verifier.rs b/crates/prover/src/core/pcs/verifier.rs index 200fe98d5..d6ecd334e 100644 --- a/crates/prover/src/core/pcs/verifier.rs +++ b/crates/prover/src/core/pcs/verifier.rs @@ -96,24 +96,26 @@ impl CommitmentSchemeVerifier { }) .0 .into_iter() - .collect::>()?; + .collect::>()?; // Answer FRI queries. - let samples = sampled_points - .zip_cols(proof.sampled_values) - .map_cols(|(sampled_points, sampled_values)| { + let samples = sampled_points.zip_cols(proof.sampled_values).map_cols( + |(sampled_points, sampled_values)| { zip(sampled_points, sampled_values) .map(|(point, value)| PointSample { point, value }) .collect_vec() - }) - .flatten(); + }, + ); + + let n_columns_per_log_size = self.trees.as_ref().map(|tree| &tree.n_columns_per_log_size); let fri_answers = fri_answers( - self.column_log_sizes().flatten().into_iter().collect(), - &samples, + self.column_log_sizes(), + samples, random_coeff, &query_positions_per_log_size, - &proof.queried_values.flatten(), + proof.queried_values, + n_columns_per_log_size, )?; fri_verifier.decommit(fri_answers)?; diff --git a/crates/prover/src/core/poly/circle/canonic.rs b/crates/prover/src/core/poly/circle/canonic.rs index 837e648d9..1a559fd3c 100644 --- a/crates/prover/src/core/poly/circle/canonic.rs +++ b/crates/prover/src/core/poly/circle/canonic.rs @@ -2,12 +2,14 @@ use super::CircleDomain; use crate::core::circle::{CirclePoint, CirclePointIndex, Coset}; use crate::core::fields::m31::BaseField; -/// A coset of the form G_{2n} + , where G_n is the generator of the -/// subgroup of order n. The ordering on this coset is G_2n + i * G_n. -/// These cosets can be used as a [CircleDomain], and be interpolated on. -/// Note that this changes the ordering on the coset to be like [CircleDomain], -/// which is G_2n + i * G_n/2 and then -G_2n -i * G_n/2. -/// For example, the Xs below are a canonic coset with n=8. +/// A coset of the form `G_{2n} + `, where `G_n` is the generator of the subgroup of order `n`. +/// +/// The ordering on this coset is `G_2n + i * G_n`. +/// These cosets can be used as a [`CircleDomain`], and be interpolated on. +/// Note that this changes the ordering on the coset to be like [`CircleDomain`], +/// which is `G_{2n} + i * G_{n/2}` and then `-G_{2n} -i * G_{n/2}`. +/// For example, the `X`s below are a canonic coset with `n=8`. +/// /// ```text /// X O X /// O O @@ -31,7 +33,7 @@ impl CanonicCoset { } /// Gets the full coset represented G_{2n} + . - pub fn coset(&self) -> Coset { + pub const fn coset(&self) -> Coset { self.coset } @@ -46,24 +48,24 @@ impl CanonicCoset { } /// Returns the log size of the coset. - pub fn log_size(&self) -> u32 { + pub const fn log_size(&self) -> u32 { self.coset.log_size } /// Returns the size of the coset. - pub fn size(&self) -> usize { + pub const fn size(&self) -> usize { self.coset.size() } - pub fn initial_index(&self) -> CirclePointIndex { + pub const fn initial_index(&self) -> CirclePointIndex { self.coset.initial_index } - pub fn step_size(&self) -> CirclePointIndex { + pub const fn step_size(&self) -> CirclePointIndex { self.coset.step_size } - pub fn step(&self) -> CirclePoint { + pub const fn step(&self) -> CirclePoint { self.coset.step } diff --git a/crates/prover/src/core/poly/circle/domain.rs b/crates/prover/src/core/poly/circle/domain.rs index fba2bc3fb..83765e6d8 100644 --- a/crates/prover/src/core/poly/circle/domain.rs +++ b/crates/prover/src/core/poly/circle/domain.rs @@ -10,8 +10,9 @@ use crate::core::fields::m31::BaseField; pub const MAX_CIRCLE_DOMAIN_LOG_SIZE: u32 = M31_CIRCLE_LOG_ORDER - 1; /// A valid domain for circle polynomial interpolation and evaluation. -/// Valid domains are a disjoint union of two conjugate cosets: +-C + . -/// The ordering defined on this domain is C + iG_n, and then -C - iG_n. +/// +/// Valid domains are a disjoint union of two conjugate cosets: `+-C + `. +/// The ordering defined on this domain is `C + iG_n`, and then `-C - iG_n`. #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct CircleDomain { pub half_coset: Coset, @@ -20,7 +21,7 @@ pub struct CircleDomain { impl CircleDomain { /// Given a coset C + , constructs the circle domain +-C + (i.e., /// this coset and its conjugate). - pub fn new(half_coset: Coset) -> Self { + pub const fn new(half_coset: Coset) -> Self { Self { half_coset } } @@ -38,12 +39,12 @@ impl CircleDomain { } /// Returns the size of the domain. - pub fn size(&self) -> usize { + pub const fn size(&self) -> usize { 1 << self.log_size() } /// Returns the log size of the domain. - pub fn log_size(&self) -> u32 { + pub const fn log_size(&self) -> u32 { self.half_coset.log_size + 1 } diff --git a/crates/prover/src/core/poly/circle/evaluation.rs b/crates/prover/src/core/poly/circle/evaluation.rs index faa2b7284..6094df399 100644 --- a/crates/prover/src/core/poly/circle/evaluation.rs +++ b/crates/prover/src/core/poly/circle/evaluation.rs @@ -146,7 +146,7 @@ impl<'a, F: ExtensionOf> CosetSubEvaluation<'a, F> { } } -impl<'a, F: ExtensionOf> Index for CosetSubEvaluation<'a, F> { +impl> Index for CosetSubEvaluation<'_, F> { type Output = F; fn index(&self, index: isize) -> &Self::Output { @@ -156,7 +156,7 @@ impl<'a, F: ExtensionOf> Index for CosetSubEvaluation<'a, F> { } } -impl<'a, F: ExtensionOf> Index for CosetSubEvaluation<'a, F> { +impl> Index for CosetSubEvaluation<'_, F> { type Output = F; fn index(&self, index: usize) -> &Self::Output { diff --git a/crates/prover/src/core/poly/circle/poly.rs b/crates/prover/src/core/poly/circle/poly.rs index c10fc5e7a..6744a0c80 100644 --- a/crates/prover/src/core/poly/circle/poly.rs +++ b/crates/prover/src/core/poly/circle/poly.rs @@ -34,7 +34,7 @@ impl CirclePoly { Self { log_size, coeffs } } - pub fn log_size(&self) -> u32 { + pub const fn log_size(&self) -> u32 { self.log_size } diff --git a/crates/prover/src/core/poly/line.rs b/crates/prover/src/core/poly/line.rs index 2bf640c63..d684d0c58 100644 --- a/crates/prover/src/core/poly/line.rs +++ b/crates/prover/src/core/poly/line.rs @@ -9,14 +9,14 @@ use serde::{Deserialize, Serialize}; use super::circle::CircleDomain; use super::utils::fold; +use crate::core::backend::cpu::bit_reverse; use crate::core::backend::{ColumnOps, CpuBackend}; use crate::core::circle::{CirclePoint, Coset, CosetIterator}; use crate::core::fft::ibutterfly; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SecureColumnByCoords; -use crate::core::fields::{ExtensionOf, FieldExpOps, FieldOps}; -use crate::core::utils::bit_reverse; +use crate::core::fields::{ExtensionOf, FieldOps}; /// Domain comprising of the x-coordinates of points in a [Coset]. /// @@ -58,12 +58,12 @@ impl LineDomain { } /// Returns the size of the domain. - pub fn size(&self) -> usize { + pub const fn size(&self) -> usize { self.coset.size() } /// Returns the log size of the domain. - pub fn log_size(&self) -> u32 { + pub const fn log_size(&self) -> u32 { self.coset.log_size() } @@ -80,7 +80,7 @@ impl LineDomain { } /// Returns the domain's underlying coset. - pub fn coset(&self) -> Coset { + pub const fn coset(&self) -> Coset { self.coset } } @@ -209,11 +209,11 @@ impl> LineEvaluation { /// Returns the number of evaluations. #[allow(clippy::len_without_is_empty)] - pub fn len(&self) -> usize { + pub const fn len(&self) -> usize { 1 << self.domain.log_size() } - pub fn domain(&self) -> LineDomain { + pub const fn domain(&self) -> LineDomain { self.domain } diff --git a/crates/prover/src/core/poly/twiddles.rs b/crates/prover/src/core/poly/twiddles.rs index f3b186376..2e172a2cf 100644 --- a/crates/prover/src/core/poly/twiddles.rs +++ b/crates/prover/src/core/poly/twiddles.rs @@ -2,6 +2,7 @@ use super::circle::PolyOps; use crate::core::circle::Coset; /// Precomputed twiddles for a specific coset tower. +/// /// A coset tower is every repeated doubling of a `root_coset`. /// The largest CircleDomain that can be ffted using these twiddles is one with `root_coset` as /// its `half_coset`. diff --git a/crates/prover/src/core/prover/mod.rs b/crates/prover/src/core/prover/mod.rs index 5bef4def1..36f067c02 100644 --- a/crates/prover/src/core/prover/mod.rs +++ b/crates/prover/src/core/prover/mod.rs @@ -174,7 +174,7 @@ pub enum VerificationError { ProofOfWork, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct StarkProof(pub CommitmentSchemeProof); impl StarkProof { diff --git a/crates/prover/src/core/queries.rs b/crates/prover/src/core/queries.rs index cd9546e6a..946380000 100644 --- a/crates/prover/src/core/queries.rs +++ b/crates/prover/src/core/queries.rs @@ -70,10 +70,10 @@ impl Deref for Queries { #[cfg(test)] mod tests { + use crate::core::backend::cpu::bit_reverse; use crate::core::channel::Blake2sChannel; use crate::core::poly::circle::CanonicCoset; use crate::core::queries::Queries; - use crate::core::utils::bit_reverse; #[test] fn test_generate_queries() { diff --git a/crates/prover/src/core/utils.rs b/crates/prover/src/core/utils.rs index c7bfd4e4e..e9f79ad08 100644 --- a/crates/prover/src/core/utils.rs +++ b/crates/prover/src/core/utils.rs @@ -5,7 +5,6 @@ use std::ops::{Add, Mul, Sub}; use num_traits::{One, Zero}; use super::fields::m31::BaseField; -use super::fields::qm31::SecureField; use super::fields::Field; pub trait IteratorMutExt<'a, T: 'a>: Iterator { @@ -25,7 +24,7 @@ pub struct PeekTakeWhile<'a, I: Iterator, P: FnMut(&I::Item) -> bool> { iter: &'a mut Peekable, predicate: P, } -impl<'a, I: Iterator, P: FnMut(&I::Item) -> bool> Iterator for PeekTakeWhile<'a, I, P> { +impl bool> Iterator for PeekTakeWhile<'_, I, P> { type Item = I::Item; fn next(&mut self) -> Option { @@ -55,7 +54,7 @@ impl<'a, I: Iterator> PeekableExt<'a, I> for Peekable { } /// Returns the bit reversed index of `i` which is represented by `log_size` bits. -pub fn bit_reverse_index(i: usize, log_size: u32) -> usize { +pub const fn bit_reverse_index(i: usize, log_size: u32) -> usize { if log_size == 0 { return i; } @@ -65,7 +64,7 @@ pub fn bit_reverse_index(i: usize, log_size: u32) -> usize { /// Returns the index of the previous element in a bit reversed /// [super::poly::circle::CircleEvaluation] of log size `eval_log_size` relative to a smaller domain /// of size `domain_log_size`. -pub fn previous_bit_reversed_circle_domain_index( +pub const fn previous_bit_reversed_circle_domain_index( i: usize, domain_log_size: u32, eval_log_size: u32, @@ -76,7 +75,7 @@ pub fn previous_bit_reversed_circle_domain_index( /// Returns the index of the offset element in a bit reversed /// [super::poly::circle::CircleEvaluation] of log size `eval_log_size` relative to a smaller domain /// of size `domain_log_size`. -pub fn offset_bit_reversed_circle_domain_index( +pub const fn offset_bit_reversed_circle_domain_index( i: usize, domain_log_size: u32, eval_log_size: u32, @@ -123,7 +122,7 @@ pub(crate) fn coset_order_to_circle_domain_order(values: &[F]) -> Vec< /// /// [`CircleDomain`]: crate::core::poly::circle::CircleDomain /// [`Coset`]: crate::core::circle::Coset -pub fn coset_index_to_circle_domain_index(coset_index: usize, log_domain_size: u32) -> usize { +pub const fn coset_index_to_circle_domain_index(coset_index: usize, log_domain_size: u32) -> usize { if coset_index % 2 == 0 { coset_index / 2 } else { @@ -131,67 +130,6 @@ pub fn coset_index_to_circle_domain_index(coset_index: usize, log_domain_size: u } } -/// Performs a naive bit-reversal permutation inplace. -/// -/// # Panics -/// -/// Panics if the length of the slice is not a power of two. -// TODO(alont): Move this to the cpu backend. -pub fn bit_reverse(v: &mut [T]) { - let n = v.len(); - assert!(n.is_power_of_two()); - #[cfg(not(feature = "icicle"))] - { - let log_n = n.ilog2(); - for i in 0..n { - let j = bit_reverse_index(i, log_n); - if j > i { - v.swap(i, j); - } - } - } - - #[cfg(feature = "icicle")] - unsafe { - let limbs_count: usize = size_of_val(&v[0]) / 4; - use std::slice; - - use icicle_core::traits::FieldImpl; - use icicle_core::vec_ops::{bit_reverse_inplace, BitReverseConfig, VecOps}; - use icicle_cuda_runtime::device::get_device_from_pointer; - use icicle_cuda_runtime::memory::{DeviceSlice, HostSlice}; - use icicle_m31::field::{ComplexExtensionField, QuarticExtensionField, ScalarField}; - - fn bit_rev_generic(v: &mut [T], n: usize) - where - F: FieldImpl, - ::Config: VecOps, - { - let cfg = BitReverseConfig::default(); - - // Check if v is a DeviceSlice or some other slice type - let mut v_ptr = v.as_mut_ptr() as *mut F; - let rr = unsafe { slice::from_raw_parts_mut(v_ptr, n) }; - - // means data already on device (some finite device id, instead of huge number for host - // pointer) - if get_device_from_pointer(v_ptr as _).unwrap() <= 1024 { - bit_reverse_inplace(unsafe { DeviceSlice::from_mut_slice(rr) }, &cfg).unwrap(); - } else { - bit_reverse_inplace(HostSlice::from_mut_slice(rr), &cfg).unwrap(); - } - } - - if limbs_count == 1 { - bit_rev_generic::(v, n); - } else if limbs_count == 2 { - bit_rev_generic::(v, n); - } else if limbs_count == 4 { - bit_rev_generic::(v, n); - } - } -} - /// Performs a coset-natural-order to circle-domain-bit-reversed-order permutation in-place. /// /// # Panics @@ -209,80 +147,17 @@ pub fn bit_reverse_coset_to_circle_domain_order(v: &mut [T]) { } } -pub fn generate_secure_powers(felt: SecureField, n_powers: usize) -> Vec { - (0..n_powers) - .scan(SecureField::one(), |acc, _| { - let res = *acc; - *acc *= felt; - Some(res) - }) - .collect() -} - -/// Securely combines the given values using the given random alpha and z. -/// Alpha and z should be secure field elements for soundness. -pub fn shifted_secure_combination(values: &[F], alpha: EF, z: EF) -> EF -where - EF: Copy + Zero + Mul + Add + Sub, -{ - let res = values - .iter() - .fold(EF::zero(), |acc, &value| acc * alpha + value); - res - z -} - #[cfg(test)] mod tests { use itertools::Itertools; - use num_traits::One; use super::{ offset_bit_reversed_circle_domain_index, previous_bit_reversed_circle_domain_index, }; use crate::core::backend::cpu::CpuCircleEvaluation; - use crate::core::fields::qm31::SecureField; - use crate::core::fields::FieldExpOps; use crate::core::poly::circle::CanonicCoset; use crate::core::poly::NaturalOrder; - use crate::core::utils::bit_reverse; - use crate::{m31, qm31}; - - #[test] - fn bit_reverse_works() { - let mut data = [0, 1, 2, 3, 4, 5, 6, 7]; - bit_reverse(&mut data); - assert_eq!(data, [0, 4, 2, 6, 1, 5, 3, 7]); - } - - #[test] - #[should_panic] - fn bit_reverse_non_power_of_two_size_fails() { - let mut data = [0, 1, 2, 3, 4, 5]; - bit_reverse(&mut data); - } - - #[test] - fn generate_secure_powers_works() { - let felt = qm31!(1, 2, 3, 4); - let n_powers = 10; - - let powers = super::generate_secure_powers(felt, n_powers); - - assert_eq!(powers.len(), n_powers); - assert_eq!(powers[0], SecureField::one()); - assert_eq!(powers[1], felt); - assert_eq!(powers[7], felt.pow(7)); - } - - #[test] - fn generate_empty_secure_powers_works() { - let felt = qm31!(1, 2, 3, 4); - let max_log_size = 0; - - let powers = super::generate_secure_powers(felt, max_log_size); - - assert_eq!(powers, vec![]); - } + use crate::m31; #[test] fn test_offset_bit_reversed_circle_domain_index() { diff --git a/crates/prover/src/core/vcs/blake2_merkle.rs b/crates/prover/src/core/vcs/blake2_merkle.rs index 8401716fa..15e043723 100644 --- a/crates/prover/src/core/vcs/blake2_merkle.rs +++ b/crates/prover/src/core/vcs/blake2_merkle.rs @@ -20,7 +20,7 @@ impl MerkleHasher for Blake2sMerkleHasher { if let Some((left, right)) = children_hashes { state = compress( state, - unsafe { std::mem::transmute([left, right]) }, + unsafe { std::mem::transmute::<[Blake2sHash; 2], [u32; 16]>([left, right]) }, 0, 0, 0, @@ -33,9 +33,16 @@ impl MerkleHasher for Blake2sMerkleHasher { .copied() .chain(std::iter::repeat(BaseField::zero()).take(rem)); for chunk in padded_values.array_chunks::<16>() { - state = compress(state, unsafe { std::mem::transmute(chunk) }, 0, 0, 0, 0); + state = compress( + state, + unsafe { std::mem::transmute::<[BaseField; 16], [u32; 16]>(chunk) }, + 0, + 0, + 0, + 0, + ); } - state.map(|x| x.to_le_bytes()).flatten().into() + state.map(|x| x.to_le_bytes()).as_flattened().into() } } @@ -86,7 +93,7 @@ mod tests { #[test] fn test_merkle_invalid_value() { let (queries, decommitment, mut values, verifier) = prepare_merkle::(); - values[3][2] = BaseField::zero(); + values[6] = BaseField::zero(); assert_eq!( verifier.verify(&queries, values, decommitment).unwrap_err(), @@ -119,22 +126,22 @@ mod tests { #[test] fn test_merkle_column_values_too_long() { let (queries, decommitment, mut values, verifier) = prepare_merkle::(); - values[3].push(BaseField::zero()); + values.insert(3, BaseField::zero()); assert_eq!( verifier.verify(&queries, values, decommitment).unwrap_err(), - MerkleVerificationError::ColumnValuesTooLong + MerkleVerificationError::TooManyQueriedValues ); } #[test] fn test_merkle_column_values_too_short() { let (queries, decommitment, mut values, verifier) = prepare_merkle::(); - values[3].pop(); + values.remove(3); assert_eq!( verifier.verify(&queries, values, decommitment).unwrap_err(), - MerkleVerificationError::ColumnValuesTooShort + MerkleVerificationError::TooFewQueriedValues ); } diff --git a/crates/prover/src/core/vcs/blake2s_ref.rs b/crates/prover/src/core/vcs/blake2s_ref.rs index 9b982bba6..3776a8830 100644 --- a/crates/prover/src/core/vcs/blake2s_ref.rs +++ b/crates/prover/src/core/vcs/blake2s_ref.rs @@ -19,33 +19,33 @@ pub const SIGMA: [[u8; 16]; 10] = [ ]; #[inline(always)] -fn add(a: u32, b: u32) -> u32 { +const fn add(a: u32, b: u32) -> u32 { a.wrapping_add(b) } #[inline(always)] -fn xor(a: u32, b: u32) -> u32 { +const fn xor(a: u32, b: u32) -> u32 { a ^ b } #[inline(always)] -fn rot16(x: u32) -> u32 { - (x >> 16) | (x << (32 - 16)) +const fn rot16(x: u32) -> u32 { + x.rotate_right(16) } #[inline(always)] -fn rot12(x: u32) -> u32 { - (x >> 12) | (x << (32 - 12)) +const fn rot12(x: u32) -> u32 { + x.rotate_right(12) } #[inline(always)] -fn rot8(x: u32) -> u32 { - (x >> 8) | (x << (32 - 8)) +const fn rot8(x: u32) -> u32 { + x.rotate_right(8) } #[inline(always)] -fn rot7(x: u32) -> u32 { - (x >> 7) | (x << (32 - 7)) +const fn rot7(x: u32) -> u32 { + x.rotate_right(7) } #[inline(always)] diff --git a/crates/prover/src/core/vcs/ops.rs b/crates/prover/src/core/vcs/ops.rs index 6f0f728da..b55646f14 100644 --- a/crates/prover/src/core/vcs/ops.rs +++ b/crates/prover/src/core/vcs/ops.rs @@ -6,13 +6,12 @@ use crate::core::backend::{Col, ColumnOps}; use crate::core::fields::m31::BaseField; use crate::core::vcs::hash::Hash; -/// A Merkle node hash is a hash of: -/// [left_child_hash, right_child_hash], column0_value, column1_value, ... -/// "[]" denotes optional values. +/// A Merkle node hash is a hash of: `[left_child_hash, right_child_hash], column0_value, +/// column1_value, ...` where `[]` denotes optional values. +/// /// The largest Merkle layer has no left and right child hashes. The rest of the layers have -/// children hashes. -/// At each layer, the tree may have multiple columns of the same length as the layer. -/// Each node in that layer contains one value from each column. +/// children hashes. At each layer, the tree may have multiple columns of the same length as the +/// layer. Each node in that layer contains one value from each column. pub trait MerkleHasher: Debug + Default + Clone { type Hash: Hash; /// Hashes a single Merkle node. See [MerkleHasher] for more details. diff --git a/crates/prover/src/core/vcs/poseidon252_merkle.rs b/crates/prover/src/core/vcs/poseidon252_merkle.rs index 5ffba1ea6..f39a2c62d 100644 --- a/crates/prover/src/core/vcs/poseidon252_merkle.rs +++ b/crates/prover/src/core/vcs/poseidon252_merkle.rs @@ -114,7 +114,7 @@ mod tests { fn test_merkle_invalid_value() { let (queries, decommitment, mut values, verifier) = prepare_merkle::(); - values[3][2] = BaseField::zero(); + values[6] = BaseField::zero(); assert_eq!( verifier.verify(&queries, values, decommitment).unwrap_err(), @@ -147,26 +147,26 @@ mod tests { } #[test] - fn test_merkle_column_values_too_long() { + fn test_merkle_values_too_long() { let (queries, decommitment, mut values, verifier) = prepare_merkle::(); - values[3].push(BaseField::zero()); + values.insert(3, BaseField::zero()); assert_eq!( verifier.verify(&queries, values, decommitment).unwrap_err(), - MerkleVerificationError::ColumnValuesTooLong + MerkleVerificationError::TooManyQueriedValues ); } #[test] - fn test_merkle_column_values_too_short() { + fn test_merkle_values_too_short() { let (queries, decommitment, mut values, verifier) = prepare_merkle::(); - values[3].pop(); + values.remove(3); assert_eq!( verifier.verify(&queries, values, decommitment).unwrap_err(), - MerkleVerificationError::ColumnValuesTooShort + MerkleVerificationError::TooFewQueriedValues ); } } diff --git a/crates/prover/src/core/vcs/prover.rs b/crates/prover/src/core/vcs/prover.rs index 3390e5e60..4f623bf9f 100644 --- a/crates/prover/src/core/vcs/prover.rs +++ b/crates/prover/src/core/vcs/prover.rs @@ -9,7 +9,6 @@ use super::utils::{next_decommitment_node, option_flatten_peekable}; use crate::core::backend::{Col, Column}; use crate::core::fields::m31::BaseField; use crate::core::utils::PeekableExt; -use crate::core::ColumnVec; pub struct MerkleProver, H: MerkleHasher> { /// Layers of the Merkle tree. @@ -80,22 +79,22 @@ impl, H: MerkleHasher> MerkleProver { /// /// # Arguments /// - /// * `queries_per_log_size` - A map from log_size to a vector of queries for columns of that - /// log_size. + /// * `queries_per_log_size` - Maps a log_size to a vector of queries for columns of that size. /// * `columns` - A vector of references to columns. /// /// # Returns /// /// A tuple containing: - /// * A vector of vectors of queried values for each column, in the order of the input columns. + /// * A vector queried values sorted by the order they were queried from the largest layer to + /// the smallest. /// * A `MerkleDecommitment` containing the hash and column witnesses. pub fn decommit( &self, queries_per_log_size: &BTreeMap>, columns: Vec<&Col>, - ) -> (ColumnVec>, MerkleDecommitment) { + ) -> (Vec, MerkleDecommitment) { // Prepare output buffers. - let mut queried_values_by_layer = vec![]; + let mut queried_values = vec![]; let mut decommitment = MerkleDecommitment::empty(); // Sort columns by layer. @@ -106,9 +105,6 @@ impl, H: MerkleHasher> MerkleProver { let mut last_layer_queries = vec![]; for layer_log_size in (0..self.layers.len() as u32).rev() { - // Prepare write buffer for queried values to the current layer. - let mut layer_queried_values = vec![]; - // Prepare write buffer for queries to the current layer. This will propagate to the // next layer. let mut layer_total_queries = vec![]; @@ -152,7 +148,7 @@ impl, H: MerkleHasher> MerkleProver { // If the column values were queried, return them. let node_values = layer_columns.iter().map(|c| c.at(node_index)); if layer_column_queries.next_if_eq(&node_index).is_some() { - layer_queried_values.push(node_values.collect_vec()); + queried_values.extend(node_values); } else { // Otherwise, add them to the witness. decommitment.column_witness.extend(node_values); @@ -161,50 +157,13 @@ impl, H: MerkleHasher> MerkleProver { layer_total_queries.push(node_index); } - queried_values_by_layer.push(layer_queried_values); - // Propagate queries to the next layer. last_layer_queries = layer_total_queries; } - queried_values_by_layer.reverse(); - - // Rearrange returned queried values according to input, and not by layer. - let queried_values = Self::rearrange_queried_values(queried_values_by_layer, columns); (queried_values, decommitment) } - /// Given queried values by layer, rearranges in the order of input columns. - fn rearrange_queried_values( - queried_values_by_layer: Vec>>, - columns: Vec<&Col>, - ) -> Vec> { - // Turn each column queried values into an iterator. - let mut queried_values_by_layer = queried_values_by_layer - .into_iter() - .map(|layer_results| { - layer_results - .into_iter() - .map(|x| x.into_iter()) - .collect_vec() - }) - .collect_vec(); - - // For each input column, fetch the queried values from the corresponding layer. - let queried_values = columns - .iter() - .map(|column| { - queried_values_by_layer - .get_mut(column.len().ilog2() as usize) - .unwrap() - .iter_mut() - .map(|x| x.next().unwrap()) - .collect_vec() - }) - .collect_vec(); - queried_values - } - pub fn root(&self) -> H::Hash { self.layers.first().unwrap().at(0) } @@ -222,7 +181,7 @@ pub struct MerkleDecommitment { pub column_witness: Vec, } impl MerkleDecommitment { - fn empty() -> Self { + const fn empty() -> Self { Self { hash_witness: Vec::new(), column_witness: Vec::new(), diff --git a/crates/prover/src/core/vcs/test_utils.rs b/crates/prover/src/core/vcs/test_utils.rs index b92f9e971..c906f05d0 100644 --- a/crates/prover/src/core/vcs/test_utils.rs +++ b/crates/prover/src/core/vcs/test_utils.rs @@ -14,7 +14,7 @@ use crate::core::vcs::prover::MerkleProver; pub type TestData = ( BTreeMap>, MerkleDecommitment, - Vec>, + Vec, MerkleVerifier, ); @@ -52,9 +52,6 @@ where let (values, decommitment) = merkle.decommit(&queries, cols.iter().collect_vec()); - let verifier = MerkleVerifier { - root: merkle.root(), - column_log_sizes: log_sizes, - }; + let verifier = MerkleVerifier::new(merkle.root(), log_sizes); (queries, decommitment, values, verifier) } diff --git a/crates/prover/src/core/vcs/verifier.rs b/crates/prover/src/core/vcs/verifier.rs index 163fed2f1..bb4969f5d 100644 --- a/crates/prover/src/core/vcs/verifier.rs +++ b/crates/prover/src/core/vcs/verifier.rs @@ -1,4 +1,3 @@ -use std::cmp::Reverse; use std::collections::BTreeMap; use itertools::Itertools; @@ -9,27 +8,35 @@ use super::prover::MerkleDecommitment; use super::utils::{next_decommitment_node, option_flatten_peekable}; use crate::core::fields::m31::BaseField; use crate::core::utils::PeekableExt; -use crate::core::ColumnVec; pub struct MerkleVerifier { pub root: H::Hash, pub column_log_sizes: Vec, + pub n_columns_per_log_size: BTreeMap, } impl MerkleVerifier { pub fn new(root: H::Hash, column_log_sizes: Vec) -> Self { + let mut n_columns_per_log_size = BTreeMap::new(); + for log_size in &column_log_sizes { + *n_columns_per_log_size.entry(*log_size).or_insert(0) += 1; + } + Self { root, column_log_sizes, + n_columns_per_log_size, } } /// Verifies the decommitment of the columns. /// + /// Returns `Ok(())` if the decommitment is successfully verified. + /// /// # Arguments /// /// * `queries_per_log_size` - A map from log_size to a vector of queries for columns of that - /// log_size. - /// * `queried_values` - A vector of vectors of queried values. For each column, there is a - /// vector of queried values to that column. + /// log_size. + /// * `queried_values` - A vector of queried values according to the order in + /// [`MerkleProver::decommit()`]. /// * `decommitment` - The decommitment object containing the witness and column values. /// /// # Errors @@ -38,50 +45,34 @@ impl MerkleVerifier { /// /// * The witness is too long (not fully consumed). /// * The witness is too short (missing values). - /// * The column values are too long (not fully consumed). - /// * The column values are too short (missing values). + /// * Too many queried values (not fully consumed). + /// * Too few queried values (missing values). /// * The computed root does not match the expected root. /// - /// # Panics - /// - /// This function will panic if the `values` vector is not sorted in descending order based on - /// the `log_size` of the columns. - /// - /// # Returns - /// - /// Returns `Ok(())` if the decommitment is successfully verified. + /// [`MerkleProver::decommit()`]: crate::core::...::MerkleProver::decommit pub fn verify( &self, queries_per_log_size: &BTreeMap>, - queried_values: ColumnVec>, + queried_values: Vec, decommitment: MerkleDecommitment, ) -> Result<(), MerkleVerificationError> { let Some(max_log_size) = self.column_log_sizes.iter().max() else { return Ok(()); }; + let mut queried_values = queried_values.into_iter(); + // Prepare read buffers. - let mut queried_values_by_layer = self - .column_log_sizes - .iter() - .copied() - .zip( - queried_values - .into_iter() - .map(|column_values| column_values.into_iter()), - ) - .sorted_by_key(|(log_size, _)| Reverse(*log_size)) - .peekable(); + let mut hash_witness = decommitment.hash_witness.into_iter(); let mut column_witness = decommitment.column_witness.into_iter(); let mut last_layer_hashes: Option> = None; for layer_log_size in (0..=*max_log_size).rev() { - // Prepare read buffer for queried values to the current layer. - let mut layer_queried_values = queried_values_by_layer - .peek_take_while(|(log_size, _)| *log_size == layer_log_size) - .collect_vec(); - let n_columns_in_layer = layer_queried_values.len(); + let n_columns_in_layer = *self + .n_columns_per_log_size + .get(&layer_log_size) + .unwrap_or(&0); // Prepare write buffer for queries to the current layer. This will propagate to the // next layer. @@ -137,29 +128,26 @@ impl MerkleVerifier { .transpose()?; // If the column values were queried, read them from `queried_value`. - let node_values = if layer_column_queries.next_if_eq(&node_index).is_some() { - layer_queried_values - .iter_mut() - .map(|(_, ref mut column_queries)| { - column_queries - .next() - .ok_or(MerkleVerificationError::ColumnValuesTooShort) - }) - .collect::, _>>()? - } else { + let (err, node_values_iter) = match layer_column_queries.next_if_eq(&node_index) { + Some(_) => ( + MerkleVerificationError::TooFewQueriedValues, + &mut queried_values, + ), // Otherwise, read them from the witness. - (&mut column_witness).take(n_columns_in_layer).collect_vec() + None => ( + MerkleVerificationError::WitnessTooShort, + &mut column_witness, + ), }; + + let node_values = node_values_iter.take(n_columns_in_layer).collect_vec(); if node_values.len() != n_columns_in_layer { - return Err(MerkleVerificationError::WitnessTooShort); + return Err(err); } layer_total_queries.push((node_index, H::hash_node(node_hashes, &node_values))); } - if !layer_queried_values.iter().all(|(_, c)| c.is_empty()) { - return Err(MerkleVerificationError::ColumnValuesTooLong); - } last_layer_hashes = Some(layer_total_queries); } @@ -167,6 +155,9 @@ impl MerkleVerifier { if !hash_witness.is_empty() { return Err(MerkleVerificationError::WitnessTooLong); } + if !queried_values.is_empty() { + return Err(MerkleVerificationError::TooManyQueriedValues); + } if !column_witness.is_empty() { return Err(MerkleVerificationError::WitnessTooLong); } @@ -180,16 +171,17 @@ impl MerkleVerifier { } } +// TODO(ilya): Make error messages consistent. #[derive(Clone, Copy, Debug, Error, PartialEq, Eq)] pub enum MerkleVerificationError { - #[error("Witness is too short.")] + #[error("Witness is too short")] WitnessTooShort, #[error("Witness is too long.")] WitnessTooLong, - #[error("Column values are too long.")] - ColumnValuesTooLong, - #[error("Column values are too short.")] - ColumnValuesTooShort, + #[error("too many Queried values")] + TooManyQueriedValues, + #[error("too few queried values")] + TooFewQueriedValues, #[error("Root mismatch.")] RootMismatch, } diff --git a/crates/prover/src/examples/blake/mod.rs b/crates/prover/src/examples/blake/mod.rs index ff62f9f7d..76feb7f8b 100644 --- a/crates/prover/src/examples/blake/mod.rs +++ b/crates/prover/src/examples/blake/mod.rs @@ -88,26 +88,26 @@ impl BlakeXorElements { // TODO(alont): Generalize this to variable sizes batches if ever used. fn use_relation(&self, eval: &mut E, w: u32, values: [&[E::F]; 2]) { match w { - 12 => eval.add_to_relation(&[ - RelationEntry::new(&self.xor12, E::EF::one(), values[0]), - RelationEntry::new(&self.xor12, E::EF::one(), values[1]), - ]), - 9 => eval.add_to_relation(&[ - RelationEntry::new(&self.xor9, E::EF::one(), values[0]), - RelationEntry::new(&self.xor9, E::EF::one(), values[1]), - ]), - 8 => eval.add_to_relation(&[ - RelationEntry::new(&self.xor8, E::EF::one(), values[0]), - RelationEntry::new(&self.xor8, E::EF::one(), values[1]), - ]), - 7 => eval.add_to_relation(&[ - RelationEntry::new(&self.xor7, E::EF::one(), values[0]), - RelationEntry::new(&self.xor7, E::EF::one(), values[1]), - ]), - 4 => eval.add_to_relation(&[ - RelationEntry::new(&self.xor4, E::EF::one(), values[0]), - RelationEntry::new(&self.xor4, E::EF::one(), values[1]), - ]), + 12 => { + eval.add_to_relation(RelationEntry::new(&self.xor12, E::EF::one(), values[0])); + eval.add_to_relation(RelationEntry::new(&self.xor12, E::EF::one(), values[1])); + } + 9 => { + eval.add_to_relation(RelationEntry::new(&self.xor9, E::EF::one(), values[0])); + eval.add_to_relation(RelationEntry::new(&self.xor9, E::EF::one(), values[1])); + } + 8 => { + eval.add_to_relation(RelationEntry::new(&self.xor8, E::EF::one(), values[0])); + eval.add_to_relation(RelationEntry::new(&self.xor8, E::EF::one(), values[1])); + } + 7 => { + eval.add_to_relation(RelationEntry::new(&self.xor7, E::EF::one(), values[0])); + eval.add_to_relation(RelationEntry::new(&self.xor7, E::EF::one(), values[1])); + } + 4 => { + eval.add_to_relation(RelationEntry::new(&self.xor4, E::EF::one(), values[0])); + eval.add_to_relation(RelationEntry::new(&self.xor4, E::EF::one(), values[1])); + } _ => panic!("Invalid w"), }; } diff --git a/crates/prover/src/examples/blake/round/constraints.rs b/crates/prover/src/examples/blake/round/constraints.rs index ada5fb287..f291f44e1 100644 --- a/crates/prover/src/examples/blake/round/constraints.rs +++ b/crates/prover/src/examples/blake/round/constraints.rs @@ -14,10 +14,11 @@ pub struct BlakeRoundEval<'a, E: EvalAtRow> { pub eval: E, pub xor_lookup_elements: &'a BlakeXorElements, pub round_lookup_elements: &'a RoundElements, - pub total_sum: SecureField, - pub log_size: u32, + // TODO(first): validate logup. + pub _total_sum: SecureField, + pub _log_size: u32, } -impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { +impl BlakeRoundEval<'_, E> { pub fn eval(mut self) -> E { let mut v: [Fu32; STATE_SIZE] = std::array::from_fn(|_| self.next_u32()); let input_v = v.clone(); @@ -65,7 +66,7 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { ); // Yield `Round(input_v, output_v, message)`. - self.eval.add_to_relation(&[RelationEntry::new( + self.eval.add_to_relation(RelationEntry::new( self.round_lookup_elements, -E::EF::one(), &chain![ @@ -74,9 +75,9 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { m.iter().cloned().flat_map(Fu32::into_felts) ] .collect_vec(), - )]); + )); - self.eval.finalize_logup(); + self.eval.finalize_logup_in_pairs(); self.eval } fn next_u32(&mut self) -> Fu32 { diff --git a/crates/prover/src/examples/blake/round/gen.rs b/crates/prover/src/examples/blake/round/gen.rs index 3b9d0c853..6f4f11a9b 100644 --- a/crates/prover/src/examples/blake/round/gen.rs +++ b/crates/prover/src/examples/blake/round/gen.rs @@ -68,7 +68,7 @@ struct TraceGeneratorRow<'a> { vec_row: usize, xor_lookups_index: usize, } -impl<'a> TraceGeneratorRow<'a> { +impl TraceGeneratorRow<'_> { fn append_felt(&mut self, val: u32x16) { self.gen.trace[self.col_index].data[self.vec_row] = unsafe { PackedBaseField::from_simd_unchecked(val) }; diff --git a/crates/prover/src/examples/blake/round/mod.rs b/crates/prover/src/examples/blake/round/mod.rs index 8fa238b26..4926fe218 100644 --- a/crates/prover/src/examples/blake/round/mod.rs +++ b/crates/prover/src/examples/blake/round/mod.rs @@ -33,8 +33,8 @@ impl FrameworkEval for BlakeRoundEval { eval, xor_lookup_elements: &self.xor_lookup_elements, round_lookup_elements: &self.round_lookup_elements, - total_sum: self.total_sum, - log_size: self.log_size, + _total_sum: self.total_sum, + _log_size: self.log_size, }; blake_eval.eval() } diff --git a/crates/prover/src/examples/blake/scheduler/constraints.rs b/crates/prover/src/examples/blake/scheduler/constraints.rs index 1bf93d1aa..aceece2e8 100644 --- a/crates/prover/src/examples/blake/scheduler/constraints.rs +++ b/crates/prover/src/examples/blake/scheduler/constraints.rs @@ -30,17 +30,23 @@ pub fn eval_blake_scheduler_constraints( ] .collect_vec() }); - eval.add_to_relation(&[ - RelationEntry::new(round_lookup_elements, E::EF::one(), &elems_i), - RelationEntry::new(round_lookup_elements, E::EF::one(), &elems_j), - ]); + eval.add_to_relation(RelationEntry::new( + round_lookup_elements, + E::EF::one(), + &elems_i, + )); + eval.add_to_relation(RelationEntry::new( + round_lookup_elements, + E::EF::one(), + &elems_j, + )); } let input_state = &states[0]; let output_state = &states[N_ROUNDS]; // TODO(alont): Remove blake interaction. - eval.add_to_relation(&[RelationEntry::new( + eval.add_to_relation(RelationEntry::new( blake_lookup_elements, E::EF::zero(), &chain![ @@ -49,9 +55,9 @@ pub fn eval_blake_scheduler_constraints( messages.iter().cloned().flat_map(Fu32::into_felts) ] .collect_vec(), - )]); + )); - eval.finalize_logup(); + eval.finalize_logup_in_pairs(); } fn eval_next_u32(eval: &mut E) -> Fu32 { diff --git a/crates/prover/src/examples/blake/scheduler/mod.rs b/crates/prover/src/examples/blake/scheduler/mod.rs index b69318ce4..c998ed61b 100644 --- a/crates/prover/src/examples/blake/scheduler/mod.rs +++ b/crates/prover/src/examples/blake/scheduler/mod.rs @@ -16,10 +16,12 @@ pub type BlakeSchedulerComponent = FrameworkComponent; relation!(BlakeElements, N_ROUND_INPUT_FELTS); +#[allow(dead_code)] pub struct BlakeSchedulerEval { pub log_size: u32, pub blake_lookup_elements: BlakeElements, pub round_lookup_elements: RoundElements, + // TODO(first): validate logup. pub total_sum: SecureField, } impl FrameworkEval for BlakeSchedulerEval { diff --git a/crates/prover/src/examples/blake/xor_table/constraints.rs b/crates/prover/src/examples/blake/xor_table/constraints.rs index 4df0a6c63..60fef8bfe 100644 --- a/crates/prover/src/examples/blake/xor_table/constraints.rs +++ b/crates/prover/src/examples/blake/xor_table/constraints.rs @@ -40,43 +40,31 @@ macro_rules! xor_table_eval { 2, )); - let entry_chunks = (0..(1 << (2 * EXPAND_BITS))) - .map(|i| { - let (i, j) = ((i >> EXPAND_BITS) as u32, (i % (1 << EXPAND_BITS)) as u32); - let multiplicity = self.eval.next_trace_mask(); + for i in (0..(1 << (2 * EXPAND_BITS))) { + let (i, j) = ((i >> EXPAND_BITS) as u32, (i % (1 << EXPAND_BITS)) as u32); + let multiplicity = self.eval.next_trace_mask(); - let a = al.clone() - + E::F::from(BaseField::from_u32_unchecked( - i << limb_bits::(), - )); - let b = bl.clone() - + E::F::from(BaseField::from_u32_unchecked( - j << limb_bits::(), - )); - let c = cl.clone() - + E::F::from(BaseField::from_u32_unchecked( - (i ^ j) << limb_bits::(), - )); + let a = al.clone() + + E::F::from(BaseField::from_u32_unchecked( + i << limb_bits::(), + )); + let b = bl.clone() + + E::F::from(BaseField::from_u32_unchecked( + j << limb_bits::(), + )); + let c = cl.clone() + + E::F::from(BaseField::from_u32_unchecked( + (i ^ j) << limb_bits::(), + )); - (self.lookup_elements, -multiplicity, [a, b, c]) - }) - .collect_vec(); - - for entry_chunk in entry_chunks.chunks(2) { - self.eval.add_to_relation( - &entry_chunk - .iter() - .map(|(lookup, multiplicity, values)| { - RelationEntry::new( - *lookup, - E::EF::from(multiplicity.clone()), - values, - ) - }) - .collect_vec(), - ); + self.eval.add_to_relation(RelationEntry::new( + self.lookup_elements, + -E::EF::from(multiplicity), + &[a, b, c], + )); } - self.eval.finalize_logup(); + + self.eval.finalize_logup_in_pairs(); self.eval } } diff --git a/crates/prover/src/examples/plonk/mod.rs b/crates/prover/src/examples/plonk/mod.rs index 49da86f8a..a1e0362c9 100644 --- a/crates/prover/src/examples/plonk/mod.rs +++ b/crates/prover/src/examples/plonk/mod.rs @@ -66,18 +66,24 @@ impl FrameworkEval for PlonkEval { + (E::F::one() - op) * a_val.clone() * b_val.clone(), ); - eval.add_to_relation(&[ - RelationEntry::new(&self.lookup_elements, E::EF::one(), &[a_wire, a_val]), - RelationEntry::new(&self.lookup_elements, E::EF::one(), &[b_wire, b_val]), - ]); + eval.add_to_relation(RelationEntry::new( + &self.lookup_elements, + E::EF::one(), + &[a_wire, a_val], + )); + eval.add_to_relation(RelationEntry::new( + &self.lookup_elements, + E::EF::one(), + &[b_wire, b_val], + )); - eval.add_to_relation(&[RelationEntry::new( + eval.add_to_relation(RelationEntry::new( &self.lookup_elements, (-mult).into(), &[c_wire, c_val], - )]); + )); - eval.finalize_logup(); + eval.finalize_logup_in_pairs(); eval } } diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 51b671580..481d30fd6 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -186,13 +186,15 @@ pub fn eval_poseidon_constraints(eval: &mut E, lookup_elements: &P }); // Provide state lookups. - eval.add_to_relation(&[ - RelationEntry::new(lookup_elements, E::EF::one(), &initial_state), - RelationEntry::new(lookup_elements, -E::EF::one(), &state), - ]) + eval.add_to_relation(RelationEntry::new( + lookup_elements, + E::EF::one(), + &initial_state, + )); + eval.add_to_relation(RelationEntry::new(lookup_elements, -E::EF::one(), &state)); } - eval.finalize_logup(); + eval.finalize_logup_in_pairs(); } pub struct LookupData { @@ -452,7 +454,7 @@ mod tests { for i in 0..16 { internal_matrix[i][i] += BaseField::from_u32_unchecked(1 << (i + 1)); } - let matrix = RowMajorMatrix::::new(internal_matrix.flatten().to_vec()); + let matrix = RowMajorMatrix::::new(internal_matrix.as_flattened().to_vec()); let expected_state = matrix.mul(state); apply_internal_round_matrix(&mut state); diff --git a/crates/prover/src/examples/state_machine/components.rs b/crates/prover/src/examples/state_machine/components.rs index d5ea6ce4c..888190399 100644 --- a/crates/prover/src/examples/state_machine/components.rs +++ b/crates/prover/src/examples/state_machine/components.rs @@ -1,17 +1,22 @@ use num_traits::{One, Zero}; use crate::constraint_framework::logup::ClaimedPrefixSum; +use crate::constraint_framework::relation_tracker::{ + RelationTrackerComponent, RelationTrackerEntry, +}; use crate::constraint_framework::{ relation, EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator, RelationEntry, - PREPROCESSED_TRACE_IDX, + TraceLocationAllocator, PREPROCESSED_TRACE_IDX, }; use crate::core::air::{Component, ComponentProver}; use crate::core::backend::simd::SimdBackend; use crate::core::backend::CpuBackend; use crate::core::channel::Channel; -use crate::core::fields::m31::M31; +use crate::core::fields::m31::{BaseField, M31}; use crate::core::fields::qm31::{SecureField, QM31}; use crate::core::pcs::TreeVec; +use crate::core::poly::circle::CircleEvaluation; +use crate::core::poly::BitReversedOrder; use crate::core::prover::StarkProof; use crate::core::vcs::ops::MerkleHasher; @@ -48,12 +53,18 @@ impl FrameworkEval for StateTransitionEval let mut output_state = input_state.clone(); output_state[COORDINATE] += E::F::one(); - eval.add_to_relation(&[ - RelationEntry::new(&self.lookup_elements, E::EF::one(), &input_state), - RelationEntry::new(&self.lookup_elements, -E::EF::one(), &output_state), - ]); - - eval.finalize_logup(); + eval.add_to_relation(RelationEntry::new( + &self.lookup_elements, + E::EF::one(), + &input_state, + )); + eval.add_to_relation(RelationEntry::new( + &self.lookup_elements, + -E::EF::one(), + &output_state, + )); + + eval.finalize_logup_in_pairs(); eval } } @@ -143,6 +154,45 @@ impl StateMachineComponents { } } +pub fn track_state_machine_relations( + trace: &TreeVec<&Vec>>, + x_axis_log_n_rows: u32, + y_axis_log_n_rows: u32, + n_rows_x: u32, + n_rows_y: u32, +) -> Vec { + let tree_span_provider = &mut TraceLocationAllocator::default(); + let mut entries = vec![]; + entries.extend( + RelationTrackerComponent::new( + tree_span_provider, + StateTransitionEval::<0> { + log_n_rows: x_axis_log_n_rows, + lookup_elements: StateMachineElements::dummy(), + total_sum: QM31::zero(), + claimed_sum: (QM31::zero(), 0), + }, + n_rows_x as usize, + ) + .entries(&trace.into()), + ); + entries.extend( + RelationTrackerComponent::new( + tree_span_provider, + StateTransitionEval::<1> { + log_n_rows: y_axis_log_n_rows, + lookup_elements: StateMachineElements::dummy(), + total_sum: QM31::zero(), + claimed_sum: (QM31::zero(), 0), + }, + n_rows_y as usize, + ) + .entries(&trace.into()), + ); + + entries +} + pub struct StateMachineProof { pub public_input: [State; 2], // Initial and final state. pub stmt0: StateMachineStatement0, diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index 558cb3b58..6eeff552c 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -1,12 +1,13 @@ +use crate::constraint_framework::relation_tracker::RelationSummary; use crate::constraint_framework::Relation; pub mod components; pub mod gen; use components::{ - State, StateMachineComponents, StateMachineElements, StateMachineOp0Component, - StateMachineOp1Component, StateMachineProof, StateMachineStatement0, StateMachineStatement1, - StateTransitionEval, + track_state_machine_relations, State, StateMachineComponents, StateMachineElements, + StateMachineOp0Component, StateMachineOp1Component, StateMachineProof, StateMachineStatement0, + StateMachineStatement1, StateTransitionEval, }; use gen::{gen_interaction_trace, gen_trace}; use itertools::{chain, Itertools}; @@ -20,7 +21,7 @@ use crate::core::backend::simd::SimdBackend; use crate::core::channel::Blake2sChannel; use crate::core::fields::m31::M31; use crate::core::fields::qm31::QM31; -use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig}; +use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig, TreeVec}; use crate::core::poly::circle::{CanonicCoset, PolyOps}; use crate::core::prover::{prove, verify, VerificationError}; use crate::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher}; @@ -31,9 +32,11 @@ pub fn prove_state_machine( initial_state: State, config: PcsConfig, channel: &mut Blake2sChannel, + track_relations: bool, ) -> ( StateMachineComponents, StateMachineProof, + Option, ) { let (x_axis_log_rows, y_axis_log_rows) = (log_n_rows, log_n_rows - 1); let (x_row, y_row) = (34, 56); @@ -63,14 +66,32 @@ pub fn prove_state_machine( ]; // Preprocessed trace. - let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(gen_preprocessed_columns(preprocessed_columns.iter())); - tree_builder.commit(channel); + let preprocessed_trace = gen_preprocessed_columns(preprocessed_columns.iter()); // Trace. let trace_op0 = gen_trace(x_axis_log_rows, initial_state, 0); let trace_op1 = gen_trace(y_axis_log_rows, intermediate_state, 1); + let trace = chain![trace_op0.clone(), trace_op1.clone()].collect_vec(); + + let relation_summary = match track_relations { + false => None, + true => Some(RelationSummary::summarize_relations( + &track_state_machine_relations( + &TreeVec(vec![&preprocessed_trace, &trace]), + x_axis_log_rows, + y_axis_log_rows, + x_row, + y_row, + ), + )), + }; + + // Commitments. + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(preprocessed_trace); + tree_builder.commit(channel); + let stmt0 = StateMachineStatement0 { n: x_axis_log_rows, m: y_axis_log_rows, @@ -136,7 +157,7 @@ pub fn prove_state_machine( stmt1, stark_proof, }; - (components, proof) + (components, proof, relation_summary) } pub fn verify_state_machine( @@ -546,7 +567,8 @@ mod tests { // Setup protocol. let channel = &mut Blake2sChannel::default(); - let (component, _) = prove_state_machine(log_n_rows, initial_state, config, channel); + let (component, ..) = + prove_state_machine(log_n_rows, initial_state, config, channel, false); let interaction_elements = component.component0.lookup_elements.clone(); let initial_state_comb: QM31 = interaction_elements.combine(&initial_state); @@ -558,6 +580,42 @@ mod tests { ); } + #[test] + fn test_relation_tracker() { + let log_n_rows = 8; + let config = PcsConfig::default(); + let initial_state = [M31::zero(); STATE_SIZE]; + let final_state = [M31::from_u32_unchecked(34), M31::from_u32_unchecked(56)]; + + // Summarize `StateMachineElements`. + let (_, _, summary) = prove_state_machine( + log_n_rows, + initial_state, + config, + &mut Blake2sChannel::default(), + true, + ); + let summary = summary.unwrap(); + let relation_info = summary.get_relation_info("StateMachineElements").unwrap(); + + // Check the final state inferred from the summary. + let mut curr_state = initial_state; + for entry in relation_info { + let (x_step, y_step) = match entry.0.len() { + 2 => (entry.0[0], entry.0[1]), + 1 => (entry.0[0], M31::zero()), + 0 => (M31::zero(), M31::zero()), + _ => unreachable!(), + }; + let mult = entry.1; + let next_state = [curr_state[0] - x_step * mult, curr_state[1] - y_step * mult]; + + curr_state = next_state; + } + + assert_eq!(curr_state, final_state); + } + #[test] fn test_state_machine_prove() { let log_n_rows = get_env_var("TSMP_LOG2", 8u32); @@ -601,38 +659,33 @@ mod tests { ); let eval = component.evaluate(ExprEvaluator::new(log_n_rows, true)); - - assert_eq!(eval.constraints.len(), 2); - let constraint0_str = "(1) \ - * ((SecureCol(\ - col_2_5[claimed_sum_offset], \ - col_2_8[claimed_sum_offset], \ - col_2_11[claimed_sum_offset], \ - col_2_14[claimed_sum_offset]\ - ) - (claimed_sum)) \ - * (col_0_2[0]))"; - assert_eq!(eval.constraints[0].format_expr(), constraint0_str); - let constraint1_str = "(1) \ - * ((SecureCol(col_2_3[0], col_2_6[0], col_2_9[0], col_2_12[0]) \ - - (SecureCol(col_2_4[-1], col_2_7[-1], col_2_10[-1], col_2_13[-1]) \ - - ((col_0_2[0]) * (total_sum))) \ - - (0)) \ - * ((0 \ - + (StateMachineElements_alpha0) * (col_1_0[0]) \ - + (StateMachineElements_alpha1) * (col_1_1[0]) \ - - (StateMachineElements_z)) \ - * (0 + (StateMachineElements_alpha0) * (col_1_0[0] + 1) \ - + (StateMachineElements_alpha1) * (col_1_1[0]) \ - - (StateMachineElements_z))) \ - - ((0 \ - + (StateMachineElements_alpha0) * (col_1_0[0] + 1) \ - + (StateMachineElements_alpha1) * (col_1_1[0]) \ - - (StateMachineElements_z)) \ - * (1) \ - + (0 + (StateMachineElements_alpha0) * (col_1_0[0]) \ - + (StateMachineElements_alpha1) * (col_1_1[0]) \ - - (StateMachineElements_z)) \ - * (-(1))))"; - assert_eq!(eval.constraints[1].format_expr(), constraint1_str); + let expected = "let intermediate0 = (StateMachineElements_alpha0) * (trace_1_column_0_offset_0) \ + + (StateMachineElements_alpha1) * (trace_1_column_1_offset_0) \ + - (StateMachineElements_z); + +\ + let intermediate1 = (StateMachineElements_alpha0) * (trace_1_column_0_offset_0 + m31(1).into()) \ + + (StateMachineElements_alpha1) * (trace_1_column_1_offset_0) \ + - (StateMachineElements_z); + +\ + let constraint_0 = (QM31Impl::from_partial_evals([\ + trace_2_column_2_offset_claimed_sum, \ + trace_2_column_3_offset_claimed_sum, \ + trace_2_column_4_offset_claimed_sum, \ + trace_2_column_5_offset_claimed_sum\ + ]) - (claimed_sum)) \ + * (preprocessed_is_first); + +\ + let constraint_1 = (QM31Impl::from_partial_evals([trace_2_column_2_offset_0, trace_2_column_3_offset_0, trace_2_column_4_offset_0, trace_2_column_5_offset_0]) \ + - (QM31Impl::from_partial_evals([trace_2_column_2_offset_neg_1, trace_2_column_3_offset_neg_1, trace_2_column_4_offset_neg_1, trace_2_column_5_offset_neg_1]) \ + - ((total_sum) * (preprocessed_is_first)))\ + ) \ + * ((intermediate0) * (intermediate1)) \ + - (intermediate1 - (intermediate0));" + .to_string(); + + assert_eq!(eval.format_constraints(), expected); } } diff --git a/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs b/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs index 986289572..8e0ae2d74 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs @@ -4,13 +4,13 @@ use std::ops::{AddAssign, Mul}; use educe::Educe; use num_traits::One; +use crate::core::air::accumulation::AccumulationOps; use crate::core::backend::simd::SimdBackend; use crate::core::backend::Backend; use crate::core::circle::M31_CIRCLE_LOG_ORDER; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::lookups::mle::Mle; -use crate::core::utils::generate_secure_powers; pub const MIN_LOG_BLOWUP_FACTOR: u32 = 1; @@ -59,7 +59,8 @@ fn mle_random_linear_combination( assert!(!mles.is_empty()); let n_variables = mles[0].n_variables(); assert!(mles.iter().all(|mle| mle.n_variables() == n_variables)); - let coeff_powers = generate_secure_powers(random_coeff, mles.len()); + let coeff_powers = + ::generate_secure_powers(random_coeff, mles.len()); let mut mle_and_coeff = zip(mles, coeff_powers.into_iter().rev()); // The last value can initialize the accumulator. diff --git a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs index 69acedccb..3c0c17d0c 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs @@ -14,6 +14,7 @@ use crate::constraint_framework::{ }; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use crate::core::air::{Component, ComponentProver, Trace}; +use crate::core::backend::cpu::bit_reverse; use crate::core::backend::simd::column::{SecureColumn, VeryPackedSecureColumnByCoords}; use crate::core::backend::simd::m31::LOG_N_LANES; use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum; @@ -36,7 +37,7 @@ use crate::core::poly::circle::{ }; use crate::core::poly::twiddles::TwiddleTree; use crate::core::poly::BitReversedOrder; -use crate::core::utils::{self, bit_reverse_index, coset_index_to_circle_domain_index}; +use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; use crate::core::ColumnVec; /// Prover component that carries out a univariate IOP for multilinear eval at point. @@ -114,9 +115,7 @@ impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> MleEvalProverComponent<'twiddl } } -impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> Component - for MleEvalProverComponent<'twiddles, 'oracle, O> -{ +impl Component for MleEvalProverComponent<'_, '_, O> { fn n_constraints(&self) -> usize { self.eval_info().n_constraints } @@ -190,9 +189,7 @@ impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> Component } } -impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> ComponentProver - for MleEvalProverComponent<'twiddles, 'oracle, O> -{ +impl ComponentProver for MleEvalProverComponent<'_, '_, O> { fn evaluate_constraint_quotients_on_domain( &self, trace: &Trace<'_, SimdBackend>, @@ -231,7 +228,7 @@ impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> ComponentProver let mut denom_inv = (0..1 << log_expand) .map(|i| coset_vanishing(trace_domain.coset(), eval_domain.at(i)).inverse()) .collect_vec(); - utils::bit_reverse(&mut denom_inv); + bit_reverse(&mut denom_inv); // Accumulator. let [mut acc] = accumulator.columns([(eval_domain.log_size(), self.n_constraints())]); @@ -329,7 +326,7 @@ impl<'oracle, O: MleCoeffColumnOracle> MleEvalVerifierComponent<'oracle, O> { } } -impl<'oracle, O: MleCoeffColumnOracle> Component for MleEvalVerifierComponent<'oracle, O> { +impl Component for MleEvalVerifierComponent<'_, O> { fn n_constraints(&self) -> usize { self.eval_info().n_constraints } @@ -752,6 +749,7 @@ mod tests { }; use crate::constraint_framework::{assert_constraints, EvalAtRow, TraceLocationAllocator}; use crate::core::air::{Component, ComponentProver, Components}; + use crate::core::backend::cpu::bit_reverse; use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum; use crate::core::backend::simd::qm31::PackedSecureField; use crate::core::backend::simd::SimdBackend; @@ -765,7 +763,7 @@ mod tests { use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; use crate::core::prover::{prove, verify, VerificationError}; - use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order}; + use crate::core::utils::coset_order_to_circle_domain_order; use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; use crate::examples::xor::gkr_lookups::accumulation::MIN_LOG_BLOWUP_FACTOR; use crate::examples::xor::gkr_lookups::mle_eval::eval_step_selector_with_offset; @@ -1213,7 +1211,7 @@ mod tests { } impl MleCoeffColumnEval { - pub fn new(interaction: usize, n_variables: usize) -> Self { + pub const fn new(interaction: usize, n_variables: usize) -> Self { Self { interaction, n_variables, diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index 02aa6a4e8..4c6e96b3d 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -1,22 +1,20 @@ #![allow(incomplete_features)] +#![cfg_attr( + all(target_arch = "x86_64", target_feature = "avx512f"), + feature(stdarch_x86_avx512) +)] #![feature( array_chunks, - array_methods, array_try_from_fn, + array_windows, assert_matches, exact_size_is_empty, - generic_const_exprs, get_many_mut, int_roundings, - is_sorted, iter_array_chunks, - new_uninit, portable_simd, - slice_first_last_chunk, - slice_flatten, - slice_group_by, slice_ptr_get, - stdsimd + trait_upcasting )] pub mod constraint_framework; diff --git a/rust-toolchain.toml b/rust-toolchain.toml index a0f1a930e..27381425f 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "nightly-2024-01-04" +channel = "nightly-2025-01-02" diff --git a/scripts/clippy.sh b/scripts/clippy.sh index 8361cd25d..9198c8648 100755 --- a/scripts/clippy.sh +++ b/scripts/clippy.sh @@ -1,3 +1,3 @@ #!/bin/bash -cargo +nightly-2024-01-04 clippy "$@" --all-targets --all-features -- -D warnings -D future-incompatible \ +cargo +nightly-2025-01-02 clippy "$@" --all-targets --all-features -- -D warnings -D future-incompatible \ -D nonstandard-style -D rust-2018-idioms -D unused diff --git a/scripts/rust_fmt.sh b/scripts/rust_fmt.sh index e4223f999..9f80b191c 100755 --- a/scripts/rust_fmt.sh +++ b/scripts/rust_fmt.sh @@ -1,3 +1,3 @@ #!/bin/bash -cargo +nightly-2024-01-04 fmt --all -- "$@" +cargo +nightly-2025-01-02 fmt --all -- "$@" diff --git a/scripts/test_avx.sh b/scripts/test_avx.sh index d911a2479..eb4429d3a 100755 --- a/scripts/test_avx.sh +++ b/scripts/test_avx.sh @@ -1,4 +1,4 @@ #!/bin/bash # Can be used as a drop in replacement for `cargo test` with avx512f flag on. # For example, `./scripts/test_avx.sh` will run all tests(not only avx). -RUSTFLAGS="-Awarnings -C target-cpu=native -C target-feature=+avx512f -C opt-level=2" cargo +nightly-2024-01-04 test "$@" +RUSTFLAGS="-Awarnings -C target-cpu=native -C target-feature=+avx512f -C opt-level=2" cargo +nightly-2025-01-02 test "$@" From 033ab83c4fd288ccbdee33a5c60eb3c3593d64a2 Mon Sep 17 00:00:00 2001 From: VitaliiH Date: Mon, 13 Jan 2025 16:18:51 +0200 Subject: [PATCH 62/69] minor cleanup --- crates/prover/src/core/backend/icicle/mod.rs | 4 ---- crates/prover/src/examples/wide_fibonacci/mod.rs | 16 ++++++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/crates/prover/src/core/backend/icicle/mod.rs b/crates/prover/src/core/backend/icicle/mod.rs index e988252e5..e15bc8313 100644 --- a/crates/prover/src/core/backend/icicle/mod.rs +++ b/crates/prover/src/core/backend/icicle/mod.rs @@ -585,10 +585,6 @@ impl FriOps for IcicleBackend { let icicle_alpha = unsafe { transmute(alpha) }; nvtx::range_push!("[ICICLE] fold circle"); - println!( - "index {}, half_log_size {}", - domain.half_coset.initial_index.0, domain.half_coset.log_size - ); let _ = fold_circle_into_line_new( &d_evals_icicle[..], domain.half_coset.initial_index.0 as _, diff --git a/crates/prover/src/examples/wide_fibonacci/mod.rs b/crates/prover/src/examples/wide_fibonacci/mod.rs index af7d5a35f..cb3578127 100644 --- a/crates/prover/src/examples/wide_fibonacci/mod.rs +++ b/crates/prover/src/examples/wide_fibonacci/mod.rs @@ -175,7 +175,7 @@ mod tests { use crate::examples::utils::get_env_var; let min_log = get_env_var("MIN_FIB_LOG", 2u32); - let max_log = get_env_var("MAX_FIB_LOG", 18u32); + let max_log = get_env_var("MAX_FIB_LOG", 25u32); for log_n_instances in min_log..=max_log { let config = PcsConfig::default(); @@ -211,12 +211,18 @@ mod tests { (SecureField::zero(), None), ); + let start = std::time::Instant::now(); let proof = prove::( &[&component], prover_channel, commitment_scheme, ) .unwrap(); + println!( + "SIMD proving for 2^{:?} took {:?} ms", + log_n_instances, + start.elapsed().as_millis() + ); // Verify. let verifier_channel = &mut Blake2sChannel::default(); @@ -255,10 +261,8 @@ mod tests { // Setup protocol. let prover_channel = &mut Blake2sChannel::default(); - let mut commitment_scheme = CommitmentSchemeProver::< - TheBackend, - Blake2sMerkleChannel, - >::new(config, &twiddles); + let mut commitment_scheme = + CommitmentSchemeProver::::new(config, &twiddles); // Preprocessed trace let mut tree_builder = commitment_scheme.tree_builder(); @@ -286,7 +290,7 @@ mod tests { ); icicle_m31::fri::precompute_fri_twiddles(log_n_instances).unwrap(); - println!("++++++++ proving for 2^{:?}", log_n_instances); + let start = std::time::Instant::now(); let proof = prove::( &[&component], From 603499b721c5790b7df16a016675315fef144ff1 Mon Sep 17 00:00:00 2001 From: VitaliiH Date: Tue, 14 Jan 2025 05:30:05 +0100 Subject: [PATCH 63/69] fix fold line with domain calc --- crates/prover/src/core/backend/icicle/mod.rs | 26 ++++---------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/crates/prover/src/core/backend/icicle/mod.rs b/crates/prover/src/core/backend/icicle/mod.rs index 6364768a5..b8ee57b79 100644 --- a/crates/prover/src/core/backend/icicle/mod.rs +++ b/crates/prover/src/core/backend/icicle/mod.rs @@ -503,26 +503,8 @@ impl FriOps for IcicleBackend { let dom_vals_len = length / 2; - let mut domain_vals = Vec::new(); let line_domain_log_size = domain.log_size(); - nvtx::range_push!("[ICICLE] calc domain values"); - for i in 0..dom_vals_len { - // TODO: on-device batch - // TODO(andrew): Inefficient. Update when domain twiddles get stored in a buffer. - domain_vals.push(ScalarField::from_u32( - domain - .at(bit_reverse_index(i << FOLD_STEP, line_domain_log_size)) - .inverse() - .0, - )); - } - nvtx::range_pop!(); - - nvtx::range_push!("[ICICLE] domain values to device"); - let domain_icicle_host = HostSlice::from_slice(domain_vals.as_slice()); - let mut d_domain_icicle = DeviceVec::::cuda_malloc(dom_vals_len).unwrap(); - d_domain_icicle.copy_from_host(domain_icicle_host).unwrap(); - nvtx::range_pop!(); + nvtx::range_push!("[ICICLE] domain evals convert + move"); let mut d_evals_icicle = DeviceVec::::cuda_malloc(length).unwrap(); SecureColumnByCoords::::convert_to_icicle( @@ -536,9 +518,11 @@ impl FriOps for IcicleBackend { let cfg = FriConfig::default(); let icicle_alpha = unsafe { transmute(alpha) }; nvtx::range_push!("[ICICLE] fold_line"); - let _ = fri::fold_line( + let _ = fri::fold_line_new( &d_evals_icicle[..], - &d_domain_icicle[..], + // &d_domain_icicle[..], + domain.coset().initial_index.0 as _, + domain.coset().log_size, &mut d_folded_eval[..], icicle_alpha, &cfg, From a76c4c059d5348690467bae190e1d1d2a0732413 Mon Sep 17 00:00:00 2001 From: VitaliiH Date: Tue, 14 Jan 2025 06:37:10 +0200 Subject: [PATCH 64/69] benches added --- crates/prover/benches/eval_at_point.rs | 2 ++ crates/prover/benches/fri.rs | 2 +- crates/prover/benches/quotients.rs | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/crates/prover/benches/eval_at_point.rs b/crates/prover/benches/eval_at_point.rs index 64d1eecc7..82a4debdd 100644 --- a/crates/prover/benches/eval_at_point.rs +++ b/crates/prover/benches/eval_at_point.rs @@ -2,6 +2,7 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; use stwo_prover::core::backend::cpu::CpuBackend; +use stwo_prover::core::backend::icicle::IcicleBackend; use stwo_prover::core::backend::simd::SimdBackend; use stwo_prover::core::circle::CirclePoint; use stwo_prover::core::fields::m31::BaseField; @@ -26,6 +27,7 @@ fn bench_eval_at_secure_point(c: &mut Criterion, id: &str) { fn eval_at_secure_point_benches(c: &mut Criterion) { bench_eval_at_secure_point::(c, "simd"); bench_eval_at_secure_point::(c, "cpu"); + bench_eval_at_secure_point::(c, "icicle"); } criterion_group!( diff --git a/crates/prover/benches/fri.rs b/crates/prover/benches/fri.rs index f20fb4674..7bbb8cb49 100644 --- a/crates/prover/benches/fri.rs +++ b/crates/prover/benches/fri.rs @@ -137,7 +137,7 @@ fn icicle_raw_folding_benchmark(c: &mut Criterion) { use icicle_cuda_runtime::memory::{DeviceVec, HostOrDeviceSlice, HostSlice}; use icicle_m31::field::{QuarticExtensionField, ScalarField}; use icicle_m31::fri::{self, fold_circle_into_line, FriConfig}; - use stwo_prover::core::fields::FieldExpOps; + // use stwo_prover::core::fields::FieldExpOps; use stwo_prover::core::fri::{CIRCLE_TO_LINE_FOLD_STEP, FOLD_STEP}; use stwo_prover::core::poly::BitReversedOrder; use stwo_prover::core::utils::bit_reverse_index; diff --git a/crates/prover/benches/quotients.rs b/crates/prover/benches/quotients.rs index e1b592d9c..07eceaca6 100644 --- a/crates/prover/benches/quotients.rs +++ b/crates/prover/benches/quotients.rs @@ -49,7 +49,7 @@ fn quotients_benches(c: &mut Criterion) { bench_quotients::(c, "simd"); #[cfg(feature = "icicle")] bench_quotients::(c, "icicle"); - bench_quotients::(c, "cpu"); + bench_quotients::(c, "cpu"); } criterion_group!( From 81dc838553a5d22e04d462e39878a2f5e022eb64 Mon Sep 17 00:00:00 2001 From: Jeremy Felder Date: Tue, 21 Jan 2025 10:10:45 +0000 Subject: [PATCH 65/69] Remove unnecessary device mem operations. Fix profile ranges from additional pop --- crates/prover/src/core/backend/icicle/mod.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/crates/prover/src/core/backend/icicle/mod.rs b/crates/prover/src/core/backend/icicle/mod.rs index b8ee57b79..7b2b9fe68 100644 --- a/crates/prover/src/core/backend/icicle/mod.rs +++ b/crates/prover/src/core/backend/icicle/mod.rs @@ -559,7 +559,6 @@ impl FriOps for IcicleBackend { let mut d_evals_icicle = DeviceVec::::cuda_malloc(length).unwrap(); SecureColumnByCoords::convert_to_icicle(&src.values, &mut d_evals_icicle); - nvtx::range_pop!(); nvtx::range_push!("[ICICLE] d_folded_evals"); let mut d_folded_eval = @@ -567,9 +566,6 @@ impl FriOps for IcicleBackend { SecureColumnByCoords::convert_to_icicle(&dst.values, &mut d_folded_eval); nvtx::range_pop!(); - let mut folded_eval_raw = vec![QuarticExtensionField::zero(); dom_vals_len]; - let folded_eval = HostSlice::from_mut_slice(folded_eval_raw.as_mut_slice()); - let cfg = FriConfig::default(); let icicle_alpha = unsafe { transmute(alpha) }; @@ -585,8 +581,6 @@ impl FriOps for IcicleBackend { .unwrap(); nvtx::range_pop!(); - d_folded_eval.copy_to_host(folded_eval).unwrap(); - nvtx::range_push!("[ICICLE] convert to SecureColumnByCoords"); SecureColumnByCoords::convert_from_icicle_q31(&mut dst.values, &mut d_folded_eval[..]); nvtx::range_pop!(); From 743a8c150f4673e5b440f13f0a21a495b9eadb8a Mon Sep 17 00:00:00 2001 From: VitaliiH Date: Tue, 21 Jan 2025 14:54:42 +0100 Subject: [PATCH 66/69] fix compute poly to use icicle --- .../src/constraint_framework/component.rs | 5 +- .../src/constraint_framework/icicle_domain.rs | 99 +++++++++++++++++++ crates/prover/src/constraint_framework/mod.rs | 3 + 3 files changed, 105 insertions(+), 2 deletions(-) create mode 100644 crates/prover/src/constraint_framework/icicle_domain.rs diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index caae6fec6..d501c2b3a 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -15,6 +15,7 @@ use super::preprocessed_columns::PreprocessedColumn; use super::{ EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator, PREPROCESSED_TRACE_IDX, }; +use crate::constraint_framework::icicle_domain::IcicleDomainEvaluator; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use crate::core::air::{Component, ComponentProver, Trace}; use crate::core::backend::cpu::bit_reverse; @@ -602,8 +603,8 @@ impl ComponentProver for FrameworkCompon }); // Evaluate constrains at row. - let eval = CpuDomainEvaluator::new( - unsafe { std::mem::transmute(&trace_cols) }, + let eval = IcicleDomainEvaluator::new( + &trace_cols, row, &accum.random_coeff_powers, trace_domain.log_size(), diff --git a/crates/prover/src/constraint_framework/icicle_domain.rs b/crates/prover/src/constraint_framework/icicle_domain.rs new file mode 100644 index 000000000..3496b180b --- /dev/null +++ b/crates/prover/src/constraint_framework/icicle_domain.rs @@ -0,0 +1,99 @@ +use std::ops::Mul; + +use num_traits::Zero; + +use super::logup::{LogupAtRow, LogupSums}; +use super::{EvalAtRow, INTERACTION_TRACE_IDX}; +use crate::core::backend::icicle::IcicleBackend; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; +use crate::core::lookups::utils::Fraction; +use crate::core::pcs::TreeVec; +use crate::core::poly::circle::CircleEvaluation; +use crate::core::poly::BitReversedOrder; +use crate::core::utils::offset_bit_reversed_circle_domain_index; + +/// Evaluates constraints at an evaluation domain points. +pub struct IcicleDomainEvaluator<'a> { + pub trace_eval: &'a TreeVec>>, + pub column_index_per_interaction: Vec, + pub row: usize, + pub random_coeff_powers: &'a [SecureField], + pub row_res: SecureField, + pub constraint_index: usize, + pub domain_log_size: u32, + pub eval_domain_log_size: u32, + pub logup: LogupAtRow, +} + +impl<'a> IcicleDomainEvaluator<'a> { + #[allow(dead_code)] + pub fn new( + trace_eval: &'a TreeVec>>, + row: usize, + random_coeff_powers: &'a [SecureField], + domain_log_size: u32, + eval_log_size: u32, + log_size: u32, + logup_sums: LogupSums, + ) -> Self { + Self { + trace_eval, + column_index_per_interaction: vec![0; trace_eval.len()], + row, + random_coeff_powers, + row_res: SecureField::zero(), + constraint_index: 0, + domain_log_size, + eval_domain_log_size: eval_log_size, + logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size), + } + } +} + +impl EvalAtRow for IcicleDomainEvaluator<'_> { + type F = BaseField; + type EF = SecureField; + + // TODO(spapini): Remove all boundary checks. + fn next_interaction_mask( + &mut self, + interaction: usize, + offsets: [isize; N], + ) -> [Self::F; N] { + let col_index = self.column_index_per_interaction[interaction]; + self.column_index_per_interaction[interaction] += 1; + offsets.map(|off| { + // If the offset is 0, we can just return the value directly from this row. + if off == 0 { + let col = &self.trace_eval[interaction][col_index]; + return col[self.row]; + } + // Otherwise, we need to look up the value at the offset. + // Since the domain is bit-reversed circle domain ordered, we need to look up the value + // at the bit-reversed natural order index at an offset. + let row = offset_bit_reversed_circle_domain_index( + self.row, + self.domain_log_size, + self.eval_domain_log_size, + off, + ); + self.trace_eval[interaction][col_index][row] + }) + } + + fn add_constraint(&mut self, constraint: G) + where + Self::EF: Mul + From, + { + self.row_res += self.random_coeff_powers[self.constraint_index] * constraint; + self.constraint_index += 1; + } + + fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF { + SecureField::from_m31_array(values) + } + + super::logup_proxy!(); +} diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 22809152d..b7f07f575 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -10,6 +10,9 @@ pub mod preprocessed_columns; pub mod relation_tracker; mod simd_domain; +#[cfg(feature = "icicle")] +mod icicle_domain; + use std::array; use std::fmt::Debug; use std::ops::{Add, AddAssign, Mul, Neg, Sub}; From 77abeb88135869d869e2d5c81acb1202c6d34c65 Mon Sep 17 00:00:00 2001 From: Jeremy Felder Date: Mon, 27 Jan 2025 13:38:20 +0000 Subject: [PATCH 67/69] commit_layer for icicle backend --- Cargo.lock | 8 +- crates/prover/Cargo.toml | 8 +- crates/prover/src/core/backend/cpu/blake2s.rs | 2 - .../src/core/backend/cpu/poseidon252.rs | 2 - crates/prover/src/core/backend/icicle/mod.rs | 97 ++++++------------- .../prover/src/core/backend/simd/blake2s.rs | 2 - .../src/core/backend/simd/poseidon252.rs | 2 - crates/prover/src/core/vcs/ops.rs | 5 - crates/prover/src/core/vcs/prover.rs | 44 +++------ 9 files changed, 56 insertions(+), 114 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4e715e06b..f7f6a5826 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -754,7 +754,7 @@ dependencies = [ [[package]] name = "icicle-core" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=7e7c1c8d96af1963df94c9ab6e7fdc37176e9543#7e7c1c8d96af1963df94c9ab6e7fdc37176e9543" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=9d33e77dee08e4146da4da7b1bc6d5c77a531159#9d33e77dee08e4146da4da7b1bc6d5c77a531159" dependencies = [ "criterion 0.3.6", "hex", @@ -765,7 +765,7 @@ dependencies = [ [[package]] name = "icicle-cuda-runtime" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=7e7c1c8d96af1963df94c9ab6e7fdc37176e9543#7e7c1c8d96af1963df94c9ab6e7fdc37176e9543" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=9d33e77dee08e4146da4da7b1bc6d5c77a531159#9d33e77dee08e4146da4da7b1bc6d5c77a531159" dependencies = [ "bindgen", "bitflags 1.3.2", @@ -774,7 +774,7 @@ dependencies = [ [[package]] name = "icicle-hash" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=7e7c1c8d96af1963df94c9ab6e7fdc37176e9543#7e7c1c8d96af1963df94c9ab6e7fdc37176e9543" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=9d33e77dee08e4146da4da7b1bc6d5c77a531159#9d33e77dee08e4146da4da7b1bc6d5c77a531159" dependencies = [ "cmake", "icicle-core", @@ -784,7 +784,7 @@ dependencies = [ [[package]] name = "icicle-m31" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=7e7c1c8d96af1963df94c9ab6e7fdc37176e9543#7e7c1c8d96af1963df94c9ab6e7fdc37176e9543" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=9d33e77dee08e4146da4da7b1bc6d5c77a531159#9d33e77dee08e4146da4da7b1bc6d5c77a531159" dependencies = [ "cmake", "criterion 0.3.6", diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index dac1f7d2a..6c8223147 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -28,10 +28,10 @@ tracing.workspace = true rayon = { version = "1.10.0", optional = true } serde = { version = "1.0", features = ["derive"] } -icicle-cuda-runtime = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="7e7c1c8d96af1963df94c9ab6e7fdc37176e9543"} -icicle-core = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="7e7c1c8d96af1963df94c9ab6e7fdc37176e9543"} -icicle-m31 = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="7e7c1c8d96af1963df94c9ab6e7fdc37176e9543"} -icicle-hash = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="7e7c1c8d96af1963df94c9ab6e7fdc37176e9543"} +icicle-cuda-runtime = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="9d33e77dee08e4146da4da7b1bc6d5c77a531159"} +icicle-core = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="9d33e77dee08e4146da4da7b1bc6d5c77a531159"} +icicle-m31 = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="9d33e77dee08e4146da4da7b1bc6d5c77a531159"} +icicle-hash = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="9d33e77dee08e4146da4da7b1bc6d5c77a531159"} nvtx = { version = "*", optional = true } diff --git a/crates/prover/src/core/backend/cpu/blake2s.rs b/crates/prover/src/core/backend/cpu/blake2s.rs index a4ea918e1..a87a5ae00 100644 --- a/crates/prover/src/core/backend/cpu/blake2s.rs +++ b/crates/prover/src/core/backend/cpu/blake2s.rs @@ -7,8 +7,6 @@ use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher; use crate::core::vcs::ops::{MerkleHasher, MerkleOps}; impl MerkleOps for CpuBackend { - const COMMIT_IMPLEMENTED: bool = false; - fn commit_on_layer( log_size: u32, prev_layer: Option<&Vec>, diff --git a/crates/prover/src/core/backend/cpu/poseidon252.rs b/crates/prover/src/core/backend/cpu/poseidon252.rs index f4b2089fc..8cc5dd9d5 100644 --- a/crates/prover/src/core/backend/cpu/poseidon252.rs +++ b/crates/prover/src/core/backend/cpu/poseidon252.rs @@ -7,8 +7,6 @@ use crate::core::vcs::ops::{MerkleHasher, MerkleOps}; use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleHasher; impl MerkleOps for CpuBackend { - const COMMIT_IMPLEMENTED: bool = false; - fn commit_on_layer( log_size: u32, prev_layer: Option<&Vec>, diff --git a/crates/prover/src/core/backend/icicle/mod.rs b/crates/prover/src/core/backend/icicle/mod.rs index 7b2b9fe68..51380b8b5 100644 --- a/crates/prover/src/core/backend/icicle/mod.rs +++ b/crates/prover/src/core/backend/icicle/mod.rs @@ -11,7 +11,7 @@ use icicle_core::field::Field as IcicleField; use icicle_core::tree::{merkle_tree_digests_len, TreeBuilderConfig}; use icicle_core::vec_ops::{accumulate_scalars, VecOpsConfig}; use icicle_core::Matrix; -use icicle_hash::blake2s::build_blake2s_mmcs; +use icicle_hash::blake2s::blake2s_commit_layer; use icicle_m31::dcct::{evaluate, get_dcct_root_of_unity, initialize_dcct_domain, interpolate}; use icicle_m31::field::ScalarCfg; use icicle_m31::fri::{self, fold_circle_into_line, fold_circle_into_line_new, FriConfig}; @@ -171,73 +171,42 @@ impl AccumulationOps for IcicleBackend { // stwo/crates/prover/src/core/backend/cpu/blake2s.rs impl MerkleOps for IcicleBackend { - const COMMIT_IMPLEMENTED: bool = true; - - fn commit_columns( - columns: Vec<&Col>, - ) -> Vec::Hash>> { - let mut config = TreeBuilderConfig::default(); - config.arity = 2; - config.digest_elements = 32; - config.sort_inputs = false; - - nvtx::range_push!("[ICICLE] log_max"); - let log_max = columns - .iter() - .sorted_by_key(|c| Reverse(c.len())) - .next() - .unwrap() - .len() - .ilog2(); - nvtx::range_pop!(); - let mut matrices = vec![]; - nvtx::range_push!("[ICICLE] create matrix"); - for col in columns.into_iter().sorted_by_key(|c| Reverse(c.len())) { - matrices.push(Matrix::from_slice(col, 4, col.len())); - } - nvtx::range_pop!(); - nvtx::range_push!("[ICICLE] merkle_tree_digests_len"); - let digests_len = merkle_tree_digests_len(log_max as u32, 2, 32); - nvtx::range_pop!(); - let mut digests = vec![0u8; digests_len]; - let digests_slice = HostSlice::from_mut_slice(&mut digests); - nvtx::range_push!("[ICICLE] build_blake2s_mmcs"); - build_blake2s_mmcs(&matrices, digests_slice, &config).unwrap(); - nvtx::range_pop!(); - - let mut digests: &[::Hash] = - unsafe { std::mem::transmute(digests.as_mut_slice()) }; - // Transmute digests into stwo format - let mut layers = vec![]; - let mut offset = 0usize; - nvtx::range_push!("[ICICLE] convert to CPU layer"); - for log in 0..=log_max { - let inv_log = log_max - log; - let number_of_rows = 1 << inv_log; - - let mut layer = vec![]; - layer.extend_from_slice(&digests[offset..offset + number_of_rows]); - layers.push(layer); - - if log != log_max { - offset += number_of_rows; - } - } - - layers.reverse(); - nvtx::range_pop!(); - layers - } - fn commit_on_layer( log_size: u32, prev_layer: Option<&Col::Hash>>, columns: &[&Col], ) -> Col::Hash> { - // todo!() - >::commit_on_layer( - log_size, prev_layer, columns, - ) + let prev_layer = match prev_layer { + Some(layer) => unsafe{ transmute(layer) }, + None => &Vec::::with_capacity(0), + }; + + let mut columns_as_matrices = vec![]; + for &col in columns { + columns_as_matrices.push(Matrix::from_slice(col, 4, col.len())); + } + + let digest_bytes = (1 << log_size) * 32; + let mut d_digests_slice = DeviceVec::cuda_malloc(digest_bytes).unwrap(); + + blake2s_commit_layer( + HostSlice::from_slice(prev_layer), + false, + &columns_as_matrices, + false, + columns.len() as u32, + 1 << log_size, + &mut d_digests_slice[..], + ).unwrap(); + + let mut digests = vec![0u8; digest_bytes]; + let mut digests_slice = HostSlice::from_mut_slice(&mut digests); + + d_digests_slice + .copy_to_host(&mut digests_slice) + .unwrap(); + + unsafe { std::mem::transmute(digests) } } } @@ -695,8 +664,6 @@ impl QuotientOps for IcicleBackend { // stwo/crates/prover/src/core/vcs/poseidon252_merkle.rs impl MerkleOps for IcicleBackend { - const COMMIT_IMPLEMENTED: bool = false; - fn commit_on_layer( log_size: u32, prev_layer: Option<&Col::Hash>>, diff --git a/crates/prover/src/core/backend/simd/blake2s.rs b/crates/prover/src/core/backend/simd/blake2s.rs index b499efa8f..3f4d46b8f 100644 --- a/crates/prover/src/core/backend/simd/blake2s.rs +++ b/crates/prover/src/core/backend/simd/blake2s.rs @@ -46,8 +46,6 @@ impl ColumnOps for SimdBackend { } impl MerkleOps for SimdBackend { - const COMMIT_IMPLEMENTED: bool = false; - fn commit_on_layer( log_size: u32, prev_layer: Option<&Vec>, diff --git a/crates/prover/src/core/backend/simd/poseidon252.rs b/crates/prover/src/core/backend/simd/poseidon252.rs index f0f6a210b..b001481b4 100644 --- a/crates/prover/src/core/backend/simd/poseidon252.rs +++ b/crates/prover/src/core/backend/simd/poseidon252.rs @@ -18,8 +18,6 @@ impl ColumnOps for SimdBackend { } impl MerkleOps for SimdBackend { - const COMMIT_IMPLEMENTED: bool = false; - // TODO(ShaharS): replace with SIMD implementation. fn commit_on_layer( log_size: u32, diff --git a/crates/prover/src/core/vcs/ops.rs b/crates/prover/src/core/vcs/ops.rs index b55646f14..b40a91bef 100644 --- a/crates/prover/src/core/vcs/ops.rs +++ b/crates/prover/src/core/vcs/ops.rs @@ -25,11 +25,6 @@ pub trait MerkleHasher: Debug + Default + Clone { pub trait MerkleOps: ColumnOps + ColumnOps + for<'de> Deserialize<'de> + Serialize { - const COMMIT_IMPLEMENTED: bool; - - fn commit_columns(columns: Vec<&Col>) -> Vec> { - Vec::new() - } /// Commits on an entire layer of the Merkle tree. /// See [MerkleHasher] for more details. /// diff --git a/crates/prover/src/core/vcs/prover.rs b/crates/prover/src/core/vcs/prover.rs index 4f623bf9f..b79b5ad43 100644 --- a/crates/prover/src/core/vcs/prover.rs +++ b/crates/prover/src/core/vcs/prover.rs @@ -39,39 +39,27 @@ impl, H: MerkleHasher> MerkleProver { pub fn commit(columns: Vec<&Col>) -> Self { if columns.is_empty() { return Self { - // TODO: does our Merkle support this? layers: vec![B::commit_on_layer(0, None, &[])], }; } - if B::COMMIT_IMPLEMENTED { - Self { - layers: B::commit_columns(columns), - } - } else { - let columns = &mut columns - .into_iter() - .sorted_by_key(|c| Reverse(c.len())) - .peekable(); - let mut layers: Vec> = Vec::new(); - - let max_log_size = columns.peek().unwrap().len().ilog2(); - for log_size in (0..=max_log_size).rev() { - // Take columns of the current log_size. - let layer_columns = columns - .peek_take_while(|column| column.len().ilog2() == log_size) - .collect_vec(); - - // TO DO: Remove on clean up - // for col in &layer_columns { - // println!("First element equals {:02x}", col.at(0).0 & 0xff ); - // } - - layers.push(B::commit_on_layer(log_size, layers.last(), &layer_columns)); - } - layers.reverse(); - Self { layers } + let columns = &mut columns + .into_iter() + .sorted_by_key(|c| Reverse(c.len())) + .peekable(); + let mut layers: Vec> = Vec::new(); + + let max_log_size = columns.peek().unwrap().len().ilog2(); + for log_size in (0..=max_log_size).rev() { + // Take columns of the current log_size. + let layer_columns = columns + .peek_take_while(|column| column.len().ilog2() == log_size) + .collect_vec(); + + layers.push(B::commit_on_layer(log_size, layers.last(), &layer_columns)); } + layers.reverse(); + Self { layers } } /// Decommits to columns on the given queries. From 9afd39db0edc047669c30e94cc19ed703d5c7f62 Mon Sep 17 00:00:00 2001 From: Jeremy Felder Date: Tue, 28 Jan 2025 09:51:35 +0200 Subject: [PATCH 68/69] Profiling ranges for commit-layer --- crates/prover/src/core/backend/icicle/mod.rs | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/crates/prover/src/core/backend/icicle/mod.rs b/crates/prover/src/core/backend/icicle/mod.rs index 51380b8b5..4fe2e2835 100644 --- a/crates/prover/src/core/backend/icicle/mod.rs +++ b/crates/prover/src/core/backend/icicle/mod.rs @@ -176,19 +176,26 @@ impl MerkleOps for IcicleBackend { prev_layer: Option<&Col::Hash>>, columns: &[&Col], ) -> Col::Hash> { + nvtx::range_push!("[ICICLE] Extract prev_layer"); let prev_layer = match prev_layer { Some(layer) => unsafe{ transmute(layer) }, None => &Vec::::with_capacity(0), }; + nvtx::range_pop!(); + nvtx::range_push!("[ICICLE] Create matrices"); let mut columns_as_matrices = vec![]; for &col in columns { columns_as_matrices.push(Matrix::from_slice(col, 4, col.len())); } - + nvtx::range_pop!(); + + nvtx::range_push!("[ICICLE] Cuda malloc digests"); let digest_bytes = (1 << log_size) * 32; let mut d_digests_slice = DeviceVec::cuda_malloc(digest_bytes).unwrap(); - + nvtx::range_pop!(); + + nvtx::range_push!("[ICICLE] cuda commit layer"); blake2s_commit_layer( HostSlice::from_slice(prev_layer), false, @@ -198,13 +205,16 @@ impl MerkleOps for IcicleBackend { 1 << log_size, &mut d_digests_slice[..], ).unwrap(); - + nvtx::range_pop!(); + + nvtx::range_push!("[ICICLE] Copy digests back to host"); let mut digests = vec![0u8; digest_bytes]; let mut digests_slice = HostSlice::from_mut_slice(&mut digests); - + d_digests_slice .copy_to_host(&mut digests_slice) .unwrap(); + nvtx::range_pop!(); unsafe { std::mem::transmute(digests) } } From 5ccd3372cb4715b2a712d083a757a7eed11c3cb0 Mon Sep 17 00:00:00 2001 From: Jeremy Felder Date: Tue, 28 Jan 2025 15:11:19 +0200 Subject: [PATCH 69/69] Update blake commit_layer to work with DeviceVecs --- Cargo.lock | 8 +- crates/prover/Cargo.toml | 8 +- .../prover/src/core/backend/icicle/blake2s.rs | 213 +++++++++++++----- .../src/core/backend/icicle/poseidon252.rs | 2 - 4 files changed, 159 insertions(+), 72 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 425473654..261efc8e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -754,7 +754,7 @@ dependencies = [ [[package]] name = "icicle-core" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=ff80aea90686d2001989f0cb5c7a0ed652c395ae#ff80aea90686d2001989f0cb5c7a0ed652c395ae" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=6c74c54997a104702debe24fac68ab51ec3fe154#6c74c54997a104702debe24fac68ab51ec3fe154" dependencies = [ "criterion 0.3.6", "hex", @@ -765,7 +765,7 @@ dependencies = [ [[package]] name = "icicle-cuda-runtime" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=ff80aea90686d2001989f0cb5c7a0ed652c395ae#ff80aea90686d2001989f0cb5c7a0ed652c395ae" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=6c74c54997a104702debe24fac68ab51ec3fe154#6c74c54997a104702debe24fac68ab51ec3fe154" dependencies = [ "bindgen", "bitflags 1.3.2", @@ -774,7 +774,7 @@ dependencies = [ [[package]] name = "icicle-hash" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=ff80aea90686d2001989f0cb5c7a0ed652c395ae#ff80aea90686d2001989f0cb5c7a0ed652c395ae" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=6c74c54997a104702debe24fac68ab51ec3fe154#6c74c54997a104702debe24fac68ab51ec3fe154" dependencies = [ "cmake", "icicle-core", @@ -784,7 +784,7 @@ dependencies = [ [[package]] name = "icicle-m31" version = "2.8.0" -source = "git+https://github.com/ingonyama-zk/icicle.git?rev=ff80aea90686d2001989f0cb5c7a0ed652c395ae#ff80aea90686d2001989f0cb5c7a0ed652c395ae" +source = "git+https://github.com/ingonyama-zk/icicle.git?rev=6c74c54997a104702debe24fac68ab51ec3fe154#6c74c54997a104702debe24fac68ab51ec3fe154" dependencies = [ "cmake", "criterion 0.3.6", diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index 71759ab69..e578db8a3 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -28,10 +28,10 @@ tracing.workspace = true rayon = { version = "1.10.0", optional = true } serde = { version = "1.0", features = ["derive"] } -icicle-cuda-runtime = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="ff80aea90686d2001989f0cb5c7a0ed652c395ae"} -icicle-core = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="ff80aea90686d2001989f0cb5c7a0ed652c395ae"} -icicle-m31 = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="ff80aea90686d2001989f0cb5c7a0ed652c395ae"} -icicle-hash = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="ff80aea90686d2001989f0cb5c7a0ed652c395ae"} +icicle-cuda-runtime = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="6c74c54997a104702debe24fac68ab51ec3fe154"} +icicle-core = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="6c74c54997a104702debe24fac68ab51ec3fe154"} +icicle-m31 = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="6c74c54997a104702debe24fac68ab51ec3fe154"} +icicle-hash = { git = "https://github.com/ingonyama-zk/icicle.git", optional = true, rev="6c74c54997a104702debe24fac68ab51ec3fe154"} nvtx = { version = "*", optional = true } diff --git a/crates/prover/src/core/backend/icicle/blake2s.rs b/crates/prover/src/core/backend/icicle/blake2s.rs index 111192f70..64ea544fb 100644 --- a/crates/prover/src/core/backend/icicle/blake2s.rs +++ b/crates/prover/src/core/backend/icicle/blake2s.rs @@ -1,10 +1,14 @@ use std::cmp::Reverse; +use std::mem::transmute; +use std::fmt::Debug; +use std::ops::Deref; use icicle_core::tree::{merkle_tree_digests_len, TreeBuilderConfig}; use icicle_core::Matrix; -use icicle_cuda_runtime::memory::HostSlice; -use icicle_hash::blake2s::build_blake2s_mmcs; +use icicle_hash::blake2s::blake2s_commit_layer; +use icicle_cuda_runtime::memory::{HostSlice, DeviceVec, DeviceSlice, HostOrDeviceSlice}; use itertools::Itertools; +use icicle_core::vec_ops::{are_bytes_equal, VecOpsConfig}; use super::IcicleBackend; use crate::core::backend::{BackendForChannel, Col, Column, ColumnOps, CpuBackend}; @@ -13,84 +17,169 @@ use crate::core::vcs::blake2_hash::Blake2sHash; use crate::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher}; use crate::core::vcs::ops::{MerkleHasher, MerkleOps}; +impl BackendForChannel for IcicleBackend {} + impl ColumnOps for IcicleBackend { - type Column = Vec; + type Column = DeviceColumnBlake; fn bit_reverse_column(_column: &mut Self::Column) { unimplemented!() } } -impl MerkleOps for IcicleBackend { - const COMMIT_IMPLEMENTED: bool = true; - - fn commit_columns( - columns: Vec<&Col>, - ) -> Vec::Hash>> { - let mut config = TreeBuilderConfig::default(); - config.arity = 2; - config.digest_elements = 32; - config.sort_inputs = false; - - nvtx::range_push!("[ICICLE] log_max"); - let log_max = columns - .iter() - .sorted_by_key(|c| Reverse(c.len())) - .next() - .unwrap() - .len() - .ilog2(); - nvtx::range_pop!(); - let mut matrices = vec![]; - nvtx::range_push!("[ICICLE] create matrix"); - for col in columns.into_iter().sorted_by_key(|c| Reverse(c.len())) { - matrices.push(Matrix::from_slice(col, 4, col.len())); +pub struct DeviceColumnBlake { + pub data: DeviceVec, + pub length: usize, +} + +impl PartialEq for DeviceColumnBlake { + fn eq(&self, other: &Self) -> bool { + if self.length != other.length { + return false; } - nvtx::range_pop!(); - nvtx::range_push!("[ICICLE] merkle_tree_digests_len"); - let digests_len = merkle_tree_digests_len(log_max as u32, 2, 32); - nvtx::range_pop!(); - let mut digests = vec![0u8; digests_len]; - let digests_slice = HostSlice::from_mut_slice(&mut digests); + let cfg = VecOpsConfig::default(); + are_bytes_equal::(self.data.deref(), other.data.deref(), &cfg) + } + + fn ne(&self, other: &Self) -> bool { + !self.eq(other) + } +} - nvtx::range_push!("[ICICLE] build_blake2s_mmcs"); - build_blake2s_mmcs(&matrices, digests_slice, &config).unwrap(); - nvtx::range_pop!(); +impl Clone for DeviceColumnBlake { + fn clone(&self) -> Self { + let mut data = DeviceVec::cuda_malloc(self.length).unwrap(); + data.copy_from_device(&self.data); + Self{data, length: self.length} + } +} + +impl Debug for DeviceColumnBlake { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let data = self.to_cpu(); + f.debug_struct("DeviceColumnBlake").field("data", &data.as_slice()).field("length", &self.length).finish() + } +} - let mut digests: &[::Hash] = - unsafe { std::mem::transmute(digests.as_mut_slice()) }; - // Transmute digests into stwo format - let mut layers = vec![]; - let mut offset = 0usize; - nvtx::range_push!("[ICICLE] convert to CPU layer"); - for log in 0..=log_max { - let inv_log = log_max - log; - let number_of_rows = 1 << inv_log; - - let mut layer = vec![]; - layer.extend_from_slice(&digests[offset..offset + number_of_rows]); - layers.push(layer); - - if log != log_max { - offset += number_of_rows; - } +impl DeviceColumnBlake { + pub fn from_cpu(values: &[Blake2sHash]) -> Self { + let length = values.len(); + let mut data: DeviceVec = DeviceVec::cuda_malloc(length).unwrap(); + data.copy_from_host(HostSlice::from_slice(&values)); + Self{data, length} + } + + pub fn len(&self) -> usize { + self.length + } +} + +impl Column for DeviceColumnBlake { + fn zeros(length: usize) -> Self { + let mut data = DeviceVec::cuda_malloc(length).unwrap(); + + let host_data = vec![Blake2sHash::default(); length]; + data.copy_from_host(HostSlice::from_slice(&host_data)); + + Self { data, length } + } + + #[allow(clippy::uninit_vec)] + unsafe fn uninitialized(length: usize) -> Self { + let mut data = DeviceVec::cuda_malloc(length).unwrap(); + Self { data, length } + } + + fn to_cpu(&self) -> Vec { + let mut host_data = Vec::::with_capacity(self.length); + self.data + .copy_to_host(HostSlice::from_mut_slice(&mut host_data)); + host_data + } + + fn len(&self) -> usize { + self.length + } + + fn at(&self, index: usize) -> Blake2sHash { + let mut host_vec = vec![Blake2sHash::default(); 1]; + unsafe { + DeviceSlice::from_slice(std::slice::from_raw_parts(self.data.as_ptr().add(index), 1)) + .copy_to_host(HostSlice::from_mut_slice(&mut host_vec)) + .unwrap(); } + host_vec[0] + } - layers.reverse(); - nvtx::range_pop!(); - layers + fn set(&mut self, index: usize, value: Blake2sHash) { + let host_vec = vec![value; 1]; + unsafe { + DeviceSlice::from_mut_slice(std::slice::from_raw_parts_mut( + self.data.as_mut_ptr().add(index), + 1, + )) + .copy_from_host(HostSlice::from_slice(&host_vec)) + .unwrap(); + } + } +} + +impl FromIterator for DeviceColumnBlake { + fn from_iter>(iter: I) -> Self { + let host_data = iter.into_iter().collect_vec(); + let length = host_data.len(); + let mut data = DeviceVec::cuda_malloc(length).unwrap(); + data.copy_from_host(HostSlice::from_slice(&host_data)) + .unwrap(); + + Self { data, length } } +} +impl MerkleOps for IcicleBackend { fn commit_on_layer( log_size: u32, prev_layer: Option<&Col::Hash>>, columns: &[&Col], ) -> Col::Hash> { - // todo!() - >::commit_on_layer( - log_size, prev_layer, columns, - ) + nvtx::range_push!("[ICICLE] Extract prev_layer"); + let prev_layer = match prev_layer { + Some(layer) => layer, + // Hacky, since creating a DeviceVec of size 0 seems to not work + // NOTE: blake2s_commit_layer uses a length of 1 as an indicator that + // the prev_layer does not exist + None => unsafe { &::Hash> as Column>::uninitialized(1) }, + }; + nvtx::range_pop!(); + + nvtx::range_push!("[ICICLE] Create matrices"); + let mut columns_as_matrices = vec![]; + for &col in columns { + let col_as_slice = col.data[..].as_slice(); + columns_as_matrices.push(Matrix::from_slice(&col_as_slice, 4, col.len())); + } + nvtx::range_pop!(); + + nvtx::range_push!("[ICICLE] Cuda malloc digests"); + let digests_bytes = (1 << log_size) * 32; + let mut d_digests_slice = DeviceVec::cuda_malloc(digests_bytes).unwrap(); + nvtx::range_pop!(); + + nvtx::range_push!("[ICICLE] cuda commit layer"); + blake2s_commit_layer( + &(unsafe { transmute::<&DeviceVec, &DeviceVec>(&prev_layer.data) })[..], + false, + &columns_as_matrices, + false, + columns.len() as u32, + 1 << log_size, + &mut d_digests_slice[..], + ).unwrap(); + nvtx::range_pop!(); + + DeviceColumnBlake { + data: unsafe { transmute(d_digests_slice) }, + length: 1 << log_size, + } } } - -impl BackendForChannel for IcicleBackend {} diff --git a/crates/prover/src/core/backend/icicle/poseidon252.rs b/crates/prover/src/core/backend/icicle/poseidon252.rs index 91886318a..8bb024ce2 100644 --- a/crates/prover/src/core/backend/icicle/poseidon252.rs +++ b/crates/prover/src/core/backend/icicle/poseidon252.rs @@ -15,8 +15,6 @@ impl ColumnOps for IcicleBackend { } impl MerkleOps for IcicleBackend { - const COMMIT_IMPLEMENTED: bool = false; - fn commit_on_layer( log_size: u32, prev_layer: Option<&Col::Hash>>,