From 24a8561ec4e302f4e0cba07632fddacd6f6e0323 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 30 Dec 2024 04:40:43 +0100 Subject: [PATCH] feat: check pattern coverage in the `grind_pattern` command (#6474) This PR adds pattern validation to the `grind_pattern` command. The new `checkCoverage` function will also be used to implement the attributes `@[grind_eq]`, `@[grind_fwd]`, and `@[grind_bwd]`. --- src/Lean/Elab/Tactic/Grind.lean | 2 +- .../Meta/Tactic/Grind/TheoremPatterns.lean | 132 +++++++++++++++++- tests/lean/run/grind_pattern1.lean | 88 ++++++++++++ 3 files changed, 218 insertions(+), 4 deletions(-) diff --git a/src/Lean/Elab/Tactic/Grind.lean b/src/Lean/Elab/Tactic/Grind.lean index cea6d7d13c40..2ceef07fc832 100644 --- a/src/Lean/Elab/Tactic/Grind.lean +++ b/src/Lean/Elab/Tactic/Grind.lean @@ -9,7 +9,6 @@ import Lean.Meta.Tactic.Grind import Lean.Elab.Command import Lean.Elab.Tactic.Basic - namespace Lean.Elab.Tactic open Meta @@ -20,6 +19,7 @@ def elabGrindPattern : CommandElab := fun stx => do | `(grind_pattern $thmName:ident => $terms,*) => do liftTermElabM do let declName ← resolveGlobalConstNoOverload thmName + discard <| addTermInfo thmName (← mkConstWithLevelParams declName) let info ← getConstInfo declName forallTelescope info.type fun xs _ => do let patterns ← terms.getElems.mapM fun term => do diff --git a/src/Lean/Meta/Tactic/Grind/TheoremPatterns.lean b/src/Lean/Meta/Tactic/Grind/TheoremPatterns.lean index 4a9d2c9a20e5..82833aad3671 100644 --- a/src/Lean/Meta/Tactic/Grind/TheoremPatterns.lean +++ b/src/Lean/Meta/Tactic/Grind/TheoremPatterns.lean @@ -6,6 +6,7 @@ Authors: Leonardo de Moura prelude import Lean.HeadIndex import Lean.Util.FoldConsts +import Lean.Util.CollectFVars import Lean.Meta.Basic import Lean.Meta.InferType @@ -153,19 +154,144 @@ private partial def go (pattern : Expr) (root := false) : M Expr := do args := args.set! i arg return mkAppN f args -def main (patterns : List Expr) : MetaM (List Expr × List HeadIndex) := do +def main (patterns : List Expr) : MetaM (List Expr × List HeadIndex × Std.HashSet Nat) := do let (patterns, s) ← patterns.mapM go |>.run {} - return (patterns, s.symbols.toList) + return (patterns, s.symbols.toList, s.bvarsFound) end NormalizePattern +/-- +Returns `true` if free variables in `type` are not in `thmVars` or are in `fvarsFound`. +We use this function to check whether `type` is fully instantiated. +-/ +private def checkTypeFVars (thmVars : FVarIdSet) (fvarsFound : FVarIdSet) (type : Expr) : Bool := + let typeFVars := (collectFVars {} type).fvarIds + typeFVars.all fun fvarId => !thmVars.contains fvarId || fvarsFound.contains fvarId + +/-- +Given an type class instance type `instType`, returns true if free variables in input parameters +1- are not in `thmVars`, or +2- are in `fvarsFound`. +Remark: `fvarsFound` is a subset of `thmVars` +-/ +private def canBeSynthesized (thmVars : FVarIdSet) (fvarsFound : FVarIdSet) (instType : Expr) : MetaM Bool := do + forallTelescopeReducing instType fun xs type => type.withApp fun classFn classArgs => do + for x in xs do + unless checkTypeFVars thmVars fvarsFound (← inferType x) do return false + forallBoundedTelescope (← inferType classFn) type.getAppNumArgs fun params _ => do + for param in params, classArg in classArgs do + let paramType ← inferType param + if !paramType.isAppOf ``semiOutParam && !paramType.isAppOf ``outParam then + unless checkTypeFVars thmVars fvarsFound classArg do + return false + return true + +/-- +Auxiliary type for the `checkCoverage` function. +-/ +inductive CheckCoverageResult where + | /-- `checkCoverage` succeeded -/ + ok + | /-- + `checkCoverage` failed because some of the theorem parameters are missing, + `pos` contains their positions + -/ + missing (pos : List Nat) + +/-- +After we process a set of patterns, we obtain the set of de Bruijn indices in these patterns. +We say they are pattern variables. This function checks whether the set of pattern variables is sufficient for +instantiating the theorem with proof `thmProof`. The theorem has `numParams` parameters. +The missing parameters: +1- we may be able to infer them using type inference or type class synthesis, or +2- they are propositions, and may become hypotheses of the instantiated theorem. + +For type class instance parameters, we must check whether the free variables in class input parameters are available. +-/ +private def checkCoverage (thmProof : Expr) (numParams : Nat) (bvarsFound : Std.HashSet Nat) : MetaM CheckCoverageResult := do + if bvarsFound.size == numParams then return .ok + forallBoundedTelescope (← inferType thmProof) numParams fun xs _ => do + assert! numParams == xs.size + let patternVars := bvarsFound.toList.map fun bidx => xs[numParams - bidx - 1]!.fvarId! + -- `xs` as a `FVarIdSet`. + let thmVars : FVarIdSet := RBTree.ofList <| xs.toList.map (·.fvarId!) + -- Collect free variables occurring in `e`, and insert the ones that are in `thmVars` into `fvarsFound` + let update (fvarsFound : FVarIdSet) (e : Expr) : FVarIdSet := + (collectFVars {} e).fvarIds.foldl (init := fvarsFound) fun s fvarId => + if thmVars.contains fvarId then s.insert fvarId else s + -- Theorem variables found so far. We initialize with the variables occurring in patterns + -- Remark: fvarsFound is a subset of thmVars + let mut fvarsFound : FVarIdSet := RBTree.ofList patternVars + for patternVar in patternVars do + let type ← patternVar.getType + fvarsFound := update fvarsFound type + if fvarsFound.size == numParams then return .ok + -- Now, we keep traversing remaining variables and collecting + -- `processed` contains the variables we have already processed. + let mut processed : FVarIdSet := RBTree.ofList patternVars + let mut modified := false + repeat + modified := false + for x in xs do + let fvarId := x.fvarId! + unless processed.contains fvarId do + let xType ← inferType x + if fvarsFound.contains fvarId then + -- Collect free vars in `x`s type and mark as processed + fvarsFound := update fvarsFound xType + processed := processed.insert fvarId + modified := true + else if (← isProp xType) then + -- If `x` is a proposition, and all theorem variables in `x`s type have already been found + -- add it to `fvarsFound` and mark it as processed. + if checkTypeFVars thmVars fvarsFound xType then + fvarsFound := fvarsFound.insert fvarId + processed := processed.insert fvarId + modified := true + else if (← fvarId.getDecl).binderInfo matches .instImplicit then + -- If `x` is instance implicit, check whether + -- we have found all free variables needed to synthesize instance + if (← canBeSynthesized thmVars fvarsFound xType) then + fvarsFound := fvarsFound.insert fvarId + fvarsFound := update fvarsFound xType + processed := processed.insert fvarId + modified := true + if fvarsFound.size == numParams then + return .ok + if !modified then + break + let mut pos := #[] + for h : i in [:xs.size] do + let fvarId := xs[i].fvarId! + unless fvarsFound.contains fvarId do + pos := pos.push i + return .missing pos.toList + +/-- +Given a theorem with proof `proof` and `numParams` parameters, returns a message +containing the parameters at positions `paramPos`. +-/ +private def ppParamsAt (proof : Expr) (numParms : Nat) (paramPos : List Nat) : MetaM MessageData := do + forallBoundedTelescope (← inferType proof) numParms fun xs _ => do + let mut msg := m!"" + let mut first := true + for h : i in [:xs.size] do + if paramPos.contains i then + let x := xs[i] + if first then first := false else msg := msg ++ "\n" + msg := msg ++ m!"{x} : {← inferType x}" + addMessageContextFull msg + def addTheoremPattern (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM Unit := do let .thmInfo info ← getConstInfo declName | throwError "`{declName}` is not a theorem, you cannot assign patterns to non-theorems for the `grind` tactic" let us := info.levelParams.map mkLevelParam let proof := mkConst declName us - let (patterns, symbols) ← NormalizePattern.main patterns + let (patterns, symbols, bvarFound) ← NormalizePattern.main patterns trace[grind.pattern] "{declName}: {patterns.map ppPattern}" + if let .missing pos ← checkCoverage proof numParams bvarFound then + let pats : MessageData := m!"{patterns.map ppPattern}" + throwError "invalid pattern(s) for `{declName}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}" theoremPatternsExt.add { proof, patterns, numParams, symbols origin := .decl declName diff --git a/tests/lean/run/grind_pattern1.lean b/tests/lean/run/grind_pattern1.lean index 79c3a564eb58..0eb1750bdcbe 100644 --- a/tests/lean/run/grind_pattern1.lean +++ b/tests/lean/run/grind_pattern1.lean @@ -26,3 +26,91 @@ error: `foo` is not a theorem, you cannot assign patterns to non-theorems for th -/ #guard_msgs in grind_pattern foo => x + x + +/-- +error: invalid pattern(s) for `Array.getElem_push_lt` + [@Array.push #4 #3 #2] +the following theorem parameters cannot be instantiated: + i : Nat + h : i < a.size +--- +info: [grind.pattern] Array.getElem_push_lt: [@Array.push #4 #3 #2] +-/ +#guard_msgs in +grind_pattern Array.getElem_push_lt => (a.push x) + +class Foo (α : Type) (β : outParam Type) where + a : Unit + +class Boo (α : Type) (β : Type) where + b : β + +def f [Foo α β] [Boo α β] (a : α) : (α × β) := + (a, Boo.b α) + +instance [Foo α β] : Foo (List α) (Array β) where + a := () + +instance [Boo α β] : Boo (List α) (Array β) where + b := #[Boo.b α] + +theorem fEq [Foo α β] [Boo α β] (a : List α) : (f a).1 = a := rfl + +/-- info: [grind.pattern] fEq: [@f ? ? ? ? #0] -/ +#guard_msgs in +grind_pattern fEq => f a + +theorem fEq2 [Foo α β] [Boo α β] (a : List α) (_h : a.length > 5) : (f a).1 = a := rfl + +/-- info: [grind.pattern] fEq2: [@f ? ? ? ? #1] -/ +#guard_msgs in +grind_pattern fEq2 => f a + +def g [Boo α β] (a : α) : (α × β) := + (a, Boo.b α) + +theorem gEq [Boo α β] (a : List α) : (g (β := Array β) a).1 = a := rfl + +/-- +error: invalid pattern(s) for `gEq` + [@g ? ? ? #0] +the following theorem parameters cannot be instantiated: + β : Type + inst✝ : Boo α β +--- +info: [grind.pattern] gEq: [@g ? ? ? #0] +-/ +#guard_msgs in +grind_pattern gEq => g a + +def plus (a : Nat) (b : Nat) := a + b + +theorem hThm1 (h : b > 10) : plus a b + plus a c > 10 := by + unfold plus; omega + +/-- +error: invalid pattern(s) for `hThm1` + [plus #2 #3] +the following theorem parameters cannot be instantiated: + c : Nat +--- +info: [grind.pattern] hThm1: [plus #2 #3] +-/ +#guard_msgs in +grind_pattern hThm1 => plus a b + +/-- +error: invalid pattern(s) for `hThm1` + [plus #2 #1] +the following theorem parameters cannot be instantiated: + b : Nat + h : b > 10 +--- +info: [grind.pattern] hThm1: [plus #2 #1] +-/ +#guard_msgs in +grind_pattern hThm1 => plus a c + +/-- info: [grind.pattern] hThm1: [plus #2 #1, plus #2 #3] -/ +#guard_msgs in +grind_pattern hThm1 => plus a c, plus a b