Skip to content

Commit

Permalink
chore: Fix mypy lints
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Jun 3, 2024
1 parent 4b63d81 commit cc1b095
Show file tree
Hide file tree
Showing 12 changed files with 77 additions and 37 deletions.
13 changes: 5 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,11 @@ repos:
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format

# The bindings for `tket2-py` do not define their types, so we need to ignore mypy for them.
# This should be re-enabled once we add .pyi files.
#
#- repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.9.0
# hooks:
# - id: mypy
# additional_dependencies: [pydantic]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
hooks:
- id: mypy
additional_dependencies: [pydantic]

- repo: local
hooks:
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,8 @@ filterwarnings = "ignore::DeprecationWarning:lark.*"
[tool.pyright]
# Rust bindings have typing stubs but no python source code.
reportMissingModuleSource = "none"

[[tool.mypy.overrides]]
# Ignore errors in tikv-jemalloc.
module = "gen_run_tests.*"
ignore_errors = true
7 changes: 5 additions & 2 deletions tket2-py/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,11 @@ impl PyPauli {
}

/// Check if two Pauli matrices are equal.
pub fn __eq__(&self, other: &PyPauli) -> bool {
self.p == other.p
pub fn __eq__(&self, other: &Bound<PyAny>) -> bool {
let Ok(other): Result<&Bound<PyPauli>, _> = other.downcast() else {
return false;
};
self.p == other.borrow().p
}
}

Expand Down
24 changes: 19 additions & 5 deletions tket2-py/src/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::utils::{create_py_exception, ConvertPyErr};

use hugr::Hugr;
use pyo3::prelude::*;
use tket2::portmatching::{CircuitPattern, PatternMatcher};
use tket2::portmatching::{CircuitPattern, PatternMatch, PatternMatcher};

