diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Internalize.lean b/src/Lean/Meta/Tactic/Grind/Arith/Internalize.lean index d7493178314f..96fac2547f80 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Internalize.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Internalize.lean @@ -14,20 +14,20 @@ def internalizeTerm (_e : Expr) (_a : Expr) (_k : Nat) : GoalM Unit := do -- TODO return () -def internalizeCnstr (e : Expr) (c : Cnstr Expr) : GoalM Unit := OffsetM.run do +def internalizeCnstr (e : Expr) : GoalM Unit := do + let some c := isNatOffsetCnstr? e | return () let c := { c with a := (← mkNode c.a) b := (← mkNode c.b) } trace[grind.offset.internalize] "{e} ↦ {c}" - modify fun s => { s with + modify' fun s => { s with cnstrs := s.cnstrs.insert { expr := e } c } end Offset def internalize (e : Expr) : GoalM Unit := do - if let some c := isNatOffsetCnstr? e then - Offset.internalizeCnstr e c + Offset.internalizeCnstr e end Lean.Meta.Grind.Arith diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Main.lean b/src/Lean/Meta/Tactic/Grind/Arith/Main.lean index 7f008d4c9b4b..d46388c3ae0b 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Main.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Main.lean @@ -13,11 +13,11 @@ namespace Offset def isCnstr? (e : Expr) : GoalM (Option (Cnstr NodeId)) := return (← get).arith.offset.cnstrs.find? { expr := e } -def assertTrue (c : Cnstr NodeId) (p : Expr) : GoalM Unit := OffsetM.run do +def assertTrue (c : Cnstr NodeId) (p : Expr) : GoalM Unit := do addEdge c.a c.b c.k (← mkOfEqTrue p) -def assertFalse (c : Cnstr NodeId) (p : Expr) : GoalM Unit := OffsetM.run do - let p := mkOfNegEqFalse (← get).nodes c p +def assertFalse (c : Cnstr NodeId) (p : Expr) : GoalM Unit := do + let p := mkOfNegEqFalse (← get').nodes c p let c := c.neg addEdge c.a c.b c.k p diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Offset.lean b/src/Lean/Meta/Tactic/Grind/Arith/Offset.lean index 4a4e6f1bf9e1..95fec2f7edb3 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Offset.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Offset.lean @@ -30,20 +30,18 @@ The main advantage of this module over a full linear integer arithmetic procedur its ability to efficiently detect all implied equalities and inequalities. -/ -abbrev OffsetM := StateT State GoalM +def get' : GoalM State := do + return (← get).arith.offset -def OffsetM.run (x : OffsetM α) : GoalM α := do - let os ← modifyGet fun s => (s.arith.offset, { s with arith.offset := {} }) - let (a, os') ← StateT.run x os - modify fun s => { s with arith.offset := os' } - return a +@[inline] def modify' (f : State → State) : GoalM Unit := do + modify fun s => { s with arith.offset := f s.arith.offset } -def mkNode (expr : Expr) : OffsetM NodeId := do - if let some nodeId := (← get).nodeMap.find? { expr } then +def mkNode (expr : Expr) : GoalM NodeId := do + if let some nodeId := (← get').nodeMap.find? { expr } then return nodeId - let nodeId : NodeId := (← get).nodes.size + let nodeId : NodeId := (← get').nodes.size trace[grind.offset.internalize.term] "{expr} ↦ #{nodeId}" - modify fun s => { s with + modify' fun s => { s with nodes := s.nodes.push expr nodeMap := s.nodeMap.insert { expr } nodeId sources := s.sources.push {} @@ -52,62 +50,63 @@ def mkNode (expr : Expr) : OffsetM NodeId := do } return nodeId -def isUnsat : OffsetM Bool := - return (← get).unsat.isSome +def isUnsat : GoalM Bool := + return (← get').unsat.isSome -private def getDist? (u v : NodeId) : OffsetM (Option Int) := do - return (← get).targets[u]!.find? v +private def getDist? (u v : NodeId) : GoalM (Option Int) := do + return (← get').targets[u]!.find? v -private def getProof? (u v : NodeId) : OffsetM (Option ProofInfo) := do - return (← get).proofs[u]!.find? v +private def getProof? (u v : NodeId) : GoalM (Option ProofInfo) := do + return (← get').proofs[u]!.find? v -partial def extractProof (u v : NodeId) : OffsetM Expr := do +partial def extractProof (u v : NodeId) : GoalM Expr := do go (← getProof? u v).get! where - go (p : ProofInfo) : OffsetM Expr := do + go (p : ProofInfo) : GoalM Expr := do if u == p.w then return p.proof else let p' := (← getProof? u p.w).get! - go (mkTrans (← get).nodes p' p v) + go (mkTrans (← get').nodes p' p v) -private def setUnsat (_u _v : NodeId) (_k : Int) (p : Expr) : OffsetM Unit := do - modify fun s => { s with +private def setUnsat (u v : NodeId) (k : Int) (p : Expr) : GoalM Unit := do + trace[Meta.debug] "unsat #{u}-({k})->#{v}" + modify' fun s => { s with unsat := p -- TODO } -private def setDist (u v : NodeId) (k : Int) : OffsetM Unit := do +private def setDist (u v : NodeId) (k : Int) : GoalM Unit := do trace[grind.offset.dist] "{({ a := u, b := v, k : Cnstr NodeId})}" - modify fun s => { s with + modify' fun s => { s with targets := s.targets.modify u fun es => es.insert v k sources := s.sources.modify v fun es => es.insert u k } -private def setProof (u v : NodeId) (p : ProofInfo) : OffsetM Unit := do - modify fun s => { s with +private def setProof (u v : NodeId) (p : ProofInfo) : GoalM Unit := do + modify' fun s => { s with proofs := s.proofs.modify u fun es => es.insert v p } @[inline] -private def forEachSourceOf (u : NodeId) (f : NodeId → Int → OffsetM Unit) : OffsetM Unit := do - (← get).sources[u]!.forM f +private def forEachSourceOf (u : NodeId) (f : NodeId → Int → GoalM Unit) : GoalM Unit := do + (← get').sources[u]!.forM f @[inline] -private def forEachTargetOf (u : NodeId) (f : NodeId → Int → OffsetM Unit) : OffsetM Unit := do - (← get).targets[u]!.forM f +private def forEachTargetOf (u : NodeId) (f : NodeId → Int → GoalM Unit) : GoalM Unit := do + (← get').targets[u]!.forM f -private def isShorter (u v : NodeId) (k : Int) : OffsetM Bool := do +private def isShorter (u v : NodeId) (k : Int) : GoalM Bool := do if let some k' ← getDist? u v then return k < k' else return true -private def updateIfShorter (u v : NodeId) (k : Int) (w : NodeId) : OffsetM Unit := do +private def updateIfShorter (u v : NodeId) (k : Int) (w : NodeId) : GoalM Unit := do if (← isShorter u v k) then setDist u v k setProof u v (← getProof? w v).get! -def addEdge (u : NodeId) (v : NodeId) (k : Int) (p : Expr) : OffsetM Unit := do +def addEdge (u : NodeId) (v : NodeId) (k : Int) (p : Expr) : GoalM Unit := do if (← isUnsat) then return () if let some k' ← getDist? v u then if k'+k < 0 then @@ -118,7 +117,7 @@ def addEdge (u : NodeId) (v : NodeId) (k : Int) (p : Expr) : OffsetM Unit := do setProof u v { w := u, k, proof := p } update where - update : OffsetM Unit := do + update : GoalM Unit := do forEachTargetOf v fun j k₂ => do /- Check whether new path: `u -(k)-> v -(k₂)-> j` is shorter -/ updateIfShorter u j (k+k₂) v @@ -129,15 +128,15 @@ where /- Check whether new path: `i -(k₁)-> u -(k)-> v -(k₂) -> j` is shorter -/ updateIfShorter i j (k₁+k+k₂) v -def traceDists : OffsetM Unit := do - let s ← get +def traceDists : GoalM Unit := do + let s ← get' for u in [:s.targets.size], es in s.targets.toArray do for (v, k) in es do trace[grind.offset.dist] "#{u} -({k})-> #{v}" -def Cnstr.toExpr (c : Cnstr NodeId) : OffsetM Expr := do - let a := (← get).nodes[c.a]! - let b := (← get).nodes[c.b]! +def Cnstr.toExpr (c : Cnstr NodeId) : GoalM Expr := do + let a := (← get').nodes[c.a]! + let b := (← get').nodes[c.b]! let mk := if c.le then mkNatLE else mkNatEq if c.k == 0 then return mk a b @@ -146,8 +145,8 @@ def Cnstr.toExpr (c : Cnstr NodeId) : OffsetM Expr := do else return mk a (mkNatAdd b (Lean.toExpr c.k.toNat)) -def checkInvariants : GoalM Unit := OffsetM.run do - let s ← get +def checkInvariants : GoalM Unit := do + let s ← get' for u in [:s.targets.size], es in s.targets.toArray do for (v, k) in es do let c : Cnstr NodeId := { a := u, b := v, k }