diff --git a/kernel/src/main/java/org/kframework/compile/AddSortInjections.java b/kernel/src/main/java/org/kframework/compile/AddSortInjections.java index 6acc7ef344d..a4bfddabbe7 100644 --- a/kernel/src/main/java/org/kframework/compile/AddSortInjections.java +++ b/kernel/src/main/java/org/kframework/compile/AddSortInjections.java @@ -453,10 +453,24 @@ private static Sort lub(Collection entries, Sort expectedSort, HasLocation if (filteredEntries.isEmpty()) { // if all sorts are parameters, take the first return entries.iterator().next(); } - Set bounds = upperBounds(filteredEntries, mod); + + Set nonParametric = + filteredEntries.stream().filter(s -> s.params().isEmpty()).collect(Collectors.toSet()); + Set bounds = mutable(mod.subsorts().upperBounds(immutable(nonParametric))); + // Anything less than KBott or greater than K is a syntactic sort from kast.md which should not be considered + bounds.removeIf(s -> mod.subsorts().lessThanEq(s, Sorts.KBott()) || mod.subsorts().greaterThan(s, Sorts.K())); if (expectedSort != null && !expectedSort.name().equals(SORTPARAM_NAME)) { bounds.removeIf(s -> !mod.subsorts().lessThanEq(s, expectedSort)); } + + // For parametric sorts, each bound must bound at least one instantiation + Set parametric = + filteredEntries.stream().filter(s -> ! s.params().isEmpty()).collect(Collectors.toSet()); + bounds.removeIf(bound -> + parametric.stream().anyMatch(param -> + stream(mod.definedInstantiations().apply(param.head())) + .noneMatch(inst -> mod.subsorts().lessThanEq(inst, bound)))); + Set lub = mod.subsorts().minimal(bounds); if (lub.size() != 1) { throw KEMException.internalError("Could not compute least upper bound for rewrite sort. Possible candidates: " + lub, loc); @@ -464,37 +478,6 @@ private static Sort lub(Collection entries, Sort expectedSort, HasLocation return lub.iterator().next(); } - private static Set upperBounds(Collection bounds, Module mod) { - Set maxs = new HashSet<>(); - nextsort: - for (Sort sort : iterable(mod.allSorts())) { // for every declared sort - // Sorts at or below KBott, or above K, are assumed to be - // sorts from kast.k representing meta-syntax that is not a real sort. - // This is done to prevent variables from being inferred as KBott or - // as KList. - if (mod.subsorts().lessThanEq(sort, Sorts.KBott())) - continue; - if (mod.subsorts().greaterThan(sort, Sorts.K())) - continue; - for (Sort bound : bounds) - if (bound.params().isEmpty()) { - if (!mod.subsorts().lessThanEq(bound, sort)) - continue nextsort; - } else { - boolean any = false; - for (Sort instantiation : iterable(mod.definedInstantiations().apply(bound.head()))) { - if (mod.subsorts().lessThanEq(instantiation, sort)) { - any = true; - } - } - if (!any) - continue nextsort; - } - maxs.add(sort); - } - return maxs; - } - private Sort freshSortParam() { return Sort(SORTPARAM_NAME, Sort("Q" + freshSortParamCounter++)); } diff --git a/kore/src/main/scala/org/kframework/POSet.scala b/kore/src/main/scala/org/kframework/POSet.scala index b8f0173574f..f39fa63421f 100644 --- a/kore/src/main/scala/org/kframework/POSet.scala +++ b/kore/src/main/scala/org/kframework/POSet.scala @@ -4,14 +4,13 @@ package org.kframework import org.kframework.utils.errorsystem.KEMException import java.util -import java.util.Optional import collection._ +import scala.annotation.tailrec /** * A partially ordered set based on an initial set of direct relations. */ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable { - // convert the input set of relations to Map form for performance private val directRelationsMap: Map[T, Set[T]] = directRelations groupBy { _._1 } mapValues { _ map { _._2 } toSet } map identity @@ -26,6 +25,7 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable { * The implementation is simple. It links each element to the successors of its successors. * TODO: there may be a more efficient algorithm (low priority) */ + @tailrec private def transitiveClosure(relations: Map[T, Set[T]]): Map[T, Set[T]] = { val newRelations = relations map { case (start, succ) => @@ -44,23 +44,31 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable { * @param current element * @param path so far */ - private def constructAndThrowCycleException(start: T, current: T, path: Seq[T]) { + private def constructAndThrowCycleException(start: T, current: T, path: Seq[T]): Unit = { val currentPath = path :+ current val succs = directRelationsMap.getOrElse(current, Set()) if (succs.contains(start)) { - throw KEMException.compilerError("Illegal circular relation: " + (currentPath :+ start).mkString(" < ")); + throw KEMException.compilerError("Illegal circular relation: " + (currentPath :+ start).mkString(" < ")) } succs foreach { constructAndThrowCycleException(start, _, currentPath) } } /** * All the relations of the POSet, including the transitive ones. + * + * Concretely, a map from each element of the poset to the set of elements greater than it. + */ + val relations: Map[T, Set[T]] = transitiveClosure(directRelationsMap) + + /** + * A map from each element of the poset to the set of elements less than it. */ - val relations = transitiveClosure(directRelationsMap) + lazy val relationsOp: Map[T, Set[T]] = + relations.toSet[(T, Set[T])].flatMap { case (x, ys) => ys.map(_ -> x) }.groupBy(_._1).mapValues(_.map(_._2)) def <(x: T, y: T): Boolean = relations.get(x).exists(_.contains(y)) def >(x: T, y: T): Boolean = relations.get(y).exists(_.contains(x)) - def ~(x: T, y: T) = <(x, y) || <(y, x) + def ~(x: T, y: T): Boolean = <(x, y) || <(y, x) /** * Returns true if x < y @@ -77,34 +85,29 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable { /** * Returns true if y < x or y < x */ - def inSomeRelation(x: T, y: T) = this.~(x, y) - def inSomeRelationEq(x: T, y: T) = x == y || this.~(x, y) + def inSomeRelation(x: T, y: T): Boolean = this.~(x, y) + def inSomeRelationEq(x: T, y: T): Boolean = x == y || this.~(x, y) /** - * Returns an Optional of the least upper bound if it exists, or an empty Optional otherwise. + * Return the set of all upper bounds of the input. */ - lazy val leastUpperBound: Optional[T] = lub match { - case Some(x) => Optional.of(x) - case None => Optional.empty() - } + def upperBounds(sorts: Iterable[T]): Set[T] = + if (sorts.isEmpty) elements else POSet.upperBounds(sorts, relations) + + /** + * Return the set of all lower bounds of the input. + */ + def lowerBounds(sorts: Iterable[T]): Set[T] = + if (sorts.isEmpty) elements else POSet.upperBounds(sorts, relationsOp) lazy val lub: Option[T] = { - val candidates = relations.values reduce { (a, b) => a & b } - - if (candidates.size == 0) - None - else if (candidates.size == 1) - Some(candidates.head) - else { - val allPairs = for (a <- candidates; b <- candidates) yield { (a, b) } - if (allPairs exists { case (a, b) => ! ~(a, b) }) - None - else - Some( - candidates.min(new Ordering[T]() { - def compare(x: T, y: T) = if (x < y) -1 else if (x > y) 1 else 0 - })) - } + val mins = minimal(upperBounds(elements)) + if (mins.size == 1) Some(mins.head) else None + } + + lazy val glb: Option[T] = { + val maxs = maximal(lowerBounds(elements)) + if (maxs.size == 1) Some(maxs.head) else None } lazy val asOrdering: Ordering[T] = (x: T, y: T) => if (lessThanEq(x, y)) -1 else if (lessThanEq(y, x)) 1 else 0 @@ -113,33 +116,33 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable { * Return the subset of items from the argument which are not * less than any other item. */ - def maximal(sorts : Iterable[T]) : Set[T] = + def maximal(sorts: Iterable[T]): Set[T] = sorts.filter(s1 => !sorts.exists(s2 => lessThan(s1,s2))).toSet - def maximal(sorts : util.Collection[T]) : util.Set[T] = { - import scala.collection.JavaConversions._ - maximal(sorts : Iterable[T]) + def maximal(sorts: util.Collection[T]): util.Set[T] = { + import scala.collection.JavaConverters._ + maximal(sorts.asScala).asJava } /** * Return the subset of items from the argument which are not * greater than any other item. */ - def minimal(sorts : Iterable[T]) : Set[T] = + def minimal(sorts: Iterable[T]): Set[T] = sorts.filter(s1 => !sorts.exists(s2 => >(s1,s2))).toSet - def minimal(sorts : util.Collection[T]) : util.Set[T] = { - import scala.collection.JavaConversions._ - minimal(sorts : Iterable[T]) + def minimal(sorts: util.Collection[T]): util.Set[T] = { + import scala.collection.JavaConverters._ + minimal(sorts.asScala).asJava } - override def toString() = { - "POSet(" + (relations flatMap { case (from, tos) => tos map { case to => from + "<" + to } }).mkString(",") + ")" + override def toString: String = { + "POSet(" + (relations flatMap { case (from, tos) => tos map { to => from + "<" + to } }).mkString(",") + ")" } - override def hashCode = relations.hashCode() + override def hashCode: Int = relations.hashCode() - override def equals(that: Any) = that match { + override def equals(that: Any): Boolean = that match { case that: POSet[_] => relations == that.relations case _ => false } @@ -153,7 +156,15 @@ object POSet { * Import this for Scala syntactic sugar. */ implicit class PO[T](x: T)(implicit val po: POSet[T]) { - def <(y: T) = po.<(x, y) - def >(y: T) = po.>(x, y) + def <(y: T): Boolean = po.<(x, y) + def >(y: T): Boolean = po.>(x, y) } + + /** + * Return the set of all elements which are greater than or equal to each element of the input, + * using the provided relations map. Input must be non-empty. + */ + private def upperBounds[T](sorts: Iterable[T], relations: Map[T, Set[T]]): Set[T] = + (((sorts filterNot relations.keys.toSet[T]) map {Set.empty + _}) ++ + ((relations filterKeys sorts.toSet) map { case (k, v) => v + k })) reduce { (a, b) => a & b } } diff --git a/kore/src/test/scala/org/kframework/POSetTest.scala b/kore/src/test/scala/org/kframework/POSetTest.scala index 07c56e58d23..c2a5dfaf6ff 100644 --- a/kore/src/test/scala/org/kframework/POSetTest.scala +++ b/kore/src/test/scala/org/kframework/POSetTest.scala @@ -45,4 +45,12 @@ class POSetTest { assertEquals(None, POSet(b1 -> b2, b2 -> b3, b4 -> b5).lub) assertEquals(None, POSet(b1 -> b2, b2 -> b3, b2 -> b4).lub) } + + @Test def glb() { + assertEquals(Some(b2), POSet(b2 -> b1).glb) + assertEquals(Some(b3), POSet(b3 -> b1, b3 -> b2).glb) + assertEquals(Some(b4), POSet(b3 -> b1, b3 -> b2, b4 -> b3).glb) + assertEquals(None, POSet(b2 -> b1, b3 -> b2, b5 -> b4).glb) + assertEquals(None, POSet(b2 -> b1, b3 -> b2, b4 -> b2).glb) + } }