diff --git a/dev_tools/bloq-method-overrides-report.py b/dev_tools/bloq-method-overrides-report.py new file mode 100644 index 000000000..b80c4159e --- /dev/null +++ b/dev_tools/bloq-method-overrides-report.py @@ -0,0 +1,60 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import ForwardRef, Set, Type + +from qualtran_dev_tools.bloq_finder import get_bloq_classes + +from qualtran import Bloq + + +def _call_graph(bc: Type[Bloq]): + """Check that a bloq class overrides the right call graph methods. + + - Override `build_call_graph` with canonical type annotations. + - Don't override `call_graph` or `bloq_counts`. + """ + call_graph = getattr(bc, 'call_graph') + if call_graph.__qualname__ != 'Bloq.call_graph': + print(f'{bc}.call_graph should not be overridden.') + raise ValueError(str(bc)) + + bloq_counts = getattr(bc, 'bloq_counts') + if bloq_counts.__qualname__ != 'Bloq.bloq_counts': + print(f'{bc}.bloq_counts should not be overriden.') + + bcg = getattr(bc, 'build_call_graph') + annot = bcg.__annotations__ + if set(annot.keys()) != {'ssa', 'return'}: + print( + f'{bc}.build_call_graph should have one argument named `ssa` ' + f'and a return type annotation' + ) + if annot['ssa'] != 'SympySymbolAllocator': + print(f"{bc}.build_call_graph `ssa: 'SympySymbolAllocator'`") + if annot['return'] != Set[ForwardRef('BloqCountT')]: + print(f"{bc}.build_call_graph -> 'BloqCountT'") + + +def report_call_graph_methods(): + bcs = get_bloq_classes() + for bc in bcs: + _call_graph(bc) + + +def main(): + report_call_graph_methods() + + +if __name__ == '__main__': + report_call_graph_methods() diff --git a/dev_tools/qualtran_dev_tools/bloq_finder.py b/dev_tools/qualtran_dev_tools/bloq_finder.py new file mode 100644 index 000000000..50612e389 --- /dev/null +++ b/dev_tools/qualtran_dev_tools/bloq_finder.py @@ -0,0 +1,94 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import inspect +import subprocess +from pathlib import Path +from typing import Callable, Iterable, List, Type + +from qualtran import Bloq + +from .git_tools import get_git_root + + +def _get_paths(bloqs_root: Path, filter_func: Callable[[Path], bool]) -> List[Path]: + """Get *.py files based on `filter_func`.""" + cp = subprocess.run( + ['git', 'ls-files', '*.py'], capture_output=True, universal_newlines=True, cwd=bloqs_root + ) + outs = cp.stdout.splitlines() + paths = [Path(out) for out in outs] + + paths = [path for path in paths if filter_func(path)] + return paths + + +def get_bloq_module_paths(bloqs_root: Path) -> List[Path]: + """Get *.py files for non-test, non-init modules under `bloqs_root`.""" + + def is_module_path(path: Path) -> bool: + if path.name.endswith('_test.py'): + return False + + if path.name == '__init__.py': + return False + + return True + + return _get_paths(bloqs_root, is_module_path) + + +def get_bloq_test_module_paths(bloqs_root: Path) -> List[Path]: + """Get *_test.py files under `bloqs_root`.""" + + def is_test_module_path(path: Path) -> bool: + if not path.name.endswith('_test.py'): + return False + + return True + + return _get_paths(bloqs_root, is_test_module_path) + + +def _bloq_modpath_to_modname(path: Path) -> str: + """Get the canonical, full module name given a module path.""" + return 'qualtran.bloqs.' + str(path)[: -len('.py')].replace('/', '.') + + +def modpath_to_bloqs(path: Path) -> Iterable[Type[Bloq]]: + """Given a module path, return all the `Bloq` classes defined within.""" + modname = _bloq_modpath_to_modname(path) + mod = importlib.import_module(modname) + for name, cls in inspect.getmembers(mod, inspect.isclass): + if cls.__module__ != modname: + # Perhaps from an import + continue + + if not issubclass(cls, Bloq): + continue + + if cls.__name__.startswith('_'): + continue + + yield cls + + +def get_bloq_classes(): + reporoot = get_git_root() + bloqs_root = reporoot / 'qualtran/bloqs' + paths = get_bloq_module_paths(bloqs_root) + bloq_clss: List[Type[Bloq]] = [] + for path in paths: + bloq_clss.extend(modpath_to_bloqs(path)) + return bloq_clss diff --git a/qualtran/__init__.py b/qualtran/__init__.py index 70ff0e7bf..10b022e16 100644 --- a/qualtran/__init__.py +++ b/qualtran/__init__.py @@ -41,11 +41,7 @@ DidNotFlattenAnythingError, SoquetT, ) - -# Internal imports: none -# External: -# - numpy: multiplying bitsizes, making cirq quregs -from ._infra.registers import Register, SelectionRegister, Signature, Side +from ._infra.gate_with_registers import GateWithRegisters # Internal imports: none # External imports: none @@ -58,6 +54,9 @@ Soquet, ) -from ._infra.gate_with_registers import GateWithRegisters +# Internal imports: none +# External: +# - numpy: multiplying bitsizes, making cirq quregs +from ._infra.registers import Register, SelectionRegister, Signature, Side # -------------------------------------------------------------------------------------------------- diff --git a/qualtran/_infra/bloq.py b/qualtran/_infra/bloq.py index 99b132b65..d2aa381f0 100644 --- a/qualtran/_infra/bloq.py +++ b/qualtran/_infra/bloq.py @@ -16,11 +16,13 @@ """Contains the main interface for defining `Bloq`s.""" import abc -from typing import Any, Dict, Optional, Set, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union if TYPE_CHECKING: import cirq + import networkx as nx import quimb.tensor as qtn + import sympy from numpy.typing import NDArray from qualtran import BloqBuilder, CompositeBloq, Signature, Soquet, SoquetT @@ -231,16 +233,36 @@ def add_my_tensors( ) tn.add(qtn.Tensor(data=data, inds=inds, tags=[self.short_name(), tag])) - def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set['BloqCountT']: - """Return a set of `(n, bloq)` tuples where bloq is used `n` times in the decomposition. + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: + """Override this method to build the bloq call graph. + + This method must return a set of `(n, bloq)` tuples where bloq is called `n` times in + the decomposition. By default, this method will use `self.decompose_bloq()` to count up bloqs. - However, you can override this if you don't want to provide a complete decomposition, - if you know symbolic expressions for the counts, or if you need to "generalize" - the subbloqs by overwriting bloq attributes that do not affect its cost with generic - sympy symbols (perhaps with the aid of the provided `SympySymbolAllocator`). + However, you can provide specific callees directly by overriding this method if + 1) you don't want to provide a complete decomposition, 2) you know symbolic expressions + for the counts, 3) or you need to "generalize" the subbloqs by overwriting bloq + attributes that do not affect its cost with generic sympy symbols using + the provided `SympySymbolAllocator`. """ - return self.decompose_bloq().bloq_counts(ssa) + return self.decompose_bloq().build_call_graph(ssa) + + def call_graph( + self, + generalizer: Callable[['Bloq'], Optional['Bloq']] = None, + keep: Optional[Sequence['Bloq']] = None, + max_depth: Optional[int] = None, + ) -> Tuple['nx.DiGraph', Dict['Bloq', Union[int, 'sympy.Expr']]]: + from qualtran.resource_counting.bloq_counts import get_bloq_call_graph + + return get_bloq_call_graph(self, generalizer=generalizer, keep=keep, max_depth=max_depth) + + def bloq_counts( + self, generalizer: Callable[['Bloq'], Optional['Bloq']] = None + ) -> Dict['Bloq', Union[int, 'sympy.Expr']]: + _, sigma = self.call_graph(generalizer=generalizer, max_depth=1) + return sigma def t_complexity(self) -> 'TComplexity': """The `TComplexity` for this bloq. diff --git a/qualtran/_infra/composite_bloq.py b/qualtran/_infra/composite_bloq.py index 076b141d1..30af0fac3 100644 --- a/qualtran/_infra/composite_bloq.py +++ b/qualtran/_infra/composite_bloq.py @@ -203,11 +203,11 @@ def as_composite_bloq(self) -> 'CompositeBloq': def decompose_bloq(self) -> 'CompositeBloq': raise NotImplementedError("Come back later.") - def bloq_counts(self, _: Optional['SympySymbolAllocator'] = None) -> Set['BloqCountT']: + def build_call_graph(self, ssa: Optional['SympySymbolAllocator']) -> Set['BloqCountT']: """Return the bloq counts by counting up all the subbloqs.""" - from qualtran.resource_counting import get_cbloq_bloq_counts + from qualtran.resource_counting.bloq_counts import _build_cbloq_counts_graph - return get_cbloq_bloq_counts(self) + return _build_cbloq_counts_graph(self) def iter_bloqnections( self, diff --git a/qualtran/bloqs/and_bloq.ipynb b/qualtran/bloqs/and_bloq.ipynb index 20160597f..7cbb9fc20 100644 --- a/qualtran/bloqs/and_bloq.ipynb +++ b/qualtran/bloqs/and_bloq.ipynb @@ -69,12 +69,11 @@ "metadata": {}, "outputs": [], "source": [ - "from qualtran.resource_counting import get_bloq_counts_graph\n", - "from qualtran.drawing import show_counts_graph\n", + "from qualtran.drawing import show_call_graph\n", "import attrs\n", "\n", - "graph, sigma = get_bloq_counts_graph(bloq)\n", - "show_counts_graph(graph)" + "graph, sigma = bloq.call_graph()\n", + "show_call_graph(graph)" ] }, { @@ -166,8 +165,8 @@ "metadata": {}, "outputs": [], "source": [ - "graph, sigma = get_bloq_counts_graph(bloq)\n", - "show_counts_graph(graph)" + "graph, sigma = bloq.call_graph()\n", + "show_call_graph(graph)" ] }, { @@ -353,7 +352,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.10.9" } }, "nbformat": 4, diff --git a/qualtran/bloqs/and_bloq.py b/qualtran/bloqs/and_bloq.py index ef0fa1819..8d09f8273 100644 --- a/qualtran/bloqs/and_bloq.py +++ b/qualtran/bloqs/and_bloq.py @@ -14,7 +14,7 @@ import itertools from functools import cached_property -from typing import Any, Dict, Optional, Set, Tuple +from typing import Any, Dict, Set, Tuple import cirq import numpy as np @@ -23,12 +23,12 @@ from attrs import field, frozen from numpy.typing import NDArray -from qualtran import Bloq, GateWithRegisters, Register, Side, Signature, Soquet, SoquetT +from qualtran import GateWithRegisters, Register, Side, Signature, Soquet, SoquetT from qualtran.bloqs.basic_gates import TGate from qualtran.bloqs.util_bloqs import ArbitraryClifford from qualtran.cirq_interop.t_complexity_protocol import TComplexity from qualtran.drawing import Circle, directional_text_box, WireSymbol -from qualtran.resource_counting import big_O, SympySymbolAllocator +from qualtran.resource_counting import big_O, BloqCountT, SympySymbolAllocator @frozen @@ -63,7 +63,7 @@ def signature(self) -> Signature: ] ) - def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set[Tuple[int, Bloq]]: + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: if isinstance(self.cv1, sympy.Expr) or isinstance(self.cv2, sympy.Expr): pre_post_cliffords = big_O(1) else: diff --git a/qualtran/bloqs/arithmetic.py b/qualtran/bloqs/arithmetic.py index f553b0bc5..a9c96f20c 100644 --- a/qualtran/bloqs/arithmetic.py +++ b/qualtran/bloqs/arithmetic.py @@ -13,21 +13,9 @@ # limitations under the License. from functools import cached_property -from typing import ( - Dict, - Iterable, - Iterator, - List, - Optional, - Sequence, - Set, - Tuple, - TYPE_CHECKING, - Union, -) +from typing import Dict, Iterable, Iterator, List, Sequence, Set, Tuple, TYPE_CHECKING, Union import cirq -import sympy from attrs import field, frozen from numpy.typing import NDArray @@ -39,7 +27,7 @@ from qualtran.cirq_interop.t_complexity_protocol import t_complexity, TComplexity if TYPE_CHECKING: - from qualtran.resource_counting import SympySymbolAllocator + from qualtran.resource_counting import BloqCountT, SympySymbolAllocator from qualtran.simulation.classical_sim import ClassicalValT @@ -544,7 +532,7 @@ def t_complexity(self): num_t_gates = 4 * self.bitsize - 4 return TComplexity(t=num_t_gates, clifford=num_clifford) - def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set[Tuple[int, Bloq]]: + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: num_clifford = (self.bitsize - 2) * 19 + 16 num_t_gates = 4 * self.bitsize - 4 return {(num_t_gates, TGate()), (num_clifford, ArbitraryClifford(n=1))} @@ -586,7 +574,7 @@ def t_complexity(self): num_t_gates = 4 * self.bitsize - 4 return TComplexity(t=num_t_gates, clifford=num_clifford) - def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set[Tuple[int, Bloq]]: + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: return {(1, Add(self.bitsize)), (self.bitsize, ArbitraryClifford(n=2))} @@ -693,7 +681,7 @@ def t_complexity(self): num_toff = self.bitsize * (self.bitsize - 1) return TComplexity(t=4 * num_toff) - def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set[Tuple[int, Bloq]]: + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: num_toff = self.bitsize * (self.bitsize - 1) return {(4 * num_toff, TGate())} @@ -748,7 +736,7 @@ def t_complexity(self): num_toff -= 1 return TComplexity(t=4 * num_toff) - def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set[Tuple[int, Bloq]]: + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: num_toff = self.k * self.bitsize**2 - self.bitsize if self.k % 3 == 0: num_toff -= 1 @@ -800,7 +788,7 @@ def t_complexity(self): num_toff = 2 * self.a_bitsize * self.b_bitsize - max(self.a_bitsize, self.b_bitsize) return TComplexity(t=4 * num_toff) - def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set[Tuple[int, Bloq]]: + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: num_toff = 2 * self.a_bitsize * self.b_bitsize - max(self.a_bitsize, self.b_bitsize) return {(4 * num_toff, TGate())} @@ -853,7 +841,7 @@ def t_complexity(self): num_toff = self.r_bitsize * (2 * self.i_bitsize - 1) - self.i_bitsize**2 return TComplexity(t=4 * num_toff) - def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set[Tuple[int, Bloq]]: + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: # Eq. D8, we are assuming dA(r_bitsize) and dB(i_bitsize) are inputs and # the user has ensured these are large enough for their desired # precision. @@ -906,7 +894,7 @@ def t_complexity(self): num_toff = self.bitsize**2 - self.bitsize - 1 return TComplexity(t=4 * num_toff) - def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set[Tuple[int, Bloq]]: + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: # Eq. D13, there it is suggested keeping both registers the same size is optimal. num_toff = self.bitsize**2 - self.bitsize - 1 return {(4 * num_toff, TGate())} @@ -960,7 +948,7 @@ def t_complexity(self): num_toff = self.bitsize**2 // 2 - 4 return TComplexity(t=4 * num_toff) - def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set[Tuple[int, Bloq]]: + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: # Bottom of page 74 num_toff = self.bitsize**2 // 2 - 4 return {(4 * num_toff, TGate())} @@ -999,9 +987,7 @@ def pretty_name(self) -> str: def t_complexity(self) -> 'TComplexity': return t_complexity(LessThanEqual(self.bitsize, self.bitsize)) - def bloq_counts( - self, ssa: Optional['SympySymbolAllocator'] = None - ) -> Set[Tuple[Union[int, sympy.Expr], Bloq]]: + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: # TODO Determine precise clifford count and/or ignore. # See: https://github.com/quantumlib/cirq-qubitization/issues/219 # See: https://github.com/quantumlib/cirq-qubitization/issues/217 @@ -1039,9 +1025,7 @@ def signature(self) -> Signature: def t_complexity(self) -> TComplexity: return t_complexity(LessThanConstant(self.bitsize, val=self.val)) - def bloq_counts( - self, ssa: Optional['SympySymbolAllocator'] = None - ) -> Set[Tuple[Union[int, sympy.Expr], Bloq]]: + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: # TODO Determine precise clifford count and/or ignore. # See: https://github.com/quantumlib/cirq-qubitization/issues/219 # See: https://github.com/quantumlib/cirq-qubitization/issues/217 @@ -1075,9 +1059,7 @@ def signature(self) -> Signature: def t_complexity(self) -> 'TComplexity': return TComplexity(t=4 * (self.bitsize - 1)) - def bloq_counts( - self, ssa: Optional['SympySymbolAllocator'] = None - ) -> Set[Tuple[Union[int, sympy.Expr], Bloq]]: + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: # See: https://github.com/quantumlib/cirq-qubitization/issues/219 # See: https://github.com/quantumlib/cirq-qubitization/issues/217 return {(4 * (self.bitsize - 1), TGate())} @@ -1130,7 +1112,5 @@ def t_complexity(self) -> 'TComplexity': num_toffoli = self.bitsize**2 + self.bitsize - 1 return TComplexity(t=4 * num_toffoli) - def bloq_counts( - self, ssa: Optional['SympySymbolAllocator'] = None - ) -> Set[Tuple[Union[int, sympy.Expr], Bloq]]: + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: return {(4 * (self.bitsize**2 + self.bitsize - 1), TGate())} diff --git a/qualtran/bloqs/basic_gates.ipynb b/qualtran/bloqs/basic_gates.ipynb index dc156c829..ace7ce8d3 100644 --- a/qualtran/bloqs/basic_gates.ipynb +++ b/qualtran/bloqs/basic_gates.ipynb @@ -418,11 +418,10 @@ "metadata": {}, "outputs": [], "source": [ - "from qualtran.resource_counting import get_bloq_counts_graph\n", - "from qualtran.drawing import show_counts_graph, show_counts_sigma\n", + "from qualtran.drawing import show_call_graph, show_counts_sigma\n", "\n", - "g, sigma = get_bloq_counts_graph(bloq)\n", - "show_counts_graph(g)\n", + "g, sigma = bloq.call_graph()\n", + "show_call_graph(g)\n", "show_counts_sigma(sigma)" ] }, diff --git a/qualtran/bloqs/basic_gates/rotation.py b/qualtran/bloqs/basic_gates/rotation.py index cb410e252..5c010eaa0 100644 --- a/qualtran/bloqs/basic_gates/rotation.py +++ b/qualtran/bloqs/basic_gates/rotation.py @@ -14,7 +14,7 @@ import abc from functools import cached_property -from typing import Dict, Optional, Set, Tuple, TYPE_CHECKING +from typing import Dict, Set, Tuple, TYPE_CHECKING import numpy as np from attrs import frozen @@ -48,7 +48,7 @@ def t_complexity(self): num_t = int(np.ceil(1.149 * np.log2(1.0 / self.eps) + 9.2)) return TComplexity(t=num_t) - def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set[Tuple[int, Bloq]]: + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: num_t = int(np.ceil(1.149 * np.log2(1.0 / self.eps) + 9.2)) return {(num_t, TGate())} diff --git a/qualtran/bloqs/basic_gates/swap.py b/qualtran/bloqs/basic_gates/swap.py index f45402f24..c66f1ddab 100644 --- a/qualtran/bloqs/basic_gates/swap.py +++ b/qualtran/bloqs/basic_gates/swap.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import cached_property -from typing import Any, Dict, Optional, Set, Tuple, TYPE_CHECKING, Union +from typing import Any, Dict, Set, Tuple, TYPE_CHECKING, Union import cirq import numpy as np @@ -23,8 +23,11 @@ from numpy.typing import NDArray from qualtran import Bloq, BloqBuilder, Signature, SoquetT +from qualtran.bloqs.util_bloqs import ArbitraryClifford from qualtran.cirq_interop.t_complexity_protocol import TComplexity +from .t_gate import TGate + if TYPE_CHECKING: from qualtran.cirq_interop import CirqQuregT from qualtran.resource_counting import BloqCountT, SympySymbolAllocator @@ -147,6 +150,9 @@ def t_complexity(self) -> 'TComplexity': # https://arxiv.org/abs/1308.4134 return TComplexity(t=7, clifford=10) + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: + return {(7, TGate()), (10, ArbitraryClifford(n=3))} + @frozen class CSwap(Bloq): @@ -183,7 +189,7 @@ def build_composite_bloq( return {'ctrl': ctrl, 'x': bb.join(xs), 'y': bb.join(ys)} - def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set['BloqCountT']: + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: return {(self.bitsize, TwoBitCSwap())} def on_classical_vals( diff --git a/qualtran/bloqs/basic_gates/swap_test.py b/qualtran/bloqs/basic_gates/swap_test.py index d0c6c80a8..e7518c041 100644 --- a/qualtran/bloqs/basic_gates/swap_test.py +++ b/qualtran/bloqs/basic_gates/swap_test.py @@ -30,7 +30,6 @@ ) from qualtran.bloqs.basic_gates.swap import _controlled_swap_matrix, _swap_matrix, CSwap from qualtran.bloqs.util_bloqs import Join, Split -from qualtran.resource_counting import get_cbloq_bloq_counts, SympySymbolAllocator from qualtran.testing import assert_valid_bloq_decomposition @@ -154,16 +153,15 @@ def generalize(b: Bloq) -> Optional[Bloq]: return return b - counts2 = get_cbloq_bloq_counts(bloq.decompose_bloq(), generalizer=generalize) - + counts2 = bloq.decompose_bloq().bloq_counts(generalizer=generalize) assert counts1 == counts2 def test_cswap_symbolic(): n = sympy.symbols('n') cswap = CSwap(bitsize=n) - counts = cswap.bloq_counts(SympySymbolAllocator()) + counts = cswap.bloq_counts() assert len(counts) == 1 - assert counts.pop() == (n, TwoBitCSwap()) + assert counts[TwoBitCSwap()] == n with pytest.raises(ValueError): cswap.decompose_bloq() diff --git a/qualtran/bloqs/basic_gates/t_gate_test.py b/qualtran/bloqs/basic_gates/t_gate_test.py index 91d79a42e..c8997e222 100644 --- a/qualtran/bloqs/basic_gates/t_gate_test.py +++ b/qualtran/bloqs/basic_gates/t_gate_test.py @@ -15,7 +15,6 @@ from qualtran import BloqBuilder from qualtran.bloqs.basic_gates import PlusState, TGate -from qualtran.resource_counting import get_bloq_counts_graph def _make_t_gate(): @@ -24,8 +23,8 @@ def _make_t_gate(): return TGate() -def test_bloq_counts(): - g, simga = get_bloq_counts_graph(TGate()) +def test_call_graph(): + g, simga = TGate().call_graph() assert simga == {TGate(): 1} diff --git a/qualtran/bloqs/basic_gates/toffoli.py b/qualtran/bloqs/basic_gates/toffoli.py index 8bc61107f..dd87c314c 100644 --- a/qualtran/bloqs/basic_gates/toffoli.py +++ b/qualtran/bloqs/basic_gates/toffoli.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import cached_property -from typing import Dict, Optional, Set, Tuple, TYPE_CHECKING, Union +from typing import Dict, Set, Tuple, TYPE_CHECKING, Union from attrs import frozen @@ -47,7 +47,7 @@ class Toffoli(Bloq): def signature(self) -> Signature: return Signature([Register('ctrl', 1, shape=(2,)), Register('target', 1)]) - def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set['BloqCountT']: + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: return {(4, TGate())} def on_classical_vals( diff --git a/qualtran/bloqs/basic_gates/toffoli_test.py b/qualtran/bloqs/basic_gates/toffoli_test.py index 5dd2cda59..ce95f9e11 100644 --- a/qualtran/bloqs/basic_gates/toffoli_test.py +++ b/qualtran/bloqs/basic_gates/toffoli_test.py @@ -17,7 +17,6 @@ from qualtran import BloqBuilder from qualtran.bloqs.basic_gates import TGate, Toffoli, ZeroState -from qualtran.resource_counting import get_bloq_counts_graph def _make_Toffoli(): @@ -28,9 +27,9 @@ def _make_Toffoli(): def test_toffoli_t_count(): counts = Toffoli().bloq_counts() - assert counts == {(4, TGate())} + assert counts == {TGate(): 4} - _, sigma = get_bloq_counts_graph(Toffoli()) + _, sigma = Toffoli().call_graph() assert sigma == {TGate(): 4} diff --git a/qualtran/bloqs/basic_gates/y_gate.py b/qualtran/bloqs/basic_gates/y_gate.py index e4838f257..6cc5cec2e 100644 --- a/qualtran/bloqs/basic_gates/y_gate.py +++ b/qualtran/bloqs/basic_gates/y_gate.py @@ -25,7 +25,6 @@ import cirq from qualtran.cirq_interop import CirqQuregT - from qualtran.simulation.classical_sim import ClassicalValT _PAULIY = np.array([[0, -1j], [1j, 0]], dtype=np.complex128) diff --git a/qualtran/bloqs/basic_gates/z_basis.py b/qualtran/bloqs/basic_gates/z_basis.py index 5a2012bd1..1212e1ca8 100644 --- a/qualtran/bloqs/basic_gates/z_basis.py +++ b/qualtran/bloqs/basic_gates/z_basis.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import cached_property -from typing import Any, Dict, Optional, Set, Tuple, TYPE_CHECKING, Union +from typing import Any, Dict, Set, Tuple, TYPE_CHECKING, Union import attrs import numpy as np @@ -291,7 +291,7 @@ def on_classical_vals(self, **vals: 'ClassicalValT') -> Dict[str, int]: def t_complexity(self) -> 'TComplexity': return TComplexity() - def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set['BloqCountT']: + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: return {(big_O(1), ArbitraryClifford(self.bitsize))} def short_name(self) -> str: diff --git a/qualtran/bloqs/factoring/mod_add.py b/qualtran/bloqs/factoring/mod_add.py index 14393d227..78acfd7b8 100644 --- a/qualtran/bloqs/factoring/mod_add.py +++ b/qualtran/bloqs/factoring/mod_add.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import cached_property -from typing import Dict, Optional, Set, Union +from typing import Dict, Set, Union import sympy from attrs import frozen @@ -54,14 +54,12 @@ def signature(self) -> 'Signature': ] ) - def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set['BloqCountT']: - if ssa is None: - raise ValueError(f"{self} requires a SympySymbolAllocator") + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: k = ssa.new_symbol('k') return {(self.bitsize, CtrlModAddK(k=k, bitsize=self.bitsize, mod=self.mod))} def t_complexity(self) -> 'TComplexity': - ((n, bloq),) = self.bloq_counts(SympySymbolAllocator()) + ((bloq, n),) = self.bloq_counts().items() return n * bloq.t_complexity() def on_classical_vals( @@ -100,14 +98,12 @@ class CtrlModAddK(Bloq): def signature(self) -> 'Signature': return Signature([Register('ctrl', bitsize=1), Register('x', bitsize=self.bitsize)]) - def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set['BloqCountT']: - if ssa is None: - raise ValueError(f"{self} requires a SympySymbolAllocator") + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: k = ssa.new_symbol('k') return {(5, CtrlAddK(k=k, bitsize=self.bitsize))} def t_complexity(self) -> 'TComplexity': - ((n, bloq),) = self.bloq_counts(SympySymbolAllocator()) + ((bloq, n),) = self.bloq_counts().items() return n * bloq.t_complexity() def short_name(self) -> str: @@ -137,7 +133,7 @@ def short_name(self) -> str: def signature(self) -> 'Signature': return Signature([Register('ctrl', bitsize=1), Register('x', bitsize=self.bitsize)]) - def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set['BloqCountT']: + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: return {(2 * self.bitsize, TGate())} def t_complexity(self) -> 'TComplexity': diff --git a/qualtran/bloqs/factoring/mod_add_test.py b/qualtran/bloqs/factoring/mod_add_test.py index 9924774b1..628e88ad3 100644 --- a/qualtran/bloqs/factoring/mod_add_test.py +++ b/qualtran/bloqs/factoring/mod_add_test.py @@ -13,15 +13,14 @@ # limitations under the License. from qualtran.bloqs.factoring.mod_add import CtrlModAddK, CtrlScaleModAdd -from qualtran.resource_counting import SympySymbolAllocator def test_ctrl_scale_mod_add(): bloq = CtrlScaleModAdd(k=123, mod=13 * 17, bitsize=8) assert bloq.short_name() == 'y += x*123 % 221' - counts = bloq.bloq_counts(SympySymbolAllocator()) - ((n, bloq),) = counts + counts = bloq.bloq_counts() + ((bloq, n),) = counts.items() assert n == 8 @@ -29,6 +28,6 @@ def test_ctrl_mod_add_k(): bloq = CtrlModAddK(k=123, mod=13 * 17, bitsize=8) assert bloq.short_name() == 'x += 123 % 221' - counts = bloq.bloq_counts(SympySymbolAllocator()) - ((n, bloq),) = counts + counts = bloq.bloq_counts() + ((bloq, n),) = counts.items() assert n == 5 diff --git a/qualtran/bloqs/factoring/mod_exp.py b/qualtran/bloqs/factoring/mod_exp.py index b1e061eb7..e51128575 100644 --- a/qualtran/bloqs/factoring/mod_exp.py +++ b/qualtran/bloqs/factoring/mod_exp.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import cached_property -from typing import Dict, Optional, Set, Union +from typing import Dict, Set, Union import numpy as np import sympy @@ -97,9 +97,7 @@ def build_composite_bloq(self, bb: 'BloqBuilder', exponent: 'SoquetT') -> Dict[s return {'exponent': bb.join(exponent), 'x': x} - def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set['BloqCountT']: - if ssa is None: - raise ValueError(f"{self} requires a SympySymbolAllocator") + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: k = ssa.new_symbol('k') return { (1, IntState(val=1, bitsize=self.x_bitsize)), diff --git a/qualtran/bloqs/factoring/mod_exp_test.py b/qualtran/bloqs/factoring/mod_exp_test.py index c403989e1..4ccb2400f 100644 --- a/qualtran/bloqs/factoring/mod_exp_test.py +++ b/qualtran/bloqs/factoring/mod_exp_test.py @@ -22,7 +22,7 @@ from qualtran.bloqs.factoring.mod_exp import ModExp from qualtran.bloqs.factoring.mod_mul import CtrlModMul from qualtran.bloqs.util_bloqs import Join, Split -from qualtran.resource_counting import get_cbloq_bloq_counts, SympySymbolAllocator +from qualtran.resource_counting import SympySymbolAllocator from qualtran.testing import execute_notebook @@ -60,8 +60,8 @@ def test_mod_exp_symbolic(): g, N, n_e, n_x = sympy.symbols('g N n_e, n_x') modexp = ModExp(base=g, mod=N, exp_bitsize=n_e, x_bitsize=n_x) assert modexp.short_name() == 'g^e % N' - counts = modexp.bloq_counts(SympySymbolAllocator()) - counts_by_bloq = {bloq.pretty_name(): n for n, bloq in counts} + counts = modexp.bloq_counts() + counts_by_bloq = {bloq.pretty_name(): n for bloq, n in counts.items()} assert counts_by_bloq['|1>'] == 1 assert counts_by_bloq['CtrlModMul'] == n_e @@ -72,7 +72,7 @@ def test_mod_exp_symbolic(): def test_mod_exp_consistent_counts(): bloq = ModExp(base=8, exp_bitsize=3, x_bitsize=10, mod=50) - counts1 = bloq.bloq_counts(SympySymbolAllocator()) + counts1 = bloq.bloq_counts() ssa = SympySymbolAllocator() my_k = ssa.new_symbol('k') @@ -86,7 +86,7 @@ def generalize(b: Bloq) -> Optional[Bloq]: return return b - counts2 = get_cbloq_bloq_counts(bloq.decompose_bloq(), generalizer=generalize) + counts2 = bloq.decompose_bloq().bloq_counts(generalizer=generalize) assert counts1 == counts2 diff --git a/qualtran/bloqs/factoring/mod_mul.py b/qualtran/bloqs/factoring/mod_mul.py index 3720af70c..f9cf0f675 100644 --- a/qualtran/bloqs/factoring/mod_mul.py +++ b/qualtran/bloqs/factoring/mod_mul.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import cached_property -from typing import Dict, Optional, Set, Union +from typing import Dict, Set, Union import sympy from attrs import frozen @@ -83,9 +83,7 @@ def build_composite_bloq( bb.free(y) return {'ctrl': ctrl, 'x': x} - def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set['BloqCountT']: - if ssa is None: - raise ValueError(f"{self} requires a SympySymbolAllocator") + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: k = ssa.new_symbol('k') return {(2, self._Add(k=k)), (1, CSwap(self.bitsize))} diff --git a/qualtran/bloqs/factoring/mod_mul_test.py b/qualtran/bloqs/factoring/mod_mul_test.py index ab1855ec0..c6ee3ac90 100644 --- a/qualtran/bloqs/factoring/mod_mul_test.py +++ b/qualtran/bloqs/factoring/mod_mul_test.py @@ -22,7 +22,7 @@ from qualtran.bloqs.factoring.mod_add import CtrlScaleModAdd from qualtran.bloqs.factoring.mod_mul import CtrlModMul from qualtran.bloqs.util_bloqs import Allocate, Free -from qualtran.resource_counting import get_cbloq_bloq_counts, SympySymbolAllocator +from qualtran.resource_counting import SympySymbolAllocator def _make_modmul(): @@ -99,7 +99,7 @@ def test_symbolic(): assert bloq.short_name() == 'x *= k % N' # it's all fixed constants, but check it works anyways - counts = bloq.bloq_counts(SympySymbolAllocator()) + counts = bloq.bloq_counts() assert len(counts) > 0 ctrl, x = bloq.call_classically(ctrl=1, x=sympy.Symbol('x')) @@ -112,7 +112,7 @@ def test_symbolic(): def test_consistent_counts(): bloq = CtrlModMul(k=123, mod=13 * 17, bitsize=8) - counts1 = bloq.bloq_counts(SympySymbolAllocator()) + counts1 = bloq.bloq_counts() ssa = SympySymbolAllocator() my_k = ssa.new_symbol('k') @@ -125,6 +125,6 @@ def generalize(b: Bloq) -> Optional[Bloq]: return return b - counts2 = get_cbloq_bloq_counts(bloq.decompose_bloq(), generalizer=generalize) + counts2 = bloq.decompose_bloq().bloq_counts(generalizer=generalize) assert counts1 == counts2 diff --git a/qualtran/bloqs/factoring/ref-factoring.ipynb b/qualtran/bloqs/factoring/ref-factoring.ipynb index 6a3d071dc..898380cdb 100644 --- a/qualtran/bloqs/factoring/ref-factoring.ipynb +++ b/qualtran/bloqs/factoring/ref-factoring.ipynb @@ -207,11 +207,10 @@ }, "outputs": [], "source": [ - "from qualtran.resource_counting import get_bloq_counts_graph\n", - "from qualtran.drawing import show_counts_graph, show_counts_sigma\n", + "from qualtran.drawing import show_call_graph, show_counts_sigma\n", "\n", - "g, sigma = get_bloq_counts_graph(bloq)\n", - "show_counts_graph(g)\n", + "g, sigma = bloq.call_graph()\n", + "show_call_graph(g)\n", "show_counts_sigma(sigma)" ] }, diff --git a/qualtran/bloqs/swap_network.py b/qualtran/bloqs/swap_network.py index 7df8ae0e3..526a47741 100644 --- a/qualtran/bloqs/swap_network.py +++ b/qualtran/bloqs/swap_network.py @@ -13,11 +13,10 @@ # limitations under the License. from functools import cached_property -from typing import Dict, Optional, Set, Tuple, TYPE_CHECKING, Union +from typing import Dict, Set, Tuple, TYPE_CHECKING, Union import cirq import numpy as np -import sympy from attrs import frozen from cirq_ft import MultiTargetCSwapApprox from numpy.typing import NDArray @@ -31,7 +30,7 @@ if TYPE_CHECKING: from qualtran import CompositeBloq from qualtran.cirq_interop import CirqQuregT - from qualtran.resource_counting import SympySymbolAllocator + from qualtran.resource_counting import BloqCountT, SympySymbolAllocator from qualtran.simulation.classical_sim import ClassicalValT @@ -101,9 +100,7 @@ def t_complexity(self) -> TComplexity: # 2 * n - 1: CNOTs from 1 MultiTargetCNOT return TComplexity(t=4 * n, clifford=22 * n - 1) - def bloq_counts( - self, ssa: Optional['SympySymbolAllocator'] = None - ) -> Set[Tuple[Union[int, sympy.Expr], Bloq]]: + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: n = self.bitsize # 4 * n: G gates, each wth 1 T and 4 single qubit cliffords # 4 * n: CNOTs @@ -156,9 +153,7 @@ def build_composite_bloq( return {'selection': bb.join(selection), 'targets': targets} - def bloq_counts( - self, ssa: Optional['SympySymbolAllocator'] = None - ) -> Set[Tuple[Union[int, sympy.Expr], Bloq]]: + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: num_swaps = np.floor( sum([self.n_target_registers / (2 ** (j + 1)) for j in range(self.selection_bitsize)]) ) diff --git a/qualtran/bloqs/swap_network_test.py b/qualtran/bloqs/swap_network_test.py index d2644871b..087646020 100644 --- a/qualtran/bloqs/swap_network_test.py +++ b/qualtran/bloqs/swap_network_test.py @@ -13,13 +13,14 @@ # limitations under the License. import random -from typing import Set, Tuple +from typing import Dict, Tuple import cirq import cirq_ft import cirq_ft.infra.testing as cq_testing import numpy as np import pytest +import sympy from qualtran import Bloq, BloqBuilder from qualtran.bloqs.basic_gates import TGate @@ -104,10 +105,10 @@ def test_swap_with_zero_classically(): print(sel, out_data) -def get_t_count_and_clifford(bc: Set[Tuple[int, Bloq]]) -> Tuple[int, int]: +def get_t_count_and_clifford(bc: Dict[Bloq, int]) -> Tuple[int, int]: """Get the t count and clifford cost from bloq count.""" - cliff_cost = sum([x[0] for x in bc if isinstance(x[1], ArbitraryClifford)]) - t_cost = sum([x[0] for x in bc if isinstance(x[1], TGate)]) + cliff_cost = sum([v for k, v in bc.items() if isinstance(k, ArbitraryClifford)]) + t_cost = sum([v for k, v in bc.items() if isinstance(k, TGate)]) return t_cost, cliff_cost @@ -141,10 +142,17 @@ def test_cswap_approx_bloq_counts(n): ) def test_swap_with_zero_bloq_counts(selection_bitsize, target_bitsize, n_target_registers, want): gate = SwapWithZero(selection_bitsize, target_bitsize, n_target_registers) - bc = list(gate.bloq_counts())[0] - t_cost, cliff_cost = get_t_count_and_clifford(bc[1].bloq_counts()) - assert bc[0] * t_cost == want.t - assert bc[0] * cliff_cost == want.clifford + + n = sympy.Symbol('n') + + def _gen_clif(bloq: Bloq) -> Bloq: + if isinstance(bloq, ArbitraryClifford): + return ArbitraryClifford(n) + return bloq + + _, sigma = gate.call_graph(generalizer=_gen_clif) + assert sigma[TGate()] == want.t + assert sigma[ArbitraryClifford(n)] == want.clifford @pytest.mark.parametrize( diff --git a/qualtran/cirq_interop/__init__.py b/qualtran/cirq_interop/__init__.py index 4c84e8569..552013d05 100644 --- a/qualtran/cirq_interop/__init__.py +++ b/qualtran/cirq_interop/__init__.py @@ -17,6 +17,5 @@ isort:skip_file """ -from ._cirq_to_bloq import CirqQuregT, CirqGateAsBloq, cirq_optree_to_cbloq, decompose_from_cirq_op - from ._bloq_to_cirq import BloqAsCirqGate +from ._cirq_to_bloq import CirqQuregT, CirqGateAsBloq, cirq_optree_to_cbloq, decompose_from_cirq_op diff --git a/qualtran/drawing/__init__.py b/qualtran/drawing/__init__.py index 44d687c36..56ec0c129 100644 --- a/qualtran/drawing/__init__.py +++ b/qualtran/drawing/__init__.py @@ -17,6 +17,9 @@ isort:skip_file """ +from ._show_funcs import show_bloq, show_bloqs, show_call_graph, show_counts_sigma +from .bloq_counts_graph import GraphvizCounts, format_counts_sigma, format_counts_graph_markdown +from .classical_sim_graph import ClassicalSimGraphDrawer from .graphviz import GraphDrawer, PrettyGraphDrawer from .musical_score import ( RegPosition, @@ -36,9 +39,3 @@ draw_musical_score, dump_musical_score, ) - -from .classical_sim_graph import ClassicalSimGraphDrawer - -from .bloq_counts_graph import GraphvizCounts, format_counts_sigma, format_counts_graph_markdown - -from ._show_funcs import show_bloq, show_bloqs, show_counts_graph, show_counts_sigma diff --git a/qualtran/drawing/_show_funcs.py b/qualtran/drawing/_show_funcs.py index fa7ce68ce..0983cf8f0 100644 --- a/qualtran/drawing/_show_funcs.py +++ b/qualtran/drawing/_show_funcs.py @@ -53,7 +53,7 @@ def show_bloqs(bloqs: Sequence['Bloq'], labels: Sequence[str] = None): IPython.display.display(box) -def show_counts_graph(g: 'nx.DiGraph') -> None: +def show_call_graph(g: 'nx.DiGraph') -> None: """Display a graph representation of the counts graph `g`.""" IPython.display.display(GraphvizCounts(g).get_svg()) diff --git a/qualtran/drawing/bloq_counts_graph_test.py b/qualtran/drawing/bloq_counts_graph_test.py index 00f3b09bc..fd6a43ce2 100644 --- a/qualtran/drawing/bloq_counts_graph_test.py +++ b/qualtran/drawing/bloq_counts_graph_test.py @@ -15,11 +15,11 @@ from qualtran.bloqs.and_bloq import MultiAnd from qualtran.drawing import format_counts_graph_markdown, format_counts_sigma, GraphvizCounts -from qualtran.resource_counting import get_bloq_counts_graph +from qualtran.resource_counting import get_bloq_call_graph def test_format_counts_sigma(): - graph, sigma = get_bloq_counts_graph(MultiAnd(cvs=(1,) * 6)) + graph, sigma = get_bloq_call_graph(MultiAnd(cvs=(1,) * 6)) ret = format_counts_sigma(sigma) assert ( ret @@ -31,7 +31,7 @@ def test_format_counts_sigma(): def test_format_counts_graph_markdown(): - graph, sigma = get_bloq_counts_graph(MultiAnd(cvs=(1,) * 6)) + graph, sigma = get_bloq_call_graph(MultiAnd(cvs=(1,) * 6)) ret = format_counts_graph_markdown(graph) assert ( ret @@ -45,7 +45,7 @@ def test_format_counts_graph_markdown(): def test_graphviz_counts(): - graph, sigma = get_bloq_counts_graph(MultiAnd(cvs=(1,) * 6)) + graph, sigma = get_bloq_call_graph(MultiAnd(cvs=(1,) * 6)) gvc = GraphvizCounts(graph) # The main test is in the drawing notebook, so please spot check that. diff --git a/qualtran/resource_counting/__init__.py b/qualtran/resource_counting/__init__.py index 694a48985..9b95ee62b 100644 --- a/qualtran/resource_counting/__init__.py +++ b/qualtran/resource_counting/__init__.py @@ -21,7 +21,6 @@ BloqCountT, big_O, SympySymbolAllocator, - get_cbloq_bloq_counts, - get_bloq_counts_graph, + get_bloq_call_graph, print_counts_graph, ) diff --git a/qualtran/resource_counting/bloq_counts.py b/qualtran/resource_counting/bloq_counts.py index 876f35c17..31e7db81e 100644 --- a/qualtran/resource_counting/bloq_counts.py +++ b/qualtran/resource_counting/bloq_counts.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Functionality for the `Bloq.bloq_counts()` protocol.""" +"""Functionality for the `Bloq.call_graph()` protocol.""" from collections import defaultdict -from typing import Callable, Dict, Optional, Sequence, Set, Tuple, Union +from typing import Callable, Dict, Optional, Set, Tuple, Union import networkx as nx import sympy @@ -51,9 +51,7 @@ def new_symbol(self, prefix: str) -> sympy.Symbol: return s -def get_cbloq_bloq_counts( - cbloq: CompositeBloq, generalizer: Callable[[Bloq], Optional[Bloq]] = None -) -> Set[BloqCountT]: +def _build_cbloq_counts_graph(cbloq: CompositeBloq) -> Set[BloqCountT]: """Count all the subbloqs in a composite bloq. `CompositeBloq.resource_counting` calls this with no generalizer. @@ -63,27 +61,22 @@ def get_cbloq_bloq_counts( generalizer: A function that replaces bloq attributes that do not affect resource costs with sympy placeholders. """ - if generalizer is None: - generalizer = lambda b: b counts: Dict[Bloq, int] = defaultdict(lambda: 0) for binst in cbloq.bloq_instances: - bloq = binst.bloq - bloq = generalizer(bloq) - if bloq is None: - continue - - counts[bloq] += 1 + counts[binst.bloq] += 1 return {(n, bloq) for bloq, n in counts.items()} -def _descend_counts( +def _recurse_call_graph( parent: Bloq, g: nx.DiGraph, ssa: SympySymbolAllocator, generalizer: Callable[[Bloq], Optional[Bloq]], - keep: Sequence[Bloq], + keep: Callable[[Bloq], bool], + max_depth: Optional[int], + depth: int, ) -> Dict[Bloq, Union[int, sympy.Expr]]: """Recursive counting function. @@ -103,12 +96,17 @@ def _descend_counts( g.add_node(parent) # Base case 1: This node is requested by the user to be a leaf node via the `keep` parameter. - if parent in keep: + if keep(parent): return {parent: 1} + + # Base case 2: Max depth exceeded + if max_depth is not None and depth >= max_depth: + return {parent: 1} + try: - count_decomp = parent.bloq_counts(ssa) - except NotImplementedError: - # Base case 2: Decomposition (or `bloq_counts`) is not implemented. This is left as a + count_decomp = parent.build_call_graph(ssa) + except NotImplementedError: # TODO: DecomposeNotImplementedError + # Base case 3: Decomposition (or `bloq_counts`) is not implemented. This is left as a # leaf node. return {parent: 1} @@ -125,7 +123,7 @@ def _descend_counts( g.add_edge(parent, child, n=n) # Do the recursive step, which will continue to mutate `g` - child_counts = _descend_counts(child, g, ssa, generalizer, keep) + child_counts = _recurse_call_graph(child, g, ssa, generalizer, keep, max_depth, depth + 1) # Update `sigma` with the recursion results. for k in child_counts.keys(): @@ -134,24 +132,30 @@ def _descend_counts( return dict(sigma) -def get_bloq_counts_graph( +def get_bloq_call_graph( bloq: Bloq, generalizer: Callable[[Bloq], Optional[Bloq]] = None, ssa: Optional[SympySymbolAllocator] = None, - keep: Optional[Sequence[Bloq]] = None, + keep: Optional[Callable[[Bloq], bool]] = None, + max_depth: Optional[int] = None, ) -> Tuple[nx.DiGraph, Dict[Bloq, Union[int, sympy.Expr]]]: - """Recursively gather bloq counts. + """Recursively build the bloq call graph. + + We stop recursing and keep a bloq as a leaf in the call graph if 1) `keep` is provided + and evaluates to True on the given bloq, 2) `max_depth` is provided and recursing would + exceed the maximum, or 3) if a bloq cannot be decomposed. Args: bloq: The bloq to count sub-bloqs. generalizer: If provided, run this function on each (sub)bloq to replace attributes - that do not affect resource estimates with generic sympy symbols. If this function - returns `None`, the bloq is ommitted from the counts graph. + that do not affect resource estimates with generic sympy symbols. If the function + returns `None`, the bloq is omitted from the counts graph. ssa: a `SympySymbolAllocator` that will be passed to the `Bloq.bloq_counts` methods. If your `generalizer` function closes over a `SympySymbolAllocator`, provide it here as well. Otherwise, we will create a new allocator. - keep: Stop recursing and keep these bloqs as leaf nodes in the counts graph. Otherwise, - leaf nodes are those without a decomposition. + keep: If this function evaluates to True for the current bloq, keep the bloq as a leaf + node in the call graph and stop recursing. + max_depth: If provided, stop recursing after the given depth. Returns: g: A directed graph where nodes are (generalized) bloqs and edge attribute 'n' counts @@ -163,7 +167,7 @@ def get_bloq_counts_graph( if ssa is None: ssa = SympySymbolAllocator() if keep is None: - keep = [] + keep = lambda b: False if generalizer is None: generalizer = lambda b: b @@ -171,7 +175,7 @@ def get_bloq_counts_graph( bloq = generalizer(bloq) if bloq is None: raise ValueError("You can't generalize away the root bloq.") - sigma = _descend_counts(bloq, g, ssa, generalizer, keep) + sigma = _recurse_call_graph(bloq, g, ssa, generalizer, keep, max_depth, depth=0) return g, sigma diff --git a/qualtran/resource_counting/bloq_counts_test.py b/qualtran/resource_counting/bloq_counts_test.py index 32065e1e0..81dd2e7b6 100644 --- a/qualtran/resource_counting/bloq_counts_test.py +++ b/qualtran/resource_counting/bloq_counts_test.py @@ -23,7 +23,7 @@ from qualtran import Bloq, BloqBuilder, Signature, SoquetT from qualtran.bloqs.basic_gates import TGate from qualtran.bloqs.util_bloqs import ArbitraryClifford, Join, Split -from qualtran.resource_counting import BloqCountT, get_bloq_counts_graph, SympySymbolAllocator +from qualtran.resource_counting import BloqCountT, get_bloq_call_graph, SympySymbolAllocator @frozen @@ -34,7 +34,7 @@ class BigBloq(Bloq): def signature(self) -> 'Signature': return Signature.build(x=self.bitsize) - def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set['BloqCountT']: + def build_call_graph(self, ssa: Optional['SympySymbolAllocator']) -> Set['BloqCountT']: return {(sympy.log(self.bitsize), SubBloq(unrelated_param=0.5))} @@ -62,7 +62,7 @@ class SubBloq(Bloq): def signature(self) -> 'Signature': return Signature.build(q=1) - def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set['BloqCountT']: + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: return {(3, TGate())} @@ -76,7 +76,7 @@ def generalize(bloq: Bloq) -> Optional[Bloq]: return bloq - return get_bloq_counts_graph(bloq, generalize, ss) + return get_bloq_call_graph(bloq, generalize, ss) def test_bloq_counts_method(): @@ -87,7 +87,7 @@ def test_bloq_counts_method(): def test_bloq_counts_decomp(): - graph, sigma = get_bloq_counts_graph(DecompBloq(10)) + graph, sigma = get_bloq_call_graph(DecompBloq(10)) assert len(sigma) == 3 # includes split and join expr = sigma[TGate()] assert str(expr) == '30' @@ -97,7 +97,7 @@ def generalize(bloq): return None return bloq - graph, sigma = get_bloq_counts_graph(DecompBloq(10), generalize) + graph, sigma = get_bloq_call_graph(DecompBloq(10), generalize) assert len(sigma) == 1 expr = sigma[TGate()] assert str(expr) == '30' diff --git a/qualtran/serialization/bloq.py b/qualtran/serialization/bloq.py index 2e1fa3c29..36de92204 100644 --- a/qualtran/serialization/bloq.py +++ b/qualtran/serialization/bloq.py @@ -34,7 +34,6 @@ from qualtran.bloqs.util_bloqs import Allocate, ArbitraryClifford, Free, Join, Split from qualtran.cirq_interop import CirqGateAsBloq from qualtran.protos import args_pb2, bloq_pb2 -from qualtran.resource_counting.bloq_counts import SympySymbolAllocator from qualtran.serialization import annotations, args, registers RESOLVER_DICT = { @@ -177,7 +176,7 @@ def bloqs_to_proto( try: bloq_counts = { bloq_to_idx[b]: args.int_or_sympy_to_proto(c) - for c, b in sorted(bloq.bloq_counts(SympySymbolAllocator()), key=lambda x: x[0]) + for b, c in sorted(bloq.bloq_counts().items(), key=lambda x: x[1]) } except (NotImplementedError, KeyError): # NotImplementedError is raised if `bloq` does not implement bloq_counts. @@ -279,7 +278,7 @@ def _populate_bloq_to_idx( # Approximately decompose the current Bloq and its decomposed Bloqs. try: - for _, subbloq in bloq.bloq_counts(SympySymbolAllocator()): + for subbloq, _ in bloq.bloq_counts().items(): _add_bloq_to_dict(subbloq, bloq_to_idx) _populate_bloq_to_idx(subbloq, bloq_to_idx, pred, 0)