Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Combine ExtensionSolutions (no separate closure) #884

Merged
merged 20 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions quantinuum-hugr/src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,22 @@ use thiserror::Error;
/// been inferred for their inputs.
pub type ExtensionSolution = HashMap<Node, ExtensionSet>;

/// Infer extensions for a hugr. This is the main API exposed by this module
/// Infer extensions for a hugr. This is the main API exposed by this module.
///
/// Return a tuple of the solutions found for locations on the graph, and a
/// closure: a solution which would be valid if all of the variables in the graph
/// were instantiated to an empty extension set. This is used (by validation) to
/// concretise the extension requirements of the whole hugr.
pub fn infer_extensions(
hugr: &impl HugrView,
) -> Result<(ExtensionSolution, ExtensionSolution), InferExtensionError> {
/// Return all the solutions found for locations on the graph, these can be
/// passed to [`validate_with_extension_closure`]
///
/// [`validate_with_extension_closure`]: crate::Hugr::validate_with_extension_closure
pub fn infer_extensions(hugr: &impl HugrView) -> Result<ExtensionSolution, InferExtensionError> {
let mut ctx = UnificationContext::new(hugr);
let solution = ctx.main_loop()?;
ctx.main_loop()?;
ctx.instantiate_variables();
let closed_solution = ctx.main_loop()?;
let closure: ExtensionSolution = closed_solution
let all_results = ctx.main_loop()?;
let new_results = all_results
.into_iter()
.filter(|(node, _)| !solution.contains_key(node))
.filter(|(n, _sol)| hugr.get_nodetype(*n).input_extensions().is_none())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable 👍

.collect();
Ok((solution, closure))
Ok(new_results)
}

/// Metavariables don't need much
Expand Down
114 changes: 106 additions & 8 deletions quantinuum-hugr/src/extension/infer/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ use crate::ops::{LeafOp, OpType};
#[cfg(feature = "extension_inference")]
use crate::{
builder::test::closed_dfg_root_hugr,
hugr::validate::ValidationError,
extension::prelude::PRELUDE_ID,
hugr::{hugrmut::sealed::HugrMutInternals, validate::ValidationError},
ops::{dataflow::DataflowParent, handle::NodeHandle},
};

Expand Down Expand Up @@ -100,13 +101,13 @@ fn from_graph() -> Result<(), Box<dyn Error>> {

hugr.connect(mult_c, 0, output, 0);

let (_, closure) = infer_extensions(&hugr)?;
let solution = infer_extensions(&hugr)?;
let empty = ExtensionSet::new();
let ab = ExtensionSet::from_iter([A, B]);
assert_eq!(*closure.get(&(hugr.root())).unwrap(), empty);
assert_eq!(*closure.get(&(mult_c)).unwrap(), ab);
assert_eq!(*closure.get(&(add_ab)).unwrap(), empty);
assert_eq!(*closure.get(&add_b).unwrap(), ExtensionSet::singleton(&A));
assert_eq!(*solution.get(&(hugr.root())).unwrap(), empty);
assert_eq!(*solution.get(&(mult_c)).unwrap(), ab);
assert_eq!(*solution.get(&(add_ab)).unwrap(), empty);
assert_eq!(*solution.get(&add_b).unwrap(), ExtensionSet::singleton(&A));
Ok(())
}

Expand Down Expand Up @@ -249,8 +250,7 @@ fn dangling_src() -> Result<(), Box<dyn Error>> {
hugr.connect(src, 0, mult, 1);
hugr.connect(mult, 0, output, 0);

let closure = hugr.infer_extensions()?;
assert!(closure.is_empty());
hugr.infer_extensions()?;
assert_eq!(hugr.get_nodetype(src.node()).io_extensions().unwrap().1, rs);
assert_eq!(
hugr.get_nodetype(mult.node()).io_extensions().unwrap(),
Expand Down Expand Up @@ -795,6 +795,104 @@ fn test_cfg_loops() -> Result<(), Box<dyn Error>> {
Ok(())
}

#[test]
#[cfg(feature = "extension_inference")]
fn test_validate_with_closure() -> Result<(), Box<dyn Error>> {
const EXT_ID: ExtensionId = ExtensionId::new_unchecked("foo");
let sig = FunctionType::new_endo(type_row![QB_T]);
let inner_open = {
let mut h = closed_dfg_root_hugr(sig.clone());
h.replace_op(h.root(), NodeType::new_open(h.get_optype(h.root()).clone()))?;
let [input, output] = h.get_io(h.root()).unwrap();
h.connect(input, 0, output, 0);
h
};

let inner_prelude = {
let mut h = inner_open.clone();
h.replace_op(
h.root(),
NodeType::new(
h.get_optype(h.root()).clone(),
ExtensionSet::singleton(&PRELUDE_ID),
),
)?;
h
};

let inner_other = {
let mut h = inner_open.clone();
h.replace_op(
h.root(),
NodeType::new(
h.get_optype(h.root()).clone(),
ExtensionSet::singleton(&EXT_ID),
),
)?;
h
};

// All three can be inferred and validated, without writing solutions in:
for inner in [&inner_open, &inner_prelude, &inner_other] {
assert_matches!(
inner.validate(&PRELUDE_REGISTRY),
Err(ValidationError::ExtensionError(_))
);

let soln = infer_extensions(inner)?;
inner.validate_with_extension_closure(soln, &PRELUDE_REGISTRY)?;
}

// Helper builds a Hugr with extensions {PRELUDE_ID}, around argument
let build_outer_prelude = |inner: Hugr| -> Hugr {
let mut h = closed_dfg_root_hugr(sig.clone());
h.replace_op(
h.root(),
NodeType::new(
h.get_optype(h.root()).clone(),
ExtensionSet::singleton(&PRELUDE_ID),
),
)
.unwrap();
let [input, output] = h.get_io(h.root()).unwrap();
let inner_node = h.insert_hugr(h.root(), inner).new_root;
h.connect(input, 0, inner_node, 0);
h.connect(inner_node, 0, output, 0);
h
};

// Building a Hugr around the inner DFG works if the inner DFG is open,
// or has the correct (prelude) extensions:
for inner in [&inner_open, &inner_prelude] {
let mut h = build_outer_prelude(inner.clone());
h.update_validate(&PRELUDE_REGISTRY)?;
}

// ...but fails if the inner DFG already has the 'wrong' extensions:
//let reg = ExtensionRegistry::try_new([Extension::new(EXT_ID), PRELUDE.to_owned()])?;
assert_matches!(
build_outer_prelude(inner_other.clone()).update_validate(&PRELUDE_REGISTRY),
Err(ValidationError::CantInfer(_))
);

// If we do inference on the inner Hugr first, this works if the
// inner DFG already had the correct input-extensions:
let mut inner_prelude = inner_prelude.clone();
inner_prelude.update_validate(&PRELUDE_REGISTRY)?;
build_outer_prelude(inner_prelude).update_validate(&PRELUDE_REGISTRY)?;

// But fails even for previously-open inner DFG as inference
// infers an incorrect (empty) solution:
let mut inner_inferred = inner_open;
inner_inferred.update_validate(&PRELUDE_REGISTRY)?;
assert_matches!(
build_outer_prelude(inner_inferred).update_validate(&PRELUDE_REGISTRY),
Err(ValidationError::CantInfer(_))
);

Ok(())
}

#[test]
/// A control flow graph consisting of an entry node and a single block
/// which adds a resource and links to both itself and the exit node.
Expand Down
26 changes: 11 additions & 15 deletions quantinuum-hugr/src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ pub mod serialize;
pub mod validate;
pub mod views;

#[cfg(not(feature = "extension_inference"))]
use std::collections::HashMap;
use std::collections::VecDeque;
use std::iter;

Expand Down Expand Up @@ -198,29 +196,27 @@ impl Hugr {
extension_registry: &ExtensionRegistry,
) -> Result<(), ValidationError> {
resolve_extension_ops(self, extension_registry)?;
let closure = self.infer_extensions()?;
self.validate_with_extension_closure(closure, extension_registry)?;
self.infer_extensions()?;
self.validate(extension_registry)?;
Ok(())
}

/// Infer extension requirements and add new information to `op_types` field
/// (if the "extension_inference" feature is on; otherwise, do nothing)
///
/// See [`infer_extensions`] for details on the "closure" value
#[cfg(feature = "extension_inference")]
pub fn infer_extensions(&mut self) -> Result<ExtensionSolution, InferExtensionError> {
let (solution, extension_closure) = infer_extensions(self)?;
self.instantiate_extensions(solution);
Ok(extension_closure)
}
/// Do nothing - this functionality is gated by the feature "extension_inference"
#[cfg(not(feature = "extension_inference"))]
pub fn infer_extensions(&mut self) -> Result<ExtensionSolution, InferExtensionError> {
Ok(HashMap::new())
pub fn infer_extensions(&mut self) -> Result<(), InferExtensionError> {
#[cfg(feature = "extension_inference")]
{
let solution = infer_extensions(self)?;
self.instantiate_extensions(&solution);
}
Ok(())
}

#[allow(dead_code)]
/// Add extension requirement information to the hugr in place.
fn instantiate_extensions(&mut self, solution: ExtensionSolution) {
fn instantiate_extensions(&mut self, solution: &ExtensionSolution) {
// We only care about inferred _input_ extensions, because `NodeType`
// uses those to infer the output extensions
for (node, input_extensions) in solution.iter() {
Expand Down
7 changes: 3 additions & 4 deletions quantinuum-hugr/src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,14 @@ fn children_restrictions() {
b.update_validate(&EMPTY_REG),
Err(ValidationError::NonContainerWithChildren { node, .. }) => assert_eq!(node, copy)
);
let closure = b.infer_extensions().unwrap();
b.infer_extensions().unwrap();
b.set_parent(new_def, root);

// After moving the previous definition to a valid place,
// add an input node to the module subgraph
let new_input = b.add_node_with_parent(root, ops::Input::new(type_row![]));
assert_matches!(
b.validate_with_extension_closure(closure, &EMPTY_REG),
b.validate(&EMPTY_REG),
Err(ValidationError::InvalidParentOp { parent, child, .. }) => {assert_eq!(parent, root); assert_eq!(child, new_input)}
);
}
Expand Down Expand Up @@ -590,8 +590,7 @@ mod extension_tests {
.unwrap();
// Write Extension annotations into the Hugr while it's still well-formed
// enough for us to compute them
let closure = b.infer_extensions().unwrap();
b.instantiate_extensions(closure);
b.infer_extensions().unwrap();
b.validate(&EMPTY_REG).unwrap();
b.replace_op(
copy,
Expand Down