Skip to content

Commit

Permalink
Rename FreeTypeVar to ExistentialTypeVar and free_vars to unsolved_vars
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed Jan 9, 2024
1 parent 3eedefb commit 586f3ce
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 40 deletions.
32 changes: 16 additions & 16 deletions guppy/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)
from guppy.gtypes import (
BoolType,
FreeTypeVar,
ExistentialTypeVar,
FunctionType,
GuppyType,
Inst,
Expand Down Expand Up @@ -119,7 +119,7 @@ def check(
resolved type variables.
"""
# When checking against a variable, we have to synthesize
if isinstance(ty, FreeTypeVar):
if isinstance(ty, ExistentialTypeVar):
expr, syn_ty = self._synthesize(expr, allow_free_vars=False)
return with_type(syn_ty, expr), {ty: syn_ty}

Expand Down Expand Up @@ -199,7 +199,7 @@ def synthesize(
if ty := get_type_opt(node):
return node, ty
node, ty = self.visit(node)
if ty.free_vars and not allow_free_vars:
if ty.unsolved_vars and not allow_free_vars:
raise GuppyTypeError(
f"Cannot infer type variable in expression of type `{ty}`", node
)
Expand Down Expand Up @@ -355,7 +355,7 @@ def check_type_against(
be quantified and the actual type may not contain free unification variables.
"""
assert not isinstance(exp, FunctionType) or not exp.quantified
assert not act.free_vars
assert not act.unsolved_vars

# The actual type may be quantified. In that case, we have to find an instantiation
# to avoid higher-rank types.
Expand All @@ -374,23 +374,23 @@ def check_type_against(
"rank polymorphic types are not supported)",
node,
)
if subst[v].free_vars:
if subst[v].unsolved_vars:
raise GuppyTypeError(
f"Expected {kind} of type `{exp}`, got `{act}`. Can't instantiate "
f"type variable `{act.quantified[i]}` with type `{subst[v]}` "
"containing free variables",
node,
)
inst = [subst[v] for v in free_vars]
subst = {v: t for v, t in subst.items() if v in exp.free_vars}
subst = {v: t for v, t in subst.items() if v in exp.unsolved_vars}

# Finally, check that the instantiation respects the linearity requirements
check_inst(act, inst, node)

return subst, inst

# Otherwise, we know that `act` has no free type vars, so unification is trivial
assert not act.free_vars
# Otherwise, we know that `act` has no unsolved type vars, so unification is trivial
assert not act.unsolved_vars
subst = unify(exp, act, {})
if subst is None:
raise GuppyTypeError(f"Expected {kind} of type `{exp}`, got `{act}`", node)
Expand Down Expand Up @@ -434,10 +434,10 @@ def type_check_args(

# If the argument check succeeded, this means that we must have found instantiations
# for all unification variables occurring in the argument types
assert all(set.issubset(arg.free_vars, subst.keys()) for arg in func_ty.args)
assert all(set.issubset(arg.unsolved_vars, subst.keys()) for arg in func_ty.args)

# We also have to check that we found instantiations for all vars in the return type
if not set.issubset(func_ty.returns.free_vars, subst.keys()):
if not set.issubset(func_ty.returns.unsolved_vars, subst.keys()):
raise GuppyTypeInferenceError(
f"Cannot infer type variable in expression of type "
f"`{func_ty.returns.substitute(subst)}`",
Expand All @@ -455,7 +455,7 @@ def synthesize_call(
Returns an annotated argument list, the synthesized return type, and an
instantiation for the quantifiers in the function type.
"""
assert not func_ty.free_vars
assert not func_ty.unsolved_vars
check_num_args(len(func_ty.args), len(args), node)

# Replace quantified variables with free unification variables and try to infer an
Expand All @@ -464,7 +464,7 @@ def synthesize_call(
args, subst = type_check_args(args, unquantified, {}, ctx, node)

# Success implies that the substitution is closed
assert all(not t.free_vars for t in subst.values())
assert all(not t.unsolved_vars for t in subst.values())
inst = [subst[v] for v in free_vars]

# Finally, check that the instantiation respects the linearity requirements
Expand All @@ -486,7 +486,7 @@ def check_call(
Returns an annotated argument list, a substitution for the free variables in the
expected type, and an instantiation for the quantifiers in the function type.
"""
assert not func_ty.free_vars
assert not func_ty.unsolved_vars
check_num_args(len(func_ty.args), len(args), node)

# When checking, we can use the information from the expected return type to infer
Expand Down Expand Up @@ -536,17 +536,17 @@ def check_call(

# Also make sure we found an instantiation for all free vars in the type we're
# checking against
if not set.issubset(ty.free_vars, subst.keys()):
if not set.issubset(ty.unsolved_vars, subst.keys()):
raise GuppyTypeInferenceError(
f"Expected expression of type `{ty}`, got "
f"`{func_ty.returns.substitute(subst)}`. Couldn't infer type variables",
node,
)

# Success implies that the substitution is closed
assert all(not t.free_vars for t in subst.values())
assert all(not t.unsolved_vars for t in subst.values())
inst = [subst[v] for v in free_vars]
subst = {v: t for v, t in subst.items() if v in ty.free_vars}
subst = {v: t for v, t in subst.items() if v in ty.unsolved_vars}

# Finally, check that the instantiation respects the linearity requirements
check_inst(func_ty, inst, node)
Expand Down
4 changes: 2 additions & 2 deletions guppy/checker/stmt_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class StmtChecker(AstVisitor[BBStatement]):
return_ty: GuppyType

def __init__(self, ctx: Context, bb: BB, return_ty: GuppyType) -> None:
assert not return_ty.free_vars
assert not return_ty.unsolved_vars
self.ctx = ctx
self.bb = bb
self.return_ty = return_ty
Expand Down Expand Up @@ -93,7 +93,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt:
)
ty = type_from_ast(node.annotation, self.ctx.globals)
node.value, subst = self._check_expr(node.value, ty)
assert not ty.free_vars # `ty` must be closed!
assert not ty.unsolved_vars # `ty` must be closed!
assert len(subst) == 0
self._check_assign(node.target, ty, node)
return node
Expand Down
4 changes: 2 additions & 2 deletions guppy/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, TypeVar, cast

from guppy.ast_util import AstNode, get_file, get_line_offset, get_source
from guppy.gtypes import BoundTypeVar, FreeTypeVar, FunctionType, GuppyType
from guppy.gtypes import BoundTypeVar, ExistentialTypeVar, FunctionType, GuppyType
from guppy.hugr.hugr import Node, OutPortV

# Whether the interpreter should exit when a Guppy error occurs
Expand Down Expand Up @@ -128,7 +128,7 @@ def quantified(self) -> Sequence[BoundTypeVar]:
raise InternalGuppyError("Tried to access unknown function type")

@property
def free_vars(self) -> set[FreeTypeVar]:
def unsolved_vars(self) -> set[ExistentialTypeVar]:
return set()


Expand Down
42 changes: 22 additions & 20 deletions guppy/gtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from guppy.checker.core import Globals


Subst = dict["FreeTypeVar", "GuppyType"]
Subst = dict["ExistentialTypeVar", "GuppyType"]
Inst = Sequence["GuppyType"]


Expand All @@ -30,7 +30,7 @@ class GuppyType(ABC):
name: ClassVar[str]

# Cache for free variables
_free_vars: set["FreeTypeVar"] = field(init=False, repr=False)
_unsolved_vars: set["ExistentialTypeVar"] = field(init=False, repr=False)

def __post_init__(self) -> None:
# Make sure that we don't have higher-rank polymorphic types
Expand All @@ -43,12 +43,12 @@ def __post_init__(self) -> None:
)

# Compute free variables
if isinstance(self, FreeTypeVar):
if isinstance(self, ExistentialTypeVar):
vs = {self}
else:
vs = set()
for arg in self.type_args:
vs |= arg.free_vars
vs |= arg.unsolved_vars
object.__setattr__(self, "_free_vars", vs)

@staticmethod
Expand All @@ -75,8 +75,8 @@ def transform(self, transformer: "TypeTransformer") -> "GuppyType":
pass

@property
def free_vars(self) -> set["FreeTypeVar"]:
return self._free_vars
def unsolved_vars(self) -> set["ExistentialTypeVar"]:
return self._unsolved_vars

def substitute(self, s: Subst) -> "GuppyType":
return self.transform(Substituter(s))
Expand Down Expand Up @@ -111,23 +111,23 @@ def to_hugr(self) -> tys.SimpleType:


@dataclass(frozen=True)
class FreeTypeVar(GuppyType):
"""Free type variable, identified with a globally unique id.
class ExistentialTypeVar(GuppyType):
"""Existential type variable, identified with a globally unique id.
Serves as an existential variable for unification.
Is solved during type checking.
"""

id: int
display_name: str
linear: bool = False

name: ClassVar[Literal["FreeTypeVar"]] = "FreeTypeVar"
name: ClassVar[Literal["ExistentialTypeVar"]] = "ExistentialTypeVar"

_id_generator: ClassVar[Iterator[int]] = itertools.count()

@classmethod
def new(cls, display_name: str, linear: bool) -> "FreeTypeVar":
return FreeTypeVar(next(cls._id_generator), display_name, linear)
def new(cls, display_name: str, linear: bool) -> "ExistentialTypeVar":
return ExistentialTypeVar(next(cls._id_generator), display_name, linear)

@staticmethod
def build(*rgs: GuppyType, node: AstNode | None = None) -> GuppyType:
Expand Down Expand Up @@ -228,9 +228,11 @@ def instantiate(self, tys: Sequence[GuppyType]) -> "FunctionType":
self.arg_names,
)

def unquantified(self) -> tuple["FunctionType", Sequence[FreeTypeVar]]:
def unquantified(self) -> tuple["FunctionType", Sequence[ExistentialTypeVar]]:
"""Replaces all quantified variables with free type variables."""
inst = [FreeTypeVar.new(v.display_name, v.linear) for v in self.quantified]
inst = [
ExistentialTypeVar.new(v.display_name, v.linear) for v in self.quantified
]
return self.instantiate(inst), inst


Expand Down Expand Up @@ -393,7 +395,7 @@ def __init__(self, subst: Subst) -> None:
self.subst = subst

def transform(self, ty: GuppyType) -> GuppyType | None:
if isinstance(ty, FreeTypeVar):
if isinstance(ty, ExistentialTypeVar):
return self.subst.get(ty, None)
return None

Expand Down Expand Up @@ -427,9 +429,9 @@ def unify(s: GuppyType, t: GuppyType, subst: Subst | None) -> Subst | None:
return None
if s == t:
return subst
if isinstance(s, FreeTypeVar):
if isinstance(s, ExistentialTypeVar):
return _unify_var(s, t, subst)
if isinstance(t, FreeTypeVar):
if isinstance(t, ExistentialTypeVar):
return _unify_var(t, s, subst)
if type(s) == type(t):
sargs, targs = list(s.type_args), list(t.type_args)
Expand All @@ -440,13 +442,13 @@ def unify(s: GuppyType, t: GuppyType, subst: Subst | None) -> Subst | None:
return None


def _unify_var(var: FreeTypeVar, t: GuppyType, subst: Subst) -> Subst | None:
def _unify_var(var: ExistentialTypeVar, t: GuppyType, subst: Subst) -> Subst | None:
"""Helper function for unification of type variables."""
if var in subst:
return unify(subst[var], t, subst)
if isinstance(t, FreeTypeVar) and t in subst:
if isinstance(t, ExistentialTypeVar) and t in subst:
return unify(var, subst[t], subst)
if var in t.free_vars:
if var in t.unsolved_vars:
return None
return {var: t, **subst}

Expand Down

0 comments on commit 586f3ce

Please sign in to comment.