Skip to content

Commit

Permalink
Test hash consistency of cirq objects loaded from a pickle (#6677)
Browse files Browse the repository at this point in the history
Ensure cirq objects passed as pickles in multiprocessing calls work
consistently in dictionaries just as the in-process created objects.

Partially resolves #6674
  • Loading branch information
pavoljuhas authored Aug 7, 2024
1 parent 51e8c3d commit 5377fff
Show file tree
Hide file tree
Showing 21 changed files with 261 additions and 14 deletions.
11 changes: 11 additions & 0 deletions cirq-core/cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import math
from functools import cached_property
from typing import (
Any,
Callable,
cast,
Dict,
Expand Down Expand Up @@ -508,6 +509,16 @@ def _hash(self) -> int:
def __hash__(self) -> int:
return self._hash

def __getstate__(self) -> Dict[str, Any]:
# clear cached hash value when pickling, see #6674
state = self.__dict__
# cached_property stores value in the property-named attribute
hash_attr = "_hash"
if hash_attr in state:
state = state.copy()
del state[hash_attr]
return state

def _json_dict_(self):
resp = {
'circuit': self.circuit,
Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/circuits/frozen_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ def __eq__(self, other):
def __getstate__(self):
# Don't save hash when pickling; see #3777.
state = self.__dict__
hash_cache = _compat._method_cache_name(self.__hash__)
if hash_cache in state:
hash_attr = _compat._method_cache_name(self.__hash__)
if hash_attr in state:
state = state.copy()
del state[hash_cache]
del state[hash_attr]
return state

@_compat.cached_method
Expand Down
9 changes: 9 additions & 0 deletions cirq-core/cirq/circuits/moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,15 @@ def __ne__(self, other) -> bool:
def __hash__(self):
return hash((Moment, self._sorted_operations_()))

def __getstate__(self) -> Dict[str, Any]:
# clear cached hash value when pickling, see #6674
state = self.__dict__
hash_attr = _compat._method_cache_name(self.__hash__)
if hash_attr in state:
state = state.copy()
del state[hash_attr]
return state

def __iter__(self) -> Iterator['cirq.Operation']:
return iter(self.operations)

Expand Down
8 changes: 8 additions & 0 deletions cirq-core/cirq/devices/grid_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ def __getnewargs_ex__(self):
"""Returns a tuple of (args, kwargs) to pass to __new__ when unpickling."""
return (self._row, self._col), {"dimension": self._dimension}

# avoid pickling the _hash value, attributes are already stored with __getnewargs_ex__
def __getstate__(self) -> Dict[str, Any]:
return {}

def _with_row_col(self, row: int, col: int) -> 'GridQid':
return GridQid(row, col, dimension=self._dimension)

Expand Down Expand Up @@ -387,6 +391,10 @@ def __getnewargs__(self):
"""Returns a tuple of args to pass to __new__ when unpickling."""
return (self._row, self._col)

# avoid pickling the _hash value, attributes are already stored with __getnewargs__
def __getstate__(self) -> Dict[str, Any]:
return {}

def _with_row_col(self, row: int, col: int) -> 'GridQubit':
return GridQubit(row, col)

Expand Down
8 changes: 8 additions & 0 deletions cirq-core/cirq/devices/line_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,10 @@ def __getnewargs__(self):
"""Returns a tuple of args to pass to __new__ when unpickling."""
return (self._x, self._dimension)

# avoid pickling the _hash value, attributes are already stored with __getnewargs__
def __getstate__(self) -> Dict[str, Any]:
return {}

def _with_x(self, x: int) -> 'LineQid':
return LineQid(x, dimension=self._dimension)

Expand Down Expand Up @@ -308,6 +312,10 @@ def __getnewargs__(self):
"""Returns a tuple of args to pass to __new__ when unpickling."""
return (self._x,)

# avoid pickling the _hash value, attributes are already stored with __getnewargs__
def __getstate__(self) -> Dict[str, Any]:
return {}

def _with_x(self, x: int) -> 'LineQubit':
return LineQubit(x)

Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/boolean_hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _qid_shape_(self) -> Tuple[int, ...]:
return (2,) * len(self._parameter_names)

def _value_equality_values_(self) -> Any:
return self._parameter_names, self._boolean_strs, self._theta
return tuple(self._parameter_names), tuple(self._boolean_strs), self._theta

def _json_dict_(self) -> Dict[str, Any]:
return {
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/common_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _has_mixture_(self) -> bool:
return True

def _value_equality_values_(self):
return self._num_qubits, hash(tuple(sorted(self._error_probabilities.items())))
return self._num_qubits, tuple(sorted(self._error_probabilities.items()))

def __repr__(self) -> str:
return 'cirq.asymmetric_depolarize(' + f"error_probabilities={self._error_probabilities})"
Expand Down
8 changes: 8 additions & 0 deletions cirq-core/cirq/ops/named_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ def __getnewargs__(self):
"""Returns a tuple of args to pass to __new__ when unpickling."""
return (self._name, self._dimension)

# avoid pickling the _hash value, attributes are already stored with __getnewargs__
def __getstate__(self) -> Dict[str, Any]:
return {}

def __repr__(self) -> str:
return f'cirq.NamedQid({self._name!r}, dimension={self._dimension})'

Expand Down Expand Up @@ -202,6 +206,10 @@ def __getnewargs__(self):
"""Returns a tuple of args to pass to __new__ when unpickling."""
return (self._name,)

# avoid pickling the _hash value, attributes are already stored with __getnewargs__
def __getstate__(self) -> Dict[str, Any]:
return {}

def __str__(self) -> str:
return self._name

Expand Down
11 changes: 10 additions & 1 deletion cirq-core/cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

from cirq import protocols, value
from cirq._import import LazyLoader
from cirq._compat import __cirq_debug__, cached_method
from cirq._compat import __cirq_debug__, _method_cache_name, cached_method
from cirq.type_workarounds import NotImplementedType
from cirq.ops import control_values as cv

Expand Down Expand Up @@ -115,6 +115,15 @@ def _cmp_tuple(self):
def __hash__(self) -> int:
return hash((Qid, self._comparison_key()))

def __getstate__(self) -> Dict[str, Any]:
# clear cached hash value when pickling, see #6674
state = self.__dict__
hash_attr = _method_cache_name(self.__hash__)
if hash_attr in state:
state = state.copy()
del state[hash_attr]
return state

def __eq__(self, other):
if not isinstance(other, Qid):
return NotImplemented
Expand Down
111 changes: 111 additions & 0 deletions cirq-core/cirq/protocols/hash_from_pickle_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright 2024 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import multiprocessing
import os
import pathlib
import pickle
from collections.abc import Iterator
from typing import Any, Hashable

import pytest

import cirq
from cirq.protocols.json_serialization_test import MODULE_TEST_SPECS

_EXCLUDE_JSON_FILES = (
# sympy - related objects
"cirq/protocols/json_test_data/sympy.Add.json",
"cirq/protocols/json_test_data/sympy.E.json",
"cirq/protocols/json_test_data/sympy.Equality.json",
"cirq/protocols/json_test_data/sympy.EulerGamma.json",
"cirq/protocols/json_test_data/sympy.Float.json",
"cirq/protocols/json_test_data/sympy.GreaterThan.json",
"cirq/protocols/json_test_data/sympy.Integer.json",
"cirq/protocols/json_test_data/sympy.LessThan.json",
"cirq/protocols/json_test_data/sympy.Mul.json",
"cirq/protocols/json_test_data/sympy.Pow.json",
"cirq/protocols/json_test_data/sympy.Rational.json",
"cirq/protocols/json_test_data/sympy.StrictGreaterThan.json",
"cirq/protocols/json_test_data/sympy.StrictLessThan.json",
"cirq/protocols/json_test_data/sympy.Symbol.json",
"cirq/protocols/json_test_data/sympy.Unequality.json",
"cirq/protocols/json_test_data/sympy.pi.json",
# RigettiQCSAspenDevice does not pickle
"cirq_rigetti/json_test_data/RigettiQCSAspenDevice.json",
# TODO(#6674,pavoljuhas) - fix pickling of ProjectorSum
"cirq/protocols/json_test_data/ProjectorSum.json",
)


def _is_included(json_filename: str) -> bool:
json_posix_path = pathlib.PurePath(json_filename).as_posix()
if any(json_posix_path.endswith(t) for t in _EXCLUDE_JSON_FILES):
return False
if not os.path.isfile(json_filename):
return False
return True


@pytest.fixture(scope='module')
def pool() -> Iterator[multiprocessing.pool.Pool]:
ctx = multiprocessing.get_context("spawn")
with ctx.Pool(1) as pool:
yield pool


def _read_json(json_filename: str) -> Any:
obj = cirq.read_json(json_filename)
obj = obj[0] if isinstance(obj, list) else obj
# trigger possible caching of the hash value
if isinstance(obj, Hashable):
_ = hash(obj)
return obj


def test_exclude_json_files_has_valid_entries() -> None:
"""Verify _EXCLUDE_JSON_FILES has valid entries."""
# do not check rigetti files if not installed
skip_rigetti = all(m.name != "cirq_rigetti" for m in MODULE_TEST_SPECS)
json_file_validates = lambda f: any(
m.test_data_path.joinpath(os.path.basename(f)).is_file() for m in MODULE_TEST_SPECS
) or (skip_rigetti and f.startswith("cirq_rigetti/"))
invalid_json_paths = [f for f in _EXCLUDE_JSON_FILES if not json_file_validates(f)]
assert invalid_json_paths == []


@pytest.mark.parametrize(
'json_filename',
[
f"{abs_path}.json"
for m in MODULE_TEST_SPECS
for abs_path in m.all_test_data_keys()
if _is_included(f"{abs_path}.json")
],
)
def test_hash_from_pickle(json_filename: str, pool: multiprocessing.pool.Pool):
obj_local = _read_json(json_filename)
if not isinstance(obj_local, Hashable):
return
# check if pickling works in the main process for the sake of debugging
obj_copy = pickle.loads(pickle.dumps(obj_local))
assert obj_copy == obj_local
assert hash(obj_copy) == hash(obj_local)
# Read and hash the object in a separate worker process and then
# send it back which requires pickling and unpickling.
obj_worker = pool.apply(_read_json, [json_filename])
assert obj_worker == obj_local
assert hash(obj_worker) == hash(obj_local)
11 changes: 10 additions & 1 deletion cirq-core/cirq/qis/clifford_tableau.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np

from cirq import protocols
from cirq._compat import proper_repr, cached_method
from cirq._compat import proper_repr, _method_cache_name, cached_method
from cirq.qis import quantum_state_representation
from cirq.value import big_endian_int_to_digits, linear_dict, random_state

Expand Down Expand Up @@ -658,3 +658,12 @@ def measure(
@cached_method
def __hash__(self) -> int:
return hash(self.matrix().tobytes() + self.rs.tobytes())

def __getstate__(self) -> Dict[str, Any]:
# clear cached hash value when pickling, see #6674
state = self.__dict__
hash_attr = _method_cache_name(self.__hash__)
if hash_attr in state:
state = state.copy()
del state[hash_attr]
return state
2 changes: 1 addition & 1 deletion cirq-core/cirq/sim/clifford/clifford_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def state(self):
return self._clifford_state


@value.value_equality
@value.value_equality(unhashable=True)
class CliffordState:
"""A state of the Clifford simulation.
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/sim/clifford/stabilizer_state_ch_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from cirq.value import big_endian_int_to_digits, random_state


@value.value_equality
@value.value_equality(unhashable=True)
class StabilizerStateChForm(qis.StabilizerState):
r"""A representation of stabilizer states using the CH form,
Expand Down
8 changes: 8 additions & 0 deletions cirq-core/cirq/study/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,14 @@ def __hash__(self) -> int:
self._param_hash = hash(frozenset(self._param_dict.items()))
return self._param_hash

def __getstate__(self) -> Dict[str, Any]:
# clear cached hash value when pickling, see #6674
state = self.__dict__
if state["_param_hash"] is not None:
state = state.copy()
state["_param_hash"] = None
return state

def __eq__(self, other):
if not isinstance(other, ParamResolver):
return NotImplemented
Expand Down
10 changes: 9 additions & 1 deletion cirq-core/cirq/value/measurement_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import FrozenSet, Mapping, Optional, Tuple
from typing import Any, Dict, FrozenSet, Mapping, Optional, Tuple

import dataclasses

Expand Down Expand Up @@ -77,6 +77,14 @@ def __hash__(self):
object.__setattr__(self, '_hash', hash(str(self)))
return self._hash

def __getstate__(self) -> Dict[str, Any]:
# clear cached hash value when pickling, see #6674
state = self.__dict__
if "_hash" in state:
state = state.copy()
del state["_hash"]
return state

def __lt__(self, other):
if isinstance(other, MeasurementKey):
if self.path != other.path:
Expand Down
14 changes: 13 additions & 1 deletion cirq-core/cirq/value/value_equality_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Defines `@cirq.value_equality`, for easy __eq__/__hash__ methods."""

from typing import Any, Callable, Optional, overload, Union
from typing import Any, Callable, Dict, Optional, overload, Union

from typing_extensions import Protocol

Expand Down Expand Up @@ -110,6 +110,16 @@ def _value_equality_approx_eq(
)


def _value_equality_getstate(self: _SupportsValueEquality) -> Dict[str, Any]:
# clear cached hash value when pickling, see #6674
state = self.__dict__
hash_attr = _compat._method_cache_name(self.__hash__)
if hash_attr in state:
state = state.copy()
del state[hash_attr]
return state


# pylint: disable=function-redefined
@overload
def value_equality(
Expand Down Expand Up @@ -228,6 +238,8 @@ class return the existing class' type.
cached_values_getter = values_getter if unhashable else _compat.cached_method(values_getter)
setattr(cls, '_value_equality_values_', cached_values_getter)
setattr(cls, '__hash__', None if unhashable else _compat.cached_method(_value_equality_hash))
if not unhashable:
setattr(cls, '__getstate__', _value_equality_getstate)
setattr(cls, '__eq__', _value_equality_eq)
setattr(cls, '__ne__', _value_equality_ne)

Expand Down
3 changes: 3 additions & 0 deletions cirq-core/cirq/work/observable_measurement_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ class ObservableMeasuredResult:
repetitions: int
circuit_params: Mapping[Union[str, sympy.Expr], Union[value.Scalar, sympy.Expr]]

# unhashable because of the mapping-type circuit_params attribute
__hash__ = None # type: ignore

def __repr__(self):
# I wish we could use the default dataclass __repr__ but
# we need to prefix our class name with `cirq.work.`
Expand Down
Loading

0 comments on commit 5377fff

Please sign in to comment.