Skip to content

Commit

Permalink
feat: python errors and some pyrs helpers (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q authored Sep 19, 2023
1 parent bfdd86b commit 41eaf72
Show file tree
Hide file tree
Showing 12 changed files with 184 additions and 43 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ members = ["pyrs", "compile-matcher", "taso-optimiser"]

[workspace.dependencies]

quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "bc9692b" }
quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "660fef6e" }
portgraph = { version = "0.9", features = ["serde"] }
pyo3 = { version = "0.19" }
itertools = { version = "0.11.0" }
Expand Down
8 changes: 8 additions & 0 deletions DEVELOPMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ stable available.
cargo +nightly miri test
```

To run the python tests, run:

```bash
cd pyrs
maturin develop
pytest
```

## 💅 Coding Style

The rustfmt tool is used to enforce a consistent rust coding style. The CI will fail if the code is not formatted correctly. Python code is formatted with black.
Expand Down
54 changes: 54 additions & 0 deletions pyrs/src/circuit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
//! Circuit-related functionality and utilities.
#![allow(unused)]

use pyo3::prelude::*;

use hugr::{Hugr, HugrView};
use tket2::extension::REGISTRY;
use tket2::json::TKETDecode;
use tket_json_rs::circuit_json::SerialCircuit;

/// Apply a fallible function expecting a hugr on a pytket circuit.
pub fn try_with_hugr<T, E, F>(circ: Py<PyAny>, f: F) -> PyResult<T>
where
E: Into<PyErr>,
F: FnOnce(Hugr) -> Result<T, E>,
{
let hugr = SerialCircuit::_from_tket1(circ).decode()?;
(f)(hugr).map_err(|e| e.into())
}

/// Apply a function expecting a hugr on a pytket circuit.
pub fn with_hugr<T, F: FnOnce(Hugr) -> T>(circ: Py<PyAny>, f: F) -> PyResult<T> {
try_with_hugr(circ, |hugr| Ok::<T, PyErr>((f)(hugr)))
}

/// Apply a hugr-to-hugr function on a pytket circuit, and return the modified circuit.
pub fn try_update_hugr<E: Into<PyErr>, F: FnOnce(Hugr) -> Result<Hugr, E>>(
circ: Py<PyAny>,
f: F,
) -> PyResult<Py<PyAny>> {
let hugr = try_with_hugr(circ, f)?;
SerialCircuit::encode(&hugr)?.to_tket1()
}

/// Apply a hugr-to-hugr function on a pytket circuit, and return the modified circuit.
pub fn update_hugr<F: FnOnce(Hugr) -> Hugr>(circ: Py<PyAny>, f: F) -> PyResult<Py<PyAny>> {
let hugr = with_hugr(circ, f)?;
SerialCircuit::encode(&hugr)?.to_tket1()
}

#[pyfunction]
pub fn validate_hugr(c: Py<PyAny>) -> PyResult<()> {
try_with_hugr(c, |hugr| hugr.validate(&REGISTRY))
}

#[pyfunction]
pub fn to_hugr_dot(c: Py<PyAny>) -> PyResult<String> {
with_hugr(c, |hugr| hugr.dot_string())
}

#[pyfunction]
pub fn to_hugr(c: Py<PyAny>) -> PyResult<Hugr> {
with_hugr(c, |hugr| hugr)
}
55 changes: 47 additions & 8 deletions pyrs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,56 @@
//! Python bindings for TKET2.
#![warn(missing_docs)]

mod circuit;

use pyo3::prelude::*;
use tket2::portmatching::{CircuitPattern, PatternMatcher};

/// The Python bindings to TKET2.
#[pymodule]
fn pyrs(py: Python, m: &PyModule) -> PyResult<()> {
add_patterns_module(py, m)?;
add_circuit_module(py, m)?;
add_pattern_module(py, m)?;
Ok(())
}

fn add_patterns_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "patterns")?;
m.add_class::<CircuitPattern>()?;
m.add_class::<PatternMatcher>()?;
parent.add_submodule(m)?;
Ok(())
/// circuit module
fn add_circuit_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "circuit")?;
m.add_class::<tket2::T2Op>()?;
m.add_class::<tket2::Pauli>()?;

m.add("HugrError", py.get_type::<hugr::hugr::PyHugrError>())?;
m.add("BuildError", py.get_type::<hugr::builder::PyBuildError>())?;
m.add(
"ValidationError",
py.get_type::<hugr::hugr::validate::PyValidationError>(),
)?;
m.add(
"HUGRSerializationError",
py.get_type::<hugr::hugr::serialize::PyHUGRSerializationError>(),
)?;
m.add(
"OpConvertError",
py.get_type::<tket2::json::PyOpConvertError>(),
)?;

parent.add_submodule(m)
}

/// portmatching module
fn add_pattern_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "pattern")?;
m.add_class::<tket2::portmatching::CircuitPattern>()?;
m.add_class::<tket2::portmatching::PatternMatcher>()?;

m.add(
"InvalidPatternError",
py.get_type::<tket2::portmatching::pattern::PyInvalidPatternError>(),
)?;
m.add(
"InvalidReplacementError",
py.get_type::<hugr::hugr::views::sibling_subgraph::PyInvalidReplacementError>(),
)?;

parent.add_submodule(m)
}
22 changes: 11 additions & 11 deletions pyrs/test/test_portmatching.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
from pytket import Circuit
from pytket.qasm import circuit_from_qasm
from pyrs.pyrs import patterns
from pyrs.pyrs import pattern


def test_simple_matching():
""" a simple circuit matching test """
c = Circuit(2).CX(0, 1).H(1).CX(0, 1)

p1 = patterns.CircuitPattern(Circuit(2).CX(0, 1).H(1))
p2 = patterns.CircuitPattern(Circuit(2).H(0).CX(1, 0))
p1 = pattern.CircuitPattern(Circuit(2).CX(0, 1).H(1))
p2 = pattern.CircuitPattern(Circuit(2).H(0).CX(1, 0))

matcher = patterns.PatternMatcher(iter([p1, p2]))
matcher = pattern.PatternMatcher(iter([p1, p2]))

assert len(matcher.find_matches(c)) == 2


def test_non_convex_pattern():
""" two-qubit circuits can't match three-qb ones """
p1 = patterns.CircuitPattern(Circuit(3).CX(0, 1).CX(1, 2))
matcher = patterns.PatternMatcher(iter([p1]))
p1 = pattern.CircuitPattern(Circuit(3).CX(0, 1).CX(1, 2))
matcher = pattern.PatternMatcher(iter([p1]))

c = Circuit(2).CX(0, 1).CX(1, 0)
assert len(matcher.find_matches(c)) == 0
Expand All @@ -34,11 +34,11 @@ def test_larger_matching():
""" a larger crafted circuit with matches WIP """
c = circuit_from_qasm("test/test_files/circ.qasm")

p1 = patterns.CircuitPattern(Circuit(2).CX(0, 1).H(1))
p2 = patterns.CircuitPattern(Circuit(2).H(0).CX(1, 0))
p3 = patterns.CircuitPattern(Circuit(2).CX(0, 1).CX(1, 0))
p4 = patterns.CircuitPattern(Circuit(3).CX(0, 1).CX(1, 2))
p1 = pattern.CircuitPattern(Circuit(2).CX(0, 1).H(1))
p2 = pattern.CircuitPattern(Circuit(2).H(0).CX(1, 0))
p3 = pattern.CircuitPattern(Circuit(2).CX(0, 1).CX(1, 0))
p4 = pattern.CircuitPattern(Circuit(3).CX(0, 1).CX(1, 2))

matcher = patterns.PatternMatcher(iter([p1, p2, p3, p4]))
matcher = pattern.PatternMatcher(iter([p1, p2, p3, p4]))

assert len(matcher.find_matches(c)) == 6
18 changes: 18 additions & 0 deletions src/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ pub mod op;
#[cfg(test)]
mod tests;

#[cfg(feature = "pyo3")]
use pyo3::{create_exception, exceptions::PyException, PyErr};

use std::path::Path;
use std::{fs, io};

Expand Down Expand Up @@ -85,6 +88,21 @@ pub enum OpConvertError {
NonSerializableInputs(OpType),
}

#[cfg(feature = "pyo3")]
create_exception!(
pyrs,
PyOpConvertError,
PyException,
"Error type for conversion between tket2's `Op` and `OpType`"
);

#[cfg(feature = "pyo3")]
impl From<OpConvertError> for PyErr {
fn from(err: OpConvertError) -> Self {
PyOpConvertError::new_err(err.to_string())
}
}

/// Load a TKET1 circuit from a JSON file.
pub fn load_tk1_json_file(path: impl AsRef<Path>) -> Result<Hugr, TK1ConvertError> {
let file = fs::File::open(path)?;
Expand Down
7 changes: 5 additions & 2 deletions src/json/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use std::hash::{Hash, Hasher};
use std::mem;

use hugr::builder::{CircuitBuilder, Container, DFGBuilder, Dataflow, DataflowHugr};
use hugr::extension::prelude::QB_T;
use hugr::extension::prelude::{PRELUDE_ID, QB_T};
use hugr::extension::ExtensionSet;

use hugr::hugr::CircuitUnit;
use hugr::ops::Const;
Expand Down Expand Up @@ -143,7 +144,9 @@ impl JsonDecoder {
Some(c) => {
let const_type = FLOAT64_TYPE;
let const_op = Const::new(c, const_type).unwrap();
self.hugr.add_load_const(const_op).unwrap()
self.hugr
.add_load_const(const_op, ExtensionSet::singleton(&PRELUDE_ID))
.unwrap()
}
None => {
// store string in custom op.
Expand Down
6 changes: 6 additions & 0 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ use strum::IntoEnumIterator;
use strum_macros::{Display, EnumIter, EnumString, IntoStaticStr};
use thiserror::Error;

#[cfg(feature = "pyo3")]
use pyo3::pyclass;

/// Name of tket 2 extension.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("quantum.tket2");

Expand All @@ -41,6 +44,7 @@ pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("quantum.tket2"
IntoStaticStr,
EnumString,
)]
#[cfg_attr(feature = "pyo3", pyclass)]
#[allow(missing_docs)]
/// Simple enum of tket 2 quantum operations.
pub enum T2Op {
Expand All @@ -62,7 +66,9 @@ pub enum T2Op {
AngleAdd,
TK1,
}

#[derive(Clone, Copy, Debug, Serialize, Deserialize, EnumIter, Display, PartialEq, PartialOrd)]
#[cfg_attr(feature = "pyo3", pyclass)]
#[allow(missing_docs)]
/// Simple enum representation of Pauli matrices.
pub enum Pauli {
Expand Down
2 changes: 1 addition & 1 deletion src/portmatching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
pub mod matcher;
pub mod pattern;
#[cfg(feature = "pyo3")]
mod pyo3;
pub mod pyo3;

pub use matcher::{PatternMatch, PatternMatcher};
pub use pattern::CircuitPattern;
Expand Down
17 changes: 16 additions & 1 deletion src/portmatching/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use super::{
use crate::circuit::Circuit;

#[cfg(feature = "pyo3")]
use pyo3::prelude::*;
use pyo3::{create_exception, exceptions::PyException, pyclass, PyErr};

/// A pattern that match a circuit exactly
#[cfg_attr(feature = "pyo3", pyclass)]
Expand Down Expand Up @@ -119,6 +119,21 @@ impl From<NoRootFound> for InvalidPattern {
}
}

#[cfg(feature = "pyo3")]
create_exception!(
pyrs,
PyInvalidPatternError,
PyException,
"Invalid circuit pattern"
);

#[cfg(feature = "pyo3")]
impl From<InvalidPattern> for PyErr {
fn from(err: InvalidPattern) -> Self {
PyInvalidPatternError::new_err(err.to_string())
}
}

#[cfg(test)]
mod tests {
use hugr::Hugr;
Expand Down
Loading

0 comments on commit 41eaf72

Please sign in to comment.