From 324f2ee0cc844a507a6bed6232251933a2984c3c Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Mon, 13 May 2024 15:28:53 +0100 Subject: [PATCH] chore: Pass globals to check_instantiate (#206) --- guppylang/checker/expr_checker.py | 8 +++++--- guppylang/definition/ty.py | 8 ++++++-- guppylang/prelude/_internal.py | 4 +++- guppylang/tys/builtin.py | 15 +++++++++------ guppylang/tys/parsing.py | 4 ++-- 5 files changed, 25 insertions(+), 14 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 8c98aa04..f11950b0 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -953,9 +953,9 @@ def python_value_to_guppy_type(v: Any, node: ast.expr, globals: Globals) -> Type case bool(): return bool_type() case int(): - return cast(TypeDef, globals["int"]).check_instantiate([]) + return cast(TypeDef, globals["int"]).check_instantiate([], globals) case float(): - return cast(TypeDef, globals["float"]).check_instantiate([]) + return cast(TypeDef, globals["float"]).check_instantiate([], globals) case tuple(elts): tys = [python_value_to_guppy_type(elt, node, globals) for elt in elts] if any(ty is None for ty in tys): @@ -973,7 +973,9 @@ def python_value_to_guppy_type(v: Any, node: ast.expr, globals: Globals) -> Type try: import tket2 # type: ignore[import-untyped, import-not-found, unused-ignore] # noqa: F401 - qubit = cast(TypeDef, globals["qubit"]).check_instantiate([]) + qubit = cast(TypeDef, globals["qubit"]).check_instantiate( + [], globals + ) return FunctionType( [qubit] * v.n_qubits, row_to_type( diff --git a/guppylang/definition/ty.py b/guppylang/definition/ty.py index 0e8f23db..36fc4c35 100644 --- a/guppylang/definition/ty.py +++ b/guppylang/definition/ty.py @@ -1,6 +1,7 @@ from abc import abstractmethod from collections.abc import Callable, Sequence from dataclasses import dataclass, field +from typing import TYPE_CHECKING from guppylang.ast_util import AstNode from guppylang.definition.common import CompiledDef, Definition @@ -11,6 +12,9 @@ from guppylang.tys.param import Parameter from guppylang.tys.ty import OpaqueType +if TYPE_CHECKING: + from guppylang.checker.core import Globals + @dataclass(frozen=True) class TypeDef(Definition): @@ -20,7 +24,7 @@ class TypeDef(Definition): @abstractmethod def check_instantiate( - self, args: Sequence[Argument], loc: AstNode | None = None + self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None ) -> Type: """Checks if the type definition can be instantiated with the given arguments. @@ -39,7 +43,7 @@ class OpaqueTypeDef(TypeDef, CompiledDef): bound: tys.TypeBound | None = None def check_instantiate( - self, args: Sequence[Argument], loc: AstNode | None = None + self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None ) -> OpaqueType: """Checks if the type definition can be instantiated with the given arguments. diff --git a/guppylang/prelude/_internal.py b/guppylang/prelude/_internal.py index db0e0aee..f4bc1ee2 100644 --- a/guppylang/prelude/_internal.py +++ b/guppylang/prelude/_internal.py @@ -119,7 +119,9 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: ) float_defn = self.ctx.globals["float"] assert isinstance(float_defn, TypeDef) - args[i] = with_type(float_defn.check_instantiate([]), call) + args[i] = with_type( + float_defn.check_instantiate([], self.ctx.globals), call + ) return super().synthesize(args) diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index ecd4ddf6..54461471 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from dataclasses import dataclass, field -from typing import Literal +from typing import TYPE_CHECKING, Literal from guppylang.ast_util import AstNode from guppylang.definition.common import DefId @@ -11,6 +11,9 @@ from guppylang.tys.param import TypeParam from guppylang.tys.ty import FunctionType, NoneType, OpaqueType, TupleType, Type +if TYPE_CHECKING: + from guppylang.checker.core import Globals + @dataclass(frozen=True) class _CallableTypeDef(TypeDef): @@ -22,7 +25,7 @@ class _CallableTypeDef(TypeDef): name: Literal["Callable"] = field(default="Callable", init=False) def check_instantiate( - self, args: Sequence[Argument], loc: AstNode | None = None + self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None ) -> FunctionType: # We get the inputs/output as a flattened list: `args = [*inputs, output]`. if not args: @@ -46,7 +49,7 @@ class _TupleTypeDef(TypeDef): name: Literal["tuple"] = field(default="tuple", init=False) def check_instantiate( - self, args: Sequence[Argument], loc: AstNode | None = None + self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None ) -> TupleType: # We accept any number of arguments. If users just write `tuple`, we give them # the empty tuple type. We just have to make sure that the args are of kind type @@ -68,7 +71,7 @@ class _NoneTypeDef(TypeDef): name: Literal["None"] = field(default="None", init=False) def check_instantiate( - self, args: Sequence[Argument], loc: AstNode | None = None + self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None ) -> NoneType: if args: raise GuppyError("Type `None` is not parameterized", loc) @@ -84,7 +87,7 @@ class _ListTypeDef(OpaqueTypeDef): """ def check_instantiate( - self, args: Sequence[Argument], loc: AstNode | None = None + self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None ) -> OpaqueType: if len(args) == 1: [arg] = args @@ -92,7 +95,7 @@ def check_instantiate( raise GuppyError( "Type `list` cannot store linear data, use `linst` instead", loc ) - return super().check_instantiate(args, loc) + return super().check_instantiate(args, globals, loc) def _list_to_hugr(args: Sequence[Argument]) -> tys.Opaque: diff --git a/guppylang/tys/parsing.py b/guppylang/tys/parsing.py index 1eb6d7e2..4124d6b9 100644 --- a/guppylang/tys/parsing.py +++ b/guppylang/tys/parsing.py @@ -26,7 +26,7 @@ def arg_from_ast( match globals[x]: # Either a defined type (e.g. `int`, `bool`, ...) case TypeDef() as defn: - return TypeArg(defn.check_instantiate([], node)) + return TypeArg(defn.check_instantiate([], globals, node)) # Or a parameter (e.g. `T`, `n`, ...) case ParamDef() as defn: if param_var_mapping is None: @@ -65,7 +65,7 @@ def arg_from_ast( arg_from_ast(arg_node, globals, param_var_mapping) for arg_node in arg_nodes ] - ty = defn.check_instantiate(args, node) + ty = defn.check_instantiate(args, globals, node) return TypeArg(ty) # We don't allow parametrised variables like `T[int]` if isinstance(defn, ParamDef):