Skip to content

Commit

Permalink
Extend functionality for CirqGateAsBloq.tensor_contract and add more …
Browse files Browse the repository at this point in the history
…tests
  • Loading branch information
tanujkhattar committed Sep 30, 2023
1 parent 711acaf commit 2cf9930
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 8 deletions.
5 changes: 5 additions & 0 deletions qualtran/cirq_interop/_bloq_to_cirq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,11 @@ def test_bloq_as_cirq_gate_for_mod_exp():
# Use Cirq's infrastructure to construct an operation and corresponding decomposition.
quregs = cirq_ft.infra.get_named_qubits(gate.signature)
op = gate.on_registers(**quregs)
# cirq.decompose_once(op) delegates to underlying Bloq's decomposition specified in
# `bloq.decompose_bloq()` and wraps resulting composite bloq in a Cirq op-tree. Note
# how `BloqAsCirqGate.decompose_with_registers()` automatically takes care of mapping
# newly allocated RIGHT registers in the decomposition to the one's specified by the user
# when constructing the original operation (in this case, register `x`).
circuit = cirq.Circuit(op, cirq.decompose_once(op))
assert cirq_ft.t_complexity(circuit) == 2 * mod_exp.t_complexity()
cirq.testing.assert_has_diagram(
Expand Down
42 changes: 36 additions & 6 deletions qualtran/cirq_interop/_cirq_to_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Cirq gates/circuits to Qualtran Bloqs conversion."""
import itertools
from collections import defaultdict
from functools import cached_property
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union

Expand Down Expand Up @@ -84,13 +85,42 @@ def add_my_tensors(
incoming: Dict[str, 'SoquetT'],
outgoing: Dict[str, 'SoquetT'],
):
new_shape = [
*itertools.chain.from_iterable(
(2**reg.bitsize,) * int(np.prod(reg.shape))
for reg in [*self.signature.rights(), *self.signature.lefts()]
if not cirq.has_unitary(self.gate):
raise NotImplementedError(
f"CirqGateAsBloq.add_my_tensors is currently supported only for unitary gates. "
f"Found {self.gate}."
)
]
unitary = cirq.unitary(self.gate).reshape(new_shape)
unitary_shape = []
reg_to_idx = defaultdict(list)
for reg in self.cirq_registers:
start = len(unitary_shape)
for i in range(int(np.prod(reg.shape))):
reg_to_idx[reg.name].append(start + i)
unitary_shape.append(2**reg.bitsize)

unitary_shape = (*unitary_shape, *unitary_shape)
unitary = cirq.unitary(self.gate).reshape(unitary_shape)
idx: List[Union[int, slice]] = [slice(x) for x in unitary_shape]
n = len(unitary_shape) // 2
for reg in self.signature:
if reg.side == Side.LEFT:
for i in reg_to_idx[reg.name]:
# LEFT register ends, extract right subspace that's equivalent to 0.
idx[i] = 0
if reg.side == Side.RIGHT:
for i in reg_to_idx[reg.name]:
# Right register begins, extract the left subspace that's equivalent to 0.
idx[i + n] = 0
unitary = unitary[tuple(idx)]
new_shape = tuple(
[
*itertools.chain.from_iterable(
(2**reg.bitsize,) * int(np.prod(reg.shape))
for reg in [*self.signature.rights(), *self.signature.lefts()]
)
]
)
assert unitary.shape == new_shape
incoming_list = [
*itertools.chain.from_iterable(
[np.array(incoming[reg.name]).flatten() for reg in self.signature.lefts()]
Expand Down
18 changes: 16 additions & 2 deletions qualtran/cirq_interop/_cirq_to_bloq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from attrs import frozen

import qualtran
from qualtran import Bloq, CompositeBloq, Side, Signature
from qualtran import Bloq, BloqBuilder, CompositeBloq, Side, Signature
from qualtran.bloqs.basic_gates import OneState
from qualtran.bloqs.util_bloqs import Allocate, Free, Join, Split
from qualtran.cirq_interop import (
cirq_optree_to_cbloq,
Expand Down Expand Up @@ -57,7 +58,7 @@ def signature(self) -> Signature:
return Signature.build(control=c, target=t)


def test_cirq_gate():
def test_cirq_gate_as_bloq_for_trivial_gates():
x = CirqGateAsBloq(cirq.X)
rx = CirqGateAsBloq(cirq.Rx(rads=0.123 * np.pi))
toffoli = CirqGateAsBloq(cirq.TOFFOLI)
Expand All @@ -80,6 +81,19 @@ def test_cirq_gate():
assert toffoli.short_name() == 'cirq.TOFFOLI'


def test_cirq_gate_as_bloq_tensor_contract_for_and_gate():
and_gate = cirq_ft.And()
bb = BloqBuilder()
ctrl = [bb.add(OneState()) for _ in range(2)]
ctrl, target = bb.add(CirqGateAsBloq(and_gate), ctrl=ctrl)
cbloq = bb.finalize(ctrl=ctrl, target=target)
state_vector = cbloq.tensor_contract()
assert np.isclose(state_vector[7], 1)

with pytest.raises(NotImplementedError, match="supported only for unitary gates"):
_ = CirqGateAsBloq(cirq_ft.And(adjoint=True)).as_composite_bloq().tensor_contract()


def test_bloq_decompose_from_cirq_op():
tb = TestCNOT()
assert len(tb.signature) == 2
Expand Down

0 comments on commit 2cf9930

Please sign in to comment.