Skip to content

Commit

Permalink
Remove short_name method (#934)
Browse files Browse the repository at this point in the history
* Remove short_name method

- Replace with wire_symbol(reg=None)
- Gates with no title label will now return Text('') and not print out a title.
  • Loading branch information
dstrain115 authored May 14, 2024
1 parent a4a0f92 commit 03d09a0
Show file tree
Hide file tree
Showing 77 changed files with 459 additions and 334 deletions.
15 changes: 9 additions & 6 deletions qualtran/_infra/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from functools import cached_property
from typing import Dict, List, Set, Tuple, TYPE_CHECKING
from typing import cast, Dict, List, Optional, Set, Tuple, TYPE_CHECKING

import cirq
from attrs import frozen
Expand Down Expand Up @@ -170,10 +170,6 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
"""The call graph takes the adjoint of each of the bloqs in `subbloq`'s call graph."""
return {(bloq.adjoint(), n) for bloq, n in self.subbloq.build_call_graph(ssa=ssa)}

def short_name(self) -> str:
"""The subbloq's short_name with a dagger."""
return self.subbloq.short_name() + '†'

def pretty_name(self) -> str:
"""The subbloq's pretty_name with a dagger."""
return self.subbloq.pretty_name() + '†'
Expand All @@ -182,10 +178,17 @@ def __str__(self) -> str:
"""Delegate to subbloq's `__str__` method."""
return f'Adjoint(subbloq={str(self.subbloq)})'

def wire_symbol(self, reg: 'Register', idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
def wire_symbol(
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
) -> 'WireSymbol':
# Note: since we pass are passed a soquet which has the 'new' side, we flip it before
# delegating and then flip back. Subbloqs only have to answer this protocol
# if the provided soquet is facing the correct direction.
from qualtran.drawing import Text

if reg is None:
return Text(cast(Text, self.subbloq.wire_symbol(reg=None)).text + '†')

return self.subbloq.wire_symbol(reg=reg.adjoint(), idx=idx).adjoint()

def _t_complexity_(self):
Expand Down
8 changes: 4 additions & 4 deletions qualtran/_infra/adjoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import Dict, TYPE_CHECKING
from typing import cast, Dict, TYPE_CHECKING

import pytest
import sympy
Expand All @@ -25,7 +25,7 @@
from qualtran.bloqs.for_testing.with_call_graph import TestBloqWithCallGraph
from qualtran.bloqs.for_testing.with_decomposition import TestParallelCombo, TestSerialCombo
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
from qualtran.drawing import LarrowTextBox, RarrowTextBox
from qualtran.drawing import LarrowTextBox, RarrowTextBox, Text

if TYPE_CHECKING:
from qualtran import BloqBuilder, SoquetT
Expand Down Expand Up @@ -149,11 +149,11 @@ def test_call_graph():
def test_names():
atom = TestAtom()
assert atom.pretty_name() == "TestAtom"
assert atom.short_name() == "Atom"
assert cast(Text, atom.wire_symbol(reg=None)).text == "TestAtom"

adj_atom = Adjoint(atom)
assert adj_atom.pretty_name() == "TestAtom†"
assert adj_atom.short_name() == "Atom†"
assert cast(Text, adj_atom.wire_symbol(reg=None)).text == "TestAtom†"
assert str(adj_atom) == "Adjoint(subbloq=TestAtom())"


Expand Down
27 changes: 17 additions & 10 deletions qualtran/_infra/bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,6 @@ def signature(self) -> 'Signature':
def pretty_name(self) -> str:
return self.__class__.__name__

def short_name(self) -> str:
name = self.pretty_name()
if len(name) <= 10:
return name

return name[:8] + '..'

def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str, 'SoquetT']:
"""Override this method to define a Bloq in terms of its constituent parts.
Expand Down Expand Up @@ -282,7 +275,9 @@ def add_my_tensors(
from qualtran.simulation.tensor import cbloq_as_contracted_tensor

cbloq = self.decompose_bloq()
tn.add(cbloq_as_contracted_tensor(cbloq, incoming, outgoing, tags=[self.short_name(), tag]))
tn.add(
cbloq_as_contracted_tensor(cbloq, incoming, outgoing, tags=[self.pretty_name(), tag])
)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
"""Override this method to build the bloq call graph.
Expand Down Expand Up @@ -508,18 +503,30 @@ def on_registers(

return self.on(*merge_qubits(self.signature, **qubit_regs))

def wire_symbol(self, reg: 'Register', idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
def wire_symbol(
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
) -> 'WireSymbol':
"""On a musical score visualization, use this `WireSymbol` to represent `soq`.
By default, we use a "directional text box", which is a text box that is either
rectangular for thru-registers or facing to the left or right for non-thru-registers.
If reg is specified as `None`, this should return a Text label for the title of
the gate. If no title is needed (as the wire_symbols are self-explanatory),
this should return `Text('')`.
Override this method to provide a more relevant `WireSymbol` for the provided soquet.
This method can access bloq attributes. For example: you may want to draw either
a filled or empty circle for a control register depending on a control value bloq
attribute.
"""
from qualtran.drawing import directional_text_box
from qualtran.drawing import directional_text_box, Text

if reg is None:
name = self.pretty_name()
if len(name) <= 10:
return Text(name)
return Text(name[:8] + '..')

label = reg.name
if len(idx) > 0:
Expand Down
13 changes: 6 additions & 7 deletions qualtran/_infra/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def add_my_tensors(
subbloq_shape = tensor_shape_from_signature(self.subbloq.signature)
data[active_idx] = self.subbloq.tensor_contract().reshape(subbloq_shape)
# Add the data to the tensor network.
tn.add(qtn.Tensor(data=data, inds=out_ind + in_ind, tags=[self.short_name(), tag]))
tn.add(qtn.Tensor(data=data, inds=out_ind + in_ind, tags=[self.pretty_name(), tag]))

def _unitary_(self):
if isinstance(self.subbloq, GateWithRegisters):
Expand All @@ -433,11 +433,13 @@ def _unitary_(self):
# Unable to determine the unitary effect.
return NotImplemented

def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
from qualtran.drawing import Text

if reg is None:
return Text(f'C[{self.subbloq.wire_symbol(reg=None)}]')
if reg.name not in self.ctrl_reg_names:
# Delegate to subbloq
print(self.subbloq)
print(type(self.subbloq))
return self.subbloq.wire_symbol(reg, idx)

# Otherwise, it's part of the control register.
Expand All @@ -450,9 +452,6 @@ def adjoint(self) -> 'Bloq':
def pretty_name(self) -> str:
return f'C[{self.subbloq.pretty_name()}]'

def short_name(self) -> str:
return f'C[{self.subbloq.short_name()}]'

def __str__(self) -> str:
return f'C[{self.subbloq}]'

Expand Down
14 changes: 6 additions & 8 deletions qualtran/_infra/gate_with_registers.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,12 @@ def as_cirq_op(
)
return self.on_registers(**all_quregs), out_quregs

def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
from qualtran.cirq_interop._cirq_to_bloq import _wire_symbol_from_gate
from qualtran.drawing import Text

if reg is None:
return Text(self.pretty_name())

return _wire_symbol_from_gate(self, self.signature, reg, idx)

Expand Down Expand Up @@ -515,13 +519,7 @@ def add_my_tensors(
from qualtran.cirq_interop._cirq_to_bloq import _add_my_tensors_from_gate

_add_my_tensors_from_gate(
self,
self.signature,
self.short_name(),
tn,
tag,
incoming=incoming,
outgoing=outgoing,
self, self.signature, str(self), tn, tag, incoming=incoming, outgoing=outgoing
)
else:
return super().add_my_tensors(tn, tag, incoming=incoming, outgoing=outgoing)
Expand Down
15 changes: 7 additions & 8 deletions qualtran/bloqs/arithmetic/addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def add_my_tensors(
for a, b in itertools.product(range(N_a), range(N_b)):
unitary[a, b, a, int(math.fmod(a + b, N_b))] = 1

tn.add(qtn.Tensor(data=unitary, inds=inds, tags=[self.short_name(), tag]))
tn.add(qtn.Tensor(data=unitary, inds=inds, tags=[self.pretty_name(), tag]))

def decompose_bloq(self) -> 'CompositeBloq':
return decompose_from_cirq_style_method(self)
Expand All @@ -155,17 +155,16 @@ def on_classical_vals(
N = 2**b_bitsize if unsigned else 2 ** (b_bitsize - 1)
return {'a': a, 'b': int(math.fmod(a + b, N))}

def short_name(self) -> str:
return "a+b"

def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo:
wire_symbols = ["In(x)"] * int(self.a_dtype.bitsize)
wire_symbols += ["In(y)/Out(x+y)"] * int(self.b_dtype.bitsize)
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)

def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
from qualtran.drawing import directional_text_box
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
from qualtran.drawing import directional_text_box, Text

if reg is None:
return Text("a+b")
if reg.name == 'a':
return directional_text_box('a', side=reg.side)
elif reg.name == 'b':
Expand Down Expand Up @@ -318,7 +317,7 @@ def on_classical_vals(
def with_registers(self, *new_registers: Union[int, Sequence[int]]):
raise NotImplementedError("no need to implement with_registers.")

def short_name(self) -> str:
def pretty_name(self) -> str:
return "c = a + b"

def decompose_from_registers(
Expand Down Expand Up @@ -501,7 +500,7 @@ def build_composite_bloq(
else:
return {'x': x}

def short_name(self) -> str:
def pretty_name(self) -> str:
return f'x += {self.k}'


Expand Down
60 changes: 40 additions & 20 deletions qualtran/bloqs/arithmetic/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,18 @@
# limitations under the License.

from functools import cached_property
from typing import Dict, Iterable, Iterator, List, Sequence, Set, Tuple, TYPE_CHECKING, Union
from typing import (
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
Union,
)

import attrs
import cirq
Expand Down Expand Up @@ -42,7 +53,7 @@
from qualtran.cirq_interop.bit_tools import iter_bits
from qualtran.cirq_interop.t_complexity_protocol import t_complexity, TComplexity
from qualtran.drawing import WireSymbol
from qualtran.drawing.musical_score import TextBox
from qualtran.drawing.musical_score import Text, TextBox

if TYPE_CHECKING:
from qualtran import BloqBuilder
Expand All @@ -62,8 +73,12 @@ class LessThanConstant(GateWithRegisters, cirq.ArithmeticGate): # type: ignore[
def signature(self) -> Signature:
return Signature.build_from_dtypes(x=QUInt(self.bitsize), target=QBit())

def short_name(self) -> str:
return f'x<{self.less_than_val}'
def wire_symbol(
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
) -> 'WireSymbol':
if reg is None:
return Text(f'x<{self.less_than_val}')
return super().wire_symbol(reg, idx)

def registers(self) -> Sequence[Union[int, Sequence[int]]]:
return [2] * self.bitsize, self.less_than_val, [2]
Expand Down Expand Up @@ -428,8 +443,12 @@ def apply(self, *register_vals: int) -> Union[int, int, Iterable[int]]:
x_val, y_val, target_val = register_vals
return x_val, y_val, target_val ^ (x_val <= y_val)

def short_name(self) -> str:
return 'x <= y'
def wire_symbol(
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
) -> 'WireSymbol':
if reg is None:
return Text('x <= y')
return super().wire_symbol(reg, idx)

def on_classical_vals(self, *, x: int, y: int, target: int) -> Dict[str, 'ClassicalValT']:
return {'x': x, 'y': y, 'target': target ^ (x <= y)}
Expand Down Expand Up @@ -599,16 +618,15 @@ def signature(self):
a=QUInt(self.a_bitsize), b=QUInt(self.b_bitsize), target=QBit()
)

def short_name(self) -> str:
return "a>b"

def _t_complexity_(self) -> 'TComplexity':
# TODO Determine precise clifford count and/or ignore.
# See: https://github.com/quantumlib/Qualtran/issues/219
# See: https://github.com/quantumlib/Qualtran/issues/217
return t_complexity(LessThanEqual(self.a_bitsize, self.b_bitsize))

def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> WireSymbol:
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> WireSymbol:
if reg is None:
return Text("a>b")
if reg.name == 'a':
return TextBox("In(a)")
if reg.name == 'b':
Expand Down Expand Up @@ -799,8 +817,12 @@ def build_composite_bloq(
# Return the output registers.
return {'a': a, 'b': b, 'target': target}

def short_name(self) -> str:
return "a > b"
def wire_symbol(
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
) -> 'WireSymbol':
if reg is None:
return Text('a > b')
return super().wire_symbol(reg, idx)


@frozen
Expand Down Expand Up @@ -836,10 +858,9 @@ def _t_complexity_(self) -> TComplexity:
# See: https://github.com/quantumlib/Qualtran/issues/217
return t_complexity(LessThanConstant(self.bitsize, less_than_val=self.val))

def short_name(self) -> str:
return f"x > {self.val}"

def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> WireSymbol:
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> WireSymbol:
if reg is None:
return Text(f"x > {self.val}")
if reg.name == 'x':
return TextBox("In(x)")
elif reg.name == 'target':
Expand Down Expand Up @@ -889,10 +910,9 @@ def signature(self) -> Signature:
def _t_complexity_(self) -> 'TComplexity':
return TComplexity(t=4 * (self.bitsize - 1))

def short_name(self) -> str:
return f"x == {self.val}"

def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> WireSymbol:
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> WireSymbol:
if reg is None:
return Text(f"x == {self.val}")
if reg.name == 'x':
return TextBox("In(x)")
elif reg.name == 'target':
Expand Down
Loading

0 comments on commit 03d09a0

Please sign in to comment.