Skip to content

Commit

Permalink
refactor: replace rewrite with mkEqNDRec
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkeizer committed Oct 4, 2024
1 parent 366d482 commit 943ee99
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 10 deletions.
26 changes: 26 additions & 0 deletions Tactics/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,32 @@ def Lean.Expr.eqReadField? (e : Expr) : Option (Expr × Expr × Expr) := do
| none
some (field, state, value)

/-- Return `ArmState.program <state> = <program>` -/
def mkEqProgram (state program : Expr) : Expr :=
mkApp3 (.const ``Eq [1]) (mkConst ``Program)
(mkApp (mkConst ``ArmState.program) state)
program

/-- Return `x = y`, given expressions `x, y : BitVec <n>` -/
def mkEqBitVec (n x y : Expr) : Expr :=
let ty := mkApp (mkConst ``BitVec) n
mkApp3 (.const ``Eq [1]) ty x y

/-- Return `read_mem_bytes <n> <addr> <state>` -/
def mkReadMemBytes (n addr state : Expr) : Expr :=
mkApp3 (mkConst ``read_mem_bytes) n addr state

/-- Return `read_mem_bytes <n> <addr> <state> = <value>`, given expressions
`n : Nat`, `addr : BitVec 64`, `state : ArmState` and `value : BitVec (n*8)` -/
def mkEqReadMemBytes (n addr state value : Expr) : Expr :=
let n8 := mkNatMul n (toExpr 8)
mkEqBitVec n8 (mkReadMemBytes n addr state) value

-- def mkForallReadMemBytesEqReadMemBytes (leftState rightState : Expr) : Expr :=
-- TODO

-- def mkForallEqReadMem

/-! ## Tracing helpers -/

def traceHeartbeats (cls : Name) (header : Option String := none) :
Expand Down
49 changes: 40 additions & 9 deletions Tactics/Sym/AxEffects.lean
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ structure AxEffects where
= read_mem_bytes n addr <memoryEffect>
``` -/
memoryEffectProof : Expr
/-- A proof that `<currentState>.program = <initialState>.program` -/
/-- The program of the current state, see `programProof` -/
program : Expr
/-- A proof that `<currentState>.program = <program>` -/
programProof : Expr
/-- An optional proof of `CheckSPAlignment <currentState>`.
Expand Down Expand Up @@ -141,6 +143,8 @@ def initial (state : Expr) : AxEffects where
mkApp2 (.const ``Eq.refl [1])
(mkApp (mkConst ``BitVec) <| mkNatMul (.bvar 1) (toExpr 8))
(mkApp3 (mkConst ``read_mem_bytes) (.bvar 1) (.bvar 0) state)
program := -- `ArmState.program <initialState>`
mkApp (mkConst ``ArmState.program) state
programProof :=
-- `rfl`
mkAppN (.const ``Eq.refl [1]) #[
Expand Down Expand Up @@ -503,19 +507,46 @@ def adjustCurrentStateWithEq (eff : AxEffects) (s eq : Expr) :
let fields ← eff.fields.toList.mapM fun (field, fieldEff) => do
withTraceNode m!"rewriting field {field}" (tag := "rewriteField") do
trace[Tactic.sym] "original proof: {fieldEff.proof}"
let proof ← rewriteType fieldEff.proof eq
let motive : Expr ← withLocalDeclD `s mkArmState <| fun s => do
let eq := mkEqReadField (toExpr field) s fieldEff.value
mkLambdaFVars #[s] eq
let proof ← mkEqNDRec motive fieldEff.proof eq
trace[Tactic.sym] "new proof: {proof}"
pure (field, {fieldEff with proof})
let fields := .ofList fields

withTraceNode m!"rewriting other proofs" (tag := "rewriteMisc") <| do
let nonEffectProof ← rewriteType eff.nonEffectProof eq
let memoryEffectProof ← rewriteType eff.memoryEffectProof eq
-- ^^ TODO: what happens if `memoryEffect` is the same as `currentState`?
-- Presumably, we would *not* want to encapsulate `memoryEffect` here
let programProof ← rewriteType eff.programProof eq
let stackAlignmentProof? ← eff.stackAlignmentProof?.mapM
(rewriteType · eq)
let nonEffectProof ← lambdaTelescope eff.nonEffectProof fun args proof => do
let f := args[0]!
let motive ← -- `fun s => r <f> s = r <f> <eff.initialState>`
withLocalDeclD `s mkArmState <| fun s =>
let eq := mkEqStateValue f
(mkApp2 (mkConst ``r) f s)
(mkApp2 (mkConst ``r) f eff.initialState)
mkLambdaFVars #[s] eq
mkLambdaFVars args <|← mkEqNDRec motive proof eq

let memoryMotive : Expr ←
withLocalDeclD `s mkArmState <| fun s =>
withLocalDeclD `n (mkConst ``Nat) <| fun n =>
withLocalDeclD `addr (mkApp (mkConst ``BitVec) (toExpr 64)) <| fun addr => do
let lhs := mkReadMemBytes n addr s
let rhs := mkReadMemBytes n addr eff.memoryEffect
let eq := mkEqBitVec (mkNatMul n (toExpr 8)) lhs rhs
let eq ← mkForallFVars #[n, addr] eq
mkLambdaFVars #[s] eq
let memoryEffectProof ← mkEqNDRec memoryMotive eff.memoryEffectProof eq

let programMotive : Expr ←
withLocalDeclD `s mkArmState <| fun s =>
let eq := mkEqProgram s eff.program
mkLambdaFVars #[s] eq
let programProof ← mkEqNDRec programMotive eff.programProof eq

let stackAlignmentProof? ← eff.stackAlignmentProof?.mapM fun proof => do
let motive ← withLocalDeclD `s mkArmState <| fun s =>
mkLambdaFVars #[s] <| mkApp (mkConst ``CheckSPAlignment) s
mkEqNDRec motive proof eq

return { eff with
currentState, fields, nonEffectProof, memoryEffectProof, programProof,
Expand Down
4 changes: 3 additions & 1 deletion Tactics/Sym/Context.lean
Original file line number Diff line number Diff line change
Expand Up @@ -333,12 +333,14 @@ protected def searchFor : SearchLCtxForM SymM Unit := do
searchLCtxForOnce (h_program_type currentState program)
(whenNotFound := throwNotFound)
(whenFound := fun decl _ => do
let program ← instantiateMVars program
-- Register the program proof
modifyThe AxEffects ({· with
program
programProof := decl.toExpr
})
-- Assert that `program` is a(n application of a) constant
let program := (← instantiateMVars program).getAppFn
let program := program.getAppFn
let .const program _ := program
| throwError "Expected a constant, found:\n\t{program}"
-- Retrieve the programInfo from the environment
Expand Down

0 comments on commit 943ee99

Please sign in to comment.