From c5a44e8dd76f4d553295d1b5c3e2c6c90c89cabb Mon Sep 17 00:00:00 2001 From: Scott Guest Date: Mon, 27 Nov 2023 22:26:36 -0500 Subject: [PATCH] Removed scala code and refactored accordingly. --- .../parser/inner/ParseInModule.java | 1 + .../inner/disambiguation/SortInferencer.java | 707 ------------------ .../disambiguation/inference/BoundedSort.java | 30 + .../disambiguation/inference/CompactSort.java | 124 +++ .../inference/InferenceDriver.java | 93 +++ .../disambiguation/inference/ParamId.java | 28 + .../inference/SortInferenceError.java | 50 ++ .../inference/SortInferencer.java | 535 +++++++++++++ .../inference/SortVariable.java | 5 + .../disambiguation/inference/TermSort.java | 22 + .../disambiguation/inference/VariableId.java | 42 ++ .../utils/errorsystem/KEMException.java | 7 +- .../kframework/parser/InferenceSorts.scala | 31 - 13 files changed, 933 insertions(+), 742 deletions(-) delete mode 100644 kernel/src/main/java/org/kframework/parser/inner/disambiguation/SortInferencer.java create mode 100644 kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/BoundedSort.java create mode 100644 kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/CompactSort.java create mode 100644 kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/InferenceDriver.java create mode 100644 kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/ParamId.java create mode 100644 kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortInferenceError.java create mode 100644 kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortInferencer.java create mode 100644 kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortVariable.java create mode 100644 kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/TermSort.java create mode 100644 kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/VariableId.java delete mode 100644 kore/src/main/scala/org/kframework/parser/InferenceSorts.scala diff --git a/kernel/src/main/java/org/kframework/parser/inner/ParseInModule.java b/kernel/src/main/java/org/kframework/parser/inner/ParseInModule.java index 0b3b2d39b44..869a54bd894 100644 --- a/kernel/src/main/java/org/kframework/parser/inner/ParseInModule.java +++ b/kernel/src/main/java/org/kframework/parser/inner/ParseInModule.java @@ -22,6 +22,7 @@ import org.kframework.parser.Term; import org.kframework.parser.TreeNodesToKORE; import org.kframework.parser.inner.disambiguation.*; +import org.kframework.parser.inner.disambiguation.inference.SortInferencer; import org.kframework.parser.inner.kernel.EarleyParser; import org.kframework.parser.inner.kernel.Scanner; import org.kframework.parser.outer.Outer; diff --git a/kernel/src/main/java/org/kframework/parser/inner/disambiguation/SortInferencer.java b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/SortInferencer.java deleted file mode 100644 index c86062d58da..00000000000 --- a/kernel/src/main/java/org/kframework/parser/inner/disambiguation/SortInferencer.java +++ /dev/null @@ -1,707 +0,0 @@ -// Copyright (c) K Team. All Rights Reserved. -package org.kframework.parser.inner.disambiguation; - -import static org.kframework.Collections.*; -import static org.kframework.kore.KORE.*; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.IdentityHashMap; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import org.kframework.attributes.Att; -import org.kframework.builtin.KLabels; -import org.kframework.builtin.Sorts; -import org.kframework.compile.ResolveAnonVar; -import org.kframework.definition.Module; -import org.kframework.definition.NonTerminal; -import org.kframework.definition.Production; -import org.kframework.kore.KLabel; -import org.kframework.kore.Sort; -import org.kframework.kore.SortHead; -import org.kframework.parser.Ambiguity; -import org.kframework.parser.BoundedSort; -import org.kframework.parser.CompactSort; -import org.kframework.parser.Constant; -import org.kframework.parser.InferenceResult; -import org.kframework.parser.InferenceState; -import org.kframework.parser.ProductionReference; -import org.kframework.parser.Term; -import org.kframework.parser.TermCons; -import org.kframework.parser.VariableId; -import org.kframework.utils.errorsystem.KEMException; -import org.pcollections.ConsPStack; -import scala.Tuple2; -import scala.util.Either; -import scala.util.Left; -import scala.util.Right; - -/** - * Disambiguation transformer which performs type checking and infers the sorts of variables. - * - *

The overall design is based on the algorithm described in "The Simple Essence of Algebraic - * Subtyping: Principal Type Inference with Subtyping Made Easy" by Lionel Parreaux. - * - *

Specifically, we can straightforwardly treat any (non-ambiguous) term in our language as a - * function in the SimpleSub - * - *

