Skip to content

Commit

Permalink
Implement a decision procedure for memory reads/write via omega (#56)
Browse files Browse the repository at this point in the history
This reduces goals about memory to goals about natural numbers, with the
hope of allowing the use of `omega` to dispatch such goals.

---------

Co-authored-by: Alex Keizer <[email protected]>
Co-authored-by: Shilpi Goel <[email protected]>
  • Loading branch information
3 people authored Aug 16, 2024
1 parent 9f261d1 commit 6d1ef32
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 25 deletions.
9 changes: 9 additions & 0 deletions Arm/Memory/Attr.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import Lean

open Lean

/-- Provides tracing for the `simp_mem` tactic. -/
initialize Lean.registerTraceClass `simp_mem

/-- Provides extremely verbose tracing for the `simp_mem` tactic. -/
initialize Lean.registerTraceClass `simp_mem.info
204 changes: 204 additions & 0 deletions Arm/Memory/SeparateAutomation.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author(s): Siddharth Bhat
In this file, we define proof automation for separation conditions of memory.
References:
- https://github.com/leanprover/lean4/blob/240ebff549a2cf557f9abe9568f5de885f13e50d/src/Lean/Elab/Tactic/Omega/OmegaM.lean
- https://github.com/leanprover/lean4/blob/240ebff549a2cf557f9abe9568f5de885f13e50d/src/Lean/Elab/Tactic/Omega/Frontend.lean
-/
import Arm
import Arm.Memory.MemoryProofs
import Arm.BitVec
import Arm.Memory.Attr
import Lean
import Lean.Meta.Tactic.Rewrite
import Lean.Meta.Tactic.Rewrites
import Lean.Elab.Tactic.Conv
import Lean.Elab.Tactic.Conv.Basic

open Lean Meta Elab Tactic


/-! ## Memory Separation Automation
##### A Note on Notation
- `[a..an)`: a range of memory starting with base address `a` of length `an`.
aka `mem_legal' a an`.
- `[a..an) ⟂ [b..bn)`: `mem_disjoint' a an b bn`.
- `[a..an] ⊆ [b..bn)`: `mem_subset' a an b bn`
##### Tactic Loop
The core tactic tries to simplify expressions of the form:
`mem.write_bytes [b..bn) val |>. read_bytes [a..an)`
by case splitting:
1. If `[a..an) ⟂ [b..bn)`, the write does not alias the read,
and can be replaced with ` mem.read_bytes [a..an) `
2. If `[a..an] ⊆ [b..bn)`, the write aliases the read, and can be replaced with
`val.extractLsBs adjust([a..an), [b..bn))`. Here, `adjust` is a function that
adjusts the read indices `[a..an)` with respect to the write indices `[b..bn)`,
to convert a read from `mem` into a read from `val`.
The tactic shall be implemented as follows:
1. Search the goal state for `mem.write_bytes [b..bn) val |>.read_bytes [a..an)`.
2. Try to prove that either `[a..an) ⟂ [b..bn)`, or `[a..an) ⊆ [b..bn)`.
2a. First search the local context for assumptions of this type.
2b. Try to deduce `[a..an) ⟂ [b..bn)` from the fact that
subsets of disjoint sets are disjoint.
So try to find `[a'..an')`, `[b'...bn')` such that:
(i) `[a..an) ⊆ [a'..an')`.
(ii) `[b..bn) ⊆ [b'..bn')`.
(iii) and `[a'..an') ⟂ [b'...bn')`.
2b. Try to deduce `[a..an) ⊆ [b..bn)` from transitivity of subset.
So try to find `[c..cn)` such that:
(i) `[a..an) ⊆ [c..cn)`
(ii) `[c..cn) ⊆ [b..bn)`
2d. If this also fails, then reduce all hypotheses to
linear integer arithmetic, and try to invoke `omega` to prove either
`[a..an) ⟂ [b..bn)` or `[a..an) ⊆ [b..bn)`.
3. Given a proof of either `[a..an) ⟂ [b..bn)` or `[a..an) ⊆ [b..bn)`,
simplify using the appropriate lemma from `Mem/Separate.lean`.
4. If we manage to prove *both* `[a..an) ⟂ [b..bn)` *and* `[a..an) ⊆ [b..bn)`,
declare victory as this is a contradiction. This may look useless,
but feels like it maybe useful to prove certain memory states as impossible.
##### Usability
- If no mem separate/subset assumptions are present,
then throw an error to tell the user that we expect them to
specify such assumptions for all memory regions of interest.
LNSym doesn't support automated verification of programs that
do dynamic memory allocation.
- If any non-primed separate/subset assumptions are detected,
error out to tell the user that no automation is supported in this case.
-/

namespace SeparateAutomation

structure SimpMemConfig where

/-- Context for the `SimpMemM` monad, containing the user configurable options. -/
structure Context where
/-- User configurable options for `simp_mem`. -/
cfg : SimpMemConfig

def Context.init (cfg : SimpMemConfig) : Context where
cfg := cfg

inductive Hypothesis
| separate (h : Expr) (a na b nb : Expr)
| subset (h : Expr)

def Hypothesis.expr : Hypothesis → Expr
| .separate h .. => h
| .subset h .. => h

instance : ToMessageData Hypothesis where
toMessageData
| .subset h => toMessageData h
| .separate h _a _na _b _nb => toMessageData h

/-- The internal state for the `SimpMemM` monad, recording previously encountered atoms. -/
structure State where
hypotheses : Array Hypothesis := #[]

def State.init : State := {}

abbrev SimpMemM := StateRefT State (ReaderT Context TacticM)

def SimpMemM.run (m : SimpMemM α) (cfg : SimpMemConfig) : TacticM α :=
m.run' State.init |>.run (Context.init cfg)

/-- Add a `Hypothesis` to our hypothesis cache. -/
def SimpMemM.addHypothesis (h : Hypothesis) : SimpMemM Unit :=
modify fun s => { s with hypotheses := s.hypotheses.push h }

def processingEmoji : String := "⚙️"

/-- Match an expression `h` to see if it's a useful hypothesis. -/
def processHypothesis (h : Expr) : MetaM (Option Hypothesis) := do
let ht ← inferType h
trace[simp_mem.info] "{processingEmoji} Processing '{h}' : '{toString ht}'"
match_expr ht with
| mem_separate' a ha b hb => return .some (.separate h a ha b hb)
| _ => return .none

/--
info: read_mem_bytes_write_mem_bytes_eq_read_mem_bytes_of_mem_separate' {x : BitVec 64} {xn : Nat} {y : BitVec 64} {yn : Nat}
{mem : ArmState} (hsep : mem_separate' x xn y yn) (val : BitVec (yn * 8)) :
read_mem_bytes xn x (write_mem_bytes yn y val mem) = read_mem_bytes xn x mem
-/
#guard_msgs in #check read_mem_bytes_write_mem_bytes_eq_read_mem_bytes_of_mem_separate'

partial def SimpMemM.rewrite (g : MVarId) : SimpMemM Unit := do
trace[simp_mem.info] "{processingEmoji} Matching on ⊢ {← g.getType}"
let some (_, _lhs, _rhs) ← matchEq? (← g.getType) | throwError "invalid goal, expected 'lhs = rhs'."
-- TODO: do this till fixpoint.
for h in (← get).hypotheses do
let x ← mkFreshExprMVar .none
let xn ← mkFreshExprMVar .none
let y ← mkFreshExprMVar .none
let yn ← mkFreshExprMVar .none
let state ← mkFreshExprMVar .none
let f := (Expr.const ``read_mem_bytes_write_mem_bytes_eq_read_mem_bytes_of_mem_separate' [])
let result : Option RewriteResult ←
try
pure <| some (← g.rewrite (← g.getType) (mkAppN f #[x, xn, y, yn, state, h.expr]) false)
catch _ =>
pure <| none
match result with
| .none =>
trace[simp_mem.info] "{crossEmoji} rewrite did not fire"
| .some r =>
let mvarId' ← g.replaceTargetEq r.eNew r.eqProof
-- | TODO: dispatch other goals that occur proof automation.
Tactic.setGoals <| mvarId' :: r.mvarIds

def SimpMemM.analyzeLoop : SimpMemM Unit := do
(← getMainGoal).withContext do
let hyps := (← getLocalHyps)
trace[simp_mem] "analyzing {hyps.size} hypotheses:\n{← hyps.mapM (liftMetaM ∘ inferType)}"
for h in hyps do
if let some hyp ← processHypothesis h then
trace[simp_mem.info] "{checkEmoji} Found '{h}'"
SimpMemM.addHypothesis hyp
else
trace[simp_mem.info] "{crossEmoji} Rejecting '{h}'"
SimpMemM.rewrite (← getMainGoal)

/--
Given a collection of facts, try prove `False` using the omega algorithm,
and close the goal using that.
-/
def simpMem (cfg : SimpMemConfig := {}) : TacticM Unit :=
SimpMemM.run SimpMemM.analyzeLoop cfg


/-- The `simp_mem` tactic, for simplifying away statements about memory. -/
def simpMemTactic (cfg : SimpMemConfig) : TacticM Unit := simpMem cfg

end SeparateAutomation

/--
Allow elaboration of `SimpMemConfig` arguments to tactics.
-/
declare_config_elab elabSimpMemConfig SeparateAutomation.SimpMemConfig

/--
Implement the simp_mem tactic frontend.
-/
syntax (name := simp_mem) "simp_mem" (Lean.Parser.Tactic.config)? : tactic

@[tactic simp_mem]
def evalSimpMem : Tactic := fun
| `(tactic| simp_mem $[$cfg]?) => do
let cfg ← elabSimpMemConfig (mkOptionalNode cfg)
SeparateAutomation.simpMemTactic cfg
| _ => throwUnsupportedSyntax
43 changes: 30 additions & 13 deletions Proofs/Experiments/MemoryAliasing.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,40 @@ The goal is to eliminate the sorry, and to simplify the proof to a tactic invoca
-/
import Arm
import Arm.Memory.MemoryProofs
import Arm.BitVec
import Arm.Memory.SeparateAutomation

set_option trace.simp_mem true in
set_option trace.simp_mem.info true in
theorem mem_automation_test
(h_s0_src_dest_separate : mem_separate' src_addr 16 dest_addr 16) :
read_mem_bytes 16 src_addr (write_mem_bytes 16 dest_addr blah s0) =
read_mem_bytes 16 src_addr s0 := by
-- ⊢ read_mem_bytes 16 src_addr (write_mem_bytes 16 dest_addr blah s0) = read_mem_bytes 16 src_addr s0
simp_mem
-- ⊢ read_mem_bytes 16 src_addr s0 = read_mem_bytes 16 src_addr s0
rfl

/-- info: 'mem_automation_test' depends on axioms: [propext, Classical.choice, Quot.sound] -/
#guard_msgs in #print axioms mem_automation_test

theorem mem_automation_test_2
(h_n0 : n0 ≠ 0)
(h_no_wrap_src_region : mem_legal src_addr (src_addr + ((n0 <<< 4) - 1)))
(h_no_wrap_dest_region : mem_legal dest_addr (dest_addr + ((n0 <<< 4) - 1)))
(h_no_wrap_src_region : mem_legal' src_addr (n0 <<< 4))
(h_no_wrap_dest_region : mem_legal' dest_addr (n0 <<< 4))
(h_s0_src_dest_separate :
mem_separate src_addr (src_addr + ((n0 <<< 4) - 1))
dest_addr (dest_addr + ((n0 <<< 4) - 1))) :
mem_separate' src_addr (n0 <<< 4)
dest_addr (n0 <<< 4)) :
read_mem_bytes 16 src_addr (write_mem_bytes 16 dest_addr blah s0) =
read_mem_bytes 16 src_addr s0 := by
rw [read_mem_bytes_of_write_mem_bytes_different (by decide) (by decide)]
rwa [@mem_separate_for_subset_general
src_addr (src_addr + (n0 <<< 4 - 1))
dest_addr (dest_addr + (n0 <<< 4 - 1))
src_addr (src_addr + 15#64)
dest_addr (dest_addr + 15#64)]
repeat sorry
sorry
-- rw [read_mem_bytes_of_write_mem_bytes_different (by decide) (by decide)]
-- rwa [@mem_separate_for_subset_general
-- src_addr (src_addr + (n0 <<< 4 - 1))
-- dest_addr (dest_addr + (n0 <<< 4 - 1))
-- src_addr (src_addr + 15#64)
-- dest_addr (dest_addr + 15#64)]
-- repeat sorry

/-- info: 'mem_automation_test' depends on axioms: [propext, sorryAx, Classical.choice, Lean.ofReduceBool, Quot.sound] -/
#guard_msgs in #print axioms mem_automation_test
/-- info: 'mem_automation_test_2' depends on axioms: [propext, sorryAx, Quot.sound] -/
#guard_msgs in #print axioms mem_automation_test_2
30 changes: 18 additions & 12 deletions Proofs/Experiments/SHA512MemoryAliasing.lean
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,18 @@ theorem sha512_block_armv8_prelude_sym_ctx_access (s0 : ArmState)
(h_s0_ctx : read_mem_bytes 64 (ctx_addr s0) s0 = SHA2.h0_512.toBitVec)
(h_s0_ktbl : read_mem_bytes (SHA2.k_512.length * 8) ktbl_addr s0 = BitVec.flatten SHA2.k_512)
-- (FIXME) Add separateness invariants for the stack's memory region.
-- @bollu: can we assume that `h_s1_ctx_input_separate`
-- will be given as ((num_blocks s1).toNat * 128)?
-- This is much more harmonious since we do not need to worry about overflow.
(h_s0_ctx_input_separate :
mem_separate (ctx_addr s0) (ctx_addr s0 + 63)
(input_addr s0) (input_addr s0 + (num_blocks s0 * 128)))
mem_separate' (ctx_addr s0) 64
(input_addr s0) ((num_blocks s0).toNat * 128))
(h_s0_ktbl_ctx_separate :
mem_separate (ctx_addr s0) (ctx_addr s0 + 63)
ktbl_addr (ktbl_addr + (SHA2.k_512.length * 8 - 1)))
mem_separate' (ctx_addr s0) 64
ktbl_addr (SHA2.k_512.length * 8))
(h_s0_ktbl_input_separate :
mem_separate (input_addr s0) (input_addr s0 + (num_blocks s0 * 128))
ktbl_addr (ktbl_addr + (SHA2.k_512.length * 8 - 1)))
mem_separate' (input_addr s0) ((num_blocks s0).toNat * 128)
ktbl_addr (SHA2.k_512.length * 8))
-- (h_run : sf = run 4 s0)
:
read_mem_bytes 16 (ctx_addr s0 + 48#64) s0 = xxxx := by
Expand Down Expand Up @@ -104,15 +107,18 @@ theorem sha512_block_armv8_loop_sym_ktbl_access (s1 : ArmState)
(h_s1_ctx : read_mem_bytes 64 (ctx_addr s1) s1 = SHA2.h0_512.toBitVec)
(h_s1_ktbl : read_mem_bytes (SHA2.k_512.length * 8) ktbl_addr s1 = BitVec.flatten SHA2.k_512)
-- (FIXME) Add separateness invariants for the stack's memory region.
-- @bollu: can we assume that `h_s1_ctx_input_separate`
-- will be given as ((num_blocks s1).toNat * 128)?
-- This is much more harmonious since we do not need to worry about overflow.
(h_s1_ctx_input_separate :
mem_separate (ctx_addr s1) (ctx_addr s1 + 63)
(input_addr s1) (input_addr s1 + (num_blocks s1 * 128)))
mem_separate' (ctx_addr s1) 64
(input_addr s1) ((num_blocks s1).toNat * 128))
(h_s1_ktbl_ctx_separate :
mem_separate (ctx_addr s1) (ctx_addr s1 + 63)
ktbl_addr (ktbl_addr + (SHA2.k_512.length * 8 - 1)))
mem_separate' (ctx_addr s1) 64
ktbl_addr ((SHA2.k_512.length * 8 )))
(h_s1_ktbl_input_separate :
mem_separate (input_addr s1) (input_addr s1 + (num_blocks s1 * 128))
ktbl_addr (ktbl_addr + (SHA2.k_512.length * 8 - 1))) :
mem_separate' (input_addr s1) ((num_blocks s1).toNat * 128)
ktbl_addr (SHA2.k_512.length * 8)) :
read_mem_bytes 16 ktbl_addr s1 = xxxx := by
sorry

Expand Down

0 comments on commit 6d1ef32

Please sign in to comment.