Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reference implementation of prune_updates #246

Merged
merged 29 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bcec32b
Experimental method to aggregate state effects
shigoel Oct 14, 2024
5ac43cc
Merge branch 'main' into aggr_dsl
shigoel Oct 14, 2024
1f8df56
Fix comments
shigoel Oct 14, 2024
7687069
Merge branch 'main' into aggr_dsl
shigoel Oct 14, 2024
85ea0e6
Update Tactics/ArmConstr.lean
shigoel Oct 14, 2024
c67f1ca
Add a test for 30 steps
shigoel Oct 14, 2024
a16325c
Add a python script to generate theorems for ArmConstr method
shigoel Oct 14, 2024
61fa5c3
chore: cleanup to use stdlib API
bollu Oct 14, 2024
7de7f7b
Cherry-pick @bollu's commit 7947bbf (https://github.com/leanprover/LN…
shigoel Oct 15, 2024
e61e5d6
Minor edits
shigoel Oct 15, 2024
cf0d8a6
Sort the updates when doing the aggregation
shigoel Oct 15, 2024
ad449b0
Add note about why decide doesn't do a complete reduction (thanks, @a…
shigoel Oct 15, 2024
f002039
Merge branch 'main' into aggr_dsl
shigoel Oct 17, 2024
c23db62
Prove Expr.eq_true_of_denote
shigoel Oct 21, 2024
8ab5f2a
Merge branch 'main' into aggr_dsl
shigoel Oct 21, 2024
421b76a
Minor comments
shigoel Oct 21, 2024
8295aa5
Add sym_block; manually aggregate basic blocks for SHA512; simulate S…
shigoel Oct 23, 2024
45f4742
Clean up sym_block; account for blockSize in some Sym/Context functions
shigoel Oct 24, 2024
9f83f49
Merge branch 'main' into aggr_dsl
shigoel Oct 30, 2024
d4429bb
prune_updates ready for GPR writes
shigoel Nov 5, 2024
b7a71b3
Merge branch 'main' into aggr_dsl
shigoel Nov 5, 2024
1488732
Merge branch 'aggr_dsl' into more_preprocessing
shigoel Nov 5, 2024
67df45d
Merge branch 'aggr_dsl' into more_preprocessing
shigoel Nov 5, 2024
156c712
Finished reference implementation of prune_updates
shigoel Nov 8, 2024
32619ab
Merge branch 'main' into more_preprocessing
shigoel Nov 12, 2024
c59a393
Add missing copyright header
shigoel Nov 12, 2024
dd971b2
Fix a bug in the application of mkNeProofOfNotMemAndMem
shigoel Nov 12, 2024
b477dc5
Add capability to sym_block to specify a list of possibly different b…
shigoel Nov 13, 2024
b1378c4
Add an example where prune_updates times out
shigoel Nov 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Arm/Exec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ theorem run_onestep {s s': ArmState} {n : Nat} :
(s' = run (n + 1) s) → ∃ s'', stepi s = s'' ∧ s' = run n s'' := by
simp only [run, exists_eq_left', imp_self]

theorem run_oneblock {s s' : ArmState} {n1 n2 : Nat} :
(s' = run (n1 + n2) s) →
∃ s'', run n1 s = s'' ∧ s' = run n2 s'' := by
simp only [run_plus, exists_eq_left', imp_self]

/-- helper lemma for automation -/
theorem stepi_eq_of_fetch_inst_of_decode_raw_inst
(s : ArmState) (addr : BitVec 64) (rawInst : BitVec 32) (inst : ArmInst)
Expand Down
22 changes: 22 additions & 0 deletions Arm/State.lean
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ theorem store_write_irrelevant [DecidableEq α]
unfold write_store read_store
simp

@[local simp]
theorem write_store_commute [DecidableEq α] (store : Store α β)
(h : i ≠ j) :
write_store i x (write_store j y store) =
write_store j y (write_store i x store) := by
apply funext; intro idx
simp [write_store]
split <;> simp_all

instance [Repr β]: Repr (Store (BitVec n) β) where
reprPrec store _ :=
let rec helper (a : Nat) (acc : Lean.Format) :=
Expand Down Expand Up @@ -451,6 +460,19 @@ theorem w_irrelevant : w fld (r fld s) s = s := by
unfold read_base_error write_base_error
repeat (split <;> simp_all)

theorem w_of_w_commute (h : fld1 ≠ fld2) :
w fld1 v1 (w fld2 v2 s) = w fld2 v2 (w fld1 v1 s) := by
unfold w
unfold write_base_gpr
unfold write_base_sfp
unfold write_base_pc
unfold write_base_flag
unfold write_base_error
split <;> split <;> (simp_all; try rwa [write_store_commute])
rename_i fld1 t1 i1 v1 fld2 t2 i2 v2
cases i1 <;> (cases i2 <;> (split <;> simp_all))
done

@[state_simp_rules]
theorem fetch_inst_of_w : fetch_inst addr (w fld val s) = fetch_inst addr s := by
unfold fetch_inst w
Expand Down
149 changes: 90 additions & 59 deletions Proofs/AES-GCM/GCMGmultV8Sym.lean
Original file line number Diff line number Diff line change
Expand Up @@ -134,31 +134,80 @@ multiplication into four constituent ones, for normalization.
example :
let p := 0b11#2
let q := 0b10#2
let w := 0b01#2
let z := 0b01#2
let x := 0b01#2
let y := 0b01#2
(DPSFP.polynomial_mult
(p ++ q)
(w ++ z))
(x ++ y))
=
((DPSFP.polynomial_mult p w) ++ 0#4) ^^^
(0#4 ++ (DPSFP.polynomial_mult q z)) ^^^
(0#2 ++ (DPSFP.polynomial_mult p z) ++ 0#2) ^^^
(0#2 ++ (DPSFP.polynomial_mult q w) ++ 0#2) := by native_decide
((DPSFP.polynomial_mult p x) ++ 0#4) ^^^
(0#4 ++ (DPSFP.polynomial_mult q y)) ^^^
(0#2 ++ (DPSFP.polynomial_mult p y) ++ 0#2) ^^^
(0#2 ++ (DPSFP.polynomial_mult q x) ++ 0#2) := by native_decide

def pmult_test_1 : IO Bool := do
let p ← BitVec.rand 64
let q ← BitVec.rand 64
let x ← BitVec.rand 64
let y ← BitVec.rand 64
pure
(DPSFP.polynomial_mult (p ++ q) (x ++ y) ==
((DPSFP.polynomial_mult p x) ++ 0#128) ^^^
(0#128 ++ (DPSFP.polynomial_mult q y)) ^^^
(0#64 ++ (DPSFP.polynomial_mult p y) ++ 0#64) ^^^
(0#64 ++ (DPSFP.polynomial_mult q x) ++ 0#64))

/--
info: true
-/
#guard_msgs in
#eval pmult_test_1

theorem DPSFP.polynomial_mult_append {p q x y : BitVec 64} :
DPSFP.polynomial_mult (p ++ q) (x ++ y) =
((DPSFP.polynomial_mult p x) ++ 0#128) ^^^
(0#128 ++ (DPSFP.polynomial_mult q y)) ^^^
(0#64 ++ (DPSFP.polynomial_mult p y) ++ 0#64) ^^^
(0#64 ++ (DPSFP.polynomial_mult q x) ++ 0#64) := by
sorry

/-
Source: Function `GCMInitV8` in `Specs/GCMV8.lean`:
Note that `H0` is the 128-bit HTable input to `gcm_gmult_v8`.

let H2 := GCMV8.gcm_polyval H0 H0
let H1 := ((hi H2) ^^^ (lo H2)) ++ ((hi H0) ^^^ (lo H0))
-/

set_option maxRecDepth 8000 in
set_option maxHeartbeats 500000 in
set_option pp.deepTerms false in
set_option pp.deepTerms.threshold 50 in
-- set_option trace.simp_mem.info true in
#time theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
{H1_hi H1_lo H0_hi H0_lo : BitVec 64}
(h_s0_program : s0.program = gcm_gmult_v8_program)
(h_s0_err : read_err s0 = .None)
(h_s0_pc : read_pc s0 = gcm_gmult_v8_program.min)
(h_s0_sp_aligned : CheckSPAlignment s0)
(h_Xi : Xi = s0[read_gpr 64 0#5 s0, 16])
(h_HTable : HTable = s0[read_gpr 64 1#5 s0, 32])
(h_HTable_lo : H0_hi ++ H0_lo = s0[read_gpr 64 1#5 s0, 16])
(h_HTable_hi : H1_hi ++ H1_lo = s0[read_gpr 64 1#5 s0 + 16#64, 16])
-- (h_HTable : HTable = s0[read_gpr 64 1#5 s0, 32])
-- (h_HTable_alt : HTable = H1_hi ++ H1_lo ++ H0_hi ++ H0_lo)
(h_H1_low_64 : H1_lo = H0_hi ^^^ H0_lo)
-- (h_H1 : HTable.extractLsb' 128 128 =
-- let H0 := HTable.extractLsb' 0 128
-- let H2 := GCMV8.gcm_polyval H0 H0
-- let H0_hi := H0.extractLsb' 64 64
-- let H0_lo := H0.extractLsb' 0 64
-- let H2_hi := H2.extractLsb' 64 64
-- let H2_lo := H2.extractLsb' 0 64
-- ((H2_hi) ^^^ (H2_lo)) ++ ((H0_hi) ^^^ (H0_lo)))
(h_mem_sep : Memory.Region.pairwiseSeparate
[(read_gpr 64 0#5 s0, 16),
(read_gpr 64 1#5 s0, 32)])
(read_gpr 64 1#5 s0, 16),
(read_gpr 64 1#5 s0 + 16#64, 16)])
(h_run : sf = run gcm_gmult_v8_program.length s0) :
-- The final state is error-free.
read_err sf = .None ∧
Expand All @@ -168,11 +217,6 @@ set_option pp.deepTerms.threshold 50 in
CheckSPAlignment sf ∧
-- The final state returns to the address in register `x30` in `s0`.
read_pc sf = r (StateField.GPR 30#5) s0 ∧
-- (TODO) Delete the following conjunct because it is covered by the
-- MEM_UNCHANGED_EXCEPT frame condition. We keep it around because it
-- exposes the issue with `simp_mem` that @bollu will fix.
-- HTable is unmodified.
sf[read_gpr 64 1#5 s0, 32] = HTable ∧
-- Frame conditions.
-- Note that the following also covers that the Xi address in .GPR 0
-- is unmodified.
Expand All @@ -185,24 +229,24 @@ set_option pp.deepTerms.threshold 50 in
sf[r (.GPR 0) s0, 16] =
rev_elems 128 8
(GCMV8.GCMGmultV8_alt
(HTable.extractLsb' 0 128)
(H0_hi ++ H0_lo)
(rev_elems 128 8 Xi (by decide) (by decide)))
(by decide) (by decide) := by
-- Prelude
simp_all only [state_simp_rules, -h_run]
simp only [Nat.reduceMul] at Xi HTable
-- simp only [Nat.reduceMul] at Xi HTable
simp only [Nat.reduceMul] at Xi
simp (config := {ground := true}) only at h_s0_pc
-- ^^ Still needed, because `gcm_gmult_v8_program.min` is somehow
-- unable to be reflected

sym_n 27
-- Epilogue
simp only [←Memory.mem_eq_iff_read_mem_bytes_eq] at *
-- simp only [←Memory.mem_eq_iff_read_mem_bytes_eq] at *
simp only [memory_rules] at *
sym_aggregate
-- Split conjunction
repeat' apply And.intro
· simp_mem; rfl
-- · simp_mem; rfl
· simp only [List.mem_cons, List.mem_singleton, not_or, and_imp] at *
sym_aggregate
· intro n addr h_separate
Expand All @@ -211,27 +255,9 @@ set_option pp.deepTerms.threshold 50 in
simp_mem sep with [h_separate]
· clear_named [h_s, stepi_]
clear s1 s2 s3 s4 s5 s6 s7 s8 s9 s10 s11 s12 s13 s14 s15 s16 s17 s18 s19 s20 s21 s22 s23 s24 s25 s26

-- Simplifying the LHS
have h_HTable_low :
Memory.read_bytes 16 (r (StateField.GPR 1#5) s0) s0.mem = HTable.extractLsb' 0 128 := by
-- (FIXME @bollu) use `simp_mem` instead of the rw below.
-- conv =>
-- lhs
-- simp_mem sub r
rw [@Memory.read_bytes_eq_extractLsBytes_sub_of_mem_subset'
32 (r (StateField.GPR 1#5) s0) HTable (r (StateField.GPR 1#5) s0) 16 _ h_HTable.symm]
· simp only [Nat.reduceMul, BitVec.extractLsBytes, Nat.sub_self, Nat.zero_mul]
· mem_omega!
have h_HTable_high :
(Memory.read_bytes 16 (r (StateField.GPR 1#5) s0 + 16#64) s0.mem) = HTable.extractLsb' 128 128 := by
-- (FIXME @bollu) use `simp_mem` instead of the rw below.
-- conv =>
-- lhs
-- simp_mem sub r
rw [@Memory.read_bytes_eq_extractLsBytes_sub_of_mem_subset'
32 (r (StateField.GPR 1#5) s0) HTable (r (StateField.GPR 1#5) s0 + 16#64) 16 _ h_HTable.symm]
repeat sorry
simp only [h_HTable_high, h_HTable_low, ←h_Xi]
simp only [←h_Xi]
clear h_mem_sep h_run
/-
simp/ground below to reduce
Expand All @@ -251,11 +277,12 @@ set_option pp.deepTerms.threshold 50 in
rw [@vrev128_64_8_in_terms_of_rev_elems (by decide) (by decide) (by decide) (by decide)] at h_Xi_rev
generalize h_Xi_upper_rev : rev_elems 64 8 (BitVec.extractLsb' 64 64 Xi) (by decide) (by decide) = Xi_upper_rev
generalize h_Xi_lower_rev : rev_elems 64 8 (BitVec.extractLsb' 0 64 Xi) (by decide) (by decide) = Xi_lower_rev
-- Simplifying the RHS

simp only [GCMV8.GCMGmultV8_alt,
GCMV8.lo, GCMV8.hi,
GCMV8.gcm_polyval,
←h_HTable, ←h_Xi_rev, h_Xi_lower_rev, h_Xi_upper_rev]
←h_HTable_lo, ←h_HTable_hi,
←h_Xi_rev, h_Xi_lower_rev, h_Xi_upper_rev]
simp only [pmull_op_e_0_eize_64_elements_1_size_128_eq, gcm_polyval_mul_eq_polynomial_mult]
simp only [zeroExtend_allOnes_lsh_64, zeroExtend_allOnes_lsh_0]
rw [BitVec.extractLsb'_64_128_of_appends]
Expand All @@ -265,32 +292,36 @@ set_option pp.deepTerms.threshold 50 in
repeat rw [BitVec.extractLsb'_zero_extractLsb'_of_le (by decide)]
repeat rw [BitVec.extractLsb'_extractLsb'_zero_of_le (by decide)]
rw [BitVec.and_high_to_extractLsb'_concat]
generalize h_HTable_upper : (BitVec.extractLsb' 64 64 HTable) = HTable_upper
generalize h_HTable_lower : (BitVec.extractLsb' 0 64 HTable) = HTable_lower
generalize h_term_u0u1 : (DPSFP.polynomial_mult HTable_upper Xi_upper_rev) = u0u1 at *
generalize h_term_l0l1 : (DPSFP.polynomial_mult HTable_lower Xi_lower_rev) = l0l1 at *
generalize h_term_1 : (DPSFP.polynomial_mult (BitVec.extractLsb' 128 64 HTable) (Xi_lower_rev ^^^ Xi_upper_rev) ^^^
BitVec.extractLsb' 64 128 (l0l1 ++ u0u1) ^^^
(u0u1 ^^^ l0l1)) = term_1
generalize h_term_2 : ((term_1 &&& 0xffffffffffffffff#128 ||| BitVec.zeroExtend 128 (BitVec.setWidth 64 u0u1) <<< 64) ^^^
DPSFP.polynomial_mult (BitVec.extractLsb' 0 64 u0u1) 0xc200000000000000#64)
= term_2
generalize h_term_3 : (BitVec.extractLsb' 64 128 (term_2 ++ term_2) ^^^
(BitVec.extractLsb' 64 64 l0l1 ++ 0x0#64 |||
BitVec.zeroExtend 128 (BitVec.extractLsb' 64 64 term_1) <<< 0))
= term_3

rw [@vrev128_64_8_in_terms_of_rev_elems (by decide) (by decide) (by decide) (by decide)]
rw [BitVec.extractLsb'_64_128_of_appends]
rw [@rev_elems_64_8_append_eq_rev_elems_128_8 _ _ (by decide) (by decide) (by decide) (by decide)]
apply eq_of_rev_elems_eq
rw [@rev_elems_128_8_eq_rev_elems_64_8_extractLsb' _ (by decide) (by decide) (by decide) (by decide) (by decide)]
rw [h_Xi_upper_rev, h_Xi_lower_rev]
rw [BitVec.extractLsb'_append_eq]
simp [GCMV8.gcm_polyval_red]
-- have h_reduce : (GCMV8.reduce 0x100000000000000000000000000000087#129 0x1#129) = 1#129 := by native_decide
-- simp [GCMV8.gcm_polyval_red, GCMV8.irrepoly, GCMV8.pmod, h_reduce]
-- repeat (unfold GCMV8.pmod.pmodTR; simp)
simp only [BitVec.truncate_eq_setWidth, Nat.reduceAdd, BitVec.shiftLeft_zero_eq]

simp [DPSFP.polynomial_mult_append]
simp [GCMV8.gcm_polyval_red, GCMV8.irrepoly]

generalize h_term_1 : DPSFP.polynomial_mult H0_lo Xi_lower_rev = term1
generalize h_term_2 : DPSFP.polynomial_mult H0_hi Xi_upper_rev = term2

-- (TODO) Can we remove `reverse` from `pmod` in the RHS?

-- have h_reduce : (GCMV8.reduce 0x100000000000000000000000000000087#129 0x1#129) = 1#129 := by native_decide
--
-- simp only [GCMV8.gcm_polyval_red, GCMV8.irrepoly,
-- GCMV8.pmod, GCMV8.pmod.pmodTR,
-- GCMV8.reduce, GCMV8.degree, GCMV8.degree.degreeTR]
-- simp only [Nat.reduceAdd, BitVec.ushiftRight_eq, BitVec.reduceExtracLsb',
-- BitVec.reduceHShiftLeft, BitVec.reduceAppend, BitVec.reduceHShiftRight, BitVec.ofNat_eq_ofNat,
-- BitVec.reduceEq, ↓reduceIte, Nat.sub_self, BitVec.ushiftRight_zero_eq, BitVec.reduceAnd,
-- BitVec.toNat_ofNat, Nat.pow_one, Nat.reduceMod, Nat.mul_zero, Nat.add_zero, Nat.zero_mod,
-- Nat.zero_add, Nat.sub_zero, Nat.mul_one, Nat.zero_mul, Nat.one_mul, Nat.reduceSub,
-- BitVec.and_self, BitVec.zero_and, BitVec.reduceMul, BitVec.xor_zero, BitVec.mul_one,
-- BitVec.zero_xor, Nat.add_one_sub_one, BitVec.one_mul, BitVec.reduceXOr]
sorry
done

Expand Down
2 changes: 1 addition & 1 deletion Proofs/Bit_twiddling.lean
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ theorem power_of_two (x : BitVec 32) (i : BitVec 5)

-- ============================================================

set_option sat.solver "cadical" in
/--
6. Conditionally set or clear bits without branching
(https://graphics.stanford.edu/~seander/bithacks.html#ConditionalSetOrClearBitsWithoutBranching)
Expand All @@ -159,7 +160,6 @@ unsigned int w; // the word to modify: if (f) w |= m; else w &= ~m;

w ^= (-f ^ w) & m;
-/

theorem set_clear_no_branch (x : BitVec 32) (f : Bool) (mask : BitVec 32) :
(if f then (x ||| mask) else (x &&& ~~~mask)) =
(x ^^^ (((-(BitVec.zeroExtend 32 (BitVec.ofBool f))) ^^^ x) &&& mask)) := by
Expand Down
4 changes: 3 additions & 1 deletion Proofs/Experiments/MemoryAliasing.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import Arm.Memory.MemoryProofs
import Arm.BitVec
import Arm.Memory.SeparateAutomation



-- set_option trace.simp_mem true
-- set_option trace.simp_mem.info true
-- set_option trace.Meta.Tactic.simp true
Expand Down Expand Up @@ -170,7 +172,7 @@ set_option linter.all false in
mem_omega

set_option linter.all false in
set_option trace.simp_mem.info true in
-- set_option trace.simp_mem.info true in
#time theorem mem_separate_11 (h : mem_separate' a 100 b 100)
(h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1)
(h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1)
Expand Down
2 changes: 1 addition & 1 deletion Proofs/Experiments/SHA512MemoryAliasing.lean
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ work for `16#64 + ktbl_addr`?
#time theorem sha512_block_armv8_loop_sym_ktbl_access (s1 : ArmState)
(_h_s1_err : read_err s1 = StateError.None)
(_h_s1_sp_aligned : CheckSPAlignment s1)
(h_s1_pc : read_pc s1 = 0x126500#64)
(_h_s1_pc : read_pc s1 = 0x126500#64)
(_h_s1_program : s1.program = sha512_program)
(h_s1_num_blocks : num_blocks s1 = 1)
(_h_s1_x3 : r (StateField.GPR 3#5) s1 = ktbl_addr)
Expand Down
2 changes: 1 addition & 1 deletion Proofs/Proofs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import «Proofs».«SHA512».SHA512
import Proofs.«AES-GCM».GCM
import Proofs.Popcount32

/- Experiments we use to test proof strategies and automation ideas. -/
-- /- Experiments we use to test proof strategies and automation ideas. -/
import Proofs.Experiments.Summary1
import Proofs.Experiments.MemoryAliasing
import Proofs.Experiments.SHA512MemoryAliasing
Expand Down
1 change: 1 addition & 0 deletions Proofs/SHA512/SHA512.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ Author(s): Shilpi Goel
-/
import Proofs.SHA512.SHA512_block_armv8_rules
import Proofs.SHA512.SHA512Sym
import Proofs.SHA512.SHA512BlockSym
Loading
Loading