Skip to content

Commit

Permalink
nits, etc
Browse files Browse the repository at this point in the history
  • Loading branch information
eliottrosenberg committed Jan 17, 2024
1 parent 325a70f commit 82fa43b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 14 deletions.
22 changes: 9 additions & 13 deletions cirq-core/cirq/experiments/qubit_characterizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,11 @@ def _fit_exponential(self) -> Tuple[np.ndarray, np.ndarray]:
)


@dataclasses.dataclass(frozen=True)
class ParallelRandomizedBenchmarkingResult:
"""Results from a parallel randomized benchmarking experiment."""

def __init__(self, results_dictionary: Mapping['cirq.Qid', 'RandomizedBenchMarkResult']):
"""Inits ParallelRandomizedBenchmarkingResult.
Args:
results_dictionary: A dictionary containing the results for each qubit.
"""
self._results_dictionary = results_dictionary
results_dictionary: Mapping['cirq.Qid', 'RandomizedBenchMarkResult']

def plot_single_qubit(
self, qubit: 'cirq.Qid', ax: Optional[plt.Axes] = None, **plot_kwargs: Any
Expand All @@ -173,7 +168,7 @@ def plot_single_qubit(
The plt.Axes containing the plot.
"""

return self._results_dictionary[qubit].plot(ax, **plot_kwargs)
return self.results_dictionary[qubit].plot(ax, **plot_kwargs)

def pauli_error(self) -> Mapping['cirq.Qid', float]:
"""Return a dictionary of Pauli errors.
Expand All @@ -182,8 +177,8 @@ def pauli_error(self) -> Mapping['cirq.Qid', float]:
"""

return {
qubit: self._results_dictionary[qubit].pauli_error()
for qubit in self._results_dictionary
qubit: self.results_dictionary[qubit].pauli_error()
for qubit in self.results_dictionary
}

def plot_heatmap(
Expand All @@ -206,21 +201,22 @@ def plot_heatmap(
"""

pauli_errors = self.pauli_error()
pauli_errors_with_grid_qubit_keys = {}
for qubit in pauli_errors:
assert type(qubit) == grid_qubit.GridQubit, "qubits must be cirq.GridQubits"
pauli_errors_with_grid_qubit_keys[qubit] = pauli_errors[qubit] # just for typecheck

if ax is None:
_, ax = plt.subplots(dpi=200, facecolor='white')

ax, _ = cirq_heatmap.Heatmap(pauli_errors).plot( # type: ignore
ax, _ = cirq_heatmap.Heatmap(pauli_errors_with_grid_qubit_keys).plot(
ax, annotation_format=annotation_format, title=title, **plot_kwargs
)
return ax

def plot_integrated_histogram(
self,
ax: Optional[plt.Axes] = None,
*,
cdf_on_x: bool = False,
axis_label: str = 'Pauli error',
semilog: bool = True,
Expand Down Expand Up @@ -395,7 +391,7 @@ def single_qubit_randomized_benchmarking(
num_circuits=num_circuits,
repetitions=repetitions,
)
return result._results_dictionary[qubit]
return result.results_dictionary[qubit]


def parallel_single_qubit_randomized_benchmarking(
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/experiments/qubit_characterizations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_parallel_single_qubit_randomized_benchmarking():
simulator, num_clifford_range=num_cfds, repetitions=100, qubits=qubits
)
for qubit in qubits:
g_pops = np.asarray(results._results_dictionary[qubit].data)[:, 1]
g_pops = np.asarray(results.results_dictionary[qubit].data)[:, 1]
assert np.isclose(np.mean(g_pops), 1.0)
_ = results.plot_single_qubit(qubit)
pauli_errors = results.pauli_error()
Expand Down

0 comments on commit 82fa43b

Please sign in to comment.