Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add specification for AES-GCM #28

Merged
merged 6 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 5 additions & 102 deletions Arm/Insts/DPSFP/Crypto_aes.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,59 +8,14 @@ Author(s): Shilpi Goel, Yan Peng
import Arm.Decode
import Arm.Insts.Common
import Arm.BitVec
import Specs.AESCommon

----------------------------------------------------------------------

namespace DPSFP

open BitVec

def SBox :=
-- F E D C B A 9 8 7 6 5 4 3 2 1 0
[ 0x16bb54b00f2d99416842e6bf0d89a18c#128, -- F
0xdf2855cee9871e9b948ed9691198f8e1#128, -- E
0x9e1dc186b95735610ef6034866b53e70#128, -- D
0x8a8bbd4b1f74dde8c6b4a61c2e2578ba#128, -- C
0x08ae7a65eaf4566ca94ed58d6d37c8e7#128, -- B
0x79e4959162acd3c25c2406490a3a32e0#128, -- A
0xdb0b5ede14b8ee4688902a22dc4f8160#128, -- 9
0x73195d643d7ea7c41744975fec130ccd#128, -- 8
0xd2f3ff1021dab6bcf5389d928f40a351#128, -- 7
0xa89f3c507f02f94585334d43fbaaefd0#128, -- 6
0xcf584c4a39becb6a5bb1fc20ed00d153#128, -- 5
0x842fe329b3d63b52a05a6e1b1a2c8309#128, -- 4
0x75b227ebe28012079a059618c323c704#128, -- 3
0x1531d871f1e5a534ccf73f362693fdb7#128, -- 2
0xc072a49cafa2d4adf04759fa7dc982ca#128, -- 1
0x76abd7fe2b670130c56f6bf27b777c63#128 -- 0
]

def AESShiftRows (op : BitVec 128) : BitVec 128 :=
extractLsb 95 88 op ++ extractLsb 55 48 op ++
extractLsb 15 8 op ++ extractLsb 103 96 op ++
extractLsb 63 56 op ++ extractLsb 23 16 op ++
extractLsb 111 104 op ++ extractLsb 71 64 op ++
extractLsb 31 24 op ++ extractLsb 119 112 op ++
extractLsb 79 72 op ++ extractLsb 39 32 op ++
extractLsb 127 120 op ++ extractLsb 87 80 op ++
extractLsb 47 40 op ++ extractLsb 7 0 op

def AESSubBytes_aux (i : Nat) (op : BitVec 128) (out : BitVec 128)
: BitVec 128 :=
if h₀ : 16 <= i then
out
else
let idx := (extractLsb (i * 8 + 7) (i * 8) op).toNat
let val := extractLsb (idx * 8 + 7) (idx * 8) $ BitVec.flatten SBox
have h₁ : idx * 8 + 7 - idx * 8 = i * 8 + 7 - i * 8 := by omega
let out := BitVec.partInstall (i * 8 + 7) (i * 8) (h₁ ▸ val) out
have _ : 15 - i < 16 - i := by omega
AESSubBytes_aux (i + 1) op out
termination_by (16 - i)

def AESSubBytes (op : BitVec 128) : BitVec 128 :=
AESSubBytes_aux 0 op (BitVec.zero 128)

@[state_simp_rules]
def exec_aese
(inst : Crypto_aes_cls) (s : ArmState) : ArmState :=
Expand All @@ -69,7 +24,7 @@ def exec_aese
let operand1 := read_sfp 128 inst.Rd s
let operand2 := read_sfp 128 inst.Rn s
let result := operand1 ^^^ operand2
let result := AESSubBytes $ AESShiftRows result
let result := AESCommon.SubBytes $ AESCommon.ShiftRows result
-- State Updates
let s := write_sfp 128 inst.Rd result s
let s := write_pc ((read_pc s) + 4#64) s
Expand Down Expand Up @@ -97,8 +52,7 @@ def FFmul02 (b : BitVec 8) : BitVec 8 :=
]
let lo := b.toNat * 8
let hi := lo + 7
have h : hi - lo + 1 = 8 := by omega
h ▸ extractLsb hi lo $ BitVec.flatten FFmul_02
BitVec.cast (by omega) $ extractLsb hi lo $ BitVec.flatten FFmul_02

def FFmul03 (b : BitVec 8) : BitVec 8 :=
let FFmul_03 :=
Expand All @@ -122,61 +76,10 @@ def FFmul03 (b : BitVec 8) : BitVec 8 :=
]
let lo := b.toNat * 8
let hi := lo + 7
have h : hi - lo + 1 = 8 := by omega
h ▸ extractLsb hi lo $ BitVec.flatten FFmul_03

