From d143f05bd9af310bbc75cba69b875e12f8088f46 Mon Sep 17 00:00:00 2001 From: Athena Caesura Date: Fri, 29 Mar 2024 12:15:19 -0400 Subject: [PATCH] fix: pyright issues --- src/benchq/algorithms/gsee/qpe_gsee.py | 9 +++------ src/benchq/visualization_tools/plot_graph_state.py | 8 +++++--- .../visualization_tools/plot_substrate_scheduling.py | 2 +- tests/benchq/compilation/test_rbs_with_pauli_tracking.py | 3 ++- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/benchq/algorithms/gsee/qpe_gsee.py b/src/benchq/algorithms/gsee/qpe_gsee.py index 4ff42197..0dfe1bf3 100644 --- a/src/benchq/algorithms/gsee/qpe_gsee.py +++ b/src/benchq/algorithms/gsee/qpe_gsee.py @@ -1,14 +1,10 @@ import warnings import numpy as np -from orquestra.integrations.cirq.conversions import ( - to_openfermion, # pyright: ignore[reportPrivateImportUsage] -) -from orquestra.quantum.operators import PauliRepresentation from pyLIQTR.QSP.Hamiltonian import Hamiltonian from ...algorithms.data_structures import AlgorithmImplementation, ErrorBudget -from ...conversions import openfermion_to_pyliqtr +from ...conversions import SUPPORTED_OPERATORS, get_pyliqtr_operator from ...problem_embeddings.qsp import get_qsp_program @@ -17,9 +13,10 @@ def _n_block_encodings(hamiltonian: Hamiltonian, precision: float) -> int: def qpe_gsee_algorithm( - hamiltonian: PauliRepresentation, precision: float, failure_tolerance: float + hamiltonian: SUPPORTED_OPERATORS, precision: float, failure_tolerance: float ) -> AlgorithmImplementation: warnings.warn("This is experimental implementation, use at your own risk.") + hamiltonian = get_pyliqtr_operator(hamiltonian) n_block_encodings = _n_block_encodings(hamiltonian, precision) program = get_qsp_program(hamiltonian, n_block_encodings) error_budget = ErrorBudget.from_even_split(failure_tolerance) diff --git a/src/benchq/visualization_tools/plot_graph_state.py b/src/benchq/visualization_tools/plot_graph_state.py index 7749f8e4..96e96297 100644 --- a/src/benchq/visualization_tools/plot_graph_state.py +++ b/src/benchq/visualization_tools/plot_graph_state.py @@ -1,3 +1,5 @@ +from typing import List + import matplotlib.patches as mpatches import matplotlib.pyplot as plt import networkx as nx @@ -133,7 +135,7 @@ def plot_graph_state(asg, pauli_tracker): lambda x: x[1]["shape"] == aShape, graph.nodes(data=True) ) ] - colors_for_nodes_with_this_shape = [ + colors_for_nodes_with_this_shape: List[mpatches.Patch] = [ color_map[i] for i in nodes_with_this_shape if 0 <= i < len(color_map) ] @@ -213,10 +215,10 @@ def plot_graph_state(asg, pauli_tracker): plt.tight_layout() # Create a legend - red_patch = plt.Line2D( + red_patch = plt.Line2D( # pyright: ignore[reportPrivateImportUsage] [0], [0], marker="o", color="w", markerfacecolor="red", markersize=8, label="X" ) - blue_patch = plt.Line2D( + blue_patch = plt.Line2D( # pyright: ignore[reportPrivateImportUsage] [0], [0], marker="o", color="w", markerfacecolor="blue", markersize=8, label="Z" ) diff --git a/src/benchq/visualization_tools/plot_substrate_scheduling.py b/src/benchq/visualization_tools/plot_substrate_scheduling.py index ef51a7fb..6f7fa0fb 100644 --- a/src/benchq/visualization_tools/plot_substrate_scheduling.py +++ b/src/benchq/visualization_tools/plot_substrate_scheduling.py @@ -12,7 +12,7 @@ def plot_graph_state_with_measurement_steps( asg, measurement_steps, - cmap=plt.cm.rainbow, + cmap=plt.cm.rainbow, # pyright: ignore[reportAttributeAccessIssue] name="extrapolation_plot", ): """Plot a graph state with the measurement steps highlighted in different diff --git a/tests/benchq/compilation/test_rbs_with_pauli_tracking.py b/tests/benchq/compilation/test_rbs_with_pauli_tracking.py index 3efcb83a..b97034ce 100644 --- a/tests/benchq/compilation/test_rbs_with_pauli_tracking.py +++ b/tests/benchq/compilation/test_rbs_with_pauli_tracking.py @@ -98,7 +98,8 @@ def check_correctness_for_single_init( n = len(pdf) all_bitstrings = [ - format(i, f"0{full_circuit.n_qubits}b") for i in range(2**full_circuit.n_qubits) + format(i, f"0{full_circuit.n_qubits}b") + for i in range(2**full_circuit.n_qubits) ] for i in range(n):