Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support for Fin in omega #3427

Merged
merged 2 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/Init/Data/Fin/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,19 @@ def natAdd (n) (i : Fin m) : Fin (n + m) := ⟨n + i, Nat.add_lt_add_left i.2 _
@[inline] def pred {n : Nat} (i : Fin (n + 1)) (h : i ≠ 0) : Fin n :=
subNat 1 i <| Nat.pos_of_ne_zero <| mt (Fin.eq_of_val_eq (j := 0)) h

theorem val_inj {a b : Fin n} : a.1 = b.1 ↔ a = b := ⟨Fin.eq_of_val_eq, Fin.val_eq_of_eq⟩

theorem val_congr {n : Nat} {a b : Fin n} (h : a = b) : (a : Nat) = (b : Nat) :=
Fin.val_inj.mpr h

theorem val_le_of_le {n : Nat} {a b : Fin n} (h : a ≤ b) : (a : Nat) ≤ (b : Nat) := h

theorem val_le_of_ge {n : Nat} {a b : Fin n} (h : a ≥ b) : (b : Nat) ≤ (a : Nat) := h

theorem val_add_one_le_of_lt {n : Nat} {a b : Fin n} (h : a < b) : (a : Nat) + 1 ≤ (b : Nat) := h

theorem val_add_one_le_of_gt {n : Nat} {a b : Fin n} (h : a > b) : (b : Nat) + 1 ≤ (a : Nat) := h

end Fin

instance [GetElem cont Nat elem dom] : GetElem cont (Fin n) elem fun xs i => dom xs i where
Expand Down
2 changes: 0 additions & 2 deletions src/Init/Data/Fin/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ theorem pos_iff_nonempty {n : Nat} : 0 < n ↔ Nonempty (Fin n) :=

@[ext] theorem ext {a b : Fin n} (h : (a : Nat) = b) : a = b := eq_of_val_eq h

theorem val_inj {a b : Fin n} : a.1 = b.1 ↔ a = b := ⟨Fin.eq_of_val_eq, Fin.val_eq_of_eq⟩

theorem ext_iff {a b : Fin n} : a = b ↔ a.1 = b.1 := val_inj.symm

theorem val_ne_iff {a b : Fin n} : a.1 ≠ b.1 ↔ a ≠ b := not_congr val_inj
Expand Down
34 changes: 33 additions & 1 deletion src/Init/Omega/Int.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import Init.Data.Int.DivModLemmas
import Init.Data.Nat.Lemmas

/-!
# Lemmas about `Nat` and `Int` needed internally by `omega`.
# Lemmas about `Nat`, `Int`, and `Fin` needed internally by `omega`.

These statements are useful for constructing proof expressions,
but unlikely to be widely useful, so are inside the `Std.Tactic.Omega` namespace.
Expand Down Expand Up @@ -163,6 +163,38 @@ theorem le_of_ge {x y : Nat} (h : x ≥ y) : y ≤ x := ge_iff_le.mp h

end Nat

namespace Fin

theorem ne_iff_lt_or_gt {i j : Fin n} : i ≠ j ↔ i < j ∨ i > j := by
cases i; cases j; simp only [ne_eq, Fin.mk.injEq, Nat.ne_iff_lt_or_gt, gt_iff_lt]; rfl

protected theorem lt_or_gt_of_ne {i j : Fin n} (h : i ≠ j) : i < j ∨ i > j := Fin.ne_iff_lt_or_gt.mp h

theorem not_le {i j : Fin n} : ¬ i ≤ j ↔ j < i := by
cases i; cases j; exact Nat.not_le

theorem not_lt {i j : Fin n} : ¬ i < j ↔ j ≤ i := by
cases i; cases j; exact Nat.not_lt

protected theorem lt_of_not_le {i j : Fin n} (h : ¬ i ≤ j) : j < i := Fin.not_le.mp h
protected theorem le_of_not_lt {i j : Fin n} (h : ¬ i < j) : j ≤ i := Fin.not_lt.mp h

theorem ofNat_val_add {x y : Fin n} :
(((x + y : Fin n)) : Int) = ((x : Int) + (y : Int)) % n := rfl

theorem ofNat_val_sub {x y : Fin n} :
(((x - y : Fin n)) : Int) = ((x : Int) + ((n - y : Nat) : Int)) % n := rfl

theorem ofNat_val_mul {x y : Fin n} :
(((x * y : Fin n)) : Int) = ((x : Int) * (y : Int)) % n := rfl

theorem ofNat_val_natCast {n x y : Nat} (h : y = x % (n + 1)):
@Nat.cast Int instNatCastInt (@Fin.val (n + 1) (OfNat.ofNat x)) = OfNat.ofNat y := by
rw [h]
rfl

end Fin

namespace Prod