def AESMixColumns_aux (c : Nat)
(in0 : BitVec 32) (in1 : BitVec 32) (in2 : BitVec 32) (in3 : BitVec 32)
(out0 : BitVec 32) (out1 : BitVec 32) (out2 : BitVec 32) (out3 : BitVec 32)
: BitVec 32 × BitVec 32 × BitVec 32 × BitVec 32 :=
if h₀ : 4 <= c then
(out0, out1, out2, out3)
else
let lo := c * 8
let hi := lo + 7
have h₁ : hi - lo + 1 = 8 := by omega
let in0_byte := h₁ ▸ extractLsb hi lo in0
let in1_byte := h₁ ▸ extractLsb hi lo in1
let in2_byte := h₁ ▸ extractLsb hi lo in2
let in3_byte := h₁ ▸ extractLsb hi lo in3
let val0 := h₁.symm ▸ (FFmul02 in0_byte ^^^ FFmul03 in1_byte ^^^ in2_byte ^^^ in3_byte)
let out0 := BitVec.partInstall hi lo val0 out0
let val1 := h₁.symm ▸ (FFmul02 in1_byte ^^^ FFmul03 in2_byte ^^^ in3_byte ^^^ in0_byte)
let out1 := BitVec.partInstall hi lo val1 out1
let val2 := h₁.symm ▸ (FFmul02 in2_byte ^^^ FFmul03 in3_byte ^^^ in0_byte ^^^ in1_byte)
let out2 := BitVec.partInstall hi lo val2 out2
let val3 := h₁.symm ▸ (FFmul02 in3_byte ^^^ FFmul03 in0_byte ^^^ in1_byte ^^^ in2_byte)
let out3 := BitVec.partInstall hi lo val3 out3
have _ : 3 - c < 4 - c := by omega
AESMixColumns_aux (c + 1) in0 in1 in2 in3 out0 out1 out2 out3
termination_by (4 - c)
BitVec.cast (by omega) $ extractLsb hi lo $ BitVec.flatten FFmul_03

def AESMixColumns (op : BitVec 128) : BitVec 128 :=
let in0 :=
extractLsb 103 96 op ++ extractLsb 71 64 op ++
extractLsb 39 32 op ++ extractLsb 7 0 op
let in1 :=
extractLsb 111 104 op ++ extractLsb 79 72 op ++
extractLsb 47 40 op ++ extractLsb 15 8 op
let in2 :=
extractLsb 119 112 op ++ extractLsb 87 80 op ++
extractLsb 55 48 op ++ extractLsb 23 16 op
let in3 :=
extractLsb 127 120 op ++ extractLsb 95 88 op ++
extractLsb 63 56 op ++ extractLsb 31 24 op
let (out0, out1, out2, out3) :=
(BitVec.zero 32, BitVec.zero 32,
BitVec.zero 32, BitVec.zero 32)
let (out0, out1, out2, out3) :=
AESMixColumns_aux 0 in0 in1 in2 in3 out0 out1 out2 out3
extractLsb 31 24 out3 ++ extractLsb 31 24 out2 ++
extractLsb 31 24 out1 ++ extractLsb 31 24 out0 ++
extractLsb 23 16 out3 ++ extractLsb 23 16 out2 ++
extractLsb 23 16 out1 ++ extractLsb 23 16 out0 ++
extractLsb 15 8 out3 ++ extractLsb 15 8 out2 ++
extractLsb 15 8 out1 ++ extractLsb 15 8 out0 ++
extractLsb 7 0 out3 ++ extractLsb 7 0 out2 ++
extractLsb 7 0 out1 ++ extractLsb 7 0 out0
AESCommon.MixColumns op FFmul02 FFmul03

@[state_simp_rules]
def exec_aesmc
Expand Down
227 changes: 227 additions & 0 deletions Specs/AES.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author(s): Yan Peng
-/
import Arm.BitVec
import Arm.Insts.DPSFP.Crypto_aes
import Specs.AESCommon

