Skip to content

Commit

Permalink
Nicer docstrings and code structure improvements.
Browse files Browse the repository at this point in the history
  • Loading branch information
tanujkhattar committed Sep 28, 2023
1 parent 5fd4fe5 commit e19e5c5
Showing 1 changed file with 81 additions and 65 deletions.
146 changes: 81 additions & 65 deletions qualtran/cirq_interop/_cirq_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functionality for the `Bloq.as_cirq_op(...)` protocol"""
"""Bi-directional interop between Qualtran & Cirq using Cirq-FT."""
import itertools
from functools import cached_property
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -53,6 +53,9 @@ def signature_from_cirq_registers(registers: Iterable[cirq_ft.Register]) -> 'Sig
)


# Part-I: Cirq to Bloq conversion.


@frozen
class CirqGateAsBloq(Bloq):
"""A Bloq wrapper around a `cirq.Gate`, preserving signature if gate is a `GateWithRegisters`."""
Expand Down Expand Up @@ -130,8 +133,7 @@ def as_cirq_op(
return self.gate.on(*cirq_quregs['qubits'].flatten()), cirq_quregs
return _construct_op_from_gate(
self.gate,
signature=self.signature,
cirq_quregs={k: np.array(v) for k, v in cirq_quregs.items()},
in_quregs={k: np.array(v) for k, v in cirq_quregs.items()},
qubit_manager=qubit_manager,
)

Expand All @@ -156,7 +158,7 @@ def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
class _QReg:
"""Used as a container for qubits that form a `cirq_ft.Register` of a given bitsize.
Each instance of `_QReg` would correspond to a `Soquet` in Bloqs and represents and opaque collection
Each instance of `_QReg` would correspond to a `Soquet` in Bloqs and represents an opaque collection
of qubits that together form a quantum register.
"""

Expand Down Expand Up @@ -218,8 +220,8 @@ def cirq_optree_to_cbloq(
optree: cirq.OP_TREE,
*,
signature: Optional[Signature] = None,
in_quregs: Optional[Dict[str, 'NDArray[cirq.Qid]']] = None,
out_quregs: Optional[Dict[str, 'NDArray[cirq.Qid]']] = None,
in_quregs: Optional[Dict[str, 'CirqQuregT']] = None,
out_quregs: Optional[Dict[str, 'CirqQuregT']] = None,
) -> CompositeBloq:
"""Convert a Cirq OP-TREE into a `CompositeBloq` with signature `signature`.
Expand All @@ -231,17 +233,25 @@ def cirq_optree_to_cbloq(
splits the input soquets and joins the output soquets to ensure compatibility with the
flat-list-of-qubits expected by Cirq.
When specifying a signature, users must also specify the `cirq_quregs` argument, which is a
mapping of cirq qubits used in the OP-TREE corresponding to the `signature`. If `signature`
has registers with entry
When specifying a signature, users must also specify the `in_quregs` & `out_quregs` arguments,
which are mappings of cirq qubits used in the OP-TREE corresponding to the `LEFT` & `RIGHT`
registers in `signature`. If `signature` has registers with entry
- `Register('x', bitsize=2, shape=(3, 4))` and
- `Register('y', bitsize=1, shape=(10, 20))`
- `Register('x', bitsize=2, shape=(3, 4), side=Side.THRU)`
- `Register('y', bitsize=1, shape=(10, 20), side=Side.LEFT)`
- `Register('z', bitsize=1, shape=(10, 20), side=Side.RIGHT)`
then `cirq_quregs` should have one entry corresponding to each register as follows:
then `in_quregs` should have one entry corresponding to registers `x` and `y` as follows:
- key='x'; value=`np.array(cirq_qubits_used_for_x, shape=(3, 4, 2))` and
- key='y'; value=`np.array(cirq_qubits_used_for_y, shape=(10, 20, 1))`.
and `out_quregs` should have one entry corresponding to registers `x` and `z` as follows:
- key='x'; value=`np.array(cirq_qubits_used_for_x, shape=(3, 4, 2))` and
- key='z'; value=`np.array(cirq_qubits_used_for_z, shape=(10, 20, 1))`.
Any qubit in `optree` which is not part of `in_quregs` and `out_quregs` is considered to be
allocated & deallocated inside the CompositeBloq and does not show up in it's signature.
"""
circuit = cirq.Circuit(optree)
if signature is None:
Expand Down Expand Up @@ -295,18 +305,17 @@ def cirq_optree_to_cbloq(
qvars_out = bb.add_d(bloq, **qvars_in)

# 3.4 Update `qreg_to_qvar` mapping using output soquets `qvars_out`.
for _, regs in bloq.signature.groups():
for reg in regs:
# all_op_quregs should exist for both LEFT & RIGHT registers.
assert reg.name in all_op_quregs
quregs = all_op_quregs[reg.name]
if reg.side == Side.LEFT:
# This register got de-allocated, update the `qreg_to_qvar` mapping.
for q in quregs.flatten():
_ = qreg_to_qvar.pop(q)
else:
assert quregs.shape == np.array(qvars_out[reg.name]).shape
qreg_to_qvar |= zip(quregs.flatten(), np.array(qvars_out[reg.name]).flatten())
for reg in bloq.signature:
# all_op_quregs should exist for both LEFT & RIGHT registers.
assert reg.name in all_op_quregs
quregs = all_op_quregs[reg.name]
if reg.side == Side.LEFT:
# This register got de-allocated, update the `qreg_to_qvar` mapping.
for q in quregs.flatten():
_ = qreg_to_qvar.pop(q)
else:
assert quregs.shape == np.array(qvars_out[reg.name]).shape
qreg_to_qvar |= zip(quregs.flatten(), np.array(qvars_out[reg.name]).flatten())

# 4. Combine Soquets to match the right signature.
final_soqs_dict = _gather_input_soqs(
Expand Down Expand Up @@ -350,19 +359,22 @@ def decompose_from_cirq_op(bloq: 'Bloq') -> 'CompositeBloq':
# Part-II: Bloq to Cirq conversion.


def _track_soq_name_changes(cxns: Iterable[Connection], qvar_to_qreg: Dict[Soquet, _QReg]):
"""Track inter-Bloq name changes across the two ends of a connection."""
for cxn in cxns:
qvar_to_qreg[cxn.right] = qvar_to_qreg[cxn.left]
del qvar_to_qreg[cxn.left]


def _bloq_to_cirq_op(
bloq: Bloq,
pred_cxns: Iterable[Connection],
succ_cxns: Iterable[Connection],
qvar_to_qreg: Dict[Soquet, _QReg],
qubit_manager: cirq.QubitManager,
) -> cirq.Operation:
# Track inter-Bloq name changes. The Soquets are of the same shape.
for cxn in pred_cxns:
qvar_to_qreg[cxn.right] = qvar_to_qreg[cxn.left]
del qvar_to_qreg[cxn.left]

in_quregs: Dict[str, NDArray[cirq.Qid]] = {
_track_soq_name_changes(pred_cxns, qvar_to_qreg)
in_quregs: Dict[str, CirqQuregT] = {
reg.name: np.empty((*reg.shape, reg.bitsize), dtype=object)
for reg in bloq.signature.lefts()
}
Expand Down Expand Up @@ -420,9 +432,7 @@ def _cbloq_to_cirq_circuit(
continue
pred_cxns, succ_cxns = _binst_to_cxns(binst, binst_graph=binst_graph)
if binst == RightDangle:
for cxn in pred_cxns:
qvar_to_qreg[cxn.right] = qvar_to_qreg[cxn.left]
del qvar_to_qreg[cxn.left]
_track_soq_name_changes(pred_cxns, qvar_to_qreg)
continue

op = _bloq_to_cirq_op(binst.bloq, pred_cxns, succ_cxns, qvar_to_qreg, qubit_manager)
Expand All @@ -431,8 +441,8 @@ def _cbloq_to_cirq_circuit(
if moment:
moments.append(cirq.Moment(moment))

# # Track bloq-to-dangle name changes
def _f_quregs(reg: Register):
# Find output Cirq quregs using `qvar_to_qreg` mapping for registers in `signature.rights()`.
def _f_quregs(reg: Register) -> CirqQuregT:
ret = np.empty(reg.shape + (reg.bitsize,), dtype=object)
for idx in reg.all_idxs():
soq = Soquet(RightDangle, idx=idx, reg=reg)
Expand All @@ -446,35 +456,44 @@ def _f_quregs(reg: Register):

def _construct_op_from_gate(
gate: cirq_ft.GateWithRegisters,
signature: Signature,
cirq_quregs: Dict[str, 'CirqQuregT'],
in_quregs: Dict[str, 'CirqQuregT'],
qubit_manager: cirq.QubitManager,
):
in_quregs: Dict[str, 'CirqQuregT'] = {}
for reg in signature.lefts():
) -> Tuple[cirq.Operation, Dict[str, 'CirqQuregT']]:
"""Allocates / Deallocates qubits for RIGHT / LEFT only registers to construct a Cirq operation
Args:
gate: A `cirq_ft.GateWithRegisters` which specifies a signature.
in_quregs: Mapping from LEFT register names of `gate` and corresponding cirq qubits.
qubit_manager: For allocating / deallocating qubits for RIGHT / LEFT only registers.
Returns:
A cirq operation constructed using `gate` and a mapping from RIGHT register names to
corresponding Cirq qubits.
"""
all_quregs: Dict[str, 'CirqQuregT'] = {}
out_quregs: Dict[str, 'CirqQuregT'] = {}
for reg in gate.signature:
full_shape = reg.shape + (reg.bitsize,)
if reg.name not in cirq_quregs or cirq_quregs[reg.name].shape != full_shape:
# Left registers should exist as input to `as_cirq_op`.
raise ValueError(f'Compatible {reg=} must exist in {cirq_quregs=}')
in_quregs[reg.name] = cirq_quregs[reg.name]

for reg in signature.rights():
# Right only registers will get allocated as part of `as_cirq_op`.
if not reg.side & Side.LEFT:
if reg.name in cirq_quregs:
raise ValueError(f"RIGHT register {reg=} shouldn't exist in {cirq_quregs=}.")
in_quregs[reg.name] = np.array(qubit_manager.qalloc(reg.total_bits())).reshape(
reg.shape + (reg.bitsize,)
if reg.side & cirq_ft.infra.Side.LEFT:
if reg.name not in in_quregs or in_quregs[reg.name].shape != full_shape:
# Left registers should exist as input to `as_cirq_op`.
raise ValueError(f'Compatible {reg=} must exist in {in_quregs=}')
all_quregs[reg.name] = in_quregs[reg.name]
if reg.side == cirq_ft.infra.Side.RIGHT:
# Right only registers will get allocated as part of `as_cirq_op`.
if reg.name in in_quregs:
raise ValueError(f"RIGHT register {reg=} shouldn't exist in {in_quregs=}.")
all_quregs[reg.name] = np.array(qubit_manager.qalloc(reg.total_bits())).reshape(
full_shape
)
if reg.side == cirq_ft.infra.Side.LEFT:
# LEFT only registers should be de-allocated and not be part of output.
qubit_manager.qfree(in_quregs[reg.name].flatten())

op = gate.on_registers(**in_quregs)

out_quregs = in_quregs
for reg in signature.lefts():
if not reg.side & Side.RIGHT:
# LEFT only registers should be de-allocated and removed from output.
qubit_manager.qfree(out_quregs.pop(reg.name).flatten())
return op, out_quregs
if reg.side & cirq_ft.infra.Side.RIGHT:
# Right registers should be part of the output.
out_quregs[reg.name] = all_quregs[reg.name]
return gate.on_registers(**all_quregs), out_quregs


class BloqAsCirqGate(cirq_ft.GateWithRegisters):
Expand Down Expand Up @@ -536,14 +555,11 @@ def bloq_on(
cirq_quregs: The output cirq qubit registers.
"""
return _construct_op_from_gate(
BloqAsCirqGate(bloq=bloq),
signature=bloq.signature,
cirq_quregs=cirq_quregs,
qubit_manager=qubit_manager,
BloqAsCirqGate(bloq=bloq), in_quregs=cirq_quregs, qubit_manager=qubit_manager
)

def decompose_from_registers(
self, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
self, context: cirq.DecompositionContext, **quregs: CirqQuregT
) -> cirq.OP_TREE:
"""Implementation of the GatesWithRegisters decompose method.
Expand Down

0 comments on commit e19e5c5

Please sign in to comment.