From bb09fd4d9cbe3ee54eba8a00de23feda4606e89c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Oct 2023 14:20:21 +0100 Subject: [PATCH] [new + bugfix] replace_op checks bound, adding HugrError::InvalidTag (#581) * Add HugrError::InvalidTag, use in HierarchyView::try_new and SiblingMut::try_new * In HugrMut::replace_op, check that changing the root-node's op does not invalidate the Root: NodeHandle tag. * Hence, make HugrMut::replace_op return a Result --- src/builder/module.rs | 2 +- src/hugr.rs | 5 ++++ src/hugr/hugrmut.rs | 20 ++++++++++--- src/hugr/rewrite.rs | 2 +- src/hugr/validate.rs | 26 +++++++++++------ src/hugr/views.rs | 14 +++++++-- src/hugr/views/descendants.rs | 9 ++---- src/hugr/views/sibling.rs | 53 +++++++++++++++++++++++++++-------- src/ops/custom.rs | 4 ++- 9 files changed, 98 insertions(+), 37 deletions(-) diff --git a/src/builder/module.rs b/src/builder/module.rs index d2fab41ca..2df7a1241 100644 --- a/src/builder/module.rs +++ b/src/builder/module.rs @@ -94,7 +94,7 @@ impl + AsRef> ModuleBuilder { name, signature: signature.clone(), }), - ); + )?; let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, signature, None)?; Ok(FunctionBuilder::from_dfg_builder(db)) diff --git a/src/hugr.rs b/src/hugr.rs index 294f99fea..379dade88 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -569,6 +569,11 @@ pub enum HugrError { /// The node doesn't exist. #[error("Invalid node {0:?}.")] InvalidNode(Node), + /// The node was not of the required [OpTag] + /// (e.g. to conform to a [HugrView::RootHandle]) + #[error("Invalid tag: required a tag in {required} but found {actual}")] + #[allow(missing_docs)] + InvalidTag { required: OpTag, actual: OpTag }, /// An invalid port was specified. #[error("Invalid port direction {0:?}.")] InvalidPortDirection(Direction), diff --git a/src/hugr/hugrmut.rs b/src/hugr/hugrmut.rs index 94f6a1305..15ca9f2c2 100644 --- a/src/hugr/hugrmut.rs +++ b/src/hugr/hugrmut.rs @@ -439,6 +439,7 @@ fn insert_subgraph_internal( pub(crate) mod sealed { use super::*; + use crate::ops::handle::NodeHandle; /// Trait for accessing the mutable internals of a Hugr(Mut). /// @@ -500,8 +501,18 @@ pub(crate) mod sealed { /// In general this invalidates the ports, which may need to be resized to /// match the OpType signature. /// TODO: Add a version which ignores input extensions - fn replace_op(&mut self, node: Node, op: NodeType) -> NodeType { - self.valid_node(node).unwrap_or_else(|e| panic!("{}", e)); + /// + /// # Errors + /// Returns a [`HugrError::InvalidTag`] if this would break the bound + /// ([`Self::RootHandle`]) on the root node's [OpTag] + fn replace_op(&mut self, node: Node, op: NodeType) -> Result { + self.valid_node(node)?; + if node == self.root() && !Self::RootHandle::TAG.is_superset(op.tag()) { + return Err(HugrError::InvalidTag { + required: Self::RootHandle::TAG, + actual: op.tag(), + }); + } self.hugr_mut().replace_op(node, op) } } @@ -565,9 +576,10 @@ pub(crate) mod sealed { Ok(()) } - fn replace_op(&mut self, node: Node, op: NodeType) -> NodeType { + fn replace_op(&mut self, node: Node, op: NodeType) -> Result { + // No possibility of failure here since Self::RootHandle == Any let cur = self.hugr_mut().op_types.get_mut(node.index); - std::mem::replace(cur, op) + Ok(std::mem::replace(cur, op)) } } } diff --git a/src/hugr/rewrite.rs b/src/hugr/rewrite.rs index c455ad655..970908273 100644 --- a/src/hugr/rewrite.rs +++ b/src/hugr/rewrite.rs @@ -64,7 +64,7 @@ impl Rewrite for Transactional { } if r.is_err() { // Try to restore backup. - h.replace_op(h.root(), backup.root_type().clone()); + h.replace_op(h.root(), backup.root_type().clone()).unwrap(); while let Some(child) = first_child(h) { h.remove_node(child).unwrap(); } diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index cda4a2082..9499cab2f 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -827,7 +827,7 @@ mod test { Err(ValidationError::NoParent { node }) => assert_eq!(node, other) ); b.set_parent(other, root).unwrap(); - b.replace_op(other, NodeType::pure(declare_op)); + b.replace_op(other, NodeType::pure(declare_op)).unwrap(); b.add_ports(other, Direction::Outgoing, 1); assert_eq!(b.validate(&EMPTY_REG), Ok(())); @@ -929,14 +929,16 @@ mod test { .unwrap(); // Replace the output operation of the df subgraph with a copy - b.replace_op(output, NodeType::pure(LeafOp::Noop { ty: NAT })); + b.replace_op(output, NodeType::pure(LeafOp::Noop { ty: NAT })) + .unwrap(); assert_matches!( b.validate(&EMPTY_REG), Err(ValidationError::InvalidInitialChild { parent, .. }) => assert_eq!(parent, def) ); // Revert it back to an output, but with the wrong number of ports - b.replace_op(output, NodeType::pure(ops::Output::new(type_row![BOOL_T]))); + b.replace_op(output, NodeType::pure(ops::Output::new(type_row![BOOL_T]))) + .unwrap(); assert_matches!( b.validate(&EMPTY_REG), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::IOSignatureMismatch { child, .. }, .. }) @@ -945,13 +947,15 @@ mod test { b.replace_op( output, NodeType::pure(ops::Output::new(type_row![BOOL_T, BOOL_T])), - ); + ) + .unwrap(); // After fixing the output back, replace the copy with an output op b.replace_op( copy, NodeType::pure(ops::Output::new(type_row![BOOL_T, BOOL_T])), - ); + ) + .unwrap(); assert_matches!( b.validate(&EMPTY_REG), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::InternalIOChildren { child, .. }, .. }) @@ -975,7 +979,8 @@ mod test { NodeType::pure(ops::CFG { signature: FunctionType::new(type_row![BOOL_T], type_row![BOOL_T]), }), - ); + ) + .unwrap(); assert_matches!( b.validate(&EMPTY_REG), Err(ValidationError::ContainerWithoutChildren { .. }) @@ -1033,18 +1038,21 @@ mod test { other_outputs: type_row![Q], extension_delta: ExtensionSet::new(), }), - ); + ) + .unwrap(); let mut block_children = b.hierarchy.children(block.index); let block_input = block_children.next().unwrap().into(); let block_output = block_children.next_back().unwrap().into(); - b.replace_op(block_input, NodeType::pure(ops::Input::new(type_row![Q]))); + b.replace_op(block_input, NodeType::pure(ops::Input::new(type_row![Q]))) + .unwrap(); b.replace_op( block_output, NodeType::pure(ops::Output::new(type_row![ Type::new_simple_predicate(1), Q ])), - ); + ) + .unwrap(); assert_matches!( b.validate(&EMPTY_REG), Err(ValidationError::InvalidEdges { parent, source: EdgeValidationError::CFGEdgeSignatureMismatch { .. }, .. }) diff --git a/src/hugr/views.rs b/src/hugr/views.rs index 1a19c873a..0971f16d2 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -20,7 +20,7 @@ use portgraph::{multiportgraph, LinkView, MultiPortGraph, PortView}; use super::{Hugr, HugrError, NodeMetadata, NodeType, DEFAULT_NODETYPE}; use crate::ops::handle::NodeHandle; -use crate::ops::{FuncDecl, FuncDefn, OpName, OpTag, OpType, DFG}; +use crate::ops::{FuncDecl, FuncDefn, OpName, OpTag, OpTrait, OpType, DFG}; use crate::types::{EdgeKind, FunctionType}; use crate::{Direction, Node, Port}; @@ -308,10 +308,20 @@ pub trait HierarchyView<'a>: HugrView + Sized { /// Create a hierarchical view of a HUGR given a root node. /// /// # Errors - /// Returns [`HugrError::InvalidNode`] if the root isn't a node of the required [OpTag] + /// Returns [`HugrError::InvalidTag`] if the root isn't a node of the required [OpTag] fn try_new(hugr: &'a impl HugrView, root: Node) -> Result; } +fn check_tag(hugr: &impl HugrView, node: Node) -> Result<(), HugrError> { + hugr.valid_node(node)?; + let actual = hugr.get_optype(node).tag(); + let required = Required::TAG; + if !required.is_superset(actual) { + return Err(HugrError::InvalidTag { required, actual }); + } + Ok(()) +} + impl HugrView for T where T: AsRef, diff --git a/src/hugr/views/descendants.rs b/src/hugr/views/descendants.rs index e140f1846..46029d8a4 100644 --- a/src/hugr/views/descendants.rs +++ b/src/hugr/views/descendants.rs @@ -7,10 +7,9 @@ use portgraph::{LinkView, MultiPortGraph, PortIndex, PortView}; use crate::hugr::HugrError; use crate::ops::handle::NodeHandle; -use crate::ops::OpTrait; use crate::{Direction, Hugr, Node, Port}; -use super::{sealed::HugrInternals, HierarchyView, HugrView}; +use super::{check_tag, sealed::HugrInternals, HierarchyView, HugrView}; type RegionGraph<'g> = portgraph::view::Region<'g, &'g MultiPortGraph>; @@ -161,11 +160,7 @@ where Root: NodeHandle, { fn try_new(hugr: &'a impl HugrView, root: Node) -> Result { - hugr.valid_node(root)?; - let root_tag = hugr.get_optype(root).tag(); - if !Root::TAG.is_superset(root_tag) { - return Err(HugrError::InvalidNode(root)); - } + check_tag::(hugr, root)?; let hugr = hugr.base_hugr(); Ok(Self { root, diff --git a/src/hugr/views/sibling.rs b/src/hugr/views/sibling.rs index df39c484f..dcb773efb 100644 --- a/src/hugr/views/sibling.rs +++ b/src/hugr/views/sibling.rs @@ -9,10 +9,9 @@ use portgraph::{LinkView, MultiPortGraph, PortIndex, PortView}; use crate::hugr::hugrmut::sealed::HugrMutInternals; use crate::hugr::{HugrError, HugrMut}; use crate::ops::handle::NodeHandle; -use crate::ops::OpTrait; use crate::{Direction, Hugr, Node, Port}; -use super::{sealed::HugrInternals, HierarchyView, HugrView}; +use super::{check_tag, sealed::HugrInternals, HierarchyView, HugrView}; type FlatRegionGraph<'g> = portgraph::view::FlatRegion<'g, &'g MultiPortGraph>; @@ -195,10 +194,7 @@ where Root: NodeHandle, { fn try_new(hugr: &'a impl HugrView, root: Node) -> Result { - hugr.valid_node(root)?; - if !Root::TAG.is_superset(hugr.get_optype(root).tag()) { - return Err(HugrError::InvalidNode(root)); - } + check_tag::(hugr, root)?; Ok(Self::new_unchecked(hugr, root)) } } @@ -251,10 +247,7 @@ impl<'g, Root: NodeHandle> SiblingMut<'g, Root> { /// Create a new SiblingMut from a base. /// Equivalent to [HierarchyView::try_new] but takes a *mutable* reference. pub fn try_new(hugr: &'g mut impl HugrMut, root: Node) -> Result { - hugr.valid_node(root)?; - if !Root::TAG.is_superset(hugr.get_optype(root).tag()) { - return Err(HugrError::InvalidNode(root)); - } + check_tag::(hugr, root)?; Ok(Self { hugr: hugr.hugr_mut(), root, @@ -367,10 +360,14 @@ impl<'g, Root: NodeHandle> HugrMut for SiblingMut<'g, Root> {} #[cfg(test)] mod test { + use rstest::rstest; + + use crate::builder::test::simple_dfg_hugr; use crate::builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}; use crate::extension::PRELUDE_REGISTRY; - use crate::ops::handle::{DfgID, FuncID, ModuleRootID}; - use crate::ops::{dataflow::IOTrait, Input, Output}; + use crate::hugr::NodeType; + use crate::ops::handle::{CfgID, DfgID, FuncID, ModuleRootID}; + use crate::ops::{dataflow::IOTrait, Input, OpTag, Output}; use crate::type_row; use crate::types::{FunctionType, Type}; @@ -433,4 +430,36 @@ mod test { Ok(()) } + + #[rstest] + fn flat_mut(mut simple_dfg_hugr: Hugr) { + simple_dfg_hugr + .infer_and_validate(&PRELUDE_REGISTRY) + .unwrap(); + let root = simple_dfg_hugr.root(); + let signature = simple_dfg_hugr.get_function_type().unwrap().clone(); + + let sib_mut = SiblingMut::::try_new(&mut simple_dfg_hugr, root); + assert_eq!( + sib_mut.err(), + Some(HugrError::InvalidTag { + required: OpTag::Cfg, + actual: OpTag::Dfg + }) + ); + + let mut sib_mut = SiblingMut::::try_new(&mut simple_dfg_hugr, root).unwrap(); + let bad_nodetype = NodeType::open_extensions(crate::ops::CFG { signature }); + assert_eq!( + sib_mut.replace_op(sib_mut.root(), bad_nodetype.clone()), + Err(HugrError::InvalidTag { + required: OpTag::Dfg, + actual: OpTag::Cfg + }) + ); + + // In contrast, performing this on the Hugr (where the allowed root type is 'Any') is only detected by validation + simple_dfg_hugr.replace_op(root, bad_nodetype).unwrap(); + assert!(simple_dfg_hugr.validate(&PRELUDE_REGISTRY).is_err()); + } } diff --git a/src/ops/custom.rs b/src/ops/custom.rs index 7a5945c84..3d892c515 100644 --- a/src/ops/custom.rs +++ b/src/ops/custom.rs @@ -263,7 +263,9 @@ pub fn resolve_extension_ops( for (n, op) in replacements { let leaf: LeafOp = op.into(); let node_type = NodeType::new(leaf, h.get_nodetype(n).input_extensions().cloned()); - h.replace_op(n, node_type); + debug_assert_eq!(h.get_optype(n).tag(), OpTag::Leaf); + debug_assert_eq!(node_type.tag(), OpTag::Leaf); + h.replace_op(n, node_type).unwrap(); } Ok(()) }