From a3e762f8144d1ca08bce0bbc746c7daebe1d112c Mon Sep 17 00:00:00 2001 From: Juneyoung Lee <136006969+aqjune-aws@users.noreply.github.com> Date: Wed, 21 Feb 2024 06:37:07 -0600 Subject: [PATCH] Split sym1 into smaller tactics (#14) ### Description: A macro-style tactic is chosen instead of the monad style so that developers can easily extend the tactics. However, this will change for a small set of tactics if they need more fine-grained manipulations. ### License: By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license. --------- Authored-by: Juneyoung Lee Reviewed by: Shilpi Goel --- Arm/Memory.lean | 14 ++- Arm/State.lean | 11 +++ Proofs/MultiInsts.lean | 3 +- Proofs/Sha512_block_armv8.lean | 68 ++++++++++----- Tactics/Sym.lean | 153 ++++++++++++++++++++++++++++++++- 5 files changed, 221 insertions(+), 28 deletions(-) diff --git a/Arm/Memory.lean b/Arm/Memory.lean index 500f5abd..ee6588b4 100644 --- a/Arm/Memory.lean +++ b/Arm/Memory.lean @@ -99,6 +99,16 @@ theorem read_mem_bytes_of_w : rw [n_ih] done +theorem write_mem_bytes_program {n : ℕ} (addr : BitVec 64) (bytes : BitVec (n * 8)): + (write_mem_bytes n addr bytes s).program = s.program := by + intros + induction n generalizing addr s + · simp [write_mem_bytes] + · rename_i n h_n + simp [write_mem_bytes] + rw [h_n] + simp [write_mem] + ---- Memory RoW/WoW lemmas ---- theorem read_mem_of_write_mem_same : @@ -115,8 +125,8 @@ theorem write_mem_of_write_mem_shadow : simp [write_mem]; unfold write_store; simp_all; done theorem write_mem_irrelevant : - write_mem addr (read_mem addr s) s = s := by - simp [read_mem, write_mem, store_write_irrelevant] + write_mem addr (read_mem addr s) s = s := by + simp [read_mem, write_mem, store_write_irrelevant] end Memory diff --git a/Arm/State.lean b/Arm/State.lean index db565a8c..d56b0f91 100644 --- a/Arm/State.lean +++ b/Arm/State.lean @@ -294,6 +294,17 @@ theorem fetch_inst_of_w (addr : BitVec 64) (fld : StateField) (val : (state_valu unfold write_base_error split <;> simp_all! +-- There is no StateField that overwrites the program. +theorem w_program (sf : StateField) (v : state_value sf) (s : ArmState): + (w sf v s).program = s.program := by + intros + cases sf <;> unfold w <;> simp + · unfold write_base_gpr; simp + · unfold write_base_sfp; simp + · unfold write_base_pc; simp + · unfold write_base_flag; simp + · unfold write_base_error; simp + -- The following functions are defined in terms of r and w, but may be -- simpler to use. diff --git a/Proofs/MultiInsts.lean b/Proofs/MultiInsts.lean index e3276bc1..3f3e6e90 100644 --- a/Proofs/MultiInsts.lean +++ b/Proofs/MultiInsts.lean @@ -31,6 +31,7 @@ theorem one_asm_snippet_sym_helper1 (q0_var : BitVec 128) : theorem one_asm_snippet_sym_helper2 (q0_var : BitVec 128) : q0_var ||| q0_var = q0_var := by auto +-- Todo: use sym_n to prove this theorem. theorem small_asm_snippet_sym (s : ArmState) (h_pc : read_pc s = 0x12650c#64) (h_program : s.program = test_program.find?) @@ -46,7 +47,7 @@ theorem small_asm_snippet_sym (s : ArmState) -- Wrapping up the result: -- generalize (r (StateField.SFP 0#5) s) = q0_var; unfold state_value at q0_var; simp at q0_var -- try (simp [one_asm_snippet_sym_helper1]) - simp only [one_asm_snippet_sym_helper2] + simp only [one_asm_snippet_sym_helper2] done end multi_insts_proofs diff --git a/Proofs/Sha512_block_armv8.lean b/Proofs/Sha512_block_armv8.lean index 938f1b64..f23266bf 100644 --- a/Proofs/Sha512_block_armv8.lean +++ b/Proofs/Sha512_block_armv8.lean @@ -49,14 +49,16 @@ def sha512_program_test_1 : program := ] -- set_option profiler true in -theorem sha512_program_test_1_sym (s : ArmState) - (h_pc : read_pc s = 0x126538#64) - (h_program : s.program = sha512_program_test_1.find?) - (h_s_ok : read_err s = StateError.None) - (h_s' : s' = run 4 s) : - read_err s' = StateError.None := by - iterate 4 (sym1 [h_program]) - done +theorem sha512_program_test_1_sym (s0 s_final : ArmState) + (h_s0_pc : read_pc s0 = 0x126538#64) + (h_s0_sp_aligned : CheckSPAlignment s0 = true) + (h_s0_program : s0.program = sha512_program_test_1.find?) + (h_s0_ok : read_err s0 = StateError.None) + (h_run : s_final = run 4 s0) : + read_err s_final = StateError.None := by + unfold read_pc at h_s0_pc + sym_n 4 0x126538 sha512_program_test_1 + rw [h_run,h_s4_ok] ---------------------------------------------------------------------- @@ -74,13 +76,15 @@ def sha512_program_test_2 : program := ] -- set_option profiler true in -theorem sha512_program_test_2_sym (s : ArmState) - (h_pc : read_pc s = 0x126538#64) - (h_program : s.program = sha512_program_test_2.find?) - (h_s_ok : read_err s = StateError.None) - (h_s' : s' = run 4 s) : - read_err s' = StateError.None := by - iterate 4 (sym1 [h_program]) +theorem sha512_program_test_2_sym (s0 s_final : ArmState) + (h_s0_pc : read_pc s0 = 0x126538#64) + (h_s0_program : s0.program = sha512_program_test_2.find?) + (h_s0_ok : read_err s0 = StateError.None) + (h_run : s_final = run 4 s0) : + read_err s_final = StateError.None := by + -- TODO: use sym_n. Using it causes an error at simp due to the + -- large formula size. + iterate 4 (sym1 [h_s0_program]) ---------------------------------------------------------------------- @@ -94,7 +98,21 @@ def sha512_program_test_3 : program := (0x1264cc#64 , 0x4cdf2034#32) -- ld1 {v20.16b-v23.16b}, [x1], #64 ] -theorem sha512_block_armv8_test_3_sym (s : ArmState) + +theorem sha512_block_armv8_test_3_sym (s0 s_final : ArmState) + (h_s0_ok : read_err s0 = StateError.None) + (h_s0_sp_aligned : CheckSPAlignment s0 = true) + (h_s0_pc : read_pc s0 = 0x1264c0#64) + (h_s0_program : s0.program = sha512_program_test_3.find?) + (h_run : s_final = run 4 s0) : + read_err s_final = StateError.None := by + + unfold read_pc at h_s0_pc + sym_n 4 0x1264c0 sha512_program_test_3 + rw [h_run,h_s4_ok] + +-- A record that shows simp fails. +theorem sha512_block_armv8_test_3_sym_fail (s : ArmState) (h_s_ok : read_err s = StateError.None) (h_sp_aligned : CheckSPAlignment s = true) (h_pc : read_pc s = 0x1264c0#64) @@ -133,13 +151,17 @@ theorem sha512_block_armv8_test_3_sym (s : ArmState) -- we'd like to verify). -- set_option profiler true in -theorem sha512_block_armv8_test_4_sym (s : ArmState) - (h_s_ok : read_err s = StateError.None) - (h_sp_aligned : CheckSPAlignment s = true) - (h_pc : read_pc s = 0x1264c0#64) - (h_program : s.program = sha512_program_map.find?) - (h_s' : s' = run 32 s) : - read_err s' = StateError.None := by +theorem sha512_block_armv8_test_4_sym (s0 s_final : ArmState) + (h_s0_ok : read_err s0 = StateError.None) + (h_s0_sp_aligned : CheckSPAlignment s0 = true) + (h_s0_pc : read_pc s0 = 0x1264c0#64) + (h_s0_program : s0.program = sha512_program_map.find?) + (h_run : s_final = run 32 s0) : + read_err s_final = StateError.None := by + + unfold read_pc at h_s0_pc + -- sym_n 32 0x1264c0 + -- ^^ This raises the max recursion depth limit error because the program is too large. :/ sorry end SHA512_proof diff --git a/Tactics/Sym.lean b/Tactics/Sym.lean index d8a49b1c..d2c1b834 100644 --- a/Tactics/Sym.lean +++ b/Tactics/Sym.lean @@ -6,13 +6,15 @@ Author(s): Shilpi Goel import Arm.Exec import Arm.MemoryProofs +open Std.BitVec + -- sym1 tactic symbolically simulates a single instruction. syntax "sym1" "[" term "]" : tactic macro_rules | `(tactic| sym1 [$h_program:term]) => `(tactic| - (try simp_all (config := {decide := true, ground := true})); - unfold run; + (try simp_all (config := {decide := true, ground := true})); + unfold run; simp_all [stepi]; (try rw [fetch_inst_from_rbmap_program $h_program]); (try simp (config := {decide := true, ground := true}) only); @@ -23,3 +25,150 @@ macro_rules -- (try simp_all (config := {decide := true, ground := true}) only); -- (try simp only [ne_eq, r_of_w_different, r_of_w_same, w_of_w_shadow, w_irrelevant]) (try simp_all (config := {decide := true, ground := true}))) + + +theorem run_onestep (s s': ArmState) (n: ℕ) (h_nonneg: 0 < n): + (s' = run n s) ↔ ∃ s'', s'' = stepi s ∧ s' = run (n-1) s'' := by + cases n + · cases h_nonneg + · rename_i n + simp [run] + +-- TODO: replace this with an upcoming new lemma in Std +theorem Std.BitVec.foldCtor : { toFin := { val := a, isLt := h } : BitVec n } = BitVec.ofNat n a := by + simp [BitVec.ofNat, Fin.ofNat', h, Nat.mod_eq_of_lt] + +-- init_next_step breaks 'h_s: s_next = run n s' into 'run (n-1) s' and one step. +macro "init_next_step" h_s:ident : tactic => + `(tactic| + (rw [run_onestep] at $h_s:ident <;> try omega + cases $h_s:ident + rename_i h_temp + cases h_temp + rename_i h_s' + simp at h_s')) + +-- Given 'h_step: s_next = stepi s', fetch_and_decode_inst unfolds stepi, +-- simplifies fetch_inst and decode_raw_inst. +macro "fetch_and_decode_inst" h_step:ident h_s_ok:ident h_pc:ident h_program:ident : tactic => + `(tactic| + (unfold stepi at $h_step:ident + rw [$h_s_ok:ident] at $h_step:ident + dsimp at $h_step:ident -- reduce let and match + rw [$h_pc:ident] at $h_step:ident + rw [fetch_inst_from_rbmap_program $h_program:ident] at $h_step:ident + -- Note: this often times out. It tries to evaluate, e.g., + -- Std.RBMap.find? sha512_program_test_2 1205560#64 + -- which easily becomes hard. + simp (config := {ground := true}) at $h_step:ident + repeat (rw [Std.BitVec.foldCtor] at $h_step:ident) + )) + +-- Given hstep which is the result of fetch_and_decode_inst, exec_inst executes +-- an instruction and generates 's_next = w .. (w .. (... s))'. +macro "exec_inst" h_step:ident h_sp_aligned:ident st_next:ident: tactic => + `(tactic| + (unfold exec_inst at $h_step:ident + -- A simple case where simp works (e.g., Arm.DPI) + try (simp (config := {ground := true, decide := true}) at $h_step:ident) + -- A complicated case (e.g., Arm.LDST) + try (simp at $h_step:ident; (conv at $h_step:ident => + arg 2 + apply if_true + apply $st_next:ident); simp [$h_sp_aligned:ident] at $h_step:ident) + )) + +-- Given h_step wich is 's_next = w .. (w .. (... s))', it creates assumptions +-- 'read .. s_next = value'. +-- TODO: update_invariants must add all register and memory updates as +-- assumptions. +macro "update_invariants" st_next:ident progname:ident + h_s_ok_new:ident + h_pc:ident h_pc_new:ident + h_sp_aligned:ident h_sp_aligned_new:ident + h_program_new:ident + h_step:ident pc_next:term : tactic => + `(tactic| + (have $h_s_ok_new:ident: read_err $st_next:ident = StateError.None := by + rw [$h_step:ident]; simp_all + -- Q: How can we automatically infer the next PC? + have $h_pc_new:ident: r StateField.PC $st_next:ident = $pc_next:term := by + rw [$h_step:ident,$h_pc:ident]; simp; simp (config := {ground := true}) + have $h_sp_aligned_new:ident: CheckSPAlignment $st_next:ident = true := by + unfold CheckSPAlignment at * + rw [$h_step:ident] + simp + simp at $h_sp_aligned:ident + /- + This sorry will be resovled after lean4 that has improved + `simp (config := { ground := true })` is used. + See also: + https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Simplifying.20a.20bitvector.20constant/near/422077748 + The goal: + + h_s0_sp_aligned: extractLsb 3 0 (r (StateField.GPR 31#5) s0) &&& 15#4 = 0#4 + ⊢ extractLsb 3 0 (r (StateField.GPR 31#5) s0 + + signExtend 64 126#7 <<< (2 + Std.BitVec.toNat (extractLsb 1 1 2#2))) &&& + 15#4 = + 0#4 + -/ + sorry + have $h_program_new:ident : ($st_next:ident).program = + Std.RBMap.find? ($progname:ident) := by + rw [$h_step:ident] + try (repeat rw [w_program]) + try (rw [write_mem_bytes_program]) + assumption + )) + +def sym_one (curr_state_number:ℕ) (pc_begin:ℕ) (prog:Lean.Ident): + Lean.Elab.Tactic.TacticM Unit := + Lean.Elab.Tactic.withMainContext do + let n_str := toString curr_state_number + let n'_str := toString (curr_state_number+1) + let pcexpr := Lean.mkNatLit (pc_begin + 4 * (curr_state_number + 1)) + let pcbv := ← (Lean.mkApp2 (Lean.mkConst ``Std.BitVec.ofNat) (Lean.mkNatLit 64) + pcexpr).toSyntax + -- Question: how can I convert this pcbv into Syntax? + let mk_name (s:String): Lean.Name := + Lean.Name.append Lean.Name.anonymous s + -- The name of the next state + let st' := Lean.mkIdent (mk_name ("s_" ++ n'_str)) + let h_st_ok := Lean.mkIdent (mk_name ("h_s" ++ n_str ++ "_ok")) + let h_st'_ok := Lean.mkIdent (mk_name ("h_s" ++ n'_str ++ "_ok")) + let h_st_pc := Lean.mkIdent (mk_name ("h_s" ++ n_str ++ "_pc")) + let h_st'_pc := Lean.mkIdent (mk_name ("h_s" ++ n'_str ++ "_pc")) + let h_st_program := Lean.mkIdent (mk_name ("h_s" ++ n_str ++ "_program")) + let h_st'_program := Lean.mkIdent (mk_name ("h_s" ++ n'_str ++ "_program")) + let h_st_sp_aligned := Lean.mkIdent (mk_name ("h_s" ++ n_str ++ "_sp_aligned")) + let h_st'_sp_aligned := Lean.mkIdent (mk_name ("h_s" ++ n'_str ++ "_sp_aligned")) + -- Temporary hypotheses + let h_run := Lean.mkIdent (mk_name "h_run") + Lean.Elab.Tactic.evalTactic (← + `(tactic| + (init_next_step $h_run:ident + rename_i $st':ident h_step $h_run:ident + -- Simulate one instruction + fetch_and_decode_inst h_step $h_st_ok:ident $h_st_pc:ident $h_st_program:ident + exec_inst h_step $h_st_sp_aligned:ident $st':ident + + -- Update invariants + update_invariants $st':ident $prog:ident + $h_st'_ok:ident + $h_st_pc:ident $h_st'_pc:ident + $h_st_sp_aligned $h_st'_sp_aligned:ident + $h_st'_program h_step $pcbv:term + clear $h_st_ok:ident $h_st_sp_aligned:ident $h_st_pc:ident h_step + $h_st_program:ident + ))) + +-- sym_n tactic symbolically simulates n instructions. +elab "sym_n" n:num pc:num prog:ident : tactic => do + for i in List.range n.getNat do + sym_one i pc.getNat prog + +-- sym_n tactic symbolically simulates n instructions from +-- state number i. +elab "sym_i_n" i:num n:num pc:num prog:ident : tactic => do + for j in List.range n.getNat do + sym_one (i.getNat + j) pc.getNat prog