Skip to content

Commit

Permalink
Bloq.call_graph
Browse files Browse the repository at this point in the history
  • Loading branch information
mpharrigan committed Oct 17, 2023
1 parent d4612af commit 6869779
Show file tree
Hide file tree
Showing 34 changed files with 337 additions and 192 deletions.
60 changes: 60 additions & 0 deletions dev_tools/bloq-method-overrides-report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import ForwardRef, Set, Type

from qualtran_dev_tools.bloq_finder import get_bloq_classes

from qualtran import Bloq


def _call_graph(bc: Type[Bloq]):
"""Check that a bloq class overrides the right call graph methods.
- Override `build_call_graph` with canonical type annotations.
- Don't override `call_graph` or `bloq_counts`.
"""
call_graph = getattr(bc, 'call_graph')
if call_graph.__qualname__ != 'Bloq.call_graph':
print(f'{bc}.call_graph should not be overridden.')
raise ValueError(str(bc))

bloq_counts = getattr(bc, 'bloq_counts')
if bloq_counts.__qualname__ != 'Bloq.bloq_counts':
print(f'{bc}.bloq_counts should not be overriden.')

bcg = getattr(bc, 'build_call_graph')
annot = bcg.__annotations__
if set(annot.keys()) != {'ssa', 'return'}:
print(
f'{bc}.build_call_graph should have one argument named `ssa` '
f'and a return type annotation'
)
if annot['ssa'] != 'SympySymbolAllocator':
print(f"{bc}.build_call_graph `ssa: 'SympySymbolAllocator'`")
if annot['return'] != Set[ForwardRef('BloqCountT')]:
print(f"{bc}.build_call_graph -> 'BloqCountT'")


def report_call_graph_methods():
bcs = get_bloq_classes()
for bc in bcs:
_call_graph(bc)


def main():
report_call_graph_methods()


if __name__ == '__main__':
report_call_graph_methods()
94 changes: 94 additions & 0 deletions dev_tools/qualtran_dev_tools/bloq_finder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import subprocess
from pathlib import Path
from typing import Callable, Iterable, List, Type

from qualtran import Bloq

from .git_tools import get_git_root


def _get_paths(bloqs_root: Path, filter_func: Callable[[Path], bool]) -> List[Path]:
"""Get *.py files based on `filter_func`."""
cp = subprocess.run(
['git', 'ls-files', '*.py'], capture_output=True, universal_newlines=True, cwd=bloqs_root
)
outs = cp.stdout.splitlines()
paths = [Path(out) for out in outs]

paths = [path for path in paths if filter_func(path)]
return paths


def get_bloq_module_paths(bloqs_root: Path) -> List[Path]:
"""Get *.py files for non-test, non-init modules under `bloqs_root`."""

def is_module_path(path: Path) -> bool:
if path.name.endswith('_test.py'):
return False

if path.name == '__init__.py':
return False

return True

return _get_paths(bloqs_root, is_module_path)


def get_bloq_test_module_paths(bloqs_root: Path) -> List[Path]:
"""Get *_test.py files under `bloqs_root`."""

def is_test_module_path(path: Path) -> bool:
if not path.name.endswith('_test.py'):
return False

return True

return _get_paths(bloqs_root, is_test_module_path)


def _bloq_modpath_to_modname(path: Path) -> str:
"""Get the canonical, full module name given a module path."""
return 'qualtran.bloqs.' + str(path)[: -len('.py')].replace('/', '.')


def modpath_to_bloqs(path: Path) -> Iterable[Type[Bloq]]:
"""Given a module path, return all the `Bloq` classes defined within."""
modname = _bloq_modpath_to_modname(path)
mod = importlib.import_module(modname)
for name, cls in inspect.getmembers(mod, inspect.isclass):
if cls.__module__ != modname:
# Perhaps from an import
continue

if not issubclass(cls, Bloq):
continue

if cls.__name__.startswith('_'):
continue

yield cls


def get_bloq_classes():
reporoot = get_git_root()
bloqs_root = reporoot / 'qualtran/bloqs'
paths = get_bloq_module_paths(bloqs_root)
bloq_clss: List[Type[Bloq]] = []
for path in paths:
bloq_clss.extend(modpath_to_bloqs(path))
return bloq_clss
11 changes: 5 additions & 6 deletions qualtran/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,7 @@
DidNotFlattenAnythingError,
SoquetT,
)

# Internal imports: none
# External:
# - numpy: multiplying bitsizes, making cirq quregs
from ._infra.registers import Register, SelectionRegister, Signature, Side
from ._infra.gate_with_registers import GateWithRegisters

# Internal imports: none
# External imports: none
Expand All @@ -58,6 +54,9 @@
Soquet,
)

from ._infra.gate_with_registers import GateWithRegisters
# Internal imports: none
# External:
# - numpy: multiplying bitsizes, making cirq quregs
from ._infra.registers import Register, SelectionRegister, Signature, Side

# --------------------------------------------------------------------------------------------------
38 changes: 30 additions & 8 deletions qualtran/_infra/bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
"""Contains the main interface for defining `Bloq`s."""

import abc
from typing import Any, Dict, Optional, Set, Tuple, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union

if TYPE_CHECKING:
import cirq
import networkx as nx
import quimb.tensor as qtn
import sympy
from numpy.typing import NDArray

