From 64861ee86e87995112fbfc98e932d9b69d02746e Mon Sep 17 00:00:00 2001 From: Cameron Booker Date: Fri, 26 Jul 2024 09:25:43 +0100 Subject: [PATCH 1/2] feature/add_moment_commuting_function --- cirq-core/cirq/circuits/moment.py | 101 +++++++++++++++++++++++++----- 1 file changed, 85 insertions(+), 16 deletions(-) diff --git a/cirq-core/cirq/circuits/moment.py b/cirq-core/cirq/circuits/moment.py index 5128021bcb6..efc2fbbda4f 100644 --- a/cirq-core/cirq/circuits/moment.py +++ b/cirq-core/cirq/circuits/moment.py @@ -36,11 +36,12 @@ from typing_extensions import Self import numpy as np +from scipy.cluster.hierarchy import DisjointSet from cirq import protocols, ops, qis, _compat from cirq._import import LazyLoader from cirq.ops import raw_types, op_tree -from cirq.protocols import circuit_diagram_info_protocol +from cirq.protocols import circuit_diagram_info_protocol, apply_unitary, ApplyUnitaryArgs from cirq.type_workarounds import NotImplementedType if TYPE_CHECKING: @@ -648,10 +649,11 @@ def cleanup_key(key: Any) -> Any: return diagram.render() def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplementedType]: - """Determines whether Moment commutes with the Operation. + """Determines whether Moment commutes with either another Moment or + an Operation. Args: - other: An Operation object. Other types are not implemented yet. + other: An Operation or Moment object. Other types are not implemented yet. In case a different type is specified, NotImplemented is returned. atol: Absolute error tolerance. If all entries in v1@v2 - v2@v1 @@ -660,25 +662,92 @@ def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplem Returns: True: The Moment and Operation commute OR they don't have shared - quibits. + qubits. False: The two values do not commute. NotImplemented: In case we don't know how to check this, e.g. the parameter type is not supported yet. """ - if not isinstance(other, ops.Operation): - return NotImplemented - - other_qubits = set(other.qubits) - for op in self.operations: - if not other_qubits.intersection(set(op.qubits)): - continue - - commutes = protocols.commutes(op, other, atol=atol, default=NotImplemented) + if isinstance(other, ops.Operation): + # If an Operation is provided, convert this to a Moment consisting only + # of the given Operation + return self._commutes_(Moment(other), atol=atol) + + if isinstance(other, Moment): + # Check if sets of qubits overlap. If not, then no need to go any further. + if not set(self.qubits) & set(other.qubits): + return True + + # Check pairwise commuting between all pairs of + # operations. If they all commute then no + # need to go any further + if all( + cirq.definitely_commutes(op_1, op_2, atol=atol) + for op_1, op_2 in itertools.product(self.operations, other.operations) + ): + return True + + # Decompose into disjoint overlapping sets of qubits + qubit_subsets = [list(op.qubits) for op in self.operations + other.operations] + disjoint_set = DisjointSet(itertools.chain.from_iterable(qubit_subsets)) + for subset in qubit_subsets: + if len(subset) < 2: + continue + for k in range(len(subset) - 1): + disjoint_set.merge(subset[k], subset[k + 1]) + disjoint_qubit_subsets = disjoint_set.subsets() + + # Decompose both moments onto each disjoint set of qubits and + # check for commutation using the unitary representation + if all( + cirq.definitely_commutes( + self._unitary_on_qubits(list(disjoint_set)), + other._unitary_on_qubits(list(disjoint_set)), + atol=atol, + ) + for disjoint_set in disjoint_qubit_subsets + ): + return True + + return False + + return NotImplemented + + def _unitary_on_qubits(self, target_qubits: list['cirq.Qid']) -> np.ndarray: + """Returns the unitary representation of the given moment when acting + on the target qubits. + + .. note:: + + The :code:`target_qubits` must contain all the qubits that the + moment acts on. - if not commutes or commutes is NotImplemented: - return commutes + Args: + moment: The moment to decompose. + target_basis: The target qubits. - return True + Returns: + np.ndarray: The unitary representation of the Moment on the + target qubits. + """ + # Check moment has support on subset of target qubits and that there + # are no duplicates + current_qubits = self.qubits + assert all(qubit in target_qubits for qubit in current_qubits) + assert len(set(target_qubits)) == len(target_qubits) + # Define dims + total_qubits = len(target_qubits) + dim = 2**total_qubits + # Get the indices of the target qubit that the moment has support on + qubit_indices = [target_qubits.index(qubit) for qubit in current_qubits] + + # Get the tensor operation corresponding to the moment acting on the + # target qubits. + id_tensor = cirq.qis.eye_tensor((2,) * total_qubits, dtype=np.complex128) + unitary = apply_unitary( + self, args=ApplyUnitaryArgs(id_tensor, np.empty_like(id_tensor), qubit_indices) + ) + # Reshape into a square unitary matrix + return unitary.reshape(dim, dim) class _SortByValFallbackToType: From acebb77026932f5ead7863b94b56658a68baacde Mon Sep 17 00:00:00 2001 From: Cameron Booker Date: Fri, 26 Jul 2024 12:12:29 +0100 Subject: [PATCH 2/2] feature: Add unit tests --- cirq-core/cirq/circuits/moment.py | 17 +++++++++++------ cirq-core/cirq/circuits/moment_test.py | 12 ++++++++++++ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/cirq-core/cirq/circuits/moment.py b/cirq-core/cirq/circuits/moment.py index efc2fbbda4f..8c4f4285ca9 100644 --- a/cirq-core/cirq/circuits/moment.py +++ b/cirq-core/cirq/circuits/moment.py @@ -41,7 +41,12 @@ from cirq import protocols, ops, qis, _compat from cirq._import import LazyLoader from cirq.ops import raw_types, op_tree -from cirq.protocols import circuit_diagram_info_protocol, apply_unitary, ApplyUnitaryArgs +from cirq.protocols import ( + circuit_diagram_info_protocol, + apply_unitary, + ApplyUnitaryArgs, + definitely_commutes, +) from cirq.type_workarounds import NotImplementedType if TYPE_CHECKING: @@ -681,7 +686,7 @@ def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplem # operations. If they all commute then no # need to go any further if all( - cirq.definitely_commutes(op_1, op_2, atol=atol) + definitely_commutes(op_1, op_2, atol=atol) for op_1, op_2 in itertools.product(self.operations, other.operations) ): return True @@ -699,9 +704,9 @@ def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplem # Decompose both moments onto each disjoint set of qubits and # check for commutation using the unitary representation if all( - cirq.definitely_commutes( - self._unitary_on_qubits(list(disjoint_set)), - other._unitary_on_qubits(list(disjoint_set)), + definitely_commutes( + self[disjoint_set]._unitary_on_qubits(list(disjoint_set)), + other[disjoint_set]._unitary_on_qubits(list(disjoint_set)), atol=atol, ) for disjoint_set in disjoint_qubit_subsets @@ -742,7 +747,7 @@ def _unitary_on_qubits(self, target_qubits: list['cirq.Qid']) -> np.ndarray: # Get the tensor operation corresponding to the moment acting on the # target qubits. - id_tensor = cirq.qis.eye_tensor((2,) * total_qubits, dtype=np.complex128) + id_tensor = qis.eye_tensor((2,) * total_qubits, dtype=np.complex128) unitary = apply_unitary( self, args=ApplyUnitaryArgs(id_tensor, np.empty_like(id_tensor), qubit_indices) ) diff --git a/cirq-core/cirq/circuits/moment_test.py b/cirq-core/cirq/circuits/moment_test.py index b31457223d9..08704066a33 100644 --- a/cirq-core/cirq/circuits/moment_test.py +++ b/cirq-core/cirq/circuits/moment_test.py @@ -693,6 +693,18 @@ def test_commutes(): assert not cirq.commutes(moment, cirq.X(c)) +def test_commutes_multiqubit_gates(): + a = cirq.NamedQubit('a') + b = cirq.NamedQubit('b') + c = cirq.NamedQubit("c") + + moment = cirq.Moment([cirq.Z(a), cirq.Z(b)]) + assert cirq.commutes(moment, cirq.XXPowGate(exponent=1 / 2)(a, b)) + + moment = cirq.Moment([cirq.XXPowGate(exponent=1 / 2)(a, b), cirq.Z(c)]) + assert not cirq.commutes(moment, cirq.Z(b)) + + def test_transform_qubits(): a, b = cirq.LineQubit.range(2) x, y = cirq.GridQubit.rect(2, 1, 10, 20)