From 8b340c313bd8a6132e24526e24bbae2cf93489ad Mon Sep 17 00:00:00 2001 From: Joe Hendrix Date: Thu, 16 Nov 2023 10:26:13 -0800 Subject: [PATCH] Additional BitVec definition changes. --- Std/Data/BitVec/Basic.lean | 17 +++- Std/Data/Nat/Bitwise.lean | 203 +++++++++++++++++++++---------------- Std/Data/Nat/Lemmas.lean | 20 ++++ 3 files changed, 150 insertions(+), 90 deletions(-) diff --git a/Std/Data/BitVec/Basic.lean b/Std/Data/BitVec/Basic.lean index 9c94150801..80f2f55c72 100644 --- a/Std/Data/BitVec/Basic.lean +++ b/Std/Data/BitVec/Basic.lean @@ -44,7 +44,11 @@ namespace BitVec /-- The `BitVec` with value `i mod 2^n`. Treated as an operation on bitvectors, this is truncation of the high bits when downcasting and zero-extension when upcasting. -/ protected def ofNat (n : Nat) (i : Nat) : BitVec n where - toFin := Fin.ofNat' i (Nat.pow_two_pos _) + toFin := + let p : i &&& 2^n-1 < 2^n := by + apply Nat.land_lt_2_pow + exact Nat.sub_lt (Nat.pow_two_pos n) (Nat.le_refl 1) + ⟨i &&& 2^n-1, p⟩ /-- Given a bitvector `a`, return the underlying `Nat`. This is O(1) because `BitVec` is a (zero-cost) wrapper around a `Nat`. -/ @@ -80,7 +84,7 @@ protected def toInt (a : BitVec n) : Int := if a.msb then Int.ofNat a.toNat - Int.ofNat (2^n) else a.toNat /-- Return a bitvector `0` of size `n`. This is the bitvector with all zero bits. -/ -protected def zero (n : Nat) : BitVec n := .ofNat n 0 +protected def zero (n : Nat) : BitVec n := ⟨0, Nat.pow_two_pos n⟩ instance : Inhabited (BitVec n) where default := .zero n @@ -282,7 +286,7 @@ Bitwise AND for bit vectors. SMT-Lib name: `bvand`. -/ protected def and (x y : BitVec n) : BitVec n where toFin := - ⟨x.toNat &&& y.toNat, Nat.land_lt_2_pow x.isLt y.isLt⟩ + ⟨x.toNat &&& y.toNat, Nat.land_lt_2_pow x.toNat y.isLt⟩ instance : AndOp (BitVec w) := ⟨.and⟩ /-- @@ -437,7 +441,12 @@ If `v < w` then it truncates the high bits instead. SMT-Lib name: `zero_extend`. -/ -def zeroExtend (v : Nat) (x : BitVec w) : BitVec v := .ofNat v x.toNat +def zeroExtend (v : Nat) : BitVec w → BitVec v +| ⟨x, x_lt⟩ => + if h : w ≤ v then + ⟨x, Nat.lt_of_lt_of_le x_lt (Nat.pow_le_pow_of_le_right (by trivial : 2 > 0) h)⟩ + else + .ofNat v x /-- Truncate the high bits of bitvector `x` of length `w`, resulting in a vector of length `v`. diff --git a/Std/Data/Nat/Bitwise.lean b/Std/Data/Nat/Bitwise.lean index 95854e28e0..ba5461b7de 100644 --- a/Std/Data/Nat/Bitwise.lean +++ b/Std/Data/Nat/Bitwise.lean @@ -34,25 +34,126 @@ theorem div2InductionOn have p : x/2 < x := Nat.div_lt_self x_pos (Nat.le_refl _) apply induct _ x_pos (ind _ p) +/-! ### bitwise -/ -/-! ### testBit -/ +@[local simp] +private theorem eq_0_of_lt_one (x:Nat) : x < 1 ↔ x = 0 := + Iff.intro + (fun p => + match x with + | 0 => Eq.refl 0 + | _+1 => False.elim (not_lt_zero _ (Nat.lt_of_succ_lt_succ p))) + (fun p => by simp [p, Nat.zero_lt_succ]) -theorem zero_testBit (i:Nat) : testBit 0 i = false := by - unfold testBit - simp [zero_shiftRight] +private theorem eq_0_of_lt (x:Nat) : x < 2^ 0 ↔ x = 0 := eq_0_of_lt_one x -theorem testBit_succ (x:Nat) : testBit x (succ i) = testBit (x >>> 1) i := by - unfold testBit - simp [shiftRight_succ_inside] +@[local simp] +private theorem zero_lt_pow (n:Nat) : 0 < 2^n := by + induction n + case zero => simp [eq_0_of_lt] + case succ n hyp => + simp [pow_succ] + exact (Nat.mul_lt_mul_of_pos_right hyp (by trivial : 2 > 0) : 0 < 2 ^ n * 2) +private +theorem div_2_le_of_lt_two {m n : Nat} (p : m < 2 ^ succ n) : m / 2 < 2^n := by + simp [div_lt_iff_lt_mul (by trivial : 0 < 2)] + exact p + +/-- This provides a bound on bitwise operations. -/ +theorem bitwise_lt_2_pow (left : x < 2^n) (right : y < 2^n) : (Nat.bitwise f x y) < 2^n := by + induction n generalizing x y with + | zero => + simp only [eq_0_of_lt] at left right + unfold bitwise + simp [left, right] + | succ n hyp => + unfold bitwise + if x_zero : x = 0 then + simp only [x_zero, if_true] + by_cases p : f false true = true <;> simp [p, right] + else if y_zero : y = 0 then + simp only [x_zero, y_zero, if_false, if_true] + by_cases p : f true false = true <;> simp [p, left] + else + simp only [x_zero, y_zero, if_false] + have hyp1 := hyp (div_2_le_of_lt_two left) (div_2_le_of_lt_two right) + by_cases p : f (decide (x % 2 = 1)) (decide (y % 2 = 1)) = true <;> + simp [p, pow_succ, mul_succ, Nat.add_assoc] + case pos => + apply lt_of_succ_le + simp only [← Nat.succ_add] + apply Nat.add_le_add <;> exact hyp1 + case neg => + apply Nat.add_lt_add <;> exact hyp1 + +/-! ### land -/ + +@[simp] theorem land_zero (x:Nat) : x &&& 0 = 0 := by simp [HAnd.hAnd, AndOp.and, land] unfold bitwise simp +theorem land_lt_2_pow (x : Nat) {y n : Nat} (right : y < 2^n) : (x &&& y) < 2^n := by + induction n generalizing x y with + | zero => + simp only [eq_0_of_lt] at right + simp [right] + | succ n hyp => + simp [HAnd.hAnd, AndOp.and, land] + unfold bitwise + if x_zero : x = 0 then + simp [x_zero, if_true, if_false] + else if y_zero : y = 0 then + simp [x_zero, y_zero, if_false, if_true] + else + simp only [x_zero, y_zero, if_false] + have hyp1 := hyp (x / 2) (div_2_le_of_lt_two right) + by_cases p : decide (x % 2 = 1) && decide (y % 2 = 1) <;> + simp [p, pow_succ, mul_succ, Nat.add_assoc] + case pos => + apply lt_of_succ_le + simp only [← Nat.succ_add] + apply Nat.add_le_add <;> exact hyp1 + case neg => + apply Nat.add_lt_add <;> exact hyp1 + +/-! ### lor -/ + +theorem lor_lt_2_pow {x y n : Nat} (left : x < 2^n) (right : y < 2^n) : (x ||| y) < 2^n := + bitwise_lt_2_pow left right + +/-! ### xor -/ + +theorem xor_lt_2_pow {x y n : Nat} (left : x < 2^n) (right : y < 2^n) : (x ^^^ y) < 2^n := + bitwise_lt_2_pow left right + +/-! ### shiftLeft -/ + +theorem shiftLeft_lt_2_pow {x m n : Nat} (bound : x < 2^(n-m)) : (x <<< m) < 2^n := by + induction m generalizing x n with + | zero => exact bound + | succ m hyp => + simp [shiftLeft_succ_inside] + apply hyp + revert bound + rw [Nat.sub_succ] + match n - m with + | 0 => + intro bound + simp [eq_0_of_lt_one] at bound + simp [bound] + | d + 1 => + intro bound + simp [Nat.pow_succ, Nat.mul_comm _ 2] + exact Nat.mul_lt_mul_of_pos_left bound (by trivial : 0 < 2) + +/-! ### testBit -/ + theorem testBit_zero_is_mod2 (x:Nat) : testBit x 0 = decide (x % 2 = 1) := by - rw [←div_add_mod x 2] simp [testBit] + rw [←div_add_mod x 2] simp [HAnd.hAnd, AndOp.and, land] unfold bitwise have one_div_2 : 1 / 2 = 0 := by trivial @@ -63,6 +164,14 @@ theorem testBit_zero_is_mod2 (x:Nat) : testBit x 0 = decide (x % 2 = 1) := by intro x_mod simp [x_mod, Nat.succ_add] +theorem zero_testBit (i:Nat) : testBit 0 i = false := by + unfold testBit + simp [zero_shiftRight] + +theorem testBit_succ (x:Nat) : testBit x (succ i) = testBit (x >>> 1) i := by + unfold testBit + simp [shiftRight_succ_inside] + theorem ne_zero_implies_bit_true {x : Nat} (p : x ≠ 0) : ∃ i, testBit x i := by induction x using div2InductionOn with | base => @@ -117,81 +226,3 @@ theorem eq_of_testBit_eq {x y : Nat} (pred : ∀i, testBit x i = testBit y i) : let ⟨i,eq⟩ := ne_implies_bit_diff h have p := pred i contradiction - -/-! ### bitwise and related -/ - -@[local simp] -private theorem eq_0_of_lt_one (x:Nat) : x < 1 ↔ x = 0 := - Iff.intro - (fun p => - match x with - | 0 => Eq.refl 0 - | _+1 => False.elim (not_lt_zero _ (Nat.lt_of_succ_lt_succ p))) - (fun p => by simp [p, Nat.zero_lt_succ]) - -private theorem eq_0_of_lt (x:Nat) : x < 2^ 0 ↔ x = 0 := eq_0_of_lt_one x - -@[local simp] -private theorem zero_lt_pow (n:Nat) : 0 < 2^n := by - induction n - case zero => simp [eq_0_of_lt] - case succ n hyp => - simp [pow_succ] - exact (Nat.mul_lt_mul_of_pos_right hyp (by trivial : 2 > 0) : 0 < 2 ^ n * 2) - -/-- This provides a bound on bitwise operations. -/ -theorem bitwise_lt_2_pow (left : x < 2^n) (right : y < 2^n) : (Nat.bitwise f x y) < 2^n := by - induction n generalizing x y with - | zero => - simp only [eq_0_of_lt] at left right - unfold bitwise - simp [left, right] - | succ n hyp => - unfold bitwise - if x_zero : x = 0 then - simp only [x_zero, if_true] - by_cases p : f false true = true <;> simp [p, right] - else if y_zero : y = 0 then - simp only [x_zero, y_zero, if_false, if_true] - by_cases p : f true false = true <;> simp [p, left] - else - simp only [x_zero, y_zero, if_false] - have lt : 0 < 2 := by trivial - have xlb : x / 2 < 2^n := by simp [div_lt_iff_lt_mul lt]; exact left - have ylb : y / 2 < 2^n := by simp [div_lt_iff_lt_mul lt]; exact right - have hyp1 := hyp xlb ylb - by_cases p : f (decide (x % 2 = 1)) (decide (y % 2 = 1)) = true <;> - simp [p, pow_succ, mul_succ, Nat.add_assoc] - case pos => - apply lt_of_succ_le - simp only [← Nat.succ_add] - apply Nat.add_le_add <;> exact hyp1 - case neg => - apply Nat.add_lt_add <;> exact hyp1 - -theorem lor_lt_2_pow {x y n : Nat} (left : x < 2^n) (right : y < 2^n) : (x ||| y) < 2^n := - bitwise_lt_2_pow left right - -theorem land_lt_2_pow {x y n : Nat} (left : x < 2^n) (right : y < 2^n) : (x &&& y) < 2^n := - bitwise_lt_2_pow left right - -theorem xor_lt_2_pow {x y n : Nat} (left : x < 2^n) (right : y < 2^n) : (x ^^^ y) < 2^n := - bitwise_lt_2_pow left right - -theorem shiftLeft_lt_2_pow {x m n : Nat} (bound : x < 2^(n-m)) : (x <<< m) < 2^n := by - induction m generalizing x n with - | zero => exact bound - | succ m hyp => - simp [shiftLeft_succ_inside] - apply hyp - revert bound - rw [Nat.sub_succ] - match n - m with - | 0 => - intro bound - simp [eq_0_of_lt_one] at bound - simp [bound] - | d + 1 => - intro bound - simp [Nat.pow_succ, Nat.mul_comm _ 2] - exact Nat.mul_lt_mul_of_pos_left bound (by trivial : 0 < 2) diff --git a/Std/Data/Nat/Lemmas.lean b/Std/Data/Nat/Lemmas.lean index 987b0c082d..4bb7bbcc35 100644 --- a/Std/Data/Nat/Lemmas.lean +++ b/Std/Data/Nat/Lemmas.lean @@ -740,6 +740,26 @@ protected theorem mul_self_sub_mul_self_eq (a b : Nat) : a * a - b * b = (a + b) rw [Nat.mul_sub_left_distrib, Nat.right_distrib, Nat.right_distrib, Nat.mul_comm b a, Nat.add_comm (a*a) (a*b), Nat.add_sub_add_left] +protected theorem mul_left_cancel {n m k : Nat} (np : 0 < n) (h:n * m = n * k) : m = k := by + match Nat.lt_trichotomy m k with + | Or.inl p => + have r : n * m < n * k := Nat.mul_lt_mul_of_pos_left p np + simp [h] at r + | Or.inr (Or.inl p) => exact p + | Or.inr (Or.inr p) => + have r : n * k < n * m := Nat.mul_lt_mul_of_pos_left p np + simp [h] at r + +protected theorem mul_right_cancel {n m k : Nat} (mp : 0 < m) (h:n * m = k * m) : n = k := by + simp [Nat.mul_comm _ m] at h + apply Nat.mul_left_cancel mp h + +protected theorem mul_left_cancel_iff {n m k : Nat} (p : 0 < n) : n * m = n * k ↔ m = k := + ⟨Nat.mul_left_cancel p, fun | rfl => rfl⟩ + +protected theorem mul_right_cancel_iff {n m k : Nat} (p : 0 < m) : n * m = k * m ↔ n = k := + ⟨Nat.mul_right_cancel p, fun | rfl => rfl⟩ + /-! ## div/mod -/ -- TODO mod_core_congr, mod_def