From 247349cd4e7b27cb637c7a246e8ef634166ef49b Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Mon, 3 Jun 2024 01:58:07 +0100 Subject: [PATCH 01/64] chore: start writing recurrences for mul --- src/Init/Data/BitVec/Lemmas.lean | 40 ++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 4b205933df73..c79dd1061d93 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -1085,6 +1085,46 @@ theorem ofInt_mul {n} (x y : Int) : BitVec.ofInt n (x * y) = apply eq_of_toInt_eq simp +def mulRec (l r : BitVec w) (s : Nat) : BitVec w := + let cur := if r.getLsb s then (l <<< s) else 0 + match s with + | 0 => cur + | s + 1 => mulRec l r s + cur + +@[simp] +theorem shiftLeft_zero_eq (x : BitVec w) : x <<< 0 = x := by + apply eq_of_toNat_eq + simp + +theorem mulRec_zero_eq (l r : BitVec w) : + mulRec l r 0 = if r.getLsb 0 then l else 0 := by + simp [mulRec] + +theorem mulRec_succ_eq (l r : BitVec w) (s : Nat) : + mulRec l r (s + 1) = mulRec l r s + if r.getLsb (s + 1) then (l <<< (s + 1)) else 0 := by + simp [mulRec] + +theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) : + (mulRec l r s) = l * ((r.truncate s).signExtend w) := by + induction w generalizing s + case zero => apply Subsingleton.elim + case succ w' hw => + induction s + case zero => + simp [mulRec, mulRec_zero_eq, signExtend, truncate] + sorry + case succ s' hs => sorry + +-- Provable with sign extend theory. +@[simp] +theorem signExtend_eq_self (x : BitVec w) : x.signExtend w = x := sorry + +theorem getLsb_mul (x y : BitVec w) (i : Nat) : + (x * y).getLsb i = (mulRec x y w).getLsb i := by + rw [mulRec_eq_mul_signExtend_truncate] + simp [zeroExtend_eq] + + /-! ### le and lt -/ @[bv_toNat] theorem le_def (x y : BitVec n) : From 55cfc6dfd6c0a3a6d573ec03b89f6ec7bdf7aec4 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Mon, 3 Jun 2024 02:14:21 +0100 Subject: [PATCH 02/64] chore: write udiv recursion eqn --- src/Init/Data/BitVec/Lemmas.lean | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index c79dd1061d93..bec6c0b74b61 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -726,6 +726,32 @@ theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) : Nat.not_lt, decide_eq_true_eq] omega +/-! ### udiv -/ + +theorem udiv_eq {x y : BitVec n} : + x.udiv y = BitVec.ofNat n (x.toNat / y.toNat) := by + apply BitVec.eq_of_toNat_eq + simp only [udiv, toNat_ofNatLt, toNat_ofNat] + rw [Nat.mod_eq_of_lt] + exact Nat.lt_of_le_of_lt (Nat.div_le_self ..) (by omega) + +/-- The remainder `rem` obeys the euclidean algorithm equation on computing `l.udiv r`. -/ +def udivDivisor (l r rem : BitVec w) : Prop := + rem < r ∧ + let l' := l.signExtend (2*w) + let r' := r.signExtend (2*w) + let rem' := rem.signExtend (2*w) + l' = (l' / r') * r' + rem' + +/-- Such a remainder always exists. -/ +theorem udiv_euclid_eqn_exists (l r : BitVec w) : + ∃ (rem : BitVec w), udivDivisor l r rem := sorry + +/-- Such a remainder is unique. -/ +theorem udiv_euclid_eqn_unique (l r rem rem' : BitVec w) + (hrem : udivDivisor l r rem) (hrem' : udivDivisor l r rem') : + rem = rem' := sorry + /-! ### append -/ theorem append_def (x : BitVec v) (y : BitVec w) : From 2e5ea9fe930b3a86a5263bacdbf1cedb9174f25a Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Mon, 3 Jun 2024 17:31:37 +0100 Subject: [PATCH 03/64] chore: add mul verified statement from z3 --- src/Init/Data/BitVec/Lemmas.lean | 38 +++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index bec6c0b74b61..3d0a304eef96 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -1130,8 +1130,42 @@ theorem mulRec_succ_eq (l r : BitVec w) (s : Nat) : mulRec l r (s + 1) = mulRec l r s + if r.getLsb (s + 1) then (l <<< (s + 1)) else 0 := by simp [mulRec] +/-! +#!/usr/bin/env python3 +# Check that 'hargonix-recurrences-statements' actually has the right statements. +# https://github.com/opencompl/lean4/pull/6 +# Theorems from: https://www21.in.tum.de/teaching/sar/SS20/7.pdf +from z3 import * + +# Define the `mulRec` function in Z3py +def mulRec(l : BitVecRef, r : BitVecRef, s : int): + # import pudb; pudb.set_trace() + assert isinstance(s, int) + assert isinstance(l, BitVecRef) + assert isinstance(r, BitVecRef) + cur = If(Extract(s, s, r) == 1, l << s, BitVecVal(0, w)) + if s == 0: return cur + else: return mulRec(l, r, s-1) + cur + +# Define BitVecs +w = 8 # Example width, you can adjust it as necessary +l = BitVec('l', w) +r = BitVec('r', w) + +mul_circuit = mulRec(l, r, w-1) +print(mul_circuit) + +# Define assertion +mul_circuit_correct = mul_circuit == l * r +s = Solver() +s.add(mul_circuit_correct) + +out = s.check() +print(out) +-/ + theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) : - (mulRec l r s) = l * ((r.truncate s).signExtend w) := by + (mulRec l r s) = l * r := by induction w generalizing s case zero => apply Subsingleton.elim case succ w' hw => @@ -1148,8 +1182,6 @@ theorem signExtend_eq_self (x : BitVec w) : x.signExtend w = x := sorry theorem getLsb_mul (x y : BitVec w) (i : Nat) : (x * y).getLsb i = (mulRec x y w).getLsb i := by rw [mulRec_eq_mul_signExtend_truncate] - simp [zeroExtend_eq] - /-! ### le and lt -/ From 910ce3b22f206e9395a282353543d37d1fce627c Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Tue, 4 Jun 2024 17:28:37 +0100 Subject: [PATCH 04/64] chore: check in WIP --- src/Init/Data/BitVec/Lemmas.lean | 159 ++++++++++++++++++++++++------- 1 file changed, 122 insertions(+), 37 deletions(-) diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 3d0a304eef96..6d170e465671 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -1132,56 +1132,141 @@ theorem mulRec_succ_eq (l r : BitVec w) (s : Nat) : /-! #!/usr/bin/env python3 -# Check that 'hargonix-recurrences-statements' actually has the right statements. -# https://github.com/opencompl/lean4/pull/6 -# Theorems from: https://www21.in.tum.de/teaching/sar/SS20/7.pdf +##Check that 'hargonix-recurrences-statements' actually has the right statements. from z3 import * -# Define the `mulRec` function in Z3py -def mulRec(l : BitVecRef, r : BitVecRef, s : int): - # import pudb; pudb.set_trace() - assert isinstance(s, int) - assert isinstance(l, BitVecRef) - assert isinstance(r, BitVecRef) - cur = If(Extract(s, s, r) == 1, l << s, BitVecVal(0, w)) - if s == 0: return cur - else: return mulRec(l, r, s-1) + cur - -# Define BitVecs -w = 8 # Example width, you can adjust it as necessary -l = BitVec('l', w) -r = BitVec('r', w) - -mul_circuit = mulRec(l, r, w-1) -print(mul_circuit) - -# Define assertion -mul_circuit_correct = mul_circuit == l * r -s = Solver() -s.add(mul_circuit_correct) - -out = s.check() -print(out) +def mulExample(): + # Define the `mulRec` function in Z3py + def mulRec(l : BitVecRef, r : BitVecRef, s : int): + # import pudb; pudb.set_trace() + assert isinstance(s, int) + assert isinstance(l, BitVecRef) + assert isinstance(r, BitVecRef) + cur = If(Extract(s, s, r) == 1, l << s, BitVecVal(0, w)) + if s == 0: return cur + else: return mulRec(l, r, s-1) + cur + + # Define BitVecs + w = 8 # Example width, you can adjust it as necessary + l = BitVec('l', w) + r = BitVec('r', w) + + mul_circuit = mulRec(l, r, w-1) + print(mul_circuit) + + # Define assertion + mul_circuit_correct = mul_circuit == l * r + s = Solver() + s.add(ForAll(l, ForAll(r, mul_circuit_correct))) + + assert bool(s.check()) + + # verify what happens in mulRec for all 's' + for nbits_keep in range(1, w): + s = Solver() + s.add(ForAll(l, ForAll(r, mulRec(l, r, nbits_keep) == ZeroExt(w - nbits_keep - 1, Extract(nbits_keep, 0, l * r))))) + print(f"* checking mul eqn for width:'{nbits_keep}': '{s}'.") + assert bool(s.check()) +mulExample() -/ +@[simp] +theorem getLsb_ofBool (b : Bool) (i : Nat) : (BitVec.ofBool b).getLsb i = ((i = 0) && b) := by + rcases b with rfl | rfl + · simp [ofBool] + · simp [ofBool, getLsb_ofNat] + by_cases hi : (i = 0) + · simp [hi] + · simp [hi] + omega + +/-- zero extending a bitvector to width 1 equals the boolean of the lsb. -/ +theorem zeroExtend_one_eq_ofBool_getLsb_zero (x : BitVec w) : + x.zeroExtend 1 = BitVec.ofBool (x.getLsb 0) := by + ext i + simp [getLsb_zeroExtend, Fin.fin_one_eq_zero i] + +/-- `testBit 1 i` is true iff the index `i` equals 0. -/ +private theorem Nat.testBit_one_eq_true_iff_self_eq_zero {i : Nat} : + Nat.testBit 1 i = true ↔ i = 0 := by + cases i <;> simp + +/-- Zero extending `1#v` to `1#w` equals `1#w` when `v > 0`. -/ +theorem zeroExtend_ofNat_one_eq_ofNat_one_of_lt {v w : Nat} (hv : 0 < v): + (BitVec.ofNat v 1).zeroExtend w = BitVec.ofNat w 1 := by + ext i + obtain ⟨i, hilt⟩ := i + simp only [getLsb_zeroExtend, hilt, decide_True, getLsb_ofNat, Bool.true_and, + Bool.and_iff_right_iff_imp, decide_eq_true_eq] + intros hi1 + have hv := Nat.testBit_one_eq_true_iff_self_eq_zero.mp hi1 + omega + +@[simp] +theorem BitVec.mul_one {x : BitVec w} : x * (1#w) = x := by + apply eq_of_toNat_eq + simp [toNat_mul, Nat.mod_eq_of_lt x.isLt] + +@[simp] +theorem BitVec.mul_zero {x : BitVec w} : x * (0#w) = (0#w) := by + apply eq_of_toNat_eq + simp [toNat_mul] + +theorem BitVec.mul_add {x y z : BitVec w} : + x * (y + z) = x * y + x * z := by + apply eq_of_toNat_eq + simp + rw [Nat.mul_mod, Nat.mod_mod (y.toNat + z.toNat), + ← Nat.mul_mod, Nat.mul_add] + +/-- The Bitvector that is equal to `2^i % 2^w`, the power of 2 (`pot`). -/ +def pot {w : Nat} (i : Nat) : BitVec w := (1#w) <<< i + +theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_pot (x : BitVec w) (i : Nat) : + zeroExtend w (x.truncate (i + 1)) = + zeroExtend w (x.truncate i) + (x &&& (BitVec.pot i)) := by + apply eq_of_toNat_eq + sorry + theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) : - (mulRec l r s) = l * r := by + (mulRec l r s) = l * ((r.truncate (s + 1)).zeroExtend w) := by induction w generalizing s case zero => apply Subsingleton.elim case succ w' hw => induction s case zero => - simp [mulRec, mulRec_zero_eq, signExtend, truncate] - sorry - case succ s' hs => sorry - --- Provable with sign extend theory. -@[simp] -theorem signExtend_eq_self (x : BitVec w) : x.signExtend w = x := sorry + simp [mulRec_zero_eq] + by_cases r.getLsb 0 + case pos hr => + simp only [hr, ↓reduceIte, truncate, zeroExtend_one_eq_ofBool_getLsb_zero, + hr, ofBool_true, ofNat_eq_ofNat] + rw [zeroExtend_ofNat_one_eq_ofNat_one_of_lt (by omega)]; simp + case neg hr => + simp [hr, zeroExtend_one_eq_ofBool_getLsb_zero] + case succ s' hs => + rw [mulRec_succ_eq] + rw [hs]; + have heq : + (if r.getLsb (s' + 1) = true then l <<< (s' + 1) else 0) = + (l * (r &&& (BitVec.pot (s' + 1)))) := by sorry + rw [heq, ← BitVec.mul_add] + rw [← zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_pot] + +theorem zeroExtend_zeroExtend_of_lt (x : BitVec w) + (u v : Nat) (hi : i ≤ j) : + (x.zeroExtend i |>.zeroExtend j) = x.zeroExtend j := by + ext k + simp + intros hx; + have hk : k < j := by omega + sorry + -- omega theorem getLsb_mul (x y : BitVec w) (i : Nat) : (x * y).getLsb i = (mulRec x y w).getLsb i := by - rw [mulRec_eq_mul_signExtend_truncate] + simp [mulRec_eq_mul_signExtend_truncate] + rw [truncate] + sorry /-! ### le and lt -/ From 5ae0995be5e98981fc4991caeffd42ebaec55910 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Tue, 4 Jun 2024 17:38:40 +0100 Subject: [PATCH 05/64] chore: make progress --- src/Init/Data/BitVec/Lemmas.lean | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 6d170e465671..ea97875220e4 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -1252,21 +1252,24 @@ theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) : rw [heq, ← BitVec.mul_add] rw [← zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_pot] -theorem zeroExtend_zeroExtend_of_lt (x : BitVec w) - (u v : Nat) (hi : i ≤ j) : - (x.zeroExtend i |>.zeroExtend j) = x.zeroExtend j := by +/-- Zero extending by number of bits larger than the bitwidth has no effect. -/ +theorem zeroExtend_of_ge {x : BitVec w} {i j : Nat} (hi : i ≥ w) : + (x.zeroExtend i).zeroExtend j = x.zeroExtend j := by ext k simp intros hx; - have hk : k < j := by omega - sorry - -- omega + have hi' : k < w := BitVec.lt_of_getLsb _ _ hx + omega + +/-- Zero extending by the bitwidth has no effect. -/ +theorem zeroExtend_eq_self {x : BitVec w} : x.zeroExtend w = x := by + ext i + simp [getLsb_zeroExtend] theorem getLsb_mul (x y : BitVec w) (i : Nat) : (x * y).getLsb i = (mulRec x y w).getLsb i := by simp [mulRec_eq_mul_signExtend_truncate] - rw [truncate] - sorry + rw [truncate, zeroExtend_of_ge (by omega), zeroExtend_eq_self] /-! ### le and lt -/ From 203b34b4b46337487b5891c06f8195f1766022b7 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Tue, 4 Jun 2024 18:43:37 +0100 Subject: [PATCH 06/64] chore: continue writing theorems --- src/Init/Data/BitVec/Bitblast.lean | 13 +++++++++ src/Init/Data/BitVec/Lemmas.lean | 43 ++++++++++++++++++++++++++++-- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 1ca892057551..24b013e8df69 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -159,6 +159,19 @@ theorem add_eq_adc (w : Nat) (x y : BitVec w) : x + y = (adc x y false).snd := b theorem allOnes_sub_eq_not (x : BitVec w) : allOnes w - x = ~~~x := by rw [← add_not_self x, BitVec.add_comm, add_sub_cancel] +/-- Adding two bitvectors equals or-ing them if they are 1 in mutually exclusive locations -/ +theorem add_eq_or_of_and_eq_zero (x y : BitVec w) (h : x &&& y = 0#w) : x + y = x ||| y := by + rw [add_eq_adc, adc, iunfoldr_replace (fun _ => false) (x ||| y)] + · rfl + · simp [adcb, atLeastTwo, h] + intros i + replace h : (x &&& y).getLsb i = (0#w).getLsb i := by rw [h] + simp only [getLsb_and, getLsb_zero, and_eq_false_imp] at h + constructor + · intros hx + simp_all [hx] + · by_cases hx : x.getLsb i <;> simp_all [hx] + /-! ### Negation -/ theorem bit_not_testBit (x : BitVec w) (i : Fin w) : diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index ea97875220e4..efd5ab331127 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -1191,6 +1191,7 @@ private theorem Nat.testBit_one_eq_true_iff_self_eq_zero {i : Nat} : Nat.testBit 1 i = true ↔ i = 0 := by cases i <;> simp + /-- Zero extending `1#v` to `1#w` equals `1#w` when `v > 0`. -/ theorem zeroExtend_ofNat_one_eq_ofNat_one_of_lt {v w : Nat} (hv : 0 < v): (BitVec.ofNat v 1).zeroExtend w = BitVec.ofNat w 1 := by @@ -1222,11 +1223,49 @@ theorem BitVec.mul_add {x y z : BitVec w} : /-- The Bitvector that is equal to `2^i % 2^w`, the power of 2 (`pot`). -/ def pot {w : Nat} (i : Nat) : BitVec w := (1#w) <<< i +@[simp] +theorem getLsb_pot (i j : Nat) : (pot i : BitVec w).getLsb j = ((i < w) && (i = j)) := by + rcases w with rfl | w + · simp only [pot, BitVec.reduceOfNat, Nat.zero_le, getLsb_ge, Bool.false_eq, + Bool.and_eq_false_imp, decide_eq_true_eq, decide_eq_false_iff_not] + omega + · simp [pot, getLsb_shiftLeft, getLsb_ofNat, decide_eq_true_eq] + by_cases hi : Nat.testBit 1 (j - i) + · simp [hi] + obtain hi' := Nat.testBit_one_eq_true_iff_self_eq_zero.mp hi + simp [hi'] + have hi'' : i = j := by omega + omega + · simp at hi + rw [hi] + have hij : i ≠ j := by + intro h; subst h + simp at hi + simp [hij] + +/-- This is proven in BitBlast.lean, but it's a dependency that needs an import cycle to be broken. -/ +theorem add_eq_or_of_and_eq_zero (x y : BitVec w) (h : x &&& y = 0#w) : x + y = x ||| y := by sorry + +theorem BitVec.toNat_pot (w : Nat) (i : Nat) : (pot i : BitVec w).toNat = 2^i % 2^w := by + rcases w with rfl | w + · simp [Nat.mod_one] + · simp [pot, toNat_shiftLeft] + have hone : 1 < 2 ^ (w + 1) := by + rw [show 1 = 2^0 by simp[Nat.pow_zero]] + exact Nat.pow_lt_pow_of_lt (by omega) (by omega) + simp [Nat.mod_eq_of_lt hone, Nat.shiftLeft_eq] + theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_pot (x : BitVec w) (i : Nat) : zeroExtend w (x.truncate (i + 1)) = zeroExtend w (x.truncate i) + (x &&& (BitVec.pot i)) := by - apply eq_of_toNat_eq - sorry + rw [add_eq_or_of_and_eq_zero] + · ext k + simp only [getLsb_zeroExtend, Fin.is_lt, decide_True, Bool.true_and, getLsb_or, getLsb_and] + by_cases hk:k = i + · sorry + · sorry + · ext k + sorry theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) : (mulRec l r s) = l * ((r.truncate (s + 1)).zeroExtend w) := by From 6df5fd5f6555a74578ba4a439bd49dadb516dfde Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Tue, 4 Jun 2024 19:02:16 +0100 Subject: [PATCH 07/64] chore: push --- src/Init/Data/BitVec/Lemmas.lean | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index efd5ab331127..c07b4b2f6ebf 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -1229,19 +1229,19 @@ theorem getLsb_pot (i j : Nat) : (pot i : BitVec w).getLsb j = ((i < w) && (i = · simp only [pot, BitVec.reduceOfNat, Nat.zero_le, getLsb_ge, Bool.false_eq, Bool.and_eq_false_imp, decide_eq_true_eq, decide_eq_false_iff_not] omega - · simp [pot, getLsb_shiftLeft, getLsb_ofNat, decide_eq_true_eq] - by_cases hi : Nat.testBit 1 (j - i) - · simp [hi] - obtain hi' := Nat.testBit_one_eq_true_iff_self_eq_zero.mp hi - simp [hi'] - have hi'' : i = j := by omega + · simp only [pot, getLsb_shiftLeft, getLsb_ofNat] + by_cases hj : j < i + · simp only [hj, decide_True, Bool.not_true, Bool.and_false, Bool.false_and, Bool.false_eq, + Bool.and_eq_false_imp, decide_eq_true_eq, decide_eq_false_iff_not] omega - · simp at hi - rw [hi] - have hij : i ≠ j := by - intro h; subst h - simp at hi - simp [hij] + · by_cases hi : Nat.testBit 1 (j - i) + · obtain hi' := Nat.testBit_one_eq_true_iff_self_eq_zero.mp hi + have hij : j = i := by omega + simp_all + · have hij : i ≠ j := by + intro h; subst h + simp at hi + simp_all /-- This is proven in BitBlast.lean, but it's a dependency that needs an import cycle to be broken. -/ theorem add_eq_or_of_and_eq_zero (x y : BitVec w) (h : x &&& y = 0#w) : x + y = x ||| y := by sorry From bec5b36d731ae0cb44f8fe8410362a8dbfbbcf12 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Tue, 4 Jun 2024 22:42:15 +0100 Subject: [PATCH 08/64] chore: finish proof of recurrence --- src/Init/Data/BitVec/Lemmas.lean | 52 +++++++++++++++++++++++++++++--- 1 file changed, 47 insertions(+), 5 deletions(-) diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index c07b4b2f6ebf..0b3d537a5d83 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -1223,6 +1223,17 @@ theorem BitVec.mul_add {x y z : BitVec w} : /-- The Bitvector that is equal to `2^i % 2^w`, the power of 2 (`pot`). -/ def pot {w : Nat} (i : Nat) : BitVec w := (1#w) <<< i +@[simp] +theorem toNat_pot (w : Nat) (i : Nat) : (pot i : BitVec w).toNat = 2^i % 2^w := by + rcases w with rfl | w + · simp [Nat.mod_one] + · simp [pot, toNat_shiftLeft] + have h1 : 1 < 2 ^ (w + 1) := Nat.one_lt_two_pow (by omega) + rw [Nat.mod_eq_of_lt h1] + rw [Nat.shiftLeft_eq, Nat.one_mul] + + + @[simp] theorem getLsb_pot (i j : Nat) : (pot i : BitVec w).getLsb j = ((i < w) && (i = j)) := by rcases w with rfl | w @@ -1243,6 +1254,26 @@ theorem getLsb_pot (i j : Nat) : (pot i : BitVec w).getLsb j = ((i < w) && (i = simp at hi simp_all +theorem and_pot_eq_getLsb (x : BitVec w) (i : Nat) : + x &&& (pot i : BitVec w) = if x.getLsb i then pot i else 0#w := by + ext j + simp only [getLsb_and, getLsb_pot] + by_cases hj : i = j <;> by_cases hx : x.getLsb i <;> simp_all + +@[simp] +theorem mul_pot_eq_shiftLeft (x : BitVec w) (i : Nat) : + x * (pot i : BitVec w) = x <<< i := by + apply eq_of_toNat_eq + simp only [toNat_mul, toNat_pot, toNat_shiftLeft, Nat.shiftLeft_eq] + by_cases hi : i < w + · have hpow : 2^i < 2^w := Nat.pow_lt_pow_of_lt (by omega) (by omega) + rw [Nat.mod_eq_of_lt hpow] + · have hpow : 2 ^ i % 2 ^ w = 0 := by + rw [Nat.mod_eq_zero_of_dvd] + apply Nat.pow_dvd_pow 2 (by omega) + simp [Nat.mul_mod, hpow] + + /-- This is proven in BitBlast.lean, but it's a dependency that needs an import cycle to be broken. -/ theorem add_eq_or_of_and_eq_zero (x y : BitVec w) (h : x &&& y = 0#w) : x + y = x ||| y := by sorry @@ -1261,11 +1292,20 @@ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_pot (x : BitVec w) ( rw [add_eq_or_of_and_eq_zero] · ext k simp only [getLsb_zeroExtend, Fin.is_lt, decide_True, Bool.true_and, getLsb_or, getLsb_and] - by_cases hk:k = i - · sorry - · sorry + by_cases hik:i = k + · subst hik + simp + · simp [hik] + /- Really, 'omega' should be able to do this-/ + by_cases hik' : k < (i + 1) + · have hik'' : k < i := by omega + simp [hik', hik''] + · have hik'' : ¬ (k < i) := by omega + simp [hik', hik''] · ext k - sorry + simp + intros h₁ _ _ _ + omega theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) : (mulRec l r s) = l * ((r.truncate (s + 1)).zeroExtend w) := by @@ -1287,7 +1327,9 @@ theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) : rw [hs]; have heq : (if r.getLsb (s' + 1) = true then l <<< (s' + 1) else 0) = - (l * (r &&& (BitVec.pot (s' + 1)))) := by sorry + (l * (r &&& (BitVec.pot (s' + 1)))) := by + simp only [ofNat_eq_ofNat, and_pot_eq_getLsb] + by_cases hr : r.getLsb (s' + 1) <;> simp [hr] rw [heq, ← BitVec.mul_add] rw [← zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_pot] From e6cba224f636a48b338c8eac17c06f1d7a9fc094 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Sat, 8 Jun 2024 07:55:34 +0100 Subject: [PATCH 09/64] chore: rejigger files, move to bitvec/bitblast --- src/Init/Data/BitVec/Bitblast.lean | 139 +++++++++++++++++++++++++- src/Init/Data/BitVec/Lemmas.lean | 150 ----------------------------- 2 files changed, 138 insertions(+), 151 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 24b013e8df69..91d614e83248 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -159,8 +159,10 @@ theorem add_eq_adc (w : Nat) (x y : BitVec w) : x + y = (adc x y false).snd := b theorem allOnes_sub_eq_not (x : BitVec w) : allOnes w - x = ~~~x := by rw [← add_not_self x, BitVec.add_comm, add_sub_cancel] +#check BitVec.ofNat /-- Adding two bitvectors equals or-ing them if they are 1 in mutually exclusive locations -/ -theorem add_eq_or_of_and_eq_zero (x y : BitVec w) (h : x &&& y = 0#w) : x + y = x ||| y := by +theorem add_eq_or_of_and_eq_zero {w : Nat} (x y : BitVec w) + (h : x &&& y = (0#w)) : x + y = x ||| y := by rw [add_eq_adc, adc, iunfoldr_replace (fun _ => false) (x ||| y)] · rfl · simp [adcb, atLeastTwo, h] @@ -248,4 +250,139 @@ theorem sle_eq_carry (x y : BitVec w) : x.sle y = !((x.msb == y.msb).xor (carry w y (~~~x) true)) := by rw [sle_eq_not_slt, slt_eq_not_carry, beq_comm] +/-! ### mul recurrence for bitblasting -/ + +open BitVec in +/-- The Bitvector that is equal to `2^i % 2^w`, the power of 2 (`pot`). -/ +def pot {w : Nat} (i : Nat) : BitVec w := (1#w) <<< i + +@[simp] +theorem toNat_pot (w : Nat) (i : Nat) : (pot i : BitVec w).toNat = 2^i % 2^w := by + rcases w with rfl | w + · simp [Nat.mod_one] + · simp [pot, toNat_shiftLeft] + have h1 : 1 < 2 ^ (w + 1) := Nat.one_lt_two_pow (by omega) + rw [Nat.mod_eq_of_lt h1] + rw [Nat.shiftLeft_eq, Nat.one_mul] + +/-- `testBit 1 i` is true iff the index `i` equals 0. -/ +private theorem Nat.testBit_one_eq_true_iff_self_eq_zero {i : Nat} : + Nat.testBit 1 i = true ↔ i = 0 := by + cases i <;> simp + +@[simp] +theorem getLsb_pot (i j : Nat) : (pot i : BitVec w).getLsb j = ((i < w) && (i = j)) := by + rcases w with rfl | w + · simp only [pot, BitVec.reduceOfNat, Nat.zero_le, getLsb_ge, Bool.false_eq, + Bool.and_eq_false_imp, decide_eq_true_eq, decide_eq_false_iff_not] + omega + · simp only [pot, getLsb_shiftLeft, getLsb_ofNat] + by_cases hj : j < i + · simp only [hj, decide_True, Bool.not_true, Bool.and_false, Bool.false_and, Bool.false_eq, + Bool.and_eq_false_imp, decide_eq_true_eq, decide_eq_false_iff_not] + omega + · by_cases hi : Nat.testBit 1 (j - i) + · obtain hi' := Nat.testBit_one_eq_true_iff_self_eq_zero.mp hi + have hij : j = i := by omega + simp_all + · have hij : i ≠ j := by + intro h; subst h + simp at hi + simp_all + +theorem and_pot_eq_getLsb (x : BitVec w) (i : Nat) : + x &&& (pot i : BitVec w) = if x.getLsb i then pot i else 0#w := by + ext j + simp only [getLsb_and, getLsb_pot] + by_cases hj : i = j <;> by_cases hx : x.getLsb i <;> simp_all + +@[simp] +theorem mul_pot_eq_shiftLeft (x : BitVec w) (i : Nat) : + x * (pot i : BitVec w) = x <<< i := by + apply eq_of_toNat_eq + simp only [toNat_mul, toNat_pot, toNat_shiftLeft, Nat.shiftLeft_eq] + by_cases hi : i < w + · have hpow : 2^i < 2^w := Nat.pow_lt_pow_of_lt (by omega) (by omega) + rw [Nat.mod_eq_of_lt hpow] + · have hpow : 2 ^ i % 2 ^ w = 0 := by + rw [Nat.mod_eq_zero_of_dvd] + apply Nat.pow_dvd_pow 2 (by omega) + simp [Nat.mul_mod, hpow] + +theorem BitVec.toNat_pot (w : Nat) (i : Nat) : (pot i : BitVec w).toNat = 2^i % 2^w := by + rcases w with rfl | w + · simp [Nat.mod_one] + · simp [pot, toNat_shiftLeft] + have hone : 1 < 2 ^ (w + 1) := by + rw [show 1 = 2^0 by simp[Nat.pow_zero]] + exact Nat.pow_lt_pow_of_lt (by omega) (by omega) + simp [Nat.mod_eq_of_lt hone, Nat.shiftLeft_eq] + +theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_pot (x : BitVec w) (i : Nat) : + zeroExtend w (x.truncate (i + 1)) = + zeroExtend w (x.truncate i) + (x &&& (BitVec.pot i)) := by + rw [add_eq_or_of_and_eq_zero] + · ext k + simp only [getLsb_zeroExtend, Fin.is_lt, decide_True, Bool.true_and, getLsb_or, getLsb_and] + by_cases hik:i = k + · subst hik + simp + · simp [hik] + /- Really, 'omega' should be able to do this-/ + by_cases hik' : k < (i + 1) + · have hik'' : k < i := by omega + simp [hik', hik''] + · have hik'' : ¬ (k < i) := by omega + simp [hik', hik''] + · ext k + simp + intros h₁ _ _ _ + omega + +theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) : + (mulRec l r s) = l * ((r.truncate (s + 1)).zeroExtend w) := by + induction w generalizing s + case zero => apply Subsingleton.elim + case succ w' hw => + induction s + case zero => + simp [mulRec_zero_eq] + by_cases r.getLsb 0 + case pos hr => + simp only [hr, ↓reduceIte, truncate, zeroExtend_one_eq_ofBool_getLsb_zero, + hr, ofBool_true, ofNat_eq_ofNat] + rw [zeroExtend_ofNat_one_eq_ofNat_one_of_lt (by omega)]; simp + case neg hr => + simp [hr, zeroExtend_one_eq_ofBool_getLsb_zero] + case succ s' hs => + rw [mulRec_succ_eq] + rw [hs]; + have heq : + (if r.getLsb (s' + 1) = true then l <<< (s' + 1) else 0) = + (l * (r &&& (BitVec.pot (s' + 1)))) := by + simp only [ofNat_eq_ofNat, and_pot_eq_getLsb] + by_cases hr : r.getLsb (s' + 1) <;> simp [hr] + rw [heq, ← BitVec.mul_add] + rw [← zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_pot] + + +/-- Zero extending by number of bits larger than the bitwidth has no effect. -/ +theorem zeroExtend_of_ge {x : BitVec w} {i j : Nat} (hi : i ≥ w) : + (x.zeroExtend i).zeroExtend j = x.zeroExtend j := by + ext k + simp + intros hx; + have hi' : k < w := BitVec.lt_of_getLsb _ _ hx + omega + +/-- Zero extending by the bitwidth has no effect. -/ +theorem zeroExtend_eq_self {x : BitVec w} : x.zeroExtend w = x := by + ext i + simp [getLsb_zeroExtend] + +theorem getLsb_mul (x y : BitVec w) (i : Nat) : + (x * y).getLsb i = (mulRec x y w).getLsb i := by + simp [mulRec_eq_mul_signExtend_truncate] + rw [truncate, zeroExtend_of_ge (by omega), zeroExtend_eq_self] + end BitVec diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 0b3d537a5d83..90a5ef94c29e 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -735,23 +735,6 @@ theorem udiv_eq {x y : BitVec n} : rw [Nat.mod_eq_of_lt] exact Nat.lt_of_le_of_lt (Nat.div_le_self ..) (by omega) -/-- The remainder `rem` obeys the euclidean algorithm equation on computing `l.udiv r`. -/ -def udivDivisor (l r rem : BitVec w) : Prop := - rem < r ∧ - let l' := l.signExtend (2*w) - let r' := r.signExtend (2*w) - let rem' := rem.signExtend (2*w) - l' = (l' / r') * r' + rem' - -/-- Such a remainder always exists. -/ -theorem udiv_euclid_eqn_exists (l r : BitVec w) : - ∃ (rem : BitVec w), udivDivisor l r rem := sorry - -/-- Such a remainder is unique. -/ -theorem udiv_euclid_eqn_unique (l r rem rem' : BitVec w) - (hrem : udivDivisor l r rem) (hrem' : udivDivisor l r rem') : - rem = rem' := sorry - /-! ### append -/ theorem append_def (x : BitVec v) (y : BitVec w) : @@ -1191,7 +1174,6 @@ private theorem Nat.testBit_one_eq_true_iff_self_eq_zero {i : Nat} : Nat.testBit 1 i = true ↔ i = 0 := by cases i <;> simp - /-- Zero extending `1#v` to `1#w` equals `1#w` when `v > 0`. -/ theorem zeroExtend_ofNat_one_eq_ofNat_one_of_lt {v w : Nat} (hv : 0 < v): (BitVec.ofNat v 1).zeroExtend w = BitVec.ofNat w 1 := by @@ -1220,138 +1202,6 @@ theorem BitVec.mul_add {x y z : BitVec w} : rw [Nat.mul_mod, Nat.mod_mod (y.toNat + z.toNat), ← Nat.mul_mod, Nat.mul_add] -/-- The Bitvector that is equal to `2^i % 2^w`, the power of 2 (`pot`). -/ -def pot {w : Nat} (i : Nat) : BitVec w := (1#w) <<< i - -@[simp] -theorem toNat_pot (w : Nat) (i : Nat) : (pot i : BitVec w).toNat = 2^i % 2^w := by - rcases w with rfl | w - · simp [Nat.mod_one] - · simp [pot, toNat_shiftLeft] - have h1 : 1 < 2 ^ (w + 1) := Nat.one_lt_two_pow (by omega) - rw [Nat.mod_eq_of_lt h1] - rw [Nat.shiftLeft_eq, Nat.one_mul] - - - -@[simp] -theorem getLsb_pot (i j : Nat) : (pot i : BitVec w).getLsb j = ((i < w) && (i = j)) := by - rcases w with rfl | w - · simp only [pot, BitVec.reduceOfNat, Nat.zero_le, getLsb_ge, Bool.false_eq, - Bool.and_eq_false_imp, decide_eq_true_eq, decide_eq_false_iff_not] - omega - · simp only [pot, getLsb_shiftLeft, getLsb_ofNat] - by_cases hj : j < i - · simp only [hj, decide_True, Bool.not_true, Bool.and_false, Bool.false_and, Bool.false_eq, - Bool.and_eq_false_imp, decide_eq_true_eq, decide_eq_false_iff_not] - omega - · by_cases hi : Nat.testBit 1 (j - i) - · obtain hi' := Nat.testBit_one_eq_true_iff_self_eq_zero.mp hi - have hij : j = i := by omega - simp_all - · have hij : i ≠ j := by - intro h; subst h - simp at hi - simp_all - -theorem and_pot_eq_getLsb (x : BitVec w) (i : Nat) : - x &&& (pot i : BitVec w) = if x.getLsb i then pot i else 0#w := by - ext j - simp only [getLsb_and, getLsb_pot] - by_cases hj : i = j <;> by_cases hx : x.getLsb i <;> simp_all - -@[simp] -theorem mul_pot_eq_shiftLeft (x : BitVec w) (i : Nat) : - x * (pot i : BitVec w) = x <<< i := by - apply eq_of_toNat_eq - simp only [toNat_mul, toNat_pot, toNat_shiftLeft, Nat.shiftLeft_eq] - by_cases hi : i < w - · have hpow : 2^i < 2^w := Nat.pow_lt_pow_of_lt (by omega) (by omega) - rw [Nat.mod_eq_of_lt hpow] - · have hpow : 2 ^ i % 2 ^ w = 0 := by - rw [Nat.mod_eq_zero_of_dvd] - apply Nat.pow_dvd_pow 2 (by omega) - simp [Nat.mul_mod, hpow] - - -/-- This is proven in BitBlast.lean, but it's a dependency that needs an import cycle to be broken. -/ -theorem add_eq_or_of_and_eq_zero (x y : BitVec w) (h : x &&& y = 0#w) : x + y = x ||| y := by sorry - -theorem BitVec.toNat_pot (w : Nat) (i : Nat) : (pot i : BitVec w).toNat = 2^i % 2^w := by - rcases w with rfl | w - · simp [Nat.mod_one] - · simp [pot, toNat_shiftLeft] - have hone : 1 < 2 ^ (w + 1) := by - rw [show 1 = 2^0 by simp[Nat.pow_zero]] - exact Nat.pow_lt_pow_of_lt (by omega) (by omega) - simp [Nat.mod_eq_of_lt hone, Nat.shiftLeft_eq] - -theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_pot (x : BitVec w) (i : Nat) : - zeroExtend w (x.truncate (i + 1)) = - zeroExtend w (x.truncate i) + (x &&& (BitVec.pot i)) := by - rw [add_eq_or_of_and_eq_zero] - · ext k - simp only [getLsb_zeroExtend, Fin.is_lt, decide_True, Bool.true_and, getLsb_or, getLsb_and] - by_cases hik:i = k - · subst hik - simp - · simp [hik] - /- Really, 'omega' should be able to do this-/ - by_cases hik' : k < (i + 1) - · have hik'' : k < i := by omega - simp [hik', hik''] - · have hik'' : ¬ (k < i) := by omega - simp [hik', hik''] - · ext k - simp - intros h₁ _ _ _ - omega - -theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) : - (mulRec l r s) = l * ((r.truncate (s + 1)).zeroExtend w) := by - induction w generalizing s - case zero => apply Subsingleton.elim - case succ w' hw => - induction s - case zero => - simp [mulRec_zero_eq] - by_cases r.getLsb 0 - case pos hr => - simp only [hr, ↓reduceIte, truncate, zeroExtend_one_eq_ofBool_getLsb_zero, - hr, ofBool_true, ofNat_eq_ofNat] - rw [zeroExtend_ofNat_one_eq_ofNat_one_of_lt (by omega)]; simp - case neg hr => - simp [hr, zeroExtend_one_eq_ofBool_getLsb_zero] - case succ s' hs => - rw [mulRec_succ_eq] - rw [hs]; - have heq : - (if r.getLsb (s' + 1) = true then l <<< (s' + 1) else 0) = - (l * (r &&& (BitVec.pot (s' + 1)))) := by - simp only [ofNat_eq_ofNat, and_pot_eq_getLsb] - by_cases hr : r.getLsb (s' + 1) <;> simp [hr] - rw [heq, ← BitVec.mul_add] - rw [← zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_pot] - -/-- Zero extending by number of bits larger than the bitwidth has no effect. -/ -theorem zeroExtend_of_ge {x : BitVec w} {i j : Nat} (hi : i ≥ w) : - (x.zeroExtend i).zeroExtend j = x.zeroExtend j := by - ext k - simp - intros hx; - have hi' : k < w := BitVec.lt_of_getLsb _ _ hx - omega - -/-- Zero extending by the bitwidth has no effect. -/ -theorem zeroExtend_eq_self {x : BitVec w} : x.zeroExtend w = x := by - ext i - simp [getLsb_zeroExtend] - -theorem getLsb_mul (x y : BitVec w) (i : Nat) : - (x * y).getLsb i = (mulRec x y w).getLsb i := by - simp [mulRec_eq_mul_signExtend_truncate] - rw [truncate, zeroExtend_of_ge (by omega), zeroExtend_eq_self] - /-! ### le and lt -/ @[bv_toNat] theorem le_def (x y : BitVec n) : From 2f0cb91b724c8df96258edf8503c8a978ceeff14 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Sat, 8 Jun 2024 08:11:09 +0100 Subject: [PATCH 10/64] chore: cleaup proofs, move them to the right locations --- src/Init/Data/BitVec/Basic.lean | 7 ++ src/Init/Data/BitVec/Bitblast.lean | 114 ++++++----------------------- src/Init/Data/BitVec/Lemmas.lean | 59 +++++++++++++++ 3 files changed, 88 insertions(+), 92 deletions(-) diff --git a/src/Init/Data/BitVec/Basic.lean b/src/Init/Data/BitVec/Basic.lean index 1d12641f0abb..508ce9012475 100644 --- a/src/Init/Data/BitVec/Basic.lean +++ b/src/Init/Data/BitVec/Basic.lean @@ -616,6 +616,13 @@ theorem ofBool_append (msb : Bool) (lsbs : BitVec w) : end bitwise +section twoPow + +/-- `twoPow i` is the bitvector `2^i` if `i < w`, and `0` otherwise. That is, 2 to the power `i`. -/ +def twoPow {w : Nat} (i : Nat) : BitVec w := (1#w) <<< i + +end twoPow + section normalization_eqs /-! We add simp-lemmas that rewrite bitvector operations into the equivalent notation -/ @[simp] theorem append_eq (x : BitVec w) (y : BitVec v) : BitVec.append x y = x ++ y := rfl diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 91d614e83248..b60b23ca90e7 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -252,75 +252,11 @@ theorem sle_eq_carry (x y : BitVec w) : /-! ### mul recurrence for bitblasting -/ -open BitVec in -/-- The Bitvector that is equal to `2^i % 2^w`, the power of 2 (`pot`). -/ -def pot {w : Nat} (i : Nat) : BitVec w := (1#w) <<< i - -@[simp] -theorem toNat_pot (w : Nat) (i : Nat) : (pot i : BitVec w).toNat = 2^i % 2^w := by - rcases w with rfl | w - · simp [Nat.mod_one] - · simp [pot, toNat_shiftLeft] - have h1 : 1 < 2 ^ (w + 1) := Nat.one_lt_two_pow (by omega) - rw [Nat.mod_eq_of_lt h1] - rw [Nat.shiftLeft_eq, Nat.one_mul] - -/-- `testBit 1 i` is true iff the index `i` equals 0. -/ -private theorem Nat.testBit_one_eq_true_iff_self_eq_zero {i : Nat} : - Nat.testBit 1 i = true ↔ i = 0 := by - cases i <;> simp - -@[simp] -theorem getLsb_pot (i j : Nat) : (pot i : BitVec w).getLsb j = ((i < w) && (i = j)) := by - rcases w with rfl | w - · simp only [pot, BitVec.reduceOfNat, Nat.zero_le, getLsb_ge, Bool.false_eq, - Bool.and_eq_false_imp, decide_eq_true_eq, decide_eq_false_iff_not] - omega - · simp only [pot, getLsb_shiftLeft, getLsb_ofNat] - by_cases hj : j < i - · simp only [hj, decide_True, Bool.not_true, Bool.and_false, Bool.false_and, Bool.false_eq, - Bool.and_eq_false_imp, decide_eq_true_eq, decide_eq_false_iff_not] - omega - · by_cases hi : Nat.testBit 1 (j - i) - · obtain hi' := Nat.testBit_one_eq_true_iff_self_eq_zero.mp hi - have hij : j = i := by omega - simp_all - · have hij : i ≠ j := by - intro h; subst h - simp at hi - simp_all - -theorem and_pot_eq_getLsb (x : BitVec w) (i : Nat) : - x &&& (pot i : BitVec w) = if x.getLsb i then pot i else 0#w := by - ext j - simp only [getLsb_and, getLsb_pot] - by_cases hj : i = j <;> by_cases hx : x.getLsb i <;> simp_all - -@[simp] -theorem mul_pot_eq_shiftLeft (x : BitVec w) (i : Nat) : - x * (pot i : BitVec w) = x <<< i := by - apply eq_of_toNat_eq - simp only [toNat_mul, toNat_pot, toNat_shiftLeft, Nat.shiftLeft_eq] - by_cases hi : i < w - · have hpow : 2^i < 2^w := Nat.pow_lt_pow_of_lt (by omega) (by omega) - rw [Nat.mod_eq_of_lt hpow] - · have hpow : 2 ^ i % 2 ^ w = 0 := by - rw [Nat.mod_eq_zero_of_dvd] - apply Nat.pow_dvd_pow 2 (by omega) - simp [Nat.mul_mod, hpow] - -theorem BitVec.toNat_pot (w : Nat) (i : Nat) : (pot i : BitVec w).toNat = 2^i % 2^w := by - rcases w with rfl | w - · simp [Nat.mod_one] - · simp [pot, toNat_shiftLeft] - have hone : 1 < 2 ^ (w + 1) := by - rw [show 1 = 2^0 by simp[Nat.pow_zero]] - exact Nat.pow_lt_pow_of_lt (by omega) (by omega) - simp [Nat.mod_eq_of_lt hone, Nat.shiftLeft_eq] - -theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_pot (x : BitVec w) (i : Nat) : +/-- Recurrence lemma that saus that truncating to `i+1` bits and then zero extending to `w` +equals truncating upto `i` bits `[0..i-1]`, and then adding the `i`th bit of `x`. -/ +theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow (x : BitVec w) (i : Nat) : zeroExtend w (x.truncate (i + 1)) = - zeroExtend w (x.truncate i) + (x &&& (BitVec.pot i)) := by + zeroExtend w (x.truncate i) + (x &&& twoPow i) := by rw [add_eq_or_of_and_eq_zero] · ext k simp only [getLsb_zeroExtend, Fin.is_lt, decide_True, Bool.true_and, getLsb_or, getLsb_and] @@ -341,30 +277,24 @@ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_pot (x : BitVec w) ( theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) : (mulRec l r s) = l * ((r.truncate (s + 1)).zeroExtend w) := by - induction w generalizing s - case zero => apply Subsingleton.elim - case succ w' hw => - induction s - case zero => - simp [mulRec_zero_eq] - by_cases r.getLsb 0 - case pos hr => - simp only [hr, ↓reduceIte, truncate, zeroExtend_one_eq_ofBool_getLsb_zero, - hr, ofBool_true, ofNat_eq_ofNat] - rw [zeroExtend_ofNat_one_eq_ofNat_one_of_lt (by omega)]; simp - case neg hr => - simp [hr, zeroExtend_one_eq_ofBool_getLsb_zero] - case succ s' hs => - rw [mulRec_succ_eq] - rw [hs]; - have heq : - (if r.getLsb (s' + 1) = true then l <<< (s' + 1) else 0) = - (l * (r &&& (BitVec.pot (s' + 1)))) := by - simp only [ofNat_eq_ofNat, and_pot_eq_getLsb] - by_cases hr : r.getLsb (s' + 1) <;> simp [hr] - rw [heq, ← BitVec.mul_add] - rw [← zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_pot] - + induction s + case zero => + simp [mulRec_zero_eq] + by_cases r.getLsb 0 + case pos hr => + simp only [hr, ↓reduceIte, truncate, zeroExtend_one_eq_ofBool_getLsb_zero, + hr, ofBool_true, ofNat_eq_ofNat] + rw [zeroExtend_ofNat_one_eq_ofNat_one_of_lt (by omega)]; simp + case neg hr => + simp [hr, zeroExtend_one_eq_ofBool_getLsb_zero] + case succ s' hs => + rw [mulRec_succ_eq, hs] + have heq : + (if r.getLsb (s' + 1) = true then l <<< (s' + 1) else 0) = + (l * (r &&& (BitVec.twoPow (s' + 1)))) := by + simp only [ofNat_eq_ofNat, and_twoPow_eq_getLsb] + by_cases hr : r.getLsb (s' + 1) <;> simp [hr] + rw [heq, ← BitVec.mul_add, ← zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow] /-- Zero extending by number of bits larger than the bitwidth has no effect. -/ theorem zeroExtend_of_ge {x : BitVec w} {i j : Nat} (hi : i ≥ w) : diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 90a5ef94c29e..62af1a19a289 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -1428,4 +1428,63 @@ theorem getLsb_rotateRight {x : BitVec w} {r i : Nat} : · simp · rw [← rotateRight_mod_eq_rotateRight, getLsb_rotateRight_of_le (Nat.mod_lt _ (by omega))] +/- ## twoPow -/ + +@[simp] +theorem toNat_twoPow (w : Nat) (i : Nat) : (twoPow i : BitVec w).toNat = 2^i % 2^w := by + rcases w with rfl | w + · simp [Nat.mod_one] + · simp [twoPow, toNat_shiftLeft] + have h1 : 1 < 2 ^ (w + 1) := Nat.one_lt_two_pow (by omega) + rw [Nat.mod_eq_of_lt h1] + rw [Nat.shiftLeft_eq, Nat.one_mul] + +@[simp] +theorem getLsb_twoPow (i j : Nat) : (twoPow i : BitVec w).getLsb j = ((i < w) && (i = j)) := by + rcases w with rfl | w + · simp only [twoPow, BitVec.reduceOfNat, Nat.zero_le, getLsb_ge, Bool.false_eq, + Bool.and_eq_false_imp, decide_eq_true_eq, decide_eq_false_iff_not] + omega + · simp only [twoPow, getLsb_shiftLeft, getLsb_ofNat] + by_cases hj : j < i + · simp only [hj, decide_True, Bool.not_true, Bool.and_false, Bool.false_and, Bool.false_eq, + Bool.and_eq_false_imp, decide_eq_true_eq, decide_eq_false_iff_not] + omega + · by_cases hi : Nat.testBit 1 (j - i) + · obtain hi' := Nat.testBit_one_eq_true_iff_self_eq_zero.mp hi + have hij : j = i := by omega + simp_all + · have hij : i ≠ j := by + intro h; subst h + simp at hi + simp_all + +theorem and_twoPow_eq_getLsb (x : BitVec w) (i : Nat) : + x &&& (twoPow i : BitVec w) = if x.getLsb i then twoPow i else 0#w := by + ext j + simp only [getLsb_and, getLsb_twoPow] + by_cases hj : i = j <;> by_cases hx : x.getLsb i <;> simp_all + +@[simp] +theorem mul_twoPow_eq_shiftLeft (x : BitVec w) (i : Nat) : + x * (twoPow i : BitVec w) = x <<< i := by + apply eq_of_toNat_eq + simp only [toNat_mul, toNat_twoPow, toNat_shiftLeft, Nat.shiftLeft_eq] + by_cases hi : i < w + · have hpow : 2^i < 2^w := Nat.pow_lt_pow_of_lt (by omega) (by omega) + rw [Nat.mod_eq_of_lt hpow] + · have hpow : 2 ^ i % 2 ^ w = 0 := by + rw [Nat.mod_eq_zero_of_dvd] + apply Nat.pow_dvd_pow 2 (by omega) + simp [Nat.mul_mod, hpow] + +theorem BitVec.toNat_twoPow (w : Nat) (i : Nat) : (twoPow i : BitVec w).toNat = 2^i % 2^w := by + rcases w with rfl | w + · simp [Nat.mod_one] + · simp [twoPow, toNat_shiftLeft] + have hone : 1 < 2 ^ (w + 1) := by + rw [show 1 = 2^0 by simp[Nat.pow_zero]] + exact Nat.pow_lt_pow_of_lt (by omega) (by omega) + simp [Nat.mod_eq_of_lt hone, Nat.shiftLeft_eq] + end BitVec From 863284992ee03e1435ead6522fed6d4ef1deb5b3 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Sat, 8 Jun 2024 08:23:10 +0100 Subject: [PATCH 11/64] chore: drop large z3 comment --- src/Init/Data/BitVec/Lemmas.lean | 40 -------------------------------- 1 file changed, 40 deletions(-) diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 62af1a19a289..1b61ff6dee67 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -1113,46 +1113,6 @@ theorem mulRec_succ_eq (l r : BitVec w) (s : Nat) : mulRec l r (s + 1) = mulRec l r s + if r.getLsb (s + 1) then (l <<< (s + 1)) else 0 := by simp [mulRec] -/-! -#!/usr/bin/env python3 -##Check that 'hargonix-recurrences-statements' actually has the right statements. -from z3 import * - -def mulExample(): - # Define the `mulRec` function in Z3py - def mulRec(l : BitVecRef, r : BitVecRef, s : int): - # import pudb; pudb.set_trace() - assert isinstance(s, int) - assert isinstance(l, BitVecRef) - assert isinstance(r, BitVecRef) - cur = If(Extract(s, s, r) == 1, l << s, BitVecVal(0, w)) - if s == 0: return cur - else: return mulRec(l, r, s-1) + cur - - # Define BitVecs - w = 8 # Example width, you can adjust it as necessary - l = BitVec('l', w) - r = BitVec('r', w) - - mul_circuit = mulRec(l, r, w-1) - print(mul_circuit) - - # Define assertion - mul_circuit_correct = mul_circuit == l * r - s = Solver() - s.add(ForAll(l, ForAll(r, mul_circuit_correct))) - - assert bool(s.check()) - - # verify what happens in mulRec for all 's' - for nbits_keep in range(1, w): - s = Solver() - s.add(ForAll(l, ForAll(r, mulRec(l, r, nbits_keep) == ZeroExt(w - nbits_keep - 1, Extract(nbits_keep, 0, l * r))))) - print(f"* checking mul eqn for width:'{nbits_keep}': '{s}'.") - assert bool(s.check()) -mulExample() --/ - @[simp] theorem getLsb_ofBool (b : Bool) (i : Nat) : (BitVec.ofBool b).getLsb i = ((i = 0) && b) := by rcases b with rfl | rfl From 351a7ce34c16d2e93ee0fa83091a9e37704d91de Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Sat, 8 Jun 2024 08:28:23 +0100 Subject: [PATCH 12/64] chore: move theorems around to proper location --- src/Init/Data/BitVec/Lemmas.lean | 75 ++++++++++++++++---------------- 1 file changed, 38 insertions(+), 37 deletions(-) diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 1b61ff6dee67..ed26deaa8f15 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -163,6 +163,16 @@ theorem toNat_zero (n : Nat) : (0#n).toNat = 0 := by trivial private theorem lt_two_pow_of_le {x m n : Nat} (lt : x < 2 ^ m) (le : m ≤ n) : x < 2 ^ n := Nat.lt_of_lt_of_le lt (Nat.pow_le_pow_of_le_right (by trivial : 0 < 2) le) +@[simp] +theorem getLsb_ofBool (b : Bool) (i : Nat) : (BitVec.ofBool b).getLsb i = ((i = 0) && b) := by + rcases b with rfl | rfl + · simp [ofBool] + · simp [ofBool, getLsb_ofNat] + by_cases hi : (i = 0) + · simp [hi] + · simp [hi] + omega + /-! ### msb -/ @[simp] theorem msb_zero : (0#w).msb = false := by simp [BitVec.msb, getMsb] @@ -408,6 +418,29 @@ theorem msb_zeroExtend (x : BitVec w) : (x.zeroExtend v).msb = (decide (0 < v) & theorem msb_zeroExtend' (x : BitVec w) (h : w ≤ v) : (x.zeroExtend' h).msb = (decide (0 < v) && x.getLsb (v - 1)) := by rw [zeroExtend'_eq, msb_zeroExtend] +/-- zero extending a bitvector to width 1 equals the boolean of the lsb. -/ +theorem zeroExtend_one_eq_ofBool_getLsb_zero (x : BitVec w) : + x.zeroExtend 1 = BitVec.ofBool (x.getLsb 0) := by + ext i + simp [getLsb_zeroExtend, Fin.fin_one_eq_zero i] + +/-- `testBit 1 i` is true iff the index `i` equals 0. -/ +private theorem Nat.testBit_one_eq_true_iff_self_eq_zero {i : Nat} : + Nat.testBit 1 i = true ↔ i = 0 := by + cases i <;> simp + +/-- Zero extending `1#v` to `1#w` equals `1#w` when `v > 0`. -/ +theorem zeroExtend_ofNat_one_eq_ofNat_one_of_lt {v w : Nat} (hv : 0 < v): + (BitVec.ofNat v 1).zeroExtend w = BitVec.ofNat w 1 := by + ext i + obtain ⟨i, hilt⟩ := i + simp only [getLsb_zeroExtend, hilt, decide_True, getLsb_ofNat, Bool.true_and, + Bool.and_iff_right_iff_imp, decide_eq_true_eq] + intros hi1 + have hv := Nat.testBit_one_eq_true_iff_self_eq_zero.mp hi1 + omega + + /-! ## extractLsb -/ @[simp] @@ -593,6 +626,11 @@ theorem not_def {x : BitVec v} : ~~~x = allOnes v ^^^ x := rfl @[simp] theorem toFin_shiftLeft {n : Nat} (x : BitVec w) : BitVec.toFin (x <<< n) = Fin.ofNat' (x.toNat <<< n) (Nat.two_pow_pos w) := rfl +@[simp] +theorem shiftLeft_zero_eq (x : BitVec w) : x <<< 0 = x := by + apply eq_of_toNat_eq + simp + @[simp] theorem getLsb_shiftLeft (x : BitVec m) (n) : getLsb (x <<< n) i = (decide (i < m) && !decide (i < n) && getLsb x (i - n)) := by rw [← testBit_toNat, getLsb] @@ -1100,11 +1138,6 @@ def mulRec (l r : BitVec w) (s : Nat) : BitVec w := | 0 => cur | s + 1 => mulRec l r s + cur -@[simp] -theorem shiftLeft_zero_eq (x : BitVec w) : x <<< 0 = x := by - apply eq_of_toNat_eq - simp - theorem mulRec_zero_eq (l r : BitVec w) : mulRec l r 0 = if r.getLsb 0 then l else 0 := by simp [mulRec] @@ -1113,38 +1146,6 @@ theorem mulRec_succ_eq (l r : BitVec w) (s : Nat) : mulRec l r (s + 1) = mulRec l r s + if r.getLsb (s + 1) then (l <<< (s + 1)) else 0 := by simp [mulRec] -@[simp] -theorem getLsb_ofBool (b : Bool) (i : Nat) : (BitVec.ofBool b).getLsb i = ((i = 0) && b) := by - rcases b with rfl | rfl - · simp [ofBool] - · simp [ofBool, getLsb_ofNat] - by_cases hi : (i = 0) - · simp [hi] - · simp [hi] - omega - -/-- zero extending a bitvector to width 1 equals the boolean of the lsb. -/ -theorem zeroExtend_one_eq_ofBool_getLsb_zero (x : BitVec w) : - x.zeroExtend 1 = BitVec.ofBool (x.getLsb 0) := by - ext i - simp [getLsb_zeroExtend, Fin.fin_one_eq_zero i] - -/-- `testBit 1 i` is true iff the index `i` equals 0. -/ -private theorem Nat.testBit_one_eq_true_iff_self_eq_zero {i : Nat} : - Nat.testBit 1 i = true ↔ i = 0 := by - cases i <;> simp - -/-- Zero extending `1#v` to `1#w` equals `1#w` when `v > 0`. -/ -theorem zeroExtend_ofNat_one_eq_ofNat_one_of_lt {v w : Nat} (hv : 0 < v): - (BitVec.ofNat v 1).zeroExtend w = BitVec.ofNat w 1 := by - ext i - obtain ⟨i, hilt⟩ := i - simp only [getLsb_zeroExtend, hilt, decide_True, getLsb_ofNat, Bool.true_and, - Bool.and_iff_right_iff_imp, decide_eq_true_eq] - intros hi1 - have hv := Nat.testBit_one_eq_true_iff_self_eq_zero.mp hi1 - omega - @[simp] theorem BitVec.mul_one {x : BitVec w} : x * (1#w) = x := by apply eq_of_toNat_eq From 75867dd164778e046ade6dd1102e9668723e1d5e Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Sat, 8 Jun 2024 08:37:33 +0100 Subject: [PATCH 13/64] chore: more cleanup --- src/Init/Data/BitVec/Bitblast.lean | 17 +++++++++++++++-- src/Init/Data/BitVec/Lemmas.lean | 14 -------------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index b60b23ca90e7..80a0206f7001 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -159,7 +159,6 @@ theorem add_eq_adc (w : Nat) (x y : BitVec w) : x + y = (adc x y false).snd := b theorem allOnes_sub_eq_not (x : BitVec w) : allOnes w - x = ~~~x := by rw [← add_not_self x, BitVec.add_comm, add_sub_cancel] -#check BitVec.ofNat /-- Adding two bitvectors equals or-ing them if they are 1 in mutually exclusive locations -/ theorem add_eq_or_of_and_eq_zero {w : Nat} (x y : BitVec w) (h : x &&& y = (0#w)) : x + y = x ||| y := by @@ -252,7 +251,21 @@ theorem sle_eq_carry (x y : BitVec w) : /-! ### mul recurrence for bitblasting -/ -/-- Recurrence lemma that saus that truncating to `i+1` bits and then zero extending to `w` +def mulRec (l r : BitVec w) (s : Nat) : BitVec w := + let cur := if r.getLsb s then (l <<< s) else 0 + match s with + | 0 => cur + | s + 1 => mulRec l r s + cur + +theorem mulRec_zero_eq (l r : BitVec w) : + mulRec l r 0 = if r.getLsb 0 then l else 0 := by + simp [mulRec] + +theorem mulRec_succ_eq (l r : BitVec w) (s : Nat) : + mulRec l r (s + 1) = mulRec l r s + if r.getLsb (s + 1) then (l <<< (s + 1)) else 0 := by + simp [mulRec] + +/-- Recurrence lemma: truncating to `i+1` bits and then zero extending to `w` equals truncating upto `i` bits `[0..i-1]`, and then adding the `i`th bit of `x`. -/ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow (x : BitVec w) (i : Nat) : zeroExtend w (x.truncate (i + 1)) = diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index ed26deaa8f15..a70186083f6e 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -1132,20 +1132,6 @@ theorem ofInt_mul {n} (x y : Int) : BitVec.ofInt n (x * y) = apply eq_of_toInt_eq simp -def mulRec (l r : BitVec w) (s : Nat) : BitVec w := - let cur := if r.getLsb s then (l <<< s) else 0 - match s with - | 0 => cur - | s + 1 => mulRec l r s + cur - -theorem mulRec_zero_eq (l r : BitVec w) : - mulRec l r 0 = if r.getLsb 0 then l else 0 := by - simp [mulRec] - -theorem mulRec_succ_eq (l r : BitVec w) (s : Nat) : - mulRec l r (s + 1) = mulRec l r s + if r.getLsb (s + 1) then (l <<< (s + 1)) else 0 := by - simp [mulRec] - @[simp] theorem BitVec.mul_one {x : BitVec w} : x * (1#w) = x := by apply eq_of_toNat_eq From bea6a61225bf06fe5a61ce203b12a4896c79dcec Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Sat, 8 Jun 2024 08:51:24 +0100 Subject: [PATCH 14/64] chore: move twoPow into bitwise --- src/Init/Data/BitVec/Basic.lean | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Init/Data/BitVec/Basic.lean b/src/Init/Data/BitVec/Basic.lean index 508ce9012475..1f66133b40e0 100644 --- a/src/Init/Data/BitVec/Basic.lean +++ b/src/Init/Data/BitVec/Basic.lean @@ -614,14 +614,14 @@ theorem ofBool_append (msb : Bool) (lsbs : BitVec w) : ofBool msb ++ lsbs = (cons msb lsbs).cast (Nat.add_comm ..) := rfl -end bitwise - -section twoPow - -/-- `twoPow i` is the bitvector `2^i` if `i < w`, and `0` otherwise. That is, 2 to the power `i`. -/ +/-- +`twoPow i` is the bitvector `2^i` if `i < w`, and `0` otherwise. +That is, 2 to the power `i`. +For the bitwise point of view, it has the `i`th bit as `1` and all other bits as `0`. +-/ def twoPow {w : Nat} (i : Nat) : BitVec w := (1#w) <<< i -end twoPow +end bitwise section normalization_eqs /-! We add simp-lemmas that rewrite bitvector operations into the equivalent notation -/ From 33be7e5c2c11cfe385c19626372d40fa9c44339b Mon Sep 17 00:00:00 2001 From: Siddharth Date: Sat, 8 Jun 2024 08:55:00 +0100 Subject: [PATCH 15/64] Apply suggestions from code review chore: remove parens around 0, 1 Co-authored-by: Tobias Grosser --- src/Init/Data/BitVec/Bitblast.lean | 2 +- src/Init/Data/BitVec/Lemmas.lean | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 80a0206f7001..bb36e31b3cd5 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -161,7 +161,7 @@ theorem allOnes_sub_eq_not (x : BitVec w) : allOnes w - x = ~~~x := by /-- Adding two bitvectors equals or-ing them if they are 1 in mutually exclusive locations -/ theorem add_eq_or_of_and_eq_zero {w : Nat} (x y : BitVec w) - (h : x &&& y = (0#w)) : x + y = x ||| y := by + (h : x &&& y = 0#w) : x + y = x ||| y := by rw [add_eq_adc, adc, iunfoldr_replace (fun _ => false) (x ||| y)] · rfl · simp [adcb, atLeastTwo, h] diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index a70186083f6e..e334cfd28108 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -1133,7 +1133,7 @@ theorem ofInt_mul {n} (x y : Int) : BitVec.ofInt n (x * y) = simp @[simp] -theorem BitVec.mul_one {x : BitVec w} : x * (1#w) = x := by +theorem BitVec.mul_one {x : BitVec w} : x * 1#w = x := by apply eq_of_toNat_eq simp [toNat_mul, Nat.mod_eq_of_lt x.isLt] From 9bc8d7dd74b74571f2ed21a13d2d6a6a02ec2d46 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Sat, 8 Jun 2024 08:55:56 +0100 Subject: [PATCH 16/64] chore: delete more parens around 0#w, 1#w --- src/Init/Data/BitVec/Basic.lean | 2 +- src/Init/Data/BitVec/Lemmas.lean | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Init/Data/BitVec/Basic.lean b/src/Init/Data/BitVec/Basic.lean index 1f66133b40e0..ed209717b9ac 100644 --- a/src/Init/Data/BitVec/Basic.lean +++ b/src/Init/Data/BitVec/Basic.lean @@ -619,7 +619,7 @@ theorem ofBool_append (msb : Bool) (lsbs : BitVec w) : That is, 2 to the power `i`. For the bitwise point of view, it has the `i`th bit as `1` and all other bits as `0`. -/ -def twoPow {w : Nat} (i : Nat) : BitVec w := (1#w) <<< i +def twoPow {w : Nat} (i : Nat) : BitVec w := 1#w <<< i end bitwise diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index e334cfd28108..8dbb9d3804dd 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -1138,7 +1138,7 @@ theorem BitVec.mul_one {x : BitVec w} : x * 1#w = x := by simp [toNat_mul, Nat.mod_eq_of_lt x.isLt] @[simp] -theorem BitVec.mul_zero {x : BitVec w} : x * (0#w) = (0#w) := by +theorem BitVec.mul_zero {x : BitVec w} : x * 0#w = 0#w := by apply eq_of_toNat_eq simp [toNat_mul] From fc6725610acd64fc023c75559915da1b3b09128e Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Sat, 8 Jun 2024 08:57:18 +0100 Subject: [PATCH 17/64] chore: drop newline --- src/Init/Data/BitVec/Lemmas.lean | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 8dbb9d3804dd..c3fbfe09f200 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -440,7 +440,6 @@ theorem zeroExtend_ofNat_one_eq_ofNat_one_of_lt {v w : Nat} (hv : 0 < v): have hv := Nat.testBit_one_eq_true_iff_self_eq_zero.mp hi1 omega - /-! ## extractLsb -/ @[simp] From 71a6c23dd2241573a185ebb2d38dfe1a2237e59f Mon Sep 17 00:00:00 2001 From: Siddharth Date: Sat, 8 Jun 2024 08:57:56 +0100 Subject: [PATCH 18/64] Apply suggestions from code review chore: period to end of docstring. Co-authored-by: Tobias Grosser --- src/Init/Data/BitVec/Bitblast.lean | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index bb36e31b3cd5..6df167f16c34 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -159,7 +159,7 @@ theorem add_eq_adc (w : Nat) (x y : BitVec w) : x + y = (adc x y false).snd := b theorem allOnes_sub_eq_not (x : BitVec w) : allOnes w - x = ~~~x := by rw [← add_not_self x, BitVec.add_comm, add_sub_cancel] -/-- Adding two bitvectors equals or-ing them if they are 1 in mutually exclusive locations -/ +/-- Adding two bitvectors equals or-ing them if they are 1 in mutually exclusive locations. -/ theorem add_eq_or_of_and_eq_zero {w : Nat} (x y : BitVec w) (h : x &&& y = 0#w) : x + y = x ||| y := by rw [add_eq_adc, adc, iunfoldr_replace (fun _ => false) (x ||| y)] From c91edf1383ec14282cefa3b1a0f21a123c62903b Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Sat, 8 Jun 2024 09:10:33 +0100 Subject: [PATCH 19/64] chore: drop paren --- src/Init/Data/BitVec/Bitblast.lean | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 6df167f16c34..56f73563df0e 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -289,7 +289,7 @@ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow (x : BitVec w omega theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) : - (mulRec l r s) = l * ((r.truncate (s + 1)).zeroExtend w) := by + mulRec l r s = l * ((r.truncate (s + 1)).zeroExtend w) := by induction s case zero => simp [mulRec_zero_eq] From ecde706a6ab1d433125b7fabd85806ff0c56061a Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Mon, 10 Jun 2024 13:36:55 +0100 Subject: [PATCH 20/64] wip: bitblast shifts --- src/Init/Data/BitVec/Basic.lean | 2 +- src/Init/Data/BitVec/Bitblast.lean | 145 ++++++++++++++++++++++++++++- src/Init/Data/BitVec/Lemmas.lean | 16 +++- 3 files changed, 153 insertions(+), 10 deletions(-) diff --git a/src/Init/Data/BitVec/Basic.lean b/src/Init/Data/BitVec/Basic.lean index ed209717b9ac..eb33296d1e26 100644 --- a/src/Init/Data/BitVec/Basic.lean +++ b/src/Init/Data/BitVec/Basic.lean @@ -619,7 +619,7 @@ theorem ofBool_append (msb : Bool) (lsbs : BitVec w) : That is, 2 to the power `i`. For the bitwise point of view, it has the `i`th bit as `1` and all other bits as `0`. -/ -def twoPow {w : Nat} (i : Nat) : BitVec w := 1#w <<< i +def twoPow (w : Nat) (i : Nat) : BitVec w := 1#w <<< i end bitwise diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 56f73563df0e..d58baace5328 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -265,11 +265,33 @@ theorem mulRec_succ_eq (l r : BitVec w) (s : Nat) : mulRec l r (s + 1) = mulRec l r s + if r.getLsb (s + 1) then (l <<< (s + 1)) else 0 := by simp [mulRec] +theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false + {x : BitVec w} {i : Nat} {hx : x.getLsb i = false} : + zeroExtend w (x.truncate (i + 1)) = + zeroExtend w (x.truncate i) := by + ext k + simp only [getLsb_zeroExtend, Fin.is_lt, decide_True, Bool.true_and, getLsb_or, getLsb_and] + by_cases hik:i = k + · subst hik + simp [hx] + · by_cases hik' : k < i + 1 <;> simp [hik'] <;> omega + +theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true + (x : BitVec w) (i : Nat) (hx : x.getLsb i = true) : + zeroExtend w (x.truncate (i + 1)) = + zeroExtend w (x.truncate i) ||| (twoPow w i) := by + ext k + simp only [getLsb_zeroExtend, Fin.is_lt, decide_True, Bool.true_and, getLsb_or, getLsb_and] + by_cases hik : i = k + · subst hik + simp [hx] + · by_cases hik' : k < i + 1 <;> simp [hik, hik'] <;> omega + /-- Recurrence lemma: truncating to `i+1` bits and then zero extending to `w` equals truncating upto `i` bits `[0..i-1]`, and then adding the `i`th bit of `x`. -/ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow (x : BitVec w) (i : Nat) : zeroExtend w (x.truncate (i + 1)) = - zeroExtend w (x.truncate i) + (x &&& twoPow i) := by + zeroExtend w (x.truncate i) + (x &&& twoPow w i) := by rw [add_eq_or_of_and_eq_zero] · ext k simp only [getLsb_zeroExtend, Fin.is_lt, decide_True, Bool.true_and, getLsb_or, getLsb_and] @@ -285,8 +307,8 @@ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow (x : BitVec w simp [hik', hik''] · ext k simp - intros h₁ _ _ _ - omega + intros h₁ h₂ + by_cases hi : x.getLsb i <;> simp [hi] <;> omega theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) : mulRec l r s = l * ((r.truncate (s + 1)).zeroExtend w) := by @@ -304,7 +326,7 @@ theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) : rw [mulRec_succ_eq, hs] have heq : (if r.getLsb (s' + 1) = true then l <<< (s' + 1) else 0) = - (l * (r &&& (BitVec.twoPow (s' + 1)))) := by + (l * (r &&& (BitVec.twoPow w (s' + 1)))) := by simp only [ofNat_eq_ofNat, and_twoPow_eq_getLsb] by_cases hr : r.getLsb (s' + 1) <;> simp [hr] rw [heq, ← BitVec.mul_add, ← zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow] @@ -327,5 +349,120 @@ theorem getLsb_mul (x y : BitVec w) (i : Nat) : (x * y).getLsb i = (mulRec x y w).getLsb i := by simp [mulRec_eq_mul_signExtend_truncate] rw [truncate, zeroExtend_of_ge (by omega), zeroExtend_eq_self] +/- ## Shift left for arbitrary bit width -/ + +@[simp] +theorem shiftLeft_zero (x : BitVec w) : x <<< 0 = x := by + simp [bv_toNat] + +@[simp] +theorem zero_shiftLeft (n : Nat) : (0#w) <<< n = 0 := by + simp [bv_toNat] + +@[simp] +theorem truncate_one_eq_ofBool_getLsb (x : BitVec w) : + x.truncate 1 = ofBool (x.getLsb 0) := by + ext i + simp [show i = 0 by omega] + +-- x << 3 = x << 2 << 1 +def shiftLeftRec (x : BitVec w) (y : BitVec w) (n : Nat) : BitVec w := + let shiftAmt := (y &&& (twoPow w n)) + match n with + | 0 => x <<< shiftAmt + | n + 1 => (shiftLeftRec x y n) <<< shiftAmt + +@[simp] +theorem shiftLeftRec_zero (x y : BitVec w) : + shiftLeftRec x y 0 = x <<< (y &&& twoPow w 0) := by + simp [shiftLeftRec] + +@[simp] +theorem shiftLeftRec_succ (x y : BitVec w) : + shiftLeftRec x y (n + 1) = + (shiftLeftRec x y n) <<< (y &&& twoPow w (n + 1)) := by + simp [shiftLeftRec] + +-- | TODO: should this be a simp-lemma? Probably not. +theorem shiftLeft_eq' (x y : BitVec w) : + x <<< y = x <<< y.toNat := by rfl + +-- | TODO: what to name these theorems? +@[simp] +theorem shiftLeft_zero' (x : BitVec w) : + x <<< (0#w) = x := by + simp [shiftLeft_eq'] + +@[simp] +theorem getLsb_ofNat_one (w i : Nat) : + (1#w).getLsb i = (decide (i = 0) && decide (i < w)) := by + rcases w with rfl | w + · simp; + · simp [getLsb] + by_cases hi : i = 0 + · simp [hi] + · simp [hi] + intros _; simp [testBit, shiftRight_eq_div_pow]; + suffices 1 / 2^i = 0 by simp [this] + apply Nat.div_eq_of_lt; + exact Nat.one_lt_two_pow_iff.mpr hi + +theorem shiftLeft'_shiftLeft' {x y z : BitVec w} : + x <<< y <<< z = x <<< (y.toNat + z.toNat) := by + simp [shiftLeft_eq', shiftLeft_shiftLeft] + +theorem shiftLeft_or_eq_shiftLeft_shiftLeft_of_and_eq_zero {x y z : BitVec w} + (h : y &&& z = 0#w) (h' : y.toNat + z.toNat < 2^w): + x <<< (y ||| z) = x <<< y <<< z := by + simp [← add_eq_or_of_and_eq_zero _ _ h, shiftLeft_eq', shiftLeft_shiftLeft, + toNat_add, Nat.mod_eq_of_lt h'] + +theorem shiftLeftRec_eq (x y : BitVec w) (n : Nat) (hn : n + 1 ≤ w) : + shiftLeftRec x y n = x <<< (y.truncate (n + 1)).zeroExtend w := by + induction n generalizing x y + case zero => + --- TODO: find the nicest way to state this. + simp only [shiftLeftRec_zero, twoPow_zero_eq_one, Nat.reduceAdd, truncate_one_eq_ofBool_getLsb] + congr + ext i + simp only [getLsb_and, getLsb_ofNat_one, Fin.is_lt, decide_True, Bool.and_true, + getLsb_zeroExtend, getLsb_ofBool, Bool.true_and] + by_cases h : ↑i = (0 : Nat) <;> simp [h] + case succ n ih => + simp + by_cases h : y.getLsb (n + 1) <;> simp [h] + · rw [ih] + rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true _ _ h] + rw [shiftLeft_or_eq_shiftLeft_shiftLeft_of_and_eq_zero] + · ext i <;> simp + · rcases w with rfl | w + · simp + /- disgusting proof. -/ + · apply Nat.add_lt_add_of_lt_of_le + · simp + rw [Nat.mod_eq_of_lt] + · apply Nat.lt_of_lt_of_le + · apply Nat.mod_lt + · apply Nat.pow_pos (by decide) + · apply Nat.pow_le_pow_of_le_right (by decide) (by omega) + · apply Nat.lt_of_lt_of_le + · apply Nat.mod_lt + · apply Nat.pow_pos (by decide) + · apply Nat.pow_le_pow_of_le_right (by decide) (by omega) + · simp + rw [Nat.mod_eq_of_lt (by apply Nat.pow_lt_pow_of_lt (by decide) (by omega))] + apply Nat.pow_le_pow_of_le (by decide) (by omega) + · omega + · rw [ih]; + rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1)] + simp [h] + omega + +theorem shiftLeft_eq_shiftLeft_rec (x y : BitVec w) : + x <<< y = shiftLeftRec x y (w - 1) := by + rcases w with rfl | w + · apply Subsingleton.elim + · simp [shiftLeftRec_eq x y w (by omega)] + end BitVec diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index c3fbfe09f200..045d425b3387 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -1377,7 +1377,7 @@ theorem getLsb_rotateRight {x : BitVec w} {r i : Nat} : /- ## twoPow -/ @[simp] -theorem toNat_twoPow (w : Nat) (i : Nat) : (twoPow i : BitVec w).toNat = 2^i % 2^w := by +theorem toNat_twoPow (w : Nat) (i : Nat) : (twoPow w i).toNat = 2^i % 2^w := by rcases w with rfl | w · simp [Nat.mod_one] · simp [twoPow, toNat_shiftLeft] @@ -1386,7 +1386,7 @@ theorem toNat_twoPow (w : Nat) (i : Nat) : (twoPow i : BitVec w).toNat = 2^i % 2 rw [Nat.shiftLeft_eq, Nat.one_mul] @[simp] -theorem getLsb_twoPow (i j : Nat) : (twoPow i : BitVec w).getLsb j = ((i < w) && (i = j)) := by +theorem getLsb_twoPow (i j : Nat) : (twoPow w i).getLsb j = ((i < w) && (i = j)) := by rcases w with rfl | w · simp only [twoPow, BitVec.reduceOfNat, Nat.zero_le, getLsb_ge, Bool.false_eq, Bool.and_eq_false_imp, decide_eq_true_eq, decide_eq_false_iff_not] @@ -1405,15 +1405,16 @@ theorem getLsb_twoPow (i j : Nat) : (twoPow i : BitVec w).getLsb j = ((i < w) && simp at hi simp_all +@[simp] theorem and_twoPow_eq_getLsb (x : BitVec w) (i : Nat) : - x &&& (twoPow i : BitVec w) = if x.getLsb i then twoPow i else 0#w := by + x &&& (twoPow w i) = if x.getLsb i then twoPow w i else 0#w := by ext j simp only [getLsb_and, getLsb_twoPow] by_cases hj : i = j <;> by_cases hx : x.getLsb i <;> simp_all @[simp] theorem mul_twoPow_eq_shiftLeft (x : BitVec w) (i : Nat) : - x * (twoPow i : BitVec w) = x <<< i := by + x * (twoPow w i) = x <<< i := by apply eq_of_toNat_eq simp only [toNat_mul, toNat_twoPow, toNat_shiftLeft, Nat.shiftLeft_eq] by_cases hi : i < w @@ -1424,7 +1425,7 @@ theorem mul_twoPow_eq_shiftLeft (x : BitVec w) (i : Nat) : apply Nat.pow_dvd_pow 2 (by omega) simp [Nat.mul_mod, hpow] -theorem BitVec.toNat_twoPow (w : Nat) (i : Nat) : (twoPow i : BitVec w).toNat = 2^i % 2^w := by +theorem BitVec.toNat_twoPow (w : Nat) (i : Nat) : (twoPow w i).toNat = 2^i % 2^w := by rcases w with rfl | w · simp [Nat.mod_one] · simp [twoPow, toNat_shiftLeft] @@ -1433,4 +1434,9 @@ theorem BitVec.toNat_twoPow (w : Nat) (i : Nat) : (twoPow i : BitVec w).toNat = exact Nat.pow_lt_pow_of_lt (by omega) (by omega) simp [Nat.mod_eq_of_lt hone, Nat.shiftLeft_eq] +@[simp] +theorem twoPow_zero_eq_one (w : Nat) : twoPow w 0 = 1#w := by + apply eq_of_toNat_eq + simp + end BitVec From a61c43eab157f8bd6281f76618f77b3b6420e03a Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Sat, 15 Jun 2024 14:20:52 +0100 Subject: [PATCH 21/64] chore: start shift right --- src/Init/Data/BitVec/Bitblast.lean | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index d58baace5328..6d727c07ea94 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -365,7 +365,8 @@ theorem truncate_one_eq_ofBool_getLsb (x : BitVec w) : ext i simp [show i = 0 by omega] --- x << 3 = x << 2 << 1 +/-## shiftLeft recurrence -/ + def shiftLeftRec (x : BitVec w) (y : BitVec w) (n : Nat) : BitVec w := let shiftAmt := (y &&& (twoPow w n)) match n with @@ -465,4 +466,15 @@ theorem shiftLeft_eq_shiftLeft_rec (x y : BitVec w) : · simp [shiftLeftRec_eq x y w (by omega)] +/-## logical shift right recurrence -/ +def shiftRightRec (x : BitVec w) (y : BitVec w) (n : Nat) : BitVec w := + let shiftAmt := y &&& twoPow w n + match n with + | 0 => x >>> shiftAmt + | n + 1 => (shiftLeftRec x y n) >>> shiftAmt + +theorem shiftRightRec_eq (x y : BitVec w) (n : Nat) (hn : n + 1 ≤ w) : + shiftRightRec x y n = x >>> (y.truncate (n + 1)).zeroExtend w := by + sorry + end BitVec From 5cba64a247cf27f43fa906320e0e9e4627e9c609 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Wed, 26 Jun 2024 21:48:00 +0100 Subject: [PATCH 22/64] chore: update shift left to different widths --- src/Init/Data/BitVec/Bitblast.lean | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 6d727c07ea94..f8cdd34d9bb5 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -367,21 +367,21 @@ theorem truncate_one_eq_ofBool_getLsb (x : BitVec w) : /-## shiftLeft recurrence -/ -def shiftLeftRec (x : BitVec w) (y : BitVec w) (n : Nat) : BitVec w := - let shiftAmt := (y &&& (twoPow w n)) +def shiftLeftRec (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : BitVec w₁ := + let shiftAmt := (y &&& (twoPow w₂ n)) match n with | 0 => x <<< shiftAmt | n + 1 => (shiftLeftRec x y n) <<< shiftAmt @[simp] -theorem shiftLeftRec_zero (x y : BitVec w) : - shiftLeftRec x y 0 = x <<< (y &&& twoPow w 0) := by +theorem shiftLeftRec_zero (x : BitVec w₁) (y : BitVec w₂) : + shiftLeftRec x y 0 = x <<< (y &&& twoPow w₂ 0) := by simp [shiftLeftRec] @[simp] -theorem shiftLeftRec_succ (x y : BitVec w) : +theorem shiftLeftRec_succ (x : BitVec w₁) (y : BitVec w₂) : shiftLeftRec x y (n + 1) = - (shiftLeftRec x y n) <<< (y &&& twoPow w (n + 1)) := by + (shiftLeftRec x y n) <<< (y &&& twoPow w₂ (n + 1)) := by simp [shiftLeftRec] -- | TODO: should this be a simp-lemma? Probably not. From 0b4a59f6d9e7d0f5ac62a5db597b2f16bc588f57 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Thu, 27 Jun 2024 02:46:47 +0100 Subject: [PATCH 23/64] chore: more bitblast --- src/Init/Data/BitVec/Bitblast.lean | 95 +++++++++++++++--------------- 1 file changed, 47 insertions(+), 48 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index f8cdd34d9bb5..6f32e1c895cf 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -267,8 +267,8 @@ theorem mulRec_succ_eq (l r : BitVec w) (s : Nat) : theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false {x : BitVec w} {i : Nat} {hx : x.getLsb i = false} : - zeroExtend w (x.truncate (i + 1)) = - zeroExtend w (x.truncate i) := by + zeroExtend w₂ (x.truncate (i + 1)) = + zeroExtend w₂ (x.truncate i) := by ext k simp only [getLsb_zeroExtend, Fin.is_lt, decide_True, Bool.true_and, getLsb_or, getLsb_and] by_cases hik:i = k @@ -278,8 +278,8 @@ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true (x : BitVec w) (i : Nat) (hx : x.getLsb i = true) : - zeroExtend w (x.truncate (i + 1)) = - zeroExtend w (x.truncate i) ||| (twoPow w i) := by + zeroExtend w₂ (x.truncate (i + 1)) = + zeroExtend w₂ (x.truncate i) ||| (twoPow w₂ i) := by ext k simp only [getLsb_zeroExtend, Fin.is_lt, decide_True, Bool.true_and, getLsb_or, getLsb_and] by_cases hik : i = k @@ -307,7 +307,6 @@ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow (x : BitVec w simp [hik', hik''] · ext k simp - intros h₁ h₂ by_cases hi : x.getLsb i <;> simp [hi] <;> omega theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) : @@ -385,13 +384,13 @@ theorem shiftLeftRec_succ (x : BitVec w₁) (y : BitVec w₂) : simp [shiftLeftRec] -- | TODO: should this be a simp-lemma? Probably not. -theorem shiftLeft_eq' (x y : BitVec w) : +theorem shiftLeft_eq' (x : BitVec w) (y : BitVec w₂) : x <<< y = x <<< y.toNat := by rfl -- | TODO: what to name these theorems? @[simp] theorem shiftLeft_zero' (x : BitVec w) : - x <<< (0#w) = x := by + x <<< (0#w₂) = x := by simp [shiftLeft_eq'] @[simp] @@ -412,52 +411,63 @@ theorem shiftLeft'_shiftLeft' {x y z : BitVec w} : x <<< y <<< z = x <<< (y.toNat + z.toNat) := by simp [shiftLeft_eq', shiftLeft_shiftLeft] -theorem shiftLeft_or_eq_shiftLeft_shiftLeft_of_and_eq_zero {x y z : BitVec w} - (h : y &&& z = 0#w) (h' : y.toNat + z.toNat < 2^w): +theorem shiftLeft_or_eq_shiftLeft_shiftLeft_of_and_eq_zero {x : BitVec w} {y z : BitVec w₂} + (h : y &&& z = 0#w₂) (h' : y.toNat + z.toNat < 2^w₂): x <<< (y ||| z) = x <<< y <<< z := by simp [← add_eq_or_of_and_eq_zero _ _ h, shiftLeft_eq', shiftLeft_shiftLeft, toNat_add, Nat.mod_eq_of_lt h'] -theorem shiftLeftRec_eq (x y : BitVec w) (n : Nat) (hn : n + 1 ≤ w) : - shiftLeftRec x y n = x <<< (y.truncate (n + 1)).zeroExtend w := by + +theorem getLsb_shiftLeft' (x : BitVec w) (y : BitVec w₂) (i : Nat) : + (x <<< y).getLsb i = (decide (i < w) && !decide (i < y.toNat) && x.getLsb (i - y.toNat)) := by + simp [shiftLeft_eq', getLsb_shiftLeft] + +theorem shiftLeftRec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) (hn : n + 1 ≤ w₂) : + shiftLeftRec x y n = x <<< (y.truncate (n + 1)).zeroExtend w₂ := by induction n generalizing x y case zero => - --- TODO: find the nicest way to state this. - simp only [shiftLeftRec_zero, twoPow_zero_eq_one, Nat.reduceAdd, truncate_one_eq_ofBool_getLsb] - congr ext i - simp only [getLsb_and, getLsb_ofNat_one, Fin.is_lt, decide_True, Bool.and_true, - getLsb_zeroExtend, getLsb_ofBool, Bool.true_and] - by_cases h : ↑i = (0 : Nat) <;> simp [h] + simp only [shiftLeftRec_zero, twoPow_zero_eq_one, Nat.reduceAdd, truncate_one_eq_ofBool_getLsb] + have heq : (y &&& 1#w₂) = zeroExtend w₂ (ofBool (y.getLsb 0)) := by + ext i + by_cases h : (↑i : Nat) = 0 <;> simp [h, Bool.and_comm] + simp [heq] case succ n ih => simp by_cases h : y.getLsb (n + 1) <;> simp [h] - · rw [ih] + · rw [ih (hn := by omega)] rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true _ _ h] rw [shiftLeft_or_eq_shiftLeft_shiftLeft_of_and_eq_zero] - · ext i <;> simp - · rcases w with rfl | w - · simp - /- disgusting proof. -/ + · simp + · simp; + have hpow : 2 ^ (n + 1) < 2 ^ w₂ := by + apply Nat.pow_lt_pow_of_lt (by decide) (by omega) + have h₂ : 2 ^ (n + 1) % 2 ^ w₂ = 2 ^ (n + 1) := Nat.mod_eq_of_lt (by omega) + have h₁ : y.toNat % 2 ^ (n + 1) % 2 ^ w₂ = y.toNat % 2 ^ (n + 1) := by + apply Nat.mod_eq_of_lt + apply Nat.lt_of_lt_of_le (m := 2 ^ (n + 1)) + apply Nat.mod_lt + apply Nat.pow_pos (by decide); omega + obtain h₁ : y.toNat % 2 ^ (n + 1) % 2 ^ w₂ = y.toNat % 2 ^ (n + 1) := by + apply Nat.mod_eq_of_lt + apply Nat.lt_of_lt_of_le (m := 2 ^ (n + 1)) <;> omega + rw [h₁, h₂] + rcases w₂ with rfl | w₂ + · omega · apply Nat.add_lt_add_of_lt_of_le + · simp only [pow_eq, Nat.mul_eq, Nat.mul_one] + apply Nat.lt_of_lt_of_le (m := 2 ^ (n + 1)) + · apply Nat.mod_lt + · apply Nat.pow_pos (by decide) + · apply Nat.pow_le_pow_of_le_right (by decide) (by omega) · simp - rw [Nat.mod_eq_of_lt] - · apply Nat.lt_of_lt_of_le - · apply Nat.mod_lt - · apply Nat.pow_pos (by decide) - · apply Nat.pow_le_pow_of_le_right (by decide) (by omega) - · apply Nat.lt_of_lt_of_le - · apply Nat.mod_lt - · apply Nat.pow_pos (by decide) - · apply Nat.pow_le_pow_of_le_right (by decide) (by omega) - · simp - rw [Nat.mod_eq_of_lt (by apply Nat.pow_lt_pow_of_lt (by decide) (by omega))] - apply Nat.pow_le_pow_of_le (by decide) (by omega) - · omega - · rw [ih]; + apply Nat.pow_le_pow_of_le_right (by decide) (by omega) + · rw [ih (hn := by omega)] rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1)] simp [h] - omega + +/-- info: 'BitVec.shiftLeftRec_eq' depends on axioms: [propext, Quot.sound, Classical.choice] -/ +#guard_msgs in #print axioms shiftLeftRec_eq theorem shiftLeft_eq_shiftLeft_rec (x y : BitVec w) : x <<< y = shiftLeftRec x y (w - 1) := by @@ -466,15 +476,4 @@ theorem shiftLeft_eq_shiftLeft_rec (x y : BitVec w) : · simp [shiftLeftRec_eq x y w (by omega)] -/-## logical shift right recurrence -/ -def shiftRightRec (x : BitVec w) (y : BitVec w) (n : Nat) : BitVec w := - let shiftAmt := y &&& twoPow w n - match n with - | 0 => x >>> shiftAmt - | n + 1 => (shiftLeftRec x y n) >>> shiftAmt - -theorem shiftRightRec_eq (x y : BitVec w) (n : Nat) (hn : n + 1 ≤ w) : - shiftRightRec x y n = x >>> (y.truncate (n + 1)).zeroExtend w := by - sorry - end BitVec From cbf80f80a86af0cc2919b60fed599bf65848f689 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Thu, 27 Jun 2024 10:08:48 +0100 Subject: [PATCH 24/64] chore: fixup final theorem --- src/Init/Data/BitVec/Bitblast.lean | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 6f32e1c895cf..59eb42fa6c02 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -469,11 +469,10 @@ theorem shiftLeftRec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) (hn : n + /-- info: 'BitVec.shiftLeftRec_eq' depends on axioms: [propext, Quot.sound, Classical.choice] -/ #guard_msgs in #print axioms shiftLeftRec_eq -theorem shiftLeft_eq_shiftLeft_rec (x y : BitVec w) : - x <<< y = shiftLeftRec x y (w - 1) := by - rcases w with rfl | w - · apply Subsingleton.elim - · simp [shiftLeftRec_eq x y w (by omega)] - +theorem shiftLeft_eq_shiftLeft_rec (x : BitVec ℘) (y : BitVec w₂) : + x <<< y = shiftLeftRec x y (w₂ - 1) := by + rcases w₂ with rfl | w₂ + · simp [of_length_zero] + · simp [shiftLeftRec_eq x y w₂ (by omega)] end BitVec From b14ba43d9f7787ee6eee9c390ed447d06161bfc7 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Thu, 27 Jun 2024 13:42:03 +0100 Subject: [PATCH 25/64] chore: shiftRight recurrence theorem --- src/Init/Data/BitVec/Bitblast.lean | 102 +++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 59eb42fa6c02..f87ced214770 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -475,4 +475,106 @@ theorem shiftLeft_eq_shiftLeft_rec (x : BitVec ℘) (y : BitVec w₂) : · simp [of_length_zero] · simp [shiftLeftRec_eq x y w₂ (by omega)] + +/-## sshiftRight recurrence -/ + +def sshiftRight_rec (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : BitVec w₁ := + let shiftAmt := (y &&& (twoPow w₂ n)) + match n with + | 0 => x >>> shiftAmt + | n + 1 => (sshiftRight_rec x y n) >>> shiftAmt + +@[simp] +theorem sshiftRight_rec_zero (x : BitVec w₁) (y : BitVec w₂) : + sshiftRight_rec x y 0 = x >>> (y &&& twoPow w₂ 0) := by + simp [sshiftRight_rec] + +@[simp] +theorem sshiftRight_rec_succ (x : BitVec w₁) (y : BitVec w₂) : + sshiftRight_rec x y (n + 1) = + (sshiftRight_rec x y n) >>> (y &&& twoPow w₂ (n + 1)) := by + simp [sshiftRight_rec] + +-- | TODO: should this be a simp-lemma? Probably not. +theorem sshiftRight_eq' (x : BitVec w) (y : BitVec w₂) : + x >>> y = x >>> y.toNat := by rfl + + +@[simp] +theorem BitVec.sshiftRight_zero (x : BitVec w) : x >>> 0 = x := by + simp [bv_toNat] + +-- | TODO: what to name these theorems? +@[simp] +theorem sshiftRight_zero' (x : BitVec w) : + x >>> (0#w₂) = x := by + simp [sshiftRight_eq'] + +theorem sshiftRight'_sshiftRight' {x y z : BitVec w} : + x >>> y >>> z = x >>> (y.toNat + z.toNat) := by + simp [sshiftRight_eq', shiftRight_shiftRight] + +theorem sshiftRight_or_eq_sshiftRight_sshiftRight_of_and_eq_zero {x : BitVec w} {y z : BitVec w₂} + (h : y &&& z = 0#w₂) (h' : y.toNat + z.toNat < 2^w₂): + x >>> (y ||| z) = x >>> y >>> z := by + simp [← add_eq_or_of_and_eq_zero _ _ h, sshiftRight_eq', shiftRight_shiftRight, + toNat_add, Nat.mod_eq_of_lt h'] + +theorem getLsb_sshiftRight' (x : BitVec w) (y : BitVec w₂) (i : Nat) : + (x >>> y).getLsb i = x.getLsb (y.toNat + i) := by + simp [sshiftRight_eq', getLsb_sshiftRight] + +theorem sshiftRight_rec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) (hn : n + 1 ≤ w₂) : + sshiftRight_rec x y n = x >>> (y.truncate (n + 1)).zeroExtend w₂ := by + induction n generalizing x y + case zero => + ext i + simp only [sshiftRight_rec_zero, twoPow_zero_eq_one, Nat.reduceAdd, truncate_one_eq_ofBool_getLsb] + have heq : (y &&& 1#w₂) = zeroExtend w₂ (ofBool (y.getLsb 0)) := by + ext i + by_cases h : (↑i : Nat) = 0 <;> simp [h, Bool.and_comm] + simp [heq] + case succ n ih => + simp + by_cases h : y.getLsb (n + 1) <;> simp [h] + · rw [ih (hn := by omega)] + rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true _ _ h] + rw [sshiftRight_or_eq_sshiftRight_sshiftRight_of_and_eq_zero] + · simp + · simp; + have hpow : 2 ^ (n + 1) < 2 ^ w₂ := by + apply Nat.pow_lt_pow_of_lt (by decide) (by omega) + have h₂ : 2 ^ (n + 1) % 2 ^ w₂ = 2 ^ (n + 1) := Nat.mod_eq_of_lt (by omega) + have h₁ : y.toNat % 2 ^ (n + 1) % 2 ^ w₂ = y.toNat % 2 ^ (n + 1) := by + apply Nat.mod_eq_of_lt + apply Nat.lt_of_lt_of_le (m := 2 ^ (n + 1)) + apply Nat.mod_lt + apply Nat.pow_pos (by decide); omega + obtain h₁ : y.toNat % 2 ^ (n + 1) % 2 ^ w₂ = y.toNat % 2 ^ (n + 1) := by + apply Nat.mod_eq_of_lt + apply Nat.lt_of_lt_of_le (m := 2 ^ (n + 1)) <;> omega + rw [h₁, h₂] + rcases w₂ with rfl | w₂ + · omega + · apply Nat.add_lt_add_of_lt_of_le + · simp only [pow_eq, Nat.mul_eq, Nat.mul_one] + apply Nat.lt_of_lt_of_le (m := 2 ^ (n + 1)) + · apply Nat.mod_lt + · apply Nat.pow_pos (by decide) + · apply Nat.pow_le_pow_of_le_right (by decide) (by omega) + · simp + apply Nat.pow_le_pow_of_le_right (by decide) (by omega) + · rw [ih (hn := by omega)] + rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1)] + simp [h] + +/-- info: 'BitVec.sshiftRight_rec_eq' depends on axioms: [propext, Quot.sound, Classical.choice] -/ +#guard_msgs in #print axioms sshiftRight_rec_eq + +theorem shiftRight_eq_shiftRight_rec (x : BitVec ℘) (y : BitVec w₂) : + x >>> y = sshiftRight_rec x y (w₂ - 1) := by + rcases w₂ with rfl | w₂ + · simp [of_length_zero] + · simp [sshiftRight_rec_eq x y w₂ (by omega)] + end BitVec From 2e6ff6fad8fbf663809d1ed0b96714c14891ea87 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Thu, 27 Jun 2024 13:47:45 +0100 Subject: [PATCH 26/64] chore: that was logical shift right. Skip arithmetic shift right for now The arithmetic shift right lacks a `BitVec` implementation anyway, so we move on to div/rem. --- src/Init/Data/BitVec/Bitblast.lean | 57 +++++++++++++++--------------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index f87ced214770..846b58a4dc7f 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -476,60 +476,60 @@ theorem shiftLeft_eq_shiftLeft_rec (x : BitVec ℘) (y : BitVec w₂) : · simp [shiftLeftRec_eq x y w₂ (by omega)] -/-## sshiftRight recurrence -/ +/-## (Arithmetic) sshiftRight recurrence -/ -def sshiftRight_rec (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : BitVec w₁ := +def ushiftRight_rec (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : BitVec w₁ := let shiftAmt := (y &&& (twoPow w₂ n)) match n with | 0 => x >>> shiftAmt - | n + 1 => (sshiftRight_rec x y n) >>> shiftAmt + | n + 1 => (ushiftRight_rec x y n) >>> shiftAmt @[simp] -theorem sshiftRight_rec_zero (x : BitVec w₁) (y : BitVec w₂) : - sshiftRight_rec x y 0 = x >>> (y &&& twoPow w₂ 0) := by - simp [sshiftRight_rec] +theorem ushiftRight_rec_zero (x : BitVec w₁) (y : BitVec w₂) : + ushiftRight_rec x y 0 = x >>> (y &&& twoPow w₂ 0) := by + simp [ushiftRight_rec] @[simp] -theorem sshiftRight_rec_succ (x : BitVec w₁) (y : BitVec w₂) : - sshiftRight_rec x y (n + 1) = - (sshiftRight_rec x y n) >>> (y &&& twoPow w₂ (n + 1)) := by - simp [sshiftRight_rec] +theorem ushiftRight_rec_succ (x : BitVec w₁) (y : BitVec w₂) : + ushiftRight_rec x y (n + 1) = + (ushiftRight_rec x y n) >>> (y &&& twoPow w₂ (n + 1)) := by + simp [ushiftRight_rec] -- | TODO: should this be a simp-lemma? Probably not. -theorem sshiftRight_eq' (x : BitVec w) (y : BitVec w₂) : +theorem ushiftRight_eq' (x : BitVec w) (y : BitVec w₂) : x >>> y = x >>> y.toNat := by rfl @[simp] -theorem BitVec.sshiftRight_zero (x : BitVec w) : x >>> 0 = x := by +theorem BitVec.ushiftRight_zero (x : BitVec w) : x >>> 0 = x := by simp [bv_toNat] -- | TODO: what to name these theorems? @[simp] -theorem sshiftRight_zero' (x : BitVec w) : +theorem ushiftRight_zero' (x : BitVec w) : x >>> (0#w₂) = x := by - simp [sshiftRight_eq'] + simp [ushiftRight_eq'] -theorem sshiftRight'_sshiftRight' {x y z : BitVec w} : +theorem ushiftRight'_ushiftRight' {x y z : BitVec w} : x >>> y >>> z = x >>> (y.toNat + z.toNat) := by - simp [sshiftRight_eq', shiftRight_shiftRight] + simp [ushiftRight_eq', shiftRight_shiftRight] -theorem sshiftRight_or_eq_sshiftRight_sshiftRight_of_and_eq_zero {x : BitVec w} {y z : BitVec w₂} +theorem ushiftRight_or_eq_ushiftRight_ushiftRight_of_and_eq_zero {x : BitVec w} {y z : BitVec w₂} (h : y &&& z = 0#w₂) (h' : y.toNat + z.toNat < 2^w₂): x >>> (y ||| z) = x >>> y >>> z := by - simp [← add_eq_or_of_and_eq_zero _ _ h, sshiftRight_eq', shiftRight_shiftRight, + simp [← add_eq_or_of_and_eq_zero _ _ h, ushiftRight_eq', shiftRight_shiftRight, toNat_add, Nat.mod_eq_of_lt h'] -theorem getLsb_sshiftRight' (x : BitVec w) (y : BitVec w₂) (i : Nat) : +theorem getLsb_ushiftRight' (x : BitVec w) (y : BitVec w₂) (i : Nat) : (x >>> y).getLsb i = x.getLsb (y.toNat + i) := by - simp [sshiftRight_eq', getLsb_sshiftRight] + simp [ushiftRight_eq', getLsb_ushiftRight] -theorem sshiftRight_rec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) (hn : n + 1 ≤ w₂) : - sshiftRight_rec x y n = x >>> (y.truncate (n + 1)).zeroExtend w₂ := by +theorem ushiftRight_rec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) (hn : n + 1 ≤ w₂) : + ushiftRight_rec x y n = x >>> (y.truncate (n + 1)).zeroExtend w₂ := by induction n generalizing x y case zero => ext i - simp only [sshiftRight_rec_zero, twoPow_zero_eq_one, Nat.reduceAdd, truncate_one_eq_ofBool_getLsb] + simp only [ushiftRight_rec_zero, twoPow_zero_eq_one, Nat.reduceAdd, truncate_one_eq_ofBool_getLsb] have heq : (y &&& 1#w₂) = zeroExtend w₂ (ofBool (y.getLsb 0)) := by ext i by_cases h : (↑i : Nat) = 0 <;> simp [h, Bool.and_comm] @@ -539,7 +539,7 @@ theorem sshiftRight_rec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) (hn : n by_cases h : y.getLsb (n + 1) <;> simp [h] · rw [ih (hn := by omega)] rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true _ _ h] - rw [sshiftRight_or_eq_sshiftRight_sshiftRight_of_and_eq_zero] + rw [ushiftRight_or_eq_ushiftRight_ushiftRight_of_and_eq_zero] · simp · simp; have hpow : 2 ^ (n + 1) < 2 ^ w₂ := by @@ -568,13 +568,14 @@ theorem sshiftRight_rec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) (hn : n rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1)] simp [h] -/-- info: 'BitVec.sshiftRight_rec_eq' depends on axioms: [propext, Quot.sound, Classical.choice] -/ -#guard_msgs in #print axioms sshiftRight_rec_eq +/-- info: 'BitVec.ushiftRight_rec_eq' depends on axioms: [propext, Quot.sound, Classical.choice] -/ +#guard_msgs in #print axioms ushiftRight_rec_eq theorem shiftRight_eq_shiftRight_rec (x : BitVec ℘) (y : BitVec w₂) : - x >>> y = sshiftRight_rec x y (w₂ - 1) := by + x >>> y = ushiftRight_rec x y (w₂ - 1) := by rcases w₂ with rfl | w₂ · simp [of_length_zero] - · simp [sshiftRight_rec_eq x y w₂ (by omega)] + · simp [ushiftRight_rec_eq x y w₂ (by omega)] + end BitVec From af5b90d51c15f02bfddb6ef1d2ba1762459871c1 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 28 Jun 2024 17:30:49 +0100 Subject: [PATCH 27/64] feat: characterize div and mod via arithmetic --- src/Init/Data/BitVec/Bitblast.lean | 75 ++++++++++++++++++++++++++++++ src/Init/Data/BitVec/Lemmas.lean | 22 +++++++++ 2 files changed, 97 insertions(+) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 846b58a4dc7f..3e3f60425344 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -577,5 +577,80 @@ theorem shiftRight_eq_shiftRight_rec (x : BitVec ℘) (y : BitVec w₂) : · simp [of_length_zero] · simp [ushiftRight_rec_eq x y w₂ (by omega)] +/- ## udiv/urem bitblasting -/ + +/- +r = n - d * q +r = n - d * (∑ i, 2^i * q.getLsb i) + +-/ + +/-- TODO: This theorem surely exists somewhere. -/ +theorem Nat.div_add_eq_left_of_lt {x y z : Nat} (hx : z ∣ x) (hy : y < z) (hz : 0 < z): + (x + y) / z = x / z := by + refine Nat.div_eq_of_lt_le ?lo ?hi + · apply Nat.le_trans + · exact div_mul_le_self x z + · omega + · simp only [succ_eq_add_one, Nat.add_mul, Nat.one_mul] + apply Nat.add_lt_add_of_le_of_lt + · apply Nat.le_of_eq + exact (Nat.div_eq_iff_eq_mul_left hz hx).mp rfl + · exact hy + +theorem div_characterized {d n q r : BitVec w} {hd : 0 < d} + (hrd : r < d) + (hdqnr : d.toNat * q.toNat + r.toNat = n.toNat) : + (n.udiv d = q ∧ n.umod d = r) := by + constructor + · apply BitVec.eq_of_toNat_eq + rw [toNat_udiv hd] + replace hdqnr : (d.toNat * q.toNat + r.toNat) / d.toNat = n.toNat / d.toNat := by + simp [hdqnr] + rw [Nat.div_add_eq_left_of_lt] at hdqnr + · rw [← hdqnr] + exact mul_div_right q.toNat hd + · exact Nat.dvd_mul_right d.toNat q.toNat + · exact hrd + · exact hd + · apply BitVec.eq_of_toNat_eq + rw [toNat_umod] + replace hdqnr : (d.toNat * q.toNat + r.toNat) % d.toNat = n.toNat % d.toNat := by + simp [hdqnr] + rw [Nat.add_mod, Nat.mul_mod_right] at hdqnr + simp at hdqnr + replace hrd : r.toNat < d.toNat := by + rw [BitVec.lt_def] at hrd + exact hrd -- TODO: golf + rw [Nat.mod_eq_of_lt hrd] at hdqnr + simp [hdqnr] + +theorem div_characterized' {d n q r : BitVec w} {hd : 0 < d} + (hqr : n.udiv d = q ∧ n.umod d = r) : + (d.toNat * q.toNat + r.toNat = n.toNat) := by + obtain ⟨hq, hr⟩ := hqr + have hdiv : n.toNat / d.toNat = q.toNat := by + rw [← toNat_udiv hd] -- TODO: squeeze + rw [(toNat_eq _ _).mp hq] + have hmod : n.toNat % d.toNat = r.toNat := by + rw [← toNat_umod] -- TODO: squeeze + rw [(toNat_eq _ _).mp hr] + rw [← hdiv, ← hmod] -- TODO: flip + rw [div_add_mod] + +/- Given d, R(j + 1), (calculate R(j), q.getLsb j)-/ +-- def divremi (d : BitVec w) (rjsucc : BitVec w) (j : Nat) : BitVec w × Bool := +-- -- optimistically assume (q.getLsb j = 1) and perform the subtraction. +-- let rj? := rjsucc - d * twoPow w j +-- if rj? ≥ 0 -- yay, this subtraction is allowed. +-- then (rj?, true) -- confirm the results. +-- else (rjsucc, false) -- discard the results. + +-- def divrem_rec (d : BitVec w) (n : BitVec w) (j : Nat) : BitVec w × BitVec w := +-- match j with +-- | 0 => divremi d n 0 +-- | j + 1 => +-- let (b, rj') := divrem_rec d n j +-- divremi d (if b then rj' else n) (j + 1) end BitVec diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 045d425b3387..9900628ccee4 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -772,6 +772,28 @@ theorem udiv_eq {x y : BitVec n} : rw [Nat.mod_eq_of_lt] exact Nat.lt_of_le_of_lt (Nat.div_le_self ..) (by omega) +theorem toNat_udiv {x y : BitVec n} (hy : 0 < y): + (x.udiv y).toNat = x.toNat / y.toNat := by + rw [udiv_eq] + simp only [toNat_ofNat] + rw [Nat.mod_eq_of_lt] + rw [Nat.div_lt_iff_lt_mul hy] + apply Nat.lt_of_lt_of_le x.isLt + apply Nat.le_mul_of_pos_right _ hy + +/-! ### umod -/ + +theorem umod_eq {x y : BitVec n} : + x.umod y = BitVec.ofNat n (x.toNat % y.toNat) := by + apply BitVec.eq_of_toNat_eq + simp only [umod, toNat_ofNatLt, toNat_ofNat] + rw [Nat.mod_eq_of_lt (b := 2^n)] + apply Nat.lt_of_le_of_lt (Nat.mod_le _ _) x.isLt + +@[simp] +theorem toNat_umod {x y : BitVec n} : + (x.umod y).toNat = x.toNat % y.toNat := by rfl + /-! ### append -/ theorem append_def (x : BitVec v) (y : BitVec w) : From f967caf541a2aad3f9b74d1eaf644f3754fc1380 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 28 Jun 2024 17:31:20 +0100 Subject: [PATCH 28/64] chore: cleanup hypotheses --- src/Init/Data/BitVec/Bitblast.lean | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 3e3f60425344..034213ee09d7 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -626,9 +626,8 @@ theorem div_characterized {d n q r : BitVec w} {hd : 0 < d} simp [hdqnr] theorem div_characterized' {d n q r : BitVec w} {hd : 0 < d} - (hqr : n.udiv d = q ∧ n.umod d = r) : + (hq : n.udiv d = q) (hr : n.umod d = r) : (d.toNat * q.toNat + r.toNat = n.toNat) := by - obtain ⟨hq, hr⟩ := hqr have hdiv : n.toNat / d.toNat = q.toNat := by rw [← toNat_udiv hd] -- TODO: squeeze rw [(toNat_eq _ _).mp hq] @@ -638,6 +637,7 @@ theorem div_characterized' {d n q r : BitVec w} {hd : 0 < d} rw [← hdiv, ← hmod] -- TODO: flip rw [div_add_mod] + /- Given d, R(j + 1), (calculate R(j), q.getLsb j)-/ -- def divremi (d : BitVec w) (rjsucc : BitVec w) (j : Nat) : BitVec w × Bool := -- -- optimistically assume (q.getLsb j = 1) and perform the subtraction. From b9f78b5f1a7d77322132673fbbac1be8b6f90490 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 28 Jun 2024 17:54:32 +0100 Subject: [PATCH 29/64] chore: prove another equivalence given an lt hypothesis --- src/Init/Data/BitVec/Bitblast.lean | 36 ++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 034213ee09d7..e21f00a566fc 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -585,6 +585,23 @@ r = n - d * (∑ i, 2^i * q.getLsb i) -/ + +/-! +Let us study an instructive counterexample to the claim that + `n = d * q + r` for (`0 ≤ r < d`) uniquely determining q and r *over bitvectors*. + +- Let `bitwidth = 3` +- Let `n = 0, d = 3` +- If we choose `q = 2, r = 2`, then d * q + r = 6 + 2 = 8 ≃ 0 (mod 8) so satisfies. +- But see that `q = 0, r = 0` also satisfies, as 0 * 3 + 0 = 0. +- So for (`n = 0, d = 3`), both: + `q = 2, r = 2` as well as + `q = 0, r = 0` are solutions! + +It's easy to cook up such examples, by chosing `(q, r)` for a fixed `(d, n)` +such that `(d * q + r)` overflows. +-/ + /-- TODO: This theorem surely exists somewhere. -/ theorem Nat.div_add_eq_left_of_lt {x y z : Nat} (hx : z ∣ x) (hy : y < z) (hz : 0 < z): (x + y) / z = x / z := by @@ -598,7 +615,7 @@ theorem Nat.div_add_eq_left_of_lt {x y z : Nat} (hx : z ∣ x) (hy : y < z) (hz exact (Nat.div_eq_iff_eq_mul_left hz hx).mp rfl · exact hy -theorem div_characterized {d n q r : BitVec w} {hd : 0 < d} +theorem div_characterized_of_mul_add_toNat {d n q r : BitVec w} {hd : 0 < d} (hrd : r < d) (hdqnr : d.toNat * q.toNat + r.toNat = n.toNat) : (n.udiv d = q ∧ n.umod d = r) := by @@ -625,7 +642,22 @@ theorem div_characterized {d n q r : BitVec w} {hd : 0 < d} rw [Nat.mod_eq_of_lt hrd] at hdqnr simp [hdqnr] -theorem div_characterized' {d n q r : BitVec w} {hd : 0 < d} +theorem div_characterized_of_mul_add_of_lt {d n q r : BitVec w} {hd : 0 < d} + (hrd : r < d) + (hdqnr : d * q + r = n) + (hlt : d.toNat * q.toNat + r.toNat < 2^w) : + (n.udiv d = q ∧ n.umod d = r) := by + apply div_characterized_of_mul_add_toNat <;> try assumption + apply Eq.symm + have hlt' : d.toNat * q.toNat < 2^w := by omega + calc + n.toNat = (d * q + r).toNat := by rw [← hdqnr] + _ = ((d * q).toNat + r.toNat) % 2^w := by simp [BitVec.toNat_add] + _ = ((d.toNat * q.toNat) % 2^w + r.toNat) % 2^w := by simp [BitVec.toNat_mul] + _ = ((d.toNat * q.toNat) + r.toNat) % 2^w := by simp [Nat.mod_eq_of_lt hlt'] + _ = ((d.toNat * q.toNat) + r.toNat) := by simp [Nat.mod_eq_of_lt hlt] + +theorem div_characterized_toNat_of_eq_udiv_of_eq_umod {d n q r : BitVec w} {hd : 0 < d} (hq : n.udiv d = q) (hr : n.umod d = r) : (d.toNat * q.toNat + r.toNat = n.toNat) := by have hdiv : n.toNat / d.toNat = q.toNat := by From 904a01bd30c7336000e83ede4a4057150ced14e4 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 28 Jun 2024 18:01:52 +0100 Subject: [PATCH 30/64] chore: prove necessary equivalence with sufficiently weak hypothesis to embark on divrem quest --- src/Init/Data/BitVec/Bitblast.lean | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index e21f00a566fc..aae8b1241f5a 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -615,7 +615,7 @@ theorem Nat.div_add_eq_left_of_lt {x y z : Nat} (hx : z ∣ x) (hy : y < z) (hz exact (Nat.div_eq_iff_eq_mul_left hz hx).mp rfl · exact hy -theorem div_characterized_of_mul_add_toNat {d n q r : BitVec w} {hd : 0 < d} +theorem div_characterized_of_mul_add_toNat {d n q r : BitVec w} (hd : 0 < d) (hrd : r < d) (hdqnr : d.toNat * q.toNat + r.toNat = n.toNat) : (n.udiv d = q ∧ n.umod d = r) := by @@ -642,7 +642,7 @@ theorem div_characterized_of_mul_add_toNat {d n q r : BitVec w} {hd : 0 < d} rw [Nat.mod_eq_of_lt hrd] at hdqnr simp [hdqnr] -theorem div_characterized_of_mul_add_of_lt {d n q r : BitVec w} {hd : 0 < d} +theorem div_characterized_of_mul_add_of_lt {d n q r : BitVec w} (hd : 0 < d) (hrd : r < d) (hdqnr : d * q + r = n) (hlt : d.toNat * q.toNat + r.toNat < 2^w) : @@ -657,7 +657,7 @@ theorem div_characterized_of_mul_add_of_lt {d n q r : BitVec w} {hd : 0 < d} _ = ((d.toNat * q.toNat) + r.toNat) % 2^w := by simp [Nat.mod_eq_of_lt hlt'] _ = ((d.toNat * q.toNat) + r.toNat) := by simp [Nat.mod_eq_of_lt hlt] -theorem div_characterized_toNat_of_eq_udiv_of_eq_umod {d n q r : BitVec w} {hd : 0 < d} +theorem div_characterized_toNat_of_eq_udiv_of_eq_umod {d n q r : BitVec w} (hd : 0 < d) (hq : n.udiv d = q) (hr : n.umod d = r) : (d.toNat * q.toNat + r.toNat = n.toNat) := by have hdiv : n.toNat / d.toNat = q.toNat := by @@ -669,6 +669,24 @@ theorem div_characterized_toNat_of_eq_udiv_of_eq_umod {d n q r : BitVec w} {hd : rw [← hdiv, ← hmod] -- TODO: flip rw [div_add_mod] +theorem div_characterized_toNat_of_eq_udiv_of_eq_umod_of_lt {d n q r : BitVec w} (hd : 0 < d) + (hq : n.udiv d = q) (hr : n.umod d = r) + (hlt : d.toNat * q.toNat + r.toNat < 2^w) : + d * q + r = n := by + apply eq_of_toNat_eq + simp [toNat_add, toNat_mul] + rw [Nat.mod_eq_of_lt hlt] + apply div_characterized_toNat_of_eq_udiv_of_eq_umod hd hq hr + +theorem div_iff_add_mod_of_lt {d n q r : BitVec w} (hd : 0 < d) + (hrd : r < d) + (hlt : d.toNat * q.toNat + r.toNat < 2^w) : + (n.udiv d = q ∧ n.umod d = r) ↔ (d * q + r = n) := by + constructor + · intros h; obtain ⟨h₁, h₂⟩ := h + apply div_characterized_toNat_of_eq_udiv_of_eq_umod_of_lt <;> assumption + · intros h + apply div_characterized_of_mul_add_of_lt <;> assumption /- Given d, R(j + 1), (calculate R(j), q.getLsb j)-/ -- def divremi (d : BitVec w) (rjsucc : BitVec w) (j : Nat) : BitVec w × Bool := From 36711b01994bdd48468566af296ceb0935fc17a2 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 28 Jun 2024 19:33:46 +0100 Subject: [PATCH 31/64] chore: write theorem statement for DivRem recurrence --- src/Init/Data/BitVec/Bitblast.lean | 126 +++++++++++++++++++++++++++-- 1 file changed, 119 insertions(+), 7 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index aae8b1241f5a..7a5fac0fd97e 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -602,6 +602,13 @@ It's easy to cook up such examples, by chosing `(q, r)` for a fixed `(d, n)` such that `(d * q + r)` overflows. -/ +/-! +References: +- Fast 32-bit Division on the DSP56800E: Minimized nonrestoring division algorithm by David Baca +- Bitwuzla sources for bitblasting.h +-/ + + /-- TODO: This theorem surely exists somewhere. -/ theorem Nat.div_add_eq_left_of_lt {x y z : Nat} (hx : z ∣ x) (hy : y < z) (hz : 0 < z): (x + y) / z = x / z := by @@ -688,13 +695,118 @@ theorem div_iff_add_mod_of_lt {d n q r : BitVec w} (hd : 0 < d) · intros h apply div_characterized_of_mul_add_of_lt <;> assumption -/- Given d, R(j + 1), (calculate R(j), q.getLsb j)-/ --- def divremi (d : BitVec w) (rjsucc : BitVec w) (j : Nat) : BitVec w × Bool := --- -- optimistically assume (q.getLsb j = 1) and perform the subtraction. --- let rj? := rjsucc - d * twoPow w j --- if rj? ≥ 0 -- yay, this subtraction is allowed. --- then (rj?, true) -- confirm the results. --- else (rjsucc, false) -- discard the results. +/- # Division Recurrence for Bitblasting -/ + +/-- A bundle of the quotient and remainder for the intermediate steps when computing n.div d -/ +structure DivRecQuotRem (w : Nat) (n : BitVec w) (d : BitVec w) where + r : BitVec w + q : BitVec w + + +/- Given d, R(j + 1), (calculate R(j), q.getLsb j). -/ +def divremi (qr : DivRecQuotRem w n d) (j : Nat) : BitVec w × Bool := + if d * twoPow w j ≤ qr.r then + let rj := qr.r - d * twoPow w j -- remainder is legal since it's positive, accept it. + (rj, true) + else + (qr.r, false) -- remainder is illegal, so quotient must be '0' at this bit. + +theorem divremi_eq_of_le (qr : DivRecQuotRem w n d) (h : d <<< j ≤ qr.r) : + divremi qr j = (qr.r - d * twoPow w j, true) := by + simp [divremi, h] + +theorem divremi_eq_of_not_le (qr : DivRecQuotRem w n d) (h : ¬ d <<< j ≤ qr.r) : + divremi qr j = (qr.r, false) := by + simp [divremi, h] + +def DivRecQuotRem.Lawful {n d : BitVec w} (qr : DivRecQuotRem w n d) : Prop := + d.toNat * qr.q.toNat + qr.r.toNat = n.toNat + +theorem DivRecQuotRem.Lawful.def {n d : BitVec w} {qr : DivRecQuotRem w n d} + (h : qr.Lawful) : d.toNat * qr.q.toNat + qr.r.toNat = n.toNat := by + simp [DivRecQuotRem.Lawful] at h + assumption + +def DivRecQuotRem.initialize (n d : BitVec w) : DivRecQuotRem w n d := + { r := n, q := 0 } + +theorem DivRecQuotRem.lawful_initialize {n d : BitVec w} : + (DivRecQuotRem.initialize n d).Lawful := by + simp [DivRecQuotRem.Lawful, DivRecQuotRem.initialize] + +/-- Recurrence for division in terms of `divremi`. -/ +def divRec (qr : DivRecQuotRem w n d) (j : Nat) : DivRecQuotRem w n d := + let (r, qj) := divremi qr j + let q := setBit qr.q j qj + match j with + | 0 => { r := r, q := q } + | j + 1 => divRec { r := r, q := q } j +where + /-- Set the `i`th bit of `v` to `b`.-/ + setBit (v : BitVec w) (i : Nat) (b : Bool) := + if b then v ||| twoPow w i else v + +@[simp] +theorem divRec_zero {qr : DivRecQuotRem w n d} : + divRec qr 0 = + { r := (divremi qr 0).fst , q := divRec.setBit qr.q 0 (divremi qr 0).snd } := by + unfold divRec + simp + +@[simp] +theorem divRec_succ {qr : DivRecQuotRem w n d} {j : Nat} : + divRec qr (j + 1) = + divRec { + r := (divremi qr (j + 1) |>.fst), + q := (divRec.setBit qr.q (j + 1) (divremi qr (j + 1) |>.snd)) + } j := by + conv => + lhs + unfold divRec + +/-- +Clear all low bits in the range `(j..0]`, +keeping high bits in the range `[w..j]` +-/ +abbrev clearLowBitsAfter (x : BitVec w) (j : Nat) : BitVec w := (x >>> j) <<< j + +@[simp] +theorem getLsb_clearLowBitsAfter {x : BitVec w} {j i : Nat} : + (clearLowBitsAfter x j).getLsb i = if j ≤ i then x.getLsb i else false := by + unfold clearLowBitsAfter + simp + by_cases hij : i < j + · simp only [hij, decide_True, Bool.not_true, Bool.and_false, Bool.false_and, false_eq, + and_eq_false_imp, decide_eq_true_eq] + intros hcontra + omega + · simp [hij, show j + (i - j) = i by omega, show j ≤ i by omega] + apply lt_of_getLsb + +-- d * j + q = n. +theorem divRec_lawful {qr : DivRecQuotRem w n d} {j : Nat} + (hqr : qr.Lawful) : (divRec qr j).Lawful := by sorry + +theorem divRec_remainder_inbounds {qr : DivRecQuotRem w n d} {j : Nat} + (hqr : qr.Lawful) : + (divRec qr j).r < d := by sorry + +theorem divRec_eq_div_of_lawful {qr : DivRecQuotRem w n d} (hd : 0 < d) + {hqr' : qr' = (divRec (DivRecQuotRem.initialize n d) w)} : + qr'.q = udiv n d := by + have hlawful : qr'.Lawful := by + rw [hqr'] + apply divRec_lawful + apply DivRecQuotRem.lawful_initialize + have hremainder : qr'.r < d := by + rw [hqr'] + apply divRec_remainder_inbounds + apply DivRecQuotRem.lawful_initialize + have this := div_characterized_of_mul_add_toNat + (d := d) (q := qr'.q) (n := n) (r := qr'.r) hd hremainder hlawful.def + simp [this.1] + +-- theorem div_rec_corret {d n q j : BitVec w} -- def divrem_rec (d : BitVec w) (n : BitVec w) (j : Nat) : BitVec w × BitVec w := -- match j with From f267fe3ad52cbe8e93d24ca0a7f47aa9e978b9ea Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 28 Jun 2024 19:37:41 +0100 Subject: [PATCH 32/64] chore: cleanup theorem statement to be cripser --- src/Init/Data/BitVec/Bitblast.lean | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 7a5fac0fd97e..e4afca024600 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -791,28 +791,13 @@ theorem divRec_remainder_inbounds {qr : DivRecQuotRem w n d} {j : Nat} (hqr : qr.Lawful) : (divRec qr j).r < d := by sorry -theorem divRec_eq_div_of_lawful {qr : DivRecQuotRem w n d} (hd : 0 < d) - {hqr' : qr' = (divRec (DivRecQuotRem.initialize n d) w)} : +theorem divRec_eq_div_of_lawful {qr : DivRecQuotRem w n d} (hqr : qr.Lawful) (hd : 0 < d) + {hqr' : qr' = (divRec qr w)} : qr'.q = udiv n d := by - have hlawful : qr'.Lawful := by - rw [hqr'] - apply divRec_lawful - apply DivRecQuotRem.lawful_initialize - have hremainder : qr'.r < d := by - rw [hqr'] - apply divRec_remainder_inbounds - apply DivRecQuotRem.lawful_initialize + have hlawful : qr'.Lawful := by simp [hqr', divRec_lawful hqr] + have hremainder : qr'.r < d := by simp [hqr', divRec_remainder_inbounds hqr] have this := div_characterized_of_mul_add_toNat (d := d) (q := qr'.q) (n := n) (r := qr'.r) hd hremainder hlawful.def simp [this.1] --- theorem div_rec_corret {d n q j : BitVec w} - --- def divrem_rec (d : BitVec w) (n : BitVec w) (j : Nat) : BitVec w × BitVec w := --- match j with --- | 0 => divremi d n 0 --- | j + 1 => --- let (b, rj') := divrem_rec d n j --- divremi d (if b then rj' else n) (j + 1) - end BitVec From 0ed0d4487c2d306f0032be8fbe88a83355a39782 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 28 Jun 2024 20:30:23 +0100 Subject: [PATCH 33/64] chore: write proof sketch of why the div, rem sorries are true. Now fill them up! --- src/Init/Data/BitVec/Bitblast.lean | 145 +++++++++++++++++++++++++++-- 1 file changed, 138 insertions(+), 7 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index e4afca024600..825d56c785db 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -702,7 +702,6 @@ structure DivRecQuotRem (w : Nat) (n : BitVec w) (d : BitVec w) where r : BitVec w q : BitVec w - /- Given d, R(j + 1), (calculate R(j), q.getLsb j). -/ def divremi (qr : DivRecQuotRem w n d) (j : Nat) : BitVec w × Bool := if d * twoPow w j ≤ qr.r then @@ -711,11 +710,11 @@ def divremi (qr : DivRecQuotRem w n d) (j : Nat) : BitVec w × Bool := else (qr.r, false) -- remainder is illegal, so quotient must be '0' at this bit. -theorem divremi_eq_of_le (qr : DivRecQuotRem w n d) (h : d <<< j ≤ qr.r) : +theorem divremi_eq_of_le {qr : DivRecQuotRem w n d} (h : d <<< j ≤ qr.r) : divremi qr j = (qr.r - d * twoPow w j, true) := by simp [divremi, h] -theorem divremi_eq_of_not_le (qr : DivRecQuotRem w n d) (h : ¬ d <<< j ≤ qr.r) : +theorem divremi_eq_of_not_le {qr : DivRecQuotRem w n d} (h : ¬ d <<< j ≤ qr.r) : divremi qr j = (qr.r, false) := by simp [divremi, h] @@ -727,6 +726,7 @@ theorem DivRecQuotRem.Lawful.def {n d : BitVec w} {qr : DivRecQuotRem w n d} simp [DivRecQuotRem.Lawful] at h assumption + def DivRecQuotRem.initialize (n d : BitVec w) : DivRecQuotRem w n d := { r := n, q := 0 } @@ -764,7 +764,17 @@ theorem divRec_succ {qr : DivRecQuotRem w n d} {j : Nat} : lhs unfold divRec -/-- +@[simp] +theorem divRec.setBit_false {v : BitVec w} {i : Nat} : + divRec.setBit v i false = v := by + simp [divRec.setBit] + +@[simp] +theorem divRec.setBit_true {v : BitVec w} {i : Nat} : + divRec.setBit v i true = v ||| twoPow w i := by + simp [divRec.setBit] + +/- Clear all low bits in the range `(j..0]`, keeping high bits in the range `[w..j]` -/ @@ -785,11 +795,131 @@ theorem getLsb_clearLowBitsAfter {x : BitVec w} {j i : Nat} : -- d * j + q = n. theorem divRec_lawful {qr : DivRecQuotRem w n d} {j : Nat} - (hqr : qr.Lawful) : (divRec qr j).Lawful := by sorry + (hqr : qr.Lawful) (hq : ∀ {i : Nat} (hi : i ≤ j), qr.q.getLsb j = false): (divRec qr j).Lawful := by + induction j generalizing qr + case zero => + /- + w : Nat + n d : BitVec w + qr : DivRecQuotRem w n d + hqr : qr.Lawful + ⊢ d.toNat * (divRec.setBit qr.q 0 (divremi qr 0).snd).toNat + (divremi qr 0).fst.toNat = n.toNat + -/ + simp [DivRecQuotRem.Lawful] + by_cases h : d <<< 0 ≤ qr.r + · simp [divremi_eq_of_le h] + simp at h + rcases w with rfl | w + · simp -- get rid of corner case with 1 % 2^0 = 1 % 1 = 0 + · have h1 : 1 % 2^(w + 1) = 1 := by + sorry + -- TODO: refactor instead of having toNat goal, have two goals, + -- one as bitvec equality, other as nat bound. + simp [h1] + specialize (hq (i := 0) (by omega)) + /- + case pos.succ + w : Nat + n d : BitVec (w + 1) + qr : DivRecQuotRem (w + 1) n d + hqr : qr.Lawful + hq : ∀ {i : Nat}, i ≤ 0 → qr.q.getLsb 0 = false + h : d ≤ qr.r + ⊢ d.toNat * (qr.q.toNat ||| 1 % 2 ^ (w + 1)) + (qr.r.toNat + (2 ^ (w + 1) - d.toNat)) % 2 ^ (w + 1) = n.toNat + -/ + -- this is equal, since it simplifies to: + -- (d * (q ||| 1) + (r - 2^(w + 1) - d)) % (2^(w + 1)) = n + -- = (d * (q + 1) + (r - 2^(w + 1) - d)) % (2^(w + 1)) = n [q, 1 are mutex by hq] + -- = (d * q + d + (r + (2^(w + 1) - d))) % 2 ^(w + 1) = n + -- now rearrange, and cancel the d by using (h : d ≤ r): + -- = (d * q + d + -d + (r + (2^(w + 1))) % 2 ^(w + 1) = n + -- = (d * q + (r + (2^(w + 1))) % 2 ^(w + 1) = n + -- Now cancel the modulo, giving us + -- d * q + r =indeed= n by [lawful]. + sorry -- HERE + · simp [divremi_eq_of_not_le h] + simp at h + exact hqr.def + case succ j ih => + simp + -- TODO: split below into a separate lemma + by_cases h : d <<< (j + 1) ≤ qr.r + · simp [divremi_eq_of_le h] + simp at h + sorry -- NEXT + · simp [divremi_eq_of_not_le h] + simp at h + apply ih hqr + theorem divRec_remainder_inbounds {qr : DivRecQuotRem w n d} {j : Nat} - (hqr : qr.Lawful) : - (divRec qr j).r < d := by sorry + (hqr : qr.Lawful) (hr : qr.r.toNat < 2^j) (hd : 0 < d): + (divRec qr j).r < d := by + induction j generalizing qr + case zero => + simp + by_cases h : d <<< 0 ≤ qr.r + · simp [divremi_eq_of_le h] + simp at h + simp at hr + simp [lt_def] + simp [show qr.r.toNat = 0 by omega] + simp [le_def] at h + have hd' : d.toNat = 0 := by omega + simp [lt_def] at hd + omega + · simp [divremi_eq_of_not_le h] + simp at h + simp [lt_def] at hr + simp [lt_def] + simp [lt_def] at hd + omega + -- exact h + case succ j ih => + simp + by_cases h : d <<< (j + 1) ≤ qr.r + · simp [divremi_eq_of_le h] + simp at h + /- + case pos + w : Nat + n d : BitVec w + hd : 0 < d + j : Nat + ih : ∀ {qr : DivRecQuotRem w n d}, qr.Lawful → qr.r.toNat < 2 ^ j → (divRec qr j).r < d + qr : DivRecQuotRem w n d + hqr : qr.Lawful + hr : qr.r.toNat < 2 ^ (j + 1) + h : d <<< (j + 1) ≤ qr.r + ⊢ (divRec { r := qr.r - d <<< (j + 1), q := qr.q ||| twoPow w (j + 1) } j).r < d + -/ + sorry + -- exact h + · simp [divremi_eq_of_not_le h] + simp at h + apply ih + · exact hqr + · simp; + /- + case neg.hr + w : Nat + n d : BitVec w + hd : 0 < d + j : Nat + ih : ∀ {qr : DivRecQuotRem w n d}, qr.Lawful → qr.r.toNat < 2 ^ j → (divRec qr j).r < d + qr : DivRecQuotRem w n d + hqr : qr.Lawful + hr : qr.r.toNat < 2 ^ (j + 1) + h : ¬d <<< (j + 1) ≤ qr.r + ⊢ qr.r.toNat < 2 ^ j + -/ + -- I get a contradiction, because I know the following facts: + -- - hr : qr < 2^(j + 1) + -- hd : 0 < d => d >= 1 + -- and thus, d <<< (j + 1) ≥ 1 >>> (j + 1) ≥ 2 ^(j + 1) + -- plus, h tells me that qr.r > d <<< (j + 1) => qr.r > 2^(j + 1) + -- contradiction! + sorry theorem divRec_eq_div_of_lawful {qr : DivRecQuotRem w n d} (hqr : qr.Lawful) (hd : 0 < d) {hqr' : qr' = (divRec qr w)} : @@ -800,4 +930,5 @@ theorem divRec_eq_div_of_lawful {qr : DivRecQuotRem w n d} (hqr : qr.Lawful) (hd (d := d) (q := qr'.q) (n := n) (r := qr'.r) hd hremainder hlawful.def simp [this.1] + end BitVec From 0e23bbdf0e607b6caf6d421a356f5ec603b4e080 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 28 Jun 2024 22:27:58 +0100 Subject: [PATCH 34/64] chore: I believe the proof, now I need to encode it --- src/Init/Data/BitVec/Bitblast.lean | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 825d56c785db..dacf553b217c 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -925,10 +925,13 @@ theorem divRec_eq_div_of_lawful {qr : DivRecQuotRem w n d} (hqr : qr.Lawful) (hd {hqr' : qr' = (divRec qr w)} : qr'.q = udiv n d := by have hlawful : qr'.Lawful := by simp [hqr', divRec_lawful hqr] - have hremainder : qr'.r < d := by simp [hqr', divRec_remainder_inbounds hqr] + have hremainder : qr'.r < d := by + simp [hqr'] + apply divRec_remainder_inbounds hqr + apply qr.r.isLt + assumption have this := div_characterized_of_mul_add_toNat (d := d) (q := qr'.q) (n := n) (r := qr'.r) hd hremainder hlawful.def simp [this.1] - end BitVec From 211fe3e5c744d30d74dd695b035b685a96fe3113 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 28 Jun 2024 22:56:08 +0100 Subject: [PATCH 35/64] chore: see that we have an overflow, can't write it this way. --- src/Init/Data/BitVec/Bitblast.lean | 80 ++++++++++++++++++++---------- 1 file changed, 55 insertions(+), 25 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index dacf553b217c..f0a901207dd9 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -704,7 +704,7 @@ structure DivRecQuotRem (w : Nat) (n : BitVec w) (d : BitVec w) where /- Given d, R(j + 1), (calculate R(j), q.getLsb j). -/ def divremi (qr : DivRecQuotRem w n d) (j : Nat) : BitVec w × Bool := - if d * twoPow w j ≤ qr.r then + if d * twoPow w j ≤ qr.r then -- This overflows, I need to do this differently :( let rj := qr.r - d * twoPow w j -- remainder is legal since it's positive, accept it. (rj, true) else @@ -719,13 +719,16 @@ theorem divremi_eq_of_not_le {qr : DivRecQuotRem w n d} (h : ¬ d <<< j ≤ qr.r simp [divremi, h] def DivRecQuotRem.Lawful {n d : BitVec w} (qr : DivRecQuotRem w n d) : Prop := - d.toNat * qr.q.toNat + qr.r.toNat = n.toNat + (d.toNat * qr.q.toNat + qr.r.toNat < 2^w) ∧ (d * qr.q + qr.r = n) theorem DivRecQuotRem.Lawful.def {n d : BitVec w} {qr : DivRecQuotRem w n d} (h : qr.Lawful) : d.toNat * qr.q.toNat + qr.r.toNat = n.toNat := by simp [DivRecQuotRem.Lawful] at h - assumption - + have h' : (d * qr.q + qr.r).toNat = n.toNat := by rw [h.2] + simp at h' + rw [Nat.mod_eq_of_lt] at h' + · exact h' + · exact h.1 def DivRecQuotRem.initialize (n d : BitVec w) : DivRecQuotRem w n d := { r := n, q := 0 } @@ -880,26 +883,14 @@ theorem divRec_remainder_inbounds {qr : DivRecQuotRem w n d} {j : Nat} by_cases h : d <<< (j + 1) ≤ qr.r · simp [divremi_eq_of_le h] simp at h - /- - case pos - w : Nat - n d : BitVec w - hd : 0 < d - j : Nat - ih : ∀ {qr : DivRecQuotRem w n d}, qr.Lawful → qr.r.toNat < 2 ^ j → (divRec qr j).r < d - qr : DivRecQuotRem w n d - hqr : qr.Lawful - hr : qr.r.toNat < 2 ^ (j + 1) - h : d <<< (j + 1) ≤ qr.r - ⊢ (divRec { r := qr.r - d <<< (j + 1), q := qr.q ||| twoPow w (j + 1) } j).r < d - -/ - sorry - -- exact h - · simp [divremi_eq_of_not_le h] - simp at h - apply ih - · exact hqr - · simp; + have hcontra : (d <<< (j + 1)).toNat < 2^(j + 1) := by + simp [BitVec.le_def] at h + calc + d <<< (j + 1) ≤ qr.r := by exact h + (d <<< (j + 1)).toNat ≤ qr.r.toNat := by simp + + + /- case neg.hr w : Nat @@ -919,9 +910,35 @@ theorem divRec_remainder_inbounds {qr : DivRecQuotRem w n d} {j : Nat} -- and thus, d <<< (j + 1) ≥ 1 >>> (j + 1) ≥ 2 ^(j + 1) -- plus, h tells me that qr.r > d <<< (j + 1) => qr.r > 2^(j + 1) -- contradiction! + + apply ih + simp [DivRecQuotRem.Lawful] + constructor + · sorry + · simp + -- exact h + · simp [divremi_eq_of_not_le h] + simp at h + apply ih + · exact hqr + · simp; + + /- + case pos + w : Nat + n d : BitVec w + hd : 0 < d + j : Nat + ih : ∀ {qr : DivRecQuotRem w n d}, qr.Lawful → qr.r.toNat < 2 ^ j → (divRec qr j).r < d + qr : DivRecQuotRem w n d + hqr : qr.Lawful + hr : qr.r.toNat < 2 ^ (j + 1) + h : d <<< (j + 1) ≤ qr.r + ⊢ (divRec { r := qr.r - d <<< (j + 1), q := qr.q ||| twoPow w (j + 1) } j).r < d + -/ sorry -theorem divRec_eq_div_of_lawful {qr : DivRecQuotRem w n d} (hqr : qr.Lawful) (hd : 0 < d) +theorem divRec_eq_udiv_of_lawful {qr : DivRecQuotRem w n d} (hqr : qr.Lawful) (hd : 0 < d) {hqr' : qr' = (divRec qr w)} : qr'.q = udiv n d := by have hlawful : qr'.Lawful := by simp [hqr', divRec_lawful hqr] @@ -934,4 +951,17 @@ theorem divRec_eq_div_of_lawful {qr : DivRecQuotRem w n d} (hqr : qr.Lawful) (hd (d := d) (q := qr'.q) (n := n) (r := qr'.r) hd hremainder hlawful.def simp [this.1] +theorem divRec_eq_umod_of_lawful {qr : DivRecQuotRem w n d} (hqr : qr.Lawful) (hd : 0 < d) + {hqr' : qr' = (divRec qr w)} : + qr'.r = umod n d := by + have hlawful : qr'.Lawful := by simp [hqr', divRec_lawful hqr] + have hremainder : qr'.r < d := by + simp [hqr'] + apply divRec_remainder_inbounds hqr + apply qr.r.isLt + assumption + have this := div_characterized_of_mul_add_toNat + (d := d) (q := qr'.q) (n := n) (r := qr'.r) hd hremainder hlawful.def + simp [this.2] + end BitVec From 5b3bbc1c83eb72f205d4450ef33274151712e25d Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Sat, 29 Jun 2024 15:45:50 +0100 Subject: [PATCH 36/64] chore: write udiv proof, need to prove the other branch of the proof as well --- src/Init/Data/BitVec/Bitblast.lean | 368 ++++++++++++++++++----------- 1 file changed, 228 insertions(+), 140 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index f0a901207dd9..856970b2b624 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -704,28 +704,48 @@ structure DivRecQuotRem (w : Nat) (n : BitVec w) (d : BitVec w) where /- Given d, R(j + 1), (calculate R(j), q.getLsb j). -/ def divremi (qr : DivRecQuotRem w n d) (j : Nat) : BitVec w × Bool := - if d * twoPow w j ≤ qr.r then -- This overflows, I need to do this differently :( - let rj := qr.r - d * twoPow w j -- remainder is legal since it's positive, accept it. + if d ≤ qr.r >>> j then -- This overflows, I need to do this differently :( + -- let rj := qr.r - d * twoPow w j -- remainder is legal since it's positive, accept it. + -- | this is the same as checking qr - d * 2^j >= 0. + -- | If qr - d * 2^j ≥ 0, then it's legal for the quotient to have the jth bit 1. + let rj := qr.r >>> j - d -- remainder is legal since it's positive, accept it. (rj, true) else (qr.r, false) -- remainder is illegal, so quotient must be '0' at this bit. -theorem divremi_eq_of_le {qr : DivRecQuotRem w n d} (h : d <<< j ≤ qr.r) : - divremi qr j = (qr.r - d * twoPow w j, true) := by +theorem divremi_eq_of_le {qr : DivRecQuotRem w n d} (h : d ≤ qr.r >>> j) : + divremi qr j = (qr.r >>> j - d, true) := by simp [divremi, h] -theorem divremi_eq_of_not_le {qr : DivRecQuotRem w n d} (h : ¬ d <<< j ≤ qr.r) : +theorem divremi_eq_of_not_le {qr : DivRecQuotRem w n d} (h : ¬ d ≤ qr.r >>> j) : divremi qr j = (qr.r, false) := by simp [divremi, h] def DivRecQuotRem.Lawful {n d : BitVec w} (qr : DivRecQuotRem w n d) : Prop := - (d.toNat * qr.q.toNat + qr.r.toNat < 2^w) ∧ (d * qr.q + qr.r = n) + (d.toNat * qr.q.toNat + qr.r.toNat = n.toNat) -theorem DivRecQuotRem.Lawful.def {n d : BitVec w} {qr : DivRecQuotRem w n d} +theorem DivRecQuotRem.Lawful.toNat_inbounds {n d : BitVec w} {qr : DivRecQuotRem w n d} + (h : qr.Lawful) : d.toNat * qr.q.toNat + qr.r.toNat < 2^w := by + rw [h] + omega + +theorem DivRecQuotRem.inbounds {n d : BitVec w} {qr : DivRecQuotRem w n d} : + (d * qr.q + qr.r).toNat < 2^w := by omega + + +theorem DivRecQuotRem.Lawful.eq {n d : BitVec w} {qr : DivRecQuotRem w n d} + (h : qr.Lawful) : (d * qr.q + qr.r = n) := by + apply eq_of_toNat_eq + simp + rw [h] + rw [Nat.mod_eq_of_lt] + omega + +theorem DivRecQuotRem.Lawful.eq_nat {n d : BitVec w} {qr : DivRecQuotRem w n d} (h : qr.Lawful) : d.toNat * qr.q.toNat + qr.r.toNat = n.toNat := by - simp [DivRecQuotRem.Lawful] at h + simp only [Lawful] at h have h' : (d * qr.q + qr.r).toNat = n.toNat := by rw [h.2] - simp at h' + simp only [toNat_add, toNat_mul, mod_add_mod] at h' rw [Nat.mod_eq_of_lt] at h' · exact h' · exact h.1 @@ -796,152 +816,135 @@ theorem getLsb_clearLowBitsAfter {x : BitVec w} {j i : Nat} : · simp [hij, show j + (i - j) = i by omega, show j ≤ i by omega] apply lt_of_getLsb +/-- +A bitvector can be broken down into the low bits (by truncate) and the high +bits (by left shift followed by right shift). +-/ +theorem BitVec.zeroExtend_truncate_or_shiftRight_shiftLeft_eq_self {x : BitVec w} {i : Nat} : + (x.truncate i).zeroExtend w ||| (x >>> i) <<< i = x := by + ext j + by_cases h : j < i + · simp [h] + · simp only [getLsb_or, getLsb_zeroExtend, Fin.is_lt, decide_True, h, decide_False, + Bool.false_and, Bool.and_false, getLsb_shiftLeft, Bool.not_false, Bool.and_self, + getLsb_ushiftRight, Bool.true_and, Bool.false_or] + congr + omega + +/-- Key relationship that establishes the loop invariant after one iteration. -/ +theorem DivRecQuotRem.rec_lawful_of_lawful + {qr : DivRecQuotRem w n d} + (hqr : qr.Lawful) + {hq : ∀ {i : Nat}, i ≤ j + 1 → qr.q.getLsb i = false} + {hd : d ≤ qr.r >>> (j + 1)} : + { r := qr.r >>> (j + 1) - d, q := qr.q ||| twoPow w (j + 1) : DivRecQuotRem w n d }.Lawful := by + simp only [DivRecQuotRem.Lawful] + rw [← BitVec.add_eq_or_of_and_eq_zero] + sorry + -- d * j + q = n. theorem divRec_lawful {qr : DivRecQuotRem w n d} {j : Nat} - (hqr : qr.Lawful) (hq : ∀ {i : Nat} (hi : i ≤ j), qr.q.getLsb j = false): (divRec qr j).Lawful := by + (hqr : qr.Lawful) + -- | We start the reucrrence at j=w-1, so all bits are zero. + -- | At one step of j=w-2, only the low bit maybe set, and all + -- bits at [1..w) are zero. + -- We end the recurrence at j=0, when no bit is forced to be zero. + -- | the quotient is correct, if the quotient is zero for all bits + -- | in the range `(j..0]`. + (hq : ∀ {i : Nat} (hi : i ≤ j), qr.q.getLsb i = false) : + (divRec qr j).Lawful := by induction j generalizing qr case zero => - /- - w : Nat - n d : BitVec w - qr : DivRecQuotRem w n d - hqr : qr.Lawful - ⊢ d.toNat * (divRec.setBit qr.q 0 (divremi qr 0).snd).toNat + (divremi qr 0).fst.toNat = n.toNat - -/ simp [DivRecQuotRem.Lawful] - by_cases h : d <<< 0 ≤ qr.r + by_cases h : d ≤ qr.r >>> 0 · simp [divremi_eq_of_le h] simp at h rcases w with rfl | w · simp -- get rid of corner case with 1 % 2^0 = 1 % 1 = 0 · have h1 : 1 % 2^(w + 1) = 1 := by - sorry - -- TODO: refactor instead of having toNat goal, have two goals, - -- one as bitvec equality, other as nat bound. + rw [Nat.mod_eq_of_lt] + apply Nat.one_lt_two_pow (by omega) simp [h1] specialize (hq (i := 0) (by omega)) - /- - case pos.succ - w : Nat - n d : BitVec (w + 1) - qr : DivRecQuotRem (w + 1) n d - hqr : qr.Lawful - hq : ∀ {i : Nat}, i ≤ 0 → qr.q.getLsb 0 = false - h : d ≤ qr.r - ⊢ d.toNat * (qr.q.toNat ||| 1 % 2 ^ (w + 1)) + (qr.r.toNat + (2 ^ (w + 1) - d.toNat)) % 2 ^ (w + 1) = n.toNat - -/ - -- this is equal, since it simplifies to: - -- (d * (q ||| 1) + (r - 2^(w + 1) - d)) % (2^(w + 1)) = n - -- = (d * (q + 1) + (r - 2^(w + 1) - d)) % (2^(w + 1)) = n [q, 1 are mutex by hq] - -- = (d * q + d + (r + (2^(w + 1) - d))) % 2 ^(w + 1) = n - -- now rearrange, and cancel the d by using (h : d ≤ r): - -- = (d * q + d + -d + (r + (2^(w + 1))) % 2 ^(w + 1) = n - -- = (d * q + (r + (2^(w + 1))) % 2 ^(w + 1) = n - -- Now cancel the modulo, giving us - -- d * q + r =indeed= n by [lawful]. - sorry -- HERE + have hqr_q_or_1_to_Nat : (qr.q.toNat ||| 1) = (qr.q ||| 1).toNat := by + simp + rw [Nat.mod_eq_of_lt] + repeat omega + rw [hqr_q_or_1_to_Nat] + have hqr_or_1_eq_hqr_add_1 : (qr.q ||| 1) = (qr.q + twoPow (w+1) 0) := by + rw [add_eq_or_of_and_eq_zero] + · simp + · rw [and_twoPow_eq_getLsb, hq] + simp + rw [hqr_or_1_eq_hqr_add_1] + simp -- here we get a % 2^(w + 1) that we wish to avoid + calc + d.toNat * ((qr.q.toNat + 1) % 2 ^ (w + 1)) + (qr.r.toNat + (2 ^ (w + 1) - d.toNat)) % 2 ^ (w + 1) = d.toNat * ((qr.q.toNat + 1) % 2 ^ (w + 1)) + (qr.r.toNat + (2 ^ (w + 1) - d.toNat)) % 2 ^ (w + 1) := by rfl + _ = d.toNat * ((qr.q.toNat + 1)) + (qr.r.toNat + (2 ^ (w + 1) - d.toNat)) % 2 ^ (w + 1) := by + rw [Nat.mod_eq_of_lt] + simp [getLsb] at hq + omega + _ = d.toNat * (qr.q.toNat + 1) + (qr.r.toNat - d.toNat + (2 ^ (w + 1))) % 2 ^ (w + 1) := by + congr 2 + /- Note: omega needs this. Is the preprocessor supposed to pick this up? -/ + have h' : d.toNat ≤ qr.r.toNat := by simp [BitVec.le_def] at h; omega + omega + _ = d.toNat * qr.q.toNat + d.toNat + (qr.r.toNat - d.toNat + (2 ^ (w + 1))) % 2 ^ (w + 1) := by simp [Nat.mul_add] + _ = d.toNat * qr.q.toNat + d.toNat + (qr.r.toNat - d.toNat) % 2^(w + 1) + (2^(w+1) % 2 ^ (w + 1)) := by simp [Nat.add_mod] + _ = d.toNat * qr.q.toNat + d.toNat + (qr.r.toNat - d.toNat) % 2^(w + 1) + 0 := by simp + _ = d.toNat * qr.q.toNat + d.toNat + (qr.r.toNat - d.toNat) % 2^(w + 1) := by simp + _ = d.toNat * qr.q.toNat + d.toNat + (qr.r.toNat - d.toNat) := by rw [Nat.mod_eq_of_lt]; omega + _ = d.toNat * qr.q.toNat + (d.toNat - d.toNat + qr.r.toNat) := by + have h' : d.toNat ≤ qr.r.toNat := by simp [BitVec.le_def] at h; omega + omega + _ = d.toNat * qr.q.toNat + qr.r.toNat := by simp + _ = n.toNat := hqr · simp [divremi_eq_of_not_le h] - simp at h - exact hqr.def + exact hqr case succ j ih => simp -- TODO: split below into a separate lemma - by_cases h : d <<< (j + 1) ≤ qr.r + by_cases h : d ≤ qr.r >>> (j + 1) · simp [divremi_eq_of_le h] simp at h - sorry -- NEXT - · simp [divremi_eq_of_not_le h] - simp at h - apply ih hqr - - -theorem divRec_remainder_inbounds {qr : DivRecQuotRem w n d} {j : Nat} - (hqr : qr.Lawful) (hr : qr.r.toNat < 2^j) (hd : 0 < d): - (divRec qr j).r < d := by - induction j generalizing qr - case zero => - simp - by_cases h : d <<< 0 ≤ qr.r - · simp [divremi_eq_of_le h] - simp at h - simp at hr - simp [lt_def] - simp [show qr.r.toNat = 0 by omega] - simp [le_def] at h - have hd' : d.toNat = 0 := by omega - simp [lt_def] at hd - omega - · simp [divremi_eq_of_not_le h] - simp at h - simp [lt_def] at hr - simp [lt_def] - simp [lt_def] at hd - omega - -- exact h - case succ j ih => - simp - by_cases h : d <<< (j + 1) ≤ qr.r - · simp [divremi_eq_of_le h] - simp at h - have hcontra : (d <<< (j + 1)).toNat < 2^(j + 1) := by - simp [BitVec.le_def] at h - calc - d <<< (j + 1) ≤ qr.r := by exact h - (d <<< (j + 1)).toNat ≤ qr.r.toNat := by simp - - - - /- - case neg.hr - w : Nat - n d : BitVec w - hd : 0 < d - j : Nat - ih : ∀ {qr : DivRecQuotRem w n d}, qr.Lawful → qr.r.toNat < 2 ^ j → (divRec qr j).r < d - qr : DivRecQuotRem w n d - hqr : qr.Lawful - hr : qr.r.toNat < 2 ^ (j + 1) - h : ¬d <<< (j + 1) ≤ qr.r - ⊢ qr.r.toNat < 2 ^ j - -/ - -- I get a contradiction, because I know the following facts: - -- - hr : qr < 2^(j + 1) - -- hd : 0 < d => d >= 1 - -- and thus, d <<< (j + 1) ≥ 1 >>> (j + 1) ≥ 2 ^(j + 1) - -- plus, h tells me that qr.r > d <<< (j + 1) => qr.r > 2^(j + 1) - -- contradiction! - apply ih - simp [DivRecQuotRem.Lawful] - constructor + /- + case pos.hqr + w : Nat + n d : BitVec w + j : Nat + ih : ∀ {qr : DivRecQuotRem w n d}, qr.Lawful → (∀ {i : Nat}, i ≤ j → qr.q.getLsb i = false) → (divRec qr j).Lawful + qr : DivRecQuotRem w n d + hqr : qr.Lawful + hq : ∀ {i : Nat}, i ≤ j + 1 → qr.q.getLsb i = false + h : d ≤ qr.r >>> (j + 1) + ⊢ { r := qr.r >>> (j + 1) - d, q := qr.q ||| twoPow w (j + 1) }.Lawful + -/ · sorry - · simp - -- exact h + · intros i hi + simp + constructor + · apply hq + omega + · intros h + omega · simp [divremi_eq_of_not_le h] simp at h - apply ih - · exact hqr - · simp; + apply ih hqr + intros i hi + apply hq + omega - /- - case pos - w : Nat - n d : BitVec w - hd : 0 < d - j : Nat - ih : ∀ {qr : DivRecQuotRem w n d}, qr.Lawful → qr.r.toNat < 2 ^ j → (divRec qr j).r < d - qr : DivRecQuotRem w n d - hqr : qr.Lawful - hr : qr.r.toNat < 2 ^ (j + 1) - h : d <<< (j + 1) ≤ qr.r - ⊢ (divRec { r := qr.r - d <<< (j + 1), q := qr.q ||| twoPow w (j + 1) } j).r < d - -/ - sorry +/-- +info: 'BitVec.divRec_lawful' depends on axioms: [propext, Quot.sound, Classical.choice] +-/ +#guard_msgs in #print axioms divRec_lawful theorem divRec_eq_udiv_of_lawful {qr : DivRecQuotRem w n d} (hqr : qr.Lawful) (hd : 0 < d) {hqr' : qr' = (divRec qr w)} : qr'.q = udiv n d := by - have hlawful : qr'.Lawful := by simp [hqr', divRec_lawful hqr] + have hlawful : qr'.Lawful := by + simp [hqr', divRec_lawful hqr] have hremainder : qr'.r < d := by simp [hqr'] apply divRec_remainder_inbounds hqr @@ -951,17 +954,102 @@ theorem divRec_eq_udiv_of_lawful {qr : DivRecQuotRem w n d} (hqr : qr.Lawful) (h (d := d) (q := qr'.q) (n := n) (r := qr'.r) hd hremainder hlawful.def simp [this.1] -theorem divRec_eq_umod_of_lawful {qr : DivRecQuotRem w n d} (hqr : qr.Lawful) (hd : 0 < d) - {hqr' : qr' = (divRec qr w)} : - qr'.r = umod n d := by - have hlawful : qr'.Lawful := by simp [hqr', divRec_lawful hqr] - have hremainder : qr'.r < d := by - simp [hqr'] - apply divRec_remainder_inbounds hqr - apply qr.r.isLt - assumption - have this := div_characterized_of_mul_add_toNat - (d := d) (q := qr'.q) (n := n) (r := qr'.r) hd hremainder hlawful.def - simp [this.2] + + +-- theorem divRec_remainder_inbounds {qr : DivRecQuotRem w n d} {j : Nat} +-- (hqr : qr.Lawful) (hr : qr.r.toNat < 2^j) (hd : 0 < d): +-- (divRec qr j).r < d := by +-- induction j generalizing qr +-- case zero => +-- simp +-- by_cases h : d <<< 0 ≤ qr.r +-- · simp [divremi_eq_of_le h] +-- simp at h +-- simp at hr +-- simp [lt_def] +-- simp [show qr.r.toNat = 0 by omega] +-- simp [le_def] at h +-- have hd' : d.toNat = 0 := by omega +-- simp [lt_def] at hd +-- omega +-- · simp [divremi_eq_of_not_le h] +-- simp at h +-- simp [lt_def] at hr +-- simp [lt_def] +-- simp [lt_def] at hd +-- omega +-- -- exact h +-- case succ j ih => +-- simp +-- by_cases h : d <<< (j + 1) ≤ qr.r +-- · simp [divremi_eq_of_le h] +-- simp at h +-- have hcontra : (d <<< (j + 1)).toNat < 2^(j + 1) := by +-- simp [BitVec.le_def] at h +-- calc +-- d <<< (j + 1) ≤ qr.r := by exact h +-- (d <<< (j + 1)).toNat ≤ qr.r.toNat := by simp + + + +-- /- +-- case neg.hr +-- w : Nat +-- n d : BitVec w +-- hd : 0 < d +-- j : Nat +-- ih : ∀ {qr : DivRecQuotRem w n d}, qr.Lawful → qr.r.toNat < 2 ^ j → (divRec qr j).r < d +-- qr : DivRecQuotRem w n d +-- hqr : qr.Lawful +-- hr : qr.r.toNat < 2 ^ (j + 1) +-- h : ¬d <<< (j + 1) ≤ qr.r +-- ⊢ qr.r.toNat < 2 ^ j +-- -/ +-- -- I get a contradiction, because I know the following facts: +-- -- - hr : qr < 2^(j + 1) +-- -- hd : 0 < d => d >= 1 +-- -- and thus, d <<< (j + 1) ≥ 1 >>> (j + 1) ≥ 2 ^(j + 1) +-- -- plus, h tells me that qr.r > d <<< (j + 1) => qr.r > 2^(j + 1) +-- -- contradiction! + +-- apply ih +-- simp [DivRecQuotRem.Lawful] +-- constructor +-- · sorry +-- · simp +-- -- exact h +-- · simp [divremi_eq_of_not_le h] +-- simp at h +-- apply ih +-- · exact hqr +-- · simp; + +-- /- +-- case pos +-- w : Nat +-- n d : BitVec w +-- hd : 0 < d +-- j : Nat +-- ih : ∀ {qr : DivRecQuotRem w n d}, qr.Lawful → qr.r.toNat < 2 ^ j → (divRec qr j).r < d +-- qr : DivRecQuotRem w n d +-- hqr : qr.Lawful +-- hr : qr.r.toNat < 2 ^ (j + 1) +-- h : d <<< (j + 1) ≤ qr.r +-- ⊢ (divRec { r := qr.r - d <<< (j + 1), q := qr.q ||| twoPow w (j + 1) } j).r < d +-- -/ +-- sorry + +-- theorem divRec_eq_umod_of_lawful {qr : DivRecQuotRem w n d} (hqr : qr.Lawful) (hd : 0 < d) +-- {hqr' : qr' = (divRec qr w)} : +-- qr'.r = umod n d := by +-- have hlawful : qr'.Lawful := by simp [hqr', divRec_lawful hqr] +-- have hremainder : qr'.r < d := by +-- simp [hqr'] +-- apply divRec_remainder_inbounds hqr +-- apply qr.r.isLt +-- assumption +-- have this := div_characterized_of_mul_add_toNat +-- (d := d) (q := qr'.q) (n := n) (r := qr'.r) hd hremainder hlawful.def +-- simp [this.2] end BitVec From 949c9b65b7e91fedc88819a457ba34ae2c2610ce Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Mon, 1 Jul 2024 14:41:16 +0100 Subject: [PATCH 37/64] chore: stash --- src/Init/Data/BitVec/Bitblast.lean | 40 ++++++++++++++++-------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 856970b2b624..b4e58dd1bf96 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -713,6 +713,19 @@ def divremi (qr : DivRecQuotRem w n d) (j : Nat) : BitVec w × Bool := else (qr.r, false) -- remainder is illegal, so quotient must be '0' at this bit. +/-- Recurrence for division in terms of `divremi`. -/ +def divRec (qr : DivRecQuotRem w n d) (j : Nat) : DivRecQuotRem w n d := + let (r, qj) := divremi qr j + let q := setBit qr.q j qj + match j with + | 0 => { r := r, q := q } + | j + 1 => divRec { r := r, q := q } j +where + /-- Set the `i`th bit of `v` to `b`.-/ + setBit (v : BitVec w) (i : Nat) (b : Bool) := + if b then v ||| twoPow w i else v + + theorem divremi_eq_of_le {qr : DivRecQuotRem w n d} (h : d ≤ qr.r >>> j) : divremi qr j = (qr.r >>> j - d, true) := by simp [divremi, h] @@ -741,14 +754,14 @@ theorem DivRecQuotRem.Lawful.eq {n d : BitVec w} {qr : DivRecQuotRem w n d} rw [Nat.mod_eq_of_lt] omega -theorem DivRecQuotRem.Lawful.eq_nat {n d : BitVec w} {qr : DivRecQuotRem w n d} - (h : qr.Lawful) : d.toNat * qr.q.toNat + qr.r.toNat = n.toNat := by - simp only [Lawful] at h - have h' : (d * qr.q + qr.r).toNat = n.toNat := by rw [h.2] - simp only [toNat_add, toNat_mul, mod_add_mod] at h' - rw [Nat.mod_eq_of_lt] at h' - · exact h' - · exact h.1 +-- theorem DivRecQuotRem.Lawful.eq_nat {n d : BitVec w} {qr : DivRecQuotRem w n d} +-- (h : qr.Lawful) : d.toNat * qr.q.toNat + qr.r.toNat = n.toNat := by +-- simp only [Lawful] at h +-- have h' : (d * qr.q + qr.r).toNat = n.toNat := by rw [h.2] +-- simp only [toNat_add, toNat_mul, mod_add_mod] at h' +-- rw [Nat.mod_eq_of_lt] at h' +-- · exact h' +-- · exact h.1 def DivRecQuotRem.initialize (n d : BitVec w) : DivRecQuotRem w n d := { r := n, q := 0 } @@ -757,17 +770,6 @@ theorem DivRecQuotRem.lawful_initialize {n d : BitVec w} : (DivRecQuotRem.initialize n d).Lawful := by simp [DivRecQuotRem.Lawful, DivRecQuotRem.initialize] -/-- Recurrence for division in terms of `divremi`. -/ -def divRec (qr : DivRecQuotRem w n d) (j : Nat) : DivRecQuotRem w n d := - let (r, qj) := divremi qr j - let q := setBit qr.q j qj - match j with - | 0 => { r := r, q := q } - | j + 1 => divRec { r := r, q := q } j -where - /-- Set the `i`th bit of `v` to `b`.-/ - setBit (v : BitVec w) (i : Nat) (b : Bool) := - if b then v ||| twoPow w i else v @[simp] theorem divRec_zero {qr : DivRecQuotRem w n d} : From 5c10b83efe3aeb4e186cfb3b9d51f1f849ed4841 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Mon, 1 Jul 2024 18:27:38 +0100 Subject: [PATCH 38/64] chore: comment out broken impl --- src/Init/Data/BitVec/Bitblast.lean | 113 ++++++++++++++++------------- src/Init/Data/BitVec/Lemmas.lean | 17 +++++ 2 files changed, 80 insertions(+), 50 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index b4e58dd1bf96..0803f6a4d280 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -466,8 +466,7 @@ theorem shiftLeftRec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) (hn : n + rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1)] simp [h] -/-- info: 'BitVec.shiftLeftRec_eq' depends on axioms: [propext, Quot.sound, Classical.choice] -/ -#guard_msgs in #print axioms shiftLeftRec_eq +#print axioms shiftLeftRec_eq theorem shiftLeft_eq_shiftLeft_rec (x : BitVec ℘) (y : BitVec w₂) : x <<< y = shiftLeftRec x y (w₂ - 1) := by @@ -476,7 +475,7 @@ theorem shiftLeft_eq_shiftLeft_rec (x : BitVec ℘) (y : BitVec w₂) : · simp [shiftLeftRec_eq x y w₂ (by omega)] -/-## (Arithmetic) sshiftRight recurrence -/ +/-## (Arithmetic) ushiftRight recurrence -/ def ushiftRight_rec (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : BitVec w₁ := let shiftAmt := (y &&& (twoPow w₂ n)) @@ -568,8 +567,7 @@ theorem ushiftRight_rec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) (hn : n rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1)] simp [h] -/-- info: 'BitVec.ushiftRight_rec_eq' depends on axioms: [propext, Quot.sound, Classical.choice] -/ -#guard_msgs in #print axioms ushiftRight_rec_eq +#print axioms ushiftRight_rec_eq theorem shiftRight_eq_shiftRight_rec (x : BitVec ℘) (y : BitVec w₂) : x >>> y = ushiftRight_rec x y (w₂ - 1) := by @@ -842,8 +840,22 @@ theorem DivRecQuotRem.rec_lawful_of_lawful { r := qr.r >>> (j + 1) - d, q := qr.q ||| twoPow w (j + 1) : DivRecQuotRem w n d }.Lawful := by simp only [DivRecQuotRem.Lawful] rw [← BitVec.add_eq_or_of_and_eq_zero] - sorry - + · simp + /- + w : Nat + n d : BitVec w + j : Nat + qr : DivRecQuotRem w n d + hqr : qr.Lawful + hq : ∀ {i : Nat}, i ≤ j + 1 → qr.q.getLsb i = false + hd : d ≤ qr.r >>> (j + 1) + ⊢ d.toNat * ((qr.q.toNat + 2 ^ (j + 1)) % 2 ^ w) + (qr.r.toNat >>> (j + 1) + (2 ^ w - d.toNat)) % 2 ^ w = n.toNat + -/ + sorry + · simp + specialize hq (i := j + 1) (by omega) + rw [hq] + simp -- d * j + q = n. theorem divRec_lawful {qr : DivRecQuotRem w n d} {j : Nat} (hqr : qr.Lawful) @@ -937,10 +949,49 @@ theorem divRec_lawful {qr : DivRecQuotRem w n d} {j : Nat} apply hq omega -/-- -info: 'BitVec.divRec_lawful' depends on axioms: [propext, Quot.sound, Classical.choice] --/ -#guard_msgs in #print axioms divRec_lawful +#print axioms divRec_lawful + +theorem divRec_remainder_inbounds {qr : DivRecQuotRem w n d} {j : Nat} + (hqr : qr.Lawful) (hr : qr.r.toNat < 2^j) (hd : 0 < d): + (divRec qr j).r < d := by + induction j generalizing qr + case zero => + simp + by_cases h : d <<< 0 ≤ qr.r + · simp at h + simp at hr + simp [lt_def] + simp [le_def] at h + have hd' : d.toNat = 0 := by omega + simp [lt_def] at hd + omega + · simp at h + simp [lt_def] at hr + simp [lt_def] + simp [lt_def] at hd + simp [divremi, h] + omega + case succ j ih => + simp + by_cases h : d ≤ qr.r >>> (j + 1) + · simp [divremi_eq_of_le h] + simp at h + have hqr' : qr.r >>> (j + 1) = 0#w := by + apply BitVec.eq_of_toNat_eq + simp only [toNat_ushiftRight, toNat_ofNat, zero_mod] + rw [Nat.shiftRight_eq_div_pow] + rw[Nat.div_eq_of_lt (by omega)] + rw [hqr'] + rw [hqr'] at h + have hdcontra : d = 0#w := by + simp only [le_def, toNat_ofNat, zero_mod, le_zero_eq] at h + apply BitVec.eq_of_toNat_eq + simp [h] + simp [hdcontra] at hd + · simp [divremi_eq_of_not_le h] + apply ih hqr + simp [BitVec.le_def] at h + sorry theorem divRec_eq_udiv_of_lawful {qr : DivRecQuotRem w n d} (hqr : qr.Lawful) (hd : 0 < d) {hqr' : qr' = (divRec qr w)} : @@ -953,47 +1004,9 @@ theorem divRec_eq_udiv_of_lawful {qr : DivRecQuotRem w n d} (hqr : qr.Lawful) (h apply qr.r.isLt assumption have this := div_characterized_of_mul_add_toNat - (d := d) (q := qr'.q) (n := n) (r := qr'.r) hd hremainder hlawful.def + (d := d) (q := qr'.q) (n := n) (r := qr'.r) hd hremainder hlawful simp [this.1] - - --- theorem divRec_remainder_inbounds {qr : DivRecQuotRem w n d} {j : Nat} --- (hqr : qr.Lawful) (hr : qr.r.toNat < 2^j) (hd : 0 < d): --- (divRec qr j).r < d := by --- induction j generalizing qr --- case zero => --- simp --- by_cases h : d <<< 0 ≤ qr.r --- · simp [divremi_eq_of_le h] --- simp at h --- simp at hr --- simp [lt_def] --- simp [show qr.r.toNat = 0 by omega] --- simp [le_def] at h --- have hd' : d.toNat = 0 := by omega --- simp [lt_def] at hd --- omega --- · simp [divremi_eq_of_not_le h] --- simp at h --- simp [lt_def] at hr --- simp [lt_def] --- simp [lt_def] at hd --- omega --- -- exact h --- case succ j ih => --- simp --- by_cases h : d <<< (j + 1) ≤ qr.r --- · simp [divremi_eq_of_le h] --- simp at h --- have hcontra : (d <<< (j + 1)).toNat < 2^(j + 1) := by --- simp [BitVec.le_def] at h --- calc --- d <<< (j + 1) ≤ qr.r := by exact h --- (d <<< (j + 1)).toNat ≤ qr.r.toNat := by simp - - - -- /- -- case neg.hr -- w : Nat diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 9900628ccee4..edffaaddd38c 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -763,6 +763,23 @@ theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) : Nat.not_lt, decide_eq_true_eq] omega +/-- The arithmetic shift right equals the msb when `s + i ≥ w`, and equals the logical shift right when `s + i < w. -/ +theorem getLsb_sshiftRight_eq_getLsb_ushiftRight (x : BitVec w) (s i : Nat) : + getLsb (x.sshiftRight s) i = (!decide (w ≤ i) && if s + i < w then (x >>> s).getLsb i else x.msb) := by + have h : (x >>> s).getLsb i = x.getLsb (s + i) := by + simp only [getLsb_ushiftRight] + rw [h] + simp [getLsb_sshiftRight] + +/-- A version of `BitVec.sshiftRight` with both arguments as bitvectors. -/ +def sshiftRight' (x y : BitVec w) : BitVec w := x.sshiftRight y.toNat + +theorem getLsb_sshift'_eq_getLsb_sshiftRight (x y : BitVec w) (i : Nat) : + getLsb (sshiftRight' x y) i = getLsb (x.sshiftRight y.toNat) i := by + simp [sshiftRight'] + +-- theorem getLsb_sshiftRight'_ + /-! ### udiv -/ theorem udiv_eq {x y : BitVec n} : From 5beb72b5b376c1f3cddd7cffac959028f7f63f6e Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Mon, 1 Jul 2024 19:34:39 +0100 Subject: [PATCH 39/64] chore: start implement sshiftRight' and arith shift right recurrence. --- src/Init/Data/BitVec/Bitblast.lean | 17 ++++++++++++++++- src/Init/Data/BitVec/Lemmas.lean | 5 +++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 0803f6a4d280..97d15ce1d0eb 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -475,7 +475,7 @@ theorem shiftLeft_eq_shiftLeft_rec (x : BitVec ℘) (y : BitVec w₂) : · simp [shiftLeftRec_eq x y w₂ (by omega)] -/-## (Arithmetic) ushiftRight recurrence -/ +/-## (Logical) ushiftRight recurrence -/ def ushiftRight_rec (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : BitVec w₁ := let shiftAmt := (y &&& (twoPow w₂ n)) @@ -575,6 +575,20 @@ theorem shiftRight_eq_shiftRight_rec (x : BitVec ℘) (y : BitVec w₂) : · simp [of_length_zero] · simp [ushiftRight_rec_eq x y w₂ (by omega)] + +/- ### Arithmetic (sshiftRight) recurrence -/ + +def sshiftRightRec (x : BitVec w) (y : BitVec w₂) (n : Nat) : BitVec w := + let shiftAmt := (y &&& (twoPow w₂ n)) + match n with + | 0 => x.sshiftRight' shiftAmt + | n + 1 => (sshiftRightRec x y n) >>> shiftAmt + +theorem sshiftRight_eq_sshiftRightRec (x : BitVec w₁) (y : BitVec w₂) : + (x >>> y).getLsb i = (sshiftRightRec x y w).getLsb i := sorry + + + /- ## udiv/urem bitblasting -/ /- @@ -1067,4 +1081,5 @@ theorem divRec_eq_udiv_of_lawful {qr : DivRecQuotRem w n d} (hqr : qr.Lawful) (h -- (d := d) (q := qr'.q) (n := n) (r := qr'.r) hd hremainder hlawful.def -- simp [this.2] + end BitVec diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index edffaaddd38c..e5cf7f566c87 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -772,9 +772,10 @@ theorem getLsb_sshiftRight_eq_getLsb_ushiftRight (x : BitVec w) (s i : Nat) : simp [getLsb_sshiftRight] /-- A version of `BitVec.sshiftRight` with both arguments as bitvectors. -/ -def sshiftRight' (x y : BitVec w) : BitVec w := x.sshiftRight y.toNat +def sshiftRight' (x : BitVec w₁) (y : BitVec w₂) : BitVec w₁ := + x.sshiftRight y.toNat -theorem getLsb_sshift'_eq_getLsb_sshiftRight (x y : BitVec w) (i : Nat) : +theorem getLsb_sshift'_eq_getLsb_sshiftRight : getLsb (sshiftRight' x y) i = getLsb (x.sshiftRight y.toNat) i := by simp [sshiftRight'] From 9720ab59f8714b6c4af9104aa6a034d05c874688 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Tue, 2 Jul 2024 07:41:31 +0100 Subject: [PATCH 40/64] chore: stash sshiftRight that hargonix wants next --- src/Init/Data/BitVec/Bitblast.lean | 666 ++++++++++++++++------------- 1 file changed, 363 insertions(+), 303 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 97d15ce1d0eb..d6e18891ce8e 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -709,317 +709,377 @@ theorem div_iff_add_mod_of_lt {d n q r : BitVec w} (hd : 0 < d) /- # Division Recurrence for Bitblasting -/ +-- n = d * q + r +-- Two-stage subtraction: +-- For each bit of the dividend(n) starting from the MSB: +-- +-- 1) Add ith bit of dividend as MSB of remainder `rem`. +-- +-- 1) Compute carry bits when subtracting divisor `d` from current +-- remainder `rem`, which determines the current quotient bit. +-- 2) Perform subtraction operation based on current quotient bit and shift +-- remainder by one. +-- +-- For example, n = 0111 (7 in base 10), d = 0010 (2 in base 10) +-- +-- i rem d q +-- 0 0000 -- insert n.msb [0] +-- 0010 0 -- subtract d, not successful +-- 0000 -- result [unchanged] +-- 0000 -- shift +-- +-- 1 0001 -- insert n.msb [1] +-- 0010 0 -- subtract d, not successful +-- 0001 -- result [unchanged] +-- 0010 -- shift +-- +-- 2 0011 -- insert n.msb [2] +-- 0010 1 -- subtract d, successful +-- 0001 -- result [CHANGED] +-- 0010 -- shift +-- +-- 3 0011 -- insert n.msb [3] +-- 0010 1 -- subtract d, successful +-- 0001 -- remainder [CHANGED] +-- +-- remainder: 0001 (1 in base 10) +-- quotient: 0011 (3 in base 10) /-- A bundle of the quotient and remainder for the intermediate steps when computing n.div d -/ structure DivRecQuotRem (w : Nat) (n : BitVec w) (d : BitVec w) where r : BitVec w q : BitVec w +deriving DecidableEq, Repr -/- Given d, R(j + 1), (calculate R(j), q.getLsb j). -/ -def divremi (qr : DivRecQuotRem w n d) (j : Nat) : BitVec w × Bool := - if d ≤ qr.r >>> j then -- This overflows, I need to do this differently :( - -- let rj := qr.r - d * twoPow w j -- remainder is legal since it's positive, accept it. - -- | this is the same as checking qr - d * 2^j >= 0. - -- | If qr - d * 2^j ≥ 0, then it's legal for the quotient to have the jth bit 1. - let rj := qr.r >>> j - d -- remainder is legal since it's positive, accept it. - (rj, true) - else - (qr.r, false) -- remainder is illegal, so quotient must be '0' at this bit. - -/-- Recurrence for division in terms of `divremi`. -/ def divRec (qr : DivRecQuotRem w n d) (j : Nat) : DivRecQuotRem w n d := - let (r, qj) := divremi qr j - let q := setBit qr.q j qj + let rj := (qr.r <<< 1) ||| (BitVec.ofBool (n.getLsb j)).zeroExtend w + let qr' := if rj ≤ d + then { r := rj, q := qr.q <<< 1} + else {r := rj - d, q := qr.q <<< 1 ||| 1 } match j with - | 0 => { r := r, q := q } - | j + 1 => divRec { r := r, q := q } j -where - /-- Set the `i`th bit of `v` to `b`.-/ - setBit (v : BitVec w) (i : Nat) (b : Bool) := - if b then v ||| twoPow w i else v - - -theorem divremi_eq_of_le {qr : DivRecQuotRem w n d} (h : d ≤ qr.r >>> j) : - divremi qr j = (qr.r >>> j - d, true) := by - simp [divremi, h] - -theorem divremi_eq_of_not_le {qr : DivRecQuotRem w n d} (h : ¬ d ≤ qr.r >>> j) : - divremi qr j = (qr.r, false) := by - simp [divremi, h] - -def DivRecQuotRem.Lawful {n d : BitVec w} (qr : DivRecQuotRem w n d) : Prop := - (d.toNat * qr.q.toNat + qr.r.toNat = n.toNat) - -theorem DivRecQuotRem.Lawful.toNat_inbounds {n d : BitVec w} {qr : DivRecQuotRem w n d} - (h : qr.Lawful) : d.toNat * qr.q.toNat + qr.r.toNat < 2^w := by - rw [h] - omega - -theorem DivRecQuotRem.inbounds {n d : BitVec w} {qr : DivRecQuotRem w n d} : - (d * qr.q + qr.r).toNat < 2^w := by omega - - -theorem DivRecQuotRem.Lawful.eq {n d : BitVec w} {qr : DivRecQuotRem w n d} - (h : qr.Lawful) : (d * qr.q + qr.r = n) := by - apply eq_of_toNat_eq - simp - rw [h] - rw [Nat.mod_eq_of_lt] - omega - --- theorem DivRecQuotRem.Lawful.eq_nat {n d : BitVec w} {qr : DivRecQuotRem w n d} --- (h : qr.Lawful) : d.toNat * qr.q.toNat + qr.r.toNat = n.toNat := by --- simp only [Lawful] at h --- have h' : (d * qr.q + qr.r).toNat = n.toNat := by rw [h.2] --- simp only [toNat_add, toNat_mul, mod_add_mod] at h' --- rw [Nat.mod_eq_of_lt] at h' --- · exact h' --- · exact h.1 - -def DivRecQuotRem.initialize (n d : BitVec w) : DivRecQuotRem w n d := - { r := n, q := 0 } - -theorem DivRecQuotRem.lawful_initialize {n d : BitVec w} : - (DivRecQuotRem.initialize n d).Lawful := by - simp [DivRecQuotRem.Lawful, DivRecQuotRem.initialize] - - -@[simp] -theorem divRec_zero {qr : DivRecQuotRem w n d} : - divRec qr 0 = - { r := (divremi qr 0).fst , q := divRec.setBit qr.q 0 (divremi qr 0).snd } := by - unfold divRec - simp - -@[simp] -theorem divRec_succ {qr : DivRecQuotRem w n d} {j : Nat} : - divRec qr (j + 1) = - divRec { - r := (divremi qr (j + 1) |>.fst), - q := (divRec.setBit qr.q (j + 1) (divremi qr (j + 1) |>.snd)) - } j := by - conv => - lhs - unfold divRec - -@[simp] -theorem divRec.setBit_false {v : BitVec w} {i : Nat} : - divRec.setBit v i false = v := by - simp [divRec.setBit] - -@[simp] -theorem divRec.setBit_true {v : BitVec w} {i : Nat} : - divRec.setBit v i true = v ||| twoPow w i := by - simp [divRec.setBit] - -/- -Clear all low bits in the range `(j..0]`, -keeping high bits in the range `[w..j]` --/ -abbrev clearLowBitsAfter (x : BitVec w) (j : Nat) : BitVec w := (x >>> j) <<< j + | 0 => qr' + | j + 1 => divRec qr' j + +-- invariants: +-- 1) r < d. +-- 2) +theorem div_rec_7_2 : + (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0 } 3) = + { r := 1, q := 3 } := by + simp [divRec] + +-- invariant 2 +-- n.toNat % 2^j = d.toNat * q.toNat + r.toNat +#reduce (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0 } 3) -- r = 1, q = 3 +#reduce (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0 } 2) -- r = 1, q = 3 +#reduce (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0 } 1) -- r = 1, q = 1 +#reduce (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0 } 0) -- r = 1, q = 0 + +-- /- Given d, R(j + 1), (calculate R(j), q.getLsb j). -/ +-- def divremi (qr : DivRecQuotRem w n d) (j : Nat) : BitVec w × Bool := +-- if d ≤ qr.r >>> j then -- This overflows, I need to do this differently :( +-- -- let rj := qr.r - d * twoPow w j -- remainder is legal since it's positive, accept it. +-- -- | this is the same as checking qr - d * 2^j >= 0. +-- -- | If qr - d * 2^j ≥ 0, then it's legal for the quotient to have the jth bit 1. +-- let rj := qr.r >>> j - d -- remainder is legal since it's positive, accept it. +-- (rj, true) +-- else +-- (qr.r, false) -- remainder is illegal, so quotient must be '0' at this bit. + +-- /-- Recurrence for division in terms of `divremi`. -/ +-- def divRec (qr : DivRecQuotRem w n d) (j : Nat) : DivRecQuotRem w n d := +-- let (r, qj) := divremi qr j +-- let q := setBit qr.q j qj +-- match j with +-- | 0 => { r := r, q := q } +-- | j + 1 => divRec { r := r, q := q } j +-- where +-- /-- Set the `i`th bit of `v` to `b`.-/ +-- setBit (v : BitVec w) (i : Nat) (b : Bool) := +-- if b then v ||| twoPow w i else v + + +-- theorem divremi_eq_of_le {qr : DivRecQuotRem w n d} (h : d ≤ qr.r >>> j) : +-- divremi qr j = (qr.r >>> j - d, true) := by +-- simp [divremi, h] + +-- theorem divremi_eq_of_not_le {qr : DivRecQuotRem w n d} (h : ¬ d ≤ qr.r >>> j) : +-- divremi qr j = (qr.r, false) := by +-- simp [divremi, h] + +-- def DivRecQuotRem.Lawful {n d : BitVec w} (qr : DivRecQuotRem w n d) : Prop := +-- (d.toNat * qr.q.toNat + qr.r.toNat = n.toNat) + +-- theorem DivRecQuotRem.Lawful.toNat_inbounds {n d : BitVec w} {qr : DivRecQuotRem w n d} +-- (h : qr.Lawful) : d.toNat * qr.q.toNat + qr.r.toNat < 2^w := by +-- rw [h] +-- omega + +-- theorem DivRecQuotRem.inbounds {n d : BitVec w} {qr : DivRecQuotRem w n d} : +-- (d * qr.q + qr.r).toNat < 2^w := by omega + + +-- theorem DivRecQuotRem.Lawful.eq {n d : BitVec w} {qr : DivRecQuotRem w n d} +-- (h : qr.Lawful) : (d * qr.q + qr.r = n) := by +-- apply eq_of_toNat_eq +-- simp +-- rw [h] +-- rw [Nat.mod_eq_of_lt] +-- omega + +-- -- theorem DivRecQuotRem.Lawful.eq_nat {n d : BitVec w} {qr : DivRecQuotRem w n d} +-- -- (h : qr.Lawful) : d.toNat * qr.q.toNat + qr.r.toNat = n.toNat := by +-- -- simp only [Lawful] at h +-- -- have h' : (d * qr.q + qr.r).toNat = n.toNat := by rw [h.2] +-- -- simp only [toNat_add, toNat_mul, mod_add_mod] at h' +-- -- rw [Nat.mod_eq_of_lt] at h' +-- -- · exact h' +-- -- · exact h.1 + +-- def DivRecQuotRem.initialize (n d : BitVec w) : DivRecQuotRem w n d := +-- { r := n, q := 0 } + +-- theorem DivRecQuotRem.lawful_initialize {n d : BitVec w} : +-- (DivRecQuotRem.initialize n d).Lawful := by +-- simp [DivRecQuotRem.Lawful, DivRecQuotRem.initialize] + + +-- @[simp] +-- theorem divRec_zero {qr : DivRecQuotRem w n d} : +-- divRec qr 0 = +-- { r := (divremi qr 0).fst , q := divRec.setBit qr.q 0 (divremi qr 0).snd } := by +-- unfold divRec +-- simp + +-- @[simp] +-- theorem divRec_succ {qr : DivRecQuotRem w n d} {j : Nat} : +-- divRec qr (j + 1) = +-- divRec { +-- r := (divremi qr (j + 1) |>.fst), +-- q := (divRec.setBit qr.q (j + 1) (divremi qr (j + 1) |>.snd)) +-- } j := by +-- conv => +-- lhs +-- unfold divRec + +-- @[simp] +-- theorem divRec.setBit_false {v : BitVec w} {i : Nat} : +-- divRec.setBit v i false = v := by +-- simp [divRec.setBit] + +-- @[simp] +-- theorem divRec.setBit_true {v : BitVec w} {i : Nat} : +-- divRec.setBit v i true = v ||| twoPow w i := by +-- simp [divRec.setBit] + +-- /- +-- Clear all low bits in the range `(j..0]`, +-- keeping high bits in the range `[w..j]` +-- -/ +-- abbrev clearLowBitsAfter (x : BitVec w) (j : Nat) : BitVec w := (x >>> j) <<< j + +-- @[simp] +-- theorem getLsb_clearLowBitsAfter {x : BitVec w} {j i : Nat} : +-- (clearLowBitsAfter x j).getLsb i = if j ≤ i then x.getLsb i else false := by +-- unfold clearLowBitsAfter +-- simp +-- by_cases hij : i < j +-- · simp only [hij, decide_True, Bool.not_true, Bool.and_false, Bool.false_and, false_eq, +-- and_eq_false_imp, decide_eq_true_eq] +-- intros hcontra +-- omega +-- · simp [hij, show j + (i - j) = i by omega, show j ≤ i by omega] +-- apply lt_of_getLsb + +-- /-- +-- A bitvector can be broken down into the low bits (by truncate) and the high +-- bits (by left shift followed by right shift). +-- -/ +-- theorem BitVec.zeroExtend_truncate_or_shiftRight_shiftLeft_eq_self {x : BitVec w} {i : Nat} : +-- (x.truncate i).zeroExtend w ||| (x >>> i) <<< i = x := by +-- ext j +-- by_cases h : j < i +-- · simp [h] +-- · simp only [getLsb_or, getLsb_zeroExtend, Fin.is_lt, decide_True, h, decide_False, +-- Bool.false_and, Bool.and_false, getLsb_shiftLeft, Bool.not_false, Bool.and_self, +-- getLsb_ushiftRight, Bool.true_and, Bool.false_or] +-- congr +-- omega + +-- /-- Key relationship that establishes the loop invariant after one iteration. -/ +-- theorem DivRecQuotRem.rec_lawful_of_lawful +-- {qr : DivRecQuotRem w n d} +-- (hqr : qr.Lawful) +-- {hq : ∀ {i : Nat}, i ≤ j + 1 → qr.q.getLsb i = false} +-- {hd : d ≤ qr.r >>> (j + 1)} : +-- { r := qr.r >>> (j + 1) - d, q := qr.q ||| twoPow w (j + 1) : DivRecQuotRem w n d }.Lawful := by +-- simp only [DivRecQuotRem.Lawful] +-- rw [← BitVec.add_eq_or_of_and_eq_zero] +-- · simp +-- /- +-- w : Nat +-- n d : BitVec w +-- j : Nat +-- qr : DivRecQuotRem w n d +-- hqr : qr.Lawful +-- hq : ∀ {i : Nat}, i ≤ j + 1 → qr.q.getLsb i = false +-- hd : d ≤ qr.r >>> (j + 1) +-- ⊢ d.toNat * ((qr.q.toNat + 2 ^ (j + 1)) % 2 ^ w) + (qr.r.toNat >>> (j + 1) + (2 ^ w - d.toNat)) % 2 ^ w = n.toNat +-- -/ +-- sorry +-- · simp +-- specialize hq (i := j + 1) (by omega) +-- rw [hq] +-- simp +-- -- d * j + q = n. +-- theorem divRec_lawful {qr : DivRecQuotRem w n d} {j : Nat} +-- (hqr : qr.Lawful) +-- -- | We start the reucrrence at j=w-1, so all bits are zero. +-- -- | At one step of j=w-2, only the low bit maybe set, and all +-- -- bits at [1..w) are zero. +-- -- We end the recurrence at j=0, when no bit is forced to be zero. +-- -- | the quotient is correct, if the quotient is zero for all bits +-- -- | in the range `(j..0]`. +-- (hq : ∀ {i : Nat} (hi : i ≤ j), qr.q.getLsb i = false) : +-- (divRec qr j).Lawful := by +-- induction j generalizing qr +-- case zero => +-- simp [DivRecQuotRem.Lawful] +-- by_cases h : d ≤ qr.r >>> 0 +-- · simp [divremi_eq_of_le h] +-- simp at h +-- rcases w with rfl | w +-- · simp -- get rid of corner case with 1 % 2^0 = 1 % 1 = 0 +-- · have h1 : 1 % 2^(w + 1) = 1 := by +-- rw [Nat.mod_eq_of_lt] +-- apply Nat.one_lt_two_pow (by omega) +-- simp [h1] +-- specialize (hq (i := 0) (by omega)) +-- have hqr_q_or_1_to_Nat : (qr.q.toNat ||| 1) = (qr.q ||| 1).toNat := by +-- simp +-- rw [Nat.mod_eq_of_lt] +-- repeat omega +-- rw [hqr_q_or_1_to_Nat] +-- have hqr_or_1_eq_hqr_add_1 : (qr.q ||| 1) = (qr.q + twoPow (w+1) 0) := by +-- rw [add_eq_or_of_and_eq_zero] +-- · simp +-- · rw [and_twoPow_eq_getLsb, hq] +-- simp +-- rw [hqr_or_1_eq_hqr_add_1] +-- simp -- here we get a % 2^(w + 1) that we wish to avoid +-- calc +-- d.toNat * ((qr.q.toNat + 1) % 2 ^ (w + 1)) + (qr.r.toNat + (2 ^ (w + 1) - d.toNat)) % 2 ^ (w + 1) = d.toNat * ((qr.q.toNat + 1) % 2 ^ (w + 1)) + (qr.r.toNat + (2 ^ (w + 1) - d.toNat)) % 2 ^ (w + 1) := by rfl +-- _ = d.toNat * ((qr.q.toNat + 1)) + (qr.r.toNat + (2 ^ (w + 1) - d.toNat)) % 2 ^ (w + 1) := by +-- rw [Nat.mod_eq_of_lt] +-- simp [getLsb] at hq +-- omega +-- _ = d.toNat * (qr.q.toNat + 1) + (qr.r.toNat - d.toNat + (2 ^ (w + 1))) % 2 ^ (w + 1) := by +-- congr 2 +-- /- Note: omega needs this. Is the preprocessor supposed to pick this up? -/ +-- have h' : d.toNat ≤ qr.r.toNat := by simp [BitVec.le_def] at h; omega +-- omega +-- _ = d.toNat * qr.q.toNat + d.toNat + (qr.r.toNat - d.toNat + (2 ^ (w + 1))) % 2 ^ (w + 1) := by simp [Nat.mul_add] +-- _ = d.toNat * qr.q.toNat + d.toNat + (qr.r.toNat - d.toNat) % 2^(w + 1) + (2^(w+1) % 2 ^ (w + 1)) := by simp [Nat.add_mod] +-- _ = d.toNat * qr.q.toNat + d.toNat + (qr.r.toNat - d.toNat) % 2^(w + 1) + 0 := by simp +-- _ = d.toNat * qr.q.toNat + d.toNat + (qr.r.toNat - d.toNat) % 2^(w + 1) := by simp +-- _ = d.toNat * qr.q.toNat + d.toNat + (qr.r.toNat - d.toNat) := by rw [Nat.mod_eq_of_lt]; omega +-- _ = d.toNat * qr.q.toNat + (d.toNat - d.toNat + qr.r.toNat) := by +-- have h' : d.toNat ≤ qr.r.toNat := by simp [BitVec.le_def] at h; omega +-- omega +-- _ = d.toNat * qr.q.toNat + qr.r.toNat := by simp +-- _ = n.toNat := hqr +-- · simp [divremi_eq_of_not_le h] +-- exact hqr +-- case succ j ih => +-- simp +-- -- TODO: split below into a separate lemma +-- by_cases h : d ≤ qr.r >>> (j + 1) +-- · simp [divremi_eq_of_le h] +-- simp at h +-- apply ih +-- /- +-- case pos.hqr +-- w : Nat +-- n d : BitVec w +-- j : Nat +-- ih : ∀ {qr : DivRecQuotRem w n d}, qr.Lawful → (∀ {i : Nat}, i ≤ j → qr.q.getLsb i = false) → (divRec qr j).Lawful +-- qr : DivRecQuotRem w n d +-- hqr : qr.Lawful +-- hq : ∀ {i : Nat}, i ≤ j + 1 → qr.q.getLsb i = false +-- h : d ≤ qr.r >>> (j + 1) +-- ⊢ { r := qr.r >>> (j + 1) - d, q := qr.q ||| twoPow w (j + 1) }.Lawful +-- -/ +-- · sorry +-- · intros i hi +-- simp +-- constructor +-- · apply hq +-- omega +-- · intros h +-- omega +-- · simp [divremi_eq_of_not_le h] +-- simp at h +-- apply ih hqr +-- intros i hi +-- apply hq +-- omega + +-- #print axioms divRec_lawful + +-- theorem divRec_remainder_inbounds {qr : DivRecQuotRem w n d} {j : Nat} +-- (hqr : qr.Lawful) (hr : qr.r.toNat < 2^j) (hd : 0 < d): +-- (divRec qr j).r < d := by +-- induction j generalizing qr +-- case zero => +-- simp +-- by_cases h : d <<< 0 ≤ qr.r +-- · simp at h +-- simp at hr +-- simp [lt_def] +-- simp [le_def] at h +-- have hd' : d.toNat = 0 := by omega +-- simp [lt_def] at hd +-- omega +-- · simp at h +-- simp [lt_def] at hr +-- simp [lt_def] +-- simp [lt_def] at hd +-- simp [divremi, h] +-- omega +-- case succ j ih => +-- simp +-- by_cases h : d ≤ qr.r >>> (j + 1) +-- · simp [divremi_eq_of_le h] +-- simp at h +-- have hqr' : qr.r >>> (j + 1) = 0#w := by +-- apply BitVec.eq_of_toNat_eq +-- simp only [toNat_ushiftRight, toNat_ofNat, zero_mod] +-- rw [Nat.shiftRight_eq_div_pow] +-- rw[Nat.div_eq_of_lt (by omega)] +-- rw [hqr'] +-- rw [hqr'] at h +-- have hdcontra : d = 0#w := by +-- simp only [le_def, toNat_ofNat, zero_mod, le_zero_eq] at h +-- apply BitVec.eq_of_toNat_eq +-- simp [h] +-- simp [hdcontra] at hd +-- · simp [divremi_eq_of_not_le h] +-- apply ih hqr +-- simp [BitVec.le_def] at h +-- sorry -@[simp] -theorem getLsb_clearLowBitsAfter {x : BitVec w} {j i : Nat} : - (clearLowBitsAfter x j).getLsb i = if j ≤ i then x.getLsb i else false := by - unfold clearLowBitsAfter - simp - by_cases hij : i < j - · simp only [hij, decide_True, Bool.not_true, Bool.and_false, Bool.false_and, false_eq, - and_eq_false_imp, decide_eq_true_eq] - intros hcontra - omega - · simp [hij, show j + (i - j) = i by omega, show j ≤ i by omega] - apply lt_of_getLsb - -/-- -A bitvector can be broken down into the low bits (by truncate) and the high -bits (by left shift followed by right shift). --/ -theorem BitVec.zeroExtend_truncate_or_shiftRight_shiftLeft_eq_self {x : BitVec w} {i : Nat} : - (x.truncate i).zeroExtend w ||| (x >>> i) <<< i = x := by - ext j - by_cases h : j < i - · simp [h] - · simp only [getLsb_or, getLsb_zeroExtend, Fin.is_lt, decide_True, h, decide_False, - Bool.false_and, Bool.and_false, getLsb_shiftLeft, Bool.not_false, Bool.and_self, - getLsb_ushiftRight, Bool.true_and, Bool.false_or] - congr - omega - -/-- Key relationship that establishes the loop invariant after one iteration. -/ -theorem DivRecQuotRem.rec_lawful_of_lawful - {qr : DivRecQuotRem w n d} - (hqr : qr.Lawful) - {hq : ∀ {i : Nat}, i ≤ j + 1 → qr.q.getLsb i = false} - {hd : d ≤ qr.r >>> (j + 1)} : - { r := qr.r >>> (j + 1) - d, q := qr.q ||| twoPow w (j + 1) : DivRecQuotRem w n d }.Lawful := by - simp only [DivRecQuotRem.Lawful] - rw [← BitVec.add_eq_or_of_and_eq_zero] - · simp - /- - w : Nat - n d : BitVec w - j : Nat - qr : DivRecQuotRem w n d - hqr : qr.Lawful - hq : ∀ {i : Nat}, i ≤ j + 1 → qr.q.getLsb i = false - hd : d ≤ qr.r >>> (j + 1) - ⊢ d.toNat * ((qr.q.toNat + 2 ^ (j + 1)) % 2 ^ w) + (qr.r.toNat >>> (j + 1) + (2 ^ w - d.toNat)) % 2 ^ w = n.toNat - -/ - sorry - · simp - specialize hq (i := j + 1) (by omega) - rw [hq] - simp --- d * j + q = n. -theorem divRec_lawful {qr : DivRecQuotRem w n d} {j : Nat} - (hqr : qr.Lawful) - -- | We start the reucrrence at j=w-1, so all bits are zero. - -- | At one step of j=w-2, only the low bit maybe set, and all - -- bits at [1..w) are zero. - -- We end the recurrence at j=0, when no bit is forced to be zero. - -- | the quotient is correct, if the quotient is zero for all bits - -- | in the range `(j..0]`. - (hq : ∀ {i : Nat} (hi : i ≤ j), qr.q.getLsb i = false) : - (divRec qr j).Lawful := by - induction j generalizing qr - case zero => - simp [DivRecQuotRem.Lawful] - by_cases h : d ≤ qr.r >>> 0 - · simp [divremi_eq_of_le h] - simp at h - rcases w with rfl | w - · simp -- get rid of corner case with 1 % 2^0 = 1 % 1 = 0 - · have h1 : 1 % 2^(w + 1) = 1 := by - rw [Nat.mod_eq_of_lt] - apply Nat.one_lt_two_pow (by omega) - simp [h1] - specialize (hq (i := 0) (by omega)) - have hqr_q_or_1_to_Nat : (qr.q.toNat ||| 1) = (qr.q ||| 1).toNat := by - simp - rw [Nat.mod_eq_of_lt] - repeat omega - rw [hqr_q_or_1_to_Nat] - have hqr_or_1_eq_hqr_add_1 : (qr.q ||| 1) = (qr.q + twoPow (w+1) 0) := by - rw [add_eq_or_of_and_eq_zero] - · simp - · rw [and_twoPow_eq_getLsb, hq] - simp - rw [hqr_or_1_eq_hqr_add_1] - simp -- here we get a % 2^(w + 1) that we wish to avoid - calc - d.toNat * ((qr.q.toNat + 1) % 2 ^ (w + 1)) + (qr.r.toNat + (2 ^ (w + 1) - d.toNat)) % 2 ^ (w + 1) = d.toNat * ((qr.q.toNat + 1) % 2 ^ (w + 1)) + (qr.r.toNat + (2 ^ (w + 1) - d.toNat)) % 2 ^ (w + 1) := by rfl - _ = d.toNat * ((qr.q.toNat + 1)) + (qr.r.toNat + (2 ^ (w + 1) - d.toNat)) % 2 ^ (w + 1) := by - rw [Nat.mod_eq_of_lt] - simp [getLsb] at hq - omega - _ = d.toNat * (qr.q.toNat + 1) + (qr.r.toNat - d.toNat + (2 ^ (w + 1))) % 2 ^ (w + 1) := by - congr 2 - /- Note: omega needs this. Is the preprocessor supposed to pick this up? -/ - have h' : d.toNat ≤ qr.r.toNat := by simp [BitVec.le_def] at h; omega - omega - _ = d.toNat * qr.q.toNat + d.toNat + (qr.r.toNat - d.toNat + (2 ^ (w + 1))) % 2 ^ (w + 1) := by simp [Nat.mul_add] - _ = d.toNat * qr.q.toNat + d.toNat + (qr.r.toNat - d.toNat) % 2^(w + 1) + (2^(w+1) % 2 ^ (w + 1)) := by simp [Nat.add_mod] - _ = d.toNat * qr.q.toNat + d.toNat + (qr.r.toNat - d.toNat) % 2^(w + 1) + 0 := by simp - _ = d.toNat * qr.q.toNat + d.toNat + (qr.r.toNat - d.toNat) % 2^(w + 1) := by simp - _ = d.toNat * qr.q.toNat + d.toNat + (qr.r.toNat - d.toNat) := by rw [Nat.mod_eq_of_lt]; omega - _ = d.toNat * qr.q.toNat + (d.toNat - d.toNat + qr.r.toNat) := by - have h' : d.toNat ≤ qr.r.toNat := by simp [BitVec.le_def] at h; omega - omega - _ = d.toNat * qr.q.toNat + qr.r.toNat := by simp - _ = n.toNat := hqr - · simp [divremi_eq_of_not_le h] - exact hqr - case succ j ih => - simp - -- TODO: split below into a separate lemma - by_cases h : d ≤ qr.r >>> (j + 1) - · simp [divremi_eq_of_le h] - simp at h - apply ih - /- - case pos.hqr - w : Nat - n d : BitVec w - j : Nat - ih : ∀ {qr : DivRecQuotRem w n d}, qr.Lawful → (∀ {i : Nat}, i ≤ j → qr.q.getLsb i = false) → (divRec qr j).Lawful - qr : DivRecQuotRem w n d - hqr : qr.Lawful - hq : ∀ {i : Nat}, i ≤ j + 1 → qr.q.getLsb i = false - h : d ≤ qr.r >>> (j + 1) - ⊢ { r := qr.r >>> (j + 1) - d, q := qr.q ||| twoPow w (j + 1) }.Lawful - -/ - · sorry - · intros i hi - simp - constructor - · apply hq - omega - · intros h - omega - · simp [divremi_eq_of_not_le h] - simp at h - apply ih hqr - intros i hi - apply hq - omega - -#print axioms divRec_lawful - -theorem divRec_remainder_inbounds {qr : DivRecQuotRem w n d} {j : Nat} - (hqr : qr.Lawful) (hr : qr.r.toNat < 2^j) (hd : 0 < d): - (divRec qr j).r < d := by - induction j generalizing qr - case zero => - simp - by_cases h : d <<< 0 ≤ qr.r - · simp at h - simp at hr - simp [lt_def] - simp [le_def] at h - have hd' : d.toNat = 0 := by omega - simp [lt_def] at hd - omega - · simp at h - simp [lt_def] at hr - simp [lt_def] - simp [lt_def] at hd - simp [divremi, h] - omega - case succ j ih => - simp - by_cases h : d ≤ qr.r >>> (j + 1) - · simp [divremi_eq_of_le h] - simp at h - have hqr' : qr.r >>> (j + 1) = 0#w := by - apply BitVec.eq_of_toNat_eq - simp only [toNat_ushiftRight, toNat_ofNat, zero_mod] - rw [Nat.shiftRight_eq_div_pow] - rw[Nat.div_eq_of_lt (by omega)] - rw [hqr'] - rw [hqr'] at h - have hdcontra : d = 0#w := by - simp only [le_def, toNat_ofNat, zero_mod, le_zero_eq] at h - apply BitVec.eq_of_toNat_eq - simp [h] - simp [hdcontra] at hd - · simp [divremi_eq_of_not_le h] - apply ih hqr - simp [BitVec.le_def] at h - sorry - -theorem divRec_eq_udiv_of_lawful {qr : DivRecQuotRem w n d} (hqr : qr.Lawful) (hd : 0 < d) - {hqr' : qr' = (divRec qr w)} : - qr'.q = udiv n d := by - have hlawful : qr'.Lawful := by - simp [hqr', divRec_lawful hqr] - have hremainder : qr'.r < d := by - simp [hqr'] - apply divRec_remainder_inbounds hqr - apply qr.r.isLt - assumption - have this := div_characterized_of_mul_add_toNat - (d := d) (q := qr'.q) (n := n) (r := qr'.r) hd hremainder hlawful - simp [this.1] +-- theorem divRec_eq_udiv_of_lawful {qr : DivRecQuotRem w n d} (hqr : qr.Lawful) (hd : 0 < d) +-- {hqr' : qr' = (divRec qr w)} : +-- qr'.q = udiv n d := by +-- have hlawful : qr'.Lawful := by +-- simp [hqr', divRec_lawful hqr] +-- have hremainder : qr'.r < d := by +-- simp [hqr'] +-- apply divRec_remainder_inbounds hqr +-- apply qr.r.isLt +-- assumption +-- have this := div_characterized_of_mul_add_toNat +-- (d := d) (q := qr'.q) (n := n) (r := qr'.r) hd hremainder hlawful +-- simp [this.1] -- /- -- case neg.hr From cc6817d8e0f981edd17ae2a9f52cbd127ce37b56 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Tue, 2 Jul 2024 15:11:28 +0100 Subject: [PATCH 41/64] chore: add division invariant --- src/Init/Data/BitVec/div_invariant.py | 40 +++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 src/Init/Data/BitVec/div_invariant.py diff --git a/src/Init/Data/BitVec/div_invariant.py b/src/Init/Data/BitVec/div_invariant.py new file mode 100644 index 000000000000..de2cb09e99e4 --- /dev/null +++ b/src/Init/Data/BitVec/div_invariant.py @@ -0,0 +1,40 @@ +def get_lsb(n, j): + return int(bool(n & (1 << j))) + +def print_bits(w, n): + return ("{0:0%sb}" % (w)).format(n) + +def shift_subtract(w, n, d, q, r, j): + print(f"shift_subtract> n: '%s' | d: '%s' | q : '%s' | r : '%s' | j : '%s'" % + (print_bits(w, n), print_bits(w, d), print_bits(w, q), print_bits(w, r), j)) + print(f" j[%s] = %s" % (j, get_lsb(n, j))) + r = (r << 1) | get_lsb(n, j) + print(f" r = %s" % print_bits(w, r)) + if r >= d: + print(f" r > d.") + r -= d + q = (q << 1) | 1 + print(f" r.new = %s" % print_bits(w, r)) + print(f" q.new = %s" % print_bits(w, q)) + else: + print(f" r < d.") + q = (q << 1) + print(f" r.new = %s" % print_bits(w, r)) + print(f" q = %s" % print_bits(w, q)) + if j == 0: + return (q, r) + else: + return shift_subtract(w, n, d, q, r, j-1) + +# 10 / 3 = 3 +for n in range(1, 10): + for d in range(1, 10): + w = 4 + (q, r) = shift_subtract(w, n, d, 0, 0, w) + assert n == d * q + r + if n == d * q + r: + print ("verified correct invariant for n: '%s' | d : '%s' | q : '%s' r: '%s'" % + (n, d, q, r)) + else: + raise RuntimeError("verification failed for n: '%s' | d: '%s'" % (n, d)) + From 89f6087bcc6dca27b25ddddf773ae7a26cc54db0 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Wed, 3 Jul 2024 07:42:38 +0100 Subject: [PATCH 42/64] chore: add div invariant --- src/Init/Data/BitVec/Bitblast.lean | 15 +++++++++------ src/Init/Data/BitVec/div_invariant.py | 27 ++++++++++++++++++++++++--- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index d6e18891ce8e..81e75c42b7f6 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -726,7 +726,7 @@ theorem div_iff_add_mod_of_lt {d n q r : BitVec w} (hd : 0 < d) -- 0 0000 -- insert n.msb [0] -- 0010 0 -- subtract d, not successful -- 0000 -- result [unchanged] --- 0000 -- shift +-- 0000 -- shift left -- -- 1 0001 -- insert n.msb [1] -- 0010 0 -- subtract d, not successful @@ -750,18 +750,21 @@ structure DivRecQuotRem (w : Nat) (n : BitVec w) (d : BitVec w) where q : BitVec w deriving DecidableEq, Repr +theorem invariant_qr (r : Nat) (hr : r < d) : 2 * r + 1 - d < d := by + omega + def divRec (qr : DivRecQuotRem w n d) (j : Nat) : DivRecQuotRem w n d := let rj := (qr.r <<< 1) ||| (BitVec.ofBool (n.getLsb j)).zeroExtend w - let qr' := if rj ≤ d - then { r := rj, q := qr.q <<< 1} - else {r := rj - d, q := qr.q <<< 1 ||| 1 } + let qr' := + if rj ≤ d + then { r := rj, q := qr.q <<< 1 } + else { r := rj - d, q := qr.q <<< 1 ||| 1 } match j with | 0 => qr' | j + 1 => divRec qr' j -- invariants: --- 1) r < d. --- 2) +-- 1) qr.r < d. theorem div_rec_7_2 : (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0 } 3) = { r := 1, q := 3 } := by diff --git a/src/Init/Data/BitVec/div_invariant.py b/src/Init/Data/BitVec/div_invariant.py index de2cb09e99e4..6c960b8e90a4 100644 --- a/src/Init/Data/BitVec/div_invariant.py +++ b/src/Init/Data/BitVec/div_invariant.py @@ -4,7 +4,26 @@ def get_lsb(n, j): def print_bits(w, n): return ("{0:0%sb}" % (w)).format(n) +def check_pre_invariant(w, n, d, q, r, j): + assert r < d + qright = n // d + rright = n % d + + assert n >> (j + 1) == d * q + r + assert q == qright >> (j + 1) + assert r == ((n >> j) - (d * (qright >> j))) + pass + +def check_post_invariant(w, n, d, q, r, j): + qright = n // d + rright = n % d + assert r < d + assert q == qright >> j + assert n >> j == d * q + r + assert r == ((n >> j) - (d * (qright >> j))) + def shift_subtract(w, n, d, q, r, j): + check_pre_invariant(w, n, d, q, r, j) print(f"shift_subtract> n: '%s' | d: '%s' | q : '%s' | r : '%s' | j : '%s'" % (print_bits(w, n), print_bits(w, d), print_bits(w, q), print_bits(w, r), j)) print(f" j[%s] = %s" % (j, get_lsb(n, j))) @@ -22,15 +41,17 @@ def shift_subtract(w, n, d, q, r, j): print(f" r.new = %s" % print_bits(w, r)) print(f" q = %s" % print_bits(w, q)) if j == 0: - return (q, r) + (qout, rout) = (q, r) else: - return shift_subtract(w, n, d, q, r, j-1) + (qout, rout) = shift_subtract(w, n, d, q, r, j-1) + check_post_invariant(w, n, d, q, r, j) + return (qout, rout) # 10 / 3 = 3 for n in range(1, 10): for d in range(1, 10): w = 4 - (q, r) = shift_subtract(w, n, d, 0, 0, w) + (q, r) = shift_subtract(w, n, d, 0, 0, w-1) assert n == d * q + r if n == d * q + r: print ("verified correct invariant for n: '%s' | d : '%s' | q : '%s' r: '%s'" % From 44127be7ee6c60a158a156c5d80e11f508645c0e Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Thu, 4 Jul 2024 00:26:21 +0100 Subject: [PATCH 43/64] chore: add one impl of div_invariant --- src/Init/Data/BitVec/div_invariant.py | 32 +++++++++++++-------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/Init/Data/BitVec/div_invariant.py b/src/Init/Data/BitVec/div_invariant.py index 6c960b8e90a4..e34d1ed2a7ce 100644 --- a/src/Init/Data/BitVec/div_invariant.py +++ b/src/Init/Data/BitVec/div_invariant.py @@ -5,28 +5,26 @@ def print_bits(w, n): return ("{0:0%sb}" % (w)).format(n) def check_pre_invariant(w, n, d, q, r, j): - assert r < d qright = n // d rright = n % d + assert r < d - assert n >> (j + 1) == d * q + r - assert q == qright >> (j + 1) - assert r == ((n >> j) - (d * (qright >> j))) - pass - +# n / d <-> n = q * d + r def check_post_invariant(w, n, d, q, r, j): qright = n // d rright = n % d assert r < d - assert q == qright >> j - assert n >> j == d * q + r - assert r == ((n >> j) - (d * (qright >> j))) + nhigh = n >> j + print(" n >> j = %s | q(%s) * d(%s) + r(%s) = (%s)" % + (print_bits(w, nhigh), print_bits(w, q), print_bits(w, d), print_bits(w, r), print_bits(w, q * d + r))) + assert nhigh == d * q + r def shift_subtract(w, n, d, q, r, j): - check_pre_invariant(w, n, d, q, r, j) print(f"shift_subtract> n: '%s' | d: '%s' | q : '%s' | r : '%s' | j : '%s'" % (print_bits(w, n), print_bits(w, d), print_bits(w, q), print_bits(w, r), j)) - print(f" j[%s] = %s" % (j, get_lsb(n, j))) + print(f" n[%s] = %s" % (j, get_lsb(n, j))) + check_pre_invariant(w, n, d, q, r, j) + r = (r << 1) | get_lsb(n, j) print(f" r = %s" % print_bits(w, r)) if r >= d: @@ -40,22 +38,24 @@ def shift_subtract(w, n, d, q, r, j): q = (q << 1) print(f" r.new = %s" % print_bits(w, r)) print(f" q = %s" % print_bits(w, q)) + check_post_invariant(w, n, d, q, r, j) if j == 0: (qout, rout) = (q, r) else: (qout, rout) = shift_subtract(w, n, d, q, r, j-1) - check_post_invariant(w, n, d, q, r, j) return (qout, rout) # 10 / 3 = 3 -for n in range(1, 10): - for d in range(1, 10): - w = 4 +for n in range(1, 32): + for d in range(1, 32): + w = 6 (q, r) = shift_subtract(w, n, d, 0, 0, w-1) assert n == d * q + r - if n == d * q + r: + if n == d * q + r and r < d: print ("verified correct invariant for n: '%s' | d : '%s' | q : '%s' r: '%s'" % (n, d, q, r)) else: raise RuntimeError("verification failed for n: '%s' | d: '%s'" % (n, d)) + + From a2c357603219185ffd9cbc26f121c8ec036f3d4d Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Thu, 4 Jul 2024 09:00:36 +0100 Subject: [PATCH 44/64] chore: add new division invariant that should be easier to prove --- src/Init/Data/BitVec/div_invariant.py | 1 - src/Init/Data/BitVec/div_new_invariant.py | 68 +++++++++++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) create mode 100755 src/Init/Data/BitVec/div_new_invariant.py diff --git a/src/Init/Data/BitVec/div_invariant.py b/src/Init/Data/BitVec/div_invariant.py index e34d1ed2a7ce..999eb3f56ac3 100644 --- a/src/Init/Data/BitVec/div_invariant.py +++ b/src/Init/Data/BitVec/div_invariant.py @@ -58,4 +58,3 @@ def shift_subtract(w, n, d, q, r, j): raise RuntimeError("verification failed for n: '%s' | d: '%s'" % (n, d)) - diff --git a/src/Init/Data/BitVec/div_new_invariant.py b/src/Init/Data/BitVec/div_new_invariant.py new file mode 100755 index 000000000000..ff88dae5b9bd --- /dev/null +++ b/src/Init/Data/BitVec/div_new_invariant.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 + +def get_lsb(n, j): + return int(bool(n & (1 << j))) + +def print_bits(w, n): + return ("{0:0%sb}" % (w)).format(n) + +def check_pre_invariant(w, n, d, q, r, j): + qright = n // d + rright = n % d + assert r < d + +# n / d <-> n = q * d + r +def check_post_rec_invariant(w, n, d, q, r, j): + qright = n // d + rright = n % d + assert r < d + assert n >> (w - j) == d * q + r + +# n / d <-> n = q * d + r +def check_final_invariant(w, n, d, q, r, j): + qright = n // d + rright = n % d + assert r < d + assert n >> ((w - 1) - j) == d * q + r + +def shift_subtract(w, n, d, q, r, j): + print(f"shift_subtract> n: '%s' | d: '%s' | q : '%s' | r : '%s' | j : '%s'" % + (print_bits(w, n), print_bits(w, d), print_bits(w, q), print_bits(w, r), j)) + print(f" n[%s] = %s" % (j, get_lsb(n, j))) + check_pre_invariant(w, n, d, q, r, j) + if j > 0: + (q, r) = shift_subtract(w, n, d, q, r, j-1) + check_post_rec_invariant(w, n, d, q, r, j) + + # do the last bit. + ix = (w - 1) - j + assert ix >= 0 + r = (r << 1) | get_lsb(n, ix) + print(f" r = %s" % print_bits(w, r)) + if r >= d: + print(f" r > d.") + r -= d + q = (q << 1) | 1 + print(f" r.new = %s" % print_bits(w, r)) + print(f" q.new = %s" % print_bits(w, q)) + else: + print(f" r < d.") + q = (q << 1) + print(f" r.new = %s" % print_bits(w, r)) + print(f" q = %s" % print_bits(w, q)) + check_final_invariant(w, n, d, q, r, j) + return (q, r) + +# 10 / 3 = 3 +for n in range(1, 32): + for d in range(1, 32): + w = 6 + (q, r) = shift_subtract(w, n, d, 0, 0, w-1) + assert n == d * q + r + if n == d * q + r and r < d: + print ("verified correct invariant for n: '%s' | d : '%s' | q : '%s' r: '%s'" % + (n, d, q, r)) + else: + raise RuntimeError("verification failed for n: '%s' | d: '%s'" % (n, d)) + + From 56ea083f2ab0a48d3fe5ee4d4221a761aaae8ef1 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Thu, 4 Jul 2024 10:30:36 +0100 Subject: [PATCH 45/64] chore: prove remiander property, still a sorry left --- src/Init/Data/BitVec/Bitblast.lean | 140 +++++++++++++++++++++++++---- 1 file changed, 124 insertions(+), 16 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 81e75c42b7f6..472168f90910 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -750,32 +750,140 @@ structure DivRecQuotRem (w : Nat) (n : BitVec w) (d : BitVec w) where q : BitVec w deriving DecidableEq, Repr -theorem invariant_qr (r : Nat) (hr : r < d) : 2 * r + 1 - d < d := by + +theorem BitVec.shiftLeft_eq_mul_twoPow (x : BitVec w) (n : Nat) : + x <<< n = x * (BitVec.twoPow w n) := by + ext i + simp + +/-- One round of the division algorithm, that tries to perform a subtract shift. -/ +def tryDivSubtractShift (qr : DivRecQuotRem w n d) (ix : Nat) : DivRecQuotRem w n d := + let r' := (qr.r <<< 1) ||| (BitVec.ofBool (n.getLsb ix)).zeroExtend w + if r' < d + then { r := r', q := qr.q <<< 1 } + else { + r := r' - d, + q := qr.q <<< 1 ||| 1 + } + +/- Surely this exists somewhere, I remember proving this even -/ +theorem Nat.sub_mod_self_eq_sub {x n : Nat} (hx₀ : 0 < x := by omega) (hxn : x < n := by omega) : (n - x) % n = n - x := by + rw [Nat.mod_eq_of_lt] omega -def divRec (qr : DivRecQuotRem w n d) (j : Nat) : DivRecQuotRem w n d := - let rj := (qr.r <<< 1) ||| (BitVec.ofBool (n.getLsb j)).zeroExtend w +theorem tryDivSubtractShift_lt_of_lt {qr : DivRecQuotRem w n d} {ix : Nat} (hrlt : qr.r < d) : + (tryDivSubtractShift qr ix).r < d := by + simp only [tryDivSubtractShift, ofNat_eq_ofNat] + generalize hr₂ : qr.r <<< 1 ||| zeroExtend w (ofBool (n.getLsb ix)) = r₂ + by_cases hr₂lt : r₂ < d + · simp [hr₂lt] + · simp [hr₂lt] + rw [← BitVec.add_eq_or_of_and_eq_zero] at hr₂ + rw [BitVec.shiftLeft_eq_mul_twoPow] at hr₂ + · simp only [BitVec.lt_def] at hr₂ hr₂lt ⊢ + simp only [toNat_sub, toNat_add, toNat_shiftLeft, toNat_truncate, + toNat_ofBool, add_mod_mod, mod_add_mod, toNat_mul] at hr₂ hr₂lt ⊢ + -- simp only [toNat_twoPow, Nat.pow_one] at hr₂ + rcases w with rfl | rfl | w + · have hr : qr.r = 0#0 := by apply Subsingleton.elim + have hd : d = 0#0 := by apply Subsingleton.elim + rw [hr, hd] at hrlt + simp at hrlt -- TODO: golf this with simpa, ask alex. + · simp only [Nat.reduceAdd, Nat.zero_add, Nat.pow_one, mod_self, Nat.mul_zero] at hrlt hr₂ ⊢ + simp only [Nat.reduceAdd, lt_def] at hrlt hr₂ + simp only [Nat.reduceAdd, zeroExtend_eq, lt_def, toNat_or, toNat_shiftLeft, Nat.pow_one, + toNat_ofBool, Nat.not_lt] at hr₂ + have hd : d.toNat < 2 := d.isLt + generalize hb : (n.getLsb ix) = b + rw [hb] at hr₂ + replace hd : d.toNat = 0 ∨ d.toNat = 1 := by omega; + rcases hd with hd | hd + · omega -- d ≠ 0 + · rw [hd] at hr₂lt hrlt + rcases b with rfl | rfl + · replace hrlt : qr.r.toNat = 0 := by omega + rw [← hr₂] at hr₂lt + simp at hr₂lt + rw [hrlt] at hr₂lt + simp at hr₂lt + · simp; omega + · have hr₂lt₂ : r₂.toNat - d.toNat < d.toNat := by sorry + calc + _ = (r₂.toNat + (2 ^ (w + 1 + 1) - d.toNat)) % 2 ^ (w + 1 + 1) := by rfl + _ = ((r₂.toNat + (2 ^ (w + 1 + 1)) - d.toNat)) % 2 ^ (w + 1 + 1) := by + rw [Nat.add_sub_assoc] + have := d.isLt + omega + _ = (((2 ^ (w + 1 + 1) + r₂.toNat) - d.toNat)) % 2 ^ (w + 1 + 1) := by + rw [Nat.add_comm] + _ = (2 ^ (w + 1 + 1) + (r₂.toNat - d.toNat)) % 2 ^ (w + 1 + 1) := by + congr 1 + rw [Nat.add_sub_assoc] + omega + _ = ((2 ^ (w + 1 + 1) % 2 ^ (w + 1 + 1)) + ((r₂.toNat - d.toNat) % 2 ^ (w + 1 + 1))) % (2 ^ (w + 1 + 1)) := by + rw [Nat.add_mod] + _ = (r₂.toNat - d.toNat) % 2 ^ (w + 1 + 1) := by + simp + _ = (r₂.toNat - d.toNat) := by + rw [Nat.mod_eq_of_lt] + omega + _ < d.toNat := by omega + · ext i + simp + intros hi _ hi' + omega + +/-- repeatedly apply `tryDivSubtractShift`. -/ +def divRec (qr : DivRecQuotRem w n d) (j : Nat) : + DivRecQuotRem w n d := let qr' := - if rj ≤ d - then { r := rj, q := qr.q <<< 1 } - else { r := rj - d, q := qr.q <<< 1 ||| 1 } - match j with - | 0 => qr' - | j + 1 => divRec qr' j + match j with + | 0 => qr + | j + 1 => divRec qr j + tryDivSubtractShift qr' (w - 1 - j) + + def checkDivRec : Bool × Array String := Id.run do + let w := 4 + let max := (Nat.pow 2 w) + let mut outputs := #[] + let mut wrong := false + for n in (List.range max) do + for d in (List.range (max - 1)).map (fun n => Nat.add n 1) do + have hd : d > 0 := by sorry + let qr := divRec (w := w) (n := n) (d := d) { r := 0, q := 0, hr := by sorry } (w - 1) + if qr.q * d + qr.r != n then + outputs := outputs.push s!"ERROR: n = {n}, d = {d}, q = {qr.q}, r = {qr.r}, n = {n}, d = {d}, q = {qr.q}, r = {qr.r}" + wrong := true + (wrong, outputs) + + +theorem divRec_postcondition {w : Nat} {n d : BitVec w} (qr : DivRecQuotRem w n d) (j : Nat) (hj : j ≤ w - 1) : + let qr' := divRec qr j + (qr'.q * d + qr'.r) = n >>> ((w - 1) - j) := by + induction j generalizing qr + · sorry + · sorry + +/-- info: (false, { data := [] }) -/ +#guard_msgs in #reduce checkDivRec + +-- theorem divRec_n (qr : DivRecQuotRem w n d) : +-- d * (tryDivSubtractShift qr j).q + (tryDivSubtractShift q j).r = n >>> (j + 1) -- invariants: -- 1) qr.r < d. theorem div_rec_7_2 : - (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0 } 3) = - { r := 1, q := 3 } := by - simp [divRec] + (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0, hr := by sorry } 3) = + { r := 1, q := 3, hr := by sorry } := by + simp [divRec, tryDivSubtractShift] -- invariant 2 -- n.toNat % 2^j = d.toNat * q.toNat + r.toNat -#reduce (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0 } 3) -- r = 1, q = 3 -#reduce (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0 } 2) -- r = 1, q = 3 -#reduce (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0 } 1) -- r = 1, q = 1 -#reduce (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0 } 0) -- r = 1, q = 0 +-- set_option maxHeartbeats 99999 in +-- #reduce (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0 } 3) -- r = 1, q = 3 +-- #reduce (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0 } 2) -- r = 1, q = 3 +-- #reduce (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0 } 1) -- r = 1, q = 1 +-- #reduce (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0 } 0) -- r = 1, q = 0 -- /- Given d, R(j + 1), (calculate R(j), q.getLsb j). -/ -- def divremi (qr : DivRecQuotRem w n d) (j : Nat) : BitVec w × Bool := From 9ca69d491fdcd4cc3c2b44f4363bfd8516fb51e9 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Thu, 4 Jul 2024 11:09:54 +0100 Subject: [PATCH 46/64] chore: add new invariant code. This is more subtle than I thought for the second time --- src/Init/Data/BitVec/Bitblast.lean | 72 +++++++++++++++++------ src/Init/Data/BitVec/div_new_invariant.py | 12 ++++ 2 files changed, 67 insertions(+), 17 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 472168f90910..5ae8c79b5135 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -756,9 +756,15 @@ theorem BitVec.shiftLeft_eq_mul_twoPow (x : BitVec w) (n : Nat) : ext i simp +theorem foo (a b : Nat) (hr : r * 2 + 1 < d) : r < (d - 1) / 2 := by + sorry + /-- One round of the division algorithm, that tries to perform a subtract shift. -/ def tryDivSubtractShift (qr : DivRecQuotRem w n d) (ix : Nat) : DivRecQuotRem w n d := let r' := (qr.r <<< 1) ||| (BitVec.ofBool (n.getLsb ix)).zeroExtend w + -- r * 2 + 1 < d. -- Why does this not overflow? + -- r < d - 1 + -- r < d / 2 if r' < d then { r := r', q := qr.q <<< 1 } else { @@ -771,7 +777,13 @@ theorem Nat.sub_mod_self_eq_sub {x n : Nat} (hx₀ : 0 < x := by omega) (hxn : x rw [Nat.mod_eq_of_lt] omega -theorem tryDivSubtractShift_lt_of_lt {qr : DivRecQuotRem w n d} {ix : Nat} (hrlt : qr.r < d) : +@[simp] +theorem Bool.toNat_lt (b : Bool) : b.toNat < 2 := by + have h := Bool.toNat_le b + omega + +/-- TODO: This shows that the remainer is always going to be below 'd', and does not overflow. -/ +theorem tryDivSubtractShift_lt_of_lt {qr : DivRecQuotRem w n d} {ix : Nat} (hrlt : qr.r < d) (hrltTwoPow : qr.r.toNat * 2 + 1 < 2 ^ w): (tryDivSubtractShift qr ix).r < d := by simp only [tryDivSubtractShift, ofNat_eq_ofNat] generalize hr₂ : qr.r <<< 1 ||| zeroExtend w (ofBool (n.getLsb ix)) = r₂ @@ -807,7 +819,23 @@ theorem tryDivSubtractShift_lt_of_lt {qr : DivRecQuotRem w n d} {ix : Nat} (hrlt rw [hrlt] at hr₂lt simp at hr₂lt · simp; omega - · have hr₂lt₂ : r₂.toNat - d.toNat < d.toNat := by sorry + · have hr₂lt₂ : r₂.toNat - d.toNat < d.toNat := by + rw [← hr₂] + simp only [mul_twoPow_eq_shiftLeft, toNat_add, toNat_shiftLeft, toNat_truncate, + toNat_ofBool, add_mod_mod, mod_add_mod] + rw [Nat.shiftLeft_eq] + simp only [Nat.pow_one] + have hd : d.toNat < 2^(w + 1 + 1) := d.isLt + have hb : (n.getLsb ix).toNat < 2 := by simp + simp only [lt_def] at hrlt + rw [Nat.mod_eq_of_lt] + · -- r < d [integers] + -- r - 1 <= d + -- 2(r - 2) <= 2d + -- 2r - 2 - d <= d + -- 2r - 1 - d < d + omega + · omega -- here is the use of hrltTwoPow calc _ = (r₂.toNat + (2 ^ (w + 1 + 1) - d.toNat)) % 2 ^ (w + 1 + 1) := by rfl _ = ((r₂.toNat + (2 ^ (w + 1 + 1)) - d.toNat)) % 2 ^ (w + 1 + 1) := by @@ -842,19 +870,29 @@ def divRec (qr : DivRecQuotRem w n d) (j : Nat) : | j + 1 => divRec qr j tryDivSubtractShift qr' (w - 1 - j) - def checkDivRec : Bool × Array String := Id.run do - let w := 4 - let max := (Nat.pow 2 w) - let mut outputs := #[] - let mut wrong := false - for n in (List.range max) do - for d in (List.range (max - 1)).map (fun n => Nat.add n 1) do - have hd : d > 0 := by sorry - let qr := divRec (w := w) (n := n) (d := d) { r := 0, q := 0, hr := by sorry } (w - 1) - if qr.q * d + qr.r != n then - outputs := outputs.push s!"ERROR: n = {n}, d = {d}, q = {qr.q}, r = {qr.r}, n = {n}, d = {d}, q = {qr.q}, r = {qr.r}" - wrong := true - (wrong, outputs) +theorem divRec_remainder_lt_twoPow (qr : DivRecQuotRem w n d) (j : Nat) (hj : j < w) (hr : qr.r < d) (hr₂ : qr.r.toNat * 2 + 1 < 2 ^ w) : + (divRec qr j).r.toNat * 2 + 1 < 2 ^ w := by + induction j generalizing qr + case zero => + simp [divRec, tryDivSubtractShift] + exact hr + case succ j ih => + simp [divRec] + apply tryDivSubtractShift_lt_of_lt hr hr₂ + +def checkDivRec : Bool × Array String := Id.run do + let w := 4 + let max := (Nat.pow 2 w) + let mut outputs := #[] + let mut wrong := false + for n in (List.range max) do + for d in (List.range (max - 1)).map (fun n => Nat.add n 1) do + have hd : d > 0 := by sorry + let qr := divRec (w := w) (n := n) (d := d) { r := 0, q := 0 } (w - 1) + if qr.q * d + qr.r != n then + outputs := outputs.push s!"ERROR: n = {n}, d = {d}, q = {qr.q}, r = {qr.r}, n = {n}, d = {d}, q = {qr.q}, r = {qr.r}" + wrong := true + (wrong, outputs) theorem divRec_postcondition {w : Nat} {n d : BitVec w} (qr : DivRecQuotRem w n d) (j : Nat) (hj : j ≤ w - 1) : @@ -873,8 +911,8 @@ theorem divRec_postcondition {w : Nat} {n d : BitVec w} (qr : DivRecQuotRem w n -- invariants: -- 1) qr.r < d. theorem div_rec_7_2 : - (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0, hr := by sorry } 3) = - { r := 1, q := 3, hr := by sorry } := by + (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0 } 3) = + { r := 1, q := 3 } := by simp [divRec, tryDivSubtractShift] -- invariant 2 diff --git a/src/Init/Data/BitVec/div_new_invariant.py b/src/Init/Data/BitVec/div_new_invariant.py index ff88dae5b9bd..151ae6b3b9bd 100755 --- a/src/Init/Data/BitVec/div_new_invariant.py +++ b/src/Init/Data/BitVec/div_new_invariant.py @@ -38,6 +38,7 @@ def shift_subtract(w, n, d, q, r, j): ix = (w - 1) - j assert ix >= 0 r = (r << 1) | get_lsb(n, ix) + assert r < 2 ** w # how is this loop invariant upheld, right after doing weird operations? Very weird. print(f" r = %s" % print_bits(w, r)) if r >= d: print(f" r > d.") @@ -53,6 +54,17 @@ def shift_subtract(w, n, d, q, r, j): check_final_invariant(w, n, d, q, r, j) return (q, r) + + +w = 4 +d = 10 # d * 2 will overflow. +(q, r) = shift_subtract(w, n, d, 0, 0, w-1) +assert n == d * q + r +if n == d * q + r and r < d: + print ("verified correct invariant for n: '%s' | d : '%s' | q : '%s' r: '%s'" % + (n, d, q, r)) +else: + raise RuntimeError("verification failed for n: '%s' | d: '%s'" % (n, d)) # 10 / 3 = 3 for n in range(1, 32): for d in range(1, 32): From 89ce200332e72f48a89e49c12670b4e9580f2cfb Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Thu, 4 Jul 2024 15:34:07 +0100 Subject: [PATCH 47/64] chore: stash --- src/Init/Data/BitVec/Bitblast.lean | 178 ++++++++++++++++++++++++++--- 1 file changed, 161 insertions(+), 17 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 5ae8c79b5135..c22ee3a89265 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -762,9 +762,6 @@ theorem foo (a b : Nat) (hr : r * 2 + 1 < d) : r < (d - 1) / 2 := by /-- One round of the division algorithm, that tries to perform a subtract shift. -/ def tryDivSubtractShift (qr : DivRecQuotRem w n d) (ix : Nat) : DivRecQuotRem w n d := let r' := (qr.r <<< 1) ||| (BitVec.ofBool (n.getLsb ix)).zeroExtend w - -- r * 2 + 1 < d. -- Why does this not overflow? - -- r < d - 1 - -- r < d / 2 if r' < d then { r := r', q := qr.q <<< 1 } else { @@ -772,6 +769,64 @@ def tryDivSubtractShift (qr : DivRecQuotRem w n d) (ix : Nat) : DivRecQuotRem w q := qr.q <<< 1 ||| 1 } +/-- Same as tryDivSubtractShift, with if-then-else pushed into the record, -/ +def tryDivSubtractShift' (qr : DivRecQuotRem w n d) (ix : Nat) : DivRecQuotRem w n d := + let r' := (qr.r <<< 1) ||| (BitVec.ofBool (n.getLsb ix)).zeroExtend w + { r := if r' < d then r' else r' - d, q := qr.q <<< 1 ||| (if r' < d then 0 else 1) } + +@[simp] +theorem BitVec.or_zero (x : BitVec w) : x ||| 0#w = x := by + ext i + simp + +theorem tryDivSubtractShift_eq_tryDivSubtractShift' (qr : DivRecQuotRem w n d) (ix : Nat) : + tryDivSubtractShift qr ix = tryDivSubtractShift' qr ix := by + simp [tryDivSubtractShift, tryDivSubtractShift'] + generalize qr.r <<< 1 ||| zeroExtend w (ofBool (n.getLsb ix)) = s + by_cases hslt : s < d + · simp [hslt] + · simp [hslt] + +theorem BitVec.sub_le_self_of_le {x y : BitVec w} (hx : y ≤ x) : x - y ≤ x := by + simp [BitVec.lt_def, BitVec.le_def] at hx ⊢ + rw [← Nat.add_sub_assoc (by omega)] + rw [Nat.add_comm] + rw [Nat.add_sub_assoc (by omega)] + rw [Nat.add_mod] + simp only [mod_self, Nat.zero_add, mod_mod] + rw [Nat.mod_eq_of_lt] <;> omega + +theorem BitVec.sub_lt_self_of_lt_of_lt {x y : BitVec w} (hx : y < x) (hy : 0 < y): x - y < x := by + simp [BitVec.lt_def] at hx hy ⊢ + rw [← Nat.add_sub_assoc (by omega)] + rw [Nat.add_comm] + rw [Nat.add_sub_assoc (by omega)] + rw [Nat.add_mod] + simp only [mod_self, Nat.zero_add, mod_mod] + rw [Nat.mod_eq_of_lt] <;> omega + +theorem BitVec.le_iff_not_lt {x y : BitVec w} : (¬ x < y) ↔ y ≤ x := by + constructor <;> + (intro h; simp [BitVec.lt_def, BitVec.le_def] at h ⊢; omega) + +@[simp] +theorem BitVec.le_refl (x : BitVec w) : x ≤ x := by + simp [BitVec.le_def] + + +/-- The tryDivSubtractShift's remainder is upper bounded by `r << 1 | 1`. -/ +theorem tryDivSubtractShift_remainder_lt_shiftLeft_one_or_one {qr : DivRecQuotRem w n d} {ix : Nat} : + (tryDivSubtractShift qr ix).r ≤ (qr.r <<< 1) ||| (BitVec.ofBool (n.getLsb ix)).zeroExtend w := by + rw [tryDivSubtractShift_eq_tryDivSubtractShift'] + simp only [tryDivSubtractShift'] + generalize qr.r <<< 1 ||| zeroExtend w (ofBool (n.getLsb ix)) = s + by_cases hslt : s < d + · simp [hslt] + · simp [hslt] + apply BitVec.sub_le_self_of_le + apply BitVec.le_iff_not_lt.mp hslt + + /- Surely this exists somewhere, I remember proving this even -/ theorem Nat.sub_mod_self_eq_sub {x n : Nat} (hx₀ : 0 < x := by omega) (hxn : x < n := by omega) : (n - x) % n = n - x := by rw [Nat.mod_eq_of_lt] @@ -787,7 +842,7 @@ theorem tryDivSubtractShift_lt_of_lt {qr : DivRecQuotRem w n d} {ix : Nat} (hrlt (tryDivSubtractShift qr ix).r < d := by simp only [tryDivSubtractShift, ofNat_eq_ofNat] generalize hr₂ : qr.r <<< 1 ||| zeroExtend w (ofBool (n.getLsb ix)) = r₂ - by_cases hr₂lt : r₂ < d + by_cases hr₂lt : r₂ < d · simp [hr₂lt] · simp [hr₂lt] rw [← BitVec.add_eq_or_of_and_eq_zero] at hr₂ @@ -861,6 +916,11 @@ theorem tryDivSubtractShift_lt_of_lt {qr : DivRecQuotRem w n d} {ix : Nat} (hrlt intros hi _ hi' omega +/-- +info: 'BitVec.tryDivSubtractShift_lt_of_lt' depends on axioms: [propext, Quot.sound, Classical.choice] +-/ +#guard_msgs in #print axioms tryDivSubtractShift_lt_of_lt + /-- repeatedly apply `tryDivSubtractShift`. -/ def divRec (qr : DivRecQuotRem w n d) (j : Nat) : DivRecQuotRem w n d := @@ -870,15 +930,106 @@ def divRec (qr : DivRecQuotRem w n d) (j : Nat) : | j + 1 => divRec qr j tryDivSubtractShift qr' (w - 1 - j) -theorem divRec_remainder_lt_twoPow (qr : DivRecQuotRem w n d) (j : Nat) (hj : j < w) (hr : qr.r < d) (hr₂ : qr.r.toNat * 2 + 1 < 2 ^ w) : - (divRec qr j).r.toNat * 2 + 1 < 2 ^ w := by +theorem BitVec.shiftLeft_mul_comm (x y : BitVec w) (n : Nat) : + x <<< n * y = x * y <<< n := by + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [BitVec.mul_assoc] + congr 1 + apply BitVec.mul_comm + +theorem BitVec.shiftLeft_mul_assoc (x y : BitVec w) (n : Nat) : + x * y <<< n = (x * y) <<< n := by + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [BitVec.mul_assoc] + +theorem BitVec.add_mul (x y z : BitVec w) : (y + z) * x = y * x + z * x := by + conv => + lhs + rw [BitVec.mul_comm, BitVec.mul_add] + congr 1 <;> rw [BitVec.mul_comm] + + +/-- +TODO: what's a good theorem name? +If the LSB is false, then shifting to (w - 1) is the same as shifting to w and then right shifting 1. +-/ +private theorem BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_of_getLsb_false + {x : BitVec w} (hx : x.getLsb (w - 1) = false) : + x >>> (w - 1) = x >>> w <<< 1 := by + ext i + simp only [getLsb_ushiftRight, getLsb_shiftLeft, Fin.is_lt, decide_True, Bool.true_and] + by_cases (i : Nat) < 1 + case pos h => + simp only [h, decide_True, Bool.not_true, Bool.false_and] + have hi : (i : Nat) = 0 := by omega + simp [hi, hx] + case neg h => + simp only [h, decide_False, Bool.not_false, Bool.true_and] + congr 1 + omega + +theorem BitVec.add_assoc {x y z : BitVec w} : x + y + z = x + (y + z) := by + apply eq_of_toNat_eq + simp + rw [Nat.add_assoc] + +theorem divRec_correct {w : Nat} {n d : BitVec w} {qr : DivRecQuotRem w n d} {j : Nat} (hj : j ≤ w - 1) + (hqrd : qr.r < d) + (hrn : qr.r < n >>> (j + 1)) + (hqrn : n >>> (w - j) == qr.q * d + qr.r) : + n >>> ((w - 1) - j) == (divRec qr j).q * d + (divRec qr j).r := by induction j generalizing qr case zero => - simp [divRec, tryDivSubtractShift] - exact hr - case succ j ih => simp [divRec] - apply tryDivSubtractShift_lt_of_lt hr hr₂ + simp at hqrn + -- simp [tryDivSubtractShift] + simp [tryDivSubtractShift_eq_tryDivSubtractShift'] + simp [tryDivSubtractShift'] + generalize hb : n.getLsb (w - 1) = b + generalize hs : qr.r <<< 1 ||| zeroExtend w (ofBool b) = s + by_cases hslt : s < d + · simp [hslt] + rcases b with rfl | rfl + · simp_all + rw [← hs] + have qd : qr.q <<< 1 * d = (qr.q * d) <<< 1 := by + rw [BitVec.shiftLeft_mul_comm] + rw [BitVec.shiftLeft_mul_assoc] + rw [qd] + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [← BitVec.add_mul] + rw [← hqrn] + rw [← BitVec.shiftLeft_eq_mul_twoPow] + rw [BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_of_getLsb_false hb] + · simp_all + rw [← hs] + have qd : qr.q <<< 1 * d = (qr.q * d) <<< 1 := by + rw [BitVec.shiftLeft_mul_comm] + rw [BitVec.shiftLeft_mul_assoc] + rw [qd] + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [← add_eq_or_of_and_eq_zero] + · rw [← BitVec.add_assoc] + rw [← BitVec.add_mul] + rw [← hqrn] + rw [← BitVec.shiftLeft_eq_mul_twoPow] + + · ext i + simp + intros i _ hi' + omega + · simp [hslt] + sorry + sorry + case succ j' ih => + simp [divRec] + simp at hqrn + simp [tryDivSubtractShift_eq_tryDivSubtractShift'] + sorry def checkDivRec : Bool × Array String := Id.run do let w := 4 @@ -895,13 +1046,6 @@ def checkDivRec : Bool × Array String := Id.run do (wrong, outputs) -theorem divRec_postcondition {w : Nat} {n d : BitVec w} (qr : DivRecQuotRem w n d) (j : Nat) (hj : j ≤ w - 1) : - let qr' := divRec qr j - (qr'.q * d + qr'.r) = n >>> ((w - 1) - j) := by - induction j generalizing qr - · sorry - · sorry - /-- info: (false, { data := [] }) -/ #guard_msgs in #reduce checkDivRec From 53fbd3a5c388a1abcb94584dc7d476c9b4c9f519 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Thu, 4 Jul 2024 15:39:18 +0100 Subject: [PATCH 48/64] chore: one half of base case --- src/Init/Data/BitVec/Bitblast.lean | 67 ++++++++++++++---------------- 1 file changed, 32 insertions(+), 35 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index c22ee3a89265..8663b5961b9b 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -955,21 +955,34 @@ theorem BitVec.add_mul (x y z : BitVec w) : (y + z) * x = y * x + z * x := by TODO: what's a good theorem name? If the LSB is false, then shifting to (w - 1) is the same as shifting to w and then right shifting 1. -/ -private theorem BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_of_getLsb_false - {x : BitVec w} (hx : x.getLsb (w - 1) = false) : - x >>> (w - 1) = x >>> w <<< 1 := by +private theorem BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_or_zeroExtend_getLsb + {x : BitVec w} : + x >>> (w - 1) = ((x >>> w <<< 1) ||| (BitVec.ofBool (x.getLsb (w - 1))).zeroExtend w) := by ext i - simp only [getLsb_ushiftRight, getLsb_shiftLeft, Fin.is_lt, decide_True, Bool.true_and] + simp only [getLsb_ushiftRight, getLsb_or, getLsb_shiftLeft, Fin.is_lt, decide_True, Bool.true_and, + getLsb_zeroExtend, getLsb_ofBool] by_cases (i : Nat) < 1 case pos h => simp only [h, decide_True, Bool.not_true, Bool.false_and] have hi : (i : Nat) = 0 := by omega - simp [hi, hx] + simp [hi] case neg h => simp only [h, decide_False, Bool.not_false, Bool.true_and] + have hi : (i : Nat) ≠ 0 := by omega + simp only [hi, decide_False, Bool.false_and, Bool.or_false] congr 1 omega +private theorem BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_add_zeroExtend_getLsb + {x : BitVec w} : + x >>> (w - 1) = ((x >>> w <<< 1) + (BitVec.ofBool (x.getLsb (w - 1))).zeroExtend w) := by + rw [BitVec.add_eq_or_of_and_eq_zero] + · apply BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_or_zeroExtend_getLsb + · ext i + simp + intros i _ hi' + omega + theorem BitVec.add_assoc {x y z : BitVec w} : x + y + z = x + (y + z) := by apply eq_of_toNat_eq simp @@ -984,47 +997,31 @@ theorem divRec_correct {w : Nat} {n d : BitVec w} {qr : DivRecQuotRem w n d} {j case zero => simp [divRec] simp at hqrn - -- simp [tryDivSubtractShift] simp [tryDivSubtractShift_eq_tryDivSubtractShift'] simp [tryDivSubtractShift'] generalize hb : n.getLsb (w - 1) = b generalize hs : qr.r <<< 1 ||| zeroExtend w (ofBool b) = s by_cases hslt : s < d · simp [hslt] - rcases b with rfl | rfl - · simp_all - rw [← hs] - have qd : qr.q <<< 1 * d = (qr.q * d) <<< 1 := by - rw [BitVec.shiftLeft_mul_comm] - rw [BitVec.shiftLeft_mul_assoc] - rw [qd] - rw [BitVec.shiftLeft_eq_mul_twoPow] - rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [← hs] + have qd : qr.q <<< 1 * d = (qr.q * d) <<< 1 := by + rw [BitVec.shiftLeft_mul_comm] + rw [BitVec.shiftLeft_mul_assoc] + rw [qd] + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [← add_eq_or_of_and_eq_zero] + · rw [← BitVec.add_assoc] rw [← BitVec.add_mul] rw [← hqrn] rw [← BitVec.shiftLeft_eq_mul_twoPow] - rw [BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_of_getLsb_false hb] - · simp_all - rw [← hs] - have qd : qr.q <<< 1 * d = (qr.q * d) <<< 1 := by - rw [BitVec.shiftLeft_mul_comm] - rw [BitVec.shiftLeft_mul_assoc] - rw [qd] - rw [BitVec.shiftLeft_eq_mul_twoPow] - rw [BitVec.shiftLeft_eq_mul_twoPow] - rw [← add_eq_or_of_and_eq_zero] - · rw [← BitVec.add_assoc] - rw [← BitVec.add_mul] - rw [← hqrn] - rw [← BitVec.shiftLeft_eq_mul_twoPow] - - · ext i - simp - intros i _ hi' - omega + rw [BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_add_zeroExtend_getLsb, hb] + · ext i + simp + intros i _ hi' + omega · simp [hslt] sorry - sorry case succ j' ih => simp [divRec] simp at hqrn From ac9233141ae4ed90143326ef744f1f3cd57b9693 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Thu, 4 Jul 2024 15:58:07 +0100 Subject: [PATCH 49/64] chore: second half of base case, half done --- src/Init/Data/BitVec/Bitblast.lean | 39 ++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 8663b5961b9b..f62ea38a2a89 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -988,6 +988,39 @@ theorem BitVec.add_assoc {x y z : BitVec w} : x + y + z = x + (y + z) := by simp rw [Nat.add_assoc] +theorem BitVec.add_sub_assoc {m k : BitVec w} (h : k ≤ m) (n : BitVec w) : n + m - k = n + (m - k) := sorry + +/-- +Bitwise or of (x <<< 1) with 1 is the same as addition. +This is useful to reason in mixed-arithmetic bitwise contexts. +-/ +private theorem BitVec.shiftLeft_one_or_one_eq_shiftLeft_one_add_one {x : BitVec w} : + x <<< 1 ||| 1#w = (x <<< 1) + 1#w := by + rw [BitVec.add_eq_or_of_and_eq_zero] + ext i + simp + intro i _ hi' + omega + +theorem BitVec.add_sub_self_left {x y : BitVec w} : x + y - x = y := by + apply eq_of_toNat_eq + simp + calc + (x.toNat + y.toNat + (2 ^ w - x.toNat)) % 2 ^ w = (x.toNat + y.toNat + 2 ^ w - x.toNat) % 2 ^ w := by + rw [Nat.add_sub_assoc (Nat.le_of_lt x.isLt)] + _ = (x.toNat + y.toNat - x.toNat + 2 ^ w) % 2 ^ w := by rw [Nat.sub_add_comm]; omega + _ = (y.toNat + 2 ^ w) % 2 ^ w := by rw [Nat.add_sub_self_left] + _ = y.toNat % 2 ^ w := by simp + _ = y.toNat := by simp [Nat.mod_eq_of_lt] + +theorem BitVec.add_sub_self_right {x y : BitVec w} : x + y - y = x := by + rw [BitVec.add_comm] + rw [BitVec.add_sub_self_left] + +@[simp] +theorem BitVec.le_of_not_lt {x y : BitVec w} : ¬ x < y → y ≤ x := by + simp [BitVec.lt_def, BitVec.le_def] + theorem divRec_correct {w : Nat} {n d : BitVec w} {qr : DivRecQuotRem w n d} {j : Nat} (hj : j ≤ w - 1) (hqrd : qr.r < d) (hrn : qr.r < n >>> (j + 1)) @@ -1021,6 +1054,12 @@ theorem divRec_correct {w : Nat} {n d : BitVec w} {qr : DivRecQuotRem w n d} {j intros i _ hi' omega · simp [hslt] + rw [BitVec.shiftLeft_one_or_one_eq_shiftLeft_one_add_one] + rw [BitVec.add_mul] + simp only [BitVec.one_mul] + rw [BitVec.add_assoc] + rw [← BitVec.add_sub_assoc (by simp [hslt])] + rw [BitVec.add_sub_self_left] sorry case succ j' ih => simp [divRec] From 80dedc21b90650fcb605c976fc646cf2ecba64f0 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Thu, 4 Jul 2024 16:03:17 +0100 Subject: [PATCH 50/64] chore: base case done --- src/Init/Data/BitVec/Bitblast.lean | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index f62ea38a2a89..e304d45ebfa0 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -1034,12 +1034,12 @@ theorem divRec_correct {w : Nat} {n d : BitVec w} {qr : DivRecQuotRem w n d} {j simp [tryDivSubtractShift'] generalize hb : n.getLsb (w - 1) = b generalize hs : qr.r <<< 1 ||| zeroExtend w (ofBool b) = s - by_cases hslt : s < d + have qd : qr.q <<< 1 * d = (qr.q * d) <<< 1 := by + rw [BitVec.shiftLeft_mul_comm] + rw [BitVec.shiftLeft_mul_assoc] + by_cases hslt : s < d -- Note that the proof is identical on both sides of the case split. · simp [hslt] rw [← hs] - have qd : qr.q <<< 1 * d = (qr.q * d) <<< 1 := by - rw [BitVec.shiftLeft_mul_comm] - rw [BitVec.shiftLeft_mul_assoc] rw [qd] rw [BitVec.shiftLeft_eq_mul_twoPow] rw [BitVec.shiftLeft_eq_mul_twoPow] @@ -1060,7 +1060,20 @@ theorem divRec_correct {w : Nat} {n d : BitVec w} {qr : DivRecQuotRem w n d} {j rw [BitVec.add_assoc] rw [← BitVec.add_sub_assoc (by simp [hslt])] rw [BitVec.add_sub_self_left] - sorry + rw [← hs] + rw [qd] + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [← add_eq_or_of_and_eq_zero] + · rw [← BitVec.add_assoc] + rw [← BitVec.add_mul] + rw [← hqrn] + rw [← BitVec.shiftLeft_eq_mul_twoPow] + rw [BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_add_zeroExtend_getLsb, hb] + · ext i + simp + intros i _ hi' + omega case succ j' ih => simp [divRec] simp at hqrn From c9b1227c514d9a21edca70b6ab869ff79b4d741d Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 5 Jul 2024 09:36:03 +0100 Subject: [PATCH 51/64] chore: show how to derive final theorem from recurrence --- src/Init/Data/BitVec/Bitblast.lean | 158 ++++++++++++++++++----------- 1 file changed, 97 insertions(+), 61 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index e304d45ebfa0..3f7bb1d64121 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -756,9 +756,6 @@ theorem BitVec.shiftLeft_eq_mul_twoPow (x : BitVec w) (n : Nat) : ext i simp -theorem foo (a b : Nat) (hr : r * 2 + 1 < d) : r < (d - 1) / 2 := by - sorry - /-- One round of the division algorithm, that tries to perform a subtract shift. -/ def tryDivSubtractShift (qr : DivRecQuotRem w n d) (ix : Nat) : DivRecQuotRem w n d := let r' := (qr.r <<< 1) ||| (BitVec.ofBool (n.getLsb ix)).zeroExtend w @@ -979,16 +976,20 @@ private theorem BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_add_zeroExtend_getL rw [BitVec.add_eq_or_of_and_eq_zero] · apply BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_or_zeroExtend_getLsb · ext i - simp + simp only [getLsb_and, getLsb_shiftLeft, Fin.is_lt, decide_True, Bool.true_and, + getLsb_ushiftRight, getLsb_zeroExtend, getLsb_ofBool, getLsb_zero, and_eq_false_imp, + and_eq_true, not_eq_true', decide_eq_false_iff_not, Nat.not_lt, decide_eq_true_eq, and_imp] intros i _ hi' omega theorem BitVec.add_assoc {x y z : BitVec w} : x + y + z = x + (y + z) := by apply eq_of_toNat_eq - simp - rw [Nat.add_assoc] + simp[Nat.add_assoc] -theorem BitVec.add_sub_assoc {m k : BitVec w} (h : k ≤ m) (n : BitVec w) : n + m - k = n + (m - k) := sorry +theorem BitVec.add_sub_assoc {m k : BitVec w} (h : k ≤ m) (n : BitVec w) : + n + m - k = n + (m - k) := by + apply BitVec.eq_of_toNat_eq + simp only [toNat_sub, toNat_add, mod_add_mod, add_mod_mod, Nat.add_assoc] /-- Bitwise or of (x <<< 1) with 1 is the same as addition. @@ -1021,65 +1022,101 @@ theorem BitVec.add_sub_self_right {x y : BitVec w} : x + y - y = x := by theorem BitVec.le_of_not_lt {x y : BitVec w} : ¬ x < y → y ≤ x := by simp [BitVec.lt_def, BitVec.le_def] -theorem divRec_correct {w : Nat} {n d : BitVec w} {qr : DivRecQuotRem w n d} {j : Nat} (hj : j ≤ w - 1) +-- theorem div_iff_add_mod_of_lt {d n q r : BitVec w} (hd : 0 < d) +-- (hrd : r < d) +-- (hlt : d.toNat * q.toNat + r.toNat < 2^w) : +-- (n.udiv d = q ∧ n.umod d = r) ↔ (d * q + r = n) := by + +theorem divRec_correct {w : Nat} {n d : BitVec w} {qr : DivRecQuotRem w n d} {j : Nat} + (hj : j ≤ w - 1) (hqrd : qr.r < d) - (hrn : qr.r < n >>> (j + 1)) - (hqrn : n >>> (w - j) == qr.q * d + qr.r) : - n >>> ((w - 1) - j) == (divRec qr j).q * d + (divRec qr j).r := by + (hrn : qr.r < d) + (hqrn : n >>> (w - j) = qr.q * d + qr.r) : + ((n >>> ((w - 1) - j) = (divRec qr j).q * d + (divRec qr j).r)) ∧ + (d.toNat * (divRec qr j).q.toNat + (divRec qr j).r.toNat < 2^w) ∧ + (divRec qr j).r < d := by induction j generalizing qr case zero => - simp [divRec] - simp at hqrn - simp [tryDivSubtractShift_eq_tryDivSubtractShift'] - simp [tryDivSubtractShift'] - generalize hb : n.getLsb (w - 1) = b - generalize hs : qr.r <<< 1 ||| zeroExtend w (ofBool b) = s - have qd : qr.q <<< 1 * d = (qr.q * d) <<< 1 := by - rw [BitVec.shiftLeft_mul_comm] - rw [BitVec.shiftLeft_mul_assoc] - by_cases hslt : s < d -- Note that the proof is identical on both sides of the case split. - · simp [hslt] - rw [← hs] - rw [qd] - rw [BitVec.shiftLeft_eq_mul_twoPow] - rw [BitVec.shiftLeft_eq_mul_twoPow] - rw [← add_eq_or_of_and_eq_zero] - · rw [← BitVec.add_assoc] - rw [← BitVec.add_mul] - rw [← hqrn] - rw [← BitVec.shiftLeft_eq_mul_twoPow] - rw [BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_add_zeroExtend_getLsb, hb] - · ext i - simp - intros i _ hi' - omega - · simp [hslt] - rw [BitVec.shiftLeft_one_or_one_eq_shiftLeft_one_add_one] - rw [BitVec.add_mul] - simp only [BitVec.one_mul] - rw [BitVec.add_assoc] - rw [← BitVec.add_sub_assoc (by simp [hslt])] - rw [BitVec.add_sub_self_left] - rw [← hs] - rw [qd] - rw [BitVec.shiftLeft_eq_mul_twoPow] - rw [BitVec.shiftLeft_eq_mul_twoPow] - rw [← add_eq_or_of_and_eq_zero] - · rw [← BitVec.add_assoc] - rw [← BitVec.add_mul] - rw [← hqrn] - rw [← BitVec.shiftLeft_eq_mul_twoPow] - rw [BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_add_zeroExtend_getLsb, hb] - · ext i - simp - intros i _ hi' - omega + constructor + · simp [divRec] + simp at hqrn + simp [tryDivSubtractShift_eq_tryDivSubtractShift'] + simp [tryDivSubtractShift'] + generalize hb : n.getLsb (w - 1) = b + generalize hs : qr.r <<< 1 ||| zeroExtend w (ofBool b) = s + have qd : qr.q <<< 1 * d = (qr.q * d) <<< 1 := by + rw [BitVec.shiftLeft_mul_comm] + rw [BitVec.shiftLeft_mul_assoc] + by_cases hslt : s < d -- Note that the proof is identical on both sides of the case split. + · simp [hslt] + rw [← hs] + rw [qd] + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [← add_eq_or_of_and_eq_zero] + · rw [← BitVec.add_assoc] + rw [← BitVec.add_mul] + rw [← hqrn] + rw [← BitVec.shiftLeft_eq_mul_twoPow] + rw [BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_add_zeroExtend_getLsb, hb] + · ext i + simp + intros i _ hi' + omega + · simp [hslt] + rw [BitVec.shiftLeft_one_or_one_eq_shiftLeft_one_add_one] + rw [BitVec.add_mul] + simp only [BitVec.one_mul] + rw [BitVec.add_assoc] + rw [← BitVec.add_sub_assoc (by simp [hslt])] + rw [BitVec.add_sub_self_left] + rw [← hs] + rw [qd] + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [← add_eq_or_of_and_eq_zero] + · rw [← BitVec.add_assoc] + rw [← BitVec.add_mul] + rw [← hqrn] + rw [← BitVec.shiftLeft_eq_mul_twoPow] + rw [BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_add_zeroExtend_getLsb, hb] + · ext i + simp + intros i _ hi' + omega + · constructor + · simp [divRec] + simp at hqrn + sorry + · -- r < d + simp [divRec] + apply tryDivSubtractShift_lt_of_lt + apply hqrd + sorry case succ j' ih => - simp [divRec] - simp at hqrn - simp [tryDivSubtractShift_eq_tryDivSubtractShift'] sorry + +theorem div_eq_divRec (n d : BitVec w) (hd : d > 0) : + n.udiv d = (divRec (w := w) (n := n) (d := d) { r := 0, q := 0 } (w - 1)).q ∧ + n.umod d = (divRec (w := w) (n := n) (d := d) { r := 0, q := 0 } (w - 1)).r := by + obtain ⟨h₁, h₂, h₃⟩ := divRec_correct (w := w) (n := n) (d := d) (j := w - 1) (qr := { r := 0, q := 0}) + (by omega) + (by simpa using hd) + (by simpa using hd) + (by sorry) + simp at h₃ + simp at h₂ + simp at h₁ + have k := div_characterized_of_mul_add_of_lt (d := d) (n := n) + (q := (divRec (w := w) (n := n) (d := d) { r := 0, q := 0 } (w - 1)).q) + (r := (divRec (w := w) (n := n) (d := d) { r := 0, q := 0 } (w - 1)).r) + hd + h₃ + (by rw [BitVec.mul_comm]; simp_all) + (by simp_all) + simp [k] + def checkDivRec : Bool × Array String := Id.run do let w := 4 let max := (Nat.pow 2 w) @@ -1094,7 +1131,6 @@ def checkDivRec : Bool × Array String := Id.run do wrong := true (wrong, outputs) - /-- info: (false, { data := [] }) -/ #guard_msgs in #reduce checkDivRec From 0b6e2e8779ec5e29014f8b2c0c64e839cff9ef12 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 5 Jul 2024 09:36:25 +0100 Subject: [PATCH 52/64] chore: delete dead code --- src/Init/Data/BitVec/Bitblast.lean | 376 ----------------------------- 1 file changed, 376 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 3f7bb1d64121..fe134c2be7a8 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -1144,380 +1144,4 @@ theorem div_rec_7_2 : { r := 1, q := 3 } := by simp [divRec, tryDivSubtractShift] --- invariant 2 --- n.toNat % 2^j = d.toNat * q.toNat + r.toNat --- set_option maxHeartbeats 99999 in --- #reduce (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0 } 3) -- r = 1, q = 3 --- #reduce (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0 } 2) -- r = 1, q = 3 --- #reduce (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0 } 1) -- r = 1, q = 1 --- #reduce (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0 } 0) -- r = 1, q = 0 - --- /- Given d, R(j + 1), (calculate R(j), q.getLsb j). -/ --- def divremi (qr : DivRecQuotRem w n d) (j : Nat) : BitVec w × Bool := --- if d ≤ qr.r >>> j then -- This overflows, I need to do this differently :( --- -- let rj := qr.r - d * twoPow w j -- remainder is legal since it's positive, accept it. --- -- | this is the same as checking qr - d * 2^j >= 0. --- -- | If qr - d * 2^j ≥ 0, then it's legal for the quotient to have the jth bit 1. --- let rj := qr.r >>> j - d -- remainder is legal since it's positive, accept it. --- (rj, true) --- else --- (qr.r, false) -- remainder is illegal, so quotient must be '0' at this bit. - --- /-- Recurrence for division in terms of `divremi`. -/ --- def divRec (qr : DivRecQuotRem w n d) (j : Nat) : DivRecQuotRem w n d := --- let (r, qj) := divremi qr j --- let q := setBit qr.q j qj --- match j with --- | 0 => { r := r, q := q } --- | j + 1 => divRec { r := r, q := q } j --- where --- /-- Set the `i`th bit of `v` to `b`.-/ --- setBit (v : BitVec w) (i : Nat) (b : Bool) := --- if b then v ||| twoPow w i else v - - --- theorem divremi_eq_of_le {qr : DivRecQuotRem w n d} (h : d ≤ qr.r >>> j) : --- divremi qr j = (qr.r >>> j - d, true) := by --- simp [divremi, h] - --- theorem divremi_eq_of_not_le {qr : DivRecQuotRem w n d} (h : ¬ d ≤ qr.r >>> j) : --- divremi qr j = (qr.r, false) := by --- simp [divremi, h] - --- def DivRecQuotRem.Lawful {n d : BitVec w} (qr : DivRecQuotRem w n d) : Prop := --- (d.toNat * qr.q.toNat + qr.r.toNat = n.toNat) - --- theorem DivRecQuotRem.Lawful.toNat_inbounds {n d : BitVec w} {qr : DivRecQuotRem w n d} --- (h : qr.Lawful) : d.toNat * qr.q.toNat + qr.r.toNat < 2^w := by --- rw [h] --- omega - --- theorem DivRecQuotRem.inbounds {n d : BitVec w} {qr : DivRecQuotRem w n d} : --- (d * qr.q + qr.r).toNat < 2^w := by omega - - --- theorem DivRecQuotRem.Lawful.eq {n d : BitVec w} {qr : DivRecQuotRem w n d} --- (h : qr.Lawful) : (d * qr.q + qr.r = n) := by --- apply eq_of_toNat_eq --- simp --- rw [h] --- rw [Nat.mod_eq_of_lt] --- omega - --- -- theorem DivRecQuotRem.Lawful.eq_nat {n d : BitVec w} {qr : DivRecQuotRem w n d} --- -- (h : qr.Lawful) : d.toNat * qr.q.toNat + qr.r.toNat = n.toNat := by --- -- simp only [Lawful] at h --- -- have h' : (d * qr.q + qr.r).toNat = n.toNat := by rw [h.2] --- -- simp only [toNat_add, toNat_mul, mod_add_mod] at h' --- -- rw [Nat.mod_eq_of_lt] at h' --- -- · exact h' --- -- · exact h.1 - --- def DivRecQuotRem.initialize (n d : BitVec w) : DivRecQuotRem w n d := --- { r := n, q := 0 } - --- theorem DivRecQuotRem.lawful_initialize {n d : BitVec w} : --- (DivRecQuotRem.initialize n d).Lawful := by --- simp [DivRecQuotRem.Lawful, DivRecQuotRem.initialize] - - --- @[simp] --- theorem divRec_zero {qr : DivRecQuotRem w n d} : --- divRec qr 0 = --- { r := (divremi qr 0).fst , q := divRec.setBit qr.q 0 (divremi qr 0).snd } := by --- unfold divRec --- simp - --- @[simp] --- theorem divRec_succ {qr : DivRecQuotRem w n d} {j : Nat} : --- divRec qr (j + 1) = --- divRec { --- r := (divremi qr (j + 1) |>.fst), --- q := (divRec.setBit qr.q (j + 1) (divremi qr (j + 1) |>.snd)) --- } j := by --- conv => --- lhs --- unfold divRec - --- @[simp] --- theorem divRec.setBit_false {v : BitVec w} {i : Nat} : --- divRec.setBit v i false = v := by --- simp [divRec.setBit] - --- @[simp] --- theorem divRec.setBit_true {v : BitVec w} {i : Nat} : --- divRec.setBit v i true = v ||| twoPow w i := by --- simp [divRec.setBit] - --- /- --- Clear all low bits in the range `(j..0]`, --- keeping high bits in the range `[w..j]` --- -/ --- abbrev clearLowBitsAfter (x : BitVec w) (j : Nat) : BitVec w := (x >>> j) <<< j - --- @[simp] --- theorem getLsb_clearLowBitsAfter {x : BitVec w} {j i : Nat} : --- (clearLowBitsAfter x j).getLsb i = if j ≤ i then x.getLsb i else false := by --- unfold clearLowBitsAfter --- simp --- by_cases hij : i < j --- · simp only [hij, decide_True, Bool.not_true, Bool.and_false, Bool.false_and, false_eq, --- and_eq_false_imp, decide_eq_true_eq] --- intros hcontra --- omega --- · simp [hij, show j + (i - j) = i by omega, show j ≤ i by omega] --- apply lt_of_getLsb - --- /-- --- A bitvector can be broken down into the low bits (by truncate) and the high --- bits (by left shift followed by right shift). --- -/ --- theorem BitVec.zeroExtend_truncate_or_shiftRight_shiftLeft_eq_self {x : BitVec w} {i : Nat} : --- (x.truncate i).zeroExtend w ||| (x >>> i) <<< i = x := by --- ext j --- by_cases h : j < i --- · simp [h] --- · simp only [getLsb_or, getLsb_zeroExtend, Fin.is_lt, decide_True, h, decide_False, --- Bool.false_and, Bool.and_false, getLsb_shiftLeft, Bool.not_false, Bool.and_self, --- getLsb_ushiftRight, Bool.true_and, Bool.false_or] --- congr --- omega - --- /-- Key relationship that establishes the loop invariant after one iteration. -/ --- theorem DivRecQuotRem.rec_lawful_of_lawful --- {qr : DivRecQuotRem w n d} --- (hqr : qr.Lawful) --- {hq : ∀ {i : Nat}, i ≤ j + 1 → qr.q.getLsb i = false} --- {hd : d ≤ qr.r >>> (j + 1)} : --- { r := qr.r >>> (j + 1) - d, q := qr.q ||| twoPow w (j + 1) : DivRecQuotRem w n d }.Lawful := by --- simp only [DivRecQuotRem.Lawful] --- rw [← BitVec.add_eq_or_of_and_eq_zero] --- · simp --- /- --- w : Nat --- n d : BitVec w --- j : Nat --- qr : DivRecQuotRem w n d --- hqr : qr.Lawful --- hq : ∀ {i : Nat}, i ≤ j + 1 → qr.q.getLsb i = false --- hd : d ≤ qr.r >>> (j + 1) --- ⊢ d.toNat * ((qr.q.toNat + 2 ^ (j + 1)) % 2 ^ w) + (qr.r.toNat >>> (j + 1) + (2 ^ w - d.toNat)) % 2 ^ w = n.toNat --- -/ --- sorry --- · simp --- specialize hq (i := j + 1) (by omega) --- rw [hq] --- simp --- -- d * j + q = n. --- theorem divRec_lawful {qr : DivRecQuotRem w n d} {j : Nat} --- (hqr : qr.Lawful) --- -- | We start the reucrrence at j=w-1, so all bits are zero. --- -- | At one step of j=w-2, only the low bit maybe set, and all --- -- bits at [1..w) are zero. --- -- We end the recurrence at j=0, when no bit is forced to be zero. --- -- | the quotient is correct, if the quotient is zero for all bits --- -- | in the range `(j..0]`. --- (hq : ∀ {i : Nat} (hi : i ≤ j), qr.q.getLsb i = false) : --- (divRec qr j).Lawful := by --- induction j generalizing qr --- case zero => --- simp [DivRecQuotRem.Lawful] --- by_cases h : d ≤ qr.r >>> 0 --- · simp [divremi_eq_of_le h] --- simp at h --- rcases w with rfl | w --- · simp -- get rid of corner case with 1 % 2^0 = 1 % 1 = 0 --- · have h1 : 1 % 2^(w + 1) = 1 := by --- rw [Nat.mod_eq_of_lt] --- apply Nat.one_lt_two_pow (by omega) --- simp [h1] --- specialize (hq (i := 0) (by omega)) --- have hqr_q_or_1_to_Nat : (qr.q.toNat ||| 1) = (qr.q ||| 1).toNat := by --- simp --- rw [Nat.mod_eq_of_lt] --- repeat omega --- rw [hqr_q_or_1_to_Nat] --- have hqr_or_1_eq_hqr_add_1 : (qr.q ||| 1) = (qr.q + twoPow (w+1) 0) := by --- rw [add_eq_or_of_and_eq_zero] --- · simp --- · rw [and_twoPow_eq_getLsb, hq] --- simp --- rw [hqr_or_1_eq_hqr_add_1] --- simp -- here we get a % 2^(w + 1) that we wish to avoid --- calc --- d.toNat * ((qr.q.toNat + 1) % 2 ^ (w + 1)) + (qr.r.toNat + (2 ^ (w + 1) - d.toNat)) % 2 ^ (w + 1) = d.toNat * ((qr.q.toNat + 1) % 2 ^ (w + 1)) + (qr.r.toNat + (2 ^ (w + 1) - d.toNat)) % 2 ^ (w + 1) := by rfl --- _ = d.toNat * ((qr.q.toNat + 1)) + (qr.r.toNat + (2 ^ (w + 1) - d.toNat)) % 2 ^ (w + 1) := by --- rw [Nat.mod_eq_of_lt] --- simp [getLsb] at hq --- omega --- _ = d.toNat * (qr.q.toNat + 1) + (qr.r.toNat - d.toNat + (2 ^ (w + 1))) % 2 ^ (w + 1) := by --- congr 2 --- /- Note: omega needs this. Is the preprocessor supposed to pick this up? -/ --- have h' : d.toNat ≤ qr.r.toNat := by simp [BitVec.le_def] at h; omega --- omega --- _ = d.toNat * qr.q.toNat + d.toNat + (qr.r.toNat - d.toNat + (2 ^ (w + 1))) % 2 ^ (w + 1) := by simp [Nat.mul_add] --- _ = d.toNat * qr.q.toNat + d.toNat + (qr.r.toNat - d.toNat) % 2^(w + 1) + (2^(w+1) % 2 ^ (w + 1)) := by simp [Nat.add_mod] --- _ = d.toNat * qr.q.toNat + d.toNat + (qr.r.toNat - d.toNat) % 2^(w + 1) + 0 := by simp --- _ = d.toNat * qr.q.toNat + d.toNat + (qr.r.toNat - d.toNat) % 2^(w + 1) := by simp --- _ = d.toNat * qr.q.toNat + d.toNat + (qr.r.toNat - d.toNat) := by rw [Nat.mod_eq_of_lt]; omega --- _ = d.toNat * qr.q.toNat + (d.toNat - d.toNat + qr.r.toNat) := by --- have h' : d.toNat ≤ qr.r.toNat := by simp [BitVec.le_def] at h; omega --- omega --- _ = d.toNat * qr.q.toNat + qr.r.toNat := by simp --- _ = n.toNat := hqr --- · simp [divremi_eq_of_not_le h] --- exact hqr --- case succ j ih => --- simp --- -- TODO: split below into a separate lemma --- by_cases h : d ≤ qr.r >>> (j + 1) --- · simp [divremi_eq_of_le h] --- simp at h --- apply ih --- /- --- case pos.hqr --- w : Nat --- n d : BitVec w --- j : Nat --- ih : ∀ {qr : DivRecQuotRem w n d}, qr.Lawful → (∀ {i : Nat}, i ≤ j → qr.q.getLsb i = false) → (divRec qr j).Lawful --- qr : DivRecQuotRem w n d --- hqr : qr.Lawful --- hq : ∀ {i : Nat}, i ≤ j + 1 → qr.q.getLsb i = false --- h : d ≤ qr.r >>> (j + 1) --- ⊢ { r := qr.r >>> (j + 1) - d, q := qr.q ||| twoPow w (j + 1) }.Lawful --- -/ --- · sorry --- · intros i hi --- simp --- constructor --- · apply hq --- omega --- · intros h --- omega --- · simp [divremi_eq_of_not_le h] --- simp at h --- apply ih hqr --- intros i hi --- apply hq --- omega - --- #print axioms divRec_lawful - --- theorem divRec_remainder_inbounds {qr : DivRecQuotRem w n d} {j : Nat} --- (hqr : qr.Lawful) (hr : qr.r.toNat < 2^j) (hd : 0 < d): --- (divRec qr j).r < d := by --- induction j generalizing qr --- case zero => --- simp --- by_cases h : d <<< 0 ≤ qr.r --- · simp at h --- simp at hr --- simp [lt_def] --- simp [le_def] at h --- have hd' : d.toNat = 0 := by omega --- simp [lt_def] at hd --- omega --- · simp at h --- simp [lt_def] at hr --- simp [lt_def] --- simp [lt_def] at hd --- simp [divremi, h] --- omega --- case succ j ih => --- simp --- by_cases h : d ≤ qr.r >>> (j + 1) --- · simp [divremi_eq_of_le h] --- simp at h --- have hqr' : qr.r >>> (j + 1) = 0#w := by --- apply BitVec.eq_of_toNat_eq --- simp only [toNat_ushiftRight, toNat_ofNat, zero_mod] --- rw [Nat.shiftRight_eq_div_pow] --- rw[Nat.div_eq_of_lt (by omega)] --- rw [hqr'] --- rw [hqr'] at h --- have hdcontra : d = 0#w := by --- simp only [le_def, toNat_ofNat, zero_mod, le_zero_eq] at h --- apply BitVec.eq_of_toNat_eq --- simp [h] --- simp [hdcontra] at hd --- · simp [divremi_eq_of_not_le h] --- apply ih hqr --- simp [BitVec.le_def] at h --- sorry - --- theorem divRec_eq_udiv_of_lawful {qr : DivRecQuotRem w n d} (hqr : qr.Lawful) (hd : 0 < d) --- {hqr' : qr' = (divRec qr w)} : --- qr'.q = udiv n d := by --- have hlawful : qr'.Lawful := by --- simp [hqr', divRec_lawful hqr] --- have hremainder : qr'.r < d := by --- simp [hqr'] --- apply divRec_remainder_inbounds hqr --- apply qr.r.isLt --- assumption --- have this := div_characterized_of_mul_add_toNat --- (d := d) (q := qr'.q) (n := n) (r := qr'.r) hd hremainder hlawful --- simp [this.1] - --- /- --- case neg.hr --- w : Nat --- n d : BitVec w --- hd : 0 < d --- j : Nat --- ih : ∀ {qr : DivRecQuotRem w n d}, qr.Lawful → qr.r.toNat < 2 ^ j → (divRec qr j).r < d --- qr : DivRecQuotRem w n d --- hqr : qr.Lawful --- hr : qr.r.toNat < 2 ^ (j + 1) --- h : ¬d <<< (j + 1) ≤ qr.r --- ⊢ qr.r.toNat < 2 ^ j --- -/ --- -- I get a contradiction, because I know the following facts: --- -- - hr : qr < 2^(j + 1) --- -- hd : 0 < d => d >= 1 --- -- and thus, d <<< (j + 1) ≥ 1 >>> (j + 1) ≥ 2 ^(j + 1) --- -- plus, h tells me that qr.r > d <<< (j + 1) => qr.r > 2^(j + 1) --- -- contradiction! - --- apply ih --- simp [DivRecQuotRem.Lawful] --- constructor --- · sorry --- · simp --- -- exact h --- · simp [divremi_eq_of_not_le h] --- simp at h --- apply ih --- · exact hqr --- · simp; - --- /- --- case pos --- w : Nat --- n d : BitVec w --- hd : 0 < d --- j : Nat --- ih : ∀ {qr : DivRecQuotRem w n d}, qr.Lawful → qr.r.toNat < 2 ^ j → (divRec qr j).r < d --- qr : DivRecQuotRem w n d --- hqr : qr.Lawful --- hr : qr.r.toNat < 2 ^ (j + 1) --- h : d <<< (j + 1) ≤ qr.r --- ⊢ (divRec { r := qr.r - d <<< (j + 1), q := qr.q ||| twoPow w (j + 1) } j).r < d --- -/ --- sorry - --- theorem divRec_eq_umod_of_lawful {qr : DivRecQuotRem w n d} (hqr : qr.Lawful) (hd : 0 < d) --- {hqr' : qr' = (divRec qr w)} : --- qr'.r = umod n d := by --- have hlawful : qr'.Lawful := by simp [hqr', divRec_lawful hqr] --- have hremainder : qr'.r < d := by --- simp [hqr'] --- apply divRec_remainder_inbounds hqr --- apply qr.r.isLt --- assumption --- have this := div_characterized_of_mul_add_toNat --- (d := d) (q := qr'.q) (n := n) (r := qr'.r) hd hremainder hlawful.def --- simp [this.2] - - end BitVec From 8ad944aaaa01f7b9cd2c56fdf5ae8391953150a1 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 5 Jul 2024 10:57:40 +0100 Subject: [PATCH 53/64] feat: I now understand why the remainder is wr and dividend is wn, this is actually important for the proof --- src/Init/Data/BitVec/Bitblast.lean | 410 +++++++++++++++++++++-------- 1 file changed, 297 insertions(+), 113 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index fe134c2be7a8..a0018fb8fd81 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -707,7 +707,297 @@ theorem div_iff_add_mod_of_lt {d n q r : BitVec w} (hd : 0 < d) · intros h apply div_characterized_of_mul_add_of_lt <;> assumption -/- # Division Recurrence for Bitblasting -/ +/-# Tons of Lemmas for Proving Bitblasting Correct -/ + + + +theorem BitVec.shiftLeft_eq_mul_twoPow (x : BitVec w) (n : Nat) : + x <<< n = x * (BitVec.twoPow w n) := by + ext i + simp + + +@[simp] +theorem BitVec.or_zero (x : BitVec w) : x ||| 0#w = x := by + ext i + simp + + +theorem BitVec.sub_le_self_of_le {x y : BitVec w} (hx : y ≤ x) : x - y ≤ x := by + simp [BitVec.lt_def, BitVec.le_def] at hx ⊢ + rw [← Nat.add_sub_assoc (by omega)] + rw [Nat.add_comm] + rw [Nat.add_sub_assoc (by omega)] + rw [Nat.add_mod] + simp only [mod_self, Nat.zero_add, mod_mod] + rw [Nat.mod_eq_of_lt] <;> omega + +theorem BitVec.sub_lt_self_of_lt_of_lt {x y : BitVec w} (hx : y < x) (hy : 0 < y): x - y < x := by + simp [BitVec.lt_def] at hx hy ⊢ + rw [← Nat.add_sub_assoc (by omega)] + rw [Nat.add_comm] + rw [Nat.add_sub_assoc (by omega)] + rw [Nat.add_mod] + simp only [mod_self, Nat.zero_add, mod_mod] + rw [Nat.mod_eq_of_lt] <;> omega + +theorem BitVec.le_iff_not_lt {x y : BitVec w} : (¬ x < y) ↔ y ≤ x := by + constructor <;> + (intro h; simp [BitVec.lt_def, BitVec.le_def] at h ⊢; omega) + +@[simp] +theorem BitVec.le_refl (x : BitVec w) : x ≤ x := by + simp [BitVec.le_def] + + +theorem BitVec.shiftLeft_mul_comm (x y : BitVec w) (n : Nat) : + x <<< n * y = x * y <<< n := by + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [BitVec.mul_assoc] + congr 1 + apply BitVec.mul_comm + +theorem BitVec.shiftLeft_mul_assoc (x y : BitVec w) (n : Nat) : + x * y <<< n = (x * y) <<< n := by + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [BitVec.mul_assoc] + +theorem BitVec.add_mul (x y z : BitVec w) : (y + z) * x = y * x + z * x := by + conv => + lhs + rw [BitVec.mul_comm, BitVec.mul_add] + congr 1 <;> rw [BitVec.mul_comm] + + + + +theorem BitVec.add_assoc {x y z : BitVec w} : x + y + z = x + (y + z) := by + apply eq_of_toNat_eq + simp[Nat.add_assoc] + +theorem BitVec.add_sub_assoc {m k : BitVec w} (h : k ≤ m) (n : BitVec w) : + n + m - k = n + (m - k) := by + apply BitVec.eq_of_toNat_eq + simp only [toNat_sub, toNat_add, mod_add_mod, add_mod_mod, Nat.add_assoc] + +/-- +Bitwise or of (x <<< 1) with 1 is the same as addition. +This is useful to reason in mixed-arithmetic bitwise contexts. +-/ +private theorem BitVec.shiftLeft_one_or_one_eq_shiftLeft_one_add_one {x : BitVec w} : + x <<< 1 ||| 1#w = (x <<< 1) + 1#w := by + rw [BitVec.add_eq_or_of_and_eq_zero] + ext i + simp + intro i _ hi' + omega + +theorem BitVec.add_sub_self_left {x y : BitVec w} : x + y - x = y := by + apply eq_of_toNat_eq + simp + calc + (x.toNat + y.toNat + (2 ^ w - x.toNat)) % 2 ^ w = (x.toNat + y.toNat + 2 ^ w - x.toNat) % 2 ^ w := by + rw [Nat.add_sub_assoc (Nat.le_of_lt x.isLt)] + _ = (x.toNat + y.toNat - x.toNat + 2 ^ w) % 2 ^ w := by rw [Nat.sub_add_comm]; omega + _ = (y.toNat + 2 ^ w) % 2 ^ w := by rw [Nat.add_sub_self_left] + _ = y.toNat % 2 ^ w := by simp + _ = y.toNat := by simp [Nat.mod_eq_of_lt] + +theorem BitVec.add_sub_self_right {x y : BitVec w} : x + y - y = x := by + rw [BitVec.add_comm] + rw [BitVec.add_sub_self_left] + +@[simp] +theorem BitVec.le_of_not_lt {x y : BitVec w} : ¬ x < y → y ≤ x := by + simp [BitVec.lt_def, BitVec.le_def] + +/-- +if the MSB is false, then the arithmetic value of shifting +is the same as the original value times 2. +That is, if the msb is false, then shifting by 1 does not overflow. +Can be generalized to talk about shifting by `k` if the top `k` bits are false. +-/ +theorem BitVec.toNat_shiftLeft_one_eq_mul_two_of_msb_false + (x : BitVec w) + (h : x.msb = false) : + (x <<< 1).toNat = x.toNat * 2 := by + simp only [toNat_shiftLeft] + have h := (BitVec.msb_eq_false_iff_two_mul_lt x).mp h + rw [Nat.shiftLeft_eq, Nat.mod_eq_of_lt (by omega)] + +/- upon shifting left by one, if times 2 is less than 2^w, then we cannot overflow. -/ +theorem BitVec.toNat_shiftLeft_one_eq_mul_two_of_lt + (x : BitVec w) + (hlt : x.toNat * 2 < 2 ^ w) : + (x <<< 1).toNat = x.toNat * 2 := by + simp only [toNat_shiftLeft] + rw [Nat.shiftLeft_eq, Nat.mod_eq_of_lt (by omega)] + +/- # Division Recurrence for Bitblasting (V2 )-/ + +/-- +One round of the division algorithm, that tries to perform a subtract shift. +Note that this is only called when `r.msb = false`, so we will not overflow. +This means that `r'.toNat = r.toNat *2 + q.toNat` +. +-/ +def divSubtractShift (w : Nat) + (d : BitVec w) + (r : BitVec w) + (nb : Bool) + : + BitVec w × BitVec 1 := + let r' := (r <<< 1) ||| (BitVec.ofBool nb).zeroExtend w + let q := r' < d + ⟨if q then r' - d else r', BitVec.ofBool q⟩ + +/-- +Core divsion recurrence. +We have three widths at play: +- w, the total bitwidth +- wr, the effective bitwidth of the reminder +- wn, the effective bitwidth of the dividend. +- We have the invariant that wn + wr = w. + +See that when it is called, we will know that : + - r < [2^wr = 2^(w - wn)] + which allows us to safely shift left, since it is of length n. + In particular, since 'wn' decreases in the course of the recursion, + will will allow larger and larger values, and at the step where 'wn = 0', + we will have `r < 2^w`, which is no longer sufficient to allow for a shift left. + Thus, at this step, we will stop and return a full remainder. + So, the remainder is morally of length `w - wn`. + - d > 0 + - r < d + - n.toNat >>> wr = +-/ +def divRec' (w wn: Nat) + (d : BitVec w) + (r : BitVec w) -- morally, this of length 'w - wn'. + (n : BitVec wn) + : BitVec w × BitVec wn := + match wn with + | 0 => (r, 0#0) + | wn + 1 => + let (r', qcur) := divSubtractShift w d r (n.getMsb 0) + let (r'', q) := divRec' w wn d r' (n.truncate wn) + let q' := (q.zeroExtend (wn + 1)) ||| (qcur.zeroExtend (wn + 1) <<< wn) + (r'', q') + +def checkDivRec' : Bool × Option (BitVec 4 × BitVec 4 × BitVec 4 × BitVec 4) := Id.run do + let w := 4 + let max := (Nat.pow 2 w) + let mut outputs := .none + let mut wrong := false + for n in (List.range max) do + for d in (List.range (max - 1)).map (fun n => Nat.add n 1) do + have hd : d > 0 := by sorry + let qr := divRec' w w d 0#w n + if qr.2 * d + qr.1 != n then + wrong := true + outputs := .some (n, d, qr.2, qr.1) + (wrong, outputs) + +/-- info: (false, none) -/ +#guard_msgs in #reduce checkDivRec' + +/-- # Tons of helper lemmas about the behaviour of divRec -/ + +theorem mul_two_plus_one_lt_mul_two_of_lt (n m : Nat) (hn : n < m) : + n * 2 + 1 < 2 * m := by + omega + +theorem mul_two_plus_lt_mul_two_of_lt_of_lt_two (n m b : Nat) + (hn : n < m) (hb : b < 2) : + n * 2 + b < m * 2 := by + omega + + +/-- +the LHS of the condition of divSubtractShift, +written as an arithmetic inequality. +-/ +theorem divSubtractShift_toNat_cond_lhs (w : Nat) + (d r : BitVec w) + (b : Bool) + (hk : k < w) + (hr : r.toNat < 2 ^ k) : + (r <<< 1 ||| zeroExtend w (ofBool b)).toNat = + (r.toNat * 2 + (if b then 1 else 0)) := by + have hk' : 2^k < 2^w := by + apply Nat.pow_lt_pow_of_lt (by decide) (by omega) + rcases w with rfl | w + · omega -- contradiction, k < w + · rw [← BitVec.add_eq_or_of_and_eq_zero] + · simp [Bool.toNat] + rw [Nat.shiftLeft_eq] + simp [show (2^1 = 2) by decide] + rw [Nat.mod_eq_of_lt] + · rcases b with rfl | rfl <;> simp + · rw [Nat.pow_succ] + apply mul_two_plus_lt_mul_two_of_lt_of_lt_two + + · ext i + simp only [getLsb_and, getLsb_shiftLeft, Fin.is_lt, decide_True, Bool.true_and, + getLsb_zeroExtend, getLsb_ofBool, getLsb_zero, and_eq_false_imp, and_eq_true, not_eq_true', + decide_eq_false_iff_not, Nat.not_lt, decide_eq_true_eq, and_imp] + intros hi _ hi' + omega + +/-- +the condition of divSubtractShift is true iff +the arithmetic inequality holds. +-/ +theorem divSubtractShift_cond_iff (w : Nat) + (d r : BitVec w) + (b : Bool) : + (r <<< 1 ||| zeroExtend w (ofBool b) < d) ↔ + (r.toNat * 2 + (if b then 1 else 0) < 3) := by sorry + +theorem divSubtractShift_lt_two_pow_of_lt_two_pow (w : Nat) + (d r : BitVec w) + (b : Bool) + (hk : k < w) + (hr : r.toNat < 2 ^ k) : + (divSubtractShift w d r b).1.toNat < 2 ^ (k + 1) := by + simp [divSubtractShift] + by_cases (r <<< 1 ||| zeroExtend w (ofBool b) < d) + + case pos h => sorry + case neg h => sorry +theorem divSubtractShift_lt_of_lt (w : Nat) (d r : BitVec w) (b : Bool) (hrd : r < d) : + (divSubtractShift w d r b).1 < d := by + simp [divSubtractShift] + by_cases (r <<< 1 ||| zeroExtend w (ofBool b) < d) + case pos h => sorry + case neg h => sorry + +mutual + theorem t1 (w wn : Nat) (d r : BitVec w) (n : BitVec wn) : + (divRec' w wn d r n).fst < d := by sorry + -- theorem t2 (w wn : Nat) (d r : BitVec w) (n : BitVec wn) : + -- d * (divRec' w wn d r n).snd + (divRec' w wn d r n).fst = n + theorem t2 (w : Nat) (d : BitVec w) (n : BitVec w) : + d * (divRec' w w d (0#w) n).snd + (divRec' w w d (0#w) n).fst = n := by sorry + theorem t3 (w : Nat) (d n : BitVec w) : + d.toNat * (divRec' w w d (0#w) n).snd.toNat + (divRec' w w d (0#w) n).fst.toNat < 2 ^ w := by sorry +end + +theorem divRec'_eq_udiv (w : Nat) (n : BitVec w) (d : BitVec w) (hd : 0 < d): + n.umod d = (divRec' w w d 0#w n).1 ∧ + n.udiv d = (divRec' w w d 0#w n).2 := by + have k := div_characterized_of_mul_add_of_lt (d := d) (n := n) + (q := (divRec' w w d 0#w n).2) + (r := (divRec' w w d 0#w n).1) + hd + (by apply t1) + (by apply t2) + (by apply t3) + simp [k] + +/- # (OLD OLD) Division Recurrence for Bitblasting -/ -- n = d * q + r -- Two-stage subtraction: @@ -751,11 +1041,6 @@ structure DivRecQuotRem (w : Nat) (n : BitVec w) (d : BitVec w) where deriving DecidableEq, Repr -theorem BitVec.shiftLeft_eq_mul_twoPow (x : BitVec w) (n : Nat) : - x <<< n = x * (BitVec.twoPow w n) := by - ext i - simp - /-- One round of the division algorithm, that tries to perform a subtract shift. -/ def tryDivSubtractShift (qr : DivRecQuotRem w n d) (ix : Nat) : DivRecQuotRem w n d := let r' := (qr.r <<< 1) ||| (BitVec.ofBool (n.getLsb ix)).zeroExtend w @@ -771,11 +1056,6 @@ def tryDivSubtractShift' (qr : DivRecQuotRem w n d) (ix : Nat) : DivRecQuotRem w let r' := (qr.r <<< 1) ||| (BitVec.ofBool (n.getLsb ix)).zeroExtend w { r := if r' < d then r' else r' - d, q := qr.q <<< 1 ||| (if r' < d then 0 else 1) } -@[simp] -theorem BitVec.or_zero (x : BitVec w) : x ||| 0#w = x := by - ext i - simp - theorem tryDivSubtractShift_eq_tryDivSubtractShift' (qr : DivRecQuotRem w n d) (ix : Nat) : tryDivSubtractShift qr ix = tryDivSubtractShift' qr ix := by simp [tryDivSubtractShift, tryDivSubtractShift'] @@ -784,33 +1064,6 @@ theorem tryDivSubtractShift_eq_tryDivSubtractShift' (qr : DivRecQuotRem w n d) ( · simp [hslt] · simp [hslt] -theorem BitVec.sub_le_self_of_le {x y : BitVec w} (hx : y ≤ x) : x - y ≤ x := by - simp [BitVec.lt_def, BitVec.le_def] at hx ⊢ - rw [← Nat.add_sub_assoc (by omega)] - rw [Nat.add_comm] - rw [Nat.add_sub_assoc (by omega)] - rw [Nat.add_mod] - simp only [mod_self, Nat.zero_add, mod_mod] - rw [Nat.mod_eq_of_lt] <;> omega - -theorem BitVec.sub_lt_self_of_lt_of_lt {x y : BitVec w} (hx : y < x) (hy : 0 < y): x - y < x := by - simp [BitVec.lt_def] at hx hy ⊢ - rw [← Nat.add_sub_assoc (by omega)] - rw [Nat.add_comm] - rw [Nat.add_sub_assoc (by omega)] - rw [Nat.add_mod] - simp only [mod_self, Nat.zero_add, mod_mod] - rw [Nat.mod_eq_of_lt] <;> omega - -theorem BitVec.le_iff_not_lt {x y : BitVec w} : (¬ x < y) ↔ y ≤ x := by - constructor <;> - (intro h; simp [BitVec.lt_def, BitVec.le_def] at h ⊢; omega) - -@[simp] -theorem BitVec.le_refl (x : BitVec w) : x ≤ x := by - simp [BitVec.le_def] - - /-- The tryDivSubtractShift's remainder is upper bounded by `r << 1 | 1`. -/ theorem tryDivSubtractShift_remainder_lt_shiftLeft_one_or_one {qr : DivRecQuotRem w n d} {ix : Nat} : (tryDivSubtractShift qr ix).r ≤ (qr.r <<< 1) ||| (BitVec.ofBool (n.getLsb ix)).zeroExtend w := by @@ -927,27 +1180,6 @@ def divRec (qr : DivRecQuotRem w n d) (j : Nat) : | j + 1 => divRec qr j tryDivSubtractShift qr' (w - 1 - j) -theorem BitVec.shiftLeft_mul_comm (x y : BitVec w) (n : Nat) : - x <<< n * y = x * y <<< n := by - rw [BitVec.shiftLeft_eq_mul_twoPow] - rw [BitVec.shiftLeft_eq_mul_twoPow] - rw [BitVec.mul_assoc] - congr 1 - apply BitVec.mul_comm - -theorem BitVec.shiftLeft_mul_assoc (x y : BitVec w) (n : Nat) : - x * y <<< n = (x * y) <<< n := by - rw [BitVec.shiftLeft_eq_mul_twoPow] - rw [BitVec.shiftLeft_eq_mul_twoPow] - rw [BitVec.mul_assoc] - -theorem BitVec.add_mul (x y z : BitVec w) : (y + z) * x = y * x + z * x := by - conv => - lhs - rw [BitVec.mul_comm, BitVec.mul_add] - congr 1 <;> rw [BitVec.mul_comm] - - /-- TODO: what's a good theorem name? If the LSB is false, then shifting to (w - 1) is the same as shifting to w and then right shifting 1. @@ -982,56 +1214,11 @@ private theorem BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_add_zeroExtend_getL intros i _ hi' omega -theorem BitVec.add_assoc {x y z : BitVec w} : x + y + z = x + (y + z) := by - apply eq_of_toNat_eq - simp[Nat.add_assoc] - -theorem BitVec.add_sub_assoc {m k : BitVec w} (h : k ≤ m) (n : BitVec w) : - n + m - k = n + (m - k) := by - apply BitVec.eq_of_toNat_eq - simp only [toNat_sub, toNat_add, mod_add_mod, add_mod_mod, Nat.add_assoc] - -/-- -Bitwise or of (x <<< 1) with 1 is the same as addition. -This is useful to reason in mixed-arithmetic bitwise contexts. --/ -private theorem BitVec.shiftLeft_one_or_one_eq_shiftLeft_one_add_one {x : BitVec w} : - x <<< 1 ||| 1#w = (x <<< 1) + 1#w := by - rw [BitVec.add_eq_or_of_and_eq_zero] - ext i - simp - intro i _ hi' - omega - -theorem BitVec.add_sub_self_left {x y : BitVec w} : x + y - x = y := by - apply eq_of_toNat_eq - simp - calc - (x.toNat + y.toNat + (2 ^ w - x.toNat)) % 2 ^ w = (x.toNat + y.toNat + 2 ^ w - x.toNat) % 2 ^ w := by - rw [Nat.add_sub_assoc (Nat.le_of_lt x.isLt)] - _ = (x.toNat + y.toNat - x.toNat + 2 ^ w) % 2 ^ w := by rw [Nat.sub_add_comm]; omega - _ = (y.toNat + 2 ^ w) % 2 ^ w := by rw [Nat.add_sub_self_left] - _ = y.toNat % 2 ^ w := by simp - _ = y.toNat := by simp [Nat.mod_eq_of_lt] - -theorem BitVec.add_sub_self_right {x y : BitVec w} : x + y - y = x := by - rw [BitVec.add_comm] - rw [BitVec.add_sub_self_left] - -@[simp] -theorem BitVec.le_of_not_lt {x y : BitVec w} : ¬ x < y → y ≤ x := by - simp [BitVec.lt_def, BitVec.le_def] - --- theorem div_iff_add_mod_of_lt {d n q r : BitVec w} (hd : 0 < d) --- (hrd : r < d) --- (hlt : d.toNat * q.toNat + r.toNat < 2^w) : --- (n.udiv d = q ∧ n.umod d = r) ↔ (d * q + r = n) := by - theorem divRec_correct {w : Nat} {n d : BitVec w} {qr : DivRecQuotRem w n d} {j : Nat} (hj : j ≤ w - 1) (hqrd : qr.r < d) - (hrn : qr.r < d) - (hqrn : n >>> (w - j) = qr.q * d + qr.r) : + (hqrn : n >>> (w - j) = qr.q * d + qr.r) + : ((n >>> ((w - 1) - j) = (divRec qr j).q * d + (divRec qr j).r)) ∧ (d.toNat * (divRec qr j).q.toNat + (divRec qr j).r.toNat < 2^w) ∧ (divRec qr j).r < d := by @@ -1098,13 +1285,13 @@ theorem divRec_correct {w : Nat} {n d : BitVec w} {qr : DivRecQuotRem w n d} {j theorem div_eq_divRec (n d : BitVec w) (hd : d > 0) : - n.udiv d = (divRec (w := w) (n := n) (d := d) { r := 0, q := 0 } (w - 1)).q ∧ - n.umod d = (divRec (w := w) (n := n) (d := d) { r := 0, q := 0 } (w - 1)).r := by + let qr := divRec (w := w) (n := n) (d := d) { r := 0, q := 0 } (w - 1) + n.udiv d = qr.q ∧ + n.umod d = qr.r := by obtain ⟨h₁, h₂, h₃⟩ := divRec_correct (w := w) (n := n) (d := d) (j := w - 1) (qr := { r := 0, q := 0}) (by omega) (by simpa using hd) - (by simpa using hd) - (by sorry) + (by simp [show (w - (w - 1) = 1) by omega];) simp at h₃ simp at h₂ simp at h₁ @@ -1134,9 +1321,6 @@ def checkDivRec : Bool × Array String := Id.run do /-- info: (false, { data := [] }) -/ #guard_msgs in #reduce checkDivRec --- theorem divRec_n (qr : DivRecQuotRem w n d) : --- d * (tryDivSubtractShift qr j).q + (tryDivSubtractShift q j).r = n >>> (j + 1) - -- invariants: -- 1) qr.r < d. theorem div_rec_7_2 : From f1559cd6bbe2fdf7124a7f7fb736a1c894a042af Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 5 Jul 2024 12:32:38 +0100 Subject: [PATCH 54/64] chore: lay everything out as a large structure with all preconditions --- src/Init/Data/BitVec/Bitblast.lean | 208 +++++++++++++++++++++++------ 1 file changed, 170 insertions(+), 38 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index a0018fb8fd81..c31e5d48bd87 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -835,6 +835,115 @@ theorem BitVec.toNat_shiftLeft_one_eq_mul_two_of_lt simp only [toNat_shiftLeft] rw [Nat.shiftLeft_eq, Nat.mod_eq_of_lt (by omega)] +structure DivRemInput (w wr wn : Nat) + (n : BitVec w) + (d : BitVec w) + (q r : BitVec w) : Type where + hwr : wr ≤ w + hwn : wn ≤ w + hwrn : wr + wn = w + hd : 0 < d + hrd : r.toNat < d.toNat + hrwr : r.toNat < 2^wr + hqwr : q.toNat < 2^wr + hdiv : (n >>> wn).toNat = d.toNat * q.toNat + r.toNat + +/-- In a valid DivRemInput, it is implied that `w > 0`. -/ +def DivRemInput.hw (h : DivRemInput w wr wn n d q r) : 0 < w := by + have hd := h.hd + rcases w with rfl | w + · have hcontra : d = 0#0 := by apply Subsingleton.elim + rw [hcontra] at hd + simp at hd + · omega + +/-- +Make an initial state of the DivRemInput, for a given choice of +`n, d, q, r`. -/ +def DivRemInput_init (w : Nat) (n d : BitVec w) (hw : 0 < w) (hd : 0#w < d) : + DivRemInput w 0 w n d 0#w 0#w := { + hwr := by omega, + hwn := by omega, + hwrn := by omega, + hd := by assumption + hrd := by simp [BitVec.lt_def] at hd ⊢; assumption + hrwr := by simp, + hqwr := by simp, + hdiv := by + simp; + rw [Nat.shiftRight_eq_div_pow] + apply Nat.div_eq_of_lt n.isLt +} + +theorem DivRemInput_implies_udiv_urem + (h : DivRemInput w w 0 n d q r) : + n.udiv d = q ∧ n.umod d = r := by + apply div_characterized_of_mul_add_toNat + (n := n) (d := d) (q := q) (r := r) + (h.hd) + (h.hrd) + (by + have hdiv := h.hdiv + simp at hdiv + omega + ) + +structure ShiftSubtractInput (w wr wn : Nat) (n d q r : BitVec w) + extends DivRemInput w wr wn n d q r : Type where + hwn_lt : 0 < wn -- we can only call this function legally if we have dividend bits. + +/-- In the shift subtract input, we have one more bit to spare, +so we do not overflow. -/ +def ShiftSubtractInput.wr_add_one_le_w + (h : ShiftSubtractInput w wr wn n d q r) : wr + 1 ≤ w := by + have hwrn := h.hwrn + have hwn_lt := h.hwn_lt + omega + +/-- In the shift subtract input, we have one more bit to spare, +so we do not overflow. -/ +def ShiftSubtractInput.wr_le_wr_sub_one + (h : ShiftSubtractInput w wr wn n d q r) : wr ≤ w - 1 := by + have hw := h.hw + have hwrn := h.hwrn + have hwn_lt := h.hwn_lt + omega + +/-- If we have extra bits to spare in `n`, +then the div rem input can be converted into a shift subtract input +to run a round of the shift subtracter. -/ +def DivRemInput.toShiftSubtractInput + (h : DivRemInput w wr (wn + 1) n d q r) : + ShiftSubtractInput w wr (wn + 1) n d q r := { + hwr := h.hwr, + hwn := h.hwn, + hwrn := by have := h.hwrn; omega, + hd := h.hd, + hrd := h.hrd, + hrwr := h.hrwr, + hqwr := h.hqwr, + hdiv := h.hdiv, + hwn_lt := by omega + } + +def ShiftSubtractInput.nmsb (_ : ShiftSubtractInput w wr wn n d q r) : + BitVec 1 := + BitVec.ofBool <| n.getMsb wn + +def DivRemInput.wr_eq_w_of_wn_eq_zero + (h : DivRemInput w wr 0 n d q r) : DivRemInput w w 0 n d q r := + { + hwr := by have := h.hwr; omega, + hwn := h.hwn, + hwrn := by have := h.hwrn; omega, + hd := h.hd, + hrd := h.hrd, + hrwr := by have := h.hrwr; omega, + hqwr := by have := h.hqwr; omega, + hdiv := h.hdiv + } + + /- # Division Recurrence for Bitblasting (V2 )-/ /-- @@ -843,13 +952,9 @@ Note that this is only called when `r.msb = false`, so we will not overflow. This means that `r'.toNat = r.toNat *2 + q.toNat` . -/ -def divSubtractShift (w : Nat) - (d : BitVec w) - (r : BitVec w) - (nb : Bool) - : - BitVec w × BitVec 1 := - let r' := (r <<< 1) ||| (BitVec.ofBool nb).zeroExtend w +def divSubtractShift (h : ShiftSubtractInput w wr wn n d q r) : + DivRemInput w (wr + 1) (wn - 1) n d q r := + let r' := (r <<< 1) ||| (h.nmsb).zeroExtend w let q := r' < d ⟨if q then r' - d else r', BitVec.ofBool q⟩ @@ -873,18 +978,13 @@ See that when it is called, we will know that : - r < d - n.toNat >>> wr = -/ -def divRec' (w wn: Nat) - (d : BitVec w) - (r : BitVec w) -- morally, this of length 'w - wn'. - (n : BitVec wn) - : BitVec w × BitVec wn := +def divRec' (h : DivRemInput w wr wn n d q r) : + DivRemOutput w n d := match wn with - | 0 => (r, 0#0) - | wn + 1 => - let (r', qcur) := divSubtractShift w d r (n.getMsb 0) - let (r'', q) := divRec' w wn d r' (n.truncate wn) - let q' := (q.zeroExtend (wn + 1)) ||| (qcur.zeroExtend (wn + 1) <<< wn) - (r'', q') + | 0 => h.wr_eq_w_of_wn_eq_zero + | _ + 1 => + let new := divSubtractShift h.toShiftSubtractInput + divRec' new def checkDivRec' : Bool × Option (BitVec 4 × BitVec 4 × BitVec 4 × BitVec 4) := Id.run do let w := 4 @@ -894,8 +994,12 @@ def checkDivRec' : Bool × Option (BitVec 4 × BitVec 4 × BitVec 4 × BitVec 4) for n in (List.range max) do for d in (List.range (max - 1)).map (fun n => Nat.add n 1) do have hd : d > 0 := by sorry - let qr := divRec' w w d 0#w n - if qr.2 * d + qr.1 != n then + have hd' : d < 2 ^ w := by sorry + let init := DivRemInput_init w (BitVec.ofNat w n) (BitVec.ofNat w d) + (by omega) + (by simp; rw [Nat.mod_eq_of_lt]; omega; omega) + let qr := divRec' init + if qr.q != n then wrong := true outputs := .some (n, d, qr.2, qr.1) (wrong, outputs) @@ -903,24 +1007,39 @@ def checkDivRec' : Bool × Option (BitVec 4 × BitVec 4 × BitVec 4 × BitVec 4) /-- info: (false, none) -/ #guard_msgs in #reduce checkDivRec' -/-- # Tons of helper lemmas about the behaviour of divRec -/ - -theorem mul_two_plus_one_lt_mul_two_of_lt (n m : Nat) (hn : n < m) : - n * 2 + 1 < 2 * m := by - omega - -theorem mul_two_plus_lt_mul_two_of_lt_of_lt_two (n m b : Nat) - (hn : n < m) (hb : b < 2) : - n * 2 + b < m * 2 := by - omega - +/- # Tons of helper lemmas about the behaviour of divRec -/ +/-- +The arithmetic version of: +If `n : Bitvec w` has only the low `k < w` bits set, +then `(n <<< 1 | b)` does not overflow. +-/ +theorem mul_two_add_lt_two_pow_of_lt_two_pow_of_lt_two + (hn : n < 2 ^ k) (hb : b < 2) (hk : k < w) : + n * 2 + b < 2 ^ w := by + have : k + 1 ≤ w := by omega + have : 2^(k + 1) ≤ 2 ^w := by + apply Nat.pow_le_pow_of_le_right (by decide) (by assumption) + have : n ≤ 2 ^k - 1 := by omega + have : n * 2 ≤ 2^k * 2 - 2 := by omega + have : n * 2 + b ≤ 2^k * 2 - 1 := by omega + have : n * 2 + b ≤ 2 ^(k + 1) - 1 := by omega + have : n * 2 + b ≤ 2 ^w - 1 := by omega + have : n * 2 + b < 2^w := by omega + assumption + +-- 0 a | a < 2 +-- a b -- 2a + b < 4 +-- k < w + 1 +-- 2^k ≤ 2^w +-- x ≤ 2^w +-- x * w ≤ 2^w + 1 /-- the LHS of the condition of divSubtractShift, written as an arithmetic inequality. -/ theorem divSubtractShift_toNat_cond_lhs (w : Nat) - (d r : BitVec w) + (r : BitVec w) (b : Bool) (hk : k < w) (hr : r.toNat < 2 ^ k) : @@ -936,9 +1055,10 @@ theorem divSubtractShift_toNat_cond_lhs (w : Nat) simp [show (2^1 = 2) by decide] rw [Nat.mod_eq_of_lt] · rcases b with rfl | rfl <;> simp - · rw [Nat.pow_succ] - apply mul_two_plus_lt_mul_two_of_lt_of_lt_two - + · apply mul_two_add_lt_two_pow_of_lt_two_pow_of_lt_two + · exact hr + · rcases b <;> decide + · assumption · ext i simp only [getLsb_and, getLsb_shiftLeft, Fin.is_lt, decide_True, Bool.true_and, getLsb_zeroExtend, getLsb_ofBool, getLsb_zero, and_eq_false_imp, and_eq_true, not_eq_true', @@ -950,11 +1070,23 @@ theorem divSubtractShift_toNat_cond_lhs (w : Nat) the condition of divSubtractShift is true iff the arithmetic inequality holds. -/ -theorem divSubtractShift_cond_iff (w : Nat) - (d r : BitVec w) +theorem divSubtractShift_cond_iff (w wr wn: Nat) + (hwr : wr < w) + (hwn : wn < w) + (hwrn : wr + wn = w) + (d n q r : BitVec w) + (hr : r.toNat < 2^wr) -- r : BitVec wr + (hq : q.toNat < 2^wr) -- q : BitVec wr + (hn : n.toNat < 2^wn) -- n : BitVec wn + (hdiv : (n >>> wr).toNat = r.toNat + d.toNat * q.toNat) (b : Bool) : (r <<< 1 ||| zeroExtend w (ofBool b) < d) ↔ - (r.toNat * 2 + (if b then 1 else 0) < 3) := by sorry + (r.toNat * 2 + (if b then 1 else 0) < d.toNat) := by + constructor + · intros h + sorry + · intros h + sorry theorem divSubtractShift_lt_two_pow_of_lt_two_pow (w : Nat) (d r : BitVec w) From 809263c33c70546c4f6dcf1c7be3336a1f0043d4 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 5 Jul 2024 13:19:10 +0100 Subject: [PATCH 55/64] chore: more progress, first settle on using Bool.toNat everywhere --- src/Init/Data/BitVec/Bitblast.lean | 274 ++++++++++++++++++++--------- 1 file changed, 192 insertions(+), 82 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index c31e5d48bd87..00d39fb563cd 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -835,10 +835,65 @@ theorem BitVec.toNat_shiftLeft_one_eq_mul_two_of_lt simp only [toNat_shiftLeft] rw [Nat.shiftLeft_eq, Nat.mod_eq_of_lt (by omega)] +/-- +The arithmetic version of: +If `n : Bitvec w` has only the low `k < w` bits set, +then `(n <<< 1 | b)` does not overflow. +-/ +theorem mul_two_add_lt_two_pow_of_lt_two_pow_of_lt_two + (hn : n < 2 ^ k) (hb : b < 2) (hk : k < w) : + n * 2 + b < 2 ^ w := by + have : k + 1 ≤ w := by omega + have : 2^(k + 1) ≤ 2 ^w := by + apply Nat.pow_le_pow_of_le_right (by decide) (by assumption) + have : n ≤ 2 ^k - 1 := by omega + have : n * 2 ≤ 2^k * 2 - 2 := by omega + have : n * 2 + b ≤ 2^k * 2 - 1 := by omega + have : n * 2 + b ≤ 2 ^(k + 1) - 1 := by omega + have : n * 2 + b ≤ 2 ^w - 1 := by omega + have : n * 2 + b < 2^w := by omega + assumption + +/-- +If `n : Bitvec w` has only the low `k < w` bits set, +then `(n <<< 1 | b)` does not overflow, and we can compute its value +as a multiply and add. +-/ +theorem toNat_shiftLeft_or_zeroExtend_ofBool_eq (w : Nat) + (r : BitVec w) + (b : Bool) + (hk : k < w) + (hr : r.toNat < 2 ^ k) : + (r <<< 1 ||| zeroExtend w (ofBool b)).toNat = + (r.toNat * 2 + (if b then 1 else 0)) := by + have hk' : 2^k < 2^w := by + apply Nat.pow_lt_pow_of_lt (by decide) (by omega) + rcases w with rfl | w + · omega -- contradiction, k < w + · rw [← BitVec.add_eq_or_of_and_eq_zero] + · simp [Bool.toNat] + rw [Nat.shiftLeft_eq] + simp [show (2^1 = 2) by decide] + rw [Nat.mod_eq_of_lt] + · rcases b with rfl | rfl <;> simp + · apply mul_two_add_lt_two_pow_of_lt_two_pow_of_lt_two + · exact hr + · rcases b <;> decide + · assumption + · ext i + simp only [getLsb_and, getLsb_shiftLeft, Fin.is_lt, decide_True, Bool.true_and, + getLsb_zeroExtend, getLsb_ofBool, getLsb_zero, and_eq_false_imp, and_eq_true, not_eq_true', + decide_eq_false_iff_not, Nat.not_lt, decide_eq_true_eq, and_imp] + intros hi _ hi' + omega + + +/- # DivRem, V3 -/ structure DivRemInput (w wr wn : Nat) (n : BitVec w) - (d : BitVec w) - (q r : BitVec w) : Type where + (d : BitVec w) : Type where + q : BitVec w + r : BitVec w hwr : wr ≤ w hwn : wn ≤ w hwrn : wr + wn = w @@ -849,7 +904,7 @@ structure DivRemInput (w wr wn : Nat) hdiv : (n >>> wn).toNat = d.toNat * q.toNat + r.toNat /-- In a valid DivRemInput, it is implied that `w > 0`. -/ -def DivRemInput.hw (h : DivRemInput w wr wn n d q r) : 0 < w := by +def DivRemInput.hw (h : DivRemInput w wr wn n d) : 0 < w := by have hd := h.hd rcases w with rfl | w · have hcontra : d = 0#0 := by apply Subsingleton.elim @@ -861,7 +916,9 @@ def DivRemInput.hw (h : DivRemInput w wr wn n d q r) : 0 < w := by Make an initial state of the DivRemInput, for a given choice of `n, d, q, r`. -/ def DivRemInput_init (w : Nat) (n d : BitVec w) (hw : 0 < w) (hd : 0#w < d) : - DivRemInput w 0 w n d 0#w 0#w := { + DivRemInput w 0 w n d:= { + q := 0#w + r := 0#w hwr := by omega, hwn := by omega, hwrn := by omega, @@ -876,10 +933,10 @@ def DivRemInput_init (w : Nat) (n d : BitVec w) (hw : 0 < w) (hd : 0#w < d) : } theorem DivRemInput_implies_udiv_urem - (h : DivRemInput w w 0 n d q r) : - n.udiv d = q ∧ n.umod d = r := by + (h : DivRemInput w w 0 n d) : + n.udiv d = h.q ∧ n.umod d = h.r := by apply div_characterized_of_mul_add_toNat - (n := n) (d := d) (q := q) (r := r) + (n := n) (d := d) (q := h.q) (r := h.r) (h.hd) (h.hrd) (by @@ -888,14 +945,14 @@ theorem DivRemInput_implies_udiv_urem omega ) -structure ShiftSubtractInput (w wr wn : Nat) (n d q r : BitVec w) - extends DivRemInput w wr wn n d q r : Type where +structure ShiftSubtractInput (w wr wn : Nat) (n d: BitVec w) + extends DivRemInput w wr wn n d : Type where hwn_lt : 0 < wn -- we can only call this function legally if we have dividend bits. /-- In the shift subtract input, we have one more bit to spare, so we do not overflow. -/ def ShiftSubtractInput.wr_add_one_le_w - (h : ShiftSubtractInput w wr wn n d q r) : wr + 1 ≤ w := by + (h : ShiftSubtractInput w wr wn n d) : wr + 1 ≤ w := by have hwrn := h.hwrn have hwn_lt := h.hwn_lt omega @@ -903,7 +960,7 @@ def ShiftSubtractInput.wr_add_one_le_w /-- In the shift subtract input, we have one more bit to spare, so we do not overflow. -/ def ShiftSubtractInput.wr_le_wr_sub_one - (h : ShiftSubtractInput w wr wn n d q r) : wr ≤ w - 1 := by + (h : ShiftSubtractInput w wr wn n d) : wr ≤ w - 1 := by have hw := h.hw have hwrn := h.hwrn have hwn_lt := h.hwn_lt @@ -913,8 +970,10 @@ def ShiftSubtractInput.wr_le_wr_sub_one then the div rem input can be converted into a shift subtract input to run a round of the shift subtracter. -/ def DivRemInput.toShiftSubtractInput - (h : DivRemInput w wr (wn + 1) n d q r) : - ShiftSubtractInput w wr (wn + 1) n d q r := { + (h : DivRemInput w wr (wn + 1) n d) : + ShiftSubtractInput w wr (wn + 1) n d := { + q := h.q, + r := h.r hwr := h.hwr, hwn := h.hwn, hwrn := by have := h.hwrn; omega, @@ -926,13 +985,14 @@ def DivRemInput.toShiftSubtractInput hwn_lt := by omega } -def ShiftSubtractInput.nmsb (_ : ShiftSubtractInput w wr wn n d q r) : - BitVec 1 := - BitVec.ofBool <| n.getMsb wn +def ShiftSubtractInput.nmsb (_ : ShiftSubtractInput w wr wn n d) : + Bool := n.getLsb (wn - 1) def DivRemInput.wr_eq_w_of_wn_eq_zero - (h : DivRemInput w wr 0 n d q r) : DivRemInput w w 0 n d q r := + (h : DivRemInput w wr 0 n d) : DivRemInput w w 0 n d := { + q := h.q, + r := h.r, hwr := by have := h.hwr; omega, hwn := h.hwn, hwrn := by have := h.hwrn; omega, @@ -946,17 +1006,117 @@ def DivRemInput.wr_eq_w_of_wn_eq_zero /- # Division Recurrence for Bitblasting (V2 )-/ +def concatBit' (x : BitVec w) (b : Bool) : BitVec w := + x <<< 1 ||| (BitVec.ofBool b).zeroExtend w + +theorem concatBit'_lt (x : BitVec w) (b : Bool) : + (concatBit' x b).toNat < 2 ^ w := (concatBit' x b).isLt + +theorem toNat_concatBit'_eq (x : BitVec w) (b : Bool) {k : Nat} + (hkw : k < w) (hkx : x.toNat < 2 ^ k) : + (concatBit' x b).toNat = x.toNat * 2 + (if b then 1 else 0) := by + simp only [concatBit'] + rw [toNat_shiftLeft_or_zeroExtend_ofBool_eq (k := k)] + · omega + · omega + +#print axioms toNat_concatBit'_eq + +theorem toNat_concatBit'_lt (x : BitVec w) (b : Bool) {k : Nat} + (hkw : k < w) (hkx : x.toNat < 2 ^ k) : + (concatBit' x b).toNat < 2 ^ (k + 1) := by + rw [toNat_concatBit'_eq x b hkw hkx] + apply mul_two_add_lt_two_pow_of_lt_two_pow_of_lt_two hkx + · rcases b with rfl | rfl <;> decide + · omega + +private theorem BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_or_zeroExtend_getLsb + {x : BitVec w} {k : Nat} (hk' : 0 < k) : + x >>> (k - 1) = ((x >>> k <<< 1) ||| ((BitVec.ofBool (x.getLsb (k - 1))).zeroExtend w)) := by + ext i + simp only [getLsb_ushiftRight, getLsb_or, getLsb_shiftLeft, Fin.is_lt, decide_True, Bool.true_and, + getLsb_zeroExtend, getLsb_ofBool] + by_cases (i : Nat) < 1 + case pos h => + simp only [h, decide_True, Bool.not_true, Bool.false_and] + have hi : (i : Nat) = 0 := by omega + simp [hi] + case neg h => + simp only [h, decide_False, Bool.not_false, Bool.true_and] + have hi : (i : Nat) ≠ 0 := by omega + simp only [hi, decide_False, Bool.false_and, Bool.or_false] + congr 1 + omega + +theorem ShiftSubtractInput.n_shiftr_wl_minus_one_eq_n_shiftr_wl_or_nmsb + (h : ShiftSubtractInput w wr wn n d) : + n >>> (wn - 1) = (n >>> wn).concatBit' (ShiftSubtractInput.nmsb h) := by + rw [concatBit'] + rw [ShiftSubtractInput.nmsb] + rw [BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_or_zeroExtend_getLsb] + have hwn_lt := h.hwn_lt + omega + +theorem ShiftSubtractInput.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb + (h : ShiftSubtractInput w wr wn n d) {k : Nat} : + n.toNat >>> (wn - 1) = (n.toNat >>> wn) * 2 + if (n.getMsb wn) then 1 else 0 := by + have hn := ShiftSubtractInput.n_shiftr_wl_minus_one_eq_n_shiftr_wl_or_nmsb h + obtain hn : (n >>> (wn - 1)).toNat = ((n >>> wn).concatBit' h.nmsb).toNat := by + simp [hn] + simp at hn + rw [toNat_concatBit'_eq] at hn + repeat sorry + -- rw [BitVec.toNat_shiftRight] + /-- One round of the division algorithm, that tries to perform a subtract shift. Note that this is only called when `r.msb = false`, so we will not overflow. This means that `r'.toNat = r.toNat *2 + q.toNat` . + +TODO: think of isolating the pattern as `concatBit'`. -/ -def divSubtractShift (h : ShiftSubtractInput w wr wn n d q r) : - DivRemInput w (wr + 1) (wn - 1) n d q r := - let r' := (r <<< 1) ||| (h.nmsb).zeroExtend w - let q := r' < d - ⟨if q then r' - d else r', BitVec.ofBool q⟩ +def divSubtractShift (h : ShiftSubtractInput w wr wn n d) : + DivRemInput w (wr + 1) (wn - 1) n d := + let r' := concatBit' h.r h.nmsb + let qlo := r' < d + let q := h.q.concatBit' qlo + if hqlo : qlo then { + q := q, + r := r', + hwr := by + have := h.hwr + have := h.wr_add_one_le_w + omega, + hwn := by + have := h.hwn + omega, + hwrn := by + have := h.hwrn + have := h.wr_add_one_le_w + omega, + hd := h.hd, + hrd := by + simp [qlo] at hqlo + simp [BitVec.lt_def] at hqlo + assumption, + hrwr := by + simp [r'] + apply toNat_concatBit'_lt + · exact h.wr_add_one_le_w + · exact h.hrwr, + hqwr := by + simp [q] + apply toNat_concatBit'_lt + · exact h.wr_add_one_le_w + · exact h.hqwr, + hdiv := by + simp [r'] + rw [toNat_concatBit'_eq] + · simp only [hqlo, decide_True, ↓reduceIte] + sorry + repeat sorry -- HERE + } else sorry /-- Core divsion recurrence. @@ -978,14 +1138,20 @@ See that when it is called, we will know that : - r < d - n.toNat >>> wr = -/ -def divRec' (h : DivRemInput w wr wn n d q r) : - DivRemOutput w n d := +def divRec' (h : DivRemInput w wr wn n d) : + DivRemInput w w 0 n d := match wn with | 0 => h.wr_eq_w_of_wn_eq_zero | _ + 1 => let new := divSubtractShift h.toShiftSubtractInput divRec' new +theorem divRec'_correct (n d : BitVec w) (hw : 0 < w) (hd : 0 < d) : + let out := divRec' (DivRemInput_init w n d hw hd) + n.udiv d = out.q ∧ n.umod d = out.r := by + simp + apply DivRemInput_implies_udiv_urem + def checkDivRec' : Bool × Option (BitVec 4 × BitVec 4 × BitVec 4 × BitVec 4) := Id.run do let w := 4 let max := (Nat.pow 2 w) @@ -1004,67 +1170,11 @@ def checkDivRec' : Bool × Option (BitVec 4 × BitVec 4 × BitVec 4 × BitVec 4) outputs := .some (n, d, qr.2, qr.1) (wrong, outputs) -/-- info: (false, none) -/ -#guard_msgs in #reduce checkDivRec' +-- /-- info: (false, none) -/ +-- #guard_msgs in #reduce checkDivRec' /- # Tons of helper lemmas about the behaviour of divRec -/ -/-- -The arithmetic version of: -If `n : Bitvec w` has only the low `k < w` bits set, -then `(n <<< 1 | b)` does not overflow. --/ -theorem mul_two_add_lt_two_pow_of_lt_two_pow_of_lt_two - (hn : n < 2 ^ k) (hb : b < 2) (hk : k < w) : - n * 2 + b < 2 ^ w := by - have : k + 1 ≤ w := by omega - have : 2^(k + 1) ≤ 2 ^w := by - apply Nat.pow_le_pow_of_le_right (by decide) (by assumption) - have : n ≤ 2 ^k - 1 := by omega - have : n * 2 ≤ 2^k * 2 - 2 := by omega - have : n * 2 + b ≤ 2^k * 2 - 1 := by omega - have : n * 2 + b ≤ 2 ^(k + 1) - 1 := by omega - have : n * 2 + b ≤ 2 ^w - 1 := by omega - have : n * 2 + b < 2^w := by omega - assumption - --- 0 a | a < 2 --- a b -- 2a + b < 4 --- k < w + 1 --- 2^k ≤ 2^w --- x ≤ 2^w --- x * w ≤ 2^w + 1 -/-- -the LHS of the condition of divSubtractShift, -written as an arithmetic inequality. --/ -theorem divSubtractShift_toNat_cond_lhs (w : Nat) - (r : BitVec w) - (b : Bool) - (hk : k < w) - (hr : r.toNat < 2 ^ k) : - (r <<< 1 ||| zeroExtend w (ofBool b)).toNat = - (r.toNat * 2 + (if b then 1 else 0)) := by - have hk' : 2^k < 2^w := by - apply Nat.pow_lt_pow_of_lt (by decide) (by omega) - rcases w with rfl | w - · omega -- contradiction, k < w - · rw [← BitVec.add_eq_or_of_and_eq_zero] - · simp [Bool.toNat] - rw [Nat.shiftLeft_eq] - simp [show (2^1 = 2) by decide] - rw [Nat.mod_eq_of_lt] - · rcases b with rfl | rfl <;> simp - · apply mul_two_add_lt_two_pow_of_lt_two_pow_of_lt_two - · exact hr - · rcases b <;> decide - · assumption - · ext i - simp only [getLsb_and, getLsb_shiftLeft, Fin.is_lt, decide_True, Bool.true_and, - getLsb_zeroExtend, getLsb_ofBool, getLsb_zero, and_eq_false_imp, and_eq_true, not_eq_true', - decide_eq_false_iff_not, Nat.not_lt, decide_eq_true_eq, and_imp] - intros hi _ hi' - omega /-- the condition of divSubtractShift is true iff From 25fa2b46551ddd11f6d5404ce3a1710dc39219f9 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 5 Jul 2024 13:42:19 +0100 Subject: [PATCH 56/64] wrap up one branch of proof, now have second branch left --- src/Init/Data/BitVec/Bitblast.lean | 150 +++++++++-------------------- 1 file changed, 48 insertions(+), 102 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 00d39fb563cd..222e706a7fd8 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -865,15 +865,18 @@ theorem toNat_shiftLeft_or_zeroExtend_ofBool_eq (w : Nat) (hk : k < w) (hr : r.toNat < 2 ^ k) : (r <<< 1 ||| zeroExtend w (ofBool b)).toNat = - (r.toNat * 2 + (if b then 1 else 0)) := by + (r.toNat * 2 + b.toNat) := by + have : b.toNat = if b then 1 else 0 := by rcases b <;> rfl + rw [this] have hk' : 2^k < 2^w := by apply Nat.pow_lt_pow_of_lt (by decide) (by omega) rcases w with rfl | w · omega -- contradiction, k < w · rw [← BitVec.add_eq_or_of_and_eq_zero] - · simp [Bool.toNat] + · simp only [toNat_add, toNat_shiftLeft, toNat_truncate, toNat_ofBool, toNat, add_mod_mod, + mod_add_mod] rw [Nat.shiftLeft_eq] - simp [show (2^1 = 2) by decide] + simp only [show (2 ^ 1 = 2) by decide] rw [Nat.mod_eq_of_lt] · rcases b with rfl | rfl <;> simp · apply mul_two_add_lt_two_pow_of_lt_two_pow_of_lt_two @@ -887,7 +890,6 @@ theorem toNat_shiftLeft_or_zeroExtend_ofBool_eq (w : Nat) intros hi _ hi' omega - /- # DivRem, V3 -/ structure DivRemInput (w wr wn : Nat) (n : BitVec w) @@ -901,7 +903,7 @@ structure DivRemInput (w wr wn : Nat) hrd : r.toNat < d.toNat hrwr : r.toNat < 2^wr hqwr : q.toNat < 2^wr - hdiv : (n >>> wn).toNat = d.toNat * q.toNat + r.toNat + hdiv : n.toNat >>> wn = d.toNat * q.toNat + r.toNat /-- In a valid DivRemInput, it is implied that `w > 0`. -/ def DivRemInput.hw (h : DivRemInput w wr wn n d) : 0 < w := by @@ -1012,9 +1014,9 @@ def concatBit' (x : BitVec w) (b : Bool) : BitVec w := theorem concatBit'_lt (x : BitVec w) (b : Bool) : (concatBit' x b).toNat < 2 ^ w := (concatBit' x b).isLt -theorem toNat_concatBit'_eq (x : BitVec w) (b : Bool) {k : Nat} +theorem toNat_concatBit'_eq (x : BitVec w) (b : Bool) (k : Nat) (hkw : k < w) (hkx : x.toNat < 2 ^ k) : - (concatBit' x b).toNat = x.toNat * 2 + (if b then 1 else 0) := by + (concatBit' x b).toNat = x.toNat * 2 + b.toNat:= by simp only [concatBit'] rw [toNat_shiftLeft_or_zeroExtend_ofBool_eq (k := k)] · omega @@ -1022,10 +1024,10 @@ theorem toNat_concatBit'_eq (x : BitVec w) (b : Bool) {k : Nat} #print axioms toNat_concatBit'_eq -theorem toNat_concatBit'_lt (x : BitVec w) (b : Bool) {k : Nat} +theorem toNat_concatBit'_lt (x : BitVec w) (b : Bool) (k : Nat) (hkw : k < w) (hkx : x.toNat < 2 ^ k) : (concatBit' x b).toNat < 2 ^ (k + 1) := by - rw [toNat_concatBit'_eq x b hkw hkx] + rw [toNat_concatBit'_eq x b k hkw hkx] apply mul_two_add_lt_two_pow_of_lt_two_pow_of_lt_two hkx · rcases b with rfl | rfl <;> decide · omega @@ -1057,16 +1059,38 @@ theorem ShiftSubtractInput.n_shiftr_wl_minus_one_eq_n_shiftr_wl_or_nmsb have hwn_lt := h.hwn_lt omega +/-- +Shifting right by `n < w` yields a bitvector whose value +is less than `2^(w - n)` +-/ +theorem BitVec.ushiftRight_lt (x : BitVec w) (n : Nat) (hn : n ≤ w) : + (x >>> n).toNat < 2 ^ (w - n) := by + rw [toNat_ushiftRight] + rw [shiftRight_eq_div_pow] + rw [Nat.div_lt_iff_lt_mul] + · rw [Nat.pow_sub_mul_pow] + · apply x.isLt + · apply hn + · apply Nat.pow_pos (by decide) + +/-- The value of shifting by `wn - 1` equals +shifting by `wn` and grabbing the lsb at (wn - 1) -/ theorem ShiftSubtractInput.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb - (h : ShiftSubtractInput w wr wn n d) {k : Nat} : - n.toNat >>> (wn - 1) = (n.toNat >>> wn) * 2 + if (n.getMsb wn) then 1 else 0 := by + (h : ShiftSubtractInput w wr wn n d) : + n.toNat >>> (wn - 1) = (n.toNat >>> wn) * 2 + h.nmsb.toNat := by have hn := ShiftSubtractInput.n_shiftr_wl_minus_one_eq_n_shiftr_wl_or_nmsb h obtain hn : (n >>> (wn - 1)).toNat = ((n >>> wn).concatBit' h.nmsb).toNat := by simp [hn] simp at hn - rw [toNat_concatBit'_eq] at hn - repeat sorry - -- rw [BitVec.toNat_shiftRight] + rw [toNat_concatBit'_eq (k := w - wn)] at hn + · rw [hn] + rw [toNat_ushiftRight] + · have := h.hwn_lt + have := h.hw + omega + · apply BitVec.ushiftRight_lt + have := h.hwrn + omega /-- One round of the division algorithm, that tries to perform a subtract shift. @@ -1112,10 +1136,16 @@ def divSubtractShift (h : ShiftSubtractInput w wr wn n d) : · exact h.hqwr, hdiv := by simp [r'] - rw [toNat_concatBit'_eq] - · simp only [hqlo, decide_True, ↓reduceIte] - sorry - repeat sorry -- HERE + sorry -- HERE HERE + -- rw [toNat_concatBit'_eq] + -- · simp only [hqlo, decide_True, ↓reduceIte] + -- sorry + -- rw [h.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb] + -- rw [h.hdiv] + -- rw [toNat_concatBit'_eq (k := wr)] + -- · sorry + -- · sorry + -- · sorry } else sorry /-- @@ -1176,69 +1206,6 @@ def checkDivRec' : Bool × Option (BitVec 4 × BitVec 4 × BitVec 4 × BitVec 4) /- # Tons of helper lemmas about the behaviour of divRec -/ -/-- -the condition of divSubtractShift is true iff -the arithmetic inequality holds. --/ -theorem divSubtractShift_cond_iff (w wr wn: Nat) - (hwr : wr < w) - (hwn : wn < w) - (hwrn : wr + wn = w) - (d n q r : BitVec w) - (hr : r.toNat < 2^wr) -- r : BitVec wr - (hq : q.toNat < 2^wr) -- q : BitVec wr - (hn : n.toNat < 2^wn) -- n : BitVec wn - (hdiv : (n >>> wr).toNat = r.toNat + d.toNat * q.toNat) - (b : Bool) : - (r <<< 1 ||| zeroExtend w (ofBool b) < d) ↔ - (r.toNat * 2 + (if b then 1 else 0) < d.toNat) := by - constructor - · intros h - sorry - · intros h - sorry - -theorem divSubtractShift_lt_two_pow_of_lt_two_pow (w : Nat) - (d r : BitVec w) - (b : Bool) - (hk : k < w) - (hr : r.toNat < 2 ^ k) : - (divSubtractShift w d r b).1.toNat < 2 ^ (k + 1) := by - simp [divSubtractShift] - by_cases (r <<< 1 ||| zeroExtend w (ofBool b) < d) - - case pos h => sorry - case neg h => sorry -theorem divSubtractShift_lt_of_lt (w : Nat) (d r : BitVec w) (b : Bool) (hrd : r < d) : - (divSubtractShift w d r b).1 < d := by - simp [divSubtractShift] - by_cases (r <<< 1 ||| zeroExtend w (ofBool b) < d) - case pos h => sorry - case neg h => sorry - -mutual - theorem t1 (w wn : Nat) (d r : BitVec w) (n : BitVec wn) : - (divRec' w wn d r n).fst < d := by sorry - -- theorem t2 (w wn : Nat) (d r : BitVec w) (n : BitVec wn) : - -- d * (divRec' w wn d r n).snd + (divRec' w wn d r n).fst = n - theorem t2 (w : Nat) (d : BitVec w) (n : BitVec w) : - d * (divRec' w w d (0#w) n).snd + (divRec' w w d (0#w) n).fst = n := by sorry - theorem t3 (w : Nat) (d n : BitVec w) : - d.toNat * (divRec' w w d (0#w) n).snd.toNat + (divRec' w w d (0#w) n).fst.toNat < 2 ^ w := by sorry -end - -theorem divRec'_eq_udiv (w : Nat) (n : BitVec w) (d : BitVec w) (hd : 0 < d): - n.umod d = (divRec' w w d 0#w n).1 ∧ - n.udiv d = (divRec' w w d 0#w n).2 := by - have k := div_characterized_of_mul_add_of_lt (d := d) (n := n) - (q := (divRec' w w d 0#w n).2) - (r := (divRec' w w d 0#w n).1) - hd - (by apply t1) - (by apply t2) - (by apply t3) - simp [k] - /- # (OLD OLD) Division Recurrence for Bitblasting -/ -- n = d * q + r @@ -1422,27 +1389,6 @@ def divRec (qr : DivRecQuotRem w n d) (j : Nat) : | j + 1 => divRec qr j tryDivSubtractShift qr' (w - 1 - j) -/-- -TODO: what's a good theorem name? -If the LSB is false, then shifting to (w - 1) is the same as shifting to w and then right shifting 1. --/ -private theorem BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_or_zeroExtend_getLsb - {x : BitVec w} : - x >>> (w - 1) = ((x >>> w <<< 1) ||| (BitVec.ofBool (x.getLsb (w - 1))).zeroExtend w) := by - ext i - simp only [getLsb_ushiftRight, getLsb_or, getLsb_shiftLeft, Fin.is_lt, decide_True, Bool.true_and, - getLsb_zeroExtend, getLsb_ofBool] - by_cases (i : Nat) < 1 - case pos h => - simp only [h, decide_True, Bool.not_true, Bool.false_and] - have hi : (i : Nat) = 0 := by omega - simp [hi] - case neg h => - simp only [h, decide_False, Bool.not_false, Bool.true_and] - have hi : (i : Nat) ≠ 0 := by omega - simp only [hi, decide_False, Bool.false_and, Bool.or_false] - congr 1 - omega private theorem BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_add_zeroExtend_getLsb {x : BitVec w} : From f9f8b09c4f76962341c5325294149eb16a64ed5e Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 5 Jul 2024 15:11:20 +0100 Subject: [PATCH 57/64] feat: finish then branch, now do else branch --- src/Init/Data/BitVec/Bitblast.lean | 115 ++++++++++++++++++++++------- 1 file changed, 87 insertions(+), 28 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 222e706a7fd8..963c1da31a09 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -770,9 +770,6 @@ theorem BitVec.add_mul (x y z : BitVec w) : (y + z) * x = y * x + z * x := by rw [BitVec.mul_comm, BitVec.mul_add] congr 1 <;> rw [BitVec.mul_comm] - - - theorem BitVec.add_assoc {x y z : BitVec w} : x + y + z = x + (y + z) := by apply eq_of_toNat_eq simp[Nat.add_assoc] @@ -951,16 +948,26 @@ structure ShiftSubtractInput (w wr wn : Nat) (n d: BitVec w) extends DivRemInput w wr wn n d : Type where hwn_lt : 0 < wn -- we can only call this function legally if we have dividend bits. -/-- In the shift subtract input, we have one more bit to spare, -so we do not overflow. -/ + +/-- +In the shift subtract input, we have one more bit to spare, +so we do not overflow. +-/ def ShiftSubtractInput.wr_add_one_le_w (h : ShiftSubtractInput w wr wn n d) : wr + 1 ≤ w := by have hwrn := h.hwrn have hwn_lt := h.hwn_lt omega -/-- In the shift subtract input, we have one more bit to spare, -so we do not overflow. -/ +def ShiftSubtractInput.wr_lt_w + (h : ShiftSubtractInput w wr wn n d) : wr < w := by + have hwr := h.wr_add_one_le_w + omega + +/-- +In the shift subtract input, we have one more bit to spare, +so we do not overflow. +-/ def ShiftSubtractInput.wr_le_wr_sub_one (h : ShiftSubtractInput w wr wn n d) : wr ≤ w - 1 := by have hw := h.hw @@ -1015,20 +1022,24 @@ theorem concatBit'_lt (x : BitVec w) (b : Bool) : (concatBit' x b).toNat < 2 ^ w := (concatBit' x b).isLt theorem toNat_concatBit'_eq (x : BitVec w) (b : Bool) (k : Nat) - (hkw : k < w) (hkx : x.toNat < 2 ^ k) : + (hk : k < w) (hx : x.toNat < 2 ^ k) : (concatBit' x b).toNat = x.toNat * 2 + b.toNat:= by simp only [concatBit'] rw [toNat_shiftLeft_or_zeroExtend_ofBool_eq (k := k)] · omega · omega -#print axioms toNat_concatBit'_eq +theorem toNat_concatBit'_false_eq (x : BitVec w) (k : Nat) + (hk : k < w) (hx : x.toNat < 2 ^ k) : + (concatBit' x false).toNat = x.toNat * 2 := by + rw [toNat_concatBit'_eq (k := k) (hk := hk) (hx := hx)] + simp theorem toNat_concatBit'_lt (x : BitVec w) (b : Bool) (k : Nat) - (hkw : k < w) (hkx : x.toNat < 2 ^ k) : + (hk : k < w) (hx : x.toNat < 2 ^ k) : (concatBit' x b).toNat < 2 ^ (k + 1) := by - rw [toNat_concatBit'_eq x b k hkw hkx] - apply mul_two_add_lt_two_pow_of_lt_two_pow_of_lt_two hkx + rw [toNat_concatBit'_eq x b k hk hx] + apply mul_two_add_lt_two_pow_of_lt_two_pow_of_lt_two hx · rcases b with rfl | rfl <;> decide · omega @@ -1103,9 +1114,10 @@ TODO: think of isolating the pattern as `concatBit'`. def divSubtractShift (h : ShiftSubtractInput w wr wn n d) : DivRemInput w (wr + 1) (wn - 1) n d := let r' := concatBit' h.r h.nmsb - let qlo := r' < d - let q := h.q.concatBit' qlo - if hqlo : qlo then { + let rltd : Bool := r' < d + let q := h.q.concatBit' !rltd -- if r ≥ d, then we have a quotient bit. + if hrltd : rltd + then { q := q, r := r', hwr := by @@ -1121,8 +1133,8 @@ def divSubtractShift (h : ShiftSubtractInput w wr wn n d) : omega, hd := h.hd, hrd := by - simp [qlo] at hqlo - simp [BitVec.lt_def] at hqlo + simp [rltd] at hrltd + simp [BitVec.lt_def] at hrltd assumption, hrwr := by simp [r'] @@ -1135,18 +1147,65 @@ def divSubtractShift (h : ShiftSubtractInput w wr wn n d) : · exact h.wr_add_one_le_w · exact h.hqwr, hdiv := by + rw [h.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb] + simp only [r'] + rw [h.hdiv] + rw [toNat_concatBit'_eq (x := h.r) + (k := wr) + (hk := h.wr_lt_w) + (hx := h.hrwr)] + simp only [q] + simp only [hrltd, Bool.not_true] + have hq' := toNat_concatBit'_false_eq h.q wr h.wr_lt_w h.hqwr + rw [hq'] + rw [← Nat.mul_assoc] + rw [Nat.add_mul] + rw [Nat.add_assoc] + } + else { + q := q, + r := r', + hwr := by + have := h.hwr + have := h.wr_add_one_le_w + omega, + hwn := by + have := h.hwn + omega, + hwrn := by + have := h.hwrn + have := h.wr_add_one_le_w + omega, + hd := h.hd, + hrd := by + simp [rltd] at hrltd + simp [BitVec.lt_def] at hrltd + sorry + hrwr := by simp [r'] - sorry -- HERE HERE - -- rw [toNat_concatBit'_eq] - -- · simp only [hqlo, decide_True, ↓reduceIte] - -- sorry - -- rw [h.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb] - -- rw [h.hdiv] - -- rw [toNat_concatBit'_eq (k := wr)] - -- · sorry - -- · sorry - -- · sorry - } else sorry + apply toNat_concatBit'_lt + · exact h.wr_add_one_le_w + · exact h.hrwr, + hqwr := by + simp [q] + apply toNat_concatBit'_lt + · exact h.wr_add_one_le_w + · exact h.hqwr, + hdiv := by + rw [h.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb] + simp only [r'] + rw [h.hdiv] + rw [toNat_concatBit'_eq (x := h.r) + (k := wr) + (hk := h.wr_lt_w) + (hx := h.hrwr)] + simp only [q] + simp only [hrltd, Bool.not_true] + have hq' := toNat_concatBit'_false_eq h.q wr h.wr_lt_w h.hqwr + rw [Nat.add_mul] + rw [Nat.add_assoc] + sorry + } /-- Core divsion recurrence. From e15e05ccbaa1a818923f3e4f07a55c001e7057b2 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 5 Jul 2024 15:57:49 +0100 Subject: [PATCH 58/64] chore: finish another sorry on the else branch --- src/Init/Data/BitVec/Bitblast.lean | 85 ++++++++++++++++++++++++++---- 1 file changed, 74 insertions(+), 11 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 963c1da31a09..c37a4b95e84f 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -851,6 +851,33 @@ theorem mul_two_add_lt_two_pow_of_lt_two_pow_of_lt_two have : n * 2 + b < 2^w := by omega assumption +/-- +This is used when proving the correctness of the divison algorithm, +where we know that `r < d`. +We then want to show that `r <<< 1 | b - d < d` as the loop invariant. +In arithmethic, this is the same as showing that +`r * 2 + 1 - d < d`, which this theorem establishes. +-/ +theorem two_mul_add_sub_lt_of_lt_of_lt_two -- HERE HERE + (h : a < x) (hy : y < 2): + 2 * a + y - x < x := by omega + +/-- +Variant of `BitVec.toNat_sub` that does not introduce a modulo. +-/ +theorem BitVec.toNat_sub_of_lt {x y : BitVec w} (hy : y ≤ x) : + (x - y).toNat = x.toNat - y.toNat := by + simp only [toNat_sub] + rw [← Nat.add_sub_assoc] + · rw [Nat.sub_add_comm] + · rw [Nat.add_mod] + simp only [mod_self, Nat.add_zero, mod_mod] + rw [Nat.mod_eq_of_lt] + omega + · simp only [le_def] at hy + omega + · omega + /-- If `n : Bitvec w` has only the low `k < w` bits set, then `(n <<< 1 | b)` does not overflow, and we can compute its value @@ -1114,7 +1141,7 @@ TODO: think of isolating the pattern as `concatBit'`. def divSubtractShift (h : ShiftSubtractInput w wr wn n d) : DivRemInput w (wr + 1) (wn - 1) n d := let r' := concatBit' h.r h.nmsb - let rltd : Bool := r' < d + let rltd : Bool := r' < d -- true if r' < d. In this case, we don't have a quotient bit. let q := h.q.concatBit' !rltd -- if r ≥ d, then we have a quotient bit. if hrltd : rltd then { @@ -1164,7 +1191,7 @@ def divSubtractShift (h : ShiftSubtractInput w wr wn n d) : } else { q := q, - r := r', + r := r' - d, hwr := by have := h.hwr have := h.wr_add_one_le_w @@ -1180,12 +1207,25 @@ def divSubtractShift (h : ShiftSubtractInput w wr wn n d) : hrd := by simp [rltd] at hrltd simp [BitVec.lt_def] at hrltd + have hr := h.hrd + -- | TODO: make this a field. + have hr' : h.r < d := by simp [BitVec.lt_def]; exact hr + simp only [r'] sorry hrwr := by - simp [r'] - apply toNat_concatBit'_lt - · exact h.wr_add_one_le_w - · exact h.hrwr, + simp only [r'] + /- TODO: this proof is repeated, lift it to above the structure building. -/ + have hdr' : ¬ (r' < d) := by + simp [rltd] at hrltd + assumption + have hdr' : d ≤ r' := BitVec.le_iff_not_lt.mp hdr' + rw [BitVec.toNat_sub_of_lt hdr'] + have hr' : r'.toNat < 2 ^ (wr + 1) := by + simp [r'] + apply toNat_concatBit'_lt + · exact h.wr_add_one_le_w + · exact h.hrwr + omega hqwr := by simp [q] apply toNat_concatBit'_lt @@ -1193,6 +1233,11 @@ def divSubtractShift (h : ShiftSubtractInput w wr wn n d) : · exact h.hqwr, hdiv := by rw [h.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb] + have hdr' : ¬ (r' < d) := by + simp [rltd] at hrltd + assumption + have hdr' : d ≤ r' := BitVec.le_iff_not_lt.mp hdr' + rw [BitVec.toNat_sub_of_lt hdr'] simp only [r'] rw [h.hdiv] rw [toNat_concatBit'_eq (x := h.r) @@ -1200,11 +1245,29 @@ def divSubtractShift (h : ShiftSubtractInput w wr wn n d) : (hk := h.wr_lt_w) (hx := h.hrwr)] simp only [q] - simp only [hrltd, Bool.not_true] - have hq' := toNat_concatBit'_false_eq h.q wr h.wr_lt_w h.hqwr - rw [Nat.add_mul] - rw [Nat.add_assoc] - sorry + rw [toNat_concatBit'_eq (x := h.q) + (k := wr) + (hk := h.wr_lt_w) + (hx := h.hqwr)] + simp only [hrltd, Bool.not_false, toNat_true] + simp [Nat.mul_add] + apply Eq.symm + calc + _ = d.toNat * (h.q.toNat * 2) + d.toNat + (h.r.toNat * 2 + h.nmsb.toNat - d.toNat) := + by rfl + _ = d.toNat * (h.q.toNat * 2) + d.toNat - d.toNat + (h.r.toNat * 2 + h.nmsb.toNat) := by + simp + rw [Nat.add_assoc] + congr 1 + rw [Nat.add_sub_cancel'] + sorry + _ = d.toNat * (h.q.toNat * 2) + (h.r.toNat * 2 + h.nmsb.toNat) := by + rw [Nat.add_sub_cancel] + _ = (d.toNat * h.q.toNat + h.r.toNat) * 2 + h.nmsb.toNat := by + rw [← Nat.add_assoc] + rw [← Nat.mul_assoc] + rw [Nat.add_mul] + _ = (d.toNat * h.q.toNat + h.r.toNat) * 2 + h.nmsb.toNat := rfl } /-- From faea676a610c378793fd3fbd1ec3091168e71674 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 5 Jul 2024 20:28:01 +0100 Subject: [PATCH 59/64] feat: finish udiv/urem bitblast --- src/Init/Data/BitVec/Bitblast.lean | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index c37a4b95e84f..9a5a84f9603b 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -1210,8 +1210,16 @@ def divSubtractShift (h : ShiftSubtractInput w wr wn n d) : have hr := h.hrd -- | TODO: make this a field. have hr' : h.r < d := by simp [BitVec.lt_def]; exact hr + rw [BitVec.toNat_sub_of_lt hrltd] simp only [r'] - sorry + rw [toNat_concatBit'_eq (x := h.r) + (k := wr) + (hk := h.wr_lt_w) + (hx := h.hrwr)] + rw [Nat.mul_comm] -- TODO: canonicalize an order between w*2 and 2*w + apply two_mul_add_sub_lt_of_lt_of_lt_two + · exact hr + · apply Bool.toNat_lt hrwr := by simp only [r'] /- TODO: this proof is repeated, lift it to above the structure building. -/ @@ -1260,7 +1268,15 @@ def divSubtractShift (h : ShiftSubtractInput w wr wn n d) : rw [Nat.add_assoc] congr 1 rw [Nat.add_sub_cancel'] - sorry + simp only [r'] at hdr' + simp only [BitVec.le_def] at hdr' + rw [BitVec.toNat_concatBit'_eq + (x := h.r) + (b := h.nmsb) + (k := wr) + (hk := h.wr_lt_w) + (hx := h.hrwr)] at hdr' + assumption _ = d.toNat * (h.q.toNat * 2) + (h.r.toNat * 2 + h.nmsb.toNat) := by rw [Nat.add_sub_cancel] _ = (d.toNat * h.q.toNat + h.r.toNat) * 2 + h.nmsb.toNat := by @@ -1270,6 +1286,9 @@ def divSubtractShift (h : ShiftSubtractInput w wr wn n d) : _ = (d.toNat * h.q.toNat + h.r.toNat) * 2 + h.nmsb.toNat := rfl } +/-- info: 'BitVec.divSubtractShift' depends on axioms: [propext, Classical.choice, Quot.sound] -/ +#guard_msgs in #print axioms divSubtractShift + /-- Core divsion recurrence. We have three widths at play: @@ -1298,12 +1317,18 @@ def divRec' (h : DivRemInput w wr wn n d) : let new := divSubtractShift h.toShiftSubtractInput divRec' new +/-- info: 'BitVec.divRec'' depends on axioms: [propext, Classical.choice, Quot.sound] -/ +#guard_msgs in #print axioms divRec' + theorem divRec'_correct (n d : BitVec w) (hw : 0 < w) (hd : 0 < d) : let out := divRec' (DivRemInput_init w n d hw hd) n.udiv d = out.q ∧ n.umod d = out.r := by simp apply DivRemInput_implies_udiv_urem +/-- info: 'BitVec.divRec'_correct' depends on axioms: [propext, Classical.choice, Quot.sound] -/ +#guard_msgs in #print axioms divRec'_correct + def checkDivRec' : Bool × Option (BitVec 4 × BitVec 4 × BitVec 4 × BitVec 4) := Id.run do let w := 4 let max := (Nat.pow 2 w) @@ -1322,6 +1347,7 @@ def checkDivRec' : Bool × Option (BitVec 4 × BitVec 4 × BitVec 4 × BitVec 4) outputs := .some (n, d, qr.2, qr.1) (wrong, outputs) + -- /-- info: (false, none) -/ -- #guard_msgs in #reduce checkDivRec' From ccf7628cc4d623ec4215982ceb3e3fca12e9725c Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 5 Jul 2024 20:33:56 +0100 Subject: [PATCH 60/64] chore: throw away incorrect implementations --- src/Init/Data/BitVec/Bitblast.lean | 338 ----------------------------- 1 file changed, 338 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 9a5a84f9603b..152435e08fc9 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -1326,342 +1326,4 @@ theorem divRec'_correct (n d : BitVec w) (hw : 0 < w) (hd : 0 < d) : simp apply DivRemInput_implies_udiv_urem -/-- info: 'BitVec.divRec'_correct' depends on axioms: [propext, Classical.choice, Quot.sound] -/ -#guard_msgs in #print axioms divRec'_correct - -def checkDivRec' : Bool × Option (BitVec 4 × BitVec 4 × BitVec 4 × BitVec 4) := Id.run do - let w := 4 - let max := (Nat.pow 2 w) - let mut outputs := .none - let mut wrong := false - for n in (List.range max) do - for d in (List.range (max - 1)).map (fun n => Nat.add n 1) do - have hd : d > 0 := by sorry - have hd' : d < 2 ^ w := by sorry - let init := DivRemInput_init w (BitVec.ofNat w n) (BitVec.ofNat w d) - (by omega) - (by simp; rw [Nat.mod_eq_of_lt]; omega; omega) - let qr := divRec' init - if qr.q != n then - wrong := true - outputs := .some (n, d, qr.2, qr.1) - (wrong, outputs) - - --- /-- info: (false, none) -/ --- #guard_msgs in #reduce checkDivRec' - -/- # Tons of helper lemmas about the behaviour of divRec -/ - - -/- # (OLD OLD) Division Recurrence for Bitblasting -/ - --- n = d * q + r --- Two-stage subtraction: --- For each bit of the dividend(n) starting from the MSB: --- --- 1) Add ith bit of dividend as MSB of remainder `rem`. --- --- 1) Compute carry bits when subtracting divisor `d` from current --- remainder `rem`, which determines the current quotient bit. --- 2) Perform subtraction operation based on current quotient bit and shift --- remainder by one. --- --- For example, n = 0111 (7 in base 10), d = 0010 (2 in base 10) --- --- i rem d q --- 0 0000 -- insert n.msb [0] --- 0010 0 -- subtract d, not successful --- 0000 -- result [unchanged] --- 0000 -- shift left --- --- 1 0001 -- insert n.msb [1] --- 0010 0 -- subtract d, not successful --- 0001 -- result [unchanged] --- 0010 -- shift --- --- 2 0011 -- insert n.msb [2] --- 0010 1 -- subtract d, successful --- 0001 -- result [CHANGED] --- 0010 -- shift --- --- 3 0011 -- insert n.msb [3] --- 0010 1 -- subtract d, successful --- 0001 -- remainder [CHANGED] --- --- remainder: 0001 (1 in base 10) --- quotient: 0011 (3 in base 10) -/-- A bundle of the quotient and remainder for the intermediate steps when computing n.div d -/ -structure DivRecQuotRem (w : Nat) (n : BitVec w) (d : BitVec w) where - r : BitVec w - q : BitVec w -deriving DecidableEq, Repr - - -/-- One round of the division algorithm, that tries to perform a subtract shift. -/ -def tryDivSubtractShift (qr : DivRecQuotRem w n d) (ix : Nat) : DivRecQuotRem w n d := - let r' := (qr.r <<< 1) ||| (BitVec.ofBool (n.getLsb ix)).zeroExtend w - if r' < d - then { r := r', q := qr.q <<< 1 } - else { - r := r' - d, - q := qr.q <<< 1 ||| 1 - } - -/-- Same as tryDivSubtractShift, with if-then-else pushed into the record, -/ -def tryDivSubtractShift' (qr : DivRecQuotRem w n d) (ix : Nat) : DivRecQuotRem w n d := - let r' := (qr.r <<< 1) ||| (BitVec.ofBool (n.getLsb ix)).zeroExtend w - { r := if r' < d then r' else r' - d, q := qr.q <<< 1 ||| (if r' < d then 0 else 1) } - -theorem tryDivSubtractShift_eq_tryDivSubtractShift' (qr : DivRecQuotRem w n d) (ix : Nat) : - tryDivSubtractShift qr ix = tryDivSubtractShift' qr ix := by - simp [tryDivSubtractShift, tryDivSubtractShift'] - generalize qr.r <<< 1 ||| zeroExtend w (ofBool (n.getLsb ix)) = s - by_cases hslt : s < d - · simp [hslt] - · simp [hslt] - -/-- The tryDivSubtractShift's remainder is upper bounded by `r << 1 | 1`. -/ -theorem tryDivSubtractShift_remainder_lt_shiftLeft_one_or_one {qr : DivRecQuotRem w n d} {ix : Nat} : - (tryDivSubtractShift qr ix).r ≤ (qr.r <<< 1) ||| (BitVec.ofBool (n.getLsb ix)).zeroExtend w := by - rw [tryDivSubtractShift_eq_tryDivSubtractShift'] - simp only [tryDivSubtractShift'] - generalize qr.r <<< 1 ||| zeroExtend w (ofBool (n.getLsb ix)) = s - by_cases hslt : s < d - · simp [hslt] - · simp [hslt] - apply BitVec.sub_le_self_of_le - apply BitVec.le_iff_not_lt.mp hslt - - -/- Surely this exists somewhere, I remember proving this even -/ -theorem Nat.sub_mod_self_eq_sub {x n : Nat} (hx₀ : 0 < x := by omega) (hxn : x < n := by omega) : (n - x) % n = n - x := by - rw [Nat.mod_eq_of_lt] - omega - -@[simp] -theorem Bool.toNat_lt (b : Bool) : b.toNat < 2 := by - have h := Bool.toNat_le b - omega - -/-- TODO: This shows that the remainer is always going to be below 'd', and does not overflow. -/ -theorem tryDivSubtractShift_lt_of_lt {qr : DivRecQuotRem w n d} {ix : Nat} (hrlt : qr.r < d) (hrltTwoPow : qr.r.toNat * 2 + 1 < 2 ^ w): - (tryDivSubtractShift qr ix).r < d := by - simp only [tryDivSubtractShift, ofNat_eq_ofNat] - generalize hr₂ : qr.r <<< 1 ||| zeroExtend w (ofBool (n.getLsb ix)) = r₂ - by_cases hr₂lt : r₂ < d - · simp [hr₂lt] - · simp [hr₂lt] - rw [← BitVec.add_eq_or_of_and_eq_zero] at hr₂ - rw [BitVec.shiftLeft_eq_mul_twoPow] at hr₂ - · simp only [BitVec.lt_def] at hr₂ hr₂lt ⊢ - simp only [toNat_sub, toNat_add, toNat_shiftLeft, toNat_truncate, - toNat_ofBool, add_mod_mod, mod_add_mod, toNat_mul] at hr₂ hr₂lt ⊢ - -- simp only [toNat_twoPow, Nat.pow_one] at hr₂ - rcases w with rfl | rfl | w - · have hr : qr.r = 0#0 := by apply Subsingleton.elim - have hd : d = 0#0 := by apply Subsingleton.elim - rw [hr, hd] at hrlt - simp at hrlt -- TODO: golf this with simpa, ask alex. - · simp only [Nat.reduceAdd, Nat.zero_add, Nat.pow_one, mod_self, Nat.mul_zero] at hrlt hr₂ ⊢ - simp only [Nat.reduceAdd, lt_def] at hrlt hr₂ - simp only [Nat.reduceAdd, zeroExtend_eq, lt_def, toNat_or, toNat_shiftLeft, Nat.pow_one, - toNat_ofBool, Nat.not_lt] at hr₂ - have hd : d.toNat < 2 := d.isLt - generalize hb : (n.getLsb ix) = b - rw [hb] at hr₂ - replace hd : d.toNat = 0 ∨ d.toNat = 1 := by omega; - rcases hd with hd | hd - · omega -- d ≠ 0 - · rw [hd] at hr₂lt hrlt - rcases b with rfl | rfl - · replace hrlt : qr.r.toNat = 0 := by omega - rw [← hr₂] at hr₂lt - simp at hr₂lt - rw [hrlt] at hr₂lt - simp at hr₂lt - · simp; omega - · have hr₂lt₂ : r₂.toNat - d.toNat < d.toNat := by - rw [← hr₂] - simp only [mul_twoPow_eq_shiftLeft, toNat_add, toNat_shiftLeft, toNat_truncate, - toNat_ofBool, add_mod_mod, mod_add_mod] - rw [Nat.shiftLeft_eq] - simp only [Nat.pow_one] - have hd : d.toNat < 2^(w + 1 + 1) := d.isLt - have hb : (n.getLsb ix).toNat < 2 := by simp - simp only [lt_def] at hrlt - rw [Nat.mod_eq_of_lt] - · -- r < d [integers] - -- r - 1 <= d - -- 2(r - 2) <= 2d - -- 2r - 2 - d <= d - -- 2r - 1 - d < d - omega - · omega -- here is the use of hrltTwoPow - calc - _ = (r₂.toNat + (2 ^ (w + 1 + 1) - d.toNat)) % 2 ^ (w + 1 + 1) := by rfl - _ = ((r₂.toNat + (2 ^ (w + 1 + 1)) - d.toNat)) % 2 ^ (w + 1 + 1) := by - rw [Nat.add_sub_assoc] - have := d.isLt - omega - _ = (((2 ^ (w + 1 + 1) + r₂.toNat) - d.toNat)) % 2 ^ (w + 1 + 1) := by - rw [Nat.add_comm] - _ = (2 ^ (w + 1 + 1) + (r₂.toNat - d.toNat)) % 2 ^ (w + 1 + 1) := by - congr 1 - rw [Nat.add_sub_assoc] - omega - _ = ((2 ^ (w + 1 + 1) % 2 ^ (w + 1 + 1)) + ((r₂.toNat - d.toNat) % 2 ^ (w + 1 + 1))) % (2 ^ (w + 1 + 1)) := by - rw [Nat.add_mod] - _ = (r₂.toNat - d.toNat) % 2 ^ (w + 1 + 1) := by - simp - _ = (r₂.toNat - d.toNat) := by - rw [Nat.mod_eq_of_lt] - omega - _ < d.toNat := by omega - · ext i - simp - intros hi _ hi' - omega - -/-- -info: 'BitVec.tryDivSubtractShift_lt_of_lt' depends on axioms: [propext, Quot.sound, Classical.choice] --/ -#guard_msgs in #print axioms tryDivSubtractShift_lt_of_lt - -/-- repeatedly apply `tryDivSubtractShift`. -/ -def divRec (qr : DivRecQuotRem w n d) (j : Nat) : - DivRecQuotRem w n d := - let qr' := - match j with - | 0 => qr - | j + 1 => divRec qr j - tryDivSubtractShift qr' (w - 1 - j) - - -private theorem BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_add_zeroExtend_getLsb - {x : BitVec w} : - x >>> (w - 1) = ((x >>> w <<< 1) + (BitVec.ofBool (x.getLsb (w - 1))).zeroExtend w) := by - rw [BitVec.add_eq_or_of_and_eq_zero] - · apply BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_or_zeroExtend_getLsb - · ext i - simp only [getLsb_and, getLsb_shiftLeft, Fin.is_lt, decide_True, Bool.true_and, - getLsb_ushiftRight, getLsb_zeroExtend, getLsb_ofBool, getLsb_zero, and_eq_false_imp, - and_eq_true, not_eq_true', decide_eq_false_iff_not, Nat.not_lt, decide_eq_true_eq, and_imp] - intros i _ hi' - omega - -theorem divRec_correct {w : Nat} {n d : BitVec w} {qr : DivRecQuotRem w n d} {j : Nat} - (hj : j ≤ w - 1) - (hqrd : qr.r < d) - (hqrn : n >>> (w - j) = qr.q * d + qr.r) - : - ((n >>> ((w - 1) - j) = (divRec qr j).q * d + (divRec qr j).r)) ∧ - (d.toNat * (divRec qr j).q.toNat + (divRec qr j).r.toNat < 2^w) ∧ - (divRec qr j).r < d := by - induction j generalizing qr - case zero => - constructor - · simp [divRec] - simp at hqrn - simp [tryDivSubtractShift_eq_tryDivSubtractShift'] - simp [tryDivSubtractShift'] - generalize hb : n.getLsb (w - 1) = b - generalize hs : qr.r <<< 1 ||| zeroExtend w (ofBool b) = s - have qd : qr.q <<< 1 * d = (qr.q * d) <<< 1 := by - rw [BitVec.shiftLeft_mul_comm] - rw [BitVec.shiftLeft_mul_assoc] - by_cases hslt : s < d -- Note that the proof is identical on both sides of the case split. - · simp [hslt] - rw [← hs] - rw [qd] - rw [BitVec.shiftLeft_eq_mul_twoPow] - rw [BitVec.shiftLeft_eq_mul_twoPow] - rw [← add_eq_or_of_and_eq_zero] - · rw [← BitVec.add_assoc] - rw [← BitVec.add_mul] - rw [← hqrn] - rw [← BitVec.shiftLeft_eq_mul_twoPow] - rw [BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_add_zeroExtend_getLsb, hb] - · ext i - simp - intros i _ hi' - omega - · simp [hslt] - rw [BitVec.shiftLeft_one_or_one_eq_shiftLeft_one_add_one] - rw [BitVec.add_mul] - simp only [BitVec.one_mul] - rw [BitVec.add_assoc] - rw [← BitVec.add_sub_assoc (by simp [hslt])] - rw [BitVec.add_sub_self_left] - rw [← hs] - rw [qd] - rw [BitVec.shiftLeft_eq_mul_twoPow] - rw [BitVec.shiftLeft_eq_mul_twoPow] - rw [← add_eq_or_of_and_eq_zero] - · rw [← BitVec.add_assoc] - rw [← BitVec.add_mul] - rw [← hqrn] - rw [← BitVec.shiftLeft_eq_mul_twoPow] - rw [BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_add_zeroExtend_getLsb, hb] - · ext i - simp - intros i _ hi' - omega - · constructor - · simp [divRec] - simp at hqrn - sorry - · -- r < d - simp [divRec] - apply tryDivSubtractShift_lt_of_lt - apply hqrd - sorry - case succ j' ih => - sorry - - -theorem div_eq_divRec (n d : BitVec w) (hd : d > 0) : - let qr := divRec (w := w) (n := n) (d := d) { r := 0, q := 0 } (w - 1) - n.udiv d = qr.q ∧ - n.umod d = qr.r := by - obtain ⟨h₁, h₂, h₃⟩ := divRec_correct (w := w) (n := n) (d := d) (j := w - 1) (qr := { r := 0, q := 0}) - (by omega) - (by simpa using hd) - (by simp [show (w - (w - 1) = 1) by omega];) - simp at h₃ - simp at h₂ - simp at h₁ - have k := div_characterized_of_mul_add_of_lt (d := d) (n := n) - (q := (divRec (w := w) (n := n) (d := d) { r := 0, q := 0 } (w - 1)).q) - (r := (divRec (w := w) (n := n) (d := d) { r := 0, q := 0 } (w - 1)).r) - hd - h₃ - (by rw [BitVec.mul_comm]; simp_all) - (by simp_all) - simp [k] - -def checkDivRec : Bool × Array String := Id.run do - let w := 4 - let max := (Nat.pow 2 w) - let mut outputs := #[] - let mut wrong := false - for n in (List.range max) do - for d in (List.range (max - 1)).map (fun n => Nat.add n 1) do - have hd : d > 0 := by sorry - let qr := divRec (w := w) (n := n) (d := d) { r := 0, q := 0 } (w - 1) - if qr.q * d + qr.r != n then - outputs := outputs.push s!"ERROR: n = {n}, d = {d}, q = {qr.q}, r = {qr.r}, n = {n}, d = {d}, q = {qr.q}, r = {qr.r}" - wrong := true - (wrong, outputs) - -/-- info: (false, { data := [] }) -/ -#guard_msgs in #reduce checkDivRec - --- invariants: --- 1) qr.r < d. -theorem div_rec_7_2 : - (divRec (w := 4) (n := 7) (d := 2) { r := 0, q := 0 } 3) = - { r := 1, q := 3 } := by - simp [divRec, tryDivSubtractShift] - end BitVec From 935d35f5dfd37705a6797dfeab800459f726ac4e Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 5 Jul 2024 22:22:32 +0100 Subject: [PATCH 61/64] chore: establish correctness of non-dependent statement --- src/Init/Data/BitVec/Bitblast.lean | 112 ++++++++++++++++++++++++++++- 1 file changed, 109 insertions(+), 3 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 152435e08fc9..0e45b2e5e8a2 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -958,6 +958,16 @@ def DivRemInput_init (w : Nat) (n d : BitVec w) (hw : 0 < w) (hd : 0#w < d) : apply Nat.div_eq_of_lt n.isLt } +@[simp] +theorem DivRemInput_init_q (w : Nat) (n d : BitVec w) (hw : 0 < w) (hd : 0#w < d) : + (DivRemInput_init w n d hw hd).q = 0#w := by + rfl + +@[simp] +theorem DivRemInput_init_r (w : Nat) (n d : BitVec w) (hw : 0 < w) (hd : 0#w < d) : + (DivRemInput_init w n d hw hd).r = 0#w := by + rfl + theorem DivRemInput_implies_udiv_urem (h : DivRemInput w w 0 n d) : n.udiv d = h.q ∧ n.umod d = h.r := by @@ -1134,9 +1144,6 @@ theorem ShiftSubtractInput.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb One round of the division algorithm, that tries to perform a subtract shift. Note that this is only called when `r.msb = false`, so we will not overflow. This means that `r'.toNat = r.toNat *2 + q.toNat` -. - -TODO: think of isolating the pattern as `concatBit'`. -/ def divSubtractShift (h : ShiftSubtractInput w wr wn n d) : DivRemInput w (wr + 1) (wn - 1) n d := @@ -1326,4 +1333,103 @@ theorem divRec'_correct (n d : BitVec w) (hw : 0 < w) (hd : 0 < d) : simp apply DivRemInput_implies_udiv_urem +def divSubtractShiftNonDep (n q r d : BitVec w) (wn : Nat) : BitVec w × BitVec w := + let r' := concatBit' r (n.getLsb (wn - 1)) + let rltd : Bool := r' < d + let q := q.concatBit' !rltd + if rltd + then (q, r') + else (q, r' - d) + +@[simp] +theorem DivRemInput.toShiftSubtractInput_r_eq_r + (h : DivRemInput w wr (wn + 1) n d) : + (h.toShiftSubtractInput).r = h.r := by + simp [toShiftSubtractInput] + +@[simp] +theorem DivRemInput.toShiftSubtractInput_q_eq_q + (h : DivRemInput w wr (wn + 1) n d) : + (h.toShiftSubtractInput).q = h.q := by + simp only [toShiftSubtractInput] + +theorem divSubtractShift_eq_divSubtractShiftNonDep + (h : ShiftSubtractInput w wr wn n d) : + ((divSubtractShift h).q, (divSubtractShift h).r) = divSubtractShiftNonDep n h.q h.r d wn := by + simp [divSubtractShift, divSubtractShiftNonDep, ShiftSubtractInput.nmsb] + by_cases h : h.r.concatBit' (n.getLsb (wn - 1)) < d <;> + simp only [h, ↓reduceDite, decide_True, Bool.not_true, ↓reduceIte] + +@[simp] +theorem q_divSubtractShift_eq_fst_divSubtractShiftNonDep' + (h : DivRemInput w wr (wn + 1) n d) : + (divSubtractShift h.toShiftSubtractInput).q = + (divSubtractShiftNonDep n h.q h.r d (wn + 1)).fst := by + simp [divSubtractShift, + divSubtractShiftNonDep, + ShiftSubtractInput.nmsb] + by_cases cond : h.r.concatBit' (n.getLsb wn) < d <;> + simp only [cond, ↓reduceDite, decide_True, Bool.not_true, ↓reduceIte] + +@[simp] +theorem r_divSubtractShift_eq_snd_divSubtractShiftNonDep' + (h : DivRemInput w wr (wn + 1) n d) : + (divSubtractShift h.toShiftSubtractInput).r = + (divSubtractShiftNonDep n h.q h.r d (wn + 1)).snd := by + simp [divSubtractShift, + divSubtractShiftNonDep, + ShiftSubtractInput.nmsb] + by_cases cond : h.r.concatBit' (n.getLsb wn) < d <;> + simp only [cond, ↓reduceDite, decide_True, Bool.not_true, ↓reduceIte] + +theorem divSubtractShift_eq_divSubtractShiftNonDep' + (h : DivRemInput w wr (wn + 1) n d) : + ((divSubtractShift h.toShiftSubtractInput).q, (divSubtractShift h.toShiftSubtractInput).r) = + divSubtractShiftNonDep n h.q h.r d (wn + 1) := by + simp [divSubtractShift, divSubtractShiftNonDep, ShiftSubtractInput.nmsb] + by_cases h : h.r.concatBit' (n.getLsb wn) < d <;> + simp only [h, ↓reduceDite, decide_True, Bool.not_true, ↓reduceIte] + +def divRecNondep (n q r d : BitVec w) (wn : Nat) : + BitVec w × BitVec w := + match wn with + | 0 => (q, r) + | wn + 1 => + let (q', r') := divSubtractShiftNonDep n q r d (wn + 1) + divRecNondep n q' r' d wn + +theorem divRec_eq_divRecNonDep (h h' : DivRemInput w wr wn n d) + (hh' : h.q = h'.q ∧ h.r = h'.r): + ((divRec' h).q, (divRec' h).r) = divRecNondep n h'.q h'.r d wn := by + induction wn generalizing w wr n d + case zero => + simp [divRec', divRecNondep, DivRemInput.wr_eq_w_of_wn_eq_zero] + simp [hh'.1, hh'.2] + case succ wn ih => + simp [divRecNondep, divRec'] + rw[← divSubtractShift_eq_divSubtractShiftNonDep'] + apply ih <;> + simp [q_divSubtractShift_eq_fst_divSubtractShiftNonDep', + r_divSubtractShift_eq_snd_divSubtractShiftNonDep', + hh'.1, hh'.2] + + theorem divRecNonDep_correct (n d : BitVec w) (hw : 0 < w) (hd : 0 < d) : + let out := divRecNondep n 0#w 0#w d w + n.udiv d = out.fst ∧ n.umod d = out.snd := by + simp + have heq := divRec_eq_divRecNonDep (DivRemInput_init w n d hw hd) (DivRemInput_init w n d hw hd) + (by simp) + simp at heq + have hcorrect := divRec'_correct n d hw hd + obtain ⟨hqcorrect, hrcorrect⟩ := hcorrect + rw [hqcorrect, hrcorrect] + have heq_q : (divRec' (DivRemInput_init w n d hw hd)).q = + (n.divRecNondep (0#w) (0#w) d w).fst := by + rw [← heq] + have heq_r : (divRec' (DivRemInput_init w n d hw hd)).r = + (n.divRecNondep (0#w) (0#w) d w).snd := by + rw [← heq] + rw [heq_q, heq_r] + simp + end BitVec From c83cf0360724156313529f5664409c1d542b35e2 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 5 Jul 2024 22:36:09 +0100 Subject: [PATCH 62/64] chore: write _zero and _succ theorems for @hargonix to be able to bitblast --- src/Init/Data/BitVec/Bitblast.lean | 34 +++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 0e45b2e5e8a2..a604394196fa 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -1413,7 +1413,32 @@ theorem divRec_eq_divRecNonDep (h h' : DivRemInput w wr wn n d) r_divSubtractShift_eq_snd_divSubtractShiftNonDep', hh'.1, hh'.2] - theorem divRecNonDep_correct (n d : BitVec w) (hw : 0 < w) (hd : 0 < d) : +-- def concatBit' (x : BitVec w) (b : Bool) : BitVec w := +-- x <<< 1 ||| (BitVec.ofBool b).zeroExtend w + +theorem divSubtractShiftNonDep_fst (n q r d : BitVec w) (wn : Nat) : + (divSubtractShiftNonDep n q r d wn).fst = + q.concatBit' !decide (r.concatBit' (n.getLsb (wn - 1)) < d) := by + simp [divSubtractShiftNonDep] + by_cases h : r.concatBit' (n.getLsb (wn - 1)) < d <;> + simp [h] + +theorem divSubtractShiftNonDep_snd (n q r d : BitVec w) (wn : Nat) : + (divSubtractShiftNonDep n q r d wn).snd = + if r.concatBit' (n.getLsb (wn - 1)) < d then r.concatBit' (n.getLsb (wn - 1)) + else r.concatBit' (n.getLsb (wn - 1)) - d := by + simp [divSubtractShiftNonDep] + by_cases h : r.concatBit' (n.getLsb (wn - 1)) < d <;> simp [h] + +theorem divRecNonDep_zero (n q r d : BitVec w) : divRecNondep n q r d 0 = (q, r) := by simp [divRecNondep] + +theorem divRecNonDep_succ (n q r d : BitVec w) (wn : Nat) : + (divRecNondep n q r d (wn + 1) = + divRecNondep n (divSubtractShiftNonDep n q r d (wn + 1)).1 + (divSubtractShiftNonDep n q r d (wn + 1)).2 d wn) := by + simp [divRecNondep, divSubtractShiftNonDep] + +theorem divRecNonDep_correct (n d : BitVec w) (hw : 0 < w) (hd : 0 < d) : let out := divRecNondep n 0#w 0#w d w n.udiv d = out.fst ∧ n.umod d = out.snd := by simp @@ -1429,7 +1454,10 @@ theorem divRec_eq_divRecNonDep (h h' : DivRemInput w wr wn n d) have heq_r : (divRec' (DivRemInput_init w n d hw hd)).r = (n.divRecNondep (0#w) (0#w) d w).snd := by rw [← heq] - rw [heq_q, heq_r] - simp + simp [heq_q, heq_r] +/-- +info: 'BitVec.divRecNonDep_correct' depends on axioms: [propext, Classical.choice, Quot.sound] +-/ +#guard_msgs in #print axioms divRecNonDep_correct end BitVec From a20b14dd291a76705da50396cdc0ed5c0d250c65 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Wed, 10 Jul 2024 13:03:47 +0100 Subject: [PATCH 63/64] chore: delete new file --- src/Init/Data/BitVec/div_new_invariant.py | 80 ----------------------- 1 file changed, 80 deletions(-) delete mode 100755 src/Init/Data/BitVec/div_new_invariant.py diff --git a/src/Init/Data/BitVec/div_new_invariant.py b/src/Init/Data/BitVec/div_new_invariant.py deleted file mode 100755 index 151ae6b3b9bd..000000000000 --- a/src/Init/Data/BitVec/div_new_invariant.py +++ /dev/null @@ -1,80 +0,0 @@ -#!/usr/bin/env python3 - -def get_lsb(n, j): - return int(bool(n & (1 << j))) - -def print_bits(w, n): - return ("{0:0%sb}" % (w)).format(n) - -def check_pre_invariant(w, n, d, q, r, j): - qright = n // d - rright = n % d - assert r < d - -# n / d <-> n = q * d + r -def check_post_rec_invariant(w, n, d, q, r, j): - qright = n // d - rright = n % d - assert r < d - assert n >> (w - j) == d * q + r - -# n / d <-> n = q * d + r -def check_final_invariant(w, n, d, q, r, j): - qright = n // d - rright = n % d - assert r < d - assert n >> ((w - 1) - j) == d * q + r - -def shift_subtract(w, n, d, q, r, j): - print(f"shift_subtract> n: '%s' | d: '%s' | q : '%s' | r : '%s' | j : '%s'" % - (print_bits(w, n), print_bits(w, d), print_bits(w, q), print_bits(w, r), j)) - print(f" n[%s] = %s" % (j, get_lsb(n, j))) - check_pre_invariant(w, n, d, q, r, j) - if j > 0: - (q, r) = shift_subtract(w, n, d, q, r, j-1) - check_post_rec_invariant(w, n, d, q, r, j) - - # do the last bit. - ix = (w - 1) - j - assert ix >= 0 - r = (r << 1) | get_lsb(n, ix) - assert r < 2 ** w # how is this loop invariant upheld, right after doing weird operations? Very weird. - print(f" r = %s" % print_bits(w, r)) - if r >= d: - print(f" r > d.") - r -= d - q = (q << 1) | 1 - print(f" r.new = %s" % print_bits(w, r)) - print(f" q.new = %s" % print_bits(w, q)) - else: - print(f" r < d.") - q = (q << 1) - print(f" r.new = %s" % print_bits(w, r)) - print(f" q = %s" % print_bits(w, q)) - check_final_invariant(w, n, d, q, r, j) - return (q, r) - - - -w = 4 -d = 10 # d * 2 will overflow. -(q, r) = shift_subtract(w, n, d, 0, 0, w-1) -assert n == d * q + r -if n == d * q + r and r < d: - print ("verified correct invariant for n: '%s' | d : '%s' | q : '%s' r: '%s'" % - (n, d, q, r)) -else: - raise RuntimeError("verification failed for n: '%s' | d: '%s'" % (n, d)) -# 10 / 3 = 3 -for n in range(1, 32): - for d in range(1, 32): - w = 6 - (q, r) = shift_subtract(w, n, d, 0, 0, w-1) - assert n == d * q + r - if n == d * q + r and r < d: - print ("verified correct invariant for n: '%s' | d : '%s' | q : '%s' r: '%s'" % - (n, d, q, r)) - else: - raise RuntimeError("verification failed for n: '%s' | d: '%s'" % (n, d)) - - From 09016e7e66822b1957476906405a8f3bf726b88c Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Wed, 10 Jul 2024 15:51:48 +0100 Subject: [PATCH 64/64] chore: added sshiftRight as well --- src/Init/Data/BitVec/Basic.lean | 2 + src/Init/Data/BitVec/Bitblast.lean | 112 ++++++++++++++++++++++++++++- src/Init/Data/BitVec/Lemmas.lean | 6 -- 3 files changed, 111 insertions(+), 9 deletions(-) diff --git a/src/Init/Data/BitVec/Basic.lean b/src/Init/Data/BitVec/Basic.lean index eb33296d1e26..33106dbccfd1 100644 --- a/src/Init/Data/BitVec/Basic.lean +++ b/src/Init/Data/BitVec/Basic.lean @@ -531,6 +531,8 @@ SMT-Lib name: `bvashr` except this operator uses a `Nat` shift value. -/ def sshiftRight (a : BitVec n) (s : Nat) : BitVec n := .ofInt n (a.toInt >>> s) +def sshiftRight' (a : BitVec n) (s : BitVec m) : BitVec n := a.sshiftRight s.toNat + instance {n} : HShiftLeft (BitVec m) (BitVec n) (BitVec m) := ⟨fun x y => x <<< y.toNat⟩ instance {n} : HShiftRight (BitVec m) (BitVec n) (BitVec m) := ⟨fun x y => x >>> y.toNat⟩ diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index a604394196fa..5f142f389ce3 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -578,16 +578,122 @@ theorem shiftRight_eq_shiftRight_rec (x : BitVec ℘) (y : BitVec w₂) : /- ### Arithmetic (sshiftRight) recurrence -/ +@[simp] +theorem sshiftRight'_zero (x : BitVec w) : + x.sshiftRight' (0#w₂) = x := by + ext i + rw [sshiftRight', getLsb_sshiftRight] + simp + def sshiftRightRec (x : BitVec w) (y : BitVec w₂) (n : Nat) : BitVec w := let shiftAmt := (y &&& (twoPow w₂ n)) match n with | 0 => x.sshiftRight' shiftAmt - | n + 1 => (sshiftRightRec x y n) >>> shiftAmt + | n + 1 => (sshiftRightRec x y n).sshiftRight' shiftAmt -theorem sshiftRight_eq_sshiftRightRec (x : BitVec w₁) (y : BitVec w₂) : - (x >>> y).getLsb i = (sshiftRightRec x y w).getLsb i := sorry +@[simp] +theorem sshiftRightRec_zero_eq (x : BitVec w) (y : BitVec w₂) : + sshiftRightRec x y 0 = x.sshiftRight' (y &&& 1#w₂) := by + simp [sshiftRightRec] + +@[simp] +theorem sshiftRightRec_succ_eq (x : BitVec w) (y : BitVec w₂) (n : Nat) : + sshiftRightRec x y (n + 1) = (sshiftRightRec x y n).sshiftRight' (y &&& twoPow w₂ (n + 1)) := by + simp [sshiftRightRec] + +/-- The msb after arithmetic shifting right equals the original msb. -/ +theorem sshiftRight_msb_eq_msb {n : Nat} {x : BitVec w} : + (x.sshiftRight n).msb = x.msb := by + rw [msb_eq_getLsb_last, getLsb_sshiftRight] + rcases w with rfl | w + · simp + · simp only [Nat.add_sub_cancel] + simp [show ¬ (w + 1 ≤ w) by omega] + intros h + rw [msb_eq_getLsb_last] + simp only [Nat.add_sub_cancel] + simp [show n + w = w by omega] + +theorem sshiftRight_sshiftRight {x : BitVec w} {m n : Nat} : + (x.sshiftRight m).sshiftRight n = x.sshiftRight (m + n) := by + ext i + simp only [getLsb_sshiftRight] + simp only [Nat.add_assoc] + by_cases h₁ : w ≤ (i : Nat) + · simp [h₁] + · simp only [h₁, decide_False, Bool.not_false, Bool.true_and] + by_cases h₂ : n + ↑i < w + · simp [h₂] + · simp only [h₂, ↓reduceIte] + by_cases h₃ : m + (n + ↑i) < w + · simp [h₃] + omega + · simp [h₃] + apply sshiftRight_msb_eq_msb + +theorem sshiftRight'_sshiftRight' {x : BitVec w₁} {y : BitVec w₂} {z : BitVec w₃} : + (x.sshiftRight' y).sshiftRight' z = x.sshiftRight (y.toNat + z.toNat) := by + simp [sshiftRight', shiftRight_shiftRight, sshiftRight_sshiftRight] + + +theorem sshiftRight_or_eq_sshiftRight_sshiftRight_of_and_eq_zero {x : BitVec w} {y z : BitVec w₂} + (h : y &&& z = 0#w₂) (h' : y.toNat + z.toNat < 2^w₂): + x.sshiftRight' (y ||| z) = (x.sshiftRight' y).sshiftRight' z := by + simp [← add_eq_or_of_and_eq_zero _ _ h] + simp [BitVec.sshiftRight'] + simp [sshiftRight_sshiftRight] + rw [Nat.mod_eq_of_lt h'] + +theorem sshiftRightRec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) (hn : n + 1 ≤ w₂) : + sshiftRightRec x y n = x.sshiftRight' ((y.truncate (n + 1)).zeroExtend w₂) := by + induction n generalizing x y + case zero => + ext i + simp only [ushiftRight_rec_zero, twoPow_zero_eq_one, Nat.reduceAdd, truncate_one_eq_ofBool_getLsb] + have heq : (y &&& 1#w₂) = zeroExtend w₂ (ofBool (y.getLsb 0)) := by + ext i + by_cases h : (↑i : Nat) = 0 <;> simp [h, Bool.and_comm] + simp [heq] + case succ n ih => + simp + by_cases h : y.getLsb (n + 1) <;> simp [h] + · rw [ih (hn := by omega)] + rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true _ _ h] + rw [sshiftRight_or_eq_sshiftRight_sshiftRight_of_and_eq_zero] + · simp + · simp; + have hpow : 2 ^ (n + 1) < 2 ^ w₂ := by + apply Nat.pow_lt_pow_of_lt (by decide) (by omega) + have h₂ : 2 ^ (n + 1) % 2 ^ w₂ = 2 ^ (n + 1) := Nat.mod_eq_of_lt (by omega) + have h₁ : y.toNat % 2 ^ (n + 1) % 2 ^ w₂ = y.toNat % 2 ^ (n + 1) := by + apply Nat.mod_eq_of_lt + apply Nat.lt_of_lt_of_le (m := 2 ^ (n + 1)) + apply Nat.mod_lt + apply Nat.pow_pos (by decide); omega + obtain h₁ : y.toNat % 2 ^ (n + 1) % 2 ^ w₂ = y.toNat % 2 ^ (n + 1) := by + apply Nat.mod_eq_of_lt + apply Nat.lt_of_lt_of_le (m := 2 ^ (n + 1)) <;> omega + rw [h₁, h₂] + rcases w₂ with rfl | w₂ + · omega + · apply Nat.add_lt_add_of_lt_of_le + · simp only [pow_eq, Nat.mul_eq, Nat.mul_one] + apply Nat.lt_of_lt_of_le (m := 2 ^ (n + 1)) + · apply Nat.mod_lt + · apply Nat.pow_pos (by decide) + · apply Nat.pow_le_pow_of_le_right (by decide) (by omega) + · simp + apply Nat.pow_le_pow_of_le_right (by decide) (by omega) + · rw [ih (hn := by omega)] + rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1)] + simp [h] +theorem sshiftRight_eq_sshiftRightRec (x : BitVec w₁) (y : BitVec w₂) : + (x.sshiftRight' y).getLsb i = (sshiftRightRec x y (w₂ - 1)).getLsb i := by + rcases w₂ with rfl | w₂ + · simp [of_length_zero] + · simp [sshiftRightRec_eq x y w₂ (by omega)] /- ## udiv/urem bitblasting -/ diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index e5cf7f566c87..48afd622d602 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -771,16 +771,10 @@ theorem getLsb_sshiftRight_eq_getLsb_ushiftRight (x : BitVec w) (s i : Nat) : rw [h] simp [getLsb_sshiftRight] -/-- A version of `BitVec.sshiftRight` with both arguments as bitvectors. -/ -def sshiftRight' (x : BitVec w₁) (y : BitVec w₂) : BitVec w₁ := - x.sshiftRight y.toNat - theorem getLsb_sshift'_eq_getLsb_sshiftRight : getLsb (sshiftRight' x y) i = getLsb (x.sshiftRight y.toNat) i := by simp [sshiftRight'] --- theorem getLsb_sshiftRight'_ - /-! ### udiv -/ theorem udiv_eq {x y : BitVec n} :