- Constants are treated as built-ins - TermCons are treated as primitive functions - */ -public class SortInferencer { - private final Module mod; - - private int id = 0; - - private final Map prIds = new IdentityHashMap<>(); - - public SortInferencer(Module mod) { - this.mod = mod; - } - - /** - * @param t - A term - * @return Whether t is a Term which the sort inference engine can currently handle. Specifically, - */ - public static boolean isSupported(Term t) { - return !hasAmbiguity(t) && !hasStrictCast(t) && !hasParametricSorts(t); - } - - private static boolean hasAmbiguity(Term t) { - if (t instanceof Ambiguity) { - return true; - } - if (t instanceof Constant) { - return false; - } - return ((TermCons) t).items().stream().anyMatch(SortInferencer::hasAmbiguity); - } - - private static boolean hasStrictCast(Term t) { - if (t instanceof Ambiguity) { - return ((Ambiguity) t).items().stream().anyMatch(SortInferencer::hasStrictCast); - } - ProductionReference pr = (ProductionReference) t; - if (pr.production().klabel().isDefined()) { - KLabel klabel = pr.production().klabel().get(); - String label = klabel.name(); - if (label.equals("#SyntacticCast") || label.equals("#InnerCast")) { - return true; - } - } - if (t instanceof Constant) { - return false; - } - return ((TermCons) t).items().stream().anyMatch(SortInferencer::hasStrictCast); - } - - private static boolean hasParametricSorts(Term t) { - if (t instanceof Ambiguity) { - return ((Ambiguity) t).items().stream().anyMatch(SortInferencer::hasParametricSorts); - } - ProductionReference pr = (ProductionReference) t; - if (stream(pr.production().items()) - .filter(pi -> pi instanceof NonTerminal) - .map(pi -> ((NonTerminal) pi).sort()) - .anyMatch(s -> !s.params().isEmpty())) { - return true; - } - if (pr instanceof Constant) { - return false; - } - return ((TermCons) t).items().stream().anyMatch(SortInferencer::hasParametricSorts); - } - - public Either, Term> apply(Term t, Sort topSort, boolean isAnywhere) { - Set> monoRes; - try { - InferenceState inferState = - new InferenceState(new HashMap<>(), new HashMap<>(), new HashSet<>()); - BoundedSort itemSort = infer(t, isAnywhereRule(t, isAnywhere), inferState); - BoundedSort topBoundedSort = sortWithoutSortVariablesToBoundedSort(topSort); - constrain(itemSort, topBoundedSort, inferState, (ProductionReference) t); - InferenceResult unsimplifiedRes = - new InferenceResult<>(topBoundedSort, inferState.varSorts()); - InferenceResult res = simplify(compact(unsimplifiedRes)); - monoRes = monomorphize(res, t); - } catch (SortInferenceError e) { - Set errs = new HashSet<>(); - errs.add(e.asInnerParseError(t)); - return Left.apply(errs); - } - - Set items = new HashSet<>(); - for (InferenceResult mono : monoRes) { - items.add(insertCasts(t, mono, false)); - } - if (items.size() == 1) { - return Right.apply(items.iterator().next()); - } else { - return Right.apply(Ambiguity.apply(items)); - } - } - - private static boolean isAnywhereRule(Term t, boolean isAnywhere) { - if (t instanceof Ambiguity) { - throw new AssertionError("Ambiguities are not yet supported!"); - } - t = stripBrackets(t); - if (t instanceof Constant) { - return false; - } - TermCons tc = (TermCons) t; - // For every #RuleContent production, the first non-terminal holds a #RuleBody - if (tc.production().sort().equals(Sorts.RuleContent())) { - assert tc.production().nonterminals().size() >= 1 - && tc.production().nonterminal(0).sort().equals(Sorts.RuleBody()); - return isAnywhereRule(tc.get(0), isAnywhere); - } - // For every #RuleBody production, the first non-terminal holds the actual K term - if (tc.production().sort().equals(Sorts.RuleBody())) { - assert tc.production().nonterminals().size() >= 1 - && tc.production().nonterminal(0).sort().equals(Sorts.K()); - return isAnywhereRule(tc.get(0), isAnywhere); - } - // This is the first actual K term we encounter after stripping away rule syntax - if (tc.production().klabel().filter(k -> k.head().equals(KLabels.KREWRITE)).isDefined()) { - Term lhs = stripBrackets(tc.get(0)); - if (lhs instanceof Ambiguity) { - throw new AssertionError("Ambiguities are not yet supported!"); - } - ProductionReference lhsPr = (ProductionReference) lhs; - return isAnywhere - || lhsPr.production().att().contains(Att.FUNCTION()) - || lhsPr.production().att().getMacro().isDefined(); - } - return false; - } - - private static class SortInferenceError extends Exception { - private final Optional term; - - SortInferenceError(String message, Optional term) { - super(message); - this.term = term; - } - - public KEMException asInnerParseError(Term defaultTerm) { - return KEMException.innerParserError(getMessage(), term.orElse(defaultTerm)); - } - - private static SortInferenceError constrainError(Sort lhs, Sort rhs, ProductionReference pr) { - String msg = - "Unexpected sort " - + lhs - + " for term parsed as production " - + pr.production() - + ". Expected: " - + rhs; - return new SortInferenceError(msg, Optional.of(pr)); - } - - private static SortInferenceError latticeOpError( - LatticeOpError err, Term t, Optional name) { - - String msg = - "Sort" - + name.map(n -> " of " + n + " ").orElse(" ") - + "inferred as " - + (err.polarity ? "least upper bound" : "greatest lower bound") - + " of " - + err.sorts - + ", but "; - if (err.candidates.isEmpty()) { - msg += "no such bound exists."; - } - if (!err.candidates.isEmpty()) { - msg += "candidate bounds are incomparable: " + err.candidates + "."; - } - return new SortInferenceError(msg, Optional.of(t)); - } - } - - /** - * @param t - The term we want to infer the type of - * @param isAnywhereRule - Whether t is a rule which can be applied anywhere in a configuration - * @param inferState - All state maintained during inference, which will be updated throughout - * with sorts for all contained variables - * @return The unsimplified sort of the input term - * @throws SortInferenceError - an exception indicating that the term is not well-typed - */ - private BoundedSort infer(Term t, boolean isAnywhereRule, InferenceState inferState) - throws SortInferenceError { - if (t instanceof Ambiguity) { - throw new AssertionError("Ambiguities are not yet supported!"); - } - - ProductionReference pr = (ProductionReference) t; - if (!prIds.containsKey(pr)) { - prIds.put(pr, id); - id++; - } - addParamsFromProduction(inferState, pr); - - if (pr instanceof Constant c) { - if (c.production().sort().equals(Sorts.KVariable()) - || c.production().sort().equals(Sorts.KConfigVar())) { - VariableId varId = varId(c); - if (!inferState.varSorts().containsKey(varId)) { - inferState - .varSorts() - .put(varId, new BoundedSort.Variable(new ArrayList<>(), new ArrayList<>())); - } - return inferState.varSorts().get(varId); - } - return sortToBoundedSort(c.production().sort(), pr, inferState.params()); - } - - TermCons tc = (TermCons) pr; - if (isAnywhereRule - && tc.production().klabel().filter(k -> k.head().equals(KLabels.KREWRITE)).isDefined()) { - // For rules which apply anywhere, the overall sort cannot be wider than the LHS - BoundedSort lhsSort = infer(tc.get(0), false, inferState); - // To prevent widening, we constrain RHS's inferred sort <: LHS's declared sort. - // - // Note that we do actually need the LHS's declared sort. The LHS's inferred sort - // is a variable X with a bound L <: X, and constraining against X would just add a - // new lower bound aka permit widening. - BoundedSort lhsDeclaredSort = - sortToBoundedSort( - ((ProductionReference) stripBrackets(tc.get(0))).production().sort(), - pr, - inferState.params()); - BoundedSort rhsSort = infer(tc.get(1), false, inferState); - constrain(rhsSort, lhsDeclaredSort, inferState, (ProductionReference) tc.get(1)); - return lhsSort; - } - - for (int prodI = 0, tcI = 0; prodI < tc.production().items().size(); prodI++) { - if (!(tc.production().items().apply(prodI) instanceof NonTerminal nt)) { - continue; - } - BoundedSort expectedSort = sortToBoundedSort(nt.sort(), pr, inferState.params()); - BoundedSort childSort = infer(tc.get(tcI), isAnywhereRule, inferState); - constrain(childSort, expectedSort, inferState, pr); - tcI++; - } - BoundedSort resSort = new BoundedSort.Variable(new ArrayList<>(), new ArrayList<>()); - constrain( - sortToBoundedSort(tc.production().sort(), pr, inferState.params()), - resSort, - inferState, - pr); - return resSort; - } - - private void addParamsFromProduction(InferenceState inferState, ProductionReference pr) { - for (Sort param : iterable(pr.production().params())) { - inferState - .params() - .put( - Tuple2.apply(pr, param), - new BoundedSort.Variable(new ArrayList<>(), new ArrayList<>())); - } - } - - private void constrain( - BoundedSort lhs, BoundedSort rhs, InferenceState inferState, ProductionReference pr) - throws SortInferenceError { - if (lhs.equals(rhs) || inferState.constraintCache().contains(Tuple2.apply(lhs, rhs))) { - return; - } - - if (lhs instanceof BoundedSort.Variable lhsVar) { - inferState.constraintCache().add(Tuple2.apply(lhs, rhs)); - lhsVar.upperBounds().add(rhs); - for (BoundedSort lhsLower : lhsVar.lowerBounds()) { - constrain(lhsLower, rhs, inferState, pr); - } - return; - } - - if (rhs instanceof BoundedSort.Variable rhsVar) { - inferState.constraintCache().add(Tuple2.apply(lhs, rhs)); - rhsVar.lowerBounds().add(lhs); - for (BoundedSort rhsUpper : rhsVar.upperBounds()) { - constrain(lhs, rhsUpper, inferState, pr); - } - return; - } - - // If they are primitive sorts, we can check the sort poset directly - BoundedSort.Constructor lhsCtor = (BoundedSort.Constructor) lhs; - BoundedSort.Constructor rhsCtor = (BoundedSort.Constructor) rhs; - if (lhsCtor.head().params() == 0 && rhsCtor.head().params() == 0) { - Sort lhsSort = new org.kframework.kore.ADT.Sort(lhsCtor.head().name(), Seq()); - Sort rhsSort = new org.kframework.kore.ADT.Sort(rhsCtor.head().name(), Seq()); - if (mod.subsorts().lessThanEq(lhsSort, rhsSort)) { - return; - } - throw SortInferenceError.constrainError(lhsSort, rhsSort, pr); - } - - throw new AssertionError("Parametric sorts are not yet supported!"); - } - - private CompactSort compact(BoundedSort sort, boolean polarity) { - if (sort instanceof BoundedSort.Constructor ctor) { - if (ctor.head().params() == 0) { - Set ctors = new HashSet<>(); - ctors.add(ctor.head()); - return new CompactSort(new HashSet<>(), ctors); - } - throw new AssertionError("Parametric sorts are not yet supported!"); - } - BoundedSort.Variable var = (BoundedSort.Variable) sort; - List bounds = polarity ? var.lowerBounds() : var.upperBounds(); - - Set vars = new HashSet<>(); - Set ctors = new HashSet<>(); - vars.add(var); - for (BoundedSort bound : bounds) { - CompactSort compactBound = compact(bound, polarity); - vars.addAll(compactBound.vars()); - ctors.addAll(compactBound.ctors()); - } - return new CompactSort(vars, ctors); - } - - private InferenceResult compact(InferenceResult res) { - CompactSort sort = compact(res.sort(), true); - - Map varSorts = new HashMap<>(); - for (Map.Entry entry : res.varSorts().entrySet()) { - varSorts.put(entry.getKey(), compact(entry.getValue(), false)); - } - - return new InferenceResult<>(sort, varSorts); - } - - private InferenceResult simplify(InferenceResult res) - throws SortInferenceError { - - Map, Set> coOccurrences = - analyzeCoOccurrences(res, CoOccurMode.ALWAYS); - Map> varSubst = new HashMap<>(); - // Simplify away all those variables that only occur in negative (resp. positive) position. - Set allVars = - coOccurrences.keySet().stream().map(Tuple2::_1).collect(Collectors.toSet()); - allVars.forEach( - (v) -> { - boolean negative = coOccurrences.containsKey(Tuple2.apply(v, false)); - boolean positive = coOccurrences.containsKey(Tuple2.apply(v, true)); - if ((negative && !positive) || (!negative && positive)) { - varSubst.put(v, Optional.empty()); - } - }); - - List pols = new ArrayList<>(); - pols.add(false); - pols.add(true); - - for (Boolean pol : pols) { - for (BoundedSort.Variable v : allVars) { - for (BoundedSort co : coOccurrences.getOrDefault(Tuple2.apply(v, pol), new HashSet<>())) { - if (co instanceof BoundedSort.Variable w) { - if (v.equals(w) || varSubst.containsKey(w)) { - continue; - } - if (coOccurrences.getOrDefault(Tuple2.apply(w, pol), new HashSet<>()).contains(v)) { - // v and w co-occur in the given polarity, so we unify w into v - varSubst.put(w, Optional.of(v)); - // we also need to update v's co-occurrences correspondingly - // (intersecting with w's) - coOccurrences - .get(Tuple2.apply(v, !pol)) - .retainAll(coOccurrences.get(Tuple2.apply(w, !pol))); - coOccurrences.get(Tuple2.apply(v, !pol)).add(v); - } - continue; - } - // This is not a variable, so check if we have a sandwich co <: v <: co - // and can thus simplify away v - if (coOccurrences.getOrDefault(Tuple2.apply(v, !pol), new HashSet<>()).contains(co)) { - varSubst.put(v, Optional.empty()); - } - } - } - } - - CompactSort newSort = applySubstitutions(res.sort(), varSubst); - Map newVarSorts = - res.varSorts().entrySet().stream() - .collect( - Collectors.toMap( - (Entry::getKey), (e) -> applySubstitutions(e.getValue(), varSubst))); - return new InferenceResult<>(newSort, newVarSorts); - } - - /** Modes for the co-occurrence analysis. */ - private enum CoOccurMode { - /** Record only those sorts which always co-occur with a given variable and polarity. */ - ALWAYS, - /** Record any sort that ever co-occurs with a given variable and polarity. */ - EVER - } - - private Map, Set> analyzeCoOccurrences( - InferenceResult res, CoOccurMode mode) { - Map, Set> coOccurrences = new HashMap<>(); - - // Boolean is used to represent polarity - true is positive, false is negative - List> compactSorts = new ArrayList<>(); - // The sort of the overall term is positive - compactSorts.add(Tuple2.apply(res.sort(), true)); - // The sorts of variables are negative - compactSorts.addAll( - res.varSorts().values().stream().map((v) -> Tuple2.apply(v, false)).toList()); - compactSorts.forEach( - polSort -> updateCoOccurrences(polSort._1, polSort._2, mode, coOccurrences)); - - return coOccurrences; - } - - /** - * Update the co-occurrence analysis results so-far to account for the occurrences within sort - * - * @param sort - The sort which we are processing - * @param polarity - The polarity of the provided sort - * @param coOccurrences - mutated to record all co-occurrences in each variable occurring in sort - */ - private void updateCoOccurrences( - CompactSort sort, - boolean polarity, - CoOccurMode mode, - Map, Set> coOccurrences) { - for (BoundedSort.Variable var : sort.vars()) { - Set newOccurs = - Stream.concat( - sort.vars().stream().map(v -> (BoundedSort) v), - sort.ctors().stream().map(BoundedSort.Constructor::new)) - .collect(Collectors.toSet()); - Tuple2 polVar = Tuple2.apply(var, polarity); - if (coOccurrences.containsKey(polVar)) { - switch (mode) { - case ALWAYS -> coOccurrences.get(polVar).retainAll(newOccurs); - case EVER -> coOccurrences.get(polVar).addAll(newOccurs); - } - } else { - coOccurrences.put(polVar, newOccurs); - } - } - } - - /** - * Apply a substitution to a sort - * - * @param sort - The input sort - * @param varSubst - A map describing the substitution - a -> None indicates that a should be - * removed form the sort - a -> Some(b) indicates that a should be replaced by b - * @return sort with the substitution applied - */ - private static CompactSort applySubstitutions( - CompactSort sort, Map> varSubst) { - Set vars = new HashSet<>(); - for (BoundedSort.Variable var : sort.vars()) { - if (!varSubst.containsKey(var)) { - vars.add(var); - continue; - } - varSubst.get(var).ifPresent(vars::add); - } - Set ctors = new HashSet<>(sort.ctors()); - return new CompactSort(vars, ctors); - } - - private Set> monomorphize(InferenceResult res, Term t) - throws SortInferenceError { - Map, Set> bounds = - analyzeCoOccurrences(res, CoOccurMode.EVER); - Set allVars = - bounds.keySet().stream().map(Tuple2::_1).collect(Collectors.toSet()); - Set> instantiations = new HashSet<>(); - instantiations.add(new HashMap<>()); - for (BoundedSort.Variable var : allVars) { - Set> newInstantiations = new HashSet<>(); - for (Map instant : instantiations) { - newInstantiations.addAll(monomorphizeInVar(instant, var, bounds)); - } - // TODO: Throw a nice error message here - assert !newInstantiations.isEmpty(); - instantiations = newInstantiations; - } - - Set> monos = new HashSet<>(); - SortInferenceError lastError = null; - instloop: - for (Map inst : instantiations) { - Either sortRes = compactSortToSort(res.sort(), true, inst); - if (sortRes.isLeft()) { - lastError = SortInferenceError.latticeOpError(sortRes.left().get(), t, Optional.empty()); - continue; - } - Sort sort = sortRes.right().get(); - Map varSorts = new HashMap<>(); - for (Entry entry : res.varSorts().entrySet()) { - Either varRes = compactSortToSort(entry.getValue(), false, inst); - if (varRes.isLeft()) { - LatticeOpError latticeErr = varRes.left().get(); - if (entry.getKey() instanceof VariableId.Anon anon) { - lastError = - SortInferenceError.latticeOpError( - latticeErr, anon.constant(), Optional.of("variable")); - } else if (entry.getKey() instanceof VariableId.Named named) { - lastError = - SortInferenceError.latticeOpError( - latticeErr, t, Optional.of("variable " + named.name())); - } - continue instloop; - } - varSorts.put(entry.getKey(), varRes.right().get()); - } - monos.add(new InferenceResult<>(sort, varSorts)); - } - if (monos.isEmpty()) { - assert lastError != null; - throw lastError; - } - return monos; - } - - private record LatticeOpError(Set sorts, Set candidates, boolean polarity) {} - - /** - * Convert a CompactSort into a Sort - * - * @param sort - A compact sort - * @param polarity - The polarity in which sort occurs. True for positive, false for negative. - * @param instantiation - A map indicating how the variables in sort should be instantiated - * @return An equivalent Sort - */ - private Either compactSortToSort( - CompactSort sort, boolean polarity, Map instantiation) { - Set sorts = sort.vars().stream().map(instantiation::get).collect(Collectors.toSet()); - sorts.addAll( - sort.ctors().stream() - .map(h -> new org.kframework.kore.ADT.Sort(h.name(), Seq())) - .collect(Collectors.toSet())); - Set bounds = - polarity ? mod.subsorts().upperBounds(sorts) : mod.subsorts().lowerBounds(sorts); - bounds.removeIf( - s -> - mod.subsorts().lessThanEq(s, Sorts.KLabel()) - || mod.subsorts().lessThanEq(s, Sorts.KBott()) - || mod.subsorts().greaterThan(s, Sorts.K())); - Set candidates = - polarity ? mod.subsorts().minimal(bounds) : mod.subsorts().maximal(bounds); - if (candidates.size() != 1) { - return Left.apply(new LatticeOpError(sorts, candidates, polarity)); - } - return Right.apply(candidates.iterator().next()); - } - - private Set> monomorphizeInVar( - Map instantiation, - BoundedSort.Variable var, - Map, Set> bounds) { - - Map> polBounds = new HashMap<>(); - polBounds.put(true, new HashSet<>()); - polBounds.put(false, new HashSet<>()); - - for (Entry> polBound : polBounds.entrySet()) { - for (BoundedSort bSort : - bounds.getOrDefault(Tuple2.apply(var, polBound.getKey()), new HashSet<>())) { - if (bSort instanceof BoundedSort.Variable bVar) { - if (instantiation.containsKey(bVar)) { - polBound.getValue().add(instantiation.get(bVar)); - } - } else if (bSort instanceof BoundedSort.Constructor lowerCtor) { - polBound.getValue().add(new org.kframework.kore.ADT.Sort(lowerCtor.head().name(), Seq())); - } - } - } - - Set range = mod.subsorts().upperBounds(polBounds.get(true)); - range.retainAll(mod.subsorts().lowerBounds(polBounds.get(false))); - - Set> insts = new HashSet<>(); - for (Sort sort : range) { - Map inst = new HashMap<>(instantiation); - inst.put(var, sort); - insts.add(inst); - } - return insts; - } - - private Term insertCasts(Term t, InferenceResult sorts, boolean existingCast) { - if (t instanceof Ambiguity) { - throw new AssertionError("Ambiguities are not yet supported!"); - } - - ProductionReference pr = (ProductionReference) t; - if (pr instanceof Constant c) { - if (c.production().sort().equals(Sorts.KVariable()) - || c.production().sort().equals(Sorts.KConfigVar())) { - Sort inferred = sorts.varSorts().get(varId(c)); - if (!existingCast) { - return wrapTermWithCast(c, inferred); - } - } - return c; - } - - TermCons tc = (TermCons) pr; - boolean isCast = - tc.production().klabel().filter(k -> k.name().startsWith("#SemanticCastTo")).isDefined(); - for (int i = 0; i < tc.items().size(); i++) { - tc = tc.with(i, insertCasts(tc.get(i), sorts, isCast)); - } - return tc; - } - - private Term wrapTermWithCast(Term t, Sort sort) { - Production cast = - mod.productionsFor().apply(KLabel("#SemanticCastTo" + sort.toString())).head(); - return TermCons.apply(ConsPStack.singleton(t), cast, t.location(), t.source()); - } - - public VariableId varId(Constant var) { - if (ResolveAnonVar.isAnonVarOrNamedAnonVar(KVariable(var.value()))) { - return new VariableId.Anon(var, prIds.get(var)); - } - return new VariableId.Named(var.value()); - } - - private static Term stripBrackets(Term tc) { - Term child = tc; - while (child instanceof TermCons - && ((TermCons) child).production().att().contains(Att.BRACKET())) { - child = ((TermCons) child).get(0); - } - return child; - } - - private static BoundedSort sortToBoundedSort( - Sort sort, - ProductionReference pr, - Map, BoundedSort.Variable> params) { - if (pr.production().params().contains(sort)) { - return params.get(Tuple2.apply(pr, sort)); - } - return new BoundedSort.Constructor(sort.head()); - } - - /** - * @param sort - The sort to convert, which must not contain any sort variables! - * @return sort as a BoundedSort - */ - private static BoundedSort.Constructor sortWithoutSortVariablesToBoundedSort(Sort sort) { - return new BoundedSort.Constructor(sort.head()); - } -} diff --git a/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/BoundedSort.java b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/BoundedSort.java new file mode 100644 index 00000000000..d70467e2d29 --- /dev/null +++ b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/BoundedSort.java @@ -0,0 +1,30 @@ +package org.kframework.parser.inner.disambiguation.inference; + +import java.util.ArrayList; +import java.util.List; +import org.kframework.kore.SortHead; + +public sealed interface BoundedSort { + record Constructor(SortHead head) implements BoundedSort {} + + // This is a class rather than a record because we want reference equality + final class Variable implements BoundedSort { + private final SortVariable sortVar = new SortVariable(); + private final List lowerBounds = new ArrayList<>(); + private final List upperBounds = new ArrayList<>(); + + public Variable() {} + + public SortVariable sortVar() { + return sortVar; + } + + public List lowerBounds() { + return lowerBounds; + } + + public List upperBounds() { + return upperBounds; + } + } +} diff --git a/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/CompactSort.java b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/CompactSort.java new file mode 100644 index 00000000000..134123766e5 --- /dev/null +++ b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/CompactSort.java @@ -0,0 +1,124 @@ +package org.kframework.parser.inner.disambiguation.inference; + +import static org.kframework.Collections.*; + +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import org.kframework.POSet; +import org.kframework.builtin.Sorts; +import org.kframework.kore.Sort; +import org.kframework.kore.SortHead; +import scala.util.Either; +import scala.util.Left; +import scala.util.Right; + +public record CompactSort(Set vars, Set ctors) { + + public CompactSort(SortVariable var) { + this( + new HashSet<>() { + { + add(var); + } + }, + new HashSet<>()); + } + + /** + * Compactify a BoundedSort, chasing all transitive bounds + * + * @param sort - The BoundedSort to make compact + * @param polarity - The polarity where sort occurs + * @return A CompactSort containing all bounds represented by sort + */ + public static CompactSort makeCompact(BoundedSort sort, boolean polarity) { + if (sort instanceof BoundedSort.Constructor ctor) { + if (ctor.head().params() == 0) { + Set ctors = new HashSet<>(); + ctors.add(ctor.head()); + return new CompactSort(new HashSet<>(), ctors); + } + throw new AssertionError("Parametric sorts are not yet supported!"); + } + BoundedSort.Variable var = (BoundedSort.Variable) sort; + + List bounds = polarity ? var.lowerBounds() : var.upperBounds(); + + Set vars = new HashSet<>(); + Set ctors = new HashSet<>(); + vars.add(var.sortVar()); + for (BoundedSort bound : bounds) { + CompactSort compactBound = makeCompact(bound, polarity); + vars.addAll(compactBound.vars()); + ctors.addAll(compactBound.ctors()); + } + return new CompactSort(vars, ctors); + } + + /** + * Substitute variables for CompactSorts + * + * @param subst - A map where an entry v |-> Optional.of(t) indicates that the variable v should + * be replaced by t, and an entry v |-> Optional.empty() indicates that v should be removed + * entirely (effectively, replacing it with top or bottom depending on polarity). + * @return A new CompactSort with the substitution applied + */ + public CompactSort substitute(Map> subst) { + Set newVars = new HashSet<>(); + Set newCtors = new HashSet<>(ctors); + for (SortVariable var : vars) { + if (!subst.containsKey(var)) { + newVars.add(var); + continue; + } + if (subst.get(var).isPresent()) { + CompactSort newSort = subst.get(var).get(); + newVars.addAll(newSort.vars()); + newCtors.addAll(newSort.ctors()); + } + } + return new CompactSort(newVars, newCtors); + } + + /** + * An error indicating that we could not compute a type meet or join. + * + * @param sorts - The set of sorts we are trying to meet/join. + * @param candidates - The set of minimal upper bounds / maximal lower bounds of sorts. + * @param polarity - True for positive, false for negative + */ + public record LatticeOpError(Set sorts, Set candidates, boolean polarity) {} + + /** + * Convert to an equivalent Sort, instantiating variables and actually computing a type join/meet + * as appropriate. + * + * @param polarity - The polarity where this CompactSort occurs. + * @param instantiation - A map indicating how variables should be instantiated + * @param subsorts - The Sort poset + * @return An equivalent Sort + */ + public Either asSort( + boolean polarity, Map instantiation, POSet subsorts) { + Set sorts = vars.stream().map(instantiation::get).collect(Collectors.toSet()); + sorts.addAll( + ctors.stream() + .map(h -> new org.kframework.kore.ADT.Sort(h.name(), Seq())) + .collect(Collectors.toSet())); + Set bounds = polarity ? subsorts.upperBounds(sorts) : subsorts.lowerBounds(sorts); + bounds.removeIf( + s -> + subsorts.lessThanEq(s, Sorts.KLabel()) + || subsorts.lessThanEq(s, Sorts.KBott()) + || subsorts.greaterThan(s, Sorts.K())); + Set candidates = polarity ? subsorts.minimal(bounds) : subsorts.maximal(bounds); + if (candidates.size() != 1) { + return Left.apply(new LatticeOpError(sorts, candidates, polarity)); + } + return Right.apply(candidates.iterator().next()); + } +} diff --git a/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/InferenceDriver.java b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/InferenceDriver.java new file mode 100644 index 00000000000..e5ddd5517c9 --- /dev/null +++ b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/InferenceDriver.java @@ -0,0 +1,93 @@ +package org.kframework.parser.inner.disambiguation.inference; + +import static org.kframework.Collections.*; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import org.kframework.POSet; +import org.kframework.kore.Sort; +import org.kframework.parser.Constant; +import org.kframework.parser.ProductionReference; +import org.kframework.parser.Term; +import scala.Tuple2; + +public final class InferenceDriver { + private final POSet subsorts; + private final Map varSorts = new HashMap<>(); + private final Map paramSorts = new HashMap<>(); + private final Set> constraintCache = new HashSet<>(); + + public InferenceDriver(POSet subsorts) { + this.subsorts = subsorts; + } + + public BoundedSort varSort(Constant var) { + VariableId varId = VariableId.apply(var); + if (!varSorts.containsKey(varId)) { + varSorts.put(varId, new BoundedSort.Variable()); + } + return varSorts.get(varId); + } + + /** + * Convert a Sort to a BoundedSort + * + * @param sort - The Sort to convert + * @return A BoundedSort representing sort + */ + public BoundedSort sortToBoundedSort(Sort sort, ProductionReference prOrNull) { + if (prOrNull != null && prOrNull.production().params().contains(sort)) { + ParamId paramId = new ParamId(prOrNull, sort); + if (!paramSorts.containsKey(paramId)) { + paramSorts.put(paramId, new BoundedSort.Variable()); + } + return paramSorts.get(paramId); + } + return new BoundedSort.Constructor(sort.head()); + } + + public void constrain(BoundedSort lhs, BoundedSort rhs, ProductionReference pr) + throws ConstraintError { + if (lhs.equals(rhs) || constraintCache.contains(Tuple2.apply(lhs, rhs))) { + return; + } + + if (lhs instanceof BoundedSort.Variable lhsVar) { + constraintCache.add(Tuple2.apply(lhs, rhs)); + lhsVar.upperBounds().add(rhs); + for (BoundedSort lhsLower : lhsVar.lowerBounds()) { + constrain(lhsLower, rhs, pr); + } + return; + } + + if (rhs instanceof BoundedSort.Variable rhsVar) { + constraintCache.add(Tuple2.apply(lhs, rhs)); + rhsVar.lowerBounds().add(lhs); + for (BoundedSort rhsUpper : rhsVar.upperBounds()) { + constrain(lhs, rhsUpper, pr); + } + return; + } + + // If they are primitive sorts, we can check the sort poset directly + BoundedSort.Constructor lhsCtor = (BoundedSort.Constructor) lhs; + BoundedSort.Constructor rhsCtor = (BoundedSort.Constructor) rhs; + if (lhsCtor.head().params() == 0 && rhsCtor.head().params() == 0) { + Sort lhsSort = new org.kframework.kore.ADT.Sort(lhsCtor.head().name(), Seq()); + Sort rhsSort = new org.kframework.kore.ADT.Sort(rhsCtor.head().name(), Seq()); + if (subsorts.lessThanEq(lhsSort, rhsSort)) { + return; + } + throw new ConstraintError(lhsSort, rhsSort, pr); + } + + throw new AssertionError("Parametric sorts are not yet supported!"); + } + + public TermSort getResult(Term term, BoundedSort sort) { + return new TermSort<>(term, sort, varSorts); + } +} diff --git a/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/ParamId.java b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/ParamId.java new file mode 100644 index 00000000000..5fa46ea54ad --- /dev/null +++ b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/ParamId.java @@ -0,0 +1,28 @@ +package org.kframework.parser.inner.disambiguation.inference; + +import java.util.Objects; +import org.kframework.kore.Sort; +import org.kframework.parser.ProductionReference; + +public final class ParamId { + private final ProductionReference pr; + private final Sort param; + + public ParamId(ProductionReference pr, Sort param) { + this.pr = pr; + this.param = param; + } + + @Override + public boolean equals(Object o) { + if (o instanceof ParamId p) { + return this.pr == p.pr && this.param.equals(p.param); + } + return false; + } + + @Override + public int hashCode() { + return Objects.hash(System.identityHashCode(pr), param); + } +} diff --git a/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortInferenceError.java b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortInferenceError.java new file mode 100644 index 00000000000..f114359b29f --- /dev/null +++ b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortInferenceError.java @@ -0,0 +1,50 @@ +package org.kframework.parser.inner.disambiguation.inference; + +import java.util.Optional; +import org.kframework.attributes.HasLocation; +import org.kframework.kore.Sort; +import org.kframework.parser.ProductionReference; +import org.kframework.utils.errorsystem.KEMException; + +abstract sealed class SortInferenceError extends Exception { + private final Optional loc; + + public SortInferenceError(String message, HasLocation loc) { + super(message); + this.loc = Optional.of(loc); + } + + public KEMException asInnerParseError(HasLocation defaultLoc) { + return KEMException.innerParserError(getMessage(), loc.orElse(defaultLoc)); + } +} + +final class LatticeOpError extends SortInferenceError { + public LatticeOpError(CompactSort.LatticeOpError err, HasLocation loc, Optional name) { + super( + "Sort" + + name.map(n -> " of " + n + " ").orElse(" ") + + "inferred as " + + (err.polarity() ? "least upper bound" : "greatest lower bound") + + " of " + + err.sorts() + + ", but " + + (err.candidates().isEmpty() + ? "no such bound exists." + : ("candidate bounds are " + "incomparable: " + err.candidates() + ".")), + loc); + } +} + +final class ConstraintError extends SortInferenceError { + public ConstraintError(Sort lhs, Sort rhs, ProductionReference pr) { + super( + "Unexpected sort " + + lhs + + " for term parsed as production " + + pr.production() + + ". Expected: " + + rhs, + pr); + } +} diff --git a/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortInferencer.java b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortInferencer.java new file mode 100644 index 00000000000..b3730a4f34a --- /dev/null +++ b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortInferencer.java @@ -0,0 +1,535 @@ +// Copyright (c) K Team. All Rights Reserved. +package org.kframework.parser.inner.disambiguation.inference; + +import static org.kframework.Collections.*; +import static org.kframework.kore.KORE.*; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import org.kframework.attributes.Att; +import org.kframework.builtin.KLabels; +import org.kframework.builtin.Sorts; +import org.kframework.definition.Module; +import org.kframework.definition.NonTerminal; +import org.kframework.definition.Production; +import org.kframework.kore.KLabel; +import org.kframework.kore.Sort; +import org.kframework.kore.SortHead; +import org.kframework.parser.Ambiguity; +import org.kframework.parser.Constant; +import org.kframework.parser.ProductionReference; +import org.kframework.parser.Term; +import org.kframework.parser.TermCons; +import org.kframework.utils.errorsystem.KEMException; +import org.pcollections.ConsPStack; +import scala.Tuple2; +import scala.util.Either; +import scala.util.Left; +import scala.util.Right; + +/** + * Disambiguation transformer which performs type checking and infers the sorts of variables. + * + *

The overall design is heavily inspired by the algorithm described in "The Simple Essence of + * Algebraic Subtyping: Principal Type Inference with Subtyping Made Easy" by Lionel Parreaux. + * + *

Each Term can be viewed as a SimpleSub-esque term with equivalent subtyping constraints: + * + *

+ * + * Inferring the SimpleSub-esque type is then equivalent to performing SortInference. That is, we + * infer a type a1 -> ... aN -> b telling us that the variables x1, ..., xN have sorts a1, ..., aN + * and the overall Term has sort b. + * + *

Explicitly, the algorithm proceeds as follows + * + *

    + *
  1. Infer a BoundedSort for the input Term as well as all of its variables, recording each + * subtype constraint as lower and upper bounds on sort variables. + *
      + *
    • BoundedSort is directly analogous to SimpleType from SimpleSub, except that we only + * have primitive sorts (BoundedSort.Constructor) and variables (BoundedSort.Variable). + *
    • TermSort represents the "function type" of the overall Term + *
    + *
  2. Constrain the inferred BoundedSort of the overall Term as a subsort of the expected + * topSort. + *
  3. Compactify then simplify the TermSort to produce a CompactSort (analogous to producing the + * CompactType in SimpleSub). + *
  4. Convert the inferred CompactSort into a normal K Sort + *
      + *
    • Monomorphize each sort variable, allowing it to take any value between its recorded + * bounds and possibly producing multiple valid monomorphizations. + *
    • For each type intersection/union, actually compute the corresponding meet/join on the + * subsort poset, erroring if no such meet/join exists. + *
    + *
  5. Insert a SemanticCast around every variable in the Term to record the results. + *
+ */ +public class SortInferencer { + private final Module mod; + + public SortInferencer(Module mod) { + this.mod = mod; + } + + /** + * @param t - A Term + * @return Whether t is a Term which SortInferencer can currently handle. Supported terms can + * contain neither ambiguities, strict casts, nor parametric sorts. + */ + public static boolean isSupported(Term t) { + return !hasAmbiguity(t) && !hasStrictCast(t) && !hasParametricSorts(t); + } + + private static boolean hasAmbiguity(Term t) { + if (t instanceof Ambiguity) { + return true; + } + if (t instanceof Constant) { + return false; + } + return ((TermCons) t).items().stream().anyMatch(SortInferencer::hasAmbiguity); + } + + private static boolean hasStrictCast(Term t) { + if (t instanceof Ambiguity) { + return ((Ambiguity) t).items().stream().anyMatch(SortInferencer::hasStrictCast); + } + ProductionReference pr = (ProductionReference) t; + if (pr.production().klabel().isDefined()) { + KLabel klabel = pr.production().klabel().get(); + String label = klabel.name(); + if (label.equals("#SyntacticCast") || label.equals("#InnerCast")) { + return true; + } + } + if (t instanceof Constant) { + return false; + } + return ((TermCons) t).items().stream().anyMatch(SortInferencer::hasStrictCast); + } + + private static boolean hasParametricSorts(Term t) { + if (t instanceof Ambiguity) { + return ((Ambiguity) t).items().stream().anyMatch(SortInferencer::hasParametricSorts); + } + ProductionReference pr = (ProductionReference) t; + if (stream(pr.production().items()) + .filter(pi -> pi instanceof NonTerminal) + .map(pi -> ((NonTerminal) pi).sort()) + .anyMatch(s -> !s.params().isEmpty())) { + return true; + } + if (!pr.production().sort().params().isEmpty()) { + return true; + } + if (pr instanceof Constant) { + return false; + } + return ((TermCons) t).items().stream().anyMatch(SortInferencer::hasParametricSorts); + } + + /** + * Determine if a term is a rule which can be applied anywhere in a configuration, and thus does + * not permit the RHS sort to be wider than the LHS. + * + * @param t - The Term to inspect + * @param isAnywhere - Whether t was explicitly marked with an attribute such as anywhere, + * simplification, macro, etc. indicating that it is a rule which applies anywhere + * @return Whether t is a rule which applies anywhere + */ + private static boolean isAnywhereRule(Term t, boolean isAnywhere) { + if (t instanceof Ambiguity) { + throw new AssertionError("Ambiguities are not yet supported!"); + } + t = stripBrackets(t); + if (t instanceof Constant) { + return false; + } + TermCons tc = (TermCons) t; + // For every #RuleContent production, the first non-terminal holds a #RuleBody + if (tc.production().sort().equals(Sorts.RuleContent())) { + assert tc.production().nonterminals().size() >= 1 + && tc.production().nonterminal(0).sort().equals(Sorts.RuleBody()); + return isAnywhereRule(tc.get(0), isAnywhere); + } + // For every #RuleBody production, the first non-terminal holds the actual K term + if (tc.production().sort().equals(Sorts.RuleBody())) { + assert tc.production().nonterminals().size() >= 1 + && tc.production().nonterminal(0).sort().equals(Sorts.K()); + return isAnywhereRule(tc.get(0), isAnywhere); + } + // This is the first actual K term we encounter after stripping away rule syntax, + // and should be a rewrite if this is anywhere rule. + if (tc.production().klabel().filter(k -> k.head().equals(KLabels.KREWRITE)).isDefined()) { + Term lhs = stripBrackets(tc.get(0)); + if (lhs instanceof Ambiguity) { + throw new AssertionError("Ambiguities are not yet supported!"); + } + ProductionReference lhsPr = (ProductionReference) lhs; + return isAnywhere + || lhsPr.production().att().contains(Att.FUNCTION()) + || lhsPr.production().att().getMacro().isDefined(); + } + return false; + } + + /** + * The main entry point of SortInferencer, inferring the sort of the input's variables and + * recording the results by inserting casts. + * + * @param t - The Term to infer the sort of + * @param topSort - The expected sort of t + * @param isAnywhere - Whether t is a rule with an attribute indicating that the rule applies + * anywhere in a configuration (e.g. macro, simplification, anywhere, ...). + * @return If t is not well-sorted, then a set of errors. If t is well-sorted, then a new Term + * which is the same as t, but with each variable wrapped in a SemanticCast to its inferred + * type (returning an Ambiguity of all solutions when there are multiple possible sorts). + */ + public Either, Term> apply(Term t, Sort topSort, boolean isAnywhere) { + Set> monoRes; + try { + InferenceDriver driver = new InferenceDriver(mod.subsorts()); + BoundedSort itemSort = infer(t, isAnywhereRule(t, isAnywhere), driver); + BoundedSort topBoundedSort = driver.sortToBoundedSort(topSort, null); + driver.constrain(itemSort, topBoundedSort, (ProductionReference) t); + TermSort unsimplifiedRes = driver.getResult(t, topBoundedSort); + TermSort res = simplify(unsimplifiedRes.mapSorts(CompactSort::makeCompact)); + monoRes = monomorphize(res, t); + } catch (SortInferenceError e) { + Set errs = new HashSet<>(); + errs.add(e.asInnerParseError(t)); + return Left.apply(errs); + } + + Set items = new HashSet<>(); + for (TermSort mono : monoRes) { + items.add(insertCasts(t, mono, false)); + } + if (items.size() == 1) { + return Right.apply(items.iterator().next()); + } else { + return Right.apply(Ambiguity.apply(items)); + } + } + + /** + * @param t - The term we want to infer the type of + * @param isAnywhereRule - Whether t is a rule which can be applied anywhere in a configuration + * @param driver - All state maintained during inference, which will be updated throughout with + * sorts for all contained variables + * @return The unsimplified sort of the input term + * @throws SortInferenceError - an exception indicating that the term is not well-typed + */ + private BoundedSort infer(Term t, boolean isAnywhereRule, InferenceDriver driver) + throws SortInferenceError { + if (t instanceof Ambiguity) { + throw new AssertionError("Ambiguities are not yet supported!"); + } + + ProductionReference pr = (ProductionReference) t; + if (pr instanceof Constant c) { + if (c.production().sort().equals(Sorts.KVariable()) + || c.production().sort().equals(Sorts.KConfigVar())) { + return driver.varSort(c); + } + return driver.sortToBoundedSort(c.production().sort(), pr); + } + + TermCons tc = (TermCons) pr; + if (isAnywhereRule + && tc.production().klabel().filter(k -> k.head().equals(KLabels.KREWRITE)).isDefined()) { + BoundedSort lhsSort = infer(tc.get(0), false, driver); + // To prevent widening, we constrain RHS's inferred sort <: LHS's declared sort. + // + // Note that we do actually need the LHS's declared sort. The LHS's inferred sort + // is a variable X with a bound L <: X, and constraining against X would just add a + // new lower bound aka permit widening. + ProductionReference lhsDeclaredPr = (ProductionReference) stripBrackets(tc.get(0)); + BoundedSort lhsDeclaredSort = + driver.sortToBoundedSort(lhsDeclaredPr.production().sort(), lhsDeclaredPr); + BoundedSort rhsSort = infer(tc.get(1), false, driver); + driver.constrain(rhsSort, lhsDeclaredSort, (ProductionReference) tc.get(1)); + return lhsSort; + } + + for (int prodI = 0, tcI = 0; prodI < tc.production().items().size(); prodI++) { + if (!(tc.production().items().apply(prodI) instanceof NonTerminal nt)) { + continue; + } + BoundedSort expectedSort = driver.sortToBoundedSort(nt.sort(), pr); + BoundedSort childSort = infer(tc.get(tcI), isAnywhereRule, driver); + driver.constrain(childSort, expectedSort, pr); + tcI++; + } + BoundedSort resSort = new BoundedSort.Variable(); + driver.constrain(driver.sortToBoundedSort(tc.production().sort(), pr), resSort, pr); + return resSort; + } + + private TermSort simplify(TermSort res) throws SortInferenceError { + + Map, CompactSort> coOccurrences = + analyzeCoOccurrences(res, CoOccurMode.ALWAYS); + Map> varSubst = new HashMap<>(); + // Simplify away all those variables that only occur in negative (resp. positive) position. + Set allVars = + coOccurrences.keySet().stream().map(Tuple2::_1).collect(Collectors.toSet()); + allVars.forEach( + (v) -> { + boolean negative = coOccurrences.containsKey(Tuple2.apply(v, false)); + boolean positive = coOccurrences.containsKey(Tuple2.apply(v, true)); + if ((negative && !positive) || (!negative && positive)) { + varSubst.put(v, Optional.empty()); + } + }); + + List pols = new ArrayList<>(); + pols.add(false); + pols.add(true); + + for (Boolean pol : pols) { + for (SortVariable v : allVars) { + if (!coOccurrences.containsKey(Tuple2.apply(v, pol))) { + continue; + } + CompactSort vCoOccurs = coOccurrences.get(Tuple2.apply(v, pol)); + CompactSort vOpCoOccurs = coOccurrences.get(Tuple2.apply(v, !pol)); + for (SortVariable w : vCoOccurs.vars()) { + if (v.equals(w) || varSubst.containsKey(w)) { + continue; + } + if (coOccurrences.containsKey(Tuple2.apply(w, pol)) + && coOccurrences.get(Tuple2.apply(w, pol)).vars().contains(v)) { + // v and w always co-occur in the given polarity, so we unify w into v + varSubst.put(w, Optional.of(new CompactSort(v))); + // we also need to update v's co-occurrences correspondingly + // (intersecting with w's) + CompactSort wOpCoOccurs = coOccurrences.get(Tuple2.apply(w, !pol)); + vOpCoOccurs.vars().retainAll(wOpCoOccurs.vars()); + vOpCoOccurs.ctors().retainAll(wOpCoOccurs.ctors()); + vOpCoOccurs.vars().add(v); + } + } + for (SortHead ctor : vCoOccurs.ctors()) { + // This is not a variable, so check if we have a sandwich ctor <: v <: ctor + // and can thus simplify away v + if (coOccurrences.containsKey(Tuple2.apply(v, !pol)) + && coOccurrences.get(Tuple2.apply(v, !pol)).ctors().contains(ctor)) { + varSubst.put(v, Optional.empty()); + } + } + } + } + + return res.mapSorts((c, p) -> c.substitute(varSubst)); + } + + /** Modes for the co-occurrence analysis. */ + private enum CoOccurMode { + /** Record only those sorts which always co-occur with a given variable and polarity. */ + ALWAYS, + /** Record any sort that ever co-occurs with a given variable and polarity. */ + EVER + } + + private Map, CompactSort> analyzeCoOccurrences( + TermSort res, CoOccurMode mode) { + Map, CompactSort> coOccurrences = new HashMap<>(); + res.forEachSort((s, pol) -> updateCoOccurrences(s, pol, mode, coOccurrences)); + return coOccurrences; + } + + /** + * Update the co-occurrence analysis results so-far to account for the occurrences within sort + * + * @param sort - The sort which we are processing + * @param polarity - The polarity of the provided sort + * @param coOccurrences - mutated to record all co-occurrences in each variable occurring in sort + */ + private void updateCoOccurrences( + CompactSort sort, + boolean polarity, + CoOccurMode mode, + Map, CompactSort> coOccurrences) { + for (SortVariable var : sort.vars()) { + Tuple2 polVar = Tuple2.apply(var, polarity); + if (coOccurrences.containsKey(polVar)) { + CompactSort coOccurs = coOccurrences.get(polVar); + switch (mode) { + case ALWAYS -> { + coOccurs.vars().retainAll(sort.vars()); + coOccurs.ctors().retainAll(sort.ctors()); + } + case EVER -> { + coOccurs.vars().addAll(sort.vars()); + coOccurs.ctors().addAll(sort.ctors()); + } + } + } else { + coOccurrences.put( + polVar, new CompactSort(new HashSet<>(sort.vars()), new HashSet<>(sort.ctors()))); + } + } + } + + /** + * @param res - The result to monomorphize + * @param t - The term whose inference result is res. Only used for error reporting + * @return A set of all possible monomorphizations of the input result + * @throws SortInferenceError - An error if there are no monomorphizations which can actually be + * produced from the subsort lattice. + */ + private Set> monomorphize(TermSort res, Term t) + throws SortInferenceError { + Map, CompactSort> bounds = + analyzeCoOccurrences(res, CoOccurMode.EVER); + Set allVars = + bounds.keySet().stream().map(Tuple2::_1).collect(Collectors.toSet()); + Set> instantiations = new HashSet<>(); + instantiations.add(new HashMap<>()); + for (SortVariable var : allVars) { + Set> newInstantiations = new HashSet<>(); + for (Map instant : instantiations) { + newInstantiations.addAll(monomorphizeInVar(instant, var, bounds)); + } + if (newInstantiations.isEmpty()) { + throw new AssertionError(); + } + instantiations = newInstantiations; + } + + Set> monos = new HashSet<>(); + SortInferenceError lastError = null; + for (Map inst : instantiations) { + Either> monoRes = realizeTermSort(res, inst, t); + if (monoRes.isLeft()) { + lastError = monoRes.left().get(); + } else { + monos.add(monoRes.right().get()); + } + } + if (monos.isEmpty()) { + assert lastError != null; + throw lastError; + } + return monos; + } + + private Set> monomorphizeInVar( + Map instantiation, + SortVariable var, + Map, CompactSort> bounds) { + + Map> polBounds = new HashMap<>(); + polBounds.put(true, new HashSet<>()); + polBounds.put(false, new HashSet<>()); + + for (Entry> polBound : polBounds.entrySet()) { + Tuple2 polVar = Tuple2.apply(var, polBound.getKey()); + if (!bounds.containsKey(polVar)) { + continue; + } + CompactSort bound = bounds.get(polVar); + for (SortVariable bVar : bound.vars()) { + if (instantiation.containsKey(bVar)) { + polBound.getValue().add(instantiation.get(bVar)); + } + } + for (SortHead bCtor : bound.ctors()) { + polBound.getValue().add(new org.kframework.kore.ADT.Sort(bCtor.name(), Seq())); + } + } + + Set range = mod.subsorts().upperBounds(polBounds.get(true)); + range.retainAll(mod.subsorts().lowerBounds(polBounds.get(false))); + + Set> insts = new HashSet<>(); + for (Sort sort : range) { + Map inst = new HashMap<>(instantiation); + inst.put(var, sort); + insts.add(inst); + } + return insts; + } + + private Either> realizeTermSort( + TermSort res, Map instantiation, Term t) { + Either sortRes = + res.sort().asSort(true, instantiation, mod.subsorts()); + if (sortRes.isLeft()) { + return Left.apply(new LatticeOpError(sortRes.left().get(), t, Optional.empty())); + } + Sort sort = sortRes.right().get(); + Map varSorts = new HashMap<>(); + for (Entry entry : res.varSorts().entrySet()) { + Either varRes = + entry.getValue().asSort(false, instantiation, mod.subsorts()); + if (varRes.isLeft()) { + CompactSort.LatticeOpError latticeErr = varRes.left().get(); + if (entry.getKey() instanceof VariableId.Anon anon) { + return Left.apply( + new LatticeOpError(latticeErr, anon.constant(), Optional.of("variable"))); + } + if (entry.getKey() instanceof VariableId.Named named) { + return Left.apply( + new LatticeOpError(latticeErr, t, Optional.of("variable " + named.name()))); + } + throw new AssertionError("VariableId should be either Anon or Named"); + } + varSorts.put(entry.getKey(), varRes.right().get()); + } + return Right.apply(new TermSort<>(t, sort, varSorts)); + } + + private Term insertCasts(Term t, TermSort sorts, boolean existingCast) { + if (t instanceof Ambiguity) { + throw new AssertionError("Ambiguities are not yet supported!"); + } + + ProductionReference pr = (ProductionReference) t; + if (pr instanceof Constant c) { + if (c.production().sort().equals(Sorts.KVariable()) + || c.production().sort().equals(Sorts.KConfigVar())) { + Sort inferred = sorts.varSorts().get(VariableId.apply(c)); + if (!existingCast) { + Production cast = + mod.productionsFor().apply(KLabel("#SemanticCastTo" + inferred.toString())).head(); + return TermCons.apply(ConsPStack.singleton(t), cast, t.location(), t.source()); + } + } + return c; + } + + TermCons tc = (TermCons) pr; + boolean isCast = + tc.production().klabel().filter(k -> k.name().startsWith("#SemanticCastTo")).isDefined(); + for (int i = 0; i < tc.items().size(); i++) { + tc = tc.with(i, insertCasts(tc.get(i), sorts, isCast)); + } + return tc; + } + + private static Term stripBrackets(Term tc) { + Term child = tc; + while (child instanceof TermCons + && ((TermCons) child).production().att().contains(Att.BRACKET())) { + child = ((TermCons) child).get(0); + } + return child; + } +} diff --git a/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortVariable.java b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortVariable.java new file mode 100644 index 00000000000..61d523aecc7 --- /dev/null +++ b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/SortVariable.java @@ -0,0 +1,5 @@ +package org.kframework.parser.inner.disambiguation.inference; + +public class SortVariable { + public SortVariable() {} +} diff --git a/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/TermSort.java b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/TermSort.java new file mode 100644 index 00000000000..393216835f2 --- /dev/null +++ b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/TermSort.java @@ -0,0 +1,22 @@ +package org.kframework.parser.inner.disambiguation.inference; + +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.stream.Collectors; +import org.kframework.parser.Term; + +public record TermSort(Term term, S sort, Map varSorts) { + public TermSort mapSorts(BiFunction func) { + T newSort = func.apply(sort, true); + Map newVarSorts = + varSorts().entrySet().stream() + .collect(Collectors.toMap((Map.Entry::getKey), (e) -> func.apply(e.getValue(), false))); + return new TermSort<>(term, newSort, newVarSorts); + } + + public void forEachSort(BiConsumer action) { + action.accept(sort, true); + varSorts().values().forEach((v) -> action.accept(v, false)); + } +} diff --git a/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/VariableId.java b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/VariableId.java new file mode 100644 index 00000000000..e7d89eb3760 --- /dev/null +++ b/kernel/src/main/java/org/kframework/parser/inner/disambiguation/inference/VariableId.java @@ -0,0 +1,42 @@ +package org.kframework.parser.inner.disambiguation.inference; + +import org.kframework.attributes.Att; +import org.kframework.compile.ResolveAnonVar; +import org.kframework.kore.ADT.KVariable; +import org.kframework.parser.Constant; + +public sealed interface VariableId { + static VariableId apply(Constant var) { + if (ResolveAnonVar.isAnonVarOrNamedAnonVar(new KVariable(var.value(), Att.empty()))) { + return new Anon(var); + } + return new Named(var.value()); + } + + record Named(String name) implements VariableId {} + + final class Anon implements VariableId { + private final Constant constant; + + public Anon(Constant constant) { + this.constant = constant; + } + + public Constant constant() { + return constant; + } + + @Override + public boolean equals(Object o) { + if (o instanceof Anon a) { + return this.constant == a.constant; + } + return false; + } + + @Override + public int hashCode() { + return System.identityHashCode(constant); + } + } +} diff --git a/kore/src/main/java/org/kframework/utils/errorsystem/KEMException.java b/kore/src/main/java/org/kframework/utils/errorsystem/KEMException.java index fcb85e7cd25..745aed6fb64 100644 --- a/kore/src/main/java/org/kframework/utils/errorsystem/KEMException.java +++ b/kore/src/main/java/org/kframework/utils/errorsystem/KEMException.java @@ -5,7 +5,6 @@ import org.kframework.attributes.HasLocation; import org.kframework.attributes.Location; import org.kframework.attributes.Source; -import org.kframework.parser.Term; import org.kframework.utils.errorsystem.KException.ExceptionType; import org.kframework.utils.errorsystem.KException.KExceptionGroup; @@ -138,14 +137,14 @@ public static KEMException innerParserError(String message, Source source, Locat ExceptionType.ERROR, KExceptionGroup.INNER_PARSER, message, null, location, source); } - public static KEMException innerParserError(String message, Term t) { + public static KEMException innerParserError(String message, HasLocation node) { return create( ExceptionType.ERROR, KExceptionGroup.INNER_PARSER, message, null, - t.location().orElse(null), - t.source().orElse(null)); + node.location().orElse(null), + node.source().orElse(null)); } public static KEMException innerParserError( diff --git a/kore/src/main/scala/org/kframework/parser/InferenceSorts.scala b/kore/src/main/scala/org/kframework/parser/InferenceSorts.scala deleted file mode 100644 index 60f394403ab..00000000000 --- a/kore/src/main/scala/org/kframework/parser/InferenceSorts.scala +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) K Team. All Rights Reserved. -package org.kframework.parser - -import org.kframework.kore.Sort -import org.kframework.kore.SortHead - -sealed abstract class BoundedSort - -object BoundedSort { - final case class Constructor(head: SortHead) extends BoundedSort - - final class Variable(val lowerBounds: java.util.List[BoundedSort], - val upperBounds: java.util.List[BoundedSort]) extends BoundedSort -} - -final case class CompactSort(vars: java.util.Set[BoundedSort.Variable], - ctors: java.util.Set[SortHead]) - -final case class InferenceState(varSorts: java.util.Map[VariableId, BoundedSort], - params: java.util.Map[(ProductionReference, Sort), BoundedSort.Variable], - constraintCache: java.util.Set[(BoundedSort, BoundedSort)]) - -final case class InferenceResult[T](sort: T, - varSorts: java.util.Map[VariableId, T]) - - -sealed abstract class VariableId -object VariableId { - final case class Named(name: String) extends VariableId - final case class Anon(constant: Constant, id: Integer) extends VariableId -}