From cc1b095ae2461669bae3410c465e33d968fbffdf Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Mon, 3 Jun 2024 13:30:51 +0100 Subject: [PATCH] chore: Fix mypy lints --- .pre-commit-config.yaml | 13 +++++-------- pyproject.toml | 5 +++++ tket2-py/src/ops.rs | 7 +++++-- tket2-py/src/pattern.rs | 24 +++++++++++++++++++----- tket2-py/src/pattern/portmatching.rs | 7 +++++++ tket2-py/test/test_pass.py | 2 +- tket2-py/test/test_pauli_prop.py | 4 ++-- tket2-py/tket2/_tket2/ops.pyi | 4 ++-- tket2-py/tket2/_tket2/optimiser.pyi | 7 +++++-- tket2-py/tket2/_tket2/passes.pyi | 11 +++++++---- tket2-py/tket2/_tket2/pattern.pyi | 12 +++++++++--- tket2-py/tket2/passes.py | 18 ++++++++++-------- 12 files changed, 77 insertions(+), 37 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1736c69a3..d1b66f5c9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,14 +39,11 @@ repos: args: [--fix, --exit-non-zero-on-fix] - id: ruff-format - # The bindings for `tket2-py` do not define their types, so we need to ignore mypy for them. - # This should be re-enabled once we add .pyi files. - # - #- repo: https://github.com/pre-commit/mirrors-mypy - # rev: v1.9.0 - # hooks: - # - id: mypy - # additional_dependencies: [pydantic] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.9.0 + hooks: + - id: mypy + additional_dependencies: [pydantic] - repo: local hooks: diff --git a/pyproject.toml b/pyproject.toml index 67f83d2b7..0f9b8061a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,3 +88,8 @@ filterwarnings = "ignore::DeprecationWarning:lark.*" [tool.pyright] # Rust bindings have typing stubs but no python source code. reportMissingModuleSource = "none" + +[[tool.mypy.overrides]] +# Ignore errors in tikv-jemalloc. +module = "gen_run_tests.*" +ignore_errors = true diff --git a/tket2-py/src/ops.rs b/tket2-py/src/ops.rs index fca3176c1..a973bce00 100644 --- a/tket2-py/src/ops.rs +++ b/tket2-py/src/ops.rs @@ -173,8 +173,11 @@ impl PyPauli { } /// Check if two Pauli matrices are equal. - pub fn __eq__(&self, other: &PyPauli) -> bool { - self.p == other.p + pub fn __eq__(&self, other: &Bound) -> bool { + let Ok(other): Result<&Bound, _> = other.downcast() else { + return false; + }; + self.p == other.borrow().p } } diff --git a/tket2-py/src/pattern.rs b/tket2-py/src/pattern.rs index 028c228eb..61d433b12 100644 --- a/tket2-py/src/pattern.rs +++ b/tket2-py/src/pattern.rs @@ -8,7 +8,7 @@ use crate::utils::{create_py_exception, ConvertPyErr}; use hugr::Hugr; use pyo3::prelude::*; -use tket2::portmatching::{CircuitPattern, PatternMatcher}; +use tket2::portmatching::{CircuitPattern, PatternMatch, PatternMatcher}; /// The module definition pub fn module(py: Python<'_>) -> PyResult> { @@ -80,12 +80,26 @@ impl RuleMatcher { pub fn find_match(&self, target: &Tk2Circuit) -> PyResult> { let h = &target.hugr; - if let Some(p_match) = self.matcher.find_matches_iter(h).next() { - let r = self.rights.get(p_match.pattern_id().0).unwrap().clone(); - let rw = p_match.to_rewrite(h, r).convert_pyerrs()?; - Ok(Some(rw.into())) + if let Some(pmatch) = self.matcher.find_matches_iter(h).next() { + Ok(Some(self.match_to_rewrite(pmatch, h)?)) } else { Ok(None) } } + + pub fn find_matches(&self, target: &Tk2Circuit) -> PyResult> { + let h = &target.hugr; + self.matcher + .find_matches_iter(h) + .map(|m| self.match_to_rewrite(m, h)) + .collect() + } +} + +impl RuleMatcher { + fn match_to_rewrite(&self, pmatch: PatternMatch, target: &Hugr) -> PyResult { + let r = self.rights.get(pmatch.pattern_id().0).unwrap().clone(); + let rw = pmatch.to_rewrite(target, r).convert_pyerrs()?; + Ok(rw.into()) + } } diff --git a/tket2-py/src/pattern/portmatching.rs b/tket2-py/src/pattern/portmatching.rs index 4ffc314db..4ed90b2e2 100644 --- a/tket2-py/src/pattern/portmatching.rs +++ b/tket2-py/src/pattern/portmatching.rs @@ -78,6 +78,13 @@ impl PyPatternMatcher { Ok(format!("{:?}", self.matcher)) } + /// Find one convex match in a circuit. + pub fn find_match(&self, circ: &Bound) -> PyResult> { + with_hugr(circ, |circ, _| { + self.matcher.find_matches_iter(&circ).next().map(Into::into) + }) + } + /// Find all convex matches in a circuit. pub fn find_matches(&self, circ: &Bound) -> PyResult> { with_hugr(circ, |circ, _| { diff --git a/tket2-py/test/test_pass.py b/tket2-py/test/test_pass.py index 6022b2695..0ce16c6e5 100644 --- a/tket2-py/test/test_pass.py +++ b/tket2-py/test/test_pass.py @@ -42,7 +42,7 @@ def circuits( def test_simple_badger_pass_no_opt(): c = Circuit(3).CCX(0, 1, 2) - badger = badger_pass(max_threads=1, timeout=0) + badger = badger_pass(max_threads=1, timeout=0, rebase=True) badger.apply(c) assert c.n_gates_of_type(OpType.CX) == 6 diff --git a/tket2-py/test/test_pauli_prop.py b/tket2-py/test/test_pauli_prop.py index a41d84853..f33d69b6b 100644 --- a/tket2-py/test/test_pauli_prop.py +++ b/tket2-py/test/test_pauli_prop.py @@ -84,8 +84,8 @@ def measure_rules() -> list[Rule]: r_build = Dfg([QB_T], [QB_T, BOOL_T]) qs = r_build.inputs() qs = r_build.add_op(PauliZ.op(), qs).outs(1) - qs, b = r_build.add_op(Measure, qs).outs(2) - ltk = r_build.finish([qs, b]) + q, b = r_build.add_op(Measure, qs).outs(2) + ltk = r_build.finish([q, b]) r_build = Dfg([QB_T], [QB_T, BOOL_T]) qs = r_build.inputs() diff --git a/tket2-py/tket2/_tket2/ops.pyi b/tket2-py/tket2/_tket2/ops.pyi index 7bb2ce659..81bca46a2 100644 --- a/tket2-py/tket2/_tket2/ops.pyi +++ b/tket2-py/tket2/_tket2/ops.pyi @@ -1,5 +1,5 @@ from enum import Enum -from typing import Iterable +from typing import Any, Iterable class Tk2Op(Enum): """A rust-backed Tket2 built-in operation.""" @@ -57,4 +57,4 @@ class Pauli(Enum): def __str__(self) -> str: """Get the string name of the Pauli.""" - def __eq__(self, value: Pauli) -> bool: ... + def __eq__(self, value: Any) -> bool: ... diff --git a/tket2-py/tket2/_tket2/optimiser.pyi b/tket2-py/tket2/_tket2/optimiser.pyi index c511b9e10..aef94ad0c 100644 --- a/tket2-py/tket2/_tket2/optimiser.pyi +++ b/tket2-py/tket2/_tket2/optimiser.pyi @@ -1,8 +1,11 @@ +from typing import TypeVar from .circuit import Tk2Circuit from pytket._tket.circuit import Circuit from pathlib import Path +CircuitClass = TypeVar("CircuitClass", Circuit, Tk2Circuit) + class BadgerOptimiser: @staticmethod def load_precompiled(filename: Path) -> BadgerOptimiser: @@ -14,14 +17,14 @@ class BadgerOptimiser: def optimise( self, - circ: Tk2Circuit | Circuit, + circ: CircuitClass, timeout: int | None = None, progress_timeout: int | None = None, n_threads: int | None = None, split_circ: bool = False, queue_size: int | None = None, log_progress: Path | None = None, - ) -> Tk2Circuit | Circuit: + ) -> CircuitClass: """Optimise a circuit. :param circ: The circuit to optimise. diff --git a/tket2-py/tket2/_tket2/passes.pyi b/tket2-py/tket2/_tket2/passes.pyi index a334a17d1..1e618e081 100644 --- a/tket2-py/tket2/_tket2/passes.pyi +++ b/tket2-py/tket2/_tket2/passes.pyi @@ -1,9 +1,12 @@ from pathlib import Path +from typing import TypeVar from .optimiser import BadgerOptimiser from .circuit import Tk2Circuit from pytket._tket.circuit import Circuit +CircuitClass = TypeVar("CircuitClass", Circuit, Tk2Circuit) + class CircuitChunks: def reassemble(self) -> Circuit | Tk2Circuit: """Reassemble the circuit from its chunks.""" @@ -17,21 +20,21 @@ class CircuitChunks: class PullForwardError(Exception): """Error from a `PullForward` operation.""" -def greedy_depth_reduce(circ: Circuit | Tk2Circuit) -> tuple[Circuit | Tk2Circuit, int]: +def greedy_depth_reduce(circ: CircuitClass) -> tuple[CircuitClass, int]: """Greedy depth reduction of a circuit. Returns the reduced circuit and the depth reduction. """ def badger_optimise( - circ: Circuit | Tk2Circuit, + circ: CircuitClass, optimiser: BadgerOptimiser, max_threads: int | None = None, timeout: int | None = None, progress_timeout: int | None = None, log_dir: Path | None = None, - rebase: bool = False, -) -> Circuit | Tk2Circuit: + rebase: bool | None = False, +) -> CircuitClass: """Optimise a circuit using the Badger optimiser. HyperTKET's best attempt at optimising a circuit using circuit rewriting diff --git a/tket2-py/tket2/_tket2/pattern.pyi b/tket2-py/tket2/_tket2/pattern.pyi index 662552400..a6e62d8dd 100644 --- a/tket2-py/tket2/_tket2/pattern.pyi +++ b/tket2-py/tket2/_tket2/pattern.pyi @@ -1,4 +1,4 @@ -from typing import Iterator, Optional +from typing import Iterator from .circuit import Node, Tk2Circuit from .rewrite import CircuitRewrite from pytket._tket.circuit import Circuit @@ -19,9 +19,12 @@ class RuleMatcher: def __init__(self, rules: list[Rule]) -> None: """Create a new rule matcher.""" - def find_matches(self, circ: Tk2Circuit) -> Optional[CircuitRewrite]: + def find_match(self, circ: Tk2Circuit) -> CircuitRewrite | None: """Find a match of the rules in the circuit.""" + def find_matches(self, circ: Tk2Circuit) -> list[CircuitRewrite]: + """Find all matches of the rules in the circuit.""" + class CircuitPattern: """A pattern that matches a circuit exactly.""" @@ -34,9 +37,12 @@ class PatternMatcher: def __init__(self, patterns: Iterator[CircuitPattern]) -> None: """Create a new pattern matcher.""" - def find_matches(self, circ: Circuit | Tk2Circuit) -> list[PatternMatch]: + def find_match(self, circ: Tk2Circuit) -> PatternMatch | None: """Find a match of the patterns in the circuit.""" + def find_matches(self, circ: Tk2Circuit) -> list[PatternMatch]: + """Find all matches of the patterns in the circuit.""" + class PatternMatch: """A convex pattern match in a circuit""" diff --git a/tket2-py/tket2/passes.py b/tket2-py/tket2/passes.py index 392d351a7..ac545fa12 100644 --- a/tket2-py/tket2/passes.py +++ b/tket2-py/tket2/passes.py @@ -3,7 +3,7 @@ from importlib import resources from pytket import Circuit -from pytket.passes import CustomPass +from pytket.passes import CustomPass, BasePass from tket2 import optimiser @@ -32,9 +32,10 @@ def badger_pass( rewriter: Optional[Path] = None, max_threads: Optional[int] = None, timeout: Optional[int] = None, + progress_timeout: Optional[int] = None, log_dir: Optional[Path] = None, - rebase: Optional[bool] = None, -) -> CustomPass: + rebase: bool = False, +) -> BasePass: """Construct a Badger pass. The Badger optimiser requires a pre-compiled rewriter produced by the @@ -54,11 +55,12 @@ def apply(circuit: Circuit) -> Circuit: """Apply Badger optimisation to the circuit.""" return badger_optimise( circuit, - opt, - max_threads, - timeout, - log_dir, - rebase, + optimiser=opt, + max_threads=max_threads, + timeout=timeout, + progress_timeout=progress_timeout, + log_dir=log_dir, + rebase=rebase, ) return CustomPass(apply)