Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: recursion circuit public values chip #1183

Merged
merged 24 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions recursion/compiler/src/circuit/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use std::iter::repeat;

use p3_field::{AbstractExtensionField, AbstractField};
use sp1_recursion_core::air::RecursionPublicValues;

use crate::prelude::*;
use sp1_recursion_core_v2::{chips::poseidon2_skinny::WIDTH, D, DIGEST_SIZE, HASH_RATE};
Expand All @@ -24,6 +25,7 @@ pub trait CircuitV2Builder<C: Config> {
) -> [Felt<C::F>; DIGEST_SIZE];
fn fri_fold_v2(&mut self, input: CircuitV2FriFoldInput<C>) -> CircuitV2FriFoldOutput<C>;
fn ext2felt_v2(&mut self, ext: Ext<C::F, C::EF>) -> [Felt<C::F>; D];
fn commit_public_values_v2(&mut self, public_values: RecursionPublicValues<Felt<C::F>>);
fn cycle_tracker_v2_enter(&mut self, name: String);
fn cycle_tracker_v2_exit(&mut self);
fn hint_ext_v2(&mut self) -> Ext<C::F, C::EF>;
Expand Down Expand Up @@ -157,6 +159,12 @@ impl<C: Config> CircuitV2Builder<C> for Builder<C> {
felts
}

// Commits public values.
fn commit_public_values_v2(&mut self, public_values: RecursionPublicValues<Felt<C::F>>) {
self.operations
.push(DslIr::CircuitV2CommitPublicValues(Box::new(public_values)));
}

fn cycle_tracker_v2_enter(&mut self, name: String) {
self.operations.push(DslIr::CycleTrackerV2Enter(name));
}
Expand Down
29 changes: 28 additions & 1 deletion recursion/compiler/src/circuit/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ use core::fmt::Debug;
use instruction::{FieldEltType, HintBitsInstr, HintExt2FeltsInstr, HintInstr, PrintInstr};
use p3_field::{AbstractExtensionField, AbstractField, Field, PrimeField, TwoAdicField};
use sp1_core::utils::SpanBuilder;
use sp1_recursion_core::air::Block;
use sp1_recursion_core::air::{Block, RecursionPublicValues, RECURSIVE_PROOF_NUM_PV_ELTS};
use sp1_recursion_core_v2::{BaseAluInstr, BaseAluOpcode};
use std::{
borrow::Borrow,
collections::{hash_map::Entry, HashMap},
iter::{repeat, zip},
mem::transmute,
};

use sp1_recursion_core_v2::*;
Expand Down Expand Up @@ -334,6 +336,26 @@ impl<C: Config> AsmCompiler<C> {
.into()
}

fn commit_public_values(
&mut self,
public_values: &RecursionPublicValues<Felt<C::F>>,
) -> CompileOneItem<C::F> {
let pv_addrs =
unsafe {
transmute::<
RecursionPublicValues<Felt<C::F>>,
[Felt<C::F>; RECURSIVE_PROOF_NUM_PV_ELTS],
>(*public_values)
}
.map(|pv| pv.read(self));

let public_values_a: &RecursionPublicValues<Address<C::F>> = pv_addrs.as_slice().borrow();
Instruction::CommitPublicValues(CommitPublicValuesInstr {
pv_addrs: *public_values_a,
})
.into()
}

