From 96d4f2bea072a7a93a590440477a15337e6d0b29 Mon Sep 17 00:00:00 2001 From: Kevin Jue Date: Sat, 11 May 2024 12:24:38 -0700 Subject: [PATCH] feat(recursion): HALT instruction (#703) --- recursion/compiler/src/asm/compiler.rs | 3 ++ recursion/compiler/src/asm/instruction.rs | 15 ++++++++++ recursion/compiler/src/ir/builder.rs | 4 +++ recursion/compiler/src/ir/instructions.rs | 1 + recursion/core/src/cpu/air/mod.rs | 28 +++++++++++++++++++ recursion/core/src/cpu/air/system.rs | 34 +++++++++++++++++++++++ recursion/core/src/cpu/columns/opcode.rs | 3 ++ recursion/core/src/runtime/mod.rs | 8 ++++-- recursion/core/src/runtime/opcode.rs | 1 + recursion/program/src/challenger.rs | 1 + recursion/program/src/constraints.rs | 1 + recursion/program/src/fri/domain.rs | 1 + recursion/program/src/fri/two_adic_pcs.rs | 1 + recursion/program/src/machine/compress.rs | 2 ++ recursion/program/src/machine/core.rs | 2 ++ recursion/program/src/machine/deferred.rs | 2 ++ recursion/program/src/machine/root.rs | 2 ++ recursion/program/src/stark.rs | 2 ++ 18 files changed, 109 insertions(+), 2 deletions(-) create mode 100644 recursion/core/src/cpu/air/system.rs diff --git a/recursion/compiler/src/asm/compiler.rs b/recursion/compiler/src/asm/compiler.rs index bb7a40503c..eb1ab9820d 100644 --- a/recursion/compiler/src/asm/compiler.rs +++ b/recursion/compiler/src/asm/compiler.rs @@ -538,6 +538,9 @@ impl + TwoAdicField> AsmCo DslIr::CycleTracker(name) => { self.push(AsmInstruction::CycleTracker(name.clone()), trace); } + DslIr::Halt => { + self.push(AsmInstruction::Halt, trace); + } _ => unimplemented!(), } } diff --git a/recursion/compiler/src/asm/instruction.rs b/recursion/compiler/src/asm/instruction.rs index 31c7e1a120..ab89fd4a78 100644 --- a/recursion/compiler/src/asm/instruction.rs +++ b/recursion/compiler/src/asm/instruction.rs @@ -134,6 +134,9 @@ pub enum AsmInstruction { /// Trap. Trap, + /// Halt. + Halt, + /// Break(label) Break(F), @@ -703,6 +706,17 @@ impl> AsmInstruction { false, "".to_string(), ), + AsmInstruction::Halt => Instruction::new( + Opcode::HALT, + F::zero(), + zero, + zero, + F::zero(), + F::zero(), + false, + false, + "".to_string(), + ), AsmInstruction::HintBits(dst, src) => Instruction::new( Opcode::HintBits, i32_f(dst), @@ -1071,6 +1085,7 @@ impl> AsmInstruction { ) } AsmInstruction::Trap => write!(f, "trap"), + AsmInstruction::Halt => write!(f, "halt"), AsmInstruction::HintBits(dst, src) => write!(f, "hint_bits ({})fp, ({})fp", dst, src), AsmInstruction::Poseidon2Permute(dst, src) => { write!(f, "poseidon2_permute ({})fp, ({})fp", dst, src) diff --git a/recursion/compiler/src/ir/builder.rs b/recursion/compiler/src/ir/builder.rs index e4b562cda6..4a06dca938 100644 --- a/recursion/compiler/src/ir/builder.rs +++ b/recursion/compiler/src/ir/builder.rs @@ -456,6 +456,10 @@ impl Builder { pub fn cycle_tracker(&mut self, name: &str) { self.operations.push(DslIr::CycleTracker(name.to_string())); } + + pub fn halt(&mut self) { + self.operations.push(DslIr::Halt); + } } /// A builder for the DSL that handles if statements. diff --git a/recursion/compiler/src/ir/instructions.rs b/recursion/compiler/src/ir/instructions.rs index 190f510012..1e79528029 100644 --- a/recursion/compiler/src/ir/instructions.rs +++ b/recursion/compiler/src/ir/instructions.rs @@ -155,6 +155,7 @@ pub enum DslIr { WitnessFelt(Felt, u32), WitnessExt(Ext, u32), Commit(Felt, Var), + Halt, // Public inputs for circuits. CircuitCommitVkeyHash(Var), diff --git a/recursion/core/src/cpu/air/mod.rs b/recursion/core/src/cpu/air/mod.rs index d8abbae4f5..719d0031da 100644 --- a/recursion/core/src/cpu/air/mod.rs +++ b/recursion/core/src/cpu/air/mod.rs @@ -3,6 +3,7 @@ mod branch; mod jump; mod memory; mod operands; +mod system; use std::borrow::Borrow; @@ -75,6 +76,12 @@ where // Constrain the clk. self.eval_clk(builder, local, next); + + // Constrain the system instructions (TRAP, HALT). + self.eval_system_instructions(builder, local, next); + + // Constrain the is_real_flag. + self.eval_is_real(builder, local, next); } } @@ -101,6 +108,27 @@ impl CpuChip { .assert_eq(local.clk.into() + local.a.value()[0], next.clk); } + /// Eval the is_real flag. + pub fn eval_is_real( + &self, + builder: &mut AB, + local: &CpuCols, + next: &CpuCols, + ) where + AB: SP1RecursionAirBuilder, + { + builder.assert_bool(local.is_real); + + // First row should be real. + builder.when_first_row().assert_one(local.is_real); + + // Once rows transition to not real, then they should stay not real. + builder + .when_transition() + .when_not(local.is_real) + .assert_zero(next.is_real); + } + /// Expr to check for alu instructions. pub fn is_alu_instruction(&self, local: &CpuCols) -> AB::Expr where diff --git a/recursion/core/src/cpu/air/system.rs b/recursion/core/src/cpu/air/system.rs new file mode 100644 index 0000000000..39be8a5fd8 --- /dev/null +++ b/recursion/core/src/cpu/air/system.rs @@ -0,0 +1,34 @@ +use p3_air::AirBuilder; +use p3_field::Field; +use sp1_core::air::BaseAirBuilder; + +use crate::{ + air::SP1RecursionAirBuilder, + cpu::{CpuChip, CpuCols}, +}; + +impl CpuChip { + /// Eval the system instructions (TRAP, HALT). + /// + /// This method will contrain the following: + /// 1) Ensure that none of the instructions are TRAP. + /// 2) Ensure that the last real instruction is a HALT. + pub fn eval_system_instructions( + &self, + builder: &mut AB, + local: &CpuCols, + next: &CpuCols, + ) where + AB: SP1RecursionAirBuilder, + { + builder + .when(local.is_real) + .assert_zero(local.selectors.is_trap); + + builder + .when_transition() + .when(local.is_real) + .when_not(next.is_real) + .assert_one(local.selectors.is_halt); + } +} diff --git a/recursion/core/src/cpu/columns/opcode.rs b/recursion/core/src/cpu/columns/opcode.rs index 8fb7bddd87..faf78d58ab 100644 --- a/recursion/core/src/cpu/columns/opcode.rs +++ b/recursion/core/src/cpu/columns/opcode.rs @@ -35,6 +35,7 @@ pub struct OpcodeSelectorCols { // System instructions. pub is_trap: T, pub is_noop: T, + pub is_halt: T, pub is_poseidon: T, pub is_fri_fold: T, @@ -61,6 +62,7 @@ impl OpcodeSelectorCols { Opcode::JAL => self.is_jal = F::one(), Opcode::JALR => self.is_jalr = F::one(), Opcode::TRAP => self.is_trap = F::one(), + Opcode::HALT => self.is_halt = F::one(), Opcode::FRIFold => self.is_fri_fold = F::one(), Opcode::Poseidon2Compress => self.is_poseidon = F::one(), // TODO: Double-check that `is_noop` is constrained properly in the CPU air. @@ -101,6 +103,7 @@ impl IntoIterator for &OpcodeSelectorCols { self.is_jal, self.is_jalr, self.is_trap, + self.is_halt, self.is_noop, self.is_poseidon, self.is_fri_fold, diff --git a/recursion/core/src/runtime/mod.rs b/recursion/core/src/runtime/mod.rs index 91bb1ae3cb..3d8b95c3a0 100644 --- a/recursion/core/src/runtime/mod.rs +++ b/recursion/core/src/runtime/mod.rs @@ -588,6 +588,10 @@ where } exit(1); } + Opcode::HALT => { + let (a_val, b_val, c_val) = self.all_rr(&instruction); + (a, b, c) = (a_val, b_val, c_val); + } Opcode::Ext2Felt => { let (a_val, b_val, c_val) = self.all_rr(&instruction); let dst = a_val[0].as_canonical_u32() as usize; @@ -808,7 +812,7 @@ where clk: self.clk, pc: self.pc, fp: self.fp, - instruction, + instruction: instruction.clone(), a, a_record: self.access.a, b, @@ -823,7 +827,7 @@ where self.timestamp += 1; self.access = CpuRecord::default(); - if self.timestamp >= early_exit_ts { + if self.timestamp >= early_exit_ts || instruction.opcode == Opcode::HALT { break; } } diff --git a/recursion/core/src/runtime/opcode.rs b/recursion/core/src/runtime/opcode.rs index b2bead5ea7..c12ce58864 100644 --- a/recursion/core/src/runtime/opcode.rs +++ b/recursion/core/src/runtime/opcode.rs @@ -30,6 +30,7 @@ pub enum Opcode { // System instructions. TRAP = 30, + HALT = 31, // Hash instructions. Poseidon2Compress = 39, diff --git a/recursion/program/src/challenger.rs b/recursion/program/src/challenger.rs index 6a05cca8e3..65603ced54 100644 --- a/recursion/program/src/challenger.rs +++ b/recursion/program/src/challenger.rs @@ -333,6 +333,7 @@ mod tests { }; let one: Felt<_> = builder.eval(F::one()); let two: Felt<_> = builder.eval(F::two()); + builder.halt(); challenger.observe(&mut builder, one); challenger.observe(&mut builder, two); challenger.observe(&mut builder, two); diff --git a/recursion/program/src/constraints.rs b/recursion/program/src/constraints.rs index 176abdb33a..c7ab21108a 100644 --- a/recursion/program/src/constraints.rs +++ b/recursion/program/src/constraints.rs @@ -345,6 +345,7 @@ mod tests { } break; } + builder.halt(); let program = builder.compile_program(); run_test_recursion(program, None, TestConfig::All); diff --git a/recursion/program/src/fri/domain.rs b/recursion/program/src/fri/domain.rs index f3f20c7a76..a26acf23ed 100644 --- a/recursion/program/src/fri/domain.rs +++ b/recursion/program/src/fri/domain.rs @@ -267,6 +267,7 @@ pub(crate) mod tests { domain_assertions(&mut builder, &dom, dom_val, zeta_val); } } + builder.halt(); let program = builder.compile_program(); run_test_recursion(program, None, TestConfig::All); diff --git a/recursion/program/src/fri/two_adic_pcs.rs b/recursion/program/src/fri/two_adic_pcs.rs index 0e9c316a8b..71b361f8ef 100644 --- a/recursion/program/src/fri/two_adic_pcs.rs +++ b/recursion/program/src/fri/two_adic_pcs.rs @@ -394,6 +394,7 @@ pub mod tests { challenger.observe(&mut builder, commit); challenger.sample_ext(&mut builder); pcs.verify(&mut builder, rounds, proofvar, &mut challenger); + builder.halt(); let program = builder.compile_program(); let mut witness_stream = VecDeque::new(); diff --git a/recursion/program/src/machine/compress.rs b/recursion/program/src/machine/compress.rs index b3c98ad155..441f648944 100644 --- a/recursion/program/src/machine/compress.rs +++ b/recursion/program/src/machine/compress.rs @@ -471,5 +471,7 @@ where for value in reduce_public_values_stream { builder.commit_public_value(value); } + + builder.halt(); } } diff --git a/recursion/program/src/machine/core.rs b/recursion/program/src/machine/core.rs index 527533835a..2ae14ffba2 100644 --- a/recursion/program/src/machine/core.rs +++ b/recursion/program/src/machine/core.rs @@ -323,5 +323,7 @@ where for value in recursion_public_values_stream { builder.commit_public_value(value); } + + builder.halt(); } } diff --git a/recursion/program/src/machine/deferred.rs b/recursion/program/src/machine/deferred.rs index 206cb6fae4..29708825df 100644 --- a/recursion/program/src/machine/deferred.rs +++ b/recursion/program/src/machine/deferred.rs @@ -287,5 +287,7 @@ where for value in deferred_public_values_stream { builder.commit_public_value(value); } + + builder.halt(); } } diff --git a/recursion/program/src/machine/root.rs b/recursion/program/src/machine/root.rs index 16dfc3993f..57e6f646a5 100644 --- a/recursion/program/src/machine/root.rs +++ b/recursion/program/src/machine/root.rs @@ -135,5 +135,7 @@ where for value in public_values_elements { builder.commit_public_value(value); } + + builder.halt(); } } diff --git a/recursion/program/src/stark.rs b/recursion/program/src/stark.rs index 86fcd11762..cd59a8804d 100644 --- a/recursion/program/src/stark.rs +++ b/recursion/program/src/stark.rs @@ -497,6 +497,7 @@ pub(crate) mod tests { permutation_challenges[i].cons(), ); } + builder.halt(); let program = builder.compile_program(); run_test_recursion(program, Some(witness_stream.into()), TestConfig::All); @@ -521,6 +522,7 @@ pub(crate) mod tests { let a_plus_b_ext = builder.eval(a_ext + b_ext); builder.print_f(a_plus_b); builder.print_e(a_plus_b_ext); + builder.halt(); let program = builder.compile_program(); let elapsed = time.elapsed();