Skip to content

Commit

Permalink
refactor: change memory-effects theorem to a quantifier-free statement (
Browse files Browse the repository at this point in the history
#224)

### Description:

Changes the memory effect proof to be of type `<currentState>.mem =
<trace of memory writes>.mem`, instead of the quantified statement that
the result of reading from memory at any bytes and any address of either
state agrees.

* To keep aggregation working with the new statement, we had to add
`memory_rules` to the simpsets that are used by sym_n.
* This meant we had to enhance `memory_rules` to do, e.g.,
read-over-write reasoning, and
* We had to change the `s[base, n]` notation to desugar into
`s.mem.read_bytes ..`

### Testing:

What tests have been run? Did `make all` succeed for your changes? Was
conformance testing successful on an Aarch64 machine?

Yes

### License:

By submitting this pull request, I confirm that my contribution is
made under the terms of the Apache 2.0 license.

---------

Co-authored-by: Siddharth <[email protected]>
Co-authored-by: Shilpi Goel <[email protected]>
  • Loading branch information
3 people authored Oct 16, 2024
1 parent c7829f1 commit a47a266
Show file tree
Hide file tree
Showing 12 changed files with 167 additions and 141 deletions.
155 changes: 89 additions & 66 deletions Arm/Memory/MemoryProofs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,27 @@ section MemoryProofs

open BitVec

/-! ## One byte read/write lemmas-/
namespace Memory

theorem read_write_same :
read addr (write addr v mem) = v := by
simp [read, write, store_read_over_write_same]

theorem read_write_different (h : addr1 ≠ addr2) :
read addr1 (write addr2 v s) = read addr1 s := by
simp [read, write, store_read_over_write_different (h := h)]

theorem write_write_shadow :
write addr val2 (write addr val1 s) = write addr val2 s := by
unfold write write_store; simp_all

theorem write_irrelevant :
write addr (read addr s) s = s := by
simp [read, write, store_write_irrelevant]

end Memory

----------------------------------------------------------------------
-- Key theorem: read_mem_bytes_of_write_mem_bytes_same

Expand All @@ -34,32 +55,39 @@ theorem mem_separate_preserved_second_start_addr_add_one
apply BitVec.val_nat_le 1 m 64 h0 (_ : 1 < 2^64) h1
decide

theorem read_mem_of_write_mem_bytes_different (hn1 : n <= 2^64)
(h : mem_separate addr1 addr1 addr2 (addr2 + (BitVec.ofNat 64 (n - 1)))) :
read_mem addr1 (write_mem_bytes n addr2 v s) = read_mem addr1 s := by
by_cases hn0 : n = 0
case pos => -- n = 0
subst n; simp only [write_mem_bytes]
case neg => -- n ≠ 0
have hn0' : 0 < n := by omega
induction n, hn0' using Nat.le_induction generalizing addr2 s
case base =>
have h' : addr1 ≠ addr2 := by apply mem_separate_starting_addresses_neq h
simp only [write_mem_bytes]
apply read_mem_of_write_mem_different h'
case succ =>
have h' : addr1 ≠ addr2 := by refine mem_separate_starting_addresses_neq h
rename_i m hn n_ih
simp_all only [Nat.succ_sub_succ_eq_sub, Nat.sub_zero,
Nat.succ_ne_zero, not_false_eq_true, ne_eq,
write_mem_bytes, Nat.add_eq, Nat.add_zero]
rw [n_ih]
· rw [read_mem_of_write_mem_different h']
· omega
· rw [addr_add_one_add_m_sub_one m addr2 hn hn1]
rw [mem_separate_preserved_second_start_addr_add_one hn hn1 h]
· omega
done
theorem Memory.read_write_bytes_different (hn1 : n ≤ 2^64)
(h : mem_separate addr1 addr1 addr2 (addr2 + (BitVec.ofNat 64 (n - 1)))) :
read addr1 (write_bytes n addr2 v mem) = read addr1 mem := by
induction n generalizing mem addr1 addr2
case zero => simp only [write_bytes]
case succ n ih =>
have h_neq : addr1 ≠ addr2 :=
mem_separate_starting_addresses_neq h
rw [Nat.add_one_sub_one] at h
cases n
case zero =>
simp [write_bytes, read_write_different h_neq]
case succ n =>
have h_sep : mem_separate addr1 addr1 (addr2 + 1#64)
(addr2 + 1#64 + BitVec.ofNat 64 n) := by
unfold mem_separate mem_overlap at h ⊢
simp only [BitVec.sub_self, ofNat_add, Bool.or_self_right, Bool.not_or,
Bool.and_eq_true, Bool.not_eq_eq_eq_not, Bool.not_true,
decide_eq_false_iff_not, BitVec.not_le] at h ⊢
generalize hn' : BitVec.ofNat 64 n = n' at *
have : n' ≠ -1 := by bv_omega
clear hn1 ih
bv_decide
have h_neq : addr1 ≠ addr2 :=
mem_separate_starting_addresses_neq h
rw [write_bytes, ih (by omega) h_sep, Memory.read_write_different h_neq]

theorem read_mem_of_write_mem_bytes_different (hn1 : n ≤ 2^64)
(h : mem_separate addr1 addr1 addr2 (addr2 + (BitVec.ofNat 64 (n - 1)))) :
read_mem addr1 (write_mem_bytes n addr2 v s) = read_mem addr1 s := by
simp only [ArmState.read_mem_eq_mem_read,
Memory.write_mem_bytes_eq_mem_write_bytes]
exact Memory.read_write_bytes_different hn1 h

theorem append_byte_of_extract_rest_same_cast (n : Nat) (v : BitVec ((n + 1) * 8))
(hn0 : Nat.succ 0 ≤ n)
Expand All @@ -69,47 +97,42 @@ theorem append_byte_of_extract_rest_same_cast (n : Nat) (v : BitVec ((n + 1) * 8
· omega
done

example (s : ArmState) :
read_mem_bytes n addr s = s.mem.read_bytes n addr := by
exact Memory.State.read_mem_bytes_eq_mem_read_bytes s

@[state_simp_rules]
theorem read_mem_bytes_of_write_mem_bytes_same (hn1 : n <= 2^64) :
read_mem_bytes n addr (write_mem_bytes n addr v s) = v := by
by_cases hn0 : n = 0
case pos =>
subst n
unfold read_mem_bytes
simp only [of_length_zero]
case neg => -- n ≠ 0
have hn0' : 0 < n := by omega
induction n, hn0' using Nat.le_induction generalizing addr s
case base =>
simp only [read_mem_bytes, write_mem_bytes,
read_mem_of_write_mem_same, BitVec.cast_eq]
have l1 := BitVec.extractLsb'_eq v
simp only [Nat.reduceSucc, Nat.one_mul, Nat.succ_sub_succ_eq_sub,
Nat.sub_zero, Nat.reduceAdd, BitVec.cast_eq,
forall_const] at l1
rw [l1]
have l2 := BitVec.empty_bitvector_append_left v
simp only [Nat.reduceSucc, Nat.one_mul, Nat.zero_add,
BitVec.cast_eq, forall_const] at l2
exact l2
case succ =>
rename_i n hn n_ih
simp only [read_mem_bytes, Nat.add_eq, Nat.add_zero, write_mem_bytes]
rw [n_ih]
rw [read_mem_of_write_mem_bytes_different]
· simp only [Nat.add_eq, Nat.add_zero, read_mem_of_write_mem_same]
rw [append_byte_of_extract_rest_same_cast n v hn]
· omega
· have := mem_separate_contiguous_regions addr 0#64 (BitVec.ofNat 64 (n - 1))
simp only [Nat.reducePow, Nat.succ_sub_succ_eq_sub, Nat.sub_zero,
BitVec.sub_zero, ofNat_lt_ofNat, Nat.reduceMod,
BitVec.add_zero] at this
apply this
simp only [Nat.reducePow] at hn1
omega
· omega
· omega
done
theorem Memory.read_bytes_write_bytes_same (hn1 : n ≤ 2^64) :
read_bytes n addr (write_bytes n addr v mem) = v := by
induction n generalizing addr mem
case zero =>
simp [read_bytes, of_length_zero]
case succ n ih =>
simp only [read_bytes, write_bytes]
rw [ih (by omega)]
have h_sep :
let m := BitVec.ofNat 64 (n - 1)
mem_separate addr addr (addr + 1#64) (addr + 1#64 + m) := by
rw [← mem_separate_contiguous_regions addr 0#64 _]
· simp; rfl
· bv_omega
rw [read_write_bytes_different (by omega) h_sep, read_write_same]
apply BitVec.eq_of_getLsbD_eq
intro i
simp only [getLsbD_cast, getLsbD_append]
by_cases hi : i.val < 8
· simp [hi]
· have h₁ : i.val - 8 < n * 8 := by omega
have h₂ : 8 + (i.val - 8) = i.val := by omega
simp [hi, h₁, h₂]

@[state_simp_rules, memory_rules]
theorem read_mem_bytes_of_write_mem_bytes_same (hn1 : n ≤ 2^64) :
read_mem_bytes n addr (write_mem_bytes n addr v s) = v := by
open Memory in
rw [State.read_mem_bytes_eq_mem_read_bytes,
write_mem_bytes_eq_mem_write_bytes,
Memory.read_bytes_write_bytes_same hn1]

----------------------------------------------------------------------
-- Key theorem: read_mem_bytes_of_write_mem_bytes_different
Expand Down
20 changes: 13 additions & 7 deletions Arm/State.lean
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,12 @@ theorem read_mem_bytes_w_of_read_mem_eq
= read_mem_bytes n₁ addr₁ s₂ := by
simp only [read_mem_bytes_of_w, h]

@[state_simp_rules]
theorem mem_w_of_mem_eq {s₁ s₂ : ArmState} (h : s₁.mem = s₂.mem) (fld val) :
(w fld val s₁).mem = s₂.mem := by
unfold w;
cases fld <;> exact h

@[state_simp_rules]
theorem write_mem_bytes_program {n : Nat} (addr : BitVec 64) (bytes : BitVec (n * 8)):
(write_mem_bytes n addr bytes s).program = s.program := by
Expand Down Expand Up @@ -838,6 +844,9 @@ def read_bytes (n : Nat) (addr : BitVec 64) (m : Memory) : BitVec (n * 8) :=
have h : n' * 8 + 8 = (n' + 1) * 8 := by simp_arith
BitVec.cast h (rest ++ byte)

-- TODO (@bollu): we should drop the `State` namespace here, given that
-- this namespace is used nowhere else. Also, `ArmState.read_mem_eq_mem_read`
-- should probably live under the `Memory` namespace.
@[memory_rules]
theorem State.read_mem_bytes_eq_mem_read_bytes (s : ArmState) :
read_mem_bytes n addr s = s.mem.read_bytes n addr := by
Expand Down Expand Up @@ -1163,13 +1172,10 @@ theorem Memory.mem_eq_iff_read_mem_bytes_eq {s₁ s₂ : ArmState} :
· intro h _ _; rw[h]
· exact Memory.eq_of_read_mem_bytes_eq

theorem read_mem_bytes_write_mem_bytes_of_read_mem_eq
(h : ∀ n addr, read_mem_bytes n addr s₁ = read_mem_bytes n addr s₂)
(n₂ addr₂ val n₁ addr₁) :
read_mem_bytes n₁ addr₁ (write_mem_bytes n₂ addr₂ val s₁)
= read_mem_bytes n₁ addr₁ (write_mem_bytes n₂ addr₂ val s₂) := by
revert n₁ addr₁
simp only [← Memory.mem_eq_iff_read_mem_bytes_eq] at h ⊢
theorem mem_write_mem_bytes_of_mem_eq
(h : s₁.mem = s₂.mem) (n addr val) :
(write_mem_bytes n addr val s₁).mem
= (write_mem_bytes n addr val s₂).mem := by
simp only [memory_rules, h]

/- Helper lemma for `state_eq_iff_components_eq` -/
Expand Down
4 changes: 2 additions & 2 deletions Arm/Syntax.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ import Arm.Memory.Separate

namespace ArmStateNotation

/-! We build a notation for `read_mem_bytes $n $base $s` as `$s[$base, $n]` -/
/-! We build a notation for `$s.mem.read_bytes $n $base $s` as `$s[$base, $n]` -/
@[inherit_doc read_mem_bytes]
syntax:max term noWs "[" withoutPosition(term) "," withoutPosition(term) noWs "]" : term
macro_rules | `($s[$base,$n]) => `(read_mem_bytes $n $base $s)
macro_rules | `($s[$base,$n]) => `(Memory.read_bytes $n $base (ArmState.mem $s))


/-! Notation to specify the frame condition for non-memory state components. E.g.,
Expand Down
26 changes: 2 additions & 24 deletions Proofs/AES-GCM/GCMGmultV8Sym.lean
Original file line number Diff line number Diff line change
Expand Up @@ -89,37 +89,15 @@ theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
simp (config := {ground := true}) only at h_s0_pc
-- ^^ Still needed, because `gcm_gmult_v8_program.min` is somehow
-- unable to be reflected

sym_n 27
-- Epilogue
simp only [←Memory.mem_eq_iff_read_mem_bytes_eq] at *
simp only [memory_rules] at *
sym_aggregate
-- Split conjunction
repeat' apply And.intro
· -- Aggregate the memory (non)effects.
-- (FIXME) This will be tackled by `sym_aggregate` when `sym_n` and `simp_mem`
-- are merged.
simp only [*]
/-
(FIXME @bollu) `simp_mem; rfl` creates a malformed proof here. The tactic produces
no goals, but we get the following error message:
application type mismatch
Memory.read_bytes_eq_extractLsBytes_sub_of_mem_subset'
(Eq.mp (congrArg (Eq HTable) (Memory.State.read_mem_bytes_eq_mem_read_bytes s0))
(Eq.mp (congrArg (fun x => HTable = read_mem_bytes 256 x s0) zeroExtend_eq_of_r_gpr) h_HTable))
argument has type
HTable = Memory.read_bytes 256 (r (StateField.GPR 1#5) s0) s0.mem
but function has type
Memory.read_bytes 256 (r (StateField.GPR 1#5) s0) s0.mem = HTable →
mem_subset' (r (StateField.GPR 1#5) s0) 256 (r (StateField.GPR 1#5) s0) 256 →
Memory.read_bytes 256 (r (StateField.GPR 1#5) s0) s0.mem =
HTable.extractLsBytes (BitVec.toNat (r (StateField.GPR 1#5) s0) - BitVec.toNat (r (StateField.GPR 1#5) s0)) 256
simp_mem; rfl
-/
rw [Memory.read_bytes_write_bytes_eq_read_bytes_of_mem_separate']
simp_mem
· simp_mem; rfl
· simp only [List.mem_cons, List.mem_singleton, not_or, and_imp]
sym_aggregate
· intro n addr h_separate
Expand Down
Loading

0 comments on commit a47a266

Please sign in to comment.