fn print_f(&mut self, addr: impl Reg<C>) -> CompileOneItem<C::F> {
Instruction::Print(PrintInstr {
field_elt_type: FieldEltType::Base,
Expand Down Expand Up @@ -457,6 +479,9 @@ impl<C: Config> AsmCompiler<C> {
vec![self.hint_bit_decomposition(value, output)]
}
DslIr::CircuitV2FriFold(output, input) => vec![self.fri_fold(output, input)],
DslIr::CircuitV2CommitPublicValues(public_values) => {
vec![self.commit_public_values(&public_values)]
}

DslIr::PrintV(dst) => vec![self.print_f(dst)],
DslIr::PrintF(dst) => vec![self.print_f(dst)],
Expand Down Expand Up @@ -577,6 +602,7 @@ impl<C: Config> AsmCompiler<C> {
kind: MemAccessKind::Read,
..
})
| Instruction::CommitPublicValues(_)
| Instruction::Print(_) => vec![],
})
.for_each(|(mult, addr): (&mut C::F, &Address<C::F>)| {
Expand Down Expand Up @@ -620,6 +646,7 @@ const fn instr_name<F>(instr: &Instruction<F>) -> &'static str {
Instruction::Print(_) => "Print",
Instruction::HintExt2Felts(_) => "HintExt2Felts",
Instruction::Hint(_) => "Hint",
Instruction::CommitPublicValues(_) => "CommitPublicValues",
}
}

Expand Down
4 changes: 4 additions & 0 deletions recursion/compiler/src/ir/instructions.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use sp1_recursion_core::air::RecursionPublicValues;

use super::{
Array, CircuitV2FriFoldInput, CircuitV2FriFoldOutput, FriFoldInput, MemIndex, Ptr, TracedVec,
};
Expand Down Expand Up @@ -216,6 +218,8 @@ pub enum DslIr<C: Config> {
CircuitV2Poseidon2PermuteBabyBearSkinny([Felt<C::F>; 16], [Felt<C::F>; 16]),
/// Permutates an array of BabyBear elements in the circuit using the wide precompile.
CircuitV2Poseidon2PermuteBabyBearWide([Felt<C::F>; 16], [Felt<C::F>; 16]),
/// Commits the public values.
CircuitV2CommitPublicValues(Box<RecursionPublicValues<Felt<C::F>>>),

// Miscellaneous instructions.
/// Decompose hint operation of a usize into an array. (output = num2bits(usize)).
Expand Down
1 change: 1 addition & 0 deletions recursion/core-v2/src/chips/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ pub mod fri_fold;
pub mod mem;
pub mod poseidon2_skinny;
pub mod poseidon2_wide;
pub mod public_values;
275 changes: 275 additions & 0 deletions recursion/core-v2/src/chips/public_values.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
use std::borrow::{Borrow, BorrowMut};

use p3_air::{Air, AirBuilder, BaseAir, PairBuilder};
use p3_field::PrimeField32;
use p3_matrix::{dense::RowMajorMatrix, Matrix};
use sp1_core::{air::MachineAir, utils::pad_rows_fixed};
use sp1_derive::AlignedBorrow;
use sp1_recursion_core::air::{RecursionPublicValues, RECURSIVE_PROOF_NUM_PV_ELTS};

use crate::{
builder::SP1RecursionAirBuilder,
runtime::{Instruction, RecursionProgram},
ExecutionRecord,
};

use crate::DIGEST_SIZE;

use super::mem::MemoryAccessCols;

pub const NUM_PUBLIC_VALUES_COLS: usize = core::mem::size_of::<PublicValuesCols<u8>>();
pub const NUM_PUBLIC_VALUES_PREPROCESSED_COLS: usize =
core::mem::size_of::<PublicValuesPreprocessedCols<u8>>();

#[derive(Default)]
pub struct PublicValuesChip {}

/// The preprocessed columns for the CommitPVHash instruction.
#[derive(AlignedBorrow, Debug, Clone, Copy)]
#[repr(C)]
pub struct PublicValuesPreprocessedCols<T: Copy> {
pub pv_idx: [T; DIGEST_SIZE],
pub pv_mem: MemoryAccessCols<T>,
}

/// The cols for a CommitPVHash invocation.
#[derive(AlignedBorrow, Debug, Clone, Copy)]
#[repr(C)]
pub struct PublicValuesCols<T: Copy> {
pub pv_element: T,
}

impl<F> BaseAir<F> for PublicValuesChip {
fn width(&self) -> usize {
NUM_PUBLIC_VALUES_COLS
}
}

impl<F: PrimeField32> MachineAir<F> for PublicValuesChip {
type Record = ExecutionRecord<F>;

type Program = RecursionProgram<F>;

fn name(&self) -> String {
"PublicValues".to_string()
}

fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
// This is a no-op.
}

fn preprocessed_width(&self) -> usize {
NUM_PUBLIC_VALUES_PREPROCESSED_COLS
}

fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
let mut rows: Vec<[F; NUM_PUBLIC_VALUES_PREPROCESSED_COLS]> = Vec::new();
let commit_pv_hash_instrs = program
.instructions
.iter()
.filter_map(|instruction| {
if let Instruction::CommitPublicValues(instr) = instruction {
Some(instr)
} else {
None
}
})
.collect::<Vec<_>>();

if commit_pv_hash_instrs.len() != 1 {
tracing::warn!("Expected exactly one CommitPVHash instruction.");
}

// We only take 1 commit pv hash instruction, since our air only checks for one public values hash.
for instr in commit_pv_hash_instrs.iter().take(1) {
for (i, addr) in instr.pv_addrs.digest.iter().enumerate() {
let mut row = [F::zero(); NUM_PUBLIC_VALUES_PREPROCESSED_COLS];
let cols: &mut PublicValuesPreprocessedCols<F> = row.as_mut_slice().borrow_mut();
cols.pv_idx[i] = F::one();
cols.pv_mem = MemoryAccessCols {
addr: *addr,
mult: F::neg_one(),
};
rows.push(row);
}
}

// Pad the preprocessed rows to 8 rows.
pad_rows_fixed(
&mut rows,
|| [F::zero(); NUM_PUBLIC_VALUES_PREPROCESSED_COLS],
Some(3),
);

let trace = RowMajorMatrix::new(
rows.into_iter().flatten().collect(),
NUM_PUBLIC_VALUES_PREPROCESSED_COLS,
);
Some(trace)
}

