From 36e71e0923de6cf5a776b5dabb722e5f171f4002 Mon Sep 17 00:00:00 2001 From: Douglas Wilson <141026920+doug-q@users.noreply.github.com> Date: Tue, 16 Jul 2024 09:57:24 +0100 Subject: [PATCH] fix!: force_order failing on Const nodes, add arg to rank. (#1300) BREAKING CHANGE: the `rank` argument of `force_order` takes an additional argument. --- hugr-passes/src/force_order.rs | 76 ++++++++++++++++++++++++---------- 1 file changed, 53 insertions(+), 23 deletions(-) diff --git a/hugr-passes/src/force_order.rs b/hugr-passes/src/force_order.rs index 522e40875..c2f7b9221 100644 --- a/hugr-passes/src/force_order.rs +++ b/hugr-passes/src/force_order.rs @@ -1,5 +1,5 @@ //! Provides [force_order], a tool for fixing the order of nodes in a Hugr. -use std::{cmp::Reverse, collections::BinaryHeap}; +use std::{cmp::Reverse, collections::BinaryHeap, iter}; use hugr_core::{ hugr::{ @@ -7,9 +7,9 @@ use hugr_core::{ views::{DescendantsGraph, HierarchyView, SiblingGraph}, HugrError, }, - ops::{OpTag, OpTrait}, + ops::{NamedOp, OpTag, OpTrait}, types::EdgeKind, - Direction, HugrView as _, Node, + HugrView as _, Node, }; use itertools::Itertools as _; use petgraph::{ @@ -36,45 +36,58 @@ use petgraph::{ /// there is no path from `n2` to `n1` (otherwise this would invalidate `hugr`). /// Nodes of equal rank will be ordered arbitrarily, although that arbitrary /// order is deterministic. -pub fn force_order( - hugr: &mut impl HugrMut, +pub fn force_order( + hugr: &mut H, root: Node, - rank: impl Fn(Node) -> i64, + rank: impl Fn(&H, Node) -> i64, ) -> Result<(), HugrError> { force_order_by_key(hugr, root, rank) } /// As [force_order], but allows a generic [Ord] choice for the result of the /// `rank` function. -pub fn force_order_by_key( - hugr: &mut impl HugrMut, +pub fn force_order_by_key( + hugr: &mut H, root: Node, - rank: impl Fn(Node) -> K, + rank: impl Fn(&H, Node) -> K, ) -> Result<(), HugrError> { let dataflow_parents = DescendantsGraph::::try_new(hugr, root)? .nodes() .filter(|n| hugr.get_optype(*n).tag() <= OpTag::DataflowParent) .collect_vec(); for dp in dataflow_parents { + // we filter out the input and output nodes from the topological sort + let [i, o] = hugr.get_io(dp).unwrap(); + let rank = |n| rank(hugr, n); let sg = SiblingGraph::::try_new(hugr, dp)?; - let petgraph = NodeFiltered::from_fn(sg.as_petgraph(), |x| x != dp); + let petgraph = NodeFiltered::from_fn(sg.as_petgraph(), |x| x != dp && x != i && x != o); let ordered_nodes = ForceOrder::new(&petgraph, &rank) .iter(&petgraph) - .filter(|&x| hugr.get_optype(x).tag() <= OpTag::DataflowChild) + .filter(|&x| { + let expected_edge = Some(EdgeKind::StateOrder); + let optype = hugr.get_optype(x); + if optype.other_input() == expected_edge || optype.other_output() == expected_edge { + assert_eq!( + optype.other_input(), + optype.other_output(), + "Optype does not have both input and output order edge: {}", + optype.name() + ); + true + } else { + false + } + }) .collect_vec(); - for (&n1, &n2) in ordered_nodes.iter().tuple_windows() { + // we iterate over the topologically sorted nodes, prepending the input + // node and suffixing the output node. + for (&n1, &n2) in iter::once(&i) + .chain(ordered_nodes.iter()) + .chain(iter::once(&o)) + .tuple_windows() + { let (n1_ot, n2_ot) = (hugr.get_optype(n1), hugr.get_optype(n2)); - assert_eq!( - Some(EdgeKind::StateOrder), - n1_ot.other_port_kind(Direction::Outgoing), - "Node {n1} does not support state order edges" - ); - assert_eq!( - Some(EdgeKind::StateOrder), - n2_ot.other_port_kind(Direction::Incoming), - "Node {n2} does not support state order edges" - ); if !hugr.output_neighbours(n1).contains(&n2) { hugr.connect( n1, @@ -192,10 +205,13 @@ mod test { use super::*; use hugr_core::builder::{endo_ft, BuildHandle, Dataflow, DataflowHugr}; + use hugr_core::extension::EMPTY_REG; use hugr_core::ops::handle::{DataflowOpID, NodeHandle}; + use hugr_core::ops::Value; use hugr_core::std_extensions::arithmetic::int_ops::{self, IntOpDef}; use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; + use hugr_core::types::{FunctionType, Type}; use hugr_core::{builder::DFGBuilder, hugr::Hugr}; use hugr_core::{HugrView, Wire}; @@ -257,7 +273,7 @@ mod test { type RankMap = HashMap; fn force_order_test_impl(hugr: &mut Hugr, rank_map: RankMap) -> Vec { - force_order(hugr, hugr.root(), |n| *rank_map.get(&n).unwrap_or(&0)).unwrap(); + force_order(hugr, hugr.root(), |_, n| *rank_map.get(&n).unwrap_or(&0)).unwrap(); let topo_sorted = Topo::new(&hugr.as_petgraph()) .iter(&hugr.as_petgraph()) @@ -303,4 +319,18 @@ mod test { let topo_sort = force_order_test_impl(&mut hugr, rank_map); assert_eq!(vec![v0, v1, v2, v3], topo_sort); } + + #[test] + fn test_force_order_const() { + let mut hugr = { + let mut builder = + DFGBuilder::new(FunctionType::new(Type::EMPTY_TYPEROW, Type::UNIT)).unwrap(); + let unit = builder.add_load_value(Value::unary_unit_sum()); + builder + .finish_hugr_with_outputs([unit], &EMPTY_REG) + .unwrap() + }; + let root = hugr.root(); + force_order(&mut hugr, root, |_, _| 0).unwrap(); + } }