Skip to content

Commit

Permalink
feat: new implementation for simp (config := { ground := true }) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
leodemoura authored Jan 18, 2024
1 parent 27b7002 commit ec30da8
Show file tree
Hide file tree
Showing 11 changed files with 280 additions and 95 deletions.
8 changes: 8 additions & 0 deletions src/Init/Data/UInt/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,14 @@ instance : Min UInt16 := minOfLe
def UInt32.ofNat (n : @& Nat) : UInt32 := ⟨Fin.ofNat n⟩
@[extern "lean_uint32_of_nat"]
def UInt32.ofNat' (n : Nat) (h : n < UInt32.size) : UInt32 := ⟨⟨n, h⟩⟩
/--
Converts the given natural number to `UInt32`, but returns `2^32 - 1` for natural numbers `>= 2^32`.
-/
def UInt32.ofNatTruncate (n : Nat) : UInt32 :=
if h : n < UInt32.size then
UInt32.ofNat' n h
else
UInt32.ofNat' (UInt32.size - 1) (by decide)
abbrev Nat.toUInt32 := UInt32.ofNat
@[extern "lean_uint32_add"]
def UInt32.add (a b : UInt32) : UInt32 := ⟨a.val + b.val⟩
Expand Down
5 changes: 5 additions & 0 deletions src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Fin.lean
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,9 @@ builtin_simproc reduceLE (( _ : Fin _) ≤ _) := reduceBinPred ``LE.le 4 (.
builtin_simproc reduceGT (( _ : Fin _) > _) := reduceBinPred ``GT.gt 4 (. > .)
builtin_simproc reduceGE (( _ : Fin _) ≥ _) := reduceBinPred ``GE.ge 4 (. ≥ .)

/-- Return `.done` for Fin values. We don't want to unfold them when `ground := true`. -/
builtin_simproc isValue ((OfNat.ofNat _ : Fin _)) := fun e => OptionT.run do
guard (e.isAppOfArity ``OfNat.ofNat 3)
return .done { expr := e }

end Fin
19 changes: 16 additions & 3 deletions src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Int.lean
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,22 @@ If they do, they must disable the following `simprocs`.

builtin_simproc reduceNeg ((- _ : Int)) := fun e => OptionT.run do
guard (e.isAppOfArity ``Neg.neg 3)
let v ← fromExpr? e.appArg!
guard (v < 0)
return .done { expr := toExpr (- v) }
let arg := e.appArg!
if arg.isAppOfArity ``OfNat.ofNat 3 then
-- We return .done to ensure `Neg.neg` is not unfolded even when `ground := true`.
guard (← getContext).unfoldGround
return .done { expr := e }
else
let v ← fromExpr? arg
if v < 0 then
return .done { expr := toExpr (- v) }
else
return .done { expr := toExpr v }

/-- Return `.done` for positive Int values. We don't want to unfold them when `ground := true`. -/
builtin_simproc isPosValue ((OfNat.ofNat _ : Int)) := fun e => OptionT.run do
guard (e.isAppOfArity ``OfNat.ofNat 3)
return .done { expr := e }

builtin_simproc reduceAdd ((_ + _ : Int)) := reduceBin ``HAdd.hAdd 6 (· + ·)
builtin_simproc reduceMul ((_ * _ : Int)) := reduceBin ``HMul.hMul 6 (· * ·)
Expand Down
6 changes: 6 additions & 0 deletions src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Nat.lean
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,10 @@ builtin_simproc reduceLE (( _ : Nat) ≤ _) := reduceBinPred ``LE.le 4 (. ≤
builtin_simproc reduceGT (( _ : Nat) > _) := reduceBinPred ``GT.gt 4 (. > .)
builtin_simproc reduceGE (( _ : Nat) ≥ _) := reduceBinPred ``GE.ge 4 (. ≥ .)

/-- Return `.done` for Nat values. We don't want to unfold them when `ground := true`. -/
builtin_simproc isValue ((OfNat.ofNat _ : Nat)) := fun e => OptionT.run do
guard (← getContext).unfoldGround
guard (e.isAppOfArity ``OfNat.ofNat 3)
return .done { expr := e }

end Nat
6 changes: 6 additions & 0 deletions src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/UInt.lean
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ builtin_simproc $(mkIdent `reduceLE):ident (( _ : $typeName) ≤ _) := reduceB
builtin_simproc $(mkIdent `reduceGT):ident (( _ : $typeName) > _) := reduceBinPred ``GT.gt 4 (. > .)
builtin_simproc $(mkIdent `reduceGE):ident (( _ : $typeName) ≥ _) := reduceBinPred ``GE.ge 4 (. ≥ .)

/-- Return `.done` for UInt values. We don't want to unfold them when `ground := true`. -/
builtin_simproc isValue ((OfNat.ofNat _ : $typeName)) := fun e => OptionT.run do
guard (← getContext).unfoldGround
guard (e.isAppOfArity ``OfNat.ofNat 3)
return .done { expr := e }

end $typeName
)

Expand Down
86 changes: 30 additions & 56 deletions src/Lean/Meta/Tactic/Simp/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,6 @@ where
let f := e.getAppFn
f.isConst && isMatcherCore env f.constName!

/--
Auxiliary function for implementing `ctx.config.ground`: evaluate ground terms eagerly.
We currently use `whnf` to implement this feature, but we want to stop ground evaluation at symbols marked with the `-` modifier.
For example, `simp (config := { ground := true }) [-f]` should not unfold `f` even if the goal contains a ground term such as `f 2`.
-/
private def canUnfoldAtSimpGround (erased : SimpTheoremsArray) (_ : Meta.Config) (info : ConstantInfo) : CoreM Bool := do
return !erased.isErased (.decl info.name)

/--
Try to unfold `e`.
-/
Expand All @@ -124,37 +116,6 @@ private def unfold? (e : Expr) : SimpM (Option Expr) := do
return none
let fName := f.constName!
let ctx ← getContext
-- TODO: remove `rec` after we switch to new code generator
let rec unfoldGround? : SimpM (Option Expr) := do
unless ctx.config.ground do return none
-- We are assuming that assigned metavariables are going to be instantiated by the main simp loop.
if e.hasExprMVar || e.hasFVar then return none
if ctx.simpTheorems.isErased (.decl fName) then return none
-- TODO: check whether we need more filters
if (← isType e) then return none -- we don't unfold types
if (← isProof e) then return none -- we don't unfold proofs
if (← isInstance fName) then return none -- we don't unfold instances
-- TODO: we must have a notion of `simp` value, or more general solution for Lean
if Meta.isMatchValue e || isOfNatNatLit e then return none
if e.isConst then
-- We don't unfold constants that take arguments
-- TODO: add support for skipping partial applications too.
if let .forallE .. ← whnfD (← inferType e) then
return none
/-
We are currently using `whnf` with a custom `canUnfold?` predicate to reduce ground terms.
This can be inefficient, and produce proofs that are too expensive to type check in the Kernel. Some reasons:
- Functions defined by Well-founded recursion are expensive to reduce here and in the kernel.
- The kernel does not know we may have controlled reduction using `canUnfold?`.
It would be great to reduce the ground term using a to-be-implemented `cbv` tactic which produces a
proof that can be efficiently checked by the kernel.
-/
let eNew ← withDefault <|
withTheReader Meta.Context (fun c => { c with canUnfold? := canUnfoldAtSimpGround ctx.simpTheorems }) <| whnf e
if eNew == e then return none
trace[Meta.Tactic.simp.ground] "{e}\n---->\n{eNew}"
return some eNew
let rec unfoldDeclToUnfold? : SimpM (Option Expr) := do
let options ← getOptions
let cfg ← getConfig
Expand All @@ -172,9 +133,7 @@ private def unfold? (e : Expr) : SimpM (Option Expr) := do
-- Partially applied function, return `none`. See issue #2042
if arity > e.getAppNumArgs then return none
withDefault <| unfoldDefinition? e
if let some eNew ← unfoldGround? then
return some eNew
else if (← isProjectionFn fName) then
if (← isProjectionFn fName) then
return none -- should be reduced by `reduceProjFn?`
else if ctx.config.autoUnfold then
if ctx.simpTheorems.isErased (.decl fName) then
Expand Down Expand Up @@ -449,7 +408,7 @@ private partial def dsimpImpl (e : Expr) : SimpM Expr := do
return .visit r.expr
return .continue
let post (e : Expr) : SimpM TransformStep := do
if let Step.visit r ← rewritePost e (fun _ => pure none) (rflOnly := true) then
if let some r ← rewritePost? e (fun _ => pure none) (rflOnly := true) then
if r.expr != e then
return .visit r.expr
let mut eNew ← reduce e
Expand Down Expand Up @@ -600,8 +559,12 @@ def simpStep (e : Expr) : SimpM Result := do

def cacheResult (e : Expr) (cfg : Config) (r : Result) : SimpM Result := do
if cfg.memoize then
let dischargeDepth := (← readThe Simp.Context).dischargeDepth
modify fun s => { s with cache := s.cache.insert e { r with dischargeDepth } }
let ctx ← readThe Simp.Context
let dischargeDepth := ctx.dischargeDepth
if ctx.unfoldGround then
modify fun s => { s with cacheGround := s.cacheGround.insert e { r with dischargeDepth } }
else
modify fun s => { s with cache := s.cache.insert e { r with dischargeDepth } }
return r

partial def simpLoop (e : Expr) (r : Result) : SimpM Result := do
Expand All @@ -628,19 +591,30 @@ partial def simpLoop (e : Expr) (r : Result) : SimpM Result := do
@[export lean_simp]
def simpImpl (e : Expr) : SimpM Result := withIncRecDepth do
checkSystem "simp"
let cfg ← getConfig
if (← isProof e) then
return { expr := e }
if cfg.memoize then
if let some result := (← get).cache.find? e then
/-
If the result was cached at a dischargeDepth > the current one, it may not be valid.
See issue #1234
-/
if result.dischargeDepth ≤ (← readThe Simp.Context).dischargeDepth then
return result
trace[Meta.Tactic.simp.heads] "{repr e.toHeadIndex}"
simpLoop e { expr := e }
let ctx ← getContext
trace[Meta.debug] "visit [{ctx.unfoldGround}]: {e}"
if ctx.unfoldGround then
if (← isType e) then
unless (← isProp e) do
-- Recall that we set `unfoldGround := false` if `e` is a type that is not a proposition.
return (← withTheReader Context (fun ctx => { ctx with unfoldGround := false }) go)
go
where
go : SimpM Result := do
let cfg ← getConfig
if cfg.memoize then
let cache ← if (← getContext).unfoldGround then pure ((← get).cacheGround) else pure ((← get).cache)
if let some result := cache.find? e then
/-
If the result was cached at a dischargeDepth > the current one, it may not be valid.
See issue #1234
-/
if result.dischargeDepth ≤ (← readThe Simp.Context).dischargeDepth then
return result
trace[Meta.Tactic.simp.heads] "{repr e.toHeadIndex}"
simpLoop e { expr := e }

@[inline] def withSimpConfig (ctx : Context) (x : MetaM α) : MetaM α :=
withConfig (fun c => { c with etaStruct := ctx.config.etaStruct }) <| withReducible x
Expand Down
83 changes: 74 additions & 9 deletions src/Lean/Meta/Tactic/Simp/Rewrite.lean
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,6 @@ def rewritePre (e : Expr) (discharge? : Expr → SimpM (Option Expr)) (rflOnly :
return Step.visit r
return Step.visit { expr := e }

def rewritePost (e : Expr) (discharge? : Expr → SimpM (Option Expr)) (rflOnly := false) : SimpM Step := do
for thms in (← getContext).simpTheorems do
if let some r ← rewrite? e thms.post thms.erased discharge? (tag := "post") (rflOnly := rflOnly) then
return Step.visit r
return Step.visit { expr := e }

partial def preDefault (e : Expr) (discharge? : Expr → SimpM (Option Expr)) : SimpM Step := do
let s ← rewritePre e discharge?
let s ← andThen s (simpMatch? discharge?)
Expand All @@ -309,10 +303,81 @@ partial def preDefault (e : Expr) (discharge? : Expr → SimpM (Option Expr)) :
else
andThen s (preDefault · discharge?)

def rewritePost? (e : Expr) (discharge? : Expr → SimpM (Option Expr)) (rflOnly := false) : SimpM (Option Result) := do
for thms in (← getContext).simpTheorems do
if let some r ← rewrite? e thms.post thms.erased discharge? (tag := "post") (rflOnly := rflOnly) then
return r
return none

/--
Try to unfold ground term when `Context.unfoldGround := true`.
-/
def unfoldGround? (discharge? : Expr → SimpM (Option Expr)) (e : Expr) : SimpM (Option Step) := do
-- Ground term unfolding is disabled.
unless (← getContext).unfoldGround do return none
-- `e` is not a ground term.
unless !e.hasExprMVar && !e.hasFVar do return none
trace[Meta.debug] "unfoldGround? {e}"
-- Check whether `e` is a constant application
let f := e.getAppFn
let .const declName lvls := f | return none
-- If declaration has been marked to not be unfolded, return none.
let ctx ← getContext
if ctx.simpTheorems.isErased (.decl declName) then return none
-- Matcher applications should have been reduced before we get here.
if (← isMatcher declName) then return none
if let some eqns ← withDefault <| getEqnsFor? declName then
-- `declName` has equation theorems associated with it.
for eqn in eqns do
-- TODO: cache SimpTheorem to avoid calls to `isRflTheorem`
if let some result ← Simp.tryTheorem? e { origin := .decl eqn, proof := mkConst eqn, rfl := (← isRflTheorem eqn) } discharge? then
trace[Meta.Tactic.simp.ground] "unfolded, {e} => {result.expr}"
return some (.visit result)
return none
-- `declName` does not have equation theorems associated with it.
if e.isConst then
-- We don't unfold constants that take arguments
if let .forallE .. ← whnfD (← inferType e) then
return none
let info ← getConstInfo declName
unless info.hasValue && info.levelParams.length == lvls.length do return none
let fBody ← instantiateValueLevelParams info lvls
let eNew := fBody.betaRev e.getAppRevArgs (useZeta := true)
trace[Meta.Tactic.simp.ground] "delta, {e} => {eNew}"
return some (.visit { expr := eNew })

def postDefault (e : Expr) (discharge? : Expr → SimpM (Option Expr)) : SimpM Step := do
let s ← rewritePost e discharge?
/-
Remark 1:
`rewritePost?` used to return a `Step`, and we would try other methods even if it succeeded in rewriting the term.
This behavior was problematic, especially when `ground := true`, because we have rewriting rules such as
`List.append as bs = as ++ bs`, which are rules for folding polymorphic functions.
This type of rule can trigger nontermination in the context of `ground := true`.
For example, the method `unfoldGround?` would reduce `[] ++ [1]` to `List.append [] [1]`, and
`rewritePost` would refold it back to `[] ++ [1]`, leading to an endless loop.
Initially, we considered always reducing ground terms first. However, this approach would
prevent us from adding auxiliary lemmas that could short-circuit the evaluation.
Ultimately, we settled on the following compromise: if a `rewritePost?` succeeds and produces a result `r`,
we return with `.visit r`. This allows pre-methods to be applied again along with other rewriting rules.
This strategy helps avoid non-termination, as we have `[simp]` theorems specifically for reducing `List.append`
```lean
@[simp] theorem nil_append (as : List α) : [] ++ as = as := ...
@[simp] theorem cons_append (a : α) (as bs : List α) : (a::as) ++ bs = a::(as ++ bs) := ...
```
Remark 2:
In the simplifier, the ground value for some inductive types is *not* a constructor application.
Examples: `Nat`, `Int`, `Fin _`, `UInt?`. These types are represented using `OfNat.ofNat`.
To ensure `unfoldGround?` does not unfold `OfNat.ofNat` applications for these types, we
have `simproc` that return `.done ..` for these ground values. Thus, `unfoldGround?` is not
even tried. Alternative design: we could add an extensible ground value predicate.
-/
if let some r ← rewritePost? e discharge? then
return .visit r
let s ← andThen (.visit { expr := e }) postSimproc?
let s ← andThen s (unfoldGround? discharge?)
let s ← andThen s simpArith?
let s ← andThen s postSimproc?
let s ← andThen s tryRewriteUsingDecide?
andThen s tryRewriteCtorEq?

Expand Down Expand Up @@ -388,7 +453,7 @@ def dischargeDefault? (e : Expr) : SimpM (Option Expr) := do
return some r
let ctx ← getContext
trace[Meta.Tactic.simp.discharge] ">> discharge?: {e}"
if ctx.dischargeDepth >= ctx.config.maxDischargeDepth then
if ctx.dischargeDepth >= ctx.maxDischargeDepth then
trace[Meta.Tactic.simp.discharge] "maximum discharge depth has been reached"
return none
else
Expand Down
Loading

0 comments on commit ec30da8

Please sign in to comment.