Skip to content

Commit

Permalink
Poseidon merkle hasher (#657)
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware authored Jul 8, 2024
1 parent 8866c28 commit b3f9285
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 87 deletions.
31 changes: 15 additions & 16 deletions crates/prover/src/core/channel/poseidon252.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::iter;

use starknet_crypto::poseidon_hash;
use starknet_crypto::{poseidon_hash, poseidon_hash_many};
use starknet_ff::FieldElement as FieldElement252;

use super::{Channel, ChannelTime};
Expand Down Expand Up @@ -65,25 +65,24 @@ impl Channel for Poseidon252Channel {
self.channel_time.inc_challenges();
}

// TODO(spapini): Optimize.
fn mix_felts(&mut self, felts: &[SecureField]) {
let shift = (1u64 << 31).into();
let mut cur = FieldElement252::default();
let mut in_chunk = 0;
for x in felts {
for y in x.to_m31_array() {
cur = cur * shift + y.0.into();
}
in_chunk += 1;
if in_chunk == 2 {
self.digest = poseidon_hash(self.digest, cur);
cur = FieldElement252::default();
in_chunk = 0;
}
}
if in_chunk > 0 {
self.digest = poseidon_hash(self.digest, cur);
let mut res = Vec::with_capacity(felts.len() / 2 + 2);
res.push(self.digest);
for chunk in felts.chunks(2) {
res.push(
chunk
.iter()
.flat_map(|x| x.to_m31_array())
.fold(FieldElement252::default(), |cur, y| {
cur * shift + y.0.into()
}),
);
}

self.digest = poseidon_hash_many(&res);

// TODO(spapini): do we need length padding?
self.channel_time.inc_challenges();
}
Expand Down
91 changes: 20 additions & 71 deletions crates/prover/src/core/vcs/blake2_merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,75 +39,24 @@ impl MerkleHasher for Blake2sMerkleHasher {

#[cfg(test)]
mod tests {
use std::collections::BTreeMap;

use itertools::Itertools;
use num_traits::Zero;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};

use crate::core::backend::CpuBackend;
use crate::core::fields::m31::BaseField;
use crate::core::vcs::blake2_merkle::{Blake2sHash, Blake2sMerkleHasher};
use crate::core::vcs::prover::{MerkleDecommitment, MerkleProver};
use crate::core::vcs::verifier::{MerkleVerificationError, MerkleVerifier};

type TestData = (
BTreeMap<u32, Vec<usize>>,
MerkleDecommitment<Blake2sMerkleHasher>,
Vec<Vec<BaseField>>,
MerkleVerifier<Blake2sMerkleHasher>,
);
fn prepare_merkle() -> TestData {
const N_COLS: usize = 400;
const N_QUERIES: usize = 7;
let log_size_range = 6..9;

let mut rng = SmallRng::seed_from_u64(0);
let log_sizes = (0..N_COLS)
.map(|_| rng.gen_range(log_size_range.clone()))
.collect_vec();
let cols = log_sizes
.iter()
.map(|&log_size| {
(0..(1 << log_size))
.map(|_| BaseField::from(rng.gen_range(0..(1 << 30))))
.collect_vec()
})
.collect_vec();
let merkle =
MerkleProver::<CpuBackend, Blake2sMerkleHasher>::commit(cols.iter().collect_vec());

let mut queries = BTreeMap::<u32, Vec<usize>>::new();
for log_size in log_size_range.rev() {
let layer_queries = (0..N_QUERIES)
.map(|_| rng.gen_range(0..(1 << log_size)))
.sorted()
.dedup()
.collect_vec();
queries.insert(log_size, layer_queries);
}

let (values, decommitment) = merkle.decommit(queries.clone(), cols.iter().collect_vec());

let verifier = MerkleVerifier {
root: merkle.root(),
column_log_sizes: log_sizes,
};
(queries, decommitment, values, verifier)
}
use crate::core::vcs::test_utils::prepare_merkle;
use crate::core::vcs::verifier::MerkleVerificationError;

#[test]
fn test_merkle_success() {
let (queries, decommitment, values, verifier) = prepare_merkle();
let (queries, decommitment, values, verifier) = prepare_merkle::<Blake2sMerkleHasher>();

verifier.verify(queries, values, decommitment).unwrap();
}

#[test]
fn test_merkle_invalid_witness() {
let (queries, mut decommitment, values, verifier) = prepare_merkle();
decommitment.hash_witness[20] = Blake2sHash::default();
let (queries, mut decommitment, values, verifier) = prepare_merkle::<Blake2sMerkleHasher>();
decommitment.hash_witness[4] = Blake2sHash::default();

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
Expand All @@ -117,8 +66,8 @@ mod tests {

#[test]
fn test_merkle_invalid_value() {
let (queries, decommitment, mut values, verifier) = prepare_merkle();
values[3][6] = BaseField::zero();
let (queries, decommitment, mut values, verifier) = prepare_merkle::<Blake2sMerkleHasher>();
values[3][2] = BaseField::zero();

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
Expand All @@ -128,7 +77,7 @@ mod tests {

#[test]
fn test_merkle_witness_too_short() {
let (queries, mut decommitment, values, verifier) = prepare_merkle();
let (queries, mut decommitment, values, verifier) = prepare_merkle::<Blake2sMerkleHasher>();
decommitment.hash_witness.pop();

assert_eq!(
Expand All @@ -138,35 +87,35 @@ mod tests {
}

#[test]
fn test_merkle_column_values_too_long() {
let (queries, decommitment, mut values, verifier) = prepare_merkle();
values[3].push(BaseField::zero());
fn test_merkle_witness_too_long() {
let (queries, mut decommitment, values, verifier) = prepare_merkle::<Blake2sMerkleHasher>();
decommitment.hash_witness.push(Blake2sHash::default());

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
MerkleVerificationError::ColumnValuesTooLong
MerkleVerificationError::WitnessTooLong
);
}

#[test]
fn test_merkle_column_values_too_short() {
let (queries, decommitment, mut values, verifier) = prepare_merkle();
values[3].pop();
fn test_merkle_column_values_too_long() {
let (queries, decommitment, mut values, verifier) = prepare_merkle::<Blake2sMerkleHasher>();
values[3].push(BaseField::zero());

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
MerkleVerificationError::ColumnValuesTooShort
MerkleVerificationError::ColumnValuesTooLong
);
}

#[test]
fn test_merkle_witness_too_long() {
let (queries, mut decommitment, values, verifier) = prepare_merkle();
decommitment.hash_witness.push(Blake2sHash::default());
fn test_merkle_column_values_too_short() {
let (queries, decommitment, mut values, verifier) = prepare_merkle::<Blake2sMerkleHasher>();
values[3].pop();

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
MerkleVerificationError::WitnessTooLong
MerkleVerificationError::ColumnValuesTooShort
);
}
}
5 changes: 5 additions & 0 deletions crates/prover/src/core/vcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ pub mod blake2s_ref;
pub mod blake3_hash;
pub mod hasher;
pub mod ops;
#[cfg(not(target_arch = "wasm32"))]
pub mod poseidon252_merkle;
pub mod prover;
mod utils;
pub mod verifier;

#[cfg(test)]
mod test_utils;
174 changes: 174 additions & 0 deletions crates/prover/src/core/vcs/poseidon252_merkle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
use itertools::Itertools;
use num_traits::Zero;
use starknet_crypto::poseidon_hash_many;
use starknet_ff::FieldElement as FieldElement252;

use super::ops::{MerkleHasher, MerkleOps};
use crate::core::backend::CpuBackend;
use crate::core::fields::m31::BaseField;

const ELEMENTS_IN_BLOCK: usize = 8;

#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
pub struct Poseidon252MerkleHasher;
impl MerkleHasher for Poseidon252MerkleHasher {
type Hash = FieldElement252;

fn hash_node(
children_hashes: Option<(Self::Hash, Self::Hash)>,
column_values: &[BaseField],
) -> Self::Hash {
let n_column_blocks = column_values.len().div_ceil(ELEMENTS_IN_BLOCK);
let values_len = 2 + n_column_blocks;
let mut values = Vec::with_capacity(values_len);

if let Some((left, right)) = children_hashes {
values.push(left);
values.push(right);
}

let padding_length = ELEMENTS_IN_BLOCK * n_column_blocks - column_values.len();
let padded_values = column_values
.iter()
.copied()
.chain(std::iter::repeat(BaseField::zero()).take(padding_length));
for chunk in padded_values.array_chunks::<ELEMENTS_IN_BLOCK>() {
let mut word = FieldElement252::default();
for x in chunk {
word = word * FieldElement252::from(2u64.pow(31)) + FieldElement252::from(x.0);
}
values.push(word);
}
poseidon_hash_many(&values)
}
}

impl MerkleOps<Poseidon252MerkleHasher> for CpuBackend {
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<FieldElement252>>,
columns: &[&Vec<BaseField>],
) -> Vec<FieldElement252> {
(0..(1 << log_size))
.map(|i| {
Poseidon252MerkleHasher::hash_node(
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
&columns.iter().map(|column| column[i]).collect_vec(),
)
})
.collect()
}
}

#[cfg(test)]
mod tests {
use num_traits::Zero;
use starknet_ff::FieldElement as FieldElement252;

use crate::core::fields::m31::BaseField;
use crate::core::vcs::ops::MerkleHasher;
use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleHasher;
use crate::core::vcs::test_utils::prepare_merkle;
use crate::core::vcs::verifier::MerkleVerificationError;
use crate::m31;

#[test]
fn test_vector() {
assert_eq!(
Poseidon252MerkleHasher::hash_node(None, &[m31!(0), m31!(1)]),
FieldElement252::from_dec_str(
"2552053700073128806553921687214114320458351061521275103654266875084493044716"
)
.unwrap()
);

assert_eq!(
Poseidon252MerkleHasher::hash_node(
Some((FieldElement252::from(1u32), FieldElement252::from(2u32))),
&[m31!(3)]
),
FieldElement252::from_dec_str(
"159358216886023795422515519110998391754567506678525778721401012606792642769"
)
.unwrap()
);
}

#[test]
fn test_merkle_success() {
let (queries, decommitment, values, verifier) = prepare_merkle::<Poseidon252MerkleHasher>();
verifier.verify(queries, values, decommitment).unwrap();
}

#[test]
fn test_merkle_invalid_witness() {
let (queries, mut decommitment, values, verifier) =
prepare_merkle::<Poseidon252MerkleHasher>();
decommitment.hash_witness[4] = FieldElement252::default();

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
MerkleVerificationError::RootMismatch
);
}

#[test]
fn test_merkle_invalid_value() {
let (queries, decommitment, mut values, verifier) =
prepare_merkle::<Poseidon252MerkleHasher>();
values[3][2] = BaseField::zero();

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
MerkleVerificationError::RootMismatch
);
}

#[test]
fn test_merkle_witness_too_short() {
let (queries, mut decommitment, values, verifier) =
prepare_merkle::<Poseidon252MerkleHasher>();
decommitment.hash_witness.pop();

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
MerkleVerificationError::WitnessTooShort
);
}

#[test]
fn test_merkle_witness_too_long() {
let (queries, mut decommitment, values, verifier) =
prepare_merkle::<Poseidon252MerkleHasher>();
decommitment.hash_witness.push(FieldElement252::default());

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
MerkleVerificationError::WitnessTooLong
);
}

#[test]
fn test_merkle_column_values_too_long() {
let (queries, decommitment, mut values, verifier) =
prepare_merkle::<Poseidon252MerkleHasher>();
values[3].push(BaseField::zero());

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
MerkleVerificationError::ColumnValuesTooLong
);
}

#[test]
fn test_merkle_column_values_too_short() {
let (queries, decommitment, mut values, verifier) =
prepare_merkle::<Poseidon252MerkleHasher>();
values[3].pop();

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
MerkleVerificationError::ColumnValuesTooShort
);
}
}
Loading

0 comments on commit b3f9285

Please sign in to comment.