From 10f16f6c6150bd3172cab07ba02936e2f0dc5a15 Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Wed, 22 May 2024 11:38:15 +0100 Subject: [PATCH 01/28] Create new hugr-passes crate, copying three files from hugr/src/algorithms. --- Cargo.toml | 2 +- hugr-passes/Cargo.toml | 13 + hugr-passes/src/const_fold.rs | 551 ++++++++++++++++++++ hugr-passes/src/lib.rs | 3 + hugr-passes/src/merge_bbs.rs | 398 ++++++++++++++ hugr-passes/src/nest_cfgs.rs | 946 ++++++++++++++++++++++++++++++++++ 6 files changed, 1912 insertions(+), 1 deletion(-) create mode 100644 hugr-passes/Cargo.toml create mode 100644 hugr-passes/src/const_fold.rs create mode 100644 hugr-passes/src/lib.rs create mode 100644 hugr-passes/src/merge_bbs.rs create mode 100644 hugr-passes/src/nest_cfgs.rs diff --git a/Cargo.toml b/Cargo.toml index 1d86879ab..3063ea616 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ lto = "thin" [workspace] resolver = "2" -members = ["hugr"] +members = ["hugr", "hugr-passes"] default-members = ["hugr"] [workspace.package] diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml new file mode 100644 index 000000000..6fe4b8607 --- /dev/null +++ b/hugr-passes/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "hugr-passes" +version = "0.1.0" +edition = "2021" + +[dependencies] +hugr = { path = "../hugr" } +itertools = "0.12.0" +paste = "1.0" +thiserror = "1.0.28" + +[dev-dependencies] +rstest = "0.19.0" diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs new file mode 100644 index 000000000..5d4cffed3 --- /dev/null +++ b/hugr-passes/src/const_fold.rs @@ -0,0 +1,551 @@ +//! Constant folding routines. + +use std::collections::{BTreeSet, HashMap}; + +use itertools::Itertools; +use thiserror::Error; + +use crate::hugr::{SimpleReplacementError, ValidationError}; +use crate::types::SumType; +use crate::Direction; +use crate::{ + builder::{DFGBuilder, Dataflow, DataflowHugr}, + extension::{ConstFoldResult, ExtensionRegistry}, + hugr::{ + rewrite::consts::{RemoveConst, RemoveLoadConstant}, + views::SiblingSubgraph, + HugrMut, + }, + ops::{OpType, Value}, + type_row, + types::FunctionType, + Hugr, HugrView, IncomingPort, Node, SimpleReplacement, +}; + +#[derive(Error, Debug)] +#[allow(missing_docs)] +pub enum ConstFoldError { + #[error("Failed to verify {label} HUGR: {err}")] + VerifyError { + label: String, + #[source] + err: ValidationError, + }, + #[error(transparent)] + SimpleReplaceError(#[from] SimpleReplacementError), +} + +/// Tag some output constants with [`OutgoingPort`] inferred from the ordering. +fn out_row(consts: impl IntoIterator) -> ConstFoldResult { + let vec = consts + .into_iter() + .enumerate() + .map(|(i, c)| (i.into(), c)) + .collect(); + Some(vec) +} + +/// Sort folding inputs with [`IncomingPort`] as key +fn sort_by_in_port(consts: &[(IncomingPort, Value)]) -> Vec<&(IncomingPort, Value)> { + let mut v: Vec<_> = consts.iter().collect(); + v.sort_by_key(|(i, _)| i); + v +} + +/// Sort some input constants by port and just return the constants. +pub(crate) fn sorted_consts(consts: &[(IncomingPort, Value)]) -> Vec<&Value> { + sort_by_in_port(consts) + .into_iter() + .map(|(_, c)| c) + .collect() +} + +/// For a given op and consts, attempt to evaluate the op. +pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldResult { + let fold_result = match op { + OpType::Noop { .. } => out_row([consts.first()?.1.clone()]), + OpType::MakeTuple { .. } => { + out_row([Value::tuple(sorted_consts(consts).into_iter().cloned())]) + } + OpType::UnpackTuple { .. } => { + let c = &consts.first()?.1; + let Value::Tuple { vs } = c else { + panic!("This op always takes a Tuple input."); + }; + out_row(vs.iter().cloned()) + } + + OpType::Tag(t) => out_row([Value::sum( + t.tag, + consts.iter().map(|(_, konst)| konst.clone()), + SumType::new(t.variants.clone()), + ) + .unwrap()]), + OpType::CustomOp(op) => { + let ext_op = op.as_extension_op()?; + ext_op.constant_fold(consts) + } + _ => None, + }; + debug_assert!(fold_result.as_ref().map_or(true, |x| x.len() + == op.value_port_count(Direction::Outgoing))); + fold_result +} + +/// Generate a graph that loads and outputs `consts` in order, validating +/// against `reg`. +fn const_graph(consts: Vec, reg: &ExtensionRegistry) -> Hugr { + let const_types = consts.iter().map(Value::get_type).collect_vec(); + let mut b = DFGBuilder::new(FunctionType::new(type_row![], const_types)).unwrap(); + + let outputs = consts + .into_iter() + .map(|c| b.add_load_const(c)) + .collect_vec(); + + b.finish_hugr_with_outputs(outputs, reg).unwrap() +} + +/// Given some `candidate_nodes` to search for LoadConstant operations in `hugr`, +/// return an iterator of possible constant folding rewrites. The +/// [`SimpleReplacement`] replaces an operation with constants that result from +/// evaluating it, the extension registry `reg` is used to validate the +/// replacement HUGR. The vector of [`RemoveLoadConstant`] refer to the +/// LoadConstant nodes that could be removed - they are not automatically +/// removed as they may be used by other operations. +pub fn find_consts<'a, 'r: 'a>( + hugr: &'a impl HugrView, + candidate_nodes: impl IntoIterator + 'a, + reg: &'r ExtensionRegistry, +) -> impl Iterator)> + 'a { + // track nodes for operations that have already been considered for folding + let mut used_neighbours = BTreeSet::new(); + + candidate_nodes + .into_iter() + .filter_map(move |n| { + // only look at LoadConstant + hugr.get_optype(n).is_load_constant().then_some(())?; + + let (out_p, _) = hugr.out_value_types(n).exactly_one().ok()?; + let neighbours = hugr + .linked_inputs(n, out_p) + .filter(|(n, _)| used_neighbours.insert(*n)) + .collect_vec(); + if neighbours.is_empty() { + // no uses of LoadConstant that haven't already been considered. + return None; + } + let fold_iter = neighbours + .into_iter() + .filter_map(|(neighbour, _)| fold_op(hugr, neighbour, reg)); + Some(fold_iter) + }) + .flatten() +} + +/// Attempt to evaluate and generate rewrites for the operation at `op_node` +fn fold_op( + hugr: &impl HugrView, + op_node: Node, + reg: &ExtensionRegistry, +) -> Option<(SimpleReplacement, Vec)> { + // only support leaf folding for now. + let neighbour_op = hugr.get_optype(op_node); + let (in_consts, removals): (Vec<_>, Vec<_>) = hugr + .node_inputs(op_node) + .filter_map(|in_p| { + let (con_op, load_n) = get_const(hugr, op_node, in_p)?; + Some(((in_p, con_op), RemoveLoadConstant(load_n))) + }) + .unzip(); + // attempt to evaluate op + let (nu_out, consts): (HashMap<_, _>, Vec<_>) = fold_leaf_op(neighbour_op, &in_consts)? + .into_iter() + .enumerate() + .filter_map(|(i, (op_out, konst))| { + // for each used port of the op give the nu_out entry and the + // corresponding Value + hugr.single_linked_input(op_node, op_out) + .map(|np| ((np, i.into()), konst)) + }) + .unzip(); + let replacement = const_graph(consts, reg); + let sibling_graph = SiblingSubgraph::try_from_nodes([op_node], hugr) + .expect("Operation should form valid subgraph."); + + let simple_replace = SimpleReplacement::new( + sibling_graph, + replacement, + // no inputs to replacement + HashMap::new(), + nu_out, + ); + Some((simple_replace, removals)) +} + +/// If `op_node` is connected to a LoadConstant at `in_p`, return the constant +/// and the LoadConstant node +fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option<(Value, Node)> { + let (load_n, _) = hugr.single_linked_output(op_node, in_p)?; + let load_op = hugr.get_optype(load_n).as_load_constant()?; + let const_node = hugr + .single_linked_output(load_n, load_op.constant_port())? + .0; + let const_op = hugr.get_optype(const_node).as_const()?; + + // TODO avoid const clone here + Some((const_op.as_ref().clone(), load_n)) +} + +/// Exhaustively apply constant folding to a HUGR. +pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { + #[cfg(test)] + let verify = |label, h: &H| { + h.validate_no_extensions(reg).unwrap_or_else(|err| { + panic!( + "constant_fold_pass: failed to verify {label} HUGR: {err}\n{}", + h.mermaid_string() + ) + }) + }; + #[cfg(test)] + verify("input", h); + loop { + // We can only safely apply a single replacement. Applying a + // replacement removes nodes and edges which may be referenced by + // further replacements returned by find_consts. Even worse, if we + // attempted to apply those replacements, expecting them to fail if + // the nodes and edges they reference had been deleted, they may + // succeed because new nodes and edges reused the ids. + // + // We could be a lot smarter here, keeping track of `LoadConstant` + // nodes and only looking at their out neighbours. + let Some((replace, removes)) = find_consts(h, h.nodes(), reg).next() else { + break; + }; + h.apply_rewrite(replace).unwrap(); + for rem in removes { + // We are optimistically applying these [RemoveLoadConstant] and + // [RemoveConst] rewrites without checking whether the nodes + // they attempt to remove have remaining uses. If they do, then + // the rewrite fails and we move on. + if let Ok(const_node) = h.apply_rewrite(rem) { + // if the LoadConst was removed, try removing the Const too. + let _ = h.apply_rewrite(RemoveConst(const_node)); + } + } + } + #[cfg(test)] + verify("output", h); +} + +#[cfg(test)] +mod test { + + use super::*; + use crate::extension::prelude::{sum_with_error, BOOL_T}; + use crate::extension::{ExtensionRegistry, PRELUDE}; + use crate::ops::{OpType, UnpackTuple}; + use crate::std_extensions::arithmetic; + use crate::std_extensions::arithmetic::conversions::ConvertOpDef; + use crate::std_extensions::arithmetic::float_ops::FloatOps; + use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; + use crate::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; + use crate::std_extensions::logic::{self, NaryLogic, NotOp}; + use crate::utils::test::{assert_fully_folded, assert_fully_folded_with}; + + use rstest::rstest; + + /// int to constant + fn i2c(b: u64) -> Value { + Value::extension(ConstInt::new_u(5, b).unwrap()) + } + + /// float to constant + fn f2c(f: f64) -> Value { + ConstF64::new(f).into() + } + + #[rstest] + #[case(0.0, 0.0, 0.0)] + #[case(0.0, 1.0, 1.0)] + #[case(23.5, 435.5, 459.0)] + // c = a + b + fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { + let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))]; + let add_op: OpType = FloatOps::fadd.into(); + let outs = fold_leaf_op(&add_op, &consts) + .unwrap() + .into_iter() + .map(|(p, v)| (p, v.get_custom_value::().unwrap().value())) + .collect_vec(); + + assert_eq!(outs.as_slice(), &[(0.into(), c)]); + } + #[test] + fn test_big() { + /* + Test approximately calculates + let x = (5.6, 3.2); + int(x.0 - x.1) == 2 + */ + let sum_type = sum_with_error(INT_TYPES[5].to_owned()); + let mut build = DFGBuilder::new(FunctionType::new( + type_row![], + vec![sum_type.clone().into()], + )) + .unwrap(); + + let tup = build.add_load_const(Value::tuple([f2c(5.6), f2c(3.2)])); + + let unpack = build + .add_dataflow_op( + UnpackTuple { + tys: type_row![FLOAT64_TYPE, FLOAT64_TYPE], + }, + [tup], + ) + .unwrap(); + + let sub = build + .add_dataflow_op(FloatOps::fsub, unpack.outputs()) + .unwrap(); + let to_int = build + .add_dataflow_op(ConvertOpDef::trunc_u.with_log_width(5), sub.outputs()) + .unwrap(); + + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + arithmetic::float_types::EXTENSION.to_owned(), + arithmetic::float_ops::EXTENSION.to_owned(), + arithmetic::conversions::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build + .finish_hugr_with_outputs(to_int.outputs(), ®) + .unwrap(); + assert_eq!(h.node_count(), 8); + + constant_fold_pass(&mut h, ®); + + let expected = Value::sum(0, [i2c(2).clone()], sum_type).unwrap(); + assert_fully_folded(&h, &expected); + } + + #[test] + #[cfg_attr( + feature = "extension_inference", + ignore = "inference fails for test graph, it shouldn't" + )] + fn test_list_ops() -> Result<(), Box> { + use crate::std_extensions::collections::{self, ListOp, ListValue}; + + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + logic::EXTENSION.to_owned(), + collections::EXTENSION.to_owned(), + ]) + .unwrap(); + let list: Value = ListValue::new(BOOL_T, [Value::unit_sum(0, 1).unwrap()]).into(); + let mut build = DFGBuilder::new(FunctionType::new( + type_row![], + vec![list.get_type().clone()], + )) + .unwrap(); + + let list_wire = build.add_load_const(list.clone()); + + let pop = build.add_dataflow_op( + ListOp::Pop.with_type(BOOL_T).to_extension_op(®).unwrap(), + [list_wire], + )?; + + let push = build.add_dataflow_op( + ListOp::Push + .with_type(BOOL_T) + .to_extension_op(®) + .unwrap(), + pop.outputs(), + )?; + let mut h = build.finish_hugr_with_outputs(push.outputs(), ®)?; + constant_fold_pass(&mut h, ®); + + assert_fully_folded(&h, &list); + Ok(()) + } + + #[test] + fn test_fold_and() { + // pseudocode: + // x0, x1 := bool(true), bool(true) + // x2 := and(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::true_val()); + let x1 = build.add_load_const(Value::true_val()); + let x2 = build + .add_dataflow_op(NaryLogic::And.with_n_inputs(2), [x0, x1]) + .unwrap(); + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); + } + + #[test] + fn test_fold_or() { + // pseudocode: + // x0, x1 := bool(true), bool(false) + // x2 := or(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::true_val()); + let x1 = build.add_load_const(Value::false_val()); + let x2 = build + .add_dataflow_op(NaryLogic::Or.with_n_inputs(2), [x0, x1]) + .unwrap(); + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); + } + + #[test] + fn test_fold_not() { + // pseudocode: + // x0 := bool(true) + // x1 := not(x0) + // output x1 == false; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::true_val()); + let x1 = build.add_dataflow_op(NotOp, [x0]).unwrap(); + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); + let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::false_val(); + assert_fully_folded(&h, &expected); + } + + #[test] + fn orphan_output() { + // pseudocode: + // x0 := bool(true) + // x1 := not(x0) + // x2 := or(x0,x1) + // output x2 == true; + // + // We arange things so that the `or` folds away first, leaving the not + // with no outputs. + use crate::hugr::NodeType; + use crate::ops::handle::NodeHandle; + + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let true_wire = build.add_load_value(Value::true_val()); + // this Not will be manually replaced + let orig_not = build.add_dataflow_op(NotOp, [true_wire]).unwrap(); + let r = build + .add_dataflow_op( + NaryLogic::Or.with_n_inputs(2), + [true_wire, orig_not.out_wire(0)], + ) + .unwrap(); + let or_node = r.node(); + let parent = build.dfg_node; + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); + let mut h = build.finish_hugr_with_outputs(r.outputs(), ®).unwrap(); + + // we delete the original Not and create a new One. This means it will be + // traversed by `constant_fold_pass` after the Or. + let new_not = h.add_node_with_parent(parent, NodeType::new_auto(NotOp)); + h.connect(true_wire.node(), true_wire.source(), new_not, 0); + h.disconnect(or_node, IncomingPort::from(1)); + h.connect(new_not, 0, or_node, 1); + h.remove_node(orig_not.node()); + constant_fold_pass(&mut h, ®); + assert_fully_folded(&h, &Value::true_val()) + } + + #[test] + fn test_folding_pass_issue_996() { + // pseudocode: + // + // x0 := 3.0 + // x1 := 4.0 + // x2 := fne(x0, x1); // true + // x3 := flt(x0, x1); // true + // x4 := and(x2, x3); // true + // x5 := -10.0 + // x6 := flt(x0, x5) // false + // x7 := or(x4, x6) // true + // output x7 + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstF64::new(3.0))); + let x1 = build.add_load_const(Value::extension(ConstF64::new(4.0))); + let x2 = build.add_dataflow_op(FloatOps::fne, [x0, x1]).unwrap(); + let x3 = build.add_dataflow_op(FloatOps::flt, [x0, x1]).unwrap(); + let x4 = build + .add_dataflow_op( + NaryLogic::And.with_n_inputs(2), + x2.outputs().chain(x3.outputs()), + ) + .unwrap(); + let x5 = build.add_load_const(Value::extension(ConstF64::new(-10.0))); + let x6 = build.add_dataflow_op(FloatOps::flt, [x0, x5]).unwrap(); + let x7 = build + .add_dataflow_op( + NaryLogic::Or.with_n_inputs(2), + x4.outputs().chain(x6.outputs()), + ) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + logic::EXTENSION.to_owned(), + arithmetic::float_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x7.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); + } + + #[test] + fn test_const_fold_to_nonfinite() { + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::float_types::EXTENSION.to_owned(), + ]) + .unwrap(); + + // HUGR computing 1.0 / 1.0 + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![FLOAT64_TYPE])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); + let x1 = build.add_load_const(Value::extension(ConstF64::new(1.0))); + let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); + let mut h0 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h0, ®); + assert_fully_folded_with(&h0, |v| { + v.get_custom_value::().unwrap().value() == 1.0 + }); + assert_eq!(h0.node_count(), 5); + + // HUGR computing 1.0 / 0.0 + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![FLOAT64_TYPE])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); + let x1 = build.add_load_const(Value::extension(ConstF64::new(0.0))); + let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); + let mut h1 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h1, ®); + assert_eq!(h1.node_count(), 8); + } +} diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs new file mode 100644 index 000000000..2ed5859f0 --- /dev/null +++ b/hugr-passes/src/lib.rs @@ -0,0 +1,3 @@ +pub mod const_fold; +pub mod merge_bbs; +pub mod nest_cfgs; diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs new file mode 100644 index 000000000..06d3b3bfe --- /dev/null +++ b/hugr-passes/src/merge_bbs.rs @@ -0,0 +1,398 @@ +//! Merge BBs along control-flow edges where the source BB has no other successors +//! and the target BB has no other predecessors. +use std::collections::HashMap; + +use itertools::Itertools; + +use crate::hugr::rewrite::inline_dfg::InlineDFG; +use crate::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec, Replacement}; +use crate::hugr::{HugrMut, RootTagged}; +use crate::ops::handle::CfgID; +use crate::ops::leaf::UnpackTuple; +use crate::ops::{DataflowBlock, DataflowParent, Input, Output, DFG}; +use crate::{Hugr, HugrView, Node}; + +/// Merge any basic blocks that are direct children of the specified CFG +/// i.e. where a basic block B has a single successor B' whose only predecessor +/// is B, B and B' can be combined. +pub fn merge_basic_blocks(cfg: &mut impl HugrMut) { + let mut worklist = cfg.nodes().collect::>(); + while let Some(n) = worklist.pop() { + // Consider merging n with its successor + let Ok(succ) = cfg.output_neighbours(n).exactly_one() else { + continue; + }; + if cfg.input_neighbours(succ).count() != 1 { + continue; + }; + if cfg.children(cfg.root()).take(2).contains(&succ) { + // If succ is... + // - the entry block, that has an implicit extra in-edge, so cannot merge with n. + // - the exit block, nodes in n should move *outside* the CFG - a separate pass. + continue; + }; + let (rep, merge_bb, dfgs) = mk_rep(cfg, n, succ); + let node_map = cfg.hugr_mut().apply_rewrite(rep).unwrap(); + let merged_bb = *node_map.get(&merge_bb).unwrap(); + for dfg_id in dfgs { + let n_id = *node_map.get(&dfg_id).unwrap(); + cfg.hugr_mut() + .apply_rewrite(InlineDFG(n_id.into())) + .unwrap(); + } + worklist.push(merged_bb); + } +} + +fn mk_rep( + cfg: &impl RootTagged, + pred: Node, + succ: Node, +) -> (Replacement, Node, [Node; 2]) { + let pred_ty = cfg.get_optype(pred).as_dataflow_block().unwrap(); + let succ_ty = cfg.get_optype(succ).as_dataflow_block().unwrap(); + let succ_sig = succ_ty.inner_signature(); + let mut replacement: Hugr = Hugr::new(cfg.root_type().clone()); + let merged = replacement.add_node_with_parent(replacement.root(), { + let mut merged_block = DataflowBlock { + inputs: pred_ty.inputs.clone(), + ..succ_ty.clone() + }; + merged_block.extension_delta = merged_block + .extension_delta + .union(pred_ty.extension_delta.clone()); + merged_block + }); + let input = replacement.add_node_with_parent( + merged, + Input { + types: pred_ty.inputs.clone(), + }, + ); + let output = replacement.add_node_with_parent( + merged, + Output { + types: succ_sig.output.clone(), + }, + ); + + let dfg1 = replacement.add_node_with_parent( + merged, + DFG { + signature: pred_ty.inner_signature().clone(), + }, + ); + for (i, _) in pred_ty.inputs.iter().enumerate() { + replacement.connect(input, i, dfg1, i) + } + + let dfg2 = replacement.add_node_with_parent( + merged, + DFG { + signature: succ_sig.clone(), + }, + ); + for (i, _) in succ_sig.output.iter().enumerate() { + replacement.connect(dfg2, i, output, i) + } + + // At the junction, must unpack the first (tuple, branch predicate) output + let tuple_elems = pred_ty.sum_rows.clone().into_iter().exactly_one().unwrap(); + let unp = replacement.add_node_with_parent( + merged, + UnpackTuple { + tys: tuple_elems.clone(), + }, + ); + replacement.connect(dfg1, 0, unp, 0); + let other_start = tuple_elems.len(); + for (i, _) in tuple_elems.iter().enumerate() { + replacement.connect(unp, i, dfg2, i) + } + for (i, _) in pred_ty.other_outputs.iter().enumerate() { + replacement.connect(dfg1, i + 1, dfg2, i + other_start) + } + // If there are edges from succ back to pred, we cannot do these via the mu_inp/out/new + // edge-maps as both source and target of the new edge are in the replacement Hugr + for (_, src_pos) in cfg.all_linked_outputs(pred).filter(|(src, _)| *src == succ) { + replacement.connect(merged, src_pos, merged, 0); + } + let rep = Replacement { + removal: vec![pred, succ], + replacement, + adoptions: HashMap::from([(dfg1, pred), (dfg2, succ)]), + mu_inp: cfg + .all_linked_outputs(pred) + .filter(|(src, _)| *src != succ) + .map(|(src, src_pos)| NewEdgeSpec { + src, + tgt: merged, + kind: NewEdgeKind::ControlFlow { src_pos }, + }) + .collect(), + mu_out: cfg + .node_outputs(succ) + .filter_map(|src_pos| { + let tgt = cfg + .linked_inputs(succ, src_pos) + .exactly_one() + .ok() + .unwrap() + .0; + if tgt == pred { + None + } else { + Some(NewEdgeSpec { + src: merged, + tgt, + kind: NewEdgeKind::ControlFlow { src_pos }, + }) + } + }) + .collect(), + mu_new: vec![], + }; + (rep, merged, [dfg1, dfg2]) +} + +#[cfg(test)] +mod test { + use std::collections::HashSet; + + use itertools::Itertools; + use rstest::rstest; + + use crate::builder::{CFGBuilder, DFGWrapper, Dataflow, HugrBuilder}; + use crate::extension::prelude::{ConstUsize, PRELUDE_ID, QB_T, USIZE_T}; + use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE, PRELUDE_REGISTRY}; + use crate::hugr::views::sibling::SiblingMut; + use crate::ops::constant::Value; + use crate::ops::handle::CfgID; + use crate::ops::{Lift, LoadConstant, Noop, OpTrait, OpType}; + use crate::types::{FunctionType, Type, TypeRow}; + use crate::{const_extension_ids, type_row, Extension, Hugr, HugrView, Wire}; + + use super::merge_basic_blocks; + + const_extension_ids! { + const EXT_ID: ExtensionId = "TestExt"; + } + + fn extension() -> Extension { + let mut e = Extension::new(EXT_ID); + e.add_op( + "Test".into(), + String::new(), + FunctionType::new( + type_row![QB_T, USIZE_T], + TypeRow::from(vec![Type::new_sum(vec![ + type_row![QB_T], + type_row![USIZE_T], + ])]), + ), + ) + .unwrap(); + e + } + + fn lifted_unary_unit_sum + AsRef, T>(b: &mut DFGWrapper) -> Wire { + let lc = b.add_load_value(Value::unary_unit_sum()); + let lift = b + .add_dataflow_op( + Lift { + type_row: Type::new_unit_sum(1).into(), + new_extension: PRELUDE_ID, + }, + [lc], + ) + .unwrap(); + let [w] = lift.outputs_arr(); + w + } + + #[rstest] + #[case(true)] + #[case(false)] + fn in_loop(#[case] self_loop: bool) -> Result<(), Box> { + /* self_loop==False: + -> Noop1 -----> Test -> Exit -> Noop1AndTest --> Exit + | | => / \ + \-<- Noop2 <-/ \-<- Noop2 <-/ + (Noop2 -> Noop1 cannot be merged because Noop1 is the entry node) + + self_loop==True: + -> Noop --> Test -> Exit -> NoopAndTest --> Exit + | | => / \ + \--<--<--/ \--<-----<--/ + */ + let loop_variants = type_row![QB_T]; + let exit_types = type_row![USIZE_T]; + let e = extension(); + let tst_op = e.instantiate_extension_op("Test", [], &PRELUDE_REGISTRY)?; + let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), e])?; + let mut h = CFGBuilder::new( + FunctionType::new(loop_variants.clone(), exit_types.clone()) + .with_extension_delta(ExtensionSet::singleton(&PRELUDE_ID)), + )?; + let mut no_b1 = h.simple_entry_builder(loop_variants.clone(), 1, ExtensionSet::new())?; + let n = no_b1.add_dataflow_op(Noop { ty: QB_T }, no_b1.input_wires())?; + let br = lifted_unary_unit_sum(&mut no_b1); + let no_b1 = no_b1.finish_with_outputs(br, n.outputs())?; + let mut test_block = h.block_builder( + loop_variants.clone(), + vec![loop_variants.clone(), exit_types], + ExtensionSet::singleton(&PRELUDE_ID), + type_row![], + )?; + let [test_input] = test_block.input_wires_arr(); + let usize_cst = test_block.add_load_value(ConstUsize::new(1)); + let [tst] = test_block + .add_dataflow_op(tst_op, [test_input, usize_cst])? + .outputs_arr(); + let test_block = test_block.finish_with_outputs(tst, [])?; + let loop_backedge_target = if self_loop { + no_b1 + } else { + let mut no_b2 = h.simple_block_builder(FunctionType::new_endo(loop_variants), 1)?; + let n = no_b2.add_dataflow_op(Noop { ty: QB_T }, no_b2.input_wires())?; + let br = lifted_unary_unit_sum(&mut no_b2); + let nid = no_b2.finish_with_outputs(br, n.outputs())?; + h.branch(&nid, 0, &no_b1)?; + nid + }; + h.branch(&no_b1, 0, &test_block)?; + h.branch(&test_block, 0, &loop_backedge_target)?; + h.branch(&test_block, 1, &h.exit_block())?; + + let mut h = h.finish_hugr(®)?; + let r = h.root(); + merge_basic_blocks(&mut SiblingMut::::try_new(&mut h, r)?); + h.update_validate(®).unwrap(); + assert_eq!(r, h.root()); + assert!(matches!(h.get_optype(r), OpType::CFG(_))); + let [entry, exit] = h + .children(r) + .take(2) + .collect::>() + .try_into() + .unwrap(); + // Check the Noop('s) is/are in the right block(s) + let nops = h + .nodes() + .filter(|n| matches!(h.get_optype(*n), OpType::Noop(_))); + let (entry_nop, expected_backedge_target) = if self_loop { + assert_eq!(h.children(r).len(), 2); + (nops.exactly_one().ok().unwrap(), entry) + } else { + let [_, _, no_b2] = h.children(r).collect::>().try_into().unwrap(); + let mut nops = nops.collect::>(); + let entry_nop_idx = nops + .iter() + .position(|n| h.get_parent(*n) == Some(entry)) + .unwrap(); + let entry_nop = nops[entry_nop_idx]; + nops.remove(entry_nop_idx); + let [n_op2] = nops.try_into().unwrap(); + assert_eq!(h.get_parent(n_op2), Some(no_b2)); + (entry_nop, no_b2) + }; + assert_eq!(h.get_parent(entry_nop), Some(entry)); + assert_eq!( + h.output_neighbours(entry).collect::>(), + HashSet::from([expected_backedge_target, exit]) + ); + // And the Noop in the entry block is consumed by the custom Test op + let tst = find_unique(h.nodes(), |n| { + matches!(h.get_optype(*n), OpType::CustomOp(_)) + }); + assert_eq!(h.get_parent(tst), Some(entry)); + assert_eq!( + h.output_neighbours(entry_nop).collect::>(), + vec![tst] + ); + Ok(()) + } + + #[test] + fn triple_with_permute() -> Result<(), Box> { + // Blocks are just BB1 -> BB2 -> BB3 --> Exit. + // CFG Normalization would move everything outside the CFG and elide the CFG altogether, + // but this is an easy-to-construct test of merge-basic-blocks only (no CFG normalization). + let e = extension(); + let tst_op: OpType = e + .instantiate_extension_op("Test", &[], &PRELUDE_REGISTRY)? + .into(); + let [res_t] = tst_op + .dataflow_signature() + .unwrap() + .output + .into_owned() + .try_into() + .unwrap(); + let mut h = CFGBuilder::new( + FunctionType::new(QB_T, res_t.clone()) + .with_extension_delta(ExtensionSet::singleton(&PRELUDE_ID)), + )?; + let mut bb1 = h.entry_builder( + vec![type_row![]], + type_row![USIZE_T, QB_T], + ExtensionSet::singleton(&PRELUDE_ID), + )?; + let [inw] = bb1.input_wires_arr(); + let load_cst = bb1.add_load_value(ConstUsize::new(1)); + let pred = lifted_unary_unit_sum(&mut bb1); + let bb1 = bb1.finish_with_outputs(pred, [load_cst, inw])?; + + let mut bb2 = h.block_builder( + type_row![USIZE_T, QB_T], + vec![type_row![]], + ExtensionSet::new(), + type_row![QB_T, USIZE_T], + )?; + let [u, q] = bb2.input_wires_arr(); + let pred = lifted_unary_unit_sum(&mut bb2); + let bb2 = bb2.finish_with_outputs(pred, [q, u])?; + + let mut bb3 = h.block_builder( + type_row![QB_T, USIZE_T], + vec![type_row![]], + ExtensionSet::new(), + res_t.clone().into(), + )?; + let [q, u] = bb3.input_wires_arr(); + let tst = bb3.add_dataflow_op(tst_op, [q, u])?; + let pred = lifted_unary_unit_sum(&mut bb3); + let bb3 = bb3.finish_with_outputs(pred, tst.outputs())?; + // Now add control-flow edges between basic blocks + h.branch(&bb1, 0, &bb2)?; + h.branch(&bb2, 0, &bb3)?; + h.branch(&bb3, 0, &h.exit_block())?; + + let reg = ExtensionRegistry::try_new([e, PRELUDE.to_owned()])?; + let mut h = h.finish_hugr(®)?; + let root = h.root(); + merge_basic_blocks(&mut SiblingMut::try_new(&mut h, root)?); + h.update_validate(®)?; + + // Should only be one BB left + let [bb, _exit] = h.children(h.root()).collect::>().try_into().unwrap(); + let tst = find_unique(h.nodes(), |n| { + matches!(h.get_optype(*n), OpType::CustomOp(_)) + }); + assert_eq!(h.get_parent(tst), Some(bb)); + + let inp = find_unique(h.nodes(), |n| matches!(h.get_optype(*n), OpType::Input(_))); + let mut tst_inputs = h.input_neighbours(tst).collect::>(); + tst_inputs.remove(tst_inputs.iter().find_position(|n| **n == inp).unwrap().0); + let [other_input] = tst_inputs.try_into().unwrap(); + assert_eq!( + h.get_optype(other_input), + &(LoadConstant { datatype: USIZE_T }.into()) + ); + Ok(()) + } + + fn find_unique(items: impl Iterator, pred: impl Fn(&T) -> bool) -> T { + items.filter(pred).exactly_one().ok().unwrap() + } +} diff --git a/hugr-passes/src/nest_cfgs.rs b/hugr-passes/src/nest_cfgs.rs new file mode 100644 index 000000000..feae6470b --- /dev/null +++ b/hugr-passes/src/nest_cfgs.rs @@ -0,0 +1,946 @@ +//! # Nest CFGs +//! +//! Identify Single-Entry-Single-Exit (SESE) regions in the CFG. +//! These are pairs of edges (a,b) where +//! * a dominates b +//! * b postdominates a +//! * there are no other edges in/out of the nodes inbetween +//! (this last condition is necessary because loop backedges do not affect (post)dominance). +//! +//! # Algorithm +//! See paper: , approximately: +//! 1. those three conditions are equivalent to: +//! *a and b are cycle-equivalent in the CFG with an extra edge from the exit node to the entry* +//! where cycle-equivalent means every cycle has either both a and b, or neither +//! 2. cycle equivalence is unaffected if all edges are considered *un*directed +//! (not obvious, see paper for proof) +//! 3. take undirected CFG, perform depth-first traversal +//! => all edges are either *tree edges*, or *backedges* where one endpoint is an ancestor of the other +//! 4. identify the "bracketlist" of each tree edge - the set of backedges going from a descendant of that edge to an ancestor +//! -- post-order traversal, merging bracketlists of children, +//! then delete backedges from below to here, add backedges from here to above +//! => tree edges with the same bracketlist are cycle-equivalent; +//! + a tree edge with a single-element bracketlist is cycle-equivalent with that single element +//! 5. this would be expensive (comparing large sets of backedges) - so to optimize, +//! - the backedge most recently added (at the top) of the bracketlist, plus the size of the bracketlist, +//! is sufficient to identify the set *when the UDFS tree is linear*; +//! - when UDFS is treelike, any ancestor with brackets from >1 subtree cannot be cycle-equivalent with any descendant +//! (as the brackets of said descendant come from beneath it to its ancestors, not from any sibling/etc. in the other subtree). +//! So, add (onto top of bracketlist) a fake "capping" backedge from here to the highest ancestor reached by >1 subtree. +//! (Thus, edges from here up to that ancestor, cannot be cycle-equivalent with any edges elsewhere.) +//! +//! # Restrictions +//! * The paper assumes that all CFG nodes are on paths from entry to exit, i.e. no loops without exits. +//! HUGR assumes only that they are all reachable from entry, so we do a backward traversal from exit node +//! first and restrict to the CFG nodes in the reachable set. (This means we will not discover SESE regions +//! in exit-free loops, but that doesn't seem a major concern.) +//! * Multiple edges in the same direction between the same BBs will "confuse" the algorithm in the paper. +//! However it is straightforward for us to treat successors and predecessors as sets. (Two edges between +//! the same BBs but in opposite directions must be distinct!) + +use std::collections::{HashMap, HashSet, LinkedList, VecDeque}; +use std::hash::Hash; + +use itertools::Itertools; +use thiserror::Error; + +use crate::hugr::rewrite::outline_cfg::OutlineCfg; +use crate::hugr::views::sibling::SiblingMut; +use crate::hugr::views::{HierarchyView, HugrView, SiblingGraph}; +use crate::hugr::{HugrMut, Rewrite, RootTagged}; +use crate::ops::handle::{BasicBlockID, CfgID}; +use crate::ops::OpTag; +use crate::ops::OpTrait; +use crate::{Direction, Hugr, Node}; + +/// A "view" of a CFG in a Hugr which allows basic blocks in the underlying CFG to be split into +/// multiple blocks in the view (or merged together). +/// `T` is the type of basic block; this can just be a BasicBlock (e.g. [`Node`]) in the Hugr, +/// or an [IdentityCfgMap] if the extra level of indirection is not required. However, since +/// SESE regions are bounded by edges between pairs of such `T`, such splitting may allow the +/// algorithm to identify more regions than existed in the underlying CFG, without mutating the +/// underlying CFG just for the analysis - the splitting (and/or merging) can then be performed by +/// [CfgNester::nest_sese_region] only as necessary for regions actually nested. +pub trait CfgNodeMap { + /// The unique entry node of the CFG. It may any n>=0 of incoming edges; we assume control arrives here from "outside". + fn entry_node(&self) -> T; + /// The unique exit node of the CFG. The only node to have no successors. + fn exit_node(&self) -> T; + /// Allows the trait implementor to define a type of iterator it will return from + /// `successors` and `predecessors`. + type Iterator<'c>: Iterator + where + Self: 'c; + /// Returns an iterator over the successors of the specified basic block. + fn successors(&self, node: T) -> Self::Iterator<'_>; + /// Returns an iterator over the predecessors of the specified basic block. + fn predecessors(&self, node: T) -> Self::Iterator<'_>; +} + +/// Extension of [CfgNodeMap] to that can perform (mutable/destructive) +/// nesting of regions detected. +pub trait CfgNester: CfgNodeMap { + /// Given an entry edge and exit edge defining a SESE region, mutates the + /// Hugr such that all nodes between these edges are placed in a nested CFG. + /// Returns the newly-constructed block (containing a nested CFG). + /// + /// # Panics + /// May panic if the two edges do not constitute a SESE region. + fn nest_sese_region(&mut self, entry_edge: (T, T), exit_edge: (T, T)) -> T; +} + +/// Transforms a CFG into as much-nested a form as possible. +pub fn transform_cfg_to_nested( + view: &mut impl CfgNester, +) { + let edge_classes = EdgeClassifier::get_edge_classes(view); + let mut rem_edges: HashMap> = HashMap::new(); + for (e, cls) in edge_classes.iter() { + rem_edges.entry(*cls).or_default().insert(*e); + } + + // Traverse. Any traversal will encounter edges in SESE-respecting order. + fn traverse( + view: &mut impl CfgNester, + n: T, + edge_classes: &HashMap<(T, T), usize>, + rem_edges: &mut HashMap>, + stop_at: Option, + ) -> Option<(T, T)> { + let mut seen = HashSet::new(); + let mut stack = Vec::new(); + let mut exit_edges = Vec::new(); + stack.push(n); + while let Some(n) = stack.pop() { + if !seen.insert(n) { + continue; + } + let (exit, rest): (Vec<_>, Vec<_>) = view + .successors(n) + .map(|s| (n, s)) + .partition(|e| stop_at.is_some() && edge_classes.get(e).copied() == stop_at); + exit_edges.extend(exit.into_iter().at_most_one().unwrap()); + for mut e in rest { + if let Some(cls) = edge_classes.get(&e) { + assert!(rem_edges.get_mut(cls).unwrap().remove(&e)); + // While there are more edges in that same class, we can traverse the entire + // subregion between pairs of edges in that class in a single step + // (as these are strictly nested in any outer region) + while !rem_edges.get_mut(cls).unwrap().is_empty() { + let prev_e = e; + // Traverse to the next edge in the same class - we know it exists in the set + e = traverse(view, e.1, edge_classes, rem_edges, Some(*cls)).unwrap(); + assert!(rem_edges.get_mut(cls).unwrap().remove(&e)); + // Skip trivial regions of a single node, unless the node has other edges + // (non-exiting, but e.g. a backedge to a loop header, ending that loop) + if prev_e.1 != e.0 || view.successors(e.0).count() > 1 { + // Traversal and nesting of the subregion's *contents* were completed in the + // recursive call above, so only processed nodes are moved into descendant CFGs + e = (view.nest_sese_region(prev_e, e), e.1) + }; + } + } + stack.push(e.1); + } + } + exit_edges.into_iter().unique().at_most_one().unwrap() + } + traverse(view, view.entry_node(), &edge_classes, &mut rem_edges, None); + // TODO we should probably now try to merge consecutive basic blocks + // (i.e. where a BB has a single successor, that has a single predecessor) + // and thus convert CF dependencies into (parallelizable) dataflow. +} + +/// Search the entire Hugr looking for CFGs, and transform each +/// into as deeply-nested form as possible (as per [transform_cfg_to_nested]). +/// This search may be expensive, although if it finds much/many CFGs, +/// the analysis/transformation on them is likely to be more expensive still! +pub fn transform_all_cfgs(h: &mut Hugr) { + let mut node_stack = Vec::from([h.root()]); + while let Some(n) = node_stack.pop() { + if let Ok(s) = SiblingMut::::try_new(h, n) { + transform_cfg_to_nested(&mut IdentityCfgMap::new(s)); + } + node_stack.extend(h.children(n)) + } +} + +/// Directed edges in a Cfg - i.e. along which control flows from first to second only. +type CfgEdge = (T, T); + +// The next enum + few functions allow to abstract over the edge directions +// in a CfgView. + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +enum EdgeDest { + Forward(T), + Backward(T), +} + +impl EdgeDest { + pub fn target(&self) -> T { + match self { + EdgeDest::Forward(i) => *i, + EdgeDest::Backward(i) => *i, + } + } +} + +fn all_edges<'a, T: Copy + Clone + PartialEq + Eq + Hash + 'a>( + cfg: &'a impl CfgNodeMap, + n: T, +) -> impl Iterator> + '_ { + let extra = if n == cfg.exit_node() { + vec![cfg.entry_node()] + } else { + vec![] + }; + cfg.successors(n) + .chain(extra) + .map(EdgeDest::Forward) + .chain(cfg.predecessors(n).map(EdgeDest::Backward)) + .unique() +} + +fn flip(src: T, d: EdgeDest) -> (T, EdgeDest) { + match d { + EdgeDest::Forward(tgt) => (tgt, EdgeDest::Backward(src)), + EdgeDest::Backward(tgt) => (tgt, EdgeDest::Forward(src)), + } +} + +fn cfg_edge(s: T, d: EdgeDest) -> CfgEdge { + match d { + EdgeDest::Forward(t) => (s, t), + EdgeDest::Backward(t) => (t, s), + } +} + +/// A straightforward view of a Cfg as it appears in a Hugr +pub struct IdentityCfgMap { + h: H, + entry: Node, + exit: Node, +} +impl> IdentityCfgMap { + /// Creates an [IdentityCfgMap] for the specified CFG + pub fn new(h: H) -> Self { + // Panic if malformed enough not to have two children + let (entry, exit) = h.children(h.root()).take(2).collect_tuple().unwrap(); + debug_assert_eq!(h.get_optype(exit).tag(), OpTag::BasicBlockExit); + Self { h, entry, exit } + } +} +impl CfgNodeMap for IdentityCfgMap { + fn entry_node(&self) -> Node { + self.entry + } + + fn exit_node(&self) -> Node { + self.exit + } + + type Iterator<'c> = ::Neighbours<'c> + where + Self: 'c; + + fn successors(&self, node: Node) -> Self::Iterator<'_> { + self.h.neighbours(node, Direction::Outgoing) + } + + fn predecessors(&self, node: Node) -> Self::Iterator<'_> { + self.h.neighbours(node, Direction::Incoming) + } +} + +impl CfgNester for IdentityCfgMap { + fn nest_sese_region(&mut self, entry_edge: (Node, Node), exit_edge: (Node, Node)) -> Node { + // The algorithm only calls with entry/exit edges for a SESE region; panic if they don't + let blocks = region_blocks(self, entry_edge, exit_edge).unwrap(); + assert!([entry_edge.0, entry_edge.1, exit_edge.0, exit_edge.1] + .iter() + .all(|n| self.h.get_parent(*n) == Some(self.h.root()))); + let (new_block, new_cfg) = OutlineCfg::new(blocks).apply(&mut self.h).unwrap(); + debug_assert!([entry_edge.0, exit_edge.1] + .iter() + .all(|n| self.h.get_parent(*n) == Some(self.h.root()))); + + debug_assert!({ + let new_block_view = SiblingGraph::::try_new(&self.h, new_block).unwrap(); + let new_cfg_view = SiblingGraph::::try_new(&new_block_view, new_cfg).unwrap(); + [entry_edge.1, exit_edge.0] + .iter() + .all(|n| new_cfg_view.get_parent(*n) == Some(new_cfg)) + }); + new_block + } +} + +/// An error trying to get the blocks of a SESE (single-entry-single-exit) region +#[derive(Clone, Debug, Error)] +#[non_exhaustive] +pub enum RegionBlocksError { + /// The specified exit edge did not exist in the CFG + ExitEdgeNotPresent(T, T), + /// The specified entry edge did not exist in the CFG + EntryEdgeNotPresent(T, T), + /// The source of the entry edge was in the region + /// (reachable from the target of the entry edge without using the exit edge) + EntryEdgeSourceInRegion(T), + /// The target of the entry edge had other predecessors (given) + /// that were outside the region (i.e. not reachable from the target) + UnexpectedEntryEdges(Vec), +} + +/// Given entry and exit edges for a SESE region, identify all the blocks in it. +pub fn region_blocks( + v: &impl CfgNodeMap, + entry_edge: (T, T), + exit_edge: (T, T), +) -> Result, RegionBlocksError> { + let mut blocks = HashSet::new(); + let mut queue = VecDeque::new(); + queue.push_back(entry_edge.1); + while let Some(n) = queue.pop_front() { + if blocks.insert(n) { + if n == exit_edge.0 { + let succs: Vec = v.successors(n).collect(); + let n_succs = succs.len(); + let internal_succs: Vec = + succs.into_iter().filter(|s| *s != exit_edge.1).collect(); + if internal_succs.len() == n_succs { + return Err(RegionBlocksError::ExitEdgeNotPresent( + exit_edge.0, + exit_edge.1, + )); + } + queue.extend(internal_succs) + } else { + queue.extend(v.successors(n)); + } + } + } + if blocks.contains(&entry_edge.0) { + return Err(RegionBlocksError::EntryEdgeSourceInRegion(entry_edge.0)); + } + + let ext_preds = v + .predecessors(entry_edge.1) + .unique() + .filter(|p| !blocks.contains(p)); + let (expected, extra): (Vec, Vec) = ext_preds.partition(|i| *i == entry_edge.0); + if expected != vec![entry_edge.0] { + return Err(RegionBlocksError::EntryEdgeNotPresent( + entry_edge.0, + entry_edge.1, + )); + }; + if !extra.is_empty() { + return Err(RegionBlocksError::UnexpectedEntryEdges(extra)); + } + // We could check for other nodes in the region having predecessors outside it, but that would be more expensive + Ok(blocks) +} + +/// Records an undirected Depth First Search over a CfgView, +/// restricted to nodes forwards-reachable from the entry. +/// That is, the DFS traversal goes both ways along the edges of the CFG. +/// *Undirected* DFS classifies all edges into *only two* categories +/// * tree edges, which on their own (with the nodes) form a tree (minimum spanning tree); +/// * backedges, i.e. those for which, when DFS tried to traverse them, the other endpoint was an ancestor +/// Moreover, we record *which way* along the underlying CFG edge we went. +struct UndirectedDFSTree { + /// Pre-order traversal numbering + dfs_num: HashMap, + /// For each node, the edge along which it was reached from its parent + dfs_parents: HashMap>, +} + +impl UndirectedDFSTree { + fn new(cfg: &impl CfgNodeMap) -> Self { + //1. Traverse backwards-only from exit building bitset of reachable nodes + let mut reachable = HashSet::new(); + { + let mut pending = VecDeque::new(); + pending.push_back(cfg.exit_node()); + while let Some(n) = pending.pop_front() { + if reachable.insert(n) { + pending.extend(cfg.predecessors(n)); + } + } + } + //2. Traverse undirected from entry node, building dfs_num and setting dfs_parents + let mut dfs_num = HashMap::new(); + let mut dfs_parents = HashMap::new(); + { + // Node, and directed edge along which reached + let mut pending = vec![(cfg.entry_node(), EdgeDest::Backward(cfg.exit_node()))]; + while let Some((n, p_edge)) = pending.pop() { + if !dfs_num.contains_key(&n) && reachable.contains(&n) { + dfs_num.insert(n, dfs_num.len()); + dfs_parents.insert(n, p_edge); + for e in all_edges(cfg, n) { + pending.push(flip(n, e)); + } + } + } + dfs_parents.remove(&cfg.entry_node()).unwrap(); + } + UndirectedDFSTree { + dfs_num, + dfs_parents, + } + } +} + +#[derive(Clone, PartialEq, Eq, Hash)] +enum Bracket { + Real(CfgEdge), + Capping(usize, T), +} + +/// Manages a list of brackets. The goal here is to allow constant-time deletion +/// out of the middle of the list - which isn't really possible, so instead we +/// track deleted items (in an external set) and the remaining number (here). +/// +/// Note - we could put the items deleted from *this* BracketList here, and merge in concat(). +/// That would be cleaner, but repeated set-merging would be slower than adding the +/// deleted items to a single set in the `TraversalState` +struct BracketList { + items: LinkedList>, // Allows O(1) `append` of two lists. + size: usize, // Not counting deleted items. +} + +impl BracketList { + pub fn new() -> Self { + BracketList { + items: LinkedList::new(), + size: 0, + } + } + + pub fn tag(&mut self, deleted: &HashSet>) -> Option<(Bracket, usize)> { + while let Some(e) = self.items.front() { + // Pop deleted elements to save time (and memory) + if deleted.contains(e) { + self.items.pop_front(); + //deleted.remove(e); // Would only save memory, so keep as immutable + } else { + return Some((e.clone(), self.size)); + } + } + None + } + + pub fn concat(&mut self, other: BracketList) { + let BracketList { mut items, size } = other; + self.items.append(&mut items); + assert!(items.is_empty()); + self.size += size; + } + + pub fn delete(&mut self, b: &Bracket, deleted: &mut HashSet>) { + // Ideally, here we would also assert that no *other* BracketList contains b. + debug_assert!(self.items.contains(b)); // Makes operation O(n), otherwise O(1) + let was_new = deleted.insert(b.clone()); + assert!(was_new); + self.size -= 1; + } + + pub fn push(&mut self, e: Bracket) { + self.items.push_back(e); + self.size += 1; + } +} + +/// Mutable state updated during traversal of the UndirectedDFSTree by the cycle equivalence algorithm. +pub struct EdgeClassifier { + /// Edges we have marked as deleted, allowing constant-time deletion without searching BracketList + deleted_backedges: HashSet>, + /// Key is DFS num of highest ancestor + /// to which backedges reached from >1 sibling subtree; + /// Value is the LCA i.e. parent of those siblings. + capping_edges: HashMap>, + /// Result of traversal - accumulated here, entries should never be overwritten + edge_classes: HashMap, Option<(Bracket, usize)>>, +} + +impl EdgeClassifier { + /// Computes equivalence class of each edge, i.e. two edges with the same value + /// are cycle-equivalent. Any two consecutive edges in the same class define a SESE region + /// (where "consecutive" means on any path in the original directed CFG, as the edges + /// in a class all dominate + postdominate each other as part of defn of cycle equivalence). + pub fn get_edge_classes(cfg: &impl CfgNodeMap) -> HashMap, usize> { + let tree = UndirectedDFSTree::new(cfg); + let mut s = Self { + deleted_backedges: HashSet::new(), + capping_edges: HashMap::new(), + edge_classes: HashMap::new(), + }; + s.traverse(cfg, &tree, cfg.entry_node()); + assert!(s.capping_edges.is_empty()); + s.edge_classes.remove(&(cfg.exit_node(), cfg.entry_node())); + let mut cycle_class_idxs = HashMap::new(); + s.edge_classes + .into_iter() + .map(|(k, v)| { + let l = cycle_class_idxs.len(); + (k, *cycle_class_idxs.entry(v).or_insert(l)) + }) + .collect() + } + + /// Returns the lowest DFS num (highest ancestor) reached by any bracket leaving + /// the subtree, and the list of said brackets. + fn traverse( + &mut self, + cfg: &impl CfgNodeMap, + tree: &UndirectedDFSTree, + n: T, + ) -> (usize, BracketList) { + let n_dfs = *tree.dfs_num.get(&n).unwrap(); // should only be called for nodes on path to exit + let (children, non_capping_backedges): (Vec<_>, Vec<_>) = all_edges(cfg, n) + .filter(|e| tree.dfs_num.contains_key(&e.target())) + .partition(|e| { + // The tree edges are those whose *targets* list the edge as parent-edge + let (tgt, from) = flip(n, *e); + tree.dfs_parents.get(&tgt) == Some(&from) + }); + let child_results: Vec<_> = children + .iter() + .map(|c| self.traverse(cfg, tree, c.target())) + .collect(); + let mut min_dfs_target: [Option; 2] = [None, None]; // We want highest-but-one + let mut bs = BracketList::new(); + for (tgt, brs) in child_results { + if tgt < min_dfs_target[0].unwrap_or(usize::MAX) { + min_dfs_target = [Some(tgt), min_dfs_target[0]] + } else if tgt < min_dfs_target[1].unwrap_or(usize::MAX) { + min_dfs_target[1] = Some(tgt) + } + bs.concat(brs); + } + // Add capping backedge + if let Some(min1dfs) = min_dfs_target[1] { + if min1dfs < n_dfs { + bs.push(Bracket::Capping(min1dfs, n)); + // mark capping edge to be removed when we return out to the other end + self.capping_edges.entry(min1dfs).or_default().push(n); + } + } + + let parent_edge = tree.dfs_parents.get(&n); + let (be_up, be_down): (Vec<_>, Vec<_>) = non_capping_backedges + .into_iter() + .map(|e| (*tree.dfs_num.get(&e.target()).unwrap(), e)) + .partition(|(dfs, _)| *dfs < n_dfs); + + // Remove edges to here from beneath + for (_, e) in be_down { + let e = cfg_edge(n, e); + let b = Bracket::Real(e); + bs.delete(&b, &mut self.deleted_backedges); + // Last chance to assign an edge class! This will be a singleton class, + // but assign for consistency with other singletons. + self.edge_classes.entry(e).or_insert_with(|| Some((b, 0))); + } + // And capping backedges + for src in self.capping_edges.remove(&n_dfs).unwrap_or_default() { + bs.delete(&Bracket::Capping(n_dfs, src), &mut self.deleted_backedges) + } + + // Add backedges from here to ancestors (not the parent edge, but perhaps other edges to the same node) + be_up + .iter() + .filter(|(_, e)| Some(e) != parent_edge) + .for_each(|(_, e)| bs.push(Bracket::Real(cfg_edge(n, *e)))); + + // Now calculate edge classes + let class = bs.tag(&self.deleted_backedges); + if let Some((Bracket::Real(e), 1)) = &class { + self.edge_classes.insert(*e, class.clone()); + } + if let Some(parent_edge) = tree.dfs_parents.get(&n) { + self.edge_classes.insert(cfg_edge(n, *parent_edge), class); + } + let highest_target = be_up + .into_iter() + .map(|(dfs, _)| dfs) + .chain(min_dfs_target[0]); + (highest_target.min().unwrap_or(usize::MAX), bs) + } +} + +#[cfg(test)] +pub(crate) mod test { + use super::*; + use crate::builder::{BuildError, CFGBuilder, Container, DataflowSubContainer, HugrBuilder}; + use crate::extension::PRELUDE_REGISTRY; + use crate::extension::{prelude::USIZE_T, ExtensionSet}; + + use crate::hugr::views::RootChecked; + use crate::ops::handle::{ConstID, NodeHandle}; + use crate::ops::Value; + use crate::type_row; + use crate::types::{FunctionType, Type}; + const NAT: Type = USIZE_T; + + pub fn group_by(h: HashMap) -> HashSet> { + let mut res = HashMap::new(); + for (k, v) in h.into_iter() { + res.entry(v).or_insert_with(Vec::new).push(k); + } + res.into_values().map(sorted).collect() + } + + pub fn sorted(items: impl IntoIterator) -> Vec { + let mut v: Vec<_> = items.into_iter().collect(); + v.sort(); + v + } + + #[test] + fn test_cond_then_loop_separate() -> Result<(), BuildError> { + // /-> left --\ + // entry -> split > merge -> head -> tail -> exit + // \-> right -/ \-<--<-/ + let mut cfg_builder = CFGBuilder::new(FunctionType::new_endo(NAT))?; + + let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2")); + let const_unit = cfg_builder.add_constant(Value::unary_unit_sum()); + + let entry = n_identity( + cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?, + &const_unit, + )?; + let (split, merge) = build_if_then_else_merge(&mut cfg_builder, &pred_const, &const_unit)?; + cfg_builder.branch(&entry, 0, &split)?; + let head = n_identity( + cfg_builder.simple_block_builder(FunctionType::new_endo(NAT), 1)?, + &const_unit, + )?; + let tail = n_identity( + cfg_builder.simple_block_builder(FunctionType::new_endo(NAT), 2)?, + &pred_const, + )?; + cfg_builder.branch(&tail, 1, &head)?; + cfg_builder.branch(&head, 0, &tail)?; // trivial "loop body" + cfg_builder.branch(&merge, 0, &head)?; + let exit = cfg_builder.exit_block(); + cfg_builder.branch(&tail, 0, &exit)?; + + let mut h = cfg_builder.finish_prelude_hugr()?; + let rc = RootChecked::<_, CfgID>::try_new(&mut h).unwrap(); + let (entry, exit) = (entry.node(), exit.node()); + let (split, merge, head, tail) = (split.node(), merge.node(), head.node(), tail.node()); + let edge_classes = EdgeClassifier::get_edge_classes(&IdentityCfgMap::new(rc.borrow())); + let [&left, &right] = edge_classes + .keys() + .filter(|(s, _)| *s == split) + .map(|(_, t)| t) + .collect::>()[..] + else { + panic!("Split node should have two successors"); + }; + + let classes = group_by(edge_classes); + assert_eq!( + classes, + HashSet::from([ + sorted([(split, left), (left, merge)]), // Region containing single BB 'left'. + sorted([(split, right), (right, merge)]), // Region containing single BB 'right'. + Vec::from([(head, tail)]), // Loop body and backedges are in their own classes because + Vec::from([(tail, head)]), // the path executing the loop exactly once skips the backedge. + sorted([(entry, split), (merge, head), (tail, exit)]), // Two regions, conditional and then loop. + ]) + ); + transform_cfg_to_nested(&mut IdentityCfgMap::new(rc)); + h.update_validate(&PRELUDE_REGISTRY).unwrap(); + assert_eq!(1, depth(&h, entry)); + assert_eq!(1, depth(&h, exit)); + for n in [split, left, right, merge, head, tail] { + assert_eq!(3, depth(&h, n)); + } + let first = [split, left, right, merge] + .iter() + .map(|n| h.get_parent(*n).unwrap()) + .unique() + .exactly_one() + .unwrap(); + let second = [head, tail] + .iter() + .map(|n| h.get_parent(*n).unwrap()) + .unique() + .exactly_one() + .unwrap(); + assert_ne!(first, second); + Ok(()) + } + + #[test] + fn test_cond_then_loop_combined() -> Result<(), BuildError> { + // Here we would like two consecutive regions, but there is no *edge* between + // the conditional and the loop to indicate the boundary, so we cannot separate them. + let (h, merge, tail) = build_cond_then_loop_cfg()?; + let (merge, tail) = (merge.node(), tail.node()); + let [entry, exit]: [Node; 2] = h + .children(h.root()) + .take(2) + .collect_vec() + .try_into() + .unwrap(); + + let v = IdentityCfgMap::new(RootChecked::try_new(&h).unwrap()); + let edge_classes = EdgeClassifier::get_edge_classes(&v); + let [&left, &right] = edge_classes + .keys() + .filter(|(s, _)| *s == entry) + .map(|(_, t)| t) + .collect::>()[..] + else { + panic!("Entry node should have two successors"); + }; + + let classes = group_by(edge_classes); + assert_eq!( + classes, + HashSet::from([ + sorted([(entry, left), (left, merge)]), // Region containing single BB 'left'. + sorted([(entry, right), (right, merge)]), // Region containing single BB 'right'. + Vec::from([(tail, exit)]), // The only edge in neither conditional nor loop. + Vec::from([(merge, tail)]), // Loop body (at least once per execution). + Vec::from([(tail, merge)]), // Loop backedge (0 or more times per execution). + ]) + ); + Ok(()) + } + + #[test] + fn test_cond_in_loop_separate_headers() -> Result<(), BuildError> { + let (mut h, head, tail) = build_conditional_in_loop_cfg(true)?; + let head = head.node(); + let tail = tail.node(); + // /-> left --\ + // entry -> head -> split > merge -> tail -> exit + // | \-> right -/ | + // \---<---<---<---<---<---<---<---<---/ + // split is unique successor of head + let split = h.output_neighbours(head).exactly_one().unwrap(); + // merge is unique predecessor of tail + let merge = h.input_neighbours(tail).exactly_one().unwrap(); + + // There's no need to use a view of a region here but we do so just to check + // that we *can* (as we'll need to for "real" module Hugr's) + let v = IdentityCfgMap::new(SiblingGraph::try_new(&h, h.root()).unwrap()); + let edge_classes = EdgeClassifier::get_edge_classes(&v); + let IdentityCfgMap { h: _, entry, exit } = v; + let [&left, &right] = edge_classes + .keys() + .filter(|(s, _)| *s == split) + .map(|(_, t)| t) + .collect::>()[..] + else { + panic!("Split node should have two successors"); + }; + let classes = group_by(edge_classes); + assert_eq!( + classes, + HashSet::from([ + sorted([(split, left), (left, merge)]), // Region containing single BB 'left' + sorted([(split, right), (right, merge)]), // Region containing single BB 'right' + sorted([(head, split), (merge, tail)]), // "Conditional" region containing split+merge choosing between left/right + sorted([(entry, head), (tail, exit)]), // "Loop" region containing body (conditional) + back-edge + Vec::from([(tail, head)]) // The loop back-edge + ]) + ); + + // Again, there's no need for a view of a region here, but check that the + // transformation still works when we can only directly mutate the top level + let root = h.root(); + let m = SiblingMut::::try_new(&mut h, root).unwrap(); + transform_cfg_to_nested(&mut IdentityCfgMap::new(m)); + h.update_validate(&PRELUDE_REGISTRY).unwrap(); + assert_eq!(1, depth(&h, entry)); + assert_eq!(3, depth(&h, head)); + for n in [split, left, right, merge] { + assert_eq!(5, depth(&h, n)); + } + assert_eq!(3, depth(&h, tail)); + assert_eq!(1, depth(&h, exit)); + Ok(()) + } + + #[test] + fn test_cond_in_loop_combined_headers() -> Result<(), BuildError> { + let (h, head, tail) = build_conditional_in_loop_cfg(false)?; + let head = head.node(); + let tail = tail.node(); + // /-> left --\ + // entry -> head > merge -> tail -> exit + // | \-> right -/ | + // \---<---<---<---<---<--<---/ + // Here we would like an indication that we can make two nested regions, + // but there is no edge to act as entry to a region containing just the conditional :-(. + + let v = IdentityCfgMap::new(RootChecked::try_new(&h).unwrap()); + let edge_classes = EdgeClassifier::get_edge_classes(&v); + let IdentityCfgMap { h: _, entry, exit } = v; + // merge is unique predecessor of tail + let merge = *edge_classes + .keys() + .filter(|(_, t)| *t == tail) + .map(|(s, _)| s) + .exactly_one() + .unwrap(); + let [&left, &right] = edge_classes + .keys() + .filter(|(s, _)| *s == head) + .map(|(_, t)| t) + .collect::>()[..] + else { + panic!("Loop header should have two successors"); + }; + let classes = group_by(edge_classes); + assert_eq!( + classes, + HashSet::from([ + sorted([(head, left), (left, merge)]), // Region containing single BB 'left' + sorted([(head, right), (right, merge)]), // Region containing single BB 'right' + Vec::from([(merge, tail)]), // The edge "in the loop", but no other edge in its class to define SESE region + sorted([(entry, head), (tail, exit)]), // "Loop" region containing body (conditional) + back-edge + Vec::from([(tail, head)]) // The loop back-edge + ]) + ); + Ok(()) + } + + fn n_identity( + mut dataflow_builder: T, + pred_const: &ConstID, + ) -> Result { + let w = dataflow_builder.input_wires(); + let u = dataflow_builder.load_const(pred_const); + dataflow_builder.finish_with_outputs([u].into_iter().chain(w)) + } + + fn build_if_then_else_merge + AsRef>( + cfg: &mut CFGBuilder, + const_pred: &ConstID, + unit_const: &ConstID, + ) -> Result<(BasicBlockID, BasicBlockID), BuildError> { + let split = n_identity( + cfg.simple_block_builder(FunctionType::new_endo(NAT), 2)?, + const_pred, + )?; + let merge = build_then_else_merge_from_if(cfg, unit_const, split)?; + Ok((split, merge)) + } + + fn build_then_else_merge_from_if + AsRef>( + cfg: &mut CFGBuilder, + unit_const: &ConstID, + split: BasicBlockID, + ) -> Result { + let merge = n_identity( + cfg.simple_block_builder(FunctionType::new_endo(NAT), 1)?, + unit_const, + )?; + let left = n_identity( + cfg.simple_block_builder(FunctionType::new_endo(NAT), 1)?, + unit_const, + )?; + let right = n_identity( + cfg.simple_block_builder(FunctionType::new_endo(NAT), 1)?, + unit_const, + )?; + cfg.branch(&split, 0, &left)?; + cfg.branch(&split, 1, &right)?; + cfg.branch(&left, 0, &merge)?; + cfg.branch(&right, 0, &merge)?; + Ok(merge) + } + + // /-> left --\ + // entry > merge -> tail -> exit + // \-> right -/ \-<--<-/ + // Result is Hugr plus merge and tail blocks + fn build_cond_then_loop_cfg() -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> { + let mut cfg_builder = CFGBuilder::new(FunctionType::new_endo(NAT))?; + let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2")); + let const_unit = cfg_builder.add_constant(Value::unary_unit_sum()); + + let entry = n_identity( + cfg_builder.simple_entry_builder(type_row![NAT], 2, ExtensionSet::new())?, + &pred_const, + )?; + let merge = build_then_else_merge_from_if(&mut cfg_builder, &const_unit, entry)?; + // The merge block is also the loop header (so it merges three incoming control-flow edges) + let tail = n_identity( + cfg_builder.simple_block_builder(FunctionType::new_endo(NAT), 2)?, + &pred_const, + )?; + cfg_builder.branch(&tail, 1, &merge)?; + cfg_builder.branch(&merge, 0, &tail)?; // trivial "loop body" + let exit = cfg_builder.exit_block(); + cfg_builder.branch(&tail, 0, &exit)?; + + let h = cfg_builder.finish_prelude_hugr()?; + Ok((h, merge, tail)) + } + + // Build a CFG, returning the Hugr + pub(crate) fn build_conditional_in_loop_cfg( + separate_headers: bool, + ) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> { + let mut cfg_builder = CFGBuilder::new(FunctionType::new_endo(NAT))?; + let (head, tail) = build_conditional_in_loop(&mut cfg_builder, separate_headers)?; + let h = cfg_builder.finish_prelude_hugr()?; + Ok((h, head, tail)) + } + + pub(crate) fn build_conditional_in_loop + AsRef>( + cfg_builder: &mut CFGBuilder, + separate_headers: bool, + ) -> Result<(BasicBlockID, BasicBlockID), BuildError> { + let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2")); + let const_unit = cfg_builder.add_constant(Value::unary_unit_sum()); + + let entry = n_identity( + cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?, + &const_unit, + )?; + let (split, merge) = build_if_then_else_merge(cfg_builder, &pred_const, &const_unit)?; + + let head = if separate_headers { + let head = n_identity( + cfg_builder.simple_block_builder(FunctionType::new_endo(NAT), 1)?, + &const_unit, + )?; + cfg_builder.branch(&head, 0, &split)?; + head + } else { + // Combine loop header with split. + split + }; + let tail = n_identity( + cfg_builder.simple_block_builder(FunctionType::new_endo(NAT), 2)?, + &pred_const, + )?; + cfg_builder.branch(&tail, 1, &head)?; + cfg_builder.branch(&merge, 0, &tail)?; + + let exit = cfg_builder.exit_block(); + + cfg_builder.branch(&entry, 0, &head)?; + cfg_builder.branch(&tail, 0, &exit)?; + + Ok((head, tail)) + } + + pub fn depth(h: &Hugr, n: Node) -> u32 { + match h.get_parent(n) { + Some(p) => 1 + depth(h, p), + None => 0, + } + } +} From 7b1e4f4556c372a7cd2581399bbc43cf31c73a54 Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Wed, 22 May 2024 11:51:36 +0100 Subject: [PATCH 02/28] Fix up files to build outside hugr crate. --- hugr-passes/src/const_fold.rs | 67 +++++++++++++++++++++++------------ hugr-passes/src/merge_bbs.rs | 62 ++++++++++++++++---------------- hugr-passes/src/nest_cfgs.rs | 34 +++++++++--------- 3 files changed, 92 insertions(+), 71 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 5d4cffed3..b75a48ab2 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -5,16 +5,16 @@ use std::collections::{BTreeSet, HashMap}; use itertools::Itertools; use thiserror::Error; -use crate::hugr::{SimpleReplacementError, ValidationError}; -use crate::types::SumType; -use crate::Direction; -use crate::{ +use hugr::hugr::{SimpleReplacementError, ValidationError}; +use hugr::types::SumType; +use hugr::Direction; +use hugr::{ builder::{DFGBuilder, Dataflow, DataflowHugr}, extension::{ConstFoldResult, ExtensionRegistry}, hugr::{ + hugrmut::HugrMut, rewrite::consts::{RemoveConst, RemoveLoadConstant}, views::SiblingSubgraph, - HugrMut, }, ops::{OpType, Value}, type_row, @@ -244,19 +244,44 @@ pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { mod test { use super::*; - use crate::extension::prelude::{sum_with_error, BOOL_T}; - use crate::extension::{ExtensionRegistry, PRELUDE}; - use crate::ops::{OpType, UnpackTuple}; - use crate::std_extensions::arithmetic; - use crate::std_extensions::arithmetic::conversions::ConvertOpDef; - use crate::std_extensions::arithmetic::float_ops::FloatOps; - use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; - use crate::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; - use crate::std_extensions::logic::{self, NaryLogic, NotOp}; - use crate::utils::test::{assert_fully_folded, assert_fully_folded_with}; + use hugr::builder::Container; + use hugr::extension::prelude::{sum_with_error, BOOL_T}; + use hugr::extension::{ExtensionRegistry, PRELUDE}; + use hugr::ops::{OpType, UnpackTuple}; + use hugr::std_extensions::arithmetic; + use hugr::std_extensions::arithmetic::conversions::ConvertOpDef; + use hugr::std_extensions::arithmetic::float_ops::FloatOps; + use hugr::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; + use hugr::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; + use hugr::std_extensions::logic::{self, NaryLogic, NotOp}; use rstest::rstest; + /// Check that a hugr just loads and returns a single expected constant. + fn assert_fully_folded(h: &Hugr, expected_value: &Value) { + assert_fully_folded_with(h, |v| v == expected_value) + } + + /// Check that a hugr just loads and returns a single constant, and validate + /// that constant using `check_value`. + /// + /// [CustomConst::equals_const] is not required to be implemented. Use this + /// function for Values containing such a `CustomConst`. + fn assert_fully_folded_with(h: &Hugr, check_value: impl Fn(&Value) -> bool) { + let mut node_count = 0; + + for node in h.children(h.root()) { + let op = h.get_optype(node); + match op { + OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1, + OpType::Const(c) if check_value(c.value()) => node_count += 1, + _ => panic!("unexpected op: {:?}\n{}", op, h.mermaid_string()), + } + } + + assert_eq!(node_count, 4); + } + /// int to constant fn i2c(b: u64) -> Value { Value::extension(ConstInt::new_u(5, b).unwrap()) @@ -301,9 +326,7 @@ mod test { let unpack = build .add_dataflow_op( - UnpackTuple { - tys: type_row![FLOAT64_TYPE, FLOAT64_TYPE], - }, + UnpackTuple::new(type_row![FLOAT64_TYPE, FLOAT64_TYPE]), [tup], ) .unwrap(); @@ -340,7 +363,7 @@ mod test { ignore = "inference fails for test graph, it shouldn't" )] fn test_list_ops() -> Result<(), Box> { - use crate::std_extensions::collections::{self, ListOp, ListValue}; + use hugr::std_extensions::collections::{self, ListOp, ListValue}; let reg = ExtensionRegistry::try_new([ PRELUDE.to_owned(), @@ -443,8 +466,8 @@ mod test { // // We arange things so that the `or` folds away first, leaving the not // with no outputs. - use crate::hugr::NodeType; - use crate::ops::handle::NodeHandle; + use hugr::hugr::NodeType; + use hugr::ops::handle::NodeHandle; let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); let true_wire = build.add_load_value(Value::true_val()); @@ -457,7 +480,7 @@ mod test { ) .unwrap(); let or_node = r.node(); - let parent = build.dfg_node; + let parent = build.container_node(); let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); let mut h = build.finish_hugr_with_outputs(r.outputs(), ®).unwrap(); diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index 06d3b3bfe..8d2b73c06 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -2,15 +2,17 @@ //! and the target BB has no other predecessors. use std::collections::HashMap; +use hugr::builder::{CFGBuilder, HugrBuilder}; +use hugr::hugr::hugrmut::HugrMut; use itertools::Itertools; -use crate::hugr::rewrite::inline_dfg::InlineDFG; -use crate::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec, Replacement}; -use crate::hugr::{HugrMut, RootTagged}; -use crate::ops::handle::CfgID; -use crate::ops::leaf::UnpackTuple; -use crate::ops::{DataflowBlock, DataflowParent, Input, Output, DFG}; -use crate::{Hugr, HugrView, Node}; +use hugr::hugr::rewrite::inline_dfg::InlineDFG; +use hugr::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec, Replacement}; +use hugr::hugr::RootTagged; +use hugr::ops::handle::CfgID; +use hugr::ops::leaf::UnpackTuple; +use hugr::ops::{DataflowBlock, DataflowParent, Input, Output, DFG}; +use hugr::{Hugr, HugrView, Node}; /// Merge any basic blocks that are direct children of the specified CFG /// i.e. where a basic block B has a single successor B' whose only predecessor @@ -52,7 +54,14 @@ fn mk_rep( let pred_ty = cfg.get_optype(pred).as_dataflow_block().unwrap(); let succ_ty = cfg.get_optype(succ).as_dataflow_block().unwrap(); let succ_sig = succ_ty.inner_signature(); - let mut replacement: Hugr = Hugr::new(cfg.root_type().clone()); + + // Make a Hugr with just a single CFG root node having the same signature. + let mut replacement: Hugr = CFGBuilder::new(cfg.root_type().op_signature().unwrap()) + .unwrap() + .finish_prelude_hugr() + .unwrap(); + replacement.remove_node(replacement.children(replacement.root()).next().unwrap()); + let merged = replacement.add_node_with_parent(replacement.root(), { let mut merged_block = DataflowBlock { inputs: pred_ty.inputs.clone(), @@ -98,12 +107,7 @@ fn mk_rep( // At the junction, must unpack the first (tuple, branch predicate) output let tuple_elems = pred_ty.sum_rows.clone().into_iter().exactly_one().unwrap(); - let unp = replacement.add_node_with_parent( - merged, - UnpackTuple { - tys: tuple_elems.clone(), - }, - ); + let unp = replacement.add_node_with_parent(merged, UnpackTuple::new(tuple_elems.clone())); replacement.connect(dfg1, 0, unp, 0); let other_start = tuple_elems.len(); for (i, _) in tuple_elems.iter().enumerate() { @@ -162,15 +166,15 @@ mod test { use itertools::Itertools; use rstest::rstest; - use crate::builder::{CFGBuilder, DFGWrapper, Dataflow, HugrBuilder}; - use crate::extension::prelude::{ConstUsize, PRELUDE_ID, QB_T, USIZE_T}; - use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE, PRELUDE_REGISTRY}; - use crate::hugr::views::sibling::SiblingMut; - use crate::ops::constant::Value; - use crate::ops::handle::CfgID; - use crate::ops::{Lift, LoadConstant, Noop, OpTrait, OpType}; - use crate::types::{FunctionType, Type, TypeRow}; - use crate::{const_extension_ids, type_row, Extension, Hugr, HugrView, Wire}; + use hugr::builder::{CFGBuilder, DFGWrapper, Dataflow, HugrBuilder}; + use hugr::extension::prelude::{ConstUsize, PRELUDE_ID, QB_T, USIZE_T}; + use hugr::extension::{ExtensionRegistry, ExtensionSet, PRELUDE, PRELUDE_REGISTRY}; + use hugr::hugr::views::sibling::SiblingMut; + use hugr::ops::constant::Value; + use hugr::ops::handle::CfgID; + use hugr::ops::{Lift, LoadConstant, Noop, OpTrait, OpType}; + use hugr::types::{FunctionType, Type, TypeRow}; + use hugr::{const_extension_ids, type_row, Extension, Hugr, HugrView, Wire}; use super::merge_basic_blocks; @@ -198,13 +202,7 @@ mod test { fn lifted_unary_unit_sum + AsRef, T>(b: &mut DFGWrapper) -> Wire { let lc = b.add_load_value(Value::unary_unit_sum()); let lift = b - .add_dataflow_op( - Lift { - type_row: Type::new_unit_sum(1).into(), - new_extension: PRELUDE_ID, - }, - [lc], - ) + .add_dataflow_op(Lift::new(Type::new_unit_sum(1).into(), PRELUDE_ID), [lc]) .unwrap(); let [w] = lift.outputs_arr(); w @@ -235,7 +233,7 @@ mod test { .with_extension_delta(ExtensionSet::singleton(&PRELUDE_ID)), )?; let mut no_b1 = h.simple_entry_builder(loop_variants.clone(), 1, ExtensionSet::new())?; - let n = no_b1.add_dataflow_op(Noop { ty: QB_T }, no_b1.input_wires())?; + let n = no_b1.add_dataflow_op(Noop::new(QB_T), no_b1.input_wires())?; let br = lifted_unary_unit_sum(&mut no_b1); let no_b1 = no_b1.finish_with_outputs(br, n.outputs())?; let mut test_block = h.block_builder( @@ -254,7 +252,7 @@ mod test { no_b1 } else { let mut no_b2 = h.simple_block_builder(FunctionType::new_endo(loop_variants), 1)?; - let n = no_b2.add_dataflow_op(Noop { ty: QB_T }, no_b2.input_wires())?; + let n = no_b2.add_dataflow_op(Noop::new(QB_T), no_b2.input_wires())?; let br = lifted_unary_unit_sum(&mut no_b2); let nid = no_b2.finish_with_outputs(br, n.outputs())?; h.branch(&nid, 0, &no_b1)?; diff --git a/hugr-passes/src/nest_cfgs.rs b/hugr-passes/src/nest_cfgs.rs index feae6470b..10ce7c2ac 100644 --- a/hugr-passes/src/nest_cfgs.rs +++ b/hugr-passes/src/nest_cfgs.rs @@ -44,14 +44,14 @@ use std::hash::Hash; use itertools::Itertools; use thiserror::Error; -use crate::hugr::rewrite::outline_cfg::OutlineCfg; -use crate::hugr::views::sibling::SiblingMut; -use crate::hugr::views::{HierarchyView, HugrView, SiblingGraph}; -use crate::hugr::{HugrMut, Rewrite, RootTagged}; -use crate::ops::handle::{BasicBlockID, CfgID}; -use crate::ops::OpTag; -use crate::ops::OpTrait; -use crate::{Direction, Hugr, Node}; +use hugr::hugr::rewrite::outline_cfg::OutlineCfg; +use hugr::hugr::views::sibling::SiblingMut; +use hugr::hugr::views::{HierarchyView, HugrView, SiblingGraph}; +use hugr::hugr::{hugrmut::HugrMut, Rewrite, RootTagged}; +use hugr::ops::handle::{BasicBlockID, CfgID}; +use hugr::ops::OpTag; +use hugr::ops::OpTrait; +use hugr::{Direction, Hugr, Node}; /// A "view" of a CFG in a Hugr which allows basic blocks in the underlying CFG to be split into /// multiple blocks in the view (or merged together). @@ -574,15 +574,15 @@ impl EdgeClassifier { #[cfg(test)] pub(crate) mod test { use super::*; - use crate::builder::{BuildError, CFGBuilder, Container, DataflowSubContainer, HugrBuilder}; - use crate::extension::PRELUDE_REGISTRY; - use crate::extension::{prelude::USIZE_T, ExtensionSet}; - - use crate::hugr::views::RootChecked; - use crate::ops::handle::{ConstID, NodeHandle}; - use crate::ops::Value; - use crate::type_row; - use crate::types::{FunctionType, Type}; + use hugr::builder::{BuildError, CFGBuilder, Container, DataflowSubContainer, HugrBuilder}; + use hugr::extension::PRELUDE_REGISTRY; + use hugr::extension::{prelude::USIZE_T, ExtensionSet}; + + use hugr::hugr::views::RootChecked; + use hugr::ops::handle::{ConstID, NodeHandle}; + use hugr::ops::Value; + use hugr::type_row; + use hugr::types::{FunctionType, Type}; const NAT: Type = USIZE_T; pub fn group_by(h: HashMap) -> HashSet> { From e94d5945d3e544bc19743868616477fd54896337 Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Wed, 22 May 2024 11:57:22 +0100 Subject: [PATCH 03/28] Include hugr-passes in default workspace members. --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 3063ea616..80b2a79f5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ lto = "thin" [workspace] resolver = "2" members = ["hugr", "hugr-passes"] -default-members = ["hugr"] +default-members = ["hugr", "hugr-passes"] [workspace.package] rust-version = "1.75" From 4762a82b9f758751d2958157c64190e4674ab41c Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Wed, 22 May 2024 13:15:18 +0100 Subject: [PATCH 04/28] Copy `half_node` into hugr-passes. --- hugr-passes/src/half_node.rs | 162 +++++++++++++++++++++++++++++++++++ hugr-passes/src/lib.rs | 1 + 2 files changed, 163 insertions(+) create mode 100644 hugr-passes/src/half_node.rs diff --git a/hugr-passes/src/half_node.rs b/hugr-passes/src/half_node.rs new file mode 100644 index 000000000..4edd35d34 --- /dev/null +++ b/hugr-passes/src/half_node.rs @@ -0,0 +1,162 @@ +use std::hash::Hash; + +use super::nest_cfgs::CfgNodeMap; + +use crate::hugr::RootTagged; + +use crate::ops::handle::CfgID; +use crate::ops::{OpTag, OpTrait}; + +use crate::{Direction, Node}; + +/// We provide a view of a cfg where every node has at most one of +/// (multiple predecessors, multiple successors). +/// So for BBs with multiple preds + succs, we generate TWO HalfNode's with a single edge between +/// them; that single edge can then be a region boundary that did not exist before. +/// TODO: this unfortunately doesn't capture all cases: when a node has multiple preds and succs, +/// we could "merge" *any subset* of the in-edges into a single in-edge via an extra empty BB; +/// the in-edge from that extra/empty BB, might be the endpoint of a useful SESE region, +/// but we don't have a way to identify *which subset* to select. (Here we say *all preds* if >1 succ) +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +enum HalfNode { + /// All predecessors of original BB; successors if this does not break rule, else the X + N(Node), + // Exists only for BBs with multiple preds _and_ succs; has a single pred (the N), plus original succs + X(Node), +} + +struct HalfNodeView { + h: H, + entry: Node, + exit: Node, +} + +impl> HalfNodeView { + #[allow(unused)] + pub(crate) fn new(h: H) -> Self { + let (entry, exit) = { + let mut children = h.children(h.root()); + (children.next().unwrap(), children.next().unwrap()) + }; + assert_eq!(h.get_optype(exit).tag(), OpTag::BasicBlockExit); + Self { h, entry, exit } + } + + fn is_multi_node(&self, n: Node) -> bool { + // TODO if is the entry-node, should we pretend there's an extra predecessor? (The "outside") + // We could also setify here before counting, but never + self.bb_preds(n).take(2).count() + self.bb_succs(n).take(2).count() == 4 + } + fn resolve_out(&self, n: Node) -> HalfNode { + if self.is_multi_node(n) { + HalfNode::X(n) + } else { + HalfNode::N(n) + } + } + + fn bb_succs(&self, n: Node) -> impl Iterator + '_ { + self.h.neighbours(n, Direction::Outgoing) + } + fn bb_preds(&self, n: Node) -> impl Iterator + '_ { + self.h.neighbours(n, Direction::Incoming) + } +} + +impl> CfgNodeMap for HalfNodeView { + type Iterator<'c> = as IntoIterator>::IntoIter where Self: 'c; + fn entry_node(&self) -> HalfNode { + HalfNode::N(self.entry) + } + fn exit_node(&self) -> HalfNode { + assert!(self.bb_succs(self.exit).count() == 0); + HalfNode::N(self.exit) + } + fn predecessors(&self, h: HalfNode) -> Self::Iterator<'_> { + let mut ps = Vec::new(); + match h { + HalfNode::N(ni) => ps.extend(self.bb_preds(ni).map(|n| self.resolve_out(n))), + HalfNode::X(ni) => ps.push(HalfNode::N(ni)), + }; + if h == self.entry_node() { + ps.push(self.exit_node()); + } + ps.into_iter() + } + fn successors(&self, n: HalfNode) -> Self::Iterator<'_> { + let mut succs = Vec::new(); + match n { + HalfNode::N(ni) if self.is_multi_node(ni) => succs.push(HalfNode::X(ni)), + HalfNode::N(ni) | HalfNode::X(ni) => succs.extend(self.bb_succs(ni).map(HalfNode::N)), + }; + succs.into_iter() + } +} + +#[cfg(test)] +mod test { + use super::super::nest_cfgs::{test::*, EdgeClassifier}; + use super::{HalfNode, HalfNodeView}; + use crate::builder::BuildError; + use crate::hugr::views::RootChecked; + use crate::ops::handle::NodeHandle; + + use itertools::Itertools; + use std::collections::HashSet; + #[test] + fn test_cond_in_loop_combined_headers() -> Result<(), BuildError> { + let (h, main, tail) = build_conditional_in_loop_cfg(false)?; + // /-> left --\ + // entry -> main > merge -> tail -> exit + // | \-> right -/ | + // \---<---<---<---<---<--<---/ + // The "main" has two predecessors (entry and tail) and two successors (left and right) so + // we get HalfNode::N(main) aka "head" and HalfNode::X(main) aka "split" in this form: + // /-> left --\ + // N(entry) -> head -> split > N(merge) -> N(tail) -> N(exit) + // | \-> right -/ | + // \---<---<---<---<---<---<---<---<---<---/ + // Allowing to identify two nested regions (and fixing the problem with an IdentityCfgMap on the same example) + + let v = HalfNodeView::new(RootChecked::try_new(&h).unwrap()); + + let edge_classes = EdgeClassifier::get_edge_classes(&v); + let HalfNodeView { h: _, entry, exit } = v; + + let head = HalfNode::N(main.node()); + let tail = HalfNode::N(tail.node()); + let split = HalfNode::X(main.node()); + let (entry, exit) = (HalfNode::N(entry), HalfNode::N(exit)); + // merge is unique predecessor of tail + let merge = *edge_classes + .keys() + .filter(|(_, t)| *t == tail) + .map(|(s, _)| s) + .exactly_one() + .unwrap(); + let [&left, &right] = edge_classes + .keys() + .filter(|(s, _)| *s == split) + .map(|(_, t)| t) + .collect::>()[..] + else { + panic!("Split node should have two successors"); + }; + let classes = group_by(edge_classes); + assert_eq!( + classes, + HashSet::from([ + sorted([(split, left), (left, merge)]), // Region containing single BB 'left'. + sorted([(split, right), (right, merge)]), // Region containing single BB 'right'. + sorted([(head, split), (merge, tail)]), // The inner "conditional" region. + sorted([(entry, head), (tail, exit)]), // "Loop" region containing body (conditional) + back-edge. + Vec::from([(tail, head)]) // The loop back-edge. + ]) + ); + Ok(()) + } + + // Sadly this HalfNode logic is too simple to fix the test_cond_then_loop_combined case + // (The "merge" node is not split, but needs to be split with the tail->merge edge incoming + // to the *second* node after splitting). +} diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 2ed5859f0..1670995e8 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,3 +1,4 @@ pub mod const_fold; +mod half_node; pub mod merge_bbs; pub mod nest_cfgs; From 7bff2486c5f47cc9111f49c5bc9a2db571cd8a48 Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Wed, 22 May 2024 13:17:26 +0100 Subject: [PATCH 05/28] Fix up `half_node` to build outside hugr crate. --- hugr-passes/src/half_node.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/half_node.rs b/hugr-passes/src/half_node.rs index 4edd35d34..cb9c6e55e 100644 --- a/hugr-passes/src/half_node.rs +++ b/hugr-passes/src/half_node.rs @@ -2,12 +2,12 @@ use std::hash::Hash; use super::nest_cfgs::CfgNodeMap; -use crate::hugr::RootTagged; +use hugr::hugr::RootTagged; -use crate::ops::handle::CfgID; -use crate::ops::{OpTag, OpTrait}; +use hugr::ops::handle::CfgID; +use hugr::ops::{OpTag, OpTrait}; -use crate::{Direction, Node}; +use hugr::{Direction, Node}; /// We provide a view of a cfg where every node has at most one of /// (multiple predecessors, multiple successors). @@ -97,9 +97,9 @@ impl> CfgNodeMap for HalfNodeView mod test { use super::super::nest_cfgs::{test::*, EdgeClassifier}; use super::{HalfNode, HalfNodeView}; - use crate::builder::BuildError; - use crate::hugr::views::RootChecked; - use crate::ops::handle::NodeHandle; + use hugr::builder::BuildError; + use hugr::hugr::views::RootChecked; + use hugr::ops::handle::NodeHandle; use itertools::Itertools; use std::collections::HashSet; From 880d4ab5395da56d414f40671470a61ebdce4ba3 Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Wed, 22 May 2024 16:36:56 +0100 Subject: [PATCH 06/28] Make `utils` public and move `sorted_consts` into it. --- hugr-passes/src/const_fold.rs | 16 +--------------- hugr/src/algorithm/const_fold.rs | 16 +--------------- hugr/src/lib.rs | 2 +- .../arithmetic/float_ops/const_fold.rs | 2 +- hugr/src/std_extensions/collections.rs | 2 +- hugr/src/std_extensions/logic.rs | 2 +- hugr/src/utils.rs | 19 +++++++++++++++++++ 7 files changed, 25 insertions(+), 34 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index b75a48ab2..33fe8e367 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -19,6 +19,7 @@ use hugr::{ ops::{OpType, Value}, type_row, types::FunctionType, + utils::sorted_consts, Hugr, HugrView, IncomingPort, Node, SimpleReplacement, }; @@ -45,21 +46,6 @@ fn out_row(consts: impl IntoIterator) -> ConstFoldResult { Some(vec) } -/// Sort folding inputs with [`IncomingPort`] as key -fn sort_by_in_port(consts: &[(IncomingPort, Value)]) -> Vec<&(IncomingPort, Value)> { - let mut v: Vec<_> = consts.iter().collect(); - v.sort_by_key(|(i, _)| i); - v -} - -/// Sort some input constants by port and just return the constants. -pub(crate) fn sorted_consts(consts: &[(IncomingPort, Value)]) -> Vec<&Value> { - sort_by_in_port(consts) - .into_iter() - .map(|(_, c)| c) - .collect() -} - /// For a given op and consts, attempt to evaluate the op. pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldResult { let fold_result = match op { diff --git a/hugr/src/algorithm/const_fold.rs b/hugr/src/algorithm/const_fold.rs index 5d4cffed3..c76514cd4 100644 --- a/hugr/src/algorithm/const_fold.rs +++ b/hugr/src/algorithm/const_fold.rs @@ -7,6 +7,7 @@ use thiserror::Error; use crate::hugr::{SimpleReplacementError, ValidationError}; use crate::types::SumType; +use crate::utils::sorted_consts; use crate::Direction; use crate::{ builder::{DFGBuilder, Dataflow, DataflowHugr}, @@ -45,21 +46,6 @@ fn out_row(consts: impl IntoIterator) -> ConstFoldResult { Some(vec) } -/// Sort folding inputs with [`IncomingPort`] as key -fn sort_by_in_port(consts: &[(IncomingPort, Value)]) -> Vec<&(IncomingPort, Value)> { - let mut v: Vec<_> = consts.iter().collect(); - v.sort_by_key(|(i, _)| i); - v -} - -/// Sort some input constants by port and just return the constants. -pub(crate) fn sorted_consts(consts: &[(IncomingPort, Value)]) -> Vec<&Value> { - sort_by_in_port(consts) - .into_iter() - .map(|(_, c)| c) - .collect() -} - /// For a given op and consts, attempt to evaluate the op. pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldResult { let fold_result = match op { diff --git a/hugr/src/lib.rs b/hugr/src/lib.rs index 7c1cdca5f..f2859aacc 100644 --- a/hugr/src/lib.rs +++ b/hugr/src/lib.rs @@ -149,7 +149,7 @@ pub mod macros; pub mod ops; pub mod std_extensions; pub mod types; -mod utils; +pub mod utils; pub use crate::core::{ CircuitUnit, Direction, IncomingPort, Node, NodeIndex, OutgoingPort, Port, PortIndex, Wire, diff --git a/hugr/src/std_extensions/arithmetic/float_ops/const_fold.rs b/hugr/src/std_extensions/arithmetic/float_ops/const_fold.rs index a97a2d1c7..974dbe9b6 100644 --- a/hugr/src/std_extensions/arithmetic/float_ops/const_fold.rs +++ b/hugr/src/std_extensions/arithmetic/float_ops/const_fold.rs @@ -1,8 +1,8 @@ use crate::{ - algorithm::const_fold::sorted_consts, extension::{prelude::ConstString, ConstFold, ConstFoldResult, OpDef}, ops, std_extensions::arithmetic::float_types::ConstF64, + utils::sorted_consts, IncomingPort, }; diff --git a/hugr/src/std_extensions/collections.rs b/hugr/src/std_extensions/collections.rs index 747841a6b..392417dff 100644 --- a/hugr/src/std_extensions/collections.rs +++ b/hugr/src/std_extensions/collections.rs @@ -8,7 +8,6 @@ use crate::ops::constant::ValueName; use crate::ops::{OpName, Value}; use crate::types::TypeName; use crate::{ - algorithm::const_fold::sorted_consts, extension::{ simple_op::{MakeExtensionOp, OpLoadError}, ConstFold, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, TypeDef, @@ -20,6 +19,7 @@ use crate::{ type_param::{TypeArg, TypeParam}, CustomCheckFailure, CustomType, FunctionType, PolyFuncType, Type, TypeBound, }, + utils::sorted_consts, Extension, }; diff --git a/hugr/src/std_extensions/logic.rs b/hugr/src/std_extensions/logic.rs index f747428c8..d6e51811f 100644 --- a/hugr/src/std_extensions/logic.rs +++ b/hugr/src/std_extensions/logic.rs @@ -6,7 +6,6 @@ use crate::extension::{ConstFold, ConstFoldResult}; use crate::ops::constant::ValueName; use crate::ops::{OpName, Value}; use crate::{ - algorithm::const_fold::sorted_consts, extension::{ prelude::BOOL_T, simple_op::{try_from_name, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError}, @@ -18,6 +17,7 @@ use crate::{ type_param::{TypeArg, TypeParam}, FunctionType, }, + utils::sorted_consts, Extension, IncomingPort, }; use lazy_static::lazy_static; diff --git a/hugr/src/utils.rs b/hugr/src/utils.rs index d7ce0e101..693eb2f7b 100644 --- a/hugr/src/utils.rs +++ b/hugr/src/utils.rs @@ -1,7 +1,11 @@ +//! General utilities. + use std::fmt::{self, Debug, Display}; use itertools::Itertools; +use crate::{ops::Value, IncomingPort}; + /// Write a comma separated list of of some types. /// Like debug_list, but using the Display instance rather than Debug, /// and not adding surrounding square brackets. @@ -205,6 +209,21 @@ pub(crate) mod test_quantum_extension { } } +/// Sort folding inputs with [`IncomingPort`] as key +fn sort_by_in_port(consts: &[(IncomingPort, Value)]) -> Vec<&(IncomingPort, Value)> { + let mut v: Vec<_> = consts.iter().collect(); + v.sort_by_key(|(i, _)| i); + v +} + +/// Sort some input constants by port and just return the constants. +pub fn sorted_consts(consts: &[(IncomingPort, Value)]) -> Vec<&Value> { + sort_by_in_port(consts) + .into_iter() + .map(|(_, c)| c) + .collect() +} + #[allow(dead_code)] // Test only utils #[cfg(test)] From 0fd72112315f0374d796a8c477a56e6285a80314 Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Thu, 23 May 2024 08:29:11 +0100 Subject: [PATCH 07/28] Inherit `extension_inference` feature from hugr crate. --- hugr-passes/Cargo.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 6fe4b8607..e25a4cb52 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -9,5 +9,8 @@ itertools = "0.12.0" paste = "1.0" thiserror = "1.0.28" +[features] +extension_inference = ["hugr/extension_inference"] + [dev-dependencies] rstest = "0.19.0" From 04c41757968e343550b101ae5d55a84ccf6b7a59 Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Thu, 23 May 2024 08:41:51 +0100 Subject: [PATCH 08/28] Make `Hugr::new()` public. Otherwise extension inference fails in some tests. --- hugr-passes/src/merge_bbs.rs | 7 +------ hugr/src/hugr.rs | 10 +++++----- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index 8d2b73c06..17adc4e57 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -2,7 +2,6 @@ //! and the target BB has no other predecessors. use std::collections::HashMap; -use hugr::builder::{CFGBuilder, HugrBuilder}; use hugr::hugr::hugrmut::HugrMut; use itertools::Itertools; @@ -56,11 +55,7 @@ fn mk_rep( let succ_sig = succ_ty.inner_signature(); // Make a Hugr with just a single CFG root node having the same signature. - let mut replacement: Hugr = CFGBuilder::new(cfg.root_type().op_signature().unwrap()) - .unwrap() - .finish_prelude_hugr() - .unwrap(); - replacement.remove_node(replacement.children(replacement.root()).next().unwrap()); + let mut replacement: Hugr = Hugr::new(cfg.root_type().clone()); let merged = replacement.add_node_with_parent(replacement.root(), { let mut merged_block = DataflowBlock { diff --git a/hugr/src/hugr.rs b/hugr/src/hugr.rs index 4d25807b3..f53dfe482 100644 --- a/hugr/src/hugr.rs +++ b/hugr/src/hugr.rs @@ -192,6 +192,11 @@ pub type NodeMetadataMap = serde_json::Map; /// Public API for HUGRs. impl Hugr { + /// Create a new Hugr, with a single root node. + pub fn new(root_node: NodeType) -> Self { + Self::with_capacity(root_node, 0, 0) + } + /// Resolve extension ops, infer extensions used, and pass the closure into validation pub fn update_validate( &mut self, @@ -237,11 +242,6 @@ impl Hugr { /// Internal API for HUGRs, not intended for use by users. impl Hugr { - /// Create a new Hugr, with a single root node. - pub(crate) fn new(root_node: NodeType) -> Self { - Self::with_capacity(root_node, 0, 0) - } - /// Create a new Hugr, with a single root node and preallocated capacity. // TODO: Make this take a NodeType pub(crate) fn with_capacity(root_node: NodeType, nodes: usize, ports: usize) -> Self { From 24f1df25c9ea06321a3b530b6ca660064c209092 Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Thu, 23 May 2024 08:44:40 +0100 Subject: [PATCH 09/28] Remove `merge_bbs` from hugr crate. --- hugr/src/algorithm.rs | 1 - hugr/src/algorithm/merge_bbs.rs | 398 -------------------------------- 2 files changed, 399 deletions(-) delete mode 100644 hugr/src/algorithm/merge_bbs.rs diff --git a/hugr/src/algorithm.rs b/hugr/src/algorithm.rs index 585e25a01..633231504 100644 --- a/hugr/src/algorithm.rs +++ b/hugr/src/algorithm.rs @@ -2,5 +2,4 @@ pub mod const_fold; mod half_node; -pub mod merge_bbs; pub mod nest_cfgs; diff --git a/hugr/src/algorithm/merge_bbs.rs b/hugr/src/algorithm/merge_bbs.rs deleted file mode 100644 index 06d3b3bfe..000000000 --- a/hugr/src/algorithm/merge_bbs.rs +++ /dev/null @@ -1,398 +0,0 @@ -//! Merge BBs along control-flow edges where the source BB has no other successors -//! and the target BB has no other predecessors. -use std::collections::HashMap; - -use itertools::Itertools; - -use crate::hugr::rewrite::inline_dfg::InlineDFG; -use crate::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec, Replacement}; -use crate::hugr::{HugrMut, RootTagged}; -use crate::ops::handle::CfgID; -use crate::ops::leaf::UnpackTuple; -use crate::ops::{DataflowBlock, DataflowParent, Input, Output, DFG}; -use crate::{Hugr, HugrView, Node}; - -/// Merge any basic blocks that are direct children of the specified CFG -/// i.e. where a basic block B has a single successor B' whose only predecessor -/// is B, B and B' can be combined. -pub fn merge_basic_blocks(cfg: &mut impl HugrMut) { - let mut worklist = cfg.nodes().collect::>(); - while let Some(n) = worklist.pop() { - // Consider merging n with its successor - let Ok(succ) = cfg.output_neighbours(n).exactly_one() else { - continue; - }; - if cfg.input_neighbours(succ).count() != 1 { - continue; - }; - if cfg.children(cfg.root()).take(2).contains(&succ) { - // If succ is... - // - the entry block, that has an implicit extra in-edge, so cannot merge with n. - // - the exit block, nodes in n should move *outside* the CFG - a separate pass. - continue; - }; - let (rep, merge_bb, dfgs) = mk_rep(cfg, n, succ); - let node_map = cfg.hugr_mut().apply_rewrite(rep).unwrap(); - let merged_bb = *node_map.get(&merge_bb).unwrap(); - for dfg_id in dfgs { - let n_id = *node_map.get(&dfg_id).unwrap(); - cfg.hugr_mut() - .apply_rewrite(InlineDFG(n_id.into())) - .unwrap(); - } - worklist.push(merged_bb); - } -} - -fn mk_rep( - cfg: &impl RootTagged, - pred: Node, - succ: Node, -) -> (Replacement, Node, [Node; 2]) { - let pred_ty = cfg.get_optype(pred).as_dataflow_block().unwrap(); - let succ_ty = cfg.get_optype(succ).as_dataflow_block().unwrap(); - let succ_sig = succ_ty.inner_signature(); - let mut replacement: Hugr = Hugr::new(cfg.root_type().clone()); - let merged = replacement.add_node_with_parent(replacement.root(), { - let mut merged_block = DataflowBlock { - inputs: pred_ty.inputs.clone(), - ..succ_ty.clone() - }; - merged_block.extension_delta = merged_block - .extension_delta - .union(pred_ty.extension_delta.clone()); - merged_block - }); - let input = replacement.add_node_with_parent( - merged, - Input { - types: pred_ty.inputs.clone(), - }, - ); - let output = replacement.add_node_with_parent( - merged, - Output { - types: succ_sig.output.clone(), - }, - ); - - let dfg1 = replacement.add_node_with_parent( - merged, - DFG { - signature: pred_ty.inner_signature().clone(), - }, - ); - for (i, _) in pred_ty.inputs.iter().enumerate() { - replacement.connect(input, i, dfg1, i) - } - - let dfg2 = replacement.add_node_with_parent( - merged, - DFG { - signature: succ_sig.clone(), - }, - ); - for (i, _) in succ_sig.output.iter().enumerate() { - replacement.connect(dfg2, i, output, i) - } - - // At the junction, must unpack the first (tuple, branch predicate) output - let tuple_elems = pred_ty.sum_rows.clone().into_iter().exactly_one().unwrap(); - let unp = replacement.add_node_with_parent( - merged, - UnpackTuple { - tys: tuple_elems.clone(), - }, - ); - replacement.connect(dfg1, 0, unp, 0); - let other_start = tuple_elems.len(); - for (i, _) in tuple_elems.iter().enumerate() { - replacement.connect(unp, i, dfg2, i) - } - for (i, _) in pred_ty.other_outputs.iter().enumerate() { - replacement.connect(dfg1, i + 1, dfg2, i + other_start) - } - // If there are edges from succ back to pred, we cannot do these via the mu_inp/out/new - // edge-maps as both source and target of the new edge are in the replacement Hugr - for (_, src_pos) in cfg.all_linked_outputs(pred).filter(|(src, _)| *src == succ) { - replacement.connect(merged, src_pos, merged, 0); - } - let rep = Replacement { - removal: vec![pred, succ], - replacement, - adoptions: HashMap::from([(dfg1, pred), (dfg2, succ)]), - mu_inp: cfg - .all_linked_outputs(pred) - .filter(|(src, _)| *src != succ) - .map(|(src, src_pos)| NewEdgeSpec { - src, - tgt: merged, - kind: NewEdgeKind::ControlFlow { src_pos }, - }) - .collect(), - mu_out: cfg - .node_outputs(succ) - .filter_map(|src_pos| { - let tgt = cfg - .linked_inputs(succ, src_pos) - .exactly_one() - .ok() - .unwrap() - .0; - if tgt == pred { - None - } else { - Some(NewEdgeSpec { - src: merged, - tgt, - kind: NewEdgeKind::ControlFlow { src_pos }, - }) - } - }) - .collect(), - mu_new: vec![], - }; - (rep, merged, [dfg1, dfg2]) -} - -#[cfg(test)] -mod test { - use std::collections::HashSet; - - use itertools::Itertools; - use rstest::rstest; - - use crate::builder::{CFGBuilder, DFGWrapper, Dataflow, HugrBuilder}; - use crate::extension::prelude::{ConstUsize, PRELUDE_ID, QB_T, USIZE_T}; - use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE, PRELUDE_REGISTRY}; - use crate::hugr::views::sibling::SiblingMut; - use crate::ops::constant::Value; - use crate::ops::handle::CfgID; - use crate::ops::{Lift, LoadConstant, Noop, OpTrait, OpType}; - use crate::types::{FunctionType, Type, TypeRow}; - use crate::{const_extension_ids, type_row, Extension, Hugr, HugrView, Wire}; - - use super::merge_basic_blocks; - - const_extension_ids! { - const EXT_ID: ExtensionId = "TestExt"; - } - - fn extension() -> Extension { - let mut e = Extension::new(EXT_ID); - e.add_op( - "Test".into(), - String::new(), - FunctionType::new( - type_row![QB_T, USIZE_T], - TypeRow::from(vec![Type::new_sum(vec![ - type_row![QB_T], - type_row![USIZE_T], - ])]), - ), - ) - .unwrap(); - e - } - - fn lifted_unary_unit_sum + AsRef, T>(b: &mut DFGWrapper) -> Wire { - let lc = b.add_load_value(Value::unary_unit_sum()); - let lift = b - .add_dataflow_op( - Lift { - type_row: Type::new_unit_sum(1).into(), - new_extension: PRELUDE_ID, - }, - [lc], - ) - .unwrap(); - let [w] = lift.outputs_arr(); - w - } - - #[rstest] - #[case(true)] - #[case(false)] - fn in_loop(#[case] self_loop: bool) -> Result<(), Box> { - /* self_loop==False: - -> Noop1 -----> Test -> Exit -> Noop1AndTest --> Exit - | | => / \ - \-<- Noop2 <-/ \-<- Noop2 <-/ - (Noop2 -> Noop1 cannot be merged because Noop1 is the entry node) - - self_loop==True: - -> Noop --> Test -> Exit -> NoopAndTest --> Exit - | | => / \ - \--<--<--/ \--<-----<--/ - */ - let loop_variants = type_row![QB_T]; - let exit_types = type_row![USIZE_T]; - let e = extension(); - let tst_op = e.instantiate_extension_op("Test", [], &PRELUDE_REGISTRY)?; - let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), e])?; - let mut h = CFGBuilder::new( - FunctionType::new(loop_variants.clone(), exit_types.clone()) - .with_extension_delta(ExtensionSet::singleton(&PRELUDE_ID)), - )?; - let mut no_b1 = h.simple_entry_builder(loop_variants.clone(), 1, ExtensionSet::new())?; - let n = no_b1.add_dataflow_op(Noop { ty: QB_T }, no_b1.input_wires())?; - let br = lifted_unary_unit_sum(&mut no_b1); - let no_b1 = no_b1.finish_with_outputs(br, n.outputs())?; - let mut test_block = h.block_builder( - loop_variants.clone(), - vec![loop_variants.clone(), exit_types], - ExtensionSet::singleton(&PRELUDE_ID), - type_row![], - )?; - let [test_input] = test_block.input_wires_arr(); - let usize_cst = test_block.add_load_value(ConstUsize::new(1)); - let [tst] = test_block - .add_dataflow_op(tst_op, [test_input, usize_cst])? - .outputs_arr(); - let test_block = test_block.finish_with_outputs(tst, [])?; - let loop_backedge_target = if self_loop { - no_b1 - } else { - let mut no_b2 = h.simple_block_builder(FunctionType::new_endo(loop_variants), 1)?; - let n = no_b2.add_dataflow_op(Noop { ty: QB_T }, no_b2.input_wires())?; - let br = lifted_unary_unit_sum(&mut no_b2); - let nid = no_b2.finish_with_outputs(br, n.outputs())?; - h.branch(&nid, 0, &no_b1)?; - nid - }; - h.branch(&no_b1, 0, &test_block)?; - h.branch(&test_block, 0, &loop_backedge_target)?; - h.branch(&test_block, 1, &h.exit_block())?; - - let mut h = h.finish_hugr(®)?; - let r = h.root(); - merge_basic_blocks(&mut SiblingMut::::try_new(&mut h, r)?); - h.update_validate(®).unwrap(); - assert_eq!(r, h.root()); - assert!(matches!(h.get_optype(r), OpType::CFG(_))); - let [entry, exit] = h - .children(r) - .take(2) - .collect::>() - .try_into() - .unwrap(); - // Check the Noop('s) is/are in the right block(s) - let nops = h - .nodes() - .filter(|n| matches!(h.get_optype(*n), OpType::Noop(_))); - let (entry_nop, expected_backedge_target) = if self_loop { - assert_eq!(h.children(r).len(), 2); - (nops.exactly_one().ok().unwrap(), entry) - } else { - let [_, _, no_b2] = h.children(r).collect::>().try_into().unwrap(); - let mut nops = nops.collect::>(); - let entry_nop_idx = nops - .iter() - .position(|n| h.get_parent(*n) == Some(entry)) - .unwrap(); - let entry_nop = nops[entry_nop_idx]; - nops.remove(entry_nop_idx); - let [n_op2] = nops.try_into().unwrap(); - assert_eq!(h.get_parent(n_op2), Some(no_b2)); - (entry_nop, no_b2) - }; - assert_eq!(h.get_parent(entry_nop), Some(entry)); - assert_eq!( - h.output_neighbours(entry).collect::>(), - HashSet::from([expected_backedge_target, exit]) - ); - // And the Noop in the entry block is consumed by the custom Test op - let tst = find_unique(h.nodes(), |n| { - matches!(h.get_optype(*n), OpType::CustomOp(_)) - }); - assert_eq!(h.get_parent(tst), Some(entry)); - assert_eq!( - h.output_neighbours(entry_nop).collect::>(), - vec![tst] - ); - Ok(()) - } - - #[test] - fn triple_with_permute() -> Result<(), Box> { - // Blocks are just BB1 -> BB2 -> BB3 --> Exit. - // CFG Normalization would move everything outside the CFG and elide the CFG altogether, - // but this is an easy-to-construct test of merge-basic-blocks only (no CFG normalization). - let e = extension(); - let tst_op: OpType = e - .instantiate_extension_op("Test", &[], &PRELUDE_REGISTRY)? - .into(); - let [res_t] = tst_op - .dataflow_signature() - .unwrap() - .output - .into_owned() - .try_into() - .unwrap(); - let mut h = CFGBuilder::new( - FunctionType::new(QB_T, res_t.clone()) - .with_extension_delta(ExtensionSet::singleton(&PRELUDE_ID)), - )?; - let mut bb1 = h.entry_builder( - vec![type_row![]], - type_row![USIZE_T, QB_T], - ExtensionSet::singleton(&PRELUDE_ID), - )?; - let [inw] = bb1.input_wires_arr(); - let load_cst = bb1.add_load_value(ConstUsize::new(1)); - let pred = lifted_unary_unit_sum(&mut bb1); - let bb1 = bb1.finish_with_outputs(pred, [load_cst, inw])?; - - let mut bb2 = h.block_builder( - type_row![USIZE_T, QB_T], - vec![type_row![]], - ExtensionSet::new(), - type_row![QB_T, USIZE_T], - )?; - let [u, q] = bb2.input_wires_arr(); - let pred = lifted_unary_unit_sum(&mut bb2); - let bb2 = bb2.finish_with_outputs(pred, [q, u])?; - - let mut bb3 = h.block_builder( - type_row![QB_T, USIZE_T], - vec![type_row![]], - ExtensionSet::new(), - res_t.clone().into(), - )?; - let [q, u] = bb3.input_wires_arr(); - let tst = bb3.add_dataflow_op(tst_op, [q, u])?; - let pred = lifted_unary_unit_sum(&mut bb3); - let bb3 = bb3.finish_with_outputs(pred, tst.outputs())?; - // Now add control-flow edges between basic blocks - h.branch(&bb1, 0, &bb2)?; - h.branch(&bb2, 0, &bb3)?; - h.branch(&bb3, 0, &h.exit_block())?; - - let reg = ExtensionRegistry::try_new([e, PRELUDE.to_owned()])?; - let mut h = h.finish_hugr(®)?; - let root = h.root(); - merge_basic_blocks(&mut SiblingMut::try_new(&mut h, root)?); - h.update_validate(®)?; - - // Should only be one BB left - let [bb, _exit] = h.children(h.root()).collect::>().try_into().unwrap(); - let tst = find_unique(h.nodes(), |n| { - matches!(h.get_optype(*n), OpType::CustomOp(_)) - }); - assert_eq!(h.get_parent(tst), Some(bb)); - - let inp = find_unique(h.nodes(), |n| matches!(h.get_optype(*n), OpType::Input(_))); - let mut tst_inputs = h.input_neighbours(tst).collect::>(); - tst_inputs.remove(tst_inputs.iter().find_position(|n| **n == inp).unwrap().0); - let [other_input] = tst_inputs.try_into().unwrap(); - assert_eq!( - h.get_optype(other_input), - &(LoadConstant { datatype: USIZE_T }.into()) - ); - Ok(()) - } - - fn find_unique(items: impl Iterator, pred: impl Fn(&T) -> bool) -> T { - items.filter(pred).exactly_one().ok().unwrap() - } -} From 2a1ca98529db9a6d00ecc11c6a43e53c98229015 Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Thu, 23 May 2024 09:03:46 +0100 Subject: [PATCH 10/28] Remove `half_node` and `nest_cfgs` from hugr crate. Move one test into the hugr-passes crate, and refactor `depth()` into `utils`, to avoid code duplication. --- hugr-passes/src/nest_cfgs.rs | 30 +- hugr/src/algorithm.rs | 2 - hugr/src/algorithm/half_node.rs | 162 ---- hugr/src/algorithm/nest_cfgs.rs | 946 ----------------------- hugr/src/hugr/rewrite/insert_identity.rs | 21 - hugr/src/hugr/rewrite/replace.rs | 2 +- hugr/src/utils.rs | 10 +- 7 files changed, 32 insertions(+), 1141 deletions(-) delete mode 100644 hugr/src/algorithm/half_node.rs delete mode 100644 hugr/src/algorithm/nest_cfgs.rs diff --git a/hugr-passes/src/nest_cfgs.rs b/hugr-passes/src/nest_cfgs.rs index 10ce7c2ac..4231e2583 100644 --- a/hugr-passes/src/nest_cfgs.rs +++ b/hugr-passes/src/nest_cfgs.rs @@ -578,11 +578,13 @@ pub(crate) mod test { use hugr::extension::PRELUDE_REGISTRY; use hugr::extension::{prelude::USIZE_T, ExtensionSet}; + use hugr::hugr::rewrite::insert_identity::{IdentityInsertion, IdentityInsertionError}; use hugr::hugr::views::RootChecked; use hugr::ops::handle::{ConstID, NodeHandle}; use hugr::ops::Value; use hugr::type_row; - use hugr::types::{FunctionType, Type}; + use hugr::types::{EdgeKind, FunctionType, Type}; + use hugr::utils::depth; const NAT: Type = USIZE_T; pub fn group_by(h: HashMap) -> HashSet> { @@ -814,6 +816,25 @@ pub(crate) mod test { Ok(()) } + #[test] + fn incorrect_insertion() { + let (mut h, _, tail) = build_conditional_in_loop_cfg(false).unwrap(); + + let final_node = tail.node(); + + let final_node_input = h.node_inputs(final_node).next().unwrap(); + + let rw = IdentityInsertion::new(final_node, final_node_input); + + let apply_result = h.apply_rewrite(rw); + assert_eq!( + apply_result, + Err(IdentityInsertionError::InvalidPortKind(Some( + EdgeKind::ControlFlow + ))) + ); + } + fn n_identity( mut dataflow_builder: T, pred_const: &ConstID, @@ -936,11 +957,4 @@ pub(crate) mod test { Ok((head, tail)) } - - pub fn depth(h: &Hugr, n: Node) -> u32 { - match h.get_parent(n) { - Some(p) => 1 + depth(h, p), - None => 0, - } - } } diff --git a/hugr/src/algorithm.rs b/hugr/src/algorithm.rs index 633231504..685d827c8 100644 --- a/hugr/src/algorithm.rs +++ b/hugr/src/algorithm.rs @@ -1,5 +1,3 @@ //! Algorithms using the Hugr. pub mod const_fold; -mod half_node; -pub mod nest_cfgs; diff --git a/hugr/src/algorithm/half_node.rs b/hugr/src/algorithm/half_node.rs deleted file mode 100644 index 4edd35d34..000000000 --- a/hugr/src/algorithm/half_node.rs +++ /dev/null @@ -1,162 +0,0 @@ -use std::hash::Hash; - -use super::nest_cfgs::CfgNodeMap; - -use crate::hugr::RootTagged; - -use crate::ops::handle::CfgID; -use crate::ops::{OpTag, OpTrait}; - -use crate::{Direction, Node}; - -/// We provide a view of a cfg where every node has at most one of -/// (multiple predecessors, multiple successors). -/// So for BBs with multiple preds + succs, we generate TWO HalfNode's with a single edge between -/// them; that single edge can then be a region boundary that did not exist before. -/// TODO: this unfortunately doesn't capture all cases: when a node has multiple preds and succs, -/// we could "merge" *any subset* of the in-edges into a single in-edge via an extra empty BB; -/// the in-edge from that extra/empty BB, might be the endpoint of a useful SESE region, -/// but we don't have a way to identify *which subset* to select. (Here we say *all preds* if >1 succ) -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] -enum HalfNode { - /// All predecessors of original BB; successors if this does not break rule, else the X - N(Node), - // Exists only for BBs with multiple preds _and_ succs; has a single pred (the N), plus original succs - X(Node), -} - -struct HalfNodeView { - h: H, - entry: Node, - exit: Node, -} - -impl> HalfNodeView { - #[allow(unused)] - pub(crate) fn new(h: H) -> Self { - let (entry, exit) = { - let mut children = h.children(h.root()); - (children.next().unwrap(), children.next().unwrap()) - }; - assert_eq!(h.get_optype(exit).tag(), OpTag::BasicBlockExit); - Self { h, entry, exit } - } - - fn is_multi_node(&self, n: Node) -> bool { - // TODO if is the entry-node, should we pretend there's an extra predecessor? (The "outside") - // We could also setify here before counting, but never - self.bb_preds(n).take(2).count() + self.bb_succs(n).take(2).count() == 4 - } - fn resolve_out(&self, n: Node) -> HalfNode { - if self.is_multi_node(n) { - HalfNode::X(n) - } else { - HalfNode::N(n) - } - } - - fn bb_succs(&self, n: Node) -> impl Iterator + '_ { - self.h.neighbours(n, Direction::Outgoing) - } - fn bb_preds(&self, n: Node) -> impl Iterator + '_ { - self.h.neighbours(n, Direction::Incoming) - } -} - -impl> CfgNodeMap for HalfNodeView { - type Iterator<'c> = as IntoIterator>::IntoIter where Self: 'c; - fn entry_node(&self) -> HalfNode { - HalfNode::N(self.entry) - } - fn exit_node(&self) -> HalfNode { - assert!(self.bb_succs(self.exit).count() == 0); - HalfNode::N(self.exit) - } - fn predecessors(&self, h: HalfNode) -> Self::Iterator<'_> { - let mut ps = Vec::new(); - match h { - HalfNode::N(ni) => ps.extend(self.bb_preds(ni).map(|n| self.resolve_out(n))), - HalfNode::X(ni) => ps.push(HalfNode::N(ni)), - }; - if h == self.entry_node() { - ps.push(self.exit_node()); - } - ps.into_iter() - } - fn successors(&self, n: HalfNode) -> Self::Iterator<'_> { - let mut succs = Vec::new(); - match n { - HalfNode::N(ni) if self.is_multi_node(ni) => succs.push(HalfNode::X(ni)), - HalfNode::N(ni) | HalfNode::X(ni) => succs.extend(self.bb_succs(ni).map(HalfNode::N)), - }; - succs.into_iter() - } -} - -#[cfg(test)] -mod test { - use super::super::nest_cfgs::{test::*, EdgeClassifier}; - use super::{HalfNode, HalfNodeView}; - use crate::builder::BuildError; - use crate::hugr::views::RootChecked; - use crate::ops::handle::NodeHandle; - - use itertools::Itertools; - use std::collections::HashSet; - #[test] - fn test_cond_in_loop_combined_headers() -> Result<(), BuildError> { - let (h, main, tail) = build_conditional_in_loop_cfg(false)?; - // /-> left --\ - // entry -> main > merge -> tail -> exit - // | \-> right -/ | - // \---<---<---<---<---<--<---/ - // The "main" has two predecessors (entry and tail) and two successors (left and right) so - // we get HalfNode::N(main) aka "head" and HalfNode::X(main) aka "split" in this form: - // /-> left --\ - // N(entry) -> head -> split > N(merge) -> N(tail) -> N(exit) - // | \-> right -/ | - // \---<---<---<---<---<---<---<---<---<---/ - // Allowing to identify two nested regions (and fixing the problem with an IdentityCfgMap on the same example) - - let v = HalfNodeView::new(RootChecked::try_new(&h).unwrap()); - - let edge_classes = EdgeClassifier::get_edge_classes(&v); - let HalfNodeView { h: _, entry, exit } = v; - - let head = HalfNode::N(main.node()); - let tail = HalfNode::N(tail.node()); - let split = HalfNode::X(main.node()); - let (entry, exit) = (HalfNode::N(entry), HalfNode::N(exit)); - // merge is unique predecessor of tail - let merge = *edge_classes - .keys() - .filter(|(_, t)| *t == tail) - .map(|(s, _)| s) - .exactly_one() - .unwrap(); - let [&left, &right] = edge_classes - .keys() - .filter(|(s, _)| *s == split) - .map(|(_, t)| t) - .collect::>()[..] - else { - panic!("Split node should have two successors"); - }; - let classes = group_by(edge_classes); - assert_eq!( - classes, - HashSet::from([ - sorted([(split, left), (left, merge)]), // Region containing single BB 'left'. - sorted([(split, right), (right, merge)]), // Region containing single BB 'right'. - sorted([(head, split), (merge, tail)]), // The inner "conditional" region. - sorted([(entry, head), (tail, exit)]), // "Loop" region containing body (conditional) + back-edge. - Vec::from([(tail, head)]) // The loop back-edge. - ]) - ); - Ok(()) - } - - // Sadly this HalfNode logic is too simple to fix the test_cond_then_loop_combined case - // (The "merge" node is not split, but needs to be split with the tail->merge edge incoming - // to the *second* node after splitting). -} diff --git a/hugr/src/algorithm/nest_cfgs.rs b/hugr/src/algorithm/nest_cfgs.rs deleted file mode 100644 index feae6470b..000000000 --- a/hugr/src/algorithm/nest_cfgs.rs +++ /dev/null @@ -1,946 +0,0 @@ -//! # Nest CFGs -//! -//! Identify Single-Entry-Single-Exit (SESE) regions in the CFG. -//! These are pairs of edges (a,b) where -//! * a dominates b -//! * b postdominates a -//! * there are no other edges in/out of the nodes inbetween -//! (this last condition is necessary because loop backedges do not affect (post)dominance). -//! -//! # Algorithm -//! See paper: , approximately: -//! 1. those three conditions are equivalent to: -//! *a and b are cycle-equivalent in the CFG with an extra edge from the exit node to the entry* -//! where cycle-equivalent means every cycle has either both a and b, or neither -//! 2. cycle equivalence is unaffected if all edges are considered *un*directed -//! (not obvious, see paper for proof) -//! 3. take undirected CFG, perform depth-first traversal -//! => all edges are either *tree edges*, or *backedges* where one endpoint is an ancestor of the other -//! 4. identify the "bracketlist" of each tree edge - the set of backedges going from a descendant of that edge to an ancestor -//! -- post-order traversal, merging bracketlists of children, -//! then delete backedges from below to here, add backedges from here to above -//! => tree edges with the same bracketlist are cycle-equivalent; -//! + a tree edge with a single-element bracketlist is cycle-equivalent with that single element -//! 5. this would be expensive (comparing large sets of backedges) - so to optimize, -//! - the backedge most recently added (at the top) of the bracketlist, plus the size of the bracketlist, -//! is sufficient to identify the set *when the UDFS tree is linear*; -//! - when UDFS is treelike, any ancestor with brackets from >1 subtree cannot be cycle-equivalent with any descendant -//! (as the brackets of said descendant come from beneath it to its ancestors, not from any sibling/etc. in the other subtree). -//! So, add (onto top of bracketlist) a fake "capping" backedge from here to the highest ancestor reached by >1 subtree. -//! (Thus, edges from here up to that ancestor, cannot be cycle-equivalent with any edges elsewhere.) -//! -//! # Restrictions -//! * The paper assumes that all CFG nodes are on paths from entry to exit, i.e. no loops without exits. -//! HUGR assumes only that they are all reachable from entry, so we do a backward traversal from exit node -//! first and restrict to the CFG nodes in the reachable set. (This means we will not discover SESE regions -//! in exit-free loops, but that doesn't seem a major concern.) -//! * Multiple edges in the same direction between the same BBs will "confuse" the algorithm in the paper. -//! However it is straightforward for us to treat successors and predecessors as sets. (Two edges between -//! the same BBs but in opposite directions must be distinct!) - -use std::collections::{HashMap, HashSet, LinkedList, VecDeque}; -use std::hash::Hash; - -use itertools::Itertools; -use thiserror::Error; - -use crate::hugr::rewrite::outline_cfg::OutlineCfg; -use crate::hugr::views::sibling::SiblingMut; -use crate::hugr::views::{HierarchyView, HugrView, SiblingGraph}; -use crate::hugr::{HugrMut, Rewrite, RootTagged}; -use crate::ops::handle::{BasicBlockID, CfgID}; -use crate::ops::OpTag; -use crate::ops::OpTrait; -use crate::{Direction, Hugr, Node}; - -/// A "view" of a CFG in a Hugr which allows basic blocks in the underlying CFG to be split into -/// multiple blocks in the view (or merged together). -/// `T` is the type of basic block; this can just be a BasicBlock (e.g. [`Node`]) in the Hugr, -/// or an [IdentityCfgMap] if the extra level of indirection is not required. However, since -/// SESE regions are bounded by edges between pairs of such `T`, such splitting may allow the -/// algorithm to identify more regions than existed in the underlying CFG, without mutating the -/// underlying CFG just for the analysis - the splitting (and/or merging) can then be performed by -/// [CfgNester::nest_sese_region] only as necessary for regions actually nested. -pub trait CfgNodeMap { - /// The unique entry node of the CFG. It may any n>=0 of incoming edges; we assume control arrives here from "outside". - fn entry_node(&self) -> T; - /// The unique exit node of the CFG. The only node to have no successors. - fn exit_node(&self) -> T; - /// Allows the trait implementor to define a type of iterator it will return from - /// `successors` and `predecessors`. - type Iterator<'c>: Iterator - where - Self: 'c; - /// Returns an iterator over the successors of the specified basic block. - fn successors(&self, node: T) -> Self::Iterator<'_>; - /// Returns an iterator over the predecessors of the specified basic block. - fn predecessors(&self, node: T) -> Self::Iterator<'_>; -} - -/// Extension of [CfgNodeMap] to that can perform (mutable/destructive) -/// nesting of regions detected. -pub trait CfgNester: CfgNodeMap { - /// Given an entry edge and exit edge defining a SESE region, mutates the - /// Hugr such that all nodes between these edges are placed in a nested CFG. - /// Returns the newly-constructed block (containing a nested CFG). - /// - /// # Panics - /// May panic if the two edges do not constitute a SESE region. - fn nest_sese_region(&mut self, entry_edge: (T, T), exit_edge: (T, T)) -> T; -} - -/// Transforms a CFG into as much-nested a form as possible. -pub fn transform_cfg_to_nested( - view: &mut impl CfgNester, -) { - let edge_classes = EdgeClassifier::get_edge_classes(view); - let mut rem_edges: HashMap> = HashMap::new(); - for (e, cls) in edge_classes.iter() { - rem_edges.entry(*cls).or_default().insert(*e); - } - - // Traverse. Any traversal will encounter edges in SESE-respecting order. - fn traverse( - view: &mut impl CfgNester, - n: T, - edge_classes: &HashMap<(T, T), usize>, - rem_edges: &mut HashMap>, - stop_at: Option, - ) -> Option<(T, T)> { - let mut seen = HashSet::new(); - let mut stack = Vec::new(); - let mut exit_edges = Vec::new(); - stack.push(n); - while let Some(n) = stack.pop() { - if !seen.insert(n) { - continue; - } - let (exit, rest): (Vec<_>, Vec<_>) = view - .successors(n) - .map(|s| (n, s)) - .partition(|e| stop_at.is_some() && edge_classes.get(e).copied() == stop_at); - exit_edges.extend(exit.into_iter().at_most_one().unwrap()); - for mut e in rest { - if let Some(cls) = edge_classes.get(&e) { - assert!(rem_edges.get_mut(cls).unwrap().remove(&e)); - // While there are more edges in that same class, we can traverse the entire - // subregion between pairs of edges in that class in a single step - // (as these are strictly nested in any outer region) - while !rem_edges.get_mut(cls).unwrap().is_empty() { - let prev_e = e; - // Traverse to the next edge in the same class - we know it exists in the set - e = traverse(view, e.1, edge_classes, rem_edges, Some(*cls)).unwrap(); - assert!(rem_edges.get_mut(cls).unwrap().remove(&e)); - // Skip trivial regions of a single node, unless the node has other edges - // (non-exiting, but e.g. a backedge to a loop header, ending that loop) - if prev_e.1 != e.0 || view.successors(e.0).count() > 1 { - // Traversal and nesting of the subregion's *contents* were completed in the - // recursive call above, so only processed nodes are moved into descendant CFGs - e = (view.nest_sese_region(prev_e, e), e.1) - }; - } - } - stack.push(e.1); - } - } - exit_edges.into_iter().unique().at_most_one().unwrap() - } - traverse(view, view.entry_node(), &edge_classes, &mut rem_edges, None); - // TODO we should probably now try to merge consecutive basic blocks - // (i.e. where a BB has a single successor, that has a single predecessor) - // and thus convert CF dependencies into (parallelizable) dataflow. -} - -/// Search the entire Hugr looking for CFGs, and transform each -/// into as deeply-nested form as possible (as per [transform_cfg_to_nested]). -/// This search may be expensive, although if it finds much/many CFGs, -/// the analysis/transformation on them is likely to be more expensive still! -pub fn transform_all_cfgs(h: &mut Hugr) { - let mut node_stack = Vec::from([h.root()]); - while let Some(n) = node_stack.pop() { - if let Ok(s) = SiblingMut::::try_new(h, n) { - transform_cfg_to_nested(&mut IdentityCfgMap::new(s)); - } - node_stack.extend(h.children(n)) - } -} - -/// Directed edges in a Cfg - i.e. along which control flows from first to second only. -type CfgEdge = (T, T); - -// The next enum + few functions allow to abstract over the edge directions -// in a CfgView. - -#[derive(Copy, Clone, PartialEq, Eq, Hash)] -enum EdgeDest { - Forward(T), - Backward(T), -} - -impl EdgeDest { - pub fn target(&self) -> T { - match self { - EdgeDest::Forward(i) => *i, - EdgeDest::Backward(i) => *i, - } - } -} - -fn all_edges<'a, T: Copy + Clone + PartialEq + Eq + Hash + 'a>( - cfg: &'a impl CfgNodeMap, - n: T, -) -> impl Iterator> + '_ { - let extra = if n == cfg.exit_node() { - vec![cfg.entry_node()] - } else { - vec![] - }; - cfg.successors(n) - .chain(extra) - .map(EdgeDest::Forward) - .chain(cfg.predecessors(n).map(EdgeDest::Backward)) - .unique() -} - -fn flip(src: T, d: EdgeDest) -> (T, EdgeDest) { - match d { - EdgeDest::Forward(tgt) => (tgt, EdgeDest::Backward(src)), - EdgeDest::Backward(tgt) => (tgt, EdgeDest::Forward(src)), - } -} - -fn cfg_edge(s: T, d: EdgeDest) -> CfgEdge { - match d { - EdgeDest::Forward(t) => (s, t), - EdgeDest::Backward(t) => (t, s), - } -} - -/// A straightforward view of a Cfg as it appears in a Hugr -pub struct IdentityCfgMap { - h: H, - entry: Node, - exit: Node, -} -impl> IdentityCfgMap { - /// Creates an [IdentityCfgMap] for the specified CFG - pub fn new(h: H) -> Self { - // Panic if malformed enough not to have two children - let (entry, exit) = h.children(h.root()).take(2).collect_tuple().unwrap(); - debug_assert_eq!(h.get_optype(exit).tag(), OpTag::BasicBlockExit); - Self { h, entry, exit } - } -} -impl CfgNodeMap for IdentityCfgMap { - fn entry_node(&self) -> Node { - self.entry - } - - fn exit_node(&self) -> Node { - self.exit - } - - type Iterator<'c> = ::Neighbours<'c> - where - Self: 'c; - - fn successors(&self, node: Node) -> Self::Iterator<'_> { - self.h.neighbours(node, Direction::Outgoing) - } - - fn predecessors(&self, node: Node) -> Self::Iterator<'_> { - self.h.neighbours(node, Direction::Incoming) - } -} - -impl CfgNester for IdentityCfgMap { - fn nest_sese_region(&mut self, entry_edge: (Node, Node), exit_edge: (Node, Node)) -> Node { - // The algorithm only calls with entry/exit edges for a SESE region; panic if they don't - let blocks = region_blocks(self, entry_edge, exit_edge).unwrap(); - assert!([entry_edge.0, entry_edge.1, exit_edge.0, exit_edge.1] - .iter() - .all(|n| self.h.get_parent(*n) == Some(self.h.root()))); - let (new_block, new_cfg) = OutlineCfg::new(blocks).apply(&mut self.h).unwrap(); - debug_assert!([entry_edge.0, exit_edge.1] - .iter() - .all(|n| self.h.get_parent(*n) == Some(self.h.root()))); - - debug_assert!({ - let new_block_view = SiblingGraph::::try_new(&self.h, new_block).unwrap(); - let new_cfg_view = SiblingGraph::::try_new(&new_block_view, new_cfg).unwrap(); - [entry_edge.1, exit_edge.0] - .iter() - .all(|n| new_cfg_view.get_parent(*n) == Some(new_cfg)) - }); - new_block - } -} - -/// An error trying to get the blocks of a SESE (single-entry-single-exit) region -#[derive(Clone, Debug, Error)] -#[non_exhaustive] -pub enum RegionBlocksError { - /// The specified exit edge did not exist in the CFG - ExitEdgeNotPresent(T, T), - /// The specified entry edge did not exist in the CFG - EntryEdgeNotPresent(T, T), - /// The source of the entry edge was in the region - /// (reachable from the target of the entry edge without using the exit edge) - EntryEdgeSourceInRegion(T), - /// The target of the entry edge had other predecessors (given) - /// that were outside the region (i.e. not reachable from the target) - UnexpectedEntryEdges(Vec), -} - -/// Given entry and exit edges for a SESE region, identify all the blocks in it. -pub fn region_blocks( - v: &impl CfgNodeMap, - entry_edge: (T, T), - exit_edge: (T, T), -) -> Result, RegionBlocksError> { - let mut blocks = HashSet::new(); - let mut queue = VecDeque::new(); - queue.push_back(entry_edge.1); - while let Some(n) = queue.pop_front() { - if blocks.insert(n) { - if n == exit_edge.0 { - let succs: Vec = v.successors(n).collect(); - let n_succs = succs.len(); - let internal_succs: Vec = - succs.into_iter().filter(|s| *s != exit_edge.1).collect(); - if internal_succs.len() == n_succs { - return Err(RegionBlocksError::ExitEdgeNotPresent( - exit_edge.0, - exit_edge.1, - )); - } - queue.extend(internal_succs) - } else { - queue.extend(v.successors(n)); - } - } - } - if blocks.contains(&entry_edge.0) { - return Err(RegionBlocksError::EntryEdgeSourceInRegion(entry_edge.0)); - } - - let ext_preds = v - .predecessors(entry_edge.1) - .unique() - .filter(|p| !blocks.contains(p)); - let (expected, extra): (Vec, Vec) = ext_preds.partition(|i| *i == entry_edge.0); - if expected != vec![entry_edge.0] { - return Err(RegionBlocksError::EntryEdgeNotPresent( - entry_edge.0, - entry_edge.1, - )); - }; - if !extra.is_empty() { - return Err(RegionBlocksError::UnexpectedEntryEdges(extra)); - } - // We could check for other nodes in the region having predecessors outside it, but that would be more expensive - Ok(blocks) -} - -/// Records an undirected Depth First Search over a CfgView, -/// restricted to nodes forwards-reachable from the entry. -/// That is, the DFS traversal goes both ways along the edges of the CFG. -/// *Undirected* DFS classifies all edges into *only two* categories -/// * tree edges, which on their own (with the nodes) form a tree (minimum spanning tree); -/// * backedges, i.e. those for which, when DFS tried to traverse them, the other endpoint was an ancestor -/// Moreover, we record *which way* along the underlying CFG edge we went. -struct UndirectedDFSTree { - /// Pre-order traversal numbering - dfs_num: HashMap, - /// For each node, the edge along which it was reached from its parent - dfs_parents: HashMap>, -} - -impl UndirectedDFSTree { - fn new(cfg: &impl CfgNodeMap) -> Self { - //1. Traverse backwards-only from exit building bitset of reachable nodes - let mut reachable = HashSet::new(); - { - let mut pending = VecDeque::new(); - pending.push_back(cfg.exit_node()); - while let Some(n) = pending.pop_front() { - if reachable.insert(n) { - pending.extend(cfg.predecessors(n)); - } - } - } - //2. Traverse undirected from entry node, building dfs_num and setting dfs_parents - let mut dfs_num = HashMap::new(); - let mut dfs_parents = HashMap::new(); - { - // Node, and directed edge along which reached - let mut pending = vec![(cfg.entry_node(), EdgeDest::Backward(cfg.exit_node()))]; - while let Some((n, p_edge)) = pending.pop() { - if !dfs_num.contains_key(&n) && reachable.contains(&n) { - dfs_num.insert(n, dfs_num.len()); - dfs_parents.insert(n, p_edge); - for e in all_edges(cfg, n) { - pending.push(flip(n, e)); - } - } - } - dfs_parents.remove(&cfg.entry_node()).unwrap(); - } - UndirectedDFSTree { - dfs_num, - dfs_parents, - } - } -} - -#[derive(Clone, PartialEq, Eq, Hash)] -enum Bracket { - Real(CfgEdge), - Capping(usize, T), -} - -/// Manages a list of brackets. The goal here is to allow constant-time deletion -/// out of the middle of the list - which isn't really possible, so instead we -/// track deleted items (in an external set) and the remaining number (here). -/// -/// Note - we could put the items deleted from *this* BracketList here, and merge in concat(). -/// That would be cleaner, but repeated set-merging would be slower than adding the -/// deleted items to a single set in the `TraversalState` -struct BracketList { - items: LinkedList>, // Allows O(1) `append` of two lists. - size: usize, // Not counting deleted items. -} - -impl BracketList { - pub fn new() -> Self { - BracketList { - items: LinkedList::new(), - size: 0, - } - } - - pub fn tag(&mut self, deleted: &HashSet>) -> Option<(Bracket, usize)> { - while let Some(e) = self.items.front() { - // Pop deleted elements to save time (and memory) - if deleted.contains(e) { - self.items.pop_front(); - //deleted.remove(e); // Would only save memory, so keep as immutable - } else { - return Some((e.clone(), self.size)); - } - } - None - } - - pub fn concat(&mut self, other: BracketList) { - let BracketList { mut items, size } = other; - self.items.append(&mut items); - assert!(items.is_empty()); - self.size += size; - } - - pub fn delete(&mut self, b: &Bracket, deleted: &mut HashSet>) { - // Ideally, here we would also assert that no *other* BracketList contains b. - debug_assert!(self.items.contains(b)); // Makes operation O(n), otherwise O(1) - let was_new = deleted.insert(b.clone()); - assert!(was_new); - self.size -= 1; - } - - pub fn push(&mut self, e: Bracket) { - self.items.push_back(e); - self.size += 1; - } -} - -/// Mutable state updated during traversal of the UndirectedDFSTree by the cycle equivalence algorithm. -pub struct EdgeClassifier { - /// Edges we have marked as deleted, allowing constant-time deletion without searching BracketList - deleted_backedges: HashSet>, - /// Key is DFS num of highest ancestor - /// to which backedges reached from >1 sibling subtree; - /// Value is the LCA i.e. parent of those siblings. - capping_edges: HashMap>, - /// Result of traversal - accumulated here, entries should never be overwritten - edge_classes: HashMap, Option<(Bracket, usize)>>, -} - -impl EdgeClassifier { - /// Computes equivalence class of each edge, i.e. two edges with the same value - /// are cycle-equivalent. Any two consecutive edges in the same class define a SESE region - /// (where "consecutive" means on any path in the original directed CFG, as the edges - /// in a class all dominate + postdominate each other as part of defn of cycle equivalence). - pub fn get_edge_classes(cfg: &impl CfgNodeMap) -> HashMap, usize> { - let tree = UndirectedDFSTree::new(cfg); - let mut s = Self { - deleted_backedges: HashSet::new(), - capping_edges: HashMap::new(), - edge_classes: HashMap::new(), - }; - s.traverse(cfg, &tree, cfg.entry_node()); - assert!(s.capping_edges.is_empty()); - s.edge_classes.remove(&(cfg.exit_node(), cfg.entry_node())); - let mut cycle_class_idxs = HashMap::new(); - s.edge_classes - .into_iter() - .map(|(k, v)| { - let l = cycle_class_idxs.len(); - (k, *cycle_class_idxs.entry(v).or_insert(l)) - }) - .collect() - } - - /// Returns the lowest DFS num (highest ancestor) reached by any bracket leaving - /// the subtree, and the list of said brackets. - fn traverse( - &mut self, - cfg: &impl CfgNodeMap, - tree: &UndirectedDFSTree, - n: T, - ) -> (usize, BracketList) { - let n_dfs = *tree.dfs_num.get(&n).unwrap(); // should only be called for nodes on path to exit - let (children, non_capping_backedges): (Vec<_>, Vec<_>) = all_edges(cfg, n) - .filter(|e| tree.dfs_num.contains_key(&e.target())) - .partition(|e| { - // The tree edges are those whose *targets* list the edge as parent-edge - let (tgt, from) = flip(n, *e); - tree.dfs_parents.get(&tgt) == Some(&from) - }); - let child_results: Vec<_> = children - .iter() - .map(|c| self.traverse(cfg, tree, c.target())) - .collect(); - let mut min_dfs_target: [Option; 2] = [None, None]; // We want highest-but-one - let mut bs = BracketList::new(); - for (tgt, brs) in child_results { - if tgt < min_dfs_target[0].unwrap_or(usize::MAX) { - min_dfs_target = [Some(tgt), min_dfs_target[0]] - } else if tgt < min_dfs_target[1].unwrap_or(usize::MAX) { - min_dfs_target[1] = Some(tgt) - } - bs.concat(brs); - } - // Add capping backedge - if let Some(min1dfs) = min_dfs_target[1] { - if min1dfs < n_dfs { - bs.push(Bracket::Capping(min1dfs, n)); - // mark capping edge to be removed when we return out to the other end - self.capping_edges.entry(min1dfs).or_default().push(n); - } - } - - let parent_edge = tree.dfs_parents.get(&n); - let (be_up, be_down): (Vec<_>, Vec<_>) = non_capping_backedges - .into_iter() - .map(|e| (*tree.dfs_num.get(&e.target()).unwrap(), e)) - .partition(|(dfs, _)| *dfs < n_dfs); - - // Remove edges to here from beneath - for (_, e) in be_down { - let e = cfg_edge(n, e); - let b = Bracket::Real(e); - bs.delete(&b, &mut self.deleted_backedges); - // Last chance to assign an edge class! This will be a singleton class, - // but assign for consistency with other singletons. - self.edge_classes.entry(e).or_insert_with(|| Some((b, 0))); - } - // And capping backedges - for src in self.capping_edges.remove(&n_dfs).unwrap_or_default() { - bs.delete(&Bracket::Capping(n_dfs, src), &mut self.deleted_backedges) - } - - // Add backedges from here to ancestors (not the parent edge, but perhaps other edges to the same node) - be_up - .iter() - .filter(|(_, e)| Some(e) != parent_edge) - .for_each(|(_, e)| bs.push(Bracket::Real(cfg_edge(n, *e)))); - - // Now calculate edge classes - let class = bs.tag(&self.deleted_backedges); - if let Some((Bracket::Real(e), 1)) = &class { - self.edge_classes.insert(*e, class.clone()); - } - if let Some(parent_edge) = tree.dfs_parents.get(&n) { - self.edge_classes.insert(cfg_edge(n, *parent_edge), class); - } - let highest_target = be_up - .into_iter() - .map(|(dfs, _)| dfs) - .chain(min_dfs_target[0]); - (highest_target.min().unwrap_or(usize::MAX), bs) - } -} - -#[cfg(test)] -pub(crate) mod test { - use super::*; - use crate::builder::{BuildError, CFGBuilder, Container, DataflowSubContainer, HugrBuilder}; - use crate::extension::PRELUDE_REGISTRY; - use crate::extension::{prelude::USIZE_T, ExtensionSet}; - - use crate::hugr::views::RootChecked; - use crate::ops::handle::{ConstID, NodeHandle}; - use crate::ops::Value; - use crate::type_row; - use crate::types::{FunctionType, Type}; - const NAT: Type = USIZE_T; - - pub fn group_by(h: HashMap) -> HashSet> { - let mut res = HashMap::new(); - for (k, v) in h.into_iter() { - res.entry(v).or_insert_with(Vec::new).push(k); - } - res.into_values().map(sorted).collect() - } - - pub fn sorted(items: impl IntoIterator) -> Vec { - let mut v: Vec<_> = items.into_iter().collect(); - v.sort(); - v - } - - #[test] - fn test_cond_then_loop_separate() -> Result<(), BuildError> { - // /-> left --\ - // entry -> split > merge -> head -> tail -> exit - // \-> right -/ \-<--<-/ - let mut cfg_builder = CFGBuilder::new(FunctionType::new_endo(NAT))?; - - let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2")); - let const_unit = cfg_builder.add_constant(Value::unary_unit_sum()); - - let entry = n_identity( - cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?, - &const_unit, - )?; - let (split, merge) = build_if_then_else_merge(&mut cfg_builder, &pred_const, &const_unit)?; - cfg_builder.branch(&entry, 0, &split)?; - let head = n_identity( - cfg_builder.simple_block_builder(FunctionType::new_endo(NAT), 1)?, - &const_unit, - )?; - let tail = n_identity( - cfg_builder.simple_block_builder(FunctionType::new_endo(NAT), 2)?, - &pred_const, - )?; - cfg_builder.branch(&tail, 1, &head)?; - cfg_builder.branch(&head, 0, &tail)?; // trivial "loop body" - cfg_builder.branch(&merge, 0, &head)?; - let exit = cfg_builder.exit_block(); - cfg_builder.branch(&tail, 0, &exit)?; - - let mut h = cfg_builder.finish_prelude_hugr()?; - let rc = RootChecked::<_, CfgID>::try_new(&mut h).unwrap(); - let (entry, exit) = (entry.node(), exit.node()); - let (split, merge, head, tail) = (split.node(), merge.node(), head.node(), tail.node()); - let edge_classes = EdgeClassifier::get_edge_classes(&IdentityCfgMap::new(rc.borrow())); - let [&left, &right] = edge_classes - .keys() - .filter(|(s, _)| *s == split) - .map(|(_, t)| t) - .collect::>()[..] - else { - panic!("Split node should have two successors"); - }; - - let classes = group_by(edge_classes); - assert_eq!( - classes, - HashSet::from([ - sorted([(split, left), (left, merge)]), // Region containing single BB 'left'. - sorted([(split, right), (right, merge)]), // Region containing single BB 'right'. - Vec::from([(head, tail)]), // Loop body and backedges are in their own classes because - Vec::from([(tail, head)]), // the path executing the loop exactly once skips the backedge. - sorted([(entry, split), (merge, head), (tail, exit)]), // Two regions, conditional and then loop. - ]) - ); - transform_cfg_to_nested(&mut IdentityCfgMap::new(rc)); - h.update_validate(&PRELUDE_REGISTRY).unwrap(); - assert_eq!(1, depth(&h, entry)); - assert_eq!(1, depth(&h, exit)); - for n in [split, left, right, merge, head, tail] { - assert_eq!(3, depth(&h, n)); - } - let first = [split, left, right, merge] - .iter() - .map(|n| h.get_parent(*n).unwrap()) - .unique() - .exactly_one() - .unwrap(); - let second = [head, tail] - .iter() - .map(|n| h.get_parent(*n).unwrap()) - .unique() - .exactly_one() - .unwrap(); - assert_ne!(first, second); - Ok(()) - } - - #[test] - fn test_cond_then_loop_combined() -> Result<(), BuildError> { - // Here we would like two consecutive regions, but there is no *edge* between - // the conditional and the loop to indicate the boundary, so we cannot separate them. - let (h, merge, tail) = build_cond_then_loop_cfg()?; - let (merge, tail) = (merge.node(), tail.node()); - let [entry, exit]: [Node; 2] = h - .children(h.root()) - .take(2) - .collect_vec() - .try_into() - .unwrap(); - - let v = IdentityCfgMap::new(RootChecked::try_new(&h).unwrap()); - let edge_classes = EdgeClassifier::get_edge_classes(&v); - let [&left, &right] = edge_classes - .keys() - .filter(|(s, _)| *s == entry) - .map(|(_, t)| t) - .collect::>()[..] - else { - panic!("Entry node should have two successors"); - }; - - let classes = group_by(edge_classes); - assert_eq!( - classes, - HashSet::from([ - sorted([(entry, left), (left, merge)]), // Region containing single BB 'left'. - sorted([(entry, right), (right, merge)]), // Region containing single BB 'right'. - Vec::from([(tail, exit)]), // The only edge in neither conditional nor loop. - Vec::from([(merge, tail)]), // Loop body (at least once per execution). - Vec::from([(tail, merge)]), // Loop backedge (0 or more times per execution). - ]) - ); - Ok(()) - } - - #[test] - fn test_cond_in_loop_separate_headers() -> Result<(), BuildError> { - let (mut h, head, tail) = build_conditional_in_loop_cfg(true)?; - let head = head.node(); - let tail = tail.node(); - // /-> left --\ - // entry -> head -> split > merge -> tail -> exit - // | \-> right -/ | - // \---<---<---<---<---<---<---<---<---/ - // split is unique successor of head - let split = h.output_neighbours(head).exactly_one().unwrap(); - // merge is unique predecessor of tail - let merge = h.input_neighbours(tail).exactly_one().unwrap(); - - // There's no need to use a view of a region here but we do so just to check - // that we *can* (as we'll need to for "real" module Hugr's) - let v = IdentityCfgMap::new(SiblingGraph::try_new(&h, h.root()).unwrap()); - let edge_classes = EdgeClassifier::get_edge_classes(&v); - let IdentityCfgMap { h: _, entry, exit } = v; - let [&left, &right] = edge_classes - .keys() - .filter(|(s, _)| *s == split) - .map(|(_, t)| t) - .collect::>()[..] - else { - panic!("Split node should have two successors"); - }; - let classes = group_by(edge_classes); - assert_eq!( - classes, - HashSet::from([ - sorted([(split, left), (left, merge)]), // Region containing single BB 'left' - sorted([(split, right), (right, merge)]), // Region containing single BB 'right' - sorted([(head, split), (merge, tail)]), // "Conditional" region containing split+merge choosing between left/right - sorted([(entry, head), (tail, exit)]), // "Loop" region containing body (conditional) + back-edge - Vec::from([(tail, head)]) // The loop back-edge - ]) - ); - - // Again, there's no need for a view of a region here, but check that the - // transformation still works when we can only directly mutate the top level - let root = h.root(); - let m = SiblingMut::::try_new(&mut h, root).unwrap(); - transform_cfg_to_nested(&mut IdentityCfgMap::new(m)); - h.update_validate(&PRELUDE_REGISTRY).unwrap(); - assert_eq!(1, depth(&h, entry)); - assert_eq!(3, depth(&h, head)); - for n in [split, left, right, merge] { - assert_eq!(5, depth(&h, n)); - } - assert_eq!(3, depth(&h, tail)); - assert_eq!(1, depth(&h, exit)); - Ok(()) - } - - #[test] - fn test_cond_in_loop_combined_headers() -> Result<(), BuildError> { - let (h, head, tail) = build_conditional_in_loop_cfg(false)?; - let head = head.node(); - let tail = tail.node(); - // /-> left --\ - // entry -> head > merge -> tail -> exit - // | \-> right -/ | - // \---<---<---<---<---<--<---/ - // Here we would like an indication that we can make two nested regions, - // but there is no edge to act as entry to a region containing just the conditional :-(. - - let v = IdentityCfgMap::new(RootChecked::try_new(&h).unwrap()); - let edge_classes = EdgeClassifier::get_edge_classes(&v); - let IdentityCfgMap { h: _, entry, exit } = v; - // merge is unique predecessor of tail - let merge = *edge_classes - .keys() - .filter(|(_, t)| *t == tail) - .map(|(s, _)| s) - .exactly_one() - .unwrap(); - let [&left, &right] = edge_classes - .keys() - .filter(|(s, _)| *s == head) - .map(|(_, t)| t) - .collect::>()[..] - else { - panic!("Loop header should have two successors"); - }; - let classes = group_by(edge_classes); - assert_eq!( - classes, - HashSet::from([ - sorted([(head, left), (left, merge)]), // Region containing single BB 'left' - sorted([(head, right), (right, merge)]), // Region containing single BB 'right' - Vec::from([(merge, tail)]), // The edge "in the loop", but no other edge in its class to define SESE region - sorted([(entry, head), (tail, exit)]), // "Loop" region containing body (conditional) + back-edge - Vec::from([(tail, head)]) // The loop back-edge - ]) - ); - Ok(()) - } - - fn n_identity( - mut dataflow_builder: T, - pred_const: &ConstID, - ) -> Result { - let w = dataflow_builder.input_wires(); - let u = dataflow_builder.load_const(pred_const); - dataflow_builder.finish_with_outputs([u].into_iter().chain(w)) - } - - fn build_if_then_else_merge + AsRef>( - cfg: &mut CFGBuilder, - const_pred: &ConstID, - unit_const: &ConstID, - ) -> Result<(BasicBlockID, BasicBlockID), BuildError> { - let split = n_identity( - cfg.simple_block_builder(FunctionType::new_endo(NAT), 2)?, - const_pred, - )?; - let merge = build_then_else_merge_from_if(cfg, unit_const, split)?; - Ok((split, merge)) - } - - fn build_then_else_merge_from_if + AsRef>( - cfg: &mut CFGBuilder, - unit_const: &ConstID, - split: BasicBlockID, - ) -> Result { - let merge = n_identity( - cfg.simple_block_builder(FunctionType::new_endo(NAT), 1)?, - unit_const, - )?; - let left = n_identity( - cfg.simple_block_builder(FunctionType::new_endo(NAT), 1)?, - unit_const, - )?; - let right = n_identity( - cfg.simple_block_builder(FunctionType::new_endo(NAT), 1)?, - unit_const, - )?; - cfg.branch(&split, 0, &left)?; - cfg.branch(&split, 1, &right)?; - cfg.branch(&left, 0, &merge)?; - cfg.branch(&right, 0, &merge)?; - Ok(merge) - } - - // /-> left --\ - // entry > merge -> tail -> exit - // \-> right -/ \-<--<-/ - // Result is Hugr plus merge and tail blocks - fn build_cond_then_loop_cfg() -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> { - let mut cfg_builder = CFGBuilder::new(FunctionType::new_endo(NAT))?; - let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2")); - let const_unit = cfg_builder.add_constant(Value::unary_unit_sum()); - - let entry = n_identity( - cfg_builder.simple_entry_builder(type_row![NAT], 2, ExtensionSet::new())?, - &pred_const, - )?; - let merge = build_then_else_merge_from_if(&mut cfg_builder, &const_unit, entry)?; - // The merge block is also the loop header (so it merges three incoming control-flow edges) - let tail = n_identity( - cfg_builder.simple_block_builder(FunctionType::new_endo(NAT), 2)?, - &pred_const, - )?; - cfg_builder.branch(&tail, 1, &merge)?; - cfg_builder.branch(&merge, 0, &tail)?; // trivial "loop body" - let exit = cfg_builder.exit_block(); - cfg_builder.branch(&tail, 0, &exit)?; - - let h = cfg_builder.finish_prelude_hugr()?; - Ok((h, merge, tail)) - } - - // Build a CFG, returning the Hugr - pub(crate) fn build_conditional_in_loop_cfg( - separate_headers: bool, - ) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> { - let mut cfg_builder = CFGBuilder::new(FunctionType::new_endo(NAT))?; - let (head, tail) = build_conditional_in_loop(&mut cfg_builder, separate_headers)?; - let h = cfg_builder.finish_prelude_hugr()?; - Ok((h, head, tail)) - } - - pub(crate) fn build_conditional_in_loop + AsRef>( - cfg_builder: &mut CFGBuilder, - separate_headers: bool, - ) -> Result<(BasicBlockID, BasicBlockID), BuildError> { - let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2")); - let const_unit = cfg_builder.add_constant(Value::unary_unit_sum()); - - let entry = n_identity( - cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?, - &const_unit, - )?; - let (split, merge) = build_if_then_else_merge(cfg_builder, &pred_const, &const_unit)?; - - let head = if separate_headers { - let head = n_identity( - cfg_builder.simple_block_builder(FunctionType::new_endo(NAT), 1)?, - &const_unit, - )?; - cfg_builder.branch(&head, 0, &split)?; - head - } else { - // Combine loop header with split. - split - }; - let tail = n_identity( - cfg_builder.simple_block_builder(FunctionType::new_endo(NAT), 2)?, - &pred_const, - )?; - cfg_builder.branch(&tail, 1, &head)?; - cfg_builder.branch(&merge, 0, &tail)?; - - let exit = cfg_builder.exit_block(); - - cfg_builder.branch(&entry, 0, &head)?; - cfg_builder.branch(&tail, 0, &exit)?; - - Ok((head, tail)) - } - - pub fn depth(h: &Hugr, n: Node) -> u32 { - match h.get_parent(n) { - Some(p) => 1 + depth(h, p), - None => 0, - } - } -} diff --git a/hugr/src/hugr/rewrite/insert_identity.rs b/hugr/src/hugr/rewrite/insert_identity.rs index 43cb35142..7dbde932b 100644 --- a/hugr/src/hugr/rewrite/insert_identity.rs +++ b/hugr/src/hugr/rewrite/insert_identity.rs @@ -100,9 +100,7 @@ mod tests { use super::super::simple_replace::test::dfg_hugr; use super::*; use crate::{ - algorithm::nest_cfgs::test::build_conditional_in_loop_cfg, extension::{prelude::QB_T, PRELUDE_REGISTRY}, - ops::handle::NodeHandle, Hugr, }; @@ -131,23 +129,4 @@ mod tests { h.update_validate(&PRELUDE_REGISTRY).unwrap(); } - - #[test] - fn incorrect_insertion() { - let (mut h, _, tail) = build_conditional_in_loop_cfg(false).unwrap(); - - let final_node = tail.node(); - - let final_node_input = h.node_inputs(final_node).next().unwrap(); - - let rw = IdentityInsertion::new(final_node, final_node_input); - - let apply_result = h.apply_rewrite(rw); - assert_eq!( - apply_result, - Err(IdentityInsertionError::InvalidPortKind(Some( - EdgeKind::ControlFlow - ))) - ); - } } diff --git a/hugr/src/hugr/rewrite/replace.rs b/hugr/src/hugr/rewrite/replace.rs index b06fcfe62..5b2ed2948 100644 --- a/hugr/src/hugr/rewrite/replace.rs +++ b/hugr/src/hugr/rewrite/replace.rs @@ -445,7 +445,6 @@ mod test { use cool_asserts::assert_matches; use itertools::Itertools; - use crate::algorithm::nest_cfgs::test::depth; use crate::builder::{ BuildError, CFGBuilder, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, SubContainer, @@ -463,6 +462,7 @@ mod test { use crate::ops::{self, Case, DataflowBlock, OpTag, OpType, DFG}; use crate::std_extensions::collections; use crate::types::{FunctionType, Type, TypeArg, TypeRow}; + use crate::utils::depth; use crate::{type_row, Direction, Hugr, HugrView, OutgoingPort}; use super::{NewEdgeKind, NewEdgeSpec, ReplaceError, Replacement}; diff --git a/hugr/src/utils.rs b/hugr/src/utils.rs index 693eb2f7b..863cfa7a9 100644 --- a/hugr/src/utils.rs +++ b/hugr/src/utils.rs @@ -4,7 +4,7 @@ use std::fmt::{self, Debug, Display}; use itertools::Itertools; -use crate::{ops::Value, IncomingPort}; +use crate::{ops::Value, Hugr, HugrView, IncomingPort, Node}; /// Write a comma separated list of of some types. /// Like debug_list, but using the Display instance rather than Debug, @@ -224,6 +224,14 @@ pub fn sorted_consts(consts: &[(IncomingPort, Value)]) -> Vec<&Value> { .collect() } +/// Calculate the depth of a node in the hierarchy. +pub fn depth(h: &Hugr, n: Node) -> u32 { + match h.get_parent(n) { + Some(p) => 1 + depth(h, p), + None => 0, + } +} + #[allow(dead_code)] // Test only utils #[cfg(test)] From 9ee77bb8495b36cf676bacc90e4f706287a78695 Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Thu, 23 May 2024 10:14:46 +0100 Subject: [PATCH 11/28] Remove `algorithms::const_fold` (and hence `algorithms`) from hugr crate. --- hugr/src/algorithm.rs | 3 - hugr/src/algorithm/const_fold.rs | 537 ------------------------------- hugr/src/lib.rs | 1 - 3 files changed, 541 deletions(-) delete mode 100644 hugr/src/algorithm.rs delete mode 100644 hugr/src/algorithm/const_fold.rs diff --git a/hugr/src/algorithm.rs b/hugr/src/algorithm.rs deleted file mode 100644 index 685d827c8..000000000 --- a/hugr/src/algorithm.rs +++ /dev/null @@ -1,3 +0,0 @@ -//! Algorithms using the Hugr. - -pub mod const_fold; diff --git a/hugr/src/algorithm/const_fold.rs b/hugr/src/algorithm/const_fold.rs deleted file mode 100644 index c76514cd4..000000000 --- a/hugr/src/algorithm/const_fold.rs +++ /dev/null @@ -1,537 +0,0 @@ -//! Constant folding routines. - -use std::collections::{BTreeSet, HashMap}; - -use itertools::Itertools; -use thiserror::Error; - -use crate::hugr::{SimpleReplacementError, ValidationError}; -use crate::types::SumType; -use crate::utils::sorted_consts; -use crate::Direction; -use crate::{ - builder::{DFGBuilder, Dataflow, DataflowHugr}, - extension::{ConstFoldResult, ExtensionRegistry}, - hugr::{ - rewrite::consts::{RemoveConst, RemoveLoadConstant}, - views::SiblingSubgraph, - HugrMut, - }, - ops::{OpType, Value}, - type_row, - types::FunctionType, - Hugr, HugrView, IncomingPort, Node, SimpleReplacement, -}; - -#[derive(Error, Debug)] -#[allow(missing_docs)] -pub enum ConstFoldError { - #[error("Failed to verify {label} HUGR: {err}")] - VerifyError { - label: String, - #[source] - err: ValidationError, - }, - #[error(transparent)] - SimpleReplaceError(#[from] SimpleReplacementError), -} - -/// Tag some output constants with [`OutgoingPort`] inferred from the ordering. -fn out_row(consts: impl IntoIterator) -> ConstFoldResult { - let vec = consts - .into_iter() - .enumerate() - .map(|(i, c)| (i.into(), c)) - .collect(); - Some(vec) -} - -/// For a given op and consts, attempt to evaluate the op. -pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldResult { - let fold_result = match op { - OpType::Noop { .. } => out_row([consts.first()?.1.clone()]), - OpType::MakeTuple { .. } => { - out_row([Value::tuple(sorted_consts(consts).into_iter().cloned())]) - } - OpType::UnpackTuple { .. } => { - let c = &consts.first()?.1; - let Value::Tuple { vs } = c else { - panic!("This op always takes a Tuple input."); - }; - out_row(vs.iter().cloned()) - } - - OpType::Tag(t) => out_row([Value::sum( - t.tag, - consts.iter().map(|(_, konst)| konst.clone()), - SumType::new(t.variants.clone()), - ) - .unwrap()]), - OpType::CustomOp(op) => { - let ext_op = op.as_extension_op()?; - ext_op.constant_fold(consts) - } - _ => None, - }; - debug_assert!(fold_result.as_ref().map_or(true, |x| x.len() - == op.value_port_count(Direction::Outgoing))); - fold_result -} - -/// Generate a graph that loads and outputs `consts` in order, validating -/// against `reg`. -fn const_graph(consts: Vec, reg: &ExtensionRegistry) -> Hugr { - let const_types = consts.iter().map(Value::get_type).collect_vec(); - let mut b = DFGBuilder::new(FunctionType::new(type_row![], const_types)).unwrap(); - - let outputs = consts - .into_iter() - .map(|c| b.add_load_const(c)) - .collect_vec(); - - b.finish_hugr_with_outputs(outputs, reg).unwrap() -} - -/// Given some `candidate_nodes` to search for LoadConstant operations in `hugr`, -/// return an iterator of possible constant folding rewrites. The -/// [`SimpleReplacement`] replaces an operation with constants that result from -/// evaluating it, the extension registry `reg` is used to validate the -/// replacement HUGR. The vector of [`RemoveLoadConstant`] refer to the -/// LoadConstant nodes that could be removed - they are not automatically -/// removed as they may be used by other operations. -pub fn find_consts<'a, 'r: 'a>( - hugr: &'a impl HugrView, - candidate_nodes: impl IntoIterator + 'a, - reg: &'r ExtensionRegistry, -) -> impl Iterator)> + 'a { - // track nodes for operations that have already been considered for folding - let mut used_neighbours = BTreeSet::new(); - - candidate_nodes - .into_iter() - .filter_map(move |n| { - // only look at LoadConstant - hugr.get_optype(n).is_load_constant().then_some(())?; - - let (out_p, _) = hugr.out_value_types(n).exactly_one().ok()?; - let neighbours = hugr - .linked_inputs(n, out_p) - .filter(|(n, _)| used_neighbours.insert(*n)) - .collect_vec(); - if neighbours.is_empty() { - // no uses of LoadConstant that haven't already been considered. - return None; - } - let fold_iter = neighbours - .into_iter() - .filter_map(|(neighbour, _)| fold_op(hugr, neighbour, reg)); - Some(fold_iter) - }) - .flatten() -} - -/// Attempt to evaluate and generate rewrites for the operation at `op_node` -fn fold_op( - hugr: &impl HugrView, - op_node: Node, - reg: &ExtensionRegistry, -) -> Option<(SimpleReplacement, Vec)> { - // only support leaf folding for now. - let neighbour_op = hugr.get_optype(op_node); - let (in_consts, removals): (Vec<_>, Vec<_>) = hugr - .node_inputs(op_node) - .filter_map(|in_p| { - let (con_op, load_n) = get_const(hugr, op_node, in_p)?; - Some(((in_p, con_op), RemoveLoadConstant(load_n))) - }) - .unzip(); - // attempt to evaluate op - let (nu_out, consts): (HashMap<_, _>, Vec<_>) = fold_leaf_op(neighbour_op, &in_consts)? - .into_iter() - .enumerate() - .filter_map(|(i, (op_out, konst))| { - // for each used port of the op give the nu_out entry and the - // corresponding Value - hugr.single_linked_input(op_node, op_out) - .map(|np| ((np, i.into()), konst)) - }) - .unzip(); - let replacement = const_graph(consts, reg); - let sibling_graph = SiblingSubgraph::try_from_nodes([op_node], hugr) - .expect("Operation should form valid subgraph."); - - let simple_replace = SimpleReplacement::new( - sibling_graph, - replacement, - // no inputs to replacement - HashMap::new(), - nu_out, - ); - Some((simple_replace, removals)) -} - -/// If `op_node` is connected to a LoadConstant at `in_p`, return the constant -/// and the LoadConstant node -fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option<(Value, Node)> { - let (load_n, _) = hugr.single_linked_output(op_node, in_p)?; - let load_op = hugr.get_optype(load_n).as_load_constant()?; - let const_node = hugr - .single_linked_output(load_n, load_op.constant_port())? - .0; - let const_op = hugr.get_optype(const_node).as_const()?; - - // TODO avoid const clone here - Some((const_op.as_ref().clone(), load_n)) -} - -/// Exhaustively apply constant folding to a HUGR. -pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { - #[cfg(test)] - let verify = |label, h: &H| { - h.validate_no_extensions(reg).unwrap_or_else(|err| { - panic!( - "constant_fold_pass: failed to verify {label} HUGR: {err}\n{}", - h.mermaid_string() - ) - }) - }; - #[cfg(test)] - verify("input", h); - loop { - // We can only safely apply a single replacement. Applying a - // replacement removes nodes and edges which may be referenced by - // further replacements returned by find_consts. Even worse, if we - // attempted to apply those replacements, expecting them to fail if - // the nodes and edges they reference had been deleted, they may - // succeed because new nodes and edges reused the ids. - // - // We could be a lot smarter here, keeping track of `LoadConstant` - // nodes and only looking at their out neighbours. - let Some((replace, removes)) = find_consts(h, h.nodes(), reg).next() else { - break; - }; - h.apply_rewrite(replace).unwrap(); - for rem in removes { - // We are optimistically applying these [RemoveLoadConstant] and - // [RemoveConst] rewrites without checking whether the nodes - // they attempt to remove have remaining uses. If they do, then - // the rewrite fails and we move on. - if let Ok(const_node) = h.apply_rewrite(rem) { - // if the LoadConst was removed, try removing the Const too. - let _ = h.apply_rewrite(RemoveConst(const_node)); - } - } - } - #[cfg(test)] - verify("output", h); -} - -#[cfg(test)] -mod test { - - use super::*; - use crate::extension::prelude::{sum_with_error, BOOL_T}; - use crate::extension::{ExtensionRegistry, PRELUDE}; - use crate::ops::{OpType, UnpackTuple}; - use crate::std_extensions::arithmetic; - use crate::std_extensions::arithmetic::conversions::ConvertOpDef; - use crate::std_extensions::arithmetic::float_ops::FloatOps; - use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; - use crate::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; - use crate::std_extensions::logic::{self, NaryLogic, NotOp}; - use crate::utils::test::{assert_fully_folded, assert_fully_folded_with}; - - use rstest::rstest; - - /// int to constant - fn i2c(b: u64) -> Value { - Value::extension(ConstInt::new_u(5, b).unwrap()) - } - - /// float to constant - fn f2c(f: f64) -> Value { - ConstF64::new(f).into() - } - - #[rstest] - #[case(0.0, 0.0, 0.0)] - #[case(0.0, 1.0, 1.0)] - #[case(23.5, 435.5, 459.0)] - // c = a + b - fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { - let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))]; - let add_op: OpType = FloatOps::fadd.into(); - let outs = fold_leaf_op(&add_op, &consts) - .unwrap() - .into_iter() - .map(|(p, v)| (p, v.get_custom_value::().unwrap().value())) - .collect_vec(); - - assert_eq!(outs.as_slice(), &[(0.into(), c)]); - } - #[test] - fn test_big() { - /* - Test approximately calculates - let x = (5.6, 3.2); - int(x.0 - x.1) == 2 - */ - let sum_type = sum_with_error(INT_TYPES[5].to_owned()); - let mut build = DFGBuilder::new(FunctionType::new( - type_row![], - vec![sum_type.clone().into()], - )) - .unwrap(); - - let tup = build.add_load_const(Value::tuple([f2c(5.6), f2c(3.2)])); - - let unpack = build - .add_dataflow_op( - UnpackTuple { - tys: type_row![FLOAT64_TYPE, FLOAT64_TYPE], - }, - [tup], - ) - .unwrap(); - - let sub = build - .add_dataflow_op(FloatOps::fsub, unpack.outputs()) - .unwrap(); - let to_int = build - .add_dataflow_op(ConvertOpDef::trunc_u.with_log_width(5), sub.outputs()) - .unwrap(); - - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - arithmetic::float_types::EXTENSION.to_owned(), - arithmetic::float_ops::EXTENSION.to_owned(), - arithmetic::conversions::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build - .finish_hugr_with_outputs(to_int.outputs(), ®) - .unwrap(); - assert_eq!(h.node_count(), 8); - - constant_fold_pass(&mut h, ®); - - let expected = Value::sum(0, [i2c(2).clone()], sum_type).unwrap(); - assert_fully_folded(&h, &expected); - } - - #[test] - #[cfg_attr( - feature = "extension_inference", - ignore = "inference fails for test graph, it shouldn't" - )] - fn test_list_ops() -> Result<(), Box> { - use crate::std_extensions::collections::{self, ListOp, ListValue}; - - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - logic::EXTENSION.to_owned(), - collections::EXTENSION.to_owned(), - ]) - .unwrap(); - let list: Value = ListValue::new(BOOL_T, [Value::unit_sum(0, 1).unwrap()]).into(); - let mut build = DFGBuilder::new(FunctionType::new( - type_row![], - vec![list.get_type().clone()], - )) - .unwrap(); - - let list_wire = build.add_load_const(list.clone()); - - let pop = build.add_dataflow_op( - ListOp::Pop.with_type(BOOL_T).to_extension_op(®).unwrap(), - [list_wire], - )?; - - let push = build.add_dataflow_op( - ListOp::Push - .with_type(BOOL_T) - .to_extension_op(®) - .unwrap(), - pop.outputs(), - )?; - let mut h = build.finish_hugr_with_outputs(push.outputs(), ®)?; - constant_fold_pass(&mut h, ®); - - assert_fully_folded(&h, &list); - Ok(()) - } - - #[test] - fn test_fold_and() { - // pseudocode: - // x0, x1 := bool(true), bool(true) - // x2 := and(x0, x1) - // output x2 == true; - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::true_val()); - let x1 = build.add_load_const(Value::true_val()); - let x2 = build - .add_dataflow_op(NaryLogic::And.with_n_inputs(2), [x0, x1]) - .unwrap(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::true_val(); - assert_fully_folded(&h, &expected); - } - - #[test] - fn test_fold_or() { - // pseudocode: - // x0, x1 := bool(true), bool(false) - // x2 := or(x0, x1) - // output x2 == true; - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::true_val()); - let x1 = build.add_load_const(Value::false_val()); - let x2 = build - .add_dataflow_op(NaryLogic::Or.with_n_inputs(2), [x0, x1]) - .unwrap(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::true_val(); - assert_fully_folded(&h, &expected); - } - - #[test] - fn test_fold_not() { - // pseudocode: - // x0 := bool(true) - // x1 := not(x0) - // output x1 == false; - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::true_val()); - let x1 = build.add_dataflow_op(NotOp, [x0]).unwrap(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::false_val(); - assert_fully_folded(&h, &expected); - } - - #[test] - fn orphan_output() { - // pseudocode: - // x0 := bool(true) - // x1 := not(x0) - // x2 := or(x0,x1) - // output x2 == true; - // - // We arange things so that the `or` folds away first, leaving the not - // with no outputs. - use crate::hugr::NodeType; - use crate::ops::handle::NodeHandle; - - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); - let true_wire = build.add_load_value(Value::true_val()); - // this Not will be manually replaced - let orig_not = build.add_dataflow_op(NotOp, [true_wire]).unwrap(); - let r = build - .add_dataflow_op( - NaryLogic::Or.with_n_inputs(2), - [true_wire, orig_not.out_wire(0)], - ) - .unwrap(); - let or_node = r.node(); - let parent = build.dfg_node; - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(r.outputs(), ®).unwrap(); - - // we delete the original Not and create a new One. This means it will be - // traversed by `constant_fold_pass` after the Or. - let new_not = h.add_node_with_parent(parent, NodeType::new_auto(NotOp)); - h.connect(true_wire.node(), true_wire.source(), new_not, 0); - h.disconnect(or_node, IncomingPort::from(1)); - h.connect(new_not, 0, or_node, 1); - h.remove_node(orig_not.node()); - constant_fold_pass(&mut h, ®); - assert_fully_folded(&h, &Value::true_val()) - } - - #[test] - fn test_folding_pass_issue_996() { - // pseudocode: - // - // x0 := 3.0 - // x1 := 4.0 - // x2 := fne(x0, x1); // true - // x3 := flt(x0, x1); // true - // x4 := and(x2, x3); // true - // x5 := -10.0 - // x6 := flt(x0, x5) // false - // x7 := or(x4, x6) // true - // output x7 - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstF64::new(3.0))); - let x1 = build.add_load_const(Value::extension(ConstF64::new(4.0))); - let x2 = build.add_dataflow_op(FloatOps::fne, [x0, x1]).unwrap(); - let x3 = build.add_dataflow_op(FloatOps::flt, [x0, x1]).unwrap(); - let x4 = build - .add_dataflow_op( - NaryLogic::And.with_n_inputs(2), - x2.outputs().chain(x3.outputs()), - ) - .unwrap(); - let x5 = build.add_load_const(Value::extension(ConstF64::new(-10.0))); - let x6 = build.add_dataflow_op(FloatOps::flt, [x0, x5]).unwrap(); - let x7 = build - .add_dataflow_op( - NaryLogic::Or.with_n_inputs(2), - x4.outputs().chain(x6.outputs()), - ) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - logic::EXTENSION.to_owned(), - arithmetic::float_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x7.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::true_val(); - assert_fully_folded(&h, &expected); - } - - #[test] - fn test_const_fold_to_nonfinite() { - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::float_types::EXTENSION.to_owned(), - ]) - .unwrap(); - - // HUGR computing 1.0 / 1.0 - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![FLOAT64_TYPE])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); - let x1 = build.add_load_const(Value::extension(ConstF64::new(1.0))); - let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); - let mut h0 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h0, ®); - assert_fully_folded_with(&h0, |v| { - v.get_custom_value::().unwrap().value() == 1.0 - }); - assert_eq!(h0.node_count(), 5); - - // HUGR computing 1.0 / 0.0 - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![FLOAT64_TYPE])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); - let x1 = build.add_load_const(Value::extension(ConstF64::new(0.0))); - let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); - let mut h1 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h1, ®); - assert_eq!(h1.node_count(), 8); - } -} diff --git a/hugr/src/lib.rs b/hugr/src/lib.rs index f2859aacc..c093ba2aa 100644 --- a/hugr/src/lib.rs +++ b/hugr/src/lib.rs @@ -140,7 +140,6 @@ // https://github.com/proptest-rs/proptest/issues/447 #![cfg_attr(test, allow(non_local_definitions))] -pub mod algorithm; pub mod builder; pub mod core; pub mod extension; From 6971110ab5c73dcd836d9a794559321a1f140d4e Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Thu, 23 May 2024 10:22:52 +0100 Subject: [PATCH 12/28] Move const-folding tests for int ops into hugr-passes crate. --- hugr-passes/src/const_fold.rs | 3 +++ .../src/const_fold/int_ops_const_fold_test.rs | 0 hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs | 3 --- 3 files changed, 3 insertions(+), 3 deletions(-) rename hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs => hugr-passes/src/const_fold/int_ops_const_fold_test.rs (100%) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 33fe8e367..92e8b4feb 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -558,3 +558,6 @@ mod test { assert_eq!(h1.node_count(), 8); } } + +#[cfg(test)] +mod int_ops_const_fold_test; diff --git a/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs b/hugr-passes/src/const_fold/int_ops_const_fold_test.rs similarity index 100% rename from hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs rename to hugr-passes/src/const_fold/int_ops_const_fold_test.rs diff --git a/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs b/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs index 8738e1872..4c520963d 100644 --- a/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs +++ b/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs @@ -1196,6 +1196,3 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) { }, }); } - -#[cfg(test)] -mod test; From 6ff82e809639ea11425cbf9ef3e856c82755f4c1 Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Thu, 23 May 2024 10:32:48 +0100 Subject: [PATCH 13/28] Fix up moved tests. --- hugr-passes/Cargo.toml | 1 + hugr-passes/src/const_fold.rs | 2 +- .../src/const_fold/int_ops_const_fold_test.rs | 36 ++++++++++++------- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index e25a4cb52..f1bad2e3e 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] hugr = { path = "../hugr" } itertools = "0.12.0" +lazy_static = "1.4.0" paste = "1.0" thiserror = "1.0.28" diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 92e8b4feb..51426dc6d 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -244,7 +244,7 @@ mod test { use rstest::rstest; /// Check that a hugr just loads and returns a single expected constant. - fn assert_fully_folded(h: &Hugr, expected_value: &Value) { + pub fn assert_fully_folded(h: &Hugr, expected_value: &Value) { assert_fully_folded_with(h, |v| v == expected_value) } diff --git a/hugr-passes/src/const_fold/int_ops_const_fold_test.rs b/hugr-passes/src/const_fold/int_ops_const_fold_test.rs index 959240a51..f696f5231 100644 --- a/hugr-passes/src/const_fold/int_ops_const_fold_test.rs +++ b/hugr-passes/src/const_fold/int_ops_const_fold_test.rs @@ -1,18 +1,21 @@ -use crate::algorithm::const_fold::constant_fold_pass; -use crate::builder::{DFGBuilder, Dataflow, DataflowHugr}; -use crate::extension::prelude::{sum_with_error, ConstError, ConstString, BOOL_T, STRING_TYPE}; -use crate::extension::{ExtensionRegistry, PRELUDE}; -use crate::ops::Value; -use crate::std_extensions::arithmetic; -use crate::std_extensions::arithmetic::int_ops::IntOpDef; -use crate::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; -use crate::std_extensions::logic::{self, NaryLogic}; -use crate::type_row; -use crate::types::{FunctionType, Type, TypeRow}; -use crate::utils::test::assert_fully_folded; +use crate::const_fold::constant_fold_pass; +use hugr::builder::{DFGBuilder, Dataflow, DataflowHugr}; +use hugr::extension::prelude::{sum_with_error, ConstError, ConstString, BOOL_T, STRING_TYPE}; +use hugr::extension::{ExtensionRegistry, PRELUDE}; +use hugr::ops::Value; +use hugr::std_extensions::arithmetic; +use hugr::std_extensions::arithmetic::int_ops::IntOpDef; +use hugr::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; +use hugr::std_extensions::logic::{self, NaryLogic}; +use hugr::type_row; +use hugr::types::{FunctionType, Type, TypeRow}; use rstest::rstest; +use super::test::assert_fully_folded; + +use lazy_static::lazy_static; + #[test] fn test_fold_iwiden_u() { // pseudocode: @@ -109,10 +112,17 @@ fn test_fold_inarrow, E: std::fmt::Debug>( .unwrap(); let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); constant_fold_pass(&mut h, ®); + lazy_static! { + static ref INARROW_ERROR_VALUE: Value = ConstError { + signal: 0, + message: "Integer too large to narrow".to_string(), + } + .into(); + } let expected = if succeeds { Value::sum(0, [mk_const(to_log_width, val).unwrap().into()], sum_type).unwrap() } else { - Value::sum(1, [super::INARROW_ERROR_VALUE.clone()], sum_type).unwrap() + Value::sum(1, [INARROW_ERROR_VALUE.clone()], sum_type).unwrap() }; assert_fully_folded(&h, &expected); } From c6cddfd5839799937cba4f069943891408d1228b Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Thu, 23 May 2024 13:02:00 +0100 Subject: [PATCH 14/28] Use common workspace dependencies. --- hugr-passes/Cargo.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index f1bad2e3e..727e605a3 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -5,10 +5,10 @@ edition = "2021" [dependencies] hugr = { path = "../hugr" } -itertools = "0.12.0" -lazy_static = "1.4.0" -paste = "1.0" -thiserror = "1.0.28" +itertools = { workspace = true } +lazy_static = { workspace = true } +paste = { workspace = true } +thiserror = { workspace = true } [features] extension_inference = ["hugr/extension_inference"] From 9dab1a7eb52ae32e83c04d11b35a01246976d6ed Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Thu, 23 May 2024 13:10:43 +0100 Subject: [PATCH 15/28] Move all test code into one file. --- hugr-passes/src/const_fold.rs | 333 ------------------ .../src/const_fold/int_ops_const_fold_test.rs | 324 ++++++++++++++++- 2 files changed, 321 insertions(+), 336 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 51426dc6d..f42e1a1b6 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -226,338 +226,5 @@ pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { verify("output", h); } -#[cfg(test)] -mod test { - - use super::*; - use hugr::builder::Container; - use hugr::extension::prelude::{sum_with_error, BOOL_T}; - use hugr::extension::{ExtensionRegistry, PRELUDE}; - use hugr::ops::{OpType, UnpackTuple}; - use hugr::std_extensions::arithmetic; - use hugr::std_extensions::arithmetic::conversions::ConvertOpDef; - use hugr::std_extensions::arithmetic::float_ops::FloatOps; - use hugr::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; - use hugr::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; - use hugr::std_extensions::logic::{self, NaryLogic, NotOp}; - - use rstest::rstest; - - /// Check that a hugr just loads and returns a single expected constant. - pub fn assert_fully_folded(h: &Hugr, expected_value: &Value) { - assert_fully_folded_with(h, |v| v == expected_value) - } - - /// Check that a hugr just loads and returns a single constant, and validate - /// that constant using `check_value`. - /// - /// [CustomConst::equals_const] is not required to be implemented. Use this - /// function for Values containing such a `CustomConst`. - fn assert_fully_folded_with(h: &Hugr, check_value: impl Fn(&Value) -> bool) { - let mut node_count = 0; - - for node in h.children(h.root()) { - let op = h.get_optype(node); - match op { - OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1, - OpType::Const(c) if check_value(c.value()) => node_count += 1, - _ => panic!("unexpected op: {:?}\n{}", op, h.mermaid_string()), - } - } - - assert_eq!(node_count, 4); - } - - /// int to constant - fn i2c(b: u64) -> Value { - Value::extension(ConstInt::new_u(5, b).unwrap()) - } - - /// float to constant - fn f2c(f: f64) -> Value { - ConstF64::new(f).into() - } - - #[rstest] - #[case(0.0, 0.0, 0.0)] - #[case(0.0, 1.0, 1.0)] - #[case(23.5, 435.5, 459.0)] - // c = a + b - fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { - let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))]; - let add_op: OpType = FloatOps::fadd.into(); - let outs = fold_leaf_op(&add_op, &consts) - .unwrap() - .into_iter() - .map(|(p, v)| (p, v.get_custom_value::().unwrap().value())) - .collect_vec(); - - assert_eq!(outs.as_slice(), &[(0.into(), c)]); - } - #[test] - fn test_big() { - /* - Test approximately calculates - let x = (5.6, 3.2); - int(x.0 - x.1) == 2 - */ - let sum_type = sum_with_error(INT_TYPES[5].to_owned()); - let mut build = DFGBuilder::new(FunctionType::new( - type_row![], - vec![sum_type.clone().into()], - )) - .unwrap(); - - let tup = build.add_load_const(Value::tuple([f2c(5.6), f2c(3.2)])); - - let unpack = build - .add_dataflow_op( - UnpackTuple::new(type_row![FLOAT64_TYPE, FLOAT64_TYPE]), - [tup], - ) - .unwrap(); - - let sub = build - .add_dataflow_op(FloatOps::fsub, unpack.outputs()) - .unwrap(); - let to_int = build - .add_dataflow_op(ConvertOpDef::trunc_u.with_log_width(5), sub.outputs()) - .unwrap(); - - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - arithmetic::float_types::EXTENSION.to_owned(), - arithmetic::float_ops::EXTENSION.to_owned(), - arithmetic::conversions::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build - .finish_hugr_with_outputs(to_int.outputs(), ®) - .unwrap(); - assert_eq!(h.node_count(), 8); - - constant_fold_pass(&mut h, ®); - - let expected = Value::sum(0, [i2c(2).clone()], sum_type).unwrap(); - assert_fully_folded(&h, &expected); - } - - #[test] - #[cfg_attr( - feature = "extension_inference", - ignore = "inference fails for test graph, it shouldn't" - )] - fn test_list_ops() -> Result<(), Box> { - use hugr::std_extensions::collections::{self, ListOp, ListValue}; - - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - logic::EXTENSION.to_owned(), - collections::EXTENSION.to_owned(), - ]) - .unwrap(); - let list: Value = ListValue::new(BOOL_T, [Value::unit_sum(0, 1).unwrap()]).into(); - let mut build = DFGBuilder::new(FunctionType::new( - type_row![], - vec![list.get_type().clone()], - )) - .unwrap(); - - let list_wire = build.add_load_const(list.clone()); - - let pop = build.add_dataflow_op( - ListOp::Pop.with_type(BOOL_T).to_extension_op(®).unwrap(), - [list_wire], - )?; - - let push = build.add_dataflow_op( - ListOp::Push - .with_type(BOOL_T) - .to_extension_op(®) - .unwrap(), - pop.outputs(), - )?; - let mut h = build.finish_hugr_with_outputs(push.outputs(), ®)?; - constant_fold_pass(&mut h, ®); - - assert_fully_folded(&h, &list); - Ok(()) - } - - #[test] - fn test_fold_and() { - // pseudocode: - // x0, x1 := bool(true), bool(true) - // x2 := and(x0, x1) - // output x2 == true; - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::true_val()); - let x1 = build.add_load_const(Value::true_val()); - let x2 = build - .add_dataflow_op(NaryLogic::And.with_n_inputs(2), [x0, x1]) - .unwrap(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::true_val(); - assert_fully_folded(&h, &expected); - } - - #[test] - fn test_fold_or() { - // pseudocode: - // x0, x1 := bool(true), bool(false) - // x2 := or(x0, x1) - // output x2 == true; - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::true_val()); - let x1 = build.add_load_const(Value::false_val()); - let x2 = build - .add_dataflow_op(NaryLogic::Or.with_n_inputs(2), [x0, x1]) - .unwrap(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::true_val(); - assert_fully_folded(&h, &expected); - } - - #[test] - fn test_fold_not() { - // pseudocode: - // x0 := bool(true) - // x1 := not(x0) - // output x1 == false; - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::true_val()); - let x1 = build.add_dataflow_op(NotOp, [x0]).unwrap(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::false_val(); - assert_fully_folded(&h, &expected); - } - - #[test] - fn orphan_output() { - // pseudocode: - // x0 := bool(true) - // x1 := not(x0) - // x2 := or(x0,x1) - // output x2 == true; - // - // We arange things so that the `or` folds away first, leaving the not - // with no outputs. - use hugr::hugr::NodeType; - use hugr::ops::handle::NodeHandle; - - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); - let true_wire = build.add_load_value(Value::true_val()); - // this Not will be manually replaced - let orig_not = build.add_dataflow_op(NotOp, [true_wire]).unwrap(); - let r = build - .add_dataflow_op( - NaryLogic::Or.with_n_inputs(2), - [true_wire, orig_not.out_wire(0)], - ) - .unwrap(); - let or_node = r.node(); - let parent = build.container_node(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(r.outputs(), ®).unwrap(); - - // we delete the original Not and create a new One. This means it will be - // traversed by `constant_fold_pass` after the Or. - let new_not = h.add_node_with_parent(parent, NodeType::new_auto(NotOp)); - h.connect(true_wire.node(), true_wire.source(), new_not, 0); - h.disconnect(or_node, IncomingPort::from(1)); - h.connect(new_not, 0, or_node, 1); - h.remove_node(orig_not.node()); - constant_fold_pass(&mut h, ®); - assert_fully_folded(&h, &Value::true_val()) - } - - #[test] - fn test_folding_pass_issue_996() { - // pseudocode: - // - // x0 := 3.0 - // x1 := 4.0 - // x2 := fne(x0, x1); // true - // x3 := flt(x0, x1); // true - // x4 := and(x2, x3); // true - // x5 := -10.0 - // x6 := flt(x0, x5) // false - // x7 := or(x4, x6) // true - // output x7 - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstF64::new(3.0))); - let x1 = build.add_load_const(Value::extension(ConstF64::new(4.0))); - let x2 = build.add_dataflow_op(FloatOps::fne, [x0, x1]).unwrap(); - let x3 = build.add_dataflow_op(FloatOps::flt, [x0, x1]).unwrap(); - let x4 = build - .add_dataflow_op( - NaryLogic::And.with_n_inputs(2), - x2.outputs().chain(x3.outputs()), - ) - .unwrap(); - let x5 = build.add_load_const(Value::extension(ConstF64::new(-10.0))); - let x6 = build.add_dataflow_op(FloatOps::flt, [x0, x5]).unwrap(); - let x7 = build - .add_dataflow_op( - NaryLogic::Or.with_n_inputs(2), - x4.outputs().chain(x6.outputs()), - ) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - logic::EXTENSION.to_owned(), - arithmetic::float_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x7.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::true_val(); - assert_fully_folded(&h, &expected); - } - - #[test] - fn test_const_fold_to_nonfinite() { - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::float_types::EXTENSION.to_owned(), - ]) - .unwrap(); - - // HUGR computing 1.0 / 1.0 - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![FLOAT64_TYPE])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); - let x1 = build.add_load_const(Value::extension(ConstF64::new(1.0))); - let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); - let mut h0 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h0, ®); - assert_fully_folded_with(&h0, |v| { - v.get_custom_value::().unwrap().value() == 1.0 - }); - assert_eq!(h0.node_count(), 5); - - // HUGR computing 1.0 / 0.0 - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![FLOAT64_TYPE])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); - let x1 = build.add_load_const(Value::extension(ConstF64::new(0.0))); - let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); - let mut h1 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h1, ®); - assert_eq!(h1.node_count(), 8); - } -} - #[cfg(test)] mod int_ops_const_fold_test; diff --git a/hugr-passes/src/const_fold/int_ops_const_fold_test.rs b/hugr-passes/src/const_fold/int_ops_const_fold_test.rs index f696f5231..07b881bc7 100644 --- a/hugr-passes/src/const_fold/int_ops_const_fold_test.rs +++ b/hugr-passes/src/const_fold/int_ops_const_fold_test.rs @@ -6,16 +6,334 @@ use hugr::ops::Value; use hugr::std_extensions::arithmetic; use hugr::std_extensions::arithmetic::int_ops::IntOpDef; use hugr::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; -use hugr::std_extensions::logic::{self, NaryLogic}; +use hugr::std_extensions::logic::{self, NaryLogic, NotOp}; use hugr::type_row; use hugr::types::{FunctionType, Type, TypeRow}; use rstest::rstest; -use super::test::assert_fully_folded; - use lazy_static::lazy_static; +use super::*; +use hugr::builder::Container; +use hugr::ops::{OpType, UnpackTuple}; +use hugr::std_extensions::arithmetic::conversions::ConvertOpDef; +use hugr::std_extensions::arithmetic::float_ops::FloatOps; +use hugr::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; + +/// Check that a hugr just loads and returns a single expected constant. +pub fn assert_fully_folded(h: &Hugr, expected_value: &Value) { + assert_fully_folded_with(h, |v| v == expected_value) +} + +/// Check that a hugr just loads and returns a single constant, and validate +/// that constant using `check_value`. +/// +/// [CustomConst::equals_const] is not required to be implemented. Use this +/// function for Values containing such a `CustomConst`. +fn assert_fully_folded_with(h: &Hugr, check_value: impl Fn(&Value) -> bool) { + let mut node_count = 0; + + for node in h.children(h.root()) { + let op = h.get_optype(node); + match op { + OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1, + OpType::Const(c) if check_value(c.value()) => node_count += 1, + _ => panic!("unexpected op: {:?}\n{}", op, h.mermaid_string()), + } + } + + assert_eq!(node_count, 4); +} + +/// int to constant +fn i2c(b: u64) -> Value { + Value::extension(ConstInt::new_u(5, b).unwrap()) +} + +/// float to constant +fn f2c(f: f64) -> Value { + ConstF64::new(f).into() +} + +#[rstest] +#[case(0.0, 0.0, 0.0)] +#[case(0.0, 1.0, 1.0)] +#[case(23.5, 435.5, 459.0)] +// c = a + b +fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { + let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))]; + let add_op: OpType = FloatOps::fadd.into(); + let outs = fold_leaf_op(&add_op, &consts) + .unwrap() + .into_iter() + .map(|(p, v)| (p, v.get_custom_value::().unwrap().value())) + .collect_vec(); + + assert_eq!(outs.as_slice(), &[(0.into(), c)]); +} +#[test] +fn test_big() { + /* + Test approximately calculates + let x = (5.6, 3.2); + int(x.0 - x.1) == 2 + */ + let sum_type = sum_with_error(INT_TYPES[5].to_owned()); + let mut build = DFGBuilder::new(FunctionType::new( + type_row![], + vec![sum_type.clone().into()], + )) + .unwrap(); + + let tup = build.add_load_const(Value::tuple([f2c(5.6), f2c(3.2)])); + + let unpack = build + .add_dataflow_op( + UnpackTuple::new(type_row![FLOAT64_TYPE, FLOAT64_TYPE]), + [tup], + ) + .unwrap(); + + let sub = build + .add_dataflow_op(FloatOps::fsub, unpack.outputs()) + .unwrap(); + let to_int = build + .add_dataflow_op(ConvertOpDef::trunc_u.with_log_width(5), sub.outputs()) + .unwrap(); + + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + arithmetic::float_types::EXTENSION.to_owned(), + arithmetic::float_ops::EXTENSION.to_owned(), + arithmetic::conversions::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build + .finish_hugr_with_outputs(to_int.outputs(), ®) + .unwrap(); + assert_eq!(h.node_count(), 8); + + constant_fold_pass(&mut h, ®); + + let expected = Value::sum(0, [i2c(2).clone()], sum_type).unwrap(); + assert_fully_folded(&h, &expected); +} + +#[test] +#[cfg_attr( + feature = "extension_inference", + ignore = "inference fails for test graph, it shouldn't" +)] +fn test_list_ops() -> Result<(), Box> { + use hugr::std_extensions::collections::{self, ListOp, ListValue}; + + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + logic::EXTENSION.to_owned(), + collections::EXTENSION.to_owned(), + ]) + .unwrap(); + let list: Value = ListValue::new(BOOL_T, [Value::unit_sum(0, 1).unwrap()]).into(); + let mut build = DFGBuilder::new(FunctionType::new( + type_row![], + vec![list.get_type().clone()], + )) + .unwrap(); + + let list_wire = build.add_load_const(list.clone()); + + let pop = build.add_dataflow_op( + ListOp::Pop.with_type(BOOL_T).to_extension_op(®).unwrap(), + [list_wire], + )?; + + let push = build.add_dataflow_op( + ListOp::Push + .with_type(BOOL_T) + .to_extension_op(®) + .unwrap(), + pop.outputs(), + )?; + let mut h = build.finish_hugr_with_outputs(push.outputs(), ®)?; + constant_fold_pass(&mut h, ®); + + assert_fully_folded(&h, &list); + Ok(()) +} + +#[test] +fn test_fold_and() { + // pseudocode: + // x0, x1 := bool(true), bool(true) + // x2 := and(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::true_val()); + let x1 = build.add_load_const(Value::true_val()); + let x2 = build + .add_dataflow_op(NaryLogic::And.with_n_inputs(2), [x0, x1]) + .unwrap(); + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_or() { + // pseudocode: + // x0, x1 := bool(true), bool(false) + // x2 := or(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::true_val()); + let x1 = build.add_load_const(Value::false_val()); + let x2 = build + .add_dataflow_op(NaryLogic::Or.with_n_inputs(2), [x0, x1]) + .unwrap(); + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_not() { + // pseudocode: + // x0 := bool(true) + // x1 := not(x0) + // output x1 == false; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::true_val()); + let x1 = build.add_dataflow_op(NotOp, [x0]).unwrap(); + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); + let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::false_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn orphan_output() { + // pseudocode: + // x0 := bool(true) + // x1 := not(x0) + // x2 := or(x0,x1) + // output x2 == true; + // + // We arange things so that the `or` folds away first, leaving the not + // with no outputs. + use hugr::hugr::NodeType; + use hugr::ops::handle::NodeHandle; + + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let true_wire = build.add_load_value(Value::true_val()); + // this Not will be manually replaced + let orig_not = build.add_dataflow_op(NotOp, [true_wire]).unwrap(); + let r = build + .add_dataflow_op( + NaryLogic::Or.with_n_inputs(2), + [true_wire, orig_not.out_wire(0)], + ) + .unwrap(); + let or_node = r.node(); + let parent = build.container_node(); + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); + let mut h = build.finish_hugr_with_outputs(r.outputs(), ®).unwrap(); + + // we delete the original Not and create a new One. This means it will be + // traversed by `constant_fold_pass` after the Or. + let new_not = h.add_node_with_parent(parent, NodeType::new_auto(NotOp)); + h.connect(true_wire.node(), true_wire.source(), new_not, 0); + h.disconnect(or_node, IncomingPort::from(1)); + h.connect(new_not, 0, or_node, 1); + h.remove_node(orig_not.node()); + constant_fold_pass(&mut h, ®); + assert_fully_folded(&h, &Value::true_val()) +} + +#[test] +fn test_folding_pass_issue_996() { + // pseudocode: + // + // x0 := 3.0 + // x1 := 4.0 + // x2 := fne(x0, x1); // true + // x3 := flt(x0, x1); // true + // x4 := and(x2, x3); // true + // x5 := -10.0 + // x6 := flt(x0, x5) // false + // x7 := or(x4, x6) // true + // output x7 + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstF64::new(3.0))); + let x1 = build.add_load_const(Value::extension(ConstF64::new(4.0))); + let x2 = build.add_dataflow_op(FloatOps::fne, [x0, x1]).unwrap(); + let x3 = build.add_dataflow_op(FloatOps::flt, [x0, x1]).unwrap(); + let x4 = build + .add_dataflow_op( + NaryLogic::And.with_n_inputs(2), + x2.outputs().chain(x3.outputs()), + ) + .unwrap(); + let x5 = build.add_load_const(Value::extension(ConstF64::new(-10.0))); + let x6 = build.add_dataflow_op(FloatOps::flt, [x0, x5]).unwrap(); + let x7 = build + .add_dataflow_op( + NaryLogic::Or.with_n_inputs(2), + x4.outputs().chain(x6.outputs()), + ) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + logic::EXTENSION.to_owned(), + arithmetic::float_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x7.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_const_fold_to_nonfinite() { + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::float_types::EXTENSION.to_owned(), + ]) + .unwrap(); + + // HUGR computing 1.0 / 1.0 + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![FLOAT64_TYPE])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); + let x1 = build.add_load_const(Value::extension(ConstF64::new(1.0))); + let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); + let mut h0 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h0, ®); + assert_fully_folded_with(&h0, |v| { + v.get_custom_value::().unwrap().value() == 1.0 + }); + assert_eq!(h0.node_count(), 5); + + // HUGR computing 1.0 / 0.0 + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![FLOAT64_TYPE])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); + let x1 = build.add_load_const(Value::extension(ConstF64::new(0.0))); + let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); + let mut h1 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h1, ®); + assert_eq!(h1.node_count(), 8); +} + #[test] fn test_fold_iwiden_u() { // pseudocode: From 81cdc198cb91a80945528dc14899ea0e09716f62 Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Thu, 23 May 2024 13:12:23 +0100 Subject: [PATCH 16/28] Rename test file. --- hugr-passes/src/const_fold.rs | 2 +- .../src/const_fold/{int_ops_const_fold_test.rs => test.rs} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename hugr-passes/src/const_fold/{int_ops_const_fold_test.rs => test.rs} (100%) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index f42e1a1b6..4a124c1a4 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -227,4 +227,4 @@ pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { } #[cfg(test)] -mod int_ops_const_fold_test; +mod test; diff --git a/hugr-passes/src/const_fold/int_ops_const_fold_test.rs b/hugr-passes/src/const_fold/test.rs similarity index 100% rename from hugr-passes/src/const_fold/int_ops_const_fold_test.rs rename to hugr-passes/src/const_fold/test.rs From ffa7d9fdf7cf0cb13e9547f6bdfa36d477b959fd Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Thu, 23 May 2024 13:14:13 +0100 Subject: [PATCH 17/28] Restore module documentation. --- hugr-passes/src/lib.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 1670995e8..585e25a01 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,3 +1,5 @@ +//! Algorithms using the Hugr. + pub mod const_fold; mod half_node; pub mod merge_bbs; From 617e57197f74afbb4a47e2435f0f9a90a2770d83 Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Thu, 23 May 2024 13:16:46 +0100 Subject: [PATCH 18/28] Improve module description. --- hugr-passes/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 585e25a01..803196144 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,4 +1,4 @@ -//! Algorithms using the Hugr. +//! Compilation passes acting on the HUGR program representation. pub mod const_fold; mod half_node; From beb1d4d69651cbb4dc49163ac7163f5fb02245a5 Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Thu, 23 May 2024 13:28:32 +0100 Subject: [PATCH 19/28] Add README and CHANGELOG. --- hugr-passes/CHANGELOG.md | 5 ++++ hugr-passes/README.md | 58 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 hugr-passes/CHANGELOG.md create mode 100644 hugr-passes/README.md diff --git a/hugr-passes/CHANGELOG.md b/hugr-passes/CHANGELOG.md new file mode 100644 index 000000000..4818de914 --- /dev/null +++ b/hugr-passes/CHANGELOG.md @@ -0,0 +1,5 @@ +# Changelog + +## 0.1.0 (2024-05-23) + +Initial release, with functions ported from the `hugr::algorithms` module. diff --git a/hugr-passes/README.md b/hugr-passes/README.md new file mode 100644 index 000000000..29b17e2af --- /dev/null +++ b/hugr-passes/README.md @@ -0,0 +1,58 @@ +![](/hugr/assets/hugr_logo.svg) + +hugr-passes +=============== + +[![build_status][]](https://github.com/CQCL/hugr/actions) +[![crates][]](https://crates.io/crates/hugr-passes) +[![msrv][]](https://github.com/CQCL/hugr) +[![codecov][]](https://codecov.io/gh/CQCL/hugr) + +The Hierarchical Unified Graph Representation (HUGR, pronounced _hugger_) is the +common representation of quantum circuits and operations in the Quantinuum +ecosystem. + +It provides a high-fidelity representation of operations, that facilitates +compilation and encodes runnable programs. + +The HUGR specification is [here](https://github.com/CQCL/hugr/blob/main/specification/hugr.md). + +This crate provides compilation passes that act on HUGR programs. + +## Usage + +Add the dependency to your project: + +```bash +cargo add hugr-passes +``` + +Please read the [API documentation here][]. + +## Experimental Features + +- `extension_inference`: + Experimental feature which allows automatic inference of extension usages and + requirements in a HUGR and validation that extensions are correctly specified. + Not enabled by default. + +## Recent Changes + +See [CHANGELOG][] for a list of changes. The minimum supported rust +version will only change on major releases. + +## Development + +See [DEVELOPMENT.md](https://github.com/CQCL/hugr/blob/main/DEVELOPMENT.md) for instructions on setting up the development environment. + +## License + +This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http://www.apache.org/licenses/LICENSE-2.0). + + [API documentation here]: https://docs.rs/hugr/ + [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main + [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg + [crates]: https://img.shields.io/crates/v/hugr-passes + [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov + [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE + [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/CHANGELOG.md From c56a57492bffc665d19bd8681c1c37d30c164754 Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Thu, 23 May 2024 13:31:10 +0100 Subject: [PATCH 20/28] Add hugr-passes to release-plz config. --- release-plz.toml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/release-plz.toml b/release-plz.toml index f6cfe39de..0fe9dd7e5 100644 --- a/release-plz.toml +++ b/release-plz.toml @@ -19,3 +19,9 @@ name = "hugr" # Enable the changelog for this package changelog_update = true + +[[package]] +name = "hugr-passes" + +# Enable the changelog for this package +changelog_update = true From e0991219f32b6a64b89343cf3e0cb2953a40a175 Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Thu, 23 May 2024 13:32:11 +0100 Subject: [PATCH 21/28] Add `git_release_name` to release-plz config. --- release-plz.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/release-plz.toml b/release-plz.toml index 0fe9dd7e5..c17abd90a 100644 --- a/release-plz.toml +++ b/release-plz.toml @@ -14,6 +14,8 @@ changelog_update = false # (This would normally only be enabled once there are multiple packages in the workspace) git_tag_name = "{{ package }}-v{{ version }}" +git_release_name = "{{ package }}: v{{ version }}" + [[package]] name = "hugr" From f27c71fabe9acbb911f878aec09893b873a117fb Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Thu, 23 May 2024 13:41:11 +0100 Subject: [PATCH 22/28] Extend package metadata. --- hugr-passes/Cargo.toml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 727e605a3..f336b5061 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -2,6 +2,15 @@ name = "hugr-passes" version = "0.1.0" edition = "2021" +rust-version = { workspace = true } +license = { workspace = true } +readme = "README.md" +documentation = "https://docs.rs/hugr-passes/" +homepage = { workspace = true } +repository = { workspace = true } +description = "Compiler passes for Quantinuum's HUGR" +keywords = ["Quantum", "Quantinuum"] +categories = ["compilers"] [dependencies] hugr = { path = "../hugr" } From ad95736913bdac73e61cc69699bd21ae6e15bc22 Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Thu, 23 May 2024 13:45:21 +0100 Subject: [PATCH 23/28] Specify hugr version requirement for publishing. --- hugr-passes/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index f336b5061..f62882235 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -13,7 +13,7 @@ keywords = ["Quantum", "Quantinuum"] categories = ["compilers"] [dependencies] -hugr = { path = "../hugr" } +hugr = { path = "../hugr", version = "0.4.0" } itertools = { workspace = true } lazy_static = { workspace = true } paste = { workspace = true } From e0ebb9fa5fa938768a590b241d680c36619fa68a Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Thu, 23 May 2024 13:46:49 +0100 Subject: [PATCH 24/28] Add `CHANGELOG.md` to `CODEOWNERS`. --- .github/CODEOWNERS | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index e260a89b9..1071de1ca 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -9,4 +9,5 @@ # The release PRs that trigger publication to crates.io or PyPI always modify the changelog. # We require those PRs to be approved by someone with release permissions. hugr/CHANGELOG.md @aborgna-q @ss2165 +hugr-passes/CHANGELOG.md @aborgna-q @ss2165 hugr-py/CHANGELOG.md @aborgna-q @ss2165 From 11798866f33eb82ed4f6d14c4f265bfef908f97b Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Thu, 23 May 2024 13:47:45 +0100 Subject: [PATCH 25/28] Add new directory to `rust` filter. --- .github/change-filters.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/change-filters.yml b/.github/change-filters.yml index 9fb256c3d..ebf6b45e7 100644 --- a/.github/change-filters.yml +++ b/.github/change-filters.yml @@ -3,6 +3,7 @@ rust: - "hugr/**" + - "hugr-passes/**" - "Cargo.toml" - "specification/schema/**" From be0569973f75bdf408f9ccce02a283c78ead9ea3 Mon Sep 17 00:00:00 2001 From: Alec Edgington <54802828+cqc-alec@users.noreply.github.com> Date: Thu, 23 May 2024 13:57:05 +0100 Subject: [PATCH 26/28] Update hugr-passes/README.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Agustín Borgna <121866228+aborgna-q@users.noreply.github.com> --- hugr-passes/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-passes/README.md b/hugr-passes/README.md index 29b17e2af..bdb566181 100644 --- a/hugr-passes/README.md +++ b/hugr-passes/README.md @@ -49,10 +49,10 @@ See [DEVELOPMENT.md](https://github.com/CQCL/hugr/blob/main/DEVELOPMENT.md) for This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http://www.apache.org/licenses/LICENSE-2.0). - [API documentation here]: https://docs.rs/hugr/ + [API documentation here]: https://docs.rs/hugr-passes/ [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg [crates]: https://img.shields.io/crates/v/hugr-passes [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE - [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/CHANGELOG.md + [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr-passes/CHANGELOG.md From 8105f02726695784781376b7818a77f8a72ae4eb Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Thu, 23 May 2024 14:09:06 +0100 Subject: [PATCH 27/28] Remove top-level CHANGELOG symlink and fix up link from README. --- CHANGELOG.md | 1 - hugr/README.md | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) delete mode 120000 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md deleted file mode 120000 index 729dc652a..000000000 --- a/CHANGELOG.md +++ /dev/null @@ -1 +0,0 @@ -hugr/CHANGELOG.md \ No newline at end of file diff --git a/hugr/README.md b/hugr/README.md index 96c823016..88a61035b 100644 --- a/hugr/README.md +++ b/hugr/README.md @@ -53,4 +53,4 @@ This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http: [crates]: https://img.shields.io/crates/v/hugr [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE - [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/CHANGELOG.md + [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr/CHANGELOG.md From 8840073dc3330746f9be5d318ff5230c3dc439c9 Mon Sep 17 00:00:00 2001 From: Alec Edgington Date: Thu, 23 May 2024 14:11:26 +0100 Subject: [PATCH 28/28] Get package.edition from workspace. --- hugr-passes/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index f62882235..688d32e48 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "hugr-passes" version = "0.1.0" -edition = "2021" +edition = { workspace = true } rust-version = { workspace = true } license = { workspace = true } readme = "README.md"