Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: prove specification of Max program that manipulates memory via new VCG #136

Merged
merged 24 commits into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
6030448
Add another proof method to obtain partial correctness, and use it to…
shigoel Aug 26, 2024
ad37fb1
Redo AddLoop proof using partial_correctness_from_assertions
shigoel Aug 26, 2024
6a685ae
Merge branch 'main' into vcg_assertions
shigoel Aug 26, 2024
7bac006
Testing a strategy for termination proofs that can borrow from the sa…
shigoel Aug 27, 2024
d47b199
chore: first commit, porting code over
bollu Sep 3, 2024
b9a2efb
feat: add projections from state to common registers
bollu Sep 3, 2024
873f2e8
feat: AddWithCarry lemmas
bollu Sep 3, 2024
83c06e4
feat: ArmState.mem_w_eq_mem
bollu Sep 3, 2024
2648d30
feat: Aligned_BitVecSub_64_4
bollu Sep 3, 2024
0da4d40
chore: eqn lemmas for (cassert s0 si i)
bollu Sep 3, 2024
b20e6de
chore: add proofs that `x > y` iff `(N = V) ∧ Z = 0`
bollu Sep 3, 2024
4242cda
chore: add MaxTandem example
bollu Sep 3, 2024
0c1da2e
Merge remote-tracking branch 'origin/main' into mem-max-take-3-vcg-as…
bollu Sep 4, 2024
aabff62
chore: cleanup
bollu Sep 4, 2024
45203d3
Update Arm/State.lean
bollu Sep 4, 2024
8d88628
Update Arm/State.lean
bollu Sep 4, 2024
5dafafc
chore: cleanups suggested
bollu Sep 4, 2024
82ef7d6
chore: check that we have all licenses
bollu Sep 4, 2024
f2e0000
Merge branch 'main' into mem-max-take-3-vcg-assertions
shigoel Sep 5, 2024
6898cae
Merge branch 'main' into mem-max-take-3-vcg-assertions
shigoel Sep 5, 2024
c28c73a
Merge branch 'main' into mem-max-take-3-vcg-assertions
shigoel Sep 6, 2024
7b789b8
chore: move to BitVec
bollu Sep 6, 2024
291b802
chore: add TODO to upstream.
bollu Sep 6, 2024
619073b
Merge branch 'main' into mem-max-take-3-vcg-assertions
shigoel Sep 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Arm/BitVec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,16 @@ theorem extractLsBytes_ge (h : a ≥ n) (x : BitVec n) :
apply BitVec.getLsb_ge
omega

/-- TODO: upstream -/
theorem not_slt {w} (a b : BitVec w) : ¬ (a.slt b) ↔ (b.sle a) := by
simp only [BitVec.slt, BitVec.sle]
by_cases h : a.toInt < b.toInt
· simp [h]
exact Int.not_le.mpr h
· simp [h]
exact Int.not_lt.mp h


/-! ## `Quote` instance -/

