From fe45ddd6105078a0a3bd855e5d94673e794f6b88 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 28 Dec 2024 00:50:58 +0100 Subject: [PATCH] feat: projections in `grind` (#6465) This PR adds support for projection functions to the (WIP) `grind` tactic. --- src/Lean/Meta/Tactic/Grind/Core.lean | 48 +--------- src/Lean/Meta/Tactic/Grind/Internalize.lean | 60 ++++++++++++ src/Lean/Meta/Tactic/Grind/Preprocessor.lean | 1 + src/Lean/Meta/Tactic/Grind/Proj.lean | 35 +++++++ src/Lean/Meta/Tactic/Grind/Propagate.lean | 1 + .../Meta/Tactic/Grind/PropagatorAttr.lean | 61 ++++++++++++ src/Lean/Meta/Tactic/Grind/Run.lean | 53 +++++++++++ src/Lean/Meta/Tactic/Grind/Types.lean | 95 ++----------------- tests/lean/run/grind_t1.lean | 20 ++++ 9 files changed, 238 insertions(+), 136 deletions(-) create mode 100644 src/Lean/Meta/Tactic/Grind/Internalize.lean create mode 100644 src/Lean/Meta/Tactic/Grind/Proj.lean create mode 100644 src/Lean/Meta/Tactic/Grind/PropagatorAttr.lean create mode 100644 src/Lean/Meta/Tactic/Grind/Run.lean diff --git a/src/Lean/Meta/Tactic/Grind/Core.lean b/src/Lean/Meta/Tactic/Grind/Core.lean index e1230973c369..5c4755345e71 100644 --- a/src/Lean/Meta/Tactic/Grind/Core.lean +++ b/src/Lean/Meta/Tactic/Grind/Core.lean @@ -10,56 +10,10 @@ import Lean.Meta.Tactic.Grind.Types import Lean.Meta.Tactic.Grind.Inv import Lean.Meta.Tactic.Grind.PP import Lean.Meta.Tactic.Grind.Ctor +import Lean.Meta.Tactic.Grind.Internalize namespace Lean.Meta.Grind -/-- Adds `e` to congruence table. -/ -private def addCongrTable (e : Expr) : GoalM Unit := do - if let some { e := e' } := (← get).congrTable.find? { e } then - trace[grind.congr] "{e} = {e'}" - pushEqHEq e e' congrPlaceholderProof - -- TODO: we must check whether the types of the functions are the same - -- TODO: update cgRoot for `e` - else - modify fun s => { s with congrTable := s.congrTable.insert { e } } - -partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do - if (← alreadyInternalized e) then return () - match e with - | .bvar .. => unreachable! - | .sort .. => return () - | .fvar .. | .letE .. | .lam .. | .forallE .. => - mkENodeCore e (ctor := false) (interpreted := false) (generation := generation) - | .lit .. | .const .. => - mkENode e generation - | .mvar .. - | .mdata .. - | .proj .. => - trace[grind.issues] "unexpected term during internalization{indentExpr e}" - mkENodeCore e (ctor := false) (interpreted := false) (generation := generation) - | .app .. => - if (← isLitValue e) then - -- We do not want to internalize the components of a literal value. - mkENode e generation - else e.withApp fun f args => do - if f.isConstOf ``Lean.Grind.nestedProof && args.size == 2 then - -- We only internalize the proposition. We can skip the proof because of - -- proof irrelevance - let c := args[0]! - internalize c generation - registerParent e c - else - unless f.isConst do - internalize f generation - registerParent e f - for h : i in [: args.size] do - let arg := args[i] - internalize arg generation - registerParent e arg - mkENode e generation - addCongrTable e - propagateUp e - /-- The fields `target?` and `proof?` in `e`'s `ENode` are encoding a transitivity proof from `e` to the root of the equivalence class diff --git a/src/Lean/Meta/Tactic/Grind/Internalize.lean b/src/Lean/Meta/Tactic/Grind/Internalize.lean new file mode 100644 index 000000000000..2114b1dd9a7f --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/Internalize.lean @@ -0,0 +1,60 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Init.Grind.Util +import Lean.Meta.LitValues +import Lean.Meta.Tactic.Grind.Types + +namespace Lean.Meta.Grind + +/-- Adds `e` to congruence table. -/ +def addCongrTable (e : Expr) : GoalM Unit := do + if let some { e := e' } := (← get).congrTable.find? { e } then + trace[grind.congr] "{e} = {e'}" + pushEqHEq e e' congrPlaceholderProof + -- TODO: we must check whether the types of the functions are the same + -- TODO: update cgRoot for `e` + else + modify fun s => { s with congrTable := s.congrTable.insert { e } } + +partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do + if (← alreadyInternalized e) then return () + match e with + | .bvar .. => unreachable! + | .sort .. => return () + | .fvar .. | .letE .. | .lam .. | .forallE .. => + mkENodeCore e (ctor := false) (interpreted := false) (generation := generation) + | .lit .. | .const .. => + mkENode e generation + | .mvar .. + | .mdata .. + | .proj .. => + trace[grind.issues] "unexpected term during internalization{indentExpr e}" + mkENodeCore e (ctor := false) (interpreted := false) (generation := generation) + | .app .. => + if (← isLitValue e) then + -- We do not want to internalize the components of a literal value. + mkENode e generation + else e.withApp fun f args => do + if f.isConstOf ``Lean.Grind.nestedProof && args.size == 2 then + -- We only internalize the proposition. We can skip the proof because of + -- proof irrelevance + let c := args[0]! + internalize c generation + registerParent e c + else + unless f.isConst do + internalize f generation + registerParent e f + for h : i in [: args.size] do + let arg := args[i] + internalize arg generation + registerParent e arg + mkENode e generation + addCongrTable e + propagateUp e + +end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Preprocessor.lean b/src/Lean/Meta/Tactic/Grind/Preprocessor.lean index 2f532c20f6bb..2e663fb59d1b 100644 --- a/src/Lean/Meta/Tactic/Grind/Preprocessor.lean +++ b/src/Lean/Meta/Tactic/Grind/Preprocessor.lean @@ -17,6 +17,7 @@ import Lean.Meta.Tactic.Grind.Cases import Lean.Meta.Tactic.Grind.Injection import Lean.Meta.Tactic.Grind.Core import Lean.Meta.Tactic.Grind.Simp +import Lean.Meta.Tactic.Grind.Run namespace Lean.Meta.Grind namespace Preprocessor diff --git a/src/Lean/Meta/Tactic/Grind/Proj.lean b/src/Lean/Meta/Tactic/Grind/Proj.lean new file mode 100644 index 000000000000..b742642b2f5f --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/Proj.lean @@ -0,0 +1,35 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Lean.ProjFns +import Lean.Meta.Tactic.Grind.Types +import Lean.Meta.Tactic.Grind.Internalize + +namespace Lean.Meta.Grind + +/-- +If `parent` is a projection-application `proj_i c`, +check whether the root of the equivalence class containing `c` is a constructor-application `ctor ... a_i ...`. +If so, internalize the term `proj_i (ctor ... a_i ...)` and add the equality `proj_i (ctor ... a_i ...) = a_i`. +-/ +def propagateProjEq (parent : Expr) : GoalM Unit := do + let .const declName _ := parent.getAppFn | return () + let some info ← getProjectionFnInfo? declName | return () + unless info.numParams + 1 == parent.getAppNumArgs do return () + let arg := parent.appArg! + let ctor ← getRoot arg + unless ctor.isAppOf info.ctorName do return () + if isSameExpr arg ctor then + let idx := info.numParams + info.i + unless idx < ctor.getAppNumArgs do return () + let v := ctor.getArg! idx + pushEq parent v (← mkEqRefl v) + else + let newProj := mkApp parent.appFn! ctor + let newProj ← shareCommon newProj + internalize newProj (← getGeneration parent) + +end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Propagate.lean b/src/Lean/Meta/Tactic/Grind/Propagate.lean index f9bbfda4c477..00dadb8f1c13 100644 --- a/src/Lean/Meta/Tactic/Grind/Propagate.lean +++ b/src/Lean/Meta/Tactic/Grind/Propagate.lean @@ -6,6 +6,7 @@ Authors: Leonardo de Moura prelude import Init.Grind import Lean.Meta.Tactic.Grind.Proof +import Lean.Meta.Tactic.Grind.PropagatorAttr namespace Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/PropagatorAttr.lean b/src/Lean/Meta/Tactic/Grind/PropagatorAttr.lean new file mode 100644 index 000000000000..6cd98068bf13 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/PropagatorAttr.lean @@ -0,0 +1,61 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Init.Grind +import Lean.Meta.Tactic.Grind.Proof + +namespace Lean.Meta.Grind + +/-- Builtin propagators. -/ +structure BuiltinPropagators where + up : Std.HashMap Name Propagator := {} + down : Std.HashMap Name Propagator := {} + deriving Inhabited + +builtin_initialize builtinPropagatorsRef : IO.Ref BuiltinPropagators ← IO.mkRef {} + +private def registerBuiltinPropagatorCore (declName : Name) (up : Bool) (proc : Propagator) : IO Unit := do + unless (← initializing) do + throw (IO.userError s!"invalid builtin `grind` propagator declaration, it can only be registered during initialization") + if up then + if (← builtinPropagatorsRef.get).up.contains declName then + throw (IO.userError s!"invalid builtin `grind` upward propagator `{declName}`, it has already been declared") + builtinPropagatorsRef.modify fun { up, down } => { up := up.insert declName proc, down } + else + if (← builtinPropagatorsRef.get).down.contains declName then + throw (IO.userError s!"invalid builtin `grind` downward propagator `{declName}`, it has already been declared") + builtinPropagatorsRef.modify fun { up, down } => { up, down := down.insert declName proc } + +def registerBuiltinUpwardPropagator (declName : Name) (proc : Propagator) : IO Unit := + registerBuiltinPropagatorCore declName true proc + +def registerBuiltinDownwardPropagator (declName : Name) (proc : Propagator) : IO Unit := + registerBuiltinPropagatorCore declName false proc + +private def addBuiltin (propagatorName : Name) (stx : Syntax) : AttrM Unit := do + let go : MetaM Unit := do + let up := stx[1].getKind == ``Lean.Parser.Tactic.simpPost + let addDeclName := if up then + ``registerBuiltinUpwardPropagator + else + ``registerBuiltinDownwardPropagator + let declName ← resolveGlobalConstNoOverload stx[2] + let val := mkAppN (mkConst addDeclName) #[toExpr declName, mkConst propagatorName] + let initDeclName ← mkFreshUserName (propagatorName ++ `declare) + declareBuiltin initDeclName val + go.run' {} + +builtin_initialize + registerBuiltinAttribute { + ref := by exact decl_name% + name := `grindPropagatorBuiltinAttr + descr := "Builtin `grind` propagator procedure" + applicationTime := AttributeApplicationTime.afterCompilation + erase := fun _ => throwError "Not implemented yet, [-builtin_simproc]" + add := fun declName stx _ => addBuiltin declName stx + } + +end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Run.lean b/src/Lean/Meta/Tactic/Grind/Run.lean new file mode 100644 index 000000000000..f2b534b4557e --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/Run.lean @@ -0,0 +1,53 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Init.Grind.Lemmas +import Lean.Meta.Tactic.Grind.Types +import Lean.Meta.Tactic.Grind.PropagatorAttr +import Lean.Meta.Tactic.Grind.Proj + +namespace Lean.Meta.Grind + +def mkMethods : CoreM Methods := do + let builtinPropagators ← builtinPropagatorsRef.get + return { + propagateUp := fun e => do + let .const declName _ := e.getAppFn | return () + propagateProjEq e + if let some prop := builtinPropagators.up[declName]? then + prop e + propagateDown := fun e => do + let .const declName _ := e.getAppFn | return () + if let some prop := builtinPropagators.down[declName]? then + prop e + } + +def GrindM.run (x : GrindM α) (mainDeclName : Name) : MetaM α := do + let scState := ShareCommon.State.mk _ + let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False) + let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True) + let thms ← grindNormExt.getTheorems + let simprocs := #[(← grindNormSimprocExt.getSimprocs)] + let simp ← Simp.mkContext + (config := { arith := true }) + (simpTheorems := #[thms]) + (congrTheorems := (← getSimpCongrTheorems)) + x (← mkMethods).toMethodsRef { mainDeclName, simprocs, simp } |>.run' { scState, trueExpr, falseExpr } + +@[inline] def GoalM.run (goal : Goal) (x : GoalM α) : GrindM (α × Goal) := + goal.mvarId.withContext do StateRefT'.run x goal + +@[inline] def GoalM.run' (goal : Goal) (x : GoalM Unit) : GrindM Goal := + goal.mvarId.withContext do StateRefT'.run' (x *> get) goal + +def mkGoal (mvarId : MVarId) : GrindM Goal := do + let trueExpr ← getTrueExpr + let falseExpr ← getFalseExpr + GoalM.run' { mvarId } do + mkENodeCore falseExpr (interpreted := true) (ctor := false) (generation := 0) + mkENodeCore trueExpr (interpreted := true) (ctor := false) (generation := 0) + +end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index b33e7dfcdd19..c072f94a7fdc 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -296,6 +296,10 @@ def getENode (e : Expr) : GoalM ENode := do | throwError "internal `grind` error, term has not been internalized{indentExpr e}" return n +/-- Returns the generation of the given term. Is assumes it has been internalized -/ +def getGeneration (e : Expr) : GoalM Nat := + return (← getENode e).generation + /-- Returns `true` if `e` is in the equivalence class of `True`. -/ def isEqTrue (e : Expr) : GoalM Bool := do let n ← getENode e @@ -508,12 +512,12 @@ def forEachEqc (f : ENode → GoalM Unit) : GoalM Unit := do if isSameExpr n.self n.root then f n -private structure Methods where +structure Methods where propagateUp : Propagator := fun _ => return () propagateDown : Propagator := fun _ => return () deriving Inhabited -private def Methods.toMethodsRef (m : Methods) : MethodsRef := +def Methods.toMethodsRef (m : Methods) : MethodsRef := unsafe unsafeCast m private def MethodsRef.toMethods (m : MethodsRef) : Methods := @@ -528,93 +532,6 @@ def propagateUp (e : Expr) : GoalM Unit := do def propagateDown (e : Expr) : GoalM Unit := do (← getMethods).propagateDown e -/-- Builtin propagators. -/ -structure BuiltinPropagators where - up : Std.HashMap Name Propagator := {} - down : Std.HashMap Name Propagator := {} - deriving Inhabited - -builtin_initialize builtinPropagatorsRef : IO.Ref BuiltinPropagators ← IO.mkRef {} - -private def registerBuiltinPropagatorCore (declName : Name) (up : Bool) (proc : Propagator) : IO Unit := do - unless (← initializing) do - throw (IO.userError s!"invalid builtin `grind` propagator declaration, it can only be registered during initialization") - if up then - if (← builtinPropagatorsRef.get).up.contains declName then - throw (IO.userError s!"invalid builtin `grind` upward propagator `{declName}`, it has already been declared") - builtinPropagatorsRef.modify fun { up, down } => { up := up.insert declName proc, down } - else - if (← builtinPropagatorsRef.get).down.contains declName then - throw (IO.userError s!"invalid builtin `grind` downward propagator `{declName}`, it has already been declared") - builtinPropagatorsRef.modify fun { up, down } => { up, down := down.insert declName proc } - -def registerBuiltinUpwardPropagator (declName : Name) (proc : Propagator) : IO Unit := - registerBuiltinPropagatorCore declName true proc - -def registerBuiltinDownwardPropagator (declName : Name) (proc : Propagator) : IO Unit := - registerBuiltinPropagatorCore declName false proc - -private def addBuiltin (propagatorName : Name) (stx : Syntax) : AttrM Unit := do - let go : MetaM Unit := do - let up := stx[1].getKind == ``Lean.Parser.Tactic.simpPost - let addDeclName := if up then - ``registerBuiltinUpwardPropagator - else - ``registerBuiltinDownwardPropagator - let declName ← resolveGlobalConstNoOverload stx[2] - let val := mkAppN (mkConst addDeclName) #[toExpr declName, mkConst propagatorName] - let initDeclName ← mkFreshUserName (propagatorName ++ `declare) - declareBuiltin initDeclName val - go.run' {} - -builtin_initialize - registerBuiltinAttribute { - ref := by exact decl_name% - name := `grindPropagatorBuiltinAttr - descr := "Builtin `grind` propagator procedure" - applicationTime := AttributeApplicationTime.afterCompilation - erase := fun _ => throwError "Not implemented yet, [-builtin_simproc]" - add := fun declName stx _ => addBuiltin declName stx - } - -def mkMethods : CoreM Methods := do - let builtinPropagators ← builtinPropagatorsRef.get - return { - propagateUp := fun e => do - let .const declName _ := e.getAppFn | return () - if let some prop := builtinPropagators.up[declName]? then - prop e - propagateDown := fun e => do - let .const declName _ := e.getAppFn | return () - if let some prop := builtinPropagators.down[declName]? then - prop e - } - -def GrindM.run (x : GrindM α) (mainDeclName : Name) : MetaM α := do - let scState := ShareCommon.State.mk _ - let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False) - let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True) - let thms ← grindNormExt.getTheorems - let simprocs := #[(← grindNormSimprocExt.getSimprocs)] - let simp ← Simp.mkContext - (config := { arith := true }) - (simpTheorems := #[thms]) - (congrTheorems := (← getSimpCongrTheorems)) - x (← mkMethods).toMethodsRef { mainDeclName, simprocs, simp } |>.run' { scState, trueExpr, falseExpr } - -@[inline] def GoalM.run (goal : Goal) (x : GoalM α) : GrindM (α × Goal) := - goal.mvarId.withContext do StateRefT'.run x goal - -@[inline] def GoalM.run' (goal : Goal) (x : GoalM Unit) : GrindM Goal := - goal.mvarId.withContext do StateRefT'.run' (x *> get) goal - -def mkGoal (mvarId : MVarId) : GrindM Goal := do - let trueExpr ← getTrueExpr - let falseExpr ← getFalseExpr - GoalM.run' { mvarId } do - mkENodeCore falseExpr (interpreted := true) (ctor := false) (generation := 0) - mkENodeCore trueExpr (interpreted := true) (ctor := false) (generation := 0) - /-- Returns expressions in the given expression equivalence class. -/ partial def getEqc (e : Expr) : GoalM (List Expr) := go e e [] diff --git a/tests/lean/run/grind_t1.lean b/tests/lean/run/grind_t1.lean index 0d4b9ada9813..59f8749a8775 100644 --- a/tests/lean/run/grind_t1.lean +++ b/tests/lean/run/grind_t1.lean @@ -55,3 +55,23 @@ example (a b c : BitVec 32) : a = c → a = 1#32 → c = 2#32 → c = b → Fals example (a b c : UInt32) : a = c → a = 1 → c = 200 → c = b → False := by grind + +structure Boo (α : Type) where + a : α + b : α + c : α + +example (a b d : Nat) (f : Nat → Boo Nat) : (f d).1 ≠ a → f d = ⟨b, v₁, v₂⟩ → b = a → False := by + grind + +def ex (a b c d : Nat) (f : Nat → Boo Nat) : (f d).2 ≠ a → f d = ⟨b, c, v₂⟩ → c = a → False := by + grind + +example (a b c : Nat) (f : Nat → Nat) : { a := f b, c, b := 4 : Boo Nat }.1 ≠ f a → f b = f c → a = c → False := by + grind + +example (a b c : Nat) (f : Nat → Nat) : p = { a := f b, c, b := 4 : Boo Nat } → p.1 ≠ f a → f b = f c → a = c → False := by + grind + +example (a b c : Nat) (f : Nat → Nat) : p.1 ≠ f a → p = { a := f b, c, b := 4 : Boo Nat } → f b = f c → a = c → False := by + grind