From c8f7a02980b910ee78706d6b3b3793ccaa108bef Mon Sep 17 00:00:00 2001 From: Mark Edmiston <9750597+markedmiston@users.noreply.github.com> Date: Fri, 3 Mar 2023 14:53:15 -0500 Subject: [PATCH] MatrixGate names don't survive serialization (#6026) --- cirq-core/cirq/ops/matrix_gates.py | 10 +++++++--- cirq-core/cirq/ops/matrix_gates_test.py | 22 ++++++++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/ops/matrix_gates.py b/cirq-core/cirq/ops/matrix_gates.py index 7ee9cb13cc6..0f7c7232c54 100644 --- a/cirq-core/cirq/ops/matrix_gates.py +++ b/cirq-core/cirq/ops/matrix_gates.py @@ -114,11 +114,15 @@ def with_name(self, name: str) -> 'MatrixGate': return MatrixGate(self._matrix, name=name, qid_shape=self._qid_shape, unitary_check=False) def _json_dict_(self) -> Dict[str, Any]: - return {'matrix': self._matrix.tolist(), 'qid_shape': self._qid_shape} + return { + 'matrix': self._matrix.tolist(), + 'qid_shape': self._qid_shape, + **({'name': self._name} if self._name is not None else {}), + } @classmethod - def _from_json_dict_(cls, matrix, qid_shape, **kwargs): - return cls(matrix=np.array(matrix), qid_shape=qid_shape) + def _from_json_dict_(cls, matrix, qid_shape, name=None, **kwargs): + return cls(matrix=np.array(matrix), qid_shape=qid_shape, name=name) def _qid_shape_(self) -> Tuple[int, ...]: return self._qid_shape diff --git a/cirq-core/cirq/ops/matrix_gates_test.py b/cirq-core/cirq/ops/matrix_gates_test.py index 8993088ce5a..2bc8a7e18b6 100644 --- a/cirq-core/cirq/ops/matrix_gates_test.py +++ b/cirq-core/cirq/ops/matrix_gates_test.py @@ -388,3 +388,25 @@ def test_matrixgate_unitary_tolerance(): # very low atol -> the check never converges with pytest.raises(ValueError): _ = cirq.MatrixGate(np.array([[0.707, 0.707], [-0.707, 0.707]]), unitary_check_rtol=1e-10) + + +def test_matrixgate_name_serialization(): + # https://github.com/quantumlib/Cirq/issues/5999 + + # Test name serialization + gate1 = cirq.MatrixGate(np.eye(2), name='test_name') + gate_after_serialization1 = cirq.read_json(json_text=cirq.to_json(gate1)) + assert gate1._name == 'test_name' + assert gate_after_serialization1._name == 'test_name' + + # Test name backwards compatibility + gate2 = cirq.MatrixGate(np.eye(2)) + gate_after_serialization2 = cirq.read_json(json_text=cirq.to_json(gate2)) + assert gate2._name is None + assert gate_after_serialization2._name is None + + # Test empty name + gate3 = cirq.MatrixGate(np.eye(2), name='') + gate_after_serialization3 = cirq.read_json(json_text=cirq.to_json(gate3)) + assert gate3._name == '' + assert gate_after_serialization3._name == ''