diff --git a/Cargo.toml b/Cargo.toml index b83242149..d3753aa50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,9 +17,10 @@ license-file = "LICENCE" [workspace.dependencies] tket2 = { path = "./tket2" } -quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "b71cae6" } +quantinuum-hugr = { git = "https://github.com/CQCL/hugr", rev = "b71cae6" } portgraph = { version = "0.10" } pyo3 = { version = "0.20" } itertools = { version = "0.11.0" } -tket-json-rs = "0.2.0" +tket-json-rs = { version = "0.3.0" } tracing = "0.1.37" +portmatching = { git = "https://github.com/lmondada/portmatching", rev = "738c91c" } diff --git a/tket2-py/Cargo.toml b/tket2-py/Cargo.toml index b6de354aa..0dc9b67d5 100644 --- a/tket2-py/Cargo.toml +++ b/tket2-py/Cargo.toml @@ -22,3 +22,4 @@ pyo3 = { workspace = true, features = ["extension-module"] } num_cpus = "1.16.0" derive_more = "0.99.17" itertools = { workspace = true } +portmatching = { workspace = true } diff --git a/tket2-py/src/circuit.rs b/tket2-py/src/circuit.rs index 2f9167dbc..7d2a039b9 100644 --- a/tket2-py/src/circuit.rs +++ b/tket2-py/src/circuit.rs @@ -1,27 +1,31 @@ //! Circuit-related functionality and utilities. #![allow(unused)] +pub mod convert; + +use derive_more::{From, Into}; use pyo3::prelude::*; +use std::fmt; use hugr::{Hugr, HugrView}; use tket2::extension::REGISTRY; use tket2::json::TKETDecode; -use tket2::passes::CircuitChunks; use tket2::rewrite::CircuitRewrite; use tket_json_rs::circuit_json::SerialCircuit; +pub use self::convert::{try_update_hugr, try_with_hugr, update_hugr, with_hugr, T2Circuit}; + /// The module definition pub fn module(py: Python) -> PyResult<&PyModule> { let m = PyModule::new(py, "_circuit")?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; m.add_function(wrap_pyfunction!(validate_hugr, m)?)?; m.add_function(wrap_pyfunction!(to_hugr_dot, m)?)?; m.add_function(wrap_pyfunction!(to_hugr, m)?)?; - m.add_function(wrap_pyfunction!(chunks, m)?)?; m.add("HugrError", py.get_type::())?; m.add("BuildError", py.get_type::())?; @@ -41,72 +45,50 @@ pub fn module(py: Python) -> PyResult<&PyModule> { Ok(m) } -/// Apply a fallible function expecting a hugr on a pytket circuit. -pub fn try_with_hugr(circ: Py, f: F) -> PyResult -where - E: Into, - F: FnOnce(Hugr) -> Result, -{ - let hugr = SerialCircuit::_from_tket1(circ).decode()?; - (f)(hugr).map_err(|e| e.into()) -} - -/// Apply a function expecting a hugr on a pytket circuit. -pub fn with_hugr T>(circ: Py, f: F) -> PyResult { - try_with_hugr(circ, |hugr| Ok::((f)(hugr))) -} - -/// Apply a hugr-to-hugr function on a pytket circuit, and return the modified circuit. -pub fn try_update_hugr, F: FnOnce(Hugr) -> Result>( - circ: Py, - f: F, -) -> PyResult> { - let hugr = try_with_hugr(circ, f)?; - SerialCircuit::encode(&hugr)?.to_tket1() -} - -/// Apply a hugr-to-hugr function on a pytket circuit, and return the modified circuit. -pub fn update_hugr Hugr>(circ: Py, f: F) -> PyResult> { - let hugr = with_hugr(circ, f)?; - SerialCircuit::encode(&hugr)?.to_tket1() -} - +/// Run the validation checks on a circuit. #[pyfunction] pub fn validate_hugr(c: Py) -> PyResult<()> { try_with_hugr(c, |hugr| hugr.validate(®ISTRY)) } +/// Return a Graphviz DOT string representation of the circuit. #[pyfunction] pub fn to_hugr_dot(c: Py) -> PyResult { with_hugr(c, |hugr| hugr.dot_string()) } +/// Downcast a python object to a [`Hugr`]. #[pyfunction] pub fn to_hugr(c: Py) -> PyResult { with_hugr(c, |hugr| hugr) } -#[pyfunction] -pub fn chunks(c: Py, max_chunk_size: usize) -> PyResult { - with_hugr(c, |hugr| CircuitChunks::split(&hugr, max_chunk_size)) -} - +/// A [`hugr::Node`] wrapper for Python. #[pyclass] -/// A manager for tket 2 operations on a tket 1 Circuit. -pub struct T2Circuit(pub Hugr); +#[pyo3(name = "Node")] +#[repr(transparent)] +#[derive(From, Into, PartialEq, Eq, Hash, Clone, Copy)] +pub struct PyNode { + /// Rust representation of the node + pub node: hugr::Node, +} -#[pymethods] -impl T2Circuit { - #[new] - fn from_circuit(circ: PyObject) -> PyResult { - Ok(Self(to_hugr(circ)?)) +impl fmt::Display for PyNode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.node.fmt(f) } +} - fn finish(&self) -> PyResult { - SerialCircuit::encode(&self.0)?.to_tket1() +impl fmt::Debug for PyNode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.node.fmt(f) } +} - fn apply_match(&mut self, rw: CircuitRewrite) { - rw.apply(&mut self.0).expect("Apply error."); +#[pymethods] +impl PyNode { + /// A string representation of the pattern. + pub fn __repr__(&self) -> String { + format!("{:?}", self) } } diff --git a/tket2-py/src/circuit/convert.rs b/tket2-py/src/circuit/convert.rs new file mode 100644 index 000000000..f1d866afe --- /dev/null +++ b/tket2-py/src/circuit/convert.rs @@ -0,0 +1,89 @@ +//! Utilities for calling Hugr functions on generic python objects. + +use pyo3::{prelude::*, PyTypeInfo}; + +use hugr::{Hugr, HugrView}; +use tket2::extension::REGISTRY; +use tket2::json::TKETDecode; +use tket2::passes::CircuitChunks; +use tket2::rewrite::CircuitRewrite; +use tket_json_rs::circuit_json::SerialCircuit; + +/// A manager for tket 2 operations on a tket 1 Circuit. +#[pyclass] +#[derive(Clone, Debug, PartialEq)] +pub struct T2Circuit { + /// Rust representation of the circuit. + pub hugr: Hugr, +} + +#[pymethods] +impl T2Circuit { + #[new] + fn from_circuit(circ: PyObject) -> PyResult { + Ok(Self { + hugr: super::to_hugr(circ)?, + }) + } + + fn finish(&self) -> PyResult { + SerialCircuit::encode(&self.hugr)?.to_tket1_with_gil() + } + + fn apply_match(&mut self, rw: CircuitRewrite) { + rw.apply(&mut self.hugr).expect("Apply error."); + } +} +impl T2Circuit { + /// Tries to extract a T2Circuit from a python object. + /// + /// Returns an error if the py object is not a T2Circuit. + pub fn try_extract(circ: Py) -> PyResult { + Python::with_gil(|py| circ.as_ref(py).extract::()) + } +} + +/// Apply a fallible function expecting a hugr on a pytket circuit. +pub fn try_with_hugr(circ: Py, f: F) -> PyResult +where + E: Into, + F: FnOnce(Hugr) -> Result, +{ + let hugr = Python::with_gil(|py| -> PyResult { + let circ = circ.as_ref(py); + match T2Circuit::extract(circ) { + // hugr circuit + Ok(t2circ) => Ok(t2circ.hugr), + // tket1 circuit + Err(_) => Ok(SerialCircuit::from_tket1(circ)?.decode()?), + } + })?; + (f)(hugr).map_err(|e| e.into()) +} + +/// Apply a function expecting a hugr on a pytket circuit. +pub fn with_hugr(circ: Py, f: F) -> PyResult +where + F: FnOnce(Hugr) -> T, +{ + try_with_hugr(circ, |hugr| Ok::((f)(hugr))) +} + +/// Apply a hugr-to-hugr function on a pytket circuit, and return the modified circuit. +pub fn try_update_hugr(circ: Py, f: F) -> PyResult> +where + E: Into, + F: FnOnce(Hugr) -> Result, +{ + let hugr = try_with_hugr(circ, f)?; + SerialCircuit::encode(&hugr)?.to_tket1_with_gil() +} + +/// Apply a hugr-to-hugr function on a pytket circuit, and return the modified circuit. +pub fn update_hugr(circ: Py, f: F) -> PyResult> +where + F: FnOnce(Hugr) -> Hugr, +{ + let hugr = with_hugr(circ, f)?; + SerialCircuit::encode(&hugr)?.to_tket1_with_gil() +} diff --git a/tket2-py/src/lib.rs b/tket2-py/src/lib.rs index 2245220f0..e055cd18a 100644 --- a/tket2-py/src/lib.rs +++ b/tket2-py/src/lib.rs @@ -1,10 +1,10 @@ //! Python bindings for TKET2. #![warn(missing_docs)] -mod circuit; -mod optimiser; -mod passes; -mod pattern; +pub mod circuit; +pub mod optimiser; +pub mod passes; +pub mod pattern; use pyo3::prelude::*; diff --git a/tket2-py/src/passes.rs b/tket2-py/src/passes.rs index c9aef2bb6..c3363c4c4 100644 --- a/tket2-py/src/passes.rs +++ b/tket2-py/src/passes.rs @@ -1,3 +1,7 @@ +//! Passes for optimising circuits. + +pub mod chunks; + use std::{cmp::min, convert::TryInto, fs, num::NonZeroUsize, path::PathBuf}; use pyo3::{prelude::*, types::IntoPyDict}; @@ -16,7 +20,8 @@ pub fn module(py: Python) -> PyResult<&PyModule> { let m = PyModule::new(py, "_passes")?; m.add_function(wrap_pyfunction!(greedy_depth_reduce, m)?)?; m.add_function(wrap_pyfunction!(badger_optimise, m)?)?; - m.add_class::()?; + m.add_class::()?; + m.add_function(wrap_pyfunction!(self::chunks::chunks, m)?)?; m.add( "PullForwardError", py.get_type::(), @@ -28,7 +33,7 @@ pub fn module(py: Python) -> PyResult<&PyModule> { fn greedy_depth_reduce(py_c: PyObject) -> PyResult<(PyObject, u32)> { try_with_hugr(py_c, |mut h| { let n_moves = apply_greedy_commutation(&mut h)?; - let py_c = SerialCircuit::encode(&h)?.to_tket1()?; + let py_c = SerialCircuit::encode(&h)?.to_tket1_with_gil()?; PyResult::Ok((py_c, n_moves)) }) } diff --git a/tket2-py/src/passes/chunks.rs b/tket2-py/src/passes/chunks.rs new file mode 100644 index 000000000..4642aa42a --- /dev/null +++ b/tket2-py/src/passes/chunks.rs @@ -0,0 +1,69 @@ +//! Circuit chunking utilities. + +use derive_more::From; +use pyo3::exceptions::PyAttributeError; +use pyo3::prelude::*; +use tket2::json::TKETDecode; +use tket2::passes::CircuitChunks; +use tket2::Circuit; +use tket_json_rs::circuit_json::SerialCircuit; + +use crate::circuit::{with_hugr, T2Circuit}; + +/// Split a circuit into chunks of a given size. +#[pyfunction] +pub fn chunks(c: Py, max_chunk_size: usize) -> PyResult { + with_hugr(c, |hugr| { + // TODO: Detect if the circuit is in tket1 format or T2Circuit. + let is_tket1 = true; + let chunks = CircuitChunks::split(&hugr, max_chunk_size); + (chunks, is_tket1).into() + }) +} + +/// A pattern that match a circuit exactly +/// +/// Python equivalent of [`CircuitChunks`]. +/// +/// [`CircuitChunks`]: tket2::passes::chunks::CircuitChunks +#[pyclass] +#[pyo3(name = "CircuitChunks")] +#[derive(Debug, Clone, From)] +pub struct PyCircuitChunks { + /// Rust representation of the circuit chunks. + pub chunks: CircuitChunks, + /// Whether to reassemble the circuit in the tket1 format. + pub in_tket1: bool, +} + +#[pymethods] +impl PyCircuitChunks { + /// Reassemble the chunks into a circuit. + fn reassemble(&self) -> PyResult> { + let hugr = self.clone().chunks.reassemble()?; + Python::with_gil(|py| match self.in_tket1 { + true => Ok(SerialCircuit::encode(&hugr)?.to_tket1(py)?.into_py(py)), + false => Ok(T2Circuit { hugr }.into_py(py)), + }) + } + + /// Returns clones of the split circuits. + fn circuits(&self) -> PyResult>> { + self.chunks + .iter() + .map(|hugr| SerialCircuit::encode(hugr)?.to_tket1_with_gil()) + .collect() + } + + /// Replaces a chunk's circuit with an updated version. + fn update_circuit(&mut self, index: usize, new_circ: Py) -> PyResult<()> { + let hugr = SerialCircuit::from_tket1_with_gil(new_circ)?.decode()?; + if hugr.circuit_signature() != self.chunks[index].circuit_signature() { + return Err(PyAttributeError::new_err( + "The new circuit has a different signature.", + )); + } + self.chunks[index] = hugr; + Ok(()) + } +} diff --git a/tket2-py/src/pattern.rs b/tket2-py/src/pattern.rs index 52a2ed610..1bbc9cb24 100644 --- a/tket2-py/src/pattern.rs +++ b/tket2-py/src/pattern.rs @@ -1,22 +1,28 @@ -//! +//! Pattern matching on circuits. + +pub mod portmatching; use crate::circuit::{to_hugr, T2Circuit}; use hugr::Hugr; use pyo3::prelude::*; -use tket2::portmatching::pyo3::PyPatternMatch; use tket2::portmatching::{CircuitPattern, PatternMatcher}; use tket2::rewrite::CircuitRewrite; /// The module definition pub fn module(py: Python) -> PyResult<&PyModule> { let m = PyModule::new(py, "_pattern")?; - m.add_class::()?; - m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add( + "InvalidReplacementError", + py.get_type::(), + )?; m.add( "InvalidPatternError", py.get_type::(), @@ -64,12 +70,10 @@ impl RuleMatcher { } pub fn find_match(&self, target: &T2Circuit) -> PyResult> { - let h = &target.0; - let p_match = self.matcher.find_matches_iter(h).next(); - if let Some(m) = p_match { - let py_match = PyPatternMatch::try_from_rust(m, h, &self.matcher)?; - let r = self.rights.get(py_match.pattern_id).unwrap().clone(); - let rw = py_match.to_rewrite(h, r)?; + let h = &target.hugr; + if let Some(p_match) = self.matcher.find_matches_iter(h).next() { + let r = self.rights.get(p_match.pattern_id().0).unwrap().clone(); + let rw = p_match.to_rewrite(h, r)?; Ok(Some(rw)) } else { Ok(None) diff --git a/tket2-py/src/pattern/portmatching.rs b/tket2-py/src/pattern/portmatching.rs new file mode 100644 index 000000000..75344ed3f --- /dev/null +++ b/tket2-py/src/pattern/portmatching.rs @@ -0,0 +1,144 @@ +//! Python bindings for portmatching features + +use std::fmt; + +use derive_more::{From, Into}; +use hugr::Node; +use itertools::Itertools; +use portmatching::PatternID; +use pyo3::{prelude::*, types::PyIterator}; + +use tket2::portmatching::{CircuitPattern, PatternMatch, PatternMatcher}; + +use crate::circuit::{try_with_hugr, with_hugr}; + +/// A pattern that match a circuit exactly +/// +/// Python equivalent of [`CircuitPattern`]. +/// +/// [`CircuitPattern`]: tket2::portmatching::matcher::CircuitPattern +#[pyclass] +#[pyo3(name = "CircuitPattern")] +#[repr(transparent)] +#[derive(Debug, Clone, From)] +pub struct PyCircuitPattern { + /// Rust representation of the pattern + pub pattern: CircuitPattern, +} + +#[pymethods] +impl PyCircuitPattern { + /// Construct a pattern from a TKET1 circuit + #[new] + pub fn from_circuit(circ: Py) -> PyResult { + let pattern = try_with_hugr(circ, |circ| CircuitPattern::try_from_circuit(&circ))?; + Ok(pattern.into()) + } + + /// A string representation of the pattern. + pub fn __repr__(&self) -> String { + format!("{:?}", self.pattern) + } +} + +/// A matcher object for fast pattern matching on circuits. +/// +/// This uses a state automaton internally to match against a set of patterns +/// simultaneously. +/// +/// Python equivalent of [`PatternMatcher`]. +/// +/// [`PatternMatcher`]: tket2::portmatching::matcher::PatternMatcher +#[pyclass] +#[pyo3(name = "PatternMatcher")] +#[repr(transparent)] +#[derive(Debug, Clone, From)] +pub struct PyPatternMatcher { + /// Rust representation of the matcher + pub matcher: PatternMatcher, +} + +#[pymethods] +impl PyPatternMatcher { + /// Construct a matcher from a list of patterns. + #[new] + pub fn py_from_patterns(patterns: &PyIterator) -> PyResult { + Ok(PatternMatcher::from_patterns( + patterns + .iter()? + .map(|p| { + let py_pattern = p?.extract::()?; + Ok(py_pattern.pattern) + }) + .collect::>>()?, + ) + .into()) + } + /// A string representation of the pattern. + pub fn __repr__(&self) -> PyResult { + Ok(format!("{:?}", self.matcher)) + } + + /// Find all convex matches in a circuit. + pub fn find_matches(&self, circ: PyObject) -> PyResult> { + with_hugr(circ, |circ| { + self.matcher + .find_matches(&circ) + .into_iter() + .map_into() + .collect() + }) + } +} + +/// A convex pattern match in a circuit, available from Python. +/// +/// Python equivalent of [`PatternMatch`]. +/// +/// [`PatternMatch`]: tket2::portmatching::matcher::PatternMatch +#[pyclass] +#[derive(Debug, Clone, From)] +#[pyo3(name = "PatternMatch")] +pub struct PyPatternMatch { + pmatch: PatternMatch, +} + +#[pymethods] +impl PyPatternMatch { + /// The matched pattern ID. + pub fn pattern_id(&self) -> PyPatternID { + self.pmatch.pattern_id().into() + } + + /// Returns the root of the pattern in the circuit. + pub fn root(&self) -> Node { + self.pmatch.root() + } + + /// A string representation of the pattern. + pub fn __repr__(&self) -> String { + format!("{:?}", self.pmatch) + } +} + +/// A [`hugr::Node`] wrapper for Python. +#[pyclass] +#[pyo3(name = "PatternID")] +#[repr(transparent)] +#[derive(From, Into, PartialEq, Eq, Hash, Clone, Copy)] +pub struct PyPatternID { + /// Rust representation of the pattern ID + pub id: PatternID, +} + +impl fmt::Display for PyPatternID { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.id.fmt(f) + } +} + +impl fmt::Debug for PyPatternID { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.id.fmt(f) + } +} diff --git a/tket2-py/test/test_bindings.py b/tket2-py/test/test_bindings.py index 8c7653468..9385b6d4e 100644 --- a/tket2-py/test/test_bindings.py +++ b/tket2-py/test/test_bindings.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from pytket.circuit import Circuit -from tket2 import circuit +from tket2 import passes from tket2.passes import greedy_depth_reduce from tket2.circuit import T2Circuit from tket2.pattern import Rule, RuleMatcher @@ -29,7 +29,7 @@ def test_chunks(): assert c.depth() == 3 - chunks = circuit.chunks(c, 2) + chunks = passes.chunks(c, 2) circuits = chunks.circuits() chunks.update_circuit(0, circuits[0]) c2 = chunks.reassemble() diff --git a/tket2/Cargo.toml b/tket2/Cargo.toml index 94328feb4..712655c15 100644 --- a/tket2/Cargo.toml +++ b/tket2/Cargo.toml @@ -37,10 +37,7 @@ typetag = "0.2.8" itertools = { workspace = true } petgraph = { version = "0.6.3", default-features = false } serde_yaml = "0.9.22" -# portmatching = { version = "0.2.0", optional = true, features = ["serde"]} -portmatching = { optional = true, git = "https://github.com/lmondada/portmatching", rev = "738c91c", features = [ - "serde", -] } +portmatching = { workspace = true, optional = true, features = ["serde"] } derive_more = "0.99.17" quantinuum-hugr = { workspace = true } portgraph = { workspace = true, features = ["serde"] } diff --git a/tket2/src/passes/chunks.rs b/tket2/src/passes/chunks.rs index bad258398..95d73b448 100644 --- a/tket2/src/passes/chunks.rs +++ b/tket2/src/passes/chunks.rs @@ -22,12 +22,6 @@ use itertools::Itertools; use crate::Circuit; use crate::circuit::cost::{CircuitCost, CostDelta}; -#[cfg(feature = "pyo3")] -use crate::json::TKETDecode; -#[cfg(feature = "pyo3")] -use pyo3::{exceptions::PyAttributeError, pyclass, pymethods, Py, PyAny, PyResult}; -#[cfg(feature = "pyo3")] -use tket_json_rs::circuit_json::SerialCircuit; /// An identifier for the connection between chunks. /// @@ -40,7 +34,6 @@ pub struct ChunkConnection(Wire); /// A chunk of a circuit. #[derive(Debug, Clone)] -#[cfg_attr(feature = "pyo3", pyclass)] pub struct Chunk { /// The extracted circuit. pub circ: Hugr, @@ -241,7 +234,6 @@ enum ConnectionTarget { /// or [`CircuitChunks::split_with_cost`], and reassembled with /// [`CircuitChunks::reassemble`]. #[derive(Debug, Clone)] -#[cfg_attr(feature = "pyo3", pyclass)] pub struct CircuitChunks { /// The original circuit's signature. signature: FunctionType, @@ -464,38 +456,6 @@ impl CircuitChunks { } } -#[cfg(feature = "pyo3")] -#[pymethods] -impl CircuitChunks { - /// Reassemble the chunks into a circuit. - #[pyo3(name = "reassemble")] - fn py_reassemble(&self) -> PyResult> { - let hugr = self.clone().reassemble()?; - SerialCircuit::encode(&hugr)?.to_tket1() - } - - /// Returns clones of the split circuits. - #[pyo3(name = "circuits")] - fn py_circuits(&self) -> PyResult>> { - self.iter() - .map(|hugr| SerialCircuit::encode(hugr)?.to_tket1()) - .collect() - } - - /// Replaces a chunk's circuit with an updated version. - #[pyo3(name = "update_circuit")] - fn py_update_circuit(&mut self, index: usize, new_circ: Py) -> PyResult<()> { - let hugr = SerialCircuit::_from_tket1(new_circ).decode()?; - if hugr.circuit_signature() != self.chunks[index].circ.circuit_signature() { - return Err(PyAttributeError::new_err( - "The new circuit has a different signature.", - )); - } - self.chunks[index].circ = hugr; - Ok(()) - } -} - impl Index for CircuitChunks { type Output = Hugr; diff --git a/tket2/src/portmatching.rs b/tket2/src/portmatching.rs index 74db117db..3a4d047dc 100644 --- a/tket2/src/portmatching.rs +++ b/tket2/src/portmatching.rs @@ -2,8 +2,6 @@ pub mod matcher; pub mod pattern; -#[cfg(feature = "pyo3")] -pub mod pyo3; use hugr::OutgoingPort; use itertools::Itertools; diff --git a/tket2/src/portmatching/matcher.rs b/tket2/src/portmatching/matcher.rs index 8df5f0e25..a864218f8 100644 --- a/tket2/src/portmatching/matcher.rs +++ b/tket2/src/portmatching/matcher.rs @@ -79,11 +79,16 @@ pub struct PatternMatch { } impl PatternMatch { - /// The matcher's pattern ID of the match. + /// The matched pattern ID. pub fn pattern_id(&self) -> PatternID { self.pattern } + /// Returns the root of the pattern in the circuit. + pub fn root(&self) -> Node { + self.root + } + /// Create a pattern match from the image of a pattern root. /// /// This checks at construction time that the match is convex. This will diff --git a/tket2/src/portmatching/pattern.rs b/tket2/src/portmatching/pattern.rs index 9eeddc8fd..d1320cc22 100644 --- a/tket2/src/portmatching/pattern.rs +++ b/tket2/src/portmatching/pattern.rs @@ -14,10 +14,9 @@ use super::{ use crate::{circuit::Circuit, portmatching::NodeID}; #[cfg(feature = "pyo3")] -use pyo3::{create_exception, exceptions::PyException, pyclass, PyErr}; +use pyo3::{create_exception, exceptions::PyException, PyErr}; /// A pattern that match a circuit exactly -#[cfg_attr(feature = "pyo3", pyclass)] #[derive(Clone, serde::Serialize, serde::Deserialize)] pub struct CircuitPattern { pub(super) pattern: Pattern, diff --git a/tket2/src/portmatching/pyo3.rs b/tket2/src/portmatching/pyo3.rs deleted file mode 100644 index 410898e28..000000000 --- a/tket2/src/portmatching/pyo3.rs +++ /dev/null @@ -1,203 +0,0 @@ -//! Python bindings for portmatching features - -use std::fmt; - -use derive_more::{From, Into}; -use hugr::hugr::views::sibling_subgraph::PyInvalidReplacementError; -use hugr::{Hugr, IncomingPort, OutgoingPort}; -use itertools::Itertools; -use portmatching::{HashMap, PatternID}; -use pyo3::{prelude::*, types::PyIterator}; -use tket_json_rs::circuit_json::SerialCircuit; - -use super::{CircuitPattern, PatternMatch, PatternMatcher}; -use crate::circuit::Circuit; -use crate::json::TKETDecode; -use crate::rewrite::CircuitRewrite; - -#[pymethods] -impl CircuitPattern { - /// Construct a pattern from a TKET1 circuit - #[new] - pub fn py_from_circuit(circ: PyObject) -> PyResult { - let circ = pyobj_as_hugr(circ)?; - let pattern = CircuitPattern::try_from_circuit(&circ)?; - Ok(pattern) - } - - /// A string representation of the pattern. - pub fn __repr__(&self) -> String { - format!("{:?}", self) - } -} - -#[pymethods] -impl PatternMatcher { - /// Construct a matcher from a list of patterns. - #[new] - pub fn py_from_patterns(patterns: &PyIterator) -> PyResult { - Ok(PatternMatcher::from_patterns( - patterns - .iter()? - .map(|p| p?.extract::()) - .collect::>>()?, - )) - } - /// A string representation of the pattern. - pub fn __repr__(&self) -> PyResult { - Ok(format!("{:?}", self)) - } - - /// Find all convex matches in a circuit. - #[pyo3(name = "find_matches")] - pub fn py_find_matches(&self, circ: PyObject) -> PyResult> { - let circ = pyobj_as_hugr(circ)?; - self.find_matches(&circ) - .into_iter() - .map(|m| { - let pattern_id = m.pattern_id(); - PyPatternMatch::try_from_rust(m, &circ, self).map_err(|e| { - PyInvalidReplacementError::new_err(format!( - "Invalid match for pattern {:?}: {}", - pattern_id, e - )) - }) - }) - .collect() - } -} - -/// Python equivalent of [`PatternMatch`]. -/// -/// A convex pattern match in a circuit, available from Python. -/// -/// This object is semantically equivalent to Rust's [`PatternMatch`] but -/// stores data differently, and in particular removes the lifetime-bound -/// references of Rust. -/// -/// The data is stored in a way that favours a nice user-facing representation -/// over efficiency. It is provided for convenience and not recommended when -/// performance is a key concern. -/// -/// TODO: can this be a wrapper for a [`PatternMatch`] instead? -/// -/// [`PatternMatch`]: crate::portmatching::matcher::PatternMatch -#[pyclass] -#[derive(Debug, Clone)] -pub struct PyPatternMatch { - /// The ID of the pattern in the matcher. - pub pattern_id: usize, - /// The root of the pattern within the circuit. - pub root: Node, - /// The input ports of the subcircuit. - /// - /// This is the incoming boundary of a [`hugr::hugr::views::SiblingSubgraph`]. - /// The input ports are grouped together if they are connected to the same - /// source. - pub inputs: Vec>, - /// The output ports of the subcircuit. - /// - /// This is the outgoing boundary of a [`hugr::hugr::views::SiblingSubgraph`]. - pub outputs: Vec<(Node, OutgoingPort)>, - /// The node map from pattern to circuit. - pub node_map: HashMap, -} - -#[pymethods] -impl PyPatternMatch { - /// A string representation of the pattern. - pub fn __repr__(&self) -> String { - format!("CircuitMatch {:?}", self.node_map) - } -} - -impl PyPatternMatch { - /// Construct a [`PyPatternMatch`] from a [`PatternMatch`]. - /// - /// Requires references to the circuit and pattern to resolve indices - /// into these objects. - pub fn try_from_rust( - m: PatternMatch, - circ: &C, - matcher: &PatternMatcher, - ) -> PyResult { - let pattern_id = m.pattern_id(); - let pattern = matcher.get_pattern(pattern_id).unwrap(); - let root = Node(m.root); - - let node_map: HashMap = pattern - .get_match_map(root.0, circ) - .ok_or_else(|| PyInvalidReplacementError::new_err("Invalid match"))? - .into_iter() - .map(|(p, c)| (Node(p), Node(c))) - .collect(); - let inputs = pattern - .inputs - .iter() - .map(|ps| { - ps.iter() - .map(|&(n, p)| (node_map[&Node(n)], p.as_incoming().unwrap())) - .collect_vec() - }) - .collect_vec(); - let outputs = pattern - .outputs - .iter() - .map(|&(n, p)| (node_map[&Node(n)], p.as_outgoing().unwrap())) - .collect_vec(); - Ok(Self { - pattern_id: pattern_id.0, - inputs, - outputs, - node_map, - root, - }) - } - - /// Convert the pattern into a [`CircuitRewrite`]. - pub fn to_rewrite(&self, circ: &Hugr, replacement: Hugr) -> PyResult { - let inputs = self - .inputs - .iter() - .map(|p| p.iter().map(|&(n, p)| (n.0, p)).collect()) - .collect(); - let outputs = self.outputs.iter().map(|&(n, p)| (n.0, p)).collect(); - let rewrite = PatternMatch::try_from_io( - self.root.0, - PatternID(self.pattern_id), - circ, - inputs, - outputs, - ) - .expect("Invalid PyCircuitMatch object") - .to_rewrite(circ, replacement)?; - Ok(rewrite) - } -} - -/// A [`hugr::Node`] wrapper for Python. -/// -/// Note: this will probably be useful outside of portmatching -#[pyclass] -#[derive(From, Into, PartialEq, Eq, Hash, Clone, Copy)] -pub struct Node(hugr::Node); - -impl fmt::Debug for Node { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} - -#[pymethods] -impl Node { - /// A string representation of the pattern. - pub fn __repr__(&self) -> String { - format!("{:?}", self) - } -} - -fn pyobj_as_hugr(circ: PyObject) -> PyResult { - let ser_c = SerialCircuit::_from_tket1(circ); - let hugr: Hugr = ser_c.decode()?; - Ok(hugr) -}