theorem of_lex (w : Prod.Lex r s p q) : r p.fst q.fst ∨ p.fst = q.fst ∧ s p.snd q.snd :=
Expand Down
93 changes: 71 additions & 22 deletions src/Lean/Elab/Tactic/Omega/Frontend.lean
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,30 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
else
mkAtomLinearCombo e
| (``Nat.cast, #[.const ``Int [], i, n]) =>
handleNatCast e i n
| (``Prod.fst, #[α, β, p]) => match p with
| .app (.app (.app (.app (.const ``Prod.mk [u, v]) _) _) x) y =>
rewrite e (mkApp4 (.const ``Prod.fst_mk [u, v]) α x β y)
| _ => mkAtomLinearCombo e
| (``Prod.snd, #[α, β, p]) => match p with
| .app (.app (.app (.app (.const ``Prod.mk [u, v]) _) _) x) y =>
rewrite e (mkApp4 (.const ``Prod.snd_mk [u, v]) α x β y)
| _ => mkAtomLinearCombo e
| _ => mkAtomLinearCombo e
where
/--
Apply a rewrite rule to an expression, and interpret the result as a `LinearCombo`.
(We're not rewriting any subexpressions here, just the top level, for efficiency.)
-/
rewrite (lhs rw : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
trace[omega] "rewriting {lhs} via {rw} : {← inferType rw}"
match (← inferType rw).eq? with
| some (_, _lhs', rhs) =>
let (lc, prf, facts) ← asLinearCombo rhs
let prf' : OmegaM Expr := do mkEqTrans rw (← prf)
pure (lc, prf', facts)
| none => panic! "Invalid rewrite rule in 'asLinearCombo'"
handleNatCast (e i n : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
match n with
| .fvar h =>
if let some v ← h.getValue? then
Expand Down Expand Up @@ -259,29 +283,38 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
rewrite e (mkApp (.const ``Int.ofNat_natAbs []) n)
else
mkAtomLinearCombo e
| (``Fin.val, #[n, x]) =>
handleFinVal e i n x
| _ => mkAtomLinearCombo e
| (``Prod.fst, #[α, β, p]) => match p with
| .app (.app (.app (.app (.const ``Prod.mk [u, v]) _) _) x) y =>
rewrite e (mkApp4 (.const ``Prod.fst_mk [u, v]) α x β y)
| _ => mkAtomLinearCombo e
| (``Prod.snd, #[α, β, p]) => match p with
| .app (.app (.app (.app (.const ``Prod.mk [u, v]) _) _) x) y =>
rewrite e (mkApp4 (.const ``Prod.snd_mk [u, v]) α x β y)
| _ => mkAtomLinearCombo e
| _ => mkAtomLinearCombo e
where
/--
Apply a rewrite rule to an expression, and interpret the result as a `LinearCombo`.
(We're not rewriting any subexpressions here, just the top level, for efficiency.)
-/
rewrite (lhs rw : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
trace[omega] "rewriting {lhs} via {rw} : {← inferType rw}"
match (← inferType rw).eq? with
| some (_, _lhs', rhs) =>
let (lc, prf, facts) ← asLinearCombo rhs
let prf' : OmegaM Expr := do mkEqTrans rw (← prf)
pure (lc, prf', facts)
| none => panic! "Invalid rewrite rule in 'asLinearCombo'"
handleFinVal (e i n x : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
match x with
| .fvar h =>
if let some v ← h.getValue? then
rewrite e (← mkEqReflWithExpectedType e
(mkApp3 (.const ``Nat.cast [0]) (.const ``Int []) i (mkApp2 (.const ``Fin.val []) n v)))
else
mkAtomLinearCombo e
| _ => match x.getAppFnArgs, n.nat? with
| (``HAdd.hAdd, #[_, _, _, _, a, b]), _ =>
rewrite e (mkApp3 (.const ``Fin.ofNat_val_add []) n a b)
| (``HMul.hMul, #[_, _, _, _, a, b]), _ =>
rewrite e (mkApp3 (.const ``Fin.ofNat_val_mul []) n a b)
| (``HSub.hSub, #[_, _, _, _, a, b]), some _ =>
-- Only do this rewrite if `n` is a numeral.
rewrite e (mkApp3 (.const ``Fin.ofNat_val_sub []) n a b)
| (``OfNat.ofNat, #[_, y, _]), some m =>
-- Only do this rewrite if `n` is a nonzero numeral.
if m = 0 then
mkAtomLinearCombo e
else
match y with
| .lit (.natVal y) =>
rewrite e (mkApp4 (.const ``Fin.ofNat_val_natCast [])
(toExpr (m - 1)) (toExpr y) (.lit (.natVal (y % m))) (← mkEqRefl (toExpr (y % m))))
| _ =>
-- This shouldn't happen, we obtained `y` from `OfNat.ofNat`
mkAtomLinearCombo e
| _, _ => mkAtomLinearCombo e

end
namespace MetaProblem
Expand Down Expand Up @@ -344,11 +377,17 @@ def pushNot (h P : Expr) : MetaM (Option Expr) := do
return some (mkApp3 (.const ``Nat.le_of_not_lt []) x y h)
| (``LE.le, #[.const ``Nat [], _, x, y]) =>
return some (mkApp3 (.const ``Nat.lt_of_not_le []) x y h)
| (``LT.lt, #[.app (.const ``Fin []) n, _, x, y]) =>
return some (mkApp4 (.const ``Fin.le_of_not_lt []) n x y h)
| (``LE.le, #[.app (.const ``Fin []) n, _, x, y]) =>
return some (mkApp4 (.const ``Fin.lt_of_not_le []) n x y h)
| (``Eq, #[.const ``Nat [], x, y]) =>
return some (mkApp3 (.const ``Nat.lt_or_gt_of_ne []) x y h)
| (``Eq, #[.const ``Int [], x, y]) =>
return some (mkApp3 (.const ``Int.lt_or_gt_of_ne []) x y h)
| (``Prod.Lex, _) => return some (← mkAppM ``Prod.of_not_lex #[h])
| (``Eq, #[.app (.const ``Fin []) n, x, y]) =>
return some (mkApp4 (.const ``Fin.lt_or_gt_of_ne []) n x y h)
| (``Dvd.dvd, #[.const ``Nat [], _, k, x]) =>
return some (mkApp3 (.const ``Nat.emod_pos_of_not_dvd []) k x h)
| (``Dvd.dvd, #[.const ``Int [], _, k, x]) =>
Expand Down Expand Up @@ -423,6 +462,16 @@ partial def addFact (p : MetaProblem) (h : Expr) : OmegaM (MetaProblem × Nat) :
p.addFact (mkApp3 (.const ``Nat.mod_eq_zero_of_dvd []) k x h)
| (``Dvd.dvd, #[.const ``Int [], _, k, x]) =>
p.addFact (mkApp3 (.const ``Int.emod_eq_zero_of_dvd []) k x h)
| (``Eq, #[.app (.const ``Fin []) n, x, y]) =>
p.addFact (mkApp4 (.const ``Fin.val_congr []) n x y h)
| (``LE.le, #[.app (.const ``Fin []) n, _, x, y]) =>
p.addFact (mkApp4 (.const ``Fin.val_le_of_le []) n x y h)
| (``LT.lt, #[.app (.const ``Fin []) n, _, x, y]) =>
p.addFact (mkApp4 (.const ``Fin.val_add_one_le_of_lt []) n x y h)
| (``GE.ge, #[.app (.const ``Fin []) n, _, x, y]) =>
p.addFact (mkApp4 (.const ``Fin.val_le_of_ge []) n x y h)
| (``GT.gt, #[.app (.const ``Fin []) n, _, x, y]) =>
p.addFact (mkApp4 (.const ``Fin.val_add_one_le_of_gt []) n x y h)
| (``And, #[t₁, t₂]) => do
let (p₁, n₁) ← p.addFact (mkApp3 (.const ``And.left []) t₁ t₂ h)
let (p₂, n₂) ← p₁.addFact (mkApp3 (.const ``And.right []) t₁ t₂ h)
Expand Down
32 changes: 32 additions & 0 deletions tests/lean/run/omega.lean
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,40 @@ example (i j : Nat) (p : i ≥ j) : True := by
have _ : i ≥ k := by omega
trivial

/-! ### Fin -/


-- Test `<`
example (n : Nat) (i j : Fin n) (h : i < j) : (i : Nat) < n - 1 := by omega

-- Test `≤`
example (n : Nat) (i j : Fin n) (h : i < j) : (i : Nat) ≤ n - 2 := by omega

-- Test `>`
example (n : Nat) (i j : Fin n) (h : i < j) : n - 1 > i := by omega

-- Test `≥`
example (n : Nat) (i : Fin n) : n - 1 ≥ i := by omega

-- Test `=`
example (n : Nat) (i j : Fin n) (h : i = j) : (i : Int) = j := by omega

example (i j : Fin n) (w : i < j) : i < j := by omega

example (n m i : Nat) (j : Fin (n - m)) (h : i < j) (h2 : m ≥ 4) :
(i : Int) < n - 5 := by omega

example (x y : Nat) (_ : 2 ≤ x) (_ : x ≤ 3) (_ : 2 ≤ y) (_ : y ≤ 3) :
4 ≤ (x + y) % 8 ∧ (x + y) % 8 ≤ 6 := by
omega

example (x y : Fin 8) (_ : 2 ≤ x) (_ : x ≤ 3) (_ : 2 ≤ y) (_ : y ≤ 3) : 4 ≤ x + y ∧ x + y ≤ 6 := by
omega

example (i : Fin 7) : (i : Nat) < 8 := by omega

/-! ### mod 2^n -/

example (x y z i : Nat) (hz : z ≤ 1) : x % 2 ^ i + y % 2 ^ i + z < 2 * 2^ i := by omega

/-! ### Ground terms -/
Expand Down
Loading