diff --git a/src/extension/infer/test.rs b/src/extension/infer/test.rs index 5714cd4e8..7653c7578 100644 --- a/src/extension/infer/test.rs +++ b/src/extension/infer/test.rs @@ -11,7 +11,7 @@ use crate::extension::{prelude::PRELUDE_REGISTRY, ExtensionSet}; use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType}; use crate::macros::const_extension_ids; use crate::ops::custom::{ExternalOp, OpaqueOp}; -use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle, OpTrait}; +use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle}; use crate::ops::{LeafOp, OpType}; use crate::type_row; @@ -314,12 +314,8 @@ fn test_conditional_inference() -> Result<(), Box> { first_ext: ExtensionId, second_ext: ExtensionId, ) -> Result> { - let [case, case_in, case_out] = create_with_io( - hugr, - conditional_node, - op.clone(), - Into::::into(op).dataflow_signature().unwrap(), - )?; + let [case, case_in, case_out] = + create_with_io(hugr, conditional_node, op.clone(), op.inner_signature())?; let lift1 = hugr.add_node_with_parent( case, diff --git a/src/hugr/rewrite/replace.rs b/src/hugr/rewrite/replace.rs index 1ca02b4ac..017407c97 100644 --- a/src/hugr/rewrite/replace.rs +++ b/src/hugr/rewrite/replace.rs @@ -85,7 +85,7 @@ pub struct Replacement { } impl NewEdgeSpec { - fn check_src(&self, h: &impl HugrView) -> Result<(), ReplaceError> { + fn check_src(&self, h: &impl HugrView, err_spec: &NewEdgeSpec) -> Result<(), ReplaceError> { let optype = h.get_optype(self.src); let ok = match self.kind { NewEdgeKind::Order => optype.other_output() == Some(EdgeKind::StateOrder), @@ -100,9 +100,9 @@ impl NewEdgeSpec { } }; ok.then_some(()) - .ok_or(ReplaceError::BadEdgeKind(Direction::Outgoing, self.clone())) + .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Outgoing, err_spec.clone())) } - fn check_tgt(&self, h: &impl HugrView) -> Result<(), ReplaceError> { + fn check_tgt(&self, h: &impl HugrView, err_spec: &NewEdgeSpec) -> Result<(), ReplaceError> { let optype = h.get_optype(self.tgt); let ok = match self.kind { NewEdgeKind::Order => optype.other_input() == Some(EdgeKind::StateOrder), @@ -118,7 +118,7 @@ impl NewEdgeSpec { ), }; ok.then_some(()) - .ok_or(ReplaceError::BadEdgeKind(Direction::Incoming, self.clone())) + .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Incoming, err_spec.clone())) } fn check_existing_edge( @@ -233,20 +233,20 @@ impl Rewrite for Replacement { e.clone(), )); } - e.check_src(h)?; + e.check_src(h, e)?; } self.mu_out.iter().try_for_each(|e| { self.replacement.valid_non_root(e.src).map_err(|_| { ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichHugr::Replacement, e.clone()) })?; - e.check_src(&self.replacement) + e.check_src(&self.replacement, e) })?; // Edge targets... self.mu_inp.iter().try_for_each(|e| { self.replacement.valid_non_root(e.tgt).map_err(|_| { ReplaceError::BadEdgeSpec(Direction::Incoming, WhichHugr::Replacement, e.clone()) })?; - e.check_tgt(&self.replacement) + e.check_tgt(&self.replacement, e) })?; for e in self.mu_out.iter().chain(self.mu_new.iter()) { if !h.contains_node(e.tgt) || removed.contains(&e.tgt) { @@ -256,7 +256,7 @@ impl Rewrite for Replacement { e.clone(), )); } - e.check_tgt(h)?; + e.check_tgt(h, e)?; // The descendant check is to allow the case where the old edge is nonlocal // from a part of the Hugr being moved (which may require changing source, // depending on where the transplanted portion ends up). While this subsumes @@ -353,8 +353,8 @@ fn transfer_edges<'a>( h.valid_node(e.tgt).map_err(|_| { ReplaceError::BadEdgeSpec(Direction::Incoming, WhichHugr::Retained, oe.clone()) })?; - e.check_src(h)?; - e.check_tgt(h)?; + e.check_src(h, oe)?; + e.check_tgt(h, oe)?; match e.kind { NewEdgeKind::Order => { h.add_other_edge(e.src, e.tgt).unwrap(); @@ -820,7 +820,7 @@ mod test { mu_out: vec![new_out_edge.clone()], ..r.clone() }), - Err(ReplaceError::NoRemovedEdge(new_out_edge)) + Err(ReplaceError::BadEdgeKind(Direction::Outgoing, new_out_edge)) ); Ok(()) } diff --git a/src/ops/controlflow.rs b/src/ops/controlflow.rs index afd0ef5f8..4489cc4fe 100644 --- a/src/ops/controlflow.rs +++ b/src/ops/controlflow.rs @@ -235,10 +235,6 @@ impl OpTrait for Case { fn tag(&self) -> OpTag { ::TAG } - - fn dataflow_signature(&self) -> Option { - Some(self.signature.clone()) - } } impl Case { @@ -251,6 +247,11 @@ impl Case { pub fn dataflow_output(&self) -> &TypeRow { &self.signature.output } + + /// The signature of the dataflow sibling graph contained in the [`Case`] + pub fn inner_signature(&self) -> FunctionType { + self.signature.clone() + } } fn tuple_sum_first(tuple_sum_row: &TypeRow, rest: &TypeRow) -> TypeRow {