diff --git a/cirq-core/cirq/experiments/single_qubit_readout_calibration.py b/cirq-core/cirq/experiments/single_qubit_readout_calibration.py index 882d1fcee34..003ab34f365 100644 --- a/cirq-core/cirq/experiments/single_qubit_readout_calibration.py +++ b/cirq-core/cirq/experiments/single_qubit_readout_calibration.py @@ -70,11 +70,18 @@ def plot_heatmap( **plot_kwargs: Arguments to be passed to 'cirq.Heatmap.plot()'. Returns: The two plt.Axes containing the plot. + + Raises: + ValueError if axs does not contain two plt.Axes + TypeError if qubits are not cirq.GridQubits """ if axs is None: _, axs = plt.subplots(1, 2, dpi=200, facecolor='white', figsize=(12, 4)) + else: + if not len(axes) != 2 or type(axs[0]) != plt.Axes or type(axs[1]) != plt.Axes: + raise ValueError('axs should be a length-2 tuple of plt.Axes') for ax, title, data in zip( axs, ['$|0\\rangle$ errors', '$|1\\rangle$ errors'], @@ -82,7 +89,9 @@ def plot_heatmap( ): data_with_grid_qubit_keys = {} for qubit in data: - assert type(qubit) == grid_qubit.GridQubit, "qubits must be cirq.GridQubits" + if type(qubit) != grid_qubit.GridQubit: + raise TypeError(f'{qubit} must be of type cirq.GridQubit') + cast(grid_qubit.GridQubit, qubit) data_with_grid_qubit_keys[qubit] = data[qubit] # just for typecheck _, _ = cirq_heatmap.Heatmap(data_with_grid_qubit_keys).plot( ax, annotation_format=annotation_format, title=title, **plot_kwargs