diff --git a/src/hugr/views/sibling_subgraph.rs b/src/hugr/views/sibling_subgraph.rs index 81e0cacff..ed291b614 100644 --- a/src/hugr/views/sibling_subgraph.rs +++ b/src/hugr/views/sibling_subgraph.rs @@ -40,7 +40,7 @@ use pyo3::{create_exception, exceptions::PyException, PyErr}; /// [`super::SiblingGraph`], not all nodes of the sibling graph must be /// included. A convex subgraph is always an induced subgraph, i.e. it is defined /// by a set of nodes and all edges between them. - +/// /// The incoming boundary (resp. outgoing boundary) is given by the input (resp. /// output) ports of the subgraph that are linked to nodes outside of the subgraph. /// The signature of the subgraph is then given by the types of the incoming @@ -485,6 +485,9 @@ fn validate_subgraph( inputs: &IncomingPorts, outputs: &OutgoingPorts, ) -> Result<(), InvalidSubgraph> { + // Copy of the nodes for fast lookup. + let node_set = nodes.iter().copied().collect::>(); + // Check nodes is not empty if nodes.is_empty() { return Err(InvalidSubgraph::EmptySubgraph); @@ -501,40 +504,48 @@ fn validate_subgraph( .chain(outputs) .any(|&(n, p)| is_order_edge(hugr, n, p)) { - unimplemented!("Linked other ports not supported at boundary") + unimplemented!("Connected order edges not supported at the boundary") } // Check inputs are incoming ports and outputs are outgoing ports - if inputs + if let Some(&(n, p)) = inputs .iter() .flatten() - .any(|(_, p)| p.direction() == Direction::Outgoing) + .find(|(_, p)| p.direction() == Direction::Outgoing) { - return Err(InvalidSubgraph::InvalidBoundary); - } - if outputs + Err(InvalidSubgraphBoundary::InputPortDirection(n, p))?; + }; + if let Some(&(n, p)) = outputs .iter() - .any(|(_, p)| p.direction() == Direction::Incoming) + .find(|(_, p)| p.direction() == Direction::Incoming) { - return Err(InvalidSubgraph::InvalidBoundary); - } + Err(InvalidSubgraphBoundary::OutputPortDirection(n, p))?; + }; - let mut ports_inside = inputs.iter().flatten().chain(outputs).copied(); - // Check incoming & outgoing ports have target resp. source inside - let nodes = nodes.iter().copied().collect::>(); - if ports_inside.any(|(n, _)| !nodes.contains(&n)) { - return Err(InvalidSubgraph::InvalidBoundary); - } + let boundary_ports = inputs + .iter() + .flatten() + .chain(outputs) + .copied() + .collect_vec(); + // Check that the boundary ports are all in the subgraph. + if let Some(&(n, p)) = boundary_ports.iter().find(|(n, _)| !node_set.contains(n)) { + Err(InvalidSubgraphBoundary::PortNodeNotInSet(n, p))?; + }; // Check that every inside port has at least one linked port outside. - if ports_inside.any(|(n, p)| hugr.linked_ports(n, p).all(|(n1, _)| nodes.contains(&n1))) { - return Err(InvalidSubgraph::InvalidBoundary); - } + if let Some(&(n, p)) = boundary_ports.iter().find(|&&(n, p)| { + hugr.linked_ports(n, p) + .all(|(n1, _)| node_set.contains(&n1)) + }) { + Err(InvalidSubgraphBoundary::DisconnectedBoundaryPort(n, p))?; + }; + // Check that every incoming port of a node in the subgraph whose source is not in the subgraph // belongs to inputs. - if nodes.clone().into_iter().any(|n| { + if nodes.iter().any(|&n| { hugr.node_inputs(n).any(|p| { hugr.linked_ports(n, p).any(|(n1, _)| { - !nodes.contains(&n1) && !inputs.iter().any(|nps| nps.contains(&(n, p))) + !node_set.contains(&n1) && !inputs.iter().any(|nps| nps.contains(&(n, p))) }) }) }) { @@ -542,10 +553,10 @@ fn validate_subgraph( } // Check that every outgoing port of a node in the subgraph whose target is not in the subgraph // belongs to outputs. - if nodes.clone().into_iter().any(|n| { + if nodes.iter().any(|&n| { hugr.node_outputs(n).any(|p| { hugr.linked_ports(n, p) - .any(|(n1, _)| !nodes.contains(&n1) && !outputs.contains(&(n, p))) + .any(|(n1, _)| !node_set.contains(&n1) && !outputs.contains(&(n, p))) }) }) { return Err(InvalidSubgraph::NotConvex); @@ -553,24 +564,24 @@ fn validate_subgraph( // Check inputs are unique if !inputs.iter().flatten().all_unique() { - return Err(InvalidSubgraph::InvalidBoundary); + return Err(InvalidSubgraphBoundary::NonUniqueInput.into()); } // Check no incoming partition is empty if inputs.iter().any(|p| p.is_empty()) { - return Err(InvalidSubgraph::InvalidBoundary); + return Err(InvalidSubgraphBoundary::EmptyPartition.into()); } // Check edge types are equal within partition and copyable if partition size > 1 - if !inputs.iter().all(|ports| { + if let Some((i, _)) = inputs.iter().enumerate().find(|(_, ports)| { let Some(edge_t) = get_edge_type(hugr, ports) else { - return false; + return true; }; let require_copy = ports.len() > 1; - !require_copy || edge_t.copyable() + require_copy && !edge_t.copyable() }) { - return Err(InvalidSubgraph::InvalidBoundary); - } + Err(InvalidSubgraphBoundary::MismatchedTypes(i))?; + }; Ok(()) } @@ -663,13 +674,41 @@ pub enum InvalidSubgraph { EmptySubgraph, /// An invalid boundary port was found. #[error("Invalid boundary port.")] - InvalidBoundary, + InvalidBoundary(#[from] InvalidSubgraphBoundary), +} + +/// Errors that can occur while constructing a [`SiblingSubgraph`]. +#[derive(Debug, Clone, PartialEq, Eq, Error)] +pub enum InvalidSubgraphBoundary { + /// A node in the input boundary is not Incoming. + #[error("Expected (node {0:?}, port {1:?}) in the input boundary to be an incoming port.")] + InputPortDirection(Node, Port), + /// A node in the output boundary is not Outgoing. + #[error("Expected (node {0:?}, port {1:?}) in the input boundary to be an outgoing port.")] + OutputPortDirection(Node, Port), + /// A boundary port's node is not in the set of nodes. + #[error("(node {0:?}, port {1:?}) is in the boundary, but node {0:?} is not in the set.")] + PortNodeNotInSet(Node, Port), + /// A boundary port has no connections outside the subgraph. + #[error("(node {0:?}, port {1:?}) is in the boundary, but the port is not connected to a node outside the subgraph.")] + DisconnectedBoundaryPort(Node, Port), + /// There's a non-unique input-boundary port. + #[error("A port in the input boundary is used multiple times.")] + NonUniqueInput, + /// There's an empty partition in the input boundary. + #[error("A partition in the input boundary is empty.")] + EmptyPartition, + /// Different types in a partition of the input boundary. + #[error("The partition {0} in the input boundary has ports with different types.")] + MismatchedTypes(usize), } #[cfg(test)] mod tests { use std::error::Error; + use cool_asserts::assert_matches; + use crate::extension::PRELUDE_REGISTRY; use crate::{ builder::{ @@ -883,14 +922,16 @@ mod tests { let (inp, _) = hugr.children(func_root).take(2).collect_tuple().unwrap(); let first_cx_edge = hugr.node_outputs(inp).next().unwrap(); // All graph but one edge - assert!(matches!( + assert_matches!( SiblingSubgraph::try_new( vec![hugr.linked_ports(inp, first_cx_edge).collect()], vec![(inp, first_cx_edge)], &func, ), - Err(InvalidSubgraph::NotConvex) - )); + Err(InvalidSubgraph::InvalidBoundary( + InvalidSubgraphBoundary::DisconnectedBoundaryPort(_, _) + )) + ); } #[test] @@ -905,14 +946,14 @@ mod tests { let not1_out = hugr.node_outputs(not1).next().unwrap(); let not3_inp = hugr.node_inputs(not3).next().unwrap(); let not3_out = hugr.node_outputs(not3).next().unwrap(); - assert!(matches!( + assert_matches!( SiblingSubgraph::try_new( vec![vec![(not1, not1_inp)], vec![(not3, not3_inp)]], vec![(not1, not1_out), (not3, not3_out)], &func ), Err(InvalidSubgraph::NotConvex) - )); + ); } #[test] @@ -923,14 +964,16 @@ mod tests { let cx_edges_in = hugr.node_outputs(inp); let cx_edges_out = hugr.node_inputs(out); // All graph but the CX - assert!(matches!( + assert_matches!( SiblingSubgraph::try_new( cx_edges_out.map(|p| vec![(out, p)]).collect(), cx_edges_in.map(|p| (inp, p)).collect(), &func, ), - Err(InvalidSubgraph::InvalidBoundary) - )); + Err(InvalidSubgraph::InvalidBoundary( + InvalidSubgraphBoundary::DisconnectedBoundaryPort(_, _) + )) + ); } #[test]