diff --git a/Cargo.lock b/Cargo.lock index 6be119d21..c80f87249 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1141,6 +1141,10 @@ dependencies = [ "multilinear_extensions", "num-bigint", "num-integer", + "p3-field", + "p3-goldilocks", + "p3-mds", + "p3-symmetric", "plonky2", "poseidon", "rand", diff --git a/mpcs/Cargo.toml b/mpcs/Cargo.toml index f977328cc..dbdedfdfd 100644 --- a/mpcs/Cargo.toml +++ b/mpcs/Cargo.toml @@ -25,6 +25,10 @@ num-bigint = "0.4" num-integer = "0.1" plonky2.workspace = true poseidon.workspace = true +p3-field.workspace = true +p3-goldilocks.workspace = true +p3-mds.workspace = true +p3-symmetric.workspace = true rand.workspace = true rand_chacha.workspace = true rayon = { workspace = true, optional = true } diff --git a/mpcs/src/basefold.rs b/mpcs/src/basefold.rs index 6204ed038..6f6e91127 100644 --- a/mpcs/src/basefold.rs +++ b/mpcs/src/basefold.rs @@ -28,13 +28,15 @@ pub use encoding::{ }; use ff_ext::ExtensionField; use multilinear_extensions::mle::MultilinearExtension; +use p3_mds::MdsPermutation; +use poseidon::SPONGE_WIDTH; use query_phase::{ BatchedQueriesResultWithMerklePath, QueriesResultWithMerklePath, SimpleBatchQueriesResultWithMerklePath, batch_prover_query_phase, batch_verifier_query_phase, prover_query_phase, simple_batch_prover_query_phase, simple_batch_verifier_query_phase, verifier_query_phase, }; -use std::{borrow::BorrowMut, ops::Deref}; +use std::{borrow::BorrowMut, fmt::Debug, ops::Deref}; pub use structure::BasefoldSpec; use structure::{BasefoldProof, ProofQueriesResultWithMerklePath}; use transcript::Transcript; @@ -79,7 +81,7 @@ enum PolyEvalsCodeword { TooBig(usize), } -impl> Basefold +impl, Mds> Basefold where E: Serialize + DeserializeOwned, E::BaseField: Serialize + DeserializeOwned, @@ -265,18 +267,20 @@ where /// positions are (i >> k) and (i >> k) XOR 1. /// (c) The verifier checks that the folding has been correctly computed /// at these positions. -impl> PolynomialCommitmentScheme for Basefold +impl, Mds> PolynomialCommitmentScheme + for Basefold where E: Serialize + DeserializeOwned, E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default + Debug, { type Param = BasefoldParams; type ProverParam = BasefoldProverParams; type VerifierParam = BasefoldVerifierParams; - type CommitmentWithWitness = BasefoldCommitmentWithWitness; + type CommitmentWithWitness = BasefoldCommitmentWithWitness; type Commitment = BasefoldCommitment; type CommitmentChunk = Digest; - type Proof = BasefoldProof; + type Proof = BasefoldProof; fn setup(poly_size: usize) -> Result { let pp = >::setup(log2_strict(poly_size)); @@ -323,7 +327,7 @@ where // (2) The encoding of the coefficient vector (need an interpolation) let ret = match Self::get_poly_bh_evals_and_codeword(pp, poly) { PolyEvalsCodeword::Normal((bh_evals, codeword)) => { - let codeword_tree = MerkleTree::::from_leaves(codeword); + let codeword_tree = MerkleTree::::from_leaves(codeword); // All these values are stored in the `CommitmentWithWitness` because // they are useful in opening, and we don't want to recompute them. @@ -336,7 +340,7 @@ where }) } PolyEvalsCodeword::TooSmall(evals) => { - let codeword_tree = MerkleTree::::from_leaves(evals.clone()); + let codeword_tree = MerkleTree::::from_leaves(evals.clone()); // All these values are stored in the `CommitmentWithWitness` because // they are useful in opening, and we don't want to recompute them. @@ -412,7 +416,7 @@ where } }) .collect::<(Vec<_>, Vec<_>)>(); - let codeword_tree = MerkleTree::::from_batch_leaves(codewords); + let codeword_tree = MerkleTree::::from_batch_leaves(codewords); Self::CommitmentWithWitness { codeword_tree, polynomials_bh_evals: bh_evals, @@ -432,7 +436,7 @@ where } }) .collect::>(); - let codeword_tree = MerkleTree::::from_batch_leaves(bh_evals.clone()); + let codeword_tree = MerkleTree::::from_batch_leaves(bh_evals.clone()); Self::CommitmentWithWitness { codeword_tree, polynomials_bh_evals: bh_evals, @@ -494,7 +498,7 @@ where // part, the prover needs to prepare the answers to the // queries, so the prover needs the oracles and the Merkle // trees built over them. - let (trees, commit_phase_proof) = commit_phase::( + let (trees, commit_phase_proof) = commit_phase::( &pp.encoding_params, point, comm, @@ -594,7 +598,7 @@ where evals.iter().map(Evaluation::value), &evals .iter() - .map(|eval| E::from(1 << (num_vars - points[eval.point()].len()))) + .map(|eval| E::from_canonical_u64(1 << (num_vars - points[eval.point()].len()))) .collect_vec(), &poly_iter_ext(&eq_xt).take(evals.len()).collect_vec(), ); @@ -645,8 +649,8 @@ where inner_product( &poly_iter_ext(poly).collect_vec(), build_eq_x_r_vec(point).iter(), - ) * scalar - * E::from(1 << (num_vars - poly.num_vars)) + ) * *scalar + * E::from_canonical_u64(1 << (num_vars - poly.num_vars)) // When this polynomial is smaller, it will be repeatedly summed over the cosets of the hypercube }) .sum::(); @@ -719,7 +723,7 @@ where let point = challenges; - let (trees, commit_phase_proof) = batch_commit_phase::( + let (trees, commit_phase_proof) = batch_commit_phase::( &pp.encoding_params, &point, comms, @@ -815,7 +819,7 @@ where // The remaining tasks for the prover is to prove that // sum_i coeffs[i] poly_evals[i] is equal to // the new target sum, where coeffs is computed as follows - let (trees, commit_phase_proof) = simple_batch_commit_phase::( + let (trees, commit_phase_proof) = simple_batch_commit_phase::( &pp.encoding_params, point, &eq_xt, @@ -864,7 +868,7 @@ where if proof.is_trivial() { let trivial_proof = &proof.trivial_proof; - let merkle_tree = MerkleTree::from_batch_leaves(trivial_proof.clone()); + let merkle_tree = MerkleTree::::from_batch_leaves(trivial_proof.clone()); if comm.root() == merkle_tree.root() { return Ok(()); } else { @@ -919,7 +923,7 @@ where let mut eq = build_eq_x_r_vec(&point[..point.len() - fold_challenges.len()]); eq.par_iter_mut().for_each(|e| *e *= coeff); - verifier_query_phase::( + verifier_query_phase::( queries.as_slice(), &vp.encoding_params, query_result_with_merkle_path, @@ -977,7 +981,7 @@ where evals.iter().map(Evaluation::value), &evals .iter() - .map(|eval| E::from(1 << (num_vars - points[eval.point()].len()))) + .map(|eval| E::from_canonical_u64(1 << (num_vars - points[eval.point()].len()))) .collect_vec(), &poly_iter_ext(&eq_xt).take(evals.len()).collect_vec(), ); @@ -1044,7 +1048,7 @@ where ); eq.par_iter_mut().for_each(|e| *e *= coeff); - batch_verifier_query_phase::( + batch_verifier_query_phase::( queries.as_slice(), &vp.encoding_params, query_result_with_merkle_path, @@ -1079,7 +1083,7 @@ where if proof.is_trivial() { let trivial_proof = &proof.trivial_proof; - let merkle_tree = MerkleTree::from_batch_leaves(trivial_proof.clone()); + let merkle_tree = MerkleTree::::from_batch_leaves(trivial_proof.clone()); if comm.root() == merkle_tree.root() { return Ok(()); } else { @@ -1144,7 +1148,7 @@ where let mut eq = build_eq_x_r_vec(&point[..point.len() - fold_challenges.len()]); eq.par_iter_mut().for_each(|e| *e *= coeff); - simple_batch_verifier_query_phase::( + simple_batch_verifier_query_phase::( queries.as_slice(), &vp.encoding_params, query_result_with_merkle_path, @@ -1165,15 +1169,20 @@ where } } -impl> NoninteractivePCS for Basefold +impl, Mds> NoninteractivePCS + for Basefold where E: Serialize + DeserializeOwned, E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default + Debug, { } #[cfg(test)] mod test { + use ff_ext::GoldilocksExt2; + use p3_goldilocks::MdsMatrixGoldilocks; + use crate::{ basefold::Basefold, test_util::{ @@ -1181,83 +1190,93 @@ mod test { run_commit_open_verify, run_simple_batch_commit_open_verify, }, }; - use goldilocks::GoldilocksExt2; use super::{BasefoldRSParams, structure::BasefoldBasecodeParams}; - type PcsGoldilocksRSCode = Basefold; - type PcsGoldilocksBaseCode = Basefold; + type PcsGoldilocksRSCode = Basefold; + type PcsGoldilocksBaseCode = + Basefold; #[test] fn commit_open_verify_goldilocks() { for gen_rand_poly in [gen_rand_poly_base, gen_rand_poly_ext] { // Challenge is over extension field, poly over the base field - run_commit_open_verify::(gen_rand_poly, 10, 11); - // Test trivial proof with small num vars - run_commit_open_verify::(gen_rand_poly, 4, 6); - // Challenge is over extension field, poly over the base field - run_commit_open_verify::(gen_rand_poly, 10, 11); - // Test trivial proof with small num vars - run_commit_open_verify::(gen_rand_poly, 4, 6); - } - } - - #[test] - fn simple_batch_commit_open_verify_goldilocks() { - for gen_rand_poly in [gen_rand_poly_base, gen_rand_poly_ext] { - // Both challenge and poly are over base field - run_simple_batch_commit_open_verify::( - gen_rand_poly, - 10, - 11, - 1, - ); - run_simple_batch_commit_open_verify::( + run_commit_open_verify::( gen_rand_poly, 10, 11, - 4, ); // Test trivial proof with small num vars - run_simple_batch_commit_open_verify::( + run_commit_open_verify::( gen_rand_poly, 4, 6, - 4, ); - // Both challenge and poly are over base field - run_simple_batch_commit_open_verify::( - gen_rand_poly, - 10, - 11, - 1, - ); - run_simple_batch_commit_open_verify::( + // Challenge is over extension field, poly over the base field + run_commit_open_verify::( gen_rand_poly, 10, 11, - 4, ); // Test trivial proof with small num vars - run_simple_batch_commit_open_verify::( + run_commit_open_verify::( gen_rand_poly, 4, 6, - 4, ); } } + #[test] + fn simple_batch_commit_open_verify_goldilocks() { + for gen_rand_poly in [gen_rand_poly_base, gen_rand_poly_ext] { + // Both challenge and poly are over base field + run_simple_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksBaseCode, + MdsMatrixGoldilocks, + >(gen_rand_poly, 10, 11, 1); + run_simple_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksBaseCode, + MdsMatrixGoldilocks, + >(gen_rand_poly, 10, 11, 4); + // Test trivial proof with small num vars + run_simple_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksBaseCode, + MdsMatrixGoldilocks, + >(gen_rand_poly, 4, 6, 4); + // Both challenge and poly are over base field + run_simple_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksRSCode, + MdsMatrixGoldilocks, + >(gen_rand_poly, 10, 11, 1); + run_simple_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksRSCode, + MdsMatrixGoldilocks, + >(gen_rand_poly, 10, 11, 4); + // Test trivial proof with small num vars + run_simple_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksRSCode, + MdsMatrixGoldilocks, + >(gen_rand_poly, 4, 6, 4); + } + } + #[test] fn batch_commit_open_verify() { for gen_rand_poly in [gen_rand_poly_base, gen_rand_poly_ext] { // Both challenge and poly are over base field - run_batch_commit_open_verify::( - gen_rand_poly, - 10, - 11, - ); - run_batch_commit_open_verify::( + run_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksBaseCode, + MdsMatrixGoldilocks, + >(gen_rand_poly, 10, 11); + run_batch_commit_open_verify::( gen_rand_poly, 10, 11, diff --git a/mpcs/src/basefold/commit_phase.rs b/mpcs/src/basefold/commit_phase.rs index 55a6acea5..f61424fda 100644 --- a/mpcs/src/basefold/commit_phase.rs +++ b/mpcs/src/basefold/commit_phase.rs @@ -16,6 +16,8 @@ use crate::util::{ use ark_std::{end_timer, start_timer}; use ff_ext::ExtensionField; use itertools::Itertools; +use p3_mds::MdsPermutation; +use poseidon::SPONGE_WIDTH; use serde::{Serialize, de::DeserializeOwned}; use transcript::Transcript; @@ -30,16 +32,17 @@ use rayon::prelude::{ use super::structure::BasefoldCommitmentWithWitness; // outputs (trees, sumcheck_oracles, oracles, bh_evals, eq, eval) -pub fn commit_phase>( +pub fn commit_phase, Mds>( pp: &>::ProverParameters, point: &[E], - comm: &BasefoldCommitmentWithWitness, + comm: &BasefoldCommitmentWithWitness, transcript: &mut impl Transcript, num_vars: usize, num_rounds: usize, -) -> (Vec>, BasefoldCommitPhaseProof) +) -> (Vec>, BasefoldCommitPhaseProof) where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { let timer = start_timer!(|| "Commit phase"); #[cfg(feature = "sanity-check")] @@ -98,7 +101,7 @@ where ); if i > 0 { - let running_tree = MerkleTree::::from_inner_leaves( + let running_tree = MerkleTree::::from_inner_leaves( running_tree_inner, FieldType::Ext(running_oracle), ); @@ -116,8 +119,8 @@ where // Then the oracle will be used to fold to the next oracle in the next // round. After that, this oracle is free to be moved to build the // complete Merkle tree. - running_tree_inner = MerkleTree::::compute_inner_ext(&new_running_oracle); - let running_root = MerkleTree::::root_from_inner(&running_tree_inner); + running_tree_inner = MerkleTree::::compute_inner_ext(&new_running_oracle); + let running_root = MerkleTree::::root_from_inner(&running_tree_inner); write_digest_to_transcript(&running_root, transcript); roots.push(running_root.clone()); @@ -176,17 +179,18 @@ where // outputs (trees, sumcheck_oracles, oracles, bh_evals, eq, eval) #[allow(clippy::too_many_arguments)] -pub fn batch_commit_phase>( +pub fn batch_commit_phase, Mds>( pp: &>::ProverParameters, point: &[E], - comms: &[BasefoldCommitmentWithWitness], + comms: &[BasefoldCommitmentWithWitness], transcript: &mut impl Transcript, num_vars: usize, num_rounds: usize, coeffs: &[E], -) -> (Vec>, BasefoldCommitPhaseProof) +) -> (Vec>, BasefoldCommitPhaseProof) where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { let timer = start_timer!(|| "Batch Commit phase"); assert_eq!(point.len(), num_vars); @@ -266,7 +270,7 @@ where ); if i > 0 { - let running_tree = MerkleTree::::from_inner_leaves( + let running_tree = MerkleTree::::from_inner_leaves( running_tree_inner, FieldType::Ext(running_oracle), ); @@ -277,8 +281,8 @@ where last_sumcheck_message = sum_check_challenge_round(&mut eq, &mut sum_of_all_evals_for_sumcheck, challenge); sumcheck_messages.push(last_sumcheck_message.clone()); - running_tree_inner = MerkleTree::::compute_inner_ext(&new_running_oracle); - let running_root = MerkleTree::::root_from_inner(&running_tree_inner); + running_tree_inner = MerkleTree::::compute_inner_ext(&new_running_oracle); + let running_root = MerkleTree::::root_from_inner(&running_tree_inner); write_digest_to_transcript(&running_root, transcript); roots.push(running_root); @@ -346,17 +350,18 @@ where // outputs (trees, sumcheck_oracles, oracles, bh_evals, eq, eval) #[allow(clippy::too_many_arguments)] -pub fn simple_batch_commit_phase>( +pub fn simple_batch_commit_phase, Mds>( pp: &>::ProverParameters, point: &[E], batch_coeffs: &[E], - comm: &BasefoldCommitmentWithWitness, + comm: &BasefoldCommitmentWithWitness, transcript: &mut impl Transcript, num_vars: usize, num_rounds: usize, -) -> (Vec>, BasefoldCommitPhaseProof) +) -> (Vec>, BasefoldCommitPhaseProof) where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { let timer = start_timer!(|| "Simple batch commit phase"); assert_eq!(point.len(), num_vars); @@ -416,7 +421,7 @@ where ); if i > 0 { - let running_tree = MerkleTree::::from_inner_leaves( + let running_tree = MerkleTree::::from_inner_leaves( running_tree_inner, FieldType::Ext(running_oracle), ); @@ -426,8 +431,8 @@ where if i < num_rounds - 1 { last_sumcheck_message = sum_check_challenge_round(&mut eq, &mut running_evals, challenge); - running_tree_inner = MerkleTree::::compute_inner_ext(&new_running_oracle); - let running_root = MerkleTree::::root_from_inner(&running_tree_inner); + running_tree_inner = MerkleTree::::compute_inner_ext(&new_running_oracle); + let running_root = MerkleTree::::root_from_inner(&running_tree_inner); write_digest_to_transcript(&running_root, transcript); roots.push(running_root); running_oracle = new_running_oracle; diff --git a/mpcs/src/basefold/encoding.rs b/mpcs/src/basefold/encoding.rs index 410d35970..6c3a03d2f 100644 --- a/mpcs/src/basefold/encoding.rs +++ b/mpcs/src/basefold/encoding.rs @@ -173,7 +173,9 @@ pub(crate) mod test_util { pub fn test_codeword_folding>() { let num_vars = 12; - let poly: Vec = (0..(1 << num_vars)).map(|i| E::from(i)).collect(); + let poly: Vec = (0..(1 << num_vars)) + .map(|i| E::from_canonical_u64(i)) + .collect(); let mut poly = FieldType::Ext(poly); let pp: Code::PublicParameters = Code::setup(num_vars); diff --git a/mpcs/src/basefold/encoding/basecode.rs b/mpcs/src/basefold/encoding/basecode.rs index 9fbee84f1..04ea5892d 100644 --- a/mpcs/src/basefold/encoding/basecode.rs +++ b/mpcs/src/basefold/encoding/basecode.rs @@ -10,10 +10,10 @@ use crate::{ }; use aes::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek}; use ark_std::{end_timer, start_timer}; -use ff::{BatchInvert, Field, PrimeField}; use ff_ext::ExtensionField; use generic_array::GenericArray; use multilinear_extensions::mle::FieldType; +use p3_field::{Field, FieldAlgebra, batch_multiplicative_inverse}; use rand::SeedableRng; use rayon::prelude::{ParallelIterator, ParallelSlice, ParallelSliceMut}; @@ -216,7 +216,7 @@ where let x0: E::BaseField = query_root_table_from_rng_aes::(level, index, &mut cipher); let x1 = -x0; - let w = (x1 - x0).invert().unwrap(); + let w = (x1 - x0).try_inverse().unwrap(); (E::from(x0), E::from(x1), E::from(w)) } @@ -351,13 +351,13 @@ pub fn get_table_aes( assert_eq!(flat_table.len(), 1 << lg_n); // Multiply -2 to every element to get the weights. Now weights = { -2x } - let mut weights: Vec = flat_table + let weights: Vec = flat_table .par_iter() .map(|el| E::BaseField::ZERO - *el - *el) .collect(); // Then invert all the elements. Now weights = { -1/2x } - BatchInvert::batch_invert(&mut weights); + let weights = batch_multiplicative_inverse(&weights); // Zip x and -1/2x together. The result is the list { (x, -1/2x) } // What is this -1/2x? It is used in linear interpolation over the domain (x, -x), which @@ -399,13 +399,13 @@ pub fn query_root_table_from_rng_aes( } let pos = ((level_offset + (reverse_bits(index, level) as u128)) - * ((E::BaseField::NUM_BITS as usize).next_power_of_two() as u128)) + * ((E::BaseField::bits() as usize).next_power_of_two() as u128)) .checked_div(8) .unwrap(); cipher.seek(pos); - let bytes = (E::BaseField::NUM_BITS as usize).next_power_of_two() / 8; + let bytes = (E::BaseField::bits() as usize).next_power_of_two() / 8; let mut dest: Vec = vec![0u8; bytes]; cipher.apply_keystream(&mut dest); @@ -417,7 +417,7 @@ mod tests { use crate::basefold::encoding::test_util::test_codeword_folding; use super::*; - use goldilocks::GoldilocksExt2; + use ff_ext::GoldilocksExt2; use multilinear_extensions::mle::DenseMultilinearExtension; #[test] diff --git a/mpcs/src/basefold/encoding/rs.rs b/mpcs/src/basefold/encoding/rs.rs index 2bcac0826..9ae955d28 100644 --- a/mpcs/src/basefold/encoding/rs.rs +++ b/mpcs/src/basefold/encoding/rs.rs @@ -7,9 +7,9 @@ use crate::{ vec_mut, }; use ark_std::{end_timer, start_timer}; -use ff::{Field, PrimeField}; use ff_ext::ExtensionField; use multilinear_extensions::mle::FieldType; +use p3_field::{Field, FieldAlgebra, PrimeField}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; @@ -71,8 +71,7 @@ fn ifft( let n = poly.len(); let lg_n = log2_strict(n); let n_inv = (E::BaseField::ONE + E::BaseField::ONE) - .invert() - .unwrap() + .inverse() .pow([lg_n as u64]); fft(poly, zero_factor, root_table); @@ -310,13 +309,13 @@ where } let mut gamma_powers = Vec::with_capacity(max_message_size_log); let mut gamma_powers_inv = Vec::with_capacity(max_message_size_log); - gamma_powers.push(E::BaseField::MULTIPLICATIVE_GENERATOR); - gamma_powers_inv.push(E::BaseField::MULTIPLICATIVE_GENERATOR.invert().unwrap()); + gamma_powers.push(E::BaseField::GENERATOR); + gamma_powers_inv.push(E::BaseField::GENERATOR.inverse()); for i in 1..max_message_size_log + Spec::get_rate_log() { gamma_powers.push(gamma_powers[i - 1].square()); gamma_powers_inv.push(gamma_powers_inv[i - 1].square()); } - let inv_of_two = E::BaseField::from(2).invert().unwrap(); + let inv_of_two = E::BaseField::from_canonical_u64(2).inverse(); gamma_powers_inv.iter_mut().for_each(|x| *x *= inv_of_two); pp.fft_root_table .truncate(max_message_size_log + Spec::get_rate_log()); @@ -493,7 +492,7 @@ impl RSCode { let k = 1 << (full_message_size_log - lg_m); coset_fft( &mut ret, - E::BaseField::MULTIPLICATIVE_GENERATOR.pow([k]), + E::BaseField::GENERATOR.pow([k]), Spec::get_rate_log(), fft_root_table, ); @@ -514,7 +513,7 @@ impl RSCode { let x0 = E::BaseField::ROOT_OF_UNITY .pow([1 << (E::BaseField::S - (level as u32 + 1))]) .pow([index as u64]) - * E::BaseField::MULTIPLICATIVE_GENERATOR + * E::BaseField::GENERATOR .pow([1 << (full_message_size_log + Spec::get_rate_log() - level - 1)]); let x1 = -x0; let w = (x1 - x0).invert().unwrap(); @@ -546,19 +545,24 @@ fn naive_fft(poly: &[E], rate: usize, shift: E::BaseField) -> #[cfg(test)] mod tests { + use ff_ext::GoldilocksExt2; + use p3_goldilocks::Goldilocks; + use crate::{ basefold::encoding::test_util::test_codeword_folding, util::{field_type_index_ext, plonky2_util::reverse_index_bits_in_place_field_type}, }; + use ff_ext::FromUniformBytes; use super::*; - use goldilocks::{Goldilocks, GoldilocksExt2}; #[test] fn test_naive_fft() { let num_vars = 5; - let poly: Vec = (0..(1 << num_vars)).map(GoldilocksExt2::from).collect(); + let poly: Vec = (0..(1 << num_vars)) + .map(GoldilocksExt2::from_canonical_u64) + .collect(); let mut poly2 = FieldType::Ext(poly.clone()); let naive = naive_fft::(&poly, 1, Goldilocks::ONE); @@ -583,15 +587,10 @@ mod tests { .collect(); let mut poly2 = FieldType::Ext(poly.clone()); - let naive = naive_fft::(&poly, 1, Goldilocks::MULTIPLICATIVE_GENERATOR); + let naive = naive_fft::(&poly, 1, Goldilocks::GENERATOR); let root_table = fft_root_table(num_vars); - coset_fft::( - &mut poly2, - Goldilocks::MULTIPLICATIVE_GENERATOR, - 0, - &root_table, - ); + coset_fft::(&mut poly2, Goldilocks::GENERATOR, 0, &root_table); let poly2 = match poly2 { FieldType::Ext(coeffs) => coeffs, @@ -613,19 +612,10 @@ mod tests { poly2.as_mut_slice()[..poly.len()].copy_from_slice(poly.as_slice()); let mut poly2 = FieldType::Ext(poly2.clone()); - let naive = naive_fft::( - &poly, - 1 << rate_bits, - Goldilocks::MULTIPLICATIVE_GENERATOR, - ); + let naive = naive_fft::(&poly, 1 << rate_bits, Goldilocks::GENERATOR); let root_table = fft_root_table(num_vars + rate_bits); - coset_fft::( - &mut poly2, - Goldilocks::MULTIPLICATIVE_GENERATOR, - rate_bits, - &root_table, - ); + coset_fft::(&mut poly2, Goldilocks::GENERATOR, rate_bits, &root_table); let poly2 = match poly2 { FieldType::Ext(coeffs) => coeffs, @@ -638,7 +628,9 @@ mod tests { fn test_ifft() { let num_vars = 5; - let poly: Vec = (0..(1 << num_vars)).map(GoldilocksExt2::from).collect(); + let poly: Vec = (0..(1 << num_vars)) + .map(GoldilocksExt2::from_canonical_u64) + .collect(); let mut poly = FieldType::Ext(poly); let original = poly.clone(); @@ -686,14 +678,14 @@ mod tests { pub fn test_colinearity() { let num_vars = 10; - let poly: Vec = (0..(1 << num_vars)).map(E::from).collect(); + let poly: Vec = (0..(1 << num_vars)).map(E::from_canonical_u64).collect(); let poly = FieldType::Ext(poly); let pp = >::setup(num_vars); let (pp, _) = Code::trim(pp, num_vars).unwrap(); let mut codeword = Code::encode(&pp, &poly); reverse_index_bits_in_place_field_type(&mut codeword); - let challenge = E::from(2); + let challenge = E::from_canonical_u64(2); let folded_codeword = Code::fold_bitreversed_codeword(&pp, &codeword, challenge); let codeword = match codeword { FieldType::Ext(coeffs) => coeffs, @@ -712,8 +704,8 @@ mod tests { // which is equivalent to // (x0-challenge)*(b[1]-a) = (x1-challenge)*(b[0]-a) assert_eq!( - (x0 - challenge) * (b[1] - a), - (x1 - challenge) * (b[0] - a), + (x0 - challenge) * (b[1] - *a), + (x1 - challenge) * (b[0] - *a), "failed for i = {}", i ); @@ -724,7 +716,7 @@ mod tests { pub fn test_low_degree() { let num_vars = 10; - let poly: Vec = (0..(1 << num_vars)).map(E::from).collect(); + let poly: Vec = (0..(1 << num_vars)).map(E::from_canonical_u64).collect(); let poly = FieldType::Ext(poly); let pp = >::setup(num_vars); @@ -789,7 +781,7 @@ mod tests { "check low degree of (left-right)*omega^(-i)", ); - let challenge = E::from(2); + let challenge = E::from_canonical_u64(2); let folded_codeword = Code::fold_bitreversed_codeword(&pp, &codeword, challenge); let c_fold = folded_codeword[0]; let c_fold1 = folded_codeword[folded_codeword.len() >> 1]; @@ -800,7 +792,7 @@ mod tests { // The top level folding coefficient should have shift factor gamma let folding_coeffs = Code::prover_folding_coeffs(&pp, log2_strict(codeword.len()) - 1, 0); - assert_eq!(folding_coeffs.0, E::from(F::MULTIPLICATIVE_GENERATOR)); + assert_eq!(folding_coeffs.0, E::from(F::GENERATOR)); assert_eq!(folding_coeffs.0 + folding_coeffs.1, E::ZERO); assert_eq!( (folding_coeffs.1 - folding_coeffs.0) * folding_coeffs.2, @@ -815,17 +807,16 @@ mod tests { // So the folded value should be equal to // (gamma^{-1} * alpha * (c0 - c_mid) + (c0 + c_mid)) / 2 assert_eq!( - c_fold * F::MULTIPLICATIVE_GENERATOR * F::from(2), - challenge * (c0 - c_mid) + (c0 + c_mid) * F::MULTIPLICATIVE_GENERATOR + c_fold * F::GENERATOR * F::from_canonical_u64(2), + challenge * (c0 - c_mid) + (c0 + c_mid) * F::GENERATOR ); assert_eq!( - c_fold * F::MULTIPLICATIVE_GENERATOR * F::from(2), - challenge * left_right_diff[0] + left_right_sum[0] * F::MULTIPLICATIVE_GENERATOR + c_fold * F::GENERATOR * F::from_canonical_u64(2), + challenge * left_right_diff[0] + left_right_sum[0] * F::GENERATOR ); assert_eq!( - c_fold * F::from(2), - challenge * left_right_diff[0] * F::MULTIPLICATIVE_GENERATOR.invert().unwrap() - + left_right_sum[0] + c_fold * F::from_canonical_u64(2), + challenge * left_right_diff[0] * F::GENERATOR.inverse() + left_right_sum[0] ); let folding_coeffs = Code::prover_folding_coeffs(&pp, log2_strict(codeword.len()) - 1, 1); @@ -835,8 +826,7 @@ mod tests { assert_eq!(root_of_unity.pow([(codeword.len() >> 1) as u64]), -F::ONE); assert_eq!( folding_coeffs.0, - E::from(F::MULTIPLICATIVE_GENERATOR) - * E::from(root_of_unity).pow([(codeword.len() >> 2) as u64]) + E::from(F::GENERATOR) * E::from(root_of_unity).pow([(codeword.len() >> 2) as u64]) ); assert_eq!(folding_coeffs.0 + folding_coeffs.1, E::ZERO); assert_eq!( @@ -849,14 +839,14 @@ mod tests { // The coefficients are respectively 1/2 and gamma^{-1}/2 * alpha. // In another word, the folded codeword multipled by 2 is the linear // combination by coeffs: 1 and gamma^{-1} * alpha - let gamma_inv = F::MULTIPLICATIVE_GENERATOR.invert().unwrap(); + let gamma_inv = F::GENERATOR.inverse(); let b = challenge * gamma_inv; let folded_codeword_vec = match &folded_codeword { FieldType::Ext(coeffs) => coeffs.clone(), _ => panic!("Wrong field type"), }; assert_eq!( - c_fold * F::from(2), + c_fold * F::from_canonical_u64(2), left_right_diff[0] * b + left_right_sum[0] ); for (i, (c, (diff, sum))) in folded_codeword_vec @@ -864,7 +854,7 @@ mod tests { .zip(left_right_diff.iter().zip(left_right_sum.iter())) .enumerate() { - assert_eq!(*c + c, *sum + b * diff, "failed for i = {}", i); + assert_eq!(*c + *c, *sum + b * *diff, "failed for i = {}", i); } check_low_degree(&folded_codeword, "low degree check for folded"); diff --git a/mpcs/src/basefold/query_phase.rs b/mpcs/src/basefold/query_phase.rs index 9ec15d36b..a72fefb83 100644 --- a/mpcs/src/basefold/query_phase.rs +++ b/mpcs/src/basefold/query_phase.rs @@ -12,6 +12,8 @@ use ark_std::{end_timer, start_timer}; use core::fmt::Debug; use ff_ext::ExtensionField; use itertools::Itertools; +use p3_mds::MdsPermutation; +use poseidon::SPONGE_WIDTH; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use transcript::Transcript; @@ -28,14 +30,15 @@ use super::{ structure::{BasefoldCommitment, BasefoldCommitmentWithWitness, BasefoldSpec}, }; -pub fn prover_query_phase( +pub fn prover_query_phase( transcript: &mut impl Transcript, - comm: &BasefoldCommitmentWithWitness, - trees: &[MerkleTree], + comm: &BasefoldCommitmentWithWitness, + trees: &[MerkleTree], num_verifier_queries: usize, ) -> QueriesResult where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { let queries: Vec<_> = (0..num_verifier_queries) .map(|_| { @@ -57,22 +60,23 @@ where .map(|x_index| { ( *x_index, - basefold_get_query::(&comm.get_codewords()[0], trees, *x_index), + basefold_get_query::(&comm.get_codewords()[0], trees, *x_index), ) }) .collect(), } } -pub fn batch_prover_query_phase( +pub fn batch_prover_query_phase( transcript: &mut impl Transcript, codeword_size: usize, - comms: &[BasefoldCommitmentWithWitness], - trees: &[MerkleTree], + comms: &[BasefoldCommitmentWithWitness], + trees: &[MerkleTree], num_verifier_queries: usize, ) -> BatchedQueriesResult where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { let queries: Vec<_> = (0..num_verifier_queries) .map(|_| { @@ -94,21 +98,22 @@ where .map(|x_index| { ( *x_index, - batch_basefold_get_query::(comms, trees, codeword_size, *x_index), + batch_basefold_get_query::(comms, trees, codeword_size, *x_index), ) }) .collect(), } } -pub fn simple_batch_prover_query_phase( +pub fn simple_batch_prover_query_phase( transcript: &mut impl Transcript, - comm: &BasefoldCommitmentWithWitness, - trees: &[MerkleTree], + comm: &BasefoldCommitmentWithWitness, + trees: &[MerkleTree], num_verifier_queries: usize, ) -> SimpleBatchQueriesResult where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { let queries: Vec<_> = (0..num_verifier_queries) .map(|_| { @@ -130,7 +135,11 @@ where .map(|x_index| { ( *x_index, - simple_batch_basefold_get_query::(comm.get_codewords(), trees, *x_index), + simple_batch_basefold_get_query::( + comm.get_codewords(), + trees, + *x_index, + ), ) }) .collect(), @@ -138,10 +147,10 @@ where } #[allow(clippy::too_many_arguments)] -pub fn verifier_query_phase>( +pub fn verifier_query_phase, Mds>( indices: &[usize], vp: &>::VerifierParameters, - queries: &QueriesResultWithMerklePath, + queries: &QueriesResultWithMerklePath, sum_check_messages: &[Vec], fold_challenges: &[E], num_rounds: usize, @@ -153,6 +162,7 @@ pub fn verifier_query_phase>( eval: &E, ) where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { let timer = start_timer!(|| "Verifier query phase"); @@ -210,10 +220,10 @@ pub fn verifier_query_phase>( } #[allow(clippy::too_many_arguments)] -pub fn batch_verifier_query_phase>( +pub fn batch_verifier_query_phase, Mds>( indices: &[usize], vp: &>::VerifierParameters, - queries: &BatchedQueriesResultWithMerklePath, + queries: &BatchedQueriesResultWithMerklePath, sum_check_messages: &[Vec], fold_challenges: &[E], num_rounds: usize, @@ -226,6 +236,7 @@ pub fn batch_verifier_query_phase>( eval: &E, ) where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { let timer = start_timer!(|| "Verifier batch query phase"); let encode_timer = start_timer!(|| "Encode final codeword"); @@ -286,10 +297,10 @@ pub fn batch_verifier_query_phase>( } #[allow(clippy::too_many_arguments)] -pub fn simple_batch_verifier_query_phase>( +pub fn simple_batch_verifier_query_phase, Mds>( indices: &[usize], vp: &>::VerifierParameters, - queries: &SimpleBatchQueriesResultWithMerklePath, + queries: &SimpleBatchQueriesResultWithMerklePath, sum_check_messages: &[Vec], fold_challenges: &[E], batch_coeffs: &[E], @@ -302,6 +313,7 @@ pub fn simple_batch_verifier_query_phase + Default, { let timer = start_timer!(|| "Verifier query phase"); @@ -364,13 +376,14 @@ pub fn simple_batch_verifier_query_phase( +fn basefold_get_query( poly_codeword: &FieldType, - trees: &[MerkleTree], + trees: &[MerkleTree], x_index: usize, ) -> SingleQueryResult where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { let mut index = x_index; let p1 = index | 1; @@ -410,14 +423,15 @@ where } } -fn batch_basefold_get_query( - comms: &[BasefoldCommitmentWithWitness], - trees: &[MerkleTree], +fn batch_basefold_get_query( + comms: &[BasefoldCommitmentWithWitness], + trees: &[MerkleTree], codeword_size: usize, x_index: usize, ) -> BatchedSingleQueryResult where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { let mut oracle_list_queries = Vec::with_capacity(trees.len()); @@ -465,13 +479,14 @@ where } } -fn simple_batch_basefold_get_query( +fn simple_batch_basefold_get_query( poly_codewords: &[FieldType], - trees: &[MerkleTree], + trees: &[MerkleTree], x_index: usize, ) -> SimpleBatchSingleQueryResult where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { let mut index = x_index; let p1 = index | 1; @@ -529,6 +544,7 @@ where #[derive(Debug, Copy, Clone, Serialize, Deserialize)] enum CodewordPointPair { + #[serde(bound = "")] Ext(E, E), Base(E::BaseField, E::BaseField), } @@ -543,6 +559,7 @@ impl CodewordPointPair { } #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] enum SimpleBatchLeavesPair where E::BaseField: Serialize + DeserializeOwned, @@ -588,6 +605,7 @@ where } #[derive(Debug, Copy, Clone, Serialize, Deserialize)] +#[serde(bound = "")] struct CodewordSingleQueryResult where E::BaseField: Serialize + DeserializeOwned, @@ -630,17 +648,20 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] -struct CodewordSingleQueryResultWithMerklePath +#[serde(bound = "")] +struct CodewordSingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { query: CodewordSingleQueryResult, - merkle_path: MerklePathWithoutLeafOrRoot, + merkle_path: MerklePathWithoutLeafOrRoot, } -impl CodewordSingleQueryResultWithMerklePath +impl CodewordSingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn check_merkle_path(&self, root: &Digest) { // let timer = start_timer!(|| "CodewordSingleQuery::Check Merkle Path"); @@ -659,6 +680,7 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] struct OracleListQueryResult where E::BaseField: Serialize + DeserializeOwned, @@ -667,6 +689,7 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] struct CommitmentsQueryResult where E::BaseField: Serialize + DeserializeOwned, @@ -675,24 +698,29 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] -struct OracleListQueryResultWithMerklePath +#[serde(bound = "")] +struct OracleListQueryResultWithMerklePath where - E::BaseField: Serialize + DeserializeOwned, + E::BaseField: Serialize, + Mds: MdsPermutation + Default, { - inner: Vec>, + inner: Vec>, } #[derive(Debug, Clone, Serialize, Deserialize)] -struct CommitmentsQueryResultWithMerklePath +#[serde(bound = "")] +struct CommitmentsQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - inner: Vec>, + inner: Vec>, } -impl ListQueryResult for OracleListQueryResult +impl ListQueryResult for OracleListQueryResult where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { fn get_inner(&self) -> &Vec> { &self.inner @@ -703,9 +731,10 @@ where } } -impl ListQueryResult for CommitmentsQueryResult +impl ListQueryResult for CommitmentsQueryResult where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { fn get_inner(&self) -> &Vec> { &self.inner @@ -716,35 +745,40 @@ where } } -impl ListQueryResultWithMerklePath for OracleListQueryResultWithMerklePath +impl ListQueryResultWithMerklePath + for OracleListQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - fn get_inner(&self) -> &Vec> { + fn get_inner(&self) -> &Vec> { &self.inner } - fn new(inner: Vec>) -> Self { + fn new(inner: Vec>) -> Self { Self { inner } } } -impl ListQueryResultWithMerklePath for CommitmentsQueryResultWithMerklePath +impl ListQueryResultWithMerklePath + for CommitmentsQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - fn get_inner(&self) -> &Vec> { + fn get_inner(&self) -> &Vec> { &self.inner } - fn new(inner: Vec>) -> Self { + fn new(inner: Vec>) -> Self { Self { inner } } } -trait ListQueryResult +trait ListQueryResult where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { fn get_inner(&self) -> &Vec>; @@ -752,8 +786,8 @@ where fn merkle_path( &self, - path: impl Fn(usize, usize) -> MerklePathWithoutLeafOrRoot, - ) -> Vec> { + path: impl Fn(usize, usize) -> MerklePathWithoutLeafOrRoot, + ) -> Vec> { let ret = self .get_inner() .iter() @@ -764,17 +798,18 @@ where } } -trait ListQueryResultWithMerklePath: Sized +trait ListQueryResultWithMerklePath: Sized where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - fn new(inner: Vec>) -> Self; + fn new(inner: Vec>) -> Self; - fn get_inner(&self) -> &Vec>; + fn get_inner(&self) -> &Vec>; - fn from_query_and_trees>( + fn from_query_and_trees>( query_result: LQR, - path: impl Fn(usize, usize) -> MerklePathWithoutLeafOrRoot, + path: impl Fn(usize, usize) -> MerklePathWithoutLeafOrRoot, ) -> Self { Self::new( query_result @@ -804,6 +839,7 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] struct SingleQueryResult where E::BaseField: Serialize + DeserializeOwned, @@ -813,22 +849,25 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] -struct SingleQueryResultWithMerklePath +#[serde(bound = "")] +struct SingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - oracle_query: OracleListQueryResultWithMerklePath, - commitment_query: CodewordSingleQueryResultWithMerklePath, + oracle_query: OracleListQueryResultWithMerklePath, + commitment_query: CodewordSingleQueryResultWithMerklePath, } -impl SingleQueryResultWithMerklePath +impl SingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn from_single_query_result( single_query_result: SingleQueryResult, - oracle_trees: &[MerkleTree], - commitment: &BasefoldCommitmentWithWitness, + oracle_trees: &[MerkleTree], + commitment: &BasefoldCommitmentWithWitness, ) -> Self { assert!(commitment.codeword_tree.height() > 0); Self { @@ -909,16 +948,19 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct QueriesResultWithMerklePath +#[serde(bound = "")] +pub struct QueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - inner: Vec<(usize, SingleQueryResultWithMerklePath)>, + inner: Vec<(usize, SingleQueryResultWithMerklePath)>, } -impl QueriesResultWithMerklePath +impl QueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn empty() -> Self { Self { inner: vec![] } @@ -926,8 +968,8 @@ where pub fn from_query_result( query_result: QueriesResult, - oracle_trees: &[MerkleTree], - commitment: &BasefoldCommitmentWithWitness, + oracle_trees: &[MerkleTree], + commitment: &BasefoldCommitmentWithWitness, ) -> Self { Self { inner: query_result @@ -978,6 +1020,7 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] struct BatchedSingleQueryResult where E::BaseField: Serialize + DeserializeOwned, @@ -987,22 +1030,25 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] -struct BatchedSingleQueryResultWithMerklePath +#[serde(bound = "")] +struct BatchedSingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - oracle_query: OracleListQueryResultWithMerklePath, - commitments_query: CommitmentsQueryResultWithMerklePath, + oracle_query: OracleListQueryResultWithMerklePath, + commitments_query: CommitmentsQueryResultWithMerklePath, } -impl BatchedSingleQueryResultWithMerklePath +impl BatchedSingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn from_batched_single_query_result( batched_single_query_result: BatchedSingleQueryResult, - oracle_trees: &[MerkleTree], - commitments: &[BasefoldCommitmentWithWitness], + oracle_trees: &[MerkleTree], + commitments: &[BasefoldCommitmentWithWitness], ) -> Self { Self { oracle_query: OracleListQueryResultWithMerklePath::from_query_and_trees( @@ -1133,21 +1179,24 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct BatchedQueriesResultWithMerklePath +#[serde(bound = "")] +pub struct BatchedQueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - inner: Vec<(usize, BatchedSingleQueryResultWithMerklePath)>, + inner: Vec<(usize, BatchedSingleQueryResultWithMerklePath)>, } -impl BatchedQueriesResultWithMerklePath +impl BatchedQueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn from_batched_query_result( batched_query_result: BatchedQueriesResult, - oracle_trees: &[MerkleTree], - commitments: &[BasefoldCommitmentWithWitness], + oracle_trees: &[MerkleTree], + commitments: &[BasefoldCommitmentWithWitness], ) -> Self { Self { inner: batched_query_result @@ -1202,6 +1251,7 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] struct SimpleBatchCommitmentSingleQueryResult where E::BaseField: Serialize + DeserializeOwned, @@ -1246,17 +1296,20 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] -struct SimpleBatchCommitmentSingleQueryResultWithMerklePath +#[serde(bound = "")] +struct SimpleBatchCommitmentSingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { query: SimpleBatchCommitmentSingleQueryResult, - merkle_path: MerklePathWithoutLeafOrRoot, + merkle_path: MerklePathWithoutLeafOrRoot, } -impl SimpleBatchCommitmentSingleQueryResultWithMerklePath +impl SimpleBatchCommitmentSingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn check_merkle_path(&self, root: &Digest) { // let timer = start_timer!(|| "CodewordSingleQuery::Check Merkle Path"); @@ -1283,6 +1336,7 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] struct SimpleBatchSingleQueryResult where E::BaseField: Serialize + DeserializeOwned, @@ -1292,22 +1346,25 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] -struct SimpleBatchSingleQueryResultWithMerklePath +#[serde(bound = "")] +struct SimpleBatchSingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - oracle_query: OracleListQueryResultWithMerklePath, - commitment_query: SimpleBatchCommitmentSingleQueryResultWithMerklePath, + oracle_query: OracleListQueryResultWithMerklePath, + commitment_query: SimpleBatchCommitmentSingleQueryResultWithMerklePath, } -impl SimpleBatchSingleQueryResultWithMerklePath +impl SimpleBatchSingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn from_single_query_result( single_query_result: SimpleBatchSingleQueryResult, - oracle_trees: &[MerkleTree], - commitment: &BasefoldCommitmentWithWitness, + oracle_trees: &[MerkleTree], + commitment: &BasefoldCommitmentWithWitness, ) -> Self { Self { oracle_query: OracleListQueryResultWithMerklePath::from_query_and_trees( @@ -1389,21 +1446,24 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SimpleBatchQueriesResultWithMerklePath +#[serde(bound = "")] +pub struct SimpleBatchQueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - inner: Vec<(usize, SimpleBatchSingleQueryResultWithMerklePath)>, + inner: Vec<(usize, SimpleBatchSingleQueryResultWithMerklePath)>, } -impl SimpleBatchQueriesResultWithMerklePath +impl SimpleBatchQueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn from_query_result( query_result: SimpleBatchQueriesResult, - oracle_trees: &[MerkleTree], - commitment: &BasefoldCommitmentWithWitness, + oracle_trees: &[MerkleTree], + commitment: &BasefoldCommitmentWithWitness, ) -> Self { Self { inner: query_result diff --git a/mpcs/src/basefold/structure.rs b/mpcs/src/basefold/structure.rs index 547dd7c51..b89ac7c7e 100644 --- a/mpcs/src/basefold/structure.rs +++ b/mpcs/src/basefold/structure.rs @@ -5,6 +5,8 @@ use crate::{ use core::fmt::Debug; use ff_ext::ExtensionField; +use p3_mds::MdsPermutation; +use poseidon::SPONGE_WIDTH; use serde::{Deserialize, Serialize, Serializer, de::DeserializeOwned}; use multilinear_extensions::mle::FieldType; @@ -59,20 +61,22 @@ pub struct BasefoldVerifierParams> { /// A polynomial commitment together with all the data (e.g., the codeword, and Merkle tree) /// used to generate this commitment and for assistant in opening #[derive(Clone, Debug, Default)] -pub struct BasefoldCommitmentWithWitness +pub struct BasefoldCommitmentWithWitness where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - pub(crate) codeword_tree: MerkleTree, + pub(crate) codeword_tree: MerkleTree, pub(crate) polynomials_bh_evals: Vec>, pub(crate) num_vars: usize, pub(crate) is_base: bool, pub(crate) num_polys: usize, } -impl BasefoldCommitmentWithWitness +impl BasefoldCommitmentWithWitness where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn to_commitment(&self) -> BasefoldCommitment { BasefoldCommitment::new( @@ -132,20 +136,22 @@ where } } -impl From> for Digest +impl From> for Digest where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - fn from(val: BasefoldCommitmentWithWitness) -> Self { + fn from(val: BasefoldCommitmentWithWitness) -> Self { val.get_root_as() } } -impl From<&BasefoldCommitmentWithWitness> for BasefoldCommitment +impl From<&BasefoldCommitmentWithWitness> for BasefoldCommitment where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - fn from(val: &BasefoldCommitmentWithWitness) -> Self { + fn from(val: &BasefoldCommitmentWithWitness) -> Self { val.to_commitment() } } @@ -193,9 +199,10 @@ where } } -impl PartialEq for BasefoldCommitmentWithWitness +impl PartialEq for BasefoldCommitmentWithWitness where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { fn eq(&self, other: &Self) -> bool { self.get_codewords().eq(other.get_codewords()) @@ -203,8 +210,10 @@ where } } -impl Eq for BasefoldCommitmentWithWitness where - E::BaseField: Serialize + DeserializeOwned +impl Eq for BasefoldCommitmentWithWitness +where + E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { } @@ -245,9 +254,9 @@ where } #[derive(Debug)] -pub struct Basefold>(PhantomData<(E, Spec)>); +pub struct Basefold, Mds>(PhantomData<(E, Spec, Mds)>); -impl> Serialize for Basefold { +impl, Mds> Serialize for Basefold { fn serialize(&self, serializer: S) -> Result where S: Serializer, @@ -256,9 +265,9 @@ impl> Serialize for Basefold { } } -pub type BasefoldDefault = Basefold; +pub type BasefoldDefault = Basefold; -impl> Clone for Basefold { +impl, Mds> Clone for Basefold { fn clone(&self) -> Self { Self(PhantomData) } @@ -274,9 +283,10 @@ where } } -impl AsRef<[Digest]> for BasefoldCommitmentWithWitness +impl AsRef<[Digest]> for BasefoldCommitmentWithWitness where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { fn as_ref(&self) -> &[Digest] { let root = self.get_root_ref(); @@ -285,34 +295,37 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] -pub enum ProofQueriesResultWithMerklePath +#[serde(bound = "")] +pub enum ProofQueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - Single(QueriesResultWithMerklePath), - Batched(BatchedQueriesResultWithMerklePath), - SimpleBatched(SimpleBatchQueriesResultWithMerklePath), + Single(QueriesResultWithMerklePath), + Batched(BatchedQueriesResultWithMerklePath), + SimpleBatched(SimpleBatchQueriesResultWithMerklePath), } -impl ProofQueriesResultWithMerklePath +impl ProofQueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - pub fn as_single(&self) -> &QueriesResultWithMerklePath { + pub fn as_single(&self) -> &QueriesResultWithMerklePath { match self { Self::Single(x) => x, _ => panic!("Not a single query result"), } } - pub fn as_batched(&self) -> &BatchedQueriesResultWithMerklePath { + pub fn as_batched(&self) -> &BatchedQueriesResultWithMerklePath { match self { Self::Batched(x) => x, _ => panic!("Not a batched query result"), } } - pub fn as_simple_batched(&self) -> &SimpleBatchQueriesResultWithMerklePath { + pub fn as_simple_batched(&self) -> &SimpleBatchQueriesResultWithMerklePath { match self { Self::SimpleBatched(x) => x, _ => panic!("Not a simple batched query result"), @@ -321,21 +334,24 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct BasefoldProof +#[serde(bound = "")] +pub struct BasefoldProof where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub(crate) sumcheck_messages: Vec>, pub(crate) roots: Vec>, pub(crate) final_message: Vec, - pub(crate) query_result_with_merkle_path: ProofQueriesResultWithMerklePath, + pub(crate) query_result_with_merkle_path: ProofQueriesResultWithMerklePath, pub(crate) sumcheck_proof: Option>>, pub(crate) trivial_proof: Vec>, } -impl BasefoldProof +impl BasefoldProof where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn trivial(evals: Vec>) -> Self { Self { @@ -356,6 +372,7 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] pub struct BasefoldCommitPhaseProof where E::BaseField: Serialize + DeserializeOwned, diff --git a/mpcs/src/basefold/sumcheck.rs b/mpcs/src/basefold/sumcheck.rs index ede813e21..6cefc31b1 100644 --- a/mpcs/src/basefold/sumcheck.rs +++ b/mpcs/src/basefold/sumcheck.rs @@ -1,6 +1,6 @@ -use ff::Field; use ff_ext::ExtensionField; use multilinear_extensions::mle::FieldType; +use p3_field::Field; use rayon::prelude::{ IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, ParallelSliceMut, @@ -101,9 +101,9 @@ fn parallel_pi(evals: &[F], eq: &[F]) -> Vec { } }); - coeffs[0] = firsts.par_iter().sum(); - coeffs[1] = seconds.par_iter().sum(); - coeffs[2] = thirds.par_iter().sum(); + coeffs[0] = firsts.par_iter().copied().sum(); + coeffs[1] = seconds.par_iter().copied().sum(); + coeffs[2] = thirds.par_iter().copied().sum(); coeffs } @@ -136,9 +136,9 @@ fn parallel_pi_base(evals: &[E::BaseField], eq: &[E]) -> Vec< } }); - coeffs[0] = firsts.par_iter().sum(); - coeffs[1] = seconds.par_iter().sum(); - coeffs[2] = thirds.par_iter().sum(); + coeffs[0] = firsts.par_iter().copied().sum(); + coeffs[1] = seconds.par_iter().copied().sum(); + coeffs[2] = thirds.par_iter().copied().sum(); coeffs } @@ -169,8 +169,7 @@ pub fn sum_check_last_round(eq: &mut Vec, bh_values: &mut Vec, c #[cfg(test)] mod tests { - use ff::Field; - use goldilocks::Goldilocks; + use p3_goldilocks::Goldilocks; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; diff --git a/mpcs/src/lib.rs b/mpcs/src/lib.rs index fcfd1ba69..7c1938b88 100644 --- a/mpcs/src/lib.rs +++ b/mpcs/src/lib.rs @@ -2,6 +2,8 @@ use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::mle::DenseMultilinearExtension; +use p3_mds::MdsPermutation; +use poseidon::SPONGE_WIDTH; use serde::{Serialize, de::DeserializeOwned}; use std::fmt::Debug; use transcript::{BasicTranscript, Transcript}; @@ -221,10 +223,11 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { ) -> Result<(), Error>; } -pub trait NoninteractivePCS: +pub trait NoninteractivePCS: PolynomialCommitmentScheme> where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { fn ni_open( pp: &Self::ProverParam, @@ -233,7 +236,7 @@ where point: &[E], eval: &E, ) -> Result { - let mut transcript = BasicTranscript::::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); Self::open(pp, poly, comm, point, eval, &mut transcript) } @@ -244,7 +247,7 @@ where points: &[Vec], evals: &[Evaluation], ) -> Result { - let mut transcript = BasicTranscript::::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); Self::batch_open(pp, polys, comms, points, evals, &mut transcript) } @@ -255,7 +258,7 @@ where eval: &E, proof: &Self::Proof, ) -> Result<(), Error> { - let mut transcript = BasicTranscript::::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); Self::verify(vp, comm, point, eval, proof, &mut transcript) } @@ -269,7 +272,7 @@ where where Self::Commitment: 'a, { - let mut transcript = BasicTranscript::::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); Self::batch_verify(vp, comms, points, evals, proof, &mut transcript) } } @@ -379,6 +382,10 @@ pub mod test_util { use multilinear_extensions::{ mle::MultilinearExtension, virtual_poly::ArcMultilinearExtension, }; + #[cfg(test)] + use p3_mds::MdsPermutation; + #[cfg(test)] + use poseidon::SPONGE_WIDTH; use rand::rngs::OsRng; #[cfg(test)] use transcript::BasicTranscript; @@ -445,19 +452,20 @@ pub mod test_util { } #[cfg(test)] - pub fn run_commit_open_verify( + pub fn run_commit_open_verify( gen_rand_poly: fn(usize) -> DenseMultilinearExtension, num_vars_start: usize, num_vars_end: usize, ) where Pcs: PolynomialCommitmentScheme, + Mds: MdsPermutation + Default, { for num_vars in num_vars_start..num_vars_end { let (pp, vp) = setup_pcs::(num_vars); // Commit and open let (comm, eval, proof, challenge) = { - let mut transcript = BasicTranscript::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); let poly = gen_rand_poly(num_vars); let comm = Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap(); let point = get_point_from_challenge(num_vars, &mut transcript); @@ -473,7 +481,7 @@ pub mod test_util { }; // Verify { - let mut transcript = BasicTranscript::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); Pcs::write_commitment(&comm, &mut transcript).unwrap(); let point = get_point_from_challenge(num_vars, &mut transcript); transcript.append_field_element_ext(&eval); @@ -486,12 +494,13 @@ pub mod test_util { } #[cfg(test)] - pub fn run_batch_commit_open_verify( + pub fn run_batch_commit_open_verify( gen_rand_poly: fn(usize) -> DenseMultilinearExtension, num_vars_start: usize, num_vars_end: usize, ) where E: ExtensionField, + Mds: MdsPermutation + Default, Pcs: PolynomialCommitmentScheme, { for num_vars in num_vars_start..num_vars_end { @@ -508,7 +517,7 @@ pub mod test_util { .collect_vec(); let (comms, evals, proof, challenge) = { - let mut transcript = BasicTranscript::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); let polys = gen_rand_polys(|i| num_vars - (i >> 1), batch_size, gen_rand_poly); let comms = @@ -539,7 +548,7 @@ pub mod test_util { }; // Batch verify { - let mut transcript = BasicTranscript::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); let comms = comms .iter() .map(|comm| { @@ -567,20 +576,21 @@ pub mod test_util { } #[cfg(test)] - pub(super) fn run_simple_batch_commit_open_verify( + pub(super) fn run_simple_batch_commit_open_verify( gen_rand_poly: fn(usize) -> DenseMultilinearExtension, num_vars_start: usize, num_vars_end: usize, batch_size: usize, ) where E: ExtensionField, + Mds: MdsPermutation + Default, Pcs: PolynomialCommitmentScheme, { for num_vars in num_vars_start..num_vars_end { let (pp, vp) = setup_pcs::(num_vars); let (comm, evals, proof, challenge) = { - let mut transcript = BasicTranscript::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); let polys = gen_rand_polys(|_| num_vars, batch_size, gen_rand_poly); let comm = Pcs::batch_commit_and_write(&pp, polys.as_slice(), &mut transcript).unwrap(); @@ -604,7 +614,7 @@ pub mod test_util { }; // Batch verify { - let mut transcript = BasicTranscript::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); Pcs::write_commitment(&comm, &mut transcript).unwrap(); let point = get_point_from_challenge(num_vars, &mut transcript); diff --git a/mpcs/src/sum_check.rs b/mpcs/src/sum_check.rs index f2fcf0e47..7406025ca 100644 --- a/mpcs/src/sum_check.rs +++ b/mpcs/src/sum_check.rs @@ -9,10 +9,10 @@ use crate::{ use std::{collections::HashMap, fmt::Debug}; use classic::{ClassicSumCheckRoundMessage, SumcheckProof}; -use ff::PrimeField; use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::mle::DenseMultilinearExtension; +use p3_field::Field; use serde::{Serialize, de::DeserializeOwned}; use transcript::Transcript; @@ -113,27 +113,30 @@ pub fn evaluate( ) } -pub fn lagrange_eval(x: &[F], b: usize) -> F { +pub fn lagrange_eval(x: &[F], b: usize) -> F { assert!(!x.is_empty()); product(x.iter().enumerate().map( |(idx, x_i)| { - if b.nth_bit(idx) { *x_i } else { F::ONE - x_i } + if b.nth_bit(idx) { *x_i } else { F::ONE - *x_i } }, )) } -pub fn eq_xy_eval(x: &[F], y: &[F]) -> F { +pub fn eq_xy_eval(x: &[F], y: &[F]) -> F { assert!(!x.is_empty()); assert_eq!(x.len(), y.len()); product( x.iter() .zip(y) - .map(|(x_i, y_i)| (*x_i * y_i).double() + F::ONE - x_i - y_i), + .map(|(x_i, y_i)| (*x_i * *y_i).double() + F::ONE - *x_i - *y_i), ) } -fn identity_eval(x: &[F]) -> F { - inner_product(x, &powers(F::from(2)).take(x.len()).collect_vec()) +fn identity_eval(x: &[F]) -> F { + inner_product( + x, + &powers(F::from_canonical_u64(2)).take(x.len()).collect_vec(), + ) } diff --git a/mpcs/src/sum_check/classic.rs b/mpcs/src/sum_check/classic.rs index f99d832df..ea7bfc92f 100644 --- a/mpcs/src/sum_check/classic.rs +++ b/mpcs/src/sum_check/classic.rs @@ -9,7 +9,6 @@ use crate::{ }, }; use ark_std::{end_timer, start_timer}; -use ff::Field; use ff_ext::ExtensionField; use itertools::Itertools; use num_integer::Integer; @@ -24,6 +23,7 @@ use multilinear_extensions::{ pub(crate) use coeff::Coefficients; pub use coeff::CoefficientsProver; +use p3_field::FieldAlgebra; #[derive(Debug)] pub struct ProverState<'a, E: ExtensionField> { @@ -99,12 +99,12 @@ impl<'a, E: ExtensionField> ProverState<'a, E> { fn next_round(&mut self, sum: E, challenge: &E) { self.sum = sum; - self.identity += E::from(1 << self.round) * challenge; + self.identity += E::from_canonical_u64(1 << self.round) * *challenge; self.lagranges.values_mut().for_each(|(b, value)| { if b.is_even() { - *value *= &(E::ONE - challenge); + *value *= E::ONE - *challenge; } else { - *value *= challenge; + *value *= *challenge; } *b >>= 1; }); @@ -324,51 +324,58 @@ mod tests { use transcript::BasicTranscript; use super::*; - use goldilocks::{Goldilocks as Fr, GoldilocksExt2 as E}; + use ff_ext::GoldilocksExt2 as E; + use p3_goldilocks::{Goldilocks as Fr, MdsMatrixGoldilocks}; #[test] fn test_sum_check_protocol() { let polys = [ DenseMultilinearExtension::::from_evaluations_vec(2, vec![ - Fr::from(1), - Fr::from(2), - Fr::from(3), - Fr::from(4), + Fr::from_canonical_u64(1), + Fr::from_canonical_u64(2), + Fr::from_canonical_u64(3), + Fr::from_canonical_u64(4), ]), DenseMultilinearExtension::from_evaluations_vec(2, vec![ - Fr::from(0), - Fr::from(1), - Fr::from(1), - Fr::from(0), + Fr::from_canonical_u64(0), + Fr::from_canonical_u64(1), + Fr::from_canonical_u64(1), + Fr::from_canonical_u64(0), ]), - DenseMultilinearExtension::from_evaluations_vec(1, vec![Fr::from(0), Fr::from(1)]), + DenseMultilinearExtension::from_evaluations_vec(1, vec![ + Fr::from_canonical_u64(0), + Fr::from_canonical_u64(1), + ]), + ]; + let points = vec![ + vec![E::from_canonical_u64(1), E::from_canonical_u64(2)], + vec![E::from_canonical_u64(1)], ]; - let points = vec![vec![E::from(1), E::from(2)], vec![E::from(1)]]; let expression = Expression::::eq_xy(0) * Expression::Polynomial(Query::new(0, Rotation::cur())) - * E::from(Fr::from(2)) + * E::from(Fr::from_canonical_u64(2)) + Expression::::eq_xy(0) * Expression::Polynomial(Query::new(1, Rotation::cur())) - * E::from(Fr::from(3)) + * E::from(Fr::from_canonical_u64(3)) + Expression::::eq_xy(1) * Expression::Polynomial(Query::new(2, Rotation::cur())) - * E::from(Fr::from(4)); + * E::from(Fr::from_canonical_u64(4)); let virtual_poly = VirtualPolynomial::::new(&expression, polys.iter(), &[], points.as_slice()); let sum = inner_product( &poly_iter_ext(&polys[0]).collect_vec(), &build_eq_x_r_vec(&points[0]), - ) * Fr::from(2) + ) * Fr::from_canonical_u64(2) + inner_product( &poly_iter_ext(&polys[1]).collect_vec(), &build_eq_x_r_vec(&points[0]), - ) * Fr::from(3) + ) * Fr::from_canonical_u64(3) + inner_product( &poly_iter_ext(&polys[2]).collect_vec(), &build_eq_x_r_vec(&points[1]), - ) * Fr::from(4) - * Fr::from(2); // The third polynomial is summed twice because the hypercube is larger - let mut transcript = BasicTranscript::::new(b"sumcheck"); + ) * Fr::from_canonical_u64(4) + * Fr::from_canonical_u64(2); // The third polynomial is summed twice because the hypercube is larger + let mut transcript = BasicTranscript::::new(b"sumcheck"); let (challenges, evals, proof) = > as SumCheck>::prove( &(), @@ -383,7 +390,7 @@ mod tests { assert_eq!(polys[1].evaluate(&challenges), evals[1]); assert_eq!(polys[2].evaluate(&challenges[..1]), evals[2]); - let mut transcript = BasicTranscript::::new(b"sumcheck"); + let mut transcript = BasicTranscript::::new(b"sumcheck"); let (new_sum, verifier_challenges) = > as SumCheck< E, @@ -395,12 +402,12 @@ mod tests { assert_eq!(verifier_challenges, challenges); assert_eq!( new_sum, - evals[0] * eq_xy_eval(&points[0], &challenges[..2]) * Fr::from(2) - + evals[1] * eq_xy_eval(&points[0], &challenges[..2]) * Fr::from(3) - + evals[2] * eq_xy_eval(&points[1], &challenges[..1]) * Fr::from(4) + evals[0] * eq_xy_eval(&points[0], &challenges[..2]) * Fr::from_canonical_u64(2) + + evals[1] * eq_xy_eval(&points[0], &challenges[..2]) * Fr::from_canonical_u64(3) + + evals[2] * eq_xy_eval(&points[1], &challenges[..1]) * Fr::from_canonical_u64(4) ); - let mut transcript = BasicTranscript::::new(b"sumcheck"); + let mut transcript = BasicTranscript::::new(b"sumcheck"); > as SumCheck>::verify( &(), diff --git a/mpcs/src/sum_check/classic/coeff.rs b/mpcs/src/sum_check/classic/coeff.rs index 10d5c1c20..a571deedd 100644 --- a/mpcs/src/sum_check/classic/coeff.rs +++ b/mpcs/src/sum_check/classic/coeff.rs @@ -32,6 +32,7 @@ macro_rules! zip_self { } #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] pub struct Coefficients(FieldType); impl ClassicSumCheckRoundMessage for Coefficients { @@ -49,7 +50,7 @@ impl ClassicSumCheckRoundMessage for Coefficients { } fn sum(&self) -> E { - self[0] + self[..].iter().sum::() + self[0] + self[..].iter().copied().sum::() } fn evaluate(&self, _: &Self::Auxiliary, challenge: &E) -> E { @@ -60,7 +61,7 @@ impl ClassicSumCheckRoundMessage for Coefficients { impl<'rhs, E: ExtensionField> AddAssign<&'rhs E> for Coefficients { fn add_assign(&mut self, rhs: &'rhs E) { match &mut self.0 { - FieldType::Ext(coeffs) => coeffs[0] += rhs, + FieldType::Ext(coeffs) => coeffs[0] += *rhs, FieldType::Base(_) => panic!("Cannot add extension element to base coefficients"), FieldType::Unreachable => unreachable!(), } @@ -74,11 +75,11 @@ impl<'rhs, E: ExtensionField> AddAssign<(&'rhs E, &'rhs Coefficients)> for Co if scalar == &E::ONE { lhs.iter_mut() .zip(rhs.iter()) - .for_each(|(lhs, rhs)| *lhs += rhs) + .for_each(|(lhs, rhs)| *lhs += *rhs) } else if scalar != &E::ZERO { lhs.iter_mut() .zip(rhs.iter()) - .for_each(|(lhs, rhs)| *lhs += &(*scalar * rhs)) + .for_each(|(lhs, rhs)| *lhs += *scalar * *rhs) } } _ => panic!("Cannot add base coefficients to extension coefficients"), @@ -116,7 +117,7 @@ impl CoefficientsProver { result.iter_mut().enumerate().for_each(|(i, v)| { *v += poly_index_ext(lhs, i % lhs.evaluations.len()) * poly_index_ext(rhs, i % rhs.evaluations.len()) - * scalar; + * *scalar; }) } _ => unimplemented!(), @@ -162,7 +163,7 @@ impl ClassicSumCheckProver for CoefficientsProver { outputs.extend( products .iter() - .map(|(scalar, polys)| (constant * scalar, polys.clone())), + .map(|(scalar, polys)| (constant * *scalar, polys.clone())), ) } } @@ -170,7 +171,7 @@ impl ClassicSumCheckProver for CoefficientsProver { lhs_products.iter().cartesian_product(rhs_products.iter()) { outputs.push(( - *lhs_scalar * rhs_scalar, + *lhs_scalar * *rhs_scalar, iter::empty() .chain(lhs_polys) .chain(rhs_polys) @@ -182,7 +183,7 @@ impl ClassicSumCheckProver for CoefficientsProver { }, &|(constant, mut products), rhs| { products.iter_mut().for_each(|(lhs, _)| { - *lhs *= &rhs; + *lhs *= rhs; }); (constant * rhs, products) }, @@ -194,7 +195,7 @@ impl ClassicSumCheckProver for CoefficientsProver { // Initialize h(X) to zero let mut coeffs = Coefficients(FieldType::Ext(vec![E::ZERO; state.expression.degree() + 1])); // First, sum the constant over the hypercube and add to h(X) - coeffs += &(E::from(state.size() as u64) * self.0); + coeffs += &(E::from_canonical_u64(state.size() as u64) * self.0); // Next, for every product of polynomials, where each product is assumed to be exactly 2 // put this into h(X). if self.1.iter().all(|(_, products)| products.len() == 2) { @@ -219,7 +220,7 @@ impl ClassicSumCheckProver for CoefficientsProver { } fn sum(&self, state: &ProverState) -> E { - self.evals(state).iter().sum() + self.evals(state).iter().copied().sum() } } @@ -267,10 +268,10 @@ impl CoefficientsProver { .for_each(|((lhs_0, lhs_1), (rhs_0, rhs_1))| { let coeff_0 = lhs_0 * rhs_0; let coeff_2 = (lhs_1 - lhs_0) * (rhs_1 - rhs_0); - coeffs[0] += &coeff_0; - coeffs[2] += &coeff_2; + coeffs[0] += coeff_0; + coeffs[2] += coeff_2; if !LAZY { - coeffs[1] += &(lhs_1 * rhs_1 - coeff_0 - coeff_2); + coeffs[1] += lhs_1 * rhs_1 - coeff_0 - coeff_2; } }); }; diff --git a/mpcs/src/util.rs b/mpcs/src/util.rs index 7688b53ec..49be09983 100644 --- a/mpcs/src/util.rs +++ b/mpcs/src/util.rs @@ -3,14 +3,13 @@ pub mod expression; pub mod hash; pub mod parallel; pub mod plonky2_util; -use ff::{Field, PrimeField}; -use ff_ext::ExtensionField; -use goldilocks::SmallField; +use ff_ext::{ExtensionField, SmallField}; use itertools::{Either, Itertools, izip}; use multilinear_extensions::mle::{DenseMultilinearExtension, FieldType}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; pub mod merkle_tree; use crate::{Error, util::parallel::parallelize}; +use p3_field::{FieldAlgebra, PrimeField}; pub use plonky2_util::log2_strict; pub fn ext_to_usize(x: &E) -> usize { @@ -23,7 +22,7 @@ pub fn base_to_usize(x: &E::BaseField) -> usize { } pub fn u32_to_field(x: u32) -> E::BaseField { - E::BaseField::from(x as u64) + E::BaseField::from_canonical_u32(x) } pub trait BitIndex { @@ -38,7 +37,7 @@ impl BitIndex for usize { /// How many bytes are required to store n field elements? pub fn num_of_bytes(n: usize) -> usize { - (F::NUM_BITS as usize).next_power_of_two() * n / 8 + (F::bits() as usize).next_power_of_two() * n / 8 } macro_rules! impl_index { @@ -118,8 +117,8 @@ pub fn field_type_index_mul_base( scalar: &E::BaseField, ) { match poly { - FieldType::Ext(coeffs) => coeffs[index] *= scalar, - FieldType::Base(coeffs) => coeffs[index] *= scalar, + FieldType::Ext(coeffs) => coeffs[index] *= *scalar, + FieldType::Base(coeffs) => coeffs[index] *= *scalar, _ => unreachable!(), } } @@ -194,13 +193,13 @@ pub fn multiply_poly(poly: &mut DenseMultilinearExtension, match &mut poly.evaluations { FieldType::Ext(coeffs) => { for coeff in coeffs.iter_mut() { - *coeff *= scalar; + *coeff *= *scalar; } } FieldType::Base(coeffs) => { *poly = DenseMultilinearExtension::::from_evaluations_ext_vec( poly.num_vars, - coeffs.iter().map(|x| E::from(*x) * scalar).collect(), + coeffs.iter().map(|x| E::from(*x) * *scalar).collect(), ); } _ => unreachable!(), @@ -320,11 +319,12 @@ pub fn ext_try_into_base(x: &E) -> Result(mut rng: impl RngCore) -> [F; N] { + pub fn rand_array(mut rng: impl RngCore) -> [F; N] { array::from_fn(|_| F::random(&mut rng)) } - pub fn rand_vec(n: usize, mut rng: impl RngCore) -> Vec { + pub fn rand_vec(n: usize, mut rng: impl RngCore) -> Vec { iter::repeat_with(|| F::random(&mut rng)).take(n).collect() } #[test] pub fn test_field_transform() { - assert_eq!(F::from(2) * F::from(3), F::from(6)); + assert_eq!( + F::from_canonical_u64(2) * F::from_canonical_u64(3), + F::from_canonical_u64(6) + ); assert_eq!(base_to_usize::(&u32_to_field::(1u32)), 1); assert_eq!(base_to_usize::(&u32_to_field::(10u32)), 10); } diff --git a/mpcs/src/util/arithmetic.rs b/mpcs/src/util/arithmetic.rs index 609f65455..657d76646 100644 --- a/mpcs/src/util/arithmetic.rs +++ b/mpcs/src/util/arithmetic.rs @@ -1,8 +1,7 @@ -use ff::{BatchInvert, Field, PrimeField}; - use ff_ext::ExtensionField; use multilinear_extensions::mle::FieldType; use num_integer::Integer; +use p3_field::{Field, PrimeField}; use std::{borrow::Borrow, iter}; mod bh; @@ -13,6 +12,7 @@ pub use hypercube::{ interpolate_field_type_over_boolean_hypercube, interpolate_over_boolean_hypercube, }; use num_bigint::BigUint; +use p3_field::FieldAlgebra; use itertools::Itertools; @@ -29,7 +29,7 @@ pub fn horner(coeffs: &[F], x: &F) -> F { let coeff_vec: Vec<&F> = coeffs.iter().rev().collect(); let mut acc = F::ZERO; for c in coeff_vec { - acc = acc * x + c; + acc = acc * *x + *c; } acc // 2 @@ -40,7 +40,7 @@ pub fn horner(coeffs: &[F], x: &F) -> F { pub fn horner_base(coeffs: &[E::BaseField], x: &E) -> E { let mut acc = E::ZERO; for c in coeffs.iter().rev() { - acc = acc * x + E::from(*c); + acc = acc * *x + E::from(*c); } acc // 2 @@ -52,11 +52,11 @@ pub fn steps(start: F) -> impl Iterator { } pub fn steps_by(start: F, step: F) -> impl Iterator { - iter::successors(Some(start), move |state| Some(step + state)) + iter::successors(Some(start), move |state| Some(step + *state)) } pub fn powers(scalar: F) -> impl Iterator { - iter::successors(Some(F::ONE), move |power| Some(scalar * power)) + iter::successors(Some(F::ONE), move |power| Some(scalar * *power)) } pub fn squares(scalar: F) -> impl Iterator { @@ -66,13 +66,13 @@ pub fn squares(scalar: F) -> impl Iterator { pub fn product(values: impl IntoIterator>) -> F { values .into_iter() - .fold(F::ONE, |acc, value| acc * value.borrow()) + .fold(F::ONE, |acc, value| acc * *value.borrow()) } pub fn sum(values: impl IntoIterator>) -> F { values .into_iter() - .fold(F::ZERO, |acc, value| acc + value.borrow()) + .fold(F::ZERO, |acc, value| acc + *value.borrow()) } pub fn inner_product<'a, 'b, F: Field>( @@ -81,7 +81,7 @@ pub fn inner_product<'a, 'b, F: Field>( ) -> F { lhs.into_iter() .zip_eq(rhs) - .map(|(lhs, rhs)| *lhs * rhs) + .map(|(lhs, rhs)| *lhs * *rhs) .reduce(|acc, product| acc + product) .unwrap_or_default() } @@ -94,42 +94,11 @@ pub fn inner_product_three<'a, 'b, 'c, F: Field>( a.into_iter() .zip_eq(b) .zip_eq(c) - .map(|((a, b), c)| *a * b * c) + .map(|((a, b), c)| *a * *b * *c) .reduce(|acc, product| acc + product) .unwrap_or_default() } -pub fn barycentric_weights(points: &[F]) -> Vec { - let mut weights = points - .iter() - .enumerate() - .map(|(j, point_j)| { - points - .iter() - .enumerate() - .filter(|&(i, _point_i)| (i != j)) - .map(|(_i, point_i)| *point_j - point_i) - .reduce(|acc, value| acc * value) - .unwrap_or(F::ONE) - }) - .collect_vec(); - weights.iter_mut().batch_invert(); - weights -} - -pub fn barycentric_interpolate(weights: &[F], points: &[F], evals: &[F], x: &F) -> F { - let (coeffs, sum_inv) = { - let mut coeffs = points.iter().map(|point| *x - point).collect_vec(); - coeffs.iter_mut().batch_invert(); - coeffs.iter_mut().zip(weights).for_each(|(coeff, weight)| { - *coeff *= weight; - }); - let sum_inv = coeffs.iter().fold(F::ZERO, |sum, coeff| sum + coeff); - (coeffs, sum_inv.invert().unwrap()) - }; - inner_product(&coeffs, evals) * sum_inv -} - pub fn modulus() -> BigUint { BigUint::from_bytes_le((-F::ONE).to_repr().as_ref()) + 1u64 } @@ -215,7 +184,7 @@ pub fn interpolate2(points: [(F, F); 2], x: F) -> F { let (a0, a1) = points[0]; let (b0, b1) = points[1]; assert_ne!(a0, b0); - a1 + (x - a0) * (b1 - a1) * (b0 - a0).invert().unwrap() + a1 + (x - a0) * (b1 - a1) * (b0 - a0).inverse() } pub fn degree_2_zero_plus_one(poly: &[F]) -> F { @@ -229,7 +198,7 @@ pub fn degree_2_eval(poly: &[F], point: F) -> F { pub fn base_from_raw_bytes(bytes: &[u8]) -> E::BaseField { let mut res = E::BaseField::ZERO; bytes.iter().for_each(|b| { - res += E::BaseField::from(u64::from(*b)); + res += E::BaseField::from_canonical_u8(*b); }); res } diff --git a/mpcs/src/util/arithmetic/hypercube.rs b/mpcs/src/util/arithmetic/hypercube.rs index fc4edc248..19e42372e 100644 --- a/mpcs/src/util/arithmetic/hypercube.rs +++ b/mpcs/src/util/arithmetic/hypercube.rs @@ -1,6 +1,6 @@ -use ff::Field; use ff_ext::ExtensionField; use multilinear_extensions::mle::FieldType; +use p3_field::Field; use rayon::prelude::{ParallelIterator, ParallelSliceMut}; use crate::util::log2_strict; diff --git a/mpcs/src/util/expression.rs b/mpcs/src/util/expression.rs index 97298371c..bc77edcb0 100644 --- a/mpcs/src/util/expression.rs +++ b/mpcs/src/util/expression.rs @@ -1,5 +1,4 @@ use crate::util::{Deserialize, Itertools, Serialize, izip}; -use ff::Field; use std::{ collections::BTreeSet, fmt::Debug, @@ -8,6 +7,8 @@ use std::{ ops::{Add, Mul, Neg, Sub}, }; +use p3_field::Field; + #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] pub struct Rotation(pub i32); diff --git a/mpcs/src/util/hash.rs b/mpcs/src/util/hash.rs index 2dc4c8bf4..7140129d2 100644 --- a/mpcs/src/util/hash.rs +++ b/mpcs/src/util/hash.rs @@ -1,11 +1,11 @@ -use ff_ext::ExtensionField; -use goldilocks::SmallField; -use poseidon::poseidon_hash::PoseidonHash; +use ff_ext::{ExtensionField, SmallField}; +use p3_field::PrimeField; +use p3_mds::MdsPermutation; +use poseidon::{SPONGE_WIDTH, poseidon_hash::PoseidonHash}; use transcript::Transcript; pub use poseidon::digest::Digest; -use poseidon::poseidon::PrimeField; pub fn write_digest_to_transcript( digest: &Digest, @@ -17,33 +17,50 @@ pub fn write_digest_to_transcript( .for_each(|x| transcript.append_field_element(x)); } -pub fn hash_two_leaves_ext(a: &E, b: &E) -> Digest { +pub fn hash_two_leaves_ext(a: &E, b: &E) -> Digest +where + Mds: MdsPermutation + Default, +{ let input = [a.as_bases(), b.as_bases()].concat(); - PoseidonHash::hash_or_noop(&input) + PoseidonHash::::hash_or_noop(&input) } -pub fn hash_two_leaves_base( +pub fn hash_two_leaves_base( a: &E::BaseField, b: &E::BaseField, -) -> Digest { - PoseidonHash::hash_or_noop(&[*a, *b]) +) -> Digest +where + Mds: MdsPermutation + Default, +{ + PoseidonHash::::hash_or_noop(&[*a, *b]) } -pub fn hash_two_leaves_batch_ext(a: &[E], b: &[E]) -> Digest { - let a_m_to_1_hash = PoseidonHash::hash_or_noop_iter(a.iter().flat_map(|v| v.as_bases())); - let b_m_to_1_hash = PoseidonHash::hash_or_noop_iter(b.iter().flat_map(|v| v.as_bases())); - hash_two_digests(&a_m_to_1_hash, &b_m_to_1_hash) +pub fn hash_two_leaves_batch_ext(a: &[E], b: &[E]) -> Digest +where + Mds: MdsPermutation + Default, +{ + let a_m_to_1_hash = + PoseidonHash::::hash_or_noop_iter(a.iter().flat_map(|v| v.as_bases())); + let b_m_to_1_hash = + PoseidonHash::::hash_or_noop_iter(b.iter().flat_map(|v| v.as_bases())); + hash_two_digests::(&a_m_to_1_hash, &b_m_to_1_hash) } -pub fn hash_two_leaves_batch_base( +pub fn hash_two_leaves_batch_base( a: &[E::BaseField], b: &[E::BaseField], -) -> Digest { - let a_m_to_1_hash = PoseidonHash::hash_or_noop_iter(a.iter()); - let b_m_to_1_hash = PoseidonHash::hash_or_noop_iter(b.iter()); - hash_two_digests(&a_m_to_1_hash, &b_m_to_1_hash) +) -> Digest +where + Mds: MdsPermutation + Default, +{ + let a_m_to_1_hash = PoseidonHash::::hash_or_noop_iter(a.iter()); + let b_m_to_1_hash = PoseidonHash::::hash_or_noop_iter(b.iter()); + hash_two_digests::(&a_m_to_1_hash, &b_m_to_1_hash) } -pub fn hash_two_digests(a: &Digest, b: &Digest) -> Digest { - PoseidonHash::two_to_one(a, b) +pub fn hash_two_digests(a: &Digest, b: &Digest) -> Digest +where + Mds: MdsPermutation + Default, +{ + PoseidonHash::::two_to_one(a, b) } diff --git a/mpcs/src/util/merkle_tree.rs b/mpcs/src/util/merkle_tree.rs index d24840496..e26954c84 100644 --- a/mpcs/src/util/merkle_tree.rs +++ b/mpcs/src/util/merkle_tree.rs @@ -1,6 +1,10 @@ +use std::marker::PhantomData; + use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::mle::FieldType; +use p3_mds::MdsPermutation; +use poseidon::SPONGE_WIDTH; use rayon::{ iter::{ IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator, @@ -24,28 +28,30 @@ use super::hash::write_digest_to_transcript; #[derive(Clone, Debug, Default, Serialize, Deserialize)] #[serde(bound(serialize = "E: Serialize", deserialize = "E: DeserializeOwned"))] -pub struct MerkleTree +pub struct MerkleTree where E::BaseField: Serialize + DeserializeOwned, { inner: Vec>>, leaves: Vec>, + _phantom: PhantomData, } -impl MerkleTree +impl MerkleTree where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn compute_inner(leaves: &FieldType) -> Vec>> { - merkelize::(&[leaves]) + merkelize::(&[leaves]) } pub fn compute_inner_base(leaves: &[E::BaseField]) -> Vec>> { - merkelize_base::(&[leaves]) + merkelize_base::(&[leaves]) } pub fn compute_inner_ext(leaves: &[E]) -> Vec>> { - merkelize_ext::(&[leaves]) + merkelize_ext::(&[leaves]) } pub fn root_from_inner(inner: &[Vec>]) -> Digest { @@ -56,6 +62,7 @@ where Self { inner, leaves: vec![leaves], + _phantom: PhantomData, } } @@ -63,13 +70,15 @@ where Self { inner: Self::compute_inner(&leaves), leaves: vec![leaves], + _phantom: PhantomData, } } pub fn from_batch_leaves(leaves: Vec>) -> Self { Self { - inner: merkelize::(&leaves.iter().collect_vec()), + inner: merkelize::(&leaves.iter().collect_vec()), leaves, + _phantom: PhantomData, } } @@ -139,9 +148,9 @@ where pub fn merkle_path_without_leaf_sibling_or_root( &self, leaf_index: usize, - ) -> MerklePathWithoutLeafOrRoot { + ) -> MerklePathWithoutLeafOrRoot { assert!(leaf_index < self.size().1); - MerklePathWithoutLeafOrRoot::::new( + MerklePathWithoutLeafOrRoot::::new( self.inner .iter() .take(self.height() - 1) @@ -155,19 +164,24 @@ where } #[derive(Clone, Debug, Default, Serialize, Deserialize)] -pub struct MerklePathWithoutLeafOrRoot +pub struct MerklePathWithoutLeafOrRoot where E::BaseField: Serialize + DeserializeOwned, { inner: Vec>, + _phantom: PhantomData, } -impl MerklePathWithoutLeafOrRoot +impl MerklePathWithoutLeafOrRoot where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn new(inner: Vec>) -> Self { - Self { inner } + Self { + inner, + _phantom: PhantomData, + } } pub fn is_empty(&self) -> bool { @@ -195,7 +209,7 @@ where index: usize, root: &Digest, ) { - authenticate_merkle_path_root::( + authenticate_merkle_path_root::( &self.inner, FieldType::Ext(vec![left, right]), index, @@ -210,7 +224,7 @@ where index: usize, root: &Digest, ) { - authenticate_merkle_path_root::( + authenticate_merkle_path_root::( &self.inner, FieldType::Base(vec![left, right]), index, @@ -225,7 +239,7 @@ where index: usize, root: &Digest, ) { - authenticate_merkle_path_root_batch::( + authenticate_merkle_path_root_batch::( &self.inner, FieldType::Ext(left), FieldType::Ext(right), @@ -241,7 +255,7 @@ where index: usize, root: &Digest, ) { - authenticate_merkle_path_root_batch::( + authenticate_merkle_path_root_batch::( &self.inner, FieldType::Base(left), FieldType::Base(right), @@ -253,7 +267,10 @@ where /// Merkle tree construction /// TODO: Support merkelizing mixed-type values -fn merkelize(values: &[&FieldType]) -> Vec>> { +fn merkelize(values: &[&FieldType]) -> Vec>> +where + Mds: MdsPermutation + Default, +{ #[cfg(feature = "sanity-check")] for i in 0..(values.len() - 1) { assert_eq!(values[i].len(), values[i + 1].len()); @@ -267,10 +284,10 @@ fn merkelize(values: &[&FieldType]) -> Vec { - hash_two_leaves_base::(&values[i << 1], &values[(i << 1) + 1]) + hash_two_leaves_base::(&values[i << 1], &values[(i << 1) + 1]) } FieldType::Ext(values) => { - hash_two_leaves_ext::(&values[i << 1], &values[(i << 1) + 1]) + hash_two_leaves_ext::(&values[i << 1], &values[(i << 1) + 1]) } FieldType::Unreachable => unreachable!(), }; @@ -278,7 +295,7 @@ fn merkelize(values: &[&FieldType]) -> Vec hash_two_leaves_batch_base::( + FieldType::Base(_) => hash_two_leaves_batch_base::( values .iter() .map(|values| field_type_index_base(values, i << 1)) @@ -290,7 +307,7 @@ fn merkelize(values: &[&FieldType]) -> Vec hash_two_leaves_batch_ext::( + FieldType::Ext(_) => hash_two_leaves_batch_ext::( values .iter() .map(|values| field_type_index_ext(values, i << 1)) @@ -312,7 +329,7 @@ fn merkelize(values: &[&FieldType]) -> Vec(&ys[0], &ys[1])) .collect::>(); tree.push(oracle); @@ -321,7 +338,12 @@ fn merkelize(values: &[&FieldType]) -> Vec(values: &[&[E::BaseField]]) -> Vec>> { +fn merkelize_base( + values: &[&[E::BaseField]], +) -> Vec>> +where + Mds: MdsPermutation + Default, +{ #[cfg(feature = "sanity-check")] for i in 0..(values.len() - 1) { assert_eq!(values[i].len(), values[i + 1].len()); @@ -333,11 +355,11 @@ fn merkelize_base(values: &[&[E::BaseField]]) -> Vec> 1]; if values.len() == 1 { hashes.par_iter_mut().enumerate().for_each(|(i, hash)| { - *hash = hash_two_leaves_base::(&values[0][i << 1], &values[0][(i << 1) + 1]); + *hash = hash_two_leaves_base::(&values[0][i << 1], &values[0][(i << 1) + 1]); }); } else { hashes.par_iter_mut().enumerate().for_each(|(i, hash)| { - *hash = hash_two_leaves_batch_base::( + *hash = hash_two_leaves_batch_base::( values .iter() .map(|values| values[i << 1]) @@ -357,7 +379,7 @@ fn merkelize_base(values: &[&[E::BaseField]]) -> Vec(&ys[0], &ys[1])) .collect::>(); tree.push(oracle); @@ -366,7 +388,10 @@ fn merkelize_base(values: &[&[E::BaseField]]) -> Vec(values: &[&[E]]) -> Vec>> { +fn merkelize_ext(values: &[&[E]]) -> Vec>> +where + Mds: MdsPermutation + Default, +{ #[cfg(feature = "sanity-check")] for i in 0..(values.len() - 1) { assert_eq!(values[i].len(), values[i + 1].len()); @@ -378,11 +403,11 @@ fn merkelize_ext(values: &[&[E]]) -> Vec> 1]; if values.len() == 1 { hashes.par_iter_mut().enumerate().for_each(|(i, hash)| { - *hash = hash_two_leaves_ext::(&values[0][i << 1], &values[0][(i << 1) + 1]); + *hash = hash_two_leaves_ext::(&values[0][i << 1], &values[0][(i << 1) + 1]); }); } else { hashes.par_iter_mut().enumerate().for_each(|(i, hash)| { - *hash = hash_two_leaves_batch_ext::( + *hash = hash_two_leaves_batch_ext::( values .iter() .map(|values| values[i << 1]) @@ -402,7 +427,7 @@ fn merkelize_ext(values: &[&[E]]) -> Vec(&ys[0], &ys[1])) .collect::>(); tree.push(oracle); @@ -411,17 +436,19 @@ fn merkelize_ext(values: &[&[E]]) -> Vec( +fn authenticate_merkle_path_root( path: &[Digest], leaves: FieldType, x_index: usize, root: &Digest, -) { +) where + Mds: MdsPermutation + Default, +{ let mut x_index = x_index; assert_eq!(leaves.len(), 2); let mut hash = match leaves { - FieldType::Base(leaves) => hash_two_leaves_base::(&leaves[0], &leaves[1]), - FieldType::Ext(leaves) => hash_two_leaves_ext(&leaves[0], &leaves[1]), + FieldType::Base(leaves) => hash_two_leaves_base::(&leaves[0], &leaves[1]), + FieldType::Ext(leaves) => hash_two_leaves_ext::(&leaves[0], &leaves[1]), FieldType::Unreachable => unreachable!(), }; @@ -429,40 +456,42 @@ fn authenticate_merkle_path_root( x_index >>= 1; for path_i in path.iter() { hash = if x_index & 1 == 0 { - hash_two_digests(&hash, path_i) + hash_two_digests::(&hash, path_i) } else { - hash_two_digests(path_i, &hash) + hash_two_digests::(path_i, &hash) }; x_index >>= 1; } assert_eq!(&hash, root); } -fn authenticate_merkle_path_root_batch( +fn authenticate_merkle_path_root_batch( path: &[Digest], left: FieldType, right: FieldType, x_index: usize, root: &Digest, -) { +) where + Mds: MdsPermutation + Default, +{ let mut x_index = x_index; let mut hash = if left.len() > 1 { match (left, right) { (FieldType::Base(left), FieldType::Base(right)) => { - hash_two_leaves_batch_base::(&left, &right) + hash_two_leaves_batch_base::(&left, &right) } (FieldType::Ext(left), FieldType::Ext(right)) => { - hash_two_leaves_batch_ext::(&left, &right) + hash_two_leaves_batch_ext::(&left, &right) } _ => unreachable!(), } } else { match (left, right) { (FieldType::Base(left), FieldType::Base(right)) => { - hash_two_leaves_base::(&left[0], &right[0]) + hash_two_leaves_base::(&left[0], &right[0]) } (FieldType::Ext(left), FieldType::Ext(right)) => { - hash_two_leaves_ext::(&left[0], &right[0]) + hash_two_leaves_ext::(&left[0], &right[0]) } _ => unreachable!(), } @@ -472,9 +501,9 @@ fn authenticate_merkle_path_root_batch( x_index >>= 1; for path_i in path.iter() { hash = if x_index & 1 == 0 { - hash_two_digests(&hash, path_i) + hash_two_digests::(&hash, path_i) } else { - hash_two_digests(path_i, &hash) + hash_two_digests::(path_i, &hash) }; x_index >>= 1; } diff --git a/sumcheck/src/test.rs b/sumcheck/src/test.rs index ad944820c..c6ae022b9 100644 --- a/sumcheck/src/test.rs +++ b/sumcheck/src/test.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use ark_std::{rand::RngCore, test_rng}; use ff_ext::{ExtensionField, GoldilocksExt2}; use multilinear_extensions::virtual_poly::VirtualPolynomial; -use p3_field::FieldAlgebra; +use p3_field::{Field, FieldAlgebra}; use p3_goldilocks::MdsMatrixGoldilocks; use p3_mds::MdsPermutation; use poseidon::SPONGE_WIDTH;