Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lowerBounds and upperBounds methods to POSet #3733

Merged
merged 6 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 15 additions & 32 deletions kernel/src/main/java/org/kframework/compile/AddSortInjections.java
Original file line number Diff line number Diff line change
Expand Up @@ -453,48 +453,31 @@ private static Sort lub(Collection<Sort> entries, Sort expectedSort, HasLocation
if (filteredEntries.isEmpty()) { // if all sorts are parameters, take the first
return entries.iterator().next();
}
Set<Sort> bounds = upperBounds(filteredEntries, mod);

Set<Sort> nonParametric =
filteredEntries.stream().filter(s -> s.params().isEmpty()).collect(Collectors.toSet());
Set<Sort> bounds = mutable(mod.subsorts().upperBounds(immutable(nonParametric)));
// Anything less than KBott or greater than K is a syntactic sort from kast.md which should not be considered
bounds.removeIf(s -> mod.subsorts().lessThanEq(s, Sorts.KBott()) || mod.subsorts().greaterThan(s, Sorts.K()));
if (expectedSort != null && !expectedSort.name().equals(SORTPARAM_NAME)) {
bounds.removeIf(s -> !mod.subsorts().lessThanEq(s, expectedSort));
}

// For parametric sorts, each bound must bound at least one instantiation
Set<Sort> parametric =
filteredEntries.stream().filter(s -> ! s.params().isEmpty()).collect(Collectors.toSet());
bounds.removeIf(bound ->
parametric.stream().anyMatch(param ->
stream(mod.definedInstantiations().apply(param.head()))
.noneMatch(inst -> mod.subsorts().lessThanEq(inst, bound))));

Set<Sort> lub = mod.subsorts().minimal(bounds);
if (lub.size() != 1) {
throw KEMException.internalError("Could not compute least upper bound for rewrite sort. Possible candidates: " + lub, loc);
}
return lub.iterator().next();
}

private static Set<Sort> upperBounds(Collection<Sort> bounds, Module mod) {
Set<Sort> maxs = new HashSet<>();
nextsort:
for (Sort sort : iterable(mod.allSorts())) { // for every declared sort
// Sorts at or below KBott, or above K, are assumed to be
// sorts from kast.k representing meta-syntax that is not a real sort.
// This is done to prevent variables from being inferred as KBott or
// as KList.
if (mod.subsorts().lessThanEq(sort, Sorts.KBott()))
continue;
if (mod.subsorts().greaterThan(sort, Sorts.K()))
continue;
for (Sort bound : bounds)
if (bound.params().isEmpty()) {
if (!mod.subsorts().lessThanEq(bound, sort))
continue nextsort;
} else {
boolean any = false;
for (Sort instantiation : iterable(mod.definedInstantiations().apply(bound.head()))) {
if (mod.subsorts().lessThanEq(instantiation, sort)) {
any = true;
}
}
if (!any)
continue nextsort;
}
maxs.add(sort);
}
return maxs;
}

private Sort freshSortParam() {
return Sort(SORTPARAM_NAME, Sort("Q" + freshSortParamCounter++));
}
Expand Down
97 changes: 54 additions & 43 deletions kore/src/main/scala/org/kframework/POSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@ package org.kframework
import org.kframework.utils.errorsystem.KEMException

import java.util
import java.util.Optional
import collection._
import scala.annotation.tailrec

