From 7f764cddacab9ea00913aa0cb8e66ed3ebbbda49 Mon Sep 17 00:00:00 2001 From: Martin Blicha Date: Fri, 3 May 2024 09:02:34 +0200 Subject: [PATCH] Preprocessing: Fix incremental preprocessing for theory combination After recent changes in preprocessing, theory specific rewriting now don't see the whole formula, only parts that have not been processed yet. This, however, broke an assumption made by theory combination. Previously, it assumed it will see the whole formula, and it extracts interface variables from it. To fix this problem, we remember all processed formulas and notify Theory about them at the end of each preprocessing. This means, we do not change how we extract and set interface variables, but we need to keep this extra information about preprocessed formulas. We also start storing information related to preprocessing in a new Preprocessor class. We should eventually move all preprocessing related matters out of MainSolver and into this helper class. --- src/api/MainSolver.cc | 36 ++++++++++++++++++++---- src/api/MainSolver.h | 59 ++++++++++++++++++++++++--------------- src/common/CMakeLists.txt | 2 +- src/logics/Theory.h | 1 + src/logics/UFLATheory.cc | 9 ++++-- src/logics/UFLATheory.h | 3 +- 6 files changed, 77 insertions(+), 33 deletions(-) diff --git a/src/api/MainSolver.cc b/src/api/MainSolver.cc index aea1af696..a531a5a73 100644 --- a/src/api/MainSolver.cc +++ b/src/api/MainSolver.cc @@ -60,7 +60,7 @@ MainSolver::MainSolver(std::unique_ptr th, std::unique_ptr t void MainSolver::initialize() { frames.push(); frameTerms.push(logic.getTerm_true()); - substitutions.push(); + preprocessor.initialize(); smt_solver->initialize(); opensmt::pair iorefs{CRef_Undef, CRef_Undef}; smt_solver->addOriginalSMTClause({term_mapper->getOrCreateLit(logic.getTerm_true())}, iorefs); @@ -74,7 +74,7 @@ void MainSolver::push() { bool alreadyUnsat = isLastFrameUnsat(); frames.push(); - substitutions.push(); + preprocessor.push(); frameTerms.push(newFrameTerm(frames.last().getId())); if (alreadyUnsat) { rememberLastFrameUnsat(); } } @@ -93,7 +93,7 @@ bool MainSolver::pop() pmanager.invalidatePartitions(mask); } frames.pop(); - substitutions.pop(); + preprocessor.pop(); firstNotSimplifiedFrame = std::min(firstNotSimplifiedFrame, frames.frameCount()); if (not isLastFrameUnsat()) { getSMTSolver().restoreOK(); @@ -135,10 +135,12 @@ sstat MainSolver::simplifyFormulas() { PTRef processed = theory->preprocessAfterSubstitutions(fla, context); pmanager.transferPartitionMembership(fla, processed); frameFormulas.push(processed); + preprocessor.addPreprocessedFormula(processed); } if (frameFormulas.size() == 0 or std::all_of(frameFormulas.begin(), frameFormulas.end(), [&](PTRef fla) { return fla == logic.getTerm_true(); })) { continue; } + theory->afterPreprocessing(preprocessor.getPreprocessedFormulas()); for (PTRef fla : frameFormulas) { if (fla == logic.getTerm_true()) { continue; } assert(pmanager.getPartitionIndex(fla) != -1); @@ -167,6 +169,8 @@ sstat MainSolver::simplifyFormulas() { status = s_False; break; } + preprocessor.addPreprocessedFormula(frameFormula); + theory->afterPreprocessing(preprocessor.getPreprocessedFormulas()); // Optimize the dag for cnfization if (logic.isBooleanOperator(frameFormula)) { frameFormula = rewriteMaxArity(frameFormula); @@ -481,7 +485,7 @@ std::unique_ptr MainSolver::createTheory(Logic & logic, SMTConfig & conf } PTRef MainSolver::applyLearntSubstitutions(PTRef fla) { - Logic::SubstMap knownSubst = substitutions.current(); + Logic::SubstMap knownSubst = preprocessor.getCurrentSubstitutions(); PTRef res = Substitutor(getLogic(), knownSubst).rewrite(fla); return res; } @@ -498,7 +502,7 @@ PTRef MainSolver::substitutionPass(PTRef fla, PreprocessingContext const& contex args.push(res.result); PTRef withSubs = logic.mkAnd(std::move(args)); - substitutions.set(context.frameCount, std::move(res.usedSubstitution)); + preprocessor.setSubstitutions(context.frameCount, std::move(res.usedSubstitution)); return withSubs; } @@ -550,3 +554,25 @@ MainSolver::SubstitutionResult MainSolver::computeSubstitutions(PTRef fla) { result.usedSubstitution = std::move(allsubsts); return result; } + +void MainSolver::Preprocessor::initialize() { + substitutions.push(); +} + +void MainSolver::Preprocessor::push() { + substitutions.push(); + preprocessedFormulas.pushScope(); +} + +void MainSolver::Preprocessor::pop() { + substitutions.pop(); + preprocessedFormulas.popScope(); +} + +void MainSolver::Preprocessor::addPreprocessedFormula(PTRef fla) { + preprocessedFormulas.push(fla); +} + +opensmt::span MainSolver::Preprocessor::getPreprocessedFormulas() const { + return {preprocessedFormulas.data(), static_cast(preprocessedFormulas.size())}; +} \ No newline at end of file diff --git a/src/api/MainSolver.h b/src/api/MainSolver.h index 70b48d84b..88758061e 100644 --- a/src/api/MainSolver.h +++ b/src/api/MainSolver.h @@ -15,6 +15,7 @@ #include "Model.h" #include "PartitionManager.h" #include "InterpolationContext.h" +#include "ScopedVector.h" #include @@ -157,32 +158,44 @@ class MainSolver { uint32_t frameId = 0; }; - class Substitutions { - public: - void push() { perFrameSubst.emplace_back(); } - void pop() { perFrameSubst.pop_back(); } + struct SubstitutionResult { + Logic::SubstMap usedSubstitution; + PTRef result {PTRef_Undef}; + }; - void set(std::size_t level, Logic::SubstMap && subs) { - perFrameSubst.at(level) = std::move(subs); - } + class Preprocessor { + public: + void push(); + void pop(); + void initialize(); + void addPreprocessedFormula(PTRef); + [[nodiscard]] opensmt::span getPreprocessedFormulas() const; + [[nodiscard]] Logic::SubstMap getCurrentSubstitutions() const { return substitutions.current(); } + void setSubstitutions(std::size_t level, Logic::SubstMap && subs) { substitutions.set(level, std::move(subs)); } - Logic::SubstMap current() { - Logic::SubstMap allSubst; - for (auto const & subs : perFrameSubst) { - for (PTRef key : subs.getKeys()) { - assert(not allSubst.has(key)); - allSubst.insert(key, subs[key]); + private: + class Substitutions { + public: + void push() { perFrameSubst.emplace_back(); } + void pop() { perFrameSubst.pop_back(); } + + void set(std::size_t level, Logic::SubstMap && subs) { perFrameSubst.at(level) = std::move(subs); } + + [[nodiscard]] Logic::SubstMap current() const { + Logic::SubstMap allSubst; + for (auto const & subs : perFrameSubst) { + for (PTRef key : subs.getKeys()) { + assert(not allSubst.has(key)); + allSubst.insert(key, subs[key]); + } } + return allSubst; } - return allSubst; - } - private: - std::vector perFrameSubst; - }; - - struct SubstitutionResult { - Logic::SubstMap usedSubstitution; - PTRef result {PTRef_Undef}; + private: + std::vector perFrameSubst; + }; + Substitutions substitutions; + opensmt::ScopedVector preprocessedFormulas; }; Theory & getTheory() { return *theory; } @@ -242,12 +255,12 @@ class MainSolver { PartitionManager pmanager; SMTConfig & config; Tseitin ts; + Preprocessor preprocessor; opensmt::OSMTTimeVal query_timer; // How much time we spend solving. std::string solver_name; // Name for the solver int check_called = 0; // A counter on how many times check was called. - Substitutions substitutions; vec frameTerms; std::size_t firstNotSimplifiedFrame = 0; unsigned int insertedFormulasCount = 0; diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index 57bc56054..20447fdc9 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -24,7 +24,7 @@ target_sources(common install(FILES Integer.h Number.h FastRational.h XAlloc.h Alloc.h StringMap.h Timer.h osmtinttypes.h TreeOps.h Real.h FlaPartitionMap.h PartitionInfo.h OsmtApiException.h TypeUtils.h - NumberUtils.h NatSet.h + NumberUtils.h NatSet.h ScopedVector.h DESTINATION ${INSTALL_HEADERS_DIR}) diff --git a/src/logics/Theory.h b/src/logics/Theory.h index ac9bc5dfc..3d52f526f 100644 --- a/src/logics/Theory.h +++ b/src/logics/Theory.h @@ -57,6 +57,7 @@ class Theory virtual PTRef preprocessBeforeSubstitutions(PTRef fla, PreprocessingContext const &) { return fla; } virtual PTRef preprocessAfterSubstitutions(PTRef, PreprocessingContext const &) = 0; + virtual void afterPreprocessing(opensmt::span) {} virtual ~Theory() = default; }; diff --git a/src/logics/UFLATheory.cc b/src/logics/UFLATheory.cc index 714068233..58ac0db3f 100644 --- a/src/logics/UFLATheory.cc +++ b/src/logics/UFLATheory.cc @@ -17,10 +17,13 @@ PTRef UFLATheory::preprocessAfterSubstitutions(PTRef fla, PreprocessingContext c purified = instantiateReadOverStore(logic, purified); } PTRef noArithmeticEqualities = splitArithmeticEqualities(purified); - this->getTSolverHandler().setInterfaceVars(getInterfaceVars(noArithmeticEqualities)); return noArithmeticEqualities; } +void UFLATheory::afterPreprocessing(opensmt::span preprocessedFormulas) { + this->getTSolverHandler().setInterfaceVars(getInterfaceVars(preprocessedFormulas)); +} + namespace { bool isArithmeticSymbol(ArithLogic const & logic, SymRef sym) { return logic.isPlus(sym) or logic.isTimes(sym) or logic.isLeq(sym); @@ -201,9 +204,9 @@ class CollectInterfaceVariablesConfig : public DefaultVisitorConfig { -vec UFLATheory::getInterfaceVars(PTRef fla) { +vec UFLATheory::getInterfaceVars(opensmt::span flas) { CollectInterfaceVariablesConfig config(logic); - TermVisitor(logic, config).visit(fla); + TermVisitor(logic, config).visit(flas); vec const & interfaceVars = config.getInterfaceVars(); vec ret; interfaceVars.copyTo(ret); diff --git a/src/logics/UFLATheory.h b/src/logics/UFLATheory.h index eabc410c1..6b7cba848 100644 --- a/src/logics/UFLATheory.h +++ b/src/logics/UFLATheory.h @@ -46,11 +46,12 @@ class UFLATheory : public Theory virtual UFLATHandler& getTSolverHandler() override { return uflatshandler; } virtual PTRef preprocessAfterSubstitutions(PTRef, PreprocessingContext const &) override; + virtual void afterPreprocessing(opensmt::span preprocessedFormulas) override; protected: PTRef purify(PTRef fla); PTRef splitArithmeticEqualities(PTRef fla); - vec getInterfaceVars(PTRef fla); + vec getInterfaceVars(opensmt::span flas); }; #endif