From 93f6a715b71b062cd52a0163d6114b997561b1c8 Mon Sep 17 00:00:00 2001 From: Alex Keizer Date: Mon, 19 Aug 2024 16:04:52 -0500 Subject: [PATCH] Rework step theorem generation to be faster and cache intermediate results to the environment (#92) ### Description: Forewarning: the diff of this one is rather large. I've implemented a new version of step theorem generation, culminating in the `genStepEqTheorem` function. In particular: - We get rid of the different fetch/decode/exec intermediate lemma, instead opting to generate the final step theorem in one go. The bottleneck seems to be kernel checking, so by reducing the number of theorems sent to the kernel we achieve a good speedup (SHA512 used to take about 55 seconds for the three generation commands combined, now it's about 35 seconds). - In the process, we build a `ProgramInfo` struct, which holds a bunch of interesting expressions that further proof automation could exploit. This programInfo is stored in a persistent environment extension, so that it is persisted in the olean files (hence, making it available to downstream files). - While building the previous, I've moved some definitions around and did other refactors: happy to split those off into their own PR if that makes reviewing easier. EDIT: split off #93 for the `#time` command ### Testing: `make all` succeeded ### License: By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license. --------- Co-authored-by: Siddharth Co-authored-by: Shilpi Goel --- Arm/BitVec.lean | 8 + Arm/Exec.lean | 10 + Arm/State.lean | 8 + Proofs/AES-GCM/GCMGmultV8Sym.lean | 4 +- Proofs/Experiments/Abs.lean | 11 +- Proofs/Popcount32.lean | 24 +-- Proofs/SHA512/Sha512StepLemmas.lean | 12 +- Tactics/Common.lean | 72 +++++-- Tactics/Reflect/FetchAndDecode.lean | 14 +- Tactics/Reflect/ProgramInfo.lean | 306 +++++++++++++++++++++++++--- Tactics/StepThms.lean | 207 ++++++++++++++++++- Tactics/Sym.lean | 9 +- Tactics/SymContext.lean | 22 +- 13 files changed, 596 insertions(+), 111 deletions(-) diff --git a/Arm/BitVec.lean b/Arm/BitVec.lean index 0930269a..bde49cf4 100644 --- a/Arm/BitVec.lean +++ b/Arm/BitVec.lean @@ -339,6 +339,14 @@ example : split 0xabcd1234#32 8 (by omega) = [0xab#8, 0xcd#8, 0x12#8, 0x34#8] := /-- Get the width of a bitvector. -/ protected def width (_ : BitVec n) : Nat := n +/-- Convert a bitvector into its hex representation, without leading zeroes. + +See `BitVec.toHex` if you do want the leading zeroes. + +NOTE: returns only the digits, without a `0x` prefix -/ +def toHexWithoutLeadingZeroes {w} (x : BitVec w) : String := + (Nat.toDigits 16 x.toNat).asString + ---------------------------------------------------------------------- attribute [ext] BitVec diff --git a/Arm/Exec.lean b/Arm/Exec.lean index f9bc7ccf..c8dd3db4 100644 --- a/Arm/Exec.lean +++ b/Arm/Exec.lean @@ -166,3 +166,13 @@ theorem run_onestep (s s': ArmState) (n : Nat) (h_nonneg : 0 < n): · cases h_nonneg · rename_i n simp [run] + +/-- helper lemma for automation -/ +theorem stepi_eq_of_fetch_inst_of_decode_raw_inst + (s : ArmState) (addr : BitVec 64) (rawInst : BitVec 32) (inst : ArmInst) + (h_err : r .ERR s = .None) + (h_pc : r .PC s = addr) + (h_fetch : fetch_inst addr s = some rawInst) + (h_decode : decode_raw_inst rawInst = some inst) : + stepi s = exec_inst inst s := by + simp only [stepi, h_err, h_pc, h_fetch, h_decode, read_err, read_pc] diff --git a/Arm/State.lean b/Arm/State.lean index 8315d24a..be289ecb 100644 --- a/Arm/State.lean +++ b/Arm/State.lean @@ -521,6 +521,14 @@ theorem fetch_inst_from_program unfold fetch_inst simp only +theorem fetch_inst_eq_of_prgram_eq_of_map_find + {state : ArmState} {program : Program} + {addr : BitVec 64} {inst? : Option (BitVec 32)} + (h_program : state.program = program) + (h_map : program.find? addr = inst?) : + fetch_inst addr state = inst? := by + rw [fetch_inst, h_program, h_map] + end Load_program_and_fetch_inst ---------------------------------------------------------------------- diff --git a/Proofs/AES-GCM/GCMGmultV8Sym.lean b/Proofs/AES-GCM/GCMGmultV8Sym.lean index 4fc97235..8cee03c3 100644 --- a/Proofs/AES-GCM/GCMGmultV8Sym.lean +++ b/Proofs/AES-GCM/GCMGmultV8Sym.lean @@ -4,9 +4,7 @@ import Tactics.StepThms namespace GCMGmultV8Program -#genStepTheorems gcm_gmult_v8_program thmType:="fetch" `state_simp_rules -#genStepTheorems gcm_gmult_v8_program thmType:="decodeExec" `state_simp_rules -#genStepTheorems gcm_gmult_v8_program thmType:="step" `state_simp_rules +#genStepEqTheorems gcm_gmult_v8_program theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState) (h_s0_program : s0.program = gcm_gmult_v8_program) diff --git a/Proofs/Experiments/Abs.lean b/Proofs/Experiments/Abs.lean index f2e599b6..69ee53e8 100644 --- a/Proofs/Experiments/Abs.lean +++ b/Proofs/Experiments/Abs.lean @@ -6,6 +6,8 @@ Author(s): Shilpi Goel, Siddharth Bhat The goal is to prove that this program implements absolute value correctly. -/ import Arm +import Tactics.StepThms +import Tactics.Sym namespace Abs @@ -19,14 +21,21 @@ def program : Program := def spec (x : BitVec 32) : BitVec 32 := BitVec.ofNat 32 x.toInt.natAbs +#genStepEqTheorems program + theorem correct {s0 sf : ArmState} (h_s0_pc : read_pc s0 = 0x4005d0#64) (h_s0_program : s0.program = program) (h_s0_err : read_err s0 = StateError.None) + (h_s0_sp : CheckSPAlignment s0) (h_run : sf = run program.length s0) : read_gpr 32 0 sf = spec (read_gpr 32 0 s0) ∧ - read_err sf = StateError.None := by sorry + read_err sf = StateError.None := by + simp (config := {ground := true}) at h_run + + sym1_n 5 + sorry /-- info: 'Abs.correct' depends on axioms: [propext, sorryAx, Classical.choice, Quot.sound] -/ #guard_msgs in #print axioms correct diff --git a/Proofs/Popcount32.lean b/Proofs/Popcount32.lean index 4ceb358b..fcead4eb 100644 --- a/Proofs/Popcount32.lean +++ b/Proofs/Popcount32.lean @@ -66,14 +66,7 @@ def popcount32_program : Program := (0x40061c#64 , 0xd65f03c0#32)] -- ret -#genStepTheorems popcount32_program thmType:="fetch" - --- #guard_msgs in --- #check popcount32_fetch_0x4005b4 - -#genStepTheorems popcount32_program thmType:="decodeExec" - -#genStepTheorems popcount32_program thmType:="step" `state_simp_rules +#genStepEqTheorems popcount32_program theorem popcount32_sym_no_error (s0 s_final : ArmState) (h_s0_pc : read_pc s0 = 0x4005b4#64) @@ -130,16 +123,15 @@ theorem popcount32_sym_no_error (s0 s_final : ArmState) section Tests /-- -info: popcount32_program.stepi_0x4005c0 (s sn : ArmState) (h_program : s.program = popcount32_program) +info: popcount32_program.stepi_eq_0x4005c0 {s : ArmState} (h_program : s.program = popcount32_program) (h_pc : r StateField.PC s = 4195776#64) (h_err : r StateField.ERR s = StateError.None) : - (sn = stepi s) = - (sn = - w StateField.PC (4195780#64) - (w (StateField.GPR 0#5) - (zeroExtend 64 ((zeroExtend 32 (r (StateField.GPR 0#5) s)).rotateRight 1) &&& 4294967295#64 &&& 2147483647#64) - s)) + stepi s = + w StateField.PC (4195780#64) + (w (StateField.GPR 0#5) + (zeroExtend 64 ((zeroExtend 32 (r (StateField.GPR 0#5) s)).rotateRight 1) &&& 4294967295#64 &&& 2147483647#64) + s) -/ -#guard_msgs in #check popcount32_program.stepi_0x4005c0 +#guard_msgs in #check popcount32_program.stepi_eq_0x4005c0 end Tests diff --git a/Proofs/SHA512/Sha512StepLemmas.lean b/Proofs/SHA512/Sha512StepLemmas.lean index 0613490b..4b246fbd 100644 --- a/Proofs/SHA512/Sha512StepLemmas.lean +++ b/Proofs/SHA512/Sha512StepLemmas.lean @@ -1,17 +1,15 @@ -import Proofs.SHA512.Sha512FetchLemmas -import Proofs.SHA512.Sha512DecodeExecLemmas import Proofs.SHA512.Sha512Program --- import Tests.SHA2.SHA512ProgramTest +import Tactics.StepThms -- set_option trace.gen_step.debug.heartBeats true in -- set_option trace.gen_step.print_names true in set_option maxHeartbeats 2000000 in -#genStepTheorems sha512_program thmType:="step" `state_simp_rules +#genStepEqTheorems sha512_program /-- -info: sha512_program.stepi_0x126c90 (s sn : ArmState) (h_program : s.program = sha512_program) +info: sha512_program.stepi_eq_0x126c90 {s : ArmState} (h_program : s.program = sha512_program) (h_pc : r StateField.PC s = 1207440#64) (h_err : r StateField.ERR s = StateError.None) : - (sn = stepi s) = (sn = w StateField.PC (if ¬r (StateField.GPR 2#5) s = 0#64 then 1205504#64 else 1207444#64) s) + stepi s = w StateField.PC (if ¬r (StateField.GPR 2#5) s = 0#64 then 1205504#64 else 1207444#64) s -/ #guard_msgs in -#check sha512_program.stepi_0x126c90 +#check sha512_program.stepi_eq_0x126c90 diff --git a/Tactics/Common.lean b/Tactics/Common.lean index 88c0d06b..8b5442e8 100644 --- a/Tactics/Common.lean +++ b/Tactics/Common.lean @@ -55,7 +55,7 @@ def getBitVecString? (e : Expr) (hex : Bool := false): MetaM (Option String) := | some ⟨_, value⟩ => if hex then -- We don't want leading zeroes here. - return some (Nat.toDigits 16 value.toNat).asString + return some value.toHexWithoutLeadingZeroes else return some (ToString.toString value.toNat) | none => return none @@ -88,13 +88,13 @@ that additionally recognizes: -- TODO: should this be upstreamed to core? def getBitVecValue? (e : Expr) : MetaM (Option ((n : Nat) × BitVec n)) := match_expr e with - | BitVec.ofFin _ i => OptionT.run do - let ⟨n, i⟩ ← getFinValue? i - let n' := Nat.log2 n - if h : n = 2^n' then - return ⟨n', .ofFin (Fin.cast h i)⟩ - else - failure + | BitVec.ofFin w i => OptionT.run do + let w ← getNatValue? w + let v ← do + match_expr i with + | Fin.mk _n v _h => getNatValue? v + | _ => pure (← getFinValue? i).2.val + return ⟨w, BitVec.ofNat w v⟩ | _ => Lean.Meta.getBitVecValue? e /-- Given a ground term `e` of type `Nat`, fully reduce it, @@ -115,21 +115,53 @@ which was obtained by reducing:\n\t{e}" reduce an expression `e` (of type `BitVec w`) to be of the form `?n#w`, and then reflect `?n` to build the meta-level bitvector -/ def reflectBitVecLiteral (w : Nat) (e : Expr) : MetaM (BitVec w) := do - if e.hasFVar then + if e.hasFVar || e.hasMVar then throwError "Expected a ground term, but {e} has free variables" - if let some ⟨n, x⟩ ← _root_.getBitVecValue? e then - if h : n = w then - return x.cast h - else - throwError "Expected a bitvector of width {w}, but\n\t{e}\nhas width {n}" + let some ⟨n, x⟩ ← _root_.getBitVecValue? e + | throwError "Failed to reflect:\n\t{e}\ninto a BitVec" - let x ← mkFreshExprMVar (Expr.const ``Nat []) - let e' ← mkAppM ``BitVec.ofNat #[toExpr w, x] - if (←isDefEq e e') then - return BitVec.ofNat w (← reflectNatLiteral x) + if h : n = w then + return x.cast h else - throwError "Failed to unify, expected:\n\t{e'}\nbut found:\n\t{e'}" + throwError "Expected a bitvector of width {w}, but\n\t{e}\nhas width {n}" + +/-! ## Hypothesis types -/ +namespace SymContext + +/-- `h_err_type state` returns an Expr for `r .ERR = .None`, +the expected type of `h_err` -/ +def h_err_type (state : Expr) : Expr := + mkAppN (mkConst ``Eq [1]) #[ + mkConst ``StateError, + mkApp2 (.const ``r []) (.const ``StateField.ERR []) state, + .const ``StateError.None [] + ] + +/-- `h_sp_type state` returns an Expr for `CheckSPAlignment `, +the expected type of `h_sp` -/ +def h_sp_type (state : Expr) : Expr := + mkApp (.const ``CheckSPAlignment []) state + +/-- `h_sp_type state` returns an Expr for `.program = `, +the expected type of `h_program` -/ +def h_program_type (state program : Expr) : Expr := + mkAppN (mkConst ``Eq [1]) #[ + mkConst ``Program, + mkApp (mkConst ``ArmState.program) state, + program + ] + +/-- `h_pc_type state` returns an Expr for `r .PC =
`, +the expected type of `h_pc` -/ +def h_pc_type (state address : Expr) : Expr := + mkAppN (mkConst ``Eq [1]) #[ + mkApp (mkConst ``BitVec) (toExpr 64), + mkApp2 (mkConst ``r) (mkConst ``StateField.PC) state, + address + ] + +end SymContext /-! ## Local Context Search -/ @@ -162,7 +194,7 @@ Throws an error if no such hypothesis could. -/ def findProgramHyp (state : Expr) : MetaM (LocalDecl × Name) := do -- Try to find `h_program`, and infer `program` from it let program ← mkFreshExprMVar none - let h_program_type ← mkEq (← mkAppM ``ArmState.program #[state]) program + let h_program_type := SymContext.h_program_type state program let h_program ← findLocalDeclOfTypeOrError h_program_type -- Assert that `program` is a(n application of a) constant, and find its name diff --git a/Tactics/Reflect/FetchAndDecode.lean b/Tactics/Reflect/FetchAndDecode.lean index 04c6f2e7..c9583f20 100644 --- a/Tactics/Reflect/FetchAndDecode.lean +++ b/Tactics/Reflect/FetchAndDecode.lean @@ -13,24 +13,15 @@ open Elab.Tactic Elab.Term initialize Lean.registerTraceClass `Sym.reduceFetchInst -theorem fetch_inst_eq_of_prgram_eq_of_map_find - {state : ArmState} {program : Program} - {addr : BitVec 64} {inst? : Option (BitVec 32)} - (h_program : state.program = program) - (h_map : program.find? addr = inst?) : - fetch_inst addr state = inst? := by - rw [fetch_inst, h_program, h_map] - -def reduceFetchInst? (addr : Expr) (s : Expr) : +def reduceFetchInst? (addr : BitVec 64) (s : Expr) : MetaM (BitVec 32 × Expr) := do - let addr ← reflectBitVecLiteral 64 addr let ⟨programHyp, program⟩ ← findProgramHyp s let programInfo ← try ProgramInfo.lookupOrGenerate program catch err => throwErrorAt err.getRef "Could not generate ProgramInfo for {program}:\n\n{err.toMessageData}" - let some rawInst := programInfo.getRawInstrAt? addr + let some rawInst := programInfo.getRawInstAt? addr | throwError "No instruction found at address {addr}" trace[Sym.reduceFetchInst] "{Lean.checkEmoji} reduced to: {rawInst}" @@ -54,6 +45,7 @@ simproc reduceFetchInst (fetch_inst _ _) := fun e => do trace[Sym.reduceFetchInst] "⚙️ simplifying {e}" let_expr fetch_inst addr s := e | return .continue + let addr ← reflectBitVecLiteral 64 addr try let ⟨x, proof?⟩ ← reduceFetchInst? addr s diff --git a/Tactics/Reflect/ProgramInfo.lean b/Tactics/Reflect/ProgramInfo.lean index 368b74ca..897ec2a3 100644 --- a/Tactics/Reflect/ProgramInfo.lean +++ b/Tactics/Reflect/ProgramInfo.lean @@ -17,81 +17,335 @@ Furthermore, we define a persistent env extension to store `ProgramInfo` in. open Lean Meta Elab.Term +initialize + registerTraceClass `ProgramInfo + +/-- `OnDemand α` is morally an `Option α`, +we use it for values that are computed, and cached, on demand. -/ +inductive InstInfo.OnDemand (α : Type) + /-- a value has not yet been cached, + you should run the relevant computation -/ + | notYetComputed + /-- a value has been cached -/ + | value (value : α) +open InstInfo (OnDemand) + +structure InstInfo where + /-- the raw instruction, as a bitvector -/ + rawInst : BitVec 32 + + /-- the decoded instruction, as a normalized(!) `Expr` of type `ArmInst`. + That is, `decode_raw_inst ` should be def-eq to `some `. + -/ + decodedInst? : OnDemand Expr := + .notYetComputed + + /-- if `instSemantics?` is `⟨sem, type, proof⟩`, then + - `sem` is the instruction semantics, as a simplified expression of type + `ArmState → ArmState`. + + That is, we've ran `simp` on `sem` with our dedicated simp-sets in the hopes + of obtaining only a sequence of `w` and `write_mem`s to the initial state. + However, note that some instructions might have conditional behaviour, + in which case `sem` might still contain `if`s + - `type` is the expression + ```lean + ∀ s (h_program : s.program = ) (h_pc : read_pc s = ) + (h_err : read_err s = .None), + exec_inst s = s + ``` + - `proof` is a proof of type `type` + -/ + instSemantics? : OnDemand (Expr × Expr × Expr) := + .notYetComputed + structure ProgramInfo where - rawProgram : HashMap (BitVec 64) (BitVec 32) + name : Name + instructions : HashMap (BitVec 64) InstInfo + +-------------------------------------------------------------------------------- + +/-! ## InstInfoT -/ + +/-- A monad transformer with `InstInfo` state -/ +abbrev InstInfoT := StateT InstInfo + +namespace InstInfoT +variable {m} [Monad m] + +/-- Return `InstInfo.rawInst` from the state -/ +def getRawInst : InstInfoT m (BitVec 32) := do + return (← get).rawInst + +/-- Return `InstInfo.decodedInst?` from the state if it is `some _`, +or use `f` to compute the relevant expression if it is missing -/ +def getDecodedInst (f : Unit → InstInfoT m Expr) : InstInfoT m Expr := do + let info ← get + match info.decodedInst? with + | .value val => return val + | .notYetComputed => + let val ← f () + set {info with decodedInst? := .value val} + return val + +/-- Return `InstInfo.instSemantics?` from the state if it is `some _`, +or use `f` to compute the relevant expressions if they are missing -/ +def getInstSemantics (f : Unit → InstInfoT m (Expr × Expr × Expr)) : + InstInfoT m (Expr × Expr × Expr) := do + let info ← get + match info.instSemantics? with + | .value val => return val + | .notYetComputed => + let val ← f () + set {info with instSemantics? := .value val} + return val -def ProgramInfo.getRawInstrAt? (pi : ProgramInfo) (addr : BitVec 64) : +end InstInfoT + +def InstInfo.ofRawInst (rawInst : BitVec 32) : InstInfo := + { rawInst } + +-------------------------------------------------------------------------------- + +namespace ProgramInfo + +/-- The expression `mkConst pi.name`, +i.e., an expression of this program referred to by name -/ +def expr (pi : ProgramInfo) : Expr := mkConst pi.name + +def getInstInfoAt? (pi : ProgramInfo) (addr : BitVec 64) : + Option InstInfo := + pi.instructions.find? addr + +def getRawInstAt? (pi : ProgramInfo) (addr : BitVec 64) : Option (BitVec 32) := - pi.rawProgram.find? addr + (·.rawInst) <$> pi.getInstInfoAt? addr -/-- Given an `Expr` of type `Program`, generate the basic `ProgramInfo` -/ -partial def ProgramInfo.generateFromExpr (e : Expr) : MetaM ProgramInfo := do +-- TODO: this instance could be upstreamed (after cleaning it up) +instance [BEq α] [Hashable α] : ForIn m (HashMap α β) (α × β) where + forIn map acc f := do + let f := fun (acc : ForInStep _) key val => do + match acc with + | .yield acc => f ⟨key, val⟩ acc + | .done _ => return acc + match ← map.foldM f (ForInStep.yield acc) with + | .done x | .yield x => return x + +/-! ## ProgramInfo Generation -/ + +/-- Given the name and defining expression of a `Program`, +generate the basic `ProgramInfo` -/ +partial def generateFromExpr (name : Name) (e : Expr) : MetaM ProgramInfo := do + trace[ProgramInfo] "Generating program info for `{name}` from definition:\n\t{e}" let type ← inferType e if !(←isDefEq type (mkConst ``Program)) then throwError "type mismatch: {e} {← mkHasTypeButIsExpectedMsg type (mkConst ``Program)}" - let rec go (rawProgram : HashMap _ _) (e : Expr) : MetaM (HashMap _ _) := do + let rec go (instructions : HashMap _ _) (e : Expr) : MetaM (HashMap _ _) := do let e ← whnfD e match_expr e with | List.cons _ hd tl => do + trace[ProgramInfo] "found address/instruction pair: {hd}" + let hd' ← reduce hd let_expr Prod.mk _ _ addr inst := hd' | throwError "expected `{hd}` to reduce to an application of `Prod.mk`, found:\n\t{hd'}" - let addr ← reflectBitVecLiteral 64 addr - let inst ← reflectBitVecLiteral 32 inst + let addr ← reflectBitVecLiteral 64 (← instantiateMVars addr) + let rawInst ← reflectBitVecLiteral 32 (← instantiateMVars inst) + let rawProgram := + let info := InstInfo.ofRawInst rawInst + instructions.insert addr info - let rawProgram := rawProgram.insert addr inst go rawProgram tl - | List.nil _ => return rawProgram + | List.nil _ => return instructions | _ => throwError "expected `List.cons _ _` or `List.nil`, found:\n\t{e}" return { - rawProgram := ← go ∅ e + name, + instructions := ← go ∅ e } /-- Given the `Name` of a constant of type `Program`, generate the basic `ProgramInfo` -/ -def ProgramInfo.generateFromConstName (program : Name) : MetaM ProgramInfo := do +def generateFromConstName (program : Name) : MetaM ProgramInfo := do let .defnInfo defnInfo ← getConstInfo program | throwError "expected a definition, but {program} is not" - generateFromExpr defnInfo.value + generateFromExpr program defnInfo.value /-! ## Env Extension -/ -initialize programInfoExt : PersistentEnvExtension (Name × ProgramInfo) (Name × ProgramInfo) (NameMap ProgramInfo) ← +initialize programInfoExt : PersistentEnvExtension (ProgramInfo) (ProgramInfo) (NameMap ProgramInfo) ← registerPersistentEnvExtension { name := `programInfo mkInitial := pure {} addImportedFn := fun _ _ => pure {} - addEntryFn := fun s p => s.insert p.1 p.2 + addEntryFn := fun s p => s.insert p.name p exportEntriesFn := fun m => - let r : Array (Name × _) := m.fold (fun a n p => a.push (n, p)) #[] - r.qsort (fun a b => Name.quickLt a.1 b.1) + let r : Array (ProgramInfo) := + m.fold (fun a name p => a.push {p with name}) #[] + r.qsort (fun a b => Name.quickLt a.name b.name) statsFn := fun s => "program info extension" ++ Format.line ++ "number of local entries: " ++ format s.size } -/-- store a `PogramInfo` for the given `program` in the environment -/ -private def ProgramInfo.store [Monad m] [MonadEnv m] - (program : Name) (pi : ProgramInfo) : m Unit := do - modifyEnv (programInfoExt.addEntry · ⟨program, pi⟩) +/-- persistently store a `ProgramInfo` in the environment -/ +def persistToEnv [Monad m] [MonadEnv m] (pi : ProgramInfo) : m Unit := do + modifyEnv (programInfoExt.addEntry · pi) /-- look up the `ProgramInfo` for a given `program` in the environment, returns `None` if not found -/ -def ProgramInfo.lookup? [Monad m] [MonadEnv m] (program : Name) : +def lookup? [Monad m] [MonadEnv m] (program : Name) : m (Option ProgramInfo) := do let env ← getEnv let state := programInfoExt.getState env return state.find? program /-- look up the `ProgramInfo` for a given `program` in the environment, -or, if none was found, generate (and cache) new program info -/ -def ProgramInfo.lookupOrGenerate (program : Name) : MetaM ProgramInfo := do +or, if none was found, generate new program info. + +If you pass in a value for `expr?`, that is assumed to be the definition for +`program` when generating new program info. +If you don't pass in an expr, the definition is found in the environment + +If `persist` is set to true (the default), then the newly generated program info +will be persistently cached in the environment (see `persistToEnv`) -/ +def lookupOrGenerate (program : Name) (expr? : Option Expr := none) + (persist : Bool := true) : + MetaM ProgramInfo := do if let some pi ← lookup? program then return pi else - let pi ← generateFromConstName program - store program pi + let pi ← match expr? with + | some expr => generateFromExpr program expr + | none => generateFromConstName program + if persist then + persistToEnv pi return pi + +end ProgramInfo + +/-! ## `ProgramInfoT` Monad Transformer -/ + +/-- A monad transformer with `ProgramInfo` state -/ +abbrev ProgramInfoT (m : Type → Type) := StateT ProgramInfo m + +namespace ProgramInfoT +variable [Monad m] [MonadEnv m] [MonadError m] + +/-! ### run -/ +section Run +variable [MonadLiftT MetaM m] + +protected def run' (programName : Name) (expr? : Option Expr) (persist : Bool) + (k : ProgramInfoT m α) : m α := do + let pi ← ProgramInfo.lookupOrGenerate programName expr? + let ⟨a, pi⟩ ← StateT.run k pi + if persist then + pi.persistToEnv + return a + +/-- run a `ProgramInfoT m` by looking up, or generating new program info, +by name. + +If `persist` is set to true, then the program info state after +executing `k` will be persistently cached in the environment +(see `persistToEnv`). -/ +def run (programName : Name) (k : ProgramInfoT m α) + (persist : Bool := false) : + m α := + ProgramInfoT.run' programName none persist k + +/-- run a `ProgramInfoT m` by looking up, or generating new program info. +The passed expression is assumed to be the definition of the program. + +If `persist` is set to true (the default), then the program info state after +executing `k` will be persistently cached in the environment +(see `persistToEnv`). -/ +def runE (programName : Name) (expr : Expr) (k : ProgramInfoT m α) + (persist : Bool := false) + : m α := + ProgramInfoT.run' programName expr persist k + +end Run + +/-! ### MonadError instance -/ + +instance [Monad m] [i : MonadError m] : MonadError (ProgramInfoT m) where + throw e := i.throw e + tryCatch k f := fun s => i.tryCatch (k s) (fun e => f e s) + getRef := i.getRef + withRef stx k := fun s => i.withRef stx (k s) + add stx msg := i.add stx msg + +/-! ### Wrappers -/ + +/-- Access the info for the instruction at a given address, +or throw an error if none is found -/ +def getInstInfoAt (addr : BitVec 64) : ProgramInfoT m InstInfo := do + let some x := (← StateT.get).getInstInfoAt? addr + | let addr := addr.toHexWithoutLeadingZeroes + throwError "No instruction found at address {addr}" + return x + +/-- Set the instruction info for a particular address -/ +def setInstInfoAt (addr : BitVec 64) (info : InstInfo) : + ProgramInfoT m Unit := do + let pi ← StateT.get + StateT.set {pi with instructions := pi.instructions.insert addr info} + +/-- Run `k` with the instruction info for the given address as initial state, +and store the resulting state at that same address. +Returns the value produced by `k`. Throws an error if the address is invalid -/ +def modifyInstInfoAt (addr : BitVec 64) (k : InstInfoT m α) : + ProgramInfoT m α := do + let info ← getInstInfoAt addr + let ⟨val, info⟩ ← monadLift (StateT.run k info) + setInstInfoAt addr info + return val + +/-! ### InstInfo Accessors -/ + +def getRawInstAt (addr : BitVec 64) : ProgramInfoT m (BitVec 32) := do + return (← getInstInfoAt addr).rawInst + +/-- if `decodedInst?` is `some _` for the instruction info at the given address, +return the cached value. +Otherwise, use `k` to compute the decoded instruction, then +cache and return that new value. +See `InstInfo.decodInst?` for the meaning of this field. + +NOTE: the computed value is only cached in the `ProgramInfoT` monad state, +not yet in the environment. -/ +def getDecodedInstAt (addr : BitVec 64) (k : InstInfo → ProgramInfoT m Expr) : + ProgramInfoT m Expr := do + let info ← getInstInfoAt addr + match info.decodedInst? with + | .value e => return e + | .notYetComputed => + let decodedInst ← k info + setInstInfoAt addr {info with decodedInst? := .value decodedInst} + return decodedInst + +/-- if `instSemantics?` is `some _` for the instruction info at the given address, +return the cached value. +Otherwise, use `k` to compute the decoded instruction, then +cache and return that new value. +See `InstInfo.instSemantics?` for the meaning of this field. + +NOTE: the computed value is only cached in the `ProgramInfoT` monad state, +not yet in the environment. -/ +def getInstSemanticsAt (addr : BitVec 64) + (k : InstInfo → ProgramInfoT m (Expr × Expr × Expr)) : + ProgramInfoT m (Expr × Expr × Expr) := do + let info ← getInstInfoAt addr + match info.instSemantics? with + | .value e => return e + | .notYetComputed => + let instSemantics ← k info + setInstInfoAt addr {info with instSemantics? := .value instSemantics} + return instSemantics + + +end ProgramInfoT diff --git a/Tactics/StepThms.lean b/Tactics/StepThms.lean index 2e2f14b0..3f9eb8df 100644 --- a/Tactics/StepThms.lean +++ b/Tactics/StepThms.lean @@ -1,10 +1,18 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Author(s): Shilpi Goel, Alex Keizer +-/ import Lean import Arm.Map import Arm.Decode import Tactics.Common import Tactics.Simp import Tactics.ChangeHyps +import Tactics.Reflect.ProgramInfo + open Lean Lean.Expr Lean.Meta Lean.Elab Lean.Elab.Command +open SymContext (h_pc_type h_program_type h_err_type) -- NOTE: This is an experimental and probably quite shoddy method of autogenerating -- `stepi` theorems from a program under verification, and things may change @@ -33,11 +41,194 @@ initialize registerTraceClass `gen_step.print_names initialize registerTraceClass `gen_step.debug /- When true, prints the number of heartbeats taken per theorem. -/ initialize registerTraceClass `gen_step.debug.heartBeats +/- When true, prints the time taken at various steps of generation. -/ +initialize registerTraceClass `gen_step.debug.timing + +/-- Assuming that `rawInst` is indeed the right result, construct a proof that + `fetch_inst addr state = some rawInst` +given that `state.program = program` -/ +private def fetchLemma (state program h_program : Expr) + (addr : BitVec 64) (rawInst : BitVec 32) : Expr := + let someRawInst := toExpr (some rawInst) + mkAppN (mkConst ``fetch_inst_eq_of_prgram_eq_of_map_find) #[ + state, + program, + toExpr addr, + someRawInst, + h_program, + mkApp2 (.const ``Eq.refl [1]) + (mkApp (.const ``Option [0]) <| + mkApp (.const ``BitVec []) (toExpr 32)) + someRawInst + ] + +-- /-! ## `reduceDecodeInst` -/ + +/-- `canonicalizeBitVec e` recursively walks over expression `e` to convert any +occurrences of: + `BitVec.ofFin w (Fin.mk x _)` +to the canonical form: + `BitVec.ofNat w x` (i.e., `x#w`) + +Such expressions tend to result from using `reduce` or +`simp` with `{ground := true}`. +You can call `canonicalizeBitVec` after these functions to ensure you don't +needlessly expose `BitVec` internal details -/ +-- TODO: should this canonicalize to `BitVec.ofNatLt` instead, +-- as the current transformation loses information? +partial def canonicalizeBitVec (e : Expr) : MetaM Expr := do + match_expr e with + | BitVec.ofFin w i => + let_expr Fin.mk _ x _h := i | fallback + let w ← + if w.hasFVar || w.hasMVar then + pure w + else + withTransparency .all <| reduce w + -- ^^ NOTE: potentially expensive reduction + return mkApp2 (mkConst ``BitVec.ofNat) w x + | _ => fallback + where + fallback : MetaM Expr := do + let fn := e.getAppFn + let args ← e.getAppArgs.mapM canonicalizeBitVec + return mkAppN fn args + +/-- Given an expr `rawInst` of type `BitVec 32`, +return an expr of type `Option ArmInst` representing what `rawInst` decodes to. +The resulting expr is guaranteed to be def-eq to `decode_raw_inst $rawInst` -/ +def reduceDecodeInstExpr (rawInst : Expr) : MetaM Expr := do + let expr := mkApp (mkConst ``decode_raw_inst) rawInst + let expr ← withTransparency .all <| reduce expr + -- ^^ NOTE: possibly expensive reduction + canonicalizeBitVec expr + +/-! ## SymM Monad -/ + +abbrev SymM.CacheKey := BitVec 32 +abbrev SymM.CacheM := MonadCacheT CacheKey Expr MetaM +abbrev SymM := ProgramInfoT <| MonadCacheT SymM.CacheKey Expr MetaM + +@[inherit_doc ProgramInfoT.run] +abbrev SymM.run (name : Name) (k : SymM α) (persist : Bool := true) : MetaM α := + MonadCacheT.run <| ProgramInfoT.run name k persist + +open SymM in +/-- Given a (reflected) raw instruction, +return an expr of type `Option ArmInst` representing what `rawInst` decodes to. +The resulting expr is guaranteed to be def-eq to `decode_raw_inst $rawInst`. + +Results are cached so that the same instruction is not reduced multiple times -/ +def reduceDecodeInst (rawInst : BitVec 32) : CacheM Expr := + checkCache (rawInst) fun _ => + reduceDecodeInstExpr (toExpr rawInst) + +open ProgramInfoT InstInfoT + +/-! ## reduceStepiToExecInst -/ + +/-- Given a program and an address, and optionally the corresponding +raw and decoded instructions, construct and return first the expression: +``` +∀ {s} (h_program : s.program = ) (h_pc : r .PC s = ) + (h_err : r .ERR s = .None), + stepi s = s`> +``` +and then a proof of this fact. +That is, in + `let ⟨type, value⟩ ← reduceStepi ...` +`value` is an expr whose type is `type` -/ +def reduceStepi (addr : BitVec 64) : SymM (Expr × Expr) := do + let pi : ProgramInfo ← get + let ⟨_, type, proof⟩ ← modifyInstInfoAt addr <| getInstSemantics fun _ => do + let rawInst ← getRawInst + + let inst ← getDecodedInst <| fun _ => do + let optInst ← reduceDecodeInst rawInst + let_expr some _ inst := optInst + | let some := mkConst ``Option.some [1] + throwError "Expected an application of {some}, found:\n\t{optInst}" + pure inst + + withLocalDecl `s .implicit (mkConst ``ArmState) <| fun s => + withLocalDeclD `h_program (h_program_type s pi.expr) <| fun h_program => + withLocalDeclD `h_pc (h_pc_type s (toExpr addr)) <| fun h_pc => + withLocalDeclD `h_err (h_err_type s) <| fun h_err => do + let h_fetch := fetchLemma s pi.expr h_program addr rawInst + let h_decode := + let armInstTy := mkConst ``ArmInst + mkApp2 (mkConst ``Eq.refl [1]) + (mkApp (mkConst ``Option [0]) armInstTy) + (mkApp2 (mkConst ``Option.some [0]) armInstTy inst) + + let proof := -- stepi s = exec_inst s + mkAppN (mkConst ``stepi_eq_of_fetch_inst_of_decode_raw_inst) #[ + s, toExpr addr, toExpr rawInst, inst, + h_err, h_pc, h_fetch, h_decode + ] + let type ← inferType proof + + let (ctx, simprocs) ← do + let localDecls ← do + let hs := #[h_pc, h_err] + pure <| hs.filterMap (← getLCtx).findFVar? + LNSymSimpContext + (config := {decide := true, ground := false}) + (simp_attrs := #[`minimal_theory, `bitvec_rules, `state_simp_rules]) + (decls := localDecls) + (decls_to_unfold := #[``exec_inst]) + + let ⟨simpRes, _⟩ ← simp type ctx simprocs + + let_expr Eq _ _ sem := simpRes.expr + | let eq ← mkEq (← mkFreshExprMVar none) (← mkFreshExprMVar none) + throwError "Failed to normalize instruction semantics. Expected {eq}, but found:\n\t{simpRes.expr}" + let sem ← mkLambdaFVars #[s] sem + + let proof ← simpRes.mkCast proof -- stepi s = + let hs := #[s, h_program, h_pc, h_err] + let proof ← mkLambdaFVars hs proof + let type ← mkForallFVars hs simpRes.expr + return ⟨sem, type, proof⟩ + return ⟨type, proof⟩ + +def genStepEqTheorems : SymM Unit := do + let pi ← get + for ⟨addr, instInfo⟩ in pi.instructions do + let startTime ← IO.monoMsNow + let inst := instInfo.rawInst + + trace[gen_step.debug] "[genStepEqTheorems] Generating theorem for address {addr.toHex}\ + with instruction {inst.toHex}" + let name := let addr_str := addr.toHexWithoutLeadingZeroes + Name.str pi.name ("stepi_eq_0x" ++ addr_str) + let ⟨type, value⟩ ← reduceStepi addr + + trace[gen_step.debug.timing] "[genStepEqTheorems] reduced in: {(← IO.monoMsNow) - startTime}ms" + addDecl <| Declaration.thmDecl { + name, type, value, + levelParams := [] + } + trace[gen_step.debug.timing] "[genStepEqTheorems] added to environment in: {(← IO.monoMsNow) - startTime}ms" + +/-- `#genProgramInfo program` ensures the `ProgramInfo` for `program` +has been generated and persistently cached in the enviroment -/ +elab "#genProgramInfo" program:ident : command => liftTermElabM do + let _ ← ProgramInfo.lookupOrGenerate program.getId + + +elab "#genStepEqTheorems" program:term : command => liftTermElabM do + let .const name _ ← Elab.Term.elabTerm program (mkConst ``Program) + | throwError "Expected a constant, found: {program}" + + SymM.run name (persist := true) <| + genStepEqTheorems + /- Generate and prove a fetch theorem of the following form: ``` -theorem ( ++ "fetch_0x" ++ ) (s : ArmState) - (h : s.program = ) : fetch_inst s = some +theorem .("fetch_0x" ++
) (s : ArmState) + (h : s.program = ) : fetch_inst
s = some ``` -/ def genFetchTheorem (program_name : Name) (address_str : String) @@ -110,7 +301,7 @@ def genFetchTheorem (program_name : Name) (address_str : String) /- Generate and prove an exec theorem of the following form: ``` -theorem ( ++ "exec_0x" ++ ) (s : ArmState) : +theorem .("exec_0x" ++ ) (s : ArmState) : exec_inst s = ``` -/ @@ -359,6 +550,8 @@ def test_program : Program := (0x126514#64 , 0x4ea21c5c#32), -- mov v28.16b, v2.16b (0x126518#64 , 0x4ea31c7d#32)] -- mov v29.16b, v3.16b +#genStepEqTheorems test_program + #genStepTheorems test_program thmType:="fetch" /-- info: test_program.fetch_0x126510 (s : ArmState) (h : s.program = test_program) : @@ -402,6 +595,14 @@ info: test_program.stepi_0x126510 (s sn : ArmState) (h_program : s.program = tes #guard_msgs in #check test_program.stepi_0x126510 +/-- +info: test_program.stepi_eq_0x126510 {s : ArmState} (h_program : s.program = test_program) + (h_pc : r StateField.PC s = 1205520#64) (h_err : r StateField.ERR s = StateError.None) : + stepi s = w StateField.PC (1205524#64) (w (StateField.SFP 27#5) (r (StateField.SFP 1#5) s) s) +-/ +#guard_msgs in +#check test_program.stepi_eq_0x126510 + -- Here's the theorem that we'd actually like to obtain instead of the -- erstwhile test_stepi_0x126510. theorem test_stepi_0x126510_desired (s sn : ArmState) diff --git a/Tactics/Sym.lean b/Tactics/Sym.lean index fd147875..4ea216ec 100644 --- a/Tactics/Sym.lean +++ b/Tactics/Sym.lean @@ -103,18 +103,15 @@ def stepiTac (h_step : Ident) (ctx : SymContext) : TacticM Unit := withMainContext do let pc := (Nat.toDigits 16 ctx.pc.toNat).asString -- ^^ The PC in hex - let step_lemma := mkIdent <| Name.str ctx.program s!"stepi_0x{pc}" + let step_lemma := mkIdent <| Name.str ctx.program s!"stepi_eq_0x{pc}" evalTacticAndTrace <|← `(tactic| ( replace $h_step := - (propext_iff.mp + _root_.Eq.trans $h_step ($step_lemma:ident - _ - $ctx.next_state_ident:ident $ctx.h_program_ident:ident $ctx.h_pc_ident:ident - $ctx.h_err_ident:ident)).mp - $h_step + $ctx.h_err_ident:ident) )) elab "stepi_tac" h_step:ident : tactic => do diff --git a/Tactics/SymContext.lean b/Tactics/SymContext.lean index f5cf1b42..42b040bd 100644 --- a/Tactics/SymContext.lean +++ b/Tactics/SymContext.lean @@ -80,18 +80,6 @@ structure SymContext where curr_state_number : Nat := 0 deriving Repr -/-- `h_err_type state` returns an Expr representing `r state = .None`, -the expected type of `h_err` -/ -private def h_err_type (state : Expr) : MetaM Expr := - mkEq - (mkApp2 (.const ``r []) (.const ``StateField.ERR []) state) - (.const ``StateError.None []) - -/-- `h_sp_type state` returns an Expr representing `CheckSPAlignment state`, -the expected type of `h_sp` -/ -private def h_sp_type (state : Expr) : Expr := - mkApp (.const ``CheckSPAlignment []) state - namespace SymContext /-! ## Creating initial contexts -/ @@ -131,7 +119,7 @@ def addGoalsForMissingHypotheses (ctx : SymContext) : TacticM SymContext := let newGoal ← mkFreshMVarId goal := ← do - let goalType ← h_err_type stateExpr + let goalType := h_err_type stateExpr let newGoalExpr ← mkFreshExprMVarWithId newGoal goalType let goal' ← goal.assert h_err? goalType newGoalExpr let ⟨_, goal'⟩ ← goal'.intro1P @@ -225,9 +213,7 @@ def fromLocalContext (state? : Option Name) : MetaM SymContext := do -- Then, try to find `h_pc` let pc ← mkFreshExprMVar (← mkAppM ``BitVec #[toExpr 64]) - let h_pc_type ← do - let lhs ← mkAppM ``r #[(.const ``StateField.PC []), stateExpr] - mkEq lhs pc + let h_pc_type := h_pc_type stateExpr pc let h_pc ← findLocalDeclUsernameOfTypeOrError h_pc_type -- Unwrap and reflect `pc` @@ -235,9 +221,9 @@ def fromLocalContext (state? : Option Name) : MetaM SymContext := do let pc ← withErrorContext h_pc h_pc_type <| reflectBitVecLiteral 64 pc -- Attempt to find `h_err` and `h_sp` - let h_err? ← findLocalDeclUsernameOfType? (←h_err_type stateExpr) + let h_err? ← findLocalDeclUsernameOfType? (h_err_type stateExpr) if h_err?.isNone then - trace[Sym] "Could not find local hypothesis of type {←h_err_type stateExpr}" + trace[Sym] "Could not find local hypothesis of type {h_err_type stateExpr}" let h_sp? ← findLocalDeclUsernameOfType? (h_sp_type stateExpr) if h_sp?.isNone then trace[Sym] "Could not find local hypothesis of type {h_sp_type stateExpr}"