Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Merge hugr's schema Value with Const #172

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
TupleType,
type_to_row,
)
from guppylang.hugr import ops, val
from guppylang.hugr import ops
from guppylang.hugr.hugr import DFContainingNode, OutPortV, VNode
from guppylang.nodes import (
DesugaredGenerator,
Expand Down Expand Up @@ -297,8 +297,8 @@ def instantiation_needs_unpacking(func_ty: FunctionType, inst: Inst) -> bool:
return False


def python_value_to_hugr(v: Any, exp_ty: GuppyType) -> val.Value | None:
"""Turns a Python value into a Hugr value.
def python_value_to_hugr(v: Any, exp_ty: GuppyType) -> ops.Const | None:
"""Turns a Python value into a Hugr constant value.

Returns None if the Python value cannot be represented in Guppy.
"""
Expand All @@ -322,14 +322,15 @@ def python_value_to_hugr(v: Any, exp_ty: GuppyType) -> val.Value | None:
python_value_to_hugr(elt, ty)
for elt, ty in zip(elts, exp_ty.element_types)
]
if any(value is None for value in vs):
if None in vs:
return None
return val.Tuple(vs=vs)
return ops.Tuple(vs=vs)
case list(elts):
assert isinstance(exp_ty, ListType)
return list_value(
[python_value_to_hugr(elt, exp_ty.element_type) for elt in elts]
)
consts = [python_value_to_hugr(elt, exp_ty.element_type) for elt in elts]
if None in consts:
return None
return list_value(consts) # type: ignore[arg-type] # Lint warns about passible None elements
case _:
# Pytket conversion is an optional feature
try:
Expand All @@ -341,7 +342,7 @@ def python_value_to_hugr(v: Any, exp_ty: GuppyType) -> val.Value | None:
)

hugr = json.loads(Tk2Circuit(v).to_hugr_json())
return val.FunctionVal(hugr=hugr)
return ops.FunctionConst(hugr=hugr)
except ImportError:
pass
return None
8 changes: 3 additions & 5 deletions guppylang/hugr/hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
row_to_type,
type_to_row,
)
from guppylang.hugr import tys, val
from guppylang.hugr import tys

NodeIdx = int
PortOffset = int
Expand Down Expand Up @@ -354,12 +354,10 @@ def set_root_name(self, name: str) -> VNode:
return self.root

def add_constant(
self, value: val.Value, ty: GuppyType, parent: Node | None = None
self, const: ops.Const, ty: GuppyType, parent: Node | None = None
) -> VNode:
"""Adds a constant node holding a given value to the graph."""
return self.add_node(
ops.Const(value=value, typ=ty.to_hugr()), [], [ty], parent, None
)
return self.add_node(const, [], [ty], parent, None)

def add_input(
self, output_tys: TypeList | None = None, parent: Node | None = None
Expand Down
49 changes: 42 additions & 7 deletions guppylang/hugr/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,14 @@
from typing_extensions import TypeAliasType

import guppylang.hugr.tys as tys

from .tys import (
from guppylang.hugr.tys import (
ExtensionId,
ExtensionSet,
FunctionType,
PolyFuncType,
Type,
TypeRow,
)
from .val import Value

NodeID = int

Expand Down Expand Up @@ -93,12 +91,50 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
self.signature = out


class Const(BaseOp):
"""A constant value definition."""
CustomConst = Any # TODO


class ExtensionConst(BaseOp):
"""An extension constant value, that can check it is of a given [CustomType]."""

op: Literal["Const"] = "Const"
c: Literal["Extension"] = "Extension"
e: CustomConst


class FunctionConst(BaseOp):
"""A higher-order function value."""

op: Literal["Const"] = "Const"
value: Value
c: Literal["Function"] = "Function"
hugr: Any # TODO


class Tuple(BaseOp):
"""A tuple."""

op: Literal["Const"] = "Const"
c: Literal["Tuple"] = "Tuple"
vs: list["Const"]


class Sum(BaseOp):
"""A Sum variant

For any Sum type where this value meets the type of the variant indicated by the tag
"""

op: Literal["Const"] = "Const"
c: Literal["Sum"] = "Sum"
tag: int
typ: Type
vs: list["Const"]


Const = Annotated[
(ExtensionConst | FunctionConst | Tuple | Sum),
Field(discriminator="c"),
]


# -----------------------------------------------
Expand Down Expand Up @@ -422,7 +458,6 @@ class TypeApplication(BaseModel):
(
Module
| Case
| Module
| FuncDefn
| FuncDecl
| Const
Expand Down
60 changes: 0 additions & 60 deletions guppylang/hugr/val.py

This file was deleted.

2 changes: 1 addition & 1 deletion guppylang/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def compiled(self) -> bool:
return self._compiled

@pretty_errors
def compile(self) -> Hugr | None:
def compile(self) -> Hugr:
"""Compiles the module and returns the final Hugr."""
if self.compiled:
raise GuppyError("Module has already been compiled")
Expand Down
18 changes: 9 additions & 9 deletions guppylang/prelude/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from guppylang.error import GuppyError, GuppyTypeError
from guppylang.gtypes import BoolType, FunctionType, GuppyType, Subst, unify
from guppylang.hugr import ops, tys, val
from guppylang.hugr import ops, tys
from guppylang.hugr.hugr import OutPortV
from guppylang.nodes import GlobalCall

Expand Down Expand Up @@ -59,24 +59,24 @@ class ListValue(BaseModel):
value: list[Any]


def bool_value(b: bool) -> val.Value:
def bool_value(b: bool) -> ops.Const:
"""Returns the Hugr representation of a boolean value."""
return val.Sum(tag=int(b), value=val.Tuple(vs=[]))
return ops.Sum(tag=int(b), vs=[], typ=BoolType())


def int_value(i: int) -> val.Value:
def int_value(i: int) -> ops.Const:
"""Returns the Hugr representation of an integer value."""
return val.ExtensionVal(c=(ConstIntS(log_width=INT_WIDTH, value=i),))
return ops.ExtensionConst(e=ConstIntS(log_width=INT_WIDTH, value=i))


def float_value(f: float) -> val.Value:
def float_value(f: float) -> ops.Const:
"""Returns the Hugr representation of a float value."""
return val.ExtensionVal(c=(ConstF64(value=f),))
return ops.ExtensionConst(e=ConstF64(value=f))


def list_value(v: list[val.Value]) -> val.Value:
def list_value(v: list[ops.Const]) -> ops.Const:
"""Returns the Hugr representation of a list value."""
return val.ExtensionVal(c=(ListValue(value=v),))
return ops.ExtensionConst(e=ListValue(value=v))


def logic_op(op_name: str, args: list[tys.TypeArg] | None = None) -> ops.OpType:
Expand Down
Loading
Loading