Skip to content

Commit

Permalink
fix: revert OmegaM state when not multiplying out (leanprover-communi…
Browse files Browse the repository at this point in the history
  • Loading branch information
kim-em authored and fgdorais committed Feb 18, 2024
1 parent b485faf commit 24c5044
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 23 deletions.
39 changes: 23 additions & 16 deletions Std/Tactic/Omega/Frontend.lean
Original file line number Diff line number Diff line change
Expand Up @@ -156,21 +156,28 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
(← mkEqSymm neg_eval)
pure (-l, prf', facts)
| (``HMul.hMul, #[_, _, _, _, x, y]) =>
let (xl, xprf, xfacts) ← asLinearCombo x
let (yl, yprf, yfacts) ← asLinearCombo y
if xl.coeffs.isZero ∨ yl.coeffs.isZero then
let prf : OmegaM Expr := do
let h ← mkDecideProof (mkApp2 (.const ``Or [])
(.app (.const ``Coeffs.isZero []) (toExpr xl.coeffs))
(.app (.const ``Coeffs.isZero []) (toExpr yl.coeffs)))
let mul_eval :=
mkApp4 (.const ``LinearCombo.mul_eval []) (toExpr xl) (toExpr yl) (← atomsCoeffs) h
mkEqTrans
(← mkAppM ``Int.mul_congr #[← xprf, ← yprf])
(← mkEqSymm mul_eval)
pure (LinearCombo.mul xl yl, prf, xfacts.merge yfacts)
else
mkAtomLinearCombo e
-- If we decide not to expand out the multiplication,
-- we have to revert the `OmegaM` state so that any new facts about the factors
-- can still be reported when they are visited elsewhere.
let r? ← commitWhen do
let (xl, xprf, xfacts) ← asLinearCombo x
let (yl, yprf, yfacts) ← asLinearCombo y
if xl.coeffs.isZero ∨ yl.coeffs.isZero then
let prf : OmegaM Expr := do
let h ← mkDecideProof (mkApp2 (.const ``Or [])
(.app (.const ``Coeffs.isZero []) (toExpr xl.coeffs))
(.app (.const ``Coeffs.isZero []) (toExpr yl.coeffs)))
let mul_eval :=
mkApp4 (.const ``LinearCombo.mul_eval []) (toExpr xl) (toExpr yl) (← atomsCoeffs) h
mkEqTrans
(← mkAppM ``Int.mul_congr #[← xprf, ← yprf])
(← mkEqSymm mul_eval)
pure (some (LinearCombo.mul xl yl, prf, xfacts.merge yfacts), true)
else
pure (none, false)
match r? with
| some r => pure r
| none => mkAtomLinearCombo e
| (``HMod.hMod, #[_, _, _, _, n, k]) => rewrite e (mkApp2 (.const ``Int.emod_def []) n k)
| (``HDiv.hDiv, #[_, _, _, _, x, z]) =>
match intCast? z with
Expand Down Expand Up @@ -434,7 +441,7 @@ partial def splitDisjunction (m : MetaProblem) (g : MVarId) : OmegaM Unit := g.w
let (⟨g₁, h₁⟩, ⟨g₂, h₂⟩) ← cases₂ g h
trace[omega] "Adding facts:\n{← g₁.withContext <| inferType (.fvar h₁)}"
let m₁ := { m with facts := [.fvar h₁], disjunctions := t }
let r ← savingState do
let r ← withoutModifyingState do
let (m₁, n) ← g₁.withContext m₁.processFacts
if 0 < n then
omegaImpl m₁ g₁
Expand Down
19 changes: 12 additions & 7 deletions Std/Tactic/Omega/OmegaM.lean
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,21 @@ def atomsList : OmegaM Expr := do mkListLit (.const ``Int []) (← atoms)
def atomsCoeffs : OmegaM Expr := do
return .app (.const ``Coeffs.ofList []) (← atomsList)

/-- Run an `OmegaM` computation, restoring the state afterwards depending on the result. -/
def commitWhen (t : OmegaM (α × Bool)) : OmegaM α := do
let state ← getThe State
let cache ← getThe Cache
let (a, r) ← t
if !r then do
modifyThe State fun _ => state
modifyThe Cache fun _ => cache
pure a

/--
Run an `OmegaM` computation, restoring the state afterwards.
-/
def savingState (t : OmegaM α) : OmegaM α := do
let state ← getThe State
let cache ← getThe Cache
let r ← t
modifyThe State fun _ => state
modifyThe Cache fun _ => cache
pure r
def withoutModifyingState (t : OmegaM α) : OmegaM α :=
commitWhen (do pure (← t, false))

/-- Wrapper around `Expr.nat?` that also allows `Nat.cast`. -/
def natCast? (n : Expr) : Option Nat :=
Expand Down
4 changes: 4 additions & 0 deletions test/omega/test.lean
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,7 @@ example (a : Nat) (h : a < 0) : Nat → Nat := by omega
example {a₁ a₂ p₁ p₂ : Nat}
(h₁ : a₁ = a₂ → ¬p₁ = p₂) :
(a₁ < a₂ ∨ a₁ = a₂ ∧ p₁ < p₂) ∨ a₂ < a₁ ∨ a₂ = a₁ ∧ p₂ < p₁ := by omega

-- From https://github.com/leanprover/std4/issues/562
example {i : Nat} (h1 : i < 330) (_h2 : 7 ∣ (660 + i) * (1319 - i)) : 1319 - i < 1979 := by
omega

0 comments on commit 24c5044

Please sign in to comment.