Skip to content

Commit

Permalink
Use Cirq-FT's multi-dimensional registers directly in BloqAsCirqGate (#…
Browse files Browse the repository at this point in the history
…353)

* Use Cirq-FT's multi-dimensional registers directly in BloqAsCirqGate

* Fix failing tests

---------

Co-authored-by: Fionn Malone <[email protected]>
  • Loading branch information
tanujkhattar and fdmalone authored Aug 17, 2023
1 parent 005afb8 commit 72248b0
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 49 deletions.
3 changes: 1 addition & 2 deletions qualtran/_infra/composite_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,8 +907,7 @@ def add(self, bloq: Bloq, **in_soqs: SoquetInT) -> Union[None, SoquetT, Tuple[So
unpacking. In this final case, the ordering is according to `bloq.signature`
and irrespective of the order of `**in_soqs`.
"""
binst = BloqInstance(bloq, i=self._new_binst_i())
outs = tuple(soq for _, soq in self._add_binst(binst, in_soqs=in_soqs))
outs = self.add_t(bloq, **in_soqs)
if len(outs) == 0:
return None
if len(outs) == 1:
Expand Down
9 changes: 4 additions & 5 deletions qualtran/bloqs/swap_network_cirq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,11 @@ def test_swap_with_zero_gate(selection_bitsize, target_bitsize, n_target_registe
# Allocate selection and target qubits.
all_qubits = cirq.LineQubit.range(cirq.num_qubits(gate))
selection = all_qubits[:selection_bitsize]
targets = {
f'targets_{i}': all_qubits[st : st + target_bitsize]
for i, st in enumerate(range(selection_bitsize, len(all_qubits), target_bitsize))
}
targets = np.asarray(all_qubits[selection_bitsize:]).reshape(
(n_target_registers, target_bitsize)
)
# Create a circuit.
circuit = cirq.Circuit(gate.on_registers(selection=selection, **targets))
circuit = cirq.Circuit(gate.on_registers(selection=selection, targets=targets))

# Load data[i] in i'th target register; where each register is of size target_bitsize
data = [random.randint(0, 2**target_bitsize - 1) for _ in range(n_target_registers)]
Expand Down
63 changes: 22 additions & 41 deletions qualtran/cirq_interop/_cirq_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ def _update_assign_from_cirq_quregs(
arr = np.asarray(arr)
full_shape = reg.shape + (reg.bitsize,)
if arr.shape != full_shape:
raise ValueError(f"Incorrect shape {arr.shape} received for {binst}.{reg.name}")
raise ValueError(
f"Incorrect shape {arr.shape} received for {binst}.{reg.name}. Expected {full_shape}."
)

for idx in reg.all_idxs():
soq = Soquet(binst, reg, idx=idx)
Expand Down Expand Up @@ -356,9 +358,7 @@ def registers(self) -> LegacyRegisters:
return self._legacy_regs

@staticmethod
def _init_legacy_regs(
bloq: Bloq,
) -> Tuple[LegacyRegisters, Mapping[str, Tuple[Register, Tuple[int, ...]]]]:
def _init_legacy_regs(bloq: Bloq) -> Tuple[LegacyRegisters, Mapping[str, Register]]:
"""Initialize legacy registers.
We flatten multidimensional registers and annotate non-thru registers with
Expand All @@ -373,18 +373,10 @@ def _init_legacy_regs(
side_suffixes = {Side.LEFT: '_l', Side.RIGHT: '_r', Side.THRU: ''}
compat_name_map = {}
for reg in bloq.signature:
if not reg.shape:
compat_name = f'{reg.name}{side_suffixes[reg.side]}'
compat_name_map[compat_name] = (reg, ())
legacy_regs.append(LegacyRegister(name=compat_name, shape=reg.bitsize))
continue

for idx in reg.all_idxs():
idx_str = '_'.join(str(i) for i in idx)
compat_name = f'{reg.name}{side_suffixes[reg.side]}_{idx_str}'
compat_name_map[compat_name] = (reg, idx)
legacy_regs.append(LegacyRegister(name=compat_name, shape=reg.bitsize))

compat_name = f'{reg.name}{side_suffixes[reg.side]}'
compat_name_map[compat_name] = reg
full_shape = reg.shape + (reg.bitsize,)
legacy_regs.append(LegacyRegister(name=compat_name, shape=full_shape))
return LegacyRegisters(legacy_regs), compat_name_map

@classmethod
Expand All @@ -405,27 +397,25 @@ def bloq_on(
op: A cirq operation whose gate is the `BloqAsCirqGate`-wrapped version of `bloq`.
cirq_quregs: The output cirq qubit registers.
"""
flat_qubits: List[cirq.Qid] = []
bloq_quregs: Dict[str, 'CirqQuregT'] = {}
out_quregs: Dict[str, 'CirqQuregT'] = {}
for reg in bloq.signature:
if reg.side is Side.THRU:
for i, q in enumerate(cirq_quregs[reg.name].reshape(-1)):
flat_qubits.append(q)
bloq_quregs[reg.name] = cirq_quregs[reg.name]
out_quregs[reg.name] = cirq_quregs[reg.name]
elif reg.side is Side.LEFT:
for i, q in enumerate(cirq_quregs[reg.name].reshape(-1)):
flat_qubits.append(q)
bloq_quregs[f'{reg.name}_l'] = cirq_quregs[reg.name]
qubit_manager.qfree(cirq_quregs[reg.name].reshape(-1))
del cirq_quregs[reg.name]
elif reg.side is Side.RIGHT:
new_qubits = qubit_manager.qalloc(reg.total_bits())
flat_qubits.extend(new_qubits)
out_quregs[reg.name] = np.array(new_qubits).reshape(reg.shape + (reg.bitsize,))

return BloqAsCirqGate(bloq=bloq).on(*flat_qubits), out_quregs
full_shape = reg.shape + (reg.bitsize,)
out_quregs[reg.name] = np.array(new_qubits).reshape(full_shape)
bloq_quregs[f'{reg.name}_r'] = out_quregs[reg.name]
return BloqAsCirqGate(bloq=bloq).on_registers(**bloq_quregs), out_quregs

def decompose_from_registers(
self, context: cirq.DecompositionContext, **qubit_regs: Sequence[cirq.Qid]
self, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
) -> cirq.OP_TREE:
"""Implementation of the GatesWithRegisters decompose method.
Expand All @@ -434,28 +424,18 @@ def decompose_from_registers(
Args:
context: `cirq.DecompositionContext` stores options for decomposing gates (eg:
cirq.QubitManager).
**qubit_regs: Sequences of cirq qubits as expected for the legacy register shims
**quregs: Sequences of cirq qubits as expected for the legacy register shims
of the bloq's registers.
Returns:
A cirq circuit containing the cirq-exported version of the bloq decomposition.
"""
cbloq = self._bloq.decompose_bloq()

# Initialize shapely qubit registers to pass to bloqs infrastructure
cirq_quregs: Dict[str, CirqQuregT] = {}
for reg in self._bloq.signature:
if reg.shape:
shape = reg.shape + (reg.bitsize,)
cirq_quregs[reg.name] = np.empty(shape, dtype=object)

# Shapefy the provided cirq qubits
for compat_name, qubits in qubit_regs.items():
reg, idx = self._compat_name_map[compat_name]
if idx == ():
cirq_quregs[reg.name] = np.asarray(qubits)
else:
cirq_quregs[reg.name][idx] = np.asarray(qubits)
for compat_name, qubits in quregs.items():
reg = self._compat_name_map[compat_name]
cirq_quregs[reg.name] = np.asarray(qubits)

circuit, _ = cbloq.to_cirq_circuit(qubit_manager=context.qubit_manager, **cirq_quregs)
return circuit
Expand All @@ -481,7 +461,8 @@ def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.Circ
symbs = reg_to_wires(reg)
assert len(symbs) == reg.total_bits()
wire_symbols.extend(symbs)

if self._reg_to_wires is None:
wire_symbols[0] = self._bloq.pretty_name()
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)

def __eq__(self, other):
Expand Down
75 changes: 74 additions & 1 deletion qualtran/cirq_interop/_cirq_interop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from qualtran import Bloq, BloqBuilder, CompositeBloq, Side, Signature, Soquet, SoquetT
from qualtran.bloqs.and_bloq import MultiAnd
from qualtran.bloqs.basic_gates import XGate
from qualtran.bloqs.swap_network import SwapWithZero
from qualtran.cirq_interop import (
BloqAsCirqGate,
cirq_optree_to_cbloq,
CirqGateAsBloq,
CirqQuregT,
Expand Down Expand Up @@ -178,7 +180,7 @@ def test_bloq_as_cirq_gate_left_register():
bb.free(q)
cbloq = bb.finalize()
circuit, _ = cbloq.to_cirq_circuit()
cirq.testing.assert_has_diagram(circuit, """_c(0): ───alloc───X───free───""")
cirq.testing.assert_has_diagram(circuit, """_c(0): ───Allocate───X───Free───""")


@frozen
Expand Down Expand Up @@ -223,5 +225,76 @@ def test_bloq_decompose_from_cirq_op():
TestCNOTSymbolic().decompose_bloq()


def test_bloq_as_cirq_gate_multi_dimensional_signature():
bloq = SwapWithZero(2, 3, 4)
cirq_quregs = bloq.signature.get_cirq_quregs()
op = BloqAsCirqGate(bloq).on_registers(**cirq_quregs)
cirq.testing.assert_has_diagram(
cirq.Circuit(op),
'''
selection0: ──────SwapWithZero───
selection1: ──────selection──────
targets[0, 0]: ───targets────────
targets[0, 1]: ───targets────────
targets[0, 2]: ───targets────────
targets[1, 0]: ───targets────────
targets[1, 1]: ───targets────────
targets[1, 2]: ───targets────────
targets[2, 0]: ───targets────────
targets[2, 1]: ───targets────────
targets[2, 2]: ───targets────────
targets[3, 0]: ───targets────────
targets[3, 1]: ───targets────────
targets[3, 2]: ───targets────────
''',
)
cbloq = bloq.decompose_bloq()
cirq.testing.assert_has_diagram(
cbloq.to_cirq_circuit(**cirq_quregs)[0],
'''
selection0: ──────────────────────────────@(approx)───
selection1: ──────@(approx)───@(approx)───┼───────────
│ │ │
targets[0, 0]: ───×(x)────────┼───────────×(x)────────
│ │ │
targets[0, 1]: ───×(x)────────┼───────────×(x)────────
│ │ │
targets[0, 2]: ───×(x)────────┼───────────×(x)────────
│ │ │
targets[1, 0]: ───×(y)────────┼───────────┼───────────
│ │ │
targets[1, 1]: ───×(y)────────┼───────────┼───────────
│ │ │
targets[1, 2]: ───×(y)────────┼───────────┼───────────
│ │
targets[2, 0]: ───────────────×(x)────────×(y)────────
│ │
targets[2, 1]: ───────────────×(x)────────×(y)────────
│ │
targets[2, 2]: ───────────────×(x)────────×(y)────────
targets[3, 0]: ───────────────×(y)────────────────────
targets[3, 1]: ───────────────×(y)────────────────────
targets[3, 2]: ───────────────×(y)────────────────────
''',
)


def test_notebook():
execute_notebook('cirq_interop')

0 comments on commit 72248b0

Please sign in to comment.