Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add readout plotting tools #6425

Merged
merged 19 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions cirq-core/cirq/experiments/single_qubit_readout_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@

import sympy
import numpy as np
import matplotlib.pyplot as plt
import cirq.vis.heatmap as cirq_heatmap
import cirq.vis.histogram as cirq_histogram
from cirq.devices import grid_qubit
from cirq import circuits, ops, study

if TYPE_CHECKING:
Expand Down Expand Up @@ -51,6 +55,111 @@ def _json_dict_(self) -> Dict[str, Any]:
'timestamp': self.timestamp,
}

def plot_heatmap(
self,
axs: Optional[tuple[plt.Axes, plt.Axes]] = None,
annotation_format: str = '0.1%',
**plot_kwargs: Any,
) -> tuple[plt.Axes, plt.Axes]:
"""Plot a heatmap of the readout errors. If qubits are not cirq.GridQubits, throws an error.

Args:
axs: a tuple of the plt.Axes to plot on. If not given, a new figure is created,
plotted on, and shown.
annotation_format: The format string for the numbers in the heatmap.
**plot_kwargs: Arguments to be passed to 'cirq.Heatmap.plot()'.
Returns:
The two plt.Axes containing the plot.
"""

if axs is None:
_, axs = plt.subplots(1, 2, dpi=200, facecolor='white', figsize=(12, 4))

for ax, title, data in zip(
axs,
['$|0\\rangle$ errors', '$|1\\rangle$ errors'],
[self.zero_state_errors, self.one_state_errors],
):
data_with_grid_qubit_keys = {}
for qubit in data:
assert type(qubit) == grid_qubit.GridQubit, "qubits must be cirq.GridQubits"
data_with_grid_qubit_keys[qubit] = data[qubit] # just for typecheck
_, _ = cirq_heatmap.Heatmap(data_with_grid_qubit_keys).plot(
ax, annotation_format=annotation_format, title=title, **plot_kwargs
)
return axs[0], axs[1]

def plot_integrated_histogram(
self,
ax: Optional[plt.Axes] = None,
cdf_on_x: bool = False,
axis_label: str = 'Readout error rate',
semilog: bool = True,
median_line: bool = True,
median_label: Optional[str] = 'median',
mean_line: bool = False,
mean_label: Optional[str] = 'mean',
show_zero: bool = False,
title: Optional[str] = None,
**kwargs,
) -> plt.Axes:
"""Plot the readout errors using cirq.integrated_histogram().

Args:
ax: The axis to plot on. If None, we generate one.
cdf_on_x: If True, flip the axes compared the above example.
axis_label: Label for x axis (y-axis if cdf_on_x is True).
semilog: If True, force the x-axis to be logarithmic.
median_line: If True, draw a vertical line on the median value.
median_label: If drawing median line, optional label for it.
mean_line: If True, draw a vertical line on the mean value.
mean_label: If drawing mean line, optional label for it.
title: Title of the plot. If None, we assign "N={len(data)}".
show_zero: If True, moves the step plot up by one unit by prepending 0
to the data.
**kwargs: Kwargs to forward to `ax.step()`. Some examples are
color: Color of the line.
linestyle: Linestyle to use for the plot.
lw: linewidth for integrated histogram.
ms: marker size for a histogram trace.
Returns:
The axis that was plotted on.
"""

ax = cirq_histogram.integrated_histogram(
data=self.zero_state_errors,
ax=ax,
cdf_on_x=cdf_on_x,
semilog=semilog,
median_line=median_line,
median_label=median_label,
mean_line=mean_line,
mean_label=mean_label,
show_zero=show_zero,
color='C0',
label='$|0\\rangle$ errors',
**kwargs,
)
ax = cirq_histogram.integrated_histogram(
data=self.one_state_errors,
ax=ax,
cdf_on_x=cdf_on_x,
axis_label=axis_label,
semilog=semilog,
median_line=median_line,
median_label=median_label,
mean_line=mean_line,
mean_label=mean_label,
show_zero=show_zero,
title=title,
color='C1',
label='$|1\\rangle$ errors',
**kwargs,
)
ax.legend(loc='best')
ax.set_ylabel('Percentile')
return ax

@classmethod
def _from_json_dict_(
cls, zero_state_errors, one_state_errors, repetitions, timestamp, **kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_estimate_single_qubit_readout_errors_with_noise():


def test_estimate_parallel_readout_errors_no_noise():
qubits = cirq.LineQubit.range(10)
qubits = [cirq.GridQubit(i, 0) for i in range(10)]
sampler = cirq.Simulator()
repetitions = 1000
result = cirq.estimate_parallel_single_qubit_readout_errors(
Expand All @@ -97,6 +97,8 @@ def test_estimate_parallel_readout_errors_no_noise():
assert result.one_state_errors == {q: 0 for q in qubits}
assert result.repetitions == repetitions
assert isinstance(result.timestamp, float)
_ = result.plot_integrated_histogram()
_, _ = result.plot_heatmap()


def test_estimate_parallel_readout_errors_all_zeros():
Expand Down
Loading