Skip to content

Commit

Permalink
Merge pull request #24 from atomb/smt-triggers
Browse files Browse the repository at this point in the history
Add support for SMT triggers
  • Loading branch information
PratherConid authored Apr 16, 2024
2 parents 27f1b26 + 829a324 commit 5cafe34
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 0 deletions.
103 changes: 103 additions & 0 deletions Auto/Embedding/LamBase.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,90 @@ def BitVecConst.LamWF.ofCheck (H : b.lamCheck = s) : LamWF b s := by
case bvcmp n prop? op => cases prop? <;> constructor
case bvshOp n smt? op => cases smt? <;> constructor

inductive OtherConst
| attribute : String -> LamSort -> OtherConst
deriving Inhabited, Hashable, Lean.ToExpr

def OtherConst.reprAux : OtherConst → String
| .attribute n s => s!"attribute {n} {s}"

def OtherConst.reprPrec (p : OtherConst) (n : Nat) :=
match n with
| 0 => f!"Auto.Embedding.Lam.OtherConst.{p.reprAux}"
| _ + 1 => f!"(.{p.reprAux})"

inductive OtherConst.LamWF : OtherConst → LamSort → Type
| ofAttribute n s : LamWF (.attribute n s) (.func s (.func (.base .prop) (.base .prop)))

def OtherConst.LamWF.ofOtherConst : (oc : OtherConst) → (s : LamSort) × OtherConst.LamWF oc s
| .attribute n s => ⟨.func s (.func (.base .prop) (.base .prop)), .ofAttribute n s⟩

def OtherConst.LamWF.interp (tyVal : Nat → Type u) : (lwf : LamWF p s) → s.interp tyVal
| .ofAttribute _ _ => fun _ => fun term => term

def OtherConst.toString : OtherConst → String
| .attribute n s => s!"attr[{n} : {s}]"

instance : ToString OtherConst where
toString := OtherConst.toString

def OtherConst.beq : OtherConst → OtherConst → Bool
| .attribute n1 s1, .attribute n2 s2 =>
n1 == n2 && s1.beq s2

instance : BEq OtherConst where
beq := OtherConst.beq

theorem OtherConst.beq_def {x y : OtherConst} : (x == y) = x.beq y := rfl

def OtherConst.beq_refl {o : OtherConst} : (o.beq o) = true := by
cases o <;> unfold OtherConst.beq <;> simp [LamSort.beq_refl]

def OtherConst.eq_of_beq_eq_true {o₁ o₂ : OtherConst} (H : o₁.beq o₂) : o₁ = o₂ :=
match o₁, o₂ with
| .attribute n1 s1, .attribute n2 s2 => by
simp [OtherConst.beq] at H
match H with
| And.intro neq sbeq =>
have seq := LamSort.eq_of_beq_eq_true sbeq
rw [neq, seq]

instance : LawfulBEq OtherConst where
eq_of_beq := OtherConst.eq_of_beq_eq_true
rfl := OtherConst.beq_refl

def OtherConst.lamCheck : OtherConst → LamSort
| .attribute _ as1 => .func as1 (.func (.base .prop) (.base .prop))

-- Note: add other attributes as needed
def trigger {a : Type} (_ : a) (term : Prop) := term

def OtherConst.interp (tyVal : Nat → Type u) : (o : OtherConst) → o.lamCheck.interp tyVal
| .attribute _ _ => fun _ => fun term => term

def OtherConst.interp_equiv (tyVal : Nat → Type u) (ocwf : LamWF p s) :
HEq (LamWF.interp tyVal ocwf) (interp tyVal p) := by
cases ocwf <;> rfl

def OtherConst.LamWF.unique {o : OtherConst} {s₁ s₂ : LamSort}
(ocwf₁ : LamWF o s₁) (ocwf₂ : LamWF o s₂) : s₁ = s₂ ∧ HEq ocwf₁ ocwf₂ := by
cases ocwf₁ <;> cases ocwf₂ <;> trivial

theorem OtherConst.LamWF.interp_lvalIrrelevance
(tyVal₁ tyVal₂ : Nat → Type u) (ocwf₁ : LamWF b₁ s₁) (ocwf₂ : LamWF b₂ s₂)
(HBeq : b₁ = b₂) (hTyVal : tyVal₁ = tyVal₂) :
HEq (ocwf₁.interp tyVal₁) (ocwf₂.interp tyVal₂) := by
cases HBeq; cases hTyVal; rcases OtherConst.LamWF.unique ocwf₁ ocwf₂ with ⟨⟨⟩, ⟨⟩⟩; rfl

