Skip to content

Commit

Permalink
feat: high(-er) level rule, find, apply API
Browse files Browse the repository at this point in the history
Only returns first found match for now
  • Loading branch information
ss2165 committed Oct 9, 2023
1 parent 39f5280 commit 507dd95
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 9 deletions.
81 changes: 79 additions & 2 deletions pyrs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,88 @@
//! Python bindings for TKET2.
#![warn(missing_docs)]
use circuit::{add_circuit_module, try_with_hugr};
use circuit::{add_circuit_module, to_hugr, try_with_hugr};
use hugr::Hugr;
use optimiser::add_optimiser_module;
use pyo3::prelude::*;
use tket2::{json::TKETDecode, passes::apply_greedy_commutation};
use tket2::{
json::TKETDecode,
passes::apply_greedy_commutation,
portmatching::{pyo3::PyPatternMatch, CircuitPattern, PatternMatcher},
rewrite::CircuitRewrite,
};
use tket_json_rs::circuit_json::SerialCircuit;

mod circuit;
mod optimiser;

#[derive(Clone)]
#[pyclass]
/// A rewrite rule defined by a left hand side and right hand side of an equation.
pub struct Rule(pub [Hugr; 2]);

#[pymethods]
impl Rule {
#[new]
fn new_rule(l: PyObject, r: PyObject) -> PyResult<Rule> {
let l = to_hugr(l)?;
let r = to_hugr(r)?;

Ok(Rule([l, r]))
}
}
#[pyclass]
struct RuleMatcher {
matcher: PatternMatcher,
rights: Vec<Hugr>,
}

#[pymethods]
impl RuleMatcher {
#[new]
pub fn from_rules(rules: Vec<Rule>) -> PyResult<Self> {
let (lefts, rights): (Vec<_>, Vec<_>) =
rules.into_iter().map(|Rule([l, r])| (l, r)).unzip();
let patterns: Result<Vec<CircuitPattern>, _> =
lefts.iter().map(CircuitPattern::try_from_circuit).collect();
let matcher = PatternMatcher::from_patterns(patterns?);

Ok(Self { matcher, rights })
}

pub fn find_match(&self, target: &T2Circuit) -> PyResult<Option<CircuitRewrite>> {
let h = &target.0;
let p_match = self.matcher.find_matches_iter(h).next();
if let Some(m) = p_match {
let py_match = PyPatternMatch::try_from_rust(m, h, &self.matcher)?;
let r = self.rights.get(py_match.pattern_id).unwrap().clone();
let rw = py_match.to_rewrite(h, r)?;
Ok(Some(rw))
} else {
Ok(None)
}
}
}

#[pyclass]
/// A manager for tket 2 operations on a tket 1 Circuit.
pub struct T2Circuit(Hugr);

#[pymethods]
impl T2Circuit {
#[new]
fn from_circuit(circ: PyObject) -> PyResult<Self> {
Ok(Self(to_hugr(circ)?))
}

fn finish(&self) -> PyResult<PyObject> {
SerialCircuit::encode(&self.0)?.to_tket1()
}

fn apply_match(&mut self, rw: CircuitRewrite) {
rw.apply(&mut self.0).expect("Apply error.");
}
}

#[pyfunction]
fn greedy_depth_reduce(py_c: PyObject) -> PyResult<(PyObject, u32)> {
try_with_hugr(py_c, |mut h| {
Expand All @@ -25,6 +99,9 @@ fn pyrs(py: Python, m: &PyModule) -> PyResult<()> {
add_pattern_module(py, m)?;
add_pass_module(py, m)?;
add_optimiser_module(py, m)?;
m.add_class::<Rule>()?;
m.add_class::<RuleMatcher>()?;
m.add_class::<T2Circuit>()?;
Ok(())
}

Expand Down
37 changes: 37 additions & 0 deletions pyrs/test/test_bindings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from pyrs.pyrs import passes, circuit
from pyrs.pyrs import Rule, RuleMatcher, T2Circuit
from pytket.circuit import Circuit


Expand All @@ -19,6 +20,7 @@ def test_depth_optimise():

assert c.depth() == 2


def test_chunks():
c = Circuit(4).CX(0, 2).CX(1, 3).CX(1, 2).CX(0, 3).CX(1, 3)

Expand All @@ -31,6 +33,41 @@ def test_chunks():

assert c2.depth() == 3


def test_cx_rule():
c = T2Circuit(Circuit(4).CX(0, 2).CX(1, 2).CX(1, 2))

rule = Rule(Circuit(2).CX(0, 1).CX(0, 1), Circuit(2))
matcher = RuleMatcher([rule])

mtch = matcher.find_match(c)

c.apply_match(mtch)

coms = c.finish().get_commands()
print(coms)
assert len(coms) == 1


def test_multiple_rules():
circuit = T2Circuit(Circuit(3).CX(0, 1).H(0).H(1).H(2).Z(0).H(0).H(1).H(2))

rule1 = Rule(Circuit(1).H(0).Z(0).H(0), Circuit(1).X(0))
rule2 = Rule(Circuit(1).H(0).H(0), Circuit(1))
matcher = RuleMatcher([rule1, rule2])

match_count = 0
while match := matcher.find_match(circuit):
match_count += 1
circuit.apply_match(match)

assert match_count == 3

coms = circuit.finish().get_commands()
print(coms)
assert len(coms) == 2


# from dataclasses import dataclass
# from typing import Callable, Iterable
# import time
Expand Down
13 changes: 10 additions & 3 deletions src/portmatching/matcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,19 @@ impl PatternMatcher {
}

/// Find all convex pattern matches in a circuit.
pub fn find_matches<C: Circuit + Clone>(&self, circuit: &C) -> Vec<PatternMatch> {
pub fn find_matches_iter<'a, 'c: 'a, C: Circuit + Clone>(
&'a self,
circuit: &'c C,
) -> impl Iterator<Item = PatternMatch> + 'a {
let mut checker = ConvexChecker::new(circuit);
circuit
.commands()
.flat_map(|cmd| self.find_rooted_matches(circuit, cmd.node(), &mut checker))
.collect()
.flat_map(move |cmd| self.find_rooted_matches(circuit, cmd.node(), &mut checker))
}

/// Find all convex pattern matches in a circuit.and collect in to a vector
pub fn find_matches<C: Circuit + Clone>(&self, circuit: &C) -> Vec<PatternMatch> {
self.find_matches_iter(circuit).collect()
}

/// Find all convex pattern matches in a circuit rooted at a given node.
Expand Down
7 changes: 3 additions & 4 deletions src/portmatching/pyo3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,7 @@ impl PyPatternMatch {
}

/// Convert the pattern into a [`CircuitRewrite`].
pub fn to_rewrite(&self, circ: PyObject, replacement: PyObject) -> PyResult<CircuitRewrite> {
let circ = pyobj_as_hugr(circ)?;
pub fn to_rewrite(&self, circ: &Hugr, replacement: Hugr) -> PyResult<CircuitRewrite> {
let inputs = self
.inputs
.iter()
Expand All @@ -166,12 +165,12 @@ impl PyPatternMatch {
let rewrite = PatternMatch::try_from_io(
self.root.0,
PatternID(self.pattern_id),
&circ,
circ,
inputs,
outputs,
)
.expect("Invalid PyCircuitMatch object")
.to_rewrite(&circ, pyobj_as_hugr(replacement)?)?;
.to_rewrite(circ, replacement)?;
Ok(rewrite)
}
}
Expand Down

0 comments on commit 507dd95

Please sign in to comment.