/**
* A partially ordered set based on an initial set of direct relations.
*/
class POSet[T](val directRelations: Set[(T, T)]) extends Serializable {

// convert the input set of relations to Map form for performance
private val directRelationsMap: Map[T, Set[T]] = directRelations groupBy { _._1 } mapValues { _ map { _._2 } toSet } map identity

Expand All @@ -26,6 +25,7 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable {
* The implementation is simple. It links each element to the successors of its successors.
* TODO: there may be a more efficient algorithm (low priority)
*/
@tailrec
private def transitiveClosure(relations: Map[T, Set[T]]): Map[T, Set[T]] = {
val newRelations = relations map {
case (start, succ) =>
Expand All @@ -44,23 +44,31 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable {
* @param current element
* @param path so far
*/
private def constructAndThrowCycleException(start: T, current: T, path: Seq[T]) {
private def constructAndThrowCycleException(start: T, current: T, path: Seq[T]): Unit = {
val currentPath = path :+ current
val succs = directRelationsMap.getOrElse(current, Set())
if (succs.contains(start)) {
throw KEMException.compilerError("Illegal circular relation: " + (currentPath :+ start).mkString(" < "));
throw KEMException.compilerError("Illegal circular relation: " + (currentPath :+ start).mkString(" < "))
}
succs foreach { constructAndThrowCycleException(start, _, currentPath) }
}

/**
* All the relations of the POSet, including the transitive ones.
*
* Concretely, a map from each element of the poset to the set of elements greater than it.
*/
val relations: Map[T, Set[T]] = transitiveClosure(directRelationsMap)

/**
* A map from each element of the poset to the set of elements less than it.
*/
val relations = transitiveClosure(directRelationsMap)
lazy val relationsOp: Map[T, Set[T]] =
relations.toSet[(T, Set[T])].flatMap { case (x, ys) => ys.map(_ -> x) }.groupBy(_._1).mapValues(_.map(_._2))

def <(x: T, y: T): Boolean = relations.get(x).exists(_.contains(y))
def >(x: T, y: T): Boolean = relations.get(y).exists(_.contains(x))
def ~(x: T, y: T) = <(x, y) || <(y, x)
def ~(x: T, y: T): Boolean = <(x, y) || <(y, x)

/**
* Returns true if x < y
Expand All @@ -77,34 +85,29 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable {
/**
* Returns true if y < x or y < x
*/
def inSomeRelation(x: T, y: T) = this.~(x, y)
def inSomeRelationEq(x: T, y: T) = x == y || this.~(x, y)
def inSomeRelation(x: T, y: T): Boolean = this.~(x, y)
def inSomeRelationEq(x: T, y: T): Boolean = x == y || this.~(x, y)

/**
* Returns an Optional of the least upper bound if it exists, or an empty Optional otherwise.
* Return the set of all upper bounds of the input.
*/
lazy val leastUpperBound: Optional[T] = lub match {
case Some(x) => Optional.of(x)
case None => Optional.empty()
}
def upperBounds(sorts: Iterable[T]): Set[T] =
if (sorts.isEmpty) elements else POSet.upperBounds(sorts, relations)

/**
* Return the set of all lower bounds of the input.
*/
def lowerBounds(sorts: Iterable[T]): Set[T] =
if (sorts.isEmpty) elements else POSet.upperBounds(sorts, relationsOp)

lazy val lub: Option[T] = {
val candidates = relations.values reduce { (a, b) => a & b }

if (candidates.size == 0)
None
else if (candidates.size == 1)
Some(candidates.head)
else {
val allPairs = for (a <- candidates; b <- candidates) yield { (a, b) }
if (allPairs exists { case (a, b) => ! ~(a, b) })
None
else
Some(
candidates.min(new Ordering[T]() {
def compare(x: T, y: T) = if (x < y) -1 else if (x > y) 1 else 0
}))
}
val mins = minimal(upperBounds(elements))
if (mins.size == 1) Some(mins.head) else None
}

lazy val glb: Option[T] = {
val maxs = maximal(lowerBounds(elements))
if (maxs.size == 1) Some(maxs.head) else None
}

lazy val asOrdering: Ordering[T] = (x: T, y: T) => if (lessThanEq(x, y)) -1 else if (lessThanEq(y, x)) 1 else 0
Expand All @@ -113,33 +116,33 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable {
* Return the subset of items from the argument which are not
* less than any other item.
*/
def maximal(sorts : Iterable[T]) : Set[T] =
def maximal(sorts: Iterable[T]): Set[T] =
sorts.filter(s1 => !sorts.exists(s2 => lessThan(s1,s2))).toSet

def maximal(sorts : util.Collection[T]) : util.Set[T] = {
import scala.collection.JavaConversions._
maximal(sorts : Iterable[T])
def maximal(sorts: util.Collection[T]): util.Set[T] = {
import scala.collection.JavaConverters._
maximal(sorts.asScala).asJava
}

/**
* Return the subset of items from the argument which are not
* greater than any other item.
*/
def minimal(sorts : Iterable[T]) : Set[T] =
def minimal(sorts: Iterable[T]): Set[T] =
sorts.filter(s1 => !sorts.exists(s2 => >(s1,s2))).toSet

def minimal(sorts : util.Collection[T]) : util.Set[T] = {
import scala.collection.JavaConversions._
minimal(sorts : Iterable[T])
def minimal(sorts: util.Collection[T]): util.Set[T] = {
import scala.collection.JavaConverters._
minimal(sorts.asScala).asJava
}

override def toString() = {
"POSet(" + (relations flatMap { case (from, tos) => tos map { case to => from + "<" + to } }).mkString(",") + ")"
override def toString: String = {
"POSet(" + (relations flatMap { case (from, tos) => tos map { to => from + "<" + to } }).mkString(",") + ")"
}

override def hashCode = relations.hashCode()
override def hashCode: Int = relations.hashCode()

override def equals(that: Any) = that match {
override def equals(that: Any): Boolean = that match {
case that: POSet[_] => relations == that.relations
case _ => false
}
Expand All @@ -153,7 +156,15 @@ object POSet {
* Import this for Scala syntactic sugar.
*/
implicit class PO[T](x: T)(implicit val po: POSet[T]) {
def <(y: T) = po.<(x, y)
def >(y: T) = po.>(x, y)
def <(y: T): Boolean = po.<(x, y)
def >(y: T): Boolean = po.>(x, y)
}

/**
* Return the set of all elements which are greater than or equal to each element of the input,
* using the provided relations map. Input must be non-empty.
*/
private def upperBounds[T](sorts: Iterable[T], relations: Map[T, Set[T]]): Set[T] =
(((sorts filterNot relations.keys.toSet[T]) map {Set.empty + _}) ++
((relations filterKeys sorts.toSet) map { case (k, v) => v + k })) reduce { (a, b) => a & b }
}
8 changes: 8 additions & 0 deletions kore/src/test/scala/org/kframework/POSetTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,12 @@ class POSetTest {
assertEquals(None, POSet(b1 -> b2, b2 -> b3, b4 -> b5).lub)
assertEquals(None, POSet(b1 -> b2, b2 -> b3, b2 -> b4).lub)
}

@Test def glb() {
assertEquals(Some(b2), POSet(b2 -> b1).glb)
assertEquals(Some(b3), POSet(b3 -> b1, b3 -> b2).glb)
assertEquals(Some(b4), POSet(b3 -> b1, b3 -> b2, b4 -> b3).glb)
assertEquals(None, POSet(b2 -> b1, b3 -> b2, b5 -> b4).glb)
assertEquals(None, POSet(b2 -> b1, b3 -> b2, b4 -> b2).glb)
}
}