From 84324f9ffe1ca79eb04a8360f62c356e20908caf Mon Sep 17 00:00:00 2001 From: Anurudh Peduri <7265746+anurudhp@users.noreply.github.com> Date: Tue, 20 Aug 2024 17:26:42 -0700 Subject: [PATCH] Implement `get_ctrl_system` for swap bloqs (#1320) * `get_ctrl_system` for swap bloqs * rename `cbloq` to `cswap` --- qualtran/bloqs/basic_gates/swap.py | 46 +++++++++++++++++++++++-- qualtran/bloqs/basic_gates/swap_test.py | 10 ++++++ 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/qualtran/bloqs/basic_gates/swap.py b/qualtran/bloqs/basic_gates/swap.py index 67ac4040f..fece5642d 100644 --- a/qualtran/bloqs/basic_gates/swap.py +++ b/qualtran/bloqs/basic_gates/swap.py @@ -13,7 +13,18 @@ # limitations under the License. from functools import cached_property -from typing import Dict, Iterator, List, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union +from typing import ( + Dict, + Iterable, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + TYPE_CHECKING, + Union, +) import cirq import numpy as np @@ -27,6 +38,7 @@ BloqBuilder, BloqDocSpec, ConnectionT, + CtrlSpec, DecomposeTypeError, GateWithRegisters, Register, @@ -46,7 +58,7 @@ if TYPE_CHECKING: import quimb.tensor as qtn - from qualtran import CompositeBloq + from qualtran import AddControlledT, CompositeBloq from qualtran.resource_counting import BloqCountT, SympySymbolAllocator from qualtran.simulation.classical_sim import ClassicalValT @@ -103,6 +115,21 @@ def on_classical_vals( def adjoint(self) -> 'Bloq': return self + def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> tuple['Bloq', 'AddControlledT']: + if ctrl_spec != CtrlSpec(): + return super().get_ctrl_system(ctrl_spec=ctrl_spec) + + cswap = TwoBitCSwap() + + def adder( + bb: 'BloqBuilder', ctrl_soqs: Sequence['SoquetT'], in_soqs: Dict[str, 'SoquetT'] + ) -> Tuple[Iterable['SoquetT'], Iterable['SoquetT']]: + (ctrl,) = ctrl_soqs + ctrl, x, y = bb.add(cswap, ctrl=ctrl, x=in_soqs['x'], y=in_soqs['y']) + return [ctrl], [x, y] + + return cswap, adder + @frozen class TwoBitCSwap(Bloq): @@ -243,6 +270,21 @@ def wire_symbol(self, reg: Optional['Register'], idx: Tuple[int, ...] = ()) -> ' def adjoint(self) -> 'Bloq': return self + def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> tuple['Bloq', 'AddControlledT']: + if ctrl_spec != CtrlSpec(): + return super().get_ctrl_system(ctrl_spec=ctrl_spec) + + cswap = CSwap(self.bitsize) + + def adder( + bb: 'BloqBuilder', ctrl_soqs: Sequence['SoquetT'], in_soqs: Dict[str, 'SoquetT'] + ) -> Tuple[Iterable['SoquetT'], Iterable['SoquetT']]: + (ctrl,) = ctrl_soqs + ctrl, x, y = bb.add(cswap, ctrl=ctrl, x=in_soqs['x'], y=in_soqs['y']) + return [ctrl], [x, y] + + return cswap, adder + @bloq_example(generalizer=ignore_split_join) def _swap_small() -> Swap: diff --git a/qualtran/bloqs/basic_gates/swap_test.py b/qualtran/bloqs/basic_gates/swap_test.py index 7a46ce621..60a61994b 100644 --- a/qualtran/bloqs/basic_gates/swap_test.py +++ b/qualtran/bloqs/basic_gates/swap_test.py @@ -36,6 +36,7 @@ _cswap_symb, _swap_matrix, _swap_small, + Swap, ) from qualtran.cirq_interop.t_complexity_protocol import t_complexity, TComplexity from qualtran.resource_counting.generalizers import ignore_split_join @@ -83,6 +84,10 @@ def test_two_bit_swap_call_classically(): assert y == 0 +def test_two_bit_swap_controlled(): + assert TwoBitSwap().controlled() == TwoBitCSwap() + + def _set_ctrl_two_bit_swap(ctrl_bit): states = [ZeroState(), OneState()] effs = [ZeroEffect(), OneEffect()] @@ -212,6 +217,11 @@ def test_cswap_symbolic(): cswap.decompose_bloq() +def test_swap_controlled(): + bitsize = 4 + assert Swap(bitsize).controlled() == CSwap(bitsize) + + def test_swap_small(bloq_autotester): bloq_autotester(_swap_small)