diff --git a/kernel/src/main/java/org/kframework/parser/inner/ParseInModule.java b/kernel/src/main/java/org/kframework/parser/inner/ParseInModule.java
index 0b3b2d39b44..869a54bd894 100644
--- a/kernel/src/main/java/org/kframework/parser/inner/ParseInModule.java
+++ b/kernel/src/main/java/org/kframework/parser/inner/ParseInModule.java
@@ -22,6 +22,7 @@
import org.kframework.parser.Term;
import org.kframework.parser.TreeNodesToKORE;
import org.kframework.parser.inner.disambiguation.*;
+import org.kframework.parser.inner.disambiguation.inference.SortInferencer;
import org.kframework.parser.inner.kernel.EarleyParser;
import org.kframework.parser.inner.kernel.Scanner;
import org.kframework.parser.outer.Outer;
diff --git a/kernel/src/main/java/org/kframework/parser/inner/disambiguation/SortInferencer.java b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/SortInferencer.java
deleted file mode 100644
index c86062d58da..00000000000
--- a/kernel/src/main/java/org/kframework/parser/inner/disambiguation/SortInferencer.java
+++ /dev/null
@@ -1,707 +0,0 @@
-// Copyright (c) K Team. All Rights Reserved.
-package org.kframework.parser.inner.disambiguation;
-
-import static org.kframework.Collections.*;
-import static org.kframework.kore.KORE.*;
-
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.IdentityHashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Map.Entry;
-import java.util.Optional;
-import java.util.Set;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
-import org.kframework.attributes.Att;
-import org.kframework.builtin.KLabels;
-import org.kframework.builtin.Sorts;
-import org.kframework.compile.ResolveAnonVar;
-import org.kframework.definition.Module;
-import org.kframework.definition.NonTerminal;
-import org.kframework.definition.Production;
-import org.kframework.kore.KLabel;
-import org.kframework.kore.Sort;
-import org.kframework.kore.SortHead;
-import org.kframework.parser.Ambiguity;
-import org.kframework.parser.BoundedSort;
-import org.kframework.parser.CompactSort;
-import org.kframework.parser.Constant;
-import org.kframework.parser.InferenceResult;
-import org.kframework.parser.InferenceState;
-import org.kframework.parser.ProductionReference;
-import org.kframework.parser.Term;
-import org.kframework.parser.TermCons;
-import org.kframework.parser.VariableId;
-import org.kframework.utils.errorsystem.KEMException;
-import org.pcollections.ConsPStack;
-import scala.Tuple2;
-import scala.util.Either;
-import scala.util.Left;
-import scala.util.Right;
-
-/**
- * Disambiguation transformer which performs type checking and infers the sorts of variables.
- *
- *
The overall design is based on the algorithm described in "The Simple Essence of Algebraic
- * Subtyping: Principal Type Inference with Subtyping Made Easy" by Lionel Parreaux.
- *
- *
Specifically, we can straightforwardly treat any (non-ambiguous) term in our language as a
- * function in the SimpleSub
- *
- *
- Constants are treated as built-ins - TermCons are treated as primitive functions
- */
-public class SortInferencer {
- private final Module mod;
-
- private int id = 0;
-
- private final Map prIds = new IdentityHashMap<>();
-
- public SortInferencer(Module mod) {
- this.mod = mod;
- }
-
- /**
- * @param t - A term
- * @return Whether t is a Term which the sort inference engine can currently handle. Specifically,
- */
- public static boolean isSupported(Term t) {
- return !hasAmbiguity(t) && !hasStrictCast(t) && !hasParametricSorts(t);
- }
-
- private static boolean hasAmbiguity(Term t) {
- if (t instanceof Ambiguity) {
- return true;
- }
- if (t instanceof Constant) {
- return false;
- }
- return ((TermCons) t).items().stream().anyMatch(SortInferencer::hasAmbiguity);
- }
-
- private static boolean hasStrictCast(Term t) {
- if (t instanceof Ambiguity) {
- return ((Ambiguity) t).items().stream().anyMatch(SortInferencer::hasStrictCast);
- }
- ProductionReference pr = (ProductionReference) t;
- if (pr.production().klabel().isDefined()) {
- KLabel klabel = pr.production().klabel().get();
- String label = klabel.name();
- if (label.equals("#SyntacticCast") || label.equals("#InnerCast")) {
- return true;
- }
- }
- if (t instanceof Constant) {
- return false;
- }
- return ((TermCons) t).items().stream().anyMatch(SortInferencer::hasStrictCast);
- }
-
- private static boolean hasParametricSorts(Term t) {
- if (t instanceof Ambiguity) {
- return ((Ambiguity) t).items().stream().anyMatch(SortInferencer::hasParametricSorts);
- }
- ProductionReference pr = (ProductionReference) t;
- if (stream(pr.production().items())
- .filter(pi -> pi instanceof NonTerminal)
- .map(pi -> ((NonTerminal) pi).sort())
- .anyMatch(s -> !s.params().isEmpty())) {
- return true;
- }
- if (pr instanceof Constant) {
- return false;
- }
- return ((TermCons) t).items().stream().anyMatch(SortInferencer::hasParametricSorts);
- }
-
- public Either, Term> apply(Term t, Sort topSort, boolean isAnywhere) {
- Set> monoRes;
- try {
- InferenceState inferState =
- new InferenceState(new HashMap<>(), new HashMap<>(), new HashSet<>());
- BoundedSort itemSort = infer(t, isAnywhereRule(t, isAnywhere), inferState);
- BoundedSort topBoundedSort = sortWithoutSortVariablesToBoundedSort(topSort);
- constrain(itemSort, topBoundedSort, inferState, (ProductionReference) t);
- InferenceResult unsimplifiedRes =
- new InferenceResult<>(topBoundedSort, inferState.varSorts());
- InferenceResult res = simplify(compact(unsimplifiedRes));
- monoRes = monomorphize(res, t);
- } catch (SortInferenceError e) {
- Set errs = new HashSet<>();
- errs.add(e.asInnerParseError(t));
- return Left.apply(errs);
- }
-
- Set items = new HashSet<>();
- for (InferenceResult mono : monoRes) {
- items.add(insertCasts(t, mono, false));
- }
- if (items.size() == 1) {
- return Right.apply(items.iterator().next());
- } else {
- return Right.apply(Ambiguity.apply(items));
- }
- }
-
- private static boolean isAnywhereRule(Term t, boolean isAnywhere) {
- if (t instanceof Ambiguity) {
- throw new AssertionError("Ambiguities are not yet supported!");
- }
- t = stripBrackets(t);
- if (t instanceof Constant) {
- return false;
- }
- TermCons tc = (TermCons) t;
- // For every #RuleContent production, the first non-terminal holds a #RuleBody
- if (tc.production().sort().equals(Sorts.RuleContent())) {
- assert tc.production().nonterminals().size() >= 1
- && tc.production().nonterminal(0).sort().equals(Sorts.RuleBody());
- return isAnywhereRule(tc.get(0), isAnywhere);
- }
- // For every #RuleBody production, the first non-terminal holds the actual K term
- if (tc.production().sort().equals(Sorts.RuleBody())) {
- assert tc.production().nonterminals().size() >= 1
- && tc.production().nonterminal(0).sort().equals(Sorts.K());
- return isAnywhereRule(tc.get(0), isAnywhere);
- }
- // This is the first actual K term we encounter after stripping away rule syntax
- if (tc.production().klabel().filter(k -> k.head().equals(KLabels.KREWRITE)).isDefined()) {
- Term lhs = stripBrackets(tc.get(0));
- if (lhs instanceof Ambiguity) {
- throw new AssertionError("Ambiguities are not yet supported!");
- }
- ProductionReference lhsPr = (ProductionReference) lhs;
- return isAnywhere
- || lhsPr.production().att().contains(Att.FUNCTION())
- || lhsPr.production().att().getMacro().isDefined();
- }
- return false;
- }
-
- private static class SortInferenceError extends Exception {
- private final Optional term;
-
- SortInferenceError(String message, Optional term) {
- super(message);
- this.term = term;
- }
-
- public KEMException asInnerParseError(Term defaultTerm) {
- return KEMException.innerParserError(getMessage(), term.orElse(defaultTerm));
- }
-
- private static SortInferenceError constrainError(Sort lhs, Sort rhs, ProductionReference pr) {
- String msg =
- "Unexpected sort "
- + lhs
- + " for term parsed as production "
- + pr.production()
- + ". Expected: "
- + rhs;
- return new SortInferenceError(msg, Optional.of(pr));
- }
-
- private static SortInferenceError latticeOpError(
- LatticeOpError err, Term t, Optional name) {
-
- String msg =
- "Sort"
- + name.map(n -> " of " + n + " ").orElse(" ")
- + "inferred as "
- + (err.polarity ? "least upper bound" : "greatest lower bound")
- + " of "
- + err.sorts
- + ", but ";
- if (err.candidates.isEmpty()) {
- msg += "no such bound exists.";
- }
- if (!err.candidates.isEmpty()) {
- msg += "candidate bounds are incomparable: " + err.candidates + ".";
- }
- return new SortInferenceError(msg, Optional.of(t));
- }
- }
-
- /**
- * @param t - The term we want to infer the type of
- * @param isAnywhereRule - Whether t is a rule which can be applied anywhere in a configuration
- * @param inferState - All state maintained during inference, which will be updated throughout
- * with sorts for all contained variables
- * @return The unsimplified sort of the input term
- * @throws SortInferenceError - an exception indicating that the term is not well-typed
- */
- private BoundedSort infer(Term t, boolean isAnywhereRule, InferenceState inferState)
- throws SortInferenceError {
- if (t instanceof Ambiguity) {
- throw new AssertionError("Ambiguities are not yet supported!");
- }
-
- ProductionReference pr = (ProductionReference) t;
- if (!prIds.containsKey(pr)) {
- prIds.put(pr, id);
- id++;
- }
- addParamsFromProduction(inferState, pr);
-
- if (pr instanceof Constant c) {
- if (c.production().sort().equals(Sorts.KVariable())
- || c.production().sort().equals(Sorts.KConfigVar())) {
- VariableId varId = varId(c);
- if (!inferState.varSorts().containsKey(varId)) {
- inferState
- .varSorts()
- .put(varId, new BoundedSort.Variable(new ArrayList<>(), new ArrayList<>()));
- }
- return inferState.varSorts().get(varId);
- }
- return sortToBoundedSort(c.production().sort(), pr, inferState.params());
- }
-
- TermCons tc = (TermCons) pr;
- if (isAnywhereRule
- && tc.production().klabel().filter(k -> k.head().equals(KLabels.KREWRITE)).isDefined()) {
- // For rules which apply anywhere, the overall sort cannot be wider than the LHS
- BoundedSort lhsSort = infer(tc.get(0), false, inferState);
- // To prevent widening, we constrain RHS's inferred sort <: LHS's declared sort.
- //
- // Note that we do actually need the LHS's declared sort. The LHS's inferred sort
- // is a variable X with a bound L <: X, and constraining against X would just add a
- // new lower bound aka permit widening.
- BoundedSort lhsDeclaredSort =
- sortToBoundedSort(
- ((ProductionReference) stripBrackets(tc.get(0))).production().sort(),
- pr,
- inferState.params());
- BoundedSort rhsSort = infer(tc.get(1), false, inferState);
- constrain(rhsSort, lhsDeclaredSort, inferState, (ProductionReference) tc.get(1));
- return lhsSort;
- }
-
- for (int prodI = 0, tcI = 0; prodI < tc.production().items().size(); prodI++) {
- if (!(tc.production().items().apply(prodI) instanceof NonTerminal nt)) {
- continue;
- }
- BoundedSort expectedSort = sortToBoundedSort(nt.sort(), pr, inferState.params());
- BoundedSort childSort = infer(tc.get(tcI), isAnywhereRule, inferState);
- constrain(childSort, expectedSort, inferState, pr);
- tcI++;
- }
- BoundedSort resSort = new BoundedSort.Variable(new ArrayList<>(), new ArrayList<>());
- constrain(
- sortToBoundedSort(tc.production().sort(), pr, inferState.params()),
- resSort,
- inferState,
- pr);
- return resSort;
- }
-
- private void addParamsFromProduction(InferenceState inferState, ProductionReference pr) {
- for (Sort param : iterable(pr.production().params())) {
- inferState
- .params()
- .put(
- Tuple2.apply(pr, param),
- new BoundedSort.Variable(new ArrayList<>(), new ArrayList<>()));
- }
- }
-
- private void constrain(
- BoundedSort lhs, BoundedSort rhs, InferenceState inferState, ProductionReference pr)
- throws SortInferenceError {
- if (lhs.equals(rhs) || inferState.constraintCache().contains(Tuple2.apply(lhs, rhs))) {
- return;
- }
-
- if (lhs instanceof BoundedSort.Variable lhsVar) {
- inferState.constraintCache().add(Tuple2.apply(lhs, rhs));
- lhsVar.upperBounds().add(rhs);
- for (BoundedSort lhsLower : lhsVar.lowerBounds()) {
- constrain(lhsLower, rhs, inferState, pr);
- }
- return;
- }
-
- if (rhs instanceof BoundedSort.Variable rhsVar) {
- inferState.constraintCache().add(Tuple2.apply(lhs, rhs));
- rhsVar.lowerBounds().add(lhs);
- for (BoundedSort rhsUpper : rhsVar.upperBounds()) {
- constrain(lhs, rhsUpper, inferState, pr);
- }
- return;
- }
-
- // If they are primitive sorts, we can check the sort poset directly
- BoundedSort.Constructor lhsCtor = (BoundedSort.Constructor) lhs;
- BoundedSort.Constructor rhsCtor = (BoundedSort.Constructor) rhs;
- if (lhsCtor.head().params() == 0 && rhsCtor.head().params() == 0) {
- Sort lhsSort = new org.kframework.kore.ADT.Sort(lhsCtor.head().name(), Seq());
- Sort rhsSort = new org.kframework.kore.ADT.Sort(rhsCtor.head().name(), Seq());
- if (mod.subsorts().lessThanEq(lhsSort, rhsSort)) {
- return;
- }
- throw SortInferenceError.constrainError(lhsSort, rhsSort, pr);
- }
-
- throw new AssertionError("Parametric sorts are not yet supported!");
- }
-
- private CompactSort compact(BoundedSort sort, boolean polarity) {
- if (sort instanceof BoundedSort.Constructor ctor) {
- if (ctor.head().params() == 0) {
- Set ctors = new HashSet<>();
- ctors.add(ctor.head());
- return new CompactSort(new HashSet<>(), ctors);
- }
- throw new AssertionError("Parametric sorts are not yet supported!");
- }
- BoundedSort.Variable var = (BoundedSort.Variable) sort;
- List bounds = polarity ? var.lowerBounds() : var.upperBounds();
-
- Set vars = new HashSet<>();
- Set ctors = new HashSet<>();
- vars.add(var);
- for (BoundedSort bound : bounds) {
- CompactSort compactBound = compact(bound, polarity);
- vars.addAll(compactBound.vars());
- ctors.addAll(compactBound.ctors());
- }
- return new CompactSort(vars, ctors);
- }
-
- private InferenceResult compact(InferenceResult res) {
- CompactSort sort = compact(res.sort(), true);
-
- Map varSorts = new HashMap<>();
- for (Map.Entry entry : res.varSorts().entrySet()) {
- varSorts.put(entry.getKey(), compact(entry.getValue(), false));
- }
-
- return new InferenceResult<>(sort, varSorts);
- }
-
- private InferenceResult simplify(InferenceResult res)
- throws SortInferenceError {
-
- Map, Set> coOccurrences =
- analyzeCoOccurrences(res, CoOccurMode.ALWAYS);
- Map> varSubst = new HashMap<>();
- // Simplify away all those variables that only occur in negative (resp. positive) position.
- Set allVars =
- coOccurrences.keySet().stream().map(Tuple2::_1).collect(Collectors.toSet());
- allVars.forEach(
- (v) -> {
- boolean negative = coOccurrences.containsKey(Tuple2.apply(v, false));
- boolean positive = coOccurrences.containsKey(Tuple2.apply(v, true));
- if ((negative && !positive) || (!negative && positive)) {
- varSubst.put(v, Optional.empty());
- }
- });
-
- List pols = new ArrayList<>();
- pols.add(false);
- pols.add(true);
-
- for (Boolean pol : pols) {
- for (BoundedSort.Variable v : allVars) {
- for (BoundedSort co : coOccurrences.getOrDefault(Tuple2.apply(v, pol), new HashSet<>())) {
- if (co instanceof BoundedSort.Variable w) {
- if (v.equals(w) || varSubst.containsKey(w)) {
- continue;
- }
- if (coOccurrences.getOrDefault(Tuple2.apply(w, pol), new HashSet<>()).contains(v)) {
- // v and w co-occur in the given polarity, so we unify w into v
- varSubst.put(w, Optional.of(v));
- // we also need to update v's co-occurrences correspondingly
- // (intersecting with w's)
- coOccurrences
- .get(Tuple2.apply(v, !pol))
- .retainAll(coOccurrences.get(Tuple2.apply(w, !pol)));
- coOccurrences.get(Tuple2.apply(v, !pol)).add(v);
- }
- continue;
- }
- // This is not a variable, so check if we have a sandwich co <: v <: co
- // and can thus simplify away v
- if (coOccurrences.getOrDefault(Tuple2.apply(v, !pol), new HashSet<>()).contains(co)) {
- varSubst.put(v, Optional.empty());
- }
- }
- }
- }
-
- CompactSort newSort = applySubstitutions(res.sort(), varSubst);
- Map newVarSorts =
- res.varSorts().entrySet().stream()
- .collect(
- Collectors.toMap(
- (Entry::getKey), (e) -> applySubstitutions(e.getValue(), varSubst)));
- return new InferenceResult<>(newSort, newVarSorts);
- }
-
- /** Modes for the co-occurrence analysis. */
- private enum CoOccurMode {
- /** Record only those sorts which always co-occur with a given variable and polarity. */
- ALWAYS,
- /** Record any sort that ever co-occurs with a given variable and polarity. */
- EVER
- }
-
- private Map, Set> analyzeCoOccurrences(
- InferenceResult res, CoOccurMode mode) {
- Map, Set> coOccurrences = new HashMap<>();
-
- // Boolean is used to represent polarity - true is positive, false is negative
- List> compactSorts = new ArrayList<>();
- // The sort of the overall term is positive
- compactSorts.add(Tuple2.apply(res.sort(), true));
- // The sorts of variables are negative
- compactSorts.addAll(
- res.varSorts().values().stream().map((v) -> Tuple2.apply(v, false)).toList());
- compactSorts.forEach(
- polSort -> updateCoOccurrences(polSort._1, polSort._2, mode, coOccurrences));
-
- return coOccurrences;
- }
-
- /**
- * Update the co-occurrence analysis results so-far to account for the occurrences within sort
- *
- * @param sort - The sort which we are processing
- * @param polarity - The polarity of the provided sort
- * @param coOccurrences - mutated to record all co-occurrences in each variable occurring in sort
- */
- private void updateCoOccurrences(
- CompactSort sort,
- boolean polarity,
- CoOccurMode mode,
- Map, Set> coOccurrences) {
- for (BoundedSort.Variable var : sort.vars()) {
- Set newOccurs =
- Stream.concat(
- sort.vars().stream().map(v -> (BoundedSort) v),
- sort.ctors().stream().map(BoundedSort.Constructor::new))
- .collect(Collectors.toSet());
- Tuple2 polVar = Tuple2.apply(var, polarity);
- if (coOccurrences.containsKey(polVar)) {
- switch (mode) {
- case ALWAYS -> coOccurrences.get(polVar).retainAll(newOccurs);
- case EVER -> coOccurrences.get(polVar).addAll(newOccurs);
- }
- } else {
- coOccurrences.put(polVar, newOccurs);
- }
- }
- }
-
- /**
- * Apply a substitution to a sort
- *
- * @param sort - The input sort
- * @param varSubst - A map describing the substitution - a -> None indicates that a should be
- * removed form the sort - a -> Some(b) indicates that a should be replaced by b
- * @return sort with the substitution applied
- */
- private static CompactSort applySubstitutions(
- CompactSort sort, Map> varSubst) {
- Set vars = new HashSet<>();
- for (BoundedSort.Variable var : sort.vars()) {
- if (!varSubst.containsKey(var)) {
- vars.add(var);
- continue;
- }
- varSubst.get(var).ifPresent(vars::add);
- }
- Set ctors = new HashSet<>(sort.ctors());
- return new CompactSort(vars, ctors);
- }
-
- private Set> monomorphize(InferenceResult res, Term t)
- throws SortInferenceError {
- Map, Set> bounds =
- analyzeCoOccurrences(res, CoOccurMode.EVER);
- Set allVars =
- bounds.keySet().stream().map(Tuple2::_1).collect(Collectors.toSet());
- Set