def OtherConst.lamWF_complete (wf : LamWF sc s) : LamWF.ofOtherConst sc = ⟨s, wf⟩ := by
cases wf <;> rfl

def OtherConst.lamCheck_of_LamWF (H : LamWF sc s) : sc.lamCheck = s := by
cases H <;> rfl

def OtherConst.LamWF.ofCheck (H : sc.lamCheck = s) : LamWF sc s := by
cases H; cases sc <;> constructor

/--
Interpreted constants
Note that `eq`, `forallE`, `existE` have `ilVal/lamILTy`
Expand All @@ -1442,6 +1526,7 @@ inductive LamBaseTerm
| icst : IntConst → LamBaseTerm
| scst : StringConst → LamBaseTerm
| bvcst : BitVecConst → LamBaseTerm
| ocst : OtherConst → LamBaseTerm
-- Versions of `eq, ∀, ∃, ite'` when we're importing external facts
-- Note that the [import versions] of `eq, ∀, ∃, ite'` should only be used when
-- we're importing external facts. When facts are imported, we call
Expand Down Expand Up @@ -1604,6 +1689,7 @@ def LamBaseTerm.reprPrec (l : LamBaseTerm) (n : Nat) :=
| .icst ic => f!"icst {IntConst.reprPrec ic 1}"
| .scst sc => f!"scst {StringConst.reprPrec sc 1}"
| .bvcst bvc => f!"bvcst {BitVecConst.reprPrec bvc 1}"
| .ocst oc => f!"ocst {OtherConst.reprPrec oc 1}"
| .eqI n => f!"eqI {n}"
| .forallEI n => f!"forallEI {n}"
| .existEI n => f!"existEI {n}"
Expand All @@ -1627,6 +1713,7 @@ def LamBaseTerm.toString : LamBaseTerm → String
| .icst ic => s!"{ic}"
| .scst sc => s!"{sc}"
| .bvcst bvc => s!"{bvc}"
| .ocst oc => s!"{oc}"
| .eqI _ => "="
| .forallEI _ => "∀"
| .existEI _ => "∃"
Expand All @@ -1646,6 +1733,7 @@ def LamBaseTerm.beq : LamBaseTerm → LamBaseTerm → Bool
| .icst ic₁, .icst ic₂ => IntConst.beq ic₁ ic₂
| .scst sc₁, .scst sc₂ => StringConst.beq sc₁ sc₂
| .bvcst l₁, .bvcst l₂ => BitVecConst.beq l₁ l₂
| .ocst o₁, .ocst o₂ => OtherConst.beq o₁ o₂
| .eqI n₁, .eqI n₂ => n₁.beq n₂
| .forallEI n₁, .forallEI n₂ => n₁.beq n₂
| .existEI n₁, .existEI n₂ => n₁.beq n₂
Expand All @@ -1667,6 +1755,7 @@ def LamBaseTerm.beq_refl {b : LamBaseTerm} : (b.beq b) = true := by
case icst i => apply LawfulBEq.rfl (α := IntConst)
case scst s => apply LawfulBEq.rfl (α := StringConst)
case bvcst s => apply LawfulBEq.rfl (α := BitVecConst)
case ocst o => apply LawfulBEq.rfl (α := OtherConst)

def LamBaseTerm.eq_of_beq_eq_true {b₁ b₂ : LamBaseTerm} (H : b₁.beq b₂) : b₁ = b₂ := by
cases b₁ <;> cases b₂ <;> (first | contradiction | rfl | apply congrArg) <;>
Expand All @@ -1677,6 +1766,7 @@ def LamBaseTerm.eq_of_beq_eq_true {b₁ b₂ : LamBaseTerm} (H : b₁.beq b₂)
case icst.icst.h n₁ n₂ => apply LawfulBEq.eq_of_beq (α := IntConst) H
case scst.scst.h s₁ s₂ => apply LawfulBEq.eq_of_beq (α := StringConst) H
case bvcst.bvcst.h v₁ v₂ => apply LawfulBEq.eq_of_beq (α := BitVecConst) H
case ocst.ocst.h o₁ o₂ => apply LawfulBEq.eq_of_beq (α := OtherConst) H

