Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModelValueCache rework #2675

Merged
merged 8 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.caches

import at.forsyte.apalache.tla.bmcmt.smt.SolverContext
import at.forsyte.apalache.tla.bmcmt.types.{CellT, CellTFrom}
import at.forsyte.apalache.tla.bmcmt.{ArenaCell, PureArena}
import at.forsyte.apalache.tla.lir.{ConstT1, StrT1, TlaType1}
import at.forsyte.apalache.tla.types.tla

/**
* A cache for uninterpreted literals, that are translated to uninterpreted SMT constants, with a unique sort per
* uninterpreted type. Since two values are equal iff they are literally the same literal, we force inequality between
* all the respective SMT constants.
*
* Note that Strings are just a special kind of uninterpreted type.
*
* @author
* Jure Kukovec
*/
class UninterpretedLiteralCache extends Cache[PureArena, (TlaType1, String), ArenaCell] {

/**
* Given a pair `(utype,idx)`, where `utype` represents an uninterpreted type name (possibly "Str") and `idx` some
* unique index within that type, returns an extension of `arena`, containing a cell, which represents "idx_OF_utype"
* (or "idx", if utype = "Str"), and said cell.
*
* Note that two values are equal (and get cached to the same cell) iff they have the same type and the same index, so
* e.g. "1_OF_A" and "1_OF_B" (passed here as ("A", "1") and ("B", "1")) get cached to different, incomparable cells,
* despite having the same index "1".
*/
protected def create(
arena: PureArena,
typeAndIndex: (TlaType1, String)): (PureArena, ArenaCell) = {
val (utype, _) = typeAndIndex
require(utype == StrT1 || utype.isInstanceOf[ConstT1], "Type must be Str, or an uninterpreted type.")
// introduce a new cell
val newArena = arena.appendCell(CellT.fromType1(utype))
(newArena, newArena.topCell)
}

/**
* The UninterpretedLiteralCache maintains that a cell cache for a value `idx` of type `tp` is distinct from all other
* values of type `tp` (defined so far).
*
* Whenever possible, try to use [[addAllConstraints]] instead of this method, for performance reasons instead:
*
* If we consider a naive implementation of `distinct(a1,..., an)` as `a1 != a2 /\ a1 != a3 /\ ... /\ a{n-1} != an`, a
* `distinct` with `n` elements is equivalent to `dn = n(n-1)/2` disequalities. Suppose we end up with a collection of
* `N` cache values (of a given type). If we called `addConstaintsForElem` after each addition, we'd end up with `d1 +
* d2 + ... + dN` disequalities, i.e. {{{\sum_{n=1}^N n(n-1)/2 = N(N^2 -1)/6}}} In contrast, `addAllConstraints`
* produces `dN = N(N-1)/2` disequalities, which is `O(N^2)`, instead of `O(N^3)`.
*/
override def addConstraintsForElem(ctx: SolverContext): (((TlaType1, String), ArenaCell)) => Unit = {
case ((utype, _), v) =>
require(utype == StrT1 || utype.isInstanceOf[ConstT1], "Type must be Str, or an uninterpreted type.")
val others = values().withFilter { c => c.cellType == CellTFrom(utype) && c != v }.map(_.toBuilder).toSeq
// The cell should differ from the previously created cells.
// We use the SMT constraint (distinct ...).
ctx.assertGroundExpr(tla.distinct(v.toBuilder +: others: _*))
shonfeder marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* A more efficient implementation, compared to the default one, as it introduces exactly one SMT `distinct` for each
* uninterpreted type instead of one `distinct` per cell.
*/
override def addAllConstraints(ctx: SolverContext): Unit = {
val utypes = cache.keySet.map { _._1 }

val initMap = utypes.map { _ -> Set.empty[ArenaCell] }.toMap

val cellsByUtype = cache.foldLeft(initMap) { case (map, ((utype, _), (cell, _))) =>
map + (utype -> (map(utype) + cell))
}

// For each utype, all cells of that type are distinct
cellsByUtype.foreach { case (_, cells) =>
ctx.assertGroundExpr(tla.distinct(cells.toSeq.map { _.toBuilder }: _*))
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux

import at.forsyte.apalache.tla.bmcmt.PureArena
import at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.caches.UninterpretedLiteralCache
import at.forsyte.apalache.tla.lir.{StrT1, TlaType1}
import at.forsyte.apalache.tla.types.{tla, ModelValueHandler}
import org.junit.runner.RunWith
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
import org.scalatestplus.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class UninterpretedLiteralCacheTest extends AnyFunSuite with BeforeAndAfterEach {

var cache: UninterpretedLiteralCache = new UninterpretedLiteralCache

def tpAndIdx(s: String): (TlaType1, String) = {
val (utype, idx) = ModelValueHandler.typeAndIndex(s).getOrElse((StrT1, s))
(utype, idx)
}

override def beforeEach(): Unit = {
cache = new UninterpretedLiteralCache
}

test("Cache returns stored values after the first call to getOrCreate") {
val str: String = "idx"

val utypeAndIdx = tpAndIdx(str)

val arena = PureArena.empty

// No cached value for the pair
assert(cache.get(utypeAndIdx).isEmpty)

val (newArena, iCell) = cache.getOrCreate(arena, utypeAndIdx)

// pair now cached, arena has changed
assert(cache.get(utypeAndIdx).nonEmpty && newArena != arena)

val (newArena2, iCell2) = cache.getOrCreate(newArena, utypeAndIdx)

// 2nd call returns the _same_ arena and the previously computed cell
assert(newArena == newArena2 && iCell == iCell2)
}

test("Same index of different types is cached separately") {
val str1: String = "idx"
val str2: String = "idx_OF_A"
val str3: String = "idx_OF_B"

val pa1 = tpAndIdx(str1)
val pa2 = tpAndIdx(str2)
val pa3 = tpAndIdx(str3)

val arena = PureArena.empty

val (newArena1, cell1) = cache.getOrCreate(arena, pa1)

assert(arena != newArena1)

val (newArena2, cell2) = cache.getOrCreate(newArena1, pa2)

assert(newArena2 != newArena1 && cell2 != cell1)

val (newArena3, cell3) = cache.getOrCreate(newArena2, pa3)

assert(newArena3 != newArena2 && cell3 != cell2)
}

test("Constraints are only added when addAllConstraints is explicitly called, and only once per value") {
val mockCtx: MockZ3SolverContext = new MockZ3SolverContext

val str1: String = "1_OF_A"
val str2: String = "2_OF_A"
val str3: String = "3_OF_A"

val pa1 = tpAndIdx(str1)
val pa2 = tpAndIdx(str2)
val pa3 = tpAndIdx(str3)

val a0 = PureArena.empty
val (a1, c1) = cache.getOrCreate(a0, pa1)
// Some extra calls, which shouldn't affect constraint generation
cache.getOrCreate(a0, pa1)
cache.getOrCreate(a0, pa1)
val (a2, c2) = cache.getOrCreate(a1, pa2)
// Some extra calls, which shouldn't affect constraint generation
cache.getOrCreate(a1, pa2)
cache.getOrCreate(a1, pa2)
val (_, c3) = cache.getOrCreate(a2, pa3)
// Some extra calls, which shouldn't affect constraint generation
cache.getOrCreate(a2, pa3)
cache.getOrCreate(a2, pa3)

assert(mockCtx.constraints.isEmpty)

cache.addAllConstraints(mockCtx)

// Due to the optimized `addAllConstraints` override, we only have 1 "distinct"
assert(mockCtx.constraints == Seq(
tla.distinct(c3.toBuilder, c2.toBuilder, c1.toBuilder).build
))
}

test("Constraints are only added when addConstraintsForElem is explicitly called, and only once per value") {
val mockCtx: MockZ3SolverContext = new MockZ3SolverContext

val str1: String = "1_OF_A"
val str2: String = "2_OF_A"
val str3: String = "3_OF_A"

val pa1 = tpAndIdx(str1)
val pa2 = tpAndIdx(str2)
val pa3 = tpAndIdx(str3)

val a0 = PureArena.empty
val (a1, c1) = cache.getOrCreate(a0, pa1)
// Some extra calls, which shouldn't affect constraint generation
cache.getOrCreate(a0, pa1)
cache.getOrCreate(a0, pa1)

cache.addConstraintsForElem(mockCtx)(pa1, c1)

val (a2, c2) = cache.getOrCreate(a1, pa2)
// Some extra calls, which shouldn't affect constraint generation
cache.getOrCreate(a1, pa2)
cache.getOrCreate(a1, pa2)

cache.addConstraintsForElem(mockCtx)(pa2, c2)

val (_, c3) = cache.getOrCreate(a2, pa3)
// Some extra calls, which shouldn't affect constraint generation
cache.getOrCreate(a2, pa3)
cache.getOrCreate(a2, pa3)

cache.addConstraintsForElem(mockCtx)(pa3, c3)

// -ForElem creates 3 "distinct" constraints
assert(mockCtx.constraints == Seq(
tla.distinct(c1.toBuilder).build,
tla.distinct(c2.toBuilder, c1.toBuilder).build,
tla.distinct(c3.toBuilder, c1.toBuilder, c2.toBuilder).build,
))
}

}