Skip to content

Commit

Permalink
feat: BitVec.{toInt, toFin, msb}_udiv (#6402)
Browse files Browse the repository at this point in the history
This PR adds a `toFin` and `msb` lemma for unsigned bitvector division.
We *don't* have `toInt_udiv`, since the only truly general statement we
can make does no better than unfolding the definition, and it's not
uncontroversially clear how to unfold `toInt` (see
`toInt_eq_msb_cond`/`toInt_eq_toNat_cond`/`toInt_eq_toNat_bmod` for a
few options currently provided). Instead, we do have `toInt_udiv_of_msb`
that's able to provide a more meaningful rewrite given an extra
side-condition (that `x.msb = false`).

This PR also upstreams a minor `Nat` theorem (`Nat.div_le_div_left`)
needed for the above from Mathlib.

---------

Co-authored-by: Kim Morrison <[email protected]>
  • Loading branch information
alexkeizer and kim-em authored Jan 10, 2025
1 parent c07948a commit d2c4471
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 5 deletions.
63 changes: 58 additions & 5 deletions src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import Init.Data.Bool
import Init.Data.BitVec.Basic
import Init.Data.Fin.Lemmas
import Init.Data.Nat.Lemmas
import Init.Data.Nat.Div.Lemmas
import Init.Data.Nat.Mod
import Init.Data.Int.Bitwise.Lemmas
import Init.Data.Int.Pow
Expand Down Expand Up @@ -442,6 +443,10 @@ theorem toInt_eq_toNat_cond (x : BitVec n) :
(x.toNat : Int) - (2^n : Nat) :=
rfl

theorem toInt_eq_toNat_of_lt {x : BitVec n} (h : 2 * x.toNat < 2^n) :
x.toInt = x.toNat := by
simp [toInt_eq_toNat_cond, h]

theorem msb_eq_false_iff_two_mul_lt {x : BitVec w} : x.msb = false2 * x.toNat < 2^w := by
cases w <;> simp [Nat.pow_succ, Nat.mul_comm _ 2, msb_eq_decide, toNat_of_zero_length]

Expand All @@ -454,6 +459,9 @@ theorem toInt_eq_msb_cond (x : BitVec w) :
simp only [BitVec.toInt, ← msb_eq_false_iff_two_mul_lt]
cases x.msb <;> rfl

theorem toInt_eq_toNat_of_msb {x : BitVec w} (h : x.msb = false) :
x.toInt = x.toNat := by
simp [toInt_eq_msb_cond, h]

theorem toInt_eq_toNat_bmod (x : BitVec n) : x.toInt = Int.bmod x.toNat (2^n) := by
simp only [toInt_eq_toNat_cond]
Expand Down Expand Up @@ -2300,6 +2308,12 @@ theorem ofNat_sub_ofNat {n} (x y : Nat) : BitVec.ofNat n x - BitVec.ofNat n y =
@[simp, bv_toNat] theorem toNat_neg (x : BitVec n) : (- x).toNat = (2^n - x.toNat) % 2^n := by
simp [Neg.neg, BitVec.neg]

theorem toNat_neg_of_pos {x : BitVec n} (h : 0#n < x) :
(- x).toNat = 2^n - x.toNat := by
change 0 < x.toNat at h
rw [toNat_neg, Nat.mod_eq_of_lt]
omega

theorem toInt_neg {x : BitVec w} :
(-x).toInt = (-x.toInt).bmod (2 ^ w) := by
rw [← BitVec.zero_sub, toInt_sub]
Expand Down Expand Up @@ -2586,13 +2600,13 @@ theorem udiv_def {x y : BitVec n} : x / y = BitVec.ofNat n (x.toNat / y.toNat) :
rw [← udiv_eq]
simp [udiv, bv_toNat, h, Nat.mod_eq_of_lt]

@[simp]
theorem toFin_udiv {x y : BitVec n} : (x / y).toFin = x.toFin / y.toFin := by
rfl

@[simp, bv_toNat]
theorem toNat_udiv {x y : BitVec n} : (x / y).toNat = x.toNat / y.toNat := by
rw [udiv_def]
by_cases h : y = 0
· simp [h]
· rw [toNat_ofNat, Nat.mod_eq_of_lt]
exact Nat.lt_of_le_of_lt (Nat.div_le_self ..) (by omega)
rfl

@[simp]
theorem zero_udiv {x : BitVec w} : (0#w) / x = 0#w := by
Expand Down Expand Up @@ -2628,6 +2642,45 @@ theorem udiv_self {x : BitVec w} :
↓reduceIte, toNat_udiv]
rw [Nat.div_self (by omega), Nat.mod_eq_of_lt (by omega)]

theorem msb_udiv (x y : BitVec w) :
(x / y).msb = (x.msb && y == 1#w) := by
cases msb_x : x.msb
· suffices x.toNat / y.toNat < 2 ^ (w - 1) by simpa [msb_eq_decide]
calc
x.toNat / y.toNat ≤ x.toNat := by apply Nat.div_le_self
_ < 2 ^ (w - 1) := by simpa [msb_eq_decide] using msb_x
. rcases w with _|w
· contradiction
· have : (y == 1#_) = decide (y.toNat = 1) := by
simp [(· == ·), toNat_eq]
simp only [this, Bool.true_and]
match hy : y.toNat with
| 0 =>
obtain rfl : y = 0#_ := eq_of_toNat_eq hy
simp
| 1 =>
obtain rfl : y = 1#_ := eq_of_toNat_eq (by simp [hy])
simpa using msb_x
| y + 2 =>
suffices x.toNat / (y + 2) < 2 ^ w by
simp_all [msb_eq_decide, hy]
calc
x.toNat / (y + 2)
≤ x.toNat / 2 := by apply Nat.div_add_le_right (by omega)
_ < 2 ^ w := by omega

theorem msb_udiv_eq_false_of {x : BitVec w} (h : x.msb = false) (y : BitVec w) :
(x / y).msb = false := by
simp [msb_udiv, h]

/--
If `x` is nonnegative (i.e., does not have its msb set),
then `x / y` is nonnegative, thus `toInt` and `toNat` coincide.
-/
theorem toInt_udiv_of_msb {x : BitVec w} (h : x.msb = false) (y : BitVec w) :
(x / y).toInt = x.toNat / y.toNat := by
simp [toInt_eq_msb_cond, msb_udiv_eq_false_of h]

/-! ### umod -/

theorem umod_def {x y : BitVec n} :
Expand Down
8 changes: 8 additions & 0 deletions src/Init/Data/Nat/Div/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,12 @@ theorem lt_div_mul_self (h : 0 < k) (w : k ≤ x) : x - k < x / k * k := by
have : x % k < k := mod_lt x h
omega

theorem div_le_div_left (hcb : c ≤ b) (hc : 0 < c) : a / b ≤ a / c :=
(Nat.le_div_iff_mul_le hc).2 <|
Nat.le_trans (Nat.mul_le_mul_left _ hcb) (Nat.div_mul_le_self a b)

theorem div_add_le_right {z : Nat} (h : 0 < z) (x y : Nat) :
x / (y + z) ≤ x / z :=
div_le_div_left (Nat.le_add_left z y) h

end Nat

0 comments on commit d2c4471

Please sign in to comment.