from qualtran import BloqBuilder, CompositeBloq, Signature, Soquet, SoquetT
Expand Down Expand Up @@ -231,16 +233,36 @@ def add_my_tensors(
)
tn.add(qtn.Tensor(data=data, inds=inds, tags=[self.short_name(), tag]))

def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set['BloqCountT']:
"""Return a set of `(n, bloq)` tuples where bloq is used `n` times in the decomposition.
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
"""Override this method to build the bloq call graph.
This method must return a set of `(n, bloq)` tuples where bloq is called `n` times in
the decomposition.
By default, this method will use `self.decompose_bloq()` to count up bloqs.
However, you can override this if you don't want to provide a complete decomposition,
if you know symbolic expressions for the counts, or if you need to "generalize"
the subbloqs by overwriting bloq attributes that do not affect its cost with generic
sympy symbols (perhaps with the aid of the provided `SympySymbolAllocator`).
However, you can provide specific callees directly by overriding this method if
1) you don't want to provide a complete decomposition, 2) you know symbolic expressions
for the counts, 3) or you need to "generalize" the subbloqs by overwriting bloq
attributes that do not affect its cost with generic sympy symbols using
the provided `SympySymbolAllocator`.
"""
return self.decompose_bloq().bloq_counts(ssa)
return self.decompose_bloq().build_call_graph(ssa)

def call_graph(
self,
generalizer: Callable[['Bloq'], Optional['Bloq']] = None,
keep: Optional[Sequence['Bloq']] = None,
max_depth: Optional[int] = None,
) -> Tuple['nx.DiGraph', Dict['Bloq', Union[int, 'sympy.Expr']]]:
from qualtran.resource_counting.bloq_counts import get_bloq_call_graph

return get_bloq_call_graph(self, generalizer=generalizer, keep=keep, max_depth=max_depth)

def bloq_counts(
self, generalizer: Callable[['Bloq'], Optional['Bloq']] = None
) -> Dict['Bloq', Union[int, 'sympy.Expr']]:
_, sigma = self.call_graph(generalizer=generalizer, max_depth=1)
return sigma

def t_complexity(self) -> 'TComplexity':
"""The `TComplexity` for this bloq.
Expand Down
6 changes: 3 additions & 3 deletions qualtran/_infra/composite_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,11 @@ def as_composite_bloq(self) -> 'CompositeBloq':
def decompose_bloq(self) -> 'CompositeBloq':
raise NotImplementedError("Come back later.")

def bloq_counts(self, _: Optional['SympySymbolAllocator'] = None) -> Set['BloqCountT']:
def build_call_graph(self, ssa: Optional['SympySymbolAllocator']) -> Set['BloqCountT']:
"""Return the bloq counts by counting up all the subbloqs."""
from qualtran.resource_counting import get_cbloq_bloq_counts
from qualtran.resource_counting.bloq_counts import _build_cbloq_counts_graph

return get_cbloq_bloq_counts(self)
return _build_cbloq_counts_graph(self)

def iter_bloqnections(
self,
Expand Down
13 changes: 6 additions & 7 deletions qualtran/bloqs/and_bloq.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,11 @@
"metadata": {},
"outputs": [],
"source": [
"from qualtran.resource_counting import get_bloq_counts_graph\n",
"from qualtran.drawing import show_counts_graph\n",
"from qualtran.drawing import show_call_graph\n",
"import attrs\n",
"\n",
"graph, sigma = get_bloq_counts_graph(bloq)\n",
"show_counts_graph(graph)"
"graph, sigma = bloq.call_graph()\n",
"show_call_graph(graph)"
]
},
{
Expand Down Expand Up @@ -166,8 +165,8 @@
"metadata": {},
"outputs": [],
"source": [
"graph, sigma = get_bloq_counts_graph(bloq)\n",
"show_counts_graph(graph)"
"graph, sigma = bloq.call_graph()\n",
"show_call_graph(graph)"
]
},
{
Expand Down Expand Up @@ -353,7 +352,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.9"
}
},
"nbformat": 4,
Expand Down
8 changes: 4 additions & 4 deletions qualtran/bloqs/and_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import itertools
from functools import cached_property
from typing import Any, Dict, Optional, Set, Tuple
from typing import Any, Dict, Set, Tuple

import cirq
import numpy as np
Expand All @@ -23,12 +23,12 @@
from attrs import field, frozen
from numpy.typing import NDArray

from qualtran import Bloq, GateWithRegisters, Register, Side, Signature, Soquet, SoquetT
from qualtran import GateWithRegisters, Register, Side, Signature, Soquet, SoquetT
from qualtran.bloqs.basic_gates import TGate
from qualtran.bloqs.util_bloqs import ArbitraryClifford
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
from qualtran.drawing import Circle, directional_text_box, WireSymbol
from qualtran.resource_counting import big_O, SympySymbolAllocator
from qualtran.resource_counting import big_O, BloqCountT, SympySymbolAllocator


@frozen
Expand Down Expand Up @@ -63,7 +63,7 @@ def signature(self) -> Signature:
]
)

def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set[Tuple[int, Bloq]]:
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
if isinstance(self.cv1, sympy.Expr) or isinstance(self.cv2, sympy.Expr):
pre_post_cliffords = big_O(1)
else:
Expand Down
Loading

0 comments on commit 6869779

Please sign in to comment.