Skip to content

Commit

Permalink
Merge pull request #27 from aqjune-aws/wip
Browse files Browse the repository at this point in the history
Change the program type in `ArmState` to `Map`
  • Loading branch information
shigoel authored Mar 27, 2024
2 parents 857ed82 + 39be0fa commit b148912
Show file tree
Hide file tree
Showing 11 changed files with 72 additions and 124 deletions.
4 changes: 4 additions & 0 deletions Arm/BitVec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ attribute [bitvec_rules] BitVec.extractLsb_ofFin
attribute [bitvec_rules] BitVec.zeroExtend_eq
attribute [bitvec_rules] BitVec.ofFin.injEq
attribute [bitvec_rules] BitVec.extractLsb_toNat
attribute [bitvec_rules] BitVec.truncate_or
attribute [bitvec_rules] BitVec.zeroExtend_zeroExtend_of_le
attribute [bitvec_rules] BitVec.zeroExtend_eq

-- attribute [bitvec_rules] add_ofFin

----------------------------------------------------------------------
Expand Down
5 changes: 2 additions & 3 deletions Arm/Cosim.lean
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def init_cosim_state : ArmState :=
pc := 0#64,
pstate := zero_pstate,
mem := (fun (_ : BitVec 64) => 0#8),
program := (fun (_ : BitVec 64) => none),
program := Map.empty,
error := StateError.None }

