Skip to content

Commit

Permalink
Merge branch 'main' into fix/reduce-hash-memory
Browse files Browse the repository at this point in the history
  • Loading branch information
lmondada committed Sep 27, 2023
2 parents a3d31b0 + be8b9a9 commit 7e3050f
Show file tree
Hide file tree
Showing 19 changed files with 991 additions and 189 deletions.
26 changes: 25 additions & 1 deletion pyrs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
//! Python bindings for TKET2.
#![warn(missing_docs)]
use circuit::try_with_hugr;
use pyo3::prelude::*;
use tket2::{json::TKETDecode, passes::apply_greedy_commutation};
use tket_json_rs::circuit_json::SerialCircuit;

mod circuit;

use pyo3::prelude::*;
#[pyfunction]
fn greedy_depth_reduce(py_c: PyObject) -> PyResult<(PyObject, u32)> {
try_with_hugr(py_c, |mut h| {
let n_moves = apply_greedy_commutation(&mut h)?;
let py_c = SerialCircuit::encode(&h)?.to_tket1()?;
PyResult::Ok((py_c, n_moves))
})
}

/// The Python bindings to TKET2.
#[pymodule]
fn pyrs(py: Python, m: &PyModule) -> PyResult<()> {
add_circuit_module(py, m)?;
add_pattern_module(py, m)?;
add_pass_module(py, m)?;
Ok(())
}

Expand Down Expand Up @@ -54,3 +66,15 @@ fn add_pattern_module(py: Python, parent: &PyModule) -> PyResult<()> {

parent.add_submodule(m)
}

fn add_pass_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "passes")?;
m.add_function(wrap_pyfunction!(greedy_depth_reduce, m)?)?;
m.add_class::<tket2::T2Op>()?;
m.add(
"PullForwardError",
py.get_type::<tket2::passes::PyPullForwardError>(),
)?;
parent.add_submodule(m)?;
Ok(())
}
22 changes: 22 additions & 0 deletions pyrs/test/test_bindings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
from dataclasses import dataclass
from pyrs.pyrs import passes
from pytket.circuit import Circuit


@dataclass
class DepthOptimisePass:
def apply(self, circ: Circuit) -> Circuit:
(circ, n_moves) = passes.greedy_depth_reduce(circ)
return circ


def test_depth_optimise():
c = Circuit(4).CX(0, 2).CX(1, 2).CX(1, 3)

assert c.depth() == 3

c = DepthOptimisePass().apply(c)

assert c.depth() == 2


# from dataclasses import dataclass
# from typing import Callable, Iterable
# import time
Expand Down
153 changes: 0 additions & 153 deletions src/_passes.rs

This file was deleted.

4 changes: 2 additions & 2 deletions src/circuit/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ where
// TODO: `with_wires` combinator for `Units`?
let wire_unit = circ
.linear_units()
.map(|(linear_unit, port, _)| (Wire::new(circ.input(), port), linear_unit))
.map(|(linear_unit, port, _)| (Wire::new(circ.input(), port), linear_unit.index()))
.collect();

let nodes = pv::Topo::new(&circ.as_petgraph());
Expand Down Expand Up @@ -311,7 +311,7 @@ where
// Update the map tracking the linear units
let new_wire = Wire::new(node, port);
self.wire_unit.insert(new_wire, linear_id);
linear_id
LinearUnit::new(linear_id)
})
.collect();

Expand Down
34 changes: 31 additions & 3 deletions src/circuit/units.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,36 @@ use super::Circuit;

/// A linear unit id, used in [`CircuitUnit::Linear`].
// TODO: Add this to hugr?
pub type LinearUnit = usize;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct LinearUnit(usize);

impl LinearUnit {
/// Creates a new [`LinearUnit`].
pub fn new(index: usize) -> Self {
Self(index)
}
/// Returns the index of this [`LinearUnit`].
pub fn index(&self) -> usize {
self.0
}
}

impl From<LinearUnit> for CircuitUnit {
fn from(lu: LinearUnit) -> Self {
CircuitUnit::Linear(lu.index())
}
}

