diff --git a/Arm/BitVec.lean b/Arm/BitVec.lean index 2f59c63e..a7218c5b 100644 --- a/Arm/BitVec.lean +++ b/Arm/BitVec.lean @@ -20,6 +20,10 @@ attribute [bitvec_rules] BitVec.extractLsb_ofFin attribute [bitvec_rules] BitVec.zeroExtend_eq attribute [bitvec_rules] BitVec.ofFin.injEq attribute [bitvec_rules] BitVec.extractLsb_toNat +attribute [bitvec_rules] BitVec.truncate_or +attribute [bitvec_rules] BitVec.zeroExtend_zeroExtend_of_le +attribute [bitvec_rules] BitVec.zeroExtend_eq + -- attribute [bitvec_rules] add_ofFin ---------------------------------------------------------------------- diff --git a/Arm/Cosim.lean b/Arm/Cosim.lean index 7f2139cc..bc2d4867 100644 --- a/Arm/Cosim.lean +++ b/Arm/Cosim.lean @@ -16,7 +16,7 @@ def init_cosim_state : ArmState := pc := 0#64, pstate := zero_pstate, mem := (fun (_ : BitVec 64) => 0#8), - program := (fun (_ : BitVec 64) => none), + program := Map.empty, error := StateError.None } /-- A structure to hold both the input and output values for a @@ -73,8 +73,7 @@ def init_flags (flags : BitVec 4) (s : ArmState) : ArmState := Id.run do /-- Initialize an ArmState for cosimulation from a given regState. -/ def regState_to_armState (r : regState) : ArmState := let s := init_gprs r.gpr (init_flags r.nzcv (init_sfps r.sfp init_cosim_state)) - let s := { s with program := - (fun (a : BitVec 64) => if a == some 0#64 then r.inst else none) } + let s := { s with program := def_program [(0x0#64, r.inst)] } s def bitvec_to_hex (x : BitVec n) : String := diff --git a/Arm/Map.lean b/Arm/Map.lean index e65edfb8..b6428f98 100644 --- a/Arm/Map.lean +++ b/Arm/Map.lean @@ -10,6 +10,9 @@ A simple Map-like type based on lists def Map (α : Type u) (β : Type v) := List (α × β) +instance [x : Repr (List (α × β))] : Repr (Map α β) where + reprPrec := x.reprPrec + def Map.empty : Map α β := [] def Map.find? [DecidableEq α] (m : Map α β) (a' : α) : Option β := diff --git a/Arm/MinTheory.lean b/Arm/MinTheory.lean index 691bcc86..e2ed32e1 100644 --- a/Arm/MinTheory.lean +++ b/Arm/MinTheory.lean @@ -83,14 +83,16 @@ attribute [minimal_theory] bne_self_eq_false' attribute [minimal_theory] decide_False attribute [minimal_theory] decide_True attribute [minimal_theory] bne_iff_ne -attribute [minimal_theory] Nat.le_zero_eq +attribute [minimal_theory] Nat.le_zero_eq attribute [minimal_theory] Nat.zero_add attribute [minimal_theory] Nat.zero_eq attribute [minimal_theory] Nat.succ.injEq attribute [minimal_theory] Nat.succ_ne_zero attribute [minimal_theory] Nat.sub_zero +attribute [minimal_theory] Nat.le_refl + @[minimal_theory] theorem option_get_bang_of_some [Inhabited α] (v : α) : Option.get! (some v) = v := by rfl diff --git a/Arm/State.lean b/Arm/State.lean index c0b09d43..46518826 100644 --- a/Arm/State.lean +++ b/Arm/State.lean @@ -118,11 +118,10 @@ structure ArmState where pstate : PState -- Memory: maps 64-bit addresses to bytes mem : Store (BitVec 64) (BitVec 8) - -- Program: maps 64-bit addresses to 32-bit instructions. None is - -- returned when no instruction is present at a specified address. + -- Program: maps 64-bit addresses to 32-bit instructions. -- Note that we have the following assumption baked into our machine model: -- the program is always disjoint from the rest of the memory. - program : Store (BitVec 64) (Option (BitVec 32)) + program : Map (BitVec 64) (BitVec 32) -- The error field is an artifact of this model; it is set to a -- non-None value when some irrecoverable error is encountered @@ -199,7 +198,7 @@ def write_base_flag (flag : PFlag) (val : BitVec 1) (s : ArmState) : ArmState := -- Fetch the instruction at address addr. @[irreducible] def fetch_inst (addr : BitVec 64) (s : ArmState) : Option (BitVec 32) := - read_store addr s.program + s.program.find? addr -- Error -- @@ -259,8 +258,7 @@ def w (fld : StateField) (v : (state_value fld)) (s : ArmState) : ArmState := | ERR => write_base_error v s @[state_simp_rules] -theorem r_of_w_same (fld : StateField) (v : (state_value fld)) (s : ArmState) : - r fld (w fld v s) = v := by +theorem r_of_w_same : r fld (w fld v s) = v := by unfold r w unfold read_base_gpr write_base_gpr unfold read_base_sfp write_base_sfp @@ -270,8 +268,7 @@ theorem r_of_w_same (fld : StateField) (v : (state_value fld)) (s : ArmState) : split <;> (repeat (split <;> simp_all!)) @[state_simp_rules] -theorem r_of_w_different (fld1 fld2 : StateField) (v : (state_value fld2)) (s : ArmState) - (h : fld1 ≠ fld2) : +theorem r_of_w_different (h : fld1 ≠ fld2) : r fld1 (w fld2 v s) = r fld1 s := by unfold r w unfold read_base_gpr write_base_gpr @@ -283,8 +280,7 @@ theorem r_of_w_different (fld1 fld2 : StateField) (v : (state_value fld2)) (s : split <;> (repeat (split <;> simp_all!)) @[state_simp_rules] -theorem w_of_w_shadow (fld : StateField) (v1 v2 : (state_value fld)) (s : ArmState) : - w fld v2 (w fld v1 s) = w fld v2 s := by +theorem w_of_w_shadow : w fld v2 (w fld v1 s) = w fld v2 s := by unfold w unfold write_base_gpr unfold write_base_sfp @@ -294,8 +290,7 @@ theorem w_of_w_shadow (fld : StateField) (v1 v2 : (state_value fld)) (s : ArmSta (repeat (split <;> simp_all!)) @[state_simp_rules] -theorem w_irrelevant (fld : StateField) (v1 v2 : (state_value fld)) (s : ArmState) : - w fld (r fld s) s = s := by +theorem w_irrelevant : w fld (r fld s) s = s := by unfold r w unfold read_base_gpr write_base_gpr unfold read_base_sfp write_base_sfp @@ -305,8 +300,7 @@ theorem w_irrelevant (fld : StateField) (v1 v2 : (state_value fld)) (s : ArmStat repeat (split <;> simp_all) @[state_simp_rules] -theorem fetch_inst_of_w (addr : BitVec 64) (fld : StateField) (val : (state_value fld)) (s : ArmState) : - fetch_inst addr (w fld val s) = fetch_inst addr s := by +theorem fetch_inst_of_w : fetch_inst addr (w fld val s) = fetch_inst addr s := by unfold fetch_inst w unfold write_base_gpr unfold write_base_sfp @@ -317,10 +311,9 @@ theorem fetch_inst_of_w (addr : BitVec 64) (fld : StateField) (val : (state_valu -- There is no StateField that overwrites the program. @[state_simp_rules] -theorem w_program (sf : StateField) (v : state_value sf) (s : ArmState): - (w sf v s).program = s.program := by +theorem w_program : (w fld v s).program = s.program := by intros - cases sf <;> unfold w <;> simp + cases fld <;> unfold w <;> simp · unfold write_base_gpr; simp · unfold write_base_sfp; simp · unfold write_base_pc; simp @@ -464,10 +457,9 @@ where | (addr, _) :: p, some max => if addr > max then loop p (some addr) else loop p (some max) theorem fetch_inst_from_program - {address: BitVec 64} {program : program} - (h_program : s.program = program.find?) : - fetch_inst address s = program.find? address := by - unfold fetch_inst read_store + {address: BitVec 64} : + fetch_inst address s = s.program.find? address := by + unfold fetch_inst simp_all! end Load_program_and_fetch_inst diff --git a/Proofs/MultiInsts.lean b/Proofs/MultiInsts.lean index 94cb5a5a..779a04bb 100644 --- a/Proofs/MultiInsts.lean +++ b/Proofs/MultiInsts.lean @@ -25,31 +25,31 @@ def test_program : program := (0x126514#64 , 0x4ea21c5c#32), -- mov v28.16b, v2.16b (0x126518#64 , 0x4ea31c7d#32)] -- mov v29.16b, v3.16b -theorem one_asm_snippet_sym_helper1 (q0_var : BitVec 128) : - zeroExtend 128 (zeroExtend 128 (zeroExtend 128 q0_var ||| zeroExtend 128 q0_var)) = zeroExtend 128 q0_var := by - sorry -- auto - theorem one_asm_snippet_sym_helper2 (q0_var : BitVec 128) : q0_var ||| q0_var = q0_var := by sorry -- auto --- Todo: use sym_n to prove this theorem. -theorem small_asm_snippet_sym (s : ArmState) - (h_pc : read_pc s = 0x12650c#64) - (h_program : s.program = test_program.find?) - (h_s_ok : read_err s = StateError.None) - (h_s' : s' = run 4 s) : - read_sfp 128 26#5 s' = read_sfp 128 0#5 s ∧ - read_err s' = StateError.None := by - FIXME - -- iterate 4 (sym1 [h_program]) - sym1 [h_program] - sym1 [h_program] - sym1 [h_program] - sym1 [h_program] - -- Wrapping up the result: - -- generalize (r (StateField.SFP 0#5) s) = q0_var; unfold state_value at q0_var; simp at q0_var - -- try (simp [one_asm_snippet_sym_helper1]) +theorem small_asm_snippet_sym (s0 s_final : ArmState) + (h_s0_pc : read_pc s0 = 0x12650c#64) + (h_s0_program : s0.program = test_program) + (h_s0_ok : read_err s0 = StateError.None) + (h_run : s_final = run 4 s0) : + read_sfp 128 26#5 s_final = read_sfp 128 0#5 s0 ∧ + read_err s_final = StateError.None := by + -- Prelude + simp_all only [state_simp_rules, -h_run] + -- Symbolic Simulation + sym_n 4 h_s0_program + -- Wrapping up the result: + unfold run at h_run + subst s_final s_4 + apply And.intro + · simp_all only [state_simp_rules, minimal_theory, bitvec_rules] simp only [one_asm_snippet_sym_helper2] - done + -- FIXME: Why does state_simp_rules not work here? Why do we need + -- an explicit rw? + (try (repeat (rw [r_of_w_different (by decide)]))) + (try (rw [r_of_w_same])) + · simp_all only [state_simp_rules, minimal_theory, bitvec_rules] + done end multi_insts_proofs diff --git a/Proofs/Sha512_block_armv8.lean b/Proofs/Sha512_block_armv8.lean index 7ee0b470..6ee69fad 100644 --- a/Proofs/Sha512_block_armv8.lean +++ b/Proofs/Sha512_block_armv8.lean @@ -30,7 +30,7 @@ def sha512_program_test_1 : program := theorem sha512_program_test_1_sym (s0 s_final : ArmState) (h_s0_pc : read_pc s0 = 0x126538#64) (h_s0_sp_aligned : CheckSPAlignment s0 = true) - (h_s0_program : s0.program = sha512_program_test_1.find?) + (h_s0_program : s0.program = sha512_program_test_1) (h_s0_ok : read_err s0 = StateError.None) (h_run : s_final = run 4 s0) : read_err s_final = StateError.None := by @@ -61,7 +61,7 @@ def sha512_program_test_2 : program := -- set_option profiler true in theorem sha512_program_test_2_sym (s0 s_final : ArmState) (h_s0_pc : read_pc s0 = 0x126538#64) - (h_s0_program : s0.program = sha512_program_test_2.find?) + (h_s0_program : s0.program = sha512_program_test_2) (h_s0_ok : read_err s0 = StateError.None) (h_run : s_final = run 6 s0) : read_err s_final = StateError.None := by @@ -79,16 +79,6 @@ theorem sha512_program_test_2_sym (s0 s_final : ArmState) -- Test 3: -variable (write : (n : Nat) → BitVec 64 → BitVec (n * 8) → Type → Type) - -theorem write_simplify_test_0 (a x y : BitVec 64) - (h : ((8 * 8) + 8 * 8) = 2 * ((8 * 8) / 8) * 8) : - write (2 * ((8 * 8) / 8)) a (BitVec.cast h (zeroExtend (8 * 8) x ++ (zeroExtend (8 * 8) y))) s - = - write 16 a (x ++ y) s := by - simp only [zeroExtend_eq, BitVec.cast_eq] - - def sha512_program_test_3 : program := def_program [(0x1264c0#64 , 0xa9bf7bfd#32), -- stp x29, x30, [sp, #-16]! @@ -105,7 +95,7 @@ theorem sha512_block_armv8_test_3_sym (s0 s_final : ArmState) (h_s0_ok : read_err s0 = StateError.None) (h_s0_sp_aligned : CheckSPAlignment s0 = true) (h_s0_pc : read_pc s0 = 0x1264c0#64) - (h_s0_program : s0.program = sha512_program_test_3.find?) + (h_s0_program : s0.program = sha512_program_test_3) (h_run : s_final = run 4 s0) : read_err s_final = StateError.None := by -- Prelude @@ -157,7 +147,7 @@ theorem sha512_block_armv8_test_4_sym (s0 s_final : ArmState) (h_s0_ok : read_err s0 = StateError.None) (h_s0_sp_aligned : CheckSPAlignment s0 = true) (h_s0_pc : read_pc s0 = 0x1264c0#64) - (h_s0_program : s0.program = sha512_program_map.find?) + (h_s0_program : s0.program = sha512_program_map) (h_run : s_final = run 32 s0) : read_err s_final = StateError.None := by -- Prelude diff --git a/Proofs/Test.lean b/Proofs/Test.lean index a5c54d75..bf7281e2 100644 --- a/Proofs/Test.lean +++ b/Proofs/Test.lean @@ -16,7 +16,7 @@ def test_s (inst : BitVec 32): ArmState := pstate := zero_pstate, mem := (fun (_ : BitVec 64) => 0#8), -- Program: BitVec 64 → Option (BitVec 32) - program := (fun (a : BitVec 64) => if a == some 0#64 then inst else none), + program := [(0#64, inst)], error := StateError.None } -- 0x91000421#32: add x1, x1, 1 diff --git a/Tactics/Sym.lean b/Tactics/Sym.lean index 7c24a636..3f92b9a6 100644 --- a/Tactics/Sym.lean +++ b/Tactics/Sym.lean @@ -11,24 +11,6 @@ import Lean.Expr open BitVec --- sym1 tactic symbolically simulates a single instruction. -syntax "sym1" "[" term "]" : tactic -macro_rules - | `(tactic| sym1 [$h_program:term]) => - `(tactic| - (try simp_all (config := {decide := true, ground := true}) only [state_simp_rules]); - unfold run; - simp_all only [stepi, state_simp_rules]; - (try rw [fetch_inst_from_program $h_program]); - (try simp (config := {decide := true, ground := true}) only); - -- After exec_inst is opened up, the exec functions of the - -- instructions which are tagged with simp will also open up - -- here. - simp only [exec_inst, state_simp_rules]; - -- (try simp_all (config := {decide := true, ground := true}) only); - -- (try simp only [ne_eq, r_of_w_different, r_of_w_same, w_of_w_shadow, w_irrelevant]) - (try simp_all (config := {decide := true, ground := true}) only [state_simp_rules])) - -- init_next_step breaks 'h_s: s_next = run n s' into 'run (n-1) s' and one step. macro "init_next_step" h_s:ident : tactic => `(tactic| @@ -43,48 +25,24 @@ macro "init_next_step" h_s:ident : tactic => -- simplifies fetch_inst and decode_raw_inst. macro "fetch_and_decode_inst" h_step:ident h_program:ident : tactic => `(tactic| - (-- unfold stepi at $h_step:ident - -- rw [$h_s_ok:ident] at $h_step:ident - -- dsimp at $h_step:ident -- reduce let and match - -- rw [$h_pc:ident] at $h_step:ident - -- rw [fetch_inst_from_program $h_program:ident] at $h_step:ident - -- -- Note: this often times out. It tries to evaluate, e.g., - -- -- Std.RBMap.find? sha512_program_test_2 1205560#64 - -- -- which easily becomes hard. - -- simp (config := {ground := true}) at $h_step:ident - -- repeat (rw [BitVec.ofFin_eq_ofNat] at $h_step:ident) - simp only [*, stepi, state_simp_rules, minimal_theory, bitvec_rules] at $h_step:ident - rw [fetch_inst_from_program $h_program:ident] at $h_step:ident - conv at $h_step:ident => + (simp only [*, stepi, state_simp_rules, minimal_theory, bitvec_rules] at $h_step:ident + rw [fetch_inst_from_program] at $h_step:ident + simp only [$h_program:ident] at $h_step:ident + conv at $h_step:ident => pattern Map.find? _ _ simp (config := {ground := true}) only - -- simp/ground leaves bitvecs' structure exposed, so we use - -- BitVec.ofFin_eq_ofNat to fold them back into their canonical - -- form. - simp only [BitVec.ofFin_eq_ofNat] - (try dsimp only at $h_step:ident); - conv at $h_step:ident => + (try dsimp only at $h_step:ident); + conv at $h_step:ident => pattern decode_raw_inst _ simp (config := {ground := true}) only - simp only [BitVec.ofFin_eq_ofNat] - (try dsimp only at $h_step:ident))) + (try dsimp only at $h_step:ident))) -- Given hstep which is the result of fetch_and_decode_inst, exec_inst executes -- an instruction and generates 's_next = w .. (w .. (... s))'. macro "exec_inst" h_step:ident : tactic => `(tactic| - (-- unfold exec_inst at $h_step:ident - -- -- A simple case where simp works (e.g., Arm.DPI) - -- try (simp (config := {ground := true, decide := true}) at $h_step:ident) - -- -- A complicated case (e.g., Arm.LDST) - -- try (simp at $h_step:ident; (conv at $h_step:ident => - -- arg 2 - -- apply if_true - -- apply $st_next:ident); simp [$h_sp_aligned:ident] at $h_step:ident) - simp only [*, exec_inst, state_simp_rules, minimal_theory, bitvec_rules] at $h_step:ident; - (try simp (config := {ground := true}) only [↓reduceIte, state_simp_rules, minimal_theory, bitvec_rules] at $h_step:ident) - -- Fold back any exposed bitvecs into canonical forms. - (try simp only [BitVec.ofFin_eq_ofNat] at $h_step:ident))) + (simp only [*, exec_inst, state_simp_rules, minimal_theory, bitvec_rules] at $h_step:ident; + (try (repeat simp (config := {ground := true}) only [↓reduceIte, state_simp_rules, minimal_theory, bitvec_rules] at $h_step:ident)))) -- Given h_step which is 's_next = w .. (w .. (... s))', it creates assumptions -- 'read .. s_next = value'. diff --git a/Tests/LDSTTest.lean b/Tests/LDSTTest.lean index bd743fdd..70e2faf7 100644 --- a/Tests/LDSTTest.lean +++ b/Tests/LDSTTest.lean @@ -14,13 +14,13 @@ open BitVec -- The cosimulation tests do not cover instructions related to memory access -- TODO: use macros to simplify the tests -def set_init_state (find? : Store (BitVec 64) (Option (BitVec 32))) : ArmState := +def set_init_state (program : Map (BitVec 64) (BitVec 32)) : ArmState := let s := { gpr := (fun (_ : BitVec 5) => 0#64), sfp := (fun (_ : BitVec 5) => 0#128), pc := 0#64, pstate := zero_pstate, mem := (fun (_ : BitVec 64) => 0#8), - program := find?, + program := program, error := StateError.None} s @@ -30,7 +30,7 @@ def ldr_gpr_unsigned_offset : program := def_program [ (0x0#64, 0xb9400fe0#32) ] -- ldr w0, [sp, #12] def ldr_gpr_unsigned_offset_state : ArmState := - let s := set_init_state ldr_gpr_unsigned_offset.find? + let s := set_init_state ldr_gpr_unsigned_offset -- write 20 in 4 bytes to address 12 let s := write_mem_bytes 4 12#64 20#32 s s @@ -46,7 +46,7 @@ def str_gpr_post_index : program := def_program [ (0x0#64, 0xf8003420#32) ] -- str x0, [x1], #3 def str_gpr_post_index_state : ArmState := - let s := set_init_state str_gpr_post_index.find? + let s := set_init_state str_gpr_post_index -- write 20 in gpr x0 let s := write_gpr 64 0#5 20#64 s -- write 0 in gpr x1 @@ -64,7 +64,7 @@ def ldr_sfp_post_index : program := def_program [ (0x0#64, 0xfc408420#32) ] -- ldr d0, [x1], #8 def ldr_sfp_post_index_state : ArmState := - let s := set_init_state ldr_sfp_post_index.find? + let s := set_init_state ldr_sfp_post_index let s := write_mem_bytes 8 0#64 20#64 s s @@ -79,7 +79,7 @@ def str_stp_unsigned_offset : program := def_program [ (0x0#64, 0x3d800420#32) ] -- str q0, [x1, #1] def str_sfp_unsigned_offset_state : ArmState := - let s := set_init_state str_stp_unsigned_offset.find? + let s := set_init_state str_stp_unsigned_offset write_sfp 128 0#5 123#128 s def str_sfp_unsigned_offset_final_state : ArmState := run 1 str_sfp_unsigned_offset_state @@ -92,7 +92,7 @@ def ldrb_unsigned_offset : program := def_program [ (0x0#64, 0x39401020#32) ] -- ldrb x0, [x1, #4] def ldrb_unsigned_offset_state: ArmState := - let s := set_init_state ldrb_unsigned_offset.find? + let s := set_init_state ldrb_unsigned_offset write_mem_bytes 1 4#64 20#8 s def ldrb_unsigned_offset_final_state : ArmState := run 1 ldrb_unsigned_offset_state @@ -105,7 +105,7 @@ def strb_post_index : program := def_program [ (0x0#64, 0x381fc420#32) ] -- strb x0, [x1], #-4 def strb_post_index_state : ArmState := - let s := set_init_state strb_post_index.find? + let s := set_init_state strb_post_index let s := write_gpr 64 1#5 5#64 s write_gpr 64 0#5 20#64 s @@ -120,7 +120,7 @@ def ldp_gpr_pre_index : program := def_program [ (0x0#64, 0xa9c00820#32) ] -- ldp x0, x2, [x1]! def ldp_gpr_pre_index_state : ArmState := - let s := set_init_state ldp_gpr_pre_index.find? + let s := set_init_state ldp_gpr_pre_index write_mem_bytes 16 0#64 0x1234000000000000ABCD#128 s def ldp_gpr_pre_index_final_state : ArmState := run 1 ldp_gpr_pre_index_state @@ -134,7 +134,7 @@ def stp_sfp_signed_offset : program := def_program [ (0x0#64, 0xad008820#32) ] -- stp q0, q2, [q1,#1] def stp_sfp_signed_offset_state : ArmState := - let s := set_init_state stp_sfp_signed_offset.find? + let s := set_init_state stp_sfp_signed_offset let s := write_sfp 128 0#5 0x1234#128 s write_sfp 128 2#5 0xabcd#128 s diff --git a/Tests/SHA512ProgramTest.lean b/Tests/SHA512ProgramTest.lean index d03637d6..32c28e64 100644 --- a/Tests/SHA512ProgramTest.lean +++ b/Tests/SHA512ProgramTest.lean @@ -587,7 +587,7 @@ def init_sha512_test : ArmState := pc := 0x1264c0#64, pstate := zero_pstate, mem := (fun (_ : BitVec 64) => 0#8), - program := sha512_program_map.find?, + program := sha512_program_map, error := StateError.None } have h_input : 1024 = 1024 / 8 * 8 := by decide let s := write_mem_bytes (1024 / 8) input_address (h_input ▸ asm_input) s