Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MCMT Pauli Bloqs from Cirq-FT #406

Merged
merged 4 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions qualtran/bloqs/multi_control_multi_target_pauli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import cirq
import numpy as np
from attrs import field, frozen
from cirq._compat import cached_property
from numpy.typing import NDArray

from qualtran import GateWithRegisters, Signature
from qualtran.bloqs.and_bloq import And, MultiAnd
from qualtran.cirq_interop.t_complexity_protocol import t_complexity, TComplexity


@frozen
class MultiTargetCNOT(GateWithRegisters):
"""Implements single control, multi-target CNOT_{n} gate in 2*log(n) + 1 CNOT depth.

Implements CNOT_{n} = |0><0| I + |1><1| X^{n} using a circuit of depth 2*log(n) + 1
containing only CNOT gates. See Appendix B.1 of https://arxiv.org/abs/1812.00954 for
reference.
"""

bitsize: int

@cached_property
def signature(self) -> Signature:
return Signature.build(control=1, targets=self.bitsize)

def decompose_from_registers(
self,
*,
context: cirq.DecompositionContext,
**quregs: NDArray[cirq.Qid], # type:ignore[type-var]
):
control, targets = quregs['control'], quregs['targets']

def cnots_for_depth_i(i: int, q: NDArray[cirq.Qid]) -> cirq.OP_TREE:
for c, t in zip(q[: 2**i], q[2**i : min(len(q), 2 ** (i + 1))]):
yield cirq.CNOT(c, t)

depth = len(targets).bit_length()
for i in range(depth):
yield cirq.Moment(cnots_for_depth_i(depth - i - 1, targets))
yield cirq.CNOT(*control, targets[0])
for i in range(depth):
yield cirq.Moment(cnots_for_depth_i(i, targets))

def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo:
return cirq.CircuitDiagramInfo(wire_symbols=["@"] + ["X"] * self.bitsize)


@frozen
class MultiControlPauli(GateWithRegisters):
"""Implements multi-control, single-target C^{n}P gate.

Implements $C^{n}P = (1 - |1^{n}><1^{n}|) I + |1^{n}><1^{n}| P^{n}$ using $n-1$
clean ancillas using a multi-controlled `AND` gate.

References:
[Constructing Large Controlled Nots]
(https://algassert.com/circuits/2015/06/05/Constructing-Large-Controlled-Nots.html)
"""

cvs: Tuple[int, ...] = field(converter=lambda v: (v,) if isinstance(v, int) else tuple(v))
target_gate: cirq.Pauli = cirq.X

@cached_property
def signature(self) -> Signature:
return Signature.build(controls=len(self.cvs), target=1)

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: NDArray['cirq.Qid']
) -> cirq.OP_TREE:
controls, target = quregs['controls'], quregs['target']
qm = context.qubit_manager
and_ancilla, and_target = np.array(qm.qalloc(len(self.cvs) - 2)), qm.qalloc(1)
ctrl, junk = controls[:, np.newaxis], and_ancilla[:, np.newaxis]
if len(self.cvs) == 2:
and_op = And(*self.cvs).on_registers(ctrl=ctrl, target=and_target)
else:
and_op = MultiAnd(self.cvs).on_registers(ctrl=ctrl, junk=junk, target=and_target)
yield and_op
yield self.target_gate.on(*target).controlled_by(*and_target)
yield and_op**-1
qm.qfree([*and_ancilla, *and_target])

def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo:
wire_symbols = ["@" if b else "@(0)" for b in self.cvs]
wire_symbols += [str(self.target_gate)]
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)

def _t_complexity_(self) -> TComplexity:
and_gate = And(*self.cvs) if len(self.cvs) == 2 else MultiAnd(self.cvs)
and_cost = t_complexity(and_gate)
controlled_pauli_cost = t_complexity(self.target_gate.controlled(1))
and_inv_cost = t_complexity(and_gate**-1)
return and_cost + controlled_pauli_cost + and_inv_cost

def _apply_unitary_(self, args: 'cirq.ApplyUnitaryArgs') -> np.ndarray:
return cirq.apply_unitary(self.target_gate.controlled(control_values=self.cvs), args)

def _has_unitary_(self) -> bool:
return True
43 changes: 43 additions & 0 deletions qualtran/bloqs/multi_control_multi_target_pauli_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cirq
import numpy as np
import pytest

from qualtran.bloqs.multi_control_multi_target_pauli import MultiControlPauli, MultiTargetCNOT
from qualtran.cirq_interop.testing import assert_decompose_is_consistent_with_t_complexity
from qualtran.testing import assert_valid_bloq_decomposition


@pytest.mark.parametrize("num_targets", [3, 4, 6, 8, 10])
def test_multi_target_cnot(num_targets):
qubits = cirq.LineQubit.range(num_targets + 1)
naive_circuit = cirq.Circuit(cirq.CNOT(qubits[0], q) for q in qubits[1:])
op = MultiTargetCNOT(num_targets).on(*qubits)
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
cirq.Circuit(op), naive_circuit, atol=1e-6
)
optimal_circuit = cirq.Circuit(cirq.decompose_once(op))
assert len(optimal_circuit) == 2 * np.ceil(np.log2(num_targets)) + 1
assert_valid_bloq_decomposition(op.gate)


@pytest.mark.parametrize("num_controls", [*range(7, 17)])
@pytest.mark.parametrize("pauli", [cirq.X, cirq.Y, cirq.Z])
@pytest.mark.parametrize('cv', [0, 1])
def test_t_complexity_mcp(num_controls: int, pauli: cirq.Pauli, cv: int):
gate = MultiControlPauli([cv] * num_controls, target_gate=pauli)
assert_valid_bloq_decomposition(gate)
assert_decompose_is_consistent_with_t_complexity(gate)
Loading