impl TryFrom<CircuitUnit> for LinearUnit {
type Error = ();

fn try_from(cu: CircuitUnit) -> Result<Self, Self::Error> {
match cu {
CircuitUnit::Wire(_) => Err(()),
CircuitUnit::Linear(i) => Ok(LinearUnit(i)),
}
}
}
/// An iterator over the units in the input or output boundary of a [Node].
#[derive(Clone, Debug)]
pub struct Units<UL = DefaultUnitLabeller> {
Expand Down Expand Up @@ -134,7 +162,7 @@ where
let linear_unit =
self.unit_labeller
.assign_linear(self.node, port, self.linear_count - 1);
CircuitUnit::Linear(linear_unit)
CircuitUnit::Linear(linear_unit.index())
} else {
let wire = self.unit_labeller.assign_wire(self.node, port)?;
CircuitUnit::Wire(wire)
Expand Down Expand Up @@ -206,7 +234,7 @@ pub struct DefaultUnitLabeller;
impl UnitLabeller for DefaultUnitLabeller {
#[inline]
fn assign_linear(&self, _: Node, _: Port, linear_count: usize) -> LinearUnit {
linear_count
LinearUnit(linear_count)
}

#[inline]
Expand Down
4 changes: 2 additions & 2 deletions src/circuit/units/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl UnitFilter for Linear {

fn accept(item: (CircuitUnit, Port, Type)) -> Option<Self::Item> {
match item {
(CircuitUnit::Linear(unit), port, typ) => Some((unit, port, typ)),
(CircuitUnit::Linear(unit), port, typ) => Some((LinearUnit::new(unit), port, typ)),
_ => None,
}
}
Expand All @@ -53,7 +53,7 @@ impl UnitFilter for Qubits {
fn accept(item: (CircuitUnit, Port, Type)) -> Option<Self::Item> {
match item {
(CircuitUnit::Linear(unit), port, typ) if typ == prelude::QB_T => {
Some((unit, port, typ))
Some((LinearUnit::new(unit), port, typ))
}
_ => None,
}
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub mod extension;
pub mod json;
pub(crate) mod ops;
pub mod optimiser;
pub mod passes;
pub mod rewrite;
pub use ops::{symbolic_constant_op, Pauli, T2Op};

Expand Down
34 changes: 30 additions & 4 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ pub enum T2Op {
TK1,
}

/// Whether an op is a given T2Op.
pub(crate) fn op_matches(op: &OpType, t2op: T2Op) -> bool {
let Ok(op) = T2Op::try_from(op) else {
return false;
};
op == t2op
}

#[derive(Clone, Copy, Debug, Serialize, Deserialize, EnumIter, Display, PartialEq, PartialOrd)]
#[cfg_attr(feature = "pyo3", pyclass)]
#[allow(missing_docs)]
Expand Down Expand Up @@ -294,17 +302,27 @@ impl TryFrom<OpType> for T2Op {
type Error = NotT2Op;

fn try_from(op: OpType) -> Result<Self, Self::Error> {
let leaf: LeafOp = op.try_into().map_err(|_| NotT2Op)?;
Self::try_from(&op)
}
}

impl TryFrom<&OpType> for T2Op {
type Error = NotT2Op;

fn try_from(op: &OpType) -> Result<Self, Self::Error> {
let OpType::LeafOp(leaf) = op else {
return Err(NotT2Op);
};
leaf.try_into()
}
}

impl TryFrom<LeafOp> for T2Op {
impl TryFrom<&LeafOp> for T2Op {
type Error = NotT2Op;

fn try_from(op: LeafOp) -> Result<Self, Self::Error> {
fn try_from(op: &LeafOp) -> Result<Self, Self::Error> {
match op {
LeafOp::CustomOp(b) => match *b {
LeafOp::CustomOp(b) => match b.as_ref() {
ExternalOp::Extension(e) => Self::try_from_op_def(e.def()),
ExternalOp::Opaque(o) => from_extension_name(o.extension(), o.name()),
},
Expand All @@ -313,6 +331,14 @@ impl TryFrom<LeafOp> for T2Op {
}
}

impl TryFrom<LeafOp> for T2Op {
type Error = NotT2Op;

fn try_from(op: LeafOp) -> Result<Self, Self::Error> {
Self::try_from(&op)
}
}

/// load all variants of a `SimpleOpEnum` in to an extension as op defs.
fn load_all_ops<T: SimpleOpEnum>(extension: &mut Extension) -> Result<(), ExtensionBuildError> {
for op in T::all_variants() {
Expand Down
3 changes: 3 additions & 0 deletions src/optimiser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@
//! Currently, the only optimiser is TASO
pub mod taso;

#[cfg(feature = "portmatching")]
pub use taso::DefaultTasoOptimiser;
pub use taso::TasoOptimiser;
Loading

0 comments on commit 7e3050f

Please sign in to comment.