Skip to content

Commit

Permalink
refactor: remove OffsetM
Browse files Browse the repository at this point in the history
  • Loading branch information
leodemoura committed Jan 12, 2025
1 parent 131af27 commit b6ecd21
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 47 deletions.
8 changes: 4 additions & 4 deletions src/Lean/Meta/Tactic/Grind/Arith/Internalize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,20 @@ def internalizeTerm (_e : Expr) (_a : Expr) (_k : Nat) : GoalM Unit := do
-- TODO
return ()

def internalizeCnstr (e : Expr) (c : Cnstr Expr) : GoalM Unit := OffsetM.run do
def internalizeCnstr (e : Expr) : GoalM Unit := do
let some c := isNatOffsetCnstr? e | return ()
let c := { c with
a := (← mkNode c.a)
b := (← mkNode c.b)
}
trace[grind.offset.internalize] "{e} ↦ {c}"
modify fun s => { s with
modify' fun s => { s with
cnstrs := s.cnstrs.insert { expr := e } c
}

end Offset

def internalize (e : Expr) : GoalM Unit := do
if let some c := isNatOffsetCnstr? e then
Offset.internalizeCnstr e c
Offset.internalizeCnstr e

end Lean.Meta.Grind.Arith
6 changes: 3 additions & 3 deletions src/Lean/Meta/Tactic/Grind/Arith/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ namespace Offset
def isCnstr? (e : Expr) : GoalM (Option (Cnstr NodeId)) :=
return (← get).arith.offset.cnstrs.find? { expr := e }

def assertTrue (c : Cnstr NodeId) (p : Expr) : GoalM Unit := OffsetM.run do
def assertTrue (c : Cnstr NodeId) (p : Expr) : GoalM Unit := do
addEdge c.a c.b c.k (← mkOfEqTrue p)

