Skip to content

Commit

Permalink
Define PState as a structure
Browse files Browse the repository at this point in the history
  • Loading branch information
shigoel committed Mar 4, 2024
1 parent aafe462 commit 0a025d4
Show file tree
Hide file tree
Showing 12 changed files with 134 additions and 71 deletions.
10 changes: 5 additions & 5 deletions Arm/BitVec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ open BitVec

-- Adding some useful simp lemmas to `bitvec_rules`:
attribute [bitvec_rules] BitVec.ofFin_eq_ofNat
attribute [minimal_theory] BitVec.extractLsb_ofFin
attribute [minimal_theory] zeroExtend_eq
attribute [minimal_theory] add_ofFin
attribute [minimal_theory] ofFin.injEq
attribute [minimal_theory] Fin.mk.injEq
attribute [bitvec_rules] BitVec.extractLsb_ofFin
attribute [bitvec_rules] BitVec.zeroExtend_eq
attribute [bitvec_rules] BitVec.ofFin.injEq
attribute [bitvec_rules] BitVec.extractLsb_toNat
-- attribute [bitvec_rules] add_ofFin

----------------------------------------------------------------------
-- Some BitVec definitions
Expand Down
2 changes: 1 addition & 1 deletion Arm/Cosim.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def init_cosim_state : ArmState :=
{ gpr := (fun (_ : BitVec 5) => 0#64),
sfp := (fun (_ : BitVec 5) => 0#128),
pc := 0#64,
pstate := (fun (_ : PFlag) => 0#1),
pstate := zero_pstate,
mem := (fun (_ : BitVec 64) => 0#8),
program := (fun (_ : BitVec 64) => none),
error := StateError.None }
Expand Down
14 changes: 8 additions & 6 deletions Arm/Insts/BR/Cond_branch_imm.lean
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@ namespace BR
open BitVec

@[state_simp_rules]
def Cond_branch_imm_inst.branch_taken_pc (inst : Cond_branch_imm_inst) (pc : BitVec 64) : BitVec 64 :=
def Cond_branch_imm_inst.branch_taken_pc
(inst : Cond_branch_imm_inst) (pc : BitVec 64) : BitVec 64 :=
let offset := signExtend 64 (inst.imm19 <<< 2)
pc + offset

@[state_simp_rules]
def Cond_branch_imm_inst.condition_holds (inst : Cond_branch_imm_inst) (s : ArmState): Bool :=
let Z := read_store PFlag.Z s.pstate
let C := read_store PFlag.C s.pstate
let N := read_store PFlag.N s.pstate
let V := read_store PFlag.V s.pstate
def Cond_branch_imm_inst.condition_holds
(inst : Cond_branch_imm_inst) (s : ArmState) : Bool :=
let Z := read_flag PFlag.Z s
let C := read_flag PFlag.C s
let N := read_flag PFlag.N s
let V := read_flag PFlag.V s
let result :=
match (extractLsb 3 1 inst.cond) with
| 0b000#3 => Z == 1#1
Expand Down
46 changes: 33 additions & 13 deletions Arm/Insts/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ open BitVec

----------------------------------------------------------------------

def AddWithCarry (x : BitVec n) (y : BitVec n) (carry_in : BitVec 1) : (BitVec n × PState) :=
def AddWithCarry (x : BitVec n) (y : BitVec n) (carry_in : BitVec 1) :
(BitVec n × PState) :=
let carry_in_nat := BitVec.toNat carry_in
let unsigned_sum := BitVec.toNat x + BitVec.toNat y + carry_in_nat
let signed_sum := BitVec.toInt x + BitVec.toInt y + carry_in_nat
Expand All @@ -24,17 +25,21 @@ def AddWithCarry (x : BitVec n) (y : BitVec n) (carry_in : BitVec 1) : (BitVec n
let V := if BitVec.toInt result = signed_sum then 0#1 else 1#1
(result, (make_pstate N Z C V))

def ConditionHolds (cond : BitVec 4) (pstate : PState) : Bool :=
def ConditionHolds (cond : BitVec 4) (s : ArmState) : Bool :=
open PFlag in
let N := read_flag N s
let Z := read_flag Z s
let C := read_flag C s
let V := read_flag V s
let result :=
match (extractLsb 3 1 cond) with
| 0b000#3 => pstate Z == 1#1
| 0b001#3 => pstate C == 1#1
| 0b010#3 => pstate N == 1#1
| 0b011#3 => pstate V == 1#1
| 0b100#3 => pstate C == 1#1 && pstate Z == 0#1
| 0b101#3 => pstate N == pstate V
| 0b110#3 => pstate N == pstate V && pstate Z == 0#1
| 0b000#3 => Z == 1#1
| 0b001#3 => C == 1#1
| 0b010#3 => N == 1#1
| 0b011#3 => V == 1#1
| 0b100#3 => C == 1#1 && Z == 0#1
| 0b101#3 => N == V
| 0b110#3 => N == V && Z == 0#1
| 0b111#3 => true
if (extractLsb 0 0 cond) = 1#1 && cond ≠ 0b1111#4 then
not result
Expand All @@ -51,6 +56,21 @@ def CheckSPAlignment (s : ArmState) : Bool :=
-- 16-aligned.
(extractLsb 3 0 sp) &&& 0xF#4 == 0#4

@[state_simp_rules]
theorem CheckSPAligment_of_w_different (h : StateField.GPR 31#5 ≠ fld) :
CheckSPAlignment (w fld v s) = CheckSPAlignment s := by
simp_all only [CheckSPAlignment, state_simp_rules, minimal_theory, bitvec_rules]

@[state_simp_rules]
theorem CheckSPAligment_of_w_sp :
CheckSPAlignment (w (StateField.GPR 31#5) v s) = ((extractLsb 3 0 v) &&& 0xF#4 == 0#4) := by
simp_all only [CheckSPAlignment, state_simp_rules, minimal_theory, bitvec_rules]

@[state_simp_rules]
theorem CheckSPAligment_of_write_mem_bytes :
CheckSPAlignment (write_mem_bytes n addr v s) = CheckSPAlignment s := by
simp_all only [CheckSPAlignment, state_simp_rules, minimal_theory, bitvec_rules]

----------------------------------------------------------------------

inductive ShiftType where
Expand Down Expand Up @@ -262,7 +282,7 @@ def rev_elems (n esize : Nat) (x : BitVec n) (h₀ : esize ∣ n) (h₁ : 0 < es
else
let element := BitVec.zeroExtend esize x
let rest_x := BitVec.zeroExtend (n - esize) (x >>> esize)
have h1 : esize <= n := by
have h1 : esize <= n := by
simp at h0; exact Nat.le_of_lt h0; done
have h2 : esize ∣ (n - esize) := by
refine Nat.dvd_sub ?H h₀ ?h₂
Expand All @@ -279,12 +299,12 @@ def rev_elems (n esize : Nat) (x : BitVec n) (h₀ : esize ∣ n) (h₁ : 0 < es
example : rev_elems 4 4 0xA#4 (by decide) (by decide) = 0xA#4 := rfl
example : rev_elems 8 4 0xAB#8 (by decide) (by decide) = 0xBA#8 := rfl
example : rev_elems 8 4 (rev_elems 8 4 0xAB#8 (by decide) (by decide))
(by decide) (by decide) = 0xAB#8 := by native_decide
(by decide) (by decide) = 0xAB#8 := by native_decide

theorem rev_elems_base :
rev_elems esize esize x h₀ h₁ = x := by
unfold rev_elems; simp; done

/-- Divide a bv of width `datasize` into containers, each of size
`container_size`, and within a container, reverse the order of `esize`-bit
elements. -/
Expand All @@ -302,7 +322,7 @@ def rev_vector (datasize container_size esize : Nat) (x : BitVec datasize)
have h₄' : container_size ∣ new_datasize := by
have h : container_size ∣ container_size := Nat.dvd_refl _
exact Nat.dvd_sub h₂ h₄ h
have h₂' : container_size <= new_datasize := by
have h₂' : container_size <= new_datasize := by
refine Nat.le_of_dvd ?h h₄'
omega
have h1 : 0 < container_size := by exact Nat.lt_of_lt_of_le h₀ h₁
Expand Down
2 changes: 1 addition & 1 deletion Arm/Insts/DPR/Conditional_select.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def exec_conditional_select (inst : Conditional_select_cls) (s : ArmState) : Arm
match inst.op, inst.S, inst.op2 with
| 0b0#1, 0b0#1, 0b00#2 => -- CSEL
(false,
if ConditionHolds inst.cond (read_pstate s) then
if ConditionHolds inst.cond s then
read_gpr_zr datasize inst.Rn s
else
read_gpr_zr datasize inst.Rm s)
Expand Down
1 change: 1 addition & 0 deletions Arm/MinTheory.lean
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,6 @@ attribute [minimal_theory] Option.isNone_some

attribute [minimal_theory] Fin.isValue
attribute [minimal_theory] Fin.zero_eta
attribute [minimal_theory] Fin.mk.injEq

-- attribute [minimal_theory] ↓reduceIte
65 changes: 37 additions & 28 deletions Arm/State.lean
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,12 @@ inductive PFlag where
| V : PFlag
deriving DecidableEq, Repr

abbrev PState := Store PFlag (BitVec 1)

instance [Repr β]: Repr (Store PFlag β) where
reprPrec store _ :=
let rec helper (ps : List PFlag) :=
match ps with
| [] => ""
| p :: rest => "(" ++ repr p ++ " : " ++ (repr (read_store p store)) ++ ") " ++ helper rest
open PFlag in
helper [N, Z, C, V]

-- def init_store : Store PFlag (BitVec 1) := (fun (_ : PFlag) => 0#1)
-- #eval init_store
structure PState where
n : BitVec 1
z : BitVec 1
c : BitVec 1
v : BitVec 1
deriving DecidableEq, Repr

structure ArmState where
-- General-purpose registers: register 31 is the stack pointer.
Expand Down Expand Up @@ -182,10 +175,23 @@ def write_base_pstate (pstate : PState) (s : ArmState) : ArmState :=
{ s with pstate := pstate }

def read_base_flag (flag : PFlag) (s : ArmState) : BitVec 1 :=
read_store flag s.pstate
open PFlag in
let pstate := s.pstate
match flag with
| N => pstate.n
| Z => pstate.z
| C => pstate.c
| V => pstate.v

def write_base_flag (flag : PFlag) (val : BitVec 1) (s : ArmState) : ArmState :=
let new_pstate := write_store flag val s.pstate
open PFlag in
let pstate := s.pstate
let new_pstate :=
match flag with
| N => { pstate with n := val }
| Z => { pstate with z := val }
| C => { pstate with c := val }
| V => { pstate with v := val }
{ s with pstate := new_pstate }

-- Program --
Expand Down Expand Up @@ -261,7 +267,7 @@ theorem r_of_w_same (fld : StateField) (v : (state_value fld)) (s : ArmState) :
unfold read_base_pc write_base_pc
unfold read_base_flag write_base_flag
unfold read_base_error write_base_error
split <;> split <;> simp_all!
split <;> (repeat (split <;> simp_all!))

@[state_simp_rules]
theorem r_of_w_different (fld1 fld2 : StateField) (v : (state_value fld2)) (s : ArmState)
Expand All @@ -274,7 +280,7 @@ theorem r_of_w_different (fld1 fld2 : StateField) (v : (state_value fld2)) (s :
unfold read_base_flag write_base_flag
unfold read_base_error write_base_error
simp_all!
split <;> split <;> simp_all!
split <;> (repeat (split <;> simp_all!))

@[state_simp_rules]
theorem w_of_w_shadow (fld : StateField) (v1 v2 : (state_value fld)) (s : ArmState) :
Expand All @@ -285,7 +291,7 @@ theorem w_of_w_shadow (fld : StateField) (v1 v2 : (state_value fld)) (s : ArmSta
unfold write_base_pc
unfold write_base_flag
unfold write_base_error
split <;> simp
(repeat (split <;> simp_all!))

@[state_simp_rules]
theorem w_irrelevant (fld : StateField) (v1 v2 : (state_value fld)) (s : ArmState) :
Expand All @@ -296,7 +302,7 @@ theorem w_irrelevant (fld : StateField) (v1 v2 : (state_value fld)) (s : ArmStat
unfold read_base_pc write_base_pc
unfold read_base_flag write_base_flag
unfold read_base_error write_base_error
split <;> simp
repeat (split <;> simp_all)

@[state_simp_rules]
theorem fetch_inst_of_w (addr : BitVec 64) (fld : StateField) (val : (state_value fld)) (s : ArmState) :
Expand Down Expand Up @@ -396,20 +402,23 @@ def write_flag (flag : PFlag) (val : BitVec 1) (s : ArmState) : ArmState :=

@[state_simp_rules]
def read_pstate (s : ArmState) : PState :=
fun p => read_flag p s
s.pstate

@[state_simp_rules]
def write_pstate (pstate : PState) (s : ArmState) : ArmState :=
{ s with pstate := pstate }
open StateField PFlag in
let s := w (FLAG N) pstate.n s
let s := w (FLAG Z) pstate.z s
let s := w (FLAG C) pstate.c s
let s := w (FLAG V) pstate.v s
s

-- (FIXME) Define in terms of write_flag so that we see checkpoints in
-- terms of the w function.
@[state_simp_rules]
def make_pstate (n z c v : BitVec 1) : PState :=
fun (p : PFlag) =>
open PFlag in
match p with
| N => n | Z => z | C => c | V => v
{ n, z, c, v }

def zero_pstate : PState :=
{ n := 0#1, z := 0#1, c := 0#1, v := 0#1 }

@[state_simp_rules]
def read_err (s : ArmState) : StateError :=
Expand Down Expand Up @@ -467,7 +476,7 @@ end Load_program_and_fetch_inst

example :
read_flag flag (write_flag flag val s) = val := by
simp only [state_simp_rules, minimal_theory]
simp only [state_simp_rules, minimal_theory]

example (h : flag1 ≠ flag2) :
read_flag flag1 (write_flag flag2 val s) = read_flag flag1 s := by
Expand Down
51 changes: 41 additions & 10 deletions Proofs/Sha512_block_armv8.lean
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ open BitVec

def sha512_program_test_1 : program :=
def_program
[(0x126538#64 , 0xcec08230#32), -- sha512su0 v16.2d, v17.2d
[(0x126538#64 , 0xcec08230#32), -- sha512su0 v16.2d, v17.2d
(0x12653c#64 , 0x6e154287#32), -- ext v7.16b, v20.16b, v21.16b, #8
(0x126540#64 , 0xce6680a3#32), -- sha512h q3, q5, v6.2d
(0x126544#64 , 0xce678af0#32) -- sha512su1 v16.2d, v23.2d, v7.2d
Expand All @@ -46,12 +46,11 @@ theorem sha512_program_test_1_sym (s0 s_final : ArmState)

----------------------------------------------------------------------

-- Test 2: We have a 5 instruction program, and we attempt to simulate
-- only 4 of them.
-- Test 2: A 6 instruction program.

def sha512_program_test_2 : program :=
def_program
[(0x126538#64 , 0xcec08230#32), -- sha512su0 v16.2d, v17.2d
[(0x126538#64 , 0xcec08230#32), -- sha512su0 v16.2d, v17.2d
(0x12653c#64 , 0x6e154287#32), -- ext v7.16b, v20.16b, v21.16b, #8
(0x126540#64 , 0xce6680a3#32), -- sha512h q3, q5, v6.2d
(0x126544#64 , 0xce678af0#32), -- sha512su1 v16.2d, v23.2d, v7.2d
Expand All @@ -64,15 +63,15 @@ theorem sha512_program_test_2_sym (s0 s_final : ArmState)
(h_s0_pc : read_pc s0 = 0x126538#64)
(h_s0_program : s0.program = sha512_program_test_2.find?)
(h_s0_ok : read_err s0 = StateError.None)
(h_run : s_final = run 4 s0) :
(h_run : s_final = run 6 s0) :
read_err s_final = StateError.None := by
-- Prelude
simp_all only [state_simp_rules, -h_run]
-- Symbolic simulation
sym_n 4 h_s0_program
sym_n 6 h_s0_program
-- Final steps
unfold run at h_run
subst s_final s_4
subst s_final s_6
simp_all only [state_simp_rules, minimal_theory, bitvec_rules]
done

Expand All @@ -82,13 +81,16 @@ theorem sha512_program_test_2_sym (s0 s_final : ArmState)

def sha512_program_test_3 : program :=
def_program
[(0x1264c0#64 , 0xa9bf7bfd#32), -- stp x29, x30, [sp, #-16]!
[(0x1264c0#64 , 0xa9bf7bfd#32), -- stp x29, x30, [sp, #-16]!
(0x1264c4#64 , 0x910003fd#32), -- mov x29, sp
(0x1264c8#64 , 0x4cdf2030#32), -- ld1 {v16.16b-v19.16b}, [x1], #64
(0x1264cc#64 , 0x4cdf2034#32) -- ld1 {v20.16b-v23.16b}, [x1], #64
]


-- set_option profiler true in
-- set_option maxRecDepth 10000 in
set_option pp.deepTerms false in
set_option pp.deepTerms.threshold 10 in
theorem sha512_block_armv8_test_3_sym (s0 s_final : ArmState)
(h_s0_ok : read_err s0 = StateError.None)
(h_s0_sp_aligned : CheckSPAlignment s0 = true)
Expand All @@ -99,7 +101,36 @@ theorem sha512_block_armv8_test_3_sym (s0 s_final : ArmState)
-- Prelude
simp_all only [state_simp_rules, -h_run]
-- Symbolic simulation
-- sym_n 1 h_s0_program

sym_i_n 0 1 h_s0_program

-- sym_i_n 1 1 h_s0_program
init_next_step h_run
rename_i s_2 h_step_2 h_run
fetch_and_decode_inst h_step_2 h_s0_program
clear h_step_1
-- exec_inst h_step_2
-- FIXME: simproc for first args. of write_mem_bytes and zeroExtend
simp only [exec_inst, DPI.exec_add_sub_imm, write_gpr, Nat.sub_zero, read_gpr, ne_eq,
not_false_eq_true, r_of_w_different, r_of_w_same, beq_self_eq_true, ite_true, write_pstate,
write_pc, read_pc, w_of_w_shadow, w_program, write_mem_bytes_program, h_s0_program] at h_step_2
simp only [beq_iff_eq, ofNat_add_ofNat, Nat.reduceAdd] at h_step_2
-- simp (config := {ground := true}) only at h_step_2 -- max. recursion depth is reached.
conv at h_step_2 =>
rhs
arg 3
tactic => simp (config := {ground := true}) only [minimal_theory, bitvec_rules, ↓reduceIte, reduceAdd, reduceSub, Fin.reduceEq, Fin.mk_one]
(try simp only [BitVec.ofFin_eq_ofNat] at h_step_2)

-- sym_i_n 2 1 h_s0_program
init_next_step h_run
rename_i s_3 h_step_3 h_run
fetch_and_decode_inst h_step_3 h_s0_program
clear h_step_2
-- exec_inst h_step_3
-- LDST.exec_advanced_simd_multiple_struct_post_indexed
simp only [exec_inst, state_simp_rules, minimal_theory, bitvec_rules, ↓reduceIte] at h_step_3
-- rw [CheckSPAligment_of_w_different] at h_step_3
sorry

----------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion Proofs/Test.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_s (inst : BitVec 32): ArmState :=
{ gpr := (fun (_ : BitVec 5) => 0#64),
sfp := (fun (_ : BitVec 5) => 0#128),
pc := 0#64,
pstate := (fun (_ : PFlag) => 0#1),
pstate := zero_pstate,
mem := (fun (_ : BitVec 64) => 0#8),
-- Program: BitVec 64 → Option (BitVec 32)
program := (fun (a : BitVec 64) => if a == some 0#64 then inst else none),
Expand Down
Loading

0 comments on commit 0a025d4

Please sign in to comment.