diff --git a/Batteries/Data/Fin.lean b/Batteries/Data/Fin.lean index 5fe5cc41ca..67c5fdb14a 100644 --- a/Batteries/Data/Fin.lean +++ b/Batteries/Data/Fin.lean @@ -1,2 +1,4 @@ import Batteries.Data.Fin.Basic +import Batteries.Data.Fin.Coding +import Batteries.Data.Fin.Enum import Batteries.Data.Fin.Lemmas diff --git a/Batteries/Data/Fin/Basic.lean b/Batteries/Data/Fin/Basic.lean index b61481e33a..8438128674 100644 --- a/Batteries/Data/Fin/Basic.lean +++ b/Batteries/Data/Fin/Basic.lean @@ -68,3 +68,43 @@ Fin.foldrM n f xₙ = do loop : {i // i ≤ n} → α → m α | ⟨0, _⟩, x => pure x | ⟨i+1, h⟩, x => f ⟨i, h⟩ x >>= loop ⟨i, Nat.le_of_lt h⟩ + +/-- Sum of a list indexed by `Fin n`. -/ +protected def sum [OfNat α (nat_lit 0)] [Add α] (x : Fin n → α) : α := + foldr n (x · + ·) 0 + +/-- Product of a list indexed by `Fin n`. -/ +protected def prod [OfNat α (nat_lit 1)] [Mul α] (x : Fin n → α) : α := + foldr n (x · * ·) 1 + +/-- Count the number of true values of a decidable predicate on `Fin n`. -/ +protected def count (P : Fin n → Prop) [DecidablePred P] : Nat := + Fin.sum (if P · then 1 else 0) + +/-- Find the first true value of a decidable predicate on `Fin n`, if there is one. -/ +protected def find? (P : Fin n → Prop) [inst : DecidablePred P] : Option (Fin n) := + match n, P, inst with + | 0, _, _ => none + | _+1, P, _ => + if P 0 then + some 0 + else + match Fin.find? fun i => P i.succ with + | some i => some i.succ + | none => none + +/-- Custom recursor for `Fin (n+1)`. -/ +def recZeroSuccOn {motive : Fin (n+1) → Sort _} (x : Fin (n+1)) + (zero : motive 0) (succ : (x : Fin n) → motive x.castSucc → motive x.succ) : motive x := + match x with + | 0 => zero + | ⟨x+1, hx⟩ => + let x : Fin n := ⟨x, Nat.lt_of_succ_lt_succ hx⟩ + succ x <| recZeroSuccOn x.castSucc zero succ + +/-- Custom recursor for `Fin (n+1)`. -/ +def casesZeroSuccOn {motive : Fin (n+1) → Sort _} (x : Fin (n+1)) + (zero : motive 0) (succ : (x : Fin n) → motive x.succ) : motive x := + match x with + | 0 => zero + | ⟨x+1, hx⟩ => succ ⟨x, Nat.lt_of_succ_lt_succ hx⟩ diff --git a/Batteries/Data/Fin/Coding.lean b/Batteries/Data/Fin/Coding.lean new file mode 100644 index 0000000000..5cc54ac0e6 --- /dev/null +++ b/Batteries/Data/Fin/Coding.lean @@ -0,0 +1,497 @@ +/- +Copyright (c) 2024 François G. Dorais. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: François G. Dorais +-/ +import Batteries.Data.Fin.Lemmas +import Batteries.Tactic.Basic +import Batteries.Tactic.Trans +import Batteries.Tactic.Lint + +namespace Fin + +/-- Encode a unit value as a `Fin` type. -/ +@[nolint unusedArguments] +def encodePUnit : PUnit → Fin 1 + | .unit => 0 + +/-- Decode a unit value as a `Fin` type. -/ +@[pp_nodot] def decodePUnit : Fin 1 → PUnit + | 0 => .unit + +/-- Encode a unit value as a `Fin` type. -/ +abbrev encodeUnit : Unit → Fin 1 := encodePUnit + +/-- Decode a unit value as a `Fin` type. -/ +@[pp_nodot] abbrev decodeUnit : Fin 1 → Unit := decodePUnit + +@[simp] theorem encodePUnit_decodePUnit : (x : Fin 1) → encodePUnit (decodePUnit x) = x + | 0 => rfl + +@[simp] theorem decodePUnit_encodePUnit : (x : PUnit) → decodePUnit (encodePUnit x) = x + | .unit => rfl + +/-- Encode a boolean value as a `Fin` type. -/ +def encodeBool : Bool → Fin 2 + | false => 0 + | true => 1 + +/-- Decode a boolean value as a `Fin` type. -/ +@[pp_nodot] def decodeBool : Fin 2 → Bool + | 0 => false + | 1 => true + +@[simp] theorem encodeBool_decodeBool : (x : Fin 2) → encodeBool (decodeBool x) = x + | 0 => rfl + | 1 => rfl + +@[simp] theorem decodeBool_encodeBool : (x : Bool) → decodeBool (encodeBool x) = x + | false => rfl + | true => rfl + +/-- Decode an optional `Fin` types. -/ +def encodeOption : Option (Fin n) → Fin (n+1) + | none => 0 + | some x => x.succ + +/-- Decode an optional `Fin` types. -/ +@[pp_nodot] def decodeOption (x : Fin (n+1)) : Option (Fin n) := + if h : x = 0 then + none + else + some (x.pred h) + +@[simp] theorem encodeOption_decodeOption (x : Fin (n+1)) : encodeOption (decodeOption x) = x := by + simp only [encodeOption, decodeOption] + split + · next hd => + split at hd + · next he => cases he; rfl + · contradiction + · next hd => + split at hd + · contradiction + · next he => cases hd; simp + +@[simp] theorem decodeOption_encodeOption (x : Option (Fin n)) : + decodeOption (encodeOption x) = x := by + simp only [encodeOption, decodeOption] + split + · next he => + split at he + · rfl + · simp only [Fin.succ] at he; cases he + · next he => + split at he + · simp at he + · rfl + +/-- Encode a sum of `Fin` types. -/ +def encodeSum : Sum (Fin n) (Fin m) → Fin (n + m) + | .inl x => x.castLE (Nat.le_add_right _ _) + | .inr x => x.natAdd n + +/-- Decode a sum of `Fin` types. -/ +@[pp_nodot] def decodeSum (x : Fin (n + m)) : Sum (Fin n) (Fin m) := + if h : x < n then + .inl ⟨x, h⟩ + else + .inr ⟨x - n, by omega⟩ + +@[simp] theorem encodeSum_decodeSum (x : Fin (n + m)) : encodeSum (decodeSum x) = x := by + simp only [encodeSum, decodeSum] + split + · next hd => + split at hd + · cases hd; rfl + · contradiction + · next hd => + split at hd + · contradiction + · next he => + cases x; cases hd + simp at he ⊢; omega + +@[simp] theorem decodeSum_encodeSum (x : Sum (Fin n) (Fin m)) : decodeSum (encodeSum x) = x := by + simp only [encodeSum, decodeSum] + split + · next hd => + split + · simp + · simp at hd; omega + · next hd => + split + · simp at hd + · next x => + cases x + simp; omega + +/-- Encode a product of `Fin` types. -/ +def encodeProd : Fin m × Fin n → Fin (m * n) + | (⟨i, hi⟩, ⟨j, hj⟩) => Fin.mk (n * i + j) <| calc + _ < n * i + n := Nat.add_lt_add_left hj .. + _ = n * (i + 1) := Nat.mul_succ .. + _ ≤ n * m := Nat.mul_le_mul_left n (Nat.succ_le_of_lt hi) + _ = m * n := Nat.mul_comm .. + +/-- Decode a product of `Fin` types. -/ +@[pp_nodot] def decodeProd (z : Fin (m * n)) : Fin m × Fin n := + ⟨left, right⟩ +where + hn : 0 < n := by + apply Nat.zero_lt_of_ne_zero + intro h + absurd z.is_lt + simp [h] + /-- Left case of `decodeProd`. -/ + left := ⟨z / n, by rw [Nat.div_lt_iff_lt_mul hn]; exact z.is_lt⟩ + /-- Right case of `decodeProd`. -/ + right := ⟨z % n, Nat.mod_lt _ hn⟩ + +@[simp] theorem encodeProd_decodeProd (x : Fin (m * n)) : encodeProd (decodeProd x) = x := by + simp [encodeProd, decodeProd, decodeProd.left, decodeProd.right, Nat.div_add_mod] + +@[simp] theorem decodeProd_encodeProd (x : Fin m × Fin n) : decodeProd (encodeProd x) = x := by + match x with + | ⟨⟨_, _⟩, ⟨_, h⟩⟩ => simp [encodeProd, decodeProd, decodeProd.left, decodeProd.right, + Nat.mul_add_div (Nat.zero_lt_of_lt h), Nat.div_eq_of_lt h, Nat.mul_add_mod, Nat.mod_eq_of_lt h] + +/-- Encode a dependent sum of `Fin` types. -/ +def encodeSigma (f : Fin n → Nat) (x : (i : Fin n) × Fin (f i)) : Fin (Fin.sum f) := + match n, f, x with + | _+1, _, ⟨⟨0, _⟩, ⟨j, hj⟩⟩ => + ⟨j, Nat.lt_of_lt_of_le hj (sum_succ .. ▸ Nat.le_add_right ..)⟩ + | _+1, f, ⟨⟨i+1, hi⟩, ⟨j, hj⟩⟩ => + match encodeSigma ((f ∘ succ)) ⟨⟨i, Nat.lt_of_succ_lt_succ hi⟩, ⟨j, hj⟩⟩ with + | ⟨k, hk⟩ => ⟨f 0 + k, sum_succ .. ▸ Nat.add_lt_add_left hk ..⟩ + +/-- Decode a dependent sum of `Fin` types. -/ +@[pp_nodot] def decodeSigma (f : Fin n → Nat) (x : Fin (Fin.sum f)) : (i : Fin n) × Fin (f i) := + match n, f, x with + | 0, _, ⟨_, h⟩ => False.elim (by simp at h) + | n+1, f, ⟨x, hx⟩ => + if hx0 : x < f 0 then + ⟨0, ⟨x, hx0⟩⟩ + else + have hxf : x - f 0 < Fin.sum (f ∘ succ) := by + apply Nat.sub_lt_left_of_lt_add + · exact Nat.le_of_not_gt hx0 + · rw [← sum_succ]; exact hx + match decodeSigma ((f ∘ succ)) ⟨x - f 0, hxf⟩ with + | ⟨⟨i, hi⟩, y⟩ => ⟨⟨i+1, Nat.succ_lt_succ hi⟩, y⟩ + +@[simp] theorem encodeSigma_decodeSigma (f : Fin n → Nat) (x : Fin (Fin.sum f)) : + encodeSigma f (decodeSigma f x) = x := by + induction n with + | zero => absurd x.is_lt; simp + | succ n ih => + simp only [decodeSigma] + split + · simp [encodeSigma] + · next h1 => + have h2 : x - f 0 < Fin.sum (f ∘ succ) := by + apply Nat.sub_lt_left_of_lt_add + · exact Nat.le_of_not_gt h1 + · rw [← sum_succ]; exact x.is_lt + have : encodeSigma (f ∘ succ) (decodeSigma (f ∘ succ) ⟨x - f 0, h2⟩) = ↑x - f 0 := by + rw [ih] + ext; simp only [encodeSigma] + conv => rhs; rw [← Nat.add_sub_of_le (Nat.le_of_not_gt h1)] + congr + +@[simp] theorem decodeSigma_encodeSigma (f : Fin n → Nat) (x : (i : Fin n) × Fin (f i)) : + decodeSigma f (encodeSigma f x) = x := by + induction n with + | zero => nomatch x + | succ n ih => + simp only [decodeSigma] + match x with + | ⟨0, x⟩ => simp [encodeSigma] + | ⟨⟨i+1, hi⟩, ⟨x, hx⟩⟩ => + have : ¬ encodeSigma f ⟨⟨i+1, hi⟩, ⟨x, hx⟩⟩ < f 0 := by simp [encodeSigma] + rw [dif_neg this] + have : (encodeSigma f ⟨⟨i+1, hi⟩, ⟨x, hx⟩⟩).1 - f 0 = + (encodeSigma (f ∘ Fin.succ) ⟨⟨i, Nat.lt_of_succ_lt_succ hi⟩, ⟨x, hx⟩⟩).1 := by + simp [encodeSigma, Nat.add_sub_cancel_left] + have h := ih (f ∘ Fin.succ) ⟨⟨i, Nat.lt_of_succ_lt_succ hi⟩, ⟨x, hx⟩⟩ + simp [Sigma.ext_iff, Fin.ext_iff] at h ⊢ + constructor + · conv => rhs; rw [← h.1] + apply congrArg + apply congrArg Sigma.fst + congr + · apply HEq.trans _ h.2; congr + +/-- Encode a function between `Fin` types. -/ +def encodeFun : {m : Nat} → (Fin m → Fin n) → Fin (n ^ m) + | 0, _ => ⟨0, by simp⟩ + | m+1, f => Fin.mk (n * (encodeFun fun k => f k.succ).val + (f 0).val) <| calc + _ < n * (encodeFun fun k => f k.succ).val + n := Nat.add_lt_add_left (f 0).isLt _ + _ = n * ((encodeFun fun k => f k.succ).val + 1) := Nat.mul_succ .. + _ ≤ n * n ^ m := Nat.mul_le_mul_left n (Nat.succ_le_of_lt (encodeFun fun k => f k.succ).isLt) + _ = n ^ m * n := Nat.mul_comm .. + _ = n ^ (m+1) := Nat.pow_succ .. + +/-- Decode a function between `Fin` types. -/ +@[pp_nodot] def decodeFun : {m : Nat} → Fin (n ^ m) → Fin m → Fin n + | 0, _ => (nomatch .) + | m+1, ⟨k, hk⟩ => + have hn : n > 0 := by + apply Nat.zero_lt_of_ne_zero + intro h + rw [h, Nat.pow_succ, Nat.mul_zero] at hk + contradiction + fun + | 0 => ⟨k % n, Nat.mod_lt k hn⟩ + | ⟨i+1, hi⟩ => + have h : k / n < n ^ m := by rw [Nat.div_lt_iff_lt_mul hn]; exact hk + decodeFun ⟨k / n, h⟩ ⟨i, Nat.lt_of_succ_lt_succ hi⟩ + +@[simp] theorem encodeFun_decodeFun (x : Fin (n ^ m)) : encodeFun (decodeFun x) = x := by + induction m with simp only [encodeFun, decodeFun, Fin.succ] + | zero => simp; omega + | succ m ih => cases x; simp [ih]; rw [← Fin.zero_eta, Nat.div_add_mod] + +@[simp] theorem decodeFun_encodeFun (x : Fin m → Fin n) : decodeFun (encodeFun x) = x := by + funext i; induction m with simp only [encodeFun, decodeFun] + | zero => nomatch i + | succ m ih => + have hn : 0 < n := Nat.zero_lt_of_lt (x 0).is_lt + split + · ext; simp [Nat.mul_add_mod, Nat.mod_eq_of_lt] + · next i hi => + have : decodeFun (encodeFun fun k => x k.succ) ⟨i, Nat.lt_of_succ_lt_succ hi⟩ + = x ⟨i+1, hi⟩ := by rw [ih]; rfl + simp [← this, Nat.mul_add_div hn, Nat.div_eq_of_lt] + +/-- Encode a dependent product of `Fin` types. -/ +def encodePi (f : Fin n → Nat) (x : (i : Fin n) → Fin (f i)) : Fin (Fin.prod f) := + match n, f, x with + | 0, _, _ => ⟨0, by simp [Fin.prod]⟩ + | _+1, f, x => + match encodePi ((f ∘ succ)) (fun ⟨i, hi⟩ => x ⟨i+1, Nat.succ_lt_succ hi⟩) with + | ⟨k, hk⟩ => Fin.mk (f 0 * k + (x 0).val) $ calc + _ < f 0 * k + f 0 := Nat.add_lt_add_left (x 0).isLt .. + _ = f 0 * (k + 1) := Nat.mul_succ .. + _ ≤ f 0 * Fin.prod (f ∘ succ) := Nat.mul_le_mul_left _ (Nat.succ_le_of_lt hk) + _ = Fin.prod f := Eq.symm <| prod_succ .. + +/-- Decode a dependent product of `Fin` types. -/ +def decodePi (f : Fin n → Nat) (x : Fin (Fin.prod f)) : (i : Fin n) → Fin (f i) := + match n, f, x with + | 0, _, _ => (nomatch ·) + | n+1, f, ⟨x, hx⟩ => + have h : f 0 > 0 := by + apply Nat.zero_lt_of_ne_zero + intro h + rw [prod_succ, h, Nat.zero_mul] at hx + contradiction + have : x / f 0 < Fin.prod (f ∘ succ) := by + rw [Nat.div_lt_iff_lt_mul h, Nat.mul_comm, ← prod_succ] + exact hx + match decodePi ((f ∘ succ)) ⟨x / f 0, this⟩ with + | t => fun + | ⟨0, _⟩ => ⟨x % f 0, Nat.mod_lt x h⟩ + | ⟨i+1, hi⟩ => t ⟨i, Nat.lt_of_succ_lt_succ hi⟩ + +@[simp] theorem encodePi_decodePi (f : Fin n → Nat) (x : Fin (Fin.prod f)) : + encodePi f (decodePi f x) = x := by + induction n with + | zero => + match x with + | ⟨0, _⟩ => rfl + | ⟨_+1, h⟩ => simp at h + | succ n ih => + simp only [encodePi, decodePi, ih] + ext + conv => rhs; rw [← Nat.div_add_mod x (f 0)] + congr + +@[simp] theorem decodePi_encodePi (f : Fin n → Nat) (x : (i : Fin n) → Fin (f i)) : + decodePi f (encodePi f x) = x := by + induction n with + | zero => funext i; nomatch i + | succ n ih => + have h0 : 0 < f 0 := Nat.zero_lt_of_lt (x 0).is_lt + funext i + simp only [decodePi] + split + · simp [Nat.mul_add_mod, Nat.mod_eq_of_lt (x 0).is_lt] + · next i hi => + have h : decodePi (f ∘ succ) (encodePi (f ∘ succ) fun i => x i.succ) + ⟨i, Nat.lt_of_succ_lt_succ hi⟩ = x ⟨i+1, hi⟩ := by rw [ih]; rfl + conv => rhs; rw [← h] + congr; simp [Nat.mul_add_div h0, Nat.div_eq_of_lt (x 0).is_lt]; rfl + +/-- Encode a decidable subtype of a `Fin` type. -/ +def encodeSubtype (P : Fin n → Prop) [inst : DecidablePred P] (i : { i // P i }) : + Fin (Fin.count P) := + match n, P, inst, i with + | n+1, P, inst, ⟨0, hp⟩ => + have : Fin.count P > 0 := by simp [count_succ, hp] + ⟨0, this⟩ + | n+1, P, inst, ⟨⟨i+1, hi⟩, hp⟩ => + match encodeSubtype (fun i => P i.succ) ⟨⟨i, Nat.lt_of_succ_lt_succ hi⟩, hp⟩ with + | ⟨k, hk⟩ => + if h0 : P 0 then + have : Fin.count P = (Fin.count fun i => P i.succ) + 1 := by + simp_arith only [count_succ, Function.comp_def, if_pos h0] + this ▸ ⟨k+1, Nat.succ_lt_succ hk⟩ + else + have : Fin.count P = Fin.count fun i => P i.succ := by simp [count_succ, h0] + this ▸ ⟨k, hk⟩ + +/-- Decode a decidable subtype of a `Fin` type. -/ +def decodeSubtype (p : Fin n → Prop) [inst : DecidablePred p] (k : Fin (Fin.count p)) : + { i // p i } := + match n, p, inst, k with + | 0, _, _, ⟨_, h⟩ => False.elim (by simp at h) + | n+1, p, inst, ⟨k, hk⟩ => + if h0 : p 0 then + have : Fin.count p = (Fin.count fun i => p i.succ) + 1 := by simp [count_succ, h0] + match k with + | 0 => ⟨0, h0⟩ + | k + 1 => + match decodeSubtype (fun i => p i.succ) ⟨k, Nat.lt_of_add_lt_add_right (this ▸ hk)⟩ with + | ⟨⟨i, hi⟩, hp⟩ => ⟨⟨i+1, Nat.succ_lt_succ hi⟩, hp⟩ + else + have : Fin.count p = Fin.count fun i => p i.succ := by simp [count_succ, h0] + match decodeSubtype (fun i => p (succ i)) ⟨k, this ▸ hk⟩ with + | ⟨⟨i, hi⟩, hp⟩ => ⟨⟨i+1, Nat.succ_lt_succ hi⟩, hp⟩ + +@[simp] theorem encodeSubtype_decodeSubtype (P : Fin n → Prop) [DecidablePred P] + (x : Fin (Fin.count P)) : encodeSubtype P (decodeSubtype P x) = x := by + induction n with + | zero => absurd x.is_lt; simp + | succ n ih => + simp only [decodeSubtype] + split + · ext; split <;> simp [encodeSubtype, count_succ, *] + · ext; simp [encodeSubtype, count_succ, *] + +theorem encodeSubtype_zero_pos {P : Fin (n+1) → Prop} [DecidablePred P] (h₀ : P 0) : + encodeSubtype P ⟨0, h₀⟩ = ⟨0, by simp [count_succ, *]⟩ := by + ext; simp [encodeSubtype] + +theorem encodeSubtype_succ_pos {P : Fin (n+1) → Prop} [DecidablePred P] (h₀ : P 0) {i : Fin n} + (h : P i.succ) : encodeSubtype P ⟨i.succ, h⟩ = + (encodeSubtype (fun i => P i.succ) ⟨i, h⟩).succ.cast (by simp [count_succ, *]) := by + ext; simp [encodeSubtype, count_succ, *] + +theorem encodeSubtype_succ_neg {P : Fin (n+1) → Prop} [DecidablePred P] (h₀ : ¬ P 0) {i : Fin n} + (h : P i.succ) : encodeSubtype P ⟨i.succ, h⟩ = + (encodeSubtype (fun i => P i.succ) ⟨i, h⟩).cast (by simp [count_succ, *]) := by + ext; simp [encodeSubtype, count_succ, *] + +@[simp] theorem decodeSubtype_encodeSubtype (P : Fin n → Prop) [DecidablePred P] (x : { x // P x}) : + decodeSubtype P (encodeSubtype P x) = x := by + match x with + | ⟨i, h⟩ => + induction n with + | zero => absurd x.val.is_lt; simp + | succ n ih => + if h₀ : P 0 then + simp only [decodeSubtype, dif_pos h₀] + cases i using Fin.casesZeroSuccOn with + | zero => rw [encodeSubtype_zero_pos h₀] + | succ i => + rw [encodeSubtype_succ_pos h₀] + simp only [coe_cast, val_succ, Subtype.mk.injEq] + congr + rw [ih (fun i => P i.succ) ⟨i, h⟩] + else + simp only [decodeSubtype, dif_neg h₀] + cases i using Fin.casesZeroSuccOn with + | zero => contradiction + | succ i => + rw [encodeSubtype_succ_neg h₀] + simp only [coe_cast, Subtype.mk.injEq] + congr + rw [ih (fun i => P i.succ) ⟨i, h⟩] + +/-- Get representative for the equivalence class of `x`. -/ +abbrev getRepr (s : Setoid (Fin n)) [DecidableRel s.r] (x : Fin n) : Fin n := + match h : Fin.find? (s.r x) with + | some y => y + | none => False.elim <| by + have : Fin.find? (s.r x) |>.isSome := by + rw [find?_isSome_iff_exists] + exists x + exact Setoid.refl x + simp [h] at this + +@[simp] theorem equiv_getRepr (s : Setoid (Fin n)) [DecidableRel s.r] (x : Fin n) : + s.r x (getRepr s x) := by + apply find?_prop + simp only [getRepr] + split + · assumption + · next h => + have : Fin.find? (s.r x) |>.isSome := by + rw [find?_isSome_iff_exists] + exists x + exact Setoid.refl x + simp [h] at this + +@[simp] theorem getRepr_equiv (s : Setoid (Fin n)) [DecidableRel s.r] (x : Fin n) : + s.r (getRepr s x) x := Setoid.symm (equiv_getRepr ..) + +theorem getRepr_eq_getRepr_of_equiv (s : Setoid (Fin n)) [DecidableRel s.r] (h : s.r x y) : + getRepr s x = getRepr s y := by + have hfind : Fin.find? (s.r x) = Fin.find? (s.r y) := by + congr + funext z + apply propext + constructor + · exact Setoid.trans (Setoid.symm h) + · exact Setoid.trans h + simp only [getRepr] + split + · next hx => + rw [hfind] at hx + split + · next hy => + rwa [hx, Option.some_inj] at hy + · next hy => + rw [hx] at hy + contradiction + · next hx => + rw [hfind] at hx + split + · next hy => + rw [hx] at hy + contradiction + · rfl + +@[simp] theorem getRepr_getRepr (s : Setoid (Fin n)) [DecidableRel s.r] (x : Fin n) : + getRepr s (getRepr s x) = getRepr s x := by + apply getRepr_eq_getRepr_of_equiv + exact getRepr_equiv .. + +/-- Encode decidable quotient of a `Fin` type. -/ +def encodeQuotient (s : Setoid (Fin n)) [DecidableRel s.r] (x : Quotient s) : + Fin (Fin.count fun i => getRepr s i = i) := + encodeSubtype _ <| Quotient.liftOn x (fun i => ⟨getRepr s i, getRepr_getRepr s i⟩) <| by + intro _ _ h + simp only [Subtype.mk.injEq] + exact getRepr_eq_getRepr_of_equiv s h + +/-- Decode decidable quotient of a `Fin ` type. -/ +def decodeQuotient (s : Setoid (Fin n)) [DecidableRel s.r] + (i : Fin (Fin.count fun i => getRepr s i = i)) : Quotient s := + Quotient.mk s (decodeSubtype _ i) + +@[simp] theorem encodeQuotient_decodeQuotient (s : Setoid (Fin n)) [DecidableRel s.r] (x) : + encodeQuotient s (decodeQuotient s x) = x := by + simp only [decodeQuotient, encodeQuotient, Quotient.liftOn, Quotient.mk, Quot.liftOn] + conv => rhs; rw [← encodeSubtype_decodeSubtype _ x] + congr + exact (decodeSubtype _ x).property + +@[simp] theorem decodeQuotient_encodeQuotient (s : Setoid (Fin n)) [DecidableRel s.r] (x) : + decodeQuotient s (encodeQuotient s x) = x := by + induction x using Quotient.inductionOn with + | _ x => + simp only [decodeQuotient, encodeQuotient, Quotient.liftOn, Quotient.mk, Quot.liftOn] + apply Quot.sound + simp diff --git a/Batteries/Data/Fin/Enum.lean b/Batteries/Data/Fin/Enum.lean new file mode 100644 index 0000000000..93a73516da --- /dev/null +++ b/Batteries/Data/Fin/Enum.lean @@ -0,0 +1,188 @@ +/- +Copyright (c) 2024 François G. Dorais. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: François G. Dorais +-/ +import Batteries.Data.Fin.Coding + +namespace Fin + +/-- Class of types that are in bijection with a `Fin` type. -/ +protected class Enum (α : Type _) where + /-- Size of type. -/ + size : Nat + /-- Enumeration of the type elements. -/ + enum : Fin size → α + /-- Find the index of a type element. -/ + find : α → Fin size + /-- Inverse relation for `enum` and `find`. -/ + enum_find (x) : enum (find x) = x + /-- Inverse relation for `enum` and `find`. -/ + find_enum (i) : find (enum i) = i + +attribute [simp] Enum.enum_find Enum.find_enum + +namespace Enum + +instance : Fin.Enum Empty where + size := 0 + enum := nofun + find := nofun + enum_find := nofun + find_enum := nofun + +instance : Fin.Enum PUnit where + size := 1 + enum := decodePUnit + find := encodePUnit + enum_find := decodePUnit_encodePUnit + find_enum := encodePUnit_decodePUnit + +instance : Fin.Enum Bool where + size := 2 + enum := decodeBool + find := encodeBool + enum_find := decodeBool_encodeBool + find_enum := encodeBool_decodeBool + +instance [Fin.Enum α] : Fin.Enum (Option α) where + size := size α + 1 + enum i := decodeOption i |>.map enum + find x := encodeOption <| x.map find + enum_find := by simp [Function.comp_def] + find_enum := by simp [Function.comp_def] + +instance [Fin.Enum α] [Fin.Enum β] : Fin.Enum (α ⊕ β) where + size := size α + size β + enum i := + match decodeSum i with + | .inl i => .inl <| enum i + | .inr i => .inr <| enum i + find x := + match x with + | .inl x => encodeSum <| .inl (find x) + | .inr x => encodeSum <| .inr (find x) + enum_find _ := by + simp only; split + · next h => + split at h; + · simp at h; cases h; simp + · simp at h + · next h => + split at h; + · simp at h + · simp at h; cases h; simp + find_enum _ := by + simp only; split + · next h => + split at h + · next h' => cases h; simp [← h'] + · simp at h + · next h => + split at h + · simp at h + · next h' => cases h; simp [← h'] + +instance [Fin.Enum α] [Fin.Enum β] : Fin.Enum (α × β) where + size := size α * size β + enum i := (enum (decodeProd i).fst, enum (decodeProd i).snd) + find x := encodeProd (find x.fst, find x.snd) + find_enum := by simp [Prod.eta] + enum_find := by simp + +instance [Fin.Enum α] [Fin.Enum β] : Fin.Enum (β → α) where + size := size α ^ size β + enum i x := enum (decodeFun i (find x)) + find f := encodeFun fun x => find (f (enum x)) + find_enum := by simp + enum_find := by simp + +instance (β : α → Type _) [Fin.Enum α] [(x : α) → Fin.Enum (β x)] : Fin.Enum ((x : α) × β x) where + size := Fin.sum fun i => size (β (enum i)) + enum i := + match decodeSigma _ i with + | ⟨i, j⟩ => ⟨enum i, enum j⟩ + find | ⟨x, y⟩ => encodeSigma _ ⟨find x, find (_root_.cast (by simp) y)⟩ + find_enum i := by + simp only [] + conv => rhs; rw [← encodeSigma_decodeSigma _ i] + congr 1 + ext + · simp + · simp only [] + conv => rhs; rw [← find_enum (decodeSigma _ i).snd] + congr <;> simp + enum_find + | ⟨x, y⟩ => by + ext + simp only [cast, decodeSigma_encodeSigma, enum_find] + simp [] + rw [decodeSigma_encodeSigma] + simp + +instance (β : α → Type _) [Fin.Enum α] [(x : α) → Fin.Enum (β x)] : Fin.Enum ((x : α) → β x) where + size := Fin.prod fun i => size (β (enum i)) + enum i x := enum <| (decodePi _ i (find x)).cast (by rw [enum_find]) + find f := encodePi _ fun i => find (f (enum i)) + find_enum i := by + simp only [find_enum] + conv => rhs; rw [← encodePi_decodePi _ i] + congr + ext + simp only [cast] + rw [find_enum] + enum_find f := by + funext x + simp only [] + conv => rhs; rw [← enum_find (f x)] + congr 1 + ext + simp only [cast, decodePi_encodePi] + rw [enum_find] + done + +instance (P : α → Prop) [DecidablePred P] [Fin.Enum α] : Fin.Enum { x // P x} where + size := Fin.count fun i => P (enum i) + enum i := ⟨enum (decodeSubtype _ i).val, (decodeSubtype _ i).property⟩ + find x := encodeSubtype _ ⟨find x.val, (enum_find x.val).symm ▸ x.property⟩ + find_enum i := by simp [Subtype.eta] + enum_find := by simp + +private def enumSetoid (s : Setoid α) [DecidableRel s.r] [Fin.Enum α] : + Setoid (Fin (size α)) where + r i j := s.r (enum i) (enum j) + iseqv := { + refl := fun i => Setoid.refl (enum i) + symm := Setoid.symm + trans := Setoid.trans + } + +private instance (s : Setoid α) [DecidableRel s.r] [Fin.Enum α] : DecidableRel (enumSetoid s).r := + fun _ _ => inferInstanceAs (Decidable (s.r _ _)) + +instance (s : Setoid α) [DecidableRel s.r] [Fin.Enum α] : Fin.Enum (Quotient s) where + size := Fin.count fun i => getRepr (enumSetoid s) i = i + enum i := Quotient.liftOn (decodeQuotient (enumSetoid s) i) (fun i => Quotient.mk s (enum i)) + (fun _ _ h => Quotient.sound h) + find x := Quotient.liftOn x + (fun x => encodeQuotient (enumSetoid s) (Quotient.mk _ (find x))) <| by + intro _ _ h + simp only [] + congr 1 + apply Quotient.sound + simp only [HasEquiv.Equiv, enumSetoid, enum_find] + exact h + enum_find x := by + induction x using Quotient.inductionOn with + | _ x => + simp only [decodeQuotient, encodeQuotient, decodeSubtype_encodeSubtype, Quotient.liftOn, + Quotient.mk, Quot.liftOn] + apply Quot.sound + conv => rhs; rw [← enum_find x] + exact getRepr_equiv (enumSetoid s) .. + find_enum i := by + conv => rhs; rw [← encodeSubtype_decodeSubtype _ i] + simp only [Quotient.liftOn, Quot.liftOn, encodeQuotient, Quotient.mk, decodeQuotient, + find_enum] + congr + rw [(decodeSubtype _ i).property] diff --git a/Batteries/Data/Fin/Lemmas.lean b/Batteries/Data/Fin/Lemmas.lean index 5010e1310f..e2edb35f03 100644 --- a/Batteries/Data/Fin/Lemmas.lean +++ b/Batteries/Data/Fin/Lemmas.lean @@ -5,11 +5,16 @@ Authors: Mario Carneiro -/ import Batteries.Data.Fin.Basic import Batteries.Data.List.Lemmas +import Batteries.Tactic.Lint.Simp namespace Fin attribute [norm_cast] val_last +@[nolint simpNF, simp] +theorem val_ndrec (x : Fin n) (h : m = n) : (h ▸ x).val = x.val := by + cases h; rfl + /-! ### clamp -/ @[simp] theorem coe_clamp (n m : Nat) : (clamp n m : Nat) = min n m := rfl @@ -214,3 +219,104 @@ theorem foldr_rev (f : α → Fin n → α) (x) : induction n generalizing x with | zero => simp | succ n ih => rw [foldl_succ_last, foldr_succ, ← ih]; simp [rev_succ] + +/-! ### sum -/ + +@[simp] theorem sum_zero [OfNat α (nat_lit 0)] [Add α] (x : Fin 0 → α) : + Fin.sum x = 0 := by + simp [Fin.sum] + +theorem sum_succ [OfNat α (nat_lit 0)] [Add α] (x : Fin (n + 1) → α) : + Fin.sum x = x 0 + Fin.sum (x ∘ Fin.succ) := by + simp [Fin.sum, foldr_succ] + +/-! ### prod -/ + +@[simp] theorem prod_zero [OfNat α (nat_lit 1)] [Mul α] (x : Fin 0 → α) : + Fin.prod x = 1 := by + simp [Fin.prod] + +theorem prod_succ [OfNat α (nat_lit 1)] [Mul α] (x : Fin (n + 1) → α) : + Fin.prod x = x 0 * Fin.prod (x ∘ Fin.succ) := by + simp [Fin.prod, foldr_succ] + +/-! ### count -/ + +@[simp] theorem count_zero (P : Fin 0 → Prop) [DecidablePred P] : Fin.count P = 0 := by + simp [Fin.count] + +theorem count_succ (P : Fin (n + 1) → Prop) [DecidablePred P] : Fin.count P = + if P 0 then Fin.count (fun i => P i.succ) + 1 else Fin.count (fun i => P i.succ) := by + split <;> simp [Fin.count, Fin.sum_succ, Nat.one_add, Function.comp_def, *] + +theorem count_le (P : Fin n → Prop) [DecidablePred P] : Fin.count P ≤ n := by + induction n with + | zero => simp + | succ n ih => + rw [count_succ] + split + · simp [ih] + · apply Nat.le_trans _ (Nat.le_succ n); simp [ih] + +/-! ### find? -/ + +theorem find?_prop {P : Fin n → Prop} [DecidablePred P] (h : Fin.find? P = some x) : P x := by + induction n with + | zero => contradiction + | succ n ih => + simp [Fin.find?] at h + split at h + · cases h; assumption + · split at h + · cases h + apply ih (P := fun i => P i.succ) + assumption + · contradiction + +theorem exists_of_find?_isSome {P : Fin n → Prop} [DecidablePred P] (h : (Fin.find? P).isSome) : + ∃ x, P x := by + match heq : Fin.find? P with + | some x => exists x; exact find?_prop heq + | none => rw [heq] at h; contradiction + +theorem find?_isSome_of_exists {P : Fin n → Prop} [DecidablePred P] (h : ∃ x, P x) : + (Fin.find? P).isSome := by + induction n with + | zero => match h with | ⟨x, _⟩ => nomatch x + | succ n ih => + simp only [Fin.find?] + split + · rfl + · have h : ∃ (x : Fin n), P x.succ := by + match h with + | ⟨0, _⟩ => contradiction + | ⟨⟨i+1, hi⟩, _⟩ => exists ⟨i, Nat.lt_of_succ_lt_succ hi⟩ + have h : (Fin.find? fun x => P x.succ).isSome := ih h + split + · rfl + · next heq => + rw [heq] at h + contradiction + +theorem find?_isSome_iff_exists {P : Fin n → Prop} [DecidablePred P] : + (Fin.find? P).isSome ↔ ∃ x, P x := ⟨exists_of_find?_isSome, find?_isSome_of_exists⟩ + +/-! ### recZeroSuccOn -/ + +unseal Fin.recZeroSuccOn in +@[simp] theorem recZeroSuccOn_zero {motive : Fin (n+1) → Sort _} (zero : motive 0) + (succ : (x : Fin n) → motive x.castSucc → motive x.succ) : + Fin.recZeroSuccOn 0 zero succ = zero := rfl + +unseal Fin.recZeroSuccOn in +theorem recZeroSuccOn_succ {motive : Fin (n+1) → Sort _} (x : Fin n) (zero : motive 0) + (succ : (x : Fin n) → motive x.castSucc → motive x.succ) : + Fin.recZeroSuccOn x.succ zero succ = succ x (Fin.recZeroSuccOn x.castSucc zero succ) := rfl + +/-! ### casesZeroSuccOn -/ + +@[simp] theorem casesZeroSuccOn_zero {motive : Fin (n+1) → Sort _} (zero : motive 0) + (succ : (x : Fin n) → motive x.succ) : Fin.casesZeroSuccOn 0 zero succ = zero := rfl + +@[simp] theorem casesZeroSuccOn_succ {motive : Fin (n+1) → Sort _} (x : Fin n) (zero : motive 0) + (succ : (x : Fin n) → motive x.succ) : Fin.casesZeroSuccOn x.succ zero succ = succ x := rfl