diff --git a/Auto/Parser/SMTParser.lean b/Auto/Parser/SMTParser.lean index c2a1f5f..4523ba0 100644 --- a/Auto/Parser/SMTParser.lean +++ b/Auto/Parser/SMTParser.lean @@ -270,8 +270,8 @@ def smtSymbolToLeanName (s : String) : List (Name × SymbolInput) := | "=" => [(``Eq, TwoExactEq)] | _ => [] -def builtInSymbolMap : HashMap String Expr := - let map := HashMap.empty +def builtInSymbolMap : Std.HashMap String Expr := + let map := Std.HashMap.empty let map := map.insert "Nat" (mkConst ``Nat) let map := map.insert "Int" (mkConst ``Int) let map := map.insert "Bool" (.sort .zero) @@ -290,7 +290,7 @@ partial def getForallArgumentTypes (e : Expr) : List Expr := partial def getExplicitForallArgumentTypes (e : Expr) : List Expr := match e.consumeMData with | Expr.forallE _ t b .default => t :: (getExplicitForallArgumentTypes b) - | Expr.forallE _ t b _ => getExplicitForallArgumentTypes b -- Skip over t because this binder is implicit + | Expr.forallE _ _t b _ => getExplicitForallArgumentTypes b -- Skip over t because this binder is implicit | _ => [] inductive ParseTermConstraint @@ -339,7 +339,7 @@ def getNextSortedVars (originalSortedVars : Array (String × Expr)) (curPropBool mutual /-- Given a sorted var of the form `(symbol type)`, returns the string of the symbol and the type as an Expr -/ -partial def parseSortedVar (sortedVar : Term) (symbolMap : HashMap String Expr) : MetaM (String × Expr) := do +partial def parseSortedVar (sortedVar : Term) (symbolMap : Std.HashMap String Expr) : MetaM (String × Expr) := do match sortedVar with | app sortedVar => match sortedVar with @@ -352,7 +352,7 @@ partial def parseSortedVar (sortedVar : Term) (symbolMap : HashMap String Expr) | _ => throwError "parseSortedVar :: {sortedVar} is supposed to be a sortedVar, not an atom" partial def parseForallBodyWithSortedVars (vs : List Term) (sortedVars : Array (String × Expr)) - (symbolMap : HashMap String Expr) (forallBody : Term) : MetaM Expr := do + (symbolMap : Std.HashMap String Expr) (forallBody : Term) : MetaM Expr := do withLocalDeclsD (sortedVars.map fun (n, ty) => (n.toName, fun _ => pure ty)) fun _ => do let lctx ← getLCtx let mut symbolMap := symbolMap @@ -365,7 +365,7 @@ partial def parseForallBodyWithSortedVars (vs : List Term) (sortedVars : Array ( let body ← parseTerm forallBody symbolMap mustBeProp Meta.mkForallFVars (sortedVarDecls.map (fun decl => mkFVar decl.fvarId)) body -partial def parseForall (vs : List Term) (symbolMap : HashMap String Expr) : MetaM Expr := do +partial def parseForall (vs : List Term) (symbolMap : Std.HashMap String Expr) : MetaM Expr := do let [app sortedVars, forallBody] := vs | throwError "parseForall :: Unexpected input list {vs}" let sortedVars ← sortedVars.mapM (fun sv => parseSortedVar sv symbolMap) @@ -384,7 +384,7 @@ partial def parseForall (vs : List Term) (symbolMap : HashMap String Expr) : Met throwError "parseForall :: Failed to parse for all expression with vs: {vs}" partial def parseExistsBodyWithSortedVars (vs : List Term) (sortedVars : Array (String × Expr)) - (symbolMap : HashMap String Expr) (existsBody : Term) : MetaM Expr := do + (symbolMap : Std.HashMap String Expr) (existsBody : Term) : MetaM Expr := do withLocalDeclsD (sortedVars.map fun (n, ty) => (n.toName, fun _ => pure ty)) fun _ => do let lctx ← getLCtx let mut symbolMap := symbolMap @@ -401,7 +401,7 @@ partial def parseExistsBodyWithSortedVars (vs : List Term) (sortedVars : Array ( res ← Meta.mkAppM ``Exists #[res] return res -partial def parseExists (vs : List Term) (symbolMap : HashMap String Expr) : MetaM Expr := do +partial def parseExists (vs : List Term) (symbolMap : Std.HashMap String Expr) : MetaM Expr := do let [app sortedVars, existsBody] := vs | throwError "parseExists :: Unexpected input list {vs}" let sortedVars ← sortedVars.mapM (fun sv => parseSortedVar sv symbolMap) @@ -420,7 +420,7 @@ partial def parseExists (vs : List Term) (symbolMap : HashMap String Expr) : Met throwError "parseExists :: Failed to parse exists expression with vs: {vs}" /-- Given a varBinding of the form `(symbol value)` returns the string of the symbol, the type of the value, and the value itself -/ -partial def parseVarBinding (varBinding : Term) (symbolMap : HashMap String Expr) : MetaM (String × Expr × Expr) := do +partial def parseVarBinding (varBinding : Term) (symbolMap : Std.HashMap String Expr) : MetaM (String × Expr × Expr) := do match varBinding with | app varBinding => match varBinding with @@ -433,7 +433,7 @@ partial def parseVarBinding (varBinding : Term) (symbolMap : HashMap String Expr | _ => throwError "parseVarBinding :: Failed to parse {varBinding} as a var binding" | _ => throwError "parseVarBinding :: {varBinding} is supposed to be a varBinding, not an atom" -partial def parseLet (vs : List Term) (symbolMap : HashMap String Expr) (parseTermConstraint : ParseTermConstraint) : MetaM Expr := do +partial def parseLet (vs : List Term) (symbolMap : Std.HashMap String Expr) (parseTermConstraint : ParseTermConstraint) : MetaM Expr := do let [app varBindings, letBody] := vs | throwError "parsseLet :: Unexpected input list {vs}" let varBindings ← varBindings.mapM (fun vb => parseVarBinding vb symbolMap) @@ -453,7 +453,7 @@ partial def parseLet (vs : List Term) (symbolMap : HashMap String Expr) (parseTe res := .letE varBinding.1.toName varBinding.2.1 varBinding.2.2 res true return res -partial def parseLeftAssocAppAux (headSymbol : Name) (args : List Term) (symbolMap : HashMap String Expr) +partial def parseLeftAssocAppAux (headSymbol : Name) (args : List Term) (symbolMap : Std.HashMap String Expr) (acc : Expr) (parseTermConstraint : ParseTermConstraint) : MetaM Expr := do match args with | [] => return acc @@ -462,7 +462,7 @@ partial def parseLeftAssocAppAux (headSymbol : Name) (args : List Term) (symbolM let acc ← mkAppM headSymbol #[acc, arg] parseLeftAssocAppAux headSymbol restArgs symbolMap acc parseTermConstraint -partial def parseLeftAssocApp (headSymbol : Name) (args : List Term) (symbolMap : HashMap String Expr) +partial def parseLeftAssocApp (headSymbol : Name) (args : List Term) (symbolMap : Std.HashMap String Expr) (parseTermConstraint : ParseTermConstraint) : MetaM Expr := do match args with | arg1 :: (arg2 :: restArgs) => @@ -474,7 +474,7 @@ partial def parseLeftAssocApp (headSymbol : Name) (args : List Term) (symbolMap /-- Note: parseImplicationAux expects to receive args in reverse order (meaining if args = `[x, y, z]`, this should become `z => y => x`) -/ -partial def parseImplicationAux (args : List Term) (symbolMap : HashMap String Expr) (acc : Expr) : MetaM Expr := do +partial def parseImplicationAux (args : List Term) (symbolMap : Std.HashMap String Expr) (acc : Expr) : MetaM Expr := do match args with | [] => return acc | arg :: restArgs => @@ -483,7 +483,7 @@ partial def parseImplicationAux (args : List Term) (symbolMap : HashMap String E parseImplicationAux restArgs symbolMap acc /-- SMT implication is right associative -/ -partial def parseImplication (args : List Term) (symbolMap : HashMap String Expr) : MetaM Expr := do +partial def parseImplication (args : List Term) (symbolMap : Std.HashMap String Expr) : MetaM Expr := do match args.reverse with | lastArg :: (lastArg2 :: restArgs) => let lastArg ← parseTerm lastArg symbolMap mustBeProp @@ -492,7 +492,7 @@ partial def parseImplication (args : List Term) (symbolMap : HashMap String Expr /-- The entry function for the variety of mutually recursive functions used to parse SMT terms. `symbolMap` is used to map smt constants to the original Lean expressions they are meant to represent. `parseTermConstraint` is used to indicate whether the output expression must be a particular type. -/ -partial def parseTerm (e : Term) (symbolMap : HashMap String Expr) (parseTermConstraint : ParseTermConstraint) : MetaM Expr := do +partial def parseTerm (e : Term) (symbolMap : Std.HashMap String Expr) (parseTermConstraint : ParseTermConstraint) : MetaM Expr := do match e with | atom (num n) => match parseTermConstraint with @@ -512,7 +512,7 @@ partial def parseTerm (e : Term) (symbolMap : HashMap String Expr) (parseTermCon | mustBeProp => throwError "parseTerm :: {e} can be parsed but not as a Prop" | mustBeBool => throwError "parseTerm :: {e} can be parsed but not as a Bool" | atom (symb s) => - match symbolMap.find? s with + match symbolMap.get? s with | some v => match parseTermConstraint with | noConstraint => return v @@ -533,7 +533,7 @@ partial def parseTerm (e : Term) (symbolMap : HashMap String Expr) (parseTermCon else throwError "parseTerm :: {e} is parsed as {v} which is not a Bool" | none => - match builtInSymbolMap.find? s with + match builtInSymbolMap.get? s with | some v => match parseTermConstraint with | noConstraint => return v @@ -619,7 +619,7 @@ partial def parseTerm (e : Term) (symbolMap : HashMap String Expr) (parseTermCon whnf $ ← mkAppOptM ``decide #[some res, none] else throwError "parseTerm :: {e} is parsed as {res} which is not a Bool" - | arg1 :: (arg2 :: restArgs) => + | _arg1 :: (_arg2 :: _restArgs) => -- **TODO**: Interpret `(< a b c)` as `(and (< a b) (< b c))` throwError "parseTerm :: TwoExact symbol with more than two arguments not implemented yet (e: {e})" | _ => throwError "parseTerm :: Invalid application {e}" @@ -642,7 +642,7 @@ partial def parseTerm (e : Term) (symbolMap : HashMap String Expr) (parseTermCon | noConstraint => return res | mustBeProp => return res | mustBeBool => whnf $ ← mkAppOptM ``decide #[some res, none] - | arg1 :: (arg2 :: restArgs) => + | _arg1 :: (_arg2 :: _restArgs) => -- **TODO**: Interpret `(= a b c)` as `(and (= a b) (= b c))` throwError "parseTerm :: TwoExact symbol with more than two arguments not implemented yet (e: {e})" | _ => throwError "parseTerm :: Invalid application {e}" @@ -682,7 +682,7 @@ partial def parseTerm (e : Term) (symbolMap : HashMap String Expr) (parseTermCon | mustBeProp => throwError "parseTerm :: {e} has minus as a head symbol which cannot yield a result of type Prop" | mustBeBool => throwError "parseTerm :: {e} has minus as a head symbol which cannot yield a result of type Bool" | [] => - match symbolMap.find? s with + match symbolMap.get? s with | some symbolExp => let symbolExpType ← inferType symbolExp let expectedArgTypes := getExplicitForallArgumentTypes symbolExpType @@ -724,13 +724,13 @@ initialize /-- Calls `parseTerm` on `e` and then abstracts all of the metavariables corresponding to selectors given by `selMVars` (replacing the first metavariable with `selMVars` with `Expr.bvar 0` and so on) -/ -def parseTermAndAbstractSelectors (e : Term) (symbolMap : HashMap String Expr) (selMVars : Array Expr) : MetaM Expr := do +def parseTermAndAbstractSelectors (e : Term) (symbolMap : Std.HashMap String Expr) (selMVars : Array Expr) : MetaM Expr := do let res ← parseTerm e symbolMap noConstraint res.abstractM selMVars /-- Calls `parseTerm` on `e` and then abstracts all of the metavariables corresponding to selectors given by `selMVars` (replacing the first metavariable with `selMVars` with `Expr.bvar 0` and so on). Returns `none` if any error occurs. -/ -def tryParseTermAndAbstractSelectors (e : Term) (symbolMap : HashMap String Expr) (selMVars : Array Expr) : MetaM (Option Expr) := do +def tryParseTermAndAbstractSelectors (e : Term) (symbolMap : Std.HashMap String Expr) (selMVars : Array Expr) : MetaM (Option Expr) := do try let res ← parseTerm e symbolMap noConstraint res.abstractM selMVars diff --git a/Auto/Solver/SMT.lean b/Auto/Solver/SMT.lean index 383b152..c094ae7 100644 --- a/Auto/Solver/SMT.lean +++ b/Auto/Solver/SMT.lean @@ -222,6 +222,7 @@ def querySolverWithHints (query : Array IR.SMT.Command) emitCommand solver .checkSat let stdout ← solver.stdout.getLine trace[auto.smt.result] "checkSatResponse: {stdout}" + -- **TODO** When checkSatResponse is sat, the below getTerm call can throw an error let (checkSatResponse, _) ← getTerm stdout match checkSatResponse with | .atom (.symb "sat") => @@ -229,12 +230,15 @@ def querySolverWithHints (query : Array IR.SMT.Command) let (_, solver) ← solver.takeStdin let stdout ← solver.stdout.readToEnd let stderr ← solver.stderr.readToEnd - let (model, _) ← getTerm stdout solver.kill - trace[auto.smt.result] "{name} says Sat" - trace[auto.smt.model] "Model:\n{model}" - trace[auto.smt.stderr] "stderr:\n{stderr}" - return .none + try + let (model, _) ← getTerm stdout + trace[auto.smt.result] "{name} says Sat" + trace[auto.smt.model] "Model:\n{model}" + trace[auto.smt.stderr] "stderr:\n{stderr}" + return .none + catch _ => -- Don't let a failure to parse the model prevent `querySolverWithHints` from returning `none` + return .none | .atom (.symb "unsat") => emitCommand solver (.echo "Unsat core:") emitCommand solver .getUnsatCore diff --git a/Auto/Tactic.lean b/Auto/Tactic.lean index 904dc0d..065e323 100644 --- a/Auto/Tactic.lean +++ b/Auto/Tactic.lean @@ -172,7 +172,7 @@ def collectUserLemmas (terms : Array Term) : TacticM (Array Lemma) := return lemmas def collectHintDBLemmas (names : Array Name) : TacticM (Array Lemma) := do - let mut hs : HashSet Name := HashSet.empty + let mut hs : Std.HashSet Name := Std.HashSet.empty let mut ret : Array Lemma := #[] for name in names do let .some db ← findLemDB name @@ -427,7 +427,7 @@ def querySMTForHints (exportFacts : Array REntry) (exportInds : Array MutualIndI let lamEVarTy ← LamReif.getLamEVarTy let exportLamTerms ← exportFacts.mapM (fun re => do match re with - | .valid [] t => return t + | .valid [] t => pure t | _ => throwError "runAuto :: Unexpected error") let sni : SMT.SMTNamingInfo := {tyVal := (← LamReif.getTyVal), varVal := (← LamReif.getVarVal), lamEVarTy := (← LamReif.getLamEVarTy)} @@ -436,8 +436,8 @@ def querySMTForHints (exportFacts : Array REntry) (exportInds : Array MutualIndI trace[auto.smt.printCommands] "{cmd}" if (auto.smt.save.get (← getOptions)) then Solver.SMT.saveQuery commands - let .some (unsatCore, solverHints, proof) ← Solver.SMT.querySolverWithHints commands - | return .none + let some (unsatCore, solverHints, _proof) ← Solver.SMT.querySolverWithHints commands + | return none let unsatCoreIds ← Solver.SMT.validFactOfUnsatCore unsatCore -- **Print valuation of SMT atoms** SMT.withExprValuation sni state.h2lMap (fun tyValMap varValMap etomValMap => do @@ -447,13 +447,13 @@ def querySMTForHints (exportFacts : Array REntry) (exportInds : Array MutualIndI ) -- **Print STerms corresponding to `validFacts` in unsatCore** for id in unsatCoreIds do - let .some sterm := validFacts[id]? + let some sterm := validFacts[id]? | throwError "runAuto :: Index {id} of `validFacts` out of range" trace[auto.smt.unsatCore.smtTerms] "|valid_fact_{id}| : {sterm}" -- **Print Lean expressions correesponding to `validFacts` in unsatCore** SMT.withExprValuation sni state.h2lMap (fun tyValMap varValMap etomValMap => do for id in unsatCoreIds do - let .some t := exportLamTerms[id]? + let some t := exportLamTerms[id]? | throwError "runAuto :: Index {id} of `exportLamTerms` out of range" let e ← Lam2D.interpLamTermAsUnlifted tyValMap varValMap etomValMap 0 t trace[auto.smt.unsatCore.leanExprs] "|valid_fact_{id}| : {← Core.betaReduce e}" @@ -462,14 +462,14 @@ def querySMTForHints (exportFacts : Array REntry) (exportInds : Array MutualIndI -- `unsatCoreDerivLeafStrings` contains all of the strings that appear as leaves in any derivation for any fact in the unsat core let mut unsatCoreDerivLeafStrings := #[] for id in unsatCoreIds do - let .some t := exportLamTerms[id]? + let some t := exportLamTerms[id]? | throwError "runAuto :: Index {id} of `exportLamTerm` out of range" let vderiv ← LamReif.collectDerivFor (.valid [] t) unsatCoreDerivLeafStrings := unsatCoreDerivLeafStrings ++ vderiv.collectLeafStrings trace[auto.smt.unsatCore.deriv] "|valid_fact_{id}| : {vderiv}" -- **Build symbolPrecMap using l2hMap and selInfos** let (preprocessFacts, theoryLemmas, instantiations, computationLemmas, polynomialLemmas, rewriteFacts) := solverHints - let mut symbolMap : HashMap String Expr := HashMap.empty + let mut symbolMap : Std.HashMap String Expr := Std.HashMap.empty for (varName, varAtom) in l2hMap.toArray do let varLeanExp ← SMT.withExprValuation sni state.h2lMap (fun tyValMap varValMap etomValMap => do @@ -488,7 +488,7 @@ def querySMTForHints (exportFacts : Array REntry) (exportInds : Array MutualIndI let selOutputType ← SMT.withExprValuation sni state.h2lMap (fun tyValMap _ _ => Lam2D.interpLamSortAsUnlifted tyValMap selOutputType) let selDatatype ← - match symbolMap.find? datatypeName with + match symbolMap.get? datatypeName with | some selDatatype => pure selDatatype | none => throwError "querySMTForHints :: Could not find the datatype {datatypeName} corresponding to selector {selName}" let selType := Expr.forallE `x selDatatype selOutputType .default @@ -501,28 +501,31 @@ def querySMTForHints (exportFacts : Array REntry) (exportInds : Array MutualIndI (fun (selName, selCtor, argIdx, selMVar) => return (selName, selCtor, argIdx, ← Meta.inferType selMVar)) -- **Extract solverLemmas from solverHints** if ← auto.getHints.getFailOnParseErrorM then - let preprocessFacts ← preprocessFacts.mapM - (fun lemTerm => Parser.SMTTerm.parseTermAndAbstractSelectors lemTerm symbolMap selectorMVars) - let theoryLemmas ← theoryLemmas.mapM - (fun lemTerm => Parser.SMTTerm.parseTermAndAbstractSelectors lemTerm symbolMap selectorMVars) - let instantiations ← instantiations.mapM - (fun lemTerm => Parser.SMTTerm.parseTermAndAbstractSelectors lemTerm symbolMap selectorMVars) - let computationLemmas ← computationLemmas.mapM - (fun lemTerm => Parser.SMTTerm.parseTermAndAbstractSelectors lemTerm symbolMap selectorMVars) - let polynomialLemmas ← polynomialLemmas.mapM - (fun lemTerm => Parser.SMTTerm.parseTermAndAbstractSelectors lemTerm symbolMap selectorMVars) - let rewriteFacts ← rewriteFacts.mapM - (fun rwFacts => do - match rwFacts with - | [] => return [] - | rwRule :: ruleInstances => - /- Try to parse `rwRule`. If succesful, just return that. If unsuccessful (e.g. because the rule contains approximate types), - then parse each quantifier-free instance of `rwRule` in `ruleInstances` and return all of those. -/ - match ← Parser.SMTTerm.tryParseTermAndAbstractSelectors rwRule symbolMap selectorMVars with - | some parsedRule => return [parsedRule] - | none => ruleInstances.mapM (fun lemTerm => Parser.SMTTerm.parseTermAndAbstractSelectors lemTerm symbolMap selectorMVars) - ) - return some (unsatCoreDerivLeafStrings, selectorArr, preprocessFacts, theoryLemmas, instantiations, computationLemmas, polynomialLemmas, rewriteFacts) + try + let preprocessFacts ← preprocessFacts.mapM + (fun lemTerm => Parser.SMTTerm.parseTermAndAbstractSelectors lemTerm symbolMap selectorMVars) + let theoryLemmas ← theoryLemmas.mapM + (fun lemTerm => Parser.SMTTerm.parseTermAndAbstractSelectors lemTerm symbolMap selectorMVars) + let instantiations ← instantiations.mapM + (fun lemTerm => Parser.SMTTerm.parseTermAndAbstractSelectors lemTerm symbolMap selectorMVars) + let computationLemmas ← computationLemmas.mapM + (fun lemTerm => Parser.SMTTerm.parseTermAndAbstractSelectors lemTerm symbolMap selectorMVars) + let polynomialLemmas ← polynomialLemmas.mapM + (fun lemTerm => Parser.SMTTerm.parseTermAndAbstractSelectors lemTerm symbolMap selectorMVars) + let rewriteFacts ← rewriteFacts.mapM + (fun rwFacts => do + match rwFacts with + | [] => return [] + | rwRule :: ruleInstances => + /- Try to parse `rwRule`. If succesful, just return that. If unsuccessful (e.g. because the rule contains approximate types), + then parse each quantifier-free instance of `rwRule` in `ruleInstances` and return all of those. -/ + match ← Parser.SMTTerm.tryParseTermAndAbstractSelectors rwRule symbolMap selectorMVars with + | some parsedRule => return [parsedRule] + | none => ruleInstances.mapM (fun lemTerm => Parser.SMTTerm.parseTermAndAbstractSelectors lemTerm symbolMap selectorMVars) + ) + return some (unsatCoreDerivLeafStrings, selectorArr, preprocessFacts, theoryLemmas, instantiations, computationLemmas, polynomialLemmas, rewriteFacts) + catch e => + throwError "querySMTForHints :: Encountered error trying to parse SMT solver's hints. Error: {e.toMessageData}" else let preprocessFacts ← preprocessFacts.mapM (fun lemTerm => Parser.SMTTerm.tryParseTermAndAbstractSelectors lemTerm symbolMap selectorMVars) let theoryLemmas ← theoryLemmas.mapM (fun lemTerm => Parser.SMTTerm.tryParseTermAndAbstractSelectors lemTerm symbolMap selectorMVars) @@ -751,7 +754,7 @@ def runAutoGetHints (lemmas : Array Lemma) (inhFacts : Array Lemma) : MetaM solv if let .some solverHints ← querySMTForHints exportFacts exportInds then return solverHints else - throwError "runAutoGetHints :: querySMTForHints failed to return solverHints" + throwError "runAutoGetHints :: SMT solver was unable to find a proof" else throwError "runAutoGetHints :: Either auto.smt or auto.tptp must be enabled" )