From 2ba0a4549b0714966895ec53045b9f5843a56c5e Mon Sep 17 00:00:00 2001 From: Joe Hendrix Date: Thu, 11 Apr 2024 17:26:45 +0200 Subject: [PATCH] feat: add BitVec Int add & mul lemmas (#3880) This adds some basic lemmas to support commuting ofInt/toInt and add/mul. It also removes the simp annotation on `ofNat_add_ofNat` as in some contexts the other direction or conversion to Int may be desired. --- src/Init/Data/BitVec/Lemmas.lean | 24 +++++++++++++++++++++++- src/Init/Data/Int/DivModLemmas.lean | 23 ++++++++++++++++++++--- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index d4e759addafd..7aaf8ad3ada2 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -817,9 +817,13 @@ Definition of bitvector addition as a nat. .ofFin x + y = .ofFin (x + y.toFin) := rfl @[simp] theorem add_ofFin (x : BitVec n) (y : Fin (2^n)) : x + .ofFin y = .ofFin (x.toFin + y) := rfl -@[simp] theorem ofNat_add_ofNat {n} (x y : Nat) : x#n + y#n = (x + y)#n := by + +theorem ofNat_add {n} (x y : Nat) : (x + y)#n = x#n + y#n := by apply eq_of_toNat_eq ; simp [BitVec.ofNat] +theorem ofNat_add_ofNat {n} (x y : Nat) : x#n + y#n = (x + y)#n := + (ofNat_add x y).symm + protected theorem add_assoc (x y z : BitVec n) : x + y + z = x + (y + z) := by apply eq_of_toNat_eq ; simp [Nat.add_assoc] @@ -835,6 +839,15 @@ theorem truncate_add (x y : BitVec w) (h : i ≤ w) : have dvd : 2^i ∣ 2^w := Nat.pow_dvd_pow _ h simp [bv_toNat, h, Nat.mod_mod_of_dvd _ dvd] +@[simp, bv_toNat] theorem toInt_add (x y : BitVec w) : + (x + y).toInt = (x.toInt + y.toInt).bmod (2^w) := by + simp [toInt_eq_toNat_bmod] + +theorem ofInt_add {n} (x y : Int) : BitVec.ofInt n (x + y) = + BitVec.ofInt n x + BitVec.ofInt n y := by + apply eq_of_toInt_eq + simp + /-! ### sub/neg -/ theorem sub_def {n} (x y : BitVec n) : x - y = .ofNat n (x.toNat + (2^n - y.toNat)) := by rfl @@ -911,6 +924,15 @@ instance : Std.Associative (fun (x y : BitVec w) => x * y) := ⟨BitVec.mul_asso instance : Std.LawfulCommIdentity (fun (x y : BitVec w) => x * y) (1#w) where right_id := BitVec.mul_one +@[simp, bv_toNat] theorem toInt_mul (x y : BitVec w) : + (x * y).toInt = (x.toInt * y.toInt).bmod (2^w) := by + simp [toInt_eq_toNat_bmod] + +theorem ofInt_mul {n} (x y : Int) : BitVec.ofInt n (x * y) = + BitVec.ofInt n x * BitVec.ofInt n y := by + apply eq_of_toInt_eq + simp + /-! ### le and lt -/ @[bv_toNat] theorem le_def (x y : BitVec n) : diff --git a/src/Init/Data/Int/DivModLemmas.lean b/src/Init/Data/Int/DivModLemmas.lean index b17baa8d5cfe..1885ead9a4a4 100644 --- a/src/Init/Data/Int/DivModLemmas.lean +++ b/src/Init/Data/Int/DivModLemmas.lean @@ -1054,20 +1054,37 @@ theorem emod_add_bmod_congr (x : Int) (n : Nat) : Int.bmod (x%n + y) n = Int.bmo simp [Int.emod_def, Int.sub_eq_add_neg] rw [←Int.mul_neg, Int.add_right_comm, Int.bmod_add_mul_cancel] +@[simp] +theorem emod_mul_bmod_congr (x : Int) (n : Nat) : Int.bmod (x%n * y) n = Int.bmod (x * y) n := by + simp [Int.emod_def, Int.sub_eq_add_neg] + rw [←Int.mul_neg, Int.add_mul, Int.mul_assoc, Int.bmod_add_mul_cancel] + @[simp] theorem bmod_add_bmod_congr : Int.bmod (Int.bmod x n + y) n = Int.bmod (x + y) n := by rw [bmod_def x n] split case inl p => - simp + simp only [emod_add_bmod_congr] case inr p => rw [Int.sub_eq_add_neg, Int.add_right_comm, ←Int.sub_eq_add_neg] simp -@[simp] -theorem add_bmod_bmod : Int.bmod (x + Int.bmod y n) n = Int.bmod (x + y) n := by +@[simp] theorem add_bmod_bmod : Int.bmod (x + Int.bmod y n) n = Int.bmod (x + y) n := by rw [Int.add_comm x, Int.bmod_add_bmod_congr, Int.add_comm y] +@[simp] +theorem bmod_mul_bmod : Int.bmod (Int.bmod x n * y) n = Int.bmod (x * y) n := by + rw [bmod_def x n] + split + case inl p => + simp + case inr p => + rw [Int.sub_mul, Int.sub_eq_add_neg, ← Int.mul_neg] + simp + +@[simp] theorem mul_bmod_bmod : Int.bmod (x * Int.bmod y n) n = Int.bmod (x * y) n := by + rw [Int.mul_comm x, bmod_mul_bmod, Int.mul_comm x] + theorem emod_bmod {x : Int} {m : Nat} : bmod (x % m) m = bmod x m := by simp [bmod]