From 829a324fb71be6b7150b898a065e44b52d6ef7db Mon Sep 17 00:00:00 2001 From: Aaron Tomb Date: Mon, 8 Apr 2024 10:38:50 -0700 Subject: [PATCH] Add support for SMT triggers This is implemented by allowing arbitrary attributes to be propagated throughout the pipeline. For the SMT backend, these are translated into SMT attributes. Currently, the only attribute supported is the `:pattern` attribute used to indicate triggers. Adding other attributes will be easy now that this plumbing is in place, however. Attributes currently aren't supported for TPTP, but it should be possible to allow them eventually if there's a use for them. Note that the initial construction of attributes is very general, and allows attributes of any base type supported by the translator. The set supported by SMT is more restricted, but other prover input languages may use entirely different attributes. There's no attempt to ensure that attributes go in places that are allowed. So, for instance, adding a trigger outside of a quantified expression will lead to an error from the solver. Similarly, adding a trigger to a goal, rather than an axiom, will typically also lead to an error. --- Auto/Embedding/LamBase.lean | 103 +++++++++++++++++++++++++++++++ Auto/IR/SMT.lean | 16 +++++ Auto/IR/TPTP_TH0.lean | 2 + Auto/Translation/Lam2D.lean | 1 + Auto/Translation/LamFOL2SMT.lean | 1 + Auto/Translation/LamReif.lean | 9 +++ Test/SmtTranslation/Trigger.lean | 18 ++++++ 7 files changed, 150 insertions(+) create mode 100644 Test/SmtTranslation/Trigger.lean diff --git a/Auto/Embedding/LamBase.lean b/Auto/Embedding/LamBase.lean index d7c3e55..c707acb 100644 --- a/Auto/Embedding/LamBase.lean +++ b/Auto/Embedding/LamBase.lean @@ -1418,6 +1418,90 @@ def BitVecConst.LamWF.ofCheck (H : b.lamCheck = s) : LamWF b s := by case bvcmp n prop? op => cases prop? <;> constructor case bvshOp n smt? op => cases smt? <;> constructor +inductive OtherConst + | attribute : String -> LamSort -> OtherConst +deriving Inhabited, Hashable, Lean.ToExpr + +def OtherConst.reprAux : OtherConst → String +| .attribute n s => s!"attribute {n} {s}" + +def OtherConst.reprPrec (p : OtherConst) (n : Nat) := + match n with + | 0 => f!"Auto.Embedding.Lam.OtherConst.{p.reprAux}" + | _ + 1 => f!"(.{p.reprAux})" + +inductive OtherConst.LamWF : OtherConst → LamSort → Type + | ofAttribute n s : LamWF (.attribute n s) (.func s (.func (.base .prop) (.base .prop))) + +def OtherConst.LamWF.ofOtherConst : (oc : OtherConst) → (s : LamSort) × OtherConst.LamWF oc s +| .attribute n s => ⟨.func s (.func (.base .prop) (.base .prop)), .ofAttribute n s⟩ + +def OtherConst.LamWF.interp (tyVal : Nat → Type u) : (lwf : LamWF p s) → s.interp tyVal +| .ofAttribute _ _ => fun _ => fun term => term + +def OtherConst.toString : OtherConst → String +| .attribute n s => s!"attr[{n} : {s}]" + +instance : ToString OtherConst where + toString := OtherConst.toString + +def OtherConst.beq : OtherConst → OtherConst → Bool +| .attribute n1 s1, .attribute n2 s2 => + n1 == n2 && s1.beq s2 + +instance : BEq OtherConst where + beq := OtherConst.beq + +theorem OtherConst.beq_def {x y : OtherConst} : (x == y) = x.beq y := rfl + +def OtherConst.beq_refl {o : OtherConst} : (o.beq o) = true := by + cases o <;> unfold OtherConst.beq <;> simp [LamSort.beq_refl] + +def OtherConst.eq_of_beq_eq_true {o₁ o₂ : OtherConst} (H : o₁.beq o₂) : o₁ = o₂ := + match o₁, o₂ with + | .attribute n1 s1, .attribute n2 s2 => by + simp [OtherConst.beq] at H + match H with + | And.intro neq sbeq => + have seq := LamSort.eq_of_beq_eq_true sbeq + rw [neq, seq] + +instance : LawfulBEq OtherConst where + eq_of_beq := OtherConst.eq_of_beq_eq_true + rfl := OtherConst.beq_refl + +def OtherConst.lamCheck : OtherConst → LamSort +| .attribute _ as1 => .func as1 (.func (.base .prop) (.base .prop)) + +-- Note: add other attributes as needed +def trigger {a : Type} (_ : a) (term : Prop) := term + +def OtherConst.interp (tyVal : Nat → Type u) : (o : OtherConst) → o.lamCheck.interp tyVal +| .attribute _ _ => fun _ => fun term => term + +def OtherConst.interp_equiv (tyVal : Nat → Type u) (ocwf : LamWF p s) : + HEq (LamWF.interp tyVal ocwf) (interp tyVal p) := by + cases ocwf <;> rfl + +def OtherConst.LamWF.unique {o : OtherConst} {s₁ s₂ : LamSort} + (ocwf₁ : LamWF o s₁) (ocwf₂ : LamWF o s₂) : s₁ = s₂ ∧ HEq ocwf₁ ocwf₂ := by + cases ocwf₁ <;> cases ocwf₂ <;> trivial + +theorem OtherConst.LamWF.interp_lvalIrrelevance + (tyVal₁ tyVal₂ : Nat → Type u) (ocwf₁ : LamWF b₁ s₁) (ocwf₂ : LamWF b₂ s₂) + (HBeq : b₁ = b₂) (hTyVal : tyVal₁ = tyVal₂) : + HEq (ocwf₁.interp tyVal₁) (ocwf₂.interp tyVal₂) := by + cases HBeq; cases hTyVal; rcases OtherConst.LamWF.unique ocwf₁ ocwf₂ with ⟨⟨⟩, ⟨⟩⟩; rfl + +def OtherConst.lamWF_complete (wf : LamWF sc s) : LamWF.ofOtherConst sc = ⟨s, wf⟩ := by + cases wf <;> rfl + +def OtherConst.lamCheck_of_LamWF (H : LamWF sc s) : sc.lamCheck = s := by + cases H <;> rfl + +def OtherConst.LamWF.ofCheck (H : sc.lamCheck = s) : LamWF sc s := by + cases H; cases sc <;> constructor + /-- Interpreted constants Note that `eq`, `forallE`, `existE` have `ilVal/lamILTy` @@ -1442,6 +1526,7 @@ inductive LamBaseTerm | icst : IntConst → LamBaseTerm | scst : StringConst → LamBaseTerm | bvcst : BitVecConst → LamBaseTerm + | ocst : OtherConst → LamBaseTerm -- Versions of `eq, ∀, ∃, ite'` when we're importing external facts -- Note that the [import versions] of `eq, ∀, ∃, ite'` should only be used when -- we're importing external facts. When facts are imported, we call @@ -1604,6 +1689,7 @@ def LamBaseTerm.reprPrec (l : LamBaseTerm) (n : Nat) := | .icst ic => f!"icst {IntConst.reprPrec ic 1}" | .scst sc => f!"scst {StringConst.reprPrec sc 1}" | .bvcst bvc => f!"bvcst {BitVecConst.reprPrec bvc 1}" + | .ocst oc => f!"ocst {OtherConst.reprPrec oc 1}" | .eqI n => f!"eqI {n}" | .forallEI n => f!"forallEI {n}" | .existEI n => f!"existEI {n}" @@ -1627,6 +1713,7 @@ def LamBaseTerm.toString : LamBaseTerm → String | .icst ic => s!"{ic}" | .scst sc => s!"{sc}" | .bvcst bvc => s!"{bvc}" +| .ocst oc => s!"{oc}" | .eqI _ => "=" | .forallEI _ => "∀" | .existEI _ => "∃" @@ -1646,6 +1733,7 @@ def LamBaseTerm.beq : LamBaseTerm → LamBaseTerm → Bool | .icst ic₁, .icst ic₂ => IntConst.beq ic₁ ic₂ | .scst sc₁, .scst sc₂ => StringConst.beq sc₁ sc₂ | .bvcst l₁, .bvcst l₂ => BitVecConst.beq l₁ l₂ +| .ocst o₁, .ocst o₂ => OtherConst.beq o₁ o₂ | .eqI n₁, .eqI n₂ => n₁.beq n₂ | .forallEI n₁, .forallEI n₂ => n₁.beq n₂ | .existEI n₁, .existEI n₂ => n₁.beq n₂ @@ -1667,6 +1755,7 @@ def LamBaseTerm.beq_refl {b : LamBaseTerm} : (b.beq b) = true := by case icst i => apply LawfulBEq.rfl (α := IntConst) case scst s => apply LawfulBEq.rfl (α := StringConst) case bvcst s => apply LawfulBEq.rfl (α := BitVecConst) + case ocst o => apply LawfulBEq.rfl (α := OtherConst) def LamBaseTerm.eq_of_beq_eq_true {b₁ b₂ : LamBaseTerm} (H : b₁.beq b₂) : b₁ = b₂ := by cases b₁ <;> cases b₂ <;> (first | contradiction | rfl | apply congrArg) <;> @@ -1677,6 +1766,7 @@ def LamBaseTerm.eq_of_beq_eq_true {b₁ b₂ : LamBaseTerm} (H : b₁.beq b₂) case icst.icst.h n₁ n₂ => apply LawfulBEq.eq_of_beq (α := IntConst) H case scst.scst.h s₁ s₂ => apply LawfulBEq.eq_of_beq (α := StringConst) H case bvcst.bvcst.h v₁ v₂ => apply LawfulBEq.eq_of_beq (α := BitVecConst) H + case ocst.ocst.h o₁ o₂ => apply LawfulBEq.eq_of_beq (α := OtherConst) H instance : LawfulBEq LamBaseTerm where eq_of_beq := LamBaseTerm.eq_of_beq_eq_true @@ -1690,6 +1780,7 @@ def LamBaseTerm.containsSort (b : LamBaseTerm) (s : LamSort) : Bool := | .icst _ => false | .scst _ => false | .bvcst _ => false + | .ocst _ => false | .eqI _ => false | .forallEI _ => false | .existEI _ => false @@ -1714,6 +1805,7 @@ def LamBaseTerm.lamCheck (ltv : LamTyVal) : LamBaseTerm → LamSort | .icst ic => ic.lamCheck | .scst sc => sc.lamCheck | .bvcst bvc => bvc.lamCheck +| .ocst oc => oc.lamCheck | .eqI n => let s := ltv.lamILTy n .func s (.func s (.base .prop)) @@ -1738,6 +1830,7 @@ inductive LamBaseTerm.LamWF (ltv : LamTyVal) : LamBaseTerm → LamSort → Type | ofIcst : (icwf : IntConst.LamWF ic s) → LamWF ltv (.icst ic) s | ofScst : (scwf : StringConst.LamWF sc s) → LamWF ltv (.scst sc) s | ofBvcst : (bvcwf : BitVecConst.LamWF bvc s) → LamWF ltv (.bvcst bvc) s + | ofOcst : (ocwf : OtherConst.LamWF oc s) → LamWF ltv (.ocst oc) s | ofEqI n : LamWF ltv (.eqI n) (.func (ltv.lamILTy n) (.func (ltv.lamILTy n) (.base .prop))) | ofForallEI n : LamWF ltv (.forallEI n) (.func (.func (ltv.lamILTy n) (.base .prop)) (.base .prop)) | ofExistEI n : LamWF ltv (.existEI n) (.func (.func (ltv.lamILTy n) (.base .prop)) (.base .prop)) @@ -1762,6 +1855,8 @@ def LamBaseTerm.LamWF.unique {ltv : LamTyVal} {b : LamBaseTerm} {s₁ s₂ : Lam rcases StringConst.LamWF.unique wf₁ wf₂ with ⟨⟨⟩, ⟨⟩⟩; trivial case ofBvcst.ofBvcst bvc wf₁ wf₂ => rcases BitVecConst.LamWF.unique wf₁ wf₂ with ⟨⟨⟩, ⟨⟩⟩; trivial + case ofOcst.ofOcst oc wf₁ wf₂ => + rcases OtherConst.LamWF.unique wf₁ wf₂ with ⟨⟨⟩, ⟨⟩⟩; trivial def LamBaseTerm.LamWF.eVarIrrelevance (hLamVarTy : ltv₁.lamVarTy = ltv₂.lamVarTy) @@ -1876,6 +1971,7 @@ def LamBaseTerm.LamWF.ofLamBaseTerm (ltv : LamTyVal) : (b : LamBaseTerm) → (s | .icst ic => have ⟨s, wf⟩ := IntConst.LamWF.ofIntConst ic; ⟨s, .ofIcst wf⟩ | .scst sc => have ⟨s, wf⟩ := StringConst.LamWF.ofStringConst sc; ⟨s, .ofScst wf⟩ | .bvcst bvc => have ⟨s, wf⟩ := BitVecConst.LamWF.ofBitVecConst bvc; ⟨s, .ofBvcst wf⟩ +| .ocst oc => have ⟨s, wf⟩ := OtherConst.LamWF.ofOtherConst oc; ⟨s, .ofOcst wf⟩ | .eqI n => ⟨.func _ (.func _ (.base .prop)), .ofEqI n⟩ | .forallEI n => ⟨.func (.func _ (.base .prop)) (.base .prop), .ofForallEI n⟩ | .existEI n => ⟨.func (.func _ (.base .prop)) (.base .prop), .ofExistEI n⟩ @@ -1893,6 +1989,7 @@ def LamBaseTerm.lamWF_complete (wf : LamWF ltv b s) : LamWF.ofLamBaseTerm ltv b case ofIcst ic wf => dsimp [LamWF.ofLamBaseTerm]; rw [IntConst.lamWF_complete] case ofScst bc wf => dsimp [LamWF.ofLamBaseTerm]; rw [StringConst.lamWF_complete wf] case ofBvcst bc wf => dsimp [LamWF.ofLamBaseTerm]; rw [BitVecConst.lamWF_complete wf] + case ofOcst oc wf => dsimp [LamWF.ofLamBaseTerm]; rw [OtherConst.lamWF_complete wf] def LamBaseTerm.lamCheck_of_LamWF (H : LamWF ltv b s) : b.lamCheck ltv = s := by cases H <;> try rfl @@ -1902,6 +1999,7 @@ def LamBaseTerm.lamCheck_of_LamWF (H : LamWF ltv b s) : b.lamCheck ltv = s := by case ofIcst bc wf => apply IntConst.lamCheck_of_LamWF wf case ofScst sc wf => apply StringConst.lamCheck_of_LamWF wf case ofBvcst sc wf => apply BitVecConst.lamCheck_of_LamWF wf + case ofOcst oc wf => apply OtherConst.lamCheck_of_LamWF wf def LamBaseTerm.LamWF.ofCheck (H : b.lamCheck ltv = s) : LamWF ltv b s := by cases H; cases b <;> constructor @@ -1911,6 +2009,7 @@ def LamBaseTerm.LamWF.ofCheck (H : b.lamCheck ltv = s) : LamWF ltv b s := by case refl.icst.icwf => apply IntConst.LamWF.ofCheck; rfl case refl.scst.scwf => apply StringConst.LamWF.ofCheck; rfl case refl.bvcst.bvcwf => apply BitVecConst.LamWF.ofCheck; rfl + case refl.ocst.ocwf => apply OtherConst.LamWF.ofCheck; rfl structure ILLift (β : Type u) where eqL : EqLift.{u + 1, u} β @@ -2192,6 +2291,7 @@ noncomputable def LamBaseTerm.interp (lval : LamValuation.{u}) : (b : LamBaseTer | .icst ic => ic.interp lval.tyVal | .scst sc => sc.interp lval.tyVal | .bvcst bvc => bvc.interp lval.tyVal +| .ocst oc => oc.interp lval.tyVal | .eqI n => (lval.ilVal n).eqL.eqF | .forallEI n => (lval.ilVal n).forallL.forallF | .existEI n => (lval.ilVal n).existL.existF @@ -2208,6 +2308,7 @@ noncomputable def LamBaseTerm.LamWF.interp (lval : LamValuation.{u}) : (lwf : La | .ofIcst wf => wf.interp lval.tyVal | .ofScst wf => wf.interp lval.tyVal | .ofBvcst wf => wf.interp lval.tyVal +| .ofOcst wf => wf.interp lval.tyVal | .ofEqI n => (lval.ilVal n).eqL.eqF | .ofForallEI n => (lval.ilVal n).forallL.forallF | .ofExistEI n => (lval.ilVal n).existL.existF @@ -2254,6 +2355,7 @@ theorem LamBaseTerm.LamWF.interp_lvalIrrelevance case ofIcst => apply IntConst.LamWF.interp_lvalIrrelevance <;> rfl case ofScst => apply StringConst.LamWF.interp_lvalIrrelevance <;> rfl case ofBvcst => apply BitVecConst.LamWF.interp_lvalIrrelevance <;> rfl + case ofOcst => apply OtherConst.LamWF.interp_lvalIrrelevance <;> rfl def LamBaseTerm.interp_equiv (lval : LamValuation.{u}) (lwf : LamWF lval.toLamTyVal b s) : @@ -2265,6 +2367,7 @@ def LamBaseTerm.interp_equiv (lval : LamValuation.{u}) case ofIcst => apply IntConst.interp_equiv case ofScst => apply StringConst.interp_equiv case ofBvcst => apply BitVecConst.interp_equiv + case ofOcst => apply OtherConst.interp_equiv def LamValuation.insertEVarAt (lval : LamValuation.{u}) (ty : LamSort) (val : ty.interp lval.tyVal) (pos : Nat) := diff --git a/Auto/IR/SMT.lean b/Auto/IR/SMT.lean index 852e01c..ada46cf 100644 --- a/Auto/IR/SMT.lean +++ b/Auto/IR/SMT.lean @@ -118,6 +118,22 @@ end def STerm.qStrApp (s : String) (arr : Array STerm) := STerm.qIdApp (.ofString s) arr +def STerm.attrApp (name : String) (attrTerm : STerm) (term : STerm) : STerm := + match term with + | .attr term' attrs' => .attr term' (#[attr] ++ attrs') + | _ => .attr term #[attr] + where + attr := + match attrTerm with + -- An empty string constant indicates an attribute with no arguments + | .sConst (.str "") => .none name + -- Other string constants are always symbols here. + | .sConst (.str sym) => .symb name sym + -- Other constants are constant arguments. + | .sConst c => .spec name c + -- Non-constant arguments are terms. + | t => .sexpr name #[t] + private partial def STerm.toStringAux : STerm → List SIdent → String | .sConst c, _ => SpecConst.toString c | .bvar i, binders => diff --git a/Auto/IR/TPTP_TH0.lean b/Auto/IR/TPTP_TH0.lean index 8103644..bd4bd89 100644 --- a/Auto/IR/TPTP_TH0.lean +++ b/Auto/IR/TPTP_TH0.lean @@ -82,6 +82,8 @@ def transLamBaseTerm : LamBaseTerm → Except String String | .icst ic => .ok (transIntConst ic) | .scst sc => .ok (transStringConst sc) | .bvcst bvc => .ok (transBitVecConst bvc) +-- TODO: translate to λx => x? +| .ocst _ => .error "transLamBaseTerm :: attributes not supported in TPTP" | .eqI _ => .error "transLamBaseTerm :: eqI should not occur here" | .forallEI _ => .error "transLamBaseTerm :: forallEI should not occur here" | .existEI _ => .error "transLamBaseTerm :: existEI should not occur here" diff --git a/Auto/Translation/Lam2D.lean b/Auto/Translation/Lam2D.lean index 95a8ed7..7ad4d0f 100644 --- a/Auto/Translation/Lam2D.lean +++ b/Auto/Translation/Lam2D.lean @@ -253,6 +253,7 @@ def interpLamBaseTermAsUnlifted : LamBaseTerm → ExternM Expr | .icst ic => return interpIntConstAsUnlifted ic | .scst sc => return interpStringConstAsUnlifted sc | .bvcst bvc => return interpBitVecConstAsUnlifted bvc +| .ocst _ => throwError ("interpLamTermAsUnlifted :: Attributes not supported") | .eqI _ => throwError ("interpLamTermAsUnlifted :: " ++ exportError.ImpPolyLog) | .forallEI _ => throwError ("interpLamTermAsUnlifted :: " ++ exportError.ImpPolyLog) | .existEI _ => throwError ("interpLamTermAsUnlifted :: " ++ exportError.ImpPolyLog) diff --git a/Auto/Translation/LamFOL2SMT.lean b/Auto/Translation/LamFOL2SMT.lean index 74e0477..993f3e0 100644 --- a/Auto/Translation/LamFOL2SMT.lean +++ b/Auto/Translation/LamFOL2SMT.lean @@ -164,6 +164,7 @@ private def lamBaseTerm2STerm_Arity2 (arg1 arg2 : STerm) : LamBaseTerm → Trans | .lshr => return .qStrApp "bvlshr" #[arg1, arg2] | .ashr => return .qStrApp "bvashr" #[arg1, arg2] | .bvcst (.bvappend _ _) => return .qStrApp "concat" #[arg1, arg2] +| .ocst (.attribute s _) => return .attrApp s arg1 arg2 | t => throwError "lamTerm2STerm :: The arity of {repr t} is not 2" private def lamBaseTerm2STerm_Arity1 (arg : STerm) : LamBaseTerm → TransM LamAtom STerm diff --git a/Auto/Translation/LamReif.lean b/Auto/Translation/LamReif.lean index d7cfb72..f6cb02e 100644 --- a/Auto/Translation/LamReif.lean +++ b/Auto/Translation/LamReif.lean @@ -1271,6 +1271,11 @@ def reifMapBVConst3 : HashMap Name (Nat → Nat → Nat → LamTerm) := (``BitVec.extractLsb, fun n h l => .base (.bvextract n h l)) ] +def reifMapAttributes : HashMap Name String := + HashMap.ofList [ + (``trigger, "pattern") + ] + def processSimpleLit (l : Literal) : LamTerm := match l with | .natVal n => .base (.natVal n) @@ -1307,6 +1312,10 @@ def processSimpleApp (fn arg : Expr) : ReifM (Option LamTerm) := do -- `arg` is the original (un-lifted) type if let .some tcon := reifMapIL.find? name then return .some (.base (tcon (← reifType arg))) + if let .some attrName := reifMapAttributes.find? name then + if lvls.length != 0 then + throwError "processSimpleApp :: attribute should have nil level list" + return .some (.base (.ocst (.attribute attrName (← reifType arg)))) return .none | [arg₁, arg₂] => if let .some tcon := reifMapBVConst2.find? name then diff --git a/Test/SmtTranslation/Trigger.lean b/Test/SmtTranslation/Trigger.lean new file mode 100644 index 0000000..045f3e1 --- /dev/null +++ b/Test/SmtTranslation/Trigger.lean @@ -0,0 +1,18 @@ +import Auto.Tactic +import Auto.Embedding.LamBase +open Auto.Embedding.Lam + +set_option auto.smt.trust true +set_option trace.auto.smt.printCommands true +set_option trace.auto.smt.result true +set_option trace.auto.smt.unsatCore true +set_option auto.smt.save true +set_option auto.smt.savepath "output.smt2" + +set_option auto.smt true + +axiom f : Int -> Int + +axiom fGreater : forall x, trigger (f x) (f x > x) + +theorem fPlusOneGreater : forall x, (f x) + 1 > x := by auto [fGreater] u[]