instance (w : Nat) : Quote (BitVec w) `term where
Expand Down
70 changes: 70 additions & 0 deletions Arm/Insts/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,32 @@ def AddWithCarry (x : BitVec n) (y : BitVec n) (carry_in : BitVec 1) :
let V := if signExtend (n + 1) result = signed_sum then 0#1 else 1#1
(result, (make_pstate N Z C V))

/-- When the carry bit is `0`, `AddWithCarry x y 0 = x + y` -/
shigoel marked this conversation as resolved.
Show resolved Hide resolved
theorem fst_AddWithCarry_eq_add (x : BitVec n) (y : BitVec n) :
shigoel marked this conversation as resolved.
Show resolved Hide resolved
(AddWithCarry x y 0#1).fst = x + y := by
simp [AddWithCarry, zeroExtend_eq, zeroExtend_zero, zeroExtend_zero]
apply BitVec.eq_of_toNat_eq
simp only [toNat_truncate, toNat_add, Nat.add_mod_mod, Nat.mod_add_mod]
have : 2^n < 2^(n + 1) := by
refine Nat.pow_lt_pow_of_lt (by omega) (by omega)
have : x.toNat + y.toNat < 2^(n + 1) := by omega
rw [Nat.mod_eq_of_lt this]

/-- When the carry bit is `1`, `AddWithCarry x y 1 = x - ~~~y` -/
theorem fst_AddWithCarry_eq_sub_neg (x : BitVec n) (y : BitVec n) :
(AddWithCarry x y 1#1).fst = x - ~~~y := by
simp [AddWithCarry, zeroExtend_eq, zeroExtend_zero, zeroExtend_zero]
apply BitVec.eq_of_toNat_eq
simp only [toNat_truncate, toNat_add, Nat.add_mod_mod, Nat.mod_add_mod, toNat_ofNat, Nat.pow_one,
Nat.reduceMod, toNat_sub, toNat_not]
simp only [show 2 ^ n - (2 ^ n - 1 - y.toNat) = 1 + y.toNat by omega]
have : 2^n < 2^(n + 1) := by
refine Nat.pow_lt_pow_of_lt (by omega) (by omega)
have : x.toNat + y.toNat + 1 < 2^(n + 1) := by omega
rw [Nat.mod_eq_of_lt this]
congr 1
omega

-- TODO: Is this rule helpful at all?
@[bitvec_rules]
theorem zeroExtend_eq_of_AddWithCarry :
Expand Down Expand Up @@ -83,6 +109,43 @@ def ConditionHolds (cond : BitVec 4) (s : ArmState) : Bool :=
else
result

/-- `x > y` iff `(N = V) ∧ Z = 0` . -/
theorem sgt_iff_n_eq_v_and_z_eq_0_64 (x y : BitVec 64) :
(((AddWithCarry x (~~~y) 1#1).snd.n = (AddWithCarry x (~~~y) 1#1).snd.v) ∧
(AddWithCarry x (~~~y) 1#1).snd.z = 0#1) ↔ BitVec.slt y x := by
simp [AddWithCarry, make_pstate]
split
· bv_decide
· bv_decide

/-- `x > y` iff `(N = V) ∧ Z = 0` . -/
theorem sgt_iff_n_eq_v_and_z_eq_0_32 (x y : BitVec 32) :
(((AddWithCarry x (~~~y) 1#1).snd.n = (AddWithCarry x (~~~y) 1#1).snd.v) ∧
(AddWithCarry x (~~~y) 1#1).snd.z = 0#1) ↔ BitVec.slt y x := by
simp [AddWithCarry, make_pstate]
split
· bv_decide
· bv_decide

/-- `x ≤ y` iff `¬ ((N = V) ∧ (Z = 0))`. -/
theorem sle_iff_not_n_eq_v_and_z_eq_0_64 (x y : BitVec 64) :
(¬(((AddWithCarry x (~~~y) 1#1).snd.n = (AddWithCarry x (~~~y) 1#1).snd.v) ∧
(AddWithCarry x (~~~y) 1#1).snd.z = 0#1)) ↔ BitVec.sle x y := by
simp [AddWithCarry, make_pstate]
split
· bv_decide
· bv_decide

/-- `x ≤ y` iff `¬ ((N = V) ∧ (Z = 0))`. -/
theorem sle_iff_not_n_eq_v_and_z_eq_0_32 (x y : BitVec 32) :
(¬(((AddWithCarry x (~~~y) 1#1).snd.n = (AddWithCarry x (~~~y) 1#1).snd.v) ∧
(AddWithCarry x (~~~y) 1#1).snd.z = 0#1)) ↔ BitVec.sle x y := by
simp [AddWithCarry, make_pstate]
split
· bv_decide
· bv_decide


/-- `Aligned x a` witnesses that the bitvector `x` is `a`-bit aligned. -/
def Aligned (x : BitVec n) (a : Nat) : Prop :=
match a with
Expand All @@ -93,6 +156,13 @@ def Aligned (x : BitVec n) (a : Nat) : Prop :=
instance : Decidable (Aligned x a) := by
cases a <;> simp [Aligned] <;> infer_instance

theorem Aligned_BitVecSub_64_4 {x : BitVec 64} {y : BitVec 64}
shigoel marked this conversation as resolved.
Show resolved Hide resolved
(x_aligned : Aligned x 4)
(y_aligned : Aligned y 4)
: Aligned (x - y) 4 := by
simp_all only [Aligned, Nat.sub_zero, zero_eq]
bv_decide

theorem Aligned_BitVecAdd_64_4 {x : BitVec 64} {y : BitVec 64}
(x_aligned : Aligned x 4)
(y_aligned : Aligned y 4)
Expand Down
51 changes: 51 additions & 0 deletions Arm/State.lean
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,45 @@ def r (fld : StateField) (s : ArmState) : (state_value fld) :=
| FLAG i => read_base_flag i s
| ERR => read_base_error s

/-!

We define helpers for reading and writing registers on the `ArmState` with the colloquial
names. For example, the stack pointer (`sp`) refers to register 31.
These mnemonics make it much easier to read and write theorems about assembly programs.

-/

@[state_simp_rules] abbrev ArmState.x0 (s : ArmState) : BitVec 64 := r (StateField.GPR 0) s

@[state_simp_rules] abbrev ArmState.x1 (s : ArmState) : BitVec 64 := r (StateField.GPR 1) s

@[state_simp_rules] abbrev ArmState.x2 (s : ArmState) : BitVec 64 := r (StateField.GPR 2) s

@[state_simp_rules] abbrev ArmState.sp (s : ArmState) : BitVec 64 := r (StateField.GPR 31) s

@[state_simp_rules] abbrev ArmState.V (s : ArmState) : BitVec 1 := r (StateField.FLAG PFlag.V) s

@[state_simp_rules] abbrev ArmState.C (s : ArmState) : BitVec 1 := r (StateField.FLAG PFlag.C) s

@[state_simp_rules] abbrev ArmState.Z (s : ArmState) : BitVec 1 := r (StateField.FLAG PFlag.Z) s

@[state_simp_rules] abbrev ArmState.N (s : ArmState) : BitVec 1 := r (StateField.FLAG PFlag.N) s

def ArmState.r_GPR_0_eq_x0 (s : ArmState) : r (StateField.GPR 0) s = s.x0 := by rfl
shigoel marked this conversation as resolved.
Show resolved Hide resolved

def ArmState.r_GPR_1_eq_x1 (s : ArmState) : r (StateField.GPR 1) s = s.x1 := by rfl

def ArmState.r_GPR_31_eq_sp (s : ArmState) : r (StateField.GPR 31) s = s.sp := by rfl

def ArmState.r_FLAG_V_eq_V (s : ArmState) : r (StateField.FLAG PFlag.V) s = s.V := by rfl

def ArmState.r_FLAG_C_eq_C (s : ArmState) : r (StateField.FLAG PFlag.C) s = s.C := by rfl

def ArmState.r_FLAG_Z_eq_Z (s : ArmState) : r (StateField.FLAG PFlag.Z) s = s.Z := by rfl

def ArmState.r_FLAG_N_eq_N (s : ArmState) : r (StateField.FLAG PFlag.N) s = s.N := by rfl


@[irreducible]
def w (fld : StateField) (v : (state_value fld)) (s : ArmState) : ArmState :=
open StateField in
Expand Down Expand Up @@ -753,6 +792,18 @@ def Memory.read (addr : BitVec 64) (m : Memory) : BitVec 8 :=

theorem ArmState.read_mem_eq_mem_read : read_mem addr s = s.mem.read addr := rfl

/-- `w` does not affect memory. -/
@[memory_rules, state_simp_rules]
theorem ArmState.mem_w_eq_mem (fld : StateField) (v : state_value fld) (s : ArmState) :
(w fld v s).mem = s.mem := by
cases fld <;> (
unfold w write_base_error
write_base_gpr
write_base_sfp
write_base_pc
write_base_flag
simp)

/--
A variant of `write_mem` that directly talks about writes to memory, instead of over the entire `ArmState`
-/
Expand Down
14 changes: 12 additions & 2 deletions Correctness/Correctness.lean
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author(s): Leonardo de Moura, Shilpi Goel
Author(s): Leonardo de Moura, Shilpi Goel, Siddharth Bhat
-/

/-
Expand Down Expand Up @@ -87,7 +87,6 @@ the next cutpoint) and `nextc` (to characterize the next cutpoint
state). Note that the function `csteps` is partial: if no cutpoint
is reachable from `s`, then the recursion does not terminate.
-/

noncomputable def csteps [Sys σ] [Spec' σ] (s : σ) (i : Nat) : Nat :=
iterate (fun (s, i) => if cut s then .inl i else .inr (next s, i + 1)) (s, i)

Expand Down Expand Up @@ -258,6 +257,17 @@ theorem cassert_cut [Sys σ] [Spec' σ] {s0 si : σ} (h : cut si) (i : Nat) :
simp only [↓reduceIte, and_self, h]
done

/-- If `si` is a cut-point, then `(cassert s0 si i).snd` is to verify the assertion for `si`. -/
theorem snd_cassert_of_cut [Sys σ] [Spec' σ] {s0 si : σ} (h : cut si) (i : Nat) : (cassert s0 si i).snd = assert s0 si := by
rw [cassert_eq]
simp [*]

/-- If `si` is not a cut-point, then `(cassert s0 si i).snd` is to run the next state. -/
theorem snd_cassert_of_not_cut [Sys σ] [Spec' σ] {s0 si : σ} (h : cut si = false) (i : Nat) :
(cassert s0 si i).snd = (cassert s0 (next si) (i + 1)).snd := by
rw [cassert_eq]
simp [*]

theorem cassert_not_cut [Sys σ] [Spec' σ] {s0 si : σ} (h₁ : ¬ cut si)
(h₂ : (cassert s0 (next si) (i+1)).fst = j) :
(cassert s0 si i).fst = j ∧
Expand Down
47 changes: 0 additions & 47 deletions Proofs/Experiments/Max.lean

This file was deleted.

44 changes: 44 additions & 0 deletions Proofs/Experiments/Max/MaxProgram.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author(s): Shilpi Goel, Siddharth Bhat

The goal is to prove that this program implements max correctly.
-/
import Arm
import Arm.BitVec
import Tactics.Sym
import Tactics.StepThms


namespace Max

def program : Program :=
def_program [
/- 0x0 -/ (0x894#64, 0xd10083ff#32), -- sub sp, sp, #0x20 ;
/- 0x4 -/ (0x898#64, 0xb9000fe0#32), -- str w0, [sp, #12] ; sp[12] = w0_a
/- 0x8 -/ (0x89c#64, 0xb9000be1#32), -- str w1, [sp, #8] ; sp[8] = w1_a
/- 0xc -/ (0x8a0#64, 0xb9400fe1#32), -- ldr w1, [sp, #12] ; w1_b = sp[12] = w0_a
/- 0x10 -/ (0x8a4#64, 0xb9400be0#32), -- ldr w0, [sp, #8] ; w0_b = sp[8] = w1_a
/- 0x14 -/ (0x8a8#64, 0x6b00003f#32), -- cmp w1, w0 ; w1_b - w0_b = w0_a - w1_a
/- 0x18 -/ (0x8ac#64, 0x5400008d#32), -- b.le 8bc <max+0x28> ; w0_a ≤ w1_a: br ... -- entry end.
-- LOAD FROM sp[8] = w1_a (which is > w0_a) AND STORE IN w0
/- 0x1c -/ (0x8b0#64, 0xb9400fe0#32), -- ldr w0, [sp, #12] ; w0_c = sp[12] = w0_a -- then start
/- 0x20 -/ (0x8b4#64, 0xb9001fe0#32), -- str w0, [sp, #28] ; sp[28] = w0_c = w0_a
/- 0x24 -/ (0x8b8#64, 0x14000003#32), -- b 8c4 <max+0x30> ; - then end
-- LOAD FROM sp[8] = w1_a (which is > w0_a) AND STORE IN w0
/- 0x28 -/ (0x8bc#64, 0xb9400be0#32), -- ldr w0, [sp, #8] ; w0_d = sp[8] = w1_a -- else start
/- 0x2c -/ (0x8c0#64, 0xb9001fe0#32), -- str w0, [sp, #28] ; sp[28] = w0_d = w1_a -- else end
-- LOAD FROM sp[28] AND STORE IN w0
/- 0x30 -/ (0x8c4#64, 0xb9401fe0#32), -- ldr w0, [sp, #28] ; w0 = sp[28] -- merge start
/- 0x34 -/ (0x8c8#64, 0x910083ff#32), -- add sp, sp, #0x20 ; sp = sp + 0x20
/- 0x38 -/ (0x8cc#64, 0xd65f03c0#32) -- ret -- return -- merge end [TODO: should this be ret?]
]

#genStepEqTheorems program

def spec (x y : BitVec 32) : BitVec 32 :=
if BitVec.slt y x then x else y


end Max
Loading