instance : LawfulBEq LamBaseTerm where
eq_of_beq := LamBaseTerm.eq_of_beq_eq_true
Expand All @@ -1690,6 +1780,7 @@ def LamBaseTerm.containsSort (b : LamBaseTerm) (s : LamSort) : Bool :=
| .icst _ => false
| .scst _ => false
| .bvcst _ => false
| .ocst _ => false
| .eqI _ => false
| .forallEI _ => false
| .existEI _ => false
Expand All @@ -1714,6 +1805,7 @@ def LamBaseTerm.lamCheck (ltv : LamTyVal) : LamBaseTerm → LamSort
| .icst ic => ic.lamCheck
| .scst sc => sc.lamCheck
| .bvcst bvc => bvc.lamCheck
| .ocst oc => oc.lamCheck
| .eqI n =>
let s := ltv.lamILTy n
.func s (.func s (.base .prop))
Expand All @@ -1738,6 +1830,7 @@ inductive LamBaseTerm.LamWF (ltv : LamTyVal) : LamBaseTerm → LamSort → Type
| ofIcst : (icwf : IntConst.LamWF ic s) → LamWF ltv (.icst ic) s
| ofScst : (scwf : StringConst.LamWF sc s) → LamWF ltv (.scst sc) s
| ofBvcst : (bvcwf : BitVecConst.LamWF bvc s) → LamWF ltv (.bvcst bvc) s
| ofOcst : (ocwf : OtherConst.LamWF oc s) → LamWF ltv (.ocst oc) s
| ofEqI n : LamWF ltv (.eqI n) (.func (ltv.lamILTy n) (.func (ltv.lamILTy n) (.base .prop)))
| ofForallEI n : LamWF ltv (.forallEI n) (.func (.func (ltv.lamILTy n) (.base .prop)) (.base .prop))
| ofExistEI n : LamWF ltv (.existEI n) (.func (.func (ltv.lamILTy n) (.base .prop)) (.base .prop))
Expand All @@ -1762,6 +1855,8 @@ def LamBaseTerm.LamWF.unique {ltv : LamTyVal} {b : LamBaseTerm} {s₁ s₂ : Lam
rcases StringConst.LamWF.unique wf₁ wf₂ with ⟨⟨⟩, ⟨⟩⟩; trivial
case ofBvcst.ofBvcst bvc wf₁ wf₂ =>
rcases BitVecConst.LamWF.unique wf₁ wf₂ with ⟨⟨⟩, ⟨⟩⟩; trivial
case ofOcst.ofOcst oc wf₁ wf₂ =>
rcases OtherConst.LamWF.unique wf₁ wf₂ with ⟨⟨⟩, ⟨⟩⟩; trivial

def LamBaseTerm.LamWF.eVarIrrelevance
(hLamVarTy : ltv₁.lamVarTy = ltv₂.lamVarTy)
Expand Down Expand Up @@ -1876,6 +1971,7 @@ def LamBaseTerm.LamWF.ofLamBaseTerm (ltv : LamTyVal) : (b : LamBaseTerm) → (s
| .icst ic => have ⟨s, wf⟩ := IntConst.LamWF.ofIntConst ic; ⟨s, .ofIcst wf⟩
| .scst sc => have ⟨s, wf⟩ := StringConst.LamWF.ofStringConst sc; ⟨s, .ofScst wf⟩
| .bvcst bvc => have ⟨s, wf⟩ := BitVecConst.LamWF.ofBitVecConst bvc; ⟨s, .ofBvcst wf⟩
| .ocst oc => have ⟨s, wf⟩ := OtherConst.LamWF.ofOtherConst oc; ⟨s, .ofOcst wf⟩
| .eqI n => ⟨.func _ (.func _ (.base .prop)), .ofEqI n⟩
| .forallEI n => ⟨.func (.func _ (.base .prop)) (.base .prop), .ofForallEI n⟩
| .existEI n => ⟨.func (.func _ (.base .prop)) (.base .prop), .ofExistEI n⟩
Expand All @@ -1893,6 +1989,7 @@ def LamBaseTerm.lamWF_complete (wf : LamWF ltv b s) : LamWF.ofLamBaseTerm ltv b
case ofIcst ic wf => dsimp [LamWF.ofLamBaseTerm]; rw [IntConst.lamWF_complete]
case ofScst bc wf => dsimp [LamWF.ofLamBaseTerm]; rw [StringConst.lamWF_complete wf]
case ofBvcst bc wf => dsimp [LamWF.ofLamBaseTerm]; rw [BitVecConst.lamWF_complete wf]
case ofOcst oc wf => dsimp [LamWF.ofLamBaseTerm]; rw [OtherConst.lamWF_complete wf]

def LamBaseTerm.lamCheck_of_LamWF (H : LamWF ltv b s) : b.lamCheck ltv = s := by
cases H <;> try rfl
Expand All @@ -1902,6 +1999,7 @@ def LamBaseTerm.lamCheck_of_LamWF (H : LamWF ltv b s) : b.lamCheck ltv = s := by
case ofIcst bc wf => apply IntConst.lamCheck_of_LamWF wf
case ofScst sc wf => apply StringConst.lamCheck_of_LamWF wf
case ofBvcst sc wf => apply BitVecConst.lamCheck_of_LamWF wf
case ofOcst oc wf => apply OtherConst.lamCheck_of_LamWF wf

def LamBaseTerm.LamWF.ofCheck (H : b.lamCheck ltv = s) : LamWF ltv b s := by
cases H; cases b <;> constructor
Expand All @@ -1911,6 +2009,7 @@ def LamBaseTerm.LamWF.ofCheck (H : b.lamCheck ltv = s) : LamWF ltv b s := by
case refl.icst.icwf => apply IntConst.LamWF.ofCheck; rfl
case refl.scst.scwf => apply StringConst.LamWF.ofCheck; rfl
case refl.bvcst.bvcwf => apply BitVecConst.LamWF.ofCheck; rfl
case refl.ocst.ocwf => apply OtherConst.LamWF.ofCheck; rfl

structure ILLift (β : Type u) where
eqL : EqLift.{u + 1, u} β
Expand Down Expand Up @@ -2192,6 +2291,7 @@ noncomputable def LamBaseTerm.interp (lval : LamValuation.{u}) : (b : LamBaseTer
| .icst ic => ic.interp lval.tyVal
| .scst sc => sc.interp lval.tyVal
| .bvcst bvc => bvc.interp lval.tyVal
| .ocst oc => oc.interp lval.tyVal
| .eqI n => (lval.ilVal n).eqL.eqF
| .forallEI n => (lval.ilVal n).forallL.forallF
| .existEI n => (lval.ilVal n).existL.existF
Expand All @@ -2208,6 +2308,7 @@ noncomputable def LamBaseTerm.LamWF.interp (lval : LamValuation.{u}) : (lwf : La
| .ofIcst wf => wf.interp lval.tyVal
| .ofScst wf => wf.interp lval.tyVal
| .ofBvcst wf => wf.interp lval.tyVal
| .ofOcst wf => wf.interp lval.tyVal
| .ofEqI n => (lval.ilVal n).eqL.eqF
| .ofForallEI n => (lval.ilVal n).forallL.forallF
| .ofExistEI n => (lval.ilVal n).existL.existF
Expand Down Expand Up @@ -2254,6 +2355,7 @@ theorem LamBaseTerm.LamWF.interp_lvalIrrelevance
case ofIcst => apply IntConst.LamWF.interp_lvalIrrelevance <;> rfl
case ofScst => apply StringConst.LamWF.interp_lvalIrrelevance <;> rfl
case ofBvcst => apply BitVecConst.LamWF.interp_lvalIrrelevance <;> rfl
case ofOcst => apply OtherConst.LamWF.interp_lvalIrrelevance <;> rfl

def LamBaseTerm.interp_equiv (lval : LamValuation.{u})
(lwf : LamWF lval.toLamTyVal b s) :
Expand All @@ -2265,6 +2367,7 @@ def LamBaseTerm.interp_equiv (lval : LamValuation.{u})
case ofIcst => apply IntConst.interp_equiv
case ofScst => apply StringConst.interp_equiv
case ofBvcst => apply BitVecConst.interp_equiv
case ofOcst => apply OtherConst.interp_equiv

def LamValuation.insertEVarAt (lval : LamValuation.{u})
(ty : LamSort) (val : ty.interp lval.tyVal) (pos : Nat) :=
Expand Down
16 changes: 16 additions & 0 deletions Auto/IR/SMT.lean
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,22 @@ end

def STerm.qStrApp (s : String) (arr : Array STerm) := STerm.qIdApp (.ofString s) arr

def STerm.attrApp (name : String) (attrTerm : STerm) (term : STerm) : STerm :=
match term with
| .attr term' attrs' => .attr term' (#[attr] ++ attrs')
| _ => .attr term #[attr]
where
attr :=
match attrTerm with
-- An empty string constant indicates an attribute with no arguments
| .sConst (.str "") => .none name
-- Other string constants are always symbols here.
| .sConst (.str sym) => .symb name sym
-- Other constants are constant arguments.
| .sConst c => .spec name c
-- Non-constant arguments are terms.
| t => .sexpr name #[t]

private partial def STerm.toStringAux : STerm → List SIdent → String
| .sConst c, _ => SpecConst.toString c
| .bvar i, binders =>
Expand Down
2 changes: 2 additions & 0 deletions Auto/IR/TPTP_TH0.lean
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def transLamBaseTerm : LamBaseTerm → Except String String
| .icst ic => .ok (transIntConst ic)
| .scst sc => .ok (transStringConst sc)
| .bvcst bvc => .ok (transBitVecConst bvc)
-- TODO: translate to λx => x?
| .ocst _ => .error "transLamBaseTerm :: attributes not supported in TPTP"
| .eqI _ => .error "transLamBaseTerm :: eqI should not occur here"
| .forallEI _ => .error "transLamBaseTerm :: forallEI should not occur here"
| .existEI _ => .error "transLamBaseTerm :: existEI should not occur here"
Expand Down
1 change: 1 addition & 0 deletions Auto/Translation/Lam2D.lean
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def interpLamBaseTermAsUnlifted : LamBaseTerm → ExternM Expr
| .icst ic => return interpIntConstAsUnlifted ic
| .scst sc => return interpStringConstAsUnlifted sc
| .bvcst bvc => return interpBitVecConstAsUnlifted bvc
| .ocst _ => throwError ("interpLamTermAsUnlifted :: Attributes not supported")
| .eqI _ => throwError ("interpLamTermAsUnlifted :: " ++ exportError.ImpPolyLog)
| .forallEI _ => throwError ("interpLamTermAsUnlifted :: " ++ exportError.ImpPolyLog)
| .existEI _ => throwError ("interpLamTermAsUnlifted :: " ++ exportError.ImpPolyLog)
Expand Down
1 change: 1 addition & 0 deletions Auto/Translation/LamFOL2SMT.lean
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ private def lamBaseTerm2STerm_Arity2 (arg1 arg2 : STerm) : LamBaseTerm → Trans
| .lshr => return .qStrApp "bvlshr" #[arg1, arg2]
| .ashr => return .qStrApp "bvashr" #[arg1, arg2]
| .bvcst (.bvappend _ _) => return .qStrApp "concat" #[arg1, arg2]
| .ocst (.attribute s _) => return .attrApp s arg1 arg2
| t => throwError "lamTerm2STerm :: The arity of {repr t} is not 2"

private def lamBaseTerm2STerm_Arity1 (arg : STerm) : LamBaseTerm → TransM LamAtom STerm
Expand Down
9 changes: 9 additions & 0 deletions Auto/Translation/LamReif.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,11 @@ def reifMapBVConst3 : HashMap Name (Nat → Nat → Nat → LamTerm) :=
(``BitVec.extractLsb, fun n h l => .base (.bvextract n h l))
]

def reifMapAttributes : HashMap Name String :=
HashMap.ofList [
(``trigger, "pattern")
]

def processSimpleLit (l : Literal) : LamTerm :=
match l with
| .natVal n => .base (.natVal n)
Expand Down Expand Up @@ -1307,6 +1312,10 @@ def processSimpleApp (fn arg : Expr) : ReifM (Option LamTerm) := do
-- `arg` is the original (un-lifted) type
if let .some tcon := reifMapIL.find? name then
return .some (.base (tcon (← reifType arg)))
if let .some attrName := reifMapAttributes.find? name then
if lvls.length != 0 then
throwError "processSimpleApp :: attribute should have nil level list"
return .some (.base (.ocst (.attribute attrName (← reifType arg))))
return .none
| [arg₁, arg₂] =>
if let .some tcon := reifMapBVConst2.find? name then
Expand Down
18 changes: 18 additions & 0 deletions Test/SmtTranslation/Trigger.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import Auto.Tactic
import Auto.Embedding.LamBase
open Auto.Embedding.Lam

set_option auto.smt.trust true
set_option trace.auto.smt.printCommands true
set_option trace.auto.smt.result true
set_option trace.auto.smt.unsatCore true
set_option auto.smt.save true
set_option auto.smt.savepath "output.smt2"

set_option auto.smt true

axiom f : Int -> Int

axiom fGreater : forall x, trigger (f x) (f x > x)

theorem fPlusOneGreater : forall x, (f x) + 1 > x := by auto [fGreater] u[]

0 comments on commit 5cafe34

Please sign in to comment.