Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: circuit poseidon2 babybear #870

Merged
merged 18 commits into from
Jun 3, 2024
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions recursion/circuit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,6 @@ p3-poseidon2 = { workspace = true }
zkhash = { git = "https://github.com/HorizenLabs/poseidon2" }
rand = "0.8.5"
sp1-recursion-gnark-ffi = { path = "../gnark-ffi" }

[features]
plonk = ["sp1-recursion-gnark-ffi/plonk"]
93 changes: 93 additions & 0 deletions recursion/circuit/src/poseidon2.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! An implementation of Poseidon2 over BN254.

use std::array;

use itertools::Itertools;
use p3_field::AbstractField;
use p3_field::Field;
Expand All @@ -16,6 +18,8 @@ pub trait Poseidon2CircuitBuilder<C: Config> {
fn p2_permute_mut(&mut self, state: [Var<C::N>; SPONGE_SIZE]);
fn p2_hash(&mut self, input: &[Felt<C::F>]) -> OuterDigestVariable<C>;
fn p2_compress(&mut self, input: [OuterDigestVariable<C>; 2]) -> OuterDigestVariable<C>;
fn p2_babybear_permute_mut(&mut self, state: [Felt<C::F>; 16]);
fn p2_babybear_hash(&mut self, input: &[Felt<C::F>]) -> [Felt<C::F>; 8];
}

impl<C: Config> Poseidon2CircuitBuilder<C> for Builder<C> {
Expand Down Expand Up @@ -52,6 +56,24 @@ impl<C: Config> Poseidon2CircuitBuilder<C> for Builder<C> {
self.p2_permute_mut(state);
[state[0]; DIGEST_SIZE]
}

fn p2_babybear_permute_mut(&mut self, state: [Felt<C::F>; 16]) {
self.push(DslIr::CircuitPoseidon2PermuteBabyBear(state));
}

fn p2_babybear_hash(&mut self, input: &[Felt<C::F>]) -> [Felt<C::F>; 8] {
let mut state: [Felt<C::F>; 16] = array::from_fn(|_| self.eval(C::F::zero()));

for block_chunk in &input.iter().chunks(8) {
state
.iter_mut()
.zip(block_chunk)
.for_each(|(s, i)| *s = self.eval(*i));
self.p2_babybear_permute_mut(state);
}

array::from_fn(|i| state[i])
}
}

