Skip to content

Commit

Permalink
refactor: Combine ExtensionSolutions (no separate closure) (#884)
Browse files Browse the repository at this point in the history
* `infer::infer_extensions` returns only a combined solution (for
previously-open locations), after variables instantiated
* `Hugr::infer_extensions` writes (all parts of) the solution into place
*and* returns it
* `validate_with_extension_closure` left in-place, with test
demonstrating usage w/ sub-DFGs
* This should open the way (in future PRs) to changing implementation of
`infer::infer_extensions`
  • Loading branch information
acl-cqc authored Apr 9, 2024
1 parent b05dd6b commit d98fb79
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 40 deletions.
24 changes: 11 additions & 13 deletions hugr/src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,22 @@ use thiserror::Error;
/// been inferred for their inputs.
pub type ExtensionSolution = HashMap<Node, ExtensionSet>;

/// Infer extensions for a hugr. This is the main API exposed by this module
/// Infer extensions for a hugr. This is the main API exposed by this module.
///
/// Return a tuple of the solutions found for locations on the graph, and a
/// closure: a solution which would be valid if all of the variables in the graph
/// were instantiated to an empty extension set. This is used (by validation) to
/// concretise the extension requirements of the whole hugr.
pub fn infer_extensions(
hugr: &impl HugrView,
) -> Result<(ExtensionSolution, ExtensionSolution), InferExtensionError> {
/// Return all the solutions found for locations on the graph, these can be
/// passed to [`validate_with_extension_closure`]
///
/// [`validate_with_extension_closure`]: crate::Hugr::validate_with_extension_closure
pub fn infer_extensions(hugr: &impl HugrView) -> Result<ExtensionSolution, InferExtensionError> {
let mut ctx = UnificationContext::new(hugr);
let solution = ctx.main_loop()?;
ctx.main_loop()?;
ctx.instantiate_variables();
let closed_solution = ctx.main_loop()?;
let closure: ExtensionSolution = closed_solution
let all_results = ctx.main_loop()?;
let new_results = all_results
.into_iter()
.filter(|(node, _)| !solution.contains_key(node))
.filter(|(n, _sol)| hugr.get_nodetype(*n).input_extensions().is_none())
.collect();
Ok((solution, closure))
Ok(new_results)
}

/// Metavariables don't need much
Expand Down
91 changes: 83 additions & 8 deletions hugr/src/extension/infer/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ use crate::ops::{LeafOp, OpType};
#[cfg(feature = "extension_inference")]
use crate::{
builder::test::closed_dfg_root_hugr,
hugr::validate::ValidationError,
extension::prelude::PRELUDE_ID,
hugr::{hugrmut::sealed::HugrMutInternals, validate::ValidationError},
ops::{dataflow::DataflowParent, handle::NodeHandle},
};

Expand Down Expand Up @@ -100,13 +101,13 @@ fn from_graph() -> Result<(), Box<dyn Error>> {

hugr.connect(mult_c, 0, output, 0);

let (_, closure) = infer_extensions(&hugr)?;
let solution = infer_extensions(&hugr)?;
let empty = ExtensionSet::new();
let ab = ExtensionSet::from_iter([A, B]);
assert_eq!(*closure.get(&(hugr.root())).unwrap(), empty);
assert_eq!(*closure.get(&(mult_c)).unwrap(), ab);
assert_eq!(*closure.get(&(add_ab)).unwrap(), empty);
assert_eq!(*closure.get(&add_b).unwrap(), ExtensionSet::singleton(&A));
assert_eq!(*solution.get(&(hugr.root())).unwrap(), empty);
assert_eq!(*solution.get(&(mult_c)).unwrap(), ab);
assert_eq!(*solution.get(&(add_ab)).unwrap(), empty);
assert_eq!(*solution.get(&add_b).unwrap(), ExtensionSet::singleton(&A));
Ok(())
}

Expand Down Expand Up @@ -249,8 +250,7 @@ fn dangling_src() -> Result<(), Box<dyn Error>> {
hugr.connect(src, 0, mult, 1);
hugr.connect(mult, 0, output, 0);

let closure = hugr.infer_extensions()?;
assert!(closure.is_empty());
hugr.infer_extensions()?;
assert_eq!(hugr.get_nodetype(src.node()).io_extensions().unwrap().1, rs);
assert_eq!(
hugr.get_nodetype(mult.node()).io_extensions().unwrap(),
Expand Down Expand Up @@ -795,6 +795,81 @@ fn test_cfg_loops() -> Result<(), Box<dyn Error>> {
Ok(())
}

#[test]
#[cfg(feature = "extension_inference")]
fn test_validate_with_closure() -> Result<(), Box<dyn Error>> {
fn dfg_hugr_with_exts(e: Option<ExtensionSet>) -> (Hugr, Node, Node) {
let mut h = closed_dfg_root_hugr(FunctionType::new_endo(type_row![QB_T]));
h.replace_op(h.root(), NodeType::new(h.get_optype(h.root()).clone(), e))
.unwrap();
let [input, output] = h.get_io(h.root()).unwrap();
(h, input, output)
}
fn identity_hugr_with_exts(e: Option<ExtensionSet>) -> Hugr {
let (mut h, input, output) = dfg_hugr_with_exts(e);
h.connect(input, 0, output, 0);
h
}

const EXT_ID: ExtensionId = ExtensionId::new_unchecked("foo");

let inner_open = identity_hugr_with_exts(None);

let inner_prelude = identity_hugr_with_exts(Some(ExtensionSet::singleton(&PRELUDE_ID)));

let inner_other = identity_hugr_with_exts(Some(ExtensionSet::singleton(&EXT_ID)));

// All three can be inferred and validated, without writing solutions in:
for inner in [&inner_open, &inner_prelude, &inner_other] {
assert_matches!(
inner.validate(&PRELUDE_REGISTRY),
Err(ValidationError::ExtensionError(_))
);

let soln = infer_extensions(inner)?;
inner.validate_with_extension_closure(soln, &PRELUDE_REGISTRY)?;
}

// Helper builds a Hugr with extensions {PRELUDE_ID}, around argument
let build_outer_prelude = |inner: Hugr| -> Hugr {
let (mut h, input, output) = dfg_hugr_with_exts(Some(ExtensionSet::singleton(&PRELUDE_ID)));
let inner_node = h.insert_hugr(h.root(), inner).new_root;
h.connect(input, 0, inner_node, 0);
h.connect(inner_node, 0, output, 0);
h
};

// Building a Hugr around the inner DFG works if the inner DFG is open,
// or has the correct (prelude) extensions:
for inner in [&inner_open, &inner_prelude] {
let mut h = build_outer_prelude(inner.clone());
h.update_validate(&PRELUDE_REGISTRY)?;
}

// ...but fails if the inner DFG already has the 'wrong' extensions:
assert_matches!(
build_outer_prelude(inner_other.clone()).update_validate(&PRELUDE_REGISTRY),
Err(ValidationError::CantInfer(_))
);

// If we do inference on the inner Hugr first, this (still) works if the
// inner DFG already had the correct input-extensions:
let mut inner_prelude_inferred = inner_prelude;
inner_prelude_inferred.update_validate(&PRELUDE_REGISTRY)?;
build_outer_prelude(inner_prelude_inferred).update_validate(&PRELUDE_REGISTRY)?;

// But fails for previously-open inner DFG as inference
// infers an incorrect (empty) solution:
let mut inner_inferred = inner_open;
inner_inferred.update_validate(&PRELUDE_REGISTRY)?;
assert_matches!(
build_outer_prelude(inner_inferred).update_validate(&PRELUDE_REGISTRY),
Err(ValidationError::CantInfer(_))
);

Ok(())
}

#[test]
/// A control flow graph consisting of an entry node and a single block
/// which adds a resource and links to both itself and the exit node.
Expand Down
26 changes: 11 additions & 15 deletions hugr/src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ pub mod serialize;
pub mod validate;
pub mod views;

#[cfg(not(feature = "extension_inference"))]
use std::collections::HashMap;
use std::collections::VecDeque;
use std::iter;

Expand Down Expand Up @@ -198,29 +196,27 @@ impl Hugr {
extension_registry: &ExtensionRegistry,
) -> Result<(), ValidationError> {
resolve_extension_ops(self, extension_registry)?;
let closure = self.infer_extensions()?;
self.validate_with_extension_closure(closure, extension_registry)?;
self.infer_extensions()?;
self.validate(extension_registry)?;
Ok(())
}

/// Infer extension requirements and add new information to `op_types` field
/// (if the "extension_inference" feature is on; otherwise, do nothing)
///
/// See [`infer_extensions`] for details on the "closure" value
#[cfg(feature = "extension_inference")]
pub fn infer_extensions(&mut self) -> Result<ExtensionSolution, InferExtensionError> {
let (solution, extension_closure) = infer_extensions(self)?;
self.instantiate_extensions(solution);
Ok(extension_closure)
}
/// Do nothing - this functionality is gated by the feature "extension_inference"
#[cfg(not(feature = "extension_inference"))]
pub fn infer_extensions(&mut self) -> Result<ExtensionSolution, InferExtensionError> {
Ok(HashMap::new())
pub fn infer_extensions(&mut self) -> Result<(), InferExtensionError> {
#[cfg(feature = "extension_inference")]
{
let solution = infer_extensions(self)?;
self.instantiate_extensions(&solution);
}
Ok(())
}

#[allow(dead_code)]
/// Add extension requirement information to the hugr in place.
fn instantiate_extensions(&mut self, solution: ExtensionSolution) {
fn instantiate_extensions(&mut self, solution: &ExtensionSolution) {
// We only care about inferred _input_ extensions, because `NodeType`
// uses those to infer the output extensions
for (node, input_extensions) in solution.iter() {
Expand Down
7 changes: 3 additions & 4 deletions hugr/src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,14 @@ fn children_restrictions() {
b.update_validate(&EMPTY_REG),
Err(ValidationError::NonContainerWithChildren { node, .. }) => assert_eq!(node, copy)
);
let closure = b.infer_extensions().unwrap();
b.infer_extensions().unwrap();
b.set_parent(new_def, root);

// After moving the previous definition to a valid place,
// add an input node to the module subgraph
let new_input = b.add_node_with_parent(root, ops::Input::new(type_row![]));
assert_matches!(
b.validate_with_extension_closure(closure, &EMPTY_REG),
b.validate(&EMPTY_REG),
Err(ValidationError::InvalidParentOp { parent, child, .. }) => {assert_eq!(parent, root); assert_eq!(child, new_input)}
);
}
Expand Down Expand Up @@ -608,8 +608,7 @@ mod extension_tests {
.unwrap();
// Write Extension annotations into the Hugr while it's still well-formed
// enough for us to compute them
let closure = b.infer_extensions().unwrap();
b.instantiate_extensions(closure);
b.infer_extensions().unwrap();
b.validate(&EMPTY_REG).unwrap();
b.replace_op(
copy,
Expand Down

0 comments on commit d98fb79

Please sign in to comment.