Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement PyErr conversion locally in tket2-py #258

Merged
merged 3 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tket2-py/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ tket2 = { workspace = true, features = ["pyo3", "portmatching"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tket-json-rs = { workspace = true, features = ["pyo3"] }
quantinuum-hugr = { workspace = true, features = ["pyo3"] }
portgraph = { workspace = true, features = ["pyo3", "serde"] }
quantinuum-hugr = { workspace = true }
portgraph = { workspace = true, features = ["serde"] }
pyo3 = { workspace = true, features = ["extension-module"] }
num_cpus = "1.16.0"
derive_more = "0.99.17"
Expand Down
48 changes: 37 additions & 11 deletions tket2-py/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use tket2::json::TKETDecode;
use tket2::rewrite::CircuitRewrite;
use tket_json_rs::circuit_json::SerialCircuit;

use crate::utils::create_py_exception;

pub use self::convert::{try_update_hugr, try_with_hugr, update_hugr, with_hugr, Tk2Circuit};
pub use self::cost::PyCircuitCost;
pub use tket2::{Pauli, Tk2Op};
Expand All @@ -30,24 +32,48 @@ pub fn module(py: Python) -> PyResult<&PyModule> {
m.add_function(wrap_pyfunction!(validate_hugr, m)?)?;
m.add_function(wrap_pyfunction!(to_hugr_dot, m)?)?;

m.add("HugrError", py.get_type::<hugr::hugr::PyHugrError>())?;
m.add("BuildError", py.get_type::<hugr::builder::PyBuildError>())?;
m.add(
"ValidationError",
py.get_type::<hugr::hugr::validate::PyValidationError>(),
)?;
m.add("HugrError", py.get_type::<PyHugrError>())?;
m.add("BuildError", py.get_type::<PyBuildError>())?;
m.add("ValidationError", py.get_type::<PyValidationError>())?;
m.add(
"HUGRSerializationError",
py.get_type::<hugr::hugr::serialize::PyHUGRSerializationError>(),
)?;
m.add(
"OpConvertError",
py.get_type::<tket2::json::PyOpConvertError>(),
py.get_type::<PyHUGRSerializationError>(),
)?;
m.add("OpConvertError", py.get_type::<PyOpConvertError>())?;

Ok(m)
}

create_py_exception!(
hugr::hugr::HugrError,
PyHugrError,
"Errors that can occur while manipulating a HUGR."
);

create_py_exception!(
hugr::builder::BuildError,
PyBuildError,
"Error while building the HUGR."
);

create_py_exception!(
hugr::hugr::validate::ValidationError,
PyValidationError,
"Errors that can occur while validating a Hugr."
);

create_py_exception!(
hugr::hugr::serialize::HUGRSerializationError,
PyHUGRSerializationError,
"Errors that can occur while serializing a HUGR."
);

create_py_exception!(
tket2::json::OpConvertError,
PyOpConvertError,
"Error type for the conversion between tket2 and tket1 operations."
);

/// Run the validation checks on a circuit.
#[pyfunction]
pub fn validate_hugr(c: &PyAny) -> PyResult<()> {
Expand Down
21 changes: 12 additions & 9 deletions tket2-py/src/circuit/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use tket2::{Circuit, Tk2Op};
use tket_json_rs::circuit_json::SerialCircuit;

use crate::rewrite::PyCircuitRewrite;
use crate::utils::ConvertPyErr;

use super::{cost, PyCircuitCost};

Expand Down Expand Up @@ -57,7 +58,9 @@ impl Tk2Circuit {

/// Convert the [`Tk2Circuit`] to a tket1 circuit.
pub fn to_tket1<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> {
SerialCircuit::encode(&self.hugr)?.to_tket1(py)
SerialCircuit::encode(&self.hugr)
.convert_pyerrs()?
.to_tket1(py)
}

/// Apply a rewrite on the circuit.
Expand Down Expand Up @@ -85,7 +88,7 @@ impl Tk2Circuit {
/// FIXME: Currently the encoded circuit cannot be loaded back due to
/// [https://github.com/CQCL/hugr/issues/683]
pub fn to_tket1_json(&self) -> PyResult<String> {
Ok(serde_json::to_string(&SerialCircuit::encode(&self.hugr)?).unwrap())
Ok(serde_json::to_string(&SerialCircuit::encode(&self.hugr).convert_pyerrs()?).unwrap())
}

/// Decode a tket1 json string to a circuit.
Expand All @@ -94,7 +97,7 @@ impl Tk2Circuit {
let tk1: SerialCircuit = serde_json::from_str(json)
.map_err(|e| PyErr::new::<PyAttributeError, _>(format!("Invalid encoded HUGR: {e}")))?;
Ok(Tk2Circuit {
hugr: tk1.decode()?,
hugr: tk1.decode().convert_pyerrs()?,
})
}

Expand Down Expand Up @@ -164,7 +167,7 @@ impl CircuitType {
/// Converts a `Hugr` into the format indicated by the flag.
pub fn convert(self, py: Python, hugr: Hugr) -> PyResult<&PyAny> {
match self {
CircuitType::Tket1 => SerialCircuit::encode(&hugr)?.to_tket1(py),
CircuitType::Tket1 => SerialCircuit::encode(&hugr).convert_pyerrs()?.to_tket1(py),
CircuitType::Tket2 => Ok(Py::new(py, Tk2Circuit { hugr })?.into_ref(py)),
}
}
Expand All @@ -175,19 +178,19 @@ impl CircuitType {
/// This method supports both `pytket.Circuit` and `Tk2Circuit` python objects.
pub fn try_with_hugr<T, E, F>(circ: &PyAny, f: F) -> PyResult<T>
where
E: Into<PyErr>,
E: ConvertPyErr<Output = PyErr>,
F: FnOnce(Hugr, CircuitType) -> Result<T, E>,
{
let (hugr, typ) = match Tk2Circuit::extract(circ) {
// hugr circuit
Ok(t2circ) => (t2circ.hugr, CircuitType::Tket2),
// tket1 circuit
Err(_) => (
SerialCircuit::from_tket1(circ)?.decode()?,
SerialCircuit::from_tket1(circ)?.decode().convert_pyerrs()?,
CircuitType::Tket1,
),
};
(f)(hugr, typ).map_err(|e| e.into())
(f)(hugr, typ).map_err(|e| e.convert_pyerrs())
}

/// Apply a function expecting a hugr on a python circuit.
Expand All @@ -206,12 +209,12 @@ where
/// The returned Hugr is converted to the matching python object.
pub fn try_update_hugr<E, F>(circ: &PyAny, f: F) -> PyResult<&PyAny>
where
E: Into<PyErr>,
E: ConvertPyErr<Output = PyErr>,
F: FnOnce(Hugr, CircuitType) -> Result<Hugr, E>,
{
let py = circ.py();
try_with_hugr(circ, |hugr, typ| {
let hugr = f(hugr, typ).map_err(|e| e.into())?;
let hugr = f(hugr, typ).map_err(|e| e.convert_pyerrs())?;
typ.convert(py, hugr)
})
}
Expand Down
1 change: 1 addition & 0 deletions tket2-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod optimiser;
pub mod passes;
pub mod pattern;
pub mod rewrite;
pub mod utils;

use pyo3::prelude::*;

Expand Down
10 changes: 5 additions & 5 deletions tket2-py/src/passes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::{cmp::min, convert::TryInto, fs, num::NonZeroUsize, path::PathBuf};
use pyo3::{prelude::*, types::IntoPyDict};
use tket2::{op_matches, passes::apply_greedy_commutation, Circuit, Tk2Op};

use crate::utils::{create_py_exception, ConvertPyErr};
use crate::{
circuit::{try_update_hugr, try_with_hugr},
optimiser::PyBadgerOptimiser,
Expand All @@ -21,18 +22,17 @@ pub fn module(py: Python) -> PyResult<&PyModule> {
m.add_function(wrap_pyfunction!(badger_optimise, m)?)?;
m.add_class::<self::chunks::PyCircuitChunks>()?;
m.add_function(wrap_pyfunction!(self::chunks::chunks, m)?)?;
m.add(
"PullForwardError",
py.get_type::<tket2::passes::PyPullForwardError>(),
)?;
m.add("PullForwardError", py.get_type::<PyPullForwardError>())?;
Ok(m)
}

create_py_exception!(tket2::passes::PullForwardError, PyPullForwardError, "");

#[pyfunction]
fn greedy_depth_reduce(circ: &PyAny) -> PyResult<(&PyAny, u32)> {
let py = circ.py();
try_with_hugr(circ, |mut h, typ| {
let n_moves = apply_greedy_commutation(&mut h)?;
let n_moves = apply_greedy_commutation(&mut h).convert_pyerrs()?;
let circ = typ.convert(py, h)?;
PyResult::Ok((circ, n_moves))
})
Expand Down
3 changes: 2 additions & 1 deletion tket2-py/src/passes/chunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use tket2::Circuit;

use crate::circuit::convert::CircuitType;
use crate::circuit::{try_with_hugr, with_hugr};
use crate::utils::ConvertPyErr;

/// Split a circuit into chunks of a given size.
#[pyfunction]
Expand Down Expand Up @@ -38,7 +39,7 @@ pub struct PyCircuitChunks {
impl PyCircuitChunks {
/// Reassemble the chunks into a circuit.
fn reassemble<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> {
let hugr = self.clone().chunks.reassemble()?;
let hugr = self.clone().chunks.reassemble().convert_pyerrs()?;
self.original_type.convert(py, hugr)
}

Expand Down
25 changes: 17 additions & 8 deletions tket2-py/src/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod portmatching;

use crate::circuit::Tk2Circuit;
use crate::rewrite::PyCircuitRewrite;
use crate::utils::{create_py_exception, ConvertPyErr};

use hugr::Hugr;
use pyo3::prelude::*;
Expand All @@ -18,22 +19,30 @@ pub fn module(py: Python) -> PyResult<&PyModule> {
m.add_class::<self::portmatching::PyPatternMatcher>()?;
m.add_class::<self::portmatching::PyPatternMatch>()?;

m.add(
"InvalidReplacementError",
py.get_type::<hugr::hugr::views::sibling_subgraph::PyInvalidReplacementError>(),
)?;
m.add(
"InvalidPatternError",
py.get_type::<tket2::portmatching::pattern::PyInvalidPatternError>(),
py.get_type::<PyInvalidPatternError>(),
)?;
m.add(
"InvalidReplacementError",
py.get_type::<hugr::hugr::views::sibling_subgraph::PyInvalidReplacementError>(),
py.get_type::<PyInvalidReplacementError>(),
)?;

Ok(m)
}

create_py_exception!(
hugr::hugr::views::sibling_subgraph::InvalidReplacement,
PyInvalidReplacementError,
"Errors that can occur while constructing a HUGR replacement."
);

create_py_exception!(
tket2::portmatching::pattern::InvalidPattern,
PyInvalidPatternError,
"Conversion error from circuit to pattern."
);

#[derive(Clone)]
#[pyclass]
/// A rewrite rule defined by a left hand side and right hand side of an equation.
Expand Down Expand Up @@ -63,7 +72,7 @@ impl RuleMatcher {
rules.into_iter().map(|Rule([l, r])| (l, r)).unzip();
let patterns: Result<Vec<CircuitPattern>, _> =
lefts.iter().map(CircuitPattern::try_from_circuit).collect();
let matcher = PatternMatcher::from_patterns(patterns?);
let matcher = PatternMatcher::from_patterns(patterns.convert_pyerrs()?);

Ok(Self { matcher, rights })
}
Expand All @@ -72,7 +81,7 @@ impl RuleMatcher {
let h = &target.hugr;
if let Some(p_match) = self.matcher.find_matches_iter(h).next() {
let r = self.rights.get(p_match.pattern_id().0).unwrap().clone();
let rw = p_match.to_rewrite(h, r)?;
let rw = p_match.to_rewrite(h, r).convert_pyerrs()?;
Ok(Some(rw.into()))
} else {
Ok(None)
Expand Down
47 changes: 47 additions & 0 deletions tket2-py/src/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//! Utility functions for the python interface.

/// A trait for types wrapping rust errors that may be converted into python exception.
///
/// In addition to raw errors, this is implemented for wrapper types such as `Result`.
/// [`ConvertPyErr::convert_errors`] will be called on the internal error type.
pub trait ConvertPyErr {
/// The output type after conversion.
type Output;

/// Convert any internal errors to python errors.
fn convert_pyerrs(self) -> Self::Output;
}

impl ConvertPyErr for pyo3::PyErr {
type Output = Self;

fn convert_pyerrs(self) -> Self::Output {
self
}
}

impl<T, E> ConvertPyErr for Result<T, E>
where
E: ConvertPyErr,
{
type Output = Result<T, E::Output>;

fn convert_pyerrs(self) -> Self::Output {
self.map_err(|e| e.convert_pyerrs())
}
}

macro_rules! create_py_exception {
($err:path, $py_err:ident, $doc:expr) => {
pyo3::create_exception!(tket2, $py_err, pyo3::exceptions::PyException, $doc);

impl $crate::utils::ConvertPyErr for $err {
type Output = pyo3::PyErr;

fn convert_pyerrs(self) -> Self::Output {
$py_err::new_err(<Self as std::string::ToString>::to_string(&self))
}
}
};
}
pub(crate) use create_py_exception;
17 changes: 0 additions & 17 deletions tket2/src/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ pub mod op;
mod tests;

use hugr::CircuitUnit;
#[cfg(feature = "pyo3")]
use pyo3::{create_exception, exceptions::PyException, PyErr};

use std::path::Path;
use std::{fs, io};
Expand Down Expand Up @@ -104,21 +102,6 @@ pub enum OpConvertError {
NonSerializableInputs(OpType),
}

#[cfg(feature = "pyo3")]
create_exception!(
tket2,
PyOpConvertError,
PyException,
"Error type for conversion between tket2's `Op` and `OpType`"
);

#[cfg(feature = "pyo3")]
impl From<OpConvertError> for PyErr {
fn from(err: OpConvertError) -> Self {
PyOpConvertError::new_err(err.to_string())
}
}

/// Load a TKET1 circuit from a JSON file.
pub fn load_tk1_json_file(path: impl AsRef<Path>) -> Result<Hugr, TK1ConvertError> {
let file = fs::File::open(path)?;
Expand Down
4 changes: 1 addition & 3 deletions tket2/src/passes.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
//! Optimisation passes and related utilities for circuits.

mod commutation;
pub use commutation::apply_greedy_commutation;
#[cfg(feature = "pyo3")]
pub use commutation::PyPullForwardError;
pub use commutation::{apply_greedy_commutation, PullForwardError};

pub mod chunks;
pub use chunks::CircuitChunks;
Loading