def assertFalse (c : Cnstr NodeId) (p : Expr) : GoalM Unit := OffsetM.run do
let p := mkOfNegEqFalse (← get).nodes c p
def assertFalse (c : Cnstr NodeId) (p : Expr) : GoalM Unit := do
let p := mkOfNegEqFalse (← get').nodes c p
let c := c.neg
addEdge c.a c.b c.k p

Expand Down
79 changes: 39 additions & 40 deletions src/Lean/Meta/Tactic/Grind/Arith/Offset.lean
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,18 @@ The main advantage of this module over a full linear integer arithmetic procedur
its ability to efficiently detect all implied equalities and inequalities.
-/

abbrev OffsetM := StateT State GoalM
def get' : GoalM State := do
return (← get).arith.offset

def OffsetM.run (x : OffsetM α) : GoalM α := do
let os ← modifyGet fun s => (s.arith.offset, { s with arith.offset := {} })
let (a, os') ← StateT.run x os
modify fun s => { s with arith.offset := os' }
return a
@[inline] def modify' (f : State → State) : GoalM Unit := do
modify fun s => { s with arith.offset := f s.arith.offset }

def mkNode (expr : Expr) : OffsetM NodeId := do
if let some nodeId := (← get).nodeMap.find? { expr } then
def mkNode (expr : Expr) : GoalM NodeId := do
if let some nodeId := (← get').nodeMap.find? { expr } then
return nodeId
let nodeId : NodeId := (← get).nodes.size
let nodeId : NodeId := (← get').nodes.size
trace[grind.offset.internalize.term] "{expr} ↦ #{nodeId}"
modify fun s => { s with
modify' fun s => { s with
nodes := s.nodes.push expr
nodeMap := s.nodeMap.insert { expr } nodeId
sources := s.sources.push {}
Expand All @@ -52,62 +50,63 @@ def mkNode (expr : Expr) : OffsetM NodeId := do
}
return nodeId

def isUnsat : OffsetM Bool :=
return (← get).unsat.isSome
def isUnsat : GoalM Bool :=
return (← get').unsat.isSome

private def getDist? (u v : NodeId) : OffsetM (Option Int) := do
return (← get).targets[u]!.find? v
private def getDist? (u v : NodeId) : GoalM (Option Int) := do
return (← get').targets[u]!.find? v

private def getProof? (u v : NodeId) : OffsetM (Option ProofInfo) := do
return (← get).proofs[u]!.find? v
private def getProof? (u v : NodeId) : GoalM (Option ProofInfo) := do
return (← get').proofs[u]!.find? v

partial def extractProof (u v : NodeId) : OffsetM Expr := do
partial def extractProof (u v : NodeId) : GoalM Expr := do
go (← getProof? u v).get!
where
go (p : ProofInfo) : OffsetM Expr := do
go (p : ProofInfo) : GoalM Expr := do
if u == p.w then
return p.proof
else
let p' := (← getProof? u p.w).get!
go (mkTrans (← get).nodes p' p v)
go (mkTrans (← get').nodes p' p v)

private def setUnsat (_u _v : NodeId) (_k : Int) (p : Expr) : OffsetM Unit := do
modify fun s => { s with
private def setUnsat (u v : NodeId) (k : Int) (p : Expr) : GoalM Unit := do
trace[Meta.debug] "unsat #{u}-({k})->#{v}"
modify' fun s => { s with
unsat := p -- TODO
}

private def setDist (u v : NodeId) (k : Int) : OffsetM Unit := do
private def setDist (u v : NodeId) (k : Int) : GoalM Unit := do
trace[grind.offset.dist] "{({ a := u, b := v, k : Cnstr NodeId})}"
modify fun s => { s with
modify' fun s => { s with
targets := s.targets.modify u fun es => es.insert v k
sources := s.sources.modify v fun es => es.insert u k
}

private def setProof (u v : NodeId) (p : ProofInfo) : OffsetM Unit := do
modify fun s => { s with
private def setProof (u v : NodeId) (p : ProofInfo) : GoalM Unit := do
modify' fun s => { s with
proofs := s.proofs.modify u fun es => es.insert v p
}

@[inline]
private def forEachSourceOf (u : NodeId) (f : NodeId → Int → OffsetM Unit) : OffsetM Unit := do
(← get).sources[u]!.forM f
private def forEachSourceOf (u : NodeId) (f : NodeId → Int → GoalM Unit) : GoalM Unit := do
(← get').sources[u]!.forM f

@[inline]
private def forEachTargetOf (u : NodeId) (f : NodeId → Int → OffsetM Unit) : OffsetM Unit := do
(← get).targets[u]!.forM f
private def forEachTargetOf (u : NodeId) (f : NodeId → Int → GoalM Unit) : GoalM Unit := do
(← get').targets[u]!.forM f

private def isShorter (u v : NodeId) (k : Int) : OffsetM Bool := do
private def isShorter (u v : NodeId) (k : Int) : GoalM Bool := do
if let some k' ← getDist? u v then
return k < k'
else
return true

private def updateIfShorter (u v : NodeId) (k : Int) (w : NodeId) : OffsetM Unit := do
private def updateIfShorter (u v : NodeId) (k : Int) (w : NodeId) : GoalM Unit := do
if (← isShorter u v k) then
setDist u v k
setProof u v (← getProof? w v).get!

def addEdge (u : NodeId) (v : NodeId) (k : Int) (p : Expr) : OffsetM Unit := do
def addEdge (u : NodeId) (v : NodeId) (k : Int) (p : Expr) : GoalM Unit := do
if (← isUnsat) then return ()
if let some k' ← getDist? v u then
if k'+k < 0 then
Expand All @@ -118,7 +117,7 @@ def addEdge (u : NodeId) (v : NodeId) (k : Int) (p : Expr) : OffsetM Unit := do
setProof u v { w := u, k, proof := p }
update
where
update : OffsetM Unit := do
update : GoalM Unit := do
forEachTargetOf v fun j k₂ => do
/- Check whether new path: `u -(k)-> v -(k₂)-> j` is shorter -/
updateIfShorter u j (k+k₂) v
Expand All @@ -129,15 +128,15 @@ where
/- Check whether new path: `i -(k₁)-> u -(k)-> v -(k₂) -> j` is shorter -/
updateIfShorter i j (k₁+k+k₂) v

def traceDists : OffsetM Unit := do
let s ← get
def traceDists : GoalM Unit := do
let s ← get'
for u in [:s.targets.size], es in s.targets.toArray do
for (v, k) in es do
trace[grind.offset.dist] "#{u} -({k})-> #{v}"

def Cnstr.toExpr (c : Cnstr NodeId) : OffsetM Expr := do
let a := (← get).nodes[c.a]!
let b := (← get).nodes[c.b]!
def Cnstr.toExpr (c : Cnstr NodeId) : GoalM Expr := do
let a := (← get').nodes[c.a]!
let b := (← get').nodes[c.b]!
let mk := if c.le then mkNatLE else mkNatEq
if c.k == 0 then
return mk a b
Expand All @@ -146,8 +145,8 @@ def Cnstr.toExpr (c : Cnstr NodeId) : OffsetM Expr := do
else
return mk a (mkNatAdd b (Lean.toExpr c.k.toNat))

def checkInvariants : GoalM Unit := OffsetM.run do
let s ← get
def checkInvariants : GoalM Unit := do
let s ← get'
for u in [:s.targets.size], es in s.targets.toArray do
for (v, k) in es do
let c : Cnstr NodeId := { a := u, b := v, k }
Expand Down

0 comments on commit b6ecd21

Please sign in to comment.