Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use attribute command to add and erase simprocs #3511

Merged
merged 5 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/Lean/Elab/Declaration.lean
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,21 @@ def elabMutual : CommandElab := fun stx => do
let attrs ← elabAttrs attrInsts
let idents := stx[4].getArgs
for ident in idents do withRef ident <| liftTermElabM do
let declName ← resolveGlobalConstNoOverloadWithInfo ident
/-
HACK to allow `attribute` command to disable builtin simprocs.
TODO: find a better solution. Example: have some "fake" declaration
for builtin simprocs.
-/
let declNames ←
try
resolveGlobalConst ident
catch _ =>
let name := ident.getId.eraseMacroScopes
if (← Simp.isBuiltinSimproc name) then
pure [name]
else
throwUnknownConstant name
let declName ← ensureNonAmbiguous ident declNames
Term.applyAttributes declName attrs
for attrName in toErase do
Attribute.erase declName attrName
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Meta/Tactic/NormCast.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Authors: Paul-Nicolas Madelaine, Robert Y. Lewis, Mario Carneiro, Gabriel Ebner
-/
prelude
import Lean.Meta.CongrTheorems
import Lean.Meta.Tactic.Simp.SimpTheorems
import Lean.Meta.Tactic.Simp.Attr
import Lean.Meta.CoeAttr

namespace Lean.Meta.NormCast
Expand Down
1 change: 1 addition & 0 deletions src/Lean/Meta/Tactic/Simp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import Lean.Meta.Tactic.Simp.SimpAll
import Lean.Meta.Tactic.Simp.Simproc
import Lean.Meta.Tactic.Simp.BuiltinSimprocs
import Lean.Meta.Tactic.Simp.RegisterCommand
import Lean.Meta.Tactic.Simp.Attr

namespace Lean

Expand Down
74 changes: 74 additions & 0 deletions src/Lean/Meta/Tactic/Simp/Attr.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Simp.Types
import Lean.Meta.Tactic.Simp.SimpTheorems
import Lean.Meta.Tactic.Simp.Simproc

namespace Lean.Meta
open Simp

def mkSimpAttr (attrName : Name) (attrDescr : String) (ext : SimpExtension)
(ref : Name := by exact decl_name%) : IO Unit :=
registerBuiltinAttribute {
ref := ref
name := attrName
descr := attrDescr
applicationTime := AttributeApplicationTime.afterCompilation
add := fun declName stx attrKind => do
if (← isSimproc declName <||> isBuiltinSimproc declName) then
let simprocAttrName := simpAttrNameToSimprocAttrName attrName
Attribute.add declName simprocAttrName stx attrKind
else
let go : MetaM Unit := do
let info ← getConstInfo declName
let post := if stx[1].isNone then true else stx[1][0].getKind == ``Lean.Parser.Tactic.simpPost
let prio ← getAttrParamOptPrio stx[2]
if (← isProp info.type) then
addSimpTheorem ext declName post (inv := false) attrKind prio
else if info.hasValue then
if let some eqns ← getEqnsFor? declName then
for eqn in eqns do
addSimpTheorem ext eqn post (inv := false) attrKind prio
ext.add (SimpEntry.toUnfoldThms declName eqns) attrKind
if hasSmartUnfoldingDecl (← getEnv) declName then
ext.add (SimpEntry.toUnfold declName) attrKind
else
ext.add (SimpEntry.toUnfold declName) attrKind
else
throwError "invalid 'simp', it is not a proposition nor a definition (to unfold)"
discard <| go.run {} {}
erase := fun declName => do
if (← isSimproc declName <||> isBuiltinSimproc declName) then
let simprocAttrName := simpAttrNameToSimprocAttrName attrName
Attribute.erase declName simprocAttrName
else
let s := ext.getState (← getEnv)
let s ← s.erase (.decl declName)
modifyEnv fun env => ext.modifyState env fun _ => s
}

