Skip to content

Commit

Permalink
refactor: make BitVec.carry take bitvector arguments (#3461)
Browse files Browse the repository at this point in the history
Every usage of `carry` followed the pattern: `carry _ x.toNat y.toNat`,
so we've refactorod `carry` to take the `BitVec`s as arguments, and made
the `toNat` part of its definition.
  • Loading branch information
alexkeizer authored Feb 22, 2024
1 parent 8193af3 commit 815200e
Showing 1 changed file with 15 additions and 22 deletions.
37 changes: 15 additions & 22 deletions src/Init/Data/BitVec/Bitblast.lean
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ private theorem mod_two_pow_lt (x i : Nat) : x % 2 ^ i < 2^i := Nat.mod_lt _ (Na

/-! ### Addition -/

/-- carry w x y c returns true if the `w` carry bit is true when computing `x + y + c`. -/
def carry (w x y : Nat) (c : Bool) : Bool := decide (x % 2^w + y % 2^w + c.toNat ≥ 2^w)
/-- carry i x y c returns true if the `i` carry bit is true when computing `x + y + c`. -/
def carry (i : Nat) (x y : BitVec w) (c : Bool) : Bool :=
decide (x.toNat % 2^i + y.toNat % 2^i + c.toNat ≥ 2^i)

@[simp] theorem carry_zero : carry 0 x y c = c := by
cases c <;> simp [carry, mod_one]
Expand All @@ -100,20 +101,18 @@ theorem adc_overflow_limit (x y i : Nat) (c : Bool) : x % 2^i + (y % 2^i + c.toN
rw [Nat.pow_succ]
omega

theorem carry_succ (w x y : Nat) (c : Bool) :
carry (succ w) x y c = atLeastTwo (x.testBit w) (y.testBit w) (carry w x y c) := by
simp only [carry, mod_two_pow_succ, atLeastTwo]
theorem carry_succ (i : Nat) (x y : BitVec w) (c : Bool) :
carry (i+1) x y c = atLeastTwo (x.getLsb i) (y.getLsb i) (carry i x y c) := by
simp only [carry, mod_two_pow_succ, atLeastTwo, getLsb]
simp only [Nat.pow_succ']
generalize testBit x w = xh
generalize testBit y w = yh
have sum_bnd : x%2^w + (y%2^w + c.toNat) < 2*2^w := by
simp only [← Nat.pow_succ']
exact adc_overflow_limit x y w c
cases xh <;> cases yh <;> (simp; omega)
have sum_bnd : x.toNat%2^i + (y.toNat%2^i + c.toNat) < 2*2^i := by
simp only [← Nat.pow_succ']
exact adc_overflow_limit ..
cases x.toNat.testBit i <;> cases y.toNat.testBit i <;> (simp; omega)

theorem getLsb_add_add_bool {i : Nat} (i_lt : i < w) (x y : BitVec w) (c : Bool) :
getLsb (x + y + zeroExtend w (ofBool c)) i =
Bool.xor (getLsb x i) (Bool.xor (getLsb y i) (carry i x.toNat y.toNat c)) := by
Bool.xor (getLsb x i) (Bool.xor (getLsb y i) (carry i x y c)) := by
let ⟨x, x_lt⟩ := x
let ⟨y, y_lt⟩ := y
simp only [getLsb, toNat_add, toNat_zeroExtend, i_lt, toNat_ofFin, toNat_ofBool,
Expand All @@ -134,27 +133,21 @@ theorem getLsb_add_add_bool {i : Nat} (i_lt : i < w) (x y : BitVec w) (c : Bool)

theorem getLsb_add {i : Nat} (i_lt : i < w) (x y : BitVec w) :
getLsb (x + y) i =
Bool.xor (getLsb x i) (Bool.xor (getLsb y i) (carry i x.toNat y.toNat false)) := by
Bool.xor (getLsb x i) (Bool.xor (getLsb y i) (carry i x y false)) := by
simpa using getLsb_add_add_bool i_lt x y false

theorem adc_spec (x y : BitVec w) (c : Bool) :
adc x y c = (carry w x.toNat y.toNat c, x + y + zeroExtend w (ofBool c)) := by
adc x y c = (carry w x y c, x + y + zeroExtend w (ofBool c)) := by
simp only [adc]
apply iunfoldr_replace
(fun i => carry i x.toNat y.toNat c)
(fun i => carry i x y c)
(x + y + zeroExtend w (ofBool c))
c
case init =>
simp [carry, Nat.mod_one]
cases c <;> rfl
case step =>
intro ⟨i, lt⟩
simp only [adcb, Prod.mk.injEq, carry_succ]
apply And.intro
case left =>
rw [testBit_toNat, testBit_toNat]
case right =>
simp [getLsb_add_add_bool lt]
simp [adcb, Prod.mk.injEq, carry_succ, getLsb_add_add_bool]

theorem add_eq_adc (w : Nat) (x y : BitVec w) : x + y = (adc x y false).snd := by
simp [adc_spec]
Expand Down

0 comments on commit 815200e

Please sign in to comment.