Skip to content

Commit

Permalink
Support cirq registers in CirqGateAsBloq conversion (#355)
Browse files Browse the repository at this point in the history
* Support cirq registers in CirqGateAsBloq conversion

* Fix failing test

* Address nits, add tests and fix a couple of bugs

* Fix typo in test and add comment
  • Loading branch information
tanujkhattar authored Aug 25, 2023
1 parent 72248b0 commit cc232f6
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 52 deletions.
165 changes: 121 additions & 44 deletions qualtran/cirq_interop/_cirq_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Functionality for the `Bloq.as_cirq_op(...)` protocol"""

import itertools
from functools import cached_property
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union

Expand All @@ -23,8 +24,6 @@
import numpy as np
import quimb.tensor as qtn
from attrs import frozen
from cirq_ft import Register as LegacyRegister
from cirq_ft import Registers as LegacyRegisters
from numpy.typing import NDArray

from qualtran import (
Expand All @@ -43,11 +42,16 @@
SoquetT,
)
from qualtran._infra.composite_bloq import _binst_to_cxns
from qualtran.bloqs.util_bloqs import Allocate, Free

CirqQuregT = NDArray[cirq.Qid]
CirqQuregInT = Union[NDArray[cirq.Qid], Sequence[cirq.Qid]]


def signature_from_cirq_registers(registers: Iterable[cirq_ft.Register]) -> 'Signature':
return Signature([Register(reg.name, bitsize=1, shape=reg.shape) for reg in registers])


@frozen
class CirqGateAsBloq(Bloq):
"""A Bloq wrapper around a `cirq.Gate`.
Expand All @@ -67,11 +71,24 @@ def short_name(self) -> str:

@cached_property
def signature(self) -> 'Signature':
return Signature([Register('qubits', 1, shape=(self.n_qubits,))])
return signature_from_cirq_registers(self.cirq_registers)

@cached_property
def n_qubits(self):
return cirq.num_qubits(self.gate)
def cirq_registers(self) -> cirq_ft.Registers:
if isinstance(self.gate, cirq_ft.GateWithRegisters):
return self.gate.registers
else:
return cirq_ft.Registers.build(qubits=cirq.num_qubits(self.gate))

def decompose_bloq(self) -> 'CompositeBloq':
quregs = self.signature.get_cirq_quregs()
qubit_manager = cirq.ops.SimpleQubitManager()
cirq_op, quregs = self.as_cirq_op(qubit_manager, **quregs)
context = cirq.DecompositionContext(qubit_manager=qubit_manager)
decomposed_optree = cirq.decompose_once(cirq_op, context=context, default=None)
if decomposed_optree is None:
raise NotImplementedError(f"{self} does not support decomposition.")
return cirq_optree_to_cbloq(decomposed_optree, signature=self.signature, cirq_quregs=quregs)

def add_my_tensors(
self,
Expand All @@ -81,28 +98,68 @@ def add_my_tensors(
incoming: Dict[str, 'SoquetT'],
outgoing: Dict[str, 'SoquetT'],
):
unitary = cirq.unitary(self.gate).reshape((2,) * 2 * self.n_qubits)
unitary = cirq.unitary(self.gate).reshape((2,) * 2 * self.cirq_registers.total_bits())
incoming_list = [
*itertools.chain.from_iterable(
[np.array(incoming[reg.name]).flatten() for reg in self.signature.lefts()]
)
]
outgoing_list = [
*itertools.chain.from_iterable(
[np.array(outgoing[reg.name]).flatten() for reg in self.signature.rights()]
)
]

tn.add(
qtn.Tensor(
data=unitary,
inds=outgoing['qubits'].tolist() + incoming['qubits'].tolist(),
tags=[self.short_name(), tag],
data=unitary, inds=outgoing_list + incoming_list, tags=[self.short_name(), tag]
)
)

def as_cirq_op(
self, qubit_manager: 'cirq.QubitManager', qubits: 'CirqQuregT'
self, qubit_manager: 'cirq.QubitManager', **cirq_quregs: 'CirqQuregT'
) -> Tuple['cirq.Operation', Dict[str, 'CirqQuregT']]:
assert qubits.shape == (self.n_qubits, 1)
return self.gate.on(*qubits[:, 0]), {'qubits': qubits}
merged_qubits = np.concatenate(
[cirq_quregs[reg.name].flatten() for reg in self.signature.lefts()]
)
assert len(merged_qubits) == cirq.num_qubits(self.gate)
return self.gate.on(*merged_qubits), cirq_quregs

def t_complexity(self) -> 'cirq_ft.TComplexity':
return cirq_ft.t_complexity(self.gate)

def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
from qualtran.drawing import directional_text_box

wire_symbols = cirq.circuit_diagram_info(self.gate).wire_symbols
begin = 0
symbol: str = soq.pretty()
for reg in self.signature:
finish = begin + np.product(reg.shape)
if reg == soq.reg:
symbol = np.array(wire_symbols[begin:finish]).reshape(reg.shape)[soq.idx]
begin = finish
return directional_text_box(text=symbol, side=soq.reg.side)


def _split_qvars_for_regs(
qvars: Sequence[Soquet], signature: Signature
) -> Dict[str, NDArray[Soquet]]:
"""Split a flat list of soquets into a dictionary corresponding to `signature`."""
qvars_regs = {}
base = 0
for reg in signature:
assert reg.bitsize == 1
qvars_regs[reg.name] = np.array(qvars[base : base + reg.total_bits()]).reshape(reg.shape)
base += reg.total_bits()
return qvars_regs


def cirq_optree_to_cbloq(
optree: cirq.OP_TREE, *, signature: Optional[Signature] = None
optree: cirq.OP_TREE,
*,
signature: Optional[Signature] = None,
cirq_quregs: Optional[Dict[str, 'NDArray[cirq.Qid]']] = None,
) -> CompositeBloq:
"""Convert a Cirq OP-TREE into a `CompositeBloq` with signature `signature`.
Expand All @@ -111,67 +168,87 @@ def cirq_optree_to_cbloq(
If `signature` is not None, the signature of the resultant CompositeBloq is `signature`. For
multi-dimensional registers and registers with > 1 bitsize, this function automatically
splits the input soquets into a flat list and joins the output soquets into the correct shape
to ensure compatibility with the flat API expected by Cirq.
to ensure compatibility with the flat API expected by Cirq. When specifying a signature, users
must also specify the `cirq_quregs` argument, which is a mapping of cirq qubits used in the
OP-TREE corresponding to the `signature`. If `signature` has registers with entry
- `Register('x', bitsize=2, shape=(3, 4))` and
- `Register('y', bitsize=1, shape=(10, 20))`
then `cirq_quregs` should have one entry corresponding to each register as follows:
- key='x'; value=`np.array(cirq_qubits_used_in_optree, shape=(3, 4, 2))` and
- key='y'; value=`np.array(cirq_qubits_used_in_optree, shape=(10, 20, 1))`.
If `signature` is None, the resultant composite bloq will have one thru-register named "qubits"
of shape `(n_qubits,)`.
"""
# "qubits" means cirq qubits | "qvars" means bloq Soquets
circuit = cirq.Circuit(optree)
all_qubits = sorted(circuit.all_qubits())
# "qubits" means cirq qubits | "qvars" means bloq Soquets
if signature is None:
assert cirq_quregs is None
all_qubits = sorted(circuit.all_qubits())
signature = Signature([Register('qubits', 1, shape=(len(all_qubits),))])
cirq_quregs = {'qubits': np.array(all_qubits).reshape(len(all_qubits), 1)}

assert signature is not None and cirq_quregs is not None

bb, initial_soqs = BloqBuilder.from_signature(signature, add_registers_allowed=False)

# Magic to make sure signature of the CompositeBloq matches `Signature`.
qvars = {}
qubit_to_qvar = {}
for reg in signature.lefts():
assert reg.name in cirq_quregs
soqs = initial_soqs[reg.name]
if isinstance(soqs, Soquet):
soqs = np.asarray(soqs)[np.newaxis, ...]
if reg.bitsize > 1:
# Need to split all soquets here.
if isinstance(soqs, Soquet):
qvars[reg.name] = bb.split(soqs)
else:
qvars[reg.name] = np.concatenate([bb.split(soq) for soq in soqs.reshape(-1)])
else:
if isinstance(soqs, Soquet):
qvars[reg.name] = [soqs]
else:
qvars[reg.name] = soqs.reshape(-1)

qubit_to_qvar = dict(zip(all_qubits, np.concatenate([*qvars.values()])))
soqs = np.array([bb.split(soq) for soq in soqs.flatten()])
soqs = soqs.reshape(reg.shape + (reg.bitsize,))
assert cirq_quregs[reg.name].shape == soqs.shape
qubit_to_qvar |= zip(cirq_quregs[reg.name].flatten(), soqs.flatten())

allocated_qubits = set()
for op in circuit.all_operations():
if op.gate is None:
raise ValueError(f"Only gate operations are supported, not {op}.")

bloq = CirqGateAsBloq(op.gate)
qvars_for_op = np.array([qubit_to_qvar[qubit] for qubit in op.qubits])
qvars_for_op_out = bb.add(bloq, qubits=qvars_for_op)
qubit_to_qvar |= zip(op.qubits, qvars_for_op_out)
for q in op.qubits:
if q not in qubit_to_qvar:
qubit_to_qvar[q] = bb.add(Allocate(1))
allocated_qubits.add(q)

qvars_in = [qubit_to_qvar[qubit] for qubit in op.qubits]
qvars_out = bb.add_t(bloq, **_split_qvars_for_regs(qvars_in, bloq.signature))
qubit_to_qvar |= zip(
op.qubits, itertools.chain.from_iterable([arr.flatten() for arr in qvars_out])
)

qvar_vals_out = np.array([qubit_to_qvar[qubit] for qubit in all_qubits])
for q in allocated_qubits:
bb.add(Free(1), free=qubit_to_qvar[q])

qvars = np.array([*qubit_to_qvar.values()])
final_soqs = {}
idx = 0
for reg in signature.rights():
name = reg.name
soqs = qvar_vals_out[idx : idx + len(qvars[name])]
idx = idx + len(qvars[name])
assert name in cirq_quregs
soqs = qvars[idx : idx + np.product(cirq_quregs[name].shape)]
idx = idx + np.product(cirq_quregs[name].shape)
if reg.bitsize > 1:
# Need to combine the soquets here.
if len(soqs) == reg.bitsize:
final_soqs[name] = bb.join(soqs)
else:
final_soqs[name] = np.array(
bb.join(subsoqs) for subsoqs in soqs[:: reg.bitsize]
[
bb.join(soqs[st : st + reg.bitsize])
for st in range(0, len(soqs), reg.bitsize)
]
).reshape(reg.shape)
else:
if len(soqs) == 1:
if len(soqs) == 1 and reg.shape == ():
final_soqs[name] = soqs[0]
else:
final_soqs[name] = soqs.reshape(reg.shape)

return bb.finalize(**final_soqs)


Expand Down Expand Up @@ -280,7 +357,7 @@ def decompose_from_cirq_op(bloq: 'Bloq') -> 'CompositeBloq':
):
raise NotImplementedError(f"{bloq} does not support decomposition.")

return cirq_optree_to_cbloq(cirq_op, signature=bloq.signature)
return cirq_optree_to_cbloq(cirq_op, signature=bloq.signature, cirq_quregs=cirq_quregs)


def _cbloq_to_cirq_circuit(
Expand Down Expand Up @@ -353,12 +430,12 @@ def bloq(self) -> Bloq:
return self._bloq

@property
def registers(self) -> LegacyRegisters:
def registers(self) -> cirq_ft.Registers:
"""`cirq_ft.GateWithRegisters` registers."""
return self._legacy_regs

@staticmethod
def _init_legacy_regs(bloq: Bloq) -> Tuple[LegacyRegisters, Mapping[str, Register]]:
def _init_legacy_regs(bloq: Bloq) -> Tuple[cirq_ft.Registers, Mapping[str, Register]]:
"""Initialize legacy registers.
We flatten multidimensional registers and annotate non-thru registers with
Expand All @@ -369,15 +446,15 @@ def _init_legacy_regs(bloq: Bloq) -> Tuple[LegacyRegisters, Mapping[str, Registe
compat_name_map: A mapping from the compatability-shim string names of the legacy
registers back to the original (register, idx) pair.
"""
legacy_regs: List[LegacyRegister] = []
legacy_regs: List[cirq_ft.Register] = []
side_suffixes = {Side.LEFT: '_l', Side.RIGHT: '_r', Side.THRU: ''}
compat_name_map = {}
for reg in bloq.signature:
compat_name = f'{reg.name}{side_suffixes[reg.side]}'
compat_name_map[compat_name] = reg
full_shape = reg.shape + (reg.bitsize,)
legacy_regs.append(LegacyRegister(name=compat_name, shape=full_shape))
return LegacyRegisters(legacy_regs), compat_name_map
legacy_regs.append(cirq_ft.Register(name=compat_name, shape=full_shape))
return cirq_ft.Registers(legacy_regs), compat_name_map

@classmethod
def bloq_on(
Expand Down
63 changes: 63 additions & 0 deletions qualtran/cirq_interop/_cirq_interop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,20 @@
# limitations under the License.
from typing import Dict, Tuple

import attr
import cirq
import cirq_ft
import numpy as np
import pytest
import sympy
from attrs import frozen

import qualtran
from qualtran import Bloq, BloqBuilder, CompositeBloq, Side, Signature, Soquet, SoquetT
from qualtran.bloqs.and_bloq import MultiAnd
from qualtran.bloqs.basic_gates import XGate
from qualtran.bloqs.swap_network import SwapWithZero
from qualtran.bloqs.util_bloqs import Allocate, Free, Join, Split
from qualtran.cirq_interop import (
BloqAsCirqGate,
cirq_optree_to_cbloq,
Expand Down Expand Up @@ -85,6 +88,66 @@ def test_cbloq_to_cirq_circuit():
assert circuit == circuit2


def test_cirq_optree_to_cbloq():
@attr.frozen
class CirqGateWithRegisters(cirq_ft.GateWithRegisters):
reg: cirq_ft.Register

@property
def registers(self) -> cirq_ft.Registers:
return cirq_ft.Registers([self.reg])

reg1 = cirq_ft.Register('x', shape=(3, 4, 2))
reg2 = cirq_ft.Register('y', shape=(12, 2))
anc_reg = cirq_ft.Register('anc', shape=(2, 3))
qubits = cirq.LineQubit.range(24)
anc_qubits = cirq.NamedQubit.range(3, prefix='anc')
circuit = cirq.Circuit(
CirqGateWithRegisters(reg1).on(*qubits),
CirqGateWithRegisters(anc_reg).on(*anc_qubits, *qubits[:3]),
CirqGateWithRegisters(reg2).on(*qubits),
)
# Test-1: When no signature is specified, the method uses a default signature. Ancilla qubits
# are also included in the signature itself, so no allocations / deallocations are needed.
cbloq = cirq_optree_to_cbloq(circuit)
assert cbloq.signature == qualtran.Signature(
[qualtran.Register(name='qubits', bitsize=1, shape=(27,))]
)
bloq_instances = [binst for binst, _, _ in cbloq.iter_bloqnections()]
assert bloq_instances[0].bloq == CirqGateAsBloq(CirqGateWithRegisters(reg1))
assert bloq_instances[0].bloq.signature == qualtran.Signature(
[qualtran.Register(name='x', bitsize=1, shape=(3, 4, 2))]
)
assert bloq_instances[1].bloq == CirqGateAsBloq(CirqGateWithRegisters(anc_reg))
assert bloq_instances[1].bloq.signature == qualtran.Signature(
[qualtran.Register(name='anc', bitsize=1, shape=(2, 3))]
)
assert bloq_instances[2].bloq == CirqGateAsBloq(CirqGateWithRegisters(reg2))
assert bloq_instances[2].bloq.signature == qualtran.Signature(
[qualtran.Register(name='y', bitsize=1, shape=(12, 2))]
)
# Test-2: If you provide an explicit signature, you must also provide a mapping of cirq qubits
# matching the signature. The additional ancilla allocations are automatically handled.
new_signature = qualtran.Signature(
[
qualtran.Register('xx', bitsize=3, shape=(3, 2)),
qualtran.Register('yy', bitsize=1, shape=(2, 3)),
]
)
cirq_quregs = {
'xx': np.asarray(qubits[:18]).reshape((3, 2, 3)),
'yy': np.asarray(qubits[18:]).reshape((2, 3, 1)),
}
cbloq = cirq_optree_to_cbloq(circuit, signature=new_signature, cirq_quregs=cirq_quregs)
assert cbloq.signature == new_signature
# Splits, joins, Alloc, Free are automatically inserted.
bloqs_list = [binst.bloq for binst in cbloq.bloq_instances]
assert bloqs_list.count(Split(3)) == 6
assert bloqs_list.count(Join(3)) == 6
assert bloqs_list.count(Allocate(1)) == 3
assert bloqs_list.count(Free(1)) == 3


@frozen
class SwapTwoBitsTest(Bloq):
@property
Expand Down
Loading

0 comments on commit cc232f6

Please sign in to comment.