Skip to content

Commit

Permalink
Change build_call_graph in bloqs to return dict (#1392)
Browse files Browse the repository at this point in the history
* Change build_call_graph in bloqs to return dict

- This changes the build_call_graph function within bloqs
in qualtran to return a dictionary of cost counts rather than
a set.
- This will allow the ordering of cost counts to be deterministic

Note that this requires some slight code changes for bloqs
that have multiple set items since (Toffoli(), 1) and (Toffoli(), 2)
would have two different items in a set, but share an index in the
dictionary.

This also may alter counts (i.e. fix a bug) where set items clobber
each other.  For instance, adding (Toffoli(), self.bits_a) and
(Toffoli(), self.bits_b) will previously give the wrong count
if bits_a == bits_b since the two items would be the same in the
set.
  • Loading branch information
dstrain115 authored Sep 7, 2024
1 parent 098f7ea commit 488b9a5
Show file tree
Hide file tree
Showing 100 changed files with 768 additions and 801 deletions.
4 changes: 2 additions & 2 deletions qualtran/_infra/composite_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

from qualtran.bloqs.bookkeeping.auto_partition import Unused
from qualtran.cirq_interop._cirq_to_bloq import CirqQuregInT, CirqQuregT
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT

# NDArrays must be bound to np.generic
Expand Down Expand Up @@ -237,7 +237,7 @@ def decompose_bloq(self) -> 'CompositeBloq':
"Consider using the composite bloq directly or using `.flatten()`."
)

def build_call_graph(self, ssa: Optional['SympySymbolAllocator']) -> Set['BloqCountT']:
def build_call_graph(self, ssa: Optional['SympySymbolAllocator']) -> 'BloqCountDictT':
"""Return the bloq counts by counting up all the subbloqs."""
from qualtran.resource_counting import build_cbloq_call_graph

