Skip to content

Commit

Permalink
More mypy fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Mar 19, 2024
1 parent a80418b commit 338fc6e
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 57 deletions.
4 changes: 2 additions & 2 deletions .github/pre-commit
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ fi
if ! poetry run mypy .
then
echo ""
echo "There are some python code style issues."
echo "Fix the warnings returned by `poetry run mypy .` first."
echo "There are some typing issues."
echo "Fix the warnings returned by 'poetry run mypy .' first."
exit 1
fi

Expand Down
101 changes: 49 additions & 52 deletions quantinuum-hugr-py/src/quantinuum_hugr/serialization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,20 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
# the variant data is appended to successor input. Thus, `predicate_variants`
# will only contain empty rows.
num_cases = len(out_types)
self.tuple_sum_rows = [[] for _ in range(num_cases)]
self.sum_rows = [[] for _ in range(num_cases)]

def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None:
self.inputs = inputs
pred = outputs[0]
assert isinstance(pred, tys.UnitSum | tys.GeneralSum)
if isinstance(pred, tys.UnitSum):
self.tuple_sum_rows = [[] for _ in range(cast(tys.UnitSum, pred).size)]
self.sum_rows = [[] for _ in range(cast(tys.UnitSum, pred).size)]
else:
assert isinstance(pred, tys.GeneralSum)
self.tuple_sum_rows = []
for variant in pred.row:
self.sum_rows = []
for variant in pred.rows:
assert isinstance(variant, tys.TupleType)
self.tuple_sum_rows.append(variant.inner)
self.sum_rows.append(variant.inner)
self.other_outputs = outputs[1:]

class Config:
Expand Down Expand Up @@ -255,9 +255,9 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
# The constE edge comes after the value inputs
fun_ty = in_types[-1]
assert isinstance(fun_ty, PolyFuncType)
fun_ty = cast(PolyFuncType, fun_ty)
assert len(fun_ty.params) == 0
self.signature = fun_ty.body
poly_func = cast(PolyFuncType, fun_ty)
assert len(poly_func.params) == 0
self.signature = poly_func.body

class Config:
# Need to avoid random '\n's in the pydantic description
Expand All @@ -277,11 +277,11 @@ class CallIndirect(DataflowOp):
def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
fun_ty = in_types[0]
assert isinstance(fun_ty, PolyFuncType)
fun_ty = cast(PolyFuncType, fun_ty)
assert len(fun_ty.params) == 0
assert len(fun_ty.body.input) == len(in_types) - 1
assert len(fun_ty.body.output) == len(out_types)
self.signature = fun_ty.body
poly_func = cast(PolyFuncType, fun_ty)
assert len(poly_func.params) == 0
assert len(poly_func.body.input) == len(in_types) - 1
assert len(poly_func.body.output) == len(out_types)
self.signature = poly_func.body


class LoadConstant(DataflowOp):
Expand All @@ -306,7 +306,7 @@ class DFG(DataflowOp):

def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None:
self.signature = FunctionType(
input=list(inputs), output=list(outputs), extension_reqs=[]
input=list(inputs), output=list(outputs), extension_reqs=ExtensionSet([])
)


Expand All @@ -321,6 +321,7 @@ class Conditional(DataflowOp):
op: Literal["Conditional"] = "Conditional"
other_inputs: TypeRow = Field(default_factory=list) # Remaining input types
outputs: TypeRow = Field(default_factory=list) # Output types
sum_rows: list[TypeRow] = Field(description="The possible rows of the Sum input")
# Extensions used to produce the outputs
extension_delta: ExtensionSet = Field(default_factory=list)

