diff --git a/Batteries/Data/Fin.lean b/Batteries/Data/Fin.lean index 7a5b9c16e8..cd1e813c05 100644 --- a/Batteries/Data/Fin.lean +++ b/Batteries/Data/Fin.lean @@ -1,3 +1,5 @@ import Batteries.Data.Fin.Basic +import Batteries.Data.Fin.Coding +import Batteries.Data.Fin.Enum import Batteries.Data.Fin.Fold import Batteries.Data.Fin.Lemmas diff --git a/Batteries/Data/Fin/Basic.lean b/Batteries/Data/Fin/Basic.lean index f63e38ab37..db8fa7e3b5 100644 --- a/Batteries/Data/Fin/Basic.lean +++ b/Batteries/Data/Fin/Basic.lean @@ -17,3 +17,35 @@ alias enum := Array.finRange @[deprecated (since := "2024-11-15")] alias list := List.finRange + +/-- Sum of a list indexed by `Fin n`. -/ +protected def sum [OfNat α (nat_lit 0)] [Add α] (x : Fin n → α) : α := + foldr n (x · + ·) 0 + +/-- Product of a list indexed by `Fin n`. -/ +protected def prod [OfNat α (nat_lit 1)] [Mul α] (x : Fin n → α) : α := + foldr n (x · * ·) 1 + +/-- Count the number of true values of a decidable predicate on `Fin n`. -/ +protected def count (P : Fin n → Prop) [DecidablePred P] : Nat := + Fin.sum (if P · then 1 else 0) + +/-- Find the first true value of a decidable predicate on `Fin n`, if there is one. -/ +protected def find? (P : Fin n → Prop) [DecidablePred P] : Option (Fin n) := + foldr n (fun i v => if P i then some i else v) none + +/-- Custom recursor for `Fin (n+1)`. -/ +def recZeroSuccOn {motive : Fin (n+1) → Sort _} (x : Fin (n+1)) + (zero : motive 0) (succ : (x : Fin n) → motive x.castSucc → motive x.succ) : motive x := + match x with + | 0 => zero + | ⟨x+1, hx⟩ => + let x : Fin n := ⟨x, Nat.lt_of_succ_lt_succ hx⟩ + succ x <| recZeroSuccOn x.castSucc zero succ + +/-- Custom recursor for `Fin (n+1)`. -/ +def casesZeroSuccOn {motive : Fin (n+1) → Sort _} (x : Fin (n+1)) + (zero : motive 0) (succ : (x : Fin n) → motive x.succ) : motive x := + match x with + | 0 => zero + | ⟨x+1, hx⟩ => succ ⟨x, Nat.lt_of_succ_lt_succ hx⟩ diff --git a/Batteries/Data/Fin/Coding.lean b/Batteries/Data/Fin/Coding.lean new file mode 100644 index 0000000000..6facb49006 --- /dev/null +++ b/Batteries/Data/Fin/Coding.lean @@ -0,0 +1,531 @@ +/- +Copyright (c) 2024 François G. Dorais. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: François G. Dorais +-/ +import Batteries.Data.Fin.Lemmas +import Batteries.Tactic.Basic +import Batteries.Tactic.Trans +import Batteries.Tactic.Lint + +namespace Fin + +/-- Encode a unit value as a `Fin` type. -/ +@[nolint unusedArguments] +def encodePUnit : PUnit → Fin 1 + | .unit => 0 + +/-- Decode a unit value as a `Fin` type. -/ +@[pp_nodot] def decodePUnit : Fin 1 → PUnit + | 0 => .unit + +/-- Encode a unit value as a `Fin` type. -/ +abbrev encodeUnit : Unit → Fin 1 := encodePUnit + +/-- Decode a unit value as a `Fin` type. -/ +@[pp_nodot] abbrev decodeUnit : Fin 1 → Unit := decodePUnit + +@[simp] theorem encodePUnit_decodePUnit : (x : Fin 1) → encodePUnit (decodePUnit x) = x + | 0 => rfl + +@[simp] theorem decodePUnit_encodePUnit : (x : PUnit) → decodePUnit (encodePUnit x) = x + | .unit => rfl + +/-- Encode a boolean value as a `Fin` type. -/ +def encodeBool : Bool → Fin 2 + | false => 0 + | true => 1 + +/-- Decode a boolean value as a `Fin` type. -/ +@[pp_nodot] def decodeBool : Fin 2 → Bool + | 0 => false + | 1 => true + +@[simp] theorem encodeBool_decodeBool : (x : Fin 2) → encodeBool (decodeBool x) = x + | 0 => rfl + | 1 => rfl + +@[simp] theorem decodeBool_encodeBool : (x : Bool) → decodeBool (encodeBool x) = x + | false => rfl + | true => rfl + +/-- Encode a character value as a `Fin` type. -/ +def encodeChar (c : Char) : Fin 1112064 := + have : c.toNat < 1114112 := + match c.valid with + | .inl h => Nat.lt_trans h (by decide) + | .inr h => h.right + if _ : c.toNat < 55296 then + ⟨c.toNat, by omega⟩ + else + ⟨c.toNat - 2048, by omega⟩ + +/-- Decode a character value as a `Fin` type. -/ +@[pp_nodot] def decodeChar (i : Fin 1112064) : Char := + if h : i.val < 55296 then + Char.ofNatAux i.val (by omega) + else + Char.ofNatAux (i.val + 2048) (by omega) + +@[simp] theorem encodeChar_decodeChar (x) : encodeChar (decodeChar x) = x := by + simp only [decodeChar, encodeChar] + split + · simp [Char.ofNatAux, Char.toNat, UInt32.toNat, *] + · have : ¬ x.val + 2048 < 55296 := by omega + simp [Char.ofNatAux, Char.toNat, UInt32.toNat, *] + +@[simp] theorem decodeChar_encodeChar (x) : decodeChar (encodeChar x) = x := by + ext; simp only [decodeChar, encodeChar] + split + · simp only [Char.ofNatAux, Char.toNat]; rfl + · have h0 : 57344 ≤ x.toNat ∧ x.toNat < 1114112 := by + match x.valid with + | .inl h => contradiction + | .inr h => + constructor + · exact Nat.add_one_le_of_lt h.left + · exact h.right + have h1 : ¬ x.toNat - 2048 < 55296 := by omega + have h2 : 2048 ≤ x.toNat := by omega + simp only [dif_neg h1, Char.ofNatAux, Nat.sub_add_cancel h2]; rfl + +/-- Decode an optional `Fin` types. -/ +def encodeOption : Option (Fin n) → Fin (n+1) + | none => 0 + | some ⟨i, h⟩ => ⟨i+1, Nat.succ_lt_succ h⟩ + +/-- Decode an optional `Fin` types. -/ +@[pp_nodot] def decodeOption : Fin (n+1) → Option (Fin n) + | 0 => none + | ⟨i+1, h⟩ => some ⟨i, Nat.lt_of_succ_lt_succ h⟩ + +@[simp] theorem encodeOption_decodeOption (x : Fin (n+1)) : encodeOption (decodeOption x) = x := by + simp only [encodeOption, decodeOption] + split + · next hd => + split at hd + · next he => cases he; rfl + · contradiction + · next hd => + split at hd + · contradiction + · next he => cases hd; simp + +@[simp] theorem decodeOption_encodeOption (x : Option (Fin n)) : + decodeOption (encodeOption x) = x := by + simp only [encodeOption, decodeOption] + split + · next he => + split at he + · rfl + · cases he + · next he => + split at he + · cases he + · cases he; rfl + +/-- Encode a sum of `Fin` types. -/ +def encodeSum : Sum (Fin n) (Fin m) → Fin (n + m) + | .inl x => x.castLE (Nat.le_add_right _ _) + | .inr x => x.natAdd n + +/-- Decode a sum of `Fin` types. -/ +@[pp_nodot] def decodeSum (x : Fin (n + m)) : Sum (Fin n) (Fin m) := + if h : x < n then + .inl ⟨x, h⟩ + else + .inr ⟨x - n, by omega⟩ + +@[simp] theorem encodeSum_decodeSum (x : Fin (n + m)) : encodeSum (decodeSum x) = x := by + simp only [encodeSum, decodeSum] + split + · next hd => + split at hd + · cases hd; rfl + · contradiction + · next hd => + split at hd + · contradiction + · next he => + cases x; cases hd + simp at he ⊢; omega + +@[simp] theorem decodeSum_encodeSum (x : Sum (Fin n) (Fin m)) : decodeSum (encodeSum x) = x := by + simp only [encodeSum, decodeSum] + split + · split + · simp + · next h => simp only [coe_castLE] at h; omega + · split + · next h => simp only [coe_natAdd] at h; omega + · next x _ => cases x; simp only [natAdd_mk, Sum.inr.injEq, mk.injEq]; omega + +/-- Encode a product of `Fin` types. -/ +def encodeProd : Fin m × Fin n → Fin (m * n) + | (⟨i, hi⟩, ⟨j, hj⟩) => Fin.mk (n * i + j) <| calc + _ < n * i + n := Nat.add_lt_add_left hj .. + _ = n * (i + 1) := Nat.mul_succ .. + _ ≤ n * m := Nat.mul_le_mul_left n (Nat.succ_le_of_lt hi) + _ = m * n := Nat.mul_comm .. + +/-- Decode a product of `Fin` types. -/ +@[pp_nodot] def decodeProd (z : Fin (m * n)) : Fin m × Fin n := + ⟨left, right⟩ +where + hn : 0 < n := by + apply Nat.zero_lt_of_ne_zero + intro h + absurd z.is_lt + simp [h] + /-- Left case of `decodeProd`. -/ + left := ⟨z / n, by rw [Nat.div_lt_iff_lt_mul hn]; exact z.is_lt⟩ + /-- Right case of `decodeProd`. -/ + right := ⟨z % n, Nat.mod_lt _ hn⟩ + +@[simp] theorem encodeProd_decodeProd (x : Fin (m * n)) : encodeProd (decodeProd x) = x := by + simp [encodeProd, decodeProd, decodeProd.left, decodeProd.right, Nat.div_add_mod] + +@[simp] theorem decodeProd_encodeProd (x : Fin m × Fin n) : decodeProd (encodeProd x) = x := by + match x with + | ⟨⟨_, _⟩, ⟨_, h⟩⟩ => simp [encodeProd, decodeProd, decodeProd.left, decodeProd.right, + Nat.mul_add_div (Nat.zero_lt_of_lt h), Nat.div_eq_of_lt h, Nat.mul_add_mod, Nat.mod_eq_of_lt h] + +/-- Encode a dependent sum of `Fin` types. -/ +def encodeSigma (f : Fin n → Nat) (x : (i : Fin n) × Fin (f i)) : Fin (Fin.sum f) := + match n, f, x with + | _+1, _, ⟨⟨0, _⟩, ⟨j, hj⟩⟩ => + ⟨j, Nat.lt_of_lt_of_le hj (sum_succ .. ▸ Nat.le_add_right ..)⟩ + | _+1, f, ⟨⟨i+1, hi⟩, ⟨j, hj⟩⟩ => + match encodeSigma ((f ∘ succ)) ⟨⟨i, Nat.lt_of_succ_lt_succ hi⟩, ⟨j, hj⟩⟩ with + | ⟨k, hk⟩ => ⟨f 0 + k, sum_succ .. ▸ Nat.add_lt_add_left hk ..⟩ + +/-- Decode a dependent sum of `Fin` types. -/ +@[pp_nodot] def decodeSigma (f : Fin n → Nat) (x : Fin (Fin.sum f)) : (i : Fin n) × Fin (f i) := + match n, f, x with + | 0, _, ⟨_, h⟩ => False.elim (by simp at h) + | n+1, f, ⟨x, hx⟩ => + if hx0 : x < f 0 then + ⟨0, ⟨x, hx0⟩⟩ + else + have hxf : x - f 0 < Fin.sum (f ∘ succ) := by + apply Nat.sub_lt_left_of_lt_add + · exact Nat.le_of_not_gt hx0 + · rw [← sum_succ]; exact hx + match decodeSigma ((f ∘ succ)) ⟨x - f 0, hxf⟩ with + | ⟨⟨i, hi⟩, y⟩ => ⟨⟨i+1, Nat.succ_lt_succ hi⟩, y⟩ + +@[simp] theorem encodeSigma_decodeSigma (f : Fin n → Nat) (x : Fin (Fin.sum f)) : + encodeSigma f (decodeSigma f x) = x := by + induction n with + | zero => absurd x.is_lt; simp + | succ n ih => + simp only [decodeSigma] + split + · simp [encodeSigma] + · next h1 => + have h2 : x - f 0 < Fin.sum (f ∘ succ) := by + apply Nat.sub_lt_left_of_lt_add + · exact Nat.le_of_not_gt h1 + · rw [← sum_succ]; exact x.is_lt + have : encodeSigma (f ∘ succ) (decodeSigma (f ∘ succ) ⟨x - f 0, h2⟩) = ↑x - f 0 := by + rw [ih] + ext; simp only [encodeSigma] + conv => rhs; rw [← Nat.add_sub_of_le (Nat.le_of_not_gt h1)] + congr + +@[simp] theorem decodeSigma_encodeSigma (f : Fin n → Nat) (x : (i : Fin n) × Fin (f i)) : + decodeSigma f (encodeSigma f x) = x := by + induction n with + | zero => nomatch x + | succ n ih => + simp only [decodeSigma] + match x with + | ⟨0, x⟩ => simp [encodeSigma] + | ⟨⟨i+1, hi⟩, ⟨x, hx⟩⟩ => + have : ¬ encodeSigma f ⟨⟨i+1, hi⟩, ⟨x, hx⟩⟩ < f 0 := by simp [encodeSigma] + rw [dif_neg this] + have : (encodeSigma f ⟨⟨i+1, hi⟩, ⟨x, hx⟩⟩).1 - f 0 = + (encodeSigma (f ∘ Fin.succ) ⟨⟨i, Nat.lt_of_succ_lt_succ hi⟩, ⟨x, hx⟩⟩).1 := by + simp [encodeSigma, Nat.add_sub_cancel_left] + have h := ih (f ∘ Fin.succ) ⟨⟨i, Nat.lt_of_succ_lt_succ hi⟩, ⟨x, hx⟩⟩ + simp [Sigma.ext_iff, Fin.ext_iff] at h ⊢ + constructor + · conv => rhs; rw [← h.1] + apply congrArg + apply congrArg Sigma.fst + congr + · apply HEq.trans _ h.2; congr + +/-- Encode a function between `Fin` types. -/ +def encodeFun : {m : Nat} → (Fin m → Fin n) → Fin (n ^ m) + | 0, _ => ⟨0, by simp⟩ + | m+1, f => Fin.mk (n * (encodeFun fun k => f k.succ).val + (f 0).val) <| calc + _ < n * (encodeFun fun k => f k.succ).val + n := Nat.add_lt_add_left (f 0).isLt _ + _ = n * ((encodeFun fun k => f k.succ).val + 1) := Nat.mul_succ .. + _ ≤ n * n ^ m := Nat.mul_le_mul_left n (Nat.succ_le_of_lt (encodeFun fun k => f k.succ).isLt) + _ = n ^ m * n := Nat.mul_comm .. + _ = n ^ (m+1) := Nat.pow_succ .. + +/-- Decode a function between `Fin` types. -/ +@[pp_nodot] def decodeFun : {m : Nat} → Fin (n ^ m) → Fin m → Fin n + | 0, _ => (nomatch .) + | m+1, ⟨k, hk⟩ => + have hn : n > 0 := by + apply Nat.zero_lt_of_ne_zero + intro h + rw [h, Nat.pow_succ, Nat.mul_zero] at hk + contradiction + fun + | 0 => ⟨k % n, Nat.mod_lt k hn⟩ + | ⟨i+1, hi⟩ => + have h : k / n < n ^ m := by rw [Nat.div_lt_iff_lt_mul hn]; exact hk + decodeFun ⟨k / n, h⟩ ⟨i, Nat.lt_of_succ_lt_succ hi⟩ + +@[simp] theorem encodeFun_decodeFun (x : Fin (n ^ m)) : encodeFun (decodeFun x) = x := by + induction m with simp only [encodeFun, decodeFun, Fin.succ] + | zero => simp; omega + | succ m ih => cases x; simp [ih]; rw [← Fin.zero_eta, Nat.div_add_mod] + +@[simp] theorem decodeFun_encodeFun (x : Fin m → Fin n) : decodeFun (encodeFun x) = x := by + funext i; induction m with simp only [encodeFun, decodeFun] + | zero => nomatch i + | succ m ih => + have hn : 0 < n := Nat.zero_lt_of_lt (x 0).is_lt + split + · ext; simp [Nat.mul_add_mod, Nat.mod_eq_of_lt] + · next i hi => + have : decodeFun (encodeFun fun k => x k.succ) ⟨i, Nat.lt_of_succ_lt_succ hi⟩ + = x ⟨i+1, hi⟩ := by rw [ih]; rfl + simp [← this, Nat.mul_add_div hn, Nat.div_eq_of_lt] + +/-- Encode a dependent product of `Fin` types. -/ +def encodePi (f : Fin n → Nat) (x : (i : Fin n) → Fin (f i)) : Fin (Fin.prod f) := + match n, f, x with + | 0, _, _ => ⟨0, by simp [Fin.prod]⟩ + | _+1, f, x => + match encodePi ((f ∘ succ)) (fun ⟨i, hi⟩ => x ⟨i+1, Nat.succ_lt_succ hi⟩) with + | ⟨k, hk⟩ => Fin.mk (f 0 * k + (x 0).val) $ calc + _ < f 0 * k + f 0 := Nat.add_lt_add_left (x 0).isLt .. + _ = f 0 * (k + 1) := Nat.mul_succ .. + _ ≤ f 0 * Fin.prod (f ∘ succ) := Nat.mul_le_mul_left _ (Nat.succ_le_of_lt hk) + _ = Fin.prod f := Eq.symm <| prod_succ .. + +/-- Decode a dependent product of `Fin` types. -/ +def decodePi (f : Fin n → Nat) (x : Fin (Fin.prod f)) : (i : Fin n) → Fin (f i) := + match n, f, x with + | 0, _, _ => (nomatch ·) + | n+1, f, ⟨x, hx⟩ => + have h : f 0 > 0 := by + apply Nat.zero_lt_of_ne_zero + intro h + rw [prod_succ, h, Nat.zero_mul] at hx + contradiction + have : x / f 0 < Fin.prod (f ∘ succ) := by + rw [Nat.div_lt_iff_lt_mul h, Nat.mul_comm, ← prod_succ] + exact hx + match decodePi ((f ∘ succ)) ⟨x / f 0, this⟩ with + | t => fun + | ⟨0, _⟩ => ⟨x % f 0, Nat.mod_lt x h⟩ + | ⟨i+1, hi⟩ => t ⟨i, Nat.lt_of_succ_lt_succ hi⟩ + +@[simp] theorem encodePi_decodePi (f : Fin n → Nat) (x : Fin (Fin.prod f)) : + encodePi f (decodePi f x) = x := by + induction n with + | zero => + match x with + | ⟨0, _⟩ => rfl + | ⟨_+1, h⟩ => simp at h + | succ n ih => + simp only [encodePi, decodePi, ih] + ext + conv => rhs; rw [← Nat.div_add_mod x (f 0)] + congr + +@[simp] theorem decodePi_encodePi (f : Fin n → Nat) (x : (i : Fin n) → Fin (f i)) : + decodePi f (encodePi f x) = x := by + induction n with + | zero => funext i; nomatch i + | succ n ih => + have h0 : 0 < f 0 := Nat.zero_lt_of_lt (x 0).is_lt + funext i + simp only [decodePi] + split + · simp [Nat.mul_add_mod, Nat.mod_eq_of_lt (x 0).is_lt] + · next i hi => + have h : decodePi (f ∘ succ) (encodePi (f ∘ succ) fun i => x i.succ) + ⟨i, Nat.lt_of_succ_lt_succ hi⟩ = x ⟨i+1, hi⟩ := by rw [ih]; rfl + conv => rhs; rw [← h] + congr; simp [Nat.mul_add_div h0, Nat.div_eq_of_lt (x 0).is_lt]; rfl + +/-- Encode a decidable subtype of a `Fin` type. -/ +def encodeSubtype (P : Fin n → Prop) [inst : DecidablePred P] (i : { i // P i }) : + Fin (Fin.count P) := + match n, P, inst, i with + | n+1, P, inst, ⟨0, hp⟩ => + have : Fin.count P > 0 := by simp [count_succ, hp] + ⟨0, this⟩ + | n+1, P, inst, ⟨⟨i+1, hi⟩, hp⟩ => + match encodeSubtype (fun i => P i.succ) ⟨⟨i, Nat.lt_of_succ_lt_succ hi⟩, hp⟩ with + | ⟨k, hk⟩ => + if h0 : P 0 then + have : Fin.count P = (Fin.count fun i => P i.succ) + 1 := by + simp_arith only [count_succ, Function.comp_def, if_pos h0] + this ▸ ⟨k+1, Nat.succ_lt_succ hk⟩ + else + have : Fin.count P = Fin.count fun i => P i.succ := by simp [count_succ, h0] + this ▸ ⟨k, hk⟩ + +/-- Decode a decidable subtype of a `Fin` type. -/ +def decodeSubtype (p : Fin n → Prop) [inst : DecidablePred p] (k : Fin (Fin.count p)) : + { i // p i } := + match n, p, inst, k with + | 0, _, _, ⟨_, h⟩ => False.elim (by simp at h) + | n+1, p, inst, ⟨k, hk⟩ => + if h0 : p 0 then + have : Fin.count p = (Fin.count fun i => p i.succ) + 1 := by simp [count_succ, h0] + match k with + | 0 => ⟨0, h0⟩ + | k + 1 => + match decodeSubtype (fun i => p i.succ) ⟨k, Nat.lt_of_add_lt_add_right (this ▸ hk)⟩ with + | ⟨⟨i, hi⟩, hp⟩ => ⟨⟨i+1, Nat.succ_lt_succ hi⟩, hp⟩ + else + have : Fin.count p = Fin.count fun i => p i.succ := by simp [count_succ, h0] + match decodeSubtype (fun i => p (succ i)) ⟨k, this ▸ hk⟩ with + | ⟨⟨i, hi⟩, hp⟩ => ⟨⟨i+1, Nat.succ_lt_succ hi⟩, hp⟩ + +@[simp] theorem encodeSubtype_decodeSubtype (P : Fin n → Prop) [DecidablePred P] + (x : Fin (Fin.count P)) : encodeSubtype P (decodeSubtype P x) = x := by + induction n with + | zero => absurd x.is_lt; simp + | succ n ih => + simp only [decodeSubtype] + split + · ext; split <;> simp [encodeSubtype, count_succ, *] + · ext; simp [encodeSubtype, count_succ, *] + +theorem encodeSubtype_zero_pos {P : Fin (n+1) → Prop} [DecidablePred P] (h₀ : P 0) : + encodeSubtype P ⟨0, h₀⟩ = ⟨0, by simp [count_succ, *]⟩ := by + ext; simp [encodeSubtype] + +theorem encodeSubtype_succ_pos {P : Fin (n+1) → Prop} [DecidablePred P] (h₀ : P 0) {i : Fin n} + (h : P i.succ) : encodeSubtype P ⟨i.succ, h⟩ = + (encodeSubtype (fun i => P i.succ) ⟨i, h⟩).succ.cast (by simp [count_succ, *]) := by + ext; simp [encodeSubtype, count_succ, *] + +theorem encodeSubtype_succ_neg {P : Fin (n+1) → Prop} [DecidablePred P] (h₀ : ¬ P 0) {i : Fin n} + (h : P i.succ) : encodeSubtype P ⟨i.succ, h⟩ = + (encodeSubtype (fun i => P i.succ) ⟨i, h⟩).cast (by simp [count_succ, *]) := by + ext; simp [encodeSubtype, count_succ, *] + +@[simp] theorem decodeSubtype_encodeSubtype (P : Fin n → Prop) [DecidablePred P] (x : { x // P x}) : + decodeSubtype P (encodeSubtype P x) = x := by + match x with + | ⟨i, h⟩ => + induction n with + | zero => absurd x.val.is_lt; simp + | succ n ih => + if h₀ : P 0 then + simp only [decodeSubtype, dif_pos h₀] + cases i using Fin.casesZeroSuccOn with + | zero => rw [encodeSubtype_zero_pos h₀] + | succ i => + rw [encodeSubtype_succ_pos h₀] + simp only [coe_cast, val_succ, Subtype.mk.injEq] + congr + rw [ih (fun i => P i.succ) ⟨i, h⟩] + else + simp only [decodeSubtype, dif_neg h₀] + cases i using Fin.casesZeroSuccOn with + | zero => contradiction + | succ i => + rw [encodeSubtype_succ_neg h₀] + simp only [coe_cast, Subtype.mk.injEq] + congr + rw [ih (fun i => P i.succ) ⟨i, h⟩] + +/-- Get representative for the equivalence class of `x`. -/ +abbrev getRepr (s : Setoid (Fin n)) [DecidableRel s.r] (x : Fin n) : Fin n := + match h : Fin.find? (s.r x) with + | some y => y + | none => False.elim <| by + have : Fin.find? (s.r x) |>.isSome := by + rw [find?_isSome_iff_exists] + exists x + exact Setoid.refl x + simp [h] at this + +@[simp] theorem equiv_getRepr (s : Setoid (Fin n)) [DecidableRel s.r] (x : Fin n) : + s.r x (getRepr s x) := by + apply find?_prop + simp only [getRepr] + split + · assumption + · next h => + have : Fin.find? (s.r x) |>.isSome := by + rw [find?_isSome_iff_exists] + exists x + exact Setoid.refl x + simp [h] at this + +@[simp] theorem getRepr_equiv (s : Setoid (Fin n)) [DecidableRel s.r] (x : Fin n) : + s.r (getRepr s x) x := Setoid.symm (equiv_getRepr ..) + +theorem getRepr_eq_getRepr_of_equiv (s : Setoid (Fin n)) [DecidableRel s.r] (h : s.r x y) : + getRepr s x = getRepr s y := by + have hfind : Fin.find? (s.r x) = Fin.find? (s.r y) := by + congr + funext z + apply propext + constructor + · exact Setoid.trans (Setoid.symm h) + · exact Setoid.trans h + simp only [getRepr] + split + · next hx => + rw [hfind] at hx + split + · next hy => + rwa [hx, Option.some_inj] at hy + · next hy => + rw [hx] at hy + contradiction + · next hx => + rw [hfind] at hx + split + · next hy => + rw [hx] at hy + contradiction + · rfl + +@[simp] theorem getRepr_getRepr (s : Setoid (Fin n)) [DecidableRel s.r] (x : Fin n) : + getRepr s (getRepr s x) = getRepr s x := by + apply getRepr_eq_getRepr_of_equiv + exact getRepr_equiv .. + +/-- Encode decidable quotient of a `Fin` type. -/ +def encodeQuotient (s : Setoid (Fin n)) [DecidableRel s.r] (x : Quotient s) : + Fin (Fin.count fun i => getRepr s i = i) := + encodeSubtype _ <| Quotient.liftOn x (fun i => ⟨getRepr s i, getRepr_getRepr s i⟩) <| by + intro _ _ h + simp only [Subtype.mk.injEq] + exact getRepr_eq_getRepr_of_equiv s h + +/-- Decode decidable quotient of a `Fin ` type. -/ +def decodeQuotient (s : Setoid (Fin n)) [DecidableRel s.r] + (i : Fin (Fin.count fun i => getRepr s i = i)) : Quotient s := + Quotient.mk s (decodeSubtype _ i) + +@[simp] theorem encodeQuotient_decodeQuotient (s : Setoid (Fin n)) [DecidableRel s.r] (x) : + encodeQuotient s (decodeQuotient s x) = x := by + simp only [decodeQuotient, encodeQuotient, Quotient.liftOn, Quotient.mk, Quot.liftOn] + conv => rhs; rw [← encodeSubtype_decodeSubtype _ x] + congr + exact (decodeSubtype _ x).property + +@[simp] theorem decodeQuotient_encodeQuotient (s : Setoid (Fin n)) [DecidableRel s.r] (x) : + decodeQuotient s (encodeQuotient s x) = x := by + induction x using Quotient.inductionOn with + | _ x => + simp only [decodeQuotient, encodeQuotient, Quotient.liftOn, Quotient.mk, Quot.liftOn] + apply Quot.sound + simp diff --git a/Batteries/Data/Fin/Enum.lean b/Batteries/Data/Fin/Enum.lean new file mode 100644 index 0000000000..299cabcb1c --- /dev/null +++ b/Batteries/Data/Fin/Enum.lean @@ -0,0 +1,195 @@ +/- +Copyright (c) 2024 François G. Dorais. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: François G. Dorais +-/ +import Batteries.Data.Fin.Coding + +namespace Fin + +/-- Class of types that are in bijection with a `Fin` type. -/ +protected class Enum (α : Type _) where + /-- Size of type. -/ + size : Nat + /-- Enumeration of the type elements. -/ + decode : Fin size → α + /-- Find the index of a type element. -/ + encode : α → Fin size + /-- Inverse relation for `decode` and `encode`. -/ + decode_encode (x) : decode (encode x) = x + /-- Inverse relation for `decode` and `encode`. -/ + encode_decode (i) : encode (decode i) = i + +attribute [simp] Enum.decode_encode Enum.encode_decode + +namespace Enum + +instance : Fin.Enum Empty where + size := 0 + decode := nofun + encode := nofun + decode_encode := nofun + encode_decode := nofun + +instance : Fin.Enum PUnit where + size := 1 + decode := decodePUnit + encode := encodePUnit + decode_encode := decodePUnit_encodePUnit + encode_decode := encodePUnit_decodePUnit + +instance : Fin.Enum Bool where + size := 2 + decode := decodeBool + encode := encodeBool + decode_encode := decodeBool_encodeBool + encode_decode := encodeBool_decodeBool + +instance : Fin.Enum Char where + size := 1112064 + decode := decodeChar + encode := encodeChar + decode_encode := decodeChar_encodeChar + encode_decode := encodeChar_decodeChar + +instance [Fin.Enum α] : Fin.Enum (Option α) where + size := size α + 1 + decode i := decodeOption i |>.map decode + encode x := encodeOption <| x.map encode + decode_encode := by simp [Function.comp_def] + encode_decode := by simp [Function.comp_def] + +instance [Fin.Enum α] [Fin.Enum β] : Fin.Enum (α ⊕ β) where + size := size α + size β + decode i := + match decodeSum i with + | .inl i => .inl <| decode i + | .inr i => .inr <| decode i + encode x := + match x with + | .inl x => encodeSum <| .inl (encode x) + | .inr x => encodeSum <| .inr (encode x) + decode_encode _ := by + simp only; split + · next h => + split at h; + · simp at h; cases h; simp + · simp at h + · next h => + split at h; + · simp at h + · simp at h; cases h; simp + encode_decode _ := by + simp only; split + · next h => + split at h + · next h' => cases h; simp [← h'] + · simp at h + · next h => + split at h + · simp at h + · next h' => cases h; simp [← h'] + +instance [Fin.Enum α] [Fin.Enum β] : Fin.Enum (α × β) where + size := size α * size β + decode i := (decode (decodeProd i).fst, decode (decodeProd i).snd) + encode x := encodeProd (encode x.fst, encode x.snd) + encode_decode := by simp [Prod.eta] + decode_encode := by simp + +instance [Fin.Enum α] [Fin.Enum β] : Fin.Enum (β → α) where + size := size α ^ size β + decode i x := decode (decodeFun i (encode x)) + encode f := encodeFun fun x => encode (f (decode x)) + encode_decode := by simp + decode_encode := by simp + +instance (β : α → Type _) [Fin.Enum α] [(x : α) → Fin.Enum (β x)] : Fin.Enum ((x : α) × β x) where + size := Fin.sum fun i => size (β (decode i)) + decode i := + match decodeSigma _ i with + | ⟨i, j⟩ => ⟨decode i, decode j⟩ + encode | ⟨x, y⟩ => encodeSigma _ ⟨encode x, encode (_root_.cast (by simp) y)⟩ + encode_decode i := by + simp only [] + conv => rhs; rw [← encodeSigma_decodeSigma _ i] + congr 1 + ext + · simp + · simp only [] + conv => rhs; rw [← encode_decode (decodeSigma _ i).snd] + congr <;> simp + decode_encode + | ⟨x, y⟩ => by + ext + simp only [cast, decodeSigma_encodeSigma, decode_encode] + simp [] + rw [decodeSigma_encodeSigma] + simp + +instance (β : α → Type _) [Fin.Enum α] [(x : α) → Fin.Enum (β x)] : Fin.Enum ((x : α) → β x) where + size := Fin.prod fun i => size (β (decode i)) + decode i x := decode <| (decodePi _ i (encode x)).cast (by rw [decode_encode]) + encode f := encodePi _ fun i => encode (f (decode i)) + encode_decode i := by + simp only [encode_decode] + conv => rhs; rw [← encodePi_decodePi _ i] + congr + ext + simp only [cast] + rw [encode_decode] + decode_encode f := by + funext x + simp only [] + conv => rhs; rw [← decode_encode (f x)] + congr 1 + ext + simp only [cast, decodePi_encodePi] + rw [decode_encode] + done + +instance (P : α → Prop) [DecidablePred P] [Fin.Enum α] : Fin.Enum { x // P x} where + size := Fin.count fun i => P (decode i) + decode i := ⟨decode (decodeSubtype _ i).val, (decodeSubtype _ i).property⟩ + encode x := encodeSubtype _ ⟨encode x.val, (decode_encode x.val).symm ▸ x.property⟩ + encode_decode i := by simp [Subtype.eta] + decode_encode := by simp + +private def decodeSetoid (s : Setoid α) [DecidableRel s.r] [Fin.Enum α] : + Setoid (Fin (size α)) where + r i j := s.r (decode i) (decode j) + iseqv := { + refl := fun i => Setoid.refl (decode i) + symm := Setoid.symm + trans := Setoid.trans + } + +private instance (s : Setoid α) [DecidableRel s.r] [Fin.Enum α] : DecidableRel (decodeSetoid s).r := + fun _ _ => inferInstanceAs (Decidable (s.r _ _)) + +instance (s : Setoid α) [DecidableRel s.r] [Fin.Enum α] : Fin.Enum (Quotient s) where + size := Fin.count fun i => getRepr (decodeSetoid s) i = i + decode i := Quotient.liftOn (decodeQuotient (decodeSetoid s) i) + (fun i => Quotient.mk s (decode i)) (fun _ _ h => Quotient.sound h) + encode x := Quotient.liftOn x + (fun x => encodeQuotient (decodeSetoid s) (Quotient.mk _ (encode x))) <| by + intro _ _ h + simp only [] + congr 1 + apply Quotient.sound + simp only [HasEquiv.Equiv, decodeSetoid, decode_encode] + exact h + decode_encode x := by + induction x using Quotient.inductionOn with + | _ x => + simp only [decodeQuotient, encodeQuotient, decodeSubtype_encodeSubtype, Quotient.liftOn, + Quotient.mk, Quot.liftOn] + apply Quot.sound + conv => rhs; rw [← decode_encode x] + exact getRepr_equiv (decodeSetoid s) .. + encode_decode i := by + conv => rhs; rw [← encodeSubtype_decodeSubtype _ i] + simp only [Quotient.liftOn, Quot.liftOn, encodeQuotient, Quotient.mk, decodeQuotient, + encode_decode] + congr + rw [(decodeSubtype _ i).property] diff --git a/Batteries/Data/Fin/Lemmas.lean b/Batteries/Data/Fin/Lemmas.lean index ddc7bbf1cf..699d1ba925 100644 --- a/Batteries/Data/Fin/Lemmas.lean +++ b/Batteries/Data/Fin/Lemmas.lean @@ -5,11 +5,132 @@ Authors: Mario Carneiro -/ import Batteries.Data.Fin.Basic import Batteries.Data.List.Lemmas +import Batteries.Tactic.Lint.Simp namespace Fin attribute [norm_cast] val_last +@[nolint simpNF, simp] +theorem val_ndrec (x : Fin n) (h : m = n) : (h ▸ x).val = x.val := by + cases h; rfl + /-! ### clamp -/ @[simp] theorem coe_clamp (n m : Nat) : (clamp n m : Nat) = min n m := rfl + +/-! ### foldr -/ + +theorem map_foldr {g : α → β} {f : Fin n → α → α} {f' : Fin n → β → β} + (h : ∀ i x, g (f i x) = f' i (g x)) (x) : g (foldr n f x) = foldr n f' (g x) := by + induction n generalizing x with + | zero => simp + | succ n ih => simp [foldr_succ, ih, h] + +/-! ### sum -/ + +@[simp] theorem sum_zero [OfNat α (nat_lit 0)] [Add α] (x : Fin 0 → α) : + Fin.sum x = 0 := by + simp [Fin.sum] + +theorem sum_succ [OfNat α (nat_lit 0)] [Add α] (x : Fin (n + 1) → α) : + Fin.sum x = x 0 + Fin.sum (x ∘ Fin.succ) := by + simp [Fin.sum, foldr_succ] + +/-! ### prod -/ + +@[simp] theorem prod_zero [OfNat α (nat_lit 1)] [Mul α] (x : Fin 0 → α) : + Fin.prod x = 1 := by + simp [Fin.prod] + +theorem prod_succ [OfNat α (nat_lit 1)] [Mul α] (x : Fin (n + 1) → α) : + Fin.prod x = x 0 * Fin.prod (x ∘ Fin.succ) := by + simp [Fin.prod, foldr_succ] + +/-! ### count -/ + +@[simp] theorem count_zero (P : Fin 0 → Prop) [DecidablePred P] : Fin.count P = 0 := by + simp [Fin.count] + +theorem count_succ (P : Fin (n + 1) → Prop) [DecidablePred P] : Fin.count P = + if P 0 then Fin.count (fun i => P i.succ) + 1 else Fin.count (fun i => P i.succ) := by + split <;> simp [Fin.count, Fin.sum_succ, Nat.one_add, Function.comp_def, *] + +theorem count_le (P : Fin n → Prop) [DecidablePred P] : Fin.count P ≤ n := by + induction n with + | zero => simp + | succ n ih => + rw [count_succ] + split + · simp [ih] + · apply Nat.le_trans _ (Nat.le_succ n); simp [ih] + +/-! ### find? -/ + +@[simp] theorem find?_zero {P : Fin 0 → Prop} [DecidablePred P] : Fin.find? P = none := by + simp [Fin.find?] + +theorem find?_succ (P : Fin (n+1) → Prop) [DecidablePred P] : + Fin.find? P = if P 0 then some 0 else (Fin.find? fun i => P i.succ).map Fin.succ := by + have h (i : Fin n) (v : Option (Fin n)) : + (if P i.succ then some i else v).map Fin.succ = + if P i.succ then some i.succ else v.map Fin.succ := by + intros; split <;> simp + simp [Fin.find?, foldr_succ, map_foldr h] + +theorem find?_prop {P : Fin n → Prop} [DecidablePred P] (h : Fin.find? P = some x) : P x := by + induction n with + | zero => simp at h + | succ n ih => + simp [find?_succ] at h + split at h + · cases h; assumption + · simp [Option.map_eq_some] at h + match h with + | ⟨i, h', hi⟩ => cases hi; exact ih h' + +theorem find?_isSome_of_prop {P : Fin n → Prop} [DecidablePred P] (h : P x) : + (Fin.find? P).isSome := by + induction n with + | zero => nomatch x + | succ n ih => + rw [find?_succ] + split + · rfl + · have hx : x ≠ 0 := by + intro hx + rw [hx] at h + contradiction + have h : P (x.pred hx).succ := by simp [h] + rw [Option.isSome_map'] + exact ih h + +theorem find?_isSome_iff_exists {P : Fin n → Prop} [DecidablePred P] : + (Fin.find? P).isSome ↔ ∃ x, P x := by + constructor + · intro h + match heq : Fin.find? P with + | some x => exists x; exact find?_prop heq + | none => rw [heq] at h; contradiction + · intro ⟨_, h⟩ + exact find?_isSome_of_prop h + +/-! ### recZeroSuccOn -/ + +unseal Fin.recZeroSuccOn in +@[simp] theorem recZeroSuccOn_zero {motive : Fin (n+1) → Sort _} (zero : motive 0) + (succ : (x : Fin n) → motive x.castSucc → motive x.succ) : + Fin.recZeroSuccOn 0 zero succ = zero := rfl + +unseal Fin.recZeroSuccOn in +theorem recZeroSuccOn_succ {motive : Fin (n+1) → Sort _} (x : Fin n) (zero : motive 0) + (succ : (x : Fin n) → motive x.castSucc → motive x.succ) : + Fin.recZeroSuccOn x.succ zero succ = succ x (Fin.recZeroSuccOn x.castSucc zero succ) := rfl + +/-! ### casesZeroSuccOn -/ + +@[simp] theorem casesZeroSuccOn_zero {motive : Fin (n+1) → Sort _} (zero : motive 0) + (succ : (x : Fin n) → motive x.succ) : Fin.casesZeroSuccOn 0 zero succ = zero := rfl + +@[simp] theorem casesZeroSuccOn_succ {motive : Fin (n+1) → Sort _} (x : Fin n) (zero : motive 0) + (succ : (x : Fin n) → motive x.succ) : Fin.casesZeroSuccOn x.succ zero succ = succ x := rfl