def registerSimpAttr (attrName : Name) (attrDescr : String)
(ref : Name := by exact decl_name%) : IO SimpExtension := do
let ext ← mkSimpExt ref
mkSimpAttr attrName attrDescr ext ref -- Remark: it will fail if it is not performed during initialization
simpExtensionMapRef.modify fun map => map.insert attrName ext
return ext

builtin_initialize simpExtension : SimpExtension ← registerSimpAttr `simp "simplification theorem"

builtin_initialize sevalSimpExtension : SimpExtension ← registerSimpAttr `seval "symbolic evaluator theorem"

def getSimpTheorems : CoreM SimpTheorems :=
simpExtension.getTheorems

def getSEvalTheorems : CoreM SimpTheorems :=
sevalSimpExtension.getTheorems

def Simp.Context.mkDefault : MetaM Context :=
return { config := {}, simpTheorems := #[(← Meta.getSimpTheorems)], congrTheorems := (← Meta.getSimpCongrTheorems) }

end Lean.Meta
1 change: 1 addition & 0 deletions src/Lean/Meta/Tactic/Simp/RegisterCommand.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Authors: Leonardo de Moura
prelude
import Lean.Meta.Tactic.Simp.SimpTheorems
import Lean.Meta.Tactic.Simp.Simproc
import Lean.Meta.Tactic.Simp.Attr

namespace Lean.Meta.Simp

Expand Down
1 change: 1 addition & 0 deletions src/Lean/Meta/Tactic/Simp/Rewrite.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import Lean.Meta.Tactic.UnifyEq
import Lean.Meta.Tactic.Simp.Types
import Lean.Meta.Tactic.LinearArith.Simp
import Lean.Meta.Tactic.Simp.Simproc
import Lean.Meta.Tactic.Simp.Attr

namespace Lean.Meta.Simp

Expand Down
48 changes: 0 additions & 48 deletions src/Lean/Meta/Tactic/Simp/SimpTheorems.lean
Original file line number Diff line number Diff line change
Expand Up @@ -362,37 +362,6 @@ def addSimpTheorem (ext : SimpExtension) (declName : Name) (post : Bool) (inv :
for simpThm in simpThms do
ext.add (SimpEntry.thm simpThm) attrKind

def mkSimpAttr (attrName : Name) (attrDescr : String) (ext : SimpExtension)
(ref : Name := by exact decl_name%) : IO Unit :=
registerBuiltinAttribute {
ref := ref
name := attrName
descr := attrDescr
applicationTime := AttributeApplicationTime.afterCompilation
add := fun declName stx attrKind =>
let go : MetaM Unit := do
let info ← getConstInfo declName
let post := if stx[1].isNone then true else stx[1][0].getKind == ``Lean.Parser.Tactic.simpPost
let prio ← getAttrParamOptPrio stx[2]
if (← isProp info.type) then
addSimpTheorem ext declName post (inv := false) attrKind prio
else if info.hasValue then
if let some eqns ← getEqnsFor? declName then
for eqn in eqns do
addSimpTheorem ext eqn post (inv := false) attrKind prio
ext.add (SimpEntry.toUnfoldThms declName eqns) attrKind
if hasSmartUnfoldingDecl (← getEnv) declName then
ext.add (SimpEntry.toUnfold declName) attrKind
else
ext.add (SimpEntry.toUnfold declName) attrKind
else
throwError "invalid 'simp', it is not a proposition nor a definition (to unfold)"
discard <| go.run {} {}
erase := fun declName => do
let s := ext.getState (← getEnv)
let s ← s.erase (.decl declName)
modifyEnv fun env => ext.modifyState env fun _ => s
}

def mkSimpExt (name : Name := by exact decl_name%) : IO SimpExtension :=
registerSimpleScopedEnvExtension {
Expand All @@ -409,26 +378,9 @@ abbrev SimpExtensionMap := HashMap Name SimpExtension

builtin_initialize simpExtensionMapRef : IO.Ref SimpExtensionMap ← IO.mkRef {}

def registerSimpAttr (attrName : Name) (attrDescr : String)
(ref : Name := by exact decl_name%) : IO SimpExtension := do
let ext ← mkSimpExt ref
mkSimpAttr attrName attrDescr ext ref -- Remark: it will fail if it is not performed during initialization
simpExtensionMapRef.modify fun map => map.insert attrName ext
return ext

builtin_initialize simpExtension : SimpExtension ← registerSimpAttr `simp "simplification theorem"

builtin_initialize sevalSimpExtension : SimpExtension ← registerSimpAttr `seval "symbolic evaluator theorem"

def getSimpExtension? (attrName : Name) : IO (Option SimpExtension) :=
return (← simpExtensionMapRef.get).find? attrName

def getSimpTheorems : CoreM SimpTheorems :=
simpExtension.getTheorems

def getSEvalTheorems : CoreM SimpTheorems :=
sevalSimpExtension.getTheorems

/-- Auxiliary method for adding a global declaration to a `SimpTheorems` datastructure. -/
def SimpTheorems.addConst (s : SimpTheorems) (declName : Name) (post := true) (inv := false) (prio : Nat := eval_prio default) : MetaM SimpTheorems := do
let s := { s with erased := s.erased.erase (.decl declName post inv) }
Expand Down
39 changes: 19 additions & 20 deletions src/Lean/Meta/Tactic/Simp/Simproc.lean
Original file line number Diff line number Diff line change
Expand Up @@ -127,25 +127,29 @@ def toSimprocEntry (e : SimprocOLeanEntry) : ImportM SimprocEntry := do
def eraseSimprocAttr (ext : SimprocExtension) (declName : Name) : AttrM Unit := do
let s := ext.getState (← getEnv)
unless s.simprocNames.contains declName do
throwError "'{declName}' does not have [simproc] attribute"
throwError "'{declName}' does not have a simproc attribute"
modifyEnv fun env => ext.modifyState env fun s => s.erase declName

def addSimprocAttr (ext : SimprocExtension) (declName : Name) (kind : AttributeKind) (post : Bool) : CoreM Unit := do
def addSimprocAttrCore (ext : SimprocExtension) (declName : Name) (kind : AttributeKind) (post : Bool) : CoreM Unit := do
let proc ← getSimprocFromDecl declName
let some keys ← getSimprocDeclKeys? declName |
throwError "invalid [simproc] attribute, '{declName}' is not a simproc"
ext.add { declName, post, keys, proc } kind

def Simprocs.addCore (s : Simprocs) (keys : Array SimpTheoremKey) (declName : Name) (post : Bool) (proc : Simproc) : Simprocs :=
let s := { s with simprocNames := s.simprocNames.insert declName, erased := s.erased.erase declName }
if post then
{ s with post := s.post.insertCore keys { declName, keys, post, proc } }
else
{ s with pre := s.pre.insertCore keys { declName, keys, post, proc } }

/--
Implements attributes `builtin_simproc` and `builtin_sevalproc`.
-/
def addSimprocBuiltinAttrCore (ref : IO.Ref Simprocs) (declName : Name) (post : Bool) (proc : Simproc) : IO Unit := do
let some keys := (← builtinSimprocDeclsRef.get).keys.find? declName |
throw (IO.userError "invalid [builtin_simproc] attribute, '{declName}' is not a builtin simproc")
if post then
ref.modify fun s => { s with post := s.post.insertCore keys { declName, keys, post, proc } }
else
ref.modify fun s => { s with pre := s.pre.insertCore keys { declName, keys, post, proc } }
ref.modify fun s => s.addCore keys declName post proc

def addSimprocBuiltinAttr (declName : Name) (post : Bool) (proc : Simproc) : IO Unit :=
addSimprocBuiltinAttrCore builtinSimprocsRef declName post proc
Expand All @@ -166,10 +170,7 @@ def Simprocs.add (s : Simprocs) (declName : Name) (post : Bool) : CoreM Simprocs
throw e
let some keys ← getSimprocDeclKeys? declName |
throwError "invalid [simproc] attribute, '{declName}' is not a simproc"
if post then
return { s with post := s.post.insertCore keys { declName, keys, post, proc } }
else
return { s with pre := s.pre.insertCore keys { declName, keys, post, proc } }
return s.addCore keys declName post proc

def SimprocEntry.try (s : SimprocEntry) (numExtraArgs : Nat) (e : Expr) : SimpM Step := do
let mut extraArgs := #[]
Expand Down Expand Up @@ -276,24 +277,22 @@ def mkSimprocExt (name : Name := by exact decl_name%) (ref? : Option (IO.Ref Sim
return {}
ofOLeanEntry := fun _ => toSimprocEntry
toOLeanEntry := fun e => e.toSimprocOLeanEntry
addEntry := fun s e =>
if e.post then
{ s with post := s.post.insertCore e.keys e }
else
{ s with pre := s.pre.insertCore e.keys e }
addEntry := fun s e => s.addCore e.keys e.declName e.post e.proc
}

def addSimprocAttr (ext : SimprocExtension) (declName : Name) (stx : Syntax) (attrKind : AttributeKind) : AttrM Unit := do
let go : MetaM Unit := do
let post := if stx[1].isNone then true else stx[1][0].getKind == ``Lean.Parser.Tactic.simpPost
addSimprocAttrCore ext declName attrKind post
discard <| go.run {} {}

def mkSimprocAttr (attrName : Name) (attrDescr : String) (ext : SimprocExtension) (name : Name) : IO Unit := do
registerBuiltinAttribute {
ref := name
name := attrName
descr := attrDescr
applicationTime := AttributeApplicationTime.afterCompilation
add := fun declName stx attrKind =>
let go : MetaM Unit := do
let post := if stx[1].isNone then true else stx[1][0].getKind == ``Lean.Parser.Tactic.simpPost
addSimprocAttr ext declName attrKind post
discard <| go.run {} {}
add := addSimprocAttr ext
erase := eraseSimprocAttr ext
}

Expand Down
3 changes: 0 additions & 3 deletions src/Lean/Meta/Tactic/Simp/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,6 @@ structure Context where
def Context.isDeclToUnfold (ctx : Context) (declName : Name) : Bool :=
ctx.simpTheorems.isDeclToUnfold declName

def Context.mkDefault : MetaM Context :=
return { config := {}, simpTheorems := #[(← getSimpTheorems)], congrTheorems := (← getSimpCongrTheorems) }

abbrev UsedSimps := HashMap Origin Nat

structure State where
Expand Down
12 changes: 12 additions & 0 deletions tests/lean/run/simproc_builtin_erase.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
example (h : 12 = x) : 10 + 2 = x := by
simp
guard_target =ₛ 12 = x
assumption

attribute [-simp] Nat.reduceAdd

example (h : 12 = x) : 10 + 2 = x := by
fail_if_success simp
simp [Nat.reduceAdd]
guard_target =ₛ 12 = x
assumption
49 changes: 49 additions & 0 deletions tests/lean/run/simproc_erase.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import Lean.Meta.Tactic.Simp.BuiltinSimprocs

def foo (x : Nat) : Nat :=
x + 10

open Lean Meta

/-- doc-comment for reduceFoo -/
simproc reduceFoo (foo _) := fun e => do
unless e.isAppOfArity ``foo 1 do return .continue
let some n ← Nat.fromExpr? e.appArg! | return .continue
return .done { expr := mkNatLit (n+10) }

example : x + foo 2 = 12 + x := by
simp
rw [Nat.add_comm]

attribute [-simp] reduceFoo

example : x + foo 2 = 12 + x := by
fail_if_success simp
simp [foo]
rw [Nat.add_comm]

attribute [simp] reduceFoo

example : x + foo 2 = 12 + x := by
simp
rw [Nat.add_comm]

example (h : 12 = x) : 10 + 2 = x := by
simp
guard_target =ₛ 12 = x
assumption

attribute [-simp] Nat.reduceAdd

example (h : 12 = x) : 10 + 2 = x := by
fail_if_success simp
simp [Nat.reduceAdd]
guard_target =ₛ 12 = x
assumption

attribute [simp] Nat.reduceAdd

example (h : 12 = x) : 10 + 2 = x := by
simp
guard_target =ₛ 12 = x
assumption
Loading