Skip to content

Commit

Permalink
[new + bugfix] replace_op checks bound, adding HugrError::InvalidTag (#…
Browse files Browse the repository at this point in the history
…581)

* Add HugrError::InvalidTag, use in HierarchyView::try_new and
SiblingMut::try_new
* In HugrMut::replace_op, check that changing the root-node's op does
not invalidate the Root: NodeHandle tag.
* Hence, make HugrMut::replace_op return a Result
  • Loading branch information
acl-cqc authored Oct 2, 2023
1 parent 554d658 commit bb09fd4
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 37 deletions.
2 changes: 1 addition & 1 deletion src/builder/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
name,
signature: signature.clone(),
}),
);
)?;

let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, signature, None)?;
Ok(FunctionBuilder::from_dfg_builder(db))
Expand Down
5 changes: 5 additions & 0 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,11 @@ pub enum HugrError {
/// The node doesn't exist.
#[error("Invalid node {0:?}.")]
InvalidNode(Node),
/// The node was not of the required [OpTag]
/// (e.g. to conform to a [HugrView::RootHandle])
#[error("Invalid tag: required a tag in {required} but found {actual}")]
#[allow(missing_docs)]
InvalidTag { required: OpTag, actual: OpTag },
/// An invalid port was specified.
#[error("Invalid port direction {0:?}.")]
InvalidPortDirection(Direction),
Expand Down
20 changes: 16 additions & 4 deletions src/hugr/hugrmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ fn insert_subgraph_internal(

pub(crate) mod sealed {
use super::*;
use crate::ops::handle::NodeHandle;

/// Trait for accessing the mutable internals of a Hugr(Mut).
///
Expand Down Expand Up @@ -500,8 +501,18 @@ pub(crate) mod sealed {
/// In general this invalidates the ports, which may need to be resized to
/// match the OpType signature.
/// TODO: Add a version which ignores input extensions
fn replace_op(&mut self, node: Node, op: NodeType) -> NodeType {
self.valid_node(node).unwrap_or_else(|e| panic!("{}", e));
///
/// # Errors
/// Returns a [`HugrError::InvalidTag`] if this would break the bound
/// ([`Self::RootHandle`]) on the root node's [OpTag]
fn replace_op(&mut self, node: Node, op: NodeType) -> Result<NodeType, HugrError> {
self.valid_node(node)?;
if node == self.root() && !Self::RootHandle::TAG.is_superset(op.tag()) {
return Err(HugrError::InvalidTag {
required: Self::RootHandle::TAG,
actual: op.tag(),
});
}
self.hugr_mut().replace_op(node, op)
}
}
Expand Down Expand Up @@ -565,9 +576,10 @@ pub(crate) mod sealed {
Ok(())
}

fn replace_op(&mut self, node: Node, op: NodeType) -> NodeType {
fn replace_op(&mut self, node: Node, op: NodeType) -> Result<NodeType, HugrError> {
// No possibility of failure here since Self::RootHandle == Any
let cur = self.hugr_mut().op_types.get_mut(node.index);
std::mem::replace(cur, op)
Ok(std::mem::replace(cur, op))
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/hugr/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl<R: Rewrite> Rewrite for Transactional<R> {
}
if r.is_err() {
// Try to restore backup.
h.replace_op(h.root(), backup.root_type().clone());
h.replace_op(h.root(), backup.root_type().clone()).unwrap();
while let Some(child) = first_child(h) {
h.remove_node(child).unwrap();
}
Expand Down
26 changes: 17 additions & 9 deletions src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ mod test {
Err(ValidationError::NoParent { node }) => assert_eq!(node, other)
);
b.set_parent(other, root).unwrap();
b.replace_op(other, NodeType::pure(declare_op));
b.replace_op(other, NodeType::pure(declare_op)).unwrap();
b.add_ports(other, Direction::Outgoing, 1);
assert_eq!(b.validate(&EMPTY_REG), Ok(()));

Expand Down Expand Up @@ -929,14 +929,16 @@ mod test {
.unwrap();

// Replace the output operation of the df subgraph with a copy
b.replace_op(output, NodeType::pure(LeafOp::Noop { ty: NAT }));
b.replace_op(output, NodeType::pure(LeafOp::Noop { ty: NAT }))
.unwrap();
assert_matches!(
b.validate(&EMPTY_REG),
Err(ValidationError::InvalidInitialChild { parent, .. }) => assert_eq!(parent, def)
);

// Revert it back to an output, but with the wrong number of ports
b.replace_op(output, NodeType::pure(ops::Output::new(type_row![BOOL_T])));
b.replace_op(output, NodeType::pure(ops::Output::new(type_row![BOOL_T])))
.unwrap();
assert_matches!(
b.validate(&EMPTY_REG),
Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::IOSignatureMismatch { child, .. }, .. })
Expand All @@ -945,13 +947,15 @@ mod test {
b.replace_op(
output,
NodeType::pure(ops::Output::new(type_row![BOOL_T, BOOL_T])),
);
)
.unwrap();

// After fixing the output back, replace the copy with an output op
b.replace_op(
copy,
NodeType::pure(ops::Output::new(type_row![BOOL_T, BOOL_T])),
);
)
.unwrap();
assert_matches!(
b.validate(&EMPTY_REG),
Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::InternalIOChildren { child, .. }, .. })
Expand All @@ -975,7 +979,8 @@ mod test {
NodeType::pure(ops::CFG {
signature: FunctionType::new(type_row![BOOL_T], type_row![BOOL_T]),
}),
);
)
.unwrap();
assert_matches!(
b.validate(&EMPTY_REG),
Err(ValidationError::ContainerWithoutChildren { .. })
Expand Down Expand Up @@ -1033,18 +1038,21 @@ mod test {
other_outputs: type_row![Q],
extension_delta: ExtensionSet::new(),
}),
);
)
.unwrap();
let mut block_children = b.hierarchy.children(block.index);
let block_input = block_children.next().unwrap().into();
let block_output = block_children.next_back().unwrap().into();
b.replace_op(block_input, NodeType::pure(ops::Input::new(type_row![Q])));
b.replace_op(block_input, NodeType::pure(ops::Input::new(type_row![Q])))
.unwrap();
b.replace_op(
block_output,
NodeType::pure(ops::Output::new(type_row![
Type::new_simple_predicate(1),
Q
])),
);
)
.unwrap();
assert_matches!(
b.validate(&EMPTY_REG),
Err(ValidationError::InvalidEdges { parent, source: EdgeValidationError::CFGEdgeSignatureMismatch { .. }, .. })
Expand Down
14 changes: 12 additions & 2 deletions src/hugr/views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use portgraph::{multiportgraph, LinkView, MultiPortGraph, PortView};

use super::{Hugr, HugrError, NodeMetadata, NodeType, DEFAULT_NODETYPE};
use crate::ops::handle::NodeHandle;
use crate::ops::{FuncDecl, FuncDefn, OpName, OpTag, OpType, DFG};
use crate::ops::{FuncDecl, FuncDefn, OpName, OpTag, OpTrait, OpType, DFG};
use crate::types::{EdgeKind, FunctionType};
use crate::{Direction, Node, Port};

Expand Down Expand Up @@ -308,10 +308,20 @@ pub trait HierarchyView<'a>: HugrView + Sized {
/// Create a hierarchical view of a HUGR given a root node.
///
/// # Errors
/// Returns [`HugrError::InvalidNode`] if the root isn't a node of the required [OpTag]
/// Returns [`HugrError::InvalidTag`] if the root isn't a node of the required [OpTag]
fn try_new(hugr: &'a impl HugrView, root: Node) -> Result<Self, HugrError>;
}

fn check_tag<Required: NodeHandle>(hugr: &impl HugrView, node: Node) -> Result<(), HugrError> {
hugr.valid_node(node)?;
let actual = hugr.get_optype(node).tag();
let required = Required::TAG;
if !required.is_superset(actual) {
return Err(HugrError::InvalidTag { required, actual });
}
Ok(())
}

impl<T> HugrView for T
where
T: AsRef<Hugr>,
Expand Down
9 changes: 2 additions & 7 deletions src/hugr/views/descendants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@ use portgraph::{LinkView, MultiPortGraph, PortIndex, PortView};

use crate::hugr::HugrError;
use crate::ops::handle::NodeHandle;
use crate::ops::OpTrait;
use crate::{Direction, Hugr, Node, Port};

use super::{sealed::HugrInternals, HierarchyView, HugrView};
use super::{check_tag, sealed::HugrInternals, HierarchyView, HugrView};

type RegionGraph<'g> = portgraph::view::Region<'g, &'g MultiPortGraph>;

Expand Down Expand Up @@ -161,11 +160,7 @@ where
Root: NodeHandle,
{
fn try_new(hugr: &'a impl HugrView, root: Node) -> Result<Self, HugrError> {
hugr.valid_node(root)?;
let root_tag = hugr.get_optype(root).tag();
if !Root::TAG.is_superset(root_tag) {
return Err(HugrError::InvalidNode(root));
}
check_tag::<Root>(hugr, root)?;
let hugr = hugr.base_hugr();
Ok(Self {
root,
Expand Down
53 changes: 41 additions & 12 deletions src/hugr/views/sibling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@ use portgraph::{LinkView, MultiPortGraph, PortIndex, PortView};
use crate::hugr::hugrmut::sealed::HugrMutInternals;
use crate::hugr::{HugrError, HugrMut};
use crate::ops::handle::NodeHandle;
use crate::ops::OpTrait;
use crate::{Direction, Hugr, Node, Port};

use super::{sealed::HugrInternals, HierarchyView, HugrView};
use super::{check_tag, sealed::HugrInternals, HierarchyView, HugrView};

type FlatRegionGraph<'g> = portgraph::view::FlatRegion<'g, &'g MultiPortGraph>;

Expand Down Expand Up @@ -195,10 +194,7 @@ where
Root: NodeHandle,
{
fn try_new(hugr: &'a impl HugrView, root: Node) -> Result<Self, HugrError> {
hugr.valid_node(root)?;
if !Root::TAG.is_superset(hugr.get_optype(root).tag()) {
return Err(HugrError::InvalidNode(root));
}
check_tag::<Root>(hugr, root)?;
Ok(Self::new_unchecked(hugr, root))
}
}
Expand Down Expand Up @@ -251,10 +247,7 @@ impl<'g, Root: NodeHandle> SiblingMut<'g, Root> {
/// Create a new SiblingMut from a base.
/// Equivalent to [HierarchyView::try_new] but takes a *mutable* reference.
pub fn try_new(hugr: &'g mut impl HugrMut, root: Node) -> Result<Self, HugrError> {
hugr.valid_node(root)?;
if !Root::TAG.is_superset(hugr.get_optype(root).tag()) {
return Err(HugrError::InvalidNode(root));
}
check_tag::<Root>(hugr, root)?;
Ok(Self {
hugr: hugr.hugr_mut(),
root,
Expand Down Expand Up @@ -367,10 +360,14 @@ impl<'g, Root: NodeHandle> HugrMut for SiblingMut<'g, Root> {}

#[cfg(test)]
mod test {
use rstest::rstest;

use crate::builder::test::simple_dfg_hugr;
use crate::builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder};
use crate::extension::PRELUDE_REGISTRY;
use crate::ops::handle::{DfgID, FuncID, ModuleRootID};
use crate::ops::{dataflow::IOTrait, Input, Output};
use crate::hugr::NodeType;
use crate::ops::handle::{CfgID, DfgID, FuncID, ModuleRootID};
use crate::ops::{dataflow::IOTrait, Input, OpTag, Output};
use crate::type_row;
use crate::types::{FunctionType, Type};

Expand Down Expand Up @@ -433,4 +430,36 @@ mod test {

Ok(())
}

#[rstest]
fn flat_mut(mut simple_dfg_hugr: Hugr) {
simple_dfg_hugr
.infer_and_validate(&PRELUDE_REGISTRY)
.unwrap();
let root = simple_dfg_hugr.root();
let signature = simple_dfg_hugr.get_function_type().unwrap().clone();

let sib_mut = SiblingMut::<CfgID>::try_new(&mut simple_dfg_hugr, root);
assert_eq!(
sib_mut.err(),
Some(HugrError::InvalidTag {
required: OpTag::Cfg,
actual: OpTag::Dfg
})
);

let mut sib_mut = SiblingMut::<DfgID>::try_new(&mut simple_dfg_hugr, root).unwrap();
let bad_nodetype = NodeType::open_extensions(crate::ops::CFG { signature });
assert_eq!(
sib_mut.replace_op(sib_mut.root(), bad_nodetype.clone()),
Err(HugrError::InvalidTag {
required: OpTag::Dfg,
actual: OpTag::Cfg
})
);

// In contrast, performing this on the Hugr (where the allowed root type is 'Any') is only detected by validation
simple_dfg_hugr.replace_op(root, bad_nodetype).unwrap();
assert!(simple_dfg_hugr.validate(&PRELUDE_REGISTRY).is_err());
}
}
4 changes: 3 additions & 1 deletion src/ops/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ pub fn resolve_extension_ops(
for (n, op) in replacements {
let leaf: LeafOp = op.into();
let node_type = NodeType::new(leaf, h.get_nodetype(n).input_extensions().cloned());
h.replace_op(n, node_type);
debug_assert_eq!(h.get_optype(n).tag(), OpTag::Leaf);
debug_assert_eq!(node_type.tag(), OpTag::Leaf);
h.replace_op(n, node_type).unwrap();
}
Ok(())
}
Expand Down

0 comments on commit bb09fd4

Please sign in to comment.