diff --git a/src/builder/dataflow.rs b/src/builder/dataflow.rs index 1b547172e..24e580a2a 100644 --- a/src/builder/dataflow.rs +++ b/src/builder/dataflow.rs @@ -97,7 +97,8 @@ impl DFGBuilder { impl HugrBuilder for DFGBuilder { fn finish_hugr(mut self) -> Result { - self.base.infer_extensions()?; + let closure = self.base.infer_extensions()?; + self.base.validate_with_extension_closure(closure)?; Ok(self.base) } } diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 07014b561..c8e707eb3 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -28,9 +28,23 @@ use thiserror::Error; pub type ExtensionSolution = HashMap<(Node, Direction), ExtensionSet>; /// Infer extensions for a hugr. This is the main API exposed by this module -pub fn infer_extensions(hugr: &impl HugrView) -> Result { +/// +/// Return a tuple of the solutions found for locations on the graph, and a +/// closure: a solution which would be valid if all of the variables in the graph +/// were instantiated to an empty extension set. This is used (by validation) to +/// concretise the extension requirements of the whole hugr. +pub fn infer_extensions( + hugr: &impl HugrView, +) -> Result<(ExtensionSolution, ExtensionSolution), InferExtensionError> { let mut ctx = UnificationContext::new(hugr); - ctx.main_loop() + let solution = ctx.main_loop()?; + ctx.instantiate_variables(); + let closed_solution = ctx.main_loop()?; + let closure: HashMap<(Node, Direction), ExtensionSet> = closed_solution + .into_iter() + .filter(|(loc, _)| !solution.contains_key(loc)) + .collect(); + Ok((solution, closure)) } /// Metavariables don't need much @@ -599,6 +613,16 @@ impl UnificationContext { } self.results() } + + /// Instantiate all variables in the graph with the empty extension set. + /// This is done to solve metas which depend on variables, which allows + /// us to come up with a fully concrete solution to pass into validation. + pub fn instantiate_variables(&mut self) { + for m in self.variables.clone().into_iter() { + self.add_solution(m, ExtensionSet::new()); + } + self.variables = HashSet::new(); + } } #[cfg(test)] @@ -856,7 +880,8 @@ mod test { let hugr = builder.base; // TODO: when we put new extensions onto the graph after inference, we // can call `finish_hugr` and just look at the graph - let solution = infer_extensions(&hugr)?; + let (solution, extra) = infer_extensions(&hugr)?; + assert!(extra.is_empty()); assert_eq!( *solution.get(&(src.node(), Direction::Outgoing)).unwrap(), rs diff --git a/src/extension/validate.rs b/src/extension/validate.rs index 55cc38d08..732f76e0d 100644 --- a/src/extension/validate.rs +++ b/src/extension/validate.rs @@ -20,9 +20,12 @@ pub struct ExtensionValidator { impl ExtensionValidator { /// Initialise a new extension validator, pre-computing the extension /// requirements for each node in the Hugr. - pub fn new(hugr: &Hugr) -> Self { + /// + /// The `closure` argument is a set of extensions which doesn't actually + /// live on the graph, but is used to close the graph for validation + pub fn new(hugr: &Hugr, closure: HashMap<(Node, Direction), ExtensionSet>) -> Self { let mut validator = ExtensionValidator { - extensions: HashMap::new(), + extensions: closure, }; for node in hugr.nodes() { diff --git a/src/hugr.rs b/src/hugr.rs index fc7906fa0..4e1c52edf 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -7,7 +7,7 @@ pub mod serialize; pub mod validate; pub mod views; -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use std::iter; pub(crate) use self::hugrmut::HugrInternalsMut; @@ -194,10 +194,12 @@ impl Hugr { } /// Infer extension requirements - pub fn infer_extensions(&mut self) -> Result<(), InferExtensionError> { - let solution = infer_extensions(self)?; + pub fn infer_extensions( + &mut self, + ) -> Result, InferExtensionError> { + let (solution, extension_closure) = infer_extensions(self)?; self.instantiate_extensions(solution); - Ok(()) + Ok(extension_closure) } /// TODO: Write this diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index f6521c31b..f063ec2fd 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -14,7 +14,7 @@ use pyo3::prelude::*; use crate::extension::{ validate::{ExtensionError, ExtensionValidator}, - InferExtensionError, + ExtensionSet, InferExtensionError, }; use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError}; use crate::ops::{OpTag, OpTrait, OpType, ValidateOp}; @@ -38,20 +38,35 @@ struct ValidationContext<'a> { } impl Hugr { - /// Check the validity of the HUGR. + /// Check the validity of the HUGR, assuming that it has no open extension + /// variables. + /// TODO: Add a version of validation which allows for open extension + /// variables (see github issue #457) pub fn validate(&self) -> Result<(), ValidationError> { - let mut validator = ValidationContext::new(self); + self.validate_with_extension_closure(HashMap::new()) + } + + /// Check the validity of a hugr, taking an argument of a closure for the + /// free extension variables + pub fn validate_with_extension_closure( + &self, + closure: HashMap<(Node, Direction), ExtensionSet>, + ) -> Result<(), ValidationError> { + let mut validator = ValidationContext::new(self, closure); validator.validate() } } impl<'a> ValidationContext<'a> { /// Create a new validation context. - pub fn new(hugr: &'a Hugr) -> Self { + pub fn new( + hugr: &'a Hugr, + extension_closure: HashMap<(Node, Direction), ExtensionSet>, + ) -> Self { Self { hugr, dominators: HashMap::new(), - extension_validator: ExtensionValidator::new(hugr), + extension_validator: ExtensionValidator::new(hugr, extension_closure), } }