From 876a5dec34f75f7321cdf9ddc6da240475fad061 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 9 Jan 2024 12:06:13 +0100 Subject: [PATCH] Rename FreeTypeVar to ExistentialTypeVar and free_vars to unsolved_vars --- guppy/checker/expr_checker.py | 32 ++++++++++++------------- guppy/checker/stmt_checker.py | 4 ++-- guppy/error.py | 4 ++-- guppy/gtypes.py | 44 ++++++++++++++++++----------------- 4 files changed, 43 insertions(+), 41 deletions(-) diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index d3f9756c..9e799399 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -34,7 +34,7 @@ ) from guppy.gtypes import ( BoolType, - FreeTypeVar, + ExistentialTypeVar, FunctionType, GuppyType, Inst, @@ -119,7 +119,7 @@ def check( resolved type variables. """ # When checking against a variable, we have to synthesize - if isinstance(ty, FreeTypeVar): + if isinstance(ty, ExistentialTypeVar): expr, syn_ty = self._synthesize(expr, allow_free_vars=False) return with_type(syn_ty, expr), {ty: syn_ty} @@ -199,7 +199,7 @@ def synthesize( if ty := get_type_opt(node): return node, ty node, ty = self.visit(node) - if ty.free_vars and not allow_free_vars: + if ty.unsolved_vars and not allow_free_vars: raise GuppyTypeError( f"Cannot infer type variable in expression of type `{ty}`", node ) @@ -355,7 +355,7 @@ def check_type_against( be quantified and the actual type may not contain free unification variables. """ assert not isinstance(exp, FunctionType) or not exp.quantified - assert not act.free_vars + assert not act.unsolved_vars # The actual type may be quantified. In that case, we have to find an instantiation # to avoid higher-rank types. @@ -374,7 +374,7 @@ def check_type_against( "rank polymorphic types are not supported)", node, ) - if subst[v].free_vars: + if subst[v].unsolved_vars: raise GuppyTypeError( f"Expected {kind} of type `{exp}`, got `{act}`. Can't instantiate " f"type variable `{act.quantified[i]}` with type `{subst[v]}` " @@ -382,15 +382,15 @@ def check_type_against( node, ) inst = [subst[v] for v in free_vars] - subst = {v: t for v, t in subst.items() if v in exp.free_vars} + subst = {v: t for v, t in subst.items() if v in exp.unsolved_vars} # Finally, check that the instantiation respects the linearity requirements check_inst(act, inst, node) return subst, inst - # Otherwise, we know that `act` has no free type vars, so unification is trivial - assert not act.free_vars + # Otherwise, we know that `act` has no unsolved type vars, so unification is trivial + assert not act.unsolved_vars subst = unify(exp, act, {}) if subst is None: raise GuppyTypeError(f"Expected {kind} of type `{exp}`, got `{act}`", node) @@ -434,10 +434,10 @@ def type_check_args( # If the argument check succeeded, this means that we must have found instantiations # for all unification variables occurring in the argument types - assert all(set.issubset(arg.free_vars, subst.keys()) for arg in func_ty.args) + assert all(set.issubset(arg.unsolved_vars, subst.keys()) for arg in func_ty.args) # We also have to check that we found instantiations for all vars in the return type - if not set.issubset(func_ty.returns.free_vars, subst.keys()): + if not set.issubset(func_ty.returns.unsolved_vars, subst.keys()): raise GuppyTypeInferenceError( f"Cannot infer type variable in expression of type " f"`{func_ty.returns.substitute(subst)}`", @@ -455,7 +455,7 @@ def synthesize_call( Returns an annotated argument list, the synthesized return type, and an instantiation for the quantifiers in the function type. """ - assert not func_ty.free_vars + assert not func_ty.unsolved_vars check_num_args(len(func_ty.args), len(args), node) # Replace quantified variables with free unification variables and try to infer an @@ -464,7 +464,7 @@ def synthesize_call( args, subst = type_check_args(args, unquantified, {}, ctx, node) # Success implies that the substitution is closed - assert all(not t.free_vars for t in subst.values()) + assert all(not t.unsolved_vars for t in subst.values()) inst = [subst[v] for v in free_vars] # Finally, check that the instantiation respects the linearity requirements @@ -486,7 +486,7 @@ def check_call( Returns an annotated argument list, a substitution for the free variables in the expected type, and an instantiation for the quantifiers in the function type. """ - assert not func_ty.free_vars + assert not func_ty.unsolved_vars check_num_args(len(func_ty.args), len(args), node) # When checking, we can use the information from the expected return type to infer @@ -536,7 +536,7 @@ def check_call( # Also make sure we found an instantiation for all free vars in the type we're # checking against - if not set.issubset(ty.free_vars, subst.keys()): + if not set.issubset(ty.unsolved_vars, subst.keys()): raise GuppyTypeInferenceError( f"Expected expression of type `{ty}`, got " f"`{func_ty.returns.substitute(subst)}`. Couldn't infer type variables", @@ -544,9 +544,9 @@ def check_call( ) # Success implies that the substitution is closed - assert all(not t.free_vars for t in subst.values()) + assert all(not t.unsolved_vars for t in subst.values()) inst = [subst[v] for v in free_vars] - subst = {v: t for v, t in subst.items() if v in ty.free_vars} + subst = {v: t for v, t in subst.items() if v in ty.unsolved_vars} # Finally, check that the instantiation respects the linearity requirements check_inst(func_ty, inst, node) diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py index 3959801e..0fa46608 100644 --- a/guppy/checker/stmt_checker.py +++ b/guppy/checker/stmt_checker.py @@ -26,7 +26,7 @@ class StmtChecker(AstVisitor[BBStatement]): return_ty: GuppyType def __init__(self, ctx: Context, bb: BB, return_ty: GuppyType) -> None: - assert not return_ty.free_vars + assert not return_ty.unsolved_vars self.ctx = ctx self.bb = bb self.return_ty = return_ty @@ -93,7 +93,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt: ) ty = type_from_ast(node.annotation, self.ctx.globals) node.value, subst = self._check_expr(node.value, ty) - assert not ty.free_vars # `ty` must be closed! + assert not ty.unsolved_vars # `ty` must be closed! assert len(subst) == 0 self._check_assign(node.target, ty, node) return node diff --git a/guppy/error.py b/guppy/error.py index 5c434901..14bb5813 100644 --- a/guppy/error.py +++ b/guppy/error.py @@ -7,7 +7,7 @@ from typing import Any, TypeVar, cast from guppy.ast_util import AstNode, get_file, get_line_offset, get_source -from guppy.gtypes import BoundTypeVar, FreeTypeVar, FunctionType, GuppyType +from guppy.gtypes import BoundTypeVar, ExistentialTypeVar, FunctionType, GuppyType from guppy.hugr.hugr import Node, OutPortV # Whether the interpreter should exit when a Guppy error occurs @@ -128,7 +128,7 @@ def quantified(self) -> Sequence[BoundTypeVar]: raise InternalGuppyError("Tried to access unknown function type") @property - def free_vars(self) -> set[FreeTypeVar]: + def unsolved_vars(self) -> set[ExistentialTypeVar]: return set() diff --git a/guppy/gtypes.py b/guppy/gtypes.py index 344c8855..9046e2cb 100644 --- a/guppy/gtypes.py +++ b/guppy/gtypes.py @@ -16,7 +16,7 @@ from guppy.checker.core import Globals -Subst = dict["FreeTypeVar", "GuppyType"] +Subst = dict["ExistentialTypeVar", "GuppyType"] Inst = Sequence["GuppyType"] @@ -30,7 +30,7 @@ class GuppyType(ABC): name: ClassVar[str] # Cache for free variables - _free_vars: set["FreeTypeVar"] = field(init=False, repr=False) + _unsolved_vars: set["ExistentialTypeVar"] = field(init=False, repr=False) def __post_init__(self) -> None: # Make sure that we don't have higher-rank polymorphic types @@ -43,13 +43,13 @@ def __post_init__(self) -> None: ) # Compute free variables - if isinstance(self, FreeTypeVar): + if isinstance(self, ExistentialTypeVar): vs = {self} else: vs = set() for arg in self.type_args: - vs |= arg.free_vars - object.__setattr__(self, "_free_vars", vs) + vs |= arg.unsolved_vars + object.__setattr__(self, "_unsolved_vars", vs) @staticmethod @abstractmethod @@ -75,8 +75,8 @@ def transform(self, transformer: "TypeTransformer") -> "GuppyType": pass @property - def free_vars(self) -> set["FreeTypeVar"]: - return self._free_vars + def unsolved_vars(self) -> set["ExistentialTypeVar"]: + return self._unsolved_vars def substitute(self, s: Subst) -> "GuppyType": return self.transform(Substituter(s)) @@ -111,23 +111,23 @@ def to_hugr(self) -> tys.SimpleType: @dataclass(frozen=True) -class FreeTypeVar(GuppyType): - """Free type variable, identified with a globally unique id. +class ExistentialTypeVar(GuppyType): + """Existential type variable, identified with a globally unique id. - Serves as an existential variable for unification. + Is solved during type checking. """ id: int display_name: str linear: bool = False - name: ClassVar[Literal["FreeTypeVar"]] = "FreeTypeVar" + name: ClassVar[Literal["ExistentialTypeVar"]] = "ExistentialTypeVar" _id_generator: ClassVar[Iterator[int]] = itertools.count() @classmethod - def new(cls, display_name: str, linear: bool) -> "FreeTypeVar": - return FreeTypeVar(next(cls._id_generator), display_name, linear) + def new(cls, display_name: str, linear: bool) -> "ExistentialTypeVar": + return ExistentialTypeVar(next(cls._id_generator), display_name, linear) @staticmethod def build(*rgs: GuppyType, node: AstNode | None = None) -> GuppyType: @@ -228,9 +228,11 @@ def instantiate(self, tys: Sequence[GuppyType]) -> "FunctionType": self.arg_names, ) - def unquantified(self) -> tuple["FunctionType", Sequence[FreeTypeVar]]: + def unquantified(self) -> tuple["FunctionType", Sequence[ExistentialTypeVar]]: """Replaces all quantified variables with free type variables.""" - inst = [FreeTypeVar.new(v.display_name, v.linear) for v in self.quantified] + inst = [ + ExistentialTypeVar.new(v.display_name, v.linear) for v in self.quantified + ] return self.instantiate(inst), inst @@ -393,7 +395,7 @@ def __init__(self, subst: Subst) -> None: self.subst = subst def transform(self, ty: GuppyType) -> GuppyType | None: - if isinstance(ty, FreeTypeVar): + if isinstance(ty, ExistentialTypeVar): return self.subst.get(ty, None) return None @@ -427,9 +429,9 @@ def unify(s: GuppyType, t: GuppyType, subst: Subst | None) -> Subst | None: return None if s == t: return subst - if isinstance(s, FreeTypeVar): + if isinstance(s, ExistentialTypeVar): return _unify_var(s, t, subst) - if isinstance(t, FreeTypeVar): + if isinstance(t, ExistentialTypeVar): return _unify_var(t, s, subst) if type(s) == type(t): sargs, targs = list(s.type_args), list(t.type_args) @@ -440,13 +442,13 @@ def unify(s: GuppyType, t: GuppyType, subst: Subst | None) -> Subst | None: return None -def _unify_var(var: FreeTypeVar, t: GuppyType, subst: Subst) -> Subst | None: +def _unify_var(var: ExistentialTypeVar, t: GuppyType, subst: Subst) -> Subst | None: """Helper function for unification of type variables.""" if var in subst: return unify(subst[var], t, subst) - if isinstance(t, FreeTypeVar) and t in subst: + if isinstance(t, ExistentialTypeVar) and t in subst: return unify(var, subst[t], subst) - if var in t.free_vars: + if var in t.unsolved_vars: return None return {var: t, **subst}