Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
tatiana-s committed Dec 6, 2024
1 parent a56ab6b commit 99ee621
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
4 changes: 3 additions & 1 deletion guppylang/checker/errors/py_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,7 @@ class InstallInstruction(Help):
@dataclass(frozen=True)
class PytketSignatureMismatch(Error):
title: ClassVar[str] = "Signature mismatch"
span_label: ClassVar[str] = "Signature {name} doesn't match provided pytket circuit"
span_label: ClassVar[str] = (
"Signature `{name}` doesn't match provided pytket circuit"
)
name: str
21 changes: 10 additions & 11 deletions guppylang/definition/pytket_circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, cast

import hugr.build.function as hf
from hugr import Hugr, OutPort, Wire, ops, tys, val
from hugr import Hugr, Wire, val
from hugr.build.dfg import DefinitionBuilder, OpVar

from guppylang.ast_util import AstNode, has_empty_body, with_loc
Expand Down Expand Up @@ -158,21 +158,20 @@ def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef"

# Initialise every input bit in the circuit as false.
# TODO: Provide the option for the user to pass this input as well.
bool_wires: list[OutPort] = []
for child in module.hugr.children(hugr_func):
node_data = module.hugr.get(child)
if node_data and isinstance(node_data.op, ops.Input):
input_types = node_data.op.types
num_bools = input_types.count(tys.Bool)
for _ in range(num_bools):
bool_node = outer_func.load(val.FALSE)
bool_wires.append(*bool_node.outputs())
bool_wires = [
outer_func.load(val.FALSE) for _ in range(self.input_circuit.n_bits)
]

call_node = outer_func.call(
hugr_func, *(list(outer_func.inputs()) + bool_wires)
)
# Pytket circuit hugr has qubit and bool wires in the opposite order.
outer_func.set_outputs(*reversed(list(call_node.outputs())))
output_list = list(call_node.outputs())
wires = (
output_list[self.input_circuit.n_qubits :]
+ output_list[: self.input_circuit.n_qubits]
)
outer_func.set_outputs(*wires)

except ImportError:
pass
Expand Down

0 comments on commit 99ee621

Please sign in to comment.