#[cfg(test)]
Expand All @@ -60,6 +82,9 @@ pub mod tests {
use p3_bn254_fr::Bn254Fr;
use p3_field::AbstractField;
use p3_symmetric::{CryptographicHasher, Permutation, PseudoCompressionFunction};
use rand::thread_rng;
use rand::Rng;
use sp1_core::utils::{inner_perm, InnerHash};
use sp1_recursion_compiler::config::OuterConfig;
use sp1_recursion_compiler::constraints::ConstraintCompiler;
use sp1_recursion_compiler::ir::{Builder, Felt, Var, Witness};
Expand Down Expand Up @@ -95,6 +120,25 @@ pub mod tests {
PlonkBn254Prover::test::<OuterConfig>(constraints.clone(), Witness::default());
}

#[test]
fn test_p2_babybear_permute_mut() {
let mut rng = thread_rng();
let mut builder = Builder::<OuterConfig>::default();
let input: [BabyBear; 16] = [rng.gen(); 16];
let input_vars: [Felt<_>; 16] = input.map(|x| builder.eval(x));
builder.p2_babybear_permute_mut(input_vars);

let perm = inner_perm();
let result = perm.permute(input);
for i in 0..16 {
builder.assert_felt_eq(input_vars[i], result[i]);
}

let mut backend = ConstraintCompiler::<OuterConfig>::default();
let constraints = backend.emit(builder.operations);
PlonkBn254Prover::test::<OuterConfig>(constraints.clone(), Witness::default());
}

#[test]
fn test_p2_hash() {
let perm = outer_perm();
Expand Down Expand Up @@ -147,4 +191,53 @@ pub mod tests {
let constraints = backend.emit(builder.operations);
PlonkBn254Prover::test::<OuterConfig>(constraints.clone(), Witness::default());
}

#[test]
fn test_p2_babybear_hash() {
let perm = inner_perm();
let hasher = InnerHash::new(perm.clone());

let input: [BabyBear; 26] = [
BabyBear::from_canonical_u32(0),
BabyBear::from_canonical_u32(1),
BabyBear::from_canonical_u32(2),
BabyBear::from_canonical_u32(2),
BabyBear::from_canonical_u32(2),
BabyBear::from_canonical_u32(2),
BabyBear::from_canonical_u32(2),
BabyBear::from_canonical_u32(2),
BabyBear::from_canonical_u32(2),
BabyBear::from_canonical_u32(2),
BabyBear::from_canonical_u32(2),
BabyBear::from_canonical_u32(2),
BabyBear::from_canonical_u32(2),
BabyBear::from_canonical_u32(2),
BabyBear::from_canonical_u32(2),
BabyBear::from_canonical_u32(3),
BabyBear::from_canonical_u32(3),
BabyBear::from_canonical_u32(3),
BabyBear::from_canonical_u32(3),
BabyBear::from_canonical_u32(3),
BabyBear::from_canonical_u32(3),
BabyBear::from_canonical_u32(3),
BabyBear::from_canonical_u32(3),
BabyBear::from_canonical_u32(3),
BabyBear::from_canonical_u32(3),
BabyBear::from_canonical_u32(3),
];
let output = hasher.hash_iter(input);
println!("{:?}", output);

let mut builder = Builder::<OuterConfig>::default();
let input_felts: [Felt<_>; 26] = input.map(|x| builder.eval(x));
let result = builder.p2_babybear_hash(input_felts.as_slice());

for i in 0..8 {
builder.assert_felt_eq(result[i], output[i]);
}

let mut backend = ConstraintCompiler::<OuterConfig>::default();
let constraints = backend.emit(builder.operations);
PlonkBn254Prover::test::<OuterConfig>(constraints.clone(), Witness::default());
}
}
12 changes: 11 additions & 1 deletion recursion/circuit/src/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::borrow::Borrow;
use std::marker::PhantomData;

use crate::fri::verify_two_adic_pcs;
use crate::poseidon2::Poseidon2CircuitBuilder;
use crate::types::OuterDigestVariable;
use crate::utils::{babybear_bytes_to_bn254, babybears_to_bn254, words_to_bytes};
use crate::witness::Witnessable;
Expand All @@ -20,7 +21,7 @@ use sp1_recursion_compiler::constraints::{Constraint, ConstraintCompiler};
use sp1_recursion_compiler::ir::{Builder, Config, Ext, Felt, Var};
use sp1_recursion_compiler::ir::{Usize, Witness};
use sp1_recursion_compiler::prelude::SymbolicVar;
use sp1_recursion_core::air::RecursionPublicValues;
use sp1_recursion_core::air::{RecursionPublicValues, NUM_PV_ELMS_TO_HASH};
use sp1_recursion_core::stark::config::{outer_fri_config, BabyBearPoseidon2Outer};
use sp1_recursion_core::stark::RecursionAirSkinnyDeg9;
use sp1_recursion_program::commit::PolynomialSpaceVariable;
Expand Down Expand Up @@ -270,7 +271,9 @@ pub fn build_wrap_circuit(
let element = builder.get(&proof.public_values, i);
pv_elements.push(element);
}

let pv: &RecursionPublicValues<_> = pv_elements.as_slice().borrow();

let one_felt: Felt<_> = builder.constant(BabyBear::one());
// Proof must be complete. In the reduce program, this will ensure that the SP1 proof has been
// fully accumulated.
Expand Down Expand Up @@ -347,6 +350,13 @@ pub fn build_wrap_circuit(
}
builder.assert_ext_eq(cumulative_sum, zero_ext);

// Verify the public values digest.
let calculated_digest = builder.p2_babybear_hash(&pv_elements[0..NUM_PV_ELMS_TO_HASH]);
let expected_digest = pv.digest;
for (calculated_elm, expected_elm) in calculated_digest.iter().zip(expected_digest.iter()) {
builder.assert_felt_eq(*expected_elm, *calculated_elm);
}

let mut backend = ConstraintCompiler::<OuterConfig>::default();
backend.emit(builder.operations)
}
Expand Down
4 changes: 4 additions & 0 deletions recursion/compiler/src/constraints/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,10 @@ impl<C: Config + Debug> ConstraintCompiler<C> {
opcode: ConstraintOpcode::Permute,
args: state.iter().map(|x| vec![x.id()]).collect(),
}),
DslIr::CircuitPoseidon2PermuteBabyBear(state) => constraints.push(Constraint {
opcode: ConstraintOpcode::PermuteBabyBear,
args: state.iter().map(|x| vec![x.id()]).collect(),
}),
DslIr::CircuitSelectV(cond, a, b, out) => {
constraints.push(Constraint {
opcode: ConstraintOpcode::SelectV,
Expand Down
1 change: 1 addition & 0 deletions recursion/compiler/src/constraints/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@ pub enum ConstraintOpcode {
CommitVkeyHash,
CommitCommitedValuesDigest,
CircuitFelts2Ext,
PermuteBabyBear,
}
2 changes: 2 additions & 0 deletions recursion/compiler/src/ir/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ pub enum DslIr<C: Config> {
/// Permutes an array of Bn254 elements using Poseidon2 (output = p2_permute(array)). Should only
/// be used when target is a gnark circuit.
CircuitPoseidon2Permute([Var<C::N>; 3]),
/// Permutates an array of BabyBear elements in the circuit.
CircuitPoseidon2PermuteBabyBear([Felt<C::F>; 16]),

// Miscellaneous instructions.
/// Decompose hint operation of a usize into an array. (output = num2bits(usize)).
Expand Down
2 changes: 2 additions & 0 deletions recursion/gnark-ffi/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ edition = "2021"

[dependencies]
p3-field = { workspace = true }
p3-symmetric = { workspace = true }
p3-baby-bear = { workspace = true }
sp1-recursion-compiler = { path = "../compiler" }
sp1-core = { path = "../../core" }
serde = "1.0.201"
serde_json = "1.0.117"
tempfile = "3.10.1"
Expand Down
74 changes: 74 additions & 0 deletions recursion/gnark-ffi/go/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@ import (
"sync"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend/groth16"
"github.com/consensys/gnark/backend/plonk"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/r1cs"
"github.com/consensys/gnark/frontend/cs/scs"
"github.com/consensys/gnark/test/unsafekzg"
"github.com/succinctlabs/sp1-recursion-gnark/sp1"
"github.com/succinctlabs/sp1-recursion-gnark/sp1/babybear"
"github.com/succinctlabs/sp1-recursion-gnark/sp1/poseidon2"
)

func main() {}
Expand Down Expand Up @@ -141,3 +145,73 @@ func TestMain() error {

return nil
}

//export TestPoseidonBabyBear2
func TestPoseidonBabyBear2() *C.char {
input := [poseidon2.BABYBEAR_WIDTH]babybear.Variable{
babybear.NewF("0"),
babybear.NewF("0"),
babybear.NewF("0"),
babybear.NewF("0"),
babybear.NewF("0"),
babybear.NewF("0"),
babybear.NewF("0"),
babybear.NewF("0"),
babybear.NewF("0"),
babybear.NewF("0"),
babybear.NewF("0"),
babybear.NewF("0"),
babybear.NewF("0"),
babybear.NewF("0"),
babybear.NewF("0"),
babybear.NewF("0"),
}

expectedOutput := [poseidon2.BABYBEAR_WIDTH]babybear.Variable{
babybear.NewF("348670919"),
babybear.NewF("1568590631"),
babybear.NewF("1535107508"),
babybear.NewF("186917780"),
babybear.NewF("587749971"),
babybear.NewF("1827585060"),
babybear.NewF("1218809104"),
babybear.NewF("691692291"),
babybear.NewF("1480664293"),
babybear.NewF("1491566329"),
babybear.NewF("366224457"),
babybear.NewF("490018300"),
babybear.NewF("732772134"),
babybear.NewF("560796067"),
babybear.NewF("484676252"),
babybear.NewF("405025962"),
}

circuit := sp1.TestPoseidon2BabyBearCircuit{Input: input, ExpectedOutput: expectedOutput}
assignment := sp1.TestPoseidon2BabyBearCircuit{Input: input, ExpectedOutput: expectedOutput}

builder := r1cs.NewBuilder
r1cs, err := frontend.Compile(ecc.BN254.ScalarField(), builder, &circuit)
if err != nil {
return C.CString(err.Error())
}

var pk groth16.ProvingKey
pk, err = groth16.DummySetup(r1cs)
if err != nil {
return C.CString(err.Error())
}

// Generate witness.
witness, err := frontend.NewWitness(&assignment, ecc.BN254.ScalarField())
if err != nil {
return C.CString(err.Error())
}

// Generate the proof.
_, err = groth16.Prove(r1cs, pk, witness)
if err != nil {
return C.CString(err.Error())
}

return nil
}
Loading
Loading