Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: offset constraints support for the grind tactic #6603

Merged
merged 16 commits into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 41 additions & 156 deletions src/Init/Grind/Offset.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,159 +7,44 @@ prelude
import Init.Core
import Init.Omega

namespace Lean.Grind.Offset

abbrev Var := Nat
abbrev Context := Lean.RArray Nat

def fixedVar := 100000000 -- Any big number should work here

def Var.denote (ctx : Context) (v : Var) : Nat :=
bif v == fixedVar then 1 else ctx.get v

structure Cnstr where
x : Var
y : Var
k : Nat := 0
l : Bool := true
deriving Repr, DecidableEq, Inhabited

def Cnstr.denote (c : Cnstr) (ctx : Context) : Prop :=
if c.l then
c.x.denote ctx + c.k ≤ c.y.denote ctx
else
c.x.denote ctx ≤ c.y.denote ctx + c.k

def trivialCnstr : Cnstr := { x := 0, y := 0, k := 0, l := true }

@[simp] theorem denote_trivial (ctx : Context) : trivialCnstr.denote ctx := by
simp [Cnstr.denote, trivialCnstr]

def Cnstr.trans (c₁ c₂ : Cnstr) : Cnstr :=
if c₁.y = c₂.x then
let { x, k := k₁, l := l₁, .. } := c₁
let { y, k := k₂, l := l₂, .. } := c₂
match l₁, l₂ with
| false, false =>
{ x, y, k := k₁ + k₂, l := false }
| false, true =>
if k₁ < k₂ then
{ x, y, k := k₂ - k₁, l := true }
else
{ x, y, k := k₁ - k₂, l := false }
| true, false =>
if k₁ < k₂ then
{ x, y, k := k₂ - k₁, l := false }
else
{ x, y, k := k₁ - k₂, l := true }
| true, true =>
{ x, y, k := k₁ + k₂, l := true }
else
trivialCnstr

@[simp] theorem Cnstr.denote_trans_easy (ctx : Context) (c₁ c₂ : Cnstr) (h : c₁.y ≠ c₂.x) : (c₁.trans c₂).denote ctx := by
simp [*, Cnstr.trans]

@[simp] theorem Cnstr.denote_trans (ctx : Context) (c₁ c₂ : Cnstr) : c₁.denote ctx → c₂.denote ctx → (c₁.trans c₂).denote ctx := by
by_cases c₁.y = c₂.x
case neg => simp [*]
simp [trans, *]
let { x, k := k₁, l := l₁, .. } := c₁
let { y, k := k₂, l := l₂, .. } := c₂
simp_all; split
· simp [denote]; omega
· split <;> simp [denote] <;> omega
· split <;> simp [denote] <;> omega
· simp [denote]; omega

def Cnstr.isTrivial (c : Cnstr) : Bool := c.x == c.y && c.k == 0

theorem Cnstr.of_isTrivial (ctx : Context) (c : Cnstr) : c.isTrivial = true → c.denote ctx := by
cases c; simp [isTrivial]; intros; simp [*, denote]

def Cnstr.isFalse (c : Cnstr) : Bool := c.x == c.y && c.k != 0 && c.l == true

theorem Cnstr.of_isFalse (ctx : Context) {c : Cnstr} : c.isFalse = true → ¬c.denote ctx := by
cases c; simp [isFalse]; intros; simp [*, denote]; omega

def Cnstrs := List Cnstr

def Cnstrs.denoteAnd' (ctx : Context) (c₁ : Cnstr) (c₂ : Cnstrs) : Prop :=
match c₂ with
| [] => c₁.denote ctx
| c::cs => c₁.denote ctx ∧ Cnstrs.denoteAnd' ctx c cs

