From 7c817a4e77afac8407cc527bab49bb7d943f19e7 Mon Sep 17 00:00:00 2001 From: Scott Guest Date: Thu, 19 Oct 2023 15:26:00 -0400 Subject: [PATCH 1/3] Address IntelliJ warnings in POSet.scala --- .../src/main/scala/org/kframework/POSet.scala | 55 ++++++++----------- 1 file changed, 24 insertions(+), 31 deletions(-) diff --git a/kore/src/main/scala/org/kframework/POSet.scala b/kore/src/main/scala/org/kframework/POSet.scala index b8f0173574f..232670bc2cb 100644 --- a/kore/src/main/scala/org/kframework/POSet.scala +++ b/kore/src/main/scala/org/kframework/POSet.scala @@ -4,8 +4,8 @@ package org.kframework import org.kframework.utils.errorsystem.KEMException import java.util -import java.util.Optional import collection._ +import scala.annotation.tailrec /** * A partially ordered set based on an initial set of direct relations. @@ -26,6 +26,7 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable { * The implementation is simple. It links each element to the successors of its successors. * TODO: there may be a more efficient algorithm (low priority) */ + @tailrec private def transitiveClosure(relations: Map[T, Set[T]]): Map[T, Set[T]] = { val newRelations = relations map { case (start, succ) => @@ -44,11 +45,11 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable { * @param current element * @param path so far */ - private def constructAndThrowCycleException(start: T, current: T, path: Seq[T]) { + private def constructAndThrowCycleException(start: T, current: T, path: Seq[T]): Unit = { val currentPath = path :+ current val succs = directRelationsMap.getOrElse(current, Set()) if (succs.contains(start)) { - throw KEMException.compilerError("Illegal circular relation: " + (currentPath :+ start).mkString(" < ")); + throw KEMException.compilerError("Illegal circular relation: " + (currentPath :+ start).mkString(" < ")) } succs foreach { constructAndThrowCycleException(start, _, currentPath) } } @@ -56,11 +57,11 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable { /** * All the relations of the POSet, including the transitive ones. */ - val relations = transitiveClosure(directRelationsMap) + val relations: Map[T, Set[T]] = transitiveClosure(directRelationsMap) def <(x: T, y: T): Boolean = relations.get(x).exists(_.contains(y)) def >(x: T, y: T): Boolean = relations.get(y).exists(_.contains(x)) - def ~(x: T, y: T) = <(x, y) || <(y, x) + def ~(x: T, y: T): Boolean = <(x, y) || <(y, x) /** * Returns true if x < y @@ -77,21 +78,13 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable { /** * Returns true if y < x or y < x */ - def inSomeRelation(x: T, y: T) = this.~(x, y) - def inSomeRelationEq(x: T, y: T) = x == y || this.~(x, y) - - /** - * Returns an Optional of the least upper bound if it exists, or an empty Optional otherwise. - */ - lazy val leastUpperBound: Optional[T] = lub match { - case Some(x) => Optional.of(x) - case None => Optional.empty() - } + def inSomeRelation(x: T, y: T): Boolean = this.~(x, y) + def inSomeRelationEq(x: T, y: T): Boolean = x == y || this.~(x, y) lazy val lub: Option[T] = { val candidates = relations.values reduce { (a, b) => a & b } - if (candidates.size == 0) + if (candidates.isEmpty) None else if (candidates.size == 1) Some(candidates.head) @@ -102,7 +95,7 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable { else Some( candidates.min(new Ordering[T]() { - def compare(x: T, y: T) = if (x < y) -1 else if (x > y) 1 else 0 + def compare(x: T, y: T): Int = if (x < y) -1 else if (x > y) 1 else 0 })) } } @@ -113,33 +106,33 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable { * Return the subset of items from the argument which are not * less than any other item. */ - def maximal(sorts : Iterable[T]) : Set[T] = + def maximal(sorts: Iterable[T]): Set[T] = sorts.filter(s1 => !sorts.exists(s2 => lessThan(s1,s2))).toSet - def maximal(sorts : util.Collection[T]) : util.Set[T] = { - import scala.collection.JavaConversions._ - maximal(sorts : Iterable[T]) + def maximal(sorts: util.Collection[T]): util.Set[T] = { + import scala.collection.JavaConverters._ + maximal(sorts.asScala).asJava } /** * Return the subset of items from the argument which are not * greater than any other item. */ - def minimal(sorts : Iterable[T]) : Set[T] = + def minimal(sorts: Iterable[T]): Set[T] = sorts.filter(s1 => !sorts.exists(s2 => >(s1,s2))).toSet - def minimal(sorts : util.Collection[T]) : util.Set[T] = { - import scala.collection.JavaConversions._ - minimal(sorts : Iterable[T]) + def minimal(sorts: util.Collection[T]): util.Set[T] = { + import scala.collection.JavaConverters._ + minimal(sorts.asScala).asJava } - override def toString() = { - "POSet(" + (relations flatMap { case (from, tos) => tos map { case to => from + "<" + to } }).mkString(",") + ")" + override def toString: String = { + "POSet(" + (relations flatMap { case (from, tos) => tos map { to => from + "<" + to } }).mkString(",") + ")" } - override def hashCode = relations.hashCode() + override def hashCode: Int = relations.hashCode() - override def equals(that: Any) = that match { + override def equals(that: Any): Boolean = that match { case that: POSet[_] => relations == that.relations case _ => false } @@ -153,7 +146,7 @@ object POSet { * Import this for Scala syntactic sugar. */ implicit class PO[T](x: T)(implicit val po: POSet[T]) { - def <(y: T) = po.<(x, y) - def >(y: T) = po.>(x, y) + def <(y: T): Boolean = po.<(x, y) + def >(y: T): Boolean = po.>(x, y) } } From 2a396c8d5aff06555ea39b490909d8bbed9a87df Mon Sep 17 00:00:00 2001 From: Scott Guest Date: Thu, 19 Oct 2023 15:28:39 -0400 Subject: [PATCH 2/3] POSet.scala: Implement lowerBounds and upperBounds methods --- .../src/main/scala/org/kframework/POSet.scala | 52 +++++++++++++------ .../test/scala/org/kframework/POSetTest.scala | 8 +++ 2 files changed, 43 insertions(+), 17 deletions(-) diff --git a/kore/src/main/scala/org/kframework/POSet.scala b/kore/src/main/scala/org/kframework/POSet.scala index 232670bc2cb..f39fa63421f 100644 --- a/kore/src/main/scala/org/kframework/POSet.scala +++ b/kore/src/main/scala/org/kframework/POSet.scala @@ -11,7 +11,6 @@ import scala.annotation.tailrec * A partially ordered set based on an initial set of direct relations. */ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable { - // convert the input set of relations to Map form for performance private val directRelationsMap: Map[T, Set[T]] = directRelations groupBy { _._1 } mapValues { _ map { _._2 } toSet } map identity @@ -56,9 +55,17 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable { /** * All the relations of the POSet, including the transitive ones. + * + * Concretely, a map from each element of the poset to the set of elements greater than it. */ val relations: Map[T, Set[T]] = transitiveClosure(directRelationsMap) + /** + * A map from each element of the poset to the set of elements less than it. + */ + lazy val relationsOp: Map[T, Set[T]] = + relations.toSet[(T, Set[T])].flatMap { case (x, ys) => ys.map(_ -> x) }.groupBy(_._1).mapValues(_.map(_._2)) + def <(x: T, y: T): Boolean = relations.get(x).exists(_.contains(y)) def >(x: T, y: T): Boolean = relations.get(y).exists(_.contains(x)) def ~(x: T, y: T): Boolean = <(x, y) || <(y, x) @@ -81,23 +88,26 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable { def inSomeRelation(x: T, y: T): Boolean = this.~(x, y) def inSomeRelationEq(x: T, y: T): Boolean = x == y || this.~(x, y) + /** + * Return the set of all upper bounds of the input. + */ + def upperBounds(sorts: Iterable[T]): Set[T] = + if (sorts.isEmpty) elements else POSet.upperBounds(sorts, relations) + + /** + * Return the set of all lower bounds of the input. + */ + def lowerBounds(sorts: Iterable[T]): Set[T] = + if (sorts.isEmpty) elements else POSet.upperBounds(sorts, relationsOp) + lazy val lub: Option[T] = { - val candidates = relations.values reduce { (a, b) => a & b } - - if (candidates.isEmpty) - None - else if (candidates.size == 1) - Some(candidates.head) - else { - val allPairs = for (a <- candidates; b <- candidates) yield { (a, b) } - if (allPairs exists { case (a, b) => ! ~(a, b) }) - None - else - Some( - candidates.min(new Ordering[T]() { - def compare(x: T, y: T): Int = if (x < y) -1 else if (x > y) 1 else 0 - })) - } + val mins = minimal(upperBounds(elements)) + if (mins.size == 1) Some(mins.head) else None + } + + lazy val glb: Option[T] = { + val maxs = maximal(lowerBounds(elements)) + if (maxs.size == 1) Some(maxs.head) else None } lazy val asOrdering: Ordering[T] = (x: T, y: T) => if (lessThanEq(x, y)) -1 else if (lessThanEq(y, x)) 1 else 0 @@ -149,4 +159,12 @@ object POSet { def <(y: T): Boolean = po.<(x, y) def >(y: T): Boolean = po.>(x, y) } + + /** + * Return the set of all elements which are greater than or equal to each element of the input, + * using the provided relations map. Input must be non-empty. + */ + private def upperBounds[T](sorts: Iterable[T], relations: Map[T, Set[T]]): Set[T] = + (((sorts filterNot relations.keys.toSet[T]) map {Set.empty + _}) ++ + ((relations filterKeys sorts.toSet) map { case (k, v) => v + k })) reduce { (a, b) => a & b } } diff --git a/kore/src/test/scala/org/kframework/POSetTest.scala b/kore/src/test/scala/org/kframework/POSetTest.scala index 07c56e58d23..c2a5dfaf6ff 100644 --- a/kore/src/test/scala/org/kframework/POSetTest.scala +++ b/kore/src/test/scala/org/kframework/POSetTest.scala @@ -45,4 +45,12 @@ class POSetTest { assertEquals(None, POSet(b1 -> b2, b2 -> b3, b4 -> b5).lub) assertEquals(None, POSet(b1 -> b2, b2 -> b3, b2 -> b4).lub) } + + @Test def glb() { + assertEquals(Some(b2), POSet(b2 -> b1).glb) + assertEquals(Some(b3), POSet(b3 -> b1, b3 -> b2).glb) + assertEquals(Some(b4), POSet(b3 -> b1, b3 -> b2, b4 -> b3).glb) + assertEquals(None, POSet(b2 -> b1, b3 -> b2, b5 -> b4).glb) + assertEquals(None, POSet(b2 -> b1, b3 -> b2, b4 -> b2).glb) + } } From 6f9f33500c45d93d56b58ad73d60f0c08f0182a2 Mon Sep 17 00:00:00 2001 From: Scott Guest Date: Thu, 19 Oct 2023 15:29:40 -0400 Subject: [PATCH 3/3] AddSortInjections.java: Refactor LUB computation to use new POSet methods --- .../kframework/compile/AddSortInjections.java | 47 ++++++------------- 1 file changed, 15 insertions(+), 32 deletions(-) diff --git a/kernel/src/main/java/org/kframework/compile/AddSortInjections.java b/kernel/src/main/java/org/kframework/compile/AddSortInjections.java index 6acc7ef344d..a4bfddabbe7 100644 --- a/kernel/src/main/java/org/kframework/compile/AddSortInjections.java +++ b/kernel/src/main/java/org/kframework/compile/AddSortInjections.java @@ -453,10 +453,24 @@ private static Sort lub(Collection entries, Sort expectedSort, HasLocation if (filteredEntries.isEmpty()) { // if all sorts are parameters, take the first return entries.iterator().next(); } - Set bounds = upperBounds(filteredEntries, mod); + + Set nonParametric = + filteredEntries.stream().filter(s -> s.params().isEmpty()).collect(Collectors.toSet()); + Set bounds = mutable(mod.subsorts().upperBounds(immutable(nonParametric))); + // Anything less than KBott or greater than K is a syntactic sort from kast.md which should not be considered + bounds.removeIf(s -> mod.subsorts().lessThanEq(s, Sorts.KBott()) || mod.subsorts().greaterThan(s, Sorts.K())); if (expectedSort != null && !expectedSort.name().equals(SORTPARAM_NAME)) { bounds.removeIf(s -> !mod.subsorts().lessThanEq(s, expectedSort)); } + + // For parametric sorts, each bound must bound at least one instantiation + Set parametric = + filteredEntries.stream().filter(s -> ! s.params().isEmpty()).collect(Collectors.toSet()); + bounds.removeIf(bound -> + parametric.stream().anyMatch(param -> + stream(mod.definedInstantiations().apply(param.head())) + .noneMatch(inst -> mod.subsorts().lessThanEq(inst, bound)))); + Set lub = mod.subsorts().minimal(bounds); if (lub.size() != 1) { throw KEMException.internalError("Could not compute least upper bound for rewrite sort. Possible candidates: " + lub, loc); @@ -464,37 +478,6 @@ private static Sort lub(Collection entries, Sort expectedSort, HasLocation return lub.iterator().next(); } - private static Set upperBounds(Collection bounds, Module mod) { - Set maxs = new HashSet<>(); - nextsort: - for (Sort sort : iterable(mod.allSorts())) { // for every declared sort - // Sorts at or below KBott, or above K, are assumed to be - // sorts from kast.k representing meta-syntax that is not a real sort. - // This is done to prevent variables from being inferred as KBott or - // as KList. - if (mod.subsorts().lessThanEq(sort, Sorts.KBott())) - continue; - if (mod.subsorts().greaterThan(sort, Sorts.K())) - continue; - for (Sort bound : bounds) - if (bound.params().isEmpty()) { - if (!mod.subsorts().lessThanEq(bound, sort)) - continue nextsort; - } else { - boolean any = false; - for (Sort instantiation : iterable(mod.definedInstantiations().apply(bound.head()))) { - if (mod.subsorts().lessThanEq(instantiation, sort)) { - any = true; - } - } - if (!any) - continue nextsort; - } - maxs.add(sort); - } - return maxs; - } - private Sort freshSortParam() { return Sort(SORTPARAM_NAME, Sort("Q" + freshSortParamCounter++)); }