Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: make BitVec.carry take bitvector arguments #3461

Merged
merged 1 commit into from
Feb 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 15 additions & 22 deletions src/Init/Data/BitVec/Bitblast.lean
Original file line number Diff line number Diff line change
@@ -79,8 +79,9 @@ private theorem mod_two_pow_lt (x i : Nat) : x % 2 ^ i < 2^i := Nat.mod_lt _ (Na

/-! ### Addition -/

/-- carry w x y c returns true if the `w` carry bit is true when computing `x + y + c`. -/
def carry (w x y : Nat) (c : Bool) : Bool := decide (x % 2^w + y % 2^w + c.toNat ≥ 2^w)
/-- carry i x y c returns true if the `i` carry bit is true when computing `x + y + c`. -/
def carry (i : Nat) (x y : BitVec w) (c : Bool) : Bool :=
decide (x.toNat % 2^i + y.toNat % 2^i + c.toNat ≥ 2^i)

@[simp] theorem carry_zero : carry 0 x y c = c := by
cases c <;> simp [carry, mod_one]
@@ -100,20 +101,18 @@ theorem adc_overflow_limit (x y i : Nat) (c : Bool) : x % 2^i + (y % 2^i + c.toN
rw [Nat.pow_succ]
omega

theorem carry_succ (w x y : Nat) (c : Bool) :
carry (succ w) x y c = atLeastTwo (x.testBit w) (y.testBit w) (carry w x y c) := by
simp only [carry, mod_two_pow_succ, atLeastTwo]
theorem carry_succ (i : Nat) (x y : BitVec w) (c : Bool) :
carry (i+1) x y c = atLeastTwo (x.getLsb i) (y.getLsb i) (carry i x y c) := by
simp only [carry, mod_two_pow_succ, atLeastTwo, getLsb]
simp only [Nat.pow_succ']
generalize testBit x w = xh
generalize testBit y w = yh
have sum_bnd : x%2^w + (y%2^w + c.toNat) < 2*2^w := by
simp only [← Nat.pow_succ']
exact adc_overflow_limit x y w c
cases xh <;> cases yh <;> (simp; omega)
have sum_bnd : x.toNat%2^i + (y.toNat%2^i + c.toNat) < 2*2^i := by
simp only [← Nat.pow_succ']
exact adc_overflow_limit ..
cases x.toNat.testBit i <;> cases y.toNat.testBit i <;> (simp; omega)

theorem getLsb_add_add_bool {i : Nat} (i_lt : i < w) (x y : BitVec w) (c : Bool) :
getLsb (x + y + zeroExtend w (ofBool c)) i =
Bool.xor (getLsb x i) (Bool.xor (getLsb y i) (carry i x.toNat y.toNat c)) := by
Bool.xor (getLsb x i) (Bool.xor (getLsb y i) (carry i x y c)) := by
let ⟨x, x_lt⟩ := x
let ⟨y, y_lt⟩ := y
simp only [getLsb, toNat_add, toNat_zeroExtend, i_lt, toNat_ofFin, toNat_ofBool,
@@ -134,27 +133,21 @@ theorem getLsb_add_add_bool {i : Nat} (i_lt : i < w) (x y : BitVec w) (c : Bool)

theorem getLsb_add {i : Nat} (i_lt : i < w) (x y : BitVec w) :
getLsb (x + y) i =
Bool.xor (getLsb x i) (Bool.xor (getLsb y i) (carry i x.toNat y.toNat false)) := by
Bool.xor (getLsb x i) (Bool.xor (getLsb y i) (carry i x y false)) := by
simpa using getLsb_add_add_bool i_lt x y false

theorem adc_spec (x y : BitVec w) (c : Bool) :
adc x y c = (carry w x.toNat y.toNat c, x + y + zeroExtend w (ofBool c)) := by
adc x y c = (carry w x y c, x + y + zeroExtend w (ofBool c)) := by
simp only [adc]
apply iunfoldr_replace
(fun i => carry i x.toNat y.toNat c)
(fun i => carry i x y c)
(x + y + zeroExtend w (ofBool c))
c
case init =>
simp [carry, Nat.mod_one]
cases c <;> rfl
case step =>
intro ⟨i, lt⟩
simp only [adcb, Prod.mk.injEq, carry_succ]
apply And.intro
case left =>
rw [testBit_toNat, testBit_toNat]
case right =>
simp [getLsb_add_add_bool lt]
simp [adcb, Prod.mk.injEq, carry_succ, getLsb_add_add_bool]

theorem add_eq_adc (w : Nat) (x y : BitVec w) : x + y = (adc x y false).snd := by
simp [adc_spec]