theorem Cnstrs.denote'_trans (ctx : Context) (c₁ c : Cnstr) (cs : Cnstrs) : c₁.denote ctx → denoteAnd' ctx c cs → denoteAnd' ctx (c₁.trans c) cs := by
induction cs
next => simp [denoteAnd', *]; apply Cnstr.denote_trans
next c cs ih => simp [denoteAnd']; intros; simp [*]

def Cnstrs.trans' (c₁ : Cnstr) (c₂ : Cnstrs) : Cnstr :=
match c₂ with
| [] => c₁
| c::c₂ => trans' (c₁.trans c) c₂

@[simp] theorem Cnstrs.denote'_trans' (ctx : Context) (c₁ : Cnstr) (c₂ : Cnstrs) : denoteAnd' ctx c₁ c₂ → (trans' c₁ c₂).denote ctx := by
induction c₂ generalizing c₁
next => intros; simp_all [trans', denoteAnd']
next c cs ih => simp [denoteAnd']; intros; simp [trans']; apply ih; apply denote'_trans <;> assumption

def Cnstrs.denoteAnd (ctx : Context) (c : Cnstrs) : Prop :=
match c with
| [] => True
| c::cs => denoteAnd' ctx c cs

def Cnstrs.trans (c : Cnstrs) : Cnstr :=
match c with
| [] => trivialCnstr
| c::cs => trans' c cs

theorem Cnstrs.of_denoteAnd_trans {ctx : Context} {c : Cnstrs} : c.denoteAnd ctx → c.trans.denote ctx := by
cases c <;> simp [*, trans, denoteAnd] <;> intros <;> simp [*]

def Cnstrs.isFalse (c : Cnstrs) : Bool :=
c.trans.isFalse

theorem Cnstrs.unsat' (ctx : Context) (c : Cnstrs) : c.isFalse = true → ¬ c.denoteAnd ctx := by
simp [isFalse]; intro h₁ h₂
have := of_denoteAnd_trans h₂
have := Cnstr.of_isFalse ctx h₁
contradiction

/-- `denote ctx [c_1, ..., c_n] C` is `c_1.denote ctx → ... → c_n.denote ctx → C` -/
def Cnstrs.denote (ctx : Context) (cs : Cnstrs) (C : Prop) : Prop :=
match cs with
| [] => C
| c::cs => c.denote ctx → denote ctx cs C

theorem Cnstrs.not_denoteAnd'_eq (ctx : Context) (c : Cnstr) (cs : Cnstrs) (C : Prop) : (denoteAnd' ctx c cs → C) = denote ctx (c::cs) C := by
simp [denote]
induction cs generalizing c
next => simp [denoteAnd', denote]
next c' cs ih =>
simp [denoteAnd', denote, *]

theorem Cnstrs.not_denoteAnd_eq (ctx : Context) (cs : Cnstrs) (C : Prop) : (denoteAnd ctx cs → C) = denote ctx cs C := by
cases cs
next => simp [denoteAnd, denote]
next c cs => apply not_denoteAnd'_eq

def Cnstr.isImpliedBy (cs : Cnstrs) (c : Cnstr) : Bool :=
cs.trans == c

/-! Main theorems used by `grind`. -/

/-- Auxiliary theorem used by `grind` to prove that a system of offset inequalities is unsatisfiable. -/
theorem Cnstrs.unsat (ctx : Context) (cs : Cnstrs) : cs.isFalse = true → cs.denote ctx False := by
intro h
rw [← not_denoteAnd_eq]
apply unsat'
assumption

/-- Auxiliary theorem used by `grind` to prove an implied offset inequality. -/
theorem Cnstrs.imp (ctx : Context) (cs : Cnstrs) (c : Cnstr) (h : c.isImpliedBy cs = true) : cs.denote ctx (c.denote ctx) := by
rw [← eq_of_beq h]
rw [← not_denoteAnd_eq]
apply of_denoteAnd_trans

end Lean.Grind.Offset
namespace Lean.Grind
def isLt (x y : Nat) : Bool := x < y

theorem Nat.le_ro (u w v k : Nat) : u ≤ w → w ≤ v + k → u ≤ v + k := by
omega
theorem Nat.le_lo (u w v k : Nat) : u ≤ w → w + k ≤ v → u + k ≤ v := by
omega
theorem Nat.lo_le (u w v k : Nat) : u + k ≤ w → w ≤ v → u + k ≤ v := by
omega
theorem Nat.lo_lo (u w v k₁ k₂ : Nat) : u + k₁ ≤ w → w + k₂ ≤ v → u + (k₁ + k₂) ≤ v := by
omega
theorem Nat.lo_ro_1 (u w v k₁ k₂ : Nat) : isLt k₂ k₁ = true → u + k₁ ≤ w → w ≤ v + k₂ → u + (k₁ - k₂) ≤ v := by
simp [isLt]; omega
theorem Nat.lo_ro_2 (u w v k₁ k₂ : Nat) : u + k₁ ≤ w → w ≤ v + k₂ → u ≤ v + (k₂ - k₁) := by
omega
theorem Nat.ro_le (u w v k : Nat) : u ≤ w + k → w ≤ v → u ≤ v + k := by
omega
theorem Nat.ro_lo_1 (u w v k₁ k₂ : Nat) : u ≤ w + k₁ → w + k₂ ≤ v → u ≤ v + (k₁ - k₂) := by
omega
theorem Nat.ro_lo_2 (u w v k₁ k₂ : Nat) : isLt k₁ k₂ = true → u ≤ w + k₁ → w + k₂ ≤ v → u + (k₂ - k₁) ≤ v := by
simp [isLt]; omega
theorem Nat.ro_ro (u w v k₁ k₂ : Nat) : u ≤ w + k₁ → w ≤ v + k₂ → u ≤ v + (k₁ + k₂) := by
omega

theorem Nat.of_le_eq_false (u v : Nat) : ((u ≤ v) = False) → v + 1 ≤ u := by
simp; omega
theorem Nat.of_lo_eq_false_1 (u v : Nat) : ((u + 1 ≤ v) = False) → v ≤ u := by
simp; omega
theorem Nat.of_lo_eq_false (u v k : Nat) : ((u + k ≤ v) = False) → v ≤ u + (k-1) := by
simp; omega
theorem Nat.of_ro_eq_false (u v k : Nat) : ((u ≤ v + k) = False) → v + (k+1) ≤ u := by
simp; omega

theorem Nat.unsat_le_lo (u v k : Nat) : isLt 0 k = true → u ≤ v → v + k ≤ u → False := by
simp [isLt]; omega
theorem Nat.unsat_lo_lo (u v k₁ k₂ : Nat) : isLt 0 (k₁+k₂) = true → u + k₁ ≤ v → v + k₂ ≤ u → False := by
simp [isLt]; omega
theorem Nat.unsat_lo_ro (u v k₁ k₂ : Nat) : isLt k₂ k₁ = true → u + k₁ ≤ v → v ≤ u + k₂ → False := by
simp [isLt]; omega

end Lean.Grind
8 changes: 7 additions & 1 deletion src/Lean/Data/AssocList.lean
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ abbrev empty : AssocList α β :=

instance : EmptyCollection (AssocList α β) := ⟨empty⟩

abbrev insert (m : AssocList α β) (k : α) (v : β) : AssocList α β :=
abbrev insertNew (m : AssocList α β) (k : α) (v : β) : AssocList α β :=
m.cons k v

def isEmpty : AssocList α β → Bool
Expand Down Expand Up @@ -77,6 +77,12 @@ def replace [BEq α] (a : α) (b : β) : AssocList α β → AssocList α β
| true => cons a b es
| false => cons k v (replace a b es)

def insert [BEq α] (m : AssocList α β) (k : α) (v : β) : AssocList α β :=
if m.contains k then
m.replace k v
else
m.insertNew k v

def erase [BEq α] (a : α) : AssocList α β → AssocList α β
| nil => nil
| cons k v es => match k == a with
Expand Down
12 changes: 8 additions & 4 deletions src/Lean/Meta/AppBuilder.lean
Original file line number Diff line number Diff line change
Expand Up @@ -569,12 +569,16 @@ def mkLetBodyCongr (a h : Expr) : MetaM Expr :=
mkAppM ``let_body_congr #[a, h]

/-- Return `of_eq_true h` -/
def mkOfEqTrue (h : Expr) : MetaM Expr :=
mkAppM ``of_eq_true #[h]
def mkOfEqTrue (h : Expr) : MetaM Expr := do
match_expr h with
| eq_true _ h => return h
| _ => mkAppM ``of_eq_true #[h]

/-- Return `eq_true h` -/
def mkEqTrue (h : Expr) : MetaM Expr :=
mkAppM ``eq_true #[h]
def mkEqTrue (h : Expr) : MetaM Expr := do
match_expr h with
| of_eq_true _ h => return h
| _ => return mkApp2 (mkConst ``eq_true) (← inferType h) h

/--
Return `eq_false h`
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Meta/Tactic/FVarSubst.lean
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def insert (s : FVarSubst) (fvarId : FVarId) (v : Expr) : FVarSubst :=
if s.contains fvarId then s
else
let map := s.map.mapVal fun e => e.replaceFVarId fvarId v;
{ map := map.insert fvarId v }
{ map := map.insertNew fvarId v }

def erase (s : FVarSubst) (fvarId : FVarId) : FVarSubst :=
{ map := s.map.erase fvarId }
Expand Down
7 changes: 7 additions & 0 deletions src/Lean/Meta/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import Lean.Meta.Tactic.Grind.EMatchTheorem
import Lean.Meta.Tactic.Grind.EMatch
import Lean.Meta.Tactic.Grind.Main
import Lean.Meta.Tactic.Grind.CasesMatch
import Lean.Meta.Tactic.Grind.Arith

namespace Lean

Expand All @@ -42,6 +43,10 @@ builtin_initialize registerTraceClass `grind.simp
builtin_initialize registerTraceClass `grind.split
builtin_initialize registerTraceClass `grind.split.candidate
builtin_initialize registerTraceClass `grind.split.resolved
builtin_initialize registerTraceClass `grind.offset
builtin_initialize registerTraceClass `grind.offset.dist
builtin_initialize registerTraceClass `grind.offset.internalize
builtin_initialize registerTraceClass `grind.offset.internalize.term (inherited := true)

/-! Trace options for `grind` developers -/
builtin_initialize registerTraceClass `grind.debug
Expand All @@ -54,4 +59,6 @@ builtin_initialize registerTraceClass `grind.debug.final
builtin_initialize registerTraceClass `grind.debug.forallPropagator
builtin_initialize registerTraceClass `grind.debug.split
builtin_initialize registerTraceClass `grind.debug.canon
builtin_initialize registerTraceClass `grind.debug.offset
builtin_initialize registerTraceClass `grind.debug.offset.proof
end Lean
10 changes: 10 additions & 0 deletions src/Lean/Meta/Tactic/Grind/Arith.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Grind.Arith.Util
import Lean.Meta.Tactic.Grind.Arith.Types
import Lean.Meta.Tactic.Grind.Arith.Offset
import Lean.Meta.Tactic.Grind.Arith.Main
33 changes: 33 additions & 0 deletions src/Lean/Meta/Tactic/Grind/Arith/Internalize.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Grind.Arith.Offset

namespace Lean.Meta.Grind.Arith

namespace Offset

def internalizeTerm (_e : Expr) (_a : Expr) (_k : Nat) : GoalM Unit := do
-- TODO
return ()

def internalizeCnstr (e : Expr) : GoalM Unit := do
let some c := isNatOffsetCnstr? e | return ()
let c := { c with
a := (← mkNode c.a)
b := (← mkNode c.b)
}
trace[grind.offset.internalize] "{e} ↦ {c}"
modify' fun s => { s with
cnstrs := s.cnstrs.insert { expr := e } c
}

end Offset

def internalize (e : Expr) : GoalM Unit := do
Offset.internalizeCnstr e

end Lean.Meta.Grind.Arith
14 changes: 14 additions & 0 deletions src/Lean/Meta/Tactic/Grind/Arith/Inv.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Grind.Arith.Offset

namespace Lean.Meta.Grind.Arith

def checkInvariants : GoalM Unit :=
Offset.checkInvariants

end Lean.Meta.Grind.Arith
34 changes: 34 additions & 0 deletions src/Lean/Meta/Tactic/Grind/Arith/Main.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Grind.PropagatorAttr
import Lean.Meta.Tactic.Grind.Arith.Offset

namespace Lean.Meta.Grind.Arith

namespace Offset
def isCnstr? (e : Expr) : GoalM (Option (Cnstr NodeId)) :=
return (← get).arith.offset.cnstrs.find? { expr := e }

def assertTrue (c : Cnstr NodeId) (p : Expr) : GoalM Unit := do
addEdge c.a c.b c.k (← mkOfEqTrue p)

def assertFalse (c : Cnstr NodeId) (p : Expr) : GoalM Unit := do
let p := mkOfNegEqFalse (← get').nodes c p
let c := c.neg
addEdge c.a c.b c.k p

end Offset

builtin_grind_propagator propagateLE ↓LE.le := fun e => do
if (← isEqTrue e) then
if let some c ← Offset.isCnstr? e then
Offset.assertTrue c (← mkEqTrueProof e)
if (← isEqFalse e) then
if let some c ← Offset.isCnstr? e then
Offset.assertFalse c (← mkEqFalseProof e)

end Lean.Meta.Grind.Arith
Loading
Loading