From 7a9c641759003a549e33b6333f8a4fe109af516d Mon Sep 17 00:00:00 2001 From: Doug Strain Date: Sat, 16 Sep 2023 10:06:38 -0700 Subject: [PATCH 1/4] Fix matplotlib typing matplotlib 3.8.0 was released this week and included typing hints. This fixes the resulting CI breakages. --- cirq-core/cirq/contrib/svg/svg.py | 4 +++- cirq-core/cirq/devices/named_topologies.py | 2 +- .../cirq/experiments/qubit_characterizations.py | 11 ++++++----- cirq-core/cirq/linalg/decompositions.py | 11 ++++++----- cirq-core/cirq/vis/heatmap.py | 11 +++++++---- cirq-core/cirq/vis/heatmap_test.py | 6 ++++++ cirq-core/cirq/vis/histogram.py | 8 ++++---- cirq-core/cirq/vis/state_histogram.py | 16 ++++++++++------ cirq-core/cirq/vis/state_histogram_test.py | 2 ++ cirq-google/cirq_google/engine/calibration.py | 5 +++-- .../cirq_google/engine/virtual_engine_factory.py | 2 +- examples/two_qubit_gate_compilation.py | 2 +- 12 files changed, 50 insertions(+), 30 deletions(-) diff --git a/cirq-core/cirq/contrib/svg/svg.py b/cirq-core/cirq/contrib/svg/svg.py index be7a1d60c56..3a9be84bdeb 100644 --- a/cirq-core/cirq/contrib/svg/svg.py +++ b/cirq-core/cirq/contrib/svg/svg.py @@ -2,12 +2,14 @@ from typing import TYPE_CHECKING, List, Tuple, cast, Dict import matplotlib.textpath +import matplotlib.font_manager + if TYPE_CHECKING: import cirq QBLUE = '#1967d2' -FONT = "Arial" +FONT = matplotlib.font_manager.FontProperties(family="Arial") EMPTY_MOMENT_COLWIDTH = float(21) # assumed default column width diff --git a/cirq-core/cirq/devices/named_topologies.py b/cirq-core/cirq/devices/named_topologies.py index 5f32d8b1d5d..6aa46e19e94 100644 --- a/cirq-core/cirq/devices/named_topologies.py +++ b/cirq-core/cirq/devices/named_topologies.py @@ -74,7 +74,7 @@ def _node_and_coordinates( def draw_gridlike( - graph: nx.Graph, ax: plt.Axes = None, tilted: bool = True, **kwargs + graph: nx.Graph, ax: Optional[plt.Axes] = None, tilted: bool = True, **kwargs ) -> Dict[Any, Tuple[int, int]]: """Draw a grid-like graph using Matplotlib. diff --git a/cirq-core/cirq/experiments/qubit_characterizations.py b/cirq-core/cirq/experiments/qubit_characterizations.py index 114e2e28659..aeedf123e50 100644 --- a/cirq-core/cirq/experiments/qubit_characterizations.py +++ b/cirq-core/cirq/experiments/qubit_characterizations.py @@ -15,13 +15,13 @@ import dataclasses import itertools -from typing import Any, Iterator, List, Optional, Sequence, Tuple, TYPE_CHECKING +from typing import Any, cast, Iterator, List, Optional, Sequence, Tuple, TYPE_CHECKING import numpy as np from matplotlib import pyplot as plt # this is for older systems with matplotlib <3.2 otherwise 3d projections fail -from mpl_toolkits import mplot3d # pylint: disable=unused-import +from mpl_toolkits import mplot3d from cirq import circuits, ops, protocols if TYPE_CHECKING: @@ -89,8 +89,9 @@ def plot(self, ax: Optional[plt.Axes] = None, **plot_kwargs: Any) -> plt.Axes: """ show_plot = not ax if not ax: - fig, ax = plt.subplots(1, 1, figsize=(8, 8)) - ax.set_ylim([0, 1]) + fig, ax = plt.subplots(1, 1, figsize=(8, 8)) # pragma: no cover + ax = cast(plt.Axes, ax) # pragma: no cover + ax.set_ylim((0.0, 1.0)) # pragma: no cover ax.plot(self._num_cfds_seq, self._gnd_state_probs, 'ro-', **plot_kwargs) ax.set_xlabel(r"Number of Cliffords") ax.set_ylabel('Ground State Probability') @@ -541,7 +542,7 @@ def _find_inv_matrix(mat: np.ndarray, mat_sequence: np.ndarray) -> int: def _matrix_bar_plot( mat: np.ndarray, z_label: str, - ax: plt.Axes, + ax: mplot3d.axes3d.Axes3D, kets: Optional[Sequence[str]] = None, title: Optional[str] = None, ylim: Tuple[int, int] = (-1, 1), diff --git a/cirq-core/cirq/linalg/decompositions.py b/cirq-core/cirq/linalg/decompositions.py index 60dc0123640..43434ff4d1b 100644 --- a/cirq-core/cirq/linalg/decompositions.py +++ b/cirq-core/cirq/linalg/decompositions.py @@ -20,6 +20,7 @@ from typing import ( Any, Callable, + cast, Iterable, List, Optional, @@ -33,7 +34,7 @@ import matplotlib.pyplot as plt # this is for older systems with matplotlib <3.2 otherwise 3d projections fail -from mpl_toolkits import mplot3d # pylint: disable=unused-import +from mpl_toolkits import mplot3d import numpy as np from cirq import value, protocols @@ -554,7 +555,7 @@ def scatter_plot_normalized_kak_interaction_coefficients( interactions: Iterable[Union[np.ndarray, 'cirq.SupportsUnitary', 'KakDecomposition']], *, include_frame: bool = True, - ax: Optional[plt.Axes] = None, + ax: Optional[mplot3d.axes3d.Axes3D] = None, **kwargs, ): r"""Plots the interaction coefficients of many two-qubit operations. @@ -633,13 +634,13 @@ def scatter_plot_normalized_kak_interaction_coefficients( show_plot = not ax if not ax: fig = plt.figure() - ax = fig.add_subplot(1, 1, 1, projection='3d') + ax = cast(mplot3d.axes3d.Axes3D, fig.add_subplot(1, 1, 1, projection='3d')) def coord_transform( pts: Union[List[Tuple[int, int, int]], np.ndarray] - ) -> Tuple[Iterable[float], Iterable[float], Iterable[float]]: + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: if len(pts) == 0: - return [], [], [] + return np.array([]), np.array([]), np.array([]) xs, ys, zs = np.transpose(pts) return xs, zs, ys diff --git a/cirq-core/cirq/vis/heatmap.py b/cirq-core/cirq/vis/heatmap.py index e5598f59450..e672a2b8c27 100644 --- a/cirq-core/cirq/vis/heatmap.py +++ b/cirq-core/cirq/vis/heatmap.py @@ -15,6 +15,7 @@ from dataclasses import astuple, dataclass from typing import ( Any, + cast, Dict, List, Mapping, @@ -217,7 +218,7 @@ def _plot_colorbar( ) position = self._config['colorbar_position'] orien = 'vertical' if position in ('left', 'right') else 'horizontal' - colorbar = ax.figure.colorbar( + colorbar = cast(plt.Figure, ax.figure).colorbar( mappable, colorbar_ax, ax, orientation=orien, **self._config.get("colorbar_options", {}) ) colorbar_ax.tick_params(axis='y', direction='out') @@ -230,15 +231,15 @@ def _write_annotations( ax: plt.Axes, ) -> None: """Writes annotations to the center of cells. Internal.""" - for (center, annotation), facecolor in zip(centers_and_annot, collection.get_facecolors()): + for (center, annotation), facecolor in zip(centers_and_annot, collection.get_facecolor()): # Calculate the center of the cell, assuming that it is a square # centered at (x=col, y=row). if not annotation: continue x, y = center - face_luminance = vis_utils.relative_luminance(facecolor) + face_luminance = vis_utils.relative_luminance(facecolor) # type: ignore text_color = 'black' if face_luminance > 0.4 else 'white' - text_kwargs = dict(color=text_color, ha="center", va="center") + text_kwargs: Dict[str, Any] = dict(color=text_color, ha="center", va="center") text_kwargs.update(self._config.get('annotation_text_kwargs', {})) ax.text(x, y, annotation, **text_kwargs) @@ -295,6 +296,7 @@ def plot( show_plot = not ax if not ax: fig, ax = plt.subplots(figsize=(8, 8)) + ax = cast(plt.Axes, ax) original_config = copy.deepcopy(self._config) self.update_config(**kwargs) collection = self._plot_on_axis(ax) @@ -381,6 +383,7 @@ def plot( show_plot = not ax if not ax: fig, ax = plt.subplots(figsize=(8, 8)) + ax = cast(plt.Axes, ax) original_config = copy.deepcopy(self._config) self.update_config(**kwargs) qubits = set([q for qubits in self._value_map.keys() for q in qubits]) diff --git a/cirq-core/cirq/vis/heatmap_test.py b/cirq-core/cirq/vis/heatmap_test.py index 1ca493386f5..cd2dff4712a 100644 --- a/cirq-core/cirq/vis/heatmap_test.py +++ b/cirq-core/cirq/vis/heatmap_test.py @@ -33,6 +33,10 @@ def ax(): figure = mpl.figure.Figure() return figure.add_subplot(111) +def test_default_ax(): + row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8)) + test_value_map = {grid_qubit.GridQubit(row, col):np.random.random() for (row, col) in row_col_list} + _, _ = heatmap.Heatmap(test_value_map).plot() @pytest.mark.parametrize('tuple_keys', [True, False]) def test_cells_positions(ax, tuple_keys): @@ -61,6 +65,8 @@ def test_two_qubit_heatmap(ax): title = "Two Qubit Interaction Heatmap" heatmap.TwoQubitInteractionHeatmap(value_map, title=title).plot(ax) assert ax.get_title() == title + # Test default axis + heatmap.TwoQubitInteractionHeatmap(value_map, title=title).plot() def test_invalid_args(): diff --git a/cirq-core/cirq/vis/histogram.py b/cirq-core/cirq/vis/histogram.py index f3b0a8047bc..88349097a97 100644 --- a/cirq-core/cirq/vis/histogram.py +++ b/cirq-core/cirq/vis/histogram.py @@ -100,9 +100,9 @@ def integrated_histogram( plot_options.update(kwargs) if cdf_on_x: - ax.step(bin_values, parameter_values, **plot_options) + ax.step(bin_values, parameter_values, **plot_options) # type: ignore else: - ax.step(parameter_values, bin_values, **plot_options) + ax.step(parameter_values, bin_values, **plot_options) # type: ignore set_semilog = ax.semilogy if cdf_on_x else ax.semilogx set_lim = ax.set_xlim if cdf_on_x else ax.set_ylim @@ -128,7 +128,7 @@ def integrated_histogram( if median_line: set_line( - np.median(float_data), + float(np.median(float_data)), linestyle='--', color=plot_options['color'], alpha=0.5, @@ -136,7 +136,7 @@ def integrated_histogram( ) if mean_line: set_line( - np.mean(float_data), + float(np.mean(float_data)), linestyle='-.', color=plot_options['color'], alpha=0.5, diff --git a/cirq-core/cirq/vis/state_histogram.py b/cirq-core/cirq/vis/state_histogram.py index 51ccfc5f073..3a3706cf04f 100644 --- a/cirq-core/cirq/vis/state_histogram.py +++ b/cirq-core/cirq/vis/state_histogram.py @@ -14,7 +14,7 @@ """Tool to visualize the results of a study.""" -from typing import Union, Optional, Sequence, SupportsFloat +from typing import cast, Optional, Sequence, SupportsFloat, Union import collections import numpy as np import matplotlib.pyplot as plt @@ -51,13 +51,13 @@ def get_state_histogram(result: 'result.Result') -> np.ndarray: def plot_state_histogram( data: Union['result.Result', collections.Counter, Sequence[SupportsFloat]], - ax: Optional['plt.Axis'] = None, + ax: Optional[plt.Axes] = None, *, tick_label: Optional[Sequence[str]] = None, xlabel: Optional[str] = 'qubit state', ylabel: Optional[str] = 'result count', title: Optional[str] = 'Result State Histogram', -) -> 'plt.Axis': +) -> plt.Axes: """Plot the state histogram from either a single result with repetitions or a histogram computed using `result.histogram()` or a flattened histogram of measurement results computed using `get_state_histogram`. @@ -87,6 +87,7 @@ def plot_state_histogram( show_fig = not ax if not ax: fig, ax = plt.subplots(1, 1) + ax = cast(plt.Axes, ax) if isinstance(data, result.Result): values = get_state_histogram(data) elif isinstance(data, collections.Counter): @@ -96,9 +97,12 @@ def plot_state_histogram( if tick_label is None: tick_label = [str(i) for i in range(len(values))] ax.bar(np.arange(len(values)), values, tick_label=tick_label) - ax.set_xlabel(xlabel) - ax.set_ylabel(ylabel) - ax.set_title(title) + if xlabel: + ax.set_xlabel(xlabel) + if ylabel: + ax.set_ylabel(ylabel) + if title: + ax.set_title(title) if show_fig: fig.show() return ax diff --git a/cirq-core/cirq/vis/state_histogram_test.py b/cirq-core/cirq/vis/state_histogram_test.py index 220030d0e81..a922b12b1ff 100644 --- a/cirq-core/cirq/vis/state_histogram_test.py +++ b/cirq-core/cirq/vis/state_histogram_test.py @@ -78,6 +78,8 @@ def test_plot_state_histogram_result(): for r1, r2 in zip(ax1.get_children(), ax2.get_children()): if isinstance(r1, mpl.patches.Rectangle) and isinstance(r2, mpl.patches.Rectangle): assert str(r1) == str(r2) + # Test default axis + state_histogram.plot_state_histogram(expected_values) @pytest.mark.usefixtures('closefigures') diff --git a/cirq-google/cirq_google/engine/calibration.py b/cirq-google/cirq_google/engine/calibration.py index d28434da6c0..8e0ac4c1560 100644 --- a/cirq-google/cirq_google/engine/calibration.py +++ b/cirq-google/cirq_google/engine/calibration.py @@ -17,7 +17,7 @@ from collections import abc, defaultdict import datetime from itertools import cycle -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Sequence +from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union, Sequence import matplotlib as mpl import matplotlib.pyplot as plt @@ -277,6 +277,7 @@ def plot_histograms( show_plot = not ax if not ax: fig, ax = plt.subplots(1, 1) + ax = cast(plt.Axes, ax) if isinstance(keys, str): keys = [keys] @@ -322,7 +323,7 @@ def plot( show_plot = not fig if not fig: fig = plt.figure() - axs = fig.subplots(1, 2) + axs = cast(List[plt.Axes], fig.subplots(1, 2)) self.heatmap(key).plot(axs[0]) self.plot_histograms(key, axs[1]) if show_plot: diff --git a/cirq-google/cirq_google/engine/virtual_engine_factory.py b/cirq-google/cirq_google/engine/virtual_engine_factory.py index 451db1e00fe..79c02535565 100644 --- a/cirq-google/cirq_google/engine/virtual_engine_factory.py +++ b/cirq-google/cirq_google/engine/virtual_engine_factory.py @@ -402,7 +402,7 @@ def create_default_noisy_quantum_virtual_machine( if simulator_class is None: try: # pragma: no cover - import qsimcirq # type: ignore + import qsimcirq simulator_class = qsimcirq.QSimSimulator # pragma: no cover except ImportError: diff --git a/examples/two_qubit_gate_compilation.py b/examples/two_qubit_gate_compilation.py index 2dd1a9e3260..31d360ec79f 100644 --- a/examples/two_qubit_gate_compilation.py +++ b/examples/two_qubit_gate_compilation.py @@ -88,7 +88,7 @@ def main(samples: int = 1000, max_infidelity: float = 0.01): print(f'Maximum infidelity of "failed" compilation: {np.max(failed_infidelities_arr)}') plt.figure() - plt.hist(infidelities_arr, bins=25, range=[0, max_infidelity * 1.1]) + plt.hist(infidelities_arr, bins=25, range=(0.0, max_infidelity * 1.1)) # pragma: no cover ylim = plt.ylim() plt.plot([max_infidelity] * 2, ylim, '--', label='Maximum tabulation infidelity') plt.xlabel('Compiled gate infidelity vs target') From 62623b99796e21cd0c93b651f868542a07891820 Mon Sep 17 00:00:00 2001 From: Doug Strain Date: Sat, 16 Sep 2023 10:16:06 -0700 Subject: [PATCH 2/4] Fix issues. --- cirq-core/cirq/experiments/qubit_characterizations.py | 6 +++--- cirq-core/cirq/vis/heatmap_test.py | 6 +++++- cirq-google/cirq_google/engine/virtual_engine_factory.py | 2 +- examples/two_qubit_gate_compilation.py | 2 +- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/cirq-core/cirq/experiments/qubit_characterizations.py b/cirq-core/cirq/experiments/qubit_characterizations.py index aeedf123e50..ed12b311e22 100644 --- a/cirq-core/cirq/experiments/qubit_characterizations.py +++ b/cirq-core/cirq/experiments/qubit_characterizations.py @@ -89,9 +89,9 @@ def plot(self, ax: Optional[plt.Axes] = None, **plot_kwargs: Any) -> plt.Axes: """ show_plot = not ax if not ax: - fig, ax = plt.subplots(1, 1, figsize=(8, 8)) # pragma: no cover - ax = cast(plt.Axes, ax) # pragma: no cover - ax.set_ylim((0.0, 1.0)) # pragma: no cover + fig, ax = plt.subplots(1, 1, figsize=(8, 8)) # pragma: no cover + ax = cast(plt.Axes, ax) # pragma: no cover + ax.set_ylim((0.0, 1.0)) # pragma: no cover ax.plot(self._num_cfds_seq, self._gnd_state_probs, 'ro-', **plot_kwargs) ax.set_xlabel(r"Number of Cliffords") ax.set_ylabel('Ground State Probability') diff --git a/cirq-core/cirq/vis/heatmap_test.py b/cirq-core/cirq/vis/heatmap_test.py index cd2dff4712a..dceb00cff1c 100644 --- a/cirq-core/cirq/vis/heatmap_test.py +++ b/cirq-core/cirq/vis/heatmap_test.py @@ -33,11 +33,15 @@ def ax(): figure = mpl.figure.Figure() return figure.add_subplot(111) + def test_default_ax(): row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8)) - test_value_map = {grid_qubit.GridQubit(row, col):np.random.random() for (row, col) in row_col_list} + test_value_map = { + grid_qubit.GridQubit(row, col): np.random.random() for (row, col) in row_col_list + } _, _ = heatmap.Heatmap(test_value_map).plot() + @pytest.mark.parametrize('tuple_keys', [True, False]) def test_cells_positions(ax, tuple_keys): row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8)) diff --git a/cirq-google/cirq_google/engine/virtual_engine_factory.py b/cirq-google/cirq_google/engine/virtual_engine_factory.py index 79c02535565..d5fe9a469bf 100644 --- a/cirq-google/cirq_google/engine/virtual_engine_factory.py +++ b/cirq-google/cirq_google/engine/virtual_engine_factory.py @@ -402,7 +402,7 @@ def create_default_noisy_quantum_virtual_machine( if simulator_class is None: try: # pragma: no cover - import qsimcirq + import qsimcirq # type: ignore simulator_class = qsimcirq.QSimSimulator # pragma: no cover except ImportError: diff --git a/examples/two_qubit_gate_compilation.py b/examples/two_qubit_gate_compilation.py index 31d360ec79f..9362ce9c12c 100644 --- a/examples/two_qubit_gate_compilation.py +++ b/examples/two_qubit_gate_compilation.py @@ -88,7 +88,7 @@ def main(samples: int = 1000, max_infidelity: float = 0.01): print(f'Maximum infidelity of "failed" compilation: {np.max(failed_infidelities_arr)}') plt.figure() - plt.hist(infidelities_arr, bins=25, range=(0.0, max_infidelity * 1.1)) # pragma: no cover + plt.hist(infidelities_arr, bins=25, range=(0.0, max_infidelity * 1.1)) # pragma: no cover ylim = plt.ylim() plt.plot([max_infidelity] * 2, ylim, '--', label='Maximum tabulation infidelity') plt.xlabel('Compiled gate infidelity vs target') From 57144dcbdf89a00cc56951a2221df8b0d21ff65c Mon Sep 17 00:00:00 2001 From: Doug Strain Date: Sat, 16 Sep 2023 10:20:57 -0700 Subject: [PATCH 3/4] formatting --- cirq-google/cirq_google/engine/virtual_engine_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-google/cirq_google/engine/virtual_engine_factory.py b/cirq-google/cirq_google/engine/virtual_engine_factory.py index d5fe9a469bf..451db1e00fe 100644 --- a/cirq-google/cirq_google/engine/virtual_engine_factory.py +++ b/cirq-google/cirq_google/engine/virtual_engine_factory.py @@ -402,7 +402,7 @@ def create_default_noisy_quantum_virtual_machine( if simulator_class is None: try: # pragma: no cover - import qsimcirq # type: ignore + import qsimcirq # type: ignore simulator_class = qsimcirq.QSimSimulator # pragma: no cover except ImportError: From 4a9de60b035e7a3d7829320d6aa2b79863761112 Mon Sep 17 00:00:00 2001 From: Doug Strain Date: Sat, 16 Sep 2023 18:59:27 -0700 Subject: [PATCH 4/4] Change to seaborn v0_8 --- docs/experiments/textbook_algorithms.ipynb | 2 +- docs/start/intro.ipynb | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/experiments/textbook_algorithms.ipynb b/docs/experiments/textbook_algorithms.ipynb index 182a91e5ff2..9bec52408b1 100644 --- a/docs/experiments/textbook_algorithms.ipynb +++ b/docs/experiments/textbook_algorithms.ipynb @@ -1010,7 +1010,7 @@ "outputs": [], "source": [ "\"\"\"Plot the results.\"\"\"\n", - "plt.style.use(\"seaborn-whitegrid\")\n", + "plt.style.use(\"seaborn-v0_8-whitegrid\")\n", "\n", "plt.plot(nvals, estimates, \"--o\", label=\"Phase estimation\")\n", "plt.axhline(theta, label=\"True value\", color=\"black\")\n", diff --git a/docs/start/intro.ipynb b/docs/start/intro.ipynb index 6929b08fce3..42599d0cfe2 100644 --- a/docs/start/intro.ipynb +++ b/docs/start/intro.ipynb @@ -1453,7 +1453,7 @@ " probs.append(prob[0])\n", "\n", "# Plot the probability of the ground state at each simulation step.\n", - "plt.style.use('seaborn-whitegrid')\n", + "plt.style.use('seaborn-v0_8-whitegrid')\n", "plt.plot(probs, 'o')\n", "plt.xlabel(\"Step\")\n", "plt.ylabel(\"Probability of ground state\");" @@ -1490,7 +1490,7 @@ "\n", "\n", "# Plot the probability of the ground state at each simulation step.\n", - "plt.style.use('seaborn-whitegrid')\n", + "plt.style.use('seaborn-v0_8-whitegrid')\n", "plt.plot(sampled_probs, 'o')\n", "plt.xlabel(\"Step\")\n", "plt.ylabel(\"Probability of ground state\");"