Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: python errors and some pyrs helpers #114

Merged
merged 7 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ members = ["pyrs", "compile-matcher", "taso-optimiser"]

[workspace.dependencies]

quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "bc9692b" }
quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "660fef6e" }
portgraph = { version = "0.9", features = ["serde"] }
pyo3 = { version = "0.19" }
itertools = { version = "0.11.0" }
Expand Down
54 changes: 54 additions & 0 deletions pyrs/src/circuit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
//! Circuit-related functionality and utilities.
#![allow(unused)]

use pyo3::prelude::*;

use hugr::{Hugr, HugrView};
use tket2::extension::REGISTRY;
use tket2::json::TKETDecode;
use tket_json_rs::circuit_json::SerialCircuit;

/// Apply a fallible function expecting a hugr on a pytket circuit.
pub fn try_with_hugr<T, E, F>(circ: Py<PyAny>, f: F) -> PyResult<T>
where
E: Into<PyErr>,
F: FnOnce(Hugr) -> Result<T, E>,
{
let hugr = SerialCircuit::_from_tket1(circ).decode()?;
(f)(hugr).map_err(|e| e.into())
}

/// Apply a function expecting a hugr on a pytket circuit.
pub fn with_hugr<T, F: FnOnce(Hugr) -> T>(circ: Py<PyAny>, f: F) -> PyResult<T> {
try_with_hugr(circ, |hugr| Ok::<T, PyErr>((f)(hugr)))
}

/// Apply a hugr-to-hugr function on a pytket circuit, and return the modified circuit.
pub fn try_update_hugr<E: Into<PyErr>, F: FnOnce(Hugr) -> Result<Hugr, E>>(
circ: Py<PyAny>,
f: F,
) -> PyResult<Py<PyAny>> {
let hugr = try_with_hugr(circ, f)?;
SerialCircuit::encode(&hugr)?.to_tket1()
}

/// Apply a hugr-to-hugr function on a pytket circuit, and return the modified circuit.
pub fn update_hugr<F: FnOnce(Hugr) -> Hugr>(circ: Py<PyAny>, f: F) -> PyResult<Py<PyAny>> {
let hugr = with_hugr(circ, f)?;
SerialCircuit::encode(&hugr)?.to_tket1()
}

#[pyfunction]
pub fn validate_hugr(c: Py<PyAny>) -> PyResult<()> {
try_with_hugr(c, |hugr| hugr.validate(&REGISTRY))
}

#[pyfunction]
pub fn to_hugr_dot(c: Py<PyAny>) -> PyResult<String> {
with_hugr(c, |hugr| hugr.dot_string())
}

#[pyfunction]
pub fn to_hugr(c: Py<PyAny>) -> PyResult<Hugr> {
with_hugr(c, |hugr| hugr)
}
Comment on lines +42 to +54
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

55 changes: 47 additions & 8 deletions pyrs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,56 @@
//! Python bindings for TKET2.
#![warn(missing_docs)]

mod circuit;

use pyo3::prelude::*;
use tket2::portmatching::{CircuitPattern, PatternMatcher};

/// The Python bindings to TKET2.
#[pymodule]
fn pyrs(py: Python, m: &PyModule) -> PyResult<()> {
add_patterns_module(py, m)?;
add_circuit_module(py, m)?;
add_pattern_module(py, m)?;
Ok(())
}

fn add_patterns_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "patterns")?;
m.add_class::<CircuitPattern>()?;
m.add_class::<PatternMatcher>()?;
parent.add_submodule(m)?;
Ok(())
/// circuit module
fn add_circuit_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "circuit")?;
m.add_class::<tket2::T2Op>()?;
m.add_class::<tket2::Pauli>()?;

m.add("HugrError", py.get_type::<hugr::hugr::PyHugrError>())?;
m.add("BuildError", py.get_type::<hugr::builder::PyBuildError>())?;
m.add(
"ValidationError",
py.get_type::<hugr::hugr::validate::PyValidationError>(),
)?;
m.add(
"HUGRSerializationError",
py.get_type::<hugr::hugr::serialize::PyHUGRSerializationError>(),
)?;
m.add(
"OpConvertError",
py.get_type::<tket2::json::PyOpConvertError>(),
)?;

parent.add_submodule(m)
}

/// portmatching module
fn add_pattern_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "pattern")?;
m.add_class::<tket2::portmatching::CircuitPattern>()?;
m.add_class::<tket2::portmatching::PatternMatcher>()?;

m.add(
"InvalidPatternError",
py.get_type::<tket2::portmatching::pattern::PyInvalidPatternError>(),
)?;
m.add(
"InvalidReplacementError",
py.get_type::<hugr::hugr::views::sibling_subgraph::PyInvalidReplacementError>(),
)?;

parent.add_submodule(m)
}
18 changes: 18 additions & 0 deletions src/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ pub mod op;
#[cfg(test)]
mod tests;

