diff --git a/cirq-core/cirq/circuits/moment.py b/cirq-core/cirq/circuits/moment.py index 923adfbd0de..1d643d4da1a 100644 --- a/cirq-core/cirq/circuits/moment.py +++ b/cirq-core/cirq/circuits/moment.py @@ -36,11 +36,17 @@ from typing_extensions import Self import numpy as np +from scipy.cluster.hierarchy import DisjointSet from cirq import protocols, ops, qis, _compat from cirq._import import LazyLoader from cirq.ops import raw_types, op_tree -from cirq.protocols import circuit_diagram_info_protocol +from cirq.protocols import ( + circuit_diagram_info_protocol, + apply_unitary, + ApplyUnitaryArgs, + definitely_commutes, +) from cirq.type_workarounds import NotImplementedType if TYPE_CHECKING: @@ -657,10 +663,11 @@ def cleanup_key(key: Any) -> Any: return diagram.render() def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplementedType]: - """Determines whether Moment commutes with the Operation. + """Determines whether Moment commutes with either another Moment or + an Operation. Args: - other: An Operation object. Other types are not implemented yet. + other: An Operation or Moment object. Other types are not implemented yet. In case a different type is specified, NotImplemented is returned. atol: Absolute error tolerance. If all entries in v1@v2 - v2@v1 @@ -669,25 +676,92 @@ def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplem Returns: True: The Moment and Operation commute OR they don't have shared - quibits. + qubits. False: The two values do not commute. NotImplemented: In case we don't know how to check this, e.g. the parameter type is not supported yet. """ - if not isinstance(other, ops.Operation): - return NotImplemented - - other_qubits = set(other.qubits) - for op in self.operations: - if not other_qubits.intersection(set(op.qubits)): - continue - - commutes = protocols.commutes(op, other, atol=atol, default=NotImplemented) + if isinstance(other, ops.Operation): + # If an Operation is provided, convert this to a Moment consisting only + # of the given Operation + return self._commutes_(Moment(other), atol=atol) + + if isinstance(other, Moment): + # Check if sets of qubits overlap. If not, then no need to go any further. + if not set(self.qubits) & set(other.qubits): + return True + + # Check pairwise commuting between all pairs of + # operations. If they all commute then no + # need to go any further + if all( + definitely_commutes(op_1, op_2, atol=atol) + for op_1, op_2 in itertools.product(self.operations, other.operations) + ): + return True + + # Decompose into disjoint overlapping sets of qubits + qubit_subsets = [list(op.qubits) for op in self.operations + other.operations] + disjoint_set = DisjointSet(itertools.chain.from_iterable(qubit_subsets)) + for subset in qubit_subsets: + if len(subset) < 2: + continue + for k in range(len(subset) - 1): + disjoint_set.merge(subset[k], subset[k + 1]) + disjoint_qubit_subsets = disjoint_set.subsets() + + # Decompose both moments onto each disjoint set of qubits and + # check for commutation using the unitary representation + if all( + definitely_commutes( + self[disjoint_set]._unitary_on_qubits(list(disjoint_set)), + other[disjoint_set]._unitary_on_qubits(list(disjoint_set)), + atol=atol, + ) + for disjoint_set in disjoint_qubit_subsets + ): + return True + + return False + + return NotImplemented + + def _unitary_on_qubits(self, target_qubits: list['cirq.Qid']) -> np.ndarray: + """Returns the unitary representation of the given moment when acting + on the target qubits. + + .. note:: + + The :code:`target_qubits` must contain all the qubits that the + moment acts on. - if not commutes or commutes is NotImplemented: - return commutes + Args: + moment: The moment to decompose. + target_basis: The target qubits. - return True + Returns: + np.ndarray: The unitary representation of the Moment on the + target qubits. + """ + # Check moment has support on subset of target qubits and that there + # are no duplicates + current_qubits = self.qubits + assert all(qubit in target_qubits for qubit in current_qubits) + assert len(set(target_qubits)) == len(target_qubits) + # Define dims + total_qubits = len(target_qubits) + dim = 2**total_qubits + # Get the indices of the target qubit that the moment has support on + qubit_indices = [target_qubits.index(qubit) for qubit in current_qubits] + + # Get the tensor operation corresponding to the moment acting on the + # target qubits. + id_tensor = qis.eye_tensor((2,) * total_qubits, dtype=np.complex128) + unitary = apply_unitary( + self, args=ApplyUnitaryArgs(id_tensor, np.empty_like(id_tensor), qubit_indices) + ) + # Reshape into a square unitary matrix + return unitary.reshape(dim, dim) class _SortByValFallbackToType: diff --git a/cirq-core/cirq/circuits/moment_test.py b/cirq-core/cirq/circuits/moment_test.py index b31457223d9..08704066a33 100644 --- a/cirq-core/cirq/circuits/moment_test.py +++ b/cirq-core/cirq/circuits/moment_test.py @@ -693,6 +693,18 @@ def test_commutes(): assert not cirq.commutes(moment, cirq.X(c)) +def test_commutes_multiqubit_gates(): + a = cirq.NamedQubit('a') + b = cirq.NamedQubit('b') + c = cirq.NamedQubit("c") + + moment = cirq.Moment([cirq.Z(a), cirq.Z(b)]) + assert cirq.commutes(moment, cirq.XXPowGate(exponent=1 / 2)(a, b)) + + moment = cirq.Moment([cirq.XXPowGate(exponent=1 / 2)(a, b), cirq.Z(c)]) + assert not cirq.commutes(moment, cirq.Z(b)) + + def test_transform_qubits(): a, b = cirq.LineQubit.range(2) x, y = cirq.GridQubit.rect(2, 1, 10, 20)