Skip to content

Commit

Permalink
Address Matt's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
tanujkhattar committed Sep 26, 2023
1 parent f7773aa commit de47c29
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 1,055 deletions.
191 changes: 102 additions & 89 deletions qualtran/cirq_interop/_cirq_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,25 @@
"""Functionality for the `Bloq.as_cirq_op(...)` protocol"""
import itertools
from functools import cached_property
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)

import attr
import cirq
import cirq_ft
import networkx as nx
import numpy as np
import quimb.tensor as qtn
from attrs import frozen
from attrs import field, frozen
from numpy.typing import NDArray

from qualtran import (
Expand Down Expand Up @@ -144,33 +154,71 @@ def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
begin = 0
symbol: str = soq.pretty()
for reg in self.signature:
finish = begin + int(np.product(reg.shape))
finish = begin + int(np.prod(reg.shape))
if reg == soq.reg:
symbol = np.array(wire_symbols[begin:finish]).reshape(reg.shape)[soq.idx]
begin = finish
return directional_text_box(text=symbol, side=soq.reg.side)


def _split_qvars_for_regs(
qvars: Sequence[Soquet], signature: Signature
) -> Dict[str, NDArray[Soquet]]:
"""Split a flat list of soquets into a dictionary corresponding to `signature`."""
qvars_regs = {}
base = 0
for reg in signature:
assert reg.bitsize == 1
qvars_regs[reg.name] = np.array(qvars[base : base + reg.total_bits()]).reshape(reg.shape)
base += reg.total_bits()
return qvars_regs


@attr.frozen
class QReg:
qubits: Tuple[cirq.Qid, ...] = attr.field(
@frozen
class _QReg:
"""Used as a container for qubits that form a `cirq_ft.Register` of a given bitsize.
Each instance of `_QReg` would correspond to a `Soquet` in Bloqs and represents and opaque collection
of qubits that together form a quantum register.
"""

qubits: Tuple[cirq.Qid, ...] = field(
converter=lambda v: (v,) if isinstance(v, cirq.Qid) else tuple(v)
)


def _ensure_in_reg_exists(
bb: BloqBuilder, in_reg: _QReg, qreg_to_qvar: Dict[_QReg, Soquet]
) -> None:
"""Takes care of splits and joins to make sure `qreg_to_qvar[in_reg]` exists."""

if in_reg in qreg_to_qvar:
# This is the easy case when no split / joins are needed.
return

# a. Split all registers containing at-least one qubit corresponding to `in_reg`.
in_reg_qubits = set(in_reg.qubits)

new_qreg_to_qvar: Dict[_QReg, Soquet] = {}
for qreg, soq in qreg_to_qvar.items():
if len(qreg.qubits) > 1 and any(q in qreg.qubits for q in in_reg_qubits):
new_qreg_to_qvar |= {_QReg(q): s for q, s in zip(qreg.qubits, bb.split(soq=soq))}
else:
new_qreg_to_qvar[qreg] = soq
qreg_to_qvar.clear()

# b. Join all 1-bit registers, corresponding to individual qubits, that make up `in_reg`.
soqs_to_join = []
for qreg, soq in new_qreg_to_qvar.items():
if len(in_reg_qubits) > 1 and qreg.qubits and qreg.qubits[0] in in_reg_qubits:
assert len(qreg.qubits) == 1, "Individual qubits should have been split by now."
soqs_to_join.append(soq)
else:
qreg_to_qvar[qreg] = soq
if soqs_to_join:
qreg_to_qvar[in_reg] = bb.join(np.array(soqs_to_join))


def _gather_input_soqs(
bb: BloqBuilder, op_quregs: Dict[str, NDArray[_QReg]], qreg_to_qvar: Dict[_QReg, Soquet]
) -> Dict[str, NDArray[Soquet]]:
qvars_in: Dict[str, NDArray[Soquet]] = {}
for reg_name, quregs in op_quregs.items():
flat_soqs: List[Soquet] = []
for qureg in quregs.flatten():
_ensure_in_reg_exists(bb, qureg, qreg_to_qvar)
flat_soqs.append(qreg_to_qvar[qureg])
qvars_in[reg_name] = np.array(flat_soqs).reshape(quregs.shape)
return qvars_in


def cirq_optree_to_cbloq(
optree: cirq.OP_TREE,
*,
Expand All @@ -179,93 +227,58 @@ def cirq_optree_to_cbloq(
) -> CompositeBloq:
"""Convert a Cirq OP-TREE into a `CompositeBloq` with signature `signature`.
Each `cirq.Operation` will be wrapped into a `CirqGateAsBloq` wrapper.
Each `cirq.Operation` will be wrapped into a `CirqGateAsBloq` wrapper.
The signature of the resultant CompositeBloq is `signature`, if provided. Otherwise, use
one thru-register named "qubits" of shape `(n_qubits,)`.
For multi-dimensional registers and registers with bitsize>1, this function automatically
splits the input soquets and joins the output soquets to ensure compatibility with the
flat-list-of-qubits expected by Cirq.
When specifying a signature, users must also specify the `cirq_quregs` argument, which is a
mapping of cirq qubits used in the OP-TREE corresponding to the `signature`. If `signature`
has registers with entry
If `signature` is not None, the signature of the resultant CompositeBloq is `signature`. For
multi-dimensional registers and registers with > 1 bitsize, this function automatically
splits the input soquets into a flat list and joins the output soquets into the correct shape
to ensure compatibility with the flat API expected by Cirq. When specifying a signature, users
must also specify the `cirq_quregs` argument, which is a mapping of cirq qubits used in the
OP-TREE corresponding to the `signature`. If `signature` has registers with entry
- `Register('x', bitsize=2, shape=(3, 4))` and
- `Register('y', bitsize=1, shape=(10, 20))`
then `cirq_quregs` should have one entry corresponding to each register as follows:
- key='x'; value=`np.array(cirq_qubits_used_in_optree, shape=(3, 4, 2))` and
- key='y'; value=`np.array(cirq_qubits_used_in_optree, shape=(10, 20, 1))`.
If `signature` is None, the resultant composite bloq will have one thru-register named "qubits"
of shape `(n_qubits,)`.
- key='x'; value=`np.array(cirq_qubits_used_for_x, shape=(3, 4, 2))` and
- key='y'; value=`np.array(cirq_qubits_used_for_y, shape=(10, 20, 1))`.
"""
circuit = cirq.Circuit(optree)
# "qubits" means cirq qubits | "qvars" means bloq Soquets
if signature is None:
assert cirq_quregs is None
if cirq_quregs is not None:
raise ValueError("`cirq_quregs` requires specifying `signature`.")
all_qubits = sorted(circuit.all_qubits())
signature = Signature([Register('qubits', 1, shape=(len(all_qubits),))])
cirq_quregs = {'qubits': np.array(all_qubits).reshape(len(all_qubits), 1)}
elif cirq_quregs is None:
raise ValueError("`signature` requires specifying `cirq_quregs`.")

cirq_quregs = {k: np.apply_along_axis(QReg, -1, v) for k, v in cirq_quregs.items()}

assert signature is not None and cirq_quregs is not None
cirq_quregs = {k: np.apply_along_axis(_QReg, -1, v) for k, v in cirq_quregs.items()}

bb, initial_soqs = BloqBuilder.from_signature(signature, add_registers_allowed=False)

# 0. Helper functions
def _update_qreg_to_qvar(in_reg: QReg, qreg_to_qvar: Dict[QReg, Soquet]) -> None:
"""Takes care of splits and joins to make sure `qreg_to_qvar[in_reg]` exists."""
if in_reg in qreg_to_qvar:
# This is the easy case when no split / joins are needed.
return

# Split everything and join s.t. joined register corresponds to `in_reg`.
# 1. Split all registers containing at-least one qubit corresponding to `in_reg`.
in_reg_qubits = set(in_reg.qubits)

new_qreg_to_qvar: Dict[QReg, Soquet] = {}
for qreg, soq in qreg_to_qvar.items():
if len(qreg.qubits) > 1 and any(q in qreg.qubits for q in in_reg_qubits):
new_qreg_to_qvar |= {QReg(q): s for q, s in zip(qreg.qubits, bb.split(soq=soq))}
else:
new_qreg_to_qvar[qreg] = soq
qreg_to_qvar.clear()
soqs_to_join = []
for qreg, soq in new_qreg_to_qvar.items():
if len(in_reg_qubits) > 1 and qreg.qubits and qreg.qubits[0] in in_reg_qubits:
assert len(qreg.qubits) == 1, "Individual qubits should have been split by now."
soqs_to_join.append(soq)
else:
qreg_to_qvar[qreg] = soq
if soqs_to_join:
qreg_to_qvar[in_reg] = bb.join(np.array(soqs_to_join))

def _find_input_soqs(
op_quregs: Dict[str, NDArray[QReg]], qreg_to_qvar: Dict[QReg, Soquet]
) -> Dict[str, NDArray[Soquet]]:
qvars_in: Dict[str, NDArray[Soquet]] = {}
for reg_name, quregs in op_quregs.items():
flat_soqs: List[Soquet] = []
for qureg in quregs.flatten():
_update_qreg_to_qvar(qureg, qreg_to_qvar)
flat_soqs.append(qreg_to_qvar[qureg])
qvars_in[reg_name] = np.array(flat_soqs).reshape(quregs.shape)
return qvars_in

# 1. Compute qreg_to_qvar.
qreg_to_qvar: Dict[QReg, Soquet] = {}
qreg_to_qvar: Dict[_QReg, Soquet] = {}
for reg in signature.lefts():
assert reg.name in cirq_quregs
if reg.name not in cirq_quregs:
raise ValueError(f"Register {reg.name} from signature must be present in cirq_quregs.")
soqs = initial_soqs[reg.name]
if isinstance(soqs, Soquet):
soqs = np.asarray(soqs)
assert cirq_quregs[reg.name].shape == soqs.shape, (
f'{reg.name=}, {cirq_quregs[reg.name]=}, {soqs=},'
f'{cirq_quregs[reg.name].shape=}, {soqs.shape=}'
)
soqs = np.array(soqs)
if cirq_quregs[reg.name].shape != soqs.shape:
raise ValueError(
f"Shape {cirq_quregs[reg.name].shape} of cirq register "
f"{reg.name} should be {soqs.shape}."
)
qreg_to_qvar |= zip(cirq_quregs[reg.name].flatten(), soqs.flatten())

# 2. Add allocated qubits to qreg_to_qvar
all_qubits = set(q for qreg in qreg_to_qvar for q in qreg.qubits)
allocated_qubits = QReg(sorted(circuit.all_qubits() - all_qubits))
allocated_qubits = _QReg(sorted(circuit.all_qubits() - all_qubits))
if allocated_qubits.qubits:
qreg_to_qvar |= {allocated_qubits: bb.allocate(len(allocated_qubits.qubits))}

Expand All @@ -276,11 +289,11 @@ def _find_input_soqs(

bloq = CirqGateAsBloq(op.gate)
# 3.1 Find input soquets.
op_quregs: Dict[str, NDArray[QReg]] = {
k: np.apply_along_axis(QReg, -1, v)
op_quregs: Dict[str, NDArray[_QReg]] = {
k: np.apply_along_axis(_QReg, -1, v)
for k, v in cirq_ft.infra.split_qubits(bloq.cirq_registers, op.qubits).items()
}
qvars_in = _find_input_soqs(op_quregs, qreg_to_qvar)
qvars_in = _gather_input_soqs(bb, op_quregs, qreg_to_qvar)
# 3.2 Add Bloq
qvars_out = bb.add_d(bloq, **qvars_in)

Expand All @@ -290,11 +303,11 @@ def _find_input_soqs(

# 4. Deallocated newly allocated qubits.
if allocated_qubits.qubits:
_update_qreg_to_qvar(allocated_qubits, qreg_to_qvar)
_ensure_in_reg_exists(bb, allocated_qubits, qreg_to_qvar)
bb.free(qreg_to_qvar.pop(allocated_qubits))

# 5. Combine Soquets to match the right signature.
final_soqs = _find_input_soqs(cirq_quregs, qreg_to_qvar)
final_soqs = _gather_input_soqs(bb, cirq_quregs, qreg_to_qvar)
return bb.finalize(**final_soqs)


Expand Down
Loading

0 comments on commit de47c29

Please sign in to comment.