From d423dbee66e7456b2705343865c50601d34cee66 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 20 Apr 2024 22:06:45 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20pre-commit=20fixes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mqt/bench/cli.py | 2 +- tests/test_cli.py | 158 +++++++++++++++++++++++++++++-------------- 2 files changed, 108 insertions(+), 52 deletions(-) diff --git a/src/mqt/bench/cli.py b/src/mqt/bench/cli.py index 537c1d223..ed35f299b 100644 --- a/src/mqt/bench/cli.py +++ b/src/mqt/bench/cli.py @@ -75,7 +75,7 @@ def parse_benchmark_name_and_instance(algorithm: str) -> tuple[str, str | None]: as expected by :func:`get_benchmark`. """ - if algorithm.startswith("shor_") or algorithm.startswith("groundstate_"): + if algorithm.startswith(("shor_", "groundstate_")): as_list = algorithm.split("_", 2) assert len(as_list) == 2 return cast(tuple[str, str], tuple(as_list)) diff --git a/tests/test_cli.py b/tests/test_cli.py index 11ca40605..b73561f9a 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,50 +1,94 @@ -from mqt.bench import get_benchmark, CompilerSettings, QiskitSettings -from pytest_console_scripts import ScriptRunner +from __future__ import annotations + +from typing import TYPE_CHECKING + import pytest from qiskit.qasm2 import dumps +from mqt.bench import CompilerSettings, QiskitSettings, get_benchmark + +if TYPE_CHECKING: + from pytest_console_scripts import ScriptRunner + @pytest.mark.parametrize( ("args", "expected_output"), [ - ([ - "--level", "alg", - "--algorithm", "ghz", - "--num-qubits", "10", - ], dumps(get_benchmark(level="alg", benchmark_name="ghz", circuit_size=10))), - ([ - "--level", "alg", - "--algorithm", "shor_xsmall", - "--num-qubits", "10", - ], "OPENQASM 2.0;"), # Note: shor is non-deterministic, so just a basic sanity check - ([ - "--level", "alg", - "--algorithm", "ghz", - "--num-qubits", "20", - ], dumps(get_benchmark(level="alg", benchmark_name="ghz", circuit_size=20))), - ([ - "--level", "indep", - "--algorithm", "ghz", - "--num-qubits", "20", - "--compiler", "qiskit", - ], dumps(get_benchmark(level="indep", benchmark_name="ghz", circuit_size=20, compiler="qiskit"))), - ([ - "--level", "mapped", - "--algorithm", "ghz", - "--num-qubits", "20", - "--compiler", "qiskit", - "--qiskit-optimization-level", "2", - "--native-gate-set", "ibm", - "--device", "ibm_montreal", - ], dumps(get_benchmark( - level="mapped", - benchmark_name="ghz", - circuit_size=20, - compiler="qiskit", - compiler_settings=CompilerSettings(QiskitSettings(optimization_level=2)), - provider_name="ibm", - device_name="ibm_montreal", - ))), + ( + [ + "--level", + "alg", + "--algorithm", + "ghz", + "--num-qubits", + "10", + ], + dumps(get_benchmark(level="alg", benchmark_name="ghz", circuit_size=10)), + ), + ( + [ + "--level", + "alg", + "--algorithm", + "shor_xsmall", + "--num-qubits", + "10", + ], + "OPENQASM 2.0;", + ), # Note: shor is non-deterministic, so just a basic sanity check + ( + [ + "--level", + "alg", + "--algorithm", + "ghz", + "--num-qubits", + "20", + ], + dumps(get_benchmark(level="alg", benchmark_name="ghz", circuit_size=20)), + ), + ( + [ + "--level", + "indep", + "--algorithm", + "ghz", + "--num-qubits", + "20", + "--compiler", + "qiskit", + ], + dumps(get_benchmark(level="indep", benchmark_name="ghz", circuit_size=20, compiler="qiskit")), + ), + ( + [ + "--level", + "mapped", + "--algorithm", + "ghz", + "--num-qubits", + "20", + "--compiler", + "qiskit", + "--qiskit-optimization-level", + "2", + "--native-gate-set", + "ibm", + "--device", + "ibm_montreal", + ], + dumps( + get_benchmark( + level="mapped", + benchmark_name="ghz", + circuit_size=20, + compiler="qiskit", + compiler_settings=CompilerSettings(QiskitSettings(optimization_level=2)), + provider_name="ibm", + device_name="ibm_montreal", + ) + ), + ), ], ) def test_cli(args: list[str], expected_output: str, script_runner: ScriptRunner) -> None: @@ -60,17 +104,29 @@ def test_cli(args: list[str], expected_output: str, script_runner: ScriptRunner) (["asd"], "usage: mqt.bench.cli"), (["--benchmark", "ae"], "usage: mqt.bench.cli"), # Note: We don't care about the actual error messages in most cases - ([ - "--level", "indep", - "--algorithm", "ghz", - "--num-qubits", "20", - # Missing compiler option - ], ""), - ([ - "--level", "alg", - "--algorithm", "not-a-valid-benchmark", - "--num-qubits", "20", - ], ""), + ( + [ + "--level", + "indep", + "--algorithm", + "ghz", + "--num-qubits", + "20", + # Missing compiler option + ], + "", + ), + ( + [ + "--level", + "alg", + "--algorithm", + "not-a-valid-benchmark", + "--num-qubits", + "20", + ], + "", + ), ], ) def test_cli_errors(args: list[str], expected_output: str, script_runner: ScriptRunner) -> None: