Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Upgrade Hugr and start using the shared Pydantic model #201

Merged
merged 29 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
5c47f40
feat: Upgrade Hugr and start using the shared pydantic model
mark-koch May 2, 2024
07d38b2
Merge remote-tracking branch 'origin/main' into chore/hugr-upgrade
mark-koch May 2, 2024
db8f3ad
Fix ruff and mypy issues
mark-koch May 2, 2024
11a1619
Update Hugr ref
mark-koch May 3, 2024
9177358
Add helper method to construct tys.FunctionType
mark-koch May 9, 2024
e58ac85
Simplify TupleType.to_hugr()
mark-koch May 9, 2024
99525e3
Delete generate_schema.py
mark-koch May 9, 2024
d0481a4
Fix ConstInt docstring
mark-koch May 9, 2024
d270f87
Add comment to update_op explaining function edges
mark-koch May 9, 2024
34818c0
Remove unnecessary order edges
mark-koch May 9, 2024
9812c32
Add comment explaining entry BB detection
mark-koch May 9, 2024
648027b
Clarify DummyOp docstring
mark-koch May 9, 2024
5e1cac2
Move assert and rename types to hugr_variants
mark-koch May 9, 2024
17cdd37
Fix add_def and add_declare docstrings
mark-koch May 9, 2024
d1720af
Fix root node serialisation
mark-koch May 9, 2024
11f18c3
Add rows_to_hugr helper and refactor SumType.to_hugr()
mark-koch May 9, 2024
5f9dab3
Fix add_load_function docstring
mark-koch May 9, 2024
e8166b2
Merge remote-tracking branch 'origin/main' into chore/hugr-upgrade
mark-koch May 9, 2024
99d77c4
Bump to latest Hugr model
mark-koch May 9, 2024
fbf1571
Assert argument type in _list_to_hugr
mark-koch May 9, 2024
39b4380
Bump to latest Hugr main
mark-koch May 9, 2024
e4d570b
Merge remote-tracking branch 'origin/main' into chore/hugr-upgrade
mark-koch May 14, 2024
95c1095
Use hugr 0.4 release candidate
mark-koch May 14, 2024
37c5293
Bump tket2 and load it for validation
mark-koch May 15, 2024
7083706
Use ops.Value instead of ops.Const
mark-koch May 16, 2024
7b534eb
Use match statement in to_raw
mark-koch May 16, 2024
0dc5484
Merge remote-tracking branch 'origin/main' into chore/hugr-upgrade
mark-koch May 16, 2024
553d11f
Upgrade merged changes
mark-koch May 16, 2024
7ff1734
Fix mypy
mark-koch May 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 18 additions & 22 deletions guppylang/compiler/cfg_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
)
from guppylang.compiler.expr_compiler import ExprCompiler
from guppylang.compiler.stmt_compiler import StmtCompiler
from guppylang.hugr.hugr import CFNode, Hugr, Node, OutPortV
from guppylang.hugr_builder.hugr import CFNode, Hugr, Node, OutPortV
from guppylang.tys.builtin import is_bool_type
from guppylang.tys.ty import SumType, TupleType, type_to_row
from guppylang.tys.ty import SumType, row_to_type, type_to_row

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -65,9 +65,8 @@ def compile_bb(
branch_port = ExprCompiler(graph, globals).compile(bb.branch_pred, dfg)
else:
# Even if we don't branch, we still have to add a `Sum(())` predicates
unit = graph.add_make_tuple([], parent=block).out_port(0)
branch_port = graph.add_tag(
variants=[TupleType([])], tag=0, inp=unit, parent=block
variants=[[]], tag=0, inputs=[], parent=block
).out_port(0)

# Finally, we have to add the block output.
Expand All @@ -92,7 +91,11 @@ def compile_bb(
graph=graph,
unit_sum=branch_port,
output_vars=[
[v for v in row if not v.ty.linear or is_return_var(v.name)]
[
v
for v in sort_vars(row)
if not v.ty.linear or is_return_var(v.name)
]
Comment on lines -95 to +98
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the call to sort_vars from choose_vars_for_tuple_sum() below to here as a drive-by since this is more consistent with the other sorting logic in this method

for row in bb.sig.output_rows
],
dfg=dfg,
Expand Down Expand Up @@ -133,30 +136,23 @@ def choose_vars_for_tuple_sum(
) -> OutPortV:
"""Selects an output based on a TupleSum.

Given `unit_sum: Sum((), (), ...)` and output variable sets `#s1, #s2, ...`,
constructs a TupleSum value of type `Sum(Tuple(#s1), Tuple(#s2), ...)`.
Given `unit_sum: Sum(*(), *(), ...)` and output variable rows `#s1, #s2, ...`,
constructs a TupleSum value of type `Sum(#s1, #s2, ...)`.
"""
assert isinstance(unit_sum.ty, SumType) or is_bool_type(unit_sum.ty)
assert len(output_vars) == (
len(unit_sum.ty.element_types) if isinstance(unit_sum.ty, SumType) else 2
)
tuples = [
graph.add_make_tuple(
inputs=[dfg[v.name].port for v in sort_vars(vs) if v.name in dfg],
parent=dfg.node,
).out_port(0)
for vs in output_vars
]
tys = [t.ty for t in tuples]
conditional = graph.add_conditional(
cond_input=unit_sum, inputs=tuples, parent=dfg.node
)
for i, _ty in enumerate(tys):
assert all(not v.ty.linear for var_row in output_vars for v in var_row)
conditional = graph.add_conditional(cond_input=unit_sum, inputs=[], parent=dfg.node)
tys = [[v.ty for v in var_row] for var_row in output_vars]
for i, var_row in enumerate(output_vars):
case = graph.add_case(conditional)
inp = graph.add_input(output_tys=tys, parent=case).out_port(i)
tag = graph.add_tag(variants=tys, tag=i, inp=inp, parent=case).out_port(0)
graph.add_input(output_tys=[], parent=case)
inputs = [dfg[v.name].port for v in var_row]
tag = graph.add_tag(variants=tys, tag=i, inputs=inputs, parent=case).out_port(0)
graph.add_output(inputs=[tag], parent=case)
return conditional.add_out_port(SumType(tys))
return conditional.add_out_port(SumType([row_to_type(row) for row in tys]))


def compare_var(x: Variable, y: Variable) -> int:
Expand Down
2 changes: 1 addition & 1 deletion guppylang/compiler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from guppylang.ast_util import AstNode
from guppylang.checker.core import Variable
from guppylang.definition.common import CompiledDef, DefId
from guppylang.hugr.hugr import DFContainingNode, Hugr, OutPortV
from guppylang.hugr_builder.hugr import DFContainingNode, Hugr, OutPortV


@dataclass
Expand Down
72 changes: 48 additions & 24 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import json
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any
from typing import Any, TypeGuard, TypeVar

from hugr.serialization import ops

from guppylang.ast_util import AstVisitor, get_type, with_loc, with_type
from guppylang.cfg.builder import tmp_vars
Expand All @@ -13,8 +15,13 @@
)
from guppylang.definition.value import CompiledCallableDef, CompiledValueDef
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.hugr import ops, val
from guppylang.hugr.hugr import DFContainingNode, OutPortV, VNode
from guppylang.hugr_builder.hugr import (
UNDEFINED,
DFContainingNode,
DummyOp,
OutPortV,
VNode,
)
from guppylang.nodes import (
DesugaredGenerator,
DesugaredListComp,
Expand Down Expand Up @@ -149,6 +156,12 @@ def visit_LocalName(self, node: LocalName) -> OutPortV:
def visit_GlobalName(self, node: GlobalName) -> OutPortV:
defn = self.globals[node.def_id]
assert isinstance(defn, CompiledValueDef)
if isinstance(defn, CompiledCallableDef) and defn.ty.parametrized:
raise GuppyError(
"Usage of polymorphic functions as dynamic higher-order values is not "
"supported yet",
node,
)
return defn.load(self.dfg, self.graph, self.globals, node)

def visit_Name(self, node: ast.Name) -> OutPortV:
Expand All @@ -162,7 +175,7 @@ def visit_Tuple(self, node: ast.Tuple) -> OutPortV:
def visit_List(self, node: ast.List) -> OutPortV:
# Note that this is a list literal (i.e. `[e1, e2, ...]`), not a comprehension
return self.graph.add_node(
ops.DummyOp(name="MakeList"), inputs=[self.visit(e) for e in node.elts]
DummyOp("MakeList"), inputs=[self.visit(e) for e in node.elts]
).add_out_port(get_type(node))

def _unpack_tuple(self, wire: OutPortV) -> list[OutPortV]:
Expand Down Expand Up @@ -243,9 +256,11 @@ def visit_Call(self, node: ast.Call) -> OutPortV:
raise InternalGuppyError("Node should have been removed during type checking.")

def visit_TypeApply(self, node: TypeApply) -> OutPortV:
func = self.visit(node.value)
assert isinstance(func.ty, FunctionType)
ta = self.graph.add_type_apply(func, node.inst, self.dfg.node).out_port(0)
# For now, we can only TypeApply global FunctionDefs/Decls.
if not isinstance(node.value, GlobalName):
raise InternalGuppyError("Dynamic TypeApply not supported yet!")
defn = self.globals[node.value.def_id]
assert isinstance(defn, CompiledCallableDef)

# We have to be very careful here: If we instantiate `foo: forall T. T -> T`
# with a tuple type `tuple[A, B]`, we get the type `tuple[A, B] -> tuple[A, B]`.
Expand All @@ -254,22 +269,25 @@ def visit_TypeApply(self, node: TypeApply) -> OutPortV:
# function with a single output port typed `tuple[A, B]`.
# TODO: We would need to do manual monomorphisation in that case to obtain a
# function that returns two ports as expected
if instantiation_needs_unpacking(func.ty, node.inst):
if instantiation_needs_unpacking(defn.ty, node.inst):
raise GuppyError(
"Generic function instantiations returning rows are not supported yet",
node,
)

return ta
return defn.load_with_args(node.inst, self.dfg, self.graph, self.globals, node)

def visit_UnaryOp(self, node: ast.UnaryOp) -> OutPortV:
# The only case that is not desugared by the type checker is the `not` operation
# since it is not implemented via a dunder method
if isinstance(node.op, ast.Not):
arg = self.visit(node.operand)
return self.graph.add_node(
ops.CustomOp(extension="logic", op_name="Not", args=[]), inputs=[arg]
).add_out_port(bool_type())
op = ops.CustomOp(
extension="logic", op_name="Not", args=[], parent=UNDEFINED
)
return self.graph.add_node(ops.OpType(op), inputs=[arg]).add_out_port(
bool_type()
)

raise InternalGuppyError("Node should have been removed during type checking.")

Expand All @@ -281,7 +299,7 @@ def visit_DesugaredListComp(self, node: DesugaredListComp) -> OutPortV:
# Make up a name for the list under construction and bind it to an empty list
list_ty = get_type(node)
list_name = with_type(list_ty, with_loc(node, LocalName(id=next(tmp_vars))))
empty_list = self.graph.add_node(ops.DummyOp(name="MakeList"))
empty_list = self.graph.add_node(DummyOp("MakeList"))
self.dfg[list_name.id] = PortVariable(
list_name.id, empty_list.add_out_port(list_ty), node, None
)
Expand All @@ -292,7 +310,7 @@ def compile_generators(elt: ast.expr, gens: list[DesugaredGenerator]) -> None:
if not gens:
list_port, elt_port = self.visit(list_name), self.visit(elt)
push = self.graph.add_node(
ops.DummyOp(name="Push"), inputs=[list_port, elt_port]
DummyOp("Push"), inputs=[list_port, elt_port]
)
self.dfg[list_name.id].port = push.add_out_port(list_port.ty)
return
Expand Down Expand Up @@ -348,7 +366,7 @@ def instantiation_needs_unpacking(func_ty: FunctionType, inst: Inst) -> bool:
return False


def python_value_to_hugr(v: Any, exp_ty: Type) -> val.Value | None:
def python_value_to_hugr(v: Any, exp_ty: Type) -> ops.Value | None:
"""Turns a Python value into a Hugr value.

Returns None if the Python value cannot be represented in Guppy.
Expand All @@ -373,15 +391,13 @@ def python_value_to_hugr(v: Any, exp_ty: Type) -> val.Value | None:
python_value_to_hugr(elt, ty)
for elt, ty in zip(elts, exp_ty.element_types)
]
if any(value is None for value in vs):
return None
return val.Tuple(vs=vs)
if doesnt_contain_none(vs):
return ops.Value(ops.TupleValue(vs=vs))
Comment on lines -376 to +395
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mypy started complaining about this, so I added a doesnt_contain_none type guard

case list(elts):
assert is_list_type(exp_ty)
return list_value(
[python_value_to_hugr(elt, get_element_type(exp_ty)) for elt in elts],
get_element_type(exp_ty).to_hugr(),
)
vs = [python_value_to_hugr(elt, get_element_type(exp_ty)) for elt in elts]
if doesnt_contain_none(vs):
return list_value(vs, get_element_type(exp_ty))
case _:
# Pytket conversion is an optional feature
try:
Expand All @@ -393,7 +409,15 @@ def python_value_to_hugr(v: Any, exp_ty: Type) -> val.Value | None:
)

hugr = json.loads(Tk2Circuit(v).to_hugr_json())
return val.FunctionVal(hugr=hugr)
return ops.Value(ops.FunctionValue(hugr=hugr))
except ImportError:
pass
return None
return None


T = TypeVar("T")


def doesnt_contain_none(xs: list[T | None]) -> TypeGuard[list[T]]:
"""Checks if a list contains `None`."""
return all(x is not None for x in xs)
6 changes: 3 additions & 3 deletions guppylang/compiler/func_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
DFContainer,
PortVariable,
)
from guppylang.hugr.hugr import DFContainingVNode, Hugr
from guppylang.hugr_builder.hugr import DFContainingVNode, Hugr
from guppylang.nodes import CheckedNestedFunctionDef
from guppylang.tys.ty import FunctionType, type_to_row

Expand Down Expand Up @@ -60,7 +60,7 @@ def compile_local_func_def(
# the function itself, then we provide the partially applied function as a local
# variable
if len(captured) > 0 and func.name in func.cfg.live_before[func.cfg.entry_bb]:
loaded = graph.add_load_constant(def_node.out_port(0), def_node).out_port(0)
loaded = graph.add_load_function(def_node.out_port(0), [], def_node).out_port(0)
partial = graph.add_partial(
loaded, [def_input.out_port(i) for i in range(len(captured))], def_node
)
Expand Down Expand Up @@ -93,7 +93,7 @@ def compile_local_func_def(
)

# Finally, load the function into the local data-flow graph
loaded = graph.add_load_constant(def_node.out_port(0), dfg.node).out_port(0)
loaded = graph.add_load_function(def_node.out_port(0), [], dfg.node).out_port(0)
if len(captured) > 0:
loaded = graph.add_partial(
loaded, [dfg[v.name].port for v in captured], dfg.node
Expand Down
2 changes: 1 addition & 1 deletion guppylang/compiler/stmt_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from guppylang.compiler.expr_compiler import ExprCompiler
from guppylang.error import InternalGuppyError
from guppylang.hugr.hugr import Hugr, OutPortV
from guppylang.hugr_builder.hugr import Hugr, OutPortV
from guppylang.nodes import CheckedNestedFunctionDef
from guppylang.tys.ty import TupleType

Expand Down
5 changes: 3 additions & 2 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from types import ModuleType
from typing import Any, TypeVar

from hugr.serialization import ops, tys

from guppylang.ast_util import has_empty_body
from guppylang.definition.common import DefId
from guppylang.definition.custom import (
Expand All @@ -21,8 +23,7 @@
from guppylang.definition.struct import RawStructDef
from guppylang.definition.ty import OpaqueTypeDef, TypeDef
from guppylang.error import GuppyError, MissingModuleError, pretty_errors
from guppylang.hugr import ops, tys
from guppylang.hugr.hugr import Hugr
from guppylang.hugr_builder.hugr import Hugr
from guppylang.module import GuppyModule, PyFunc

FuncDefDecorator = Callable[[PyFunc], RawFunctionDef]
Expand Down
2 changes: 1 addition & 1 deletion guppylang/definition/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, ClassVar, TypeAlias

from guppylang.hugr.hugr import Hugr, Node
from guppylang.hugr_builder.hugr import Hugr, Node

if TYPE_CHECKING:
from guppylang.checker.core import Globals
Expand Down
29 changes: 15 additions & 14 deletions guppylang/definition/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field

from hugr.serialization import ops

from guppylang.ast_util import AstNode, get_type, with_loc, with_type
from guppylang.checker.core import Context, Globals
from guppylang.checker.expr_checker import check_call, synthesize_call
Expand All @@ -10,8 +12,7 @@
from guppylang.definition.common import ParsableDef
from guppylang.definition.value import CompiledCallableDef
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.hugr import ops
from guppylang.hugr.hugr import Hugr, Node, OutPortV
from guppylang.hugr_builder.hugr import Hugr, Node, OutPortV
from guppylang.nodes import GlobalCall
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import FunctionType, NoneType, Type, type_to_row
Expand Down Expand Up @@ -123,8 +124,13 @@ def synthesize_call(
new_node, ty = self.call_checker.synthesize(args)
return with_type(ty, with_loc(node, new_node)), ty

def load(
self, dfg: "DFContainer", graph: Hugr, globals: CompiledGlobals, node: AstNode
def load_with_args(
self,
type_args: Inst,
dfg: "DFContainer",
graph: Hugr,
globals: CompiledGlobals,
node: AstNode,
) -> OutPortV:
"""Loads the custom function as a value into a local dataflow graph.

Expand All @@ -138,12 +144,7 @@ def load(
"This function does not support usage in a higher-order context",
node,
)

if self.ty.parametrized:
raise InternalGuppyError(
"Can't yet generate higher-order versions of custom functions. This "
"requires generic function *definitions*"
)
Comment on lines -142 to -146
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A nice benefit of having the type args available here is that we can properly handle this case now!

assert len(self.ty.params) == len(type_args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does self.higher_order_value still make sense? If the function is turned into a runtime value, that must be first-order (with TypeArgs provided to the LoadFunction node that generates the value - you'll need to infer these). If it's statically-called, then the function is higher-order, but the call is not (again, TypeArgs provided)....

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe in the context of a CustomFunctionDef rather than reading a field called self.higher_order_value we should be checking whether self.ty has any binders or not. (If we need to define higher_order_value in the superclass RawCustomFunctionDef then it might be an @abstractproperty or some such?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Or....why actually do we need to extend RawCustomFunctionDef at all? Is it to inherit compile_call? Does RawCustomFunctionDef actually need compile_call or could that be moved into CustomFuncDef? It feels like parsing Raw -> non-Raw should be a once-and-once-only process, so if you want to compile_call you had probably better parse the type signature (from Raw to non-Raw) first ??)

Copy link
Collaborator Author

@mark-koch mark-koch May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are non-polymorphic custum Guppy functions whose type cannot be represented in the type system. One example are functions that are implemented via dunder methods (e.g. iter(x), len(x), ...). Therefore, these need self.higher_order_value=False which cannot be inferred from the type.

Copy link
Collaborator Author

@mark-koch mark-koch May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does RawCustomFunctionDef actually need compile_call or could that be moved into CustomFuncDef?

This is a hack to allow stuff like

class IntTruedivCompiler(CustomCallCompiler):
"""Compiler for the `int.__truediv__` method."""
def compile(self, args: list[OutPortV]) -> list[OutPortV]:
from .builtins import Float, Int
# Compile `truediv` using float arithmetic
[left, right] = args
[left] = Int.__float__.compile_call(
[left], [], self.dfg, self.graph, self.globals, self.node
)
[right] = Int.__float__.compile_call(
[right], [], self.dfg, self.graph, self.globals, self.node
)
[out] = Float.__truediv__.compile_call(
[left, right], [], self.dfg, self.graph, self.globals, self.node
)
return [out]

since Int.__float__ etc are RawCustomFunctionDefs. I agree this is not nice but hopefully only a temporary hack. Once we have proper linking, we can define truediv etc via actual Guppy function defs and no longer need these CustomCallCompilers.

To be fair, reflecting on this, it would probably be better to look these methods up in globals instead of locally importing from .builtins... I'd be happy to make an issue if you agree?


# Find the module node by walking up the hierarchy
module: Node = dfg.node
Expand All @@ -159,7 +160,7 @@ def load(
def_node = graph.add_def(self.ty, module, self.name)
_, inp_ports = graph.add_input_with_ports(list(self.ty.inputs), def_node)
returns = self.compile_call(
inp_ports, [], DFContainer(def_node, {}), graph, globals, node
inp_ports, type_args, DFContainer(def_node, {}), graph, globals, node
)
graph.add_output(returns, parent=def_node)

Expand Down Expand Up @@ -251,14 +252,14 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]:
class OpCompiler(CustomCallCompiler):
"""Call compiler for functions that are directly implemented via Hugr ops."""

op: ops.BaseOp
op: ops.OpType

def __init__(self, op: ops.BaseOp) -> None:
def __init__(self, op: ops.OpType) -> None:
self.op = op

def compile(self, args: list[OutPortV]) -> list[OutPortV]:
node = self.graph.add_node(
self.op.model_copy(), inputs=args, parent=self.dfg.node
self.op.model_copy(deep=True), inputs=args, parent=self.dfg.node
)
return_ty = get_type(self.node)
return [node.add_out_port(ty) for ty in type_to_row(return_ty)]
Expand Down
Loading
Loading