From 067bf8da097389a7baeed77997165468991bcc76 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Jun 2024 08:39:49 +0100 Subject: [PATCH] tidying --- hugr-core/src/partial_value/test.rs | 8 +- hugr-passes/Cargo.toml | 2 + hugr-passes/src/const_fold2/datalog.rs | 50 +++--- hugr-passes/src/const_fold2/datalog/test.rs | 37 ++-- hugr-passes/src/const_fold2/datalog/utils.rs | 173 +++++++++++++++++-- 5 files changed, 216 insertions(+), 54 deletions(-) diff --git a/hugr-core/src/partial_value/test.rs b/hugr-core/src/partial_value/test.rs index f31c33642..35fbf5373 100644 --- a/hugr-core/src/partial_value/test.rs +++ b/hugr-core/src/partial_value/test.rs @@ -316,15 +316,19 @@ proptest! { #[test] fn bounded_lattice(v in any_partial_value()) { - prop_assert!(v <= PartialValue::Top); - prop_assert!(v >= PartialValue::Bottom); + prop_assert!(v <= PartialValue::top()); + prop_assert!(v >= PartialValue::bottom()); } #[test] fn meet_join_self_noop(v1 in any_partial_value()) { let mut subject = v1.clone(); + + assert_eq!(v1.clone(), v1.clone().join(v1.clone())); assert!(!subject.join_mut(v1.clone())); assert_eq!(subject, v1); + + assert_eq!(v1.clone(), v1.clone().meet(v1.clone())); assert!(!subject.meet_mut(v1.clone())); assert_eq!(subject, v1); } diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 06d9975ea..92967f9c9 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -27,3 +27,5 @@ extension_inference = ["hugr-core/extension_inference"] [dev-dependencies] rstest = { workspace = true } +proptest = { workspace = true } +proptest-derive = { workspace = true } diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index b300fdcc0..98e80d0a9 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -58,6 +58,10 @@ ascent::ascent! { node_in_value_row(c, n, utils::singleton_in_row(c, n, p, v.clone())) <-- in_wire_value(c, n, p, v); + // Per node-type rules + // TODO do all leaf ops with a rule + // define `fn propagate_leaf_op(Context, Node, ValueRow) -> ValueRow + // LoadConstant relation load_constant_node(C, Node); load_constant_node(c, n) <-- node(c, n), if c.hugr().get_optype(*n).is_load_constant(); @@ -104,29 +108,35 @@ ascent::ascent! { io_node(c,tl,i, IO::Input), in_wire_value(c, tl, p, v); // Output node of child region propagate to Input node of child region - out_wire_value(c, i, input_p, v) <-- tail_loop_node(c, tl), - io_node(c,tl,i, IO::Input), - io_node(c,tl,o, IO::Output), - in_wire_value(c, o, output_p, output_v), - if let Some(tailloop) = c.hugr().get_optype(*tl).as_tail_loop(), + out_wire_value(c, in_n, out_p, v) <-- tail_loop_node(c, tl_n), + io_node(c,tl_n,in_n, IO::Input), + io_node(c,tl_n,out_n, IO::Output), + node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node + if out_in_row[0].supports_tag(0), // if it is possible for tag to be 0 + if let Some(tailloop) = c.hugr().get_optype(*tl_n).as_tail_loop(), let variant_len = tailloop.just_inputs.len(), - for (input_p, v) in utils::tail_loop_worker(*output_p, 0, variant_len, output_v); + for (out_p, v) in out_in_row.iter(c, *out_n).flat_map( + |(input_p, v)| utils::outputs_for_variant(input_p, 0, variant_len, v) + ); // Output node of child region propagate to outputs of tail loop - out_wire_value(c, tl, p, v) <-- tail_loop_node(c, tl), - io_node(c,tl,o, IO::Output), - in_wire_value(c, o, output_p, output_v), - if let Some(tailloop) = c.hugr().get_optype(*tl).as_tail_loop(), + out_wire_value(c, tl_n, out_p, v) <-- tail_loop_node(c, tl_n), + io_node(c,tl_n,out_n, IO::Output), + node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node + if out_in_row[0].supports_tag(1), // if it is possible for the tag to be 1 + if let Some(tailloop) = c.hugr().get_optype(*tl_n).as_tail_loop(), let variant_len = tailloop.just_outputs.len(), - for (p, v) in utils::tail_loop_worker(*output_p, 1, variant_len, output_v); + for (out_p, v) in out_in_row.iter(c, *out_n).flat_map( + |(input_p, v)| utils::outputs_for_variant(input_p, 1, variant_len, v) + ); - lattice tail_loop_termination(C,Node,OrdLattice); - tail_loop_termination(c,tl,TailLoopTermination::NeverTerminates.into()) <-- - tail_loop_node(c,tl); - tail_loop_termination(c,tl,TailLoopTermination::from_control_value(v).into()) <-- - tail_loop_node(c,tl), - io_node(c,tl,o, IO::Output), - in_wire_value(c, o, Into::::into(0usize), v); + lattice tail_loop_termination(C,Node,TailLoopTermination); + tail_loop_termination(c,tl_n,TailLoopTermination::bottom()) <-- + tail_loop_node(c,tl_n); + tail_loop_termination(c,tl_n,TailLoopTermination::from_control_value(v)) <-- + tail_loop_node(c,tl_n), + io_node(c,tl,out_n, IO::Output), + in_wire_value(c, out_n, IncomingPort::from(0), v); // Conditional @@ -145,7 +155,7 @@ ascent::ascent! { in_wire_value(c, cond, cond_in_p, cond_in_v), if let Some(conditional) = c.hugr().get_optype(*cond).as_conditional(), let variant_len = conditional.sum_rows[*case_index].len(), - for (i_p, v) in utils::tail_loop_worker(*cond_in_p, *case_index, variant_len, cond_in_v); + for (i_p, v) in utils::outputs_for_variant(*cond_in_p, *case_index, variant_len, cond_in_v); // outputs of case nodes propagate to outputs of conditional out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <-- @@ -219,7 +229,7 @@ impl<'a, H: HugrView> Machine<'a, H> { self.program .tail_loop_termination .iter() - .find_map(|(c, n, v)| (c == context && n == &node).then_some(v.0.clone())) + .find_map(|(c, n, v)| (c == context && n == &node).then_some(*v)) .unwrap() } diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index 563c83130..e35ee0a47 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -86,7 +86,7 @@ fn test_tail_loop_never_iterates() { let o_r = machine.read_out_wire_value(&c, tl_o).unwrap(); assert_eq!(o_r, r_v); assert_eq!( - TailLoopTermination::SingleIteration, + TailLoopTermination::ExactlyZeroContinues, machine.tail_loop_terminates(&c, tail_loop.node()) ) } @@ -96,22 +96,29 @@ fn test_tail_loop_always_iterates() { let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); let r_w = builder .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); + let true_w = builder.add_load_value(Value::true_val()); + let tlb = builder - .tail_loop_builder([], [], vec![BOOL_T].into()) + .tail_loop_builder([], [(BOOL_T,true_w)], vec![BOOL_T].into()) .unwrap(); - let tail_loop = tlb.finish_with_outputs(r_w, []).unwrap(); - let [tl_o] = tail_loop.outputs_arr(); + + // r_w has tag 0, so we always continue; + // we put true in our "other_output", but we should not propagate this to + // output because r_w never supports 1. + let tail_loop = tlb.finish_with_outputs(r_w, [true_w]).unwrap(); + + let [tl_o1, tl_o2] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); let c = machine.run_hugr(&hugr); - // dbg!(&machine.tail_loop_io_node); - // dbg!(&machine.out_wire_value); - let o_r = machine.read_out_wire_partial_value(&c, tl_o).unwrap(); - assert_eq!(o_r, PartialValue::Bottom); + let o_r1 = machine.read_out_wire_partial_value(&c, tl_o1).unwrap(); + assert_eq!(o_r1, PartialValue::bottom()); + let o_r2 = machine.read_out_wire_partial_value(&c, tl_o2).unwrap(); + assert_eq!(o_r2, PartialValue::bottom()); assert_eq!( - TailLoopTermination::NeverTerminates, + TailLoopTermination::bottom(), machine.tail_loop_terminates(&c, tail_loop.node()) ) } @@ -146,20 +153,20 @@ fn test_tail_loop_iterates_twice() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); // TODO once we can do conditionals put these wires inside `just_outputs` and // we should be able to propagate their values - // let [o_w1, o_w2, _] = tail_loop.outputs_arr(); + let [o_w1, o_w2, _] = tail_loop.outputs_arr(); let mut machine = Machine::new(); let c = machine.run_hugr(&hugr); // dbg!(&machine.tail_loop_io_node); // dbg!(&machine.out_wire_value); - // TODO these hould be the propagated values - // let o_r1 = machine.read_out_wire_value(&c, o_w1).unwrap(); - // assert_eq!(o_r1, Value::false_val()); - // let o_r2 = machine.read_out_wire_value(&c, o_w2).unwrap(); + // TODO these hould be the propagated values for now they will bt join(true,false) + let o_r1 = machine.read_out_wire_partial_value(&c, o_w1).unwrap(); + // assert_eq!(o_r1, PartialValue::top()); + let o_r2 = machine.read_out_wire_partial_value(&c, o_w2).unwrap(); // assert_eq!(o_r2, Value::true_val()); assert_eq!( - TailLoopTermination::Terminates, + TailLoopTermination::Top, machine.tail_loop_terminates(&c, tail_loop.node()) ) } diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 9a63bab1e..00162e73b 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -1,4 +1,11 @@ -use ascent::lattice::{ord_lattice::OrdLattice, BoundedLattice, Dual, Lattice}; +// proptest-derive generates many of these warnings. +// https://github.com/rust-lang/rust/issues/120363 +// https://github.com/proptest-rs/proptest/issues/447 +#![cfg_attr(test, allow(non_local_definitions))] + +use std::{cmp::Ordering, ops::Index}; + +use ascent::lattice::{BoundedLattice, Lattice}; use either::Either; use hugr_core::{ ops::OpTrait as _, @@ -8,6 +15,9 @@ use hugr_core::{ }; use itertools::zip_eq; +#[cfg(test)] +use proptest_derive::Arbitrary; + use super::context::DFContext; #[derive(PartialEq, Eq, Hash, PartialOrd, Clone, Debug)] @@ -98,7 +108,7 @@ impl ValueRow { Self::new(r.len()) } - fn iter<'b>( + pub fn iter<'b>( &'b self, context: &'b impl DFContext, n: Node, @@ -151,6 +161,16 @@ impl IntoIterator for ValueRow { } } +impl Index for ValueRow where + Vec: Index +{ + type Output = as Index>::Output; + + fn index(&self, index: Idx) -> &Self::Output { + self.0.index(index) + } +} + pub(super) fn bottom_row(context: &impl DFContext, n: Node) -> ValueRow { if let Some(sig) = context.hugr().signature(n) { ValueRow::new(sig.input_count()) @@ -216,9 +236,26 @@ pub(super) fn value_outputs( context.hugr().out_value_types(n).map(|x| x.0) } -// TODO rename, this is about expanding input variants into output rows -// todo this should work for dataflowblocks too -pub(super) fn tail_loop_worker<'a>( +// We have several cases where sum types propagate to different places depending +// on their variant tag: +// - From the input of a conditional to the inputs of it's case nodes +// - From the input of the output node of a tail loop to the output of the input node of the tail loop +// - From the input of the output node of a tail loop to the output of tail loop node +// - From the input of a the output node of a dataflow block to the output of the input node of a dataflow block +// - From the input of a the output node of a dataflow block to the output of the cfg +// +// For a value `v` on an incoming porg `output_p`, compute the (out port,value) +// pairs that should be propagated for a given variant tag. We must also supply +// the length of this variant because it cannot always be deduced from the other +// inputs. +// +// If `v` does not support `variant_tag`, then all propagated values will be bottom.` +// +// If `output_p.index()` is 0 then the result is the contents of the variant. +// Otherwise, it is the single "other_output". +// +// TODO doctests +pub(super) fn outputs_for_variant<'a>( output_p: IncomingPort, variant_tag: usize, variant_len: usize, @@ -242,27 +279,129 @@ pub(super) fn tail_loop_worker<'a>( } } -#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Clone)] +#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] +#[cfg_attr(test, derive(Arbitrary))] pub enum TailLoopTermination { - NeverTerminates, - SingleIteration, - Terminates, + Bottom, + ExactlyZeroContinues, + Top, } impl TailLoopTermination { pub fn from_control_value(v: &PV) -> Self { - if v.supports_tag(1) && !v.supports_tag(0) { - Self::SingleIteration - } else if v.supports_tag(1) { - Self::Terminates + let (may_continue, may_break) = (v.supports_tag(0),v.supports_tag(1)); + if may_break && !may_continue { + Self::ExactlyZeroContinues + } else if may_break && may_continue { + Self::top() } else { - Self::NeverTerminates + Self::bottom() + } + } +} + +impl PartialOrd for TailLoopTermination { + fn partial_cmp(&self, other: &Self) -> Option { + if self == other { + return Some(std::cmp::Ordering::Equal); + }; + match (self, other) { + (Self::Bottom,_) => Some(Ordering::Less), + (_,Self::Bottom) => Some(Ordering::Greater), + (Self::Top,_) => Some(Ordering::Greater), + (_,Self::Top) => Some(Ordering::Less), + _ => None } } } -impl From for OrdLattice { - fn from(value: TailLoopTermination) -> Self { - Self(value) +impl Lattice for TailLoopTermination { + fn meet(mut self, other: Self) -> Self { + self.meet_mut(other); + self + } + + fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + fn meet_mut(&mut self, other: Self) -> bool { + // let new_self = &mut self; + match (*self).partial_cmp(&other) { + Some(Ordering::Greater) => { + *self = other; + true + } + Some(_) => false, + _ => { + *self = Self::Bottom; + true + } + } + } + + fn join_mut(&mut self, other: Self) -> bool { + match (*self).partial_cmp(&other) { + Some(Ordering::Less) => { + *self = other; + true + } + Some(_) => false, + _ => { + *self = Self::Top; + true + } + } + } +} + +impl BoundedLattice for TailLoopTermination { + fn bottom() -> Self { + Self::Bottom + + } + + fn top() -> Self { + Self::Top + } +} + +#[cfg(test)] +#[cfg_attr(test, allow(non_local_definitions))] +mod test { + use super::*; + use proptest::prelude::*; + + proptest! { + #[test] + fn bounded_lattice(v: TailLoopTermination) { + prop_assert!(v <= TailLoopTermination::top()); + prop_assert!(v >= TailLoopTermination::bottom()); + } + + #[test] + fn meet_join_self_noop(v1: TailLoopTermination) { + let mut subject = v1.clone(); + + assert_eq!(v1.clone(), v1.clone().join(v1.clone())); + assert!(!subject.join_mut(v1.clone())); + assert_eq!(subject, v1); + + assert_eq!(v1.clone(), v1.clone().meet(v1.clone())); + assert!(!subject.meet_mut(v1.clone())); + assert_eq!(subject, v1); + } + + #[test] + fn lattice(v1: TailLoopTermination, v2: TailLoopTermination) { + let meet = v1.clone().meet(v2.clone()); + prop_assert!(meet <= v1, "meet not less <=: {:#?}", &meet); + prop_assert!(meet <= v2, "meet not less <=: {:#?}", &meet); + + let join = v1.clone().join(v2.clone()); + prop_assert!(join >= v1, "join not >=: {:#?}", &join); + prop_assert!(join >= v2, "join not >=: {:#?}", &join); + } } }