-- References : https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.197-upd1.pdf
-- https://csrc.nist.gov/csrc/media/projects/cryptographic-standards-and-guidelines/documents/aes-development/rijndael-ammended.pdf
--
--------------------------------------------------
-- The NIST specification has the following rounds:
--
-- AddRoundKey key0
-- for k in key1 to key9
-- SubBytes
-- ShiftRows
-- MixColumns
-- AddRoundKey
-- SubBytes
-- ShiftRows
-- AddRoundKey key10
--
-- The Arm implementation has an optimization that commute intermediate steps:
--
-- for k in key0 to key8
-- AddRoundKey + ShiftRows + SubBytes (AESE k)
-- MixColumns (AESMC)
-- AddRoundKey + ShiftRows + SubBytes (AESE key9)
-- AddRoundKey key10
--
-- Note: SubBytes and ShiftRows are commutative because
-- SubBytes is a byte-wise operation
--
--------------------------------------------------

namespace AES

open BitVec

def WordSize := 32
def BlockSize := 128

-- General comment: Maybe consider Lists vs Vectors?
-- https://github.com/joehendrix/lean-crypto/blob/323ee9b1323deed5240762f4029700a246ecd9d5/lib/Crypto/Vector.lean#L96

def Rcon : List (BitVec WordSize) :=
[ 0x00000001#32,
0x00000002#32,
0x00000004#32,
0x00000008#32,
0x00000010#32,
0x00000020#32,
0x00000040#32,
0x00000080#32,
0x0000001b#32,
0x00000036#32 ]

-------------------------------------------------------
-- types

-- Key-Block-Round Combinations
structure KBR where
key_len : Nat
block_size : Nat
Nk := key_len / 32
Nb := block_size / 32
Nr : Nat
h : block_size = BlockSize
deriving DecidableEq, Repr

def AES128KBR : KBR :=
{key_len := 128, block_size := BlockSize, Nr := 10, h := by decide}
def AES192KBR : KBR :=
{key_len := 192, block_size := BlockSize, Nr := 12, h := by decide}
def AES256KBR : KBR :=
{key_len := 256, block_size := BlockSize, Nr := 14, h := by decide}

def KeySchedule : Type := List (BitVec WordSize)

-- Declare KeySchedule to be an instance HAppend
-- so we can apply `++` to KeySchedules propertly
instance : HAppend KeySchedule KeySchedule KeySchedule where
hAppend := List.append

-------------------------------------------------------

def sbox (ind : BitVec 8) : BitVec 8 :=
match_bv ind with
| [x:4, y:4] =>
have h : (x.toNat * 128 + y.toNat * 8 + 7) - (x.toNat * 128 + y.toNat * 8) + 1 = 8 :=
by omega
h ▸ extractLsb
(x.toNat * 128 + y.toNat * 8 + 7)
(x.toNat * 128 + y.toNat * 8) $ BitVec.flatten AESCommon.SBOX
| _ => ind -- unreachable case

-- Note: The RotWord function is written in little endian
def RotWord (w : BitVec WordSize) : BitVec WordSize :=
match_bv w with
| [a3:8, a2:8, a1:8, a0:8] => a0 ++ a3 ++ a2 ++ a1
| _ => w -- unreachable case

def SubWord (w : BitVec WordSize) : BitVec WordSize :=
match_bv w with
| [a3:8, a2:8, a1:8, a0:8] => (sbox a3) ++ (sbox a2) ++ (sbox a1) ++ (sbox a0)
| _ => w -- unreachable case

protected def InitKey {Param : KBR} (i : Nat) (key : BitVec Param.key_len)
(acc : KeySchedule) : KeySchedule :=
if h₀ : Param.Nk ≤ i then acc
else
have h₁ : i * 32 + 32 - 1 - i * 32 + 1 = WordSize := by
simp only [WordSize]; omega
let wd := h₁ ▸ extractLsb (i * 32 + 32 - 1) (i * 32) key
let (x:KeySchedule) := [wd]
have _ : Param.Nk - (i + 1) < Param.Nk - i := by omega
AES.InitKey (Param := Param) (i + 1) key (acc ++ x)
termination_by (Param.Nk - i)

protected def KeyExpansion_helper {Param : KBR} (i : Nat) (ks : KeySchedule)
: KeySchedule :=
if h : 4 * Param.Nr + 4 ≤ i then
ks
else
let tmp := List.get! ks (i - 1)
let tmp :=
if i % Param.Nk == 0 then
(SubWord (RotWord tmp)) ^^^ (List.get! Rcon $ (i / Param.Nk) - 1)
else if Param.Nk > 6 && i % Param.Nk == 4 then
SubWord tmp
else
tmp
let res := (List.get! ks (i - Param.Nk)) ^^^ tmp
let ks := List.append ks [ res ]
have _ : 4 * Param.Nr + 4 - (i + 1) < 4 * Param.Nr + 4 - i := by omega
AES.KeyExpansion_helper (Param := Param) (i + 1) ks
termination_by (4 * Param.Nr + 4 - i)

def KeyExpansion {Param : KBR} (key : BitVec Param.key_len)
: KeySchedule :=
let seeded := AES.InitKey (Param := Param) 0 key []
AES.KeyExpansion_helper (Param := Param) Param.Nk seeded

def SubBytes {Param : KBR} (state : BitVec Param.block_size)
: BitVec Param.block_size :=
have h : Param.block_size = 128 := by simp only [Param.h, BlockSize]
h ▸ AESCommon.SubBytes (h ▸ state)

def ShiftRows {Param : KBR} (state : BitVec Param.block_size)
: BitVec Param.block_size :=
have h : Param.block_size = 128 := by simp only [Param.h, BlockSize]
h ▸ AESCommon.ShiftRows (h ▸ state)

def XTimes (bv : BitVec 8) : BitVec 8 :=
let res := truncate 7 bv ++ 0b0#1
if getLsb bv 7 then res ^^^ 0b00011011#8 else res

def MixColumns {Param : KBR} (state : BitVec Param.block_size)
: BitVec Param.block_size :=
have h : Param.block_size = 128 := by simp only [Param.h, BlockSize]
let FFmul02 := fun (x : BitVec 8) => XTimes x
let FFmul03 := fun (x : BitVec 8) => x ^^^ XTimes x
h ▸ AESCommon.MixColumns (h ▸ state) FFmul02 FFmul03

-- TODO: looks like a SAT/SMT problem
protected theorem FFmul02_equiv : (fun x => XTimes x) = DPSFP.FFmul02 := by
funext x
simp only [XTimes, DPSFP.FFmul02]
sorry

-- TODO: looks like a SAT/SMT problem
protected theorem FFmul03_equiv : (fun x => x ^^^ XTimes x) = DPSFP.FFmul03 := by
funext x
simp only [XTimes, DPSFP.FFmul03]
sorry
shigoel marked this conversation as resolved.
Show resolved Hide resolved


theorem MixColumns_table_lookup_equiv {Param : KBR}
(state : BitVec Param.block_size):
have h : Param.block_size = 128 := by simp only [Param.h, BlockSize]
MixColumns (Param := Param) state = h ▸ DPSFP.AESMixColumns (h ▸ state) := by
simp only [MixColumns, DPSFP.AESMixColumns]
rw [AES.FFmul02_equiv, AES.FFmul03_equiv]

def AddRoundKey {Param : KBR} (state : BitVec Param.block_size)
(roundKey : BitVec Param.block_size) : BitVec Param.block_size :=
state ^^^ roundKey

protected def getKey {Param : KBR} (n : Nat) (w : KeySchedule) : BitVec Param.block_size :=
let ind := 4 * n
have h : WordSize + WordSize + WordSize + WordSize = Param.block_size := by
simp only [WordSize, BlockSize, Param.h]
h ▸ ((List.get! w (ind + 3)) ++ (List.get! w (ind + 2)) ++
(List.get! w (ind + 1)) ++ (List.get! w ind))

protected def AES_encrypt_with_ks_loop {Param : KBR} (round : Nat)
(state : BitVec Param.block_size) (w : KeySchedule)
: BitVec Param.block_size :=
if Param.Nr ≤ round then
state
else
let state := SubBytes state
let state := ShiftRows state
let state := MixColumns state
let state := AddRoundKey state $ AES.getKey round w
AES.AES_encrypt_with_ks_loop (Param := Param) (round + 1) state w
termination_by (Param.Nr - round)

def AES_encrypt_with_ks {Param : KBR} (input : BitVec Param.block_size)
(w : KeySchedule) : BitVec Param.block_size :=
have h₀ : WordSize + WordSize + WordSize + WordSize = Param.block_size := by
simp only [WordSize, BlockSize, Param.h]
let state := AddRoundKey input $ (h₀ ▸ AES.getKey 0 w)
let state := AES.AES_encrypt_with_ks_loop (Param := Param) 1 state w
let state := SubBytes (Param := Param) state
let state := ShiftRows (Param := Param) state
AddRoundKey state $ h₀ ▸ AES.getKey Param.Nr w

def AES_encrypt {Param : KBR} (input : BitVec Param.block_size)
(key : BitVec Param.key_len) : BitVec Param.block_size :=
let ks := KeyExpansion (Param := Param) key
AES_encrypt_with_ks (Param := Param) input ks

end AES
Loading
Loading