/// The module definition
pub fn module(py: Python<'_>) -> PyResult<Bound<'_, PyModule>> {
Expand Down Expand Up @@ -80,12 +80,26 @@ impl RuleMatcher {

pub fn find_match(&self, target: &Tk2Circuit) -> PyResult<Option<PyCircuitRewrite>> {
let h = &target.hugr;
if let Some(p_match) = self.matcher.find_matches_iter(h).next() {
let r = self.rights.get(p_match.pattern_id().0).unwrap().clone();
let rw = p_match.to_rewrite(h, r).convert_pyerrs()?;
Ok(Some(rw.into()))
if let Some(pmatch) = self.matcher.find_matches_iter(h).next() {
Ok(Some(self.match_to_rewrite(pmatch, h)?))
} else {
Ok(None)
}
}

pub fn find_matches(&self, target: &Tk2Circuit) -> PyResult<Vec<PyCircuitRewrite>> {
let h = &target.hugr;
self.matcher
.find_matches_iter(h)
.map(|m| self.match_to_rewrite(m, h))
.collect()
}
}

impl RuleMatcher {
fn match_to_rewrite(&self, pmatch: PatternMatch, target: &Hugr) -> PyResult<PyCircuitRewrite> {
let r = self.rights.get(pmatch.pattern_id().0).unwrap().clone();
let rw = pmatch.to_rewrite(target, r).convert_pyerrs()?;
Ok(rw.into())
}
}
7 changes: 7 additions & 0 deletions tket2-py/src/pattern/portmatching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ impl PyPatternMatcher {
Ok(format!("{:?}", self.matcher))
}

/// Find one convex match in a circuit.
pub fn find_match(&self, circ: &Bound<PyAny>) -> PyResult<Option<PyPatternMatch>> {
with_hugr(circ, |circ, _| {
self.matcher.find_matches_iter(&circ).next().map(Into::into)
})
}

/// Find all convex matches in a circuit.
pub fn find_matches(&self, circ: &Bound<PyAny>) -> PyResult<Vec<PyPatternMatch>> {
with_hugr(circ, |circ, _| {
Expand Down
2 changes: 1 addition & 1 deletion tket2-py/test/test_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def circuits(

def test_simple_badger_pass_no_opt():
c = Circuit(3).CCX(0, 1, 2)
badger = badger_pass(max_threads=1, timeout=0)
badger = badger_pass(max_threads=1, timeout=0, rebase=True)
badger.apply(c)
assert c.n_gates_of_type(OpType.CX) == 6

Expand Down
4 changes: 2 additions & 2 deletions tket2-py/test/test_pauli_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def measure_rules() -> list[Rule]:
r_build = Dfg([QB_T], [QB_T, BOOL_T])
qs = r_build.inputs()
qs = r_build.add_op(PauliZ.op(), qs).outs(1)
qs, b = r_build.add_op(Measure, qs).outs(2)
ltk = r_build.finish([qs, b])
q, b = r_build.add_op(Measure, qs).outs(2)
ltk = r_build.finish([q, b])

r_build = Dfg([QB_T], [QB_T, BOOL_T])
qs = r_build.inputs()
Expand Down
4 changes: 2 additions & 2 deletions tket2-py/tket2/_tket2/ops.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Iterable
from typing import Any, Iterable

class Tk2Op(Enum):
"""A rust-backed Tket2 built-in operation."""
Expand Down Expand Up @@ -57,4 +57,4 @@ class Pauli(Enum):
def __str__(self) -> str:
"""Get the string name of the Pauli."""

def __eq__(self, value: Pauli) -> bool: ...
def __eq__(self, value: Any) -> bool: ...
7 changes: 5 additions & 2 deletions tket2-py/tket2/_tket2/optimiser.pyi
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import TypeVar
from .circuit import Tk2Circuit
from pytket._tket.circuit import Circuit

from pathlib import Path

CircuitClass = TypeVar("CircuitClass", Circuit, Tk2Circuit)

class BadgerOptimiser:
@staticmethod
def load_precompiled(filename: Path) -> BadgerOptimiser:
Expand All @@ -14,14 +17,14 @@ class BadgerOptimiser:

def optimise(
self,
circ: Tk2Circuit | Circuit,
circ: CircuitClass,
timeout: int | None = None,
progress_timeout: int | None = None,
n_threads: int | None = None,
split_circ: bool = False,
queue_size: int | None = None,
log_progress: Path | None = None,
) -> Tk2Circuit | Circuit:
) -> CircuitClass:
"""Optimise a circuit.
:param circ: The circuit to optimise.
Expand Down
11 changes: 7 additions & 4 deletions tket2-py/tket2/_tket2/passes.pyi
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from pathlib import Path
from typing import TypeVar

from .optimiser import BadgerOptimiser
from .circuit import Tk2Circuit
from pytket._tket.circuit import Circuit

CircuitClass = TypeVar("CircuitClass", Circuit, Tk2Circuit)

class CircuitChunks:
def reassemble(self) -> Circuit | Tk2Circuit:
"""Reassemble the circuit from its chunks."""
Expand All @@ -17,21 +20,21 @@ class CircuitChunks:
class PullForwardError(Exception):
"""Error from a `PullForward` operation."""

def greedy_depth_reduce(circ: Circuit | Tk2Circuit) -> tuple[Circuit | Tk2Circuit, int]:
def greedy_depth_reduce(circ: CircuitClass) -> tuple[CircuitClass, int]:
"""Greedy depth reduction of a circuit.
Returns the reduced circuit and the depth reduction.
"""

def badger_optimise(
circ: Circuit | Tk2Circuit,
circ: CircuitClass,
optimiser: BadgerOptimiser,
max_threads: int | None = None,
timeout: int | None = None,
progress_timeout: int | None = None,
log_dir: Path | None = None,
rebase: bool = False,
) -> Circuit | Tk2Circuit:
rebase: bool | None = False,
) -> CircuitClass:
"""Optimise a circuit using the Badger optimiser.
HyperTKET's best attempt at optimising a circuit using circuit rewriting
Expand Down
12 changes: 9 additions & 3 deletions tket2-py/tket2/_tket2/pattern.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterator, Optional
from typing import Iterator
from .circuit import Node, Tk2Circuit
from .rewrite import CircuitRewrite
from pytket._tket.circuit import Circuit
Expand All @@ -19,9 +19,12 @@ class RuleMatcher:
def __init__(self, rules: list[Rule]) -> None:
"""Create a new rule matcher."""

def find_matches(self, circ: Tk2Circuit) -> Optional[CircuitRewrite]:
def find_match(self, circ: Tk2Circuit) -> CircuitRewrite | None:
"""Find a match of the rules in the circuit."""

def find_matches(self, circ: Tk2Circuit) -> list[CircuitRewrite]:
"""Find all matches of the rules in the circuit."""

class CircuitPattern:
"""A pattern that matches a circuit exactly."""

Expand All @@ -34,9 +37,12 @@ class PatternMatcher:
def __init__(self, patterns: Iterator[CircuitPattern]) -> None:
"""Create a new pattern matcher."""

def find_matches(self, circ: Circuit | Tk2Circuit) -> list[PatternMatch]:
def find_match(self, circ: Tk2Circuit) -> PatternMatch | None:
"""Find a match of the patterns in the circuit."""

def find_matches(self, circ: Tk2Circuit) -> list[PatternMatch]:
"""Find all matches of the patterns in the circuit."""

class PatternMatch:
"""A convex pattern match in a circuit"""

Expand Down
18 changes: 10 additions & 8 deletions tket2-py/tket2/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from importlib import resources

from pytket import Circuit
from pytket.passes import CustomPass
from pytket.passes import CustomPass, BasePass

from tket2 import optimiser

Expand Down Expand Up @@ -32,9 +32,10 @@ def badger_pass(
rewriter: Optional[Path] = None,
max_threads: Optional[int] = None,
timeout: Optional[int] = None,
progress_timeout: Optional[int] = None,
log_dir: Optional[Path] = None,
rebase: Optional[bool] = None,
) -> CustomPass:
rebase: bool = False,
) -> BasePass:
"""Construct a Badger pass.
The Badger optimiser requires a pre-compiled rewriter produced by the
Expand All @@ -54,11 +55,12 @@ def apply(circuit: Circuit) -> Circuit:
"""Apply Badger optimisation to the circuit."""
return badger_optimise(
circuit,
opt,
max_threads,
timeout,
log_dir,
rebase,
optimiser=opt,
max_threads=max_threads,
timeout=timeout,
progress_timeout=progress_timeout,
log_dir=log_dir,
rebase=rebase,
)

return CustomPass(apply)

0 comments on commit cc1b095

Please sign in to comment.