From b6b2988faa7e369d3a0227365935dbe67d56d5c4 Mon Sep 17 00:00:00 2001 From: Tej Qu Nair Date: Fri, 2 Aug 2024 15:39:26 -0700 Subject: [PATCH 1/7] hintfelts and hintexts types in ir --- recursion/compiler/src/circuit/builder.rs | 29 ++++++++++++++ recursion/compiler/src/circuit/compiler.rs | 45 +--------------------- recursion/compiler/src/ir/instructions.rs | 4 ++ 3 files changed, 35 insertions(+), 43 deletions(-) diff --git a/recursion/compiler/src/circuit/builder.rs b/recursion/compiler/src/circuit/builder.rs index b1a32527bc..6b0dc7bb55 100644 --- a/recursion/compiler/src/circuit/builder.rs +++ b/recursion/compiler/src/circuit/builder.rs @@ -26,6 +26,10 @@ pub trait CircuitV2Builder { fn ext2felt_v2(&mut self, ext: Ext) -> [Felt; D]; fn cycle_tracker_v2_enter(&mut self, name: String); fn cycle_tracker_v2_exit(&mut self); + fn hint_ext_v2(&mut self) -> Ext; + fn hint_felt_v2(&mut self) -> Felt; + fn hint_exts_v2(&mut self, len: usize) -> Vec>; + fn hint_felts_v2(&mut self, len: usize) -> Vec>; } impl CircuitV2Builder for Builder { @@ -151,4 +155,29 @@ impl CircuitV2Builder for Builder { fn cycle_tracker_v2_exit(&mut self) { self.operations.push(DslIr::CycleTrackerV2Exit); } + /// Hint a single felt. + fn hint_felt_v2(&mut self) -> Felt { + self.hint_felts_v2(1)[0] + } + + /// Hint a single ext. + fn hint_ext_v2(&mut self) -> Ext { + self.hint_exts_v2(1)[0] + } + /// Hint a vector of felts. + fn hint_felts_v2(&mut self, len: usize) -> Vec> { + let arr = std::iter::from_fn(|| Some(self.uninit())) + .take(len) + .collect::>(); + self.operations.push(DslIr::CircuitV2HintFelts(arr.clone())); + arr + } + /// Hint a vector of exts. + fn hint_exts_v2(&mut self, len: usize) -> Vec> { + let arr = std::iter::from_fn(|| Some(self.uninit())) + .take(len) + .collect::>(); + self.operations.push(DslIr::CircuitV2HintExts(arr.clone())); + arr + } } diff --git a/recursion/compiler/src/circuit/compiler.rs b/recursion/compiler/src/circuit/compiler.rs index e991247072..4adfcaeffc 100644 --- a/recursion/compiler/src/circuit/compiler.rs +++ b/recursion/compiler/src/circuit/compiler.rs @@ -448,53 +448,12 @@ impl AsmCompiler { } DslIr::CircuitV2FriFold(output, input) => vec![self.fri_fold(output, input)], - // DslIr::For(_, _, _, _, _) => todo!(), - // DslIr::IfEq(_, _, _, _) => todo!(), - // DslIr::IfNe(_, _, _, _) => todo!(), - // DslIr::IfEqI(_, _, _, _) => todo!(), - // DslIr::IfNeI(_, _, _, _) => todo!(), - // DslIr::Break => todo!(), - // DslIr::Alloc(_, _, _) => todo!(), - // DslIr::LoadV(_, _, _) => todo!(), - // DslIr::LoadF(_, _, _) => todo!(), - // DslIr::LoadE(_, _, _) => todo!(), - // DslIr::StoreV(_, _, _) => todo!(), - // DslIr::StoreF(_, _, _) => todo!(), - // DslIr::StoreE(_, _, _) => todo!(), - // DslIr::CircuitNum2BitsV(_, _, _) => todo!(), - // DslIr::Poseidon2CompressBabyBear(_, _, _) => todo!(), - // DslIr::Poseidon2AbsorbBabyBear(_, _) => todo!(), - // DslIr::Poseidon2FinalizeBabyBear(_, _) => todo!(), - // DslIr::CircuitPoseidon2Permute(_) => todo!(), - // DslIr::CircuitPoseidon2PermuteBabyBear(_) => todo!(), - // DslIr::HintBitsU(_, _) => todo!(), - // DslIr::HintBitsV(_, _) => todo!(), - // DslIr::HintBitsF(_, _) => todo!(), DslIr::PrintV(dst) => vec![self.print_f(dst)], DslIr::PrintF(dst) => vec![self.print_f(dst)], DslIr::PrintE(dst) => vec![self.print_e(dst)], - // DslIr::Error() => todo!(), - // DslIr::HintExt2Felt(_, _) => todo!(), - // DslIr::HintLen(_) => todo!(), - // DslIr::HintVars(_) => todo!(), - // DslIr::HintFelts(_) => todo!(), - // DslIr::HintExts(_) => todo!(), - // DslIr::WitnessVar(_, _) => todo!(), - // DslIr::WitnessFelt(_, _) => todo!(), - // DslIr::WitnessExt(_, _) => todo!(), - // DslIr::Commit(_, _) => todo!(), - // DslIr::RegisterPublicValue(_) => todo!(), - // DslIr::Halt => todo!(), - // DslIr::CircuitCommitVkeyHash(_) => todo!(), - // DslIr::CircuitCommitCommitedValuesDigest(_) => todo!(), - // DslIr::FriFold(_, _) => todo!(), - // DslIr::CircuitSelectV(_, _, _, _) => todo!(), - // DslIr::CircuitSelectF(_, _, _, _) => todo!(), - // DslIr::CircuitSelectE(_, _, _, _) => todo!(), + DslIr::CircuitV2HintFelts(output) => todo!(), + DslIr::CircuitV2HintExts(output) => todo!(), DslIr::CircuitExt2Felt(felts, ext) => vec![self.ext2felts(felts, ext)], - // DslIr::LessThan(_, _, _) => todo!(), - // DslIr::CycleTracker(_) => todo!(), - // DslIr::ExpReverseBitsLen(_, _, _) => todo!(), DslIr::CycleTrackerV2Enter(name) => vec![CompileOneItem::CycleTrackerEnter(name)], DslIr::CycleTrackerV2Exit => vec![CompileOneItem::CycleTrackerExit], instr => panic!("unsupported instruction: {instr:?}"), diff --git a/recursion/compiler/src/ir/instructions.rs b/recursion/compiler/src/ir/instructions.rs index d4d40bfbd9..075b1ef234 100644 --- a/recursion/compiler/src/ir/instructions.rs +++ b/recursion/compiler/src/ir/instructions.rs @@ -245,6 +245,10 @@ pub enum DslIr { HintFelts(Array>), /// Hint an array of extension field elements. HintExts(Array>), + /// Hint an array of field elements. + CircuitV2HintFelts(Vec>), + /// Hint an array of extension field elements. + CircuitV2HintExts(Vec>), /// Witness a variable. Should only be used when target is a gnark circuit. WitnessVar(Var, u32), /// Witness a field element. Should only be used when target is a gnark circuit. From b714d071c2e74e9721689efc628475067ed5f5ef Mon Sep 17 00:00:00 2001 From: Tej Qu Nair Date: Fri, 2 Aug 2024 17:21:53 -0700 Subject: [PATCH 2/7] (wip) hint felts/exts in runtime --- recursion/compiler/src/circuit/compiler.rs | 36 ++++++++++++++++++-- recursion/core-v2/src/runtime/instruction.rs | 14 ++++++++ recursion/core-v2/src/runtime/mod.rs | 35 ++++++++++++++++--- 3 files changed, 77 insertions(+), 8 deletions(-) diff --git a/recursion/compiler/src/circuit/compiler.rs b/recursion/compiler/src/circuit/compiler.rs index 4adfcaeffc..49f988ff6d 100644 --- a/recursion/compiler/src/circuit/compiler.rs +++ b/recursion/compiler/src/circuit/compiler.rs @@ -1,6 +1,8 @@ use chips::poseidon2_skinny::WIDTH; use core::fmt::Debug; -use instruction::{FieldEltType, HintBitsInstr, HintExt2FeltsInstr, PrintInstr}; +use instruction::{ + FieldEltType, HintBitsInstr, HintExt2FeltsInstr, HintExtsInstr, HintFeltsInstr, PrintInstr, +}; use p3_field::{AbstractExtensionField, AbstractField, Field, PrimeField, TwoAdicField}; use sp1_core::utils::SpanBuilder; use sp1_recursion_core::air::Block; @@ -358,6 +360,26 @@ impl AsmCompiler { .into() } + fn hint_felts(&mut self, output: &[impl Reg]) -> CompileOneItem { + Instruction::HintFelts(HintFeltsInstr { + output_addrs_mults: output + .iter() + .map(|r| (r.write(self), C::F::zero())) + .collect(), + }) + .into() + } + + fn hint_exts(&mut self, output: &[impl Reg]) -> CompileOneItem { + Instruction::HintExts(HintExtsInstr { + output_addrs_mults: output + .iter() + .map(|r| (r.write(self), C::F::zero())) + .collect(), + }) + .into() + } + pub fn compile_one(&mut self, ir_instr: DslIr) -> Vec> where F: PrimeField + TwoAdicField, @@ -451,8 +473,8 @@ impl AsmCompiler { DslIr::PrintV(dst) => vec![self.print_f(dst)], DslIr::PrintF(dst) => vec![self.print_f(dst)], DslIr::PrintE(dst) => vec![self.print_e(dst)], - DslIr::CircuitV2HintFelts(output) => todo!(), - DslIr::CircuitV2HintExts(output) => todo!(), + DslIr::CircuitV2HintFelts(output) => vec![self.hint_felts(&output)], + DslIr::CircuitV2HintExts(output) => vec![self.hint_exts(&output)], DslIr::CircuitExt2Felt(felts, ext) => vec![self.ext2felts(felts, ext)], DslIr::CycleTrackerV2Enter(name) => vec![CompileOneItem::CycleTrackerEnter(name)], DslIr::CycleTrackerV2Exit => vec![CompileOneItem::CycleTrackerExit], @@ -534,6 +556,12 @@ impl AsmCompiler { }) => vec![(mult, result)], Instruction::HintBits(HintBitsInstr { output_addrs_mults, .. + }) + | Instruction::HintFelts(HintFeltsInstr { + output_addrs_mults, .. + }) + | Instruction::HintExts(HintExtsInstr { + output_addrs_mults, .. }) => output_addrs_mults .iter_mut() .map(|(ref addr, mult)| (mult, addr)) @@ -606,6 +634,8 @@ const fn instr_name(instr: &Instruction) -> &'static str { Instruction::FriFold(_) => "FriFold", Instruction::Print(_) => "Print", Instruction::HintExt2Felts(_) => "HintExt2Felts", + Instruction::HintFelts(_) => "HintFelts", + Instruction::HintExts(_) => "HintExts", } } diff --git a/recursion/core-v2/src/runtime/instruction.rs b/recursion/core-v2/src/runtime/instruction.rs index 757ae8e94f..52f8773fcf 100644 --- a/recursion/core-v2/src/runtime/instruction.rs +++ b/recursion/core-v2/src/runtime/instruction.rs @@ -15,6 +15,8 @@ pub enum Instruction { FriFold(FriFoldInstr), Print(PrintInstr), HintExt2Felts(HintExt2FeltsInstr), + HintFelts(HintFeltsInstr), + HintExts(HintExtsInstr), } #[derive(Clone, Debug, Serialize, Deserialize)] @@ -31,6 +33,18 @@ pub struct PrintInstr { pub addr: Address, } +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct HintFeltsInstr { + /// Addresses and mults of the output felts. + pub output_addrs_mults: Vec<(Address, F)>, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct HintExtsInstr { + /// Addresses and mults of the output exts. + pub output_addrs_mults: Vec<(Address, F)>, +} + #[derive(Clone, Debug, Serialize, Deserialize)] pub struct HintExt2FeltsInstr { /// Addresses and mults of the output bits. diff --git a/recursion/core-v2/src/runtime/mod.rs b/recursion/core-v2/src/runtime/mod.rs index f22a53ef32..f97c38e96d 100644 --- a/recursion/core-v2/src/runtime/mod.rs +++ b/recursion/core-v2/src/runtime/mod.rs @@ -6,15 +6,19 @@ mod record; // Avoid triggering annoying branch of thiserror derive macro. use backtrace::Backtrace as Trace; pub use instruction::Instruction; -use instruction::{FieldEltType, HintBitsInstr, HintExt2FeltsInstr, PrintInstr}; +use instruction::{ + FieldEltType, HintBitsInstr, HintExt2FeltsInstr, HintExtsInstr, HintFeltsInstr, PrintInstr, +}; pub use opcode::*; pub use program::*; pub use record::*; use std::{ + collections::VecDeque, fmt::Debug, io::{stdout, Write}, - {marker::PhantomData, sync::Arc}, + marker::PhantomData, + sync::Arc, }; use hashbrown::hash_map::Entry; @@ -106,6 +110,8 @@ pub struct Runtime<'a, F: PrimeField32, EF: ExtensionField, Diffusion> { /// The execution record. pub record: ExecutionRecord, + pub witness_stream: VecDeque>>, + pub cycle_tracker: HashMap, /// The stream that print statements write to. @@ -151,6 +157,10 @@ pub enum RuntimeError { pc: usize, trace: Option<(usize, Trace)>, }, + #[error("failed to print to `debug_stdout`: {0}")] + DebugPrint(#[from] std::io::Error), + #[error("attempted to read `Vec<{0:?}>` from empty witness tream")] + EmptyWitnessStream(FieldEltType), } impl<'a, F: PrimeField32, EF: ExtensionField, Diffusion> Runtime<'a, F, EF, Diffusion> @@ -195,6 +205,7 @@ where pc: F::zero(), memory: HashMap::new(), record, + witness_stream: VecDeque::new(), cycle_tracker: HashMap::new(), debug_stdout: Box::new(stdout()), perm: Some(perm), @@ -531,14 +542,15 @@ where FieldEltType::Base => { self.nb_print_f += 1; let f = self.mr_mult(addr, F::zero()).val[0]; - writeln!(self.debug_stdout, "PRINTF={f}").unwrap(); + writeln!(self.debug_stdout, "PRINTF={f}") } FieldEltType::Extension => { self.nb_print_e += 1; let ef = self.mr_mult(addr, F::zero()).val; - writeln!(self.debug_stdout, "PRINTEF={ef:?}").unwrap(); + writeln!(self.debug_stdout, "PRINTEF={ef:?}") } - }, + } + .map_err(RuntimeError::DebugPrint)?, Instruction::HintExt2Felts(HintExt2FeltsInstr { output_addrs_mults, input_addr, @@ -552,6 +564,19 @@ where self.record.mem_var_events.push(MemEvent { inner: felt }); } } + Instruction::HintFelts(HintFeltsInstr { output_addrs_mults }) => { + let witness = self + .witness_stream + .pop_front() + .ok_or(RuntimeError::EmptyWitnessStream(FieldEltType::Base))?; + // TODO write to some unconstrained memory table + } + Instruction::HintExts(HintExtsInstr { output_addrs_mults }) => { + let witness = self + .witness_stream + .pop_front() + .ok_or(RuntimeError::EmptyWitnessStream(FieldEltType::Extension))?; + } } self.pc = next_pc; From cb5a2397af325b90846b5460544719150422ddf6 Mon Sep 17 00:00:00 2001 From: Tej Qu Nair Date: Mon, 5 Aug 2024 12:37:43 -0700 Subject: [PATCH 3/7] hints in ir (untested) --- recursion/compiler/src/circuit/compiler.rs | 30 +++++--------------- recursion/core-v2/src/runtime/instruction.rs | 11 ++----- recursion/core-v2/src/runtime/mod.rs | 20 ++++++------- 3 files changed, 17 insertions(+), 44 deletions(-) diff --git a/recursion/compiler/src/circuit/compiler.rs b/recursion/compiler/src/circuit/compiler.rs index 49f988ff6d..526196ded3 100644 --- a/recursion/compiler/src/circuit/compiler.rs +++ b/recursion/compiler/src/circuit/compiler.rs @@ -1,8 +1,6 @@ use chips::poseidon2_skinny::WIDTH; use core::fmt::Debug; -use instruction::{ - FieldEltType, HintBitsInstr, HintExt2FeltsInstr, HintExtsInstr, HintFeltsInstr, PrintInstr, -}; +use instruction::{FieldEltType, HintBitsInstr, HintExt2FeltsInstr, HintInstr, PrintInstr}; use p3_field::{AbstractExtensionField, AbstractField, Field, PrimeField, TwoAdicField}; use sp1_core::utils::SpanBuilder; use sp1_recursion_core::air::Block; @@ -360,18 +358,8 @@ impl AsmCompiler { .into() } - fn hint_felts(&mut self, output: &[impl Reg]) -> CompileOneItem { - Instruction::HintFelts(HintFeltsInstr { - output_addrs_mults: output - .iter() - .map(|r| (r.write(self), C::F::zero())) - .collect(), - }) - .into() - } - - fn hint_exts(&mut self, output: &[impl Reg]) -> CompileOneItem { - Instruction::HintExts(HintExtsInstr { + fn hint(&mut self, output: &[impl Reg]) -> CompileOneItem { + Instruction::Hint(HintInstr { output_addrs_mults: output .iter() .map(|r| (r.write(self), C::F::zero())) @@ -473,8 +461,8 @@ impl AsmCompiler { DslIr::PrintV(dst) => vec![self.print_f(dst)], DslIr::PrintF(dst) => vec![self.print_f(dst)], DslIr::PrintE(dst) => vec![self.print_e(dst)], - DslIr::CircuitV2HintFelts(output) => vec![self.hint_felts(&output)], - DslIr::CircuitV2HintExts(output) => vec![self.hint_exts(&output)], + DslIr::CircuitV2HintFelts(output) => vec![self.hint(&output)], + DslIr::CircuitV2HintExts(output) => vec![self.hint(&output)], DslIr::CircuitExt2Felt(felts, ext) => vec![self.ext2felts(felts, ext)], DslIr::CycleTrackerV2Enter(name) => vec![CompileOneItem::CycleTrackerEnter(name)], DslIr::CycleTrackerV2Exit => vec![CompileOneItem::CycleTrackerExit], @@ -557,10 +545,7 @@ impl AsmCompiler { Instruction::HintBits(HintBitsInstr { output_addrs_mults, .. }) - | Instruction::HintFelts(HintFeltsInstr { - output_addrs_mults, .. - }) - | Instruction::HintExts(HintExtsInstr { + | Instruction::Hint(HintInstr { output_addrs_mults, .. }) => output_addrs_mults .iter_mut() @@ -634,8 +619,7 @@ const fn instr_name(instr: &Instruction) -> &'static str { Instruction::FriFold(_) => "FriFold", Instruction::Print(_) => "Print", Instruction::HintExt2Felts(_) => "HintExt2Felts", - Instruction::HintFelts(_) => "HintFelts", - Instruction::HintExts(_) => "HintExts", + Instruction::Hint(_) => "Hint", } } diff --git a/recursion/core-v2/src/runtime/instruction.rs b/recursion/core-v2/src/runtime/instruction.rs index 52f8773fcf..34a57ead9b 100644 --- a/recursion/core-v2/src/runtime/instruction.rs +++ b/recursion/core-v2/src/runtime/instruction.rs @@ -15,8 +15,7 @@ pub enum Instruction { FriFold(FriFoldInstr), Print(PrintInstr), HintExt2Felts(HintExt2FeltsInstr), - HintFelts(HintFeltsInstr), - HintExts(HintExtsInstr), + Hint(HintInstr), } #[derive(Clone, Debug, Serialize, Deserialize)] @@ -34,17 +33,11 @@ pub struct PrintInstr { } #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct HintFeltsInstr { +pub struct HintInstr { /// Addresses and mults of the output felts. pub output_addrs_mults: Vec<(Address, F)>, } -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct HintExtsInstr { - /// Addresses and mults of the output exts. - pub output_addrs_mults: Vec<(Address, F)>, -} - #[derive(Clone, Debug, Serialize, Deserialize)] pub struct HintExt2FeltsInstr { /// Addresses and mults of the output bits. diff --git a/recursion/core-v2/src/runtime/mod.rs b/recursion/core-v2/src/runtime/mod.rs index f97c38e96d..2582fef74c 100644 --- a/recursion/core-v2/src/runtime/mod.rs +++ b/recursion/core-v2/src/runtime/mod.rs @@ -6,9 +6,7 @@ mod record; // Avoid triggering annoying branch of thiserror derive macro. use backtrace::Backtrace as Trace; pub use instruction::Instruction; -use instruction::{ - FieldEltType, HintBitsInstr, HintExt2FeltsInstr, HintExtsInstr, HintFeltsInstr, PrintInstr, -}; +use instruction::{FieldEltType, HintBitsInstr, HintExt2FeltsInstr, HintInstr, PrintInstr}; pub use opcode::*; pub use program::*; pub use record::*; @@ -17,6 +15,7 @@ use std::{ collections::VecDeque, fmt::Debug, io::{stdout, Write}, + iter::zip, marker::PhantomData, sync::Arc, }; @@ -159,7 +158,7 @@ pub enum RuntimeError { }, #[error("failed to print to `debug_stdout`: {0}")] DebugPrint(#[from] std::io::Error), - #[error("attempted to read `Vec<{0:?}>` from empty witness tream")] + #[error("attempted to read vec of {0:?} from empty witness tream")] EmptyWitnessStream(FieldEltType), } @@ -564,18 +563,15 @@ where self.record.mem_var_events.push(MemEvent { inner: felt }); } } - Instruction::HintFelts(HintFeltsInstr { output_addrs_mults }) => { + Instruction::Hint(HintInstr { output_addrs_mults }) => { let witness = self .witness_stream .pop_front() .ok_or(RuntimeError::EmptyWitnessStream(FieldEltType::Base))?; - // TODO write to some unconstrained memory table - } - Instruction::HintExts(HintExtsInstr { output_addrs_mults }) => { - let witness = self - .witness_stream - .pop_front() - .ok_or(RuntimeError::EmptyWitnessStream(FieldEltType::Extension))?; + for ((addr, mult), val) in zip(output_addrs_mults, witness) { + self.mw(addr, val, mult); + self.record.mem_events.push(MemEvent { inner: val }); + } } } From 79d244d557c00d61d61bb619cefe0523a96f7bac Mon Sep 17 00:00:00 2001 From: Tej Qu Nair Date: Mon, 5 Aug 2024 17:08:18 -0700 Subject: [PATCH 4/7] length checking, misc --- recursion/core-v2/src/chips/mem/variable.rs | 5 +++-- recursion/core-v2/src/runtime/mod.rs | 16 ++++++++++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/recursion/core-v2/src/chips/mem/variable.rs b/recursion/core-v2/src/chips/mem/variable.rs index 18540cf14b..96b5fe37d5 100644 --- a/recursion/core-v2/src/chips/mem/variable.rs +++ b/recursion/core-v2/src/chips/mem/variable.rs @@ -1,5 +1,5 @@ use core::borrow::Borrow; -use instruction::{HintBitsInstr, HintExt2FeltsInstr}; +use instruction::{HintBitsInstr, HintExt2FeltsInstr, HintInstr}; use itertools::Itertools; use p3_air::{Air, BaseAir, PairBuilder}; use p3_field::PrimeField32; @@ -61,7 +61,8 @@ impl MachineAir for MemoryChip { .instructions .iter() .flat_map(|instruction| match instruction { - Instruction::HintBits(HintBitsInstr { + Instruction::Hint(HintInstr { output_addrs_mults }) + | Instruction::HintBits(HintBitsInstr { output_addrs_mults, input_addr: _, // No receive interaction for the hint operation }) => output_addrs_mults diff --git a/recursion/core-v2/src/runtime/mod.rs b/recursion/core-v2/src/runtime/mod.rs index 2582fef74c..cd2ffc6be3 100644 --- a/recursion/core-v2/src/runtime/mod.rs +++ b/recursion/core-v2/src/runtime/mod.rs @@ -158,8 +158,13 @@ pub enum RuntimeError { }, #[error("failed to print to `debug_stdout`: {0}")] DebugPrint(#[from] std::io::Error), - #[error("attempted to read vec of {0:?} from empty witness tream")] + #[error("attempted to read vec of {0:?} from empty witness stream")] EmptyWitnessStream(FieldEltType), + #[error("attempted to write to memory vec of len {mem_vec_len} witness of size {witness_len}")] + WitnessLenMismatch { + mem_vec_len: usize, + witness_len: usize, + }, } impl<'a, F: PrimeField32, EF: ExtensionField, Diffusion> Runtime<'a, F, EF, Diffusion> @@ -568,9 +573,16 @@ where .witness_stream .pop_front() .ok_or(RuntimeError::EmptyWitnessStream(FieldEltType::Base))?; + // Check the lengths are the same. + if output_addrs_mults.len() != witness.len() { + return Err(RuntimeError::WitnessLenMismatch { + mem_vec_len: output_addrs_mults.len(), + witness_len: witness.len(), + }); + } for ((addr, mult), val) in zip(output_addrs_mults, witness) { self.mw(addr, val, mult); - self.record.mem_events.push(MemEvent { inner: val }); + self.record.mem_var_events.push(MemEvent { inner: val }); } } } From 8db1343a7e0e3512fc45d9506a83c8825d9a2023 Mon Sep 17 00:00:00 2001 From: Tej Qu Nair Date: Mon, 5 Aug 2024 17:38:40 -0700 Subject: [PATCH 5/7] tests --- recursion/compiler/src/circuit/mod.rs | 106 +++++++++++++++++++++++++- recursion/core-v2/src/runtime/mod.rs | 6 +- 2 files changed, 106 insertions(+), 6 deletions(-) diff --git a/recursion/compiler/src/circuit/mod.rs b/recursion/compiler/src/circuit/mod.rs index 6098320b30..eb1a350d86 100644 --- a/recursion/compiler/src/circuit/mod.rs +++ b/recursion/compiler/src/circuit/mod.rs @@ -7,7 +7,7 @@ pub use compiler::*; #[cfg(test)] mod tests { use p3_baby_bear::DiffusionMatrixBabyBear; - use p3_field::AbstractExtensionField; + use p3_field::{AbstractExtensionField, AbstractField}; use rand::{rngs::StdRng, Rng, SeedableRng}; use sp1_core::{ stark::{Chip, StarkGenericConfig, StarkMachine, PROOF_MAX_NUM_PVS}, @@ -23,10 +23,14 @@ mod tests { poseidon2_wide::Poseidon2WideChip, }, machine::RecursionAir, - Runtime, + Runtime, RuntimeError, }; - use crate::{asm::AsmBuilder, circuit::AsmCompiler, ir::*}; + use crate::{ + asm::AsmBuilder, + circuit::{AsmCompiler, CircuitV2Builder}, + ir::*, + }; const DEGREE: usize = 3; @@ -119,4 +123,100 @@ mod tests { tracing::info!("num shard proofs: {}", result.shard_proofs.len()); } + + #[test] + fn test_io() { + let mut builder = AsmBuilder::::default(); + + let felts = builder.hint_felts_v2(3); + assert_eq!(felts.len(), 3); + let sum: Felt<_> = builder.eval(felts[0] + felts[1]); + builder.assert_felt_eq(sum, felts[2]); + + let exts = builder.hint_exts_v2(3); + assert_eq!(exts.len(), 3); + let sum: Ext<_, _> = builder.eval(exts[0] + exts[1]); + builder.assert_ext_ne(sum, exts[2]); + + let x = builder.hint_ext_v2(); + builder.assert_ext_eq(x, exts[0] + felts[0]); + + let y = builder.hint_felt_v2(); + let zero: Felt<_> = builder.constant(F::zero()); + builder.assert_felt_eq(y, zero); + + let operations = builder.operations; + let mut compiler = AsmCompiler::default(); + let program = compiler.compile(operations); + let mut runtime = Runtime::::new(&program, SC::new().perm); + runtime.witness_stream = [ + vec![F::one().into(), F::one().into(), F::two().into()], + vec![F::zero().into(), F::one().into(), F::two().into()], + vec![F::one().into()], + vec![F::zero().into()], + ] + .into(); + runtime.run().unwrap(); + + let machine = A::machine(SC::new()); + + let (pk, vk) = machine.setup(&program); + let result = + run_test_machine(vec![runtime.record], machine, pk, vk.clone()).expect("should verify"); + + tracing::info!("num shard proofs: {}", result.shard_proofs.len()); + } + + #[test] + fn test_empty_witness_stream() { + let mut builder = AsmBuilder::::default(); + + let felts = builder.hint_felts_v2(3); + assert_eq!(felts.len(), 3); + let sum: Felt<_> = builder.eval(felts[0] + felts[1]); + builder.assert_felt_eq(sum, felts[2]); + + let exts = builder.hint_exts_v2(3); + assert_eq!(exts.len(), 3); + let sum: Ext<_, _> = builder.eval(exts[0] + exts[1]); + builder.assert_ext_ne(sum, exts[2]); + + let operations = builder.operations; + let mut compiler = AsmCompiler::default(); + let program = compiler.compile(operations); + let mut runtime = Runtime::::new(&program, SC::new().perm); + runtime.witness_stream = [vec![F::one().into(), F::one().into(), F::two().into()]].into(); + + match runtime.run() { + Err(RuntimeError::EmptyWitnessStream) => (), + Ok(_) => panic!("should not succeed"), + Err(x) => panic!("should not yield error variant: {}", x), + } + } + + #[test] + fn test_mismatched_witness_size() { + const MEM_VEC_LEN: usize = 3; + const WITNESS_LEN: usize = 5; + + let mut builder = AsmBuilder::::default(); + + let felts = builder.hint_felts_v2(MEM_VEC_LEN); + assert_eq!(felts.len(), MEM_VEC_LEN); + + let operations = builder.operations; + let mut compiler = AsmCompiler::default(); + let program = compiler.compile(operations); + let mut runtime = Runtime::::new(&program, SC::new().perm); + runtime.witness_stream = [vec![F::zero().into(); WITNESS_LEN]].into(); + + match runtime.run() { + Err(RuntimeError::WitnessLenMismatch { + mem_vec_len: MEM_VEC_LEN, + witness_len: WITNESS_LEN, + }) => (), + Ok(_) => panic!("should not succeed"), + Err(x) => panic!("should not yield error variant: {}", x), + } + } } diff --git a/recursion/core-v2/src/runtime/mod.rs b/recursion/core-v2/src/runtime/mod.rs index cd2ffc6be3..5dca4059a7 100644 --- a/recursion/core-v2/src/runtime/mod.rs +++ b/recursion/core-v2/src/runtime/mod.rs @@ -158,8 +158,8 @@ pub enum RuntimeError { }, #[error("failed to print to `debug_stdout`: {0}")] DebugPrint(#[from] std::io::Error), - #[error("attempted to read vec of {0:?} from empty witness stream")] - EmptyWitnessStream(FieldEltType), + #[error("attempted to read from empty witness stream")] + EmptyWitnessStream, #[error("attempted to write to memory vec of len {mem_vec_len} witness of size {witness_len}")] WitnessLenMismatch { mem_vec_len: usize, @@ -572,7 +572,7 @@ where let witness = self .witness_stream .pop_front() - .ok_or(RuntimeError::EmptyWitnessStream(FieldEltType::Base))?; + .ok_or(RuntimeError::EmptyWitnessStream)?; // Check the lengths are the same. if output_addrs_mults.len() != witness.len() { return Err(RuntimeError::WitnessLenMismatch { From fb3baca3308fe9db40fca5fa03d5e48413be4fb5 Mon Sep 17 00:00:00 2001 From: Tej Qu Nair Date: Wed, 7 Aug 2024 11:19:41 -0700 Subject: [PATCH 6/7] spaces between functions --- recursion/compiler/src/circuit/builder.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/recursion/compiler/src/circuit/builder.rs b/recursion/compiler/src/circuit/builder.rs index 6b0dc7bb55..0ced17dc50 100644 --- a/recursion/compiler/src/circuit/builder.rs +++ b/recursion/compiler/src/circuit/builder.rs @@ -44,6 +44,7 @@ impl CircuitV2Builder for Builder { } num } + /// Converts a felt to bits inside a circuit. fn num2bits_v2_f(&mut self, num: Felt, num_bits: usize) -> Vec> { let output = std::iter::from_fn(|| Some(self.uninit())) @@ -64,6 +65,7 @@ impl CircuitV2Builder for Builder { output } + /// A version of `exp_reverse_bits_len` that uses the ExpReverseBitsLen precompile. fn exp_reverse_bits_v2( &mut self, @@ -75,6 +77,7 @@ impl CircuitV2Builder for Builder { .push(DslIr::CircuitV2ExpReverseBits(output, input, power_bits)); output } + /// Applies the Poseidon2 permutation to the given array. fn poseidon2_permute_v2_skinny(&mut self, array: [Felt; WIDTH]) -> [Felt; WIDTH] { let output: [Felt; WIDTH] = core::array::from_fn(|_| self.uninit()); @@ -91,6 +94,7 @@ impl CircuitV2Builder for Builder { .push(DslIr::CircuitV2Poseidon2PermuteBabyBearWide(output, array)); output } + /// Applies the Poseidon2 permutation to the given array. /// /// Reference: [p3_symmetric::PaddingFreeSponge] @@ -104,6 +108,7 @@ impl CircuitV2Builder for Builder { let state: [Felt; DIGEST_SIZE] = state[..DIGEST_SIZE].try_into().unwrap(); state } + /// Applies the Poseidon2 compression function to the given array. /// /// Reference: [p3_symmetric::TruncatedPermutation] @@ -118,6 +123,7 @@ impl CircuitV2Builder for Builder { let post: [Felt; DIGEST_SIZE] = post[..DIGEST_SIZE].try_into().unwrap(); post } + /// Runs FRI fold. fn fri_fold_v2(&mut self, input: CircuitV2FriFoldInput) -> CircuitV2FriFoldOutput { let mut uninit_vec = |len| { @@ -133,6 +139,7 @@ impl CircuitV2Builder for Builder { .push(DslIr::CircuitV2FriFold(output.clone(), input)); output } + /// Decomposes an ext into its felt coordinates. fn ext2felt_v2(&mut self, ext: Ext) -> [Felt; D] { let felts = core::array::from_fn(|_| self.uninit()); @@ -149,12 +156,15 @@ impl CircuitV2Builder for Builder { felts } + fn cycle_tracker_v2_enter(&mut self, name: String) { self.operations.push(DslIr::CycleTrackerV2Enter(name)); } + fn cycle_tracker_v2_exit(&mut self) { self.operations.push(DslIr::CycleTrackerV2Exit); } + /// Hint a single felt. fn hint_felt_v2(&mut self) -> Felt { self.hint_felts_v2(1)[0] @@ -164,6 +174,7 @@ impl CircuitV2Builder for Builder { fn hint_ext_v2(&mut self) -> Ext { self.hint_exts_v2(1)[0] } + /// Hint a vector of felts. fn hint_felts_v2(&mut self, len: usize) -> Vec> { let arr = std::iter::from_fn(|| Some(self.uninit())) @@ -172,6 +183,7 @@ impl CircuitV2Builder for Builder { self.operations.push(DslIr::CircuitV2HintFelts(arr.clone())); arr } + /// Hint a vector of exts. fn hint_exts_v2(&mut self, len: usize) -> Vec> { let arr = std::iter::from_fn(|| Some(self.uninit())) From f60b2c6ac0e6347009011b25a79eda634d3b836a Mon Sep 17 00:00:00 2001 From: Tej Qu Nair Date: Wed, 7 Aug 2024 11:20:14 -0700 Subject: [PATCH 7/7] no spaces in chip names --- recursion/core-v2/src/chips/alu_base.rs | 2 +- recursion/core-v2/src/chips/alu_ext.rs | 2 +- recursion/core-v2/src/chips/dummy.rs | 2 +- recursion/core-v2/src/chips/mem/constant.rs | 2 +- recursion/core-v2/src/chips/mem/variable.rs | 2 +- recursion/core-v2/src/chips/poseidon2_skinny/trace.rs | 2 +- recursion/core-v2/src/chips/poseidon2_wide/trace.rs | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/recursion/core-v2/src/chips/alu_base.rs b/recursion/core-v2/src/chips/alu_base.rs index 9583f8b3fe..ce448c2e3d 100644 --- a/recursion/core-v2/src/chips/alu_base.rs +++ b/recursion/core-v2/src/chips/alu_base.rs @@ -65,7 +65,7 @@ impl MachineAir for BaseAluChip { type Program = crate::RecursionProgram; fn name(&self) -> String { - "Base field Alu".to_string() + "BaseAlu".to_string() } fn preprocessed_width(&self) -> usize { diff --git a/recursion/core-v2/src/chips/alu_ext.rs b/recursion/core-v2/src/chips/alu_ext.rs index f318ce9042..b167b7d0ca 100644 --- a/recursion/core-v2/src/chips/alu_ext.rs +++ b/recursion/core-v2/src/chips/alu_ext.rs @@ -50,7 +50,7 @@ impl> MachineAir for ExtAluChip { type Program = crate::RecursionProgram; fn name(&self) -> String { - "Extension field Alu".to_string() + "ExtAlu".to_string() } fn preprocessed_width(&self) -> usize { diff --git a/recursion/core-v2/src/chips/dummy.rs b/recursion/core-v2/src/chips/dummy.rs index 1078933da4..4db65a9c18 100644 --- a/recursion/core-v2/src/chips/dummy.rs +++ b/recursion/core-v2/src/chips/dummy.rs @@ -45,7 +45,7 @@ impl MachineAir for DummyChip; fn name(&self) -> String { - "Dummy wide".to_string() + "DummyWide".to_string() } fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { diff --git a/recursion/core-v2/src/chips/mem/constant.rs b/recursion/core-v2/src/chips/mem/constant.rs index dbfb2e9c03..feef160cde 100644 --- a/recursion/core-v2/src/chips/mem/constant.rs +++ b/recursion/core-v2/src/chips/mem/constant.rs @@ -49,7 +49,7 @@ impl MachineAir for MemoryChip { type Program = crate::RecursionProgram; fn name(&self) -> String { - "Memory Constants".to_string() + "MemoryConst".to_string() } fn preprocessed_width(&self) -> usize { NUM_MEM_PREPROCESSED_INIT_COLS diff --git a/recursion/core-v2/src/chips/mem/variable.rs b/recursion/core-v2/src/chips/mem/variable.rs index 96b5fe37d5..f1cb44b492 100644 --- a/recursion/core-v2/src/chips/mem/variable.rs +++ b/recursion/core-v2/src/chips/mem/variable.rs @@ -50,7 +50,7 @@ impl MachineAir for MemoryChip { type Program = crate::RecursionProgram; fn name(&self) -> String { - "Memory Variables".to_string() + "MemoryVar".to_string() } fn preprocessed_width(&self) -> usize { NUM_MEM_PREPROCESSED_INIT_COLS diff --git a/recursion/core-v2/src/chips/poseidon2_skinny/trace.rs b/recursion/core-v2/src/chips/poseidon2_skinny/trace.rs index 82638ae3cd..0768828a3c 100644 --- a/recursion/core-v2/src/chips/poseidon2_skinny/trace.rs +++ b/recursion/core-v2/src/chips/poseidon2_skinny/trace.rs @@ -32,7 +32,7 @@ impl MachineAir for Poseidon2SkinnyChip type Program = RecursionProgram; fn name(&self) -> String { - format!("Poseidon2Skinny {}", DEGREE) + format!("Poseidon2SkinnyDeg{}", DEGREE) } #[instrument(name = "generate poseidon2 skinny trace", level = "debug", skip_all, fields(rows = input.poseidon2_skinny_events.len()))] diff --git a/recursion/core-v2/src/chips/poseidon2_wide/trace.rs b/recursion/core-v2/src/chips/poseidon2_wide/trace.rs index 4fb58536a0..f07f4404c9 100644 --- a/recursion/core-v2/src/chips/poseidon2_wide/trace.rs +++ b/recursion/core-v2/src/chips/poseidon2_wide/trace.rs @@ -33,7 +33,7 @@ impl MachineAir for Poseidon2WideChip; fn name(&self) -> String { - format!("Poseidon2Wide {}", DEGREE) + format!("Poseidon2WideDeg{}", DEGREE) } #[instrument(name = "generate poseidon2 wide trace", level = "debug", skip_all, fields(rows = input.poseidon2_wide_events.len()))]