/-- A structure to hold both the input and output values for a
Expand Down Expand Up @@ -73,8 +73,7 @@ def init_flags (flags : BitVec 4) (s : ArmState) : ArmState := Id.run do
/-- Initialize an ArmState for cosimulation from a given regState. -/
def regState_to_armState (r : regState) : ArmState :=
let s := init_gprs r.gpr (init_flags r.nzcv (init_sfps r.sfp init_cosim_state))
let s := { s with program :=
(fun (a : BitVec 64) => if a == some 0#64 then r.inst else none) }
let s := { s with program := def_program [(0x0#64, r.inst)] }
s

def bitvec_to_hex (x : BitVec n) : String :=
Expand Down
3 changes: 3 additions & 0 deletions Arm/Map.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ A simple Map-like type based on lists

def Map (α : Type u) (β : Type v) := List (α × β)

instance [x : Repr (List (α × β))] : Repr (Map α β) where
reprPrec := x.reprPrec

def Map.empty : Map α β := []

def Map.find? [DecidableEq α] (m : Map α β) (a' : α) : Option β :=
Expand Down
4 changes: 3 additions & 1 deletion Arm/MinTheory.lean
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,16 @@ attribute [minimal_theory] bne_self_eq_false'
attribute [minimal_theory] decide_False
attribute [minimal_theory] decide_True
attribute [minimal_theory] bne_iff_ne
attribute [minimal_theory] Nat.le_zero_eq

attribute [minimal_theory] Nat.le_zero_eq
attribute [minimal_theory] Nat.zero_add
attribute [minimal_theory] Nat.zero_eq
attribute [minimal_theory] Nat.succ.injEq
attribute [minimal_theory] Nat.succ_ne_zero
attribute [minimal_theory] Nat.sub_zero

attribute [minimal_theory] Nat.le_refl

@[minimal_theory]
theorem option_get_bang_of_some [Inhabited α] (v : α) :
Option.get! (some v) = v := by rfl
Expand Down
34 changes: 13 additions & 21 deletions Arm/State.lean
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,10 @@ structure ArmState where
pstate : PState
-- Memory: maps 64-bit addresses to bytes
mem : Store (BitVec 64) (BitVec 8)
-- Program: maps 64-bit addresses to 32-bit instructions. None is
-- returned when no instruction is present at a specified address.
-- Program: maps 64-bit addresses to 32-bit instructions.
-- Note that we have the following assumption baked into our machine model:
-- the program is always disjoint from the rest of the memory.
program : Store (BitVec 64) (Option (BitVec 32))
program : Map (BitVec 64) (BitVec 32)

-- The error field is an artifact of this model; it is set to a
-- non-None value when some irrecoverable error is encountered
Expand Down Expand Up @@ -199,7 +198,7 @@ def write_base_flag (flag : PFlag) (val : BitVec 1) (s : ArmState) : ArmState :=
-- Fetch the instruction at address addr.
@[irreducible]
def fetch_inst (addr : BitVec 64) (s : ArmState) : Option (BitVec 32) :=
read_store addr s.program
s.program.find? addr

-- Error --

Expand Down Expand Up @@ -259,8 +258,7 @@ def w (fld : StateField) (v : (state_value fld)) (s : ArmState) : ArmState :=
| ERR => write_base_error v s

@[state_simp_rules]
theorem r_of_w_same (fld : StateField) (v : (state_value fld)) (s : ArmState) :
r fld (w fld v s) = v := by
theorem r_of_w_same : r fld (w fld v s) = v := by
unfold r w
unfold read_base_gpr write_base_gpr
unfold read_base_sfp write_base_sfp
Expand All @@ -270,8 +268,7 @@ theorem r_of_w_same (fld : StateField) (v : (state_value fld)) (s : ArmState) :
split <;> (repeat (split <;> simp_all!))

@[state_simp_rules]
theorem r_of_w_different (fld1 fld2 : StateField) (v : (state_value fld2)) (s : ArmState)
(h : fld1 ≠ fld2) :
theorem r_of_w_different (h : fld1 ≠ fld2) :
r fld1 (w fld2 v s) = r fld1 s := by
unfold r w
unfold read_base_gpr write_base_gpr
Expand All @@ -283,8 +280,7 @@ theorem r_of_w_different (fld1 fld2 : StateField) (v : (state_value fld2)) (s :
split <;> (repeat (split <;> simp_all!))

@[state_simp_rules]
theorem w_of_w_shadow (fld : StateField) (v1 v2 : (state_value fld)) (s : ArmState) :
w fld v2 (w fld v1 s) = w fld v2 s := by
theorem w_of_w_shadow : w fld v2 (w fld v1 s) = w fld v2 s := by
unfold w
unfold write_base_gpr
unfold write_base_sfp
Expand All @@ -294,8 +290,7 @@ theorem w_of_w_shadow (fld : StateField) (v1 v2 : (state_value fld)) (s : ArmSta
(repeat (split <;> simp_all!))

@[state_simp_rules]
theorem w_irrelevant (fld : StateField) (v1 v2 : (state_value fld)) (s : ArmState) :
w fld (r fld s) s = s := by
theorem w_irrelevant : w fld (r fld s) s = s := by
unfold r w
unfold read_base_gpr write_base_gpr
unfold read_base_sfp write_base_sfp
Expand All @@ -305,8 +300,7 @@ theorem w_irrelevant (fld : StateField) (v1 v2 : (state_value fld)) (s : ArmStat
repeat (split <;> simp_all)

@[state_simp_rules]
theorem fetch_inst_of_w (addr : BitVec 64) (fld : StateField) (val : (state_value fld)) (s : ArmState) :
fetch_inst addr (w fld val s) = fetch_inst addr s := by
theorem fetch_inst_of_w : fetch_inst addr (w fld val s) = fetch_inst addr s := by
unfold fetch_inst w
unfold write_base_gpr
unfold write_base_sfp
Expand All @@ -317,10 +311,9 @@ theorem fetch_inst_of_w (addr : BitVec 64) (fld : StateField) (val : (state_valu

-- There is no StateField that overwrites the program.
@[state_simp_rules]
theorem w_program (sf : StateField) (v : state_value sf) (s : ArmState):
(w sf v s).program = s.program := by
theorem w_program : (w fld v s).program = s.program := by
intros
cases sf <;> unfold w <;> simp
cases fld <;> unfold w <;> simp
· unfold write_base_gpr; simp
· unfold write_base_sfp; simp
· unfold write_base_pc; simp
Expand Down Expand Up @@ -464,10 +457,9 @@ where
| (addr, _) :: p, some max => if addr > max then loop p (some addr) else loop p (some max)

theorem fetch_inst_from_program
{address: BitVec 64} {program : program}
(h_program : s.program = program.find?) :
fetch_inst address s = program.find? address := by
unfold fetch_inst read_store
{address: BitVec 64} :
fetch_inst address s = s.program.find? address := by
unfold fetch_inst
simp_all!

end Load_program_and_fetch_inst
Expand Down
44 changes: 22 additions & 22 deletions Proofs/MultiInsts.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,31 @@ def test_program : program :=
(0x126514#64 , 0x4ea21c5c#32), -- mov v28.16b, v2.16b
(0x126518#64 , 0x4ea31c7d#32)] -- mov v29.16b, v3.16b

theorem one_asm_snippet_sym_helper1 (q0_var : BitVec 128) :
zeroExtend 128 (zeroExtend 128 (zeroExtend 128 q0_var ||| zeroExtend 128 q0_var)) = zeroExtend 128 q0_var := by
sorry -- auto

theorem one_asm_snippet_sym_helper2 (q0_var : BitVec 128) :
q0_var ||| q0_var = q0_var := by sorry -- auto

-- Todo: use sym_n to prove this theorem.
theorem small_asm_snippet_sym (s : ArmState)
(h_pc : read_pc s = 0x12650c#64)
(h_program : s.program = test_program.find?)
(h_s_ok : read_err s = StateError.None)
(h_s' : s' = run 4 s) :
read_sfp 128 26#5 s' = read_sfp 128 0#5 s ∧
read_err s' = StateError.None := by
FIXME
-- iterate 4 (sym1 [h_program])
sym1 [h_program]
sym1 [h_program]
sym1 [h_program]
sym1 [h_program]
-- Wrapping up the result:
-- generalize (r (StateField.SFP 0#5) s) = q0_var; unfold state_value at q0_var; simp at q0_var
-- try (simp [one_asm_snippet_sym_helper1])
theorem small_asm_snippet_sym (s0 s_final : ArmState)
(h_s0_pc : read_pc s0 = 0x12650c#64)
(h_s0_program : s0.program = test_program)
(h_s0_ok : read_err s0 = StateError.None)
(h_run : s_final = run 4 s0) :
read_sfp 128 26#5 s_final = read_sfp 128 0#5 s0 ∧
read_err s_final = StateError.None := by
-- Prelude
simp_all only [state_simp_rules, -h_run]
-- Symbolic Simulation
sym_n 4 h_s0_program
-- Wrapping up the result:
unfold run at h_run
subst s_final s_4
apply And.intro
· simp_all only [state_simp_rules, minimal_theory, bitvec_rules]
simp only [one_asm_snippet_sym_helper2]
done
-- FIXME: Why does state_simp_rules not work here? Why do we need
-- an explicit rw?
(try (repeat (rw [r_of_w_different (by decide)])))
(try (rw [r_of_w_same]))
· simp_all only [state_simp_rules, minimal_theory, bitvec_rules]
done

end multi_insts_proofs
18 changes: 4 additions & 14 deletions Proofs/Sha512_block_armv8.lean
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def sha512_program_test_1 : program :=
theorem sha512_program_test_1_sym (s0 s_final : ArmState)
(h_s0_pc : read_pc s0 = 0x126538#64)
(h_s0_sp_aligned : CheckSPAlignment s0 = true)
(h_s0_program : s0.program = sha512_program_test_1.find?)
(h_s0_program : s0.program = sha512_program_test_1)
(h_s0_ok : read_err s0 = StateError.None)
(h_run : s_final = run 4 s0) :
read_err s_final = StateError.None := by
Expand Down Expand Up @@ -61,7 +61,7 @@ def sha512_program_test_2 : program :=
-- set_option profiler true in
theorem sha512_program_test_2_sym (s0 s_final : ArmState)
(h_s0_pc : read_pc s0 = 0x126538#64)
(h_s0_program : s0.program = sha512_program_test_2.find?)
(h_s0_program : s0.program = sha512_program_test_2)
(h_s0_ok : read_err s0 = StateError.None)
(h_run : s_final = run 6 s0) :
read_err s_final = StateError.None := by
Expand All @@ -79,16 +79,6 @@ theorem sha512_program_test_2_sym (s0 s_final : ArmState)

-- Test 3:

variable (write : (n : Nat) → BitVec 64 → BitVec (n * 8) → TypeType)

theorem write_simplify_test_0 (a x y : BitVec 64)
(h : ((8 * 8) + 8 * 8) = 2 * ((8 * 8) / 8) * 8) :
write (2 * ((8 * 8) / 8)) a (BitVec.cast h (zeroExtend (8 * 8) x ++ (zeroExtend (8 * 8) y))) s
=
write 16 a (x ++ y) s := by
simp only [zeroExtend_eq, BitVec.cast_eq]


def sha512_program_test_3 : program :=
def_program
[(0x1264c0#64 , 0xa9bf7bfd#32), -- stp x29, x30, [sp, #-16]!
Expand All @@ -105,7 +95,7 @@ theorem sha512_block_armv8_test_3_sym (s0 s_final : ArmState)
(h_s0_ok : read_err s0 = StateError.None)
(h_s0_sp_aligned : CheckSPAlignment s0 = true)
(h_s0_pc : read_pc s0 = 0x1264c0#64)
(h_s0_program : s0.program = sha512_program_test_3.find?)
(h_s0_program : s0.program = sha512_program_test_3)
(h_run : s_final = run 4 s0) :
read_err s_final = StateError.None := by
-- Prelude
Expand Down Expand Up @@ -157,7 +147,7 @@ theorem sha512_block_armv8_test_4_sym (s0 s_final : ArmState)
(h_s0_ok : read_err s0 = StateError.None)
(h_s0_sp_aligned : CheckSPAlignment s0 = true)
(h_s0_pc : read_pc s0 = 0x1264c0#64)
(h_s0_program : s0.program = sha512_program_map.find?)
(h_s0_program : s0.program = sha512_program_map)
(h_run : s_final = run 32 s0) :
read_err s_final = StateError.None := by
-- Prelude
Expand Down
2 changes: 1 addition & 1 deletion Proofs/Test.lean
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_s (inst : BitVec 32): ArmState :=
pstate := zero_pstate,
mem := (fun (_ : BitVec 64) => 0#8),
-- Program: BitVec 64 → Option (BitVec 32)
program := (fun (a : BitVec 64) => if a == some 0#64 then inst else none),
program := [(0#64, inst)],
error := StateError.None }

-- 0x91000421#32: add x1, x1, 1
Expand Down
60 changes: 9 additions & 51 deletions Tactics/Sym.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,6 @@ import Lean.Expr

open BitVec

-- sym1 tactic symbolically simulates a single instruction.
syntax "sym1" "[" term "]" : tactic
macro_rules
| `(tactic| sym1 [$h_program:term]) =>
`(tactic|
(try simp_all (config := {decide := true, ground := true}) only [state_simp_rules]);
unfold run;
simp_all only [stepi, state_simp_rules];
(try rw [fetch_inst_from_program $h_program]);
(try simp (config := {decide := true, ground := true}) only);
-- After exec_inst is opened up, the exec functions of the
-- instructions which are tagged with simp will also open up
-- here.
simp only [exec_inst, state_simp_rules];
-- (try simp_all (config := {decide := true, ground := true}) only);
-- (try simp only [ne_eq, r_of_w_different, r_of_w_same, w_of_w_shadow, w_irrelevant])
(try simp_all (config := {decide := true, ground := true}) only [state_simp_rules]))

-- init_next_step breaks 'h_s: s_next = run n s' into 'run (n-1) s' and one step.
macro "init_next_step" h_s:ident : tactic =>
`(tactic|
Expand All @@ -43,48 +25,24 @@ macro "init_next_step" h_s:ident : tactic =>
-- simplifies fetch_inst and decode_raw_inst.
macro "fetch_and_decode_inst" h_step:ident h_program:ident : tactic =>
`(tactic|
(-- unfold stepi at $h_step:ident
-- rw [$h_s_ok:ident] at $h_step:ident
-- dsimp at $h_step:ident -- reduce let and match
-- rw [$h_pc:ident] at $h_step:ident
-- rw [fetch_inst_from_program $h_program:ident] at $h_step:ident
-- -- Note: this often times out. It tries to evaluate, e.g.,
-- -- Std.RBMap.find? sha512_program_test_2 1205560#64
-- -- which easily becomes hard.
-- simp (config := {ground := true}) at $h_step:ident
-- repeat (rw [BitVec.ofFin_eq_ofNat] at $h_step:ident)
simp only [*, stepi, state_simp_rules, minimal_theory, bitvec_rules] at $h_step:ident
rw [fetch_inst_from_program $h_program:ident] at $h_step:ident
conv at $h_step:ident =>
(simp only [*, stepi, state_simp_rules, minimal_theory, bitvec_rules] at $h_step:ident
rw [fetch_inst_from_program] at $h_step:ident
simp only [$h_program:ident] at $h_step:ident
conv at $h_step:ident =>
pattern Map.find? _ _
simp (config := {ground := true}) only
-- simp/ground leaves bitvecs' structure exposed, so we use
-- BitVec.ofFin_eq_ofNat to fold them back into their canonical
-- form.
simp only [BitVec.ofFin_eq_ofNat]
(try dsimp only at $h_step:ident);
conv at $h_step:ident =>
(try dsimp only at $h_step:ident);
conv at $h_step:ident =>
pattern decode_raw_inst _
simp (config := {ground := true}) only
simp only [BitVec.ofFin_eq_ofNat]
(try dsimp only at $h_step:ident)))
(try dsimp only at $h_step:ident)))

-- Given hstep which is the result of fetch_and_decode_inst, exec_inst executes
-- an instruction and generates 's_next = w .. (w .. (... s))'.
macro "exec_inst" h_step:ident : tactic =>
`(tactic|
(-- unfold exec_inst at $h_step:ident
-- -- A simple case where simp works (e.g., Arm.DPI)
-- try (simp (config := {ground := true, decide := true}) at $h_step:ident)
-- -- A complicated case (e.g., Arm.LDST)
-- try (simp at $h_step:ident; (conv at $h_step:ident =>
-- arg 2
-- apply if_true
-- apply $st_next:ident); simp [$h_sp_aligned:ident] at $h_step:ident)
simp only [*, exec_inst, state_simp_rules, minimal_theory, bitvec_rules] at $h_step:ident;
(try simp (config := {ground := true}) only [↓reduceIte, state_simp_rules, minimal_theory, bitvec_rules] at $h_step:ident)
-- Fold back any exposed bitvecs into canonical forms.
(try simp only [BitVec.ofFin_eq_ofNat] at $h_step:ident)))
(simp only [*, exec_inst, state_simp_rules, minimal_theory, bitvec_rules] at $h_step:ident;
(try (repeat simp (config := {ground := true}) only [↓reduceIte, state_simp_rules, minimal_theory, bitvec_rules] at $h_step:ident))))

-- Given h_step which is 's_next = w .. (w .. (... s))', it creates assumptions
-- 'read .. s_next = value'.
Expand Down
Loading

0 comments on commit b148912

Please sign in to comment.