diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index e4e721de6..b0c725830 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -73,6 +73,42 @@ impl<'a> ValidationContext<'a> { Ok(()) } + /// Check that input and output nodes match up with the signature of their parents + /// This must be done after the `gather_resources` step + fn validate_io_resources( + &self, + hugr: &impl HugrView, + parent: Node, + ) -> Result<(), ValidationError> { + if let Some([input, output]) = hugr.get_io(parent) { + let parent_input_resources = + self.resources.get(&(parent, Direction::Incoming)).unwrap(); + let parent_output_resources = + self.resources.get(&(parent, Direction::Outgoing)).unwrap(); + for dir in Direction::BOTH { + let input_resources = self.resources.get(&(input, dir)).unwrap(); + let output_resources = self.resources.get(&(output, dir)).unwrap(); + if parent_input_resources != input_resources { + return Err(ValidationError::ParentIOResourceMismatch { + parent, + parent_resources: parent_input_resources.clone(), + child: input, + child_resources: input_resources.clone(), + }); + }; + if parent_output_resources != output_resources { + return Err(ValidationError::ParentIOResourceMismatch { + parent, + parent_resources: parent_output_resources.clone(), + child: output, + child_resources: output_resources.clone(), + }); + }; + } + }; + Ok(()) + } + /// Use the signature supplied by a dataflow node to work out the /// resource requirements for all of its input and output edges, then put /// those requirements in the ValidationContext @@ -161,6 +197,10 @@ impl<'a> ValidationContext<'a> { // Check operation-specific constraints self.validate_operation(node, op_type)?; + // If this is a DataflowParent, check that the I/O nodes + // match the parent's signature + self.validate_io_resources(&self.hugr, node)?; + Ok(()) } @@ -669,6 +709,13 @@ pub enum ValidationError { }, #[error("Missing input resources for node {0:?}")] MissingInputResources(Node), + #[error("Resources of I/O node ({child:?}) {child_resources:?} don't match those expected by parent node ({parent:?}): {parent_resources:?}")] + ParentIOResourceMismatch { + parent: Node, + parent_resources: ResourceSet, + child: Node, + child_resources: ResourceSet, + }, } #[cfg(feature = "pyo3")]