diff --git a/hugr/src/algorithm/nest_cfgs.rs b/hugr/src/algorithm/nest_cfgs.rs index eadd9e5ae..feae6470b 100644 --- a/hugr/src/algorithm/nest_cfgs.rs +++ b/hugr/src/algorithm/nest_cfgs.rs @@ -604,7 +604,7 @@ pub(crate) mod test { // /-> left --\ // entry -> split > merge -> head -> tail -> exit // \-> right -/ \-<--<-/ - let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; + let mut cfg_builder = CFGBuilder::new(FunctionType::new_endo(NAT))?; let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2")); let const_unit = cfg_builder.add_constant(Value::unary_unit_sum()); @@ -615,7 +615,15 @@ pub(crate) mod test { )?; let (split, merge) = build_if_then_else_merge(&mut cfg_builder, &pred_const, &const_unit)?; cfg_builder.branch(&entry, 0, &split)?; - let (head, tail) = build_loop(&mut cfg_builder, &pred_const, &const_unit)?; + let head = n_identity( + cfg_builder.simple_block_builder(FunctionType::new_endo(NAT), 1)?, + &const_unit, + )?; + let tail = n_identity( + cfg_builder.simple_block_builder(FunctionType::new_endo(NAT), 2)?, + &pred_const, + )?; + cfg_builder.branch(&tail, 1, &head)?; cfg_builder.branch(&head, 0, &tail)?; // trivial "loop body" cfg_builder.branch(&merge, 0, &head)?; let exit = cfg_builder.exit_block(); @@ -671,12 +679,9 @@ pub(crate) mod test { #[test] fn test_cond_then_loop_combined() -> Result<(), BuildError> { - // /-> left --\ - // entry > merge -> tail -> exit - // \-> right -/ \-<--<-/ // Here we would like two consecutive regions, but there is no *edge* between // the conditional and the loop to indicate the boundary, so we cannot separate them. - let (h, merge, tail) = build_cond_then_loop_cfg(false)?; + let (h, merge, tail) = build_cond_then_loop_cfg()?; let (merge, tail) = (merge.node(), tail.node()); let [entry, exit]: [Node; 2] = h .children(h.root()) @@ -824,7 +829,7 @@ pub(crate) mod test { unit_const: &ConstID, ) -> Result<(BasicBlockID, BasicBlockID), BuildError> { let split = n_identity( - cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 2)?, + cfg.simple_block_builder(FunctionType::new_endo(NAT), 2)?, const_pred, )?; let merge = build_then_else_merge_from_if(cfg, unit_const, split)?; @@ -837,15 +842,15 @@ pub(crate) mod test { split: BasicBlockID, ) -> Result { let merge = n_identity( - cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?, + cfg.simple_block_builder(FunctionType::new_endo(NAT), 1)?, unit_const, )?; let left = n_identity( - cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?, + cfg.simple_block_builder(FunctionType::new_endo(NAT), 1)?, unit_const, )?; let right = n_identity( - cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?, + cfg.simple_block_builder(FunctionType::new_endo(NAT), 1)?, unit_const, )?; cfg.branch(&split, 0, &left)?; @@ -855,39 +860,12 @@ pub(crate) mod test { Ok(merge) } - // Returns loop tail - caller must link header to tail, and provide 0th successor of tail - fn build_loop_from_header + AsRef>( - cfg: &mut CFGBuilder, - const_pred: &ConstID, - header: BasicBlockID, - ) -> Result { - let tail = n_identity( - cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 2)?, - const_pred, - )?; - cfg.branch(&tail, 1, &header)?; - Ok(tail) - } - - // Result is header and tail. Caller must provide 0th successor of header (linking to tail), and 0th successor of tail. - fn build_loop + AsRef>( - cfg: &mut CFGBuilder, - const_pred: &ConstID, - unit_const: &ConstID, - ) -> Result<(BasicBlockID, BasicBlockID), BuildError> { - let header = n_identity( - cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?, - unit_const, - )?; - let tail = build_loop_from_header(cfg, const_pred, header)?; - Ok((header, tail)) - } - - // Result is merge and tail; loop header is (merge, if separate==true; unique successor of merge, if separate==false) - pub fn build_cond_then_loop_cfg( - separate: bool, - ) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> { - let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; + // /-> left --\ + // entry > merge -> tail -> exit + // \-> right -/ \-<--<-/ + // Result is Hugr plus merge and tail blocks + fn build_cond_then_loop_cfg() -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> { + let mut cfg_builder = CFGBuilder::new(FunctionType::new_endo(NAT))?; let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2")); let const_unit = cfg_builder.add_constant(Value::unary_unit_sum()); @@ -896,19 +874,13 @@ pub(crate) mod test { &pred_const, )?; let merge = build_then_else_merge_from_if(&mut cfg_builder, &const_unit, entry)?; - let head = if separate { - let h = n_identity( - cfg_builder - .simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?, - &const_unit, - )?; - cfg_builder.branch(&merge, 0, &h)?; - h - } else { - merge - }; - let tail = build_loop_from_header(&mut cfg_builder, &pred_const, head)?; - cfg_builder.branch(&head, 0, &tail)?; // trivial "loop body" + // The merge block is also the loop header (so it merges three incoming control-flow edges) + let tail = n_identity( + cfg_builder.simple_block_builder(FunctionType::new_endo(NAT), 2)?, + &pred_const, + )?; + cfg_builder.branch(&tail, 1, &merge)?; + cfg_builder.branch(&merge, 0, &tail)?; // trivial "loop body" let exit = cfg_builder.exit_block(); cfg_builder.branch(&tail, 0, &exit)?; @@ -920,7 +892,7 @@ pub(crate) mod test { pub(crate) fn build_conditional_in_loop_cfg( separate_headers: bool, ) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> { - let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; + let mut cfg_builder = CFGBuilder::new(FunctionType::new_endo(NAT))?; let (head, tail) = build_conditional_in_loop(&mut cfg_builder, separate_headers)?; let h = cfg_builder.finish_prelude_hugr()?; Ok((h, head, tail)) @@ -939,15 +911,22 @@ pub(crate) mod test { )?; let (split, merge) = build_if_then_else_merge(cfg_builder, &pred_const, &const_unit)?; - let (head, tail) = if separate_headers { - let (head, tail) = build_loop(cfg_builder, &pred_const, &const_unit)?; + let head = if separate_headers { + let head = n_identity( + cfg_builder.simple_block_builder(FunctionType::new_endo(NAT), 1)?, + &const_unit, + )?; cfg_builder.branch(&head, 0, &split)?; - (head, tail) + head } else { // Combine loop header with split. - let tail = build_loop_from_header(cfg_builder, &pred_const, split)?; - (split, tail) + split }; + let tail = n_identity( + cfg_builder.simple_block_builder(FunctionType::new_endo(NAT), 2)?, + &pred_const, + )?; + cfg_builder.branch(&tail, 1, &head)?; cfg_builder.branch(&merge, 0, &tail)?; let exit = cfg_builder.exit_block(); diff --git a/hugr/src/hugr/rewrite/outline_cfg.rs b/hugr/src/hugr/rewrite/outline_cfg.rs index 5cbe0f1b3..a6badb72d 100644 --- a/hugr/src/hugr/rewrite/outline_cfg.rs +++ b/hugr/src/hugr/rewrite/outline_cfg.rs @@ -230,7 +230,8 @@ pub enum OutlineCfgError { /// Multiple blocks had incoming edges #[error("Multiple blocks had predecessors outside the set - at least {0:?} and {1:?}")] MultipleEntryNodes(Node, Node), - /// Multiple blocks had outgoing edegs + /// Multiple blocks had outgoing edges + // Note possible TODO: straightforward if all outgoing edges target the same BB #[error("Multiple blocks had edges leaving the set - at least {0:?} and {1:?}")] MultipleExitNodes(Node, Node), /// One block had multiple outgoing edges @@ -239,7 +240,7 @@ pub enum OutlineCfgError { /// No block was identified as an entry block #[error("No block had predecessors outside the set")] NoEntryNode, - /// No block was identified as an exit block + /// No block was found with an edge leaving the set (so, must be an infinite loop) #[error("No block had a successor outside the set")] NoExitNode, } @@ -248,106 +249,196 @@ pub enum OutlineCfgError { mod test { use std::collections::HashSet; - use crate::algorithm::nest_cfgs::test::{ - build_cond_then_loop_cfg, build_conditional_in_loop, build_conditional_in_loop_cfg, depth, - }; use crate::builder::{ - Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder, SubContainer, + BlockBuilder, BuildError, CFGBuilder, Container, Dataflow, DataflowSubContainer, + HugrBuilder, ModuleBuilder, }; use crate::extension::prelude::USIZE_T; - use crate::extension::PRELUDE_REGISTRY; + use crate::extension::{ExtensionSet, PRELUDE_REGISTRY}; use crate::hugr::views::sibling::SiblingMut; use crate::hugr::HugrMut; - use crate::ops::handle::{BasicBlockID, CfgID, NodeHandle}; + use crate::ops::constant::Value; + use crate::ops::handle::{BasicBlockID, CfgID, ConstID, NodeHandle}; use crate::types::FunctionType; - use crate::{type_row, HugrView, Node}; + use crate::{type_row, Hugr, HugrView, Node}; use cool_asserts::assert_matches; use itertools::Itertools; + use rstest::rstest; use super::{OutlineCfg, OutlineCfgError}; - #[test] - fn test_outline_cfg_errors() { - let (mut h, head, tail) = build_conditional_in_loop_cfg(false).unwrap(); - let head = head.node(); - let tail = tail.node(); - // /-> left --\ - // entry -> head > merge -> tail -> exit - // | \-> right -/ | - // \---<---<---<---<---<--<---/ - // merge is unique predecessor of tail - let merge = h.input_neighbours(tail).exactly_one().unwrap(); - h.validate(&PRELUDE_REGISTRY).unwrap(); + /// /-> left --\ + /// entry > merge -> head -> tail -> exit + /// \-> right -/ \-<--<-/ + struct CondThenLoopCfg { + h: Hugr, + left: Node, + right: Node, + merge: Node, + head: Node, + tail: Node, + } + impl CondThenLoopCfg { + fn new() -> Result { + let block_ty = FunctionType::new_endo(USIZE_T); + let mut cfg_builder = CFGBuilder::new(block_ty.clone())?; + let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2")); + let const_unit = cfg_builder.add_constant(Value::unary_unit_sum()); + fn n_identity( + mut bbldr: BlockBuilder<&mut Hugr>, + cst: &ConstID, + ) -> Result { + let pred = bbldr.load_const(cst); + let vals = bbldr.input_wires(); + bbldr.finish_with_outputs(pred, vals) + } + let id_block = |c: &mut CFGBuilder<_>| { + n_identity(c.simple_block_builder(block_ty.clone(), 1)?, &const_unit) + }; + + let entry = n_identity( + cfg_builder.simple_entry_builder(USIZE_T.into(), 2, ExtensionSet::new())?, + &pred_const, + )?; + + let left = id_block(&mut cfg_builder)?; + let right = id_block(&mut cfg_builder)?; + cfg_builder.branch(&entry, 0, &left)?; + cfg_builder.branch(&entry, 1, &right)?; + + let merge = id_block(&mut cfg_builder)?; + cfg_builder.branch(&left, 0, &merge)?; + cfg_builder.branch(&right, 0, &merge)?; + + let head = id_block(&mut cfg_builder)?; + cfg_builder.branch(&merge, 0, &head)?; + let tail = n_identity( + cfg_builder.simple_block_builder(FunctionType::new_endo(USIZE_T), 2)?, + &pred_const, + )?; + cfg_builder.branch(&tail, 1, &head)?; + cfg_builder.branch(&head, 0, &tail)?; // trivial "loop body" + let exit = cfg_builder.exit_block(); + cfg_builder.branch(&tail, 0, &exit)?; + + let h = cfg_builder.finish_prelude_hugr()?; + let (left, right) = (left.node(), right.node()); + let (merge, head, tail) = (merge.node(), head.node(), tail.node()); + Ok(Self { + h, + left, + right, + merge, + head, + tail, + }) + } + fn entry_exit(&self) -> (Node, Node) { + self.h + .children(self.h.root()) + .take(2) + .collect_tuple() + .unwrap() + } + } + + #[rstest::fixture] + fn cond_then_loop_cfg() -> CondThenLoopCfg { + CondThenLoopCfg::new().unwrap() + } + + #[rstest] + fn test_outline_cfg_errors(cond_then_loop_cfg: CondThenLoopCfg) { + let (entry, _) = cond_then_loop_cfg.entry_exit(); + let CondThenLoopCfg { + mut h, + left, + right, + merge, + head, + tail, + } = cond_then_loop_cfg; let backup = h.clone(); - let r = h.apply_rewrite(OutlineCfg::new([merge, tail])); + + let r = h.apply_rewrite(OutlineCfg::new([tail])); assert_matches!(r, Err(OutlineCfgError::MultipleExitEdges(_, _))); assert_eq!(h, backup); - let [left, right]: [Node; 2] = h.output_neighbours(head).collect_vec().try_into().unwrap(); - let r = h.apply_rewrite(OutlineCfg::new([left, right, head])); - assert_matches!(r, Err(OutlineCfgError::MultipleExitNodes(a,b)) => assert_eq!(HashSet::from([a,b]), HashSet::from_iter([left, right]))); + let r = h.apply_rewrite(OutlineCfg::new([entry, left, right])); + assert_matches!(r, Err(OutlineCfgError::MultipleExitNodes(a,b)) + => assert_eq!(HashSet::from([a,b]), HashSet::from_iter([left, right]))); assert_eq!(h, backup); let r = h.apply_rewrite(OutlineCfg::new([left, right, merge])); - assert_matches!(r, Err(OutlineCfgError::MultipleEntryNodes(a,b)) => assert_eq!(HashSet::from([a,b]), HashSet::from([left, right]))); + assert_matches!(r, Err(OutlineCfgError::MultipleEntryNodes(a,b)) + => assert_eq!(HashSet::from([a,b]), HashSet::from([left, right]))); assert_eq!(h, backup); - } - #[test] - fn test_outline_cfg() { - let (mut h, head, tail) = build_conditional_in_loop_cfg(false).unwrap(); - h.update_validate(&PRELUDE_REGISTRY).unwrap(); - do_outline_cfg_test(&mut h, head, tail, 1); - h.update_validate(&PRELUDE_REGISTRY).unwrap(); + // The entry node implicitly has an extra incoming edge + let r = h.apply_rewrite(OutlineCfg::new([entry, left, right, merge, head])); + assert_matches!(r, Err(OutlineCfgError::MultipleEntryNodes(a,b)) + => assert_eq!(HashSet::from([a,b]), HashSet::from([entry, head]))); + assert_eq!(h, backup); } - fn do_outline_cfg_test( - h: &mut impl HugrMut, - head: BasicBlockID, - tail: BasicBlockID, - expected_depth: u32, - ) { - let head = head.node(); - let tail = tail.node(); - let parent = h.get_parent(head).unwrap(); - let [entry, exit]: [Node; 2] = h.children(parent).take(2).collect_vec().try_into().unwrap(); - // /-> left --\ - // entry -> head > merge -> tail -> exit - // | \-> right -/ | - // \---<---<---<---<---<--<---/ - // merge is unique predecessor of tail - let merge = h.input_neighbours(tail).exactly_one().ok().unwrap(); - let [left, right]: [Node; 2] = h.output_neighbours(head).collect_vec().try_into().unwrap(); - for n in [head, tail, merge] { - assert_eq!(depth(h.base_hugr(), n), expected_depth); - } - let blocks = [head, left, right, merge]; - let (new_block, new_cfg) = h.apply_rewrite(OutlineCfg::new(blocks)).unwrap(); - for n in blocks { - assert_eq!(depth(h.base_hugr(), n), expected_depth + 2); - } + #[rstest::rstest] + fn test_outline_cfg(cond_then_loop_cfg: CondThenLoopCfg) { + // Outline the loop, producing: + // /-> left -->\ + // entry merge -> newblock -> exit + // \-> right ->/ + let (_, exit) = cond_then_loop_cfg.entry_exit(); + let CondThenLoopCfg { + mut h, + merge, + head, + tail, + .. + } = cond_then_loop_cfg; + let root = h.root(); + let (new_block, _, exit_block) = outline_cfg_check_parents(&mut h, root, vec![head, tail]); + assert_eq!(h.output_neighbours(merge).collect_vec(), vec![new_block]); + assert_eq!(h.input_neighbours(exit).collect_vec(), vec![new_block]); assert_eq!( - new_block, - h.output_neighbours(entry).exactly_one().ok().unwrap() + h.output_neighbours(tail).collect::>(), + HashSet::from([head, exit_block]) ); - for n in [entry, exit, tail, new_block] { - assert_eq!(depth(h.base_hugr(), n), expected_depth); - } + } + + #[rstest] + fn test_outline_cfg_multiple_in_edges(cond_then_loop_cfg: CondThenLoopCfg) { + // Outline merge, head and tail, producing + // /-> left -->\ + // entry newblock -> exit + // \-> right ->/ + let (_, exit) = cond_then_loop_cfg.entry_exit(); + let CondThenLoopCfg { + mut h, + left, + right, + merge, + head, + tail, + } = cond_then_loop_cfg; + + let root = h.root(); + let (new_block, _, inner_exit) = + outline_cfg_check_parents(&mut h, root, vec![merge, head, tail]); + assert_eq!(h.input_neighbours(exit).collect_vec(), vec![new_block]); assert_eq!( - h.input_neighbours(tail).exactly_one().ok().unwrap(), - new_block + h.input_neighbours(new_block).collect::>(), + HashSet::from([left, right]) ); assert_eq!( - h.output_neighbours(tail).take(2).collect::>(), - HashSet::from([exit, new_block]) + h.output_neighbours(tail).collect::>(), + HashSet::from([head, inner_exit]) ); - assert!(h.get_optype(new_block).is_dataflow_block()); - assert_eq!(h.base_hugr().get_parent(new_cfg), Some(new_block)); - assert!(h.base_hugr().get_optype(new_cfg).is_cfg()); } - #[test] - fn test_outline_cfg_subregion() { + #[rstest] + fn test_outline_cfg_subregion(cond_then_loop_cfg: CondThenLoopCfg) { + // Outline the loop, as above, but with the CFG inside a Function + Module, + // operating via a SiblingMut let mut module_builder = ModuleBuilder::new(); let mut fbuild = module_builder .define_function( @@ -356,57 +447,77 @@ mod test { ) .unwrap(); let [i1] = fbuild.input_wires_arr(); - let mut cfg_builder = fbuild - .cfg_builder( - [(USIZE_T, i1)], - None, - type_row![USIZE_T], - Default::default(), - ) + let cfg = fbuild + .add_hugr_with_wires(cond_then_loop_cfg.h, [i1]) .unwrap(); - let (head, tail) = build_conditional_in_loop(&mut cfg_builder, false).unwrap(); - let cfg = cfg_builder.finish_sub_container().unwrap(); fbuild.finish_with_outputs(cfg.outputs()).unwrap(); let mut h = module_builder.finish_prelude_hugr().unwrap(); - do_outline_cfg_test( - &mut SiblingMut::<'_, CfgID>::try_new(&mut h, cfg.node()).unwrap(), - head, - tail, - 3, + // `add_hugr_with_wires` does not return an InsertionResult, so recover the nodes manually: + let cfg = cfg.node(); + let exit_node = h.children(cfg).nth(1).unwrap(); + let tail = h.input_neighbours(exit_node).exactly_one().unwrap(); + let head = h.input_neighbours(tail).exactly_one().unwrap(); + // Just sanity-check we have the correct nodes + assert!(h.get_optype(exit_node).is_exit_block()); + assert_eq!( + h.output_neighbours(tail).collect::>(), + HashSet::from([head, exit_node]) ); + outline_cfg_check_parents( + &mut SiblingMut::<'_, CfgID>::try_new(&mut h, cfg).unwrap(), + cfg, + vec![head, tail], + ); + h.update_validate(&PRELUDE_REGISTRY).unwrap(); } - #[test] - fn test_outline_cfg_move_entry() { - // /-> left --\ - // entry > merge -> head -> tail -> exit - // \-> right -/ \-<--<-/ - let (mut h, merge, tail) = build_cond_then_loop_cfg(true).unwrap(); - - let (entry, exit) = h.children(h.root()).take(2).collect_tuple().unwrap(); - let (left, right) = h.output_neighbours(entry).take(2).collect_tuple().unwrap(); - let (merge, tail) = (merge.node(), tail.node()); - let head = h.output_neighbours(merge).exactly_one().unwrap(); - - h.validate(&PRELUDE_REGISTRY).unwrap(); - let blocks_to_move = [entry, left, right, merge]; - let other_blocks = [head, tail, exit]; - for &n in blocks_to_move.iter().chain(other_blocks.iter()) { - assert_eq!(depth(&h, n), 1); - } - let (new_block, new_cfg) = h - .apply_rewrite(OutlineCfg::new(blocks_to_move.iter().copied())) - .unwrap(); + #[rstest] + fn test_outline_cfg_move_entry(cond_then_loop_cfg: CondThenLoopCfg) { + // Outline the conditional, producing + // + // newblock -> head -> tail -> exit + // \<--, + ) -> (Node, Node, Node) { + let mut other_blocks = h.children(cfg).collect::>(); + assert!(blocks.iter().all(|b| other_blocks.remove(b))); + let (new_block, new_cfg) = h.apply_rewrite(OutlineCfg::new(blocks.clone())).unwrap(); + for n in other_blocks { - assert_eq!(depth(&h, n), 1); + assert_eq!(h.get_parent(n), Some(cfg)) } - for n in blocks_to_move { - assert_eq!(h.get_parent(n).unwrap(), new_cfg); + assert_eq!(h.get_parent(new_block), Some(cfg)); + assert!(h.get_optype(new_block).is_dataflow_block()); + let b = h.base_hugr(); // To cope with `h` potentially being a SiblingMut + assert_eq!(b.get_parent(new_cfg), Some(new_block)); + for n in blocks { + assert_eq!(b.get_parent(n), Some(new_cfg)); } + assert!(b.get_optype(new_cfg).is_cfg()); + let exit_block = b.children(new_cfg).nth(1).unwrap(); + assert!(b.get_optype(exit_block).is_exit_block()); + (new_block, new_cfg, exit_block) } }