Skip to content

Commit

Permalink
Flatten nested controls in ControlledViaAnd (#1373)
Browse files Browse the repository at this point in the history
* flatten nested controls in `ControlledViaAnd`

* make `_get_ctrl_spec` a class method

* cleanup comment

* add tests

* add example to notebook

* `self` -> `GateWithRegisters`

---------

Co-authored-by: Matthew Harrigan <[email protected]>
  • Loading branch information
anurudhp and mpharrigan authored Sep 30, 2024
1 parent f95c868 commit 70dd9f5
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 9 deletions.
5 changes: 3 additions & 2 deletions qualtran/_infra/gate_with_registers.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,9 @@ def __pow__(self, power: int) -> 'GateWithRegisters':
return Power(bloq, abs(power))
raise NotImplementedError(f"{self} does not implemented __pow__ for {power=}.")

@classmethod
def _get_ctrl_spec(
self,
cls,
num_controls: Union[Optional[int], 'CtrlSpec'] = None,
control_values=None,
control_qid_shape: Optional[Tuple[int, ...]] = None,
Expand Down Expand Up @@ -498,7 +499,7 @@ def controlled(
Returns:
A controlled version of the bloq.
"""
ctrl_spec = self._get_ctrl_spec(
ctrl_spec = GateWithRegisters._get_ctrl_spec(
num_controls, control_values, control_qid_shape, ctrl_spec=ctrl_spec
)
controlled_bloq, _ = self.get_ctrl_system(ctrl_spec=ctrl_spec)
Expand Down
30 changes: 30 additions & 0 deletions qualtran/bloqs/mcmt/controlled_via_and.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,36 @@
"show_call_graph(controlled_via_and_ints_g)\n",
"show_counts_sigma(controlled_via_and_ints_sigma)"
]
},
{
"cell_type": "markdown",
"id": "11",
"metadata": {},
"source": [
"## Nested Controls\n",
"Calling `controlled` on a `ControlledViaAnd` returns another `ControlledViaAnd` by combining the existing and new controls into a single control specification."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12",
"metadata": {},
"outputs": [],
"source": [
"nested_ctrl_bloq = controlled_via_and_qbits.controlled(CtrlSpec(cvs=[1, 1]))\n",
"show_bloqs([nested_ctrl_bloq])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "13",
"metadata": {},
"outputs": [],
"source": [
"show_call_graph(nested_ctrl_bloq)"
]
}
],
"metadata": {
Expand Down
26 changes: 24 additions & 2 deletions qualtran/bloqs/mcmt/controlled_via_and.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from collections import Counter
from functools import cached_property
from typing import TYPE_CHECKING
from typing import Iterable, Sequence, TYPE_CHECKING

import numpy as np
from attrs import frozen
Expand All @@ -23,7 +23,7 @@
from qualtran.bloqs.mcmt.ctrl_spec_and import CtrlSpecAnd

if TYPE_CHECKING:
from qualtran import BloqBuilder, SoquetT
from qualtran import AddControlledT, BloqBuilder, SoquetT
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator


Expand Down Expand Up @@ -126,6 +126,28 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':

return counts

def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> tuple['Bloq', 'AddControlledT']:
ctrl_spec_combined = CtrlSpec(
qdtypes=ctrl_spec.qdtypes + self.ctrl_spec.qdtypes,
cvs=ctrl_spec.cvs + self.ctrl_spec.cvs,
)
ctrl_bloq = ControlledViaAnd(subbloq=self.subbloq, ctrl_spec=ctrl_spec_combined)

def _adder(
bb: 'BloqBuilder', ctrl_soqs: Sequence['SoquetT'], in_soqs: dict[str, 'SoquetT']
) -> tuple[Iterable['SoquetT'], Iterable['SoquetT']]:
rhs_ctrl_soqs_t = tuple(in_soqs.pop(name) for name in self.ctrl_reg_names)
all_ctrl_soqs_t = tuple([*ctrl_soqs, *rhs_ctrl_soqs_t])

all_ctrl_soqs_d = dict(zip(ctrl_bloq.ctrl_reg_names, all_ctrl_soqs_t))
all_soqs = all_ctrl_soqs_d | in_soqs
all_soqs = bb.add_t(ctrl_bloq, **all_soqs)

n_ctrl_lhs = ctrl_spec.num_ctrl_reg
return all_soqs[:n_ctrl_lhs], all_soqs[n_ctrl_lhs:]

return ctrl_bloq, _adder


@bloq_example
def _controlled_via_and_qbits() -> ControlledViaAnd:
Expand Down
39 changes: 35 additions & 4 deletions qualtran/bloqs/mcmt/controlled_via_and_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
import pytest

from qualtran import Controlled, CtrlSpec, QInt, QUInt
from qualtran.bloqs.basic_gates import XGate
from qualtran.bloqs.for_testing.matrix_gate import MatrixGate
from qualtran.bloqs.mcmt.controlled_via_and import (
_controlled_via_and_ints,
_controlled_via_and_qbits,
ControlledViaAnd,
)
from qualtran.resource_counting import GateCounts, get_cost_value, QECGatesCost


def test_examples(bloq_autotester):
Expand All @@ -40,10 +42,39 @@ def test_tensor_against_naive_controlled(ctrl_spec: CtrlSpec):
rs = np.random.RandomState(42)
subbloq = MatrixGate.random(2, random_state=rs)

cbloq = ControlledViaAnd(subbloq, ctrl_spec)
naive_cbloq = Controlled(subbloq, ctrl_spec)
ctrl_bloq = ControlledViaAnd(subbloq, ctrl_spec)
naive_ctrl_bloq = Controlled(subbloq, ctrl_spec)

expected_tensor = naive_cbloq.tensor_contract()
actual_tensor = cbloq.tensor_contract()
expected_tensor = naive_ctrl_bloq.tensor_contract()
actual_tensor = ctrl_bloq.tensor_contract()

np.testing.assert_allclose(expected_tensor, actual_tensor)


def test_nested_controls():
spec1 = CtrlSpec(QUInt(4), [2, 3])
spec2 = CtrlSpec(QInt(4), [1, 2])
spec = CtrlSpec((QInt(4), QUInt(4)), ([1, 2], [2, 3]))

rs = np.random.RandomState(42)
bloq = MatrixGate.random(2, random_state=rs)

ctrl_bloq = ControlledViaAnd(bloq, spec1).controlled(ctrl_spec=spec2)
assert ctrl_bloq == ControlledViaAnd(bloq, spec)


def test_nested_controlled_x():
bloq = XGate()

ctrl_bloq = ControlledViaAnd(bloq, CtrlSpec(cvs=[1, 1])).controlled(
ctrl_spec=CtrlSpec(cvs=[1, 1])
)
cost = get_cost_value(ctrl_bloq, QECGatesCost())

n_ands = 3
assert cost == GateCounts(and_bloq=n_ands, clifford=n_ands + 1, measurement=n_ands)

np.testing.assert_allclose(
ctrl_bloq.tensor_contract(),
XGate().controlled(CtrlSpec(cvs=[1, 1, 1, 1])).tensor_contract(),
)
4 changes: 3 additions & 1 deletion qualtran/bloqs/mcmt/ctrl_spec_and.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def __attrs_post_init__(self):
if not is_symbolic(self.n_ctrl_qubits) and self.n_ctrl_qubits <= 1:
raise ValueError(f"Expected at least 2 controls, got {self.n_ctrl_qubits}")
for qdtype in self.ctrl_spec.qdtypes:
if not isinstance(qdtype, (QBit, QInt, QUInt, BQUInt, QIntOnesComp, QMontgomeryUInt)):
if not isinstance(
qdtype, (QBit, QAny, QInt, QUInt, BQUInt, QIntOnesComp, QMontgomeryUInt)
):
raise NotImplementedError("CtrlSpecAnd currently only supports integer types")

@cached_property
Expand Down

0 comments on commit 70dd9f5

Please sign in to comment.