Skip to content

Commit

Permalink
wip: Add list types and nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed Dec 20, 2023
1 parent 325dd25 commit 7b34e89
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 9 deletions.
112 changes: 105 additions & 7 deletions guppy/gtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def linear(self) -> bool:
def to_hugr(self) -> tys.SimpleType:
pass

def hugr_bound(self) -> tys.TypeBound:
if self.linear:
return tys.TypeBound.Any
return tys.TypeBound.join(*(ty.hugr_bound() for ty in self.type_args))

@abstractmethod
def transform(self, transformer: "TypeTransformer") -> "GuppyType":
pass
Expand Down Expand Up @@ -100,14 +105,19 @@ def build(*rgs: GuppyType, node: AstNode | None = None) -> GuppyType:
def type_args(self) -> Iterator["GuppyType"]:
return iter(())

def hugr_bound(self) -> tys.TypeBound:
# We shouldn't make variables equatable, since we also want to substitute types
# like `float`
return tys.TypeBound.Any if self.linear else tys.TypeBound.Copyable

def transform(self, transformer: "TypeTransformer") -> GuppyType:
return transformer.transform(self) or self

def __str__(self) -> str:
return self.display_name

def to_hugr(self) -> tys.SimpleType:
return tys.Variable(i=self.idx, b=tys.TypeBound.from_linear(self.linear))
return tys.Variable(i=self.idx, b=self.hugr_bound())


@dataclass(frozen=True)
Expand Down Expand Up @@ -194,13 +204,14 @@ def to_hugr(self) -> tys.PolyFuncType:
outs = [t.to_hugr() for t in type_to_row(self.returns)]
func_ty = tys.FunctionType(input=ins, output=outs, extension_reqs=[])
return tys.PolyFuncType(
params=[
tys.TypeParam(b=tys.TypeBound.from_linear(v.linear))
for v in self.quantified
],
params=[tys.TypeParam(b=v.hugr_bound()) for v in self.quantified],
body=func_ty,
)

def hugr_bound(self) -> tys.TypeBound:
# Functions are not equatable, only copyable
return tys.TypeBound.Copyable

def transform(self, transformer: "TypeTransformer") -> GuppyType:
return transformer.transform(self) or FunctionType(
[ty.transform(transformer) for ty in self.args],
Expand Down Expand Up @@ -312,6 +323,90 @@ def transform(self, transformer: "TypeTransformer") -> GuppyType:
)


@dataclass(frozen=True)
class ListType(GuppyType):
element_type: GuppyType

name: ClassVar[Literal["list"]] = "list"
linear: bool = field(default=False, init=False)

@staticmethod
def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType:
from guppy.error import GuppyError

if len(args) == 0:
raise GuppyError("Missing type parameter for generic type `list`", node)
if len(args) > 1:
raise GuppyError("Too many type arguments for generic type `list`", node)
(arg,) = args
if arg.linear:
raise GuppyError(
"Type `list` cannot store linear data, use `linst` instead", node
)
return ListType(arg)

def __str__(self) -> str:
return f"list[{self.element_type}]"

@property
def type_args(self) -> Iterator[GuppyType]:
return iter((self.element_type,))

def to_hugr(self) -> tys.SimpleType:
return tys.Opaque(
extension="Collections",
id="List",
args=[tys.TypeArg(ty=self.element_type.to_hugr())],
bound=self.hugr_bound(),
)

def transform(self, transformer: "TypeTransformer") -> GuppyType:
return transformer.transform(self) or ListType(
self.element_type.transform(transformer)
)


@dataclass(frozen=True)
class LinstType(GuppyType):
element_type: GuppyType

name: ClassVar[Literal["linst"]] = "linst"

@staticmethod
def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType:
from guppy.error import GuppyError

if len(args) == 0:
raise GuppyError("Missing type parameter for generic type `linst`", node)
if len(args) > 1:
raise GuppyError("Too many type arguments for generic type `linst`", node)
return LinstType(args[0])

def __str__(self) -> str:
return f"linst[{self.element_type}]"

@property
def linear(self) -> bool:
return self.element_type.linear

@property
def type_args(self) -> Iterator[GuppyType]:
return iter((self.element_type,))

def to_hugr(self) -> tys.SimpleType:
return tys.Opaque(
extension="Collections",
id="List",
args=[tys.TypeArg(ty=self.element_type.to_hugr())],
bound=self.hugr_bound(),
)

def transform(self, transformer: "TypeTransformer") -> GuppyType:
return transformer.transform(self) or LinstType(
self.element_type.transform(transformer)
)


@dataclass(frozen=True)
class NoneType(GuppyType):
name: ClassVar[Literal["None"]] = "None"
Expand Down Expand Up @@ -482,8 +577,11 @@ def type_from_ast(
return NoneType()
if isinstance(v, str):
try:
return type_from_ast(ast.parse(v), globals, type_var_mapping)
except SyntaxError:
[stmt] = ast.parse(v).body
if not isinstance(stmt, ast.Expr):
raise GuppyError("Invalid Guppy type", node)
return type_from_ast(stmt.value, globals, type_var_mapping)
except (SyntaxError, ValueError):
raise GuppyError("Invalid Guppy type", node) from None
raise GuppyError(f"Constant `{v}` is not a valid type", node)

Expand Down
91 changes: 89 additions & 2 deletions guppy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
from guppy.checker.core import CallableVariable, Variable


class LocalName(ast.expr):
class LocalName(ast.Name):
id: str

_fields = ("id",)


class GlobalName(ast.expr):
class GlobalName(ast.Name):
id: str
value: "Variable"

Expand Down Expand Up @@ -60,6 +60,93 @@ class TypeApply(ast.expr):
)


class MakeIter(ast.expr):
"""Creates an iterator using the `__iter__` magic method.
This node is inserted in `for` loops and list comprehensions.
"""

value: ast.expr

# Node that triggered the creation of this iterator. For example, a for loop stmt.
# It is not mentioned in `_fields` so that it is not visible to AST visitors
origin_node: ast.AST

_fields = ("value",)


class IterHasNext(ast.expr):
"""Checks if an iterator has a next element using the `__hasnext__` magic method.
This node is inserted in `for` loops and list comprehensions.
"""

value: ast.expr

_fields = ("value",)


class IterNext(ast.expr):
"""Obtains the next element of an iterator using the `__next__` magic method.
This node is inserted in `for` loops and list comprehensions.
"""

value: ast.expr

_fields = ("value",)


class IterEnd(ast.expr):
"""Finalises an iterator using the `__end__` magic method.
This node is inserted in `for` loops and list comprehensions. It is needed to
consume linear iterators once they are finished.
"""

value: ast.expr

_fields = ("value",)


class DesugaredGenerator(ast.expr):
"""A single desugared generator in a list comprehension.
Stores assignments of the original generator targets as well as dummy variables for
the iterator and hasnext test.
"""

iter_assign: ast.Assign
hasnext_assign: ast.Assign
next_assign: ast.Assign
iterend: ast.expr
iter: ast.Name
hasnext: ast.Name
ifs: list[ast.expr]

_fields = (
"iter_assign",
"hasnext_assign",
"next_assign",
"iterend",
"iter",
"hasnext",
"ifs",
)


class DesugaredListComp(ast.expr):
"""A desugared list comprehension."""

elt: ast.expr
generators: list[DesugaredGenerator]

_fields = (
"elt",
"generators",
)


class NestedFunctionDef(ast.FunctionDef):
cfg: "CFG"
ty: FunctionType
Expand Down

0 comments on commit 7b34e89

Please sign in to comment.