Expand All @@ -329,13 +330,13 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
# those into a list of type rows
pred = in_types[0]
if isinstance(pred, tys.UnitSum):
self.tuple_sum_rows = [[] for _ in range(cast(tys.UnitSum, pred).size)]
self.sum_rows = [[] for _ in range(cast(tys.UnitSum, pred).size)]
else:
assert isinstance(pred, tys.GeneralSum)
self.tuple_sum_rows = []
for ty in pred.row:
self.sum_rows = []
for ty in pred.rows:
assert isinstance(ty, tys.TupleType)
self.tuple_sum_rows.append(ty.inner)
self.sum_rows.append(ty.inner)
self.other_inputs = list(in_types[1:])
self.outputs = list(out_types)

Expand All @@ -349,7 +350,7 @@ class Case(BaseOp):

def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None:
self.signature = tys.FunctionType(
input=list(inputs), output=list(outputs), extension_reqs=[]
input=list(inputs), output=list(outputs), extension_reqs=ExtensionSet([])
)


Expand Down Expand Up @@ -377,7 +378,7 @@ class CFG(DataflowOp):

def insert_port_types(self, inputs: TypeRow, outputs: TypeRow) -> None:
self.signature = FunctionType(
input=list(inputs), output=list(outputs), extension_reqs=[]
input=list(inputs), output=list(outputs), extension_reqs=ExtensionSet([])
)


Expand Down Expand Up @@ -483,40 +484,36 @@ class Config:
}


LeafOp = TypeAliasType(
"LeafOp",
Annotated[
(CustomOp | Noop | MakeTuple | UnpackTuple | Tag | TypeApply),
Field(discriminator="lop"),
],
)
class LeafOp(RootModel):
"""A constant operation."""

root: CustomOp | Noop | MakeTuple | UnpackTuple | Tag | TypeApply = Field(
discriminator="lop"
)

OpType = TypeAliasType(
"OpType",
Annotated[
(
Module
| Case
| FuncDefn
| FuncDecl
| Const
| DataflowBlock
| ExitBlock
| Conditional
| TailLoop
| CFG
| Input
| Output
| Call
| CallIndirect
| LoadConstant
| LeafOp
| DFG
),
Field(discriminator="op"),
],
)

class OpType(RootModel):
"""A constant operation."""

root: (
Module
| Case
| FuncDefn
| FuncDecl
| Const
| DataflowBlock
| ExitBlock
| Conditional
| TailLoop
| CFG
| Input
| Output
| Call
| CallIndirect
| LoadConstant
| LeafOp
| DFG
) = Field(discriminator="op")


# --------------------------------------
Expand Down
22 changes: 19 additions & 3 deletions specification/schema/hugr_schema_v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,24 @@
"title": "Outputs",
"type": "array"
},
"sum_rows": {
"description": "The possible rows of the Sum input",
"items": {
"items": {
"$ref": "#/$defs/Type"
},
"type": "array"
},
"title": "Sum Rows",
"type": "array"
},
"extension_delta": {
"$ref": "#/$defs/ExtensionSet"
}
},
"required": [
"parent"
"parent",
"sum_rows"
],
"title": "Conditional",
"type": "object"
Expand Down Expand Up @@ -653,6 +665,7 @@
"type": "object"
},
"LeafOp": {
"description": "A constant operation.",
"discriminator": {
"mapping": {
"CustomOp": "#/$defs/CustomOp",
Expand Down Expand Up @@ -683,7 +696,8 @@
{
"$ref": "#/$defs/TypeApply"
}
]
],
"title": "LeafOp"
},
"ListParam": {
"properties": {
Expand Down Expand Up @@ -816,6 +830,7 @@
"type": "object"
},
"OpType": {
"description": "A constant operation.",
"discriminator": {
"mapping": {
"CFG": "#/$defs/CFG",
Expand Down Expand Up @@ -890,7 +905,8 @@
{
"$ref": "#/$defs/DFG"
}
]
],
"title": "OpType"
},
"Opaque": {
"description": "An opaque operation that can be downcasted by the extensions that define it.",
Expand Down

0 comments on commit 338fc6e

Please sign in to comment.