Skip to content

Commit

Permalink
feat: equality proof generation for grind (#6452)
Browse files Browse the repository at this point in the history
This PR adds support for generating (small) proofs for any two
expressions that belong to the same equivalence class in the `grind`
tactic state.
  • Loading branch information
leodemoura authored Dec 26, 2024
1 parent bdcb791 commit 8a1e50f
Show file tree
Hide file tree
Showing 10 changed files with 174 additions and 29 deletions.
8 changes: 8 additions & 0 deletions src/Lean/Meta/AppBuilder.lean
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,14 @@ def mkEqOfHEq (h : Expr) : MetaM Expr := do
| _ =>
throwAppBuilderException ``eq_of_heq m!"heterogeneous equality proof expected{indentExpr h}"

/-- Given `h : Eq a b`, returns a proof of `HEq a b`. -/
def mkHEqOfEq (h : Expr) : MetaM Expr := do
let hType ← infer h
let some (α, a, b) := hType.eq?
| throwAppBuilderException ``heq_of_eq m!"equality proof expected{indentExpr h}"
let u ← getLevel α
return mkApp4 (mkConst ``heq_of_eq [u]) α a b h

/--
If `e` is `@Eq.refl α a`, return `a`.
-/
Expand Down
1 change: 1 addition & 0 deletions src/Lean/Meta/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ builtin_initialize registerTraceClass `grind.issues
builtin_initialize registerTraceClass `grind.add
builtin_initialize registerTraceClass `grind.pre
builtin_initialize registerTraceClass `grind.debug
builtin_initialize registerTraceClass `grind.debug.proofs
builtin_initialize registerTraceClass `grind.simp
builtin_initialize registerTraceClass `grind.congr

Expand Down
3 changes: 0 additions & 3 deletions src/Lean/Meta/Tactic/Grind/Core.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@ import Lean.Meta.Tactic.Grind.PP

namespace Lean.Meta.Grind

/-- We use this auxiliary constant to mark delayed congruence proofs. -/
private def congrPlaceholderProof := mkConst (Name.mkSimple "[congruence]")

/-- Adds `e` to congruence table. -/
private def addCongrTable (e : Expr) : GoalM Unit := do
if let some { e := e' } := (← get).congrTable.find? { e } then
Expand Down
16 changes: 15 additions & 1 deletion src/Lean/Meta/Tactic/Grind/Inv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Proof

namespace Lean.Meta.Grind

Expand Down Expand Up @@ -54,12 +55,23 @@ private def checkPtrEqImpliesStructEq : GoalM Unit := do
for h₁ : i in [: nodes.size] do
let n₁ := nodes[i]
for h₂ : j in [i+1 : nodes.size] do
let n₂ := nodes[i]
let n₂ := nodes[j]
-- We don't have multiple nodes for the same expression
assert! !isSameExpr n₁.self n₂.self
-- and the two expressions must not be structurally equal
assert! !Expr.equal n₁.self n₂.self

private def checkProofs : GoalM Unit := do
let eqcs ← getEqcs
for eqc in eqcs do
for a in eqc do
for b in eqc do
unless isSameExpr a b do
let p ← mkEqProof a b
trace[grind.debug.proofs] "{a} = {b}"
check p
trace[grind.debug.proofs] "checked: {← inferType p}"

/--
Checks basic invariants if `grind.debug` is enabled.
-/
Expand All @@ -71,5 +83,7 @@ def checkInvariants (expensive := false) : GoalM Unit := do
checkEqc node
if expensive then
checkPtrEqImpliesStructEq
if expensive && grind.debug.proofs.get (← getOptions) then
checkProofs

end Lean.Meta.Grind
21 changes: 0 additions & 21 deletions src/Lean/Meta/Tactic/Grind/PP.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,6 @@ def ppENodeRef (e : Expr) : GoalM Format := do
let some n ← getENode? e | return "_"
return f!"#{n.idx}"

/-- Returns expressions in the given expression equivalence class. -/
partial def getEqc (e : Expr) : GoalM (List Expr) :=
go e e []
where
go (first : Expr) (e : Expr) (acc : List Expr) : GoalM (List Expr) := do
let next ← getNext e
let acc := e :: acc
if isSameExpr first next then
return acc
else
go first next acc

/-- Returns all equivalence classes in the current goal. -/
partial def getEqcs : GoalM (List (List Expr)) := do
let mut r := []
let nodes ← getENodes
for node in nodes do
if isSameExpr node.root node.self then
r := (← getEqc node.self) :: r
return r

/-- Helper function for pretty printing the state for debugging purposes. -/
def ppENodeDeclValue (e : Expr) : GoalM Format := do
if e.isApp && !(← isLitValue e) then
Expand Down
120 changes: 116 additions & 4 deletions src/Lean/Meta/Tactic/Grind/Proof.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,128 @@ import Lean.Meta.Tactic.Grind.Types

namespace Lean.Meta.Grind

-- TODO: delete after done
private def mkTodo (a b : Expr) (heq : Bool) : MetaM Expr := do
if heq then
mkSorry (← mkHEq a b) (synthetic := false)
else
mkSorry (← mkEq a b) (synthetic := false)

private def isProtoProof (h : Expr) : Bool :=
isSameExpr h congrPlaceholderProof

private def isEqProof (h : Expr) : MetaM Bool := do
return (← whnfD (← inferType h)).isAppOf ``Eq

private def flipProof (h : Expr) (flipped : Bool) (heq : Bool) : MetaM Expr := do
let mut h' := h
if (← pure heq <&&> isEqProof h') then
h' ← mkHEqOfEq h'
if flipped then
if heq then mkHEqSymm h' else mkEqSymm h'
else
return h'

private def mkRefl (a : Expr) (heq : Bool) : MetaM Expr :=
if heq then mkHEqRefl a else mkEqRefl a

private def mkTrans (h₁ h₂ : Expr) (heq : Bool) : MetaM Expr :=
if heq then
mkHEqTrans h₁ h₂
else
mkEqTrans h₁ h₂

private def mkTrans' (h₁ : Option Expr) (h₂ : Expr) (heq : Bool) : MetaM Expr := do
let some h₁ := h₁ | return h₂
mkTrans h₁ h₂ heq

/--
Given `lhs` and `rhs` that are in the same equivalence class,
find the common expression that are in the paths from `lhs` and `rhs` to
the root of their equivalence class.
Recall that this expression must exist since it is the root itself in the
worst case.
-/
private def findCommon (lhs rhs : Expr) : GoalM Expr := do
let mut visited : RBMap Nat Expr compare := {}
let mut it := lhs
-- Mark elements found following the path from `lhs` to the root.
repeat
let n ← getENode it
visited := visited.insert n.idx n.self
let some target := n.target? | break
it := target
-- Find the marked element from the path from `rhs` to the root.
it := rhs
repeat
let n ← getENode it
if let some common := visited.find? n.idx then
return common
let some target := n.target? | unreachable! --
it := target
unreachable!

mutual
private partial def mkCongrProof (lhs rhs : Expr) (heq : Bool) : GoalM Expr := do
-- TODO: implement
mkTodo lhs rhs heq

private partial def realizeEqProof (lhs rhs : Expr) (h : Expr) (flipped : Bool) (heq : Bool) : GoalM Expr := do
let h ← if h == congrPlaceholderProof then
mkCongrProof lhs rhs heq
else
flipProof h flipped heq

private partial def mkProofTo (lhs : Expr) (common : Expr) (acc : Option Expr) (heq : Bool) : GoalM (Option Expr) := do
if isSameExpr lhs common then
return acc
let n ← getENode lhs
let some target := n.target? | unreachable!
let some h := n.proof? | unreachable!
let h ← realizeEqProof lhs target h n.flipped heq
-- h : lhs = target
let acc ← mkTrans' acc h heq
mkProofTo target common (some acc) heq

/--
Given `lhsEqCommon : lhs = common`, returns a proof for `lhs = rhs`.
-/
private partial def mkProofFrom (rhs : Expr) (common : Expr) (lhsEqCommon? : Option Expr) (heq : Bool) : GoalM (Option Expr) := do
if isSameExpr rhs common then
return lhsEqCommon?
let n ← getENode rhs
let some target := n.target? | unreachable!
let some h := n.proof? | unreachable!
let h ← realizeEqProof target rhs h (!n.flipped) heq
-- `h : target = rhs`
let h' ← mkProofFrom target common lhsEqCommon? heq
-- `h' : lhs = target`
mkTrans' h' h heq

private partial def mkEqProofCore (lhs rhs : Expr) (heq : Bool) : GoalM Expr := do
if isSameExpr lhs rhs then
return (← mkRefl lhs heq)
let n₁ ← getENode lhs
let n₂ ← getENode rhs
assert! isSameExpr n₁.root n₂.root
let common ← findCommon lhs rhs
let lhsEqCommon? ← mkProofTo lhs common none heq
let some lhsEqRhs ← mkProofFrom rhs common lhsEqCommon? heq | unreachable!
return lhsEqRhs
end

/--
Returns a proof that `a = b` (or `HEq a b`).
It assumes `a` and `b` are in the same equivalence class.
-/
def mkEqProof (a b : Expr) : GoalM Expr := do
-- TODO
if (← isDefEq (← inferType a) (← inferType b)) then
mkSorry (← mkEq a b) (synthetic := false)
let n ← getENode a
if !n.heqProofs then
mkEqProofCore a b (heq := false)
else if (← withDefault <| isDefEq (← inferType a) (← inferType b)) then
mkEqProofCore a b (heq := false)
else
mkSorry (← mkHEq a b) (synthetic := false)
mkEqProofCore a b (heq := true)

/--
Returns a proof that `a = True`.
Expand Down
30 changes: 30 additions & 0 deletions src/Lean/Meta/Tactic/Grind/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ namespace Lean.Meta.Grind
-- inserted into the E-graph
unsafe ptrEq a b

/-- We use this auxiliary constant to mark delayed congruence proofs. -/
def congrPlaceholderProof := mkConst (Name.mkSimple "[congruence]")

/--
Returns `true` if `e` is `True`, `False`, or a literal value.
See `LitValues` for supported literals.
Expand All @@ -34,6 +37,12 @@ register_builtin_option grind.debug : Bool := {
descr := "check invariants after updates"
}

register_builtin_option grind.debug.proofs : Bool := {
defValue := false
group := "debug"
descr := "check proofs between the elements of all equivalence classes"
}

/-- Context for `GrindM` monad. -/
structure Context where
simp : Simp.Context
Expand Down Expand Up @@ -559,4 +568,25 @@ def mkGoal (mvarId : MVarId) : GrindM Goal := do
mkENodeCore falseExpr (interpreted := true) (ctor := false) (generation := 0)
mkENodeCore trueExpr (interpreted := true) (ctor := false) (generation := 0)

/-- Returns expressions in the given expression equivalence class. -/
partial def getEqc (e : Expr) : GoalM (List Expr) :=
go e e []
where
go (first : Expr) (e : Expr) (acc : List Expr) : GoalM (List Expr) := do
let next ← getNext e
let acc := e :: acc
if isSameExpr first next then
return acc
else
go first next acc

/-- Returns all equivalence classes in the current goal. -/
partial def getEqcs : GoalM (List (List Expr)) := do
let mut r := []
let nodes ← getENodes
for node in nodes do
if isSameExpr node.root node.self then
r := (← getEqc node.self) :: r
return r

end Lean.Meta.Grind
1 change: 1 addition & 0 deletions tests/lean/run/grind_congr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ elab "grind_test" : tactic => withMainContext do
logInfo eqc

set_option grind.debug true
set_option grind.debug.proofs true

/--
info: [d, f b, c, f a]
Expand Down
2 changes: 2 additions & 0 deletions tests/lean/run/grind_nested_proofs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ elab "grind_test" : tactic => withMainContext do
logInfo (← getEqc n.self)

set_option grind.debug true
-- TODO: fix nested proof support
-- set_option grind.debug.proofs true

/-
Recall that array access terms, such as `a[i]`, have nested proofs.
Expand Down
1 change: 1 addition & 0 deletions tests/lean/run/grind_propagate_connectives.lean
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ elab "grind_test" : tactic => withMainContext do
logInfo eqc

set_option grind.debug true
set_option grind.debug.proofs true

/--
info: true: [q, w]
Expand Down

0 comments on commit 8a1e50f

Please sign in to comment.