Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: witness input for v2 recursion #1255

Merged
merged 7 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions recursion/compiler/src/circuit/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ pub trait CircuitV2Builder<C: Config> {
fn ext2felt_v2(&mut self, ext: Ext<C::F, C::EF>) -> [Felt<C::F>; D];
fn cycle_tracker_v2_enter(&mut self, name: String);
fn cycle_tracker_v2_exit(&mut self);
fn hint_ext_v2(&mut self) -> Ext<C::F, C::EF>;
fn hint_felt_v2(&mut self) -> Felt<C::F>;
fn hint_exts_v2(&mut self, len: usize) -> Vec<Ext<C::F, C::EF>>;
fn hint_felts_v2(&mut self, len: usize) -> Vec<Felt<C::F>>;
}

impl<C: Config> CircuitV2Builder<C> for Builder<C> {
Expand All @@ -40,6 +44,7 @@ impl<C: Config> CircuitV2Builder<C> for Builder<C> {
}
num
}

/// Converts a felt to bits inside a circuit.
fn num2bits_v2_f(&mut self, num: Felt<C::F>, num_bits: usize) -> Vec<Felt<C::F>> {
let output = std::iter::from_fn(|| Some(self.uninit()))
Expand All @@ -60,6 +65,7 @@ impl<C: Config> CircuitV2Builder<C> for Builder<C> {

output
}

/// A version of `exp_reverse_bits_len` that uses the ExpReverseBitsLen precompile.
fn exp_reverse_bits_v2(
&mut self,
Expand All @@ -71,6 +77,7 @@ impl<C: Config> CircuitV2Builder<C> for Builder<C> {
.push(DslIr::CircuitV2ExpReverseBits(output, input, power_bits));
output
}

/// Applies the Poseidon2 permutation to the given array.
fn poseidon2_permute_v2_skinny(&mut self, array: [Felt<C::F>; WIDTH]) -> [Felt<C::F>; WIDTH] {
let output: [Felt<C::F>; WIDTH] = core::array::from_fn(|_| self.uninit());
Expand All @@ -87,6 +94,7 @@ impl<C: Config> CircuitV2Builder<C> for Builder<C> {
.push(DslIr::CircuitV2Poseidon2PermuteBabyBearWide(output, array));
output
}

/// Applies the Poseidon2 permutation to the given array.
///
/// Reference: [p3_symmetric::PaddingFreeSponge]
Expand All @@ -100,6 +108,7 @@ impl<C: Config> CircuitV2Builder<C> for Builder<C> {
let state: [Felt<C::F>; DIGEST_SIZE] = state[..DIGEST_SIZE].try_into().unwrap();
state
}

/// Applies the Poseidon2 compression function to the given array.
///
/// Reference: [p3_symmetric::TruncatedPermutation]
Expand All @@ -114,6 +123,7 @@ impl<C: Config> CircuitV2Builder<C> for Builder<C> {
let post: [Felt<C::F>; DIGEST_SIZE] = post[..DIGEST_SIZE].try_into().unwrap();
post
}

/// Runs FRI fold.
fn fri_fold_v2(&mut self, input: CircuitV2FriFoldInput<C>) -> CircuitV2FriFoldOutput<C> {
let mut uninit_vec = |len| {
Expand All @@ -129,6 +139,7 @@ impl<C: Config> CircuitV2Builder<C> for Builder<C> {
.push(DslIr::CircuitV2FriFold(output.clone(), input));
output
}

/// Decomposes an ext into its felt coordinates.
fn ext2felt_v2(&mut self, ext: Ext<C::F, C::EF>) -> [Felt<C::F>; D] {
let felts = core::array::from_fn(|_| self.uninit());
Expand All @@ -145,10 +156,40 @@ impl<C: Config> CircuitV2Builder<C> for Builder<C> {

felts
}

fn cycle_tracker_v2_enter(&mut self, name: String) {
self.operations.push(DslIr::CycleTrackerV2Enter(name));
}

fn cycle_tracker_v2_exit(&mut self) {
self.operations.push(DslIr::CycleTrackerV2Exit);
}

/// Hint a single felt.
fn hint_felt_v2(&mut self) -> Felt<C::F> {
self.hint_felts_v2(1)[0]
}

/// Hint a single ext.
fn hint_ext_v2(&mut self) -> Ext<C::F, C::EF> {
self.hint_exts_v2(1)[0]
}
tqn marked this conversation as resolved.
Show resolved Hide resolved

/// Hint a vector of felts.
fn hint_felts_v2(&mut self, len: usize) -> Vec<Felt<C::F>> {
let arr = std::iter::from_fn(|| Some(self.uninit()))
.take(len)
.collect::<Vec<_>>();
self.operations.push(DslIr::CircuitV2HintFelts(arr.clone()));
arr
}

/// Hint a vector of exts.
fn hint_exts_v2(&mut self, len: usize) -> Vec<Ext<C::F, C::EF>> {
let arr = std::iter::from_fn(|| Some(self.uninit()))
.take(len)
.collect::<Vec<_>>();
self.operations.push(DslIr::CircuitV2HintExts(arr.clone()));
arr
}
}
61 changes: 17 additions & 44 deletions recursion/compiler/src/circuit/compiler.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use chips::poseidon2_skinny::WIDTH;
use core::fmt::Debug;
use instruction::{FieldEltType, HintBitsInstr, HintExt2FeltsInstr, PrintInstr};
use instruction::{FieldEltType, HintBitsInstr, HintExt2FeltsInstr, HintInstr, PrintInstr};
use p3_field::{AbstractExtensionField, AbstractField, Field, PrimeField, TwoAdicField};
use sp1_core::utils::SpanBuilder;
use sp1_recursion_core::air::Block;
Expand Down Expand Up @@ -358,6 +358,16 @@ impl<C: Config> AsmCompiler<C> {
.into()
}

fn hint(&mut self, output: &[impl Reg<C>]) -> CompileOneItem<C::F> {
Instruction::Hint(HintInstr {
output_addrs_mults: output
.iter()
.map(|r| (r.write(self), C::F::zero()))
.collect(),
})
.into()
}

pub fn compile_one<F>(&mut self, ir_instr: DslIr<C>) -> Vec<CompileOneItem<C::F>>
where
F: PrimeField + TwoAdicField,
Expand Down Expand Up @@ -448,53 +458,12 @@ impl<C: Config> AsmCompiler<C> {
}
DslIr::CircuitV2FriFold(output, input) => vec![self.fri_fold(output, input)],

// DslIr::For(_, _, _, _, _) => todo!(),
// DslIr::IfEq(_, _, _, _) => todo!(),
// DslIr::IfNe(_, _, _, _) => todo!(),
// DslIr::IfEqI(_, _, _, _) => todo!(),
// DslIr::IfNeI(_, _, _, _) => todo!(),
// DslIr::Break => todo!(),
// DslIr::Alloc(_, _, _) => todo!(),
// DslIr::LoadV(_, _, _) => todo!(),
// DslIr::LoadF(_, _, _) => todo!(),
// DslIr::LoadE(_, _, _) => todo!(),
// DslIr::StoreV(_, _, _) => todo!(),
// DslIr::StoreF(_, _, _) => todo!(),
// DslIr::StoreE(_, _, _) => todo!(),
// DslIr::CircuitNum2BitsV(_, _, _) => todo!(),
// DslIr::Poseidon2CompressBabyBear(_, _, _) => todo!(),
// DslIr::Poseidon2AbsorbBabyBear(_, _) => todo!(),
// DslIr::Poseidon2FinalizeBabyBear(_, _) => todo!(),
// DslIr::CircuitPoseidon2Permute(_) => todo!(),
// DslIr::CircuitPoseidon2PermuteBabyBear(_) => todo!(),
// DslIr::HintBitsU(_, _) => todo!(),
// DslIr::HintBitsV(_, _) => todo!(),
// DslIr::HintBitsF(_, _) => todo!(),
DslIr::PrintV(dst) => vec![self.print_f(dst)],
DslIr::PrintF(dst) => vec![self.print_f(dst)],
DslIr::PrintE(dst) => vec![self.print_e(dst)],
// DslIr::Error() => todo!(),
// DslIr::HintExt2Felt(_, _) => todo!(),
// DslIr::HintLen(_) => todo!(),
// DslIr::HintVars(_) => todo!(),
// DslIr::HintFelts(_) => todo!(),
// DslIr::HintExts(_) => todo!(),
// DslIr::WitnessVar(_, _) => todo!(),
// DslIr::WitnessFelt(_, _) => todo!(),
// DslIr::WitnessExt(_, _) => todo!(),
// DslIr::Commit(_, _) => todo!(),
// DslIr::RegisterPublicValue(_) => todo!(),
// DslIr::Halt => todo!(),
// DslIr::CircuitCommitVkeyHash(_) => todo!(),
// DslIr::CircuitCommitCommitedValuesDigest(_) => todo!(),
// DslIr::FriFold(_, _) => todo!(),
// DslIr::CircuitSelectV(_, _, _, _) => todo!(),
// DslIr::CircuitSelectF(_, _, _, _) => todo!(),
// DslIr::CircuitSelectE(_, _, _, _) => todo!(),
DslIr::CircuitV2HintFelts(output) => vec![self.hint(&output)],
DslIr::CircuitV2HintExts(output) => vec![self.hint(&output)],
DslIr::CircuitExt2Felt(felts, ext) => vec![self.ext2felts(felts, ext)],
// DslIr::LessThan(_, _, _) => todo!(),
// DslIr::CycleTracker(_) => todo!(),
// DslIr::ExpReverseBitsLen(_, _, _) => todo!(),
DslIr::CycleTrackerV2Enter(name) => vec![CompileOneItem::CycleTrackerEnter(name)],
DslIr::CycleTrackerV2Exit => vec![CompileOneItem::CycleTrackerExit],
instr => panic!("unsupported instruction: {instr:?}"),
Expand Down Expand Up @@ -575,6 +544,9 @@ impl<C: Config> AsmCompiler<C> {
}) => vec![(mult, result)],
Instruction::HintBits(HintBitsInstr {
output_addrs_mults, ..
})
| Instruction::Hint(HintInstr {
output_addrs_mults, ..
}) => output_addrs_mults
.iter_mut()
.map(|(ref addr, mult)| (mult, addr))
Expand Down Expand Up @@ -647,6 +619,7 @@ const fn instr_name<F>(instr: &Instruction<F>) -> &'static str {
Instruction::FriFold(_) => "FriFold",
Instruction::Print(_) => "Print",
Instruction::HintExt2Felts(_) => "HintExt2Felts",
Instruction::Hint(_) => "Hint",
}
}

Expand Down
106 changes: 103 additions & 3 deletions recursion/compiler/src/circuit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub use compiler::*;
#[cfg(test)]
mod tests {
use p3_baby_bear::DiffusionMatrixBabyBear;
use p3_field::AbstractExtensionField;
use p3_field::{AbstractExtensionField, AbstractField};
use rand::{rngs::StdRng, Rng, SeedableRng};
use sp1_core::{
stark::{Chip, StarkGenericConfig, StarkMachine, PROOF_MAX_NUM_PVS},
Expand All @@ -23,10 +23,14 @@ mod tests {
poseidon2_wide::Poseidon2WideChip,
},
machine::RecursionAir,
Runtime,
Runtime, RuntimeError,
};

use crate::{asm::AsmBuilder, circuit::AsmCompiler, ir::*};
use crate::{
asm::AsmBuilder,
circuit::{AsmCompiler, CircuitV2Builder},
ir::*,
};

const DEGREE: usize = 3;

Expand Down Expand Up @@ -119,4 +123,100 @@ mod tests {

tracing::info!("num shard proofs: {}", result.shard_proofs.len());
}

#[test]
fn test_io() {
let mut builder = AsmBuilder::<F, EF>::default();

let felts = builder.hint_felts_v2(3);
assert_eq!(felts.len(), 3);
let sum: Felt<_> = builder.eval(felts[0] + felts[1]);
builder.assert_felt_eq(sum, felts[2]);

let exts = builder.hint_exts_v2(3);
assert_eq!(exts.len(), 3);
let sum: Ext<_, _> = builder.eval(exts[0] + exts[1]);
builder.assert_ext_ne(sum, exts[2]);

let x = builder.hint_ext_v2();
builder.assert_ext_eq(x, exts[0] + felts[0]);

let y = builder.hint_felt_v2();
let zero: Felt<_> = builder.constant(F::zero());
builder.assert_felt_eq(y, zero);

let operations = builder.operations;
let mut compiler = AsmCompiler::default();
let program = compiler.compile(operations);
let mut runtime = Runtime::<F, EF, DiffusionMatrixBabyBear>::new(&program, SC::new().perm);
runtime.witness_stream = [
vec![F::one().into(), F::one().into(), F::two().into()],
vec![F::zero().into(), F::one().into(), F::two().into()],
vec![F::one().into()],
vec![F::zero().into()],
]
.into();
runtime.run().unwrap();

let machine = A::machine(SC::new());

let (pk, vk) = machine.setup(&program);
let result =
run_test_machine(vec![runtime.record], machine, pk, vk.clone()).expect("should verify");

tracing::info!("num shard proofs: {}", result.shard_proofs.len());
}

#[test]
fn test_empty_witness_stream() {
let mut builder = AsmBuilder::<F, EF>::default();

let felts = builder.hint_felts_v2(3);
assert_eq!(felts.len(), 3);
let sum: Felt<_> = builder.eval(felts[0] + felts[1]);
builder.assert_felt_eq(sum, felts[2]);

let exts = builder.hint_exts_v2(3);
assert_eq!(exts.len(), 3);
let sum: Ext<_, _> = builder.eval(exts[0] + exts[1]);
builder.assert_ext_ne(sum, exts[2]);

let operations = builder.operations;
let mut compiler = AsmCompiler::default();
let program = compiler.compile(operations);
let mut runtime = Runtime::<F, EF, DiffusionMatrixBabyBear>::new(&program, SC::new().perm);
runtime.witness_stream = [vec![F::one().into(), F::one().into(), F::two().into()]].into();

match runtime.run() {
Err(RuntimeError::EmptyWitnessStream) => (),
Ok(_) => panic!("should not succeed"),
Err(x) => panic!("should not yield error variant: {}", x),
}
}

#[test]
fn test_mismatched_witness_size() {
const MEM_VEC_LEN: usize = 3;
const WITNESS_LEN: usize = 5;

let mut builder = AsmBuilder::<F, EF>::default();

let felts = builder.hint_felts_v2(MEM_VEC_LEN);
assert_eq!(felts.len(), MEM_VEC_LEN);

let operations = builder.operations;
let mut compiler = AsmCompiler::default();
let program = compiler.compile(operations);
let mut runtime = Runtime::<F, EF, DiffusionMatrixBabyBear>::new(&program, SC::new().perm);
runtime.witness_stream = [vec![F::zero().into(); WITNESS_LEN]].into();

match runtime.run() {
Err(RuntimeError::WitnessLenMismatch {
mem_vec_len: MEM_VEC_LEN,
witness_len: WITNESS_LEN,
}) => (),
Ok(_) => panic!("should not succeed"),
Err(x) => panic!("should not yield error variant: {}", x),
}
}
}
4 changes: 4 additions & 0 deletions recursion/compiler/src/ir/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,10 @@ pub enum DslIr<C: Config> {
HintFelts(Array<C, Felt<C::F>>),
/// Hint an array of extension field elements.
HintExts(Array<C, Ext<C::F, C::EF>>),
/// Hint an array of field elements.
CircuitV2HintFelts(Vec<Felt<C::F>>),
/// Hint an array of extension field elements.
CircuitV2HintExts(Vec<Ext<C::F, C::EF>>),
/// Witness a variable. Should only be used when target is a gnark circuit.
WitnessVar(Var<C::N>, u32),
/// Witness a field element. Should only be used when target is a gnark circuit.
Expand Down
2 changes: 1 addition & 1 deletion recursion/core-v2/src/chips/alu_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl<F: PrimeField32> MachineAir<F> for BaseAluChip {
type Program = crate::RecursionProgram<F>;

fn name(&self) -> String {
"Base field Alu".to_string()
"BaseAlu".to_string()
}

fn preprocessed_width(&self) -> usize {
Expand Down
2 changes: 1 addition & 1 deletion recursion/core-v2/src/chips/alu_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl<F: PrimeField32 + BinomiallyExtendable<D>> MachineAir<F> for ExtAluChip {
type Program = crate::RecursionProgram<F>;

fn name(&self) -> String {
"Extension field Alu".to_string()
"ExtAlu".to_string()
}

fn preprocessed_width(&self) -> usize {
Expand Down
2 changes: 1 addition & 1 deletion recursion/core-v2/src/chips/dummy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl<F: PrimeField32, const COL_PADDING: usize> MachineAir<F> for DummyChip<COL_
type Program = crate::RecursionProgram<F>;

fn name(&self) -> String {
"Dummy wide".to_string()
"DummyWide".to_string()
}

fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
Expand Down
2 changes: 1 addition & 1 deletion recursion/core-v2/src/chips/mem/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ impl<F: PrimeField32> MachineAir<F> for MemoryChip<F> {
type Program = crate::RecursionProgram<F>;

fn name(&self) -> String {
"Memory Constants".to_string()
"MemoryConst".to_string()
}
fn preprocessed_width(&self) -> usize {
NUM_MEM_PREPROCESSED_INIT_COLS
Expand Down
Loading
Loading