fn generate_trace(
&self,
input: &ExecutionRecord<F>,
_: &mut ExecutionRecord<F>,
) -> RowMajorMatrix<F> {
if input.commit_pv_hash_events.len() != 1 {
tracing::warn!("Expected exactly one CommitPVHash event.");
}

let mut rows: Vec<[F; NUM_PUBLIC_VALUES_COLS]> = Vec::new();

// We only take 1 commit pv hash instruction, since our air only checks for one public values hash.
for event in input.commit_pv_hash_events.iter().take(1) {
for element in event.public_values.digest.iter() {
let mut row = [F::zero(); NUM_PUBLIC_VALUES_COLS];
let cols: &mut PublicValuesCols<F> = row.as_mut_slice().borrow_mut();

cols.pv_element = *element;
rows.push(row);
}
}

// Pad the trace to 8 rows.
pad_rows_fixed(&mut rows, || [F::zero(); NUM_PUBLIC_VALUES_COLS], Some(3));

// Convert the trace to a row major matrix.
RowMajorMatrix::new(rows.into_iter().flatten().collect(), NUM_PUBLIC_VALUES_COLS)
}

fn included(&self, _record: &Self::Record) -> bool {
true
}
}

impl<AB> Air<AB> for PublicValuesChip
where
AB: SP1RecursionAirBuilder + PairBuilder,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local = main.row_slice(0);
let local: &PublicValuesCols<AB::Var> = (*local).borrow();
let prepr = builder.preprocessed();
let local_prepr = prepr.row_slice(0);
let local_prepr: &PublicValuesPreprocessedCols<AB::Var> = (*local_prepr).borrow();
let pv = builder.public_values();
let pv_elms: [AB::Expr; RECURSIVE_PROOF_NUM_PV_ELTS] =
core::array::from_fn(|i| pv[i].into());
let public_values: &RecursionPublicValues<AB::Expr> = pv_elms.as_slice().borrow();

// Constrain mem read for the public value element.
builder.send_single(
local_prepr.pv_mem.addr,
local.pv_element,
local_prepr.pv_mem.mult,
);

