Skip to content

Commit

Permalink
feat: theorem patterns for heuristic instantiation in grind (#6472)
Browse files Browse the repository at this point in the history
This PR implements the command `grind_pattern`. The new command allows
users to associate patterns with theorems. These patterns are used for
performing heuristic instantiation with e-matching. In the future, we
will add the attributes `@[grind_eq]`, `@[grind_fwd]`, and
`@[grind_bwd]` to compute the patterns automatically for theorems.
  • Loading branch information
leodemoura authored Dec 29, 2024
1 parent 11eea84 commit 7433e74
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 0 deletions.
18 changes: 18 additions & 0 deletions src/Lean/Elab/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,29 @@ Authors: Leonardo de Moura
prelude
import Init.Grind.Tactics
import Lean.Meta.Tactic.Grind
import Lean.Elab.Command
import Lean.Elab.Tactic.Basic


namespace Lean.Elab.Tactic
open Meta

open Command Term in
@[builtin_command_elab Lean.Parser.Command.grindPattern]
def elabGrindPattern : CommandElab := fun stx => do
match stx with
| `(grind_pattern $thmName:ident => $terms,*) => do
liftTermElabM do
let declName ← resolveGlobalConstNoOverload thmName
let info ← getConstInfo declName
forallTelescope info.type fun xs _ => do
let patterns ← terms.getElems.mapM fun term => do
let pattern ← instantiateMVars (← elabTerm term none)
let pattern ← Grind.unfoldReducible pattern
return pattern.abstract xs
Grind.addTheoremPattern declName xs.size patterns.toList
| _ => throwUnsupportedSyntax

def grind (mvarId : MVarId) (mainDeclName : Name) : MetaM Unit := do
let mvarIds ← Grind.main mvarId mainDeclName
unless mvarIds.isEmpty do
Expand Down
2 changes: 2 additions & 0 deletions src/Lean/Meta/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import Lean.Meta.Tactic.Grind.PP
import Lean.Meta.Tactic.Grind.Simp
import Lean.Meta.Tactic.Grind.Ctor
import Lean.Meta.Tactic.Grind.Parser
import Lean.Meta.Tactic.Grind.TheoremPatterns

namespace Lean

Expand All @@ -35,5 +36,6 @@ builtin_initialize registerTraceClass `grind.simp
builtin_initialize registerTraceClass `grind.congr
builtin_initialize registerTraceClass `grind.proof
builtin_initialize registerTraceClass `grind.proof.detail
builtin_initialize registerTraceClass `grind.pattern

end Lean
174 changes: 174 additions & 0 deletions src/Lean/Meta/Tactic/Grind/TheoremPatterns.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.HeadIndex
import Lean.Util.FoldConsts
import Lean.Meta.Basic
import Lean.Meta.InferType

namespace Lean.Meta.Grind

inductive Origin where
/-- A global declaration in the environment. -/
| decl (declName : Name)
/-- A local hypothesis. -/
| fvar (fvarId : FVarId)
/--
A proof term provided directly to a call to `grind` where `ref`
is the provided grind argument. The `id` is a unique identifier for the call.
-/
| stx (id : Name) (ref : Syntax)
| other
deriving Inhabited, Repr

structure TheoremPattern where
proof : Expr
numParams : Nat
patterns : List Expr
/-- Contains all symbols used in `pattterns`. -/
symbols : List HeadIndex
origin : Origin
deriving Inhabited

abbrev TheoremPatterns := SMap Name (List TheoremPattern)

builtin_initialize theoremPatternsExt : SimpleScopedEnvExtension TheoremPattern TheoremPatterns ←
registerSimpleScopedEnvExtension {
addEntry := fun s t => Id.run do
let .const declName :: _ := t.symbols | unreachable!
if let some ts := s.find? declName then
s.insert declName (t::ts)
else
s.insert declName [t]
initial := .empty
}

-- TODO: create attribute?
private def forbiddenDeclNames := #[``Eq, ``HEq, ``Iff, ``And, ``Or, ``Not]

private def isForbidden (declName : Name) := forbiddenDeclNames.contains declName

private def dontCare := mkConst (Name.mkSimple "[grind_dontcare]")

private def mkGroundPattern (e : Expr) : Expr :=
mkAnnotation `grind.ground_pat e

private def groundPattern? (e : Expr) : Option Expr :=
annotation? `grind.ground_pat e

private def isGroundPattern (e : Expr) : Bool :=
groundPattern? e |>.isSome

private def isAtomicPattern (e : Expr) : Bool :=
e.isBVar || e == dontCare || isGroundPattern e

partial def ppPattern (pattern : Expr) : MessageData := Id.run do
if let some e := groundPattern? pattern then
return m!"`[{e}]"
else if pattern == dontCare then
return m!"?"
else match pattern with
| .bvar idx => return m!"#{idx}"
| _ =>
let mut r := m!"{pattern.getAppFn}"
for arg in pattern.getAppArgs do
let mut argFmt ← ppPattern arg
if !isAtomicPattern arg then
argFmt := MessageData.paren argFmt
r := r ++ " " ++ argFmt
return r

namespace NormalizePattern

structure State where
symbols : Array HeadIndex := #[]
symbolSet : Std.HashSet HeadIndex := {}
bvarsFound : Std.HashSet Nat := {}

abbrev M := StateRefT State MetaM

private def saveSymbol (h : HeadIndex) : M Unit := do
unless (← get).symbolSet.contains h do
modify fun s => { s with symbols := s.symbols.push h, symbolSet := s.symbolSet.insert h }

private def foundBVar (idx : Nat) : M Bool :=
return (← get).bvarsFound.contains idx

private def saveBVar (idx : Nat) : M Unit := do
modify fun s => { s with bvarsFound := s.bvarsFound.insert idx }

private def getPatternFn? (pattern : Expr) : Option Expr :=
if !pattern.isApp then
none
else match pattern.getAppFn with
| f@(.const declName _) => if isForbidden declName then none else some f
| f@(.fvar _) => some f
| _ => none

private structure PatternFunInfo where
instImplicitMask : Array Bool
typeMask : Array Bool

private def getPatternFunInfo (f : Expr) (numArgs : Nat) : MetaM PatternFunInfo := do
forallBoundedTelescope (← inferType f) numArgs fun xs _ => do
let typeMask ← xs.mapM fun x => isTypeFormer x
let instImplicitMask ← xs.mapM fun x => return (← x.fvarId!.getDecl).binderInfo matches .instImplicit
return { typeMask, instImplicitMask }

private partial def go (pattern : Expr) (root := false) : M Expr := do
if root && !pattern.hasLooseBVars then
throwError "invalid pattern, it does not have pattern variables"
let some f := getPatternFn? pattern
| throwError "invalid pattern, (non-forbidden) application expected"
assert! f.isConst || f.isFVar
saveSymbol f.toHeadIndex
let mut args := pattern.getAppArgs
let { instImplicitMask, typeMask } ← getPatternFunInfo f args.size
for i in [:args.size] do
let arg := args[i]!
let isType := typeMask[i]?.getD false
let isInstImplicit := instImplicitMask[i]?.getD false
let arg ← if !arg.hasLooseBVars then
if arg.hasMVar then
pure dontCare
else
pure <| mkGroundPattern arg
else match arg with
| .bvar idx =>
if (isType || isInstImplicit) && (← foundBVar idx) then
pure dontCare
else
saveBVar idx
pure arg
| _ =>
if isType || isInstImplicit then
pure dontCare
else if let some _ := getPatternFn? arg then
go arg
else
pure dontCare
args := args.set! i arg
return mkAppN f args

def main (patterns : List Expr) : MetaM (List Expr × List HeadIndex) := do
let (patterns, s) ← patterns.mapM go |>.run {}
return (patterns, s.symbols.toList)

end NormalizePattern

def addTheoremPattern (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM Unit := do
let .thmInfo info ← getConstInfo declName
| throwError "`{declName}` is not a theorem, you cannot assign patterns to non-theorems for the `grind` tactic"
let us := info.levelParams.map mkLevelParam
let proof := mkConst declName us
let (patterns, symbols) ← NormalizePattern.main patterns
trace[grind.pattern] "{declName}: {patterns.map ppPattern}"
theoremPatternsExt.add {
proof, patterns, numParams, symbols
origin := .decl declName
}

end Lean.Meta.Grind
28 changes: 28 additions & 0 deletions tests/lean/run/grind_pattern1.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
set_option trace.grind.pattern true

/--
info: [grind.pattern] Array.getElem_push_lt: [@getElem ? `[Nat] #4 ? ? (@Array.push ? #3 #2) #1 ?]
-/
#guard_msgs in
grind_pattern Array.getElem_push_lt => (a.push x)[i]


/--
info: [grind.pattern] List.getElem_attach: [@getElem ? `[Nat] ? ? ? (@List.attach #3 #2) #1 ?]
-/
#guard_msgs in
grind_pattern List.getElem_attach => xs.attach[i]

/--
info: [grind.pattern] List.mem_concat_self: [@Membership.mem #2 ? ? (@HAppend.hAppend ? ? ? ? #1 (@List.cons ? #0 (@List.nil ?))) #0]
-/
#guard_msgs in
grind_pattern List.mem_concat_self => a ∈ xs ++ [a]

def foo (x : Nat) := x + x

/--
error: `foo` is not a theorem, you cannot assign patterns to non-theorems for the `grind` tactic
-/
#guard_msgs in
grind_pattern foo => x + x

0 comments on commit 7433e74

Please sign in to comment.