From 7b34e89519b07bbaa660fe7cdc731c931065946e Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 20 Dec 2023 13:05:06 +0000 Subject: [PATCH] wip: Add list types and nodes --- guppy/gtypes.py | 112 +++++++++++++++++++++++++++++++++++++++++++++--- guppy/nodes.py | 91 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 194 insertions(+), 9 deletions(-) diff --git a/guppy/gtypes.py b/guppy/gtypes.py index 344c8855..37c84e83 100644 --- a/guppy/gtypes.py +++ b/guppy/gtypes.py @@ -70,6 +70,11 @@ def linear(self) -> bool: def to_hugr(self) -> tys.SimpleType: pass + def hugr_bound(self) -> tys.TypeBound: + if self.linear: + return tys.TypeBound.Any + return tys.TypeBound.join(*(ty.hugr_bound() for ty in self.type_args)) + @abstractmethod def transform(self, transformer: "TypeTransformer") -> "GuppyType": pass @@ -100,6 +105,11 @@ def build(*rgs: GuppyType, node: AstNode | None = None) -> GuppyType: def type_args(self) -> Iterator["GuppyType"]: return iter(()) + def hugr_bound(self) -> tys.TypeBound: + # We shouldn't make variables equatable, since we also want to substitute types + # like `float` + return tys.TypeBound.Any if self.linear else tys.TypeBound.Copyable + def transform(self, transformer: "TypeTransformer") -> GuppyType: return transformer.transform(self) or self @@ -107,7 +117,7 @@ def __str__(self) -> str: return self.display_name def to_hugr(self) -> tys.SimpleType: - return tys.Variable(i=self.idx, b=tys.TypeBound.from_linear(self.linear)) + return tys.Variable(i=self.idx, b=self.hugr_bound()) @dataclass(frozen=True) @@ -194,13 +204,14 @@ def to_hugr(self) -> tys.PolyFuncType: outs = [t.to_hugr() for t in type_to_row(self.returns)] func_ty = tys.FunctionType(input=ins, output=outs, extension_reqs=[]) return tys.PolyFuncType( - params=[ - tys.TypeParam(b=tys.TypeBound.from_linear(v.linear)) - for v in self.quantified - ], + params=[tys.TypeParam(b=v.hugr_bound()) for v in self.quantified], body=func_ty, ) + def hugr_bound(self) -> tys.TypeBound: + # Functions are not equatable, only copyable + return tys.TypeBound.Copyable + def transform(self, transformer: "TypeTransformer") -> GuppyType: return transformer.transform(self) or FunctionType( [ty.transform(transformer) for ty in self.args], @@ -312,6 +323,90 @@ def transform(self, transformer: "TypeTransformer") -> GuppyType: ) +@dataclass(frozen=True) +class ListType(GuppyType): + element_type: GuppyType + + name: ClassVar[Literal["list"]] = "list" + linear: bool = field(default=False, init=False) + + @staticmethod + def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: + from guppy.error import GuppyError + + if len(args) == 0: + raise GuppyError("Missing type parameter for generic type `list`", node) + if len(args) > 1: + raise GuppyError("Too many type arguments for generic type `list`", node) + (arg,) = args + if arg.linear: + raise GuppyError( + "Type `list` cannot store linear data, use `linst` instead", node + ) + return ListType(arg) + + def __str__(self) -> str: + return f"list[{self.element_type}]" + + @property + def type_args(self) -> Iterator[GuppyType]: + return iter((self.element_type,)) + + def to_hugr(self) -> tys.SimpleType: + return tys.Opaque( + extension="Collections", + id="List", + args=[tys.TypeArg(ty=self.element_type.to_hugr())], + bound=self.hugr_bound(), + ) + + def transform(self, transformer: "TypeTransformer") -> GuppyType: + return transformer.transform(self) or ListType( + self.element_type.transform(transformer) + ) + + +@dataclass(frozen=True) +class LinstType(GuppyType): + element_type: GuppyType + + name: ClassVar[Literal["linst"]] = "linst" + + @staticmethod + def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: + from guppy.error import GuppyError + + if len(args) == 0: + raise GuppyError("Missing type parameter for generic type `linst`", node) + if len(args) > 1: + raise GuppyError("Too many type arguments for generic type `linst`", node) + return LinstType(args[0]) + + def __str__(self) -> str: + return f"linst[{self.element_type}]" + + @property + def linear(self) -> bool: + return self.element_type.linear + + @property + def type_args(self) -> Iterator[GuppyType]: + return iter((self.element_type,)) + + def to_hugr(self) -> tys.SimpleType: + return tys.Opaque( + extension="Collections", + id="List", + args=[tys.TypeArg(ty=self.element_type.to_hugr())], + bound=self.hugr_bound(), + ) + + def transform(self, transformer: "TypeTransformer") -> GuppyType: + return transformer.transform(self) or LinstType( + self.element_type.transform(transformer) + ) + + @dataclass(frozen=True) class NoneType(GuppyType): name: ClassVar[Literal["None"]] = "None" @@ -482,8 +577,11 @@ def type_from_ast( return NoneType() if isinstance(v, str): try: - return type_from_ast(ast.parse(v), globals, type_var_mapping) - except SyntaxError: + [stmt] = ast.parse(v).body + if not isinstance(stmt, ast.Expr): + raise GuppyError("Invalid Guppy type", node) + return type_from_ast(stmt.value, globals, type_var_mapping) + except (SyntaxError, ValueError): raise GuppyError("Invalid Guppy type", node) from None raise GuppyError(f"Constant `{v}` is not a valid type", node) diff --git a/guppy/nodes.py b/guppy/nodes.py index a08d30da..db62b86d 100644 --- a/guppy/nodes.py +++ b/guppy/nodes.py @@ -12,13 +12,13 @@ from guppy.checker.core import CallableVariable, Variable -class LocalName(ast.expr): +class LocalName(ast.Name): id: str _fields = ("id",) -class GlobalName(ast.expr): +class GlobalName(ast.Name): id: str value: "Variable" @@ -60,6 +60,93 @@ class TypeApply(ast.expr): ) +class MakeIter(ast.expr): + """Creates an iterator using the `__iter__` magic method. + + This node is inserted in `for` loops and list comprehensions. + """ + + value: ast.expr + + # Node that triggered the creation of this iterator. For example, a for loop stmt. + # It is not mentioned in `_fields` so that it is not visible to AST visitors + origin_node: ast.AST + + _fields = ("value",) + + +class IterHasNext(ast.expr): + """Checks if an iterator has a next element using the `__hasnext__` magic method. + + This node is inserted in `for` loops and list comprehensions. + """ + + value: ast.expr + + _fields = ("value",) + + +class IterNext(ast.expr): + """Obtains the next element of an iterator using the `__next__` magic method. + + This node is inserted in `for` loops and list comprehensions. + """ + + value: ast.expr + + _fields = ("value",) + + +class IterEnd(ast.expr): + """Finalises an iterator using the `__end__` magic method. + + This node is inserted in `for` loops and list comprehensions. It is needed to + consume linear iterators once they are finished. + """ + + value: ast.expr + + _fields = ("value",) + + +class DesugaredGenerator(ast.expr): + """A single desugared generator in a list comprehension. + + Stores assignments of the original generator targets as well as dummy variables for + the iterator and hasnext test. + """ + + iter_assign: ast.Assign + hasnext_assign: ast.Assign + next_assign: ast.Assign + iterend: ast.expr + iter: ast.Name + hasnext: ast.Name + ifs: list[ast.expr] + + _fields = ( + "iter_assign", + "hasnext_assign", + "next_assign", + "iterend", + "iter", + "hasnext", + "ifs", + ) + + +class DesugaredListComp(ast.expr): + """A desugared list comprehension.""" + + elt: ast.expr + generators: list[DesugaredGenerator] + + _fields = ( + "elt", + "generators", + ) + + class NestedFunctionDef(ast.FunctionDef): cfg: "CFG" ty: FunctionType