Skip to content

Commit

Permalink
Move id field out of ProductionReference into TypeInferencer (#…
Browse files Browse the repository at this point in the history
…3828)

The `id` field of `ProductionReference` is only used by
`TypeInferencer`, so this PR removes the field and instead stores it as
a `Map<ProductionReference, Integer>` within `TypeInferencer`.

This is useful for testing the new type inference algorithm - both
algorithms can be run back-to-back on the same term to compare results,
with them each tracking the `id` internally rather than mutating the
input and storing conflicting `id`s.

Additionally, I formatted `treeNodes.scala` and fixed all IntelliJ
warnings.

---------

Co-authored-by: rv-jenkins <[email protected]>
  • Loading branch information
Scott-Guest and rv-jenkins authored Nov 27, 2023
1 parent 61de88b commit 60e3dec
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ public enum Status {
private final boolean debug;
private final Module mod;
private final java.util.Set<SortHead> sorts;
private final Map<ProductionReference, Integer> prIds = new IdentityHashMap<>();

// logic QF_DT is best if it exists as it will be faster than ALL. However, some z3 versions do
// not have this logic.
Expand Down Expand Up @@ -593,15 +594,15 @@ public String apply(Term t) {
boolean isTopSort =
expectedSort.equals(Sorts.RuleContent()) || expectedSort.name().equals("#RuleBody");
int id = nextId;
boolean shared = pr.id().isPresent() && variablesById.size() > pr.id().get();
boolean shared = prIds.containsKey(pr) && variablesById.size() > prIds.get(pr);
if (!shared) {
// if this is the first time reaching this term, initialize data structures with the
// variables associated with
// this term.
nextId++;
variablesById.add(new ArrayList<>());
cacheById.add(new HashSet<>());
pr.setId(Optional.of(id));
prIds.put(pr, id);
for (Sort param : iterable(pr.production().params())) {
String name = "FreshVar" + param.name() + locStr(pr);
if (!variables.contains(name)) {
Expand All @@ -612,7 +613,7 @@ public String apply(Term t) {
}
} else {
// get cached id
id = pr.id().get();
id = prIds.get(pr);
}
if (pr instanceof TermCons tc) {
boolean wasStrict = isStrictEquality;
Expand Down Expand Up @@ -682,7 +683,7 @@ public String apply(Term t) {
nextId++;
variablesById.add(new ArrayList<>());
cacheById.add(new HashSet<>());
pr.setId(Optional.of(id));
prIds.put(pr, id);
if (isAnonVar(c)) {
name = "Var" + c.value() + locStr(c);
isStrictEquality = true;
Expand Down Expand Up @@ -799,7 +800,7 @@ private String printSort(Sort s, Optional<ProductionReference> t, boolean isIncr
Map<Sort, String> params = new HashMap<>();
if (t.isPresent()) {
if (t.get().production().params().nonEmpty()) {
int id = t.get().id().get();
int id = prIds.get(t.get());
List<String> names = variablesById.get(id);
Seq<Sort> formalParams = t.get().production().params();
assert (names.size() == formalParams.size());
Expand Down Expand Up @@ -945,8 +946,8 @@ void pushNotModel() {
}

public Seq<Sort> getArgs(ProductionReference pr) {
if (pr.id().isPresent()) {
int id = pr.id().get();
if (prIds.containsKey(pr)) {
int id = prIds.get(pr);
List<String> names = variablesById.get(id);
return names.stream().map(this::getValue).collect(Collections.toList());
} else {
Expand Down Expand Up @@ -1023,7 +1024,7 @@ private void reset() {
}
}

private static String locStr(ProductionReference pr) {
private String locStr(ProductionReference pr) {
String suffix = "";
if (pr.production().klabel().isDefined()) {
suffix = "_" + pr.production().klabel().get().name().replace("|", "");
Expand All @@ -1040,7 +1041,7 @@ private static String locStr(ProductionReference pr) {
+ l.endColumn()
+ suffix;
}
return pr.id().get() + suffix;
return prIds.get(pr) + suffix;
}

private StringBuilder sb = new StringBuilder();
Expand Down
96 changes: 57 additions & 39 deletions kore/src/main/scala/org/kframework/parser/treeNodes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@

package org.kframework.parser

import org.kframework.attributes.{Source, Location, HasLocation}
import org.kframework.attributes.{HasLocation, Location, Source}
import org.kframework.definition.Production
import org.kframework.kore.KORE.Sort
import java.util._
import java.lang.Iterable
import org.pcollections.{ConsPStack, PStack}
import collection.JavaConverters._
import org.kframework.utils.StringUtil
import org.pcollections.{ConsPStack, PStack}

import scala.collection.mutable;
import java.util._
import scala.collection.JavaConverters._
import scala.collection.mutable

trait Term extends HasLocation {
var location: Optional[Location] = Optional.empty()
Expand All @@ -20,50 +19,54 @@ trait Term extends HasLocation {

trait ProductionReference extends Term {
val production: Production
var id: Optional[Integer] = Optional.empty()

def setId(id: Optional[Integer]) {
this.id = id
}
}

trait HasChildren {
def items: Iterable[Term]
def items: java.lang.Iterable[Term]

def replaceChildren(newChildren: Collection[Term]): Term
def replaceChildren(newChildren: java.util.Collection[Term]): Term
}

case class Constant private(value: String, production: Production) extends ProductionReference {
override def toString = "#token(" + production.sort + "," + StringUtil.enquoteKString(value) + ")"
case class Constant private (value: String, production: Production) extends ProductionReference {
override def toString: String =
"#token(" + production.sort + "," + StringUtil.enquoteKString(value) + ")"

override lazy val hashCode: Int = scala.runtime.ScalaRunTime._hashCode(Constant.this);
override lazy val hashCode: Int = scala.runtime.ScalaRunTime._hashCode(Constant.this)
}

// note that items is reversed because it is more efficient to generate it this way during parsing
case class TermCons private(items: PStack[Term], production: Production)
extends ProductionReference with HasChildren {
def get(i: Int) = items.get(items.size() - 1 - i)
case class TermCons private (items: PStack[Term], production: Production)
extends ProductionReference
with HasChildren {
def get(i: Int): Term = items.get(items.size() - 1 - i)

def `with`(i: Int, e: Term) = TermCons(items.`with`(items.size() - 1 - i, e), production, location, source, id)
def `with`(i: Int, e: Term): TermCons =
TermCons(items.`with`(items.size() - 1 - i, e), production, location, source)

def replaceChildren(newChildren: Collection[Term]) = TermCons(ConsPStack.from(newChildren), production, location, source, id)
def replaceChildren(newChildren: java.util.Collection[Term]): TermCons =
TermCons(ConsPStack.from(newChildren), production, location, source)

override def toString() = new TreeNodesToKORE(s => Sort(s)).apply(this).toString()
override def toString: String = new TreeNodesToKORE(s => Sort(s)).apply(this).toString()

override lazy val hashCode: Int = scala.runtime.ScalaRunTime._hashCode(TermCons.this);
override lazy val hashCode: Int = scala.runtime.ScalaRunTime._hashCode(TermCons.this)
}

case class Ambiguity(items: Set[Term])
extends Term with HasChildren {
def replaceChildren(newChildren: Collection[Term]) = Ambiguity(new HashSet[Term](newChildren), location, source)
case class Ambiguity(items: java.util.Set[Term]) extends Term with HasChildren {
def replaceChildren(newChildren: java.util.Collection[Term]): Ambiguity =
Ambiguity(new java.util.HashSet[Term](newChildren), location, source)

override def toString() = "amb(" + (items.asScala mkString ",") + ")"
override def toString: String = "amb(" + (items.asScala mkString ",") + ")"

override lazy val hashCode: Int = scala.runtime.ScalaRunTime._hashCode(Ambiguity.this);
override lazy val hashCode: Int = scala.runtime.ScalaRunTime._hashCode(Ambiguity.this)
}

object Constant {
def apply(value: String, production: Production, location: Optional[Location], source: Optional[Source]): Constant = {
def apply(
value: String,
production: Production,
location: Optional[Location],
source: Optional[Source]
): Constant = {
val res = Constant(value, production)
res.location = location
res.source = source
Expand All @@ -72,25 +75,40 @@ object Constant {
}

object TermCons {
def apply(items: PStack[Term], production: Production, location: Optional[Location], source: Optional[Source], id: Optional[Integer]): TermCons = {
def apply(
items: PStack[Term],
production: Production,
location: Optional[Location],
source: Optional[Source]
): TermCons = {
val res = TermCons(items, production)
res.location = location
res.source = source
res.id = id
res
}

def apply(items: PStack[Term], production: Production, location: Optional[Location], source: Optional[Source]): TermCons = {
TermCons(items, production, location, source, Optional.empty())
}

def apply(items: PStack[Term], production: Production, location: Location, source: Source): TermCons = TermCons(items, production, Optional.of(location), Optional.of(source), Optional.empty())
def apply(
items: PStack[Term],
production: Production,
location: Location,
source: Source
): TermCons = TermCons(
items,
production,
Optional.of(location),
Optional.of(source)
)
}

object Ambiguity {
@annotation.varargs def apply(items: Term*): Ambiguity = Ambiguity(items.to[mutable.Set].asJava)

def apply(items: Set[Term], location: Optional[Location], source: Optional[Source]): Ambiguity = {
@annotation.varargs
def apply(items: Term*): Ambiguity = Ambiguity(items.to[mutable.Set].asJava)

def apply(
items: java.util.Set[Term],
location: Optional[Location],
source: Optional[Source]
): Ambiguity = {
val res = Ambiguity(items)
res.location = location
res.source = source
Expand Down

0 comments on commit 60e3dec

Please sign in to comment.