Skip to content

Commit

Permalink
feat: New type representation with parameters (#174)
Browse files Browse the repository at this point in the history
* Prepare support for types and functions that are generic over constant
values (e.g. bounded nats):
- Generic types and functions are now defined in terms of `Parameter`s
and `Argument`s that can be either types or constants (see
`tys/param.py` and `tys/arg.py`).
- Implementations for `ConstParam` and `ConstArg` will follow in the
future
* Improved pretty printing of types (see `tys/printing.py`)
* Add a notion of `TypeDefinition` (see `tys/definition.py`) that
replaces the ad-hoc creation of Python classes to define new types
* `BoolType` is no longer a `SumType`. This was a hugr implementation
detail and not relevant for the Guppy type system.

Drive by renaming:
* Rename function type members from `args/return` to `inputs/output`
since `args` is already used now
* Rename `GuppyType` to `Type`
  • Loading branch information
mark-koch authored Mar 19, 2024
1 parent 043b2db commit 73e29f2
Show file tree
Hide file tree
Showing 38 changed files with 1,863 additions and 1,108 deletions.
12 changes: 6 additions & 6 deletions guppylang/ast_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, cast

if TYPE_CHECKING:
from guppylang.gtypes import GuppyType
from guppylang.tys.ty import Type

AstNode = (
ast.AST
Expand Down Expand Up @@ -286,24 +286,24 @@ def with_loc(loc: ast.AST, node: A) -> A:
return node


def with_type(ty: "GuppyType", node: A) -> A:
def with_type(ty: "Type", node: A) -> A:
"""Annotates an AST node with a type."""
node.type = ty # type: ignore[attr-defined]
return node


def get_type_opt(node: AstNode) -> Optional["GuppyType"]:
def get_type_opt(node: AstNode) -> Optional["Type"]:
"""Tries to retrieve a type annotation from an AST node."""
from guppylang.gtypes import GuppyType
from guppylang.tys.ty import Type, TypeBase

try:
ty = node.type # type: ignore[union-attr]
return ty if isinstance(ty, GuppyType) else None
return cast(Type, ty) if isinstance(ty, TypeBase) else None
except AttributeError:
return None


def get_type(node: AstNode) -> "GuppyType":
def get_type(node: AstNode) -> "Type":
"""Retrieve a type annotation from an AST node.
Fails if the node is not annotated.
Expand Down
4 changes: 2 additions & 2 deletions guppylang/cfg/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from guppylang.cfg.cfg import CFG
from guppylang.checker.core import Globals
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.gtypes import NoneType
from guppylang.nodes import (
DesugaredGenerator,
DesugaredListComp,
Expand All @@ -26,6 +25,7 @@
NestedFunctionDef,
PyExpr,
)
from guppylang.tys.ty import NoneType

# In order to build expressions, need an endless stream of unique temporary variables
# to store intermediate results
Expand Down Expand Up @@ -213,7 +213,7 @@ def visit_FunctionDef(
from guppylang.checker.func_checker import check_signature

func_ty = check_signature(node, self.globals)
returns_none = isinstance(func_ty.returns, NoneType)
returns_none = isinstance(func_ty.output, NoneType)
cfg = CFGBuilder().build(node.body, returns_none, self.globals)

new_node = NestedFunctionDef(
Expand Down
12 changes: 6 additions & 6 deletions guppylang/checker/cfg_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from guppylang.checker.expr_checker import ExprSynthesizer, to_bool
from guppylang.checker.stmt_checker import StmtChecker
from guppylang.error import GuppyError
from guppylang.gtypes import GuppyType
from guppylang.tys.ty import Type

VarRow = Sequence[Variable]

Expand Down Expand Up @@ -44,17 +44,17 @@ class CheckedBB(BB):


class CheckedCFG(BaseCFG[CheckedBB]):
input_tys: list[GuppyType]
output_ty: GuppyType
input_tys: list[Type]
output_ty: Type

def __init__(self, input_tys: list[GuppyType], output_ty: GuppyType) -> None:
def __init__(self, input_tys: list[Type], output_ty: Type) -> None:
super().__init__([])
self.input_tys = input_tys
self.output_ty = output_ty


def check_cfg(
cfg: CFG, inputs: VarRow, return_ty: GuppyType, globals: Globals
cfg: CFG, inputs: VarRow, return_ty: Type, globals: Globals
) -> CheckedCFG:
"""Type checks a control-flow graph.
Expand Down Expand Up @@ -121,7 +121,7 @@ def check_bb(
bb: BB,
checked_cfg: CheckedCFG,
inputs: VarRow,
return_ty: GuppyType,
return_ty: Type,
globals: Globals,
) -> CheckedBB:
cfg = bb.containing_cfg
Expand Down
80 changes: 53 additions & 27 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,29 @@
from dataclasses import dataclass
from typing import Any, NamedTuple

from typing_extensions import assert_never

from guppylang.ast_util import AstNode, name_nodes_in_ast
from guppylang.gtypes import (
BoolType,
from guppylang.tys.definition import (
TypeDef,
bool_type_def,
callable_type_def,
linst_type_def,
list_type_def,
none_type_def,
tuple_type_def,
)
from guppylang.tys.param import Parameter
from guppylang.tys.subst import Subst
from guppylang.tys.ty import (
BoundTypeVar,
ExistentialTypeVar,
FunctionType,
GuppyType,
LinstType,
ListType,
NoneType,
Subst,
OpaqueType,
SumType,
TupleType,
Type,
)


Expand All @@ -25,7 +37,7 @@ class Variable:
"""Class holding data associated with a variable."""

name: str
ty: GuppyType
ty: Type
defined_at: AstNode | None
used: AstNode | None

Expand All @@ -38,14 +50,14 @@ class CallableVariable(ABC, Variable):

@abstractmethod
def check_call(
self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: "Context"
self, args: list[ast.expr], ty: Type, node: AstNode, ctx: "Context"
) -> tuple[ast.expr, Subst]:
"""Checks the return type of a function call against a given type."""

@abstractmethod
def synthesize_call(
self, args: list[ast.expr], node: AstNode, ctx: "Context"
) -> tuple[ast.expr, GuppyType]:
) -> tuple[ast.expr, Type]:
"""Synthesizes the return type of a function call."""


Expand All @@ -68,30 +80,44 @@ class Globals(NamedTuple):
"""

values: dict[str, Variable]
types: dict[str, type[GuppyType]]
type_vars: dict[str, TypeVarDecl]
type_defs: dict[str, TypeDef]
param_vars: dict[str, Parameter]
python_scope: PyScope

@staticmethod
def default() -> "Globals":
"""Generates a `Globals` instance that is populated with all core types"""
tys: dict[str, type[GuppyType]] = {
FunctionType.name: FunctionType,
TupleType.name: TupleType,
SumType.name: SumType,
NoneType.name: NoneType,
BoolType.name: BoolType,
ListType.name: ListType,
LinstType.name: LinstType,
type_defs = {
"Callable": callable_type_def,
"tuple": tuple_type_def,
"None": none_type_def,
"bool": bool_type_def,
"list": list_type_def,
"linst": linst_type_def,
}
return Globals({}, tys, {}, {})
return Globals({}, type_defs, {}, {})

def get_instance_func(self, ty: GuppyType, name: str) -> CallableVariable | None:
def get_instance_func(self, ty: Type, name: str) -> CallableVariable | None:
"""Looks up an instance function with a given name for a type.
Returns `None` if the name doesn't exist or isn't a function.
"""
qualname = qualified_name(ty.__class__, name)
defn: TypeDef
match ty:
case BoundTypeVar() | ExistentialTypeVar() | SumType():
return None
case FunctionType():
defn = callable_type_def
case OpaqueType() as ty:
defn = ty.defn
case TupleType():
defn = tuple_type_def
case NoneType():
defn = none_type_def
case _:
assert_never(ty)

qualname = qualified_name(defn.name, name)
if qualname in self.values:
val = self.values[qualname]
if isinstance(val, CallableVariable):
Expand All @@ -101,15 +127,15 @@ def get_instance_func(self, ty: GuppyType, name: str) -> CallableVariable | None
def __or__(self, other: "Globals") -> "Globals":
return Globals(
self.values | other.values,
self.types | other.types,
self.type_vars | other.type_vars,
self.type_defs | other.type_defs,
self.param_vars | other.param_vars,
self.python_scope | other.python_scope,
)

def __ior__(self, other: "Globals") -> "Globals": # noqa: PYI034
self.values.update(other.values)
self.types.update(other.types)
self.type_vars.update(other.type_vars)
self.type_defs.update(other.type_defs)
self.param_vars.update(other.param_vars)
return self


Expand Down Expand Up @@ -203,7 +229,7 @@ def __contains__(self, key: object) -> bool:
return super().__contains__(key)


def qualified_name(ty: type[GuppyType] | str, name: str) -> str:
def qualified_name(ty: TypeDef | str, name: str) -> str:
"""Returns a qualified name for an instance function on a type."""
ty_name = ty if isinstance(ty, str) else ty.name
return f"{ty_name}.{name}"
Loading

0 comments on commit 73e29f2

Please sign in to comment.