From a47a2665edcb0f6191e87a261aa8a3b76f5a573a Mon Sep 17 00:00:00 2001 From: Alex Keizer Date: Wed, 16 Oct 2024 17:50:33 -0500 Subject: [PATCH] refactor: change memory-effects theorem to a quantifier-free statement (#224) ### Description: Changes the memory effect proof to be of type `.mem = .mem`, instead of the quantified statement that the result of reading from memory at any bytes and any address of either state agrees. * To keep aggregation working with the new statement, we had to add `memory_rules` to the simpsets that are used by sym_n. * This meant we had to enhance `memory_rules` to do, e.g., read-over-write reasoning, and * We had to change the `s[base, n]` notation to desugar into `s.mem.read_bytes ..` ### Testing: What tests have been run? Did `make all` succeed for your changes? Was conformance testing successful on an Aarch64 machine? Yes ### License: By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license. --------- Co-authored-by: Siddharth Co-authored-by: Shilpi Goel --- Arm/Memory/MemoryProofs.lean | 155 +++++++++++++---------- Arm/State.lean | 20 ++- Arm/Syntax.lean | 4 +- Proofs/AES-GCM/GCMGmultV8Sym.lean | 26 +--- Proofs/Experiments/Max/MaxTandem.lean | 60 +++++---- Proofs/Experiments/Memcpy/MemCpyVCG.lean | 8 +- Proofs/Popcount32.lean | 1 + Proofs/SHA512/SHA512Prelude.lean | 1 - Tactics/Aggregate.lean | 3 + Tactics/Common.lean | 4 + Tactics/Sym/Context.lean | 4 +- Tactics/Sym/MemoryEffects.lean | 22 ++-- 12 files changed, 167 insertions(+), 141 deletions(-) diff --git a/Arm/Memory/MemoryProofs.lean b/Arm/Memory/MemoryProofs.lean index 29d54254..6a1a9ee6 100644 --- a/Arm/Memory/MemoryProofs.lean +++ b/Arm/Memory/MemoryProofs.lean @@ -16,6 +16,27 @@ section MemoryProofs open BitVec +/-! ## One byte read/write lemmas-/ +namespace Memory + +theorem read_write_same : + read addr (write addr v mem) = v := by + simp [read, write, store_read_over_write_same] + +theorem read_write_different (h : addr1 ≠ addr2) : + read addr1 (write addr2 v s) = read addr1 s := by + simp [read, write, store_read_over_write_different (h := h)] + +theorem write_write_shadow : + write addr val2 (write addr val1 s) = write addr val2 s := by + unfold write write_store; simp_all + +theorem write_irrelevant : + write addr (read addr s) s = s := by + simp [read, write, store_write_irrelevant] + +end Memory + ---------------------------------------------------------------------- -- Key theorem: read_mem_bytes_of_write_mem_bytes_same @@ -34,32 +55,39 @@ theorem mem_separate_preserved_second_start_addr_add_one apply BitVec.val_nat_le 1 m 64 h0 (_ : 1 < 2^64) h1 decide -theorem read_mem_of_write_mem_bytes_different (hn1 : n <= 2^64) - (h : mem_separate addr1 addr1 addr2 (addr2 + (BitVec.ofNat 64 (n - 1)))) : - read_mem addr1 (write_mem_bytes n addr2 v s) = read_mem addr1 s := by - by_cases hn0 : n = 0 - case pos => -- n = 0 - subst n; simp only [write_mem_bytes] - case neg => -- n ≠ 0 - have hn0' : 0 < n := by omega - induction n, hn0' using Nat.le_induction generalizing addr2 s - case base => - have h' : addr1 ≠ addr2 := by apply mem_separate_starting_addresses_neq h - simp only [write_mem_bytes] - apply read_mem_of_write_mem_different h' - case succ => - have h' : addr1 ≠ addr2 := by refine mem_separate_starting_addresses_neq h - rename_i m hn n_ih - simp_all only [Nat.succ_sub_succ_eq_sub, Nat.sub_zero, - Nat.succ_ne_zero, not_false_eq_true, ne_eq, - write_mem_bytes, Nat.add_eq, Nat.add_zero] - rw [n_ih] - · rw [read_mem_of_write_mem_different h'] - · omega - · rw [addr_add_one_add_m_sub_one m addr2 hn hn1] - rw [mem_separate_preserved_second_start_addr_add_one hn hn1 h] - · omega - done +theorem Memory.read_write_bytes_different (hn1 : n ≤ 2^64) + (h : mem_separate addr1 addr1 addr2 (addr2 + (BitVec.ofNat 64 (n - 1)))) : + read addr1 (write_bytes n addr2 v mem) = read addr1 mem := by + induction n generalizing mem addr1 addr2 + case zero => simp only [write_bytes] + case succ n ih => + have h_neq : addr1 ≠ addr2 := + mem_separate_starting_addresses_neq h + rw [Nat.add_one_sub_one] at h + cases n + case zero => + simp [write_bytes, read_write_different h_neq] + case succ n => + have h_sep : mem_separate addr1 addr1 (addr2 + 1#64) + (addr2 + 1#64 + BitVec.ofNat 64 n) := by + unfold mem_separate mem_overlap at h ⊢ + simp only [BitVec.sub_self, ofNat_add, Bool.or_self_right, Bool.not_or, + Bool.and_eq_true, Bool.not_eq_eq_eq_not, Bool.not_true, + decide_eq_false_iff_not, BitVec.not_le] at h ⊢ + generalize hn' : BitVec.ofNat 64 n = n' at * + have : n' ≠ -1 := by bv_omega + clear hn1 ih + bv_decide + have h_neq : addr1 ≠ addr2 := + mem_separate_starting_addresses_neq h + rw [write_bytes, ih (by omega) h_sep, Memory.read_write_different h_neq] + +theorem read_mem_of_write_mem_bytes_different (hn1 : n ≤ 2^64) + (h : mem_separate addr1 addr1 addr2 (addr2 + (BitVec.ofNat 64 (n - 1)))) : + read_mem addr1 (write_mem_bytes n addr2 v s) = read_mem addr1 s := by + simp only [ArmState.read_mem_eq_mem_read, + Memory.write_mem_bytes_eq_mem_write_bytes] + exact Memory.read_write_bytes_different hn1 h theorem append_byte_of_extract_rest_same_cast (n : Nat) (v : BitVec ((n + 1) * 8)) (hn0 : Nat.succ 0 ≤ n) @@ -69,47 +97,42 @@ theorem append_byte_of_extract_rest_same_cast (n : Nat) (v : BitVec ((n + 1) * 8 · omega done +example (s : ArmState) : + read_mem_bytes n addr s = s.mem.read_bytes n addr := by + exact Memory.State.read_mem_bytes_eq_mem_read_bytes s + @[state_simp_rules] -theorem read_mem_bytes_of_write_mem_bytes_same (hn1 : n <= 2^64) : - read_mem_bytes n addr (write_mem_bytes n addr v s) = v := by - by_cases hn0 : n = 0 - case pos => - subst n - unfold read_mem_bytes - simp only [of_length_zero] - case neg => -- n ≠ 0 - have hn0' : 0 < n := by omega - induction n, hn0' using Nat.le_induction generalizing addr s - case base => - simp only [read_mem_bytes, write_mem_bytes, - read_mem_of_write_mem_same, BitVec.cast_eq] - have l1 := BitVec.extractLsb'_eq v - simp only [Nat.reduceSucc, Nat.one_mul, Nat.succ_sub_succ_eq_sub, - Nat.sub_zero, Nat.reduceAdd, BitVec.cast_eq, - forall_const] at l1 - rw [l1] - have l2 := BitVec.empty_bitvector_append_left v - simp only [Nat.reduceSucc, Nat.one_mul, Nat.zero_add, - BitVec.cast_eq, forall_const] at l2 - exact l2 - case succ => - rename_i n hn n_ih - simp only [read_mem_bytes, Nat.add_eq, Nat.add_zero, write_mem_bytes] - rw [n_ih] - rw [read_mem_of_write_mem_bytes_different] - · simp only [Nat.add_eq, Nat.add_zero, read_mem_of_write_mem_same] - rw [append_byte_of_extract_rest_same_cast n v hn] - · omega - · have := mem_separate_contiguous_regions addr 0#64 (BitVec.ofNat 64 (n - 1)) - simp only [Nat.reducePow, Nat.succ_sub_succ_eq_sub, Nat.sub_zero, - BitVec.sub_zero, ofNat_lt_ofNat, Nat.reduceMod, - BitVec.add_zero] at this - apply this - simp only [Nat.reducePow] at hn1 - omega - · omega - · omega - done +theorem Memory.read_bytes_write_bytes_same (hn1 : n ≤ 2^64) : + read_bytes n addr (write_bytes n addr v mem) = v := by + induction n generalizing addr mem + case zero => + simp [read_bytes, of_length_zero] + case succ n ih => + simp only [read_bytes, write_bytes] + rw [ih (by omega)] + have h_sep : + let m := BitVec.ofNat 64 (n - 1) + mem_separate addr addr (addr + 1#64) (addr + 1#64 + m) := by + rw [← mem_separate_contiguous_regions addr 0#64 _] + · simp; rfl + · bv_omega + rw [read_write_bytes_different (by omega) h_sep, read_write_same] + apply BitVec.eq_of_getLsbD_eq + intro i + simp only [getLsbD_cast, getLsbD_append] + by_cases hi : i.val < 8 + · simp [hi] + · have h₁ : i.val - 8 < n * 8 := by omega + have h₂ : 8 + (i.val - 8) = i.val := by omega + simp [hi, h₁, h₂] + +@[state_simp_rules, memory_rules] +theorem read_mem_bytes_of_write_mem_bytes_same (hn1 : n ≤ 2^64) : + read_mem_bytes n addr (write_mem_bytes n addr v s) = v := by + open Memory in + rw [State.read_mem_bytes_eq_mem_read_bytes, + write_mem_bytes_eq_mem_write_bytes, + Memory.read_bytes_write_bytes_same hn1] ---------------------------------------------------------------------- -- Key theorem: read_mem_bytes_of_write_mem_bytes_different diff --git a/Arm/State.lean b/Arm/State.lean index 9d1c3e65..4f264183 100644 --- a/Arm/State.lean +++ b/Arm/State.lean @@ -747,6 +747,12 @@ theorem read_mem_bytes_w_of_read_mem_eq = read_mem_bytes n₁ addr₁ s₂ := by simp only [read_mem_bytes_of_w, h] +@[state_simp_rules] +theorem mem_w_of_mem_eq {s₁ s₂ : ArmState} (h : s₁.mem = s₂.mem) (fld val) : + (w fld val s₁).mem = s₂.mem := by + unfold w; + cases fld <;> exact h + @[state_simp_rules] theorem write_mem_bytes_program {n : Nat} (addr : BitVec 64) (bytes : BitVec (n * 8)): (write_mem_bytes n addr bytes s).program = s.program := by @@ -838,6 +844,9 @@ def read_bytes (n : Nat) (addr : BitVec 64) (m : Memory) : BitVec (n * 8) := have h : n' * 8 + 8 = (n' + 1) * 8 := by simp_arith BitVec.cast h (rest ++ byte) +-- TODO (@bollu): we should drop the `State` namespace here, given that +-- this namespace is used nowhere else. Also, `ArmState.read_mem_eq_mem_read` +-- should probably live under the `Memory` namespace. @[memory_rules] theorem State.read_mem_bytes_eq_mem_read_bytes (s : ArmState) : read_mem_bytes n addr s = s.mem.read_bytes n addr := by @@ -1163,13 +1172,10 @@ theorem Memory.mem_eq_iff_read_mem_bytes_eq {s₁ s₂ : ArmState} : · intro h _ _; rw[h] · exact Memory.eq_of_read_mem_bytes_eq -theorem read_mem_bytes_write_mem_bytes_of_read_mem_eq - (h : ∀ n addr, read_mem_bytes n addr s₁ = read_mem_bytes n addr s₂) - (n₂ addr₂ val n₁ addr₁) : - read_mem_bytes n₁ addr₁ (write_mem_bytes n₂ addr₂ val s₁) - = read_mem_bytes n₁ addr₁ (write_mem_bytes n₂ addr₂ val s₂) := by - revert n₁ addr₁ - simp only [← Memory.mem_eq_iff_read_mem_bytes_eq] at h ⊢ +theorem mem_write_mem_bytes_of_mem_eq + (h : s₁.mem = s₂.mem) (n addr val) : + (write_mem_bytes n addr val s₁).mem + = (write_mem_bytes n addr val s₂).mem := by simp only [memory_rules, h] /- Helper lemma for `state_eq_iff_components_eq` -/ diff --git a/Arm/Syntax.lean b/Arm/Syntax.lean index 0c1ca562..99658f43 100644 --- a/Arm/Syntax.lean +++ b/Arm/Syntax.lean @@ -10,10 +10,10 @@ import Arm.Memory.Separate namespace ArmStateNotation -/-! We build a notation for `read_mem_bytes $n $base $s` as `$s[$base, $n]` -/ +/-! We build a notation for `$s.mem.read_bytes $n $base $s` as `$s[$base, $n]` -/ @[inherit_doc read_mem_bytes] syntax:max term noWs "[" withoutPosition(term) "," withoutPosition(term) noWs "]" : term -macro_rules | `($s[$base,$n]) => `(read_mem_bytes $n $base $s) +macro_rules | `($s[$base,$n]) => `(Memory.read_bytes $n $base (ArmState.mem $s)) /-! Notation to specify the frame condition for non-memory state components. E.g., diff --git a/Proofs/AES-GCM/GCMGmultV8Sym.lean b/Proofs/AES-GCM/GCMGmultV8Sym.lean index 937f80ea..d1c453c2 100644 --- a/Proofs/AES-GCM/GCMGmultV8Sym.lean +++ b/Proofs/AES-GCM/GCMGmultV8Sym.lean @@ -89,6 +89,7 @@ theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState) simp (config := {ground := true}) only at h_s0_pc -- ^^ Still needed, because `gcm_gmult_v8_program.min` is somehow -- unable to be reflected + sym_n 27 -- Epilogue simp only [←Memory.mem_eq_iff_read_mem_bytes_eq] at * @@ -96,30 +97,7 @@ theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState) sym_aggregate -- Split conjunction repeat' apply And.intro - · -- Aggregate the memory (non)effects. - -- (FIXME) This will be tackled by `sym_aggregate` when `sym_n` and `simp_mem` - -- are merged. - simp only [*] - /- - (FIXME @bollu) `simp_mem; rfl` creates a malformed proof here. The tactic produces - no goals, but we get the following error message: - - application type mismatch - Memory.read_bytes_eq_extractLsBytes_sub_of_mem_subset' - (Eq.mp (congrArg (Eq HTable) (Memory.State.read_mem_bytes_eq_mem_read_bytes s0)) - (Eq.mp (congrArg (fun x => HTable = read_mem_bytes 256 x s0) zeroExtend_eq_of_r_gpr) h_HTable)) - argument has type - HTable = Memory.read_bytes 256 (r (StateField.GPR 1#5) s0) s0.mem - but function has type - Memory.read_bytes 256 (r (StateField.GPR 1#5) s0) s0.mem = HTable → - mem_subset' (r (StateField.GPR 1#5) s0) 256 (r (StateField.GPR 1#5) s0) 256 → - Memory.read_bytes 256 (r (StateField.GPR 1#5) s0) s0.mem = - HTable.extractLsBytes (BitVec.toNat (r (StateField.GPR 1#5) s0) - BitVec.toNat (r (StateField.GPR 1#5) s0)) 256 - - simp_mem; rfl - -/ - rw [Memory.read_bytes_write_bytes_eq_read_bytes_of_mem_separate'] - simp_mem + · simp_mem; rfl · simp only [List.mem_cons, List.mem_singleton, not_or, and_imp] sym_aggregate · intro n addr h_separate diff --git a/Proofs/Experiments/Max/MaxTandem.lean b/Proofs/Experiments/Max/MaxTandem.lean index e75925a0..b3083461 100644 --- a/Proofs/Experiments/Max/MaxTandem.lean +++ b/Proofs/Experiments/Max/MaxTandem.lean @@ -194,14 +194,17 @@ theorem program.stepi_0x898_cut (s sn : ArmState) simp only [minimal_theory] at this simp_all only [run, cut, this, state_simp_rules, bitvec_rules, minimal_theory] - simp only [pcs, List.mem_cons, BitVec.reduceEq, List.mem_singleton, or_self, not_false_eq_true, - true_and, List.not_mem_nil, or_self, not_false_eq_true, true_and] - simp only [memory_rules, state_simp_rules] + simp only [pcs, List.mem_cons, BitVec.reduceEq, List.mem_singleton, or_self, + not_false_eq_true, true_and, List.not_mem_nil, or_self, not_false_eq_true, + true_and] + simp only [Memory.write_mem_bytes_eq_mem_write_bytes] simp_mem - rfl + simp only [Nat.reduceMul, BitVec.toNat_add, BitVec.toNat_ofNat, Nat.reducePow, + Nat.reduceMod, Nat.sub_self, BitVec.extractLsBytes_eq_self, BitVec.cast_eq, + and_true] /-- -info: 'MaxTandem.program.stepi_0x898_cut' depends on axioms: [propext, Classical.choice, Lean.ofReduceBool, Quot.sound] +info: 'MaxTandem.program.stepi_0x898_cut' depends on axioms: [propext, Classical.choice, Quot.sound] -/ #guard_msgs in #print axioms program.stepi_0x898_cut @@ -260,6 +263,7 @@ theorem program.stepi_0x8a0_cut (s sn : ArmState) simp_all only [run, cut, this, state_simp_rules, bitvec_rules, minimal_theory] simp only [pcs, List.mem_cons, BitVec.reduceEq, List.mem_singleton, or_self, not_false_eq_true, true_and, List.not_mem_nil, or_self, not_false_eq_true, true_and] + simp only [Memory.State.read_mem_bytes_eq_mem_read_bytes] done /-- @@ -289,6 +293,7 @@ theorem program.stepi_0x8a4_cut (s sn : ArmState) simp_all only [run, cut, this, state_simp_rules, bitvec_rules, minimal_theory] simp only [pcs, List.mem_cons, BitVec.reduceEq, List.mem_singleton, or_self, not_false_eq_true, true_and, List.not_mem_nil, or_self, not_false_eq_true, true_and] + simp only [Memory.State.read_mem_bytes_eq_mem_read_bytes] /-- info: 'MaxTandem.program.stepi_0x8a4_cut' depends on axioms: [propext, Classical.choice, Quot.sound] @@ -403,6 +408,7 @@ theorem program.stepi_0x8b0_cut (s sn : ArmState) simp_all only [run, cut, this, state_simp_rules, bitvec_rules, minimal_theory] simp only [pcs, List.mem_cons, BitVec.reduceEq, List.mem_singleton, or_self, not_false_eq_true] simp only [List.not_mem_nil, or_self, not_false_eq_true] + simp only [Memory.State.read_mem_bytes_eq_mem_read_bytes, and_true] /-- info: 'MaxTandem.program.stepi_0x8b0_cut' depends on axioms: [propext, Classical.choice, Quot.sound] @@ -430,6 +436,8 @@ theorem program.stepi_0x8b4_cut (s sn : ArmState) simp only [pcs, List.mem_cons, BitVec.reduceEq, List.mem_singleton, or_self, not_false_eq_true] simp only [List.not_mem_nil, or_self, or_false, or_true] simp only [not_false_eq_true] + simp only [Memory.write_mem_bytes_eq_mem_write_bytes, true_and] + rw [Memory.read_bytes_write_bytes_same (by omega)] /-- info: 'MaxTandem.program.stepi_0x8b4_cut' depends on axioms: [propext, Classical.choice, Lean.ofReduceBool, Quot.sound] @@ -483,6 +491,7 @@ theorem program.stepi_0x8bc_cut (s sn : ArmState) simp only [pcs, List.mem_cons, BitVec.reduceEq, List.mem_singleton, or_self, or_false, or_true] simp only [List.not_mem_nil, or_self, or_false, or_true] simp only [not_false_eq_true] + simp only [Memory.State.read_mem_bytes_eq_mem_read_bytes, and_self] /-- info: 'MaxTandem.program.stepi_0x8bc_cut' depends on axioms: [propext, Classical.choice, Quot.sound] @@ -506,6 +515,8 @@ theorem program.stepi_0x8c0_cut (s sn : ArmState) have := program.stepi_eq_0x8c0 h_program h_pc h_err simp only [minimal_theory] at this simp_all only [run, cut, this, state_simp_rules, bitvec_rules, minimal_theory] + simp only [Memory.write_mem_bytes_eq_mem_write_bytes] + rw [Memory.read_bytes_write_bytes_same (by omega)] simp [pcs] /-- @@ -532,6 +543,7 @@ theorem program.stepi_0x8c4_cut (s sn : ArmState) have := program.stepi_eq_0x8c4 h_program h_pc h_err simp only [minimal_theory] at this simp_all only [run, cut, this, state_simp_rules, bitvec_rules, minimal_theory] + simp only [Memory.State.read_mem_bytes_eq_mem_read_bytes] simp [pcs] /-- @@ -632,7 +644,8 @@ theorem partial_correctness : replace h_s2_sp : s2.sp = (s0.sp - 32#64) := by simp_all replace h_s2_x0 : s2.x0 = s0.x0 := by simp_all replace h_s2_x1 : s2.x1 = s0.x1 := by simp_all - replace h_s2_read_sp12 : read_mem_bytes 4 (s2.sp + 12#64) s2 = BitVec.truncate 32 s0.x0 := by simp_all + replace h_s2_read_sp12 : s2.mem.read_bytes 4 (s2.sp + 12#64) = BitVec.truncate 32 s0.x0 := by + simp_all clear_named [h_s1] -- 3/15 @@ -645,8 +658,8 @@ theorem partial_correctness : replace _h_s3_x1 : s3.x1 = s0.x1 := by simp_all replace h_s3_sp : s3.sp = s0.sp - 32 := by simp_all /- TODO: this should be s0.x0-/ - replace h_s3_read_sp12 : read_mem_bytes 4 (s3.sp + 12#64) s3 = BitVec.truncate 32 s0.x0 := by simp_all - replace _h_s3_read_sp8 : read_mem_bytes 4 (s3.sp + 8#64) s3 = BitVec.truncate 32 s0.x1 := by simp_all + replace h_s3_read_sp12 : s3.mem.read_bytes 4 (s3.sp + 12#64) = BitVec.truncate 32 s0.x0 := by simp_all + replace _h_s3_read_sp8 : s3.mem.read_bytes 4 (s3.sp + 8#64) = BitVec.truncate 32 s0.x1 := by simp_all clear_named [h_s2] -- 4/15 @@ -659,8 +672,8 @@ theorem partial_correctness : replace _h_s4_x0 : s4.x0 = s0.x0 := by simp_all replace h_s4_x1 : s4.x1 = BitVec.zeroExtend 64 (BitVec.truncate 32 s0.x0) := by simp_all replace h_s4_sp : s4.sp = s0.sp - 32 := by simp_all - replace h_s4_read_sp12 : read_mem_bytes 4 (s4.sp + 12#64) s4 = BitVec.truncate 32 s0.x0 := by simp_all - replace _h_s4_read_sp8 : read_mem_bytes 4 (s4.sp + 8#64) s4 = BitVec.truncate 32 s0.x1 := by simp_all + replace h_s4_read_sp12 : s4.mem.read_bytes 4 (s4.sp + 12#64) = BitVec.truncate 32 s0.x0 := by simp_all + replace _h_s4_read_sp8 : s4.mem.read_bytes 4 (s4.sp + 8#64) = BitVec.truncate 32 s0.x1 := by simp_all clear_named [h_s3] -- 5/15 @@ -673,8 +686,8 @@ theorem partial_correctness : replace h_s5_x0 : s5.x0 = BitVec.zeroExtend 64 (BitVec.truncate 32 s0.x1) := by simp_all replace h_s5_x1 : s5.x1 = BitVec.zeroExtend 64 (BitVec.truncate 32 s0.x0) := by simp_all replace h_s5_sp : s5.sp = s0.sp - 32 := by simp_all - replace h_s5_read_sp12 : read_mem_bytes 4 (s5.sp + 12#64) s5 = BitVec.truncate 32 s0.x0 := by simp_all - replace _h_s5_read_sp8 : read_mem_bytes 4 (s5.sp + 8#64) s5 = BitVec.truncate 32 s0.x1 := by simp_all + replace h_s5_read_sp12 : s5.mem.read_bytes 4 (s5.sp + 12#64) = BitVec.truncate 32 s0.x0 := by simp_all + replace _h_s5_read_sp8 : s5.mem.read_bytes 4 (s5.sp + 8#64) = BitVec.truncate 32 s0.x1 := by simp_all clear_named [h_s4] -- 6/15 @@ -687,8 +700,8 @@ theorem partial_correctness : replace h_s6_x0 : s6.x0 = BitVec.zeroExtend 64 (BitVec.truncate 32 s0.x1) := by simp_all replace h_s6_x1 : s6.x1 = BitVec.zeroExtend 64 (BitVec.truncate 32 s0.x0) := by simp_all replace h_s6_sp : s6.sp = s0.sp - 32 := by simp_all - replace h_s6_read_sp12 : read_mem_bytes 4 (s6.sp + 12#64) s6 = BitVec.truncate 32 s0.x0 := by simp_all - replace _h_s6_read_sp8 : read_mem_bytes 4 (s6.sp + 8#64) s6 = BitVec.truncate 32 s0.x1 := by simp_all + replace h_s6_read_sp12 : s6.mem.read_bytes 4 (s6.sp + 12#64) = BitVec.truncate 32 s0.x0 := by simp_all + replace _h_s6_read_sp8 : s6.mem.read_bytes 4 (s6.sp + 8#64) = BitVec.truncate 32 s0.x1 := by simp_all replace h_s6_c : s6.C = (AddWithCarry (s0.x0.zeroExtend 32) (~~~s0.x1.zeroExtend 32) 1#1).snd.c := by simp_all replace h_s6_n : s6.N = (AddWithCarry (s0.x0.zeroExtend 32) (~~~s0.x1.zeroExtend 32) 1#1).snd.n := by simp_all replace h_s6_v : s6.V = (AddWithCarry (s0.x0.zeroExtend 32) (~~~s0.x1.zeroExtend 32) 1#1).snd.v := by simp_all @@ -705,8 +718,8 @@ theorem partial_correctness : replace h_s7_x0 : s7.x0 = BitVec.zeroExtend 64 (BitVec.truncate 32 s0.x1) := by simp_all replace h_s7_x1 : s7.x1 = BitVec.zeroExtend 64 (BitVec.truncate 32 s0.x0) := by simp_all replace h_s7_sp : s7.sp = s0.sp - 32 := by simp_all - replace h_s7_read_sp12 : read_mem_bytes 4 ((s0.sp - 32#64) + 12#64) s7 = BitVec.truncate 32 s0.x0 := by simp_all - replace h_s7_read_sp8 : read_mem_bytes 4 ((s0.sp - 32#64) + 8#64) s7 = BitVec.truncate 32 s0.x1 := by simp_all + replace h_s7_read_sp12 : s7.mem.read_bytes 4 ((s0.sp - 32#64) + 12#64) = BitVec.truncate 32 s0.x0 := by simp_all + replace h_s7_read_sp8 : s7.mem.read_bytes 4 ((s0.sp - 32#64) + 8#64) = BitVec.truncate 32 s0.x1 := by simp_all have h_s7_s6_c := h_s6_c have h_s7_s6_n := h_s6_n have h_s7_s6_v := h_s6_v @@ -760,7 +773,7 @@ theorem partial_correctness : obtain ⟨h_s3_cut, h_s3_pc, h_s3_err, h_s3_program, h_s3_x0, h_s3_sp_28, h_s3_sp, h_s3_sp_aliged⟩ := h rw [Correctness.snd_cassert_of_not_cut h_s3_cut]; -- try rw [Correctness.snd_cassert_of_cut h_cut]; simp [show Sys.next _ = run 1 _ by rfl] - replace h_s3_sp_28 : read_mem_bytes 4 (s3.sp + 28#64) s3 = BitVec.zeroExtend 32 (spec s0.x0 s0.x1) := by simp_all + replace h_s3_sp_28 : s3.mem.read_bytes 4 (s3.sp + 28#64) = BitVec.zeroExtend 32 (spec s0.x0 s0.x1) := by simp_all replace h_s3_sp : s3.sp = s0.sp - 32#64 := by simp_all clear_named [h_s2, h_s1] @@ -770,7 +783,7 @@ theorem partial_correctness : rw [Correctness.snd_cassert_of_not_cut (si := s4) (by simp_all [Spec'.cut])]; simp [show Sys.next _ = run 1 _ by rfl] have h_s4_sp : s4.sp = s0.sp - 32#64 := by simp_all - have h_s4_sp_28 : read_mem_bytes 4 (s4.sp + 28#64) s4 = BitVec.zeroExtend 32 (spec s0.x0 s0.x1) := by simp_all + have h_s4_sp_28 : s4.mem.read_bytes 4 (s4.sp + 28#64) = BitVec.zeroExtend 32 (spec s0.x0 s0.x1) := by simp_all clear_named [h_s3] -- 5/15 @@ -778,7 +791,7 @@ theorem partial_correctness : obtain h_s5 := program.stepi_0x8c4_cut s4 s5 (by simp_all) (by simp_all) (by simp_all) (by simp_all) (h_run.symm) rw [Correctness.snd_cassert_of_not_cut (si := s5) (by simp_all [Spec'.cut])]; have h_s5_x0 : s5.x0 = BitVec.zeroExtend 64 (BitVec.zeroExtend 32 (spec s0.x0 s0.x1)) := by - simp only [show s5.x0 = BitVec.zeroExtend 64 (read_mem_bytes 4 (s5.sp + 28#64) s5) by simp_all] + simp only [show s5.x0 = BitVec.zeroExtend 64 (s5.mem.read_bytes 4 (s5.sp + 28#64)) by simp_all] simp only [Nat.reduceMul] /- Damn, that the rewrite system is not confluent really messes me up over here ;_; `simp` winds up rewriting `s5.sp` into `s4.sp` first because of the rule, and @@ -786,10 +799,11 @@ theorem partial_correctness : One might say that this entire proof is stupid, but really, I 'just' want it to build an e-graph and figure it out. -/ - have : (read_mem_bytes 4 (s5.sp + 28#64) s5) = read_mem_bytes 4 (s4.sp + 28#64) s4 := by + have : (s5.mem.read_bytes 4 (s5.sp + 28#64)) = read_mem_bytes 4 (s4.sp + 28#64) s4 := by obtain ⟨_, _, _, _, _, _, h, _⟩ := h_s5 exact h - simp [this] + simp only [this, Memory.State.read_mem_bytes_eq_mem_read_bytes, + BitVec.truncate_eq_setWidth] rw [h_s4_sp_28] simp [show Sys.next _ = run 1 _ by rfl] @@ -822,7 +836,7 @@ theorem partial_correctness : obtain ⟨h_s3_cut, h_s3_pc, h_s3_err, h_s3_program, h_s3_x0, h_s3_sp, h_s3_sp_aliged⟩ := h rw [Correctness.snd_cassert_of_not_cut h_s3_cut]; -- try rw [Correctness.snd_cassert_of_cut h_cut]; simp [show Sys.next _ = run 1 _ by rfl] - replace h_s3_sp_28 : read_mem_bytes 4 (s3.sp + 28#64) s3 = BitVec.zeroExtend 32 (spec s0.x0 s0.x1) := by simp_all + replace h_s3_sp_28 : s3.mem.read_bytes 4 (s3.sp + 28#64) = BitVec.zeroExtend 32 (spec s0.x0 s0.x1) := by simp_all replace h_s3_sp : s3.sp = s0.sp - 32#64 := by simp_all clear_named [h_s2, h_s1] @@ -832,7 +846,7 @@ theorem partial_correctness : rw [Correctness.snd_cassert_of_not_cut (si := s4) (by simp_all [Spec'.cut])]; simp [show Sys.next _ = run 1 _ by rfl] have h_s4_sp : s4.sp = s0.sp - 32#64 := by simp_all - have h_s4_sp_28 : read_mem_bytes 4 (s4.sp + 28#64) s4 = BitVec.zeroExtend 32 (spec s0.x0 s0.x1) := by simp_all + have h_s4_sp_28 : s4.mem.read_bytes 4 (s4.sp + 28#64) = BitVec.zeroExtend 32 (spec s0.x0 s0.x1) := by simp_all clear_named [h_s3] -- 6/15 diff --git a/Proofs/Experiments/Memcpy/MemCpyVCG.lean b/Proofs/Experiments/Memcpy/MemCpyVCG.lean index dfc532c3..b4b001f0 100644 --- a/Proofs/Experiments/Memcpy/MemCpyVCG.lean +++ b/Proofs/Experiments/Memcpy/MemCpyVCG.lean @@ -270,8 +270,10 @@ theorem program.step_8e4_8e8_of_wellformed_of_stepped (scur snext : ArmState) have := program.stepi_eq_0x8e4 h_program h_pc h_err obtain ⟨h_step⟩ := hstep subst h_step - constructor <;> simp only [*, cut, state_simp_rules, minimal_theory, bitvec_rules] - · constructor <;> simp [*, state_simp_rules, minimal_theory] + constructor + <;> simp only [*, cut, state_simp_rules, minimal_theory, bitvec_rules, + memory_rules] + · constructor <;> simp [*, state_simp_rules, minimal_theory, memory_rules] -- 3/7 (0x8e8#64, 0x3c810444#32), /- str q4, [x2], #16 -/ structure Step_8e8_8ec (scur : ArmState) (snext : ArmState) extends WellFormedAtPc snext 0x8ec : Prop where @@ -770,7 +772,7 @@ theorem partial_correctness : rw [step_8e4_8e8.h_q4] rw [h_si_x2] obtain ⟨h_assert_1, h_assert_2, h_assert_3, h_assert_4, h_assert_5, h_assert_6, h_assert_7⟩ := h_assert - simp only [memory_rules] + -- simp only [memory_rules] simp only [step_8f4_8e4.h_mem] simp only [step_8f4_8e4.h_x1] rw [h_si_x1] diff --git a/Proofs/Popcount32.lean b/Proofs/Popcount32.lean index cb71dbf4..d7857b79 100644 --- a/Proofs/Popcount32.lean +++ b/Proofs/Popcount32.lean @@ -70,6 +70,7 @@ def popcount32_program : Program := #genStepEqTheorems popcount32_program +-- set_option trace.simp_mem.info true in theorem popcount32_sym_meets_spec (s0 sf : ArmState) (h_s0_pc : read_pc s0 = 0x4005b4#64) (h_s0_program : s0.program = popcount32_program) diff --git a/Proofs/SHA512/SHA512Prelude.lean b/Proofs/SHA512/SHA512Prelude.lean index 9f8b29fe..25434718 100644 --- a/Proofs/SHA512/SHA512Prelude.lean +++ b/Proofs/SHA512/SHA512Prelude.lean @@ -166,7 +166,6 @@ theorem sha512_block_armv8_prelude (s0 sf : ArmState) -- Only memory-related obligations are left. -- (TODO @alex/@bollu) Remove ∀ in memory (non)effect hyps generated by -- `sym_n`. The user may still state memory properties using quantifiers. - simp only [←Memory.mem_eq_iff_read_mem_bytes_eq] at * -- Rewrite *_mem_bytes (in terms of ArmState) to *_bytes (in terms of Memory). simp only [memory_rules] at * -- (FIXME) Need to aggregate memory effects here automatically. diff --git a/Tactics/Aggregate.lean b/Tactics/Aggregate.lean index 886b9a3b..f4197f0b 100644 --- a/Tactics/Aggregate.lean +++ b/Tactics/Aggregate.lean @@ -30,6 +30,9 @@ def aggregate (axHyps : Array LocalDecl) (location : Location) let config := simpConfig?.getD aggregate.defaultSimpConfig let (ctx, simprocs) ← LNSymSimpContext + -- https://github.com/leanprover/lean4/blob/94b1e512da9df1394350ab81a28deca934271f65/src/Lean/Meta/DiscrTree.lean#L371 + -- refines the discrimination tree to also index applied functions. + (noIndexAtArgs := false) (config := config) (decls := axHyps) diff --git a/Tactics/Common.lean b/Tactics/Common.lean index 9763af33..5d0f78eb 100644 --- a/Tactics/Common.lean +++ b/Tactics/Common.lean @@ -280,6 +280,10 @@ def Lean.Expr.eqReadField? (e : Expr) : Option (Expr × Expr × Expr) := do /-- Return the expression for `Memory` -/ def mkMemory : Expr := mkConst ``Memory +/-- Return a proof of type `x = x`, where `x : Memory` -/ +def mkEqReflMemory (x : Expr) : Expr := + mkApp2 (.const ``Eq.refl [1]) mkMemory x + /-! ## Expr Helpers -/ /-- Throw an error if `e` is not of type `expectedType` -/ diff --git a/Tactics/Sym/Context.lean b/Tactics/Sym/Context.lean index 44d981e0..e25f3d57 100644 --- a/Tactics/Sym/Context.lean +++ b/Tactics/Sym/Context.lean @@ -240,7 +240,9 @@ private def initial (state : Expr) : MetaM SymContext := do let finalState ← mkFreshExprMVar mkArmState /- Get the default simp lemmas & simprocs for aggregation -/ let (aggregateSimpCtx, aggregateSimprocs) ← - LNSymSimpContext (config := {decide := true, failIfUnchanged := false}) + LNSymSimpContext + (config := {decide := true, failIfUnchanged := false}) + (simp_attrs := #[`minimal_theory, `bitvec_rules, `state_simp_rules, `memory_rules]) let aggregateSimpCtx := { aggregateSimpCtx with -- Create a new discrtree for effect hypotheses to be added to. -- TODO(@alexkeizer): I put this here, since the previous version kept diff --git a/Tactics/Sym/MemoryEffects.lean b/Tactics/Sym/MemoryEffects.lean index 45cb7eec..5b8d3719 100644 --- a/Tactics/Sym/MemoryEffects.lean +++ b/Tactics/Sym/MemoryEffects.lean @@ -21,9 +21,8 @@ structure MemoryEffects where effects : Expr /-- An expression that contains the proof of: ```lean - ∀ n addr, - read_mem_bytes n addr - = read_mem_bytes n addr + .mem + = .mem ``` -/ proof : Expr deriving Repr @@ -44,13 +43,8 @@ initial `state` -/ def initial (state : Expr) : MemoryEffects where effects := state proof := - -- `fun n addr => rfl` - mkLambda `n .default (mkConst ``Nat) <| - let bv64 := mkApp (mkConst ``BitVec) (toExpr 64) - mkLambda `addr .default bv64 <| - mkApp2 (.const ``Eq.refl [1]) - (mkApp (mkConst ``BitVec) <| mkNatMul (.bvar 1) (toExpr 8)) - (mkApp3 (mkConst ``read_mem_bytes) (.bvar 1) (.bvar 0) state) + -- `rfl` + mkEqReflMemory (mkApp (mkConst ``ArmState.mem) state) /-- Update the memory effects with a memory write -/ def updateWriteMem (eff : MemoryEffects) (currentState : Expr) @@ -58,8 +52,8 @@ def updateWriteMem (eff : MemoryEffects) (currentState : Expr) MetaM MemoryEffects := do let effects := mkApp4 (mkConst ``write_mem_bytes) n addr val eff.effects let proof := - -- `read_mem_bytes_write_mem_bytes_of_read_mem_eq ...` - mkAppN (mkConst ``read_mem_bytes_write_mem_bytes_of_read_mem_eq) + -- `mem_write_mem_bytes_of_mem_eq ...` + mkAppN (mkConst ``mem_write_mem_bytes_of_mem_eq) #[currentState, eff.effects, eff.proof, n, addr, val] return { effects, proof } @@ -70,8 +64,8 @@ we need to update proofs -/ def updateWrite (eff : MemoryEffects) (currentState : Expr) (fld val : Expr) : MetaM MemoryEffects := do - let proof := -- `read_mem_bytes_w_of_read_mem_eq ...` - mkAppN (mkConst ``read_mem_bytes_w_of_read_mem_eq) + let proof := -- `mem_w_of_mem_eq ...` + mkAppN (mkConst ``mem_w_of_mem_eq) #[currentState, eff.effects, eff.proof, fld, val] return { eff with proof }