diff --git a/cirq-google/cirq_google/api/v2/results.py b/cirq-google/cirq_google/api/v2/results.py index 1beaa43c227..ff467fe64b7 100644 --- a/cirq-google/cirq_google/api/v2/results.py +++ b/cirq-google/cirq_google/api/v2/results.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import cast, Dict, Hashable, Iterable, List, Optional, Sequence +from typing import cast, Dict, Hashable, Iterable, List, Optional, Sequence, Union from collections import OrderedDict import dataclasses import numpy as np @@ -20,6 +20,7 @@ from cirq_google.api import v2 from cirq_google.api.v2 import result_pb2 +GridQid = Union[cirq.GridQubit, cirq.GridQid] @dataclasses.dataclass class MeasureInfo: @@ -35,7 +36,7 @@ class MeasureInfo: """ key: str - qubits: List[cirq.GridQubit] + qubits: List[GridQid] instances: int invert_mask: List[bool] tags: List[Hashable] @@ -193,7 +194,7 @@ def _trial_sweep_from_proto( records: Dict[str, np.ndarray] = {} for mr in pr.measurement_results: instances = max(mr.instances, 1) - qubit_results: OrderedDict[cirq.GridQubit, np.ndarray] = OrderedDict() + qubit_results: OrderedDict[GridQid, np.ndarray] = OrderedDict() for qmr in mr.qubit_measurement_results: qubit = v2.grid_qubit_from_proto_id(qmr.qubit.id) if qubit in qubit_results: