diff --git a/hugr/src/hugr.rs b/hugr/src/hugr.rs index 02b5a25b5..4e817d2e0 100644 --- a/hugr/src/hugr.rs +++ b/hugr/src/hugr.rs @@ -196,6 +196,7 @@ impl Hugr { extension_registry: &ExtensionRegistry, ) -> Result<(), ValidationError> { resolve_extension_ops(self, extension_registry)?; + self.validate_no_extensions(extension_registry)?; self.infer_extensions()?; self.validate(extension_registry)?; Ok(()) diff --git a/hugr/src/hugr/validate.rs b/hugr/src/hugr/validate.rs index 26857495a..2290824cd 100644 --- a/hugr/src/hugr/validate.rs +++ b/hugr/src/hugr/validate.rs @@ -37,7 +37,7 @@ struct ValidationContext<'a, 'b> { dominators: HashMap>, /// Context for the extension validation. #[allow(dead_code)] - extension_validator: ExtensionValidator, + extension_validator: Option, /// Registry of available Extensions extension_registry: &'b ExtensionRegistry, } @@ -51,6 +51,16 @@ impl Hugr { self.validate_with_extension_closure(HashMap::new(), extension_registry) } + /// Check the validity of the HUGR, but don't check consistency of extension + /// requirements between connected nodes or between parents and children. + pub fn validate_no_extensions( + &self, + extension_registry: &ExtensionRegistry, + ) -> Result<(), ValidationError> { + let mut validator = ValidationContext::new(self, HashMap::new(), extension_registry, false); + validator.validate() + } + /// Check the validity of a hugr, taking an argument of a closure for the /// free extension variables pub fn validate_with_extension_closure( @@ -58,7 +68,7 @@ impl Hugr { closure: ExtensionSolution, extension_registry: &ExtensionRegistry, ) -> Result<(), ValidationError> { - let mut validator = ValidationContext::new(self, closure, extension_registry); + let mut validator = ValidationContext::new(self, closure, extension_registry, true); validator.validate() } } @@ -72,11 +82,16 @@ impl<'a, 'b> ValidationContext<'a, 'b> { hugr: &'a Hugr, extension_closure: ExtensionSolution, extension_registry: &'b ExtensionRegistry, + validate_extensions: bool, ) -> Self { Self { hugr, dominators: HashMap::new(), - extension_validator: ExtensionValidator::new(hugr, extension_closure), + extension_validator: if validate_extensions { + Some(ExtensionValidator::new(hugr, extension_closure)) + } else { + None + }, extension_registry, } } @@ -183,8 +198,9 @@ impl<'a, 'b> ValidationContext<'a, 'b> { // If this is a container with I/O nodes, check that the extension they // define match the extensions of the container. if let Some([input, output]) = self.hugr.get_io(node) { - self.extension_validator - .validate_io_extensions(node, input, output)?; + if let Some(validator) = &self.extension_validator { + validator.validate_io_extensions(node, input, output)?; + } } } @@ -248,8 +264,10 @@ impl<'a, 'b> ValidationContext<'a, 'b> { let other_offset = self.hugr.graph.port_offset(link).unwrap().into(); #[cfg(feature = "extension_inference")] - self.extension_validator - .check_extensions_compatible(&(node, port), &(other_node, other_offset))?; + if let Some(validator) = &self.extension_validator { + validator + .check_extensions_compatible(&(node, port), &(other_node, other_offset))?; + } let other_op = self.hugr.get_optype(other_node); let Some(other_kind) = other_op.port_kind(other_offset) else {