From 969d46329a85980efceae77c9416a99ad2908d54 Mon Sep 17 00:00:00 2001 From: Luca Mondada <72734770+lmondada@users.noreply.github.com> Date: Tue, 26 Sep 2023 19:27:42 +0200 Subject: [PATCH 1/3] feat!: Use CX count as default cost function (#134) --- src/ops.rs | 34 ++++++++++++-- src/optimiser.rs | 3 ++ src/optimiser/taso.rs | 27 +++++++++-- src/rewrite/strategy.rs | 93 +++++++++++++++++++++++++++++++------- taso-optimiser/src/main.rs | 6 +-- 5 files changed, 135 insertions(+), 28 deletions(-) diff --git a/src/ops.rs b/src/ops.rs index 3d5bd9e8..1c1cb69a 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -71,6 +71,14 @@ pub enum T2Op { TK1, } +/// Whether an op is a given T2Op. +pub(crate) fn op_matches(op: &OpType, t2op: T2Op) -> bool { + let Ok(op) = T2Op::try_from(op) else { + return false; + }; + op == t2op +} + #[derive(Clone, Copy, Debug, Serialize, Deserialize, EnumIter, Display, PartialEq, PartialOrd)] #[cfg_attr(feature = "pyo3", pyclass)] #[allow(missing_docs)] @@ -294,17 +302,27 @@ impl TryFrom<OpType> for T2Op { type Error = NotT2Op; fn try_from(op: OpType) -> Result<Self, Self::Error> { - let leaf: LeafOp = op.try_into().map_err(|_| NotT2Op)?; + Self::try_from(&op) + } +} + +impl TryFrom<&OpType> for T2Op { + type Error = NotT2Op; + + fn try_from(op: &OpType) -> Result<Self, Self::Error> { + let OpType::LeafOp(leaf) = op else { + return Err(NotT2Op); + }; leaf.try_into() } } -impl TryFrom<LeafOp> for T2Op { +impl TryFrom<&LeafOp> for T2Op { type Error = NotT2Op; - fn try_from(op: LeafOp) -> Result<Self, Self::Error> { + fn try_from(op: &LeafOp) -> Result<Self, Self::Error> { match op { - LeafOp::CustomOp(b) => match *b { + LeafOp::CustomOp(b) => match b.as_ref() { ExternalOp::Extension(e) => Self::try_from_op_def(e.def()), ExternalOp::Opaque(o) => from_extension_name(o.extension(), o.name()), }, @@ -313,6 +331,14 @@ impl TryFrom<LeafOp> for T2Op { } } +impl TryFrom<LeafOp> for T2Op { + type Error = NotT2Op; + + fn try_from(op: LeafOp) -> Result<Self, Self::Error> { + Self::try_from(&op) + } +} + /// load all variants of a `SimpleOpEnum` in to an extension as op defs. fn load_all_ops<T: SimpleOpEnum>(extension: &mut Extension) -> Result<(), ExtensionBuildError> { for op in T::all_variants() { diff --git a/src/optimiser.rs b/src/optimiser.rs index 711caab6..0760d0d3 100644 --- a/src/optimiser.rs +++ b/src/optimiser.rs @@ -3,4 +3,7 @@ //! Currently, the only optimiser is TASO pub mod taso; + +#[cfg(feature = "portmatching")] +pub use taso::DefaultTasoOptimiser; pub use taso::TasoOptimiser; diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index ad88bc51..2c7f4982 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -293,20 +293,39 @@ where #[cfg(feature = "portmatching")] mod taso_default { - use crate::circuit::Circuit; + use hugr::ops::OpType; + use hugr::HugrView; + + use crate::ops::op_matches; use crate::rewrite::strategy::ExhaustiveRewriteStrategy; use crate::rewrite::ECCRewriter; + use crate::T2Op; use super::*; - impl TasoOptimiser<ECCRewriter, ExhaustiveRewriteStrategy, fn(&Hugr) -> usize> { + /// The default TASO optimiser using ECC sets. + pub type DefaultTasoOptimiser = TasoOptimiser< + ECCRewriter, + ExhaustiveRewriteStrategy<fn(&OpType) -> bool>, + fn(&Hugr) -> usize, + >; + + impl DefaultTasoOptimiser { /// A sane default optimiser using the given ECC sets. pub fn default_with_eccs_json_file( eccs_path: impl AsRef<std::path::Path>, ) -> io::Result<Self> { let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?; - let strategy = ExhaustiveRewriteStrategy::default(); - Ok(Self::new(rewriter, strategy, |c| c.num_gates())) + let strategy = ExhaustiveRewriteStrategy::exhaustive_cx(); + Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates)) } } + + fn num_cx_gates(circ: &Hugr) -> usize { + circ.nodes() + .filter(|&n| op_matches(circ.get_optype(n), T2Op::CX)) + .count() + } } +#[cfg(feature = "portmatching")] +pub use taso_default::DefaultTasoOptimiser; diff --git a/src/rewrite/strategy.rs b/src/rewrite/strategy.rs index 29d5a934..3df7ceeb 100644 --- a/src/rewrite/strategy.rs +++ b/src/rewrite/strategy.rs @@ -10,10 +10,10 @@ use std::collections::HashSet; -use hugr::Hugr; +use hugr::{ops::OpType, Hugr, HugrView, Node}; use itertools::Itertools; -use crate::circuit::Circuit; +use crate::{ops::op_matches, T2Op}; use super::CircuitRewrite; @@ -79,27 +79,62 @@ impl RewriteStrategy for GreedyRewriteStrategy { /// circuit. /// /// The parameter gamma controls how greedy the algorithm should be. It allows -/// a rewrite C1 -> C2 if C2 has at most gamma times as many gates as C1: +/// a rewrite C1 -> C2 if C2 has at most gamma times the cost of C1: /// -/// $|C2| < gamma * |C1|$ +/// $cost(C2) < gamma * cost(C1)$ +/// +/// The cost function is given by the number of operations in the circuit that +/// satisfy a given Op predicate. This allows for instance to use the total +/// number of gates (true predicate) or the number of CX gates as cost function. /// /// gamma = 1 is the greedy strategy where a rewrite is only allowed if it -/// strictly reduces the gate count. The default is gamma = 1.0001, as set -/// in the Quartz paper. This essentially allows rewrites that improve or leave -/// the number of nodes unchanged. +/// strictly reduces the gate count. The default is gamma = 1.0001 (as set in +/// the Quartz paper) and the number of CX gates. This essentially allows +/// rewrites that improve or leave the number of CX unchanged. #[derive(Debug, Clone)] -pub struct ExhaustiveRewriteStrategy { +pub struct ExhaustiveRewriteStrategy<P> { /// The gamma parameter. pub gamma: f64, + /// Ops to count for cost function. + pub op_predicate: P, +} + +impl<P> ExhaustiveRewriteStrategy<P> { + /// New exhaustive rewrite strategy with provided predicate. + /// + /// The gamma parameter is set to the default 1.0001. + pub fn with_predicate(op_predicate: P) -> Self { + Self { + gamma: 1.0001, + op_predicate, + } + } + + /// New exhaustive rewrite strategy with provided gamma and predicate. + pub fn new(gamma: f64, op_predicate: P) -> Self { + Self { + gamma, + op_predicate, + } + } } -impl Default for ExhaustiveRewriteStrategy { - fn default() -> Self { - Self { gamma: 1.0001 } +impl ExhaustiveRewriteStrategy<fn(&OpType) -> bool> { + /// Exhaustive rewrite strategy with CX count cost function. + /// + /// The gamma parameter is set to the default 1.0001. This is a good default + /// choice for NISQ-y circuits, where CX gates are the most expensive. + pub fn exhaustive_cx() -> Self { + ExhaustiveRewriteStrategy::with_predicate(is_cx) + } + + /// Exhaustive rewrite strategy with CX count cost function and provided gamma. + pub fn exhaustive_cx_with_gamma(gamma: f64) -> Self { + ExhaustiveRewriteStrategy::new(gamma, is_cx) } } -impl RewriteStrategy for ExhaustiveRewriteStrategy { +impl<P: Fn(&OpType) -> bool> RewriteStrategy for ExhaustiveRewriteStrategy<P> { #[tracing::instrument(skip_all)] fn apply_rewrites( &self, @@ -109,8 +144,8 @@ impl RewriteStrategy for ExhaustiveRewriteStrategy { rewrites .into_iter() .filter(|rw| { - let old_count = rw.subcircuit().node_count() as f64; - let new_count = rw.replacement().num_gates() as f64; + let old_count = pre_rewrite_cost(rw, circ, &self.op_predicate) as f64; + let new_count = post_rewrite_cost(rw, circ, &self.op_predicate) as f64; new_count < old_count * self.gamma }) .map(|rw| { @@ -122,6 +157,32 @@ impl RewriteStrategy for ExhaustiveRewriteStrategy { } } +fn is_cx(op: &OpType) -> bool { + op_matches(op, T2Op::CX) +} + +fn cost( + nodes: impl IntoIterator<Item = Node>, + circ: &Hugr, + pred: impl Fn(&OpType) -> bool, +) -> usize { + nodes + .into_iter() + .filter(|n| { + let op = circ.get_optype(*n); + pred(op) + }) + .count() +} + +fn pre_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: impl Fn(&OpType) -> bool) -> usize { + cost(rw.subcircuit().nodes().iter().copied(), circ, pred) +} + +fn post_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: impl Fn(&OpType) -> bool) -> usize { + cost(rw.replacement().nodes(), circ, pred) +} + #[cfg(test)] mod tests { use super::*; @@ -197,7 +258,7 @@ mod tests { rw_to_empty(&circ, cx_gates[9..10].to_vec()), ]; - let strategy = ExhaustiveRewriteStrategy::default(); + let strategy = ExhaustiveRewriteStrategy::exhaustive_cx(); let rewritten = strategy.apply_rewrites(rws, &circ); let exp_circ_lens = HashSet::from_iter([8, 6, 9]); let circ_lens: HashSet<_> = rewritten.iter().map(|c| c.num_gates()).collect(); @@ -219,7 +280,7 @@ mod tests { rw_to_empty(&circ, cx_gates[9..10].to_vec()), ]; - let strategy = ExhaustiveRewriteStrategy { gamma: 10. }; + let strategy = ExhaustiveRewriteStrategy::exhaustive_cx_with_gamma(10.); let rewritten = strategy.apply_rewrites(rws, &circ); let exp_circ_lens = HashSet::from_iter([8, 17, 6, 9]); let circ_lens: HashSet<_> = rewritten.iter().map(|c| c.num_gates()).collect(); diff --git a/taso-optimiser/src/main.rs b/taso-optimiser/src/main.rs index 95368318..af8b1a5d 100644 --- a/taso-optimiser/src/main.rs +++ b/taso-optimiser/src/main.rs @@ -10,11 +10,9 @@ use std::{fs, path::Path}; use clap::Parser; use hugr::Hugr; +use tket2::json::{load_tk1_json_file, TKETDecode}; use tket2::optimiser::taso::log::TasoLogger; -use tket2::{ - json::{load_tk1_json_file, TKETDecode}, - optimiser::TasoOptimiser, -}; +use tket2::optimiser::TasoOptimiser; use tket_json_rs::circuit_json::SerialCircuit; #[cfg(feature = "peak_alloc")] From 16824ac028de4e2fefc735f320a72d522e40f803 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah <seyon.sivarajah@cambridgequantum.com> Date: Wed, 27 Sep 2023 11:16:32 +0100 Subject: [PATCH 2/3] feat!: depth optimisation via commutation pass (#74) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Forms slices of topologically sorted commands, then greedily moves them forwards if a command can moved to an earlier slice through commutation. Can be used with pytket through pyo3 binding, does not change the operations and respects device connectivity so can be run as a final pass after all existing pytket optimisations. I benchmarked on all the circuits in [software_benchmarking](https://github.com/CQCL-DEV/software_benchmarking), metric is *CZ depth* so ~1100 circuits total 248 qiskit wins 752 pytket wins + tket-2 does not improve on pytket 106 tket-2 improves on pytket twice tket-2 + pytket outperforms qiskit even though qiskit was outperforming pytket *for total depth*: 111 qiskit wins 834 pytket wins + tket-2 does not improve on pytket 161 tket-2 improves on pytket BREAKING CHANGE: LinearUnit is now a newtype struct --------- Co-authored-by: AgustÃn Borgna <121866228+aborgna-q@users.noreply.github.com> Co-authored-by: Luca Mondada <72734770+lmondada@users.noreply.github.com> Co-authored-by: Luca Mondada <luca@mondada.net> --- pyrs/src/lib.rs | 26 +- pyrs/test/test_bindings.py | 22 + src/circuit/command.rs | 4 +- src/circuit/units.rs | 34 +- src/circuit/units/filter.rs | 4 +- src/lib.rs | 1 + src/{_passes.rs => passes.rs} | 7 +- src/{_passes => passes}/_classical.rs | 0 src/{_passes => passes}/_multi_search.rs | 0 src/{_passes => passes}/_redundancy.rs | 0 src/{_passes => passes}/_squash.rs | 0 src/passes/commutation.rs | 629 +++++++++++++++++++++++ 12 files changed, 717 insertions(+), 10 deletions(-) rename src/{_passes.rs => passes.rs} (97%) rename src/{_passes => passes}/_classical.rs (100%) rename src/{_passes => passes}/_multi_search.rs (100%) rename src/{_passes => passes}/_redundancy.rs (100%) rename src/{_passes => passes}/_squash.rs (100%) create mode 100644 src/passes/commutation.rs diff --git a/pyrs/src/lib.rs b/pyrs/src/lib.rs index 8d2ad205..146c36f7 100644 --- a/pyrs/src/lib.rs +++ b/pyrs/src/lib.rs @@ -1,15 +1,27 @@ //! Python bindings for TKET2. #![warn(missing_docs)] +use circuit::try_with_hugr; +use pyo3::prelude::*; +use tket2::{json::TKETDecode, passes::apply_greedy_commutation}; +use tket_json_rs::circuit_json::SerialCircuit; mod circuit; -use pyo3::prelude::*; +#[pyfunction] +fn greedy_depth_reduce(py_c: PyObject) -> PyResult<(PyObject, u32)> { + try_with_hugr(py_c, |mut h| { + let n_moves = apply_greedy_commutation(&mut h)?; + let py_c = SerialCircuit::encode(&h)?.to_tket1()?; + PyResult::Ok((py_c, n_moves)) + }) +} /// The Python bindings to TKET2. #[pymodule] fn pyrs(py: Python, m: &PyModule) -> PyResult<()> { add_circuit_module(py, m)?; add_pattern_module(py, m)?; + add_pass_module(py, m)?; Ok(()) } @@ -54,3 +66,15 @@ fn add_pattern_module(py: Python, parent: &PyModule) -> PyResult<()> { parent.add_submodule(m) } + +fn add_pass_module(py: Python, parent: &PyModule) -> PyResult<()> { + let m = PyModule::new(py, "passes")?; + m.add_function(wrap_pyfunction!(greedy_depth_reduce, m)?)?; + m.add_class::<tket2::T2Op>()?; + m.add( + "PullForwardError", + py.get_type::<tket2::passes::PyPullForwardError>(), + )?; + parent.add_submodule(m)?; + Ok(()) +} diff --git a/pyrs/test/test_bindings.py b/pyrs/test/test_bindings.py index 8a1d6abf..211b0e51 100644 --- a/pyrs/test/test_bindings.py +++ b/pyrs/test/test_bindings.py @@ -1,3 +1,25 @@ +from dataclasses import dataclass +from pyrs.pyrs import passes +from pytket.circuit import Circuit + + +@dataclass +class DepthOptimisePass: + def apply(self, circ: Circuit) -> Circuit: + (circ, n_moves) = passes.greedy_depth_reduce(circ) + return circ + + +def test_depth_optimise(): + c = Circuit(4).CX(0, 2).CX(1, 2).CX(1, 3) + + assert c.depth() == 3 + + c = DepthOptimisePass().apply(c) + + assert c.depth() == 2 + + # from dataclasses import dataclass # from typing import Callable, Iterable # import time diff --git a/src/circuit/command.rs b/src/circuit/command.rs index 094255e1..9bad28ac 100644 --- a/src/circuit/command.rs +++ b/src/circuit/command.rs @@ -256,7 +256,7 @@ where // TODO: `with_wires` combinator for `Units`? let wire_unit = circ .linear_units() - .map(|(linear_unit, port, _)| (Wire::new(circ.input(), port), linear_unit)) + .map(|(linear_unit, port, _)| (Wire::new(circ.input(), port), linear_unit.index())) .collect(); let nodes = pv::Topo::new(&circ.as_petgraph()); @@ -311,7 +311,7 @@ where // Update the map tracking the linear units let new_wire = Wire::new(node, port); self.wire_unit.insert(new_wire, linear_id); - linear_id + LinearUnit::new(linear_id) }) .collect(); diff --git a/src/circuit/units.rs b/src/circuit/units.rs index cd09ab55..07ead8c5 100644 --- a/src/circuit/units.rs +++ b/src/circuit/units.rs @@ -30,8 +30,36 @@ use super::Circuit; /// A linear unit id, used in [`CircuitUnit::Linear`]. // TODO: Add this to hugr? -pub type LinearUnit = usize; +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct LinearUnit(usize); +impl LinearUnit { + /// Creates a new [`LinearUnit`]. + pub fn new(index: usize) -> Self { + Self(index) + } + /// Returns the index of this [`LinearUnit`]. + pub fn index(&self) -> usize { + self.0 + } +} + +impl From<LinearUnit> for CircuitUnit { + fn from(lu: LinearUnit) -> Self { + CircuitUnit::Linear(lu.index()) + } +} + +impl TryFrom<CircuitUnit> for LinearUnit { + type Error = (); + + fn try_from(cu: CircuitUnit) -> Result<Self, Self::Error> { + match cu { + CircuitUnit::Wire(_) => Err(()), + CircuitUnit::Linear(i) => Ok(LinearUnit(i)), + } + } +} /// An iterator over the units in the input or output boundary of a [Node]. #[derive(Clone, Debug)] pub struct Units<UL = DefaultUnitLabeller> { @@ -134,7 +162,7 @@ where let linear_unit = self.unit_labeller .assign_linear(self.node, port, self.linear_count - 1); - CircuitUnit::Linear(linear_unit) + CircuitUnit::Linear(linear_unit.index()) } else { let wire = self.unit_labeller.assign_wire(self.node, port)?; CircuitUnit::Wire(wire) @@ -206,7 +234,7 @@ pub struct DefaultUnitLabeller; impl UnitLabeller for DefaultUnitLabeller { #[inline] fn assign_linear(&self, _: Node, _: Port, linear_count: usize) -> LinearUnit { - linear_count + LinearUnit(linear_count) } #[inline] diff --git a/src/circuit/units/filter.rs b/src/circuit/units/filter.rs index 04d6a6a6..559bb5d5 100644 --- a/src/circuit/units/filter.rs +++ b/src/circuit/units/filter.rs @@ -41,7 +41,7 @@ impl UnitFilter for Linear { fn accept(item: (CircuitUnit, Port, Type)) -> Option<Self::Item> { match item { - (CircuitUnit::Linear(unit), port, typ) => Some((unit, port, typ)), + (CircuitUnit::Linear(unit), port, typ) => Some((LinearUnit::new(unit), port, typ)), _ => None, } } @@ -53,7 +53,7 @@ impl UnitFilter for Qubits { fn accept(item: (CircuitUnit, Port, Type)) -> Option<Self::Item> { match item { (CircuitUnit::Linear(unit), port, typ) if typ == prelude::QB_T => { - Some((unit, port, typ)) + Some((LinearUnit::new(unit), port, typ)) } _ => None, } diff --git a/src/lib.rs b/src/lib.rs index 0a4eee4d..3e89fda1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,7 @@ pub mod extension; pub mod json; pub(crate) mod ops; pub mod optimiser; +pub mod passes; pub mod rewrite; pub use ops::{symbolic_constant_op, Pauli, T2Op}; diff --git a/src/_passes.rs b/src/passes.rs similarity index 97% rename from src/_passes.rs rename to src/passes.rs index a1815187..b2339f09 100644 --- a/src/_passes.rs +++ b/src/passes.rs @@ -4,8 +4,11 @@ // pub mod redundancy; // pub mod pattern; // pub mod squash; -#[cfg(feature = "portmatching")] -pub mod taso; +mod commutation; +pub use commutation::apply_greedy_commutation; +#[cfg(feature = "pyo3")] +pub use commutation::PyPullForwardError; + // use rayon::prelude::*; // use crate::circuit::{ diff --git a/src/_passes/_classical.rs b/src/passes/_classical.rs similarity index 100% rename from src/_passes/_classical.rs rename to src/passes/_classical.rs diff --git a/src/_passes/_multi_search.rs b/src/passes/_multi_search.rs similarity index 100% rename from src/_passes/_multi_search.rs rename to src/passes/_multi_search.rs diff --git a/src/_passes/_redundancy.rs b/src/passes/_redundancy.rs similarity index 100% rename from src/_passes/_redundancy.rs rename to src/passes/_redundancy.rs diff --git a/src/_passes/_squash.rs b/src/passes/_squash.rs similarity index 100% rename from src/_passes/_squash.rs rename to src/passes/_squash.rs diff --git a/src/passes/commutation.rs b/src/passes/commutation.rs new file mode 100644 index 00000000..1d56d400 --- /dev/null +++ b/src/passes/commutation.rs @@ -0,0 +1,629 @@ +use std::{collections::HashMap, rc::Rc}; + +use hugr::{ + hugr::{hugrmut::HugrMut, CircuitUnit, HugrError, PortIndex, Rewrite}, + Direction, Hugr, HugrView, Node, Port, +}; +use itertools::Itertools; +use portgraph::PortOffset; + +#[cfg(feature = "pyo3")] +use pyo3::{create_exception, exceptions::PyException, PyErr}; + +use crate::{ + circuit::{command::Command, units::filter::Qubits, Circuit}, + ops::{Pauli, T2Op}, +}; + +use thiserror::Error; + +type Qb = crate::circuit::units::LinearUnit; + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +// remove once https://github.com/CQCL-DEV/tket2/issues/126 is resolved +struct ComCommand { + /// The operation node. + node: Node, + /// An assignment of linear units to the node's ports. + // + // We'll need something more complex if `follow_linear_port` stops being a + // direct map from input to output. + inputs: Vec<CircuitUnit>, +} + +impl<'c, Circ> From<Command<'c, Circ>> for ComCommand +where + Circ: HugrView, +{ + fn from(com: Command<'c, Circ>) -> Self { + ComCommand { + node: com.node(), + inputs: com.inputs().map(|(c, _, _)| c).collect(), + } + } +} +impl ComCommand { + fn node(&self) -> Node { + self.node + } + fn qubits(&self) -> impl Iterator<Item = Qb> + '_ { + self.inputs.iter().filter_map(|u| { + let CircuitUnit::Linear(i) = u else { + return None; + }; + Some(Qb::new(*i)) + }) + } + fn port_of_qb(&self, qb: Qb, direction: Direction) -> Option<Port> { + self.inputs + .iter() + .position(|cu| { + let q_cu: CircuitUnit = qb.into(); + cu == &q_cu + }) + .map(|i| PortOffset::new(direction, i).into()) + } +} + +type Slice = Vec<Option<Rc<ComCommand>>>; +type SliceVec = Vec<Slice>; + +fn add_to_slice(slice: &mut Slice, com: Rc<ComCommand>) { + for q in com.qubits() { + slice[q.index()] = Some(com.clone()); + } +} + +fn load_slices(circ: &impl Circuit) -> SliceVec { + let mut slices = vec![]; + + let n_qbs = circ.units().filter_units::<Qubits>().count(); + let mut qubit_free_slice = vec![0; n_qbs]; + + for command in circ.commands().filter(|c| is_slice_op(circ, c.node())) { + let command: ComCommand = command.into(); + let free_slice = command + .qubits() + .map(|qb| qubit_free_slice[qb.index()]) + .max() + .unwrap(); + + for q in command.qubits() { + qubit_free_slice[q.index()] = free_slice + 1; + } + if free_slice >= slices.len() { + debug_assert!(free_slice == slices.len()); + slices.push(vec![None; n_qbs]); + } + let command = Rc::new(command); + add_to_slice(&mut slices[free_slice], command); + } + + slices +} + +/// check if node is one we want to put in to a slice. +fn is_slice_op(h: &impl HugrView, node: Node) -> bool { + let op: Result<T2Op, _> = h.get_optype(node).clone().try_into(); + op.is_ok() +} + +/// Starting from starting_index, work back along slices to check for the +/// earliest slice that can accommodate this command, if any. +fn available_slice( + circ: &impl HugrView, + slice_vec: &[Slice], + starting_index: usize, + command: &Rc<ComCommand>, +) -> Option<(usize, HashMap<Qb, Rc<ComCommand>>)> { + let mut available = None; + let mut prev_nodes: HashMap<Qb, Rc<ComCommand>> = HashMap::new(); + + for slice_index in (0..starting_index + 1).rev() { + // if all qubit slots are empty here the command can be moved here + if command + .qubits() + .all(|q| slice_vec[slice_index][q.index()].is_none()) + { + available = Some((slice_index, prev_nodes.clone())); + } else if slice_index == 0 { + break; + } else { + // if command commutes with all ports here it can be moved past, + // otherwise stop + if let Some(new_prev_nodes) = commutes_at_slice(command, &slice_vec[slice_index], circ) + { + prev_nodes.extend(new_prev_nodes); + } else { + break; + } + } + } + + available +} + +// If a command commutes back through this slice return a map from the qubits of +// the command to the commands in this slice acting on those qubits. +fn commutes_at_slice( + command: &Rc<ComCommand>, + slice: &Slice, + circ: &impl HugrView, +) -> Option<HashMap<Qb, Rc<ComCommand>>> { + // map from qubit to node it is connected to immediately after the free slice. + let mut prev_nodes: HashMap<Qb, Rc<ComCommand>> = + HashMap::from_iter(command.qubits().map(|q| (q, command.clone()))); + + for q in command.qubits() { + // if slot is empty, continue checking. + let Some(other_com) = &slice[q.index()] else { + continue; + }; + + let port = command.port_of_qb(q, Direction::Incoming)?; + + let op: T2Op = circ.get_optype(command.node()).clone().try_into().ok()?; + // TODO: if not t2op, might still have serialized commutation data we + // can use. + let pauli = commutation_on_port(&op.qubit_commutation(), port)?; + + let other_op: T2Op = circ.get_optype(other_com.node()).clone().try_into().ok()?; + let other_pauli = commutation_on_port( + &other_op.qubit_commutation(), + other_com.port_of_qb(q, Direction::Outgoing)?, + )?; + + if pauli.commutes_with(other_pauli) { + prev_nodes.insert(q, other_com.clone()); + } else { + return None; + } + } + + Some(prev_nodes) +} + +fn commutation_on_port(comms: &[(usize, Pauli)], port: Port) -> Option<Pauli> { + comms + .iter() + .find_map(|(i, p)| (*i == port.index()).then_some(*p)) +} + +/// Error from a [`PullForward`] operation. +#[derive(Debug, Clone, Error, PartialEq, Eq)] +#[allow(missing_docs)] +pub enum PullForwardError { + // Error in hugr mutation. + #[error("Hugr mutation error: {0:?}")] + HugrError(#[from] HugrError), + + #[error("Qubit {0} not found in command.")] + NoQbInCommand(usize), + + #[error("No subsequent command found for qubit {0}")] + NoCommandForQb(usize), +} + +#[cfg(feature = "pyo3")] +create_exception!( + pyrs, + PyPullForwardError, + PyException, + "Error in applying PullForward rewrite." +); + +#[cfg(feature = "pyo3")] +impl From<PullForwardError> for PyErr { + fn from(err: PullForwardError) -> Self { + PyPullForwardError::new_err(err.to_string()) + } +} +struct PullForward { + command: Rc<ComCommand>, + new_nexts: HashMap<Qb, Rc<ComCommand>>, +} + +impl Rewrite for PullForward { + type Error = PullForwardError; + + type ApplyResult = (); + + const UNCHANGED_ON_FAILURE: bool = false; + + fn verify(&self, _h: &impl HugrView) -> Result<(), Self::Error> { + unimplemented!() + } + + fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> { + let Self { command, new_nexts } = self; + + let qb_port = |command: &ComCommand, qb, direction| { + command + .port_of_qb(qb, direction) + .ok_or(PullForwardError::NoQbInCommand(qb.index())) + }; + // for each qubit, disconnect node and reconnect at destination. + for qb in command.qubits() { + let out_port = qb_port(&command, qb, Direction::Outgoing)?; + let in_port = qb_port(&command, qb, Direction::Incoming)?; + + let (src, src_port) = h + .linked_ports(command.node(), in_port) + .exactly_one() + .ok() // PortLinks don't implement Debug + .unwrap(); + let (dst, dst_port) = h + .linked_ports(command.node(), out_port) + .exactly_one() + .ok() + .unwrap(); + + let Some(new_neighbour_com) = new_nexts.get(&qb) else { + return Err(PullForwardError::NoCommandForQb(qb.index())); + }; + if new_neighbour_com == &command { + // do not need to commute along this qubit. + continue; + } + h.disconnect(command.node(), in_port)?; + h.disconnect(command.node(), out_port)?; + // connect old source and destination - identity operation. + h.connect(src, src_port.index(), dst, dst_port.index())?; + + let new_dst_port = qb_port(new_neighbour_com, qb, Direction::Incoming)?; + let (new_src, new_src_port) = h + .linked_ports(new_neighbour_com.node(), new_dst_port) + .exactly_one() + .ok() + .unwrap(); + // disconnect link which we will insert in to. + h.disconnect(new_neighbour_com.node(), new_dst_port)?; + + h.connect( + new_src, + new_src_port.index(), + command.node(), + in_port.index(), + )?; + h.connect( + command.node(), + out_port.index(), + new_neighbour_com.node(), + new_dst_port.index(), + )?; + } + Ok(()) + } +} + +/// Pass which greedily commutes operations forwards in order to reduce depth. +pub fn apply_greedy_commutation(circ: &mut Hugr) -> Result<u32, PullForwardError> { + let mut count = 0; + let mut slice_vec = load_slices(circ); + + for slice_index in 0..slice_vec.len() { + let slice_commands: Vec<_> = slice_vec[slice_index] + .iter() + .flatten() + .unique() + .cloned() + .collect(); + + for command in slice_commands { + if let Some((destination, new_nexts)) = + available_slice(&circ, &slice_vec, slice_index, &command) + { + debug_assert!( + destination < slice_index, + "Avoid mutating slices we haven't got to yet." + ); + for q in command.qubits() { + let com = slice_vec[slice_index][q.index()].take(); + slice_vec[destination][q.index()] = com; + } + let rewrite = PullForward { command, new_nexts }; + circ.apply_rewrite(rewrite)?; + count += 1; + } + } + } + + // TODO remove empty slices and return + // and full slices at start? + Ok(count) +} + +#[cfg(test)] +mod test { + + use crate::{extension::REGISTRY, ops::test::t2_bell_circuit, utils::build_simple_circuit}; + use hugr::{ + builder::{DFGBuilder, Dataflow, DataflowHugr}, + extension::prelude::{BOOL_T, QB_T}, + std_extensions::arithmetic::float_types::FLOAT64_TYPE, + type_row, + types::FunctionType, + Hugr, + }; + use itertools::Itertools; + use rstest::{fixture, rstest}; + + use super::*; + + #[fixture] + // example circuit from original task + fn example_cx() -> Hugr { + build_simple_circuit(4, |circ| { + circ.append(T2Op::CX, [0, 2])?; + circ.append(T2Op::CX, [1, 2])?; + circ.append(T2Op::CX, [1, 3])?; + Ok(()) + }) + .unwrap() + } + + #[fixture] + // example circuit from original task with lower depth + fn example_cx_better() -> Hugr { + build_simple_circuit(4, |circ| { + circ.append(T2Op::CX, [0, 2])?; + circ.append(T2Op::CX, [1, 3])?; + circ.append(T2Op::CX, [1, 2])?; + Ok(()) + }) + .unwrap() + } + + #[fixture] + // can't commute anything here + fn cant_commute() -> Hugr { + build_simple_circuit(4, |circ| { + circ.append(T2Op::Z, [1])?; + circ.append(T2Op::CX, [0, 1])?; + circ.append(T2Op::CX, [2, 1])?; + Ok(()) + }) + .unwrap() + } + + #[fixture] + fn big_example() -> Hugr { + build_simple_circuit(4, |circ| { + circ.append(T2Op::CX, [0, 3])?; + circ.append(T2Op::CX, [1, 2])?; + circ.append(T2Op::H, [0])?; + circ.append(T2Op::H, [3])?; + circ.append(T2Op::CX, [0, 1])?; + circ.append(T2Op::CX, [2, 3])?; + circ.append(T2Op::CX, [0, 1])?; + circ.append(T2Op::CX, [2, 3])?; + circ.append(T2Op::CX, [2, 1])?; + circ.append(T2Op::H, [1])?; + Ok(()) + }) + .unwrap() + } + + #[fixture] + // commute a single qubit gate + fn single_qb_commute() -> Hugr { + build_simple_circuit(3, |circ| { + circ.append(T2Op::H, [1])?; + circ.append(T2Op::CX, [0, 1])?; + circ.append(T2Op::Z, [0])?; + Ok(()) + }) + .unwrap() + } + #[fixture] + + // commute 2 single qubit gates + fn single_qb_commute_2() -> Hugr { + build_simple_circuit(4, |circ| { + circ.append(T2Op::CX, [1, 2])?; + circ.append(T2Op::CX, [1, 0])?; + circ.append(T2Op::CX, [3, 2])?; + circ.append(T2Op::X, [0])?; + circ.append(T2Op::Z, [3])?; + Ok(()) + }) + .unwrap() + } + + #[fixture] + // A commutation forward exists but depth doesn't change + fn commutes_but_same_depth() -> Hugr { + build_simple_circuit(3, |circ| { + circ.append(T2Op::H, [1])?; + circ.append(T2Op::CX, [0, 1])?; + circ.append(T2Op::Z, [0])?; + circ.append(T2Op::X, [1])?; + Ok(()) + }) + .unwrap() + } + + #[fixture] + // Gate being commuted has a non-linear input + fn non_linear_inputs() -> Hugr { + let build = || { + let mut dfg = DFGBuilder::new(FunctionType::new( + type_row![QB_T, QB_T, FLOAT64_TYPE], + type_row![QB_T, QB_T], + ))?; + + let [q0, q1, f] = dfg.input_wires_arr(); + + let mut circ = dfg.as_circuit(vec![q0, q1]); + + circ.append(T2Op::H, [1])?; + circ.append(T2Op::CX, [0, 1])?; + circ.append_and_consume(T2Op::RzF64, [CircuitUnit::Linear(0), CircuitUnit::Wire(f)])?; + let qbs = circ.finish(); + dfg.finish_hugr_with_outputs(qbs, ®ISTRY) + }; + build().unwrap() + } + + #[fixture] + // Gates being commuted have non-linear outputs + fn non_linear_outputs() -> Hugr { + let build = || { + let mut dfg = DFGBuilder::new(FunctionType::new( + type_row![QB_T, QB_T], + type_row![QB_T, QB_T, BOOL_T], + ))?; + + let [q0, q1] = dfg.input_wires_arr(); + + let mut circ = dfg.as_circuit(vec![q0, q1]); + + circ.append(T2Op::H, [1])?; + circ.append(T2Op::CX, [0, 1])?; + let measured = circ.append_with_outputs(T2Op::Measure, [0])?; + let mut outs = circ.finish(); + outs.extend(measured); + dfg.finish_hugr_with_outputs(outs, ®ISTRY) + }; + build().unwrap() + } + + fn slice_from_command( + commands: &[ComCommand], + n_qbs: usize, + slice_arr: &[&[usize]], + ) -> SliceVec { + slice_arr + .iter() + .map(|command_indices| { + let mut slice = vec![None; n_qbs]; + for ind in command_indices.iter() { + let com = commands[*ind].clone(); + add_to_slice(&mut slice, Rc::new(com)) + } + + slice + }) + .collect() + } + + #[rstest] + fn test_load_slices_cx(example_cx: Hugr) { + let circ = example_cx; + let commands: Vec<ComCommand> = circ.commands().map_into().collect(); + let slices = load_slices(&circ); + let correct = slice_from_command(&commands, 4, &[&[0], &[1], &[2]]); + + assert_eq!(slices, correct); + } + + #[rstest] + fn test_load_slices_cx_better(example_cx_better: Hugr) { + let circ = example_cx_better; + let commands: Vec<ComCommand> = circ.commands().map_into().collect(); + + let slices = load_slices(&circ); + let correct = slice_from_command(&commands, 4, &[&[0, 1], &[2]]); + + assert_eq!(slices, correct); + } + + #[rstest] + fn test_load_slices_bell(t2_bell_circuit: Hugr) { + let circ = t2_bell_circuit; + let commands: Vec<ComCommand> = circ.commands().map_into().collect(); + + let slices = load_slices(&circ); + let correct = slice_from_command(&commands, 2, &[&[0], &[1]]); + + assert_eq!(slices, correct); + } + + #[rstest] + fn test_available_slice(example_cx: Hugr) { + let circ = example_cx; + let slices = load_slices(&circ); + let (found, prev_nodes) = + available_slice(&circ, &slices, 1, slices[2][1].as_ref().unwrap()).unwrap(); + assert_eq!(found, 0); + + assert_eq!( + *prev_nodes.get(&Qb::new(1)).unwrap(), + slices[1][1].as_ref().unwrap().clone() + ); + + assert_eq!( + *prev_nodes.get(&Qb::new(3)).unwrap(), + slices[2][3].as_ref().unwrap().clone() + ); + } + + #[rstest] + fn big_test(big_example: Hugr) { + let circ = big_example; + let slices = load_slices(&circ); + assert_eq!(slices.len(), 6); + // can commute final cx to front + let (found, prev_nodes) = + available_slice(&circ, &slices, 3, slices[4][1].as_ref().unwrap()).unwrap(); + assert_eq!(found, 1); + assert_eq!( + *prev_nodes.get(&Qb::new(1)).unwrap(), + slices[2][1].as_ref().unwrap().clone() + ); + + assert_eq!( + *prev_nodes.get(&Qb::new(2)).unwrap(), + slices[2][2].as_ref().unwrap().clone() + ); + // hadamard can't commute past anything + assert!(available_slice(&circ, &slices, 4, slices[5][1].as_ref().unwrap()).is_none()); + } + + /// Calculate depth by placing commands in slices. + fn depth(h: &Hugr) -> usize { + load_slices(h).len() + } + #[rstest] + #[case(example_cx(), true, 1)] + #[case(example_cx_better(), false, 0)] + #[case(big_example(), true, 1)] + #[case(cant_commute(), false, 0)] + #[case(t2_bell_circuit(), false, 0)] + #[case(single_qb_commute(), true, 1)] + #[case(single_qb_commute_2(), true, 2)] + #[case(commutes_but_same_depth(), false, 1)] + #[case(non_linear_inputs(), true, 1)] + #[case(non_linear_outputs(), true, 1)] + fn commutation_example( + #[case] mut case: Hugr, + #[case] should_reduce: bool, + #[case] expected_moves: u32, + ) { + let node_count = case.node_count(); + let depth_before = depth(&case); + let move_count = apply_greedy_commutation(&mut case).unwrap(); + case.infer_and_validate(®ISTRY).unwrap(); + + assert_eq!( + move_count, expected_moves, + "Number of commutations did not match expected." + ); + let depth_after = depth(&case); + + if should_reduce { + assert!(depth_after < depth_before, "Depth should have decreased.."); + } else { + assert_eq!( + depth_before, depth_after, + "Depth should not have changed for this case." + ); + } + + assert_eq!( + case.node_count(), + node_count, + "depth optimisation should not change the number of nodes." + ) + } +} From be8b9a9e3491cad2b5a6accfc54cdce37c9e64fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Wed, 27 Sep 2023 12:59:59 +0200 Subject: [PATCH 3/3] chore: move unused code out of `passes.rs` (#136) Moves out the commented code. We could just drop it, but I'd leave that as an issue to analyse what can be salvaged from all the `_*` unused files. --- src/passes.rs | 150 ------------------------------------------- src/passes/_apply.rs | 138 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 150 deletions(-) create mode 100644 src/passes/_apply.rs diff --git a/src/passes.rs b/src/passes.rs index b2339f09..68ac2a44 100644 --- a/src/passes.rs +++ b/src/passes.rs @@ -1,156 +1,6 @@ //! Optimisation passes for circuits. -// pub mod classical; -// pub mod redundancy; -// pub mod pattern; -// pub mod squash; mod commutation; pub use commutation::apply_greedy_commutation; #[cfg(feature = "pyo3")] pub use commutation::PyPullForwardError; - -// use rayon::prelude::*; - -// use crate::circuit::{ -// circuit::{Circuit, CircuitRewrite}, -// dag::{EdgeProperties, VertexProperties}, -// operation::{Op, Param}, -// }; - -// use self::pattern::{FixedStructPattern, Match, NodeCompClosure, PatternMatcher}; -// use portgraph::{ -// graph::{NodeIndex, DIRECTIONS}, -// substitute::{BoundedSubgraph, RewriteError, SubgraphRef}, -// }; - -// pub trait RewriteGenerator<'s, T: Iterator<Item = CircuitRewrite> + 's> { -// fn rewrites<'a: 's>(&'s self, base_circ: &'a Circuit) -> T; -// fn into_rewrites(self, base_circ: &'s Circuit) -> T; -// } - -// /// Repeatedly apply all available rewrites reported by finder closure until no more are found. -// /// -// /// # Errors -// /// -// /// This function will return an error if rewrite application fails. -// pub fn apply_exhaustive<F>(mut circ: Circuit, finder: F) -> Result<(Circuit, bool), RewriteError> -// where -// F: Fn(&Circuit) -> Vec<CircuitRewrite>, -// { -// let mut success = false; -// loop { -// // assuming all the returned rewrites are non-overlapping -// // or filter to make them non-overlapping -// // then in theory, they can all be applied in parallel -// let rewrites = finder(&circ); -// if rewrites.is_empty() { -// break; -// } -// success = true; -// for rewrite in rewrites { -// circ.apply_rewrite(rewrite)?; -// } -// } - -// Ok((circ, success)) -// } - -// /// Repeatedly apply first reported rewrite -// /// -// /// # Errors -// /// -// /// This function will return an error if rewrite application fails. -// pub fn apply_greedy<F>(mut circ: Circuit, finder: F) -> Result<(Circuit, bool), RewriteError> -// where -// F: Fn(&Circuit) -> Option<CircuitRewrite>, -// { -// let mut success = false; -// while let Some(rewrite) = finder(&circ) { -// success |= true; -// circ.apply_rewrite(rewrite)?; -// } - -// Ok((circ, success)) -// } - -// pub type CircFixedStructPattern<F> = FixedStructPattern<VertexProperties, EdgeProperties, F>; - -// impl<F> CircFixedStructPattern<F> { -// pub fn from_circ(pattern_circ: Circuit, node_comp_closure: F) -> Self { -// Self { -// boundary: pattern_circ.boundary(), -// graph: pattern_circ.dag, -// node_comp_closure, -// } -// } -// } - -// pub struct PatternRewriter<F, G> { -// pattern: CircFixedStructPattern<F>, -// rewrite_closure: G, -// } - -// impl<F, G> PatternRewriter<F, G> { -// pub fn new(pattern: CircFixedStructPattern<F>, rewrite_closure: G) -> Self { -// Self { -// pattern, -// rewrite_closure, -// } -// } -// } - -// impl<'s, 'f: 's, F, G> RewriteGenerator<'s, CircRewriteIter<'s, F, G>> for PatternRewriter<F, G> -// where -// F: NodeCompClosure<VertexProperties, EdgeProperties> + Clone + Send + Sync + 'f, -// G: Fn(Match) -> (Circuit, Param) + 's + Clone, -// { -// fn into_rewrites(self, base_circ: &'s Circuit) -> CircRewriteIter<'s, F, G> { -// let ports = pattern_ports(&self.pattern); -// let matcher = PatternMatcher::new(self.pattern, base_circ.dag_ref()); - -// RewriteIter { -// match_iter: matcher.into_iter(), -// ports, -// rewrite_closure: self.rewrite_closure, -// circ: base_circ, -// } -// } - -// fn rewrites<'a: 's>(&'s self, base_circ: &'a Circuit) -> CircRewriteIter<'s, F, G> { -// let ports = pattern_ports(&self.pattern); -// let matcher = PatternMatcher::new(self.pattern.clone(), base_circ.dag_ref()); - -// RewriteIter { -// match_iter: matcher.into_iter(), -// ports, -// rewrite_closure: self.rewrite_closure.clone(), -// circ: base_circ, -// } -// } -// } - -// pub type CircRewriteIter<'a, F, G> = RewriteIter<'a, VertexProperties, EdgeProperties, F, G>; - -// pub fn decompose_custom(circ: &Circuit) -> impl Iterator<Item = CircuitRewrite> + '_ { -// circ.dag.node_indices().filter_map(|n| { -// let op = &circ.dag.node_weight(n).unwrap().op; -// if let Op::Custom(x) = op { -// Some(CircuitRewrite::new( -// BoundedSubgraph::from_node(&circ.dag, n), -// x.to_circuit().expect("Circuit generation failed.").into(), -// 0.0, -// )) -// } else { -// None -// } -// }) -// } - -// #[cfg(feature = "pyo3")] -// use pyo3::prelude::pyfunction; - -// #[cfg_attr(feature = "pyo3", pyfunction)] -// pub fn decompose_custom_pass(circ: Circuit) -> (Circuit, bool) { -// let (circ, suc) = apply_exhaustive(circ, |c| decompose_custom(c).collect()).unwrap(); -// (circ, suc) -// } diff --git a/src/passes/_apply.rs b/src/passes/_apply.rs new file mode 100644 index 00000000..ea87da04 --- /dev/null +++ b/src/passes/_apply.rs @@ -0,0 +1,138 @@ + +use self::pattern::{FixedStructPattern, Match, NodeCompClosure, PatternMatcher}; +use portgraph::{ + graph::{NodeIndex, DIRECTIONS}, + substitute::{BoundedSubgraph, RewriteError, SubgraphRef}, +}; + +pub trait RewriteGenerator<'s, T: Iterator<Item = CircuitRewrite> + 's> { + fn rewrites<'a: 's>(&'s self, base_circ: &'a Circuit) -> T; + fn into_rewrites(self, base_circ: &'s Circuit) -> T; +} + +/// Repeatedly apply all available rewrites reported by finder closure until no more are found. +/// +/// # Errors +/// +/// This function will return an error if rewrite application fails. +pub fn apply_exhaustive<F>(mut circ: Circuit, finder: F) -> Result<(Circuit, bool), RewriteError> +where + F: Fn(&Circuit) -> Vec<CircuitRewrite>, +{ + let mut success = false; + loop { + // assuming all the returned rewrites are non-overlapping + // or filter to make them non-overlapping + // then in theory, they can all be applied in parallel + let rewrites = finder(&circ); + if rewrites.is_empty() { + break; + } + success = true; + for rewrite in rewrites { + circ.apply_rewrite(rewrite)?; + } + } + + Ok((circ, success)) +} + +/// Repeatedly apply first reported rewrite +/// +/// # Errors +/// +/// This function will return an error if rewrite application fails. +pub fn apply_greedy<F>(mut circ: Circuit, finder: F) -> Result<(Circuit, bool), RewriteError> +where + F: Fn(&Circuit) -> Option<CircuitRewrite>, +{ + let mut success = false; + while let Some(rewrite) = finder(&circ) { + success |= true; + circ.apply_rewrite(rewrite)?; + } + + Ok((circ, success)) +} + +pub type CircFixedStructPattern<F> = FixedStructPattern<VertexProperties, EdgeProperties, F>; + +impl<F> CircFixedStructPattern<F> { + pub fn from_circ(pattern_circ: Circuit, node_comp_closure: F) -> Self { + Self { + boundary: pattern_circ.boundary(), + graph: pattern_circ.dag, + node_comp_closure, + } + } +} + +pub struct PatternRewriter<F, G> { + pattern: CircFixedStructPattern<F>, + rewrite_closure: G, +} + +impl<F, G> PatternRewriter<F, G> { + pub fn new(pattern: CircFixedStructPattern<F>, rewrite_closure: G) -> Self { + Self { + pattern, + rewrite_closure, + } + } +} + +impl<'s, 'f: 's, F, G> RewriteGenerator<'s, CircRewriteIter<'s, F, G>> for PatternRewriter<F, G> +where + F: NodeCompClosure<VertexProperties, EdgeProperties> + Clone + Send + Sync + 'f, + G: Fn(Match) -> (Circuit, Param) + 's + Clone, +{ + fn into_rewrites(self, base_circ: &'s Circuit) -> CircRewriteIter<'s, F, G> { + let ports = pattern_ports(&self.pattern); + let matcher = PatternMatcher::new(self.pattern, base_circ.dag_ref()); + + RewriteIter { + match_iter: matcher.into_iter(), + ports, + rewrite_closure: self.rewrite_closure, + circ: base_circ, + } + } + + fn rewrites<'a: 's>(&'s self, base_circ: &'a Circuit) -> CircRewriteIter<'s, F, G> { + let ports = pattern_ports(&self.pattern); + let matcher = PatternMatcher::new(self.pattern.clone(), base_circ.dag_ref()); + + RewriteIter { + match_iter: matcher.into_iter(), + ports, + rewrite_closure: self.rewrite_closure.clone(), + circ: base_circ, + } + } +} + +pub type CircRewriteIter<'a, F, G> = RewriteIter<'a, VertexProperties, EdgeProperties, F, G>; + +pub fn decompose_custom(circ: &Circuit) -> impl Iterator<Item = CircuitRewrite> + '_ { + circ.dag.node_indices().filter_map(|n| { + let op = &circ.dag.node_weight(n).unwrap().op; + if let Op::Custom(x) = op { + Some(CircuitRewrite::new( + BoundedSubgraph::from_node(&circ.dag, n), + x.to_circuit().expect("Circuit generation failed.").into(), + 0.0, + )) + } else { + None + } + }) +} + +#[cfg(feature = "pyo3")] +use pyo3::prelude::pyfunction; + +#[cfg_attr(feature = "pyo3", pyfunction)] +pub fn decompose_custom_pass(circ: Circuit) -> (Circuit, bool) { + let (circ, suc) = apply_exhaustive(circ, |c| decompose_custom(c).collect()).unwrap(); + (circ, suc) +} \ No newline at end of file