Skip to content

Commit

Permalink
Preprocessing: Fix incremental preprocessing for theory combination
Browse files Browse the repository at this point in the history
After recent changes in preprocessing, theory specific rewriting now
don't see the whole formula, only parts that have not been processed
yet. This, however, broke an assumption made by theory combination.
Previously, it assumed it will see the whole formula, and it extracts
interface variables from it.

To fix this problem, we remember all processed formulas and notify
Theory about them at the end of each preprocessing. This means, we do
not change how we extract and set interface variables, but we need to
keep this extra information about preprocessed formulas.

We also start storing information related to preprocessing in a new
Preprocessor class. We should eventually move all preprocessing related
matters out of MainSolver and into this helper class.
  • Loading branch information
blishko committed May 7, 2024
1 parent d3c57fa commit 7f764cd
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 33 deletions.
36 changes: 31 additions & 5 deletions src/api/MainSolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ MainSolver::MainSolver(std::unique_ptr<Theory> th, std::unique_ptr<TermMapper> t
void MainSolver::initialize() {
frames.push();
frameTerms.push(logic.getTerm_true());
substitutions.push();
preprocessor.initialize();
smt_solver->initialize();
opensmt::pair<CRef, CRef> iorefs{CRef_Undef, CRef_Undef};
smt_solver->addOriginalSMTClause({term_mapper->getOrCreateLit(logic.getTerm_true())}, iorefs);
Expand All @@ -74,7 +74,7 @@ void MainSolver::push()
{
bool alreadyUnsat = isLastFrameUnsat();
frames.push();
substitutions.push();
preprocessor.push();
frameTerms.push(newFrameTerm(frames.last().getId()));
if (alreadyUnsat) { rememberLastFrameUnsat(); }
}
Expand All @@ -93,7 +93,7 @@ bool MainSolver::pop()
pmanager.invalidatePartitions(mask);
}
frames.pop();
substitutions.pop();
preprocessor.pop();
firstNotSimplifiedFrame = std::min(firstNotSimplifiedFrame, frames.frameCount());
if (not isLastFrameUnsat()) {
getSMTSolver().restoreOK();
Expand Down Expand Up @@ -135,10 +135,12 @@ sstat MainSolver::simplifyFormulas() {
PTRef processed = theory->preprocessAfterSubstitutions(fla, context);
pmanager.transferPartitionMembership(fla, processed);
frameFormulas.push(processed);
preprocessor.addPreprocessedFormula(processed);
}
if (frameFormulas.size() == 0 or std::all_of(frameFormulas.begin(), frameFormulas.end(), [&](PTRef fla) { return fla == logic.getTerm_true(); })) {
continue;
}
theory->afterPreprocessing(preprocessor.getPreprocessedFormulas());
for (PTRef fla : frameFormulas) {
if (fla == logic.getTerm_true()) { continue; }
assert(pmanager.getPartitionIndex(fla) != -1);
Expand Down Expand Up @@ -167,6 +169,8 @@ sstat MainSolver::simplifyFormulas() {
status = s_False;
break;
}
preprocessor.addPreprocessedFormula(frameFormula);
theory->afterPreprocessing(preprocessor.getPreprocessedFormulas());
// Optimize the dag for cnfization
if (logic.isBooleanOperator(frameFormula)) {
frameFormula = rewriteMaxArity(frameFormula);
Expand Down Expand Up @@ -481,7 +485,7 @@ std::unique_ptr<Theory> MainSolver::createTheory(Logic & logic, SMTConfig & conf
}

PTRef MainSolver::applyLearntSubstitutions(PTRef fla) {
Logic::SubstMap knownSubst = substitutions.current();
Logic::SubstMap knownSubst = preprocessor.getCurrentSubstitutions();
PTRef res = Substitutor(getLogic(), knownSubst).rewrite(fla);
return res;
}
Expand All @@ -498,7 +502,7 @@ PTRef MainSolver::substitutionPass(PTRef fla, PreprocessingContext const& contex
args.push(res.result);
PTRef withSubs = logic.mkAnd(std::move(args));

substitutions.set(context.frameCount, std::move(res.usedSubstitution));
preprocessor.setSubstitutions(context.frameCount, std::move(res.usedSubstitution));
return withSubs;
}

Expand Down Expand Up @@ -550,3 +554,25 @@ MainSolver::SubstitutionResult MainSolver::computeSubstitutions(PTRef fla) {
result.usedSubstitution = std::move(allsubsts);
return result;
}

void MainSolver::Preprocessor::initialize() {
substitutions.push();
}

void MainSolver::Preprocessor::push() {
substitutions.push();
preprocessedFormulas.pushScope();
}

void MainSolver::Preprocessor::pop() {
substitutions.pop();
preprocessedFormulas.popScope();
}

void MainSolver::Preprocessor::addPreprocessedFormula(PTRef fla) {
preprocessedFormulas.push(fla);
}

opensmt::span<const PTRef> MainSolver::Preprocessor::getPreprocessedFormulas() const {
return {preprocessedFormulas.data(), static_cast<uint32_t>(preprocessedFormulas.size())};
}
59 changes: 36 additions & 23 deletions src/api/MainSolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "Model.h"
#include "PartitionManager.h"
#include "InterpolationContext.h"
#include "ScopedVector.h"

#include <memory>

Expand Down Expand Up @@ -157,32 +158,44 @@ class MainSolver {
uint32_t frameId = 0;
};

class Substitutions {
public:
void push() { perFrameSubst.emplace_back(); }
void pop() { perFrameSubst.pop_back(); }
struct SubstitutionResult {
Logic::SubstMap usedSubstitution;
PTRef result {PTRef_Undef};
};

void set(std::size_t level, Logic::SubstMap && subs) {
perFrameSubst.at(level) = std::move(subs);
}
class Preprocessor {
public:
void push();
void pop();
void initialize();
void addPreprocessedFormula(PTRef);
[[nodiscard]] opensmt::span<const PTRef> getPreprocessedFormulas() const;
[[nodiscard]] Logic::SubstMap getCurrentSubstitutions() const { return substitutions.current(); }
void setSubstitutions(std::size_t level, Logic::SubstMap && subs) { substitutions.set(level, std::move(subs)); }

Logic::SubstMap current() {
Logic::SubstMap allSubst;
for (auto const & subs : perFrameSubst) {
for (PTRef key : subs.getKeys()) {
assert(not allSubst.has(key));
allSubst.insert(key, subs[key]);
private:
class Substitutions {
public:
void push() { perFrameSubst.emplace_back(); }
void pop() { perFrameSubst.pop_back(); }

void set(std::size_t level, Logic::SubstMap && subs) { perFrameSubst.at(level) = std::move(subs); }

[[nodiscard]] Logic::SubstMap current() const {
Logic::SubstMap allSubst;
for (auto const & subs : perFrameSubst) {
for (PTRef key : subs.getKeys()) {
assert(not allSubst.has(key));
allSubst.insert(key, subs[key]);
}
}
return allSubst;
}
return allSubst;
}
private:
std::vector<Logic::SubstMap> perFrameSubst;
};

struct SubstitutionResult {
Logic::SubstMap usedSubstitution;
PTRef result {PTRef_Undef};
private:
std::vector<Logic::SubstMap> perFrameSubst;
};
Substitutions substitutions;
opensmt::ScopedVector<PTRef> preprocessedFormulas;
};

Theory & getTheory() { return *theory; }
Expand Down Expand Up @@ -242,12 +255,12 @@ class MainSolver {
PartitionManager pmanager;
SMTConfig & config;
Tseitin ts;
Preprocessor preprocessor;

opensmt::OSMTTimeVal query_timer; // How much time we spend solving.
std::string solver_name; // Name for the solver
int check_called = 0; // A counter on how many times check was called.

Substitutions substitutions;
vec<PTRef> frameTerms;
std::size_t firstNotSimplifiedFrame = 0;
unsigned int insertedFormulasCount = 0;
Expand Down
2 changes: 1 addition & 1 deletion src/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ target_sources(common
install(FILES
Integer.h Number.h FastRational.h XAlloc.h Alloc.h StringMap.h Timer.h osmtinttypes.h
TreeOps.h Real.h FlaPartitionMap.h PartitionInfo.h OsmtApiException.h TypeUtils.h
NumberUtils.h NatSet.h
NumberUtils.h NatSet.h ScopedVector.h
DESTINATION ${INSTALL_HEADERS_DIR})


1 change: 1 addition & 0 deletions src/logics/Theory.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class Theory

virtual PTRef preprocessBeforeSubstitutions(PTRef fla, PreprocessingContext const &) { return fla; }
virtual PTRef preprocessAfterSubstitutions(PTRef, PreprocessingContext const &) = 0;
virtual void afterPreprocessing(opensmt::span<const PTRef>) {}

virtual ~Theory() = default;
};
Expand Down
9 changes: 6 additions & 3 deletions src/logics/UFLATheory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@ PTRef UFLATheory::preprocessAfterSubstitutions(PTRef fla, PreprocessingContext c
purified = instantiateReadOverStore(logic, purified);
}
PTRef noArithmeticEqualities = splitArithmeticEqualities(purified);
this->getTSolverHandler().setInterfaceVars(getInterfaceVars(noArithmeticEqualities));
return noArithmeticEqualities;
}

void UFLATheory::afterPreprocessing(opensmt::span<const PTRef> preprocessedFormulas) {
this->getTSolverHandler().setInterfaceVars(getInterfaceVars(preprocessedFormulas));
}

namespace {
bool isArithmeticSymbol(ArithLogic const & logic, SymRef sym) {
return logic.isPlus(sym) or logic.isTimes(sym) or logic.isLeq(sym);
Expand Down Expand Up @@ -201,9 +204,9 @@ class CollectInterfaceVariablesConfig : public DefaultVisitorConfig {



vec<PTRef> UFLATheory::getInterfaceVars(PTRef fla) {
vec<PTRef> UFLATheory::getInterfaceVars(opensmt::span<const PTRef> flas) {
CollectInterfaceVariablesConfig config(logic);
TermVisitor(logic, config).visit(fla);
TermVisitor(logic, config).visit(flas);
vec<PTRef> const & interfaceVars = config.getInterfaceVars();
vec<PTRef> ret;
interfaceVars.copyTo(ret);
Expand Down
3 changes: 2 additions & 1 deletion src/logics/UFLATheory.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@ class UFLATheory : public Theory
virtual UFLATHandler& getTSolverHandler() override { return uflatshandler; }

virtual PTRef preprocessAfterSubstitutions(PTRef, PreprocessingContext const &) override;
virtual void afterPreprocessing(opensmt::span<const PTRef> preprocessedFormulas) override;

protected:
PTRef purify(PTRef fla);
PTRef splitArithmeticEqualities(PTRef fla);
vec<PTRef> getInterfaceVars(PTRef fla);
vec<PTRef> getInterfaceVars(opensmt::span<const PTRef> flas);
};

#endif

0 comments on commit 7f764cd

Please sign in to comment.