#[cfg(feature = "pyo3")]
use pyo3::{create_exception, exceptions::PyException, PyErr};

use std::path::Path;
use std::{fs, io};

Expand Down Expand Up @@ -85,6 +88,21 @@ pub enum OpConvertError {
NonSerializableInputs(OpType),
}

#[cfg(feature = "pyo3")]
create_exception!(
pyrs,
PyOpConvertError,
PyException,
"Error type for conversion between tket2's `Op` and `OpType`"
);

#[cfg(feature = "pyo3")]
impl From<OpConvertError> for PyErr {
fn from(err: OpConvertError) -> Self {
PyOpConvertError::new_err(err.to_string())
}
}

/// Load a TKET1 circuit from a JSON file.
pub fn load_tk1_json_file(path: impl AsRef<Path>) -> Result<Hugr, TK1ConvertError> {
let file = fs::File::open(path)?;
Expand Down
7 changes: 5 additions & 2 deletions src/json/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use std::hash::{Hash, Hasher};
use std::mem;

use hugr::builder::{CircuitBuilder, Container, DFGBuilder, Dataflow, DataflowHugr};
use hugr::extension::prelude::QB_T;
use hugr::extension::prelude::{PRELUDE_ID, QB_T};
use hugr::extension::ExtensionSet;

use hugr::hugr::CircuitUnit;
use hugr::ops::Const;
Expand Down Expand Up @@ -143,7 +144,9 @@ impl JsonDecoder {
Some(c) => {
let const_type = FLOAT64_TYPE;
let const_op = Const::new(c, const_type).unwrap();
self.hugr.add_load_const(const_op).unwrap()
self.hugr
.add_load_const(const_op, ExtensionSet::singleton(&PRELUDE_ID))
.unwrap()
}
None => {
// store string in custom op.
Expand Down
6 changes: 6 additions & 0 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ use strum::IntoEnumIterator;
use strum_macros::{Display, EnumIter, EnumString, IntoStaticStr};
use thiserror::Error;

#[cfg(feature = "pyo3")]
use pyo3::pyclass;

/// Name of tket 2 extension.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("quantum.tket2");

Expand All @@ -41,6 +44,7 @@ pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("quantum.tket2"
IntoStaticStr,
EnumString,
)]
#[cfg_attr(feature = "pyo3", pyclass)]
#[allow(missing_docs)]
/// Simple enum of tket 2 quantum operations.
pub enum T2Op {
Expand All @@ -62,7 +66,9 @@ pub enum T2Op {
AngleAdd,
TK1,
}

#[derive(Clone, Copy, Debug, Serialize, Deserialize, EnumIter, Display, PartialEq, PartialOrd)]
#[cfg_attr(feature = "pyo3", pyclass)]
#[allow(missing_docs)]
/// Simple enum representation of Pauli matrices.
pub enum Pauli {
Expand Down
2 changes: 1 addition & 1 deletion src/portmatching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
pub mod matcher;
pub mod pattern;
#[cfg(feature = "pyo3")]
mod pyo3;
pub mod pyo3;

pub use matcher::{PatternMatch, PatternMatcher};
pub use pattern::CircuitPattern;
Expand Down
17 changes: 16 additions & 1 deletion src/portmatching/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use super::{
use crate::circuit::Circuit;

#[cfg(feature = "pyo3")]
use pyo3::prelude::*;
use pyo3::{create_exception, exceptions::PyException, pyclass, PyErr};

/// A pattern that match a circuit exactly
#[cfg_attr(feature = "pyo3", pyclass)]
Expand Down Expand Up @@ -119,6 +119,21 @@ impl From<NoRootFound> for InvalidPattern {
}
}

#[cfg(feature = "pyo3")]
create_exception!(
pyrs,
PyInvalidPatternError,
PyException,
"Invalid circuit pattern"
);

#[cfg(feature = "pyo3")]
impl From<InvalidPattern> for PyErr {
fn from(err: InvalidPattern) -> Self {
PyInvalidPatternError::new_err(err.to_string())
}
}

#[cfg(test)]
mod tests {
use hugr::Hugr;
Expand Down
34 changes: 16 additions & 18 deletions src/portmatching/pyo3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,29 @@
use std::fmt;

use derive_more::{From, Into};
use hugr::hugr::views::sibling_subgraph::PyInvalidReplacementError;
use hugr::hugr::views::{DescendantsGraph, HierarchyView};
use hugr::ops::handle::DfgID;
use hugr::{Hugr, HugrView, Port};
use itertools::Itertools;
use portmatching::{HashMap, PatternID};
use pyo3::{create_exception, exceptions::PyException, prelude::*, types::PyIterator};
use pyo3::{prelude::*, types::PyIterator};
use tket_json_rs::circuit_json::SerialCircuit;

use super::{CircuitPattern, PatternMatch, PatternMatcher};
use crate::circuit::Circuit;
use crate::json::TKETDecode;
use crate::rewrite::CircuitRewrite;

create_exception!(pyrs, PyValidateError, PyException);
create_exception!(pyrs, PyInvalidReplacement, PyException);
create_exception!(pyrs, PyInvalidPattern, PyException);

#[pymethods]
impl CircuitPattern {
/// Construct a pattern from a TKET1 circuit
#[new]
pub fn py_from_circuit(circ: PyObject) -> PyResult<CircuitPattern> {
let hugr = pyobj_as_hugr(circ)?;
let circ = hugr_as_view(&hugr);
CircuitPattern::try_from_circuit(&circ)
.map_err(|e| PyInvalidPattern::new_err(e.to_string()))
let pattern = CircuitPattern::try_from_circuit(&circ)?;
Ok(pattern)
}

/// A string representation of the pattern.
Expand Down Expand Up @@ -64,7 +61,7 @@ impl PatternMatcher {
.map(|m| {
let pattern_id = m.pattern_id();
PyPatternMatch::try_from_rust(m, &circ, self).map_err(|e| {
PyInvalidReplacement::new_err(format!(
PyInvalidReplacementError::new_err(format!(
"Invalid match for pattern {:?}: {}",
pattern_id, e
))
Expand All @@ -74,7 +71,7 @@ impl PatternMatcher {
}
}

/// Python equivalent of [`CircuitMatch`].
/// Python equivalent of [`PatternMatch`].
///
/// A convex pattern match in a circuit, available from Python.
///
Expand All @@ -86,7 +83,9 @@ impl PatternMatcher {
/// over efficiency. It is provided for convenience and not recommended when
/// performance is a key concern.
///
/// TODO: can this be a wrapper for a [`CircuitMatch`] instead?
/// TODO: can this be a wrapper for a [`PatternMatch`] instead?
///
/// [`PatternMatch`]: crate::portmatching::matcher::PatternMatch
#[pyclass]
#[derive(Debug, Clone)]
pub struct PyPatternMatch {
Expand Down Expand Up @@ -117,7 +116,7 @@ impl PyPatternMatch {
}

impl PyPatternMatch {
/// Construct a [`PyCircuitMatch`] from a [`PatternMatch`].
/// Construct a [`PyPatternMatch`] from a [`PatternMatch`].
///
/// Requires references to the circuit and pattern to resolve indices
/// into these objects.
Expand All @@ -132,7 +131,7 @@ impl PyPatternMatch {

let node_map: HashMap<Node, Node> = pattern
.get_match_map(root.0, circ)
.ok_or_else(|| PyInvalidReplacement::new_err("Invalid match"))?
.ok_or_else(|| PyInvalidReplacementError::new_err("Invalid match"))?
.into_iter()
.map(|(p, c)| (Node(p), Node(c)))
.collect();
Expand All @@ -159,6 +158,7 @@ impl PyPatternMatch {
})
}

/// Convert the pattern into a [`CircuitRewrite`].
pub fn to_rewrite(&self, circ: PyObject, replacement: PyObject) -> PyResult<CircuitRewrite> {
let hugr = pyobj_as_hugr(circ)?;
let circ = hugr_as_view(&hugr);
Expand All @@ -168,16 +168,16 @@ impl PyPatternMatch {
.map(|p| p.iter().map(|&(n, p)| (n.0, p)).collect())
.collect();
let outputs = self.outputs.iter().map(|&(n, p)| (n.0, p)).collect();
PatternMatch::try_from_io(
let rewrite = PatternMatch::try_from_io(
self.root.0,
PatternID(self.pattern_id),
&circ,
inputs,
outputs,
)
.expect("Invalid PyCircuitMatch object")
.to_rewrite(&hugr, pyobj_as_hugr(replacement)?)
.map_err(|e| PyInvalidReplacement::new_err(e.to_string()))
.to_rewrite(&hugr, pyobj_as_hugr(replacement)?)?;
Ok(rewrite)
}
}

Expand All @@ -204,9 +204,7 @@ impl Node {

fn pyobj_as_hugr(circ: PyObject) -> PyResult<Hugr> {
let ser_c = SerialCircuit::_from_tket1(circ);
let hugr: Hugr = ser_c
.decode()
.map_err(|e| PyValidateError::new_err(e.to_string()))?;
let hugr: Hugr = ser_c.decode()?;
Ok(hugr)
}

Expand Down
2 changes: 1 addition & 1 deletion taso-optimiser/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use tket_json_rs::circuit_json::SerialCircuit;

/// Optimise circuits using Quartz-generated ECCs.
///
/// Quartz: https://github.com/quantum-compiler/quartz
/// Quartz: <https://github.com/quantum-compiler/quartz>
#[derive(Parser, Debug)]
#[clap(version = "1.0", long_about = None)]
#[clap(about = "Optimise circuits using Quartz-generated ECCs.")]
Expand Down