diff --git a/Cargo.toml b/Cargo.toml index fa4abbfc..f582ded7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ itertools = { workspace = true } petgraph = { version = "0.6.3", default-features = false } serde_yaml = "0.9.22" # portmatching = { version = "0.2.0", optional = true, features = ["serde"]} -portmatching = { optional = true, git = "https://github.com/lmondada/portmatching", rev = "c4ad0ec", features = [ +portmatching = { optional = true, git = "https://github.com/lmondada/portmatching", rev = "738c91c", features = [ "serde", ] } derive_more = "0.99.17" diff --git a/src/circuit/units.rs b/src/circuit/units.rs index 222705a1..cd09ab55 100644 --- a/src/circuit/units.rs +++ b/src/circuit/units.rs @@ -19,9 +19,11 @@ use std::iter::FusedIterator; use hugr::hugr::CircuitUnit; use hugr::ops::OpTrait; -use hugr::types::{EdgeKind, Type, TypeBound, TypeRow}; +use hugr::types::{EdgeKind, Type, TypeRow}; use hugr::{Direction, Node, Port, Wire}; +use crate::utils::type_is_linear; + use self::filter::UnitFilter; use super::Circuit; @@ -215,7 +217,3 @@ impl UnitLabeller for DefaultUnitLabeller { } } } - -fn type_is_linear(typ: &Type) -> bool { - !TypeBound::Copyable.contains(typ.least_upper_bound()) -} diff --git a/src/portmatching.rs b/src/portmatching.rs index 6cf40f42..423c612c 100644 --- a/src/portmatching.rs +++ b/src/portmatching.rs @@ -5,11 +5,137 @@ pub mod pattern; #[cfg(feature = "pyo3")] pub mod pyo3; +use itertools::Itertools; pub use matcher::{PatternMatch, PatternMatcher}; pub use pattern::CircuitPattern; -use hugr::Port; +use hugr::{ + ops::{OpTag, OpTrait}, + Node, Port, +}; use matcher::MatchOp; +use thiserror::Error; + +use crate::{circuit::Circuit, utils::type_is_linear}; -type PEdge = (Port, Port); type PNode = MatchOp; + +/// An edge property in a circuit pattern. +/// +/// Edges are +/// Edges are reversible if the edge type is linear. +#[derive( + Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize, +)] +enum PEdge { + /// A "normal" edge between src and dst within a pattern. + InternalEdge { + src: Port, + dst: Port, + is_reversible: bool, + }, + /// An edge from a copied input to src. + /// + /// Edges from inputs are typically not matched as part of the pattern, + /// unless a single input is copied multiple times. In this case, an + /// InputEdge is used to link the source port to the (usually hidden) + /// copy node. + /// + /// Input edges are always irreversible. + InputEdge { src: Port }, +} + +#[derive(Debug, Clone, Error)] +enum InvalidEdgeProperty { + /// The port is linked to multiple edges. + #[error("port {0:?} is linked to multiple edges")] + AmbiguousEdge(Port), + /// The port is not linked to any edge. + #[error("port {0:?} is not linked to any edge")] + NoLinkedEdge(Port), + /// The port does not have a type. + #[error("port {0:?} does not have a type")] + UntypedPort(Port), +} + +impl PEdge { + fn try_from_port( + node: Node, + port: Port, + circ: &impl Circuit, + ) -> Result { + let src = port; + let (dst_node, dst) = circ + .linked_ports(node, src) + .exactly_one() + .map_err(|mut e| { + if e.next().is_some() { + InvalidEdgeProperty::AmbiguousEdge(src) + } else { + InvalidEdgeProperty::NoLinkedEdge(src) + } + })?; + if circ.get_optype(dst_node).tag() == OpTag::Input { + return Ok(Self::InputEdge { src }); + } + let port_type = circ + .get_optype(node) + .signature() + .get(src) + .cloned() + .ok_or(InvalidEdgeProperty::UntypedPort(src))?; + let is_reversible = type_is_linear(&port_type); + Ok(Self::InternalEdge { + src, + dst, + is_reversible, + }) + } +} + +impl portmatching::EdgeProperty for PEdge { + type OffsetID = Port; + + fn reverse(&self) -> Option { + match *self { + Self::InternalEdge { + src, + dst, + is_reversible, + } => is_reversible.then_some(Self::InternalEdge { + src: dst, + dst: src, + is_reversible, + }), + Self::InputEdge { .. } => None, + } + } + + fn offset_id(&self) -> Self::OffsetID { + match *self { + Self::InternalEdge { src, .. } => src, + Self::InputEdge { src, .. } => src, + } + } +} + +/// A node in a pattern. +/// +/// A node is either a real node in the HUGR graph or a hidden copy node +/// that is identified by its node and outgoing port. +/// +/// A NodeID::CopyNode can only be found as a target of a PEdge::InputEdge +/// property. Furthermore, a NodeID::CopyNode never has a node property. +#[derive( + Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize, +)] +pub(super) enum NodeID { + HugrNode(Node), + CopyNode(Node, Port), +} + +impl From for NodeID { + fn from(node: Node) -> Self { + Self::HugrNode(node) + } +} diff --git a/src/portmatching/matcher.rs b/src/portmatching/matcher.rs index 30075593..960ea346 100644 --- a/src/portmatching/matcher.rs +++ b/src/portmatching/matcher.rs @@ -7,13 +7,13 @@ use std::{ path::{Path, PathBuf}, }; -use super::{CircuitPattern, PEdge, PNode}; +use super::{CircuitPattern, NodeID, PEdge, PNode}; use hugr::hugr::views::sibling_subgraph::{ConvexChecker, InvalidReplacement, InvalidSubgraph}; use hugr::{hugr::views::SiblingSubgraph, ops::OpType, Hugr, Node, Port}; use itertools::Itertools; use portmatching::{ automaton::{LineBuilder, ScopeAutomaton}, - PatternID, + EdgeProperty, PatternID, }; use thiserror::Error; @@ -30,6 +30,7 @@ use crate::{ /// Matchable operations in a circuit. /// /// We currently support [`T2Op`] and a the HUGR load constant operation. +// TODO: Support OpType::Const, but blocked by use of F64 (Eq support required) #[derive( Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize, )] @@ -257,11 +258,11 @@ impl PatternMatcher { ) -> Vec { self.automaton .run( - root, + root.into(), // Node weights (none) - validate_weighted_node(circ), + validate_circuit_node(circ), // Check edge exist - validate_unweighted_edge(circ), + validate_circuit_edge(circ), ) .filter_map(|pattern_id| { handle_match_error( @@ -380,28 +381,48 @@ impl From for InvalidPatternMatch { } } -fn compatible_offsets((_, pout): &(Port, Port), (pin, _): &(Port, Port)) -> bool { - pout.direction() != pin.direction() && pout.index() == pin.index() +fn compatible_offsets(e1: &PEdge, e2: &PEdge) -> bool { + let PEdge::InternalEdge { dst: dst1, .. } = e1 else { + return false; + }; + let src2 = e2.offset_id(); + dst1.direction() != src2.direction() && dst1.index() == src2.index() } -/// Check if an edge `e` is valid in a portgraph `g` without weights. -pub(crate) fn validate_unweighted_edge( +/// Returns a predicate checking that an edge at `src` satisfies `prop` in `circ`. +pub(super) fn validate_circuit_edge( circ: &impl Circuit, -) -> impl for<'a> Fn(Node, &'a PEdge) -> Option + '_ { - move |src, &(src_port, tgt_port)| { - let (next_node, _) = circ - .linked_ports(src, src_port) - .find(|&(_, tgt)| tgt == tgt_port)?; - Some(next_node) +) -> impl for<'a> Fn(NodeID, &'a PEdge) -> Option + '_ { + move |src, &prop| { + let NodeID::HugrNode(src) = src else { + return None; + }; + match prop { + PEdge::InternalEdge { + src: src_port, + dst: dst_port, + .. + } => { + let (next_node, next_port) = circ.linked_ports(src, src_port).exactly_one().ok()?; + (dst_port == next_port).then_some(NodeID::HugrNode(next_node)) + } + PEdge::InputEdge { src: src_port } => { + let (next_node, next_port) = circ.linked_ports(src, src_port).exactly_one().ok()?; + Some(NodeID::CopyNode(next_node, next_port)) + } + } } } -/// Check if a node `n` is valid in a weighted portgraph `g`. -pub(crate) fn validate_weighted_node( +/// Returns a predicate checking that `node` satisfies `prop` in `circ`. +pub(crate) fn validate_circuit_node( circ: &impl Circuit, -) -> impl for<'a> Fn(Node, &PNode) -> bool + '_ { - move |v, prop| { - let v_weight = MatchOp::try_from(circ.get_optype(v).clone()); +) -> impl for<'a> Fn(NodeID, &PNode) -> bool + '_ { + move |node, prop| { + let NodeID::HugrNode(node) = node else { + return false; + }; + let v_weight = MatchOp::try_from(circ.get_optype(node).clone()); v_weight.is_ok_and(|w| &w == prop) } } diff --git a/src/portmatching/pattern.rs b/src/portmatching/pattern.rs index cf97d739..60b0ab1c 100644 --- a/src/portmatching/pattern.rs +++ b/src/portmatching/pattern.rs @@ -7,10 +7,10 @@ use std::fmt::Debug; use thiserror::Error; use super::{ - matcher::{validate_unweighted_edge, validate_weighted_node}, + matcher::{validate_circuit_edge, validate_circuit_node}, PEdge, PNode, }; -use crate::circuit::Circuit; +use crate::{circuit::Circuit, portmatching::NodeID}; #[cfg(feature = "pyo3")] use pyo3::{create_exception, exceptions::PyException, pyclass, PyErr}; @@ -19,7 +19,7 @@ use pyo3::{create_exception, exceptions::PyException, pyclass, PyErr}; #[cfg_attr(feature = "pyo3", pyclass)] #[derive(Clone, serde::Serialize, serde::Deserialize)] pub struct CircuitPattern { - pub(super) pattern: Pattern, + pub(super) pattern: Pattern, /// The input ports pub(super) inputs: Vec>, /// The output ports @@ -40,14 +40,21 @@ impl CircuitPattern { let mut pattern = Pattern::new(); for cmd in circuit.commands() { let op = cmd.optype().clone(); - pattern.require(cmd.node(), op.try_into().unwrap()); - for out_offset in 0..cmd.output_count() { - let out_offset = Port::new_outgoing(out_offset); - for (next_node, in_offset) in circuit.linked_ports(cmd.node(), out_offset) { - if circuit.get_optype(next_node).tag() != hugr::ops::OpTag::Output { - pattern.add_edge(cmd.node(), next_node, (out_offset, in_offset)); - } - } + pattern.require(cmd.node().into(), op.try_into().unwrap()); + for in_offset in 0..cmd.input_count() { + let in_offset = Port::new_incoming(in_offset); + let edge_prop = + PEdge::try_from_port(cmd.node(), in_offset, circuit).expect("Invalid HUGR"); + let (prev_node, prev_port) = circuit + .linked_ports(cmd.node(), in_offset) + .exactly_one() + .ok() + .expect("invalid HUGR"); + let prev_node = match edge_prop { + PEdge::InternalEdge { .. } => NodeID::HugrNode(prev_node), + PEdge::InputEdge { .. } => NodeID::CopyNode(prev_node, prev_port), + }; + pattern.add_edge(cmd.node().into(), prev_node, edge_prop); } } pattern.set_any_root()?; @@ -83,15 +90,25 @@ impl CircuitPattern { } /// Compute the map from pattern nodes to circuit nodes in `circ`. - pub fn get_match_map(&self, root: Node, circ: &C) -> Option> { + pub fn get_match_map(&self, root: Node, circ: &impl Circuit) -> Option> { let single_matcher = SinglePatternMatcher::from_pattern(self.pattern.clone()); single_matcher .get_match_map( - root, - validate_weighted_node(circ), - validate_unweighted_edge(circ), + root.into(), + validate_circuit_node(circ), + validate_circuit_edge(circ), ) - .map(|m| m.into_iter().collect()) + .map(|m| { + m.into_iter() + .filter_map(|(node_p, node_c)| match (node_p, node_c) { + (NodeID::HugrNode(node_p), NodeID::HugrNode(node_c)) => { + Some((node_p, node_c)) + } + (NodeID::CopyNode(..), NodeID::CopyNode(..)) => None, + _ => panic!("Invalid match map"), + }) + .collect() + }) } } @@ -136,9 +153,17 @@ impl From for PyErr { #[cfg(test)] mod tests { + + use std::collections::HashSet; + + use hugr::builder::{DFGBuilder, Dataflow, DataflowHugr}; + use hugr::extension::prelude::QB_T; + use hugr::ops::LeafOp; + use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE; + use hugr::types::FunctionType; use hugr::Hugr; - use itertools::Itertools; + use crate::extension::REGISTRY; use crate::utils::build_simple_circuit; use crate::T2Op; @@ -153,23 +178,68 @@ mod tests { .unwrap() } + /// A circuit with two rotation gates in sequence, sharing a param + fn circ_with_copy() -> Hugr { + let input_t = vec![QB_T, FLOAT64_TYPE]; + let output_t = vec![QB_T]; + let mut h = DFGBuilder::new(FunctionType::new(input_t, output_t)).unwrap(); + + let mut inps = h.input_wires(); + let qb = inps.next().unwrap(); + let f = inps.next().unwrap(); + + let res = h.add_dataflow_op(T2Op::RxF64, [qb, f]).unwrap(); + let qb = res.outputs().next().unwrap(); + let res = h.add_dataflow_op(T2Op::RxF64, [qb, f]).unwrap(); + let qb = res.outputs().next().unwrap(); + + h.finish_hugr_with_outputs([qb], ®ISTRY).unwrap() + } + + /// A circuit with two rotation gates in parallel, sharing a param + fn circ_with_copy_disconnected() -> Hugr { + let input_t = vec![QB_T, QB_T, FLOAT64_TYPE]; + let output_t = vec![QB_T, QB_T]; + let mut h = DFGBuilder::new(FunctionType::new(input_t, output_t)).unwrap(); + + let mut inps = h.input_wires(); + let qb1 = inps.next().unwrap(); + let qb2 = inps.next().unwrap(); + let f = inps.next().unwrap(); + + let res = h.add_dataflow_op(T2Op::RxF64, [qb1, f]).unwrap(); + let qb1 = res.outputs().next().unwrap(); + let res = h.add_dataflow_op(T2Op::RxF64, [qb2, f]).unwrap(); + let qb2 = res.outputs().next().unwrap(); + + h.finish_hugr_with_outputs([qb1, qb2], ®ISTRY).unwrap() + } + #[test] fn construct_pattern() { let hugr = h_cx(); let p = CircuitPattern::try_from_circuit(&hugr).unwrap(); - let edges = p + let edges: HashSet<_> = p .pattern .edges() .unwrap() .iter() .map(|e| (e.source.unwrap(), e.target.unwrap())) - .collect_vec(); + .collect(); + let inp = hugr.input(); + let cx_gate = NodeID::HugrNode(get_nodes_by_t2op(&hugr, T2Op::CX)[0]); + let h_gate = NodeID::HugrNode(get_nodes_by_t2op(&hugr, T2Op::H)[0]); assert_eq!( - // How would I construct hugr::Nodes for testing here? - edges.len(), - 1 + edges, + [ + (cx_gate, h_gate), + (cx_gate, NodeID::CopyNode(inp, Port::new_outgoing(0))), + (cx_gate, NodeID::CopyNode(inp, Port::new_outgoing(1))), + ] + .into_iter() + .collect() ) } @@ -199,4 +269,40 @@ mod tests { InvalidPattern::NotConnected ); } + + fn get_nodes_by_t2op(circ: &impl Circuit, t2_op: T2Op) -> Vec { + circ.nodes() + .filter(|n| { + let Ok(op): Result = circ.get_optype(*n).clone().try_into() else { + return false; + }; + op == t2_op.into() + }) + .collect() + } + + #[test] + fn pattern_with_copy() { + let circ = circ_with_copy(); + let pattern = CircuitPattern::try_from_circuit(&circ).unwrap(); + let edges = pattern.pattern.edges().unwrap(); + let rx_ns = get_nodes_by_t2op(&circ, T2Op::RxF64); + let inp = circ.input(); + for rx_n in rx_ns { + assert!(edges.iter().any(|e| { + e.reverse().is_none() + && e.source.unwrap() == rx_n.into() + && e.target.unwrap() == NodeID::CopyNode(inp, Port::new_outgoing(1)) + })); + } + } + + #[test] + fn pattern_with_copy_disconnected() { + let circ = circ_with_copy_disconnected(); + assert_eq!( + CircuitPattern::try_from_circuit(&circ).unwrap_err(), + InvalidPattern::NotConnected + ); + } } diff --git a/src/utils.rs b/src/utils.rs index c560d63e..124cda37 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,6 +1,7 @@ //! Utility functions for the library. use hugr::extension::PRELUDE_REGISTRY; +use hugr::types::{Type, TypeBound}; use hugr::{ builder::{BuildError, CircuitBuilder, DFGBuilder, Dataflow, DataflowHugr}, extension::prelude::QB_T, @@ -8,6 +9,10 @@ use hugr::{ Hugr, }; +pub(crate) fn type_is_linear(typ: &Type) -> bool { + !TypeBound::Copyable.contains(typ.least_upper_bound()) +} + // utility for building simple qubit-only circuits. #[allow(unused)] pub(crate) fn build_simple_circuit(