From 24c504426209be74ed886ccf9b563b7b5c189c3d Mon Sep 17 00:00:00 2001 From: Scott Morrison Date: Fri, 2 Feb 2024 10:27:58 +1100 Subject: [PATCH] fix: revert OmegaM state when not multiplying out (#570) --- Std/Tactic/Omega/Frontend.lean | 39 ++++++++++++++++++++-------------- Std/Tactic/Omega/OmegaM.lean | 19 +++++++++++------ test/omega/test.lean | 4 ++++ 3 files changed, 39 insertions(+), 23 deletions(-) diff --git a/Std/Tactic/Omega/Frontend.lean b/Std/Tactic/Omega/Frontend.lean index 16b64475f4..19a94faac0 100644 --- a/Std/Tactic/Omega/Frontend.lean +++ b/Std/Tactic/Omega/Frontend.lean @@ -156,21 +156,28 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × (← mkEqSymm neg_eval) pure (-l, prf', facts) | (``HMul.hMul, #[_, _, _, _, x, y]) => - let (xl, xprf, xfacts) ← asLinearCombo x - let (yl, yprf, yfacts) ← asLinearCombo y - if xl.coeffs.isZero ∨ yl.coeffs.isZero then - let prf : OmegaM Expr := do - let h ← mkDecideProof (mkApp2 (.const ``Or []) - (.app (.const ``Coeffs.isZero []) (toExpr xl.coeffs)) - (.app (.const ``Coeffs.isZero []) (toExpr yl.coeffs))) - let mul_eval := - mkApp4 (.const ``LinearCombo.mul_eval []) (toExpr xl) (toExpr yl) (← atomsCoeffs) h - mkEqTrans - (← mkAppM ``Int.mul_congr #[← xprf, ← yprf]) - (← mkEqSymm mul_eval) - pure (LinearCombo.mul xl yl, prf, xfacts.merge yfacts) - else - mkAtomLinearCombo e + -- If we decide not to expand out the multiplication, + -- we have to revert the `OmegaM` state so that any new facts about the factors + -- can still be reported when they are visited elsewhere. + let r? ← commitWhen do + let (xl, xprf, xfacts) ← asLinearCombo x + let (yl, yprf, yfacts) ← asLinearCombo y + if xl.coeffs.isZero ∨ yl.coeffs.isZero then + let prf : OmegaM Expr := do + let h ← mkDecideProof (mkApp2 (.const ``Or []) + (.app (.const ``Coeffs.isZero []) (toExpr xl.coeffs)) + (.app (.const ``Coeffs.isZero []) (toExpr yl.coeffs))) + let mul_eval := + mkApp4 (.const ``LinearCombo.mul_eval []) (toExpr xl) (toExpr yl) (← atomsCoeffs) h + mkEqTrans + (← mkAppM ``Int.mul_congr #[← xprf, ← yprf]) + (← mkEqSymm mul_eval) + pure (some (LinearCombo.mul xl yl, prf, xfacts.merge yfacts), true) + else + pure (none, false) + match r? with + | some r => pure r + | none => mkAtomLinearCombo e | (``HMod.hMod, #[_, _, _, _, n, k]) => rewrite e (mkApp2 (.const ``Int.emod_def []) n k) | (``HDiv.hDiv, #[_, _, _, _, x, z]) => match intCast? z with @@ -434,7 +441,7 @@ partial def splitDisjunction (m : MetaProblem) (g : MVarId) : OmegaM Unit := g.w let (⟨g₁, h₁⟩, ⟨g₂, h₂⟩) ← cases₂ g h trace[omega] "Adding facts:\n{← g₁.withContext <| inferType (.fvar h₁)}" let m₁ := { m with facts := [.fvar h₁], disjunctions := t } - let r ← savingState do + let r ← withoutModifyingState do let (m₁, n) ← g₁.withContext m₁.processFacts if 0 < n then omegaImpl m₁ g₁ diff --git a/Std/Tactic/Omega/OmegaM.lean b/Std/Tactic/Omega/OmegaM.lean index 59c6edc72b..093e4c851c 100644 --- a/Std/Tactic/Omega/OmegaM.lean +++ b/Std/Tactic/Omega/OmegaM.lean @@ -81,16 +81,21 @@ def atomsList : OmegaM Expr := do mkListLit (.const ``Int []) (← atoms) def atomsCoeffs : OmegaM Expr := do return .app (.const ``Coeffs.ofList []) (← atomsList) +/-- Run an `OmegaM` computation, restoring the state afterwards depending on the result. -/ +def commitWhen (t : OmegaM (α × Bool)) : OmegaM α := do + let state ← getThe State + let cache ← getThe Cache + let (a, r) ← t + if !r then do + modifyThe State fun _ => state + modifyThe Cache fun _ => cache + pure a + /-- Run an `OmegaM` computation, restoring the state afterwards. -/ -def savingState (t : OmegaM α) : OmegaM α := do - let state ← getThe State - let cache ← getThe Cache - let r ← t - modifyThe State fun _ => state - modifyThe Cache fun _ => cache - pure r +def withoutModifyingState (t : OmegaM α) : OmegaM α := + commitWhen (do pure (← t, false)) /-- Wrapper around `Expr.nat?` that also allows `Nat.cast`. -/ def natCast? (n : Expr) : Option Nat := diff --git a/test/omega/test.lean b/test/omega/test.lean index 45c49753e6..354a62aae4 100644 --- a/test/omega/test.lean +++ b/test/omega/test.lean @@ -343,3 +343,7 @@ example (a : Nat) (h : a < 0) : Nat → Nat := by omega example {a₁ a₂ p₁ p₂ : Nat} (h₁ : a₁ = a₂ → ¬p₁ = p₂) : (a₁ < a₂ ∨ a₁ = a₂ ∧ p₁ < p₂) ∨ a₂ < a₁ ∨ a₂ = a₁ ∧ p₂ < p₁ := by omega + +-- From https://github.com/leanprover/std4/issues/562 +example {i : Nat} (h1 : i < 330) (_h2 : 7 ∣ (660 + i) * (1319 - i)) : 1319 - i < 1979 := by + omega