Skip to content

Commit

Permalink
feat: Implement PyErr conversion locally in tket2-py (#258)
Browse files Browse the repository at this point in the history
With this we can drop the `pyo3` features in hugr and portgraph.
`tket2` and `tket-json-rs` still require it, as they define some
non-error pyclasses.

This requires manually converting errors by calling `convert_pyerrs`
before `?`, but it is a small price to pay for not dealing with pyo3
versions across crates.
  • Loading branch information
aborgna-q authored Nov 23, 2023
1 parent 85ce5f9 commit 3e1a68d
Show file tree
Hide file tree
Showing 12 changed files with 125 additions and 92 deletions.
4 changes: 2 additions & 2 deletions tket2-py/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ tket2 = { workspace = true, features = ["pyo3", "portmatching"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tket-json-rs = { workspace = true, features = ["pyo3"] }
quantinuum-hugr = { workspace = true, features = ["pyo3"] }
portgraph = { workspace = true, features = ["pyo3", "serde"] }
quantinuum-hugr = { workspace = true }
portgraph = { workspace = true, features = ["serde"] }
pyo3 = { workspace = true, features = ["extension-module"] }
num_cpus = "1.16.0"
derive_more = "0.99.17"
Expand Down
48 changes: 37 additions & 11 deletions tket2-py/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use tket2::json::TKETDecode;
use tket2::rewrite::CircuitRewrite;
use tket_json_rs::circuit_json::SerialCircuit;

use crate::utils::create_py_exception;

pub use self::convert::{try_update_hugr, try_with_hugr, update_hugr, with_hugr, Tk2Circuit};
pub use self::cost::PyCircuitCost;
pub use tket2::{Pauli, Tk2Op};
Expand All @@ -30,24 +32,48 @@ pub fn module(py: Python) -> PyResult<&PyModule> {
m.add_function(wrap_pyfunction!(validate_hugr, m)?)?;
m.add_function(wrap_pyfunction!(to_hugr_dot, m)?)?;

m.add("HugrError", py.get_type::<hugr::hugr::PyHugrError>())?;
m.add("BuildError", py.get_type::<hugr::builder::PyBuildError>())?;
m.add(
"ValidationError",
py.get_type::<hugr::hugr::validate::PyValidationError>(),
)?;
m.add("HugrError", py.get_type::<PyHugrError>())?;
m.add("BuildError", py.get_type::<PyBuildError>())?;
m.add("ValidationError", py.get_type::<PyValidationError>())?;
m.add(
"HUGRSerializationError",
py.get_type::<hugr::hugr::serialize::PyHUGRSerializationError>(),
)?;
m.add(
"OpConvertError",
py.get_type::<tket2::json::PyOpConvertError>(),
py.get_type::<PyHUGRSerializationError>(),
)?;
m.add("OpConvertError", py.get_type::<PyOpConvertError>())?;

Ok(m)
}

create_py_exception!(
hugr::hugr::HugrError,
PyHugrError,
"Errors that can occur while manipulating a HUGR."
);

create_py_exception!(
hugr::builder::BuildError,
PyBuildError,
"Error while building the HUGR."
);

create_py_exception!(
hugr::hugr::validate::ValidationError,
PyValidationError,
"Errors that can occur while validating a Hugr."
);

create_py_exception!(
hugr::hugr::serialize::HUGRSerializationError,
PyHUGRSerializationError,
"Errors that can occur while serializing a HUGR."
);

create_py_exception!(
tket2::json::OpConvertError,
PyOpConvertError,
"Error type for the conversion between tket2 and tket1 operations."
);

/// Run the validation checks on a circuit.
#[pyfunction]
pub fn validate_hugr(c: &PyAny) -> PyResult<()> {
Expand Down
21 changes: 12 additions & 9 deletions tket2-py/src/circuit/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use tket2::{Circuit, Tk2Op};
use tket_json_rs::circuit_json::SerialCircuit;

use crate::rewrite::PyCircuitRewrite;
use crate::utils::ConvertPyErr;

use super::{cost, PyCircuitCost};

Expand Down Expand Up @@ -57,7 +58,9 @@ impl Tk2Circuit {

/// Convert the [`Tk2Circuit`] to a tket1 circuit.
pub fn to_tket1<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> {
SerialCircuit::encode(&self.hugr)?.to_tket1(py)
SerialCircuit::encode(&self.hugr)
.convert_pyerrs()?
.to_tket1(py)
}

/// Apply a rewrite on the circuit.
Expand Down Expand Up @@ -85,7 +88,7 @@ impl Tk2Circuit {
/// FIXME: Currently the encoded circuit cannot be loaded back due to
/// [https://github.com/CQCL/hugr/issues/683]
pub fn to_tket1_json(&self) -> PyResult<String> {
Ok(serde_json::to_string(&SerialCircuit::encode(&self.hugr)?).unwrap())
Ok(serde_json::to_string(&SerialCircuit::encode(&self.hugr).convert_pyerrs()?).unwrap())
}

/// Decode a tket1 json string to a circuit.
Expand All @@ -94,7 +97,7 @@ impl Tk2Circuit {
let tk1: SerialCircuit = serde_json::from_str(json)
.map_err(|e| PyErr::new::<PyAttributeError, _>(format!("Invalid encoded HUGR: {e}")))?;
Ok(Tk2Circuit {
hugr: tk1.decode()?,
hugr: tk1.decode().convert_pyerrs()?,
})
}

Expand Down Expand Up @@ -164,7 +167,7 @@ impl CircuitType {
/// Converts a `Hugr` into the format indicated by the flag.
pub fn convert(self, py: Python, hugr: Hugr) -> PyResult<&PyAny> {
match self {
CircuitType::Tket1 => SerialCircuit::encode(&hugr)?.to_tket1(py),
CircuitType::Tket1 => SerialCircuit::encode(&hugr).convert_pyerrs()?.to_tket1(py),
CircuitType::Tket2 => Ok(Py::new(py, Tk2Circuit { hugr })?.into_ref(py)),
}
}
Expand All @@ -175,19 +178,19 @@ impl CircuitType {
/// This method supports both `pytket.Circuit` and `Tk2Circuit` python objects.
pub fn try_with_hugr<T, E, F>(circ: &PyAny, f: F) -> PyResult<T>
where
E: Into<PyErr>,
E: ConvertPyErr<Output = PyErr>,
F: FnOnce(Hugr, CircuitType) -> Result<T, E>,
{
let (hugr, typ) = match Tk2Circuit::extract(circ) {
// hugr circuit
Ok(t2circ) => (t2circ.hugr, CircuitType::Tket2),
// tket1 circuit
Err(_) => (
SerialCircuit::from_tket1(circ)?.decode()?,
SerialCircuit::from_tket1(circ)?.decode().convert_pyerrs()?,
CircuitType::Tket1,
),
};
(f)(hugr, typ).map_err(|e| e.into())
(f)(hugr, typ).map_err(|e| e.convert_pyerrs())
}

/// Apply a function expecting a hugr on a python circuit.
Expand All @@ -206,12 +209,12 @@ where
/// The returned Hugr is converted to the matching python object.
pub fn try_update_hugr<E, F>(circ: &PyAny, f: F) -> PyResult<&PyAny>
where
E: Into<PyErr>,
E: ConvertPyErr<Output = PyErr>,
F: FnOnce(Hugr, CircuitType) -> Result<Hugr, E>,
{
let py = circ.py();
try_with_hugr(circ, |hugr, typ| {
let hugr = f(hugr, typ).map_err(|e| e.into())?;
let hugr = f(hugr, typ).map_err(|e| e.convert_pyerrs())?;
typ.convert(py, hugr)
})
}
Expand Down
1 change: 1 addition & 0 deletions tket2-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod optimiser;
pub mod passes;
pub mod pattern;
pub mod rewrite;
pub mod utils;

use pyo3::prelude::*;

Expand Down
10 changes: 5 additions & 5 deletions tket2-py/src/passes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::{cmp::min, convert::TryInto, fs, num::NonZeroUsize, path::PathBuf};
use pyo3::{prelude::*, types::IntoPyDict};
use tket2::{op_matches, passes::apply_greedy_commutation, Circuit, Tk2Op};

use crate::utils::{create_py_exception, ConvertPyErr};
use crate::{
circuit::{try_update_hugr, try_with_hugr},
optimiser::PyBadgerOptimiser,
Expand All @@ -21,18 +22,17 @@ pub fn module(py: Python) -> PyResult<&PyModule> {
m.add_function(wrap_pyfunction!(badger_optimise, m)?)?;
m.add_class::<self::chunks::PyCircuitChunks>()?;
m.add_function(wrap_pyfunction!(self::chunks::chunks, m)?)?;
m.add(
"PullForwardError",
py.get_type::<tket2::passes::PyPullForwardError>(),
)?;
m.add("PullForwardError", py.get_type::<PyPullForwardError>())?;
Ok(m)
}

create_py_exception!(tket2::passes::PullForwardError, PyPullForwardError, "");

#[pyfunction]
fn greedy_depth_reduce(circ: &PyAny) -> PyResult<(&PyAny, u32)> {
let py = circ.py();
try_with_hugr(circ, |mut h, typ| {
let n_moves = apply_greedy_commutation(&mut h)?;
let n_moves = apply_greedy_commutation(&mut h).convert_pyerrs()?;
let circ = typ.convert(py, h)?;
PyResult::Ok((circ, n_moves))
})
Expand Down
3 changes: 2 additions & 1 deletion tket2-py/src/passes/chunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use tket2::Circuit;

use crate::circuit::convert::CircuitType;
use crate::circuit::{try_with_hugr, with_hugr};
use crate::utils::ConvertPyErr;

/// Split a circuit into chunks of a given size.
#[pyfunction]
Expand Down Expand Up @@ -38,7 +39,7 @@ pub struct PyCircuitChunks {
impl PyCircuitChunks {
/// Reassemble the chunks into a circuit.
fn reassemble<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> {
let hugr = self.clone().chunks.reassemble()?;
let hugr = self.clone().chunks.reassemble().convert_pyerrs()?;
self.original_type.convert(py, hugr)
}

Expand Down
25 changes: 17 additions & 8 deletions tket2-py/src/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod portmatching;

use crate::circuit::Tk2Circuit;
use crate::rewrite::PyCircuitRewrite;
use crate::utils::{create_py_exception, ConvertPyErr};

use hugr::Hugr;
use pyo3::prelude::*;
Expand All @@ -18,22 +19,30 @@ pub fn module(py: Python) -> PyResult<&PyModule> {
m.add_class::<self::portmatching::PyPatternMatcher>()?;
m.add_class::<self::portmatching::PyPatternMatch>()?;

m.add(
"InvalidReplacementError",
py.get_type::<hugr::hugr::views::sibling_subgraph::PyInvalidReplacementError>(),
)?;
m.add(
"InvalidPatternError",
py.get_type::<tket2::portmatching::pattern::PyInvalidPatternError>(),
py.get_type::<PyInvalidPatternError>(),
)?;
m.add(
"InvalidReplacementError",
py.get_type::<hugr::hugr::views::sibling_subgraph::PyInvalidReplacementError>(),
py.get_type::<PyInvalidReplacementError>(),
)?;

Ok(m)
}

create_py_exception!(
hugr::hugr::views::sibling_subgraph::InvalidReplacement,
PyInvalidReplacementError,
"Errors that can occur while constructing a HUGR replacement."
);

create_py_exception!(
tket2::portmatching::pattern::InvalidPattern,
PyInvalidPatternError,
"Conversion error from circuit to pattern."
);

#[derive(Clone)]
#[pyclass]
/// A rewrite rule defined by a left hand side and right hand side of an equation.
Expand Down Expand Up @@ -63,7 +72,7 @@ impl RuleMatcher {
rules.into_iter().map(|Rule([l, r])| (l, r)).unzip();
let patterns: Result<Vec<CircuitPattern>, _> =
lefts.iter().map(CircuitPattern::try_from_circuit).collect();
let matcher = PatternMatcher::from_patterns(patterns?);
let matcher = PatternMatcher::from_patterns(patterns.convert_pyerrs()?);

Ok(Self { matcher, rights })
}
Expand All @@ -72,7 +81,7 @@ impl RuleMatcher {
let h = &target.hugr;
if let Some(p_match) = self.matcher.find_matches_iter(h).next() {
let r = self.rights.get(p_match.pattern_id().0).unwrap().clone();
let rw = p_match.to_rewrite(h, r)?;
let rw = p_match.to_rewrite(h, r).convert_pyerrs()?;
Ok(Some(rw.into()))
} else {
Ok(None)
Expand Down
47 changes: 47 additions & 0 deletions tket2-py/src/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//! Utility functions for the python interface.
/// A trait for types wrapping rust errors that may be converted into python exception.
///
/// In addition to raw errors, this is implemented for wrapper types such as `Result`.
/// [`ConvertPyErr::convert_errors`] will be called on the internal error type.
pub trait ConvertPyErr {
/// The output type after conversion.
type Output;

/// Convert any internal errors to python errors.
fn convert_pyerrs(self) -> Self::Output;
}

impl ConvertPyErr for pyo3::PyErr {
type Output = Self;

fn convert_pyerrs(self) -> Self::Output {
self
}
}

impl<T, E> ConvertPyErr for Result<T, E>
where
E: ConvertPyErr,
{
type Output = Result<T, E::Output>;

fn convert_pyerrs(self) -> Self::Output {
self.map_err(|e| e.convert_pyerrs())
}
}

macro_rules! create_py_exception {
($err:path, $py_err:ident, $doc:expr) => {
pyo3::create_exception!(tket2, $py_err, pyo3::exceptions::PyException, $doc);

impl $crate::utils::ConvertPyErr for $err {
type Output = pyo3::PyErr;

fn convert_pyerrs(self) -> Self::Output {
$py_err::new_err(<Self as std::string::ToString>::to_string(&self))
}
}
};
}
pub(crate) use create_py_exception;
17 changes: 0 additions & 17 deletions tket2/src/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ pub mod op;
mod tests;

use hugr::CircuitUnit;
#[cfg(feature = "pyo3")]
use pyo3::{create_exception, exceptions::PyException, PyErr};

use std::path::Path;
use std::{fs, io};
Expand Down Expand Up @@ -104,21 +102,6 @@ pub enum OpConvertError {
NonSerializableInputs(OpType),
}

#[cfg(feature = "pyo3")]
create_exception!(
tket2,
PyOpConvertError,
PyException,
"Error type for conversion between tket2's `Op` and `OpType`"
);

#[cfg(feature = "pyo3")]
impl From<OpConvertError> for PyErr {
fn from(err: OpConvertError) -> Self {
PyOpConvertError::new_err(err.to_string())
}
}

/// Load a TKET1 circuit from a JSON file.
pub fn load_tk1_json_file(path: impl AsRef<Path>) -> Result<Hugr, TK1ConvertError> {
let file = fs::File::open(path)?;
Expand Down
4 changes: 1 addition & 3 deletions tket2/src/passes.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
//! Optimisation passes and related utilities for circuits.
mod commutation;
pub use commutation::apply_greedy_commutation;
#[cfg(feature = "pyo3")]
pub use commutation::PyPullForwardError;
pub use commutation::{apply_greedy_commutation, PullForwardError};

pub mod chunks;
pub use chunks::CircuitChunks;
Loading

0 comments on commit 3e1a68d

Please sign in to comment.