diff --git a/Auto/IR/SMT.lean b/Auto/IR/SMT.lean index ada46cf..5460c05 100644 --- a/Auto/IR/SMT.lean +++ b/Auto/IR/SMT.lean @@ -2,6 +2,9 @@ import Lean import Auto.Lib.MonadUtils open Lean +initialize + registerTraceClass `auto.smt.h2symb + -- smt-lib 2 namespace Auto @@ -277,6 +280,20 @@ def SMTOption.toString : SMTOption → String instance : ToString SMTOption where toString := SMTOption.toString +def SMTReservedWords : HashSet String := + let reserved := #[ + "_", "!", + "as", "let", "exists", "forall", "match", "par", + "assert", "check-sat", "check-sat-assuming", + "declare-const", "declare-datatype", "declare-datatypes", + "declare-fun", "declare-sort", "define-fun", "define-fun-rec", "define-funs-rec", + "define-sort", "echo", "exit", "get-assertions", "get-info", + "get-model", "get-option", "get-proof", "get-unsat-assumptions", + "get-unsat-core", "get-value", "pop", "push", "reset", "reset-assertions", + "set-info", "set-logic", "set-option" + ] + reserved.foldl (fun hs s => hs.insert s) {} + /-- 〈sorted_var〉 ::= ( 〈symbol〉 〈sort〉 ) 〈datatype_dec〉 ::= ( 〈constructor_dec〉+ ) | ( par ( 〈symbol〉+ ) ( 〈constructor_dec〉+ ) ) @@ -353,7 +370,7 @@ def Command.toString : Command → String | .declDtype name ddecl => s!"(declare-datatype {SIdent.symb name} {ddecl.toString})" | .declDtypes infos => - let sort_decs := String.intercalate " " (infos.data.map (fun (name, args, _) => s!"({name} {args})")) + let sort_decs := String.intercalate " " (infos.data.map (fun (name, args, _) => s!"({SIdent.symb name} {args})")) let datatype_decs := String.intercalate " " (infos.data.map (fun (_, _, ddecl) => ddecl.toString)) s!"(declare-datatypes ({sort_decs}) ({datatype_decs}))" | .exit => "(exit)" @@ -379,8 +396,8 @@ section -- Map from symbol to high-level construct l2hMap : HashMap String ω := {} -- State of low-level name generator - -- To avoid collision with keywords, we only - -- generate non-annotated identifiers `smti_` + -- To avoid collisions with other identifiers or keywords, + -- we append identifiers with an unique index, e.g. `forall_` idx : Nat := 0 -- List of commands commands : Array Command := #[] @@ -415,29 +432,48 @@ section def hIn (e : ω) : TransM ω Bool := do return (← getH2lMap).contains e - /-- Used for e.g. bound variables -/ - partial def disposableName : TransM ω String := do + def binderNamePrefixFromSort (sort : SSort) : String := + match sort with + | SSort.bvar _ => "bvar" + | SSort.app (SIdent.symb s) _ => s.take 1 + | SSort.app (SIdent.indexed s _) _ => s.take 1 + + /-- Used for e.g. bound variables. -/ + partial def disposableName (sortHint : SSort): TransM ω String := do let l2hMap ← getL2hMap let idx ← getIdx - let currName := s!"smtd_{idx}" + let mut currName := s!"{binderNamePrefixFromSort sortHint}{idx}" + -- Try to avoid collisions with other identifiers if l2hMap.contains currName then - throwError "disposableName :: Unexpected error" + currName := s!"{currName}_{idx}" + if l2hMap.contains currName then + throwError "disposableName :: Unexpected error - binder disposable name already exists!" setIdx (idx + 1) return currName + def smtNameFromExpr (e : Expr) : TransM ω String := do + let ppSyntax := (← PrettyPrinter.delab e).raw + let ppStr := toString (← PrettyPrinter.formatTerm ppSyntax) + return ppStr.map (fun c => if c.isAlphanum then c else '_') + /-- Turn high-level construct into low-level symbol Note that this function is idempotent + `nameHint` is an expression from which we can extract a name. -/ - partial def h2Symb (cstr : ω) : TransM ω String := do + partial def h2Symb [ToString ω] (cstr : ω) (nameHint : Option Expr := none) : TransM ω String := do let l2hMap ← getL2hMap let h2lMap ← getH2lMap if let .some name := h2lMap.find? cstr then return name let idx ← getIdx - let currName : String := s!"smti_{idx}" - if l2hMap.contains currName then - throwError "h2Symb :: Unexpected error" + let mut currName := (← nameHint.mapM smtNameFromExpr).getD "smti" + -- NOTE: In the case of duplicate names or SMT reserved words, we append the index to the name + if l2hMap.contains currName || SMTReservedWords.contains currName then + currName := s!"{currName}_{idx}" + if l2hMap.contains currName then + throwError "h2Symb :: Unexpected error - identifier {currName} already exists!" + trace[auto.smt.h2symb] "{currName} From LamAtom {cstr})" setL2hMap (l2hMap.insert currName cstr) setH2lMap (h2lMap.insert cstr currName) setIdx (idx + 1) diff --git a/Auto/Tactic.lean b/Auto/Tactic.lean index f0a6d49..eded1a8 100644 --- a/Auto/Tactic.lean +++ b/Auto/Tactic.lean @@ -284,13 +284,18 @@ def queryTPTP (exportFacts : Array REntry) : LamReif.ReifM (Array Embedding.Lam. open Embedding.Lam in def querySMT (exportFacts : Array REntry) (exportInds : Array MutualIndInfo) : LamReif.ReifM (Option Expr) := do - let lamVarTy := (← LamReif.getVarVal).map Prod.snd + -- GEORGE: do we need to pass more of `LamReif:State` to `lamFOL2SMT`? + let lamVarTy ← LamReif.getVarVal + trace[auto.lamReif.printValuation] "lamVarTy: {lamVarTy}" let lamEVarTy ← LamReif.getLamEVarTy + trace[auto.lamReif.printValuation] "lamEVarTy: {lamEVarTy}" + let tyVal ← LamReif.getTyVal + trace[auto.lamReif.printValuation] "tyVal: {tyVal}" let exportLamTerms ← exportFacts.mapM (fun re => do match re with | .valid [] t => return t | _ => throwError "runAuto :: Unexpected error") - let commands ← (lamFOL2SMT lamVarTy lamEVarTy exportLamTerms exportInds).run' + let commands ← (lamFOL2SMT lamVarTy lamEVarTy tyVal exportLamTerms exportInds).run' for cmd in commands do trace[auto.smt.printCommands] "{cmd}" if (auto.smt.save.get (← getOptions)) then diff --git a/Auto/Translation/LamFOL2SMT.lean b/Auto/Translation/LamFOL2SMT.lean index 69bf32f..882f6c8 100644 --- a/Auto/Translation/LamFOL2SMT.lean +++ b/Auto/Translation/LamFOL2SMT.lean @@ -32,6 +32,18 @@ private inductive LamAtom where | compProj : LamTerm → LamAtom deriving Inhabited, Hashable, BEq +private def LamAtom.toString : LamAtom → String +| .sort n => s!"sort {n}" +| .term n => s!"term {n}" +| .etom n => s!"etom {n}" +| .bvOfNat n => s!"bvOfNat {n}" +| .bvToNat n => s!"bvToNat {n}" +| .compCtor t => s!"compCtor {t}" +| .compProj t => s!"compProj {t}" + +instance : ToString LamAtom where + toString := LamAtom.toString + private def lamBaseSort2SSort : LamBaseSort → SSort | .prop => .app (.symb "Bool") #[] | .bool => .app (.symb "Bool") #[] @@ -45,53 +57,55 @@ private def lamBaseSort2SSort : LamBaseSort → SSort | _ => .app (.symb "Empty") #[] | .bv n => .app (.indexed "BitVec" #[.inr n]) #[] -private def lamSortAtom2String (n : Nat) : TransM LamAtom String := do +private def lamSortAtom2String (tyVal : Array (Expr × Level)) (n : Nat) : TransM LamAtom String := do + let .some (e, _) := tyVal[n]? + | throwError "lamSortAtom2String :: Unexpected sort atom {repr (LamSort.atom n)}" if !(← hIn (.sort n)) then - let name ← h2Symb (.sort n) + let name ← h2Symb (.sort n) e addCommand (.declSort name 0) - return ← h2Symb (.sort n) + return ← h2Symb (.sort n) e -private def lamSort2SSortAux : LamSort → TransM LamAtom SSort -| .atom n => do return .app (.symb (← lamSortAtom2String n)) #[] +private def lamSort2SSortAux (tyVal : Array (Expr × Level)) : LamSort → TransM LamAtom SSort +| .atom n => do return .app (.symb (← lamSortAtom2String tyVal n)) #[] | .base b => return lamBaseSort2SSort b | .func _ _ => throwError "lamSort2STermAux :: Unexpected error. Higher order input?" /-- Only translates first-order types -/ -private def lamSort2SSort : LamSort → TransM LamAtom (List SSort × SSort) +private def lamSort2SSort (tyVal : Array (Expr × Level)) : LamSort → TransM LamAtom (List SSort × SSort) | .func argTy resTy => do - let (smargs, smres) ← lamSort2SSort resTy - let smarg ← lamSort2SSortAux argTy + let (smargs, smres) ← lamSort2SSort tyVal resTy + let smarg ← lamSort2SSortAux tyVal argTy return (smarg :: smargs, smres) -| s => return ([], ← lamSort2SSortAux s) +| s => return ([], ← lamSort2SSortAux tyVal s) -private def addNatConstraint? (name : String) (s : LamSort) : TransM LamAtom Unit := do +private def addNatConstraint? (tyVal : Array (Expr × Level)) (name : String) (s : LamSort) : TransM LamAtom Unit := do let resTy := s.getResTy if !(resTy == .base .nat) then return - let args ← (Array.mk s.getArgTys).mapM (fun s => do return (s, ← IR.SMT.disposableName)) + let args ← (Array.mk s.getArgTys).mapM (fun s => do return (s, ← IR.SMT.disposableName (← lamSort2SSortAux tyVal s))) let fnApp := STerm.qStrApp name (args.zipWithIndex.map (fun (_, n) => .bvar (args.size - 1 - n))) let mut fnConstr := STerm.qStrApp ">=" #[fnApp, .sConst (.num 0)] for (argTy, argName) in args.reverse do if argTy == .base .nat then fnConstr := .qStrApp "=>" #[.qStrApp ">=" #[.bvar 0, .sConst (.num 0)], fnConstr] - fnConstr := .forallE argName (← lamSort2SSortAux argTy) fnConstr + fnConstr := .forallE argName (← lamSort2SSortAux tyVal argTy) fnConstr addCommand (.assert fnConstr) private def int2STerm : Int → STerm | .ofNat n => .sConst (.num n) | .negSucc n => .qIdApp (QualIdent.ofString "-") #[.sConst (.num (Nat.succ n))] -private def lamBvOfNat2String (n : Nat) : TransM LamAtom String := do +private def lamBvOfNat2String (tyVal : Array (Expr × Level)) (n : Nat) : TransM LamAtom String := do if !(← hIn (.bvOfNat n)) then - let name ← h2Symb (.bvOfNat n) - let (argSorts, resSort) ← lamSort2SSort (.func (.base .int) (.base (.bv n))) + let name ← h2Symb (.bvOfNat n) (Expr.const ``BitVec.ofNat []) + let (argSorts, resSort) ← lamSort2SSort tyVal (.func (.base .int) (.base (.bv n))) addCommand (.declFun name ⟨argSorts⟩ resSort) return ← h2Symb (.bvOfNat n) -private def lamBvToNat2String (n : Nat) : TransM LamAtom String := do +private def lamBvToNat2String (tyVal : Array (Expr × Level)) (n : Nat) : TransM LamAtom String := do if !(← hIn (.bvToNat n)) then - let name ← h2Symb (.bvToNat n) - let (argSorts, resSort) ← lamSort2SSort (.func (.base (.bv n)) (.base .int)) + let name ← h2Symb (.bvToNat n) (Expr.const ``BitVec.toNat []) + let (argSorts, resSort) ← lamSort2SSort tyVal (.func (.base (.bv n)) (.base .int)) addCommand (.declFun name ⟨argSorts⟩ resSort) return ← h2Symb (.bvToNat n) @@ -167,7 +181,7 @@ private def lamBaseTerm2STerm_Arity2 (arg1 arg2 : STerm) : LamBaseTerm → Trans | .ocst (.smtAttr1T name _ _) => return .attrApp name arg1 arg2 | t => throwError "lamTerm2STerm :: The arity of {repr t} is not 2" -private def lamBaseTerm2STerm_Arity1 (arg : STerm) : LamBaseTerm → TransM LamAtom STerm +private def lamBaseTerm2STerm_Arity1 (tyVal : Array (Expr × Level)) (arg : STerm) : LamBaseTerm → TransM LamAtom STerm | .pcst .not => return .qStrApp "not" #[arg] | .bcst .ofProp => return arg | .bcst .notb => return .qStrApp "not" #[arg] @@ -182,19 +196,19 @@ private def lamBaseTerm2STerm_Arity1 (arg : STerm) : LamBaseTerm → TransM LamA if name == .z3 || name == .cvc5 then return .qIdApp (.ident (.indexed "int2bv" #[.inr n])) #[arg] else - return .qStrApp (← lamBvOfNat2String n) #[arg] + return .qStrApp (← lamBvOfNat2String tyVal n) #[arg] | .bvcst (.bvtoNat n) => do let name ← solverName if name == .z3 || name == .cvc5 then return .qStrApp "bv2nat" #[arg] else - return .qStrApp (← lamBvToNat2String n) #[arg] + return .qStrApp (← lamBvToNat2String tyVal n) #[arg] | .bvcst (.bvofInt n) => do let name ← solverName if name == .z3 || name == .cvc5 then return .qIdApp (.ident (.indexed "int2bv" #[.inr n])) #[arg] else - return .qStrApp (← lamBvOfNat2String n) #[arg] + return .qStrApp (← lamBvOfNat2String tyVal n) #[arg] | .bvcst (.bvtoInt n) => do let name ← solverName let msbExpr := mkSMTMsbExpr n arg @@ -203,8 +217,8 @@ private def lamBaseTerm2STerm_Arity1 (arg : STerm) : LamBaseTerm → TransM LamA let arg2 := .qStrApp "bv2nat" #[arg] return .qStrApp "ite" #[msbExpr, arg1, arg2] else - let arg1 := .qStrApp "-" #[.qStrApp (← lamBvToNat2String n) #[arg], .sConst (.num (2 ^ n))] - let arg2 := .qStrApp (← lamBvToNat2String n) #[arg] + let arg1 := .qStrApp "-" #[.qStrApp (← lamBvToNat2String tyVal n) #[arg], .sConst (.num (2 ^ n))] + let arg2 := .qStrApp (← lamBvToNat2String tyVal n) #[arg] return .qStrApp "ite" #[msbExpr, arg1, arg2] -- @BitVec.msb n a = not ((a &&& (1 <<< (n - 1))) = 0#n) | .bvcst (.bvmsb n) => return mkSMTMsbExpr n arg @@ -243,20 +257,20 @@ private def lamBaseTerm2STerm_Arity0 : LamBaseTerm → TransM LamAtom STerm | .bvcst (.bvVal n i) => return bitVec2STerm n i | t => throwError "lamTerm2STerm :: The arity of {repr t} is not 0" -def lamTermAtom2String (lamVarTy : Array LamSort) (n : Nat) : TransM LamAtom (LamSort × String) := do - let .some s := lamVarTy[n]? +def lamTermAtom2String (lamVarTy : Array (Expr × LamSort)) (tyVal : Array (Expr × Level)) (n : Nat) : TransM LamAtom (LamSort × String) := do + let .some (e, s) := lamVarTy[n]? | throwError "lamTermAtom2String :: Unexpected term atom {repr (LamTerm.atom n)}" -- Empty type is not inhabited if s == .base .empty then addCommand (.assert (.qStrApp "false" #[])) if !(← hIn (.term n)) then - let name ← h2Symb (.term n) - let (argSorts, resSort) ← lamSort2SSort s + let name ← h2Symb (.term n) e + let (argSorts, resSort) ← lamSort2SSort tyVal s addCommand (.declFun name ⟨argSorts⟩ resSort) - addNatConstraint? name s - return (s, ← h2Symb (.term n)) + addNatConstraint? tyVal name s + return (s, ← h2Symb (.term n) e) -def lamTermEtom2String (lamEVarTy : Array LamSort) (n : Nat) : TransM LamAtom (LamSort × String) := do +def lamTermEtom2String (lamEVarTy : Array LamSort) (tyVal : Array (Expr × Level)) (n : Nat) : TransM LamAtom (LamSort × String) := do let .some s := lamEVarTy[n]? | throwError "lamTerm2STerm :: Unexpected etom {repr (LamTerm.etom n)}" -- Empty type is not inhabited @@ -264,38 +278,38 @@ def lamTermEtom2String (lamEVarTy : Array LamSort) (n : Nat) : TransM LamAtom (L addCommand (.assert (.qStrApp "false" #[])) if !(← hIn (.etom n)) then let name ← h2Symb (.etom n) - let (argSorts, resSort) ← lamSort2SSort s + let (argSorts, resSort) ← lamSort2SSort tyVal s addCommand (.declFun name ⟨argSorts⟩ resSort) - addNatConstraint? name s + addNatConstraint? tyVal name s return (s, ← h2Symb (.etom n)) -private def lamTerm2STermAux (lamVarTy lamEVarTy : Array LamSort) (args : Array STerm) : +private def lamTerm2STermAux (lamVarTy : Array (Expr × LamSort)) (lamEVarTy : Array LamSort) (tyVal : Array (Expr × Level)) (args : Array STerm) : LamTerm → TransM LamAtom STerm | .atom n => do - let (s, name) ← lamTermAtom2String lamVarTy n + let (s, name) ← lamTermAtom2String lamVarTy tyVal n if args.size != s.getArgTys.length then throwError "lamTerm2STerm :: Argument number mismatch. Higher order input?" return .qIdApp (QualIdent.ofString name) args | .etom n => do - let (s, name) ← lamTermEtom2String lamEVarTy n + let (s, name) ← lamTermEtom2String lamEVarTy tyVal n if args.size != s.getArgTys.length then throwError "lamTerm2STerm :: Argument number mismatch. Higher order input?" return .qIdApp (QualIdent.ofString name) args | .base b => match args with | #[] => lamBaseTerm2STerm_Arity0 b - | #[u₁] => lamBaseTerm2STerm_Arity1 u₁ b + | #[u₁] => lamBaseTerm2STerm_Arity1 tyVal u₁ b | #[u₁, u₂] => lamBaseTerm2STerm_Arity2 u₁ u₂ b | #[u₁, u₂, u₃] => lamBaseTerm2STerm_Arity3 u₁ u₂ u₃ b | _ => throwError "lamTerm2STerm :: Argument number mismatch. Higher order input?" | t => throwError "lamTerm2STerm :: Unexpected head term {repr t}" -def lamQuantified2STerm (forall? : Bool) (s : LamSort) (body : TransM LamAtom STerm) : TransM LamAtom STerm := do +def lamQuantified2STerm (tyVal : Array (Expr × Level)) (forall? : Bool) (s : LamSort) (body : TransM LamAtom STerm) : TransM LamAtom STerm := do -- Empty type is not inhabited if s == .base .empty then return .qStrApp "true" #[] - let s' ← lamSort2SSortAux s - let dname ← disposableName + let s' ← lamSort2SSortAux tyVal s + let dname ← disposableName s' let mut body' ← body if s == .base .nat then let connective := if forall? then "=>" else "and" @@ -304,7 +318,7 @@ def lamQuantified2STerm (forall? : Bool) (s : LamSort) (body : TransM LamAtom ST | true => return .forallE dname s' body' | false => return .existE dname s' body' -private partial def lamTerm2STerm (lamVarTy lamEVarTy : Array LamSort) : +private partial def lamTerm2STerm (lamVarTy : Array (Expr × LamSort)) (lamEVarTy : Array LamSort) (tyVal : Array (Expr × Level)) : LamTerm → TransM LamAtom STerm | .base b => lamBaseTerm2STerm_Arity0 b | .bvar n => return .bvar n @@ -317,22 +331,22 @@ private partial def lamTerm2STerm (lamVarTy lamEVarTy : Array LamSort) : | .app _ (.app _ (.app _ (.base (.iteI _)) _) _) _ => throwError ("lamTerm2STerm :: " ++ LamReif.exportError.ImpPolyLog) | .app _ (.app _ (.base (.eq _)) arg₁) arg₂ => do - let arg₁' ← lamTerm2STerm lamVarTy lamEVarTy arg₁ - let arg₂' ← lamTerm2STerm lamVarTy lamEVarTy arg₂ + let arg₁' ← lamTerm2STerm lamVarTy lamEVarTy tyVal arg₁ + let arg₂' ← lamTerm2STerm lamVarTy lamEVarTy tyVal arg₂ return .qIdApp (QualIdent.ofString "=") #[arg₁', arg₂'] | .app _ (.base (.forallE _)) (.lam s body) => do - lamQuantified2STerm true s (lamTerm2STerm lamVarTy lamEVarTy body) + lamQuantified2STerm tyVal true s (lamTerm2STerm lamVarTy lamEVarTy tyVal body) | .app _ (.base (.existE _)) (.lam s body) => do - lamQuantified2STerm false s (lamTerm2STerm lamVarTy lamEVarTy body) + lamQuantified2STerm tyVal false s (lamTerm2STerm lamVarTy lamEVarTy tyVal body) | .app _ (.app _ (.app _ (.base (.ite _)) cond) arg₁) arg₂ => do - let cond' ← lamTerm2STerm lamVarTy lamEVarTy cond - let arg₁' ← lamTerm2STerm lamVarTy lamEVarTy arg₁ - let arg₂' ← lamTerm2STerm lamVarTy lamEVarTy arg₂ + let cond' ← lamTerm2STerm lamVarTy lamEVarTy tyVal cond + let arg₁' ← lamTerm2STerm lamVarTy lamEVarTy tyVal arg₁ + let arg₂' ← lamTerm2STerm lamVarTy lamEVarTy tyVal arg₂ return .qStrApp "ite" #[cond', arg₁', arg₂'] | t => do let (ts, t) := splitApp t - let ts' ← ts.mapM (lamTerm2STerm lamVarTy lamEVarTy) - lamTerm2STermAux lamVarTy lamEVarTy ts' t + let ts' ← ts.mapM (lamTerm2STerm lamVarTy lamEVarTy tyVal) + lamTerm2STermAux lamVarTy lamEVarTy tyVal ts' t where splitApp : LamTerm → Array LamTerm × LamTerm | .app _ fn arg => @@ -340,7 +354,7 @@ where (ts.push arg, t) | t => (#[], t) -private def lamMutualIndInfo2STerm (mind : MutualIndInfo) : +private def lamMutualIndInfo2STerm (lamVarTy : Array (Expr × LamSort)) (lamEVarTy : Array LamSort) (tyVal : Array (Expr × Level)) (mind : MutualIndInfo) : TransM LamAtom (IR.SMT.Command × Array (String × LamSort × LamTerm) × Array (String × LamSort × LamTerm)) := do @@ -352,11 +366,15 @@ private def lamMutualIndInfo2STerm (mind : MutualIndInfo) : for ⟨type, _, _⟩ in mind do let .atom sn := type | throwError "lamMutualIndInfo2STerm :: Inductive type {type} is not a sort atom" + let .some (se, _) := tyVal[sn]? + | throwError "lamMutualIndInfo2STerm :: Inductive type {type} is not in tyVal" -- Do not use `lamSortAtom2String` because we don't want to `declare-sort` - let _ ← h2Symb (.sort sn) + let _ ← h2Symb (.sort sn) se for ⟨type, ctors, projs⟩ in mind do let .atom sn := type | throwError "lamMutualIndInfo2STerm :: Unexpected error" + let .some (se, _) := tyVal[sn]? + | throwError "lamMutualIndInfo2STerm :: Inductive type {type} is not in tyVal" let mut projInfos : Array (LamSort × String) := #[] if let .some projs := projs then if ctors.length != 1 then @@ -364,21 +382,27 @@ private def lamMutualIndInfo2STerm (mind : MutualIndInfo) : for (s, t) in projs do let mut projname := "" match t with - | .atom n => projname ← h2Symb (.term n) + | .atom n => + let .some (e, _) := lamVarTy[n]? + | throwError "lamMutualIndInfo2STerm :: {repr t} is not in lamVarTy" + projname ← h2Symb (.term n) e | .etom n => projname ← h2Symb (.etom n) | t => projname ← h2Symb (.compProj t); compProjs := compProjs.push (projname, s, t) projInfos := projInfos.push (s, projname) - let sname ← h2Symb (.sort sn) + let sname ← h2Symb (.sort sn) se let mut cstrDecls : Array ConstrDecl := #[] for (s, t) in ctors do let mut ctorname := "" match t with -- Do not use `lamSortAtom2String` because we don't want to `declare-fun` - | .atom n => ctorname ← h2Symb (.term n) + | .atom n => + let .some (e, _) := lamVarTy[n]? + | throwError "lamMutualIndInfo2STerm :: {repr t} is not in lamVarTy" + ctorname ← h2Symb (.term n) e -- Do not use `lamSortEtom2String` because we don't want to `declare-fun` | .etom n => ctorname ← h2Symb (.etom n) | t => ctorname ← h2Symb (.compCtor t); compCtors := compCtors.push (ctorname, s, t) - let (argTys, _) ← lamSort2SSort s + let (argTys, _) ← lamSort2SSort tyVal s let mut selDecls := #[] if projs.isSome then if argTys.length != projInfos.size then @@ -391,15 +415,15 @@ private def lamMutualIndInfo2STerm (mind : MutualIndInfo) : infos := infos.push (sname, 0, ⟨#[], cstrDecls⟩) return (.declDtypes infos, compCtors, compProjs) -private def compEqn (lamVarTy lamEVarTy : Array LamSort) (compInfo : String × LamSort × LamTerm) : TransM LamAtom IR.SMT.Command := do +private def compEqn (lamVarTy : Array (Expr × LamSort)) (lamEVarTy : Array LamSort) (tyVal : Array (Expr × Level)) (compInfo : String × LamSort × LamTerm) : TransM LamAtom IR.SMT.Command := do let (name, s, t) := compInfo let argTys := s.getArgTys let sbvars := (List.range argTys.length).map (fun n => .bvar (argTys.length - 1 - n)) let slhs := .qStrApp name ⟨sbvars⟩ - let srhs := ← lamTerm2STerm lamVarTy lamEVarTy (LamTerm.bvarAppsRev t argTys).headBeta + let srhs := ← lamTerm2STerm lamVarTy lamEVarTy tyVal (LamTerm.bvarAppsRev t argTys).headBeta let mut eqn := pure (.qStrApp "=" #[slhs, srhs]) for s in argTys.reverse do - eqn := lamQuantified2STerm true s eqn + eqn := lamQuantified2STerm tyVal true s eqn return .assert (← eqn) def sortAuxDecls : Array IR.SMT.Command := @@ -429,21 +453,23 @@ def termAuxDecls : Array IR.SMT.Command := `valid_fact_{i}` corresponds to the `i`-th entry in `facts` -/ def lamFOL2SMT - (lamVarTy lamEVarTy : Array LamSort) + (lamVarTy : Array (Expr × LamSort)) + (lamEVarTy : Array LamSort) + (tyVal : Array (Expr × Level)) (facts : Array LamTerm) (minds : Array MutualIndInfo) : TransM LamAtom (Array IR.SMT.Command) := do let _ ← sortAuxDecls.mapM addCommand let _ ← termAuxDecls.mapM addCommand for mind in minds do - let (dsdecl, compCtors, compProjs) ← lamMutualIndInfo2STerm mind + let (dsdecl, compCtors, compProjs) ← lamMutualIndInfo2STerm lamVarTy lamEVarTy tyVal mind trace[auto.lamFOL2SMT] "MutualIndInfo translated to command {dsdecl}" addCommand dsdecl - let compCtorEqns ← compCtors.mapM (compEqn lamVarTy lamEVarTy) + let compCtorEqns ← compCtors.mapM (compEqn lamVarTy lamEVarTy tyVal) let _ ← compCtorEqns.mapM addCommand - let compProjEqns ← compProjs.mapM (compEqn lamVarTy lamEVarTy) + let compProjEqns ← compProjs.mapM (compEqn lamVarTy lamEVarTy tyVal) let _ ← compProjEqns.mapM addCommand for (t, idx) in facts.zipWithIndex do - let sterm ← lamTerm2STerm lamVarTy lamEVarTy t + let sterm ← lamTerm2STerm lamVarTy lamEVarTy tyVal t trace[auto.lamFOL2SMT] "λ term {repr t} translated to SMT term {sterm}" addCommand (.assert (.attr sterm #[.symb "named" s!"valid_fact_{idx}"])) getCommands diff --git a/Test/SmtTranslation/Names.lean b/Test/SmtTranslation/Names.lean new file mode 100644 index 0000000..0ab68db --- /dev/null +++ b/Test/SmtTranslation/Names.lean @@ -0,0 +1,85 @@ +import Auto.Tactic + +set_option auto.smt true +set_option auto.smt.trust true + +set_option trace.auto.printLemmas true +set_option trace.auto.lamReif.printValuation true +set_option trace.auto.lamReif.printResult true + +set_option trace.auto.smt.printCommands true +set_option trace.auto.smt.result true +set_option trace.auto.smt.unsatCore true +set_option trace.auto.smt.model true + +example : forall (n : Nat), n = n := by + intro n + auto + +class TotalOrder (t : Type) := + -- relation: total order + le (x y : t) : Bool + none : t + -- axioms + le_refl (x : t) : le x x + le_trans (x y z : t) : le x y → le y z → le x z + le_antisymm (x y : t) : le x y → le y x → x = y + le_total (x y : t) : le x y ∨ le y x + +class Quorum (node : Type) (quorum : outParam Type):= + -- relation + member (a : node) (q : quorum) : Bool + -- axioms + quorum_intersection : + ∀ (q1 q2 : quorum), ∃ (a : node), member a q1 ∧ member a q2 + +theorem extracted_paxos_goal {node : Type} [inst : DecidableEq node] {value : Type} [inst_1 : DecidableEq value] + {quorum : Type} [inst_2 : Quorum node quorum] {round : Type} [inst_3 : DecidableEq round] [inst_4 : TotalOrder round] + (st_one_a : round → Bool) (st_one_b_max_vote : node → round → round → value → Bool) + (st_one_b st_leftRound : node → round → Bool) (st_proposal : round → value → Bool) + (st_vote st_decision : node → round → value → Bool) + (hinv : + (∀ (n1 n2 : node) (r1 r2 : round) (v1 v2 : value), + st_decision n1 r1 v1 = true ∧ st_decision n2 r2 v2 = true → r1 = r2 ∧ v1 = v2) ∧ + (∀ (r : round) (v1 v2 : value), st_proposal r v1 = true ∧ st_proposal r v2 = true → v1 = v2) ∧ + (∀ (n : node) (r : round) (v : value), st_vote n r v = true → st_proposal r v = true) ∧ + (∀ (r : round) (v : value), + (∃ n, st_decision n r v = true) → ∃ q, ∀ (n : node), Quorum.member n q = true → st_vote n r v = true) ∧ + (∀ (n : node) (v : value), ¬st_vote n TotalOrder.none v = true) ∧ + (∀ (r1 r2 : round) (v1 v2 : value) (q : quorum), + ¬TotalOrder.le r2 r1 = true ∧ st_proposal r2 v2 = true ∧ v1 ≠ v2 → + ∃ n r3 rmax v, + Quorum.member n q = true ∧ + ¬TotalOrder.le r3 r1 = true ∧ st_one_b_max_vote n r3 rmax v = true ∧ ¬st_vote n r1 v1 = true) ∧ + ∀ (n : node) (r1 r2 : round), + st_one_b n r2 = true ∧ ¬TotalOrder.le r2 r1 = true → st_leftRound n r1 = true) + (st'_one_a : round → Bool) (st'_one_b_max_vote : node → round → round → value → Bool) + (st'_one_b st'_leftRound : node → round → Bool) (st'_proposal : round → value → Bool) + (st'_vote st'_decision : node → round → value → Bool) + (hnext : + ∃ n r max_round max_val, + r ≠ TotalOrder.none ∧ + st_one_a r = true ∧ + ¬st_leftRound n r = true ∧ + ((max_round = TotalOrder.none ∧ + ∀ (MAXR : round) (V : value), ¬(¬TotalOrder.le r MAXR = true ∧ st_vote n MAXR V = true)) ∨ + max_round ≠ TotalOrder.none ∧ + ¬TotalOrder.le r max_round = true ∧ + st_vote n max_round max_val = true ∧ + ∀ (MAXR : round) (V : value), + ¬TotalOrder.le r MAXR = true ∧ st_vote n MAXR V = true → TotalOrder.le MAXR max_round = true) ∧ + st'_one_a = st_one_a ∧ + (st'_one_b_max_vote = fun x x_1 x_2 x_3 => + if (x, x_1, x_2, x_3, ()) = (n, r, max_round, max_val, ()) then true + else st_one_b_max_vote x x_1 x_2 x_3) ∧ + (st'_one_b = fun x x_1 => if (x, x_1, ()) = (n, r, ()) then true else st_one_b x x_1) ∧ + (st'_leftRound = fun N R => decide (st_leftRound N R = true ∨ N = n ∧ ¬TotalOrder.le r R = true)) ∧ + st'_proposal = st_proposal ∧ st'_vote = st_vote ∧ st'_decision = st_decision) + (r1 r2 : round) (v1 v2 : value) (q : quorum) + (h : ¬TotalOrder.le r2 r1 = true ∧ st'_proposal r2 v2 = true ∧ v1 ≠ v2) : + ∃ n r3 rmax v, + Quorum.member n q = true ∧ + ¬TotalOrder.le r3 r1 = true ∧ st'_one_b_max_vote n r3 rmax v = true ∧ ¬st'_vote n r1 v1 = true := by + + auto [hnext, hinv, h] + sorry