diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 662a51f942a3..1a7e582dcf2c 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -79,8 +79,9 @@ private theorem mod_two_pow_lt (x i : Nat) : x % 2 ^ i < 2^i := Nat.mod_lt _ (Na /-! ### Addition -/ -/-- carry w x y c returns true if the `w` carry bit is true when computing `x + y + c`. -/ -def carry (w x y : Nat) (c : Bool) : Bool := decide (x % 2^w + y % 2^w + c.toNat ≥ 2^w) +/-- carry i x y c returns true if the `i` carry bit is true when computing `x + y + c`. -/ +def carry (i : Nat) (x y : BitVec w) (c : Bool) : Bool := + decide (x.toNat % 2^i + y.toNat % 2^i + c.toNat ≥ 2^i) @[simp] theorem carry_zero : carry 0 x y c = c := by cases c <;> simp [carry, mod_one] @@ -100,20 +101,18 @@ theorem adc_overflow_limit (x y i : Nat) (c : Bool) : x % 2^i + (y % 2^i + c.toN rw [Nat.pow_succ] omega -theorem carry_succ (w x y : Nat) (c : Bool) : - carry (succ w) x y c = atLeastTwo (x.testBit w) (y.testBit w) (carry w x y c) := by - simp only [carry, mod_two_pow_succ, atLeastTwo] +theorem carry_succ (i : Nat) (x y : BitVec w) (c : Bool) : + carry (i+1) x y c = atLeastTwo (x.getLsb i) (y.getLsb i) (carry i x y c) := by + simp only [carry, mod_two_pow_succ, atLeastTwo, getLsb] simp only [Nat.pow_succ'] - generalize testBit x w = xh - generalize testBit y w = yh - have sum_bnd : x%2^w + (y%2^w + c.toNat) < 2*2^w := by - simp only [← Nat.pow_succ'] - exact adc_overflow_limit x y w c - cases xh <;> cases yh <;> (simp; omega) + have sum_bnd : x.toNat%2^i + (y.toNat%2^i + c.toNat) < 2*2^i := by + simp only [← Nat.pow_succ'] + exact adc_overflow_limit .. + cases x.toNat.testBit i <;> cases y.toNat.testBit i <;> (simp; omega) theorem getLsb_add_add_bool {i : Nat} (i_lt : i < w) (x y : BitVec w) (c : Bool) : getLsb (x + y + zeroExtend w (ofBool c)) i = - Bool.xor (getLsb x i) (Bool.xor (getLsb y i) (carry i x.toNat y.toNat c)) := by + Bool.xor (getLsb x i) (Bool.xor (getLsb y i) (carry i x y c)) := by let ⟨x, x_lt⟩ := x let ⟨y, y_lt⟩ := y simp only [getLsb, toNat_add, toNat_zeroExtend, i_lt, toNat_ofFin, toNat_ofBool, @@ -134,27 +133,21 @@ theorem getLsb_add_add_bool {i : Nat} (i_lt : i < w) (x y : BitVec w) (c : Bool) theorem getLsb_add {i : Nat} (i_lt : i < w) (x y : BitVec w) : getLsb (x + y) i = - Bool.xor (getLsb x i) (Bool.xor (getLsb y i) (carry i x.toNat y.toNat false)) := by + Bool.xor (getLsb x i) (Bool.xor (getLsb y i) (carry i x y false)) := by simpa using getLsb_add_add_bool i_lt x y false theorem adc_spec (x y : BitVec w) (c : Bool) : - adc x y c = (carry w x.toNat y.toNat c, x + y + zeroExtend w (ofBool c)) := by + adc x y c = (carry w x y c, x + y + zeroExtend w (ofBool c)) := by simp only [adc] apply iunfoldr_replace - (fun i => carry i x.toNat y.toNat c) + (fun i => carry i x y c) (x + y + zeroExtend w (ofBool c)) c case init => simp [carry, Nat.mod_one] cases c <;> rfl case step => - intro ⟨i, lt⟩ - simp only [adcb, Prod.mk.injEq, carry_succ] - apply And.intro - case left => - rw [testBit_toNat, testBit_toNat] - case right => - simp [getLsb_add_add_bool lt] + simp [adcb, Prod.mk.injEq, carry_succ, getLsb_add_add_bool] theorem add_eq_adc (w : Nat) (x y : BitVec w) : x + y = (adc x y false).snd := by simp [adc_spec]