Skip to content

Commit

Permalink
Add MCMT Pauli Bloqs from Cirq-FT (#406)
Browse files Browse the repository at this point in the history
* Add MCMT Pauli Bloqs from Cirq-FT

* Fix pylint

---------

Co-authored-by: Fionn Malone <[email protected]>
  • Loading branch information
tanujkhattar and fdmalone authored Oct 16, 2023
1 parent 810cf48 commit 980b8e7
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 0 deletions.
117 changes: 117 additions & 0 deletions qualtran/bloqs/multi_control_multi_target_pauli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import cirq
import numpy as np
from attrs import field, frozen
from cirq._compat import cached_property
from numpy.typing import NDArray

from qualtran import GateWithRegisters, Signature
from qualtran.bloqs.and_bloq import And, MultiAnd
from qualtran.cirq_interop.t_complexity_protocol import t_complexity, TComplexity


@frozen
class MultiTargetCNOT(GateWithRegisters):
"""Implements single control, multi-target CNOT_{n} gate in 2*log(n) + 1 CNOT depth.
Implements CNOT_{n} = |0><0| I + |1><1| X^{n} using a circuit of depth 2*log(n) + 1
containing only CNOT gates. See Appendix B.1 of https://arxiv.org/abs/1812.00954 for
reference.
"""

bitsize: int

@cached_property
def signature(self) -> Signature:
return Signature.build(control=1, targets=self.bitsize)

def decompose_from_registers(
self,
*,
context: cirq.DecompositionContext,
**quregs: NDArray[cirq.Qid], # type:ignore[type-var]
):
control, targets = quregs['control'], quregs['targets']

def cnots_for_depth_i(i: int, q: NDArray[cirq.Qid]) -> cirq.OP_TREE:
for c, t in zip(q[: 2**i], q[2**i : min(len(q), 2 ** (i + 1))]):
yield cirq.CNOT(c, t)

depth = len(targets).bit_length()
for i in range(depth):
yield cirq.Moment(cnots_for_depth_i(depth - i - 1, targets))
yield cirq.CNOT(*control, targets[0])
for i in range(depth):
yield cirq.Moment(cnots_for_depth_i(i, targets))

def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo:
return cirq.CircuitDiagramInfo(wire_symbols=["@"] + ["X"] * self.bitsize)


@frozen
class MultiControlPauli(GateWithRegisters):
"""Implements multi-control, single-target C^{n}P gate.
Implements $C^{n}P = (1 - |1^{n}><1^{n}|) I + |1^{n}><1^{n}| P^{n}$ using $n-1$
clean ancillas using a multi-controlled `AND` gate.
References:
[Constructing Large Controlled Nots]
(https://algassert.com/circuits/2015/06/05/Constructing-Large-Controlled-Nots.html)
"""

cvs: Tuple[int, ...] = field(converter=lambda v: (v,) if isinstance(v, int) else tuple(v))
target_gate: cirq.Pauli = cirq.X

@cached_property
def signature(self) -> Signature:
return Signature.build(controls=len(self.cvs), target=1)

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: NDArray['cirq.Qid']
) -> cirq.OP_TREE:
controls, target = quregs['controls'], quregs['target']
qm = context.qubit_manager
and_ancilla, and_target = np.array(qm.qalloc(len(self.cvs) - 2)), qm.qalloc(1)
ctrl, junk = controls[:, np.newaxis], and_ancilla[:, np.newaxis]
if len(self.cvs) == 2:
and_op = And(*self.cvs).on_registers(ctrl=ctrl, target=and_target)
else:
and_op = MultiAnd(self.cvs).on_registers(ctrl=ctrl, junk=junk, target=and_target)
yield and_op
yield self.target_gate.on(*target).controlled_by(*and_target)
yield and_op**-1
qm.qfree([*and_ancilla, *and_target])

def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo:
wire_symbols = ["@" if b else "@(0)" for b in self.cvs]
wire_symbols += [str(self.target_gate)]
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)

def _t_complexity_(self) -> TComplexity:
and_gate = And(*self.cvs) if len(self.cvs) == 2 else MultiAnd(self.cvs)
and_cost = t_complexity(and_gate)
controlled_pauli_cost = t_complexity(self.target_gate.controlled(1))
and_inv_cost = t_complexity(and_gate**-1)
return and_cost + controlled_pauli_cost + and_inv_cost

def _apply_unitary_(self, args: 'cirq.ApplyUnitaryArgs') -> np.ndarray:
return cirq.apply_unitary(self.target_gate.controlled(control_values=self.cvs), args)

def _has_unitary_(self) -> bool:
return True
43 changes: 43 additions & 0 deletions qualtran/bloqs/multi_control_multi_target_pauli_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cirq
import numpy as np
import pytest

from qualtran.bloqs.multi_control_multi_target_pauli import MultiControlPauli, MultiTargetCNOT
from qualtran.cirq_interop.testing import assert_decompose_is_consistent_with_t_complexity
from qualtran.testing import assert_valid_bloq_decomposition


@pytest.mark.parametrize("num_targets", [3, 4, 6, 8, 10])
def test_multi_target_cnot(num_targets):
qubits = cirq.LineQubit.range(num_targets + 1)
naive_circuit = cirq.Circuit(cirq.CNOT(qubits[0], q) for q in qubits[1:])
op = MultiTargetCNOT(num_targets).on(*qubits)
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
cirq.Circuit(op), naive_circuit, atol=1e-6
)
optimal_circuit = cirq.Circuit(cirq.decompose_once(op))
assert len(optimal_circuit) == 2 * np.ceil(np.log2(num_targets)) + 1
assert_valid_bloq_decomposition(op.gate)


@pytest.mark.parametrize("num_controls", [*range(7, 17)])
@pytest.mark.parametrize("pauli", [cirq.X, cirq.Y, cirq.Z])
@pytest.mark.parametrize('cv', [0, 1])
def test_t_complexity_mcp(num_controls: int, pauli: cirq.Pauli, cv: int):
gate = MultiControlPauli([cv] * num_controls, target_gate=pauli)
assert_valid_bloq_decomposition(gate)
assert_decompose_is_consistent_with_t_complexity(gate)

0 comments on commit 980b8e7

Please sign in to comment.