Skip to content

Commit

Permalink
feat: omega handles shift operators, and normalises ground term expon…
Browse files Browse the repository at this point in the history
…entials (#3433)

This is a preliminary to a BitVec frontend for `omega`.
  • Loading branch information
kim-em authored Feb 21, 2024
1 parent 89490f6 commit f76bb24
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 11 deletions.
8 changes: 8 additions & 0 deletions src/Init/Omega/Int.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ Authors: Scott Morrison
-/
prelude
import Init.Data.Int.Order
import Init.Data.Int.DivModLemmas
import Init.Data.Nat.Lemmas

/-!
# Lemmas about `Nat` and `Int` needed internally by `omega`.
Expand Down Expand Up @@ -43,6 +45,12 @@ theorem ofNat_lt_of_lt {x y : Nat} (h : x < y) : (x : Int) < (y : Int) :=
theorem ofNat_le_of_le {x y : Nat} (h : x ≤ y) : (x : Int) ≤ (y : Int) :=
Int.ofNat_le.mpr h

theorem ofNat_shiftLeft_eq {x y : Nat} : (x <<< y : Int) = (x : Int) * (2 ^ y : Nat) := by
simp [Nat.shiftLeft_eq]

theorem ofNat_shiftRight_eq_div_pow {x y : Nat} : (x >>> y : Int) = (x : Int) / (2 ^ y : Nat) := by
simp [Nat.shiftRight_eq_div_pow]

-- FIXME these are insane:
theorem lt_of_not_ge {x y : Int} (h : ¬ (x ≤ y)) : y < x := Int.not_le.mp h
theorem lt_of_not_le {x y : Int} (h : ¬ (x ≤ y)) : y < x := Int.not_le.mp h
Expand Down
42 changes: 35 additions & 7 deletions src/Lean/Elab/Tactic/Omega/Frontend.lean
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,24 @@ Allow elaboration of `OmegaConfig` arguments to tactics.
declare_config_elab elabOmegaConfig Lean.Meta.Omega.OmegaConfig




/--
The current `ToExpr` instance for `Int` is bad,
so we roll our own here.
-/
def mkInt (i : Int) : Expr :=
if 0 ≤ i then
mkNat i.toNat
else
mkApp3 (.const ``Neg.neg [0]) (.const ``Int []) (mkNat (-i).toNat)
(.const ``Int.instNegInt [])
where
mkNat (n : Nat) : Expr :=
let r := mkRawNatLit n
mkApp3 (.const ``OfNat.ofNat [0]) (.const ``Int []) r
(.app (.const ``instOfNat []) r)

/--
A partially processed `omega` context.
Expand Down Expand Up @@ -114,7 +132,7 @@ We also transform the expression as we descend into it:
-/
partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
trace[omega] "processing {e}"
match e.int? with
match groundInt? e with
| some i =>
let lc := {const := i}
return ⟨lc, mkEvalRflProof e lc, ∅⟩
Expand Down Expand Up @@ -177,17 +195,20 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
| some r => pure r
| none => mkAtomLinearCombo e
| (``HMod.hMod, #[_, _, _, _, n, k]) =>
match natCast? k with
| some _ => rewrite e (mkApp2 (.const ``Int.emod_def []) n k)
match groundNat? k with
| some k' => do
let k' := mkInt k'
rewrite (← mkAppM ``HMod.hMod #[n, k']) (mkApp2 (.const ``Int.emod_def []) n k')
| none => mkAtomLinearCombo e
| (``HDiv.hDiv, #[_, _, _, _, x, z]) =>
match intCast? z with
match groundInt? z with
| some 0 => rewrite e (mkApp (.const ``Int.ediv_zero []) x)
| some i =>
| some i => do
let e' ← mkAppM ``HDiv.hDiv #[x, mkInt i]
if i < 0 then
rewrite e (mkApp2 (.const ``Int.ediv_neg []) x (toExpr (-i)))
rewrite e' (mkApp2 (.const ``Int.ediv_neg []) x (mkInt (-i)))
else
mkAtomLinearCombo e
mkAtomLinearCombo e'
| _ => mkAtomLinearCombo e
| (``Min.min, #[_, _, a, b]) =>
if (← cfg).splitMinMax then
Expand Down Expand Up @@ -216,6 +237,9 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
| (``HMod.hMod, #[_, _, _, _, a, b]) => rewrite e (mkApp2 (.const ``Int.ofNat_emod []) a b)
| (``HSub.hSub, #[_, _, _, _, mkApp6 (.const ``HSub.hSub _) _ _ _ _ a b, c]) =>
rewrite e (mkApp3 (.const ``Int.ofNat_sub_sub []) a b c)
| (``HPow.hPow, #[_, _, _, _, a, b]) => match groundNat? a, groundNat? b with
| some _, some _ => rewrite e (mkApp2 (.const ``Int.ofNat_pow []) a b)
| _, _ => mkAtomLinearCombo e
| (``Prod.fst, #[_, β, p]) => match p with
| .app (.app (.app (.app (.const ``Prod.mk [0, v]) _) _) x) y =>
rewrite e (mkApp3 (.const ``Int.ofNat_fst_mk [v]) β x y)
Expand All @@ -226,6 +250,10 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
| _ => mkAtomLinearCombo e
| (``Min.min, #[_, _, a, b]) => rewrite e (mkApp2 (.const ``Int.ofNat_min []) a b)
| (``Max.max, #[_, _, a, b]) => rewrite e (mkApp2 (.const ``Int.ofNat_max []) a b)
| (``HShiftLeft.hShiftLeft, #[_, _, _, _, a, b]) =>
rewrite e (mkApp2 (.const ``Int.ofNat_shiftLeft_eq []) a b)
| (``HShiftRight.hShiftRight, #[_, _, _, _, a, b]) =>
rewrite e (mkApp2 (.const ``Int.ofNat_shiftRight_eq_div_pow []) a b)
| (``Int.natAbs, #[n]) =>
if (← cfg).splitNatAbs then
rewrite e (mkApp (.const ``Int.ofNat_natAbs []) n)
Expand Down
39 changes: 39 additions & 0 deletions src/Lean/Elab/Tactic/Omega/OmegaM.lean
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,45 @@ def intCast? (n : Expr) : Option Int :=
| (``Nat.cast, #[_, _, n]) => n.nat?
| _ => n.int?

/--
If `groundNat? e = some n`, then `e` is definitionally equal to `OfNat.ofNat n`.
-/
-- We may want to replace this with an implementation using
-- the internals of `simp (config := {ground := true})`
partial def groundNat? (e : Expr) : Option Nat :=
match e.getAppFnArgs with
| (``Nat.cast, #[_, _, n]) => groundNat? n
| (``HAdd.hAdd, #[_, _, _, _, x, y]) => op (· + ·) x y
| (``HMul.hMul, #[_, _, _, _, x, y]) => op (· * ·) x y
| (``HSub.hSub, #[_, _, _, _, x, y]) => op (· - ·) x y
| (``HDiv.hDiv, #[_, _, _, _, x, y]) => op (· / ·) x y
| (``HPow.hPow, #[_, _, _, _, x, y]) => op (· ^ ·) x y
| _ => e.nat?
where op (f : Nat → Nat → Nat) (x y : Expr) : Option Nat :=
match groundNat? x, groundNat? y with
| some x', some y' => some (f x' y')
| _, _ => none

/--
If `groundInt? e = some i`,
then `e` is definitionally equal to the standard expression for `i`.
-/
partial def groundInt? (e : Expr) : Option Int :=
match e.getAppFnArgs with
| (``Nat.cast, #[_, _, n]) => groundNat? n
| (``HAdd.hAdd, #[_, _, _, _, x, y]) => op (· + ·) x y
| (``HMul.hMul, #[_, _, _, _, x, y]) => op (· * ·) x y
| (``HSub.hSub, #[_, _, _, _, x, y]) => op (· - ·) x y
| (``HDiv.hDiv, #[_, _, _, _, x, y]) => op (· / ·) x y
| (``HPow.hPow, #[_, _, _, _, x, y]) => match groundInt? x, groundNat? y with
| some x', some y' => some (x' ^ y')
| _, _ => none
| _ => e.int?
where op (f : Int → Int → Int) (x y : Expr) : Option Int :=
match groundNat? x, groundNat? y with
| some x', some y' => some (f x' y')
| _, _ => none

/-- Construct the term with type hint `(Eq.refl a : a = b)`-/
def mkEqReflWithExpectedType (a b : Expr) : MetaM Expr := do
mkExpectedTypeHint (← mkEqRefl a) (← mkEq a b)
Expand Down
14 changes: 10 additions & 4 deletions tests/lean/run/omega.lean
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,11 @@ example (i : Fin 7) : (i : Nat) < 8 := by omega

example (x y z i : Nat) (hz : z ≤ 1) : x % 2 ^ i + y % 2 ^ i + z < 2 * 2^ i := by omega

/-! ### Ground terms -/

example : 2^7 < 165 := by omega
example (_ : x % 2^7 < 3) : x % 128 < 5 := by omega

/-! ### BitVec -/
-- Currently these tests require calling `simp` with many lemmas,
-- and sometimes adding `toNat_lt` as a hypothesis.
Expand All @@ -392,15 +397,16 @@ example (x y : BitVec 8) (hx : x < 16) (hy : y < 16) : x + y < 31 := by
simp [BitVec.lt_def] at *
omega

example (x y z : BitVec 8) (hx : x >>> 1 < 16) (hy : y < 16) (hz : z = x + 2 * y) : z ≤ 64 := by
simp [BitVec.lt_def, BitVec.le_def, BitVec.toNat_eq, Nat.shiftRight_eq_div_pow, BitVec.toNat_mul] at *
example (x y z : BitVec 8)
(hx : x >>> 1 < 16) (hy : y < 16) (hz : z = x + 2 * y) : z ≤ 64 := by
simp [BitVec.lt_def, BitVec.le_def, BitVec.toNat_eq, BitVec.toNat_mul] at *
omega

example (x : BitVec 8) (hx : (x + 1) <<< 1 = 3) : False := by
simp [BitVec.toNat_eq, Nat.shiftLeft_eq] at *
simp [BitVec.toNat_eq] at *
omega

example (x : BitVec 8) (hx : (x + 1) <<< 1 = 4) : x = 1 ∨ x = 129 := by
have := toNat_lt x
simp [BitVec.toNat_eq, Nat.shiftLeft_eq, BitVec.lt_def] at *
simp [BitVec.toNat_eq, BitVec.lt_def] at *
omega

0 comments on commit f76bb24

Please sign in to comment.