Expand Down
11 changes: 5 additions & 6 deletions qualtran/bloqs/arithmetic/_shims.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@
will be fleshed out and moved to their final organizational location soon (written: 2024-05-06).
"""
from functools import cached_property
from typing import Set

from attrs import frozen

from qualtran import Bloq, QBit, QUInt, Register, Signature
from qualtran.bloqs.basic_gates import Toffoli
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator


@frozen
Expand All @@ -36,8 +35,8 @@ class MultiCToffoli(Bloq):
def signature(self) -> 'Signature':
return Signature([Register('ctrl', QBit(), shape=(self.n,)), Register('target', QBit())])

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
return {(Toffoli(), self.n - 2)}
def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT:
return {Toffoli(): self.n - 2}


@frozen
Expand All @@ -51,9 +50,9 @@ def signature(self) -> 'Signature':
[Register('x', QUInt(self.n)), Register('y', QUInt(self.n)), Register('out', QBit())]
)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT:
# litinski
return {(Toffoli(), self.n)}
return {Toffoli(): self.n}


@frozen
Expand Down
25 changes: 15 additions & 10 deletions qualtran/bloqs/arithmetic/addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,12 @@

if TYPE_CHECKING:
from qualtran.drawing import WireSymbol
from qualtran.resource_counting import BloqCountDictT, BloqCountT, SympySymbolAllocator
from qualtran.resource_counting import (
BloqCountDictT,
BloqCountT,
MutableBloqCountDictT,
SympySymbolAllocator,
)
from qualtran.simulation.classical_sim import ClassicalValT
from qualtran.symbolics import SymbolicInt

Expand Down Expand Up @@ -209,10 +214,10 @@ def decompose_from_registers(
yield CNOT().on(input_bits[0], output_bits[0])
context.qubit_manager.qfree(ancillas)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
n = self.b_dtype.bitsize
n_cnot = (n - 2) * 6 + 3
return {(And(), n - 1), (And().adjoint(), n - 1), (CNOT(), n_cnot)}
return {And(): n - 1, And().adjoint(): n - 1, CNOT(): n_cnot}


@bloq_example(generalizer=ignore_split_join)
Expand Down Expand Up @@ -327,8 +332,8 @@ def decompose_from_registers(
]
return cirq.inverse(optree) if self.is_adjoint else optree

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
return {(And(uncompute=self.is_adjoint), self.bitsize), (CNOT(), 5 * self.bitsize)}
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {And(uncompute=self.is_adjoint): self.bitsize, CNOT(): 5 * self.bitsize}

def __pow__(self, power: int):
if power == 1:
Expand Down Expand Up @@ -505,16 +510,16 @@ def build_composite_bloq(
def build_call_graph(
self, ssa: 'SympySymbolAllocator'
) -> Union['BloqCountDictT', Set['BloqCountT']]:
loading_cost: Tuple[Bloq, SymbolicInt]
loading_cost: MutableBloqCountDictT
if len(self.cvs) == 0:
loading_cost = (XGate(), self.bitsize) # upper bound; depends on the data.
loading_cost = {XGate(): self.bitsize} # upper bound; depends on the data.
elif len(self.cvs) == 1:
loading_cost = (CNOT(), self.bitsize) # upper bound; depends on the data.
loading_cost = {CNOT(): self.bitsize} # upper bound; depends on the data.
else:
# Otherwise, use the decomposition
return super().build_call_graph(ssa=ssa)

return {loading_cost, (Add(QUInt(self.bitsize)), 1)}
loading_cost[Add(QUInt(self.bitsize))] = 1
return loading_cost

def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> Tuple['Bloq', 'AddControlledT']:
if self.cvs:
Expand Down
10 changes: 5 additions & 5 deletions qualtran/bloqs/arithmetic/bitwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from qualtran.symbolics import is_symbolic, SymbolicInt

if TYPE_CHECKING:
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT


Expand Down Expand Up @@ -90,9 +90,9 @@ def build_composite_bloq(self, bb: 'BloqBuilder', x: 'Soquet') -> dict[str, 'Soq

return {'x': x}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> set['BloqCountT']:
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
num_flips = self.bitsize if self.is_symbolic() else sum(self._bits_k)
return {(XGate(), num_flips)}
return {XGate(): num_flips}

def on_classical_vals(self, x: 'ClassicalValT') -> dict[str, 'ClassicalValT']:
if isinstance(self.k, sympy.Expr):
Expand Down Expand Up @@ -156,8 +156,8 @@ def build_composite_bloq(self, bb: BloqBuilder, x: Soquet, y: Soquet) -> dict[st

return {'x': bb.join(xs, dtype=self.dtype), 'y': bb.join(ys, dtype=self.dtype)}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> set['BloqCountT']:
return {(CNOT(), self.dtype.num_qubits)}
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {CNOT(): self.dtype.num_qubits}

def on_classical_vals(
self, x: 'ClassicalValT', y: 'ClassicalValT'
Expand Down
108 changes: 51 additions & 57 deletions qualtran/bloqs/arithmetic/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,7 @@

from collections import defaultdict
from functools import cached_property
from typing import (
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
from typing import Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union

import attrs
import cirq
Expand Down Expand Up @@ -65,7 +54,11 @@

if TYPE_CHECKING:
from qualtran import BloqBuilder
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.resource_counting import (
BloqCountDictT,
MutableBloqCountDictT,
SympySymbolAllocator,
)
from qualtran.simulation.classical_sim import ClassicalValT


Expand Down Expand Up @@ -183,22 +176,22 @@ def decompose_from_registers(
def _has_unitary_(self):
return True

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
if (
not is_symbolic(self.less_than_val, self.bitsize)
and self.less_than_val >= 2**self.bitsize
):
return {(XGate(), 1)}
return {XGate(): 1}
num_set_bits = (
int(self.less_than_val).bit_count()
if not is_symbolic(self.less_than_val)
else self.bitsize
)
return {
(And(), self.bitsize),
(And().adjoint(), self.bitsize),
(CNOT(), num_set_bits + 2 * self.bitsize),
(XGate(), 2 * (1 + num_set_bits)),
And(): self.bitsize,
And().adjoint(): self.bitsize,
CNOT(): num_set_bits + 2 * self.bitsize,
XGate(): 2 * (1 + num_set_bits),
}


Expand Down Expand Up @@ -307,8 +300,8 @@ def __pow__(self, power: int) -> 'BiQubitsMixer':
return self.adjoint()
return NotImplemented # pragma: no cover

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
return {(XGate(), 1), (CNOT(), 9), (And(uncompute=self.is_adjoint), 2)}
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {XGate(): 1, CNOT(): 9, And(uncompute=self.is_adjoint): 2}

def _has_unitary_(self):
return not self.is_adjoint
Expand Down Expand Up @@ -380,8 +373,8 @@ def __pow__(self, power: int) -> Union['SingleQubitCompare', cirq.Gate]:
return self.adjoint()
return self

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
return {(XGate(), 1), (CNOT(), 4), (And(uncompute=self.is_adjoint), 1)}
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {XGate(): 1, CNOT(): 4, And(uncompute=self.is_adjoint): 1}


@bloq_example
Expand Down Expand Up @@ -575,13 +568,13 @@ def decompose_from_registers(
all_ancilla = set([q for op in adjoint for q in op.qubits if q not in input_qubits])
context.qubit_manager.qfree(all_ancilla)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
if is_symbolic(self.x_bitsize, self.y_bitsize):
return {
(BiQubitsMixer(), self.x_bitsize),
(BiQubitsMixer().adjoint(), self.x_bitsize),
(SingleQubitCompare(), 1),
(SingleQubitCompare().adjoint(), 1),
BiQubitsMixer(): self.x_bitsize,
BiQubitsMixer().adjoint(): self.x_bitsize,
SingleQubitCompare(): 1,
SingleQubitCompare().adjoint(): 1,
}

n = min(self.x_bitsize, self.y_bitsize)
Expand Down Expand Up @@ -613,7 +606,7 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
ret[And(1, 0).adjoint()] += 1
ret[CNOT()] += 1

return set(ret.items())
return ret

def _has_unitary_(self):
return True
Expand Down Expand Up @@ -691,8 +684,8 @@ def build_composite_bloq(
target = bb.add(XGate(), q=target)
return {'a': a, 'b': b, 'target': target}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
return {(LessThanEqual(self.a_bitsize, self.b_bitsize), 1), (XGate(), 1)}
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {LessThanEqual(self.a_bitsize, self.b_bitsize): 1, XGate(): 1}


@bloq_example
Expand Down Expand Up @@ -885,23 +878,23 @@ def wire_symbol(
return TextBox('t⨁(a>b)')
raise ValueError(f'Unknown register name {reg.name}')

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
if self.bitsize == 1:
return {(MultiControlX(cvs=(1, 0)), 1)}
return {MultiControlX(cvs=(1, 0)): 1}

if self.signed:
return {
(CNOT(), 6 * self.bitsize - 7),
(XGate(), 2 * self.bitsize + 2),
(And(), self.bitsize - 1),
(And(uncompute=True), self.bitsize - 1),
CNOT(): 6 * self.bitsize - 7,
XGate(): 2 * self.bitsize + 2,
And(): self.bitsize - 1,
And(uncompute=True): self.bitsize - 1,
}

return {
(CNOT(), 6 * self.bitsize - 1),
(XGate(), 2 * self.bitsize + 4),
(And(), self.bitsize),
(And(uncompute=True), self.bitsize),
CNOT(): 6 * self.bitsize - 1,
XGate(): 2 * self.bitsize + 4,
And(): self.bitsize,
And(uncompute=True): self.bitsize,
}


Expand Down Expand Up @@ -941,8 +934,8 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -
return TextBox(f"⨁(x > {self.val})")
raise ValueError(f'Unknown register symbol {reg.name}')

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
return {(LessThanConstant(self.bitsize, less_than_val=self.val), 1)}
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {LessThanConstant(self.bitsize, less_than_val=self.val): 1}


@bloq_example
Expand Down Expand Up @@ -1007,8 +1000,8 @@ def build_composite_bloq(
x = bb.join(xs)
return {'x': x, 'target': target}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
return {(MultiControlX(self.bits_k), 1)}
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {MultiControlX(self.bits_k): 1}


def _make_equals_a_constant():
Expand Down Expand Up @@ -1134,21 +1127,22 @@ def on_classical_vals(
return {'ctrl': ctrl, 'a': a, 'b': b, 'target': target ^ (a > b)}
return {'ctrl': ctrl, 'a': a, 'b': b, 'target': target}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
signed_ops = []
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
signed_ops: 'MutableBloqCountDictT' = {}
if isinstance(self.dtype, QInt):
signed_ops = [
(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)), 2),
(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)).adjoint(), 2),
]
signed_ops = {
SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)): 2,
SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)).adjoint(): 2,
}
dtype = attrs.evolve(self.dtype, bitsize=self.dtype.bitsize + 1)
return {
(BitwiseNot(dtype), 2),
(BitwiseNot(QUInt(dtype.bitsize + 1)), 2),
(OutOfPlaceAdder(self.dtype.bitsize + 1).adjoint(), 1),
(OutOfPlaceAdder(self.dtype.bitsize + 1), 1),
(MultiControlX((self.cv, 1)), 1),
}.union(signed_ops)
BitwiseNot(dtype): 2,
BitwiseNot(QUInt(dtype.bitsize + 1)): 2,
OutOfPlaceAdder(self.dtype.bitsize + 1).adjoint(): 1,
OutOfPlaceAdder(self.dtype.bitsize + 1): 1,
MultiControlX((self.cv, 1)): 1,
**signed_ops,
}


@bloq_example(generalizer=ignore_split_join)
Expand Down
12 changes: 6 additions & 6 deletions qualtran/bloqs/arithmetic/controlled_addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Set, TYPE_CHECKING, Union
from typing import Dict, TYPE_CHECKING, Union

import numpy as np
import sympy
Expand Down Expand Up @@ -42,7 +42,7 @@
import quimb.tensor as qtn

from qualtran.drawing import WireSymbol
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT


Expand Down Expand Up @@ -155,11 +155,11 @@ def build_composite_bloq(
ctrl = bb.join(np.array([ctrl_q]))
return {'ctrl': ctrl, 'a': a, 'b': b}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {
(And(self.cv, 1), self.a_dtype.bitsize),
(Add(self.a_dtype, self.b_dtype), 1),
(And(self.cv, 1).adjoint(), self.a_dtype.bitsize),
And(self.cv, 1): self.a_dtype.bitsize,
Add(self.a_dtype, self.b_dtype): 1,
And(self.cv, 1).adjoint(): self.a_dtype.bitsize,
}


Expand Down
Loading

0 comments on commit 488b9a5

Please sign in to comment.