Skip to content

Commit

Permalink
Implement get_ctrl_system for swap bloqs (#1320)
Browse files Browse the repository at this point in the history
* `get_ctrl_system` for swap bloqs

* rename `cbloq` to `cswap`
  • Loading branch information
anurudhp authored Aug 21, 2024
1 parent f614ce1 commit 84324f9
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
46 changes: 44 additions & 2 deletions qualtran/bloqs/basic_gates/swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,18 @@
# limitations under the License.

from functools import cached_property
from typing import Dict, Iterator, List, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union
from typing import (
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
Union,
)

import cirq
import numpy as np
Expand All @@ -27,6 +38,7 @@
BloqBuilder,
BloqDocSpec,
ConnectionT,
CtrlSpec,
DecomposeTypeError,
GateWithRegisters,
Register,
Expand All @@ -46,7 +58,7 @@
if TYPE_CHECKING:
import quimb.tensor as qtn

from qualtran import CompositeBloq
from qualtran import AddControlledT, CompositeBloq
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT

Expand Down Expand Up @@ -103,6 +115,21 @@ def on_classical_vals(
def adjoint(self) -> 'Bloq':
return self

def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> tuple['Bloq', 'AddControlledT']:
if ctrl_spec != CtrlSpec():
return super().get_ctrl_system(ctrl_spec=ctrl_spec)

cswap = TwoBitCSwap()

def adder(
bb: 'BloqBuilder', ctrl_soqs: Sequence['SoquetT'], in_soqs: Dict[str, 'SoquetT']
) -> Tuple[Iterable['SoquetT'], Iterable['SoquetT']]:
(ctrl,) = ctrl_soqs
ctrl, x, y = bb.add(cswap, ctrl=ctrl, x=in_soqs['x'], y=in_soqs['y'])
return [ctrl], [x, y]

return cswap, adder


@frozen
class TwoBitCSwap(Bloq):
Expand Down Expand Up @@ -243,6 +270,21 @@ def wire_symbol(self, reg: Optional['Register'], idx: Tuple[int, ...] = ()) -> '
def adjoint(self) -> 'Bloq':
return self

def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> tuple['Bloq', 'AddControlledT']:
if ctrl_spec != CtrlSpec():
return super().get_ctrl_system(ctrl_spec=ctrl_spec)

cswap = CSwap(self.bitsize)

def adder(
bb: 'BloqBuilder', ctrl_soqs: Sequence['SoquetT'], in_soqs: Dict[str, 'SoquetT']
) -> Tuple[Iterable['SoquetT'], Iterable['SoquetT']]:
(ctrl,) = ctrl_soqs
ctrl, x, y = bb.add(cswap, ctrl=ctrl, x=in_soqs['x'], y=in_soqs['y'])
return [ctrl], [x, y]

return cswap, adder


@bloq_example(generalizer=ignore_split_join)
def _swap_small() -> Swap:
Expand Down
10 changes: 10 additions & 0 deletions qualtran/bloqs/basic_gates/swap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
_cswap_symb,
_swap_matrix,
_swap_small,
Swap,
)
from qualtran.cirq_interop.t_complexity_protocol import t_complexity, TComplexity
from qualtran.resource_counting.generalizers import ignore_split_join
Expand Down Expand Up @@ -83,6 +84,10 @@ def test_two_bit_swap_call_classically():
assert y == 0


def test_two_bit_swap_controlled():
assert TwoBitSwap().controlled() == TwoBitCSwap()


def _set_ctrl_two_bit_swap(ctrl_bit):
states = [ZeroState(), OneState()]
effs = [ZeroEffect(), OneEffect()]
Expand Down Expand Up @@ -212,6 +217,11 @@ def test_cswap_symbolic():
cswap.decompose_bloq()


def test_swap_controlled():
bitsize = 4
assert Swap(bitsize).controlled() == CSwap(bitsize)


def test_swap_small(bloq_autotester):
bloq_autotester(_swap_small)

Expand Down

0 comments on commit 84324f9

Please sign in to comment.