Skip to content

Commit

Permalink
Add reference implementation of prune_updates (#246)
Browse files Browse the repository at this point in the history
### Description:

A detailed description of your contribution. Why is this change
necessary?

### Testing:

What tests have been run? Did `make all` succeed for your changes? Was
conformance testing successful on an Aarch64 machine?

### License:

By submitting this pull request, I confirm that my contribution is
made under the terms of the Apache 2.0 license.

---------

Co-authored-by: Siddharth <[email protected]>
  • Loading branch information
shigoel and bollu authored Nov 13, 2024
1 parent 16a1dd9 commit 6847ebc
Show file tree
Hide file tree
Showing 24 changed files with 17,580 additions and 80 deletions.
5 changes: 5 additions & 0 deletions Arm/Exec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ theorem run_onestep {s s': ArmState} {n : Nat} :
(s' = run (n + 1) s) → ∃ s'', stepi s = s'' ∧ s' = run n s'' := by
simp only [run, exists_eq_left', imp_self]

theorem run_oneblock {s s' : ArmState} {n1 n2 : Nat} :
(s' = run (n1 + n2) s) →
∃ s'', run n1 s = s'' ∧ s' = run n2 s'' := by
simp only [run_plus, exists_eq_left', imp_self]

/-- helper lemma for automation -/
theorem stepi_eq_of_fetch_inst_of_decode_raw_inst
(s : ArmState) (addr : BitVec 64) (rawInst : BitVec 32) (inst : ArmInst)
Expand Down
22 changes: 22 additions & 0 deletions Arm/State.lean
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ theorem store_write_irrelevant [DecidableEq α]
unfold write_store read_store
simp

@[local simp]
theorem write_store_commute [DecidableEq α] (store : Store α β)
(h : i ≠ j) :
write_store i x (write_store j y store) =
write_store j y (write_store i x store) := by
apply funext; intro idx
simp [write_store]
split <;> simp_all

instance [Repr β]: Repr (Store (BitVec n) β) where
reprPrec store _ :=
let rec helper (a : Nat) (acc : Lean.Format) :=
Expand Down Expand Up @@ -451,6 +460,19 @@ theorem w_irrelevant : w fld (r fld s) s = s := by
unfold read_base_error write_base_error
repeat (split <;> simp_all)

theorem w_of_w_commute (h : fld1 ≠ fld2) :
w fld1 v1 (w fld2 v2 s) = w fld2 v2 (w fld1 v1 s) := by
unfold w
unfold write_base_gpr
unfold write_base_sfp
unfold write_base_pc
unfold write_base_flag
unfold write_base_error
split <;> split <;> (simp_all; try rwa [write_store_commute])
rename_i fld1 t1 i1 v1 fld2 t2 i2 v2
cases i1 <;> (cases i2 <;> (split <;> simp_all))
done

@[state_simp_rules]
theorem fetch_inst_of_w : fetch_inst addr (w fld val s) = fetch_inst addr s := by
unfold fetch_inst w
Expand Down
149 changes: 90 additions & 59 deletions Proofs/AES-GCM/GCMGmultV8Sym.lean
Original file line number Diff line number Diff line change
Expand Up @@ -134,31 +134,80 @@ multiplication into four constituent ones, for normalization.
example :
let p := 0b11#2
let q := 0b10#2
let w := 0b01#2
let z := 0b01#2
let x := 0b01#2
let y := 0b01#2
(DPSFP.polynomial_mult
(p ++ q)
(w ++ z))
(x ++ y))
=
((DPSFP.polynomial_mult p w) ++ 0#4) ^^^
(0#4 ++ (DPSFP.polynomial_mult q z)) ^^^
(0#2 ++ (DPSFP.polynomial_mult p z) ++ 0#2) ^^^
(0#2 ++ (DPSFP.polynomial_mult q w) ++ 0#2) := by native_decide
((DPSFP.polynomial_mult p x) ++ 0#4) ^^^
(0#4 ++ (DPSFP.polynomial_mult q y)) ^^^
(0#2 ++ (DPSFP.polynomial_mult p y) ++ 0#2) ^^^
(0#2 ++ (DPSFP.polynomial_mult q x) ++ 0#2) := by native_decide

def pmult_test_1 : IO Bool := do
let p ← BitVec.rand 64
let q ← BitVec.rand 64
let x ← BitVec.rand 64
let y ← BitVec.rand 64
pure
(DPSFP.polynomial_mult (p ++ q) (x ++ y) ==
((DPSFP.polynomial_mult p x) ++ 0#128) ^^^
(0#128 ++ (DPSFP.polynomial_mult q y)) ^^^
(0#64 ++ (DPSFP.polynomial_mult p y) ++ 0#64) ^^^
(0#64 ++ (DPSFP.polynomial_mult q x) ++ 0#64))

/--
info: true
-/
#guard_msgs in
#eval pmult_test_1

theorem DPSFP.polynomial_mult_append {p q x y : BitVec 64} :
DPSFP.polynomial_mult (p ++ q) (x ++ y) =
((DPSFP.polynomial_mult p x) ++ 0#128) ^^^
(0#128 ++ (DPSFP.polynomial_mult q y)) ^^^
(0#64 ++ (DPSFP.polynomial_mult p y) ++ 0#64) ^^^
(0#64 ++ (DPSFP.polynomial_mult q x) ++ 0#64) := by
sorry

/-
Source: Function `GCMInitV8` in `Specs/GCMV8.lean`:
Note that `H0` is the 128-bit HTable input to `gcm_gmult_v8`.
let H2 := GCMV8.gcm_polyval H0 H0
let H1 := ((hi H2) ^^^ (lo H2)) ++ ((hi H0) ^^^ (lo H0))
-/

set_option maxRecDepth 8000 in
set_option maxHeartbeats 500000 in
set_option pp.deepTerms false in
set_option pp.deepTerms.threshold 50 in
-- set_option trace.simp_mem.info true in
#time theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
{H1_hi H1_lo H0_hi H0_lo : BitVec 64}
(h_s0_program : s0.program = gcm_gmult_v8_program)
(h_s0_err : read_err s0 = .None)
(h_s0_pc : read_pc s0 = gcm_gmult_v8_program.min)
(h_s0_sp_aligned : CheckSPAlignment s0)
(h_Xi : Xi = s0[read_gpr 64 0#5 s0, 16])
(h_HTable : HTable = s0[read_gpr 64 1#5 s0, 32])
(h_HTable_lo : H0_hi ++ H0_lo = s0[read_gpr 64 1#5 s0, 16])
(h_HTable_hi : H1_hi ++ H1_lo = s0[read_gpr 64 1#5 s0 + 16#64, 16])
-- (h_HTable : HTable = s0[read_gpr 64 1#5 s0, 32])
-- (h_HTable_alt : HTable = H1_hi ++ H1_lo ++ H0_hi ++ H0_lo)
(h_H1_low_64 : H1_lo = H0_hi ^^^ H0_lo)
-- (h_H1 : HTable.extractLsb' 128 128 =
-- let H0 := HTable.extractLsb' 0 128
-- let H2 := GCMV8.gcm_polyval H0 H0
-- let H0_hi := H0.extractLsb' 64 64
-- let H0_lo := H0.extractLsb' 0 64
-- let H2_hi := H2.extractLsb' 64 64
-- let H2_lo := H2.extractLsb' 0 64
-- ((H2_hi) ^^^ (H2_lo)) ++ ((H0_hi) ^^^ (H0_lo)))
(h_mem_sep : Memory.Region.pairwiseSeparate
[(read_gpr 64 0#5 s0, 16),
(read_gpr 64 1#5 s0, 32)])
(read_gpr 64 1#5 s0, 16),
(read_gpr 64 1#5 s0 + 16#64, 16)])
(h_run : sf = run gcm_gmult_v8_program.length s0) :
-- The final state is error-free.
read_err sf = .None ∧
Expand All @@ -168,11 +217,6 @@ set_option pp.deepTerms.threshold 50 in
CheckSPAlignment sf ∧
-- The final state returns to the address in register `x30` in `s0`.
read_pc sf = r (StateField.GPR 30#5) s0 ∧
-- (TODO) Delete the following conjunct because it is covered by the
-- MEM_UNCHANGED_EXCEPT frame condition. We keep it around because it
-- exposes the issue with `simp_mem` that @bollu will fix.
-- HTable is unmodified.
sf[read_gpr 64 1#5 s0, 32] = HTable ∧
-- Frame conditions.
-- Note that the following also covers that the Xi address in .GPR 0
-- is unmodified.
Expand All @@ -185,24 +229,24 @@ set_option pp.deepTerms.threshold 50 in
sf[r (.GPR 0) s0, 16] =
rev_elems 128 8
(GCMV8.GCMGmultV8_alt
(HTable.extractLsb' 0 128)
(H0_hi ++ H0_lo)
(rev_elems 128 8 Xi (by decide) (by decide)))
(by decide) (by decide) := by
-- Prelude
simp_all only [state_simp_rules, -h_run]
simp only [Nat.reduceMul] at Xi HTable
-- simp only [Nat.reduceMul] at Xi HTable
simp only [Nat.reduceMul] at Xi
simp (config := {ground := true}) only at h_s0_pc
-- ^^ Still needed, because `gcm_gmult_v8_program.min` is somehow
-- unable to be reflected

sym_n 27
-- Epilogue
simp only [←Memory.mem_eq_iff_read_mem_bytes_eq] at *
-- simp only [←Memory.mem_eq_iff_read_mem_bytes_eq] at *
simp only [memory_rules] at *
sym_aggregate
-- Split conjunction
repeat' apply And.intro
· simp_mem; rfl
-- · simp_mem; rfl
· simp only [List.mem_cons, List.mem_singleton, not_or, and_imp] at *
sym_aggregate
· intro n addr h_separate
Expand All @@ -211,27 +255,9 @@ set_option pp.deepTerms.threshold 50 in
simp_mem sep with [h_separate]
· clear_named [h_s, stepi_]
clear s1 s2 s3 s4 s5 s6 s7 s8 s9 s10 s11 s12 s13 s14 s15 s16 s17 s18 s19 s20 s21 s22 s23 s24 s25 s26

-- Simplifying the LHS
have h_HTable_low :
Memory.read_bytes 16 (r (StateField.GPR 1#5) s0) s0.mem = HTable.extractLsb' 0 128 := by
-- (FIXME @bollu) use `simp_mem` instead of the rw below.
-- conv =>
-- lhs
-- simp_mem sub r
rw [@Memory.read_bytes_eq_extractLsBytes_sub_of_mem_subset'
32 (r (StateField.GPR 1#5) s0) HTable (r (StateField.GPR 1#5) s0) 16 _ h_HTable.symm]
· simp only [Nat.reduceMul, BitVec.extractLsBytes, Nat.sub_self, Nat.zero_mul]
· mem_omega!
have h_HTable_high :
(Memory.read_bytes 16 (r (StateField.GPR 1#5) s0 + 16#64) s0.mem) = HTable.extractLsb' 128 128 := by
-- (FIXME @bollu) use `simp_mem` instead of the rw below.
-- conv =>
-- lhs
-- simp_mem sub r
rw [@Memory.read_bytes_eq_extractLsBytes_sub_of_mem_subset'
32 (r (StateField.GPR 1#5) s0) HTable (r (StateField.GPR 1#5) s0 + 16#64) 16 _ h_HTable.symm]
repeat sorry
simp only [h_HTable_high, h_HTable_low, ←h_Xi]
simp only [←h_Xi]
clear h_mem_sep h_run
/-
simp/ground below to reduce
Expand All @@ -251,11 +277,12 @@ set_option pp.deepTerms.threshold 50 in
rw [@vrev128_64_8_in_terms_of_rev_elems (by decide) (by decide) (by decide) (by decide)] at h_Xi_rev
generalize h_Xi_upper_rev : rev_elems 64 8 (BitVec.extractLsb' 64 64 Xi) (by decide) (by decide) = Xi_upper_rev
generalize h_Xi_lower_rev : rev_elems 64 8 (BitVec.extractLsb' 0 64 Xi) (by decide) (by decide) = Xi_lower_rev
-- Simplifying the RHS

simp only [GCMV8.GCMGmultV8_alt,
GCMV8.lo, GCMV8.hi,
GCMV8.gcm_polyval,
←h_HTable, ←h_Xi_rev, h_Xi_lower_rev, h_Xi_upper_rev]
←h_HTable_lo, ←h_HTable_hi,
←h_Xi_rev, h_Xi_lower_rev, h_Xi_upper_rev]
simp only [pmull_op_e_0_eize_64_elements_1_size_128_eq, gcm_polyval_mul_eq_polynomial_mult]
simp only [zeroExtend_allOnes_lsh_64, zeroExtend_allOnes_lsh_0]
rw [BitVec.extractLsb'_64_128_of_appends]
Expand All @@ -265,32 +292,36 @@ set_option pp.deepTerms.threshold 50 in
repeat rw [BitVec.extractLsb'_zero_extractLsb'_of_le (by decide)]
repeat rw [BitVec.extractLsb'_extractLsb'_zero_of_le (by decide)]
rw [BitVec.and_high_to_extractLsb'_concat]
generalize h_HTable_upper : (BitVec.extractLsb' 64 64 HTable) = HTable_upper
generalize h_HTable_lower : (BitVec.extractLsb' 0 64 HTable) = HTable_lower
generalize h_term_u0u1 : (DPSFP.polynomial_mult HTable_upper Xi_upper_rev) = u0u1 at *
generalize h_term_l0l1 : (DPSFP.polynomial_mult HTable_lower Xi_lower_rev) = l0l1 at *
generalize h_term_1 : (DPSFP.polynomial_mult (BitVec.extractLsb' 128 64 HTable) (Xi_lower_rev ^^^ Xi_upper_rev) ^^^
BitVec.extractLsb' 64 128 (l0l1 ++ u0u1) ^^^
(u0u1 ^^^ l0l1)) = term_1
generalize h_term_2 : ((term_1 &&& 0xffffffffffffffff#128 ||| BitVec.zeroExtend 128 (BitVec.setWidth 64 u0u1) <<< 64) ^^^
DPSFP.polynomial_mult (BitVec.extractLsb' 0 64 u0u1) 0xc200000000000000#64)
= term_2
generalize h_term_3 : (BitVec.extractLsb' 64 128 (term_2 ++ term_2) ^^^
(BitVec.extractLsb' 64 64 l0l1 ++ 0x0#64 |||
BitVec.zeroExtend 128 (BitVec.extractLsb' 64 64 term_1) <<< 0))
= term_3

rw [@vrev128_64_8_in_terms_of_rev_elems (by decide) (by decide) (by decide) (by decide)]
rw [BitVec.extractLsb'_64_128_of_appends]
rw [@rev_elems_64_8_append_eq_rev_elems_128_8 _ _ (by decide) (by decide) (by decide) (by decide)]
apply eq_of_rev_elems_eq
rw [@rev_elems_128_8_eq_rev_elems_64_8_extractLsb' _ (by decide) (by decide) (by decide) (by decide) (by decide)]
rw [h_Xi_upper_rev, h_Xi_lower_rev]
rw [BitVec.extractLsb'_append_eq]
simp [GCMV8.gcm_polyval_red]
-- have h_reduce : (GCMV8.reduce 0x100000000000000000000000000000087#129 0x1#129) = 1#129 := by native_decide
-- simp [GCMV8.gcm_polyval_red, GCMV8.irrepoly, GCMV8.pmod, h_reduce]
-- repeat (unfold GCMV8.pmod.pmodTR; simp)
simp only [BitVec.truncate_eq_setWidth, Nat.reduceAdd, BitVec.shiftLeft_zero_eq]

simp [DPSFP.polynomial_mult_append]
simp [GCMV8.gcm_polyval_red, GCMV8.irrepoly]

generalize h_term_1 : DPSFP.polynomial_mult H0_lo Xi_lower_rev = term1
generalize h_term_2 : DPSFP.polynomial_mult H0_hi Xi_upper_rev = term2

-- (TODO) Can we remove `reverse` from `pmod` in the RHS?

-- have h_reduce : (GCMV8.reduce 0x100000000000000000000000000000087#129 0x1#129) = 1#129 := by native_decide
--
-- simp only [GCMV8.gcm_polyval_red, GCMV8.irrepoly,
-- GCMV8.pmod, GCMV8.pmod.pmodTR,
-- GCMV8.reduce, GCMV8.degree, GCMV8.degree.degreeTR]
-- simp only [Nat.reduceAdd, BitVec.ushiftRight_eq, BitVec.reduceExtracLsb',
-- BitVec.reduceHShiftLeft, BitVec.reduceAppend, BitVec.reduceHShiftRight, BitVec.ofNat_eq_ofNat,
-- BitVec.reduceEq, ↓reduceIte, Nat.sub_self, BitVec.ushiftRight_zero_eq, BitVec.reduceAnd,
-- BitVec.toNat_ofNat, Nat.pow_one, Nat.reduceMod, Nat.mul_zero, Nat.add_zero, Nat.zero_mod,
-- Nat.zero_add, Nat.sub_zero, Nat.mul_one, Nat.zero_mul, Nat.one_mul, Nat.reduceSub,
-- BitVec.and_self, BitVec.zero_and, BitVec.reduceMul, BitVec.xor_zero, BitVec.mul_one,
-- BitVec.zero_xor, Nat.add_one_sub_one, BitVec.one_mul, BitVec.reduceXOr]
sorry
done

Expand Down
2 changes: 1 addition & 1 deletion Proofs/Bit_twiddling.lean
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ theorem power_of_two (x : BitVec 32) (i : BitVec 5)

-- ============================================================

set_option sat.solver "cadical" in
/--
6. Conditionally set or clear bits without branching
(https://graphics.stanford.edu/~seander/bithacks.html#ConditionalSetOrClearBitsWithoutBranching)
Expand All @@ -159,7 +160,6 @@ unsigned int w; // the word to modify: if (f) w |= m; else w &= ~m;
w ^= (-f ^ w) & m;
-/

theorem set_clear_no_branch (x : BitVec 32) (f : Bool) (mask : BitVec 32) :
(if f then (x ||| mask) else (x &&& ~~~mask)) =
(x ^^^ (((-(BitVec.zeroExtend 32 (BitVec.ofBool f))) ^^^ x) &&& mask)) := by
Expand Down
4 changes: 3 additions & 1 deletion Proofs/Experiments/MemoryAliasing.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import Arm.Memory.MemoryProofs
import Arm.BitVec
import Arm.Memory.SeparateAutomation



-- set_option trace.simp_mem true
-- set_option trace.simp_mem.info true
-- set_option trace.Meta.Tactic.simp true
Expand Down Expand Up @@ -170,7 +172,7 @@ set_option linter.all false in
mem_omega

set_option linter.all false in
set_option trace.simp_mem.info true in
-- set_option trace.simp_mem.info true in
#time theorem mem_separate_11 (h : mem_separate' a 100 b 100)
(h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1)
(h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1) (h' : a < b + 1)
Expand Down
2 changes: 1 addition & 1 deletion Proofs/Experiments/SHA512MemoryAliasing.lean
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ work for `16#64 + ktbl_addr`?
#time theorem sha512_block_armv8_loop_sym_ktbl_access (s1 : ArmState)
(_h_s1_err : read_err s1 = StateError.None)
(_h_s1_sp_aligned : CheckSPAlignment s1)
(h_s1_pc : read_pc s1 = 0x126500#64)
(_h_s1_pc : read_pc s1 = 0x126500#64)
(_h_s1_program : s1.program = sha512_program)
(h_s1_num_blocks : num_blocks s1 = 1)
(_h_s1_x3 : r (StateField.GPR 3#5) s1 = ktbl_addr)
Expand Down
2 changes: 1 addition & 1 deletion Proofs/Proofs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import «Proofs».«SHA512».SHA512
import Proofs.«AES-GCM».GCM
import Proofs.Popcount32

/- Experiments we use to test proof strategies and automation ideas. -/
-- /- Experiments we use to test proof strategies and automation ideas. -/
import Proofs.Experiments.Summary1
import Proofs.Experiments.MemoryAliasing
import Proofs.Experiments.SHA512MemoryAliasing
Expand Down
1 change: 1 addition & 0 deletions Proofs/SHA512/SHA512.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ Author(s): Shilpi Goel
-/
import Proofs.SHA512.SHA512_block_armv8_rules
import Proofs.SHA512.SHA512Sym
import Proofs.SHA512.SHA512BlockSym
Loading

0 comments on commit 6847ebc

Please sign in to comment.