-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add functionalty for determining whether pairs of moments commute #6679
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,11 +36,17 @@ | |
from typing_extensions import Self | ||
|
||
import numpy as np | ||
from scipy.cluster.hierarchy import DisjointSet | ||
|
||
from cirq import protocols, ops, qis, _compat | ||
from cirq._import import LazyLoader | ||
from cirq.ops import raw_types, op_tree | ||
from cirq.protocols import circuit_diagram_info_protocol | ||
from cirq.protocols import ( | ||
circuit_diagram_info_protocol, | ||
apply_unitary, | ||
ApplyUnitaryArgs, | ||
definitely_commutes, | ||
) | ||
from cirq.type_workarounds import NotImplementedType | ||
|
||
if TYPE_CHECKING: | ||
|
@@ -657,10 +663,11 @@ def cleanup_key(key: Any) -> Any: | |
return diagram.render() | ||
|
||
def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplementedType]: | ||
"""Determines whether Moment commutes with the Operation. | ||
"""Determines whether Moment commutes with either another Moment or | ||
an Operation. | ||
|
||
Args: | ||
other: An Operation object. Other types are not implemented yet. | ||
other: An Operation or Moment object. Other types are not implemented yet. | ||
In case a different type is specified, NotImplemented is | ||
returned. | ||
atol: Absolute error tolerance. If all entries in v1@v2 - v2@v1 | ||
|
@@ -669,25 +676,92 @@ def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplem | |
|
||
Returns: | ||
True: The Moment and Operation commute OR they don't have shared | ||
quibits. | ||
qubits. | ||
False: The two values do not commute. | ||
NotImplemented: In case we don't know how to check this, e.g. | ||
the parameter type is not supported yet. | ||
""" | ||
if not isinstance(other, ops.Operation): | ||
return NotImplemented | ||
|
||
other_qubits = set(other.qubits) | ||
for op in self.operations: | ||
if not other_qubits.intersection(set(op.qubits)): | ||
continue | ||
|
||
commutes = protocols.commutes(op, other, atol=atol, default=NotImplemented) | ||
if isinstance(other, ops.Operation): | ||
# If an Operation is provided, convert this to a Moment consisting only | ||
# of the given Operation | ||
return self._commutes_(Moment(other), atol=atol) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not preserve the old code? ... perhaps in a |
||
|
||
if isinstance(other, Moment): | ||
# Check if sets of qubits overlap. If not, then no need to go any further. | ||
if not set(self.qubits) & set(other.qubits): | ||
return True | ||
|
||
# Check pairwise commuting between all pairs of | ||
# operations. If they all commute then no | ||
# need to go any further | ||
if all( | ||
definitely_commutes(op_1, op_2, atol=atol) | ||
for op_1, op_2 in itertools.product(self.operations, other.operations) | ||
): | ||
return True | ||
|
||
# Decompose into disjoint overlapping sets of qubits | ||
qubit_subsets = [list(op.qubits) for op in self.operations + other.operations] | ||
disjoint_set = DisjointSet(itertools.chain.from_iterable(qubit_subsets)) | ||
for subset in qubit_subsets: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens if a qubit appears in both moments in operations with different qubits? |
||
if len(subset) < 2: | ||
continue | ||
for k in range(len(subset) - 1): | ||
disjoint_set.merge(subset[k], subset[k + 1]) | ||
disjoint_qubit_subsets = disjoint_set.subsets() | ||
|
||
# Decompose both moments onto each disjoint set of qubits and | ||
# check for commutation using the unitary representation | ||
if all( | ||
definitely_commutes( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would this code raise an error if the operations are not unitary? |
||
self[disjoint_set]._unitary_on_qubits(list(disjoint_set)), | ||
other[disjoint_set]._unitary_on_qubits(list(disjoint_set)), | ||
atol=atol, | ||
) | ||
for disjoint_set in disjoint_qubit_subsets | ||
): | ||
return True | ||
|
||
return False | ||
|
||
return NotImplemented | ||
|
||
def _unitary_on_qubits(self, target_qubits: list['cirq.Qid']) -> np.ndarray: | ||
"""Returns the unitary representation of the given moment when acting | ||
on the target qubits. | ||
|
||
.. note:: | ||
|
||
The :code:`target_qubits` must contain all the qubits that the | ||
moment acts on. | ||
|
||
if not commutes or commutes is NotImplemented: | ||
return commutes | ||
Args: | ||
moment: The moment to decompose. | ||
target_basis: The target qubits. | ||
|
||
return True | ||
Returns: | ||
np.ndarray: The unitary representation of the Moment on the | ||
target qubits. | ||
""" | ||
# Check moment has support on subset of target qubits and that there | ||
# are no duplicates | ||
current_qubits = self.qubits | ||
assert all(qubit in target_qubits for qubit in current_qubits) | ||
assert len(set(target_qubits)) == len(target_qubits) | ||
# Define dims | ||
total_qubits = len(target_qubits) | ||
dim = 2**total_qubits | ||
# Get the indices of the target qubit that the moment has support on | ||
qubit_indices = [target_qubits.index(qubit) for qubit in current_qubits] | ||
|
||
# Get the tensor operation corresponding to the moment acting on the | ||
# target qubits. | ||
id_tensor = qis.eye_tensor((2,) * total_qubits, dtype=np.complex128) | ||
unitary = apply_unitary( | ||
self, args=ApplyUnitaryArgs(id_tensor, np.empty_like(id_tensor), qubit_indices) | ||
) | ||
# Reshape into a square unitary matrix | ||
return unitary.reshape(dim, dim) | ||
|
||
|
||
class _SortByValFallbackToType: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -693,6 +693,18 @@ def test_commutes(): | |
assert not cirq.commutes(moment, cirq.X(c)) | ||
|
||
|
||
def test_commutes_multiqubit_gates(): | ||
a = cirq.NamedQubit('a') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this test case is very simple and doesn't cover the more complex cases ... see above |
||
b = cirq.NamedQubit('b') | ||
c = cirq.NamedQubit("c") | ||
|
||
moment = cirq.Moment([cirq.Z(a), cirq.Z(b)]) | ||
assert cirq.commutes(moment, cirq.XXPowGate(exponent=1 / 2)(a, b)) | ||
|
||
moment = cirq.Moment([cirq.XXPowGate(exponent=1 / 2)(a, b), cirq.Z(c)]) | ||
assert not cirq.commutes(moment, cirq.Z(b)) | ||
|
||
|
||
def test_transform_qubits(): | ||
a, b = cirq.LineQubit.range(2) | ||
x, y = cirq.GridQubit.rect(2, 1, 10, 20) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you fix the formatting?