-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Upgrade Hugr and start using the shared Pydantic model #201
Changes from 3 commits
5c47f40
07d38b2
db8f3ad
11a1619
9177358
e58ac85
99525e3
d0481a4
d270f87
34818c0
9812c32
648027b
5e1cac2
17cdd37
d1720af
11f18c3
5f9dab3
e8166b2
99d77c4
fbf1571
39b4380
e4d570b
95c1095
37c5293
7083706
7b534eb
0dc5484
553d11f
7ff1734
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,9 @@ | |
import json | ||
from collections.abc import Iterator | ||
from contextlib import contextmanager | ||
from typing import Any | ||
from typing import Any, TypeGuard, TypeVar | ||
|
||
from hugr.serialization import ops | ||
|
||
from guppylang.ast_util import AstVisitor, get_type, with_loc, with_type | ||
from guppylang.cfg.builder import tmp_vars | ||
|
@@ -13,8 +15,13 @@ | |
) | ||
from guppylang.definition.value import CompiledCallableDef, CompiledValueDef | ||
from guppylang.error import GuppyError, InternalGuppyError | ||
from guppylang.hugr import ops, val | ||
from guppylang.hugr.hugr import DFContainingNode, OutPortV, VNode | ||
from guppylang.hugr_builder.hugr import ( | ||
UNDEFINED, | ||
DFContainingNode, | ||
DummyOp, | ||
OutPortV, | ||
VNode, | ||
) | ||
from guppylang.nodes import ( | ||
DesugaredGenerator, | ||
DesugaredListComp, | ||
|
@@ -148,6 +155,12 @@ def visit_LocalName(self, node: LocalName) -> OutPortV: | |
def visit_GlobalName(self, node: GlobalName) -> OutPortV: | ||
defn = self.globals[node.def_id] | ||
assert isinstance(defn, CompiledValueDef) | ||
if isinstance(defn, CompiledCallableDef) and defn.ty.parametrized: | ||
raise GuppyError( | ||
"Usage of polymorphic functions as dynamic higher-order values is not " | ||
"supported yet", | ||
node, | ||
) | ||
return defn.load(self.dfg, self.graph, self.globals, node) | ||
|
||
def visit_Name(self, node: ast.Name) -> OutPortV: | ||
|
@@ -161,7 +174,7 @@ def visit_Tuple(self, node: ast.Tuple) -> OutPortV: | |
def visit_List(self, node: ast.List) -> OutPortV: | ||
# Note that this is a list literal (i.e. `[e1, e2, ...]`), not a comprehension | ||
return self.graph.add_node( | ||
ops.DummyOp(name="MakeList"), inputs=[self.visit(e) for e in node.elts] | ||
DummyOp("MakeList"), inputs=[self.visit(e) for e in node.elts] | ||
).add_out_port(get_type(node)) | ||
|
||
def _pack_returns(self, returns: list[OutPortV]) -> OutPortV: | ||
|
@@ -193,9 +206,11 @@ def visit_Call(self, node: ast.Call) -> OutPortV: | |
raise InternalGuppyError("Node should have been removed during type checking.") | ||
|
||
def visit_TypeApply(self, node: TypeApply) -> OutPortV: | ||
func = self.visit(node.value) | ||
assert isinstance(func.ty, FunctionType) | ||
ta = self.graph.add_type_apply(func, node.inst, self.dfg.node).out_port(0) | ||
# For now, we can only TypeApply global FunctionDefs/Decls. | ||
if not isinstance(node.value, GlobalName): | ||
raise InternalGuppyError("Dynamic TypeApply not supported yet!") | ||
defn = self.globals[node.value.def_id] | ||
assert isinstance(defn, CompiledCallableDef) | ||
|
||
# We have to be very careful here: If we instantiate `foo: forall T. T -> T` | ||
# with a tuple type `tuple[A, B]`, we get the type `tuple[A, B] -> tuple[A, B]`. | ||
|
@@ -204,22 +219,25 @@ def visit_TypeApply(self, node: TypeApply) -> OutPortV: | |
# function with a single output port typed `tuple[A, B]`. | ||
# TODO: We would need to do manual monomorphisation in that case to obtain a | ||
# function that returns two ports as expected | ||
if instantiation_needs_unpacking(func.ty, node.inst): | ||
if instantiation_needs_unpacking(defn.ty, node.inst): | ||
raise GuppyError( | ||
"Generic function instantiations returning rows are not supported yet", | ||
node, | ||
) | ||
|
||
return ta | ||
return defn.load_with_args(node.inst, self.dfg, self.graph, self.globals, node) | ||
|
||
def visit_UnaryOp(self, node: ast.UnaryOp) -> OutPortV: | ||
# The only case that is not desugared by the type checker is the `not` operation | ||
# since it is not implemented via a dunder method | ||
if isinstance(node.op, ast.Not): | ||
arg = self.visit(node.operand) | ||
return self.graph.add_node( | ||
ops.CustomOp(extension="logic", op_name="Not", args=[]), inputs=[arg] | ||
).add_out_port(bool_type()) | ||
op = ops.CustomOp( | ||
extension="logic", op_name="Not", args=[], parent=UNDEFINED | ||
) | ||
return self.graph.add_node(ops.OpType(op), inputs=[arg]).add_out_port( | ||
bool_type() | ||
) | ||
|
||
raise InternalGuppyError("Node should have been removed during type checking.") | ||
|
||
|
@@ -231,7 +249,7 @@ def visit_DesugaredListComp(self, node: DesugaredListComp) -> OutPortV: | |
# Make up a name for the list under construction and bind it to an empty list | ||
list_ty = get_type(node) | ||
list_name = with_type(list_ty, with_loc(node, LocalName(id=next(tmp_vars)))) | ||
empty_list = self.graph.add_node(ops.DummyOp(name="MakeList")) | ||
empty_list = self.graph.add_node(DummyOp("MakeList")) | ||
self.dfg[list_name.id] = PortVariable( | ||
list_name.id, empty_list.add_out_port(list_ty), node, None | ||
) | ||
|
@@ -242,7 +260,7 @@ def compile_generators(elt: ast.expr, gens: list[DesugaredGenerator]) -> None: | |
if not gens: | ||
list_port, elt_port = self.visit(list_name), self.visit(elt) | ||
push = self.graph.add_node( | ||
ops.DummyOp(name="Push"), inputs=[list_port, elt_port] | ||
DummyOp("Push"), inputs=[list_port, elt_port] | ||
) | ||
self.dfg[list_name.id].port = push.add_out_port(list_port.ty) | ||
return | ||
|
@@ -298,7 +316,7 @@ def instantiation_needs_unpacking(func_ty: FunctionType, inst: Inst) -> bool: | |
return False | ||
|
||
|
||
def python_value_to_hugr(v: Any, exp_ty: Type) -> val.Value | None: | ||
def python_value_to_hugr(v: Any, exp_ty: Type) -> ops.Value | None: | ||
"""Turns a Python value into a Hugr value. | ||
|
||
Returns None if the Python value cannot be represented in Guppy. | ||
|
@@ -323,15 +341,13 @@ def python_value_to_hugr(v: Any, exp_ty: Type) -> val.Value | None: | |
python_value_to_hugr(elt, ty) | ||
for elt, ty in zip(elts, exp_ty.element_types) | ||
] | ||
if any(value is None for value in vs): | ||
return None | ||
return val.Tuple(vs=vs) | ||
if doesnt_contain_none(vs): | ||
return ops.Value(ops.TupleValue(vs=vs)) | ||
Comment on lines
-376
to
+395
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mypy started complaining about this, so I added a |
||
case list(elts): | ||
assert is_list_type(exp_ty) | ||
return list_value( | ||
[python_value_to_hugr(elt, get_element_type(exp_ty)) for elt in elts], | ||
get_element_type(exp_ty).to_hugr(), | ||
) | ||
vs = [python_value_to_hugr(elt, get_element_type(exp_ty)) for elt in elts] | ||
if doesnt_contain_none(vs): | ||
return list_value(vs, get_element_type(exp_ty).to_hugr()) | ||
case _: | ||
# Pytket conversion is an optional feature | ||
try: | ||
|
@@ -343,7 +359,15 @@ def python_value_to_hugr(v: Any, exp_ty: Type) -> val.Value | None: | |
) | ||
|
||
hugr = json.loads(Tk2Circuit(v).to_hugr_json()) | ||
return val.FunctionVal(hugr=hugr) | ||
return ops.Value(ops.FunctionValue(hugr=hugr)) | ||
except ImportError: | ||
pass | ||
return None | ||
return None | ||
|
||
|
||
T = TypeVar("T") | ||
|
||
|
||
def doesnt_contain_none(xs: list[T | None]) -> TypeGuard[list[T]]: | ||
"""Checks if a list contains `None`.""" | ||
return all(x is not None for x in xs) |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -2,6 +2,8 @@ | |||||||||||||||||||||||||||||||||||||
from abc import ABC, abstractmethod | ||||||||||||||||||||||||||||||||||||||
from dataclasses import dataclass, field | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
from hugr.serialization import ops | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
from guppylang.ast_util import AstNode, get_type, with_loc, with_type | ||||||||||||||||||||||||||||||||||||||
from guppylang.checker.core import Context, Globals | ||||||||||||||||||||||||||||||||||||||
from guppylang.checker.expr_checker import check_call, synthesize_call | ||||||||||||||||||||||||||||||||||||||
|
@@ -10,8 +12,7 @@ | |||||||||||||||||||||||||||||||||||||
from guppylang.definition.common import ParsableDef | ||||||||||||||||||||||||||||||||||||||
from guppylang.definition.value import CompiledCallableDef | ||||||||||||||||||||||||||||||||||||||
from guppylang.error import GuppyError, InternalGuppyError | ||||||||||||||||||||||||||||||||||||||
from guppylang.hugr import ops | ||||||||||||||||||||||||||||||||||||||
from guppylang.hugr.hugr import Hugr, Node, OutPortV | ||||||||||||||||||||||||||||||||||||||
from guppylang.hugr_builder.hugr import Hugr, Node, OutPortV | ||||||||||||||||||||||||||||||||||||||
from guppylang.nodes import GlobalCall | ||||||||||||||||||||||||||||||||||||||
from guppylang.tys.subst import Inst, Subst | ||||||||||||||||||||||||||||||||||||||
from guppylang.tys.ty import FunctionType, NoneType, Type, type_to_row | ||||||||||||||||||||||||||||||||||||||
|
@@ -123,8 +124,13 @@ def synthesize_call( | |||||||||||||||||||||||||||||||||||||
new_node, ty = self.call_checker.synthesize(args) | ||||||||||||||||||||||||||||||||||||||
return with_type(ty, with_loc(node, new_node)), ty | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
def load( | ||||||||||||||||||||||||||||||||||||||
self, dfg: "DFContainer", graph: Hugr, globals: CompiledGlobals, node: AstNode | ||||||||||||||||||||||||||||||||||||||
def load_with_args( | ||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||
type_args: Inst, | ||||||||||||||||||||||||||||||||||||||
dfg: "DFContainer", | ||||||||||||||||||||||||||||||||||||||
graph: Hugr, | ||||||||||||||||||||||||||||||||||||||
globals: CompiledGlobals, | ||||||||||||||||||||||||||||||||||||||
node: AstNode, | ||||||||||||||||||||||||||||||||||||||
) -> OutPortV: | ||||||||||||||||||||||||||||||||||||||
"""Loads the custom function as a value into a local dataflow graph. | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
|
@@ -138,12 +144,7 @@ def load( | |||||||||||||||||||||||||||||||||||||
"This function does not support usage in a higher-order context", | ||||||||||||||||||||||||||||||||||||||
node, | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if self.ty.parametrized: | ||||||||||||||||||||||||||||||||||||||
raise InternalGuppyError( | ||||||||||||||||||||||||||||||||||||||
"Can't yet generate higher-order versions of custom functions. This " | ||||||||||||||||||||||||||||||||||||||
"requires generic function *definitions*" | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
Comment on lines
-142
to
-146
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A nice benefit of having the type args available here is that we can properly handle this case now! |
||||||||||||||||||||||||||||||||||||||
assert len(self.ty.params) == len(type_args) | ||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or maybe in the context of a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Or....why actually do we need to extend RawCustomFunctionDef at all? Is it to inherit There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are non-polymorphic custum Guppy functions whose type cannot be represented in the type system. One example are functions that are implemented via dunder methods (e.g. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is a hack to allow stuff like guppylang/guppylang/prelude/_internal.py Lines 257 to 274 in 11a1619
since To be fair, reflecting on this, it would probably be better to look these methods up in |
||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# Find the module node by walking up the hierarchy | ||||||||||||||||||||||||||||||||||||||
module: Node = dfg.node | ||||||||||||||||||||||||||||||||||||||
|
@@ -159,7 +160,7 @@ def load( | |||||||||||||||||||||||||||||||||||||
def_node = graph.add_def(self.ty, module, self.name) | ||||||||||||||||||||||||||||||||||||||
_, inp_ports = graph.add_input_with_ports(list(self.ty.inputs), def_node) | ||||||||||||||||||||||||||||||||||||||
returns = self.compile_call( | ||||||||||||||||||||||||||||||||||||||
inp_ports, [], DFContainer(def_node, {}), graph, globals, node | ||||||||||||||||||||||||||||||||||||||
inp_ports, type_args, DFContainer(def_node, {}), graph, globals, node | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
graph.add_output(returns, parent=def_node) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
|
@@ -251,14 +252,14 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: | |||||||||||||||||||||||||||||||||||||
class OpCompiler(CustomCallCompiler): | ||||||||||||||||||||||||||||||||||||||
"""Call compiler for functions that are directly implemented via Hugr ops.""" | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
op: ops.BaseOp | ||||||||||||||||||||||||||||||||||||||
op: ops.OpType | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
def __init__(self, op: ops.BaseOp) -> None: | ||||||||||||||||||||||||||||||||||||||
def __init__(self, op: ops.OpType) -> None: | ||||||||||||||||||||||||||||||||||||||
self.op = op | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
def compile(self, args: list[OutPortV]) -> list[OutPortV]: | ||||||||||||||||||||||||||||||||||||||
node = self.graph.add_node( | ||||||||||||||||||||||||||||||||||||||
self.op.model_copy(), inputs=args, parent=self.dfg.node | ||||||||||||||||||||||||||||||||||||||
self.op.model_copy(deep=True), inputs=args, parent=self.dfg.node | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
return_ty = get_type(self.node) | ||||||||||||||||||||||||||||||||||||||
return [node.add_out_port(ty) for ty in type_to_row(return_ty)] | ||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved the call to
sort_vars
fromchoose_vars_for_tuple_sum()
below to here as a drive-by since this is more consistent with the other sorting logic in this method