for (i, pv_elm) in public_values.digest.iter().enumerate() {
// Ensure that the public value element is the same for all rows within a fri fold invocation.
builder
.when(local_prepr.pv_idx[i])
.assert_eq(pv_elm.clone(), local.pv_element);
}
}
}

#[cfg(test)]
mod tests {
use rand::rngs::StdRng;
use rand::Rng;
use rand::SeedableRng;
use sp1_core::air::MachineAir;
use sp1_core::utils::run_test_machine;
use sp1_core::utils::setup_logger;
use sp1_core::utils::BabyBearPoseidon2;
use sp1_core::utils::DIGEST_SIZE;
use sp1_recursion_core::air::RecursionPublicValues;
use sp1_recursion_core::air::NUM_PV_ELMS_TO_HASH;
use sp1_recursion_core::air::RECURSIVE_PROOF_NUM_PV_ELTS;
use sp1_recursion_core::stark::config::BabyBearPoseidon2Outer;
use std::array;
use std::borrow::Borrow;

use p3_baby_bear::BabyBear;
use p3_baby_bear::DiffusionMatrixBabyBear;
use p3_field::AbstractField;
use p3_matrix::dense::RowMajorMatrix;
use sp1_core::stark::StarkGenericConfig;

use crate::chips::public_values::PublicValuesChip;
use crate::CommitPublicValuesEvent;
use crate::{
machine::RecursionAir,
runtime::{instruction as instr, ExecutionRecord},
MemAccessKind, RecursionProgram, Runtime,
};

#[test]
fn prove_babybear_circuit_public_values() {
setup_logger();
type SC = BabyBearPoseidon2Outer;
type F = <SC as StarkGenericConfig>::Val;
type EF = <SC as StarkGenericConfig>::Challenge;
type A = RecursionAir<F, 3, 1>;

let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
let mut random_felt = move || -> F { F::from_canonical_u32(rng.gen_range(0..1 << 16)) };
let random_pv_elms: [F; RECURSIVE_PROOF_NUM_PV_ELTS] = array::from_fn(|_| random_felt());
let addr = 0u32;
let public_values_a: [u32; RECURSIVE_PROOF_NUM_PV_ELTS] =
array::from_fn(|i| i as u32 + addr);

let mut instructions = Vec::new();
// Allocate the memory for the public values hash.

for i in 0..RECURSIVE_PROOF_NUM_PV_ELTS {
let mult = (NUM_PV_ELMS_TO_HASH..NUM_PV_ELMS_TO_HASH + DIGEST_SIZE).contains(&i);
instructions.push(instr::mem_block(
MemAccessKind::Write,
mult as u32,
public_values_a[i],
random_pv_elms[i].into(),
));
}
let public_values_a: &RecursionPublicValues<u32> = public_values_a.as_slice().borrow();
instructions.push(instr::commit_public_values(public_values_a));

let program = RecursionProgram {
instructions,
traces: Default::default(),
};

let config = SC::new();

let mut runtime =
Runtime::<F, EF, DiffusionMatrixBabyBear>::new(&program, BabyBearPoseidon2::new().perm);
runtime.run().unwrap();
let machine = A::machine(config);
let (pk, vk) = machine.setup(&program);
let result = run_test_machine(vec![runtime.record], machine, pk, vk);
if let Err(e) = result {
panic!("Verification failed: {:?}", e);
}
}

#[test]
fn generate_public_values_circuit_trace() {
type F = BabyBear;

let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
let random_felts: [F; RECURSIVE_PROOF_NUM_PV_ELTS] =
array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..1 << 16)));
let random_public_values: &RecursionPublicValues<F> = random_felts.as_slice().borrow();

let shard = ExecutionRecord {
commit_pv_hash_events: vec![CommitPublicValuesEvent {
public_values: *random_public_values,
}],
..Default::default()
};
let chip = PublicValuesChip::default();
let trace: RowMajorMatrix<F> = chip.generate_trace(&shard, &mut ExecutionRecord::default());
println!("{:?}", trace.values)
}
}
Loading
Loading