diff --git a/src/Init/Omega/Int.lean b/src/Init/Omega/Int.lean index e0d39c6df1fa..375b9c321ae1 100644 --- a/src/Init/Omega/Int.lean +++ b/src/Init/Omega/Int.lean @@ -5,6 +5,8 @@ Authors: Scott Morrison -/ prelude import Init.Data.Int.Order +import Init.Data.Int.DivModLemmas +import Init.Data.Nat.Lemmas /-! # Lemmas about `Nat` and `Int` needed internally by `omega`. @@ -43,6 +45,12 @@ theorem ofNat_lt_of_lt {x y : Nat} (h : x < y) : (x : Int) < (y : Int) := theorem ofNat_le_of_le {x y : Nat} (h : x ≤ y) : (x : Int) ≤ (y : Int) := Int.ofNat_le.mpr h +theorem ofNat_shiftLeft_eq {x y : Nat} : (x <<< y : Int) = (x : Int) * (2 ^ y : Nat) := by + simp [Nat.shiftLeft_eq] + +theorem ofNat_shiftRight_eq_div_pow {x y : Nat} : (x >>> y : Int) = (x : Int) / (2 ^ y : Nat) := by + simp [Nat.shiftRight_eq_div_pow] + -- FIXME these are insane: theorem lt_of_not_ge {x y : Int} (h : ¬ (x ≤ y)) : y < x := Int.not_le.mp h theorem lt_of_not_le {x y : Int} (h : ¬ (x ≤ y)) : y < x := Int.not_le.mp h diff --git a/src/Lean/Elab/Tactic/Omega/Frontend.lean b/src/Lean/Elab/Tactic/Omega/Frontend.lean index c156cb50f8de..054e13f4ddec 100644 --- a/src/Lean/Elab/Tactic/Omega/Frontend.lean +++ b/src/Lean/Elab/Tactic/Omega/Frontend.lean @@ -24,6 +24,24 @@ Allow elaboration of `OmegaConfig` arguments to tactics. declare_config_elab elabOmegaConfig Lean.Meta.Omega.OmegaConfig + + +/-- +The current `ToExpr` instance for `Int` is bad, +so we roll our own here. +-/ +def mkInt (i : Int) : Expr := + if 0 ≤ i then + mkNat i.toNat + else + mkApp3 (.const ``Neg.neg [0]) (.const ``Int []) (mkNat (-i).toNat) + (.const ``Int.instNegInt []) +where + mkNat (n : Nat) : Expr := + let r := mkRawNatLit n + mkApp3 (.const ``OfNat.ofNat [0]) (.const ``Int []) r + (.app (.const ``instOfNat []) r) + /-- A partially processed `omega` context. @@ -114,7 +132,7 @@ We also transform the expression as we descend into it: -/ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do trace[omega] "processing {e}" - match e.int? with + match groundInt? e with | some i => let lc := {const := i} return ⟨lc, mkEvalRflProof e lc, ∅⟩ @@ -177,17 +195,20 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × | some r => pure r | none => mkAtomLinearCombo e | (``HMod.hMod, #[_, _, _, _, n, k]) => - match natCast? k with - | some _ => rewrite e (mkApp2 (.const ``Int.emod_def []) n k) + match groundNat? k with + | some k' => do + let k' := mkInt k' + rewrite (← mkAppM ``HMod.hMod #[n, k']) (mkApp2 (.const ``Int.emod_def []) n k') | none => mkAtomLinearCombo e | (``HDiv.hDiv, #[_, _, _, _, x, z]) => - match intCast? z with + match groundInt? z with | some 0 => rewrite e (mkApp (.const ``Int.ediv_zero []) x) - | some i => + | some i => do + let e' ← mkAppM ``HDiv.hDiv #[x, mkInt i] if i < 0 then - rewrite e (mkApp2 (.const ``Int.ediv_neg []) x (toExpr (-i))) + rewrite e' (mkApp2 (.const ``Int.ediv_neg []) x (mkInt (-i))) else - mkAtomLinearCombo e + mkAtomLinearCombo e' | _ => mkAtomLinearCombo e | (``Min.min, #[_, _, a, b]) => if (← cfg).splitMinMax then @@ -216,6 +237,9 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × | (``HMod.hMod, #[_, _, _, _, a, b]) => rewrite e (mkApp2 (.const ``Int.ofNat_emod []) a b) | (``HSub.hSub, #[_, _, _, _, mkApp6 (.const ``HSub.hSub _) _ _ _ _ a b, c]) => rewrite e (mkApp3 (.const ``Int.ofNat_sub_sub []) a b c) + | (``HPow.hPow, #[_, _, _, _, a, b]) => match groundNat? a, groundNat? b with + | some _, some _ => rewrite e (mkApp2 (.const ``Int.ofNat_pow []) a b) + | _, _ => mkAtomLinearCombo e | (``Prod.fst, #[_, β, p]) => match p with | .app (.app (.app (.app (.const ``Prod.mk [0, v]) _) _) x) y => rewrite e (mkApp3 (.const ``Int.ofNat_fst_mk [v]) β x y) @@ -226,6 +250,10 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × | _ => mkAtomLinearCombo e | (``Min.min, #[_, _, a, b]) => rewrite e (mkApp2 (.const ``Int.ofNat_min []) a b) | (``Max.max, #[_, _, a, b]) => rewrite e (mkApp2 (.const ``Int.ofNat_max []) a b) + | (``HShiftLeft.hShiftLeft, #[_, _, _, _, a, b]) => + rewrite e (mkApp2 (.const ``Int.ofNat_shiftLeft_eq []) a b) + | (``HShiftRight.hShiftRight, #[_, _, _, _, a, b]) => + rewrite e (mkApp2 (.const ``Int.ofNat_shiftRight_eq_div_pow []) a b) | (``Int.natAbs, #[n]) => if (← cfg).splitNatAbs then rewrite e (mkApp (.const ``Int.ofNat_natAbs []) n) diff --git a/src/Lean/Elab/Tactic/Omega/OmegaM.lean b/src/Lean/Elab/Tactic/Omega/OmegaM.lean index dbc296b23458..f9debc29fba6 100644 --- a/src/Lean/Elab/Tactic/Omega/OmegaM.lean +++ b/src/Lean/Elab/Tactic/Omega/OmegaM.lean @@ -108,6 +108,45 @@ def intCast? (n : Expr) : Option Int := | (``Nat.cast, #[_, _, n]) => n.nat? | _ => n.int? +/-- +If `groundNat? e = some n`, then `e` is definitionally equal to `OfNat.ofNat n`. +-/ +-- We may want to replace this with an implementation using +-- the internals of `simp (config := {ground := true})` +partial def groundNat? (e : Expr) : Option Nat := + match e.getAppFnArgs with + | (``Nat.cast, #[_, _, n]) => groundNat? n + | (``HAdd.hAdd, #[_, _, _, _, x, y]) => op (· + ·) x y + | (``HMul.hMul, #[_, _, _, _, x, y]) => op (· * ·) x y + | (``HSub.hSub, #[_, _, _, _, x, y]) => op (· - ·) x y + | (``HDiv.hDiv, #[_, _, _, _, x, y]) => op (· / ·) x y + | (``HPow.hPow, #[_, _, _, _, x, y]) => op (· ^ ·) x y + | _ => e.nat? +where op (f : Nat → Nat → Nat) (x y : Expr) : Option Nat := + match groundNat? x, groundNat? y with + | some x', some y' => some (f x' y') + | _, _ => none + +/-- +If `groundInt? e = some i`, +then `e` is definitionally equal to the standard expression for `i`. +-/ +partial def groundInt? (e : Expr) : Option Int := + match e.getAppFnArgs with + | (``Nat.cast, #[_, _, n]) => groundNat? n + | (``HAdd.hAdd, #[_, _, _, _, x, y]) => op (· + ·) x y + | (``HMul.hMul, #[_, _, _, _, x, y]) => op (· * ·) x y + | (``HSub.hSub, #[_, _, _, _, x, y]) => op (· - ·) x y + | (``HDiv.hDiv, #[_, _, _, _, x, y]) => op (· / ·) x y + | (``HPow.hPow, #[_, _, _, _, x, y]) => match groundInt? x, groundNat? y with + | some x', some y' => some (x' ^ y') + | _, _ => none + | _ => e.int? +where op (f : Int → Int → Int) (x y : Expr) : Option Int := + match groundNat? x, groundNat? y with + | some x', some y' => some (f x' y') + | _, _ => none + /-- Construct the term with type hint `(Eq.refl a : a = b)`-/ def mkEqReflWithExpectedType (a b : Expr) : MetaM Expr := do mkExpectedTypeHint (← mkEqRefl a) (← mkEq a b) diff --git a/tests/lean/run/omega.lean b/tests/lean/run/omega.lean index 508b6756352d..5bd6dc22d26e 100644 --- a/tests/lean/run/omega.lean +++ b/tests/lean/run/omega.lean @@ -381,6 +381,11 @@ example (i : Fin 7) : (i : Nat) < 8 := by omega example (x y z i : Nat) (hz : z ≤ 1) : x % 2 ^ i + y % 2 ^ i + z < 2 * 2^ i := by omega +/-! ### Ground terms -/ + +example : 2^7 < 165 := by omega +example (_ : x % 2^7 < 3) : x % 128 < 5 := by omega + /-! ### BitVec -/ -- Currently these tests require calling `simp` with many lemmas, -- and sometimes adding `toNat_lt` as a hypothesis. @@ -392,15 +397,16 @@ example (x y : BitVec 8) (hx : x < 16) (hy : y < 16) : x + y < 31 := by simp [BitVec.lt_def] at * omega -example (x y z : BitVec 8) (hx : x >>> 1 < 16) (hy : y < 16) (hz : z = x + 2 * y) : z ≤ 64 := by - simp [BitVec.lt_def, BitVec.le_def, BitVec.toNat_eq, Nat.shiftRight_eq_div_pow, BitVec.toNat_mul] at * +example (x y z : BitVec 8) + (hx : x >>> 1 < 16) (hy : y < 16) (hz : z = x + 2 * y) : z ≤ 64 := by + simp [BitVec.lt_def, BitVec.le_def, BitVec.toNat_eq, BitVec.toNat_mul] at * omega example (x : BitVec 8) (hx : (x + 1) <<< 1 = 3) : False := by - simp [BitVec.toNat_eq, Nat.shiftLeft_eq] at * + simp [BitVec.toNat_eq] at * omega example (x : BitVec 8) (hx : (x + 1) <<< 1 = 4) : x = 1 ∨ x = 129 := by have := toNat_lt x - simp [BitVec.toNat_eq, Nat.shiftLeft_eq, BitVec.lt_def] at * + simp [BitVec.